From 0b283f6b8c752338bf75bd7260e92ca272d859ee Mon Sep 17 00:00:00 2001 From: martin <1486756632@qq.com> Date: Wed, 15 Apr 2026 16:12:24 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=87=8F=E6=95=B0=E9=87=8F=E8=B6=85?= =?UTF-8?q?=E9=99=90=E6=8A=A5=E8=AD=A6=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- middleware/query_model.py | 3 ++ yolo/cv_multi_model_back_video.py | 27 +++++++++++---- ...ulti_yolo_trt_detect_track_trt10_yolo11.py | 34 +++++++++---------- yolo_api.py | 6 ++-- 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/middleware/query_model.py b/middleware/query_model.py index 8536864..3c64ef1 100644 --- a/middleware/query_model.py +++ b/middleware/query_model.py @@ -46,6 +46,7 @@ class ModelData: so_path: str repeat_dis: float repeat_time: float + high_count_warn: float func_description: Optional[str] filter_indices: List[int] class_indices: List[int] @@ -250,6 +251,7 @@ class ModelConfigDAO: aml.py_func, aml.repeat_dis, aml.repeat_time, + aml.high_count_warn, am.scope, am.yolo_version, am.PATH, @@ -572,6 +574,7 @@ WHERE filter_indices=filter_indices, repeat_dis=repeat_dis, repeat_time=row.get('repeat_time'), + high_count_warn=row.get('high_count_warn'), class_indices=row['cls_index'], conf=conf, classes=classes, diff --git a/yolo/cv_multi_model_back_video.py b/yolo/cv_multi_model_back_video.py index 561583f..656bc88 100644 --- a/yolo/cv_multi_model_back_video.py +++ b/yolo/cv_multi_model_back_video.py @@ -10,6 +10,7 @@ from dataclasses import dataclass import json import time +from sympy import false from ultralytics import YOLO import torch @@ -2365,7 +2366,7 @@ async def process_frames(detector: MultiYoloTrtDetectorTrackId, cancel_flag: asy time_pr_start = time.time_ns() detections, detections_list, model_para = await detector.predict(frame) time_pr_end = time.time_ns() - print(f"time_pr_starttime_pr_start {(time_pr_end - time_pr_start) / 1000000}") + # print(f"time_pr_starttime_pr_start {(time_pr_end - time_pr_start) / 1000000}") predict_state = True if detections: print("检测到任何目标") @@ -2635,8 +2636,8 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: # 如果开起侵限功能,就只显示侵限内的框 point_x = (x1 + x2) / 2 point_y = (y1 + y2) / 2 - print(f"class_name--{class_name}") - print(f"model_class_names: {model_para[0]['model_class_names']}") + # print(f"class_name--{class_name}") + # print(f"model_class_names: {model_para[0]['model_class_names']}") if class_name not in model_para[0]["model_class_names"]: continue @@ -3298,7 +3299,7 @@ invade_cache_lock = Lock() # 用于保护共享变量的锁 async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event, cv_frame_queue: asyncio.Queue, event_queue: asyncio.Queue = None, - device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1): + device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1,high_count_warn: float = -1): global stats start_time = time.time() # executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS) @@ -3383,7 +3384,15 @@ async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, ca should_report = True print(f"target_pointtarget_point {len(target_point)}") count_item = 0 - des_location_result=[] + des_location_result=[] + + high_count_warn_status=False + high_count_warn_num=0 + + if target_point is not None and 0 < high_count_warn < len(target_point):# 触发计数报警 + high_count_warn_num=len(target_point) + high_count_warn_status=True + for item in target_point: # # 跳过无效的track_id # 检查是否应该上报该track_id @@ -3560,6 +3569,10 @@ async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, ca "latitude": cam_latitude }, "count_message":count_message, + "high_count_warn":{ + "high_count_warn_status":high_count_warn_status, + "high_count_warn_num":high_count_warn_num + }, "des_location":des_location_result } await event_queue.put({ @@ -3988,7 +4001,7 @@ async def start_rtmp_processing(video_url: str, task_id: str, model_configs: Lis mqtt_sub_ip: str, mqtt_sub_port: int, mqtt_sub_topic: str, output_rtmp_url: str, invade_enable: bool, invade_file: str, camera_para_url: str, - device_height: float, repeat_dis: float, repeat_time: float): + device_height: float, repeat_dis: float, repeat_time: float,high_count_warn: float): # 初始化资源 # await initialize_resources() logger.info(f"拉流地址{video_url}") @@ -4153,7 +4166,7 @@ async def start_rtmp_processing(video_url: str, task_id: str, model_configs: Lis upload_task = asyncio.create_task( send_frame_to_s3_mq(loop, upload_executor, task_id, mqtt, mqtt_pub_topic, cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis, - repeat_time), + repeat_time,high_count_warn), name=f"send_frame_to_s3_mq_{_}" ) upload_tasks.append(upload_task) diff --git a/yolo/detect/multi_yolo_trt_detect_track_trt10_yolo11.py b/yolo/detect/multi_yolo_trt_detect_track_trt10_yolo11.py index 4326f6f..26f30ff 100644 --- a/yolo/detect/multi_yolo_trt_detect_track_trt10_yolo11.py +++ b/yolo/detect/multi_yolo_trt_detect_track_trt10_yolo11.py @@ -470,13 +470,13 @@ class YoLo11TRT(object): # 重塑为二维数组 pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :] - # 调试:打印输出信息 - if num > 0: - logger.info(f"检测到 {num} 个目标,每个目标 {num_values_per_detection} 个值") - logger.info(f"第一个检测框原始值: center_x={pred[0, 0]:.2f}, center_y={pred[0, 1]:.2f}, " - f"width={pred[0, 2]:.2f}, height={pred[0, 3]:.2f}, conf={pred[0, 4]:.4f}") - logger.info(f"原始图片尺寸: {origin_w}x{origin_h}") - logger.info(f"缩放比例: r_w={self.input_w / origin_w:.4f}, r_h={self.input_h / origin_h:.4f}") + # # 调试:打印输出信息 + # if num > 0: + # logger.info(f"检测到 {num} 个目标,每个目标 {num_values_per_detection} 个值") + # logger.info(f"第一个检测框原始值: center_x={pred[0, 0]:.2f}, center_y={pred[0, 1]:.2f}, " + # f"width={pred[0, 2]:.2f}, height={pred[0, 3]:.2f}, conf={pred[0, 4]:.4f}") + # logger.info(f"原始图片尺寸: {origin_w}x{origin_h}") + # logger.info(f"缩放比例: r_w={self.input_w / origin_w:.4f}, r_h={self.input_h / origin_h:.4f}") # 执行NMS boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD) @@ -521,20 +521,20 @@ class YoLo11TRT(object): return np.array([]) # 调试:打印转换前的坐标 - if len(boxes) > 0: - logger.info(f"NMS前第一个框: center_x={boxes[0, 0]:.2f}, center_y={boxes[0, 1]:.2f}, " - f"width={boxes[0, 2]:.2f}, height={boxes[0, 3]:.2f}, conf={boxes[0, 4]:.4f}") + # if len(boxes) > 0: + # logger.info(f"NMS前第一个框: center_x={boxes[0, 0]:.2f}, center_y={boxes[0, 1]:.2f}, " + # f"width={boxes[0, 2]:.2f}, height={boxes[0, 3]:.2f}, conf={boxes[0, 4]:.4f}") # 关键修复:使用正确的xywh2xyxy函数 boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4]) - # 调试:打印转换后的坐标 - if len(boxes) > 0: - logger.info(f"坐标转换后第一个框: x1={boxes[0, 0]:.2f}, y1={boxes[0, 1]:.2f}, " - f"x2={boxes[0, 2]:.2f}, y2={boxes[0, 3]:.2f}") - logger.info(f"框尺寸: {boxes[0, 2] - boxes[0, 0]:.1f}x{boxes[0, 3] - boxes[0, 1]:.1f}") - logger.info( - f"框占图片比例: 宽度={100 * (boxes[0, 2] - boxes[0, 0]) / origin_w:.1f}%, 高度={100 * (boxes[0, 3] - boxes[0, 1]) / origin_h:.1f}%") + # # 调试:打印转换后的坐标 + # if len(boxes) > 0: + # logger.info(f"坐标转换后第一个框: x1={boxes[0, 0]:.2f}, y1={boxes[0, 1]:.2f}, " + # f"x2={boxes[0, 2]:.2f}, y2={boxes[0, 3]:.2f}") + # logger.info(f"框尺寸: {boxes[0, 2] - boxes[0, 0]:.1f}x{boxes[0, 3] - boxes[0, 1]:.1f}") + # logger.info( + # f"框占图片比例: 宽度={100 * (boxes[0, 2] - boxes[0, 0]) / origin_w:.1f}%, 高度={100 * (boxes[0, 3] - boxes[0, 1]) / origin_h:.1f}%") # 裁剪坐标到图像边界内 boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1) diff --git a/yolo_api.py b/yolo_api.py index 9575907..1fc8d4c 100644 --- a/yolo_api.py +++ b/yolo_api.py @@ -664,6 +664,8 @@ async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio print("去重的距离:", config.repeat_dis) repeat_dis = config.repeat_dis repeat_time = config.repeat_time + high_count_warn = config.high_count_warn + print(f"config.high_count_warn {config.high_count_warn}") model_configs.append( { @@ -760,7 +762,7 @@ async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic, push_url, invade_enable, invade_file, camera_para_url, - device_height, repeat_dis, repeat_time + device_height, repeat_dis, repeat_time,high_count_warn ) except Exception as e: logger.error(f"处理流程异常: {e}") @@ -1846,5 +1848,5 @@ if __name__ == "__main__": print("正在安装psutil库...") subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"]) - app.run(host="0.0.0.0", port=12315, debug=False, access_log=True) + app.run(host="192.168.110.103", port=12315, debug=False, access_log=True) # app.run(host="0.0.0.0", workers=3, port=12315)