import asyncio import json import cv2 import subprocess from threading import Thread, Lock, Event import time import queue import numpy as np import datetime import os from ultralytics import YOLO # 导入 Ultralytics YOLO 模型 from func_aimodellist.func_1x import func_100004, func_100000, func_100021 from middleware.minio_util import upload_file # from Ai_tottle.minio_oss import upload_file from mqtt_pub import MQTTClient from ultralytics import solutions import torch from PIL import Image, ImageDraw, ImageFont from yolo.yolo_multi_model import YOLOModel, convert_result_to_multi_result, multi_model_process_frame import av # 全局变量 ifAI = {'status': False} deskLock = Lock() frame_queue = queue.Queue(maxsize=60) # 增加帧缓冲队列大小 processed_frame_queue = queue.Queue(maxsize=30) # 处理后的帧队列 list_track_id = [] # 缓存track_id 进而方便做事件判断 stop_event = Event() mqtt_client = None # MQTT 代理地址和端口 # broker = "112.44.103.230" # 公共 MQTT 代理(免费) broker = "8.137.54.85" # 公共 MQTT 代理(免费) port = 1883 # MQTT 默认端口 # 主题 topic = "thing/product/ai/events" local_drc_message = None CHINESE_LABELS = {0: "人", 1: "车", 2: "车", 3: "猫", 4: "购", 5: "猫", 6: "猫", 7: "猫", 8: "猫", 9: "猫", 10: "猫"} fps = 30 # 创建 RTMP 输出容器(显式指定 format="flv") # container = av.open(rtmp_url, mode="w", format="flv") # 初始不设置宽高,等待第一帧获取尺寸 stream = None container = None first_frame_received = False # font_path="config/SIMSUN.TTC" 为中文配置文件路径,否则标签对中文支持不好,linux上要做特殊处理 def put_chinese_text(img, text, position, font_path="config/SIMSUN.TTC", font_size=20, color=(0, 255, 0)): # def put_chinese_text(img, text, position, font_path="simhei.ttf", font_size=20, color=(0, 255, 0)): """使用PIL库在图片上绘制中文""" img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) draw = ImageDraw.Draw(img_pil) try: font = ImageFont.truetype(font_path, font_size) except: font = ImageFont.load_default() draw.text(position, text, font=font, fill=color) return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR) # 针对tracker的结果做统计输出 def cal_tricker_results(frame, results, class_names, func_id, local_func_cache, para, model_cls, chinese_label,int_func_id): global list_track_id """Draw detection results on the frame.""" list_func_id=[int_func_id] annotated_frame = frame.copy() result_boxes = results.boxes result_clss = results.clss result_conf = results.confs # print(f"result_clss {len(result_boxes)} {len(result_clss)} {len(result_conf)}") # print("cal_t 1") for i, box in enumerate(result_boxes): class_id = result_clss[i] conf = result_conf[i] if isinstance(box, (torch.Tensor, np.ndarray)): x1, y1, x2, y2 = map(int, box[:4]) else: x1, y1, x2, y2 = box[0], box[1], box[2], box[3] label = chinese_label.get(class_id, str(class_id)) text = f"{label} {conf:.2f}" # 绘制边界框 cv2.rectangle(annotated_frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) # 计算文本位置(确保在框上方显示) text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] text_width, text_height = text_size[0], text_size[1] # 调整文本位置到框上方 text_x = x1 text_y = y1 - 5 # 稍微向上移动5像素 # 如果文本超出图像顶部,则放在框内部下方 if text_y < 0: text_y = y2 + text_height + 5 # 使用PIL绘制中文标签(先绘制背景矩形增强可读性) # 创建背景矩形 bg_color = (0, 0, 0) # 黑色背景 padding = 2 annotated_frame = cv2.rectangle( annotated_frame, (int(text_x) - padding, int(text_y) - text_height - padding), (int(text_x) + text_width + padding, int(text_y) + padding), bg_color, -1 ) # 再绘制文本(使用PIL绘制中文) temp_img = annotated_frame.copy() annotated_frame = put_chinese_text( temp_img, text, (text_x, text_y - text_height), # PIL的坐标系不同,需要调整 font_size=15, color=(0, 255, 0) ) box_result = [] # print("cal_t 2") # 当前是伪代码,针对人车识别模型 # human_pt={0: 'pedestrian', 1: 'people', 2: 'bicycle', 3: 'car', 4: 'van', 5: 'truck', 6: 'tricycle', 7: 'awning-tricycle', 8: 'bus', 9: 'motor'} # human_pt = {0: "行人", 1: "人", 2: "自行车", 3: "车", 4: "面包车", 5: "货车", 6: "三轮车", 7: "带篷三轮车", # 8: "巴士", 9: "摩托车"} # 接下来部分是功能筛选代码,针对100001、100002 功能代码,做部分逻辑判断 # 发现人 即报警 # 伪代码,暂时写死 if 100014 in list_func_id: list_func_id.append(100000) func_id_100000 = 100000 if func_id_100000 in list_func_id: # index = [0,1] # type_name_list=[] # for i in model_cls: # type_name_list.append(human_pt.get(i)) result_100001 = func_100000(results, model_cls, chinese_label, func_id_100000, list_track_id,func_id) if result_100001 is not None: # print(f"result_100001result_100001 {result_100001}") box_result.append(result_100001) # box_result["result_100001"] = result_100001 # print("cal_t 3") func_id_100002 = 100002 if func_id_100002 in list_func_id: # index = [2,3,4,5,6,7,8,9] # type_name_list=[] # for i in index: # type_name_list.append(human_pt.get(i)) # print("c 1") result_100002 = func_100000(results, model_cls, chinese_label, func_id_100002, list_track_id) # print("c 2") if result_100002 is not None: # print(f"result_100001result_100001 {result_100002}") box_result.append(result_100002) # box_result["result_100001"] = result_100001 func_id_100004 = 100004 if func_id_100004 in list_func_id: # index = [0,1] # type_name_list=[] local_cache = local_func_cache["func_100004"] # for i in index: # type_name_list.append(human_pt.get(i)) result_100004, local_cache = func_100004(results, model_cls, chinese_label, func_id_100004, list_track_id, local_cache) if result_100004 is not None: local_func_cache["func_100004"] = local_cache # 更新全局缓存 box_result.append(result_100004) # print(f"result_100004 {result_100004}") # print("cal_t 4") func_id_100006 = 100006 if func_id_100006 in list_func_id: # index = [2,3,4,5,6,7,8,9] # type_name_list=[] local_cache = local_func_cache["func_100006"] # for i in index: # type_name_list.append(human_pt.get(i)) result_100006, local_cache = func_100004(results, model_cls, chinese_label, func_id_100006, list_track_id, local_cache) if result_100006 is not None: local_func_cache["func_100006"] = local_cache # 更新全局缓存 box_result.append(result_100006) # print(f"result_100006 {result_100006}") # print("cal_t 5") func_id_100021 = 100021 if func_id_100021 in list_func_id: # index = [0,1] # type_name_list=[] number_n = para["N"] # rest接口中的N值,做人员聚集、车辆聚集的判断参数 # local_cache=local_func_cache["func_id_100021"] # for i in index: # type_name_list.append(human_pt.get(i)) result_100021 = func_100021(results, model_cls, chinese_label, func_id_100021, number_n) if result_100021 is not None: box_result.append(result_100021) # print(f"result_100021 {result_100021}") func_id_100023 = 100023 if func_id_100023 in list_func_id: # index = [2,3,4,5,6,7,8,9] # type_name_list=[] number_n = para["N"] # rest接口中的N值,做人员聚集、车辆聚集的判断参数 # local_cache=local_func_cache["func_100006"] # for i in index: # type_name_list.append(human_pt.get(i)) result_100023 = func_100021(results, model_cls, chinese_label, func_id_100023, number_n) if result_100023 is not None: # local_func_cache["func_100006"]=local_cache #更新全局缓存 box_result.append(result_100023) # print(f"result_100023 {result_100023}") # bridge_pt={0:"蜂窝",1:"剥落",2:"空腔",3:"锈蚀",4:"钢筋裸露",5:"裂缝"} func_id_100031 = 100031 if func_id_100031 in list_func_id: # index = [0,1,2,3,4,5] # type_name_list=[] # for i in index: # type_name_list.append(bridge_pt.get(i)) result_100031 = func_100000(results, model_cls, chinese_label, func_id_100031, list_track_id) if result_100031 is not None: # print(f"result_100031result_100031 {result_100031}") box_result.append(result_100031) # smoke_pt={0:"烟雾"} func_id_100041 = 100041 if func_id_100041 in list_func_id: # index = [0] # type_name_list=[] # for i in index: # type_name_list.append(smoke_pt.get(i)) result_100041 = func_100000(results, model_cls, chinese_label, func_id_100041, list_track_id) if result_100041 is not None: # print(f"result_100041result_100041 {result_100041}") box_result.append(result_100041) # hwrc_pt={0:"人",1:"卡车",2:"汽车",3:"自行车"} func_id_100061 = 100061 if func_id_100061 in list_func_id: # index = [0,1,2,3] # type_name_list=[] # for i in index: # type_name_list.append(hwrc_pt.get(i)) result_100061 = func_100000(results, model_cls, chinese_label, func_id_100061, list_track_id) if result_100061 is not None: # print(f"result_100061result_100061 {result_100061}") box_result.append(result_100061) # hwgf_pt={0:"单一热班",1:"大面积热斑",2:"单一热班&异常低温",3:"大面积热班&异常低温",4:"异常低温",5:"单一热斑&遮挡",6:"异常低温&热斑",7:"二极管短路",8:"阳光反射",9:"二极管短路&异常低温"} func_id_100081 = 100081 if func_id_100081 in list_func_id: # index = [0,1,2,3,4,5,6,7,8,9] # type_name_list=[] # for i in index: # type_name_list.append(hwgf_pt.get(i)) result_100081 = func_100000(results, model_cls, chinese_label, func_id_100081, list_track_id) if result_100081 is not None: # print(f"result_100081result_100081 {result_100081}") box_result.append(result_100081) # # # 缩放一下,好展示 # resized_img = cv2.resize(annotated_frame, (0, 0), fx=0.5, fy=0.5) # 按比例缩放 # cv2.imshow("annotated_frame",resized_img) # cv2.waitKey(0) # 等待任意按键,参数为延迟毫秒(0表示无限等待) # cv2.destroyAllWindows() # 关闭窗口 func_id_100008 = 100008 if func_id_100008 in list_func_id: # index = [0,1] # type_name_list=[] # for i in model_cls: # type_name_list.append(human_pt.get(i)) result_100008 = func_100000(results, model_cls, chinese_label, func_id_100008, list_track_id) if result_100008 is not None: # print(f"result_100001result_100001 {result_100001}") box_result.append(result_100008) func_id_100052 = 100052 if func_id_100052 in list_func_id: # index = [0,1] # type_name_list=[] # for i in model_cls: # type_name_list.append(human_pt.get(i)) result_100052 = func_100000(results, model_cls, chinese_label, func_id_100052, list_track_id) if result_100052 is not None: # print(f"result_100001result_100001 {result_100001}") box_result.append(result_100052) func_id_100091 = 100091 if func_id_100091 in list_func_id: # index = [0,1] # type_name_list=[] # for i in model_cls: # type_name_list.append(human_pt.get(i)) result_100091 = func_100000(results, model_cls, chinese_label, func_id_100091, list_track_id) if result_100091 is not None: # print(f"result_100001result_100001 {result_100001}") box_result.append(result_100091) func_id_100092 = 100092 if func_id_100092 in list_func_id: # index = [0,1] # type_name_list=[] # for i in model_cls: # type_name_list.append(human_pt.get(i)) result_100092 = func_100000(results, model_cls, chinese_label, func_id_100092, list_track_id) if result_100092 is not None: # print(f"result_100001result_100001 {result_100001}") box_result.append(result_100092) return annotated_frame, box_result # 输入为检测结果,输出为检测分类,统计后方便前端使用 def extract_box_details(last_results, model_id): """ 从 YOLO 检测结果中提取边界框的详细信息 :param last_results: YOLO 检测结果对象 :param model_id: 模型id :return: 包含边界框事件和数量的字典 """ # box_event = [] box_count = [] count_dict = {} human_pt = {0: "行人", 1: "人", 2: "自行车", 3: "车", 4: "面包车", 5: "货车", 6: "三轮车", 7: "带篷三轮车", 8: "巴士", 9: "摩托车"} for box in last_results.boxes: cls_index = int(box.cls[0]) # 类别索引 if cls_index in count_dict: count_dict[cls_index] += 1 else: count_dict[cls_index] = 1 for key, value in count_dict.items(): type_name = human_pt.get(key) box_count.append({ "type": key, "type_name": type_name, # 伪代码,后续需要修改 "count": value }) box_detail = { "model_id": model_id, "box_count": box_count } return box_detail def setIfAI(pb1): deskLock.acquire() ifAI['status'] = pb1 deskLock.release() def getIfAI(): return ifAI['status'] def stopAIVideo(): print("正在停止AI视频处理...") setIfAI(False) stop_event.set() # 等待足够长的时间确保资源释放 wait_count = 0 max_wait = 5 # 减少最大等待时间到5秒 while stop_event.is_set() and wait_count < max_wait: time.sleep(0.5) wait_count += 1 if wait_count >= max_wait: print("警告: 停止AI视频处理超时,强制终止") # 不使用_thread._interrupt_main(),改用其他方式强制终止 try: # 尝试终止可能运行的进程 import os import signal import psutil # 查找并终止可能的FFmpeg进程 current_process = psutil.Process(os.getpid()) for child in current_process.children(recursive=True): try: child_name = child.name().lower() if 'ffmpeg' in child_name: print(f"正在终止子进程: {child.pid} ({child_name})") child.send_signal(signal.SIGTERM) except: pass except: pass else: print("AI视频处理已停止") def startBackAIVideo(task_id, video_path,push_url, m1, model_cls, chinese_label, list_func_id, confidence, para, mqtt_ip, mqtt_port, mqtt_topic): if ifAI['status']: stopAIVideo() time.sleep(1) stop_event.clear() thread = Thread(target=startBackAIVideo2, args=( task_id, video_path, push_url,m1, list_func_id, confidence, para, mqtt_ip, mqtt_port, mqtt_topic, model_cls, chinese_label)) # cls2_thread = Thread(target=cls2_find, args=(video_path,m1, cls, confidence)) # cls2_thread.daemon = True # 守护线程,主程序退出时线程也会退出 thread.daemon = True # 守护线程,主程序退出时线程也会退出 thread.start() def read_frames(cap,push_url, frame_queue): """优化的帧读取线程""" global first_frame_received,stream,container if container is None: container = av.open(push_url, mode="w", format="flv") frame_count = 0 last_time = time.time() last_fps_time = time.time() # 减小目标帧间隔时间,提高读取帧率 target_time_per_frame = 1.0 / 60.0 # 目标帧间隔时间(提高到60fps) # 添加连接断开检测 connection_error_count = 0 max_connection_errors = 10 # 最多允许连续10次连接错误 last_successful_read = time.time() max_read_wait = 60.0 # 30秒无法读取则认为连接断开 # 预先丢弃几帧,确保从新帧开始处理 for _ in range(5): cap.grab() retry_count = 0 max_retries = 3 while not stop_event.is_set() and retry_count < max_retries: try: # 确保container使用正确的push_url if container is None: try: container = av.open(push_url, mode="w", format="flv") print(f"成功打开RTMP推流: {push_url}") except Exception as e: print(f"无法打开RTMP推流地址: {push_url}, 错误: {e}") stop_event.set() return current_time = time.time() elapsed_time = current_time - last_time # 检查是否长时间无法读取帧 if current_time - last_successful_read > max_read_wait: print(f"超时警告: {max_read_wait}秒内未能读取到有效帧,可能连接已断开") stop_event.set() break # 帧率控制,但更积极地读取 if elapsed_time < target_time_per_frame: time.sleep(target_time_per_frame - elapsed_time) continue # print(f"4 {frame_queue.qsize()} {frame_queue.maxsize * 0.8}") # 当队列快满时,跳过一些帧以避免延迟累积 if frame_queue.qsize() > frame_queue.maxsize * 0.8: # 跳过一些帧 cap.grab() last_time = time.time() # 清空队列 while not frame_queue.empty(): try: frame_queue.get_nowait() # 非阻塞地获取并移除队列中的元素 except queue.Empty: break # 防止在极端情况下出现的竞争条件 # print(f"5 {frame_queue.qsize()} {frame_queue.maxsize * 0.8}") # print(f"引入solutions.ObjectCounter,单帧处理速度变慢,队列堆积后,直接清空") continue # print(6) ret, frame = cap.read() if not ret: print("拉流错误:无法读取帧") connection_error_count += 1 if connection_error_count >= max_connection_errors: print(f"连续{max_connection_errors}次无法读取帧,可能连接已断开,正在停止流程...") stop_event.set() break time.sleep(0.5) # 短暂等待后重试 continue # 成功读取了帧,重置错误计数 connection_error_count = 0 last_successful_read = time.time() frame_count += 1 if frame_count % 60 == 0: # 每60帧计算一次FPS current_fps_time = time.time() fps = 60 / (current_fps_time - last_fps_time) # print(f"拉流FPS: {fps:.2f}") last_fps_time = current_fps_time last_time = time.time() frame_queue.put((frame, time.time())) # 添加时间戳 height, width, channels = frame.shape # 如果是第一帧,初始化流参数 if not first_frame_received: stream = container.add_stream("libx264", rate=30) stream.width = width stream.height = height stream.pix_fmt = "yuv420p" stream.options = { "tune": "zerolatency", "preset": "ultrafast", "crf": "23", } first_frame_received = True print(f"Initialized stream with resolution: {width}x{height}") if first_frame_received: # 将 NumPy 数组转换为 AV 帧 av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24") # 编码并推送帧 for packet in stream.encode(av_frame): container.mux(packet) except av.error.BrokenPipeError: print("RTMP 连接断开,尝试重连...") retry_count += 1 if container: container.close() container = None time.sleep(2) # 等待2秒后重试 except Exception as e: print(f"读取帧时发生错误: {e}") break def process_frames_tricker(push_url,frame_queue, processed_frame_queue, model_path, list_func_id, confidence, para, model_cls, chinese_label, use_fp16=False): """Count specific classes of objects in a video with dynamic processing adjustments.""" global first_frame_received,stream,container if container is None: container = av.open(push_url, mode="w", format="flv") # solutions默认会在图上绘制一个方框,很烦 line_points = [(-10, -10), (-10, -10)] # 伪代码,后续识别类型需要跟ai_model_list 的id关联 # cls = [0, 1, 2, 3, 4] # 这里参数show=True,一旦通过rest接口重启程序,会导致线程锁住 # counter = solutions.ObjectCounter(show=False, region=line_points, model=model_path, classes=cls) counter = solutions.ObjectCounter(show=False, region=line_points, model=model_path, classes=model_cls) # Model processing setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' try: counter.model.to(device) if hasattr(counter.model, 'args') and hasattr(counter.model.args, 'batch'): counter.model.args.batch = 1 # 启用半精度 if use_fp16 and device.startswith('cuda') and hasattr(counter.model, 'model'): try: counter.model.model = counter.model.model.half() print("✅ 启用半精度 FP16 模式") except Exception as half_err: print(f"⚠️ 半精度转换失败: {half_err}") else: print("ℹ️ 半精度模式未启用") except Exception as e: print(f"⚠️ 模型设备配置警告: {e}") frame_count = 0 process_times = [] last_results = None last_counter = None skip_counter = 0 max_skip = 2 error_count = 0 # Get class names from the model class_names = counter.model.names if hasattr(counter.model, 'names') else {i: str(i) for i in range(1000)} local_func_cache = { "func_100000": None, "func_100004": None, # 存储缓存,缓存人员track_id "func_100006": None # 存储缓存,缓存车辆track_id } while not stop_event.is_set(): try: if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8: time.sleep(0.01) continue frame, timestamp = frame_queue.get(timeout=0.2) frame_copy = frame.copy() if time.time() - timestamp > 0.3: continue frame_count += 1 if skip_counter > 0 and last_counter is not None: skip_counter -= 1 annotated_frame, box_result = cal_tricker_results(frame, counter, class_names, list_func_id, local_func_cache, para, model_cls, chinese_label) continue process_start = time.time() resize_scale = 1.0 qsize = frame_queue.qsize() maxsize = frame_queue.maxsize if qsize > maxsize * 0.7: resize_scale = 0.4 elif qsize > maxsize * 0.5: resize_scale = 0.6 elif qsize > maxsize * 0.3: resize_scale = 0.8 if resize_scale < 1.0: process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale) else: process_frame = frame if process_frame is None or process_frame.size == 0: print("警告: 输入帧无效") continue try: # Run object counting/tracking # 获取当前时间的时间戳,包含毫秒 # timestamp = time.time() # # 获取13位长度的时间戳(毫秒级) # time_start = int(timestamp * 1000) # frame_copy=process_frame # 暂存frame,方便进行结果重绘 timestart = time.time() results = counter(process_frame) timeend = time.time() print(f"timestart-timeend {timeend-timestart}") if results is None: # 处理返回None的情况 print("警告: 检测结果为None,使用上一帧结果") # results = getattr(counter, 'last_results', None) # if results is None: continue # timestamp = time.time() # # 获取13位长度的时间戳(毫秒级) # time_end = int(timestamp * 1000) # print(f"time_end - time_start {time_end - time_start}") # print("1") last_counter = counter # 基于计数的方法,识别后的结果存储在counter中 # print("2") annotated_frame, box_result = cal_tricker_results( frame_copy if resize_scale == 1.0 else cv2.resize(frame_copy, (frame.shape[1], frame.shape[0])), counter, class_names, list_func_id, local_func_cache, para, model_cls, chinese_label) # if annotated_frame is None: # print("annotated_frame is None") # # if box_result is None: # print("box_result is None") # print("3") processed_frame_queue.put((annotated_frame, timestamp, box_result)) # print_box_details(results) # print("=4") if qsize > maxsize * 0.5: skip_counter = max_skip except Exception as infer_err: print(f"推理错误1: {infer_err}") if last_counter is not None: annotated_frame, box_result = cal_tricker_results(frame, counter, class_names, list_func_id, local_func_cache, para, model_cls, chinese_label) else: annotated_frame = frame.copy() process_end = time.time() process_times.append(process_end - process_start) if len(process_times) > 30: process_times.pop(0) if frame_count % 30 == 0: avg_process_time = sum(process_times) / len(process_times) fps = 1.0 / avg_process_time if avg_process_time > 0 else 0 print( f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time * 1000:.2f}ms, 队列大小: {qsize}, 缩放比例: {resize_scale:.2f}") processed_frame_queue.put((annotated_frame, timestamp, box_result)) error_count = 0 except queue.Empty: continue except Exception as e: print(f"处理帧错误: {e}") error_count += 1 if error_count >= 5: print(f"连续处理错误达到5次 ,正在停止处理...") break # finally: # # 清理资源 # for packet in stream.encode(): # container.mux(packet) # container.close() def multi_model_process_frames(frame_queue, processed_frame_queue, model_path, list_func_id, confidence, para, model_cls, chinese_label, use_fp16=False): """Count specific classes of objects in a video with dynamic processing adjustments.""" models = [ YOLOModel('pt/best.pt', cls_map={'pedestrian': '行人', 'pedestrian': '行人'}, allowed_classes=['pedestrian', 'people']) , YOLOModel('pt/best.pt', cls_map={'car': '车'}, allowed_classes=['car']) ] # 用于存储处理结果的字典 result_dict = {} frame_count = 0 processing_times = [] # 启动所有模型的工作线程 for model in models: model.start() local_func_cache = { "func_100000": None, "func_100004": None, # 存储缓存,缓存人员track_id "func_100006": None # 存储缓存,缓存车辆track_id } max_skip = 2 error_count = 0 while not stop_event.is_set(): try: if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8: time.sleep(0.01) continue frame, timestamp = frame_queue.get(timeout=0.2) frame_copy = frame.copy() if time.time() - timestamp > 0.3: continue process_start = time.time() resize_scale = 1.0 qsize = frame_queue.qsize() maxsize = frame_queue.maxsize # if qsize > maxsize * 0.7: # resize_scale = 0.4 # elif qsize > maxsize * 0.5: # resize_scale = 0.6 # elif qsize > maxsize * 0.3: # resize_scale = 0.8 # # if resize_scale < 1.0: # process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale) # else: # process_frame = frame # if process_frame is None or process_frame.size == 0: # print("警告: 输入帧无效") # continue # 初始化多个YOLO模型 try: timestart = time.time() # 处理帧(多模型并行) multi_model_process_frame(frame, models, result_dict) timeend1 = time.time() print(f"timeend1 {timeend1 - timestart}") # 获取处理结果 frame_id = id(frame) if frame_id in result_dict: results = result_dict.pop(frame_id) multi_result = convert_result_to_multi_result(results) timeend2 = time.time() print(f"timeend2 {timeend2 - timestart}") annotated_frame, box_result = cal_tricker_results(frame, multi_result, None, list_func_id,local_func_cache, para, model_cls, chinese_label) timeend3 = time.time() print(f"timeend3 {timeend3 - timestart}") time.sleep(1) timeend4 = time.time() print(f"timeend4 {timeend4 - timeend3}") # cv2.imshow('annotated_frame Frame', frame) # cv2.imwrite('output.jpg', frame) # # # # results = counter(process_frame) # # if results is None: # 处理返回None的情况 # # print("警告: 检测结果为None,使用上一帧结果") # # # results = getattr(counter, 'last_results', None) # # # if results is None: # # continue # # timestamp = time.time() # # # 获取13位长度的时间戳(毫秒级) # # time_end = int(timestamp * 1000) # # print(f"time_end - time_start {time_end - time_start}") # # print("1") # last_counter = counter # 基于计数的方法,识别后的结果存储在counter中 # # print("2") # annotated_frame, box_result = cal_tricker_results( # frame_copy if resize_scale == 1.0 else cv2.resize(frame_copy, (frame.shape[1], frame.shape[0])), # counter, class_names, list_func_id,local_func_cache,para,model_cls,chinese_label) # # # if annotated_frame is None: # # print("annotated_frame is None") # # # # if box_result is None: # # print("box_result is None") # print("3") # processed_frame_queue.put((annotated_frame, timestamp, box_result)) # print_box_details(results) # print("=4") if qsize > maxsize * 0.5: skip_counter = max_skip except Exception as infer_err: print(f"推理错误1: {infer_err}") # if last_counter is not None: # annotated_frame, box_result = cal_tricker_results(frame, counter, class_names, list_func_id,local_func_cache,para,model_cls,chinese_label) # else: # annotated_frame = frame.copy() # process_end = time.time() # process_times.append(process_end - process_start) # if len(process_times) > 30: # process_times.pop(0) # if frame_count % 30 == 0: # avg_process_time = sum(process_times) / len(process_times) # fps = 1.0 / avg_process_time if avg_process_time > 0 else 0 # print( # f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time * 1000:.2f}ms, 队列大小: {qsize}, 缩放比例: {resize_scale:.2f}") # # processed_frame_queue.put((annotated_frame, timestamp, box_result)) error_count = 0 except queue.Empty: continue except Exception as e: print(f"处理帧错误: {e}") error_count += 1 if error_count >= 5: print(f"连续处理错误达到5次 ,正在停止处理...") break # # def process_frames(model_id, frame_queue, processed_frame_queue, ov_model, cls, confidence, use_fp16=False): # """处理帧的线程,添加帧率控制,支持半精度FP16""" # # import torch # import time # import queue # import cv2 # # error_count = 0 # 添加错误计数器 # max_errors = 5 # 最大容许错误次数 # frame_count = 0 # process_times = [] # 用于计算平均处理时间 # # # 设置YOLO模型配置,提高性能 # ov_model.conf = confidence # 设置置信度阈值 # # # 将模型移到设备(GPU或CPU) # device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # try: # ov_model.to(device) # # 调整批处理大小为1,减少内存占用 # if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'): # ov_model.args.batch = 1 # # 启用半精度 # if use_fp16 and device.startswith('cuda') and hasattr(ov_model, 'model'): # try: # ov_model.model = ov_model.model.half() # print("✅ 启用半精度 FP16 模式") # except Exception as half_err: # print(f"⚠️ 半精度转换失败: {half_err}") # else: # print("ℹ️ 半精度模式未启用") # except Exception as e: # print(f"⚠️ 模型设备配置警告: {e}") # # # 缓存先前的检测结果,用于提高稳定性 # last_results = None # skip_counter = 0 # max_skip = 2 # 最多跳过几帧不处理 # # while not stop_event.is_set(): # try: # if processed_frame_queue.qsize() >= processed_frame_queue.maxsize * 0.8: # time.sleep(0.01) # continue # # frame, timestamp = frame_queue.get(timeout=0.2) # # if time.time() - timestamp > 0.3: # continue # # frame_count += 1 # # if skip_counter > 0 and last_results is not None: # skip_counter -= 1 # annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5) # # box_detail = extract_box_details(last_results, model_id) # # processed_frame_queue.put((annotated_frame, timestamp, box_detail)) # continue # # process_start = time.time() # # resize_scale = 1.0 # qsize = frame_queue.qsize() # maxsize = frame_queue.maxsize # # if qsize > maxsize * 0.7: # resize_scale = 0.4 # elif qsize > maxsize * 0.5: # resize_scale = 0.6 # elif qsize > maxsize * 0.3: # resize_scale = 0.8 # # if resize_scale < 1.0: # process_frame = cv2.resize(frame, (0, 0), fx=resize_scale, fy=resize_scale) # else: # process_frame = frame # # try: # results = ov_model(process_frame, classes=cls, show=False) # last_results = results[0] # # if resize_scale < 1.0: # annotated_frame = cv2.resize(results[0].plot(conf=False, line_width=1, font_size=1.5), # (frame.shape[1], frame.shape[0])) # # print("results[0].boxes.cls ", results[0].boxes.cls) # # else: # annotated_frame = results[0].plot(conf=False, line_width=1, font_size=1.5) # # box_detail = extract_box_details(last_results, model_id) # # if qsize > maxsize * 0.5: # skip_counter = max_skip # except Exception as infer_err: # print(f"推理错误: {infer_err}") # box_detail = None # if last_results is not None: # annotated_frame = last_results.plot(conf=False, line_width=1, font_size=1.5) # else: # annotated_frame = frame.copy() # # process_end = time.time() # process_times.append(process_end - process_start) # if len(process_times) > 30: # process_times.pop(0) # # if frame_count % 30 == 0: # avg_process_time = sum(process_times) / len(process_times) # fps = 1.0 / avg_process_time if avg_process_time > 0 else 0 # print( # f"模型处理FPS: {fps:.2f}, 平均处理时间: {avg_process_time * 1000:.2f}ms, 队列大小: {qsize}, 缩放比例: {resize_scale:.2f}") # # processed_frame_queue.put((annotated_frame, timestamp, box_detail)) # error_count = 0 # except queue.Empty: # continue # except Exception as e: # error_count += 1 # print(f"处理帧错误: {e}") # if error_count >= max_errors: # print(f"连续处理错误达到{max_errors}次,正在停止处理...") # stop_event.set() # break # # 读取mqtt的drc消息,并且在本地维持一个最新消息 # def read_drc_mqtt(mqtt_client): # global local_drc_message # while not stop_event.is_set(): # time.sleep(0.05) # # 更新全局缓存,如果要同时适配多个飞机,这里要改为队列 # mess = mqtt_client.get_messages(timeout=0.5) # if mess is not None: # # local_drc_message = mess # local_drc_message = json.loads(mess) # 解析为字典 # def read_drc_mqtt(mqtt_client): # global local_drc_message # while not stop_event.is_set(): # time.sleep(0.05) # try: # # 确保 mess 是字符串 # mess = mqtt_client.get_messages(timeout=0.5) # if mess is not None: # try: # local_drc_message = json.loads(mess) # 确保 mess 是字符串 # except json.JSONDecodeError as e: # print(f"JSON 解析错误: {e}") # except Exception as e: # print(f"读取 DRC 消息错误: {e}") def get_local_drc_message(): global local_drc_message if local_drc_message is not None: mess =local_drc_message local_drc_message=None return mess return None from asyncio_mqtt import Client async def async_read_drc_mqtt(mqtt_ip, mqtt_port, mqtt_topic): global local_drc_message async with Client(mqtt_ip, mqtt_port) as client: async with client.messages() as messages: await client.subscribe(mqtt_topic) async for message in messages: if message.topic.matches(mqtt_topic): try: data = json.loads(message.payload) if (data and isinstance(data.get("data"), dict) and "attitude_head" in data["data"]): local_drc_message = data except Exception as e: logger.error(f"Error processing MQTT message: {e}") def read_drc_mqtt(mqtt_client): global local_drc_message while not stop_event.is_set(): time.sleep(0.05) # 更新全局缓存,如果要同时适配多个飞机,这里要改为队列 mess = mqtt_client.get_messages(timeout=0.5) if mess is not None: local_drc_message = mess def write_frames(processed_frame_queue, pipe, size, mqtt_client, task_id): # # 测试用做,输出识别后的图片到窗口 # cv2.namedWindow("RTMP Stream with YOLO", cv2.WINDOW_NORMAL) # cv2.resizeWindow("RTMP Stream with YOLO", 900, 600) global local_drc_message """写入帧的线程,添加平滑处理""" last_write_time = time.time() target_time_per_frame = 1.0 / 30.0 # 30fps pipe_error_count = 0 # 添加错误计数 max_pipe_errors = 3 # 最大容忍错误数 frame_count = 0 last_fps_time = time.time() skipped_frames = 0 # 使用队列存储最近几帧,用于在需要时进行插值 recent_frames = [] max_recent_frames = 5 # 增加缓存帧数量,提高平滑性 # 使用双缓冲机制提高写入速度 buffer1 = bytearray(size[0] * size[1] * 3) buffer2 = bytearray(size[0] * size[1] * 3) current_buffer = buffer1 # 帧率控制参数 min_frame_interval = target_time_per_frame * 0.5 # 允许的最小帧间隔 max_frame_interval = target_time_per_frame * 2.0 # 允许的最大帧间隔 while not stop_event.is_set(): try: # 获取处理后的帧,超时时间较短以便更平滑地处理 frame, timestamp, box_detail = processed_frame_queue.get(timeout=0.05) # 存储最近的帧用于插值 recent_frames.append(frame) if len(recent_frames) > max_recent_frames: recent_frames.pop(0) current_time = time.time() elapsed = current_time - last_write_time # 如果两帧间隔太短,考虑合并或跳过 if elapsed < min_frame_interval and len(recent_frames) >= 2: skipped_frames += 1 continue # 如果两帧间隔太长,进行插值平滑 if elapsed > max_frame_interval and len(recent_frames) >= 2: # 计算需要插入的帧数 frames_to_insert = min(int(elapsed / target_time_per_frame), 3) for i in range(frames_to_insert): # 创建插值帧 weight = (i + 1) / (frames_to_insert + 1) interpolated_frame = cv2.addWeighted(recent_frames[-2], 1 - weight, recent_frames[-1], weight, 0) # 切换双缓冲 current_buffer = buffer2 if current_buffer is buffer1 else buffer1 # 高效调整大小并写入 interpolated_resized = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR) # 在窗口输出图像 # cv2.imshow("RTMP Stream with YOLO", interpolated_resized) # img_bytes = interpolated_resized.tobytes() # # # 写入管道 # pipe.stdin.write(img_bytes) # pipe.stdin.flush() # 切换双缓冲 current_buffer = buffer2 if current_buffer is buffer1 else buffer1 # 正常写入当前帧 - 使用高效的调整大小方法 resized_frame = cv2.resize(frame, size, interpolation=cv2.INTER_LINEAR) # cv2.imshow() # cv2.imshow("RTMP Stream with YOLO", resized_frame) frame_time = time.time_ns() date_str = datetime.datetime.now().strftime("%Y%m%d") # 避免变量覆盖 local_path = f"saveImg/output-{frame_time}.jpg" frame_count = frame_count + 1 # # 用作本地测试,保存图像;测试逻辑,随机暂时几张 if frame_count % 10 == 0: success = cv2.imwrite(local_path, resized_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) if success: print(f"图像已保存到: {frame_count}") # 上传文件到 Minio minio_path, file_type = upload_file(local_path, date_str) minio_path, file_type = upload_file(local_path, date_str) message = { "task_id": task_id, "minio": { "minio_path": minio_path, "file_type": file_type }, "box_detail": box_detail, "uav_location": local_drc_message } message_json = json.dumps(message, indent=4, ensure_ascii=False) # mqtt_client.put_message(str(message)) mqtt_client.publish_message(message_json) else: print("错误: 图像保存失败") if os.path.exists(local_path): try: os.remove(local_path) print(f"文件 {local_path} 已删除") except Exception as e: print(f"删除失败: {e}") # img_bytes = resized_frame.tobytes() # pipe.stdin.write(img_bytes) # pipe.stdin.flush() # # frame_count += 1 if frame_count % 30 == 0: current_fps_time = time.time() fps = 30 / (current_fps_time - last_fps_time) print(f"推流FPS: {fps:.2f}, 跳过的帧: {skipped_frames}, 队列大小: {processed_frame_queue.qsize()}") last_fps_time = current_fps_time skipped_frames = 0 last_write_time = time.time() pipe_error_count = 0 # 成功写入后重置错误计数 except queue.Empty: # 队列为空且有足够的最近帧时,考虑生成插值帧以保持流畅 if len(recent_frames) >= 2 and time.time() - last_write_time > target_time_per_frame: try: # 创建插值帧 interpolated_frame = cv2.addWeighted(recent_frames[-2], 0.5, recent_frames[-1], 0.5, 0) # 切换双缓冲 current_buffer = buffer2 if current_buffer is buffer1 else buffer1 resized_frame = cv2.resize(interpolated_frame, size, interpolation=cv2.INTER_LINEAR) # img_bytes = resized_frame.tobytes() # pipe.stdin.write(img_bytes) # pipe.stdin.flush() last_write_time = time.time() except Exception: pass continue except Exception as e: print(f"写入帧错误: {e}") pipe_error_count += 1 if pipe_error_count >= max_pipe_errors: print("FFmpeg管道错误过多,正在终止进程...") stop_event.set() # 主动结束所有线程 break # # def cls2_find(video_path, m1, cls, confidence): # try: # ov_model = YOLO(m1) # # # ---------------------------------MinIO 存储路径(用于后续上传)------------------------------------ # minio_path = "AIResults" # # # -------------------------------获取当前日期,用于存储图像目录-------------------------------------- # date_str = datetime.datetime.now().strftime("%Y%m%d") # save_dir = f"{date_str}" # if not os.path.exists(save_dir): # os.makedirs(save_dir) # # # 打开视频流 # cap = cv2.VideoCapture(video_path) # if not cap.isOpened(): # print("Error: Could not open video.") # return # # # 获取视频的帧率 (fps) # fps = cap.get(cv2.CAP_PROP_FPS) # # # -----------------------------根据模型设置类别------------------------------------------------- # if m1 == "gdaq.pt": # 仅当使用 gdaq.pt 时,保存类别 2 和 4 # cls2 = [2, 4] # elif m1 == "best.pt": # 仅当使用 best.pt 时,保存类别 0 # cls2 = [0] # else: # 其它模型不保存 # cls2 = [] # # # ------------------------------------------cls2检测-------------------------------------------- # skip_frames = int(fps * 10) # 设置跳过帧数为 10 秒 # # while cap.isOpened() and not stop_event.is_set(): # if skip_frames > 0: # skip_frames -= 1 # 逐帧减少 # cap.grab() # 仅抓取帧,不进行解码 # continue # 跳过处理 # # ret, frame = cap.read() # if not ret: # break # 无法读取帧时退出 # # # 目标检测 # results = ov_model(frame, conf=confidence, classes=cls, show=False) # # for result in results: # for box in result.boxes: # cls_index = int(box.cls[0]) # 获取类别索引 # # # 如果检测到的类别在 cls2 里,跳过 10 秒 # if cls_index in cls2: # skip_frames = int(fps * 10) # 设置跳过帧数为 10 秒 # # upload_and_insert_to_db(frame, ov_model, cls_index, save_dir, minio_path) # filename = f"{save_dir}/frame_{int(cap.get(cv2.CAP_PROP_POS_FRAMES))}_cls2.jpg" # cv2.imwrite(filename, frame) # print(f"保存图像: {filename}") # # # 绘制检测框 # x1, y1, x2, y2 = map(int, box.xyxy[0]) # label = f"{result.names[cls_index]} {box.conf[0]:.2f}" # cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2) # cv2.putText(frame, label, (x1, y1 - 10), # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) # except Exception as e: # print(f"cls2_find错误: {e}") # finally: # if 'cap' in locals() and cap is not None: # cap.release() # 后台计算的主要逻辑,识别图像,进而将事件通过mqtt转发、minio存储 # 1、识别视频流中的事件,进而将事件的前后帧做存储 # 2、事件基于mqtt 做转发,mqtt可以包含事件,也可以不包含事件 def startBackAIVideo2(task_id, video_path,push_url, m1, list_func_id, confidence, para, mqtt_ip, mqtt_port, mqtt_topic, model_cls, chinese_label): global mqtt_client, local_drc_message setIfAI(True) cap = None read_thread = None process_thread = None write_thread = None read_mqtt_drc_message = None ov_model = None local_drc_message = None # 清空数据 try: # 设置环境变量,提高YOLO性能 os.environ["OMP_NUM_THREADS"] = "4" os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 导入必要的库 import torch print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") torch.cuda.empty_cache() # 预加载YOLO模型 print("预加载YOLO模型...") max_retries = 3 retry_count = 0 model_params = {} # 尝试检测YOLO版本以适配不同版本的参数 try: # 先简单尝试加载模型,不带任何参数 test_model = YOLO(m1) # 如果成功,检查可用的参数 if hasattr(test_model, "task"): model_params["task"] = "detect" # 指定任务类型 # 检查是否支持half精度 if torch.cuda.is_available(): model_params["half"] = True # 检查是否支持verbose参数 import inspect if "verbose" in inspect.signature(YOLO.__init__).parameters: model_params["verbose"] = False print(f"检测到支持的YOLO参数: {model_params}") except Exception as e: print(f"参数检测失败,将使用默认参数: {e}") model_params = {} while retry_count < max_retries: try: ov_model = YOLO(m1, **model_params) dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8) for _ in range(3): # 伪代码,后续需要关联ai_model_list 的id cls = [0] ov_model(dummy_frame, classes=cls, conf=confidence, show=False) print("YOLO模型加载成功并预热完成") break except Exception as e: retry_count += 1 print(f"YOLO模型加载失败 (尝试 {retry_count}/{max_retries}): {e}") if "got an unexpected keyword argument" in str(e) and model_params: param_name = str(e).split("'")[-2] if param_name in model_params: del model_params[param_name] time.sleep(2) if ov_model is None: raise Exception("无法加载YOLO模型,达到最大重试次数") # 优化模型推理参数 ov_model.conf = confidence # 打开视频流 print(f"正在连接视频流: {video_path}") cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG) cap.set(cv2.CAP_PROP_OPEN_TIMEOUT_MSEC, 60000) cap.set(cv2.CAP_PROP_READ_TIMEOUT_MSEC, 50000) if not cap.isOpened(): print("无法打开视频流,尝试强制释放资源...") cap.release() time.sleep(2) raise Exception(f"无法打开视频流: {video_path}") # 设置全局帧队列大小 global frame_queue, processed_frame_queue frame_queue = queue.Queue(maxsize=80) processed_frame_queue = queue.Queue(maxsize=40) # 声明mqtt 对象 mqtt_client = MQTTClient(broker, port, topic) # 与机场维持消息通讯,进而读取最新一条drc消息 mqtt_client_drc = MQTTClient(mqtt_ip, mqtt_port, mqtt_topic) # 创建并启动线程 read_thread = Thread(target=read_frames, args=(cap,push_url, frame_queue)) # process_frames_tricker # 读取队列中的frame,并且对frame做二次计算 process_thread = Thread(target=process_frames_tricker, args=(push_url,frame_queue, processed_frame_queue, ov_model, list_func_id, confidence,para,model_cls,chinese_label)) # process_thread = Thread(target=multi_model_process_frames, # args=( # frame_queue, processed_frame_queue, ov_model, list_func_id, confidence, para, # model_cls, # chinese_label)) # process_thread = Thread(target=process_frames, # args=(model_id, frame_queue, processed_frame_queue, ov_model, cls, confidence)) write_thread = Thread(target=write_frames, args=(processed_frame_queue, None, (1280, 720), mqtt_client, task_id)) # pipe=None read_mqtt_drc_message = Thread(target=read_drc_mqtt, args=(mqtt_client_drc,)) # pipe=None # consumer = Thread(target=consumer_thread, args=(mqtt_client,)) read_thread.daemon = True process_thread.daemon = True write_thread.daemon = True read_mqtt_drc_message.daemon = True # consumer.daemon = True print("开始处理视频流...") read_thread.start() process_thread.start() write_thread.start() read_mqtt_drc_message.start() # consumer.start() # 等待线程结束 while getIfAI() and not stop_event.is_set(): if not (read_thread.is_alive() and process_thread.is_alive() and write_thread.is_alive() and read_mqtt_drc_message.is_alive() # and consumer.is_alive() ): print(f"read_thread.is_alive() {read_thread.is_alive()} " f"process_thread.is_alive() {process_thread.is_alive()} " f"write_thread.is_alive() {write_thread.is_alive()} ") # f"consumer.is_alive() {consumer.is_alive()} ") print("检测到某个线程已停止运行,正在终止所有线程...") stop_event.set() break time.sleep(0.1) except Exception as e: print(f"发生错误: {e}") finally: print("正在清理资源...") stop_event.set() setIfAI(False) # 安全关闭MQTT客户端 if mqtt_client is not None: try: mqtt_client.close() except Exception as e: print(f"关闭MQTT客户端时出错: {str(e)}") if (read_thread and process_thread and write_thread ): timeout = 3 start_time = time.time() while (read_thread.is_alive() or process_thread.is_alive() or write_thread.is_alive() or read_mqtt_drc_message.is_alive() # or consumer.is_alive() ) and time.time() - start_time < timeout: time.sleep(0.1) if cap is not None: try: cap.release() print("视频捕获资源已释放") except: pass try: cv2.destroyAllWindows() except: pass print("所有资源已清理完毕") if __name__ == '__main__': sn = "1581F6QAD243C00BP71E" # video_path = f"rtmp://112.44.103.230:1935/live/{sn}" video_path = "rtmp://112.44.103.230:1935/live/123456" # FFmpeg 推流地址 # rtmp = f"rtmp://112.44.103.230:1935/live/{sn}ai" mq_ts = time.time_ns() # 时间戳,标识唯一消息 try: startBackAIVideo(mq_ts, video_path, "best.pt", [0, 1, 2, 3, 4], 0.4) time.sleep(60) except KeyboardInterrupt: print("程序被用户中断") stopAIVideo() except Exception as e: print(f"程序异常: {e}") # import cv2 # # print(cv2.getBuildInformation())