import ctypes import datetime import json import os import time import cv2 import numpy as np import torch from ultralytics import YOLO from cv_back_video import extract_box_details, cal_tricker_results from middleware.minio_util import upload_file, downFile from middleware.read_pic_metadata import read_dji_exif_to_dict from mqtt_pub import MQTTClient from ultralytics import solutions from yolo.detect.yolo11_det_pic_trt import YoLo11TRT mqtt_client = None # MQTT 代理地址和端口 # broker = "112.44.103.230" # 公共 MQTT 代理(免费) broker = "8.137.54.85" # 公共 MQTT 代理(免费) port = 1883 # MQTT 默认端口 # 主题 topic = "thing/product/ai/events" def pic_detect_func(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic, task_id, model_func_id_list, para, s3_url_list, m1, cls, chinese_label, confidence, use_fp16=False): global mqtt_client ov_model = None try: # 设置环境变量,提高YOLO性能 os.environ["OMP_NUM_THREADS"] = "4" os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 导入必要的库 import torch print(f"PyTorch 可用: {torch.__version__}, CUDA可用: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") torch.cuda.empty_cache() # 预加载YOLO模型 print("预加载YOLO模型...") max_retries = 3 retry_count = 0 model_params = {} # 尝试检测YOLO版本以适配不同版本的参数 try: # 先简单尝试加载模型,不带任何参数 test_model = YOLO(m1) # 如果成功,检查可用的参数 if hasattr(test_model, "task"): model_params["task"] = "detect" # 指定任务类型 # 检查是否支持half精度 if torch.cuda.is_available(): model_params["half"] = True # 检查是否支持verbose参数 import inspect if "verbose" in inspect.signature(YOLO.__init__).parameters: model_params["verbose"] = False print(f"检测到支持的YOLO参数: {model_params}") except Exception as e: print(f"参数检测失败,将使用默认参数: {e}") model_params = {} while retry_count < max_retries: try: ov_model = YOLO(m1, **model_params) dummy_frame = np.zeros((1080, 1920, 3), dtype=np.uint8) for _ in range(3): ov_model(dummy_frame, classes=cls, conf=confidence, show=False) print("YOLO模型加载成功并预热完成") break except Exception as e: retry_count += 1 print(f"YOLO模型加载失败 (尝试 {retry_count}/{max_retries}): {e}") if "got an unexpected keyword argument" in str(e) and model_params: param_name = str(e).split("'")[-2] if param_name in model_params: del model_params[param_name] time.sleep(2) if ov_model is None: raise Exception("无法加载YOLO模型,达到最大重试次数") # 优化模型推理参数 ov_model.conf = confidence # 将模型移到设备(GPU或CPU) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' try: ov_model.to(device) # 调整批处理大小为1,减少内存占用 if hasattr(ov_model, 'args') and hasattr(ov_model.args, 'batch'): ov_model.args.batch = 1 # 启用半精度 if use_fp16 and device.startswith('cuda') and hasattr(ov_model, 'model'): try: ov_model.model = ov_model.model.half() print("✅ 启用半精度 FP16 模式") except Exception as half_err: print(f"⚠️ 半精度转换失败: {half_err}") else: print("ℹ️ 半精度模式未启用") except Exception as e: print(f"⚠️ 模型设备配置警告: {e}") mqtt_client = MQTTClient(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic) # solutions默认会在图上绘制一个方框,很烦 line_points = [(-10, -10), (-10, -10)] # 伪代码,后续识别类型需要跟ai_model_list 的id关联 # cls = [0, 1, 2, 3, 4] # 这里参数show=True,一旦通过rest接口重启程序,会导致线程锁住 counter = solutions.ObjectCounter(show=False, region=line_points, model=ov_model, classes=cls) # Model processing setup device = 'cuda:0' if torch.cuda.is_available() else 'cpu' try: counter.model.to(device) if hasattr(counter.model, 'args') and hasattr(counter.model.args, 'batch'): counter.model.args.batch = 1 # 启用半精度 if use_fp16 and device.startswith('cuda') and hasattr(counter.model, 'model'): try: counter.model.model = counter.model.model.half() print("✅ 启用半精度 FP16 模式") except Exception as half_err: print(f"⚠️ 半精度转换失败: {half_err}") else: print("ℹ️ 半精度模式未启用") except Exception as e: print(f"⚠️ 模型设备配置警告: {e}") error_count = 0 # Get class names from the model class_names = counter.model.names if hasattr(counter.model, 'names') else {i: str(i) for i in range(1000)} local_func_cache = { "func_100000": None, "func_100004": None, # 存储缓存,缓存人员track_id "func_100006": None # 存储缓存,缓存车辆track_id } try: for pic_url in s3_url_list: image_path = downFile(pic_url) print("1") frame = cv2.imread(image_path) # 读取照片元信息,包括经纬度、飞行姿态 image_exif_dict = read_dji_exif_to_dict(image_path) try: frame_copy = frame.copy() results = counter(frame) annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names, model_func_id_list, local_func_cache, para, cls, chinese_label, model_func_id_list[0]) except Exception as e: print(f"处理帧错误: {e}") error_count += 1 if error_count >= 5: print(f"连续处理错误达到5次 ,正在停止处理...") break # 提取边界框统计信息 # box_detail = extract_box_details(last_results, model_id) # 保存结果图片 local_path = f"output_frames/{os.path.basename(image_path)}" success = cv2.imwrite(local_path, annotated_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) if success: print(f"图像已保存到: {local_path}") else: print("错误: 图像保存失败") date_str = str(datetime.datetime.now().timestamp()) minio_path_before, file_type_before = upload_file(image_path, None) minio_path, file_type = upload_file(local_path, date_str) os.remove(local_path) message = { "task_id": task_id, "minio": { "minio_path": minio_path, "file_type": file_type, "original_path": minio_path_before }, "box_detail": box_result, "image_exif_dict": image_exif_dict } # print("16") # json_message = json.dumps(message) # 将字典转换为 JSON 字符串,使用双引号,并保留中文字符 json_message = json.dumps(message, indent=4, ensure_ascii=False) mqtt_client.publish_message(json_message) time.sleep(0.1) if os.path.exists(minio_path): os.remove(minio_path) if os.path.exists(minio_path_before): os.remove(minio_path_before) except Exception as infer_err: print(f"推理错误: {infer_err}") except Exception as infer_err: print(f"pic_detect 推理错误: {infer_err}") def pic_detect_func_trt(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic, task_id, s3_url_list, engine_file_path, PLUGIN_LIBRARY, confidence, categories ): global mqtt_client try: print(f"mqtt_client {mqtt_pub_ip} {mqtt_pub_port} {mqtt_pub_topic}") mqtt_client = MQTTClient(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic) print("mqtt") # # 设置引擎和插件路径 # engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine" # PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll" # 加载插件 ctypes.CDLL(PLUGIN_LIBRARY) print(f"PLUGIN_LIBRARY{PLUGIN_LIBRARY}") # 初始化推理器 yolo11_wrapper = YoLo11TRT(engine_file_path) print(f"engine_file_path{engine_file_path} {len(s3_url_list)}") try: for pic_url in s3_url_list: image_path = downFile(pic_url) print("1") frame = cv2.imread(image_path) print("2") # # 读取照片元信息,包括经纬度、飞行姿态 # image_exif_dict = read_dji_exif_to_dict(image_path) batch_image_raw, result_box_list, result_scores_list, result_classid_list, infer_time = yolo11_wrapper.infer( [frame], confidence, categories) detected_frame = batch_image_raw[0] # 提取边界框统计信息 # box_detail = extract_box_details(last_results, model_id) # 保存结果图片 local_path = f"output_frames/{os.path.basename(image_path)}" success = cv2.imwrite(local_path, detected_frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) if success: print(f"图像已保存到: {local_path}") else: print("错误: 图像保存失败") date_str = str(datetime.datetime.now().timestamp()) minio_path_before, file_type_before = upload_file(image_path, None) minio_path, file_type = upload_file(local_path, date_str) os.remove(local_path) message = { "task_id": task_id, "minio": { "minio_path": minio_path, "file_type": file_type, "original_path": minio_path_before }, "box_detail": { "result_box_list": result_box_list, "result_scores_list": result_scores_list, "result_classid_list": result_classid_list } # "box_detail": { # "result_box_list": result_box_list.tolist() if isinstance(result_box_list, # np.ndarray) else result_box_list, # "result_scores_list": result_scores_list.tolist() if isinstance(result_scores_list, # np.ndarray) else result_scores_list, # "result_classid_list": result_classid_list.tolist() if isinstance(result_classid_list, # np.ndarray) else result_classid_list # } } json_message = json.dumps(message, indent=4, ensure_ascii=False) mqtt_client.publish_message(json_message) time.sleep(0.1) if os.path.exists(minio_path): os.remove(minio_path) if os.path.exists(minio_path_before): os.remove(minio_path_before) except Exception as infer_err: print(f"推理错误: {infer_err}") finally: yolo11_wrapper.destroy() except Exception as infer_err: print(f"pic_detect 推理错误: {infer_err}")