diff --git a/middleware/entity/up_osd_info_push.py b/middleware/entity/up_osd_info_push.py index c128854..cfaa0e6 100644 --- a/middleware/entity/up_osd_info_push.py +++ b/middleware/entity/up_osd_info_push.py @@ -1,7 +1,11 @@ from dataclasses import dataclass -import json -from typing import Optional + import asyncio +from typing import Union, Optional +from typing import Any + + + @dataclass @@ -25,9 +29,23 @@ class OSDInfo: wind_speed: float +@dataclass +class OSDInfo_v1: + attitude_head: float + latitude: float + longitude: float + height: float + speed_x: float + speed_y: float + speed_z: float + gimbal_pitch: float + gimbal_roll: float + gimbal_yaw: float + + @dataclass class OSDMessage: - data: OSDInfo + data: Any # 可以是两种类型之一 method: str seq: int timestamp: int @@ -36,19 +54,47 @@ class OSDMessage: def parse_osd_message(json_str: Optional[str]) -> Optional[OSDMessage]: if not json_str: return None - + data = json_str try: - data=json_str osd_info = OSDInfo(**data["data"]) - return OSDMessage( - data=osd_info, - method=data["method"], - seq=data["seq"], - timestamp=data["timestamp"] - ) - except Exception as e: - print(f"Error parsing OSD message: {e}") - return None + data_seq=data["seq"] + except (TypeError, KeyError) as e: + # 如果OSDInfo格式失败,尝试使用OSDInfo_v1格式 + try: + osd_info = OSDInfo_v1(**data["data"]) + data_seq = 0 + except Exception as e2: + print(f"Error parsing OSD message with both formats: {e2}") + return None + return OSDMessage( + data=osd_info, + method=data["method"], + seq=data_seq, + timestamp=data["timestamp"] + ) + # try: + # data=json_str + # # osd_info = OSDInfo(**data["data"]) + # # if osd_info is None: + # # osd_info = OSDInfo_v1(**data["data"]) #适配东西湖区一代飞机的格式 + # try: + # osd_info = OSDInfo(**data["data"]) + # except (TypeError, KeyError) as e: + # # 如果OSDInfo格式失败,尝试使用OSDInfo_v1格式 + # try: + # osd_info = OSDInfo_v1(**data["data"]) + # except Exception as e2: + # print(f"Error parsing OSD message with both formats: {e2}") + # return None + # return OSDMessage( + # data=osd_info, + # method=data["method"], + # seq=data["seq"], + # timestamp=data["timestamp"] + # ) + # except Exception as e: + # print(f"Error parsing OSD message: {e}") + # return None async def main(): @@ -92,4 +138,4 @@ if __name__ == '__main__': except: pass - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/yolo/cv_multi_model_back_video.py b/yolo/cv_multi_model_back_video.py index a18d5f5..500fb10 100644 --- a/yolo/cv_multi_model_back_video.py +++ b/yolo/cv_multi_model_back_video.py @@ -2558,9 +2558,9 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: results = [] results_list = [] - invade_switch_enable = True # 侵限施工 + invade_switch_enable = True # 侵限施工 if invade_switch > 0: - invade_switch_enable = False # 超限施工 + invade_switch_enable = False # 超限施工 # 启用侵限且拿到了飞机的姿态信息,再绘制红线 if invade_state and osd_info: gimbal_yaw = osd_info.gimbal_yaw @@ -2616,7 +2616,7 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: model_func_id = model_para[0]["func_id"] invade_point = [] message_point = [] - invade_point_message_point=[] # 超限使能,统计侵限,方便画图 + invade_point_message_point = [] # 超限使能,统计侵限,方便画图 target_point = [] # 存储满足条件的图像坐标,方便后续经纬度转换 cls_count = 0 @@ -2652,9 +2652,9 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: # is_invade = is_point_in_polygon(point_x, point_y, results) # print(f"is_invadeis_invadeis_invade {is_invade} {len(results)}") - if invade_switch_enable:#只关注侵限 - if is_invade: #只关注侵限且实际发生侵限 - # if invade_switch_enable: # 只关注侵限 + if invade_switch_enable: # 只关注侵限 + if is_invade: # 只关注侵限且实际发生侵限 + # if invade_switch_enable: # 只关注侵限 cls_count += 1 # invade_point.append({ # "u": point_x, @@ -2681,7 +2681,8 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: label = f"{en_name}:{confidence:.2f}:{track_id}" label_name = f"{en_name}" # 计算文本位置 - text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[ + text_size = \ + cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[ 0] text_width, text_height = text_size[0], text_size[1] text_x = x1 @@ -2698,17 +2699,17 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: label_name, # 仅显示汉字 (text_x, text_y - 40), ) - else: #只关注超限 - if is_invade: #只关注超限的情况下,发生了侵限行为,只在图像上展示侵限,不做行为记录 - # if invade_switch_enable: # 只关注侵限 - # cls_count += 1 - # target_point.append({ - # "u": point_x, - # "v": point_y, - # "cls_id": cls_id, - # "track_id": track_id, - # "new_track_id": new_track_id - # }) # 对于侵限,只存储侵限目标 + else: # 只关注超限 + if is_invade: # 只关注超限的情况下,发生了侵限行为,只在图像上展示侵限,不做行为记录 + # if invade_switch_enable: # 只关注侵限 + # cls_count += 1 + # target_point.append({ + # "u": point_x, + # "v": point_y, + # "cls_id": cls_id, + # "track_id": track_id, + # "new_track_id": new_track_id + # }) # 对于侵限,只存储侵限目标 # model_list_func_id = model_para[0]["model_list_func_id"] # model_func_id = model_para[0]["func_id"] @@ -2722,7 +2723,8 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: label = f"{en_name}:{confidence:.2f}:{track_id}" label_name = f"{en_name}" # 计算文本位置 - text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[ + text_size = \ + cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[ 0] text_width, text_height = text_size[0], text_size[1] text_x = x1 @@ -2739,16 +2741,16 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: label_name, # 仅显示汉字 (text_x, text_y - 40), ) - else: # 超限使能且识别到了超限发生 + else: # 超限使能且识别到了超限发生 print("超限使能且识别到了超限发生") cls_count += 1 target_point.append({ - "u": point_x, - "v": point_y, - "cls_id": cls_id, - "track_id": track_id, - "new_track_id": new_track_id - }) # 对于侵限,只存储侵限目标 + "u": point_x, + "v": point_y, + "cls_id": cls_id, + "track_id": track_id, + "new_track_id": new_track_id + }) # 对于侵限,只存储侵限目标 # model_list_func_id = model_para[0]["model_list_func_id"] # model_func_id = model_para[0]["func_id"] @@ -2763,8 +2765,8 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: label_name = f"{en_name}" # 计算文本位置 text_size = \ - cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[ - 0] + cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[ + 0] text_width, text_height = text_size[0], text_size[1] text_x = x1 text_y = y1 - 5 @@ -2828,7 +2830,7 @@ async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: (point["box"][2], point["box"][3]), (0, 255, 255), 2) - if not invade_switch_enable:# 侵限使能,只关注超限的情况下,将侵限画另一个颜色 + if not invade_switch_enable: # 侵限使能,只关注超限的情况下,将侵限画另一个颜色 for point in invade_point_message_point: cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]), @@ -3118,7 +3120,8 @@ async def cal_des_invade(loop, invade_executor, task_id: str, mqtt, mqtt_publish list_points: list[list[any]], camera_para: Camera_Para, model_count: int, cancel_flag: asyncio.Event = None, invade_switch: int = 0, invade_queue: asyncio.Queue = None, event_queue: asyncio.Queue = None, - device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1): + device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1, + high_count_warn: float = -1): # loop = asyncio.get_running_loop() # upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS) pic_count_hongxian = 0 @@ -3129,7 +3132,13 @@ async def cal_des_invade(loop, invade_executor, task_id: str, mqtt, mqtt_publish report_interval = 8 target_location_back = [] # 本地缓存,用作位置重复计算 current_time_second = int(time.time()) + high_count_warn_status = False + high_count_warn_num = 0 while not cancel_flag.is_set(): + # 在循环开始处初始化变量 + high_count_warn_status = False + high_count_warn_num = 0 + # 检查队列长度,避免堆积 if invade_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2: print(f"警告:invade_queue 积压(当前长度={invade_queue.qsize()}),清空队列") @@ -3191,7 +3200,8 @@ async def cal_des_invade(loop, invade_executor, task_id: str, mqtt, mqtt_publish height = air_alti.height cam_longitude = air_alti.longitude cam_latitude = air_alti.latitude - + high_count_warn_status = False # 侵限情况下的超数报警,针对东西湖 + high_count_warn_num = 0 # 侵限情况下的超数报警,针对东西湖 try: current_time = time.time() h = device_height @@ -3244,7 +3254,13 @@ async def cal_des_invade(loop, invade_executor, task_id: str, mqtt, mqtt_publish show_des = 0 str_loca = "" des_location_result = [] + if repeat_dis > 0: # ai_model_list repeat_dis 字段大于零,才启用去重 + if target_location_back is not None: # 触发计数报警 + if 0 < cls_count < high_count_warn: + high_count_warn_num = cls_count + high_count_warn_status = True + if len(target_location_back) > 0: # 当前逻辑并不严谨,只是比较了第一个位置信息 des1_back = target_location_back[0] des1_back_longitude = des1_back[0] @@ -3329,13 +3345,18 @@ async def cal_des_invade(loop, invade_executor, task_id: str, mqtt, mqtt_publish "minio": {"minio_path": minio_path, "minio_origin_path": minio_origin_path, "file_type": file_type}, - "invade_switch":invade_switch, + "invade_switch": invade_switch, "box_detail": [{ "model_id": model_func_id, "cls_count": cls_count, "box_count": [message_point], # 特殊处理 "location_results": location_results # 增加经纬度信息 }], + "high_count_warn": { + "high_count_warn_status": high_count_warn_status, + "high_count_warn_num": high_count_warn_num, + "high_count_warn": high_count_warn + }, "osd_location": { "longitude": cam_longitude, "latitude": cam_latitude @@ -3484,9 +3505,10 @@ async def send_frame_to_s3_mq(loop, upload_executor, task_id, mqtt, mqtt_topic, high_count_warn_status = False high_count_warn_num = 0 - if target_point is not None and 0 < high_count_warn < len(target_point): # 触发计数报警 - high_count_warn_num = len(target_point) - high_count_warn_status = True + if target_point is not None: # 触发计数报警 + if 0 < high_count_warn < len(target_point): + high_count_warn_num = len(target_point) + high_count_warn_status = True for item in target_point: # # 跳过无效的track_id @@ -3661,7 +3683,7 @@ async def send_frame_to_s3_mq(loop, upload_executor, task_id, mqtt, mqtt_topic, "longitude": cam_longitude, "latitude": cam_latitude }, - "invade_switch": 0, #默认 + "invade_switch": 0, # 默认 "count_message": count_message, "high_count_warn": { "high_count_warn_status": high_count_warn_status, @@ -4251,7 +4273,7 @@ async def start_rtmp_processing(video_url: str, task_id: str, model_configs: Lis invade_queue, event_queue, device_height, - repeat_dis, repeat_time + repeat_dis, repeat_time, high_count_warn ), name="cal_des_invade" ) @@ -4552,7 +4574,7 @@ async def start_video_processing(minio_path: str, task_id: str, model_configs: L invade_queue, event_queue, device_height, - repeat_dis, repeat_time + repeat_dis, repeat_time, -1 ), name="cal_des_invade" ) @@ -4562,7 +4584,8 @@ async def start_video_processing(minio_path: str, task_id: str, model_configs: L for _ in range(2): upload_task = asyncio.create_task( send_frame_to_s3_mq(task_id, mqtt, mqtt_topic, - cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis, repeat_time), + cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis, repeat_time, + -1), name=f"send_frame_to_s3_mq_{_}" ) upload_tasks.append(upload_task) diff --git a/yolo_api.py b/yolo_api.py index 9b88e4e..f21d435 100644 --- a/yolo_api.py +++ b/yolo_api.py @@ -718,9 +718,14 @@ async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio invade = request_json.content_body.invade invade_file = invade["invade_file"] camera_para_url = invade["camera_para_url"] - invade_switch = 0 - if invade["invade_switch"] is not None: + + if high_count_warn is None: + high_count_warn=0 + + if "invade_switch" in invade: invade_switch = invade["invade_switch"] + else: + invade_switch = 0 # 或其他默认值 # dao.get_mqtt_config_by_orgcode(org_code,) str_request = str(request) + "&" + str(request.socket) # 待测试,看看公网能不能捕获到请求端ip dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request) @@ -954,9 +959,10 @@ async def run_back_Video_Multi_Detect_async(request, request_json): invade = request_json.content_body.invade invade_file = invade["invade_file"] camera_para_url = invade["camera_para_url"] - invade_switch = 0 - if invade["invade_switch"] is not None: + if "invade_switch" in invade: invade_switch = invade["invade_switch"] + else: + invade_switch = 0 # 或其他默认值 await start_video_processing(minio_file_path, task_id, model_configs, mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic, push_url, invade_enable,invade_switch, invade_file, camera_para_url, device_height, repeat_dis,