import json import os import subprocess import time import cv2 import numpy as np import torch from ultralytics import YOLO from middleware.read_srt import parse_srt_file from touying.ImageReproject_python.cal_func import red_line_reproject from touying.ImageReproject_python.img_types import Point from yolo.cv_multi_model_back_video import is_point_in_polygon # 配置参数 local_video_path = r"/mnt/mydisk1/dj/ai2/middleware/DJI_20250912152112_0001_V.mp4" srt_path = r"/mnt/mydisk1/dj/ai2/middleware/DJI_20250912152112_0001_V.mp4.srt" invade_file_path = r"/mnt/mydisk1/dj/ai2/test/高压线-0826.geojson" # invade_file_path = r"C:\Users\14867\xwechat_files\wxid_lqq76m3bwzja21_94e1\msg\file\2025-08\营业线-0826.geojson" device = "cuda:0" if torch.cuda.is_available() else "cpu" model_paths = [ # r"D:\project\AI-PYTHON\Ai_tottle\pt\best.pt", # 模型1 r"/mnt/mydisk1/dj/ai2/pt/GDCL.pt" # 模型2(可扩展) ] output_dir = r"/mnt/mydisk1/dj/ai2/test/test_pic_1" os.makedirs(output_dir, exist_ok=True) def load_models(model_paths, device): """加载多个YOLO模型""" models = [] for path in model_paths: model = YOLO(path).to(device) models.append(model) return models def merge_results(results_list, conf_threshold=0.5, iou_threshold=0.5): """整合多个模型的推理结果""" all_boxes = [] for results in results_list: for result in results: for box in result.boxes: if box.conf[0] > conf_threshold: x1, y1, x2, y2 = map(float, box.xyxy[0].tolist()) conf = float(box.conf[0]) class_id = int(box.cls[0]) all_boxes.append([x1, y1, x2, y2, conf, class_id]) if not all_boxes: return np.empty((0, 6)) # 返回空数组而不是None # 转换为NumPy数组 boxes = np.array(all_boxes) try: import torchvision boxes_tensor = torch.tensor(boxes[:, :4]) # 坐标 scores = torch.tensor(boxes[:, 4]) # 置信度 keep = torchvision.ops.nms(boxes_tensor, scores, iou_threshold) merged_boxes = boxes[keep] except ImportError: print("Warning: torchvision not installed, skipping NMS.") merged_boxes = boxes return merged_boxes # 确保返回的是二维数组 def cal(): # 加载多边形数据 list_points = [] with open(invade_file_path, 'r', encoding='utf-8') as f: data = json.load(f) features = data.get('features', []) for polygon in features: geometry = polygon.get('geometry', {}) coordinates = geometry.get('coordinates', []) points = [Point(coord[1], coord[0], coord[2], key) for key, coord in enumerate(coordinates[0])] list_points.append(points) # # 提取视频字幕(略,保持原逻辑) # dir_name = os.path.dirname(local_video_path) # srt_name = os.path.basename(local_video_path) + ".srt" # srt_path = os.path.join(dir_name, srt_name) # command = ["ffmpeg", "-i", local_video_path, "-map", "0:s:0", "-c:s", "srt", srt_path] # try: # subprocess.run(command, check=True) # except Exception as e: # print(f"字幕提取失败: {e}") srt_list = parse_srt_file(srt_path) cap = cv2.VideoCapture(local_video_path) if not cap.isOpened(): print("Error: 无法打开视频") return # 加载多模型 models = load_models(model_paths, device) frame_count = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break img_height, img_width = frame.shape[:2] frame_count += 1 # 处理SRT字幕(保持原逻辑) if frame_count < len(srt_list): dji_srt = srt_list[frame_count] gimbal_pitch = dji_srt.gb_pitch gb_yaw = dji_srt.gb_yaw air_height = dji_srt.abs_alt cam_longitude = dji_srt.longitude cam_latitude = dji_srt.latitude # 多模型推理 all_results = [] for model in models: try: results = model(frame) all_results.append(results) except Exception as e: print(f"模型推理错误: {e}") # 修改后的结果处理部分 merged_boxes = merge_results(all_results) # 确保merged_boxes是二维数组 if len(merged_boxes.shape) == 1: merged_boxes = np.expand_dims(merged_boxes, 0) # 绘制多边形(保持原逻辑) for points in list_points: red_result = red_line_reproject(gb_yaw, gimbal_pitch, dji_srt.gb_roll, air_height, cam_longitude, cam_latitude, img_width, img_height, points) point_list = [[int(p["u"]), int(p["v"])] for p in red_result] cv2.polylines(frame, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255), thickness=2) message_point=[] message_out_point=[] # 可视化合并后的结果 for box in merged_boxes: # 现在box应该是长度为6的数组 x1, y1, x2, y2, conf, class_id = map(int, box[:6]) # 取前6个元素 # cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) # cv2.putText(frame, f"ID:{class_id} {conf:.2f}", (x1, y1 - 10), # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) point_x = (x1 + x2) / 2 point_y = (y1 + y2) / 2 is_invade = is_point_in_polygon(point_x, point_y, red_result) if is_invade: message_point.append({ # "u": point_x, # "v": point_y, # "class_name": class_name "box": [x1, y1, x2, y2] }) else: message_out_point.append({ # "u": point_x, # "v": point_y, # "class_name": class_name "box": [x1, y1, x2, y2] }) for point in message_point: cv2.rectangle(frame, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]), (0, 255, 255), 2) # for point in message_out_point: # cv2.rectangle(frame, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]), # (0, 255, 0), 2) # 显示并保存结果 # cv2.imshow("Multi-Model Detection", frame) cv2.imwrite(os.path.join(output_dir, f"{frame_count}.png"), frame) if cv2.waitKey(1) & 0xFF == ord('q'): break cap.release() cv2.destroyAllWindows() if __name__ == '__main__': cal()