189 lines
6.8 KiB
Python
189 lines
6.8 KiB
Python
|
|
import json
|
|||
|
|
import os
|
|||
|
|
import subprocess
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
import cv2
|
|||
|
|
import numpy as np
|
|||
|
|
import torch
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
|
|||
|
|
from middleware.read_srt import parse_srt_file
|
|||
|
|
from touying.ImageReproject_python.cal_func import red_line_reproject
|
|||
|
|
from touying.ImageReproject_python.img_types import Point
|
|||
|
|
from yolo.cv_multi_model_back_video import is_point_in_polygon
|
|||
|
|
|
|||
|
|
# 配置参数
|
|||
|
|
local_video_path = r"/mnt/mydisk1/dj/ai2/middleware/DJI_20250912152112_0001_V.mp4"
|
|||
|
|
srt_path = r"/mnt/mydisk1/dj/ai2/middleware/DJI_20250912152112_0001_V.mp4.srt"
|
|||
|
|
invade_file_path = r"/mnt/mydisk1/dj/ai2/test/高压线-0826.geojson"
|
|||
|
|
# invade_file_path = r"C:\Users\14867\xwechat_files\wxid_lqq76m3bwzja21_94e1\msg\file\2025-08\营业线-0826.geojson"
|
|||
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|||
|
|
model_paths = [
|
|||
|
|
# r"D:\project\AI-PYTHON\Ai_tottle\pt\best.pt", # 模型1
|
|||
|
|
r"/mnt/mydisk1/dj/ai2/pt/GDCL.pt" # 模型2(可扩展)
|
|||
|
|
]
|
|||
|
|
output_dir = r"/mnt/mydisk1/dj/ai2/test/test_pic_1"
|
|||
|
|
os.makedirs(output_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_models(model_paths, device):
|
|||
|
|
"""加载多个YOLO模型"""
|
|||
|
|
models = []
|
|||
|
|
for path in model_paths:
|
|||
|
|
model = YOLO(path).to(device)
|
|||
|
|
models.append(model)
|
|||
|
|
return models
|
|||
|
|
|
|||
|
|
|
|||
|
|
def merge_results(results_list, conf_threshold=0.5, iou_threshold=0.5):
|
|||
|
|
"""整合多个模型的推理结果"""
|
|||
|
|
all_boxes = []
|
|||
|
|
for results in results_list:
|
|||
|
|
for result in results:
|
|||
|
|
for box in result.boxes:
|
|||
|
|
if box.conf[0] > conf_threshold:
|
|||
|
|
x1, y1, x2, y2 = map(float, box.xyxy[0].tolist())
|
|||
|
|
conf = float(box.conf[0])
|
|||
|
|
class_id = int(box.cls[0])
|
|||
|
|
all_boxes.append([x1, y1, x2, y2, conf, class_id])
|
|||
|
|
|
|||
|
|
if not all_boxes:
|
|||
|
|
return np.empty((0, 6)) # 返回空数组而不是None
|
|||
|
|
|
|||
|
|
# 转换为NumPy数组
|
|||
|
|
boxes = np.array(all_boxes)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import torchvision
|
|||
|
|
boxes_tensor = torch.tensor(boxes[:, :4]) # 坐标
|
|||
|
|
scores = torch.tensor(boxes[:, 4]) # 置信度
|
|||
|
|
keep = torchvision.ops.nms(boxes_tensor, scores, iou_threshold)
|
|||
|
|
merged_boxes = boxes[keep]
|
|||
|
|
except ImportError:
|
|||
|
|
print("Warning: torchvision not installed, skipping NMS.")
|
|||
|
|
merged_boxes = boxes
|
|||
|
|
|
|||
|
|
return merged_boxes # 确保返回的是二维数组
|
|||
|
|
|
|||
|
|
|
|||
|
|
def cal():
|
|||
|
|
# 加载多边形数据
|
|||
|
|
list_points = []
|
|||
|
|
with open(invade_file_path, 'r', encoding='utf-8') as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
features = data.get('features', [])
|
|||
|
|
for polygon in features:
|
|||
|
|
geometry = polygon.get('geometry', {})
|
|||
|
|
coordinates = geometry.get('coordinates', [])
|
|||
|
|
points = [Point(coord[1], coord[0], coord[2], key) for key, coord in enumerate(coordinates[0])]
|
|||
|
|
list_points.append(points)
|
|||
|
|
|
|||
|
|
# # 提取视频字幕(略,保持原逻辑)
|
|||
|
|
# dir_name = os.path.dirname(local_video_path)
|
|||
|
|
# srt_name = os.path.basename(local_video_path) + ".srt"
|
|||
|
|
# srt_path = os.path.join(dir_name, srt_name)
|
|||
|
|
# command = ["ffmpeg", "-i", local_video_path, "-map", "0:s:0", "-c:s", "srt", srt_path]
|
|||
|
|
# try:
|
|||
|
|
# subprocess.run(command, check=True)
|
|||
|
|
# except Exception as e:
|
|||
|
|
# print(f"字幕提取失败: {e}")
|
|||
|
|
|
|||
|
|
srt_list = parse_srt_file(srt_path)
|
|||
|
|
cap = cv2.VideoCapture(local_video_path)
|
|||
|
|
if not cap.isOpened():
|
|||
|
|
print("Error: 无法打开视频")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 加载多模型
|
|||
|
|
models = load_models(model_paths, device)
|
|||
|
|
frame_count = 0
|
|||
|
|
|
|||
|
|
while cap.isOpened():
|
|||
|
|
ret, frame = cap.read()
|
|||
|
|
if not ret:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
img_height, img_width = frame.shape[:2]
|
|||
|
|
frame_count += 1
|
|||
|
|
|
|||
|
|
# 处理SRT字幕(保持原逻辑)
|
|||
|
|
if frame_count < len(srt_list):
|
|||
|
|
dji_srt = srt_list[frame_count]
|
|||
|
|
gimbal_pitch = dji_srt.gb_pitch
|
|||
|
|
gb_yaw = dji_srt.gb_yaw
|
|||
|
|
air_height = dji_srt.abs_alt
|
|||
|
|
cam_longitude = dji_srt.longitude
|
|||
|
|
cam_latitude = dji_srt.latitude
|
|||
|
|
|
|||
|
|
# 多模型推理
|
|||
|
|
all_results = []
|
|||
|
|
for model in models:
|
|||
|
|
try:
|
|||
|
|
results = model(frame)
|
|||
|
|
all_results.append(results)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"模型推理错误: {e}")
|
|||
|
|
|
|||
|
|
# 修改后的结果处理部分
|
|||
|
|
merged_boxes = merge_results(all_results)
|
|||
|
|
|
|||
|
|
# 确保merged_boxes是二维数组
|
|||
|
|
if len(merged_boxes.shape) == 1:
|
|||
|
|
merged_boxes = np.expand_dims(merged_boxes, 0)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 绘制多边形(保持原逻辑)
|
|||
|
|
for points in list_points:
|
|||
|
|
red_result = red_line_reproject(gb_yaw, gimbal_pitch, dji_srt.gb_roll, air_height,
|
|||
|
|
cam_longitude, cam_latitude, img_width, img_height, points)
|
|||
|
|
point_list = [[int(p["u"]), int(p["v"])] for p in red_result]
|
|||
|
|
cv2.polylines(frame, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255), thickness=2)
|
|||
|
|
message_point=[]
|
|||
|
|
message_out_point=[]
|
|||
|
|
# 可视化合并后的结果
|
|||
|
|
for box in merged_boxes:
|
|||
|
|
# 现在box应该是长度为6的数组
|
|||
|
|
x1, y1, x2, y2, conf, class_id = map(int, box[:6]) # 取前6个元素
|
|||
|
|
# cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
|||
|
|
# cv2.putText(frame, f"ID:{class_id} {conf:.2f}", (x1, y1 - 10),
|
|||
|
|
# cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
|||
|
|
point_x = (x1 + x2) / 2
|
|||
|
|
point_y = (y1 + y2) / 2
|
|||
|
|
is_invade = is_point_in_polygon(point_x, point_y, red_result)
|
|||
|
|
if is_invade:
|
|||
|
|
|
|||
|
|
message_point.append({
|
|||
|
|
# "u": point_x,
|
|||
|
|
# "v": point_y,
|
|||
|
|
# "class_name": class_name
|
|||
|
|
"box": [x1, y1, x2, y2]
|
|||
|
|
})
|
|||
|
|
else:
|
|||
|
|
message_out_point.append({
|
|||
|
|
# "u": point_x,
|
|||
|
|
# "v": point_y,
|
|||
|
|
# "class_name": class_name
|
|||
|
|
"box": [x1, y1, x2, y2]
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
for point in message_point:
|
|||
|
|
cv2.rectangle(frame, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]),
|
|||
|
|
(0, 255, 255), 2)
|
|||
|
|
# for point in message_out_point:
|
|||
|
|
# cv2.rectangle(frame, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]),
|
|||
|
|
# (0, 255, 0), 2)
|
|||
|
|
|
|||
|
|
# 显示并保存结果
|
|||
|
|
# cv2.imshow("Multi-Model Detection", frame)
|
|||
|
|
cv2.imwrite(os.path.join(output_dir, f"{frame_count}.png"), frame)
|
|||
|
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
cap.release()
|
|||
|
|
cv2.destroyAllWindows()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
cal()
|