ai_project_v1/video_invade.py

189 lines
6.8 KiB
Python
Raw Normal View History

import json
import os
import subprocess
import time
import cv2
import numpy as np
import torch
from ultralytics import YOLO
from middleware.read_srt import parse_srt_file
from touying.ImageReproject_python.cal_func import red_line_reproject
from touying.ImageReproject_python.img_types import Point
from yolo.cv_multi_model_back_video import is_point_in_polygon
# 配置参数
local_video_path = r"/mnt/mydisk1/dj/ai2/middleware/DJI_20250912152112_0001_V.mp4"
srt_path = r"/mnt/mydisk1/dj/ai2/middleware/DJI_20250912152112_0001_V.mp4.srt"
invade_file_path = r"/mnt/mydisk1/dj/ai2/test/高压线-0826.geojson"
# invade_file_path = r"C:\Users\14867\xwechat_files\wxid_lqq76m3bwzja21_94e1\msg\file\2025-08\营业线-0826.geojson"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_paths = [
# r"D:\project\AI-PYTHON\Ai_tottle\pt\best.pt", # 模型1
r"/mnt/mydisk1/dj/ai2/pt/GDCL.pt" # 模型2可扩展
]
output_dir = r"/mnt/mydisk1/dj/ai2/test/test_pic_1"
os.makedirs(output_dir, exist_ok=True)
def load_models(model_paths, device):
"""加载多个YOLO模型"""
models = []
for path in model_paths:
model = YOLO(path).to(device)
models.append(model)
return models
def merge_results(results_list, conf_threshold=0.5, iou_threshold=0.5):
"""整合多个模型的推理结果"""
all_boxes = []
for results in results_list:
for result in results:
for box in result.boxes:
if box.conf[0] > conf_threshold:
x1, y1, x2, y2 = map(float, box.xyxy[0].tolist())
conf = float(box.conf[0])
class_id = int(box.cls[0])
all_boxes.append([x1, y1, x2, y2, conf, class_id])
if not all_boxes:
return np.empty((0, 6)) # 返回空数组而不是None
# 转换为NumPy数组
boxes = np.array(all_boxes)
try:
import torchvision
boxes_tensor = torch.tensor(boxes[:, :4]) # 坐标
scores = torch.tensor(boxes[:, 4]) # 置信度
keep = torchvision.ops.nms(boxes_tensor, scores, iou_threshold)
merged_boxes = boxes[keep]
except ImportError:
print("Warning: torchvision not installed, skipping NMS.")
merged_boxes = boxes
return merged_boxes # 确保返回的是二维数组
def cal():
# 加载多边形数据
list_points = []
with open(invade_file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
features = data.get('features', [])
for polygon in features:
geometry = polygon.get('geometry', {})
coordinates = geometry.get('coordinates', [])
points = [Point(coord[1], coord[0], coord[2], key) for key, coord in enumerate(coordinates[0])]
list_points.append(points)
# # 提取视频字幕(略,保持原逻辑)
# dir_name = os.path.dirname(local_video_path)
# srt_name = os.path.basename(local_video_path) + ".srt"
# srt_path = os.path.join(dir_name, srt_name)
# command = ["ffmpeg", "-i", local_video_path, "-map", "0:s:0", "-c:s", "srt", srt_path]
# try:
# subprocess.run(command, check=True)
# except Exception as e:
# print(f"字幕提取失败: {e}")
srt_list = parse_srt_file(srt_path)
cap = cv2.VideoCapture(local_video_path)
if not cap.isOpened():
print("Error: 无法打开视频")
return
# 加载多模型
models = load_models(model_paths, device)
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
img_height, img_width = frame.shape[:2]
frame_count += 1
# 处理SRT字幕保持原逻辑
if frame_count < len(srt_list):
dji_srt = srt_list[frame_count]
gimbal_pitch = dji_srt.gb_pitch
gb_yaw = dji_srt.gb_yaw
air_height = dji_srt.abs_alt
cam_longitude = dji_srt.longitude
cam_latitude = dji_srt.latitude
# 多模型推理
all_results = []
for model in models:
try:
results = model(frame)
all_results.append(results)
except Exception as e:
print(f"模型推理错误: {e}")
# 修改后的结果处理部分
merged_boxes = merge_results(all_results)
# 确保merged_boxes是二维数组
if len(merged_boxes.shape) == 1:
merged_boxes = np.expand_dims(merged_boxes, 0)
# 绘制多边形(保持原逻辑)
for points in list_points:
red_result = red_line_reproject(gb_yaw, gimbal_pitch, dji_srt.gb_roll, air_height,
cam_longitude, cam_latitude, img_width, img_height, points)
point_list = [[int(p["u"]), int(p["v"])] for p in red_result]
cv2.polylines(frame, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255), thickness=2)
message_point=[]
message_out_point=[]
# 可视化合并后的结果
for box in merged_boxes:
# 现在box应该是长度为6的数组
x1, y1, x2, y2, conf, class_id = map(int, box[:6]) # 取前6个元素
# cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
# cv2.putText(frame, f"ID:{class_id} {conf:.2f}", (x1, y1 - 10),
# cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
point_x = (x1 + x2) / 2
point_y = (y1 + y2) / 2
is_invade = is_point_in_polygon(point_x, point_y, red_result)
if is_invade:
message_point.append({
# "u": point_x,
# "v": point_y,
# "class_name": class_name
"box": [x1, y1, x2, y2]
})
else:
message_out_point.append({
# "u": point_x,
# "v": point_y,
# "class_name": class_name
"box": [x1, y1, x2, y2]
})
for point in message_point:
cv2.rectangle(frame, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]),
(0, 255, 255), 2)
# for point in message_out_point:
# cv2.rectangle(frame, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]),
# (0, 255, 0), 2)
# 显示并保存结果
# cv2.imshow("Multi-Model Detection", frame)
cv2.imwrite(os.path.join(output_dir, f"{frame_count}.png"), frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
cal()