ai_project_v1/yolo/cv_multi_model_back_video.py0921

2421 lines
102 KiB
Plaintext
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os.path
import subprocess
from asyncio import timeout
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import json
import time
from scipy.optimize import brent
from sympy import false
from sympy.codegen.ast import continue_
from ultralytics import YOLO
import torch
# 假设其他导入模块正确存在
from cv_back_video import cal_tricker_results, get_local_drc_message, read_drc_mqtt, async_read_drc_mqtt
from middleware.MQTTService import MQTTService, MQTTDevice
from middleware.entity.air_attitude import Air_Attitude
from middleware.entity.up_drc_camera_osd_info_push import parse_camera_osd_info
from middleware.entity.up_osd_info_push import parse_osd_message
from middleware.minio_util import upload_file_from_buffer, upload_frame_buff_from_buffer, downFile, \
downBigFile
import av
import asyncio
import cv2
import numpy as np
from typing import Dict, Any
from ultralytics import solutions
from middleware.read_srt import parse_srt_file
from touying.ImageReproject_python.cal_func import cal_canv_location_by_osd, red_line_reproject
from touying.ImageReproject_python.img_types import Point
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 全局配置
class Config:
MAX_WORKERS = 8 # 图像识别线程池大小
FRAME_QUEUE_SIZE = 120 # 帧队列大小
PROCESSED_QUEUE_SIZE = 60 # 处理后帧队列大小
RETRY_COUNT = 3 # 模型加载重试次数
UPLOAD_WORKERS = 2 # 上传专用线程池大小
MAX_QUEUE_WARN_SIZE = 15 # 队列警告阈值(PROCESSED_QUEUE_SIZE的一半)
MODEL_INPUT_SIZE = (640, 640) # YOLO模型输入尺寸确保能被32整除
DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" # 默认使用第一个GPU
TARGET_FPS = 25 # 目标推流帧率
MAX_FRAME_DROP = 5 # 最大允许丢帧数(避免队列积压)
# # 创建专用线程池
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
# frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
# processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
# stop_event = asyncio.Event()
# 模型缓存
model_cache = {}
# 性能统计
stats = {
'processed': 0,
'last_time': time.time(),
'avg_fps': 0.0,
'dropped_frames': 0
}
# 改为在每次调用时创建和销毁
async def initialize_resources():
"""初始化资源(线程池、队列等)"""
global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
stop_event = asyncio.Event()
# async def cleanup_resources():
# """清理资源"""
# global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
# stop_event.set()
# if 'executor' in globals() and executor is not None:
# executor.shutdown(wait=True)
# if 'upload_executor' in globals() and upload_executor is not None:
# upload_executor.shutdown(wait=True)
async def cleanup_resources():
"""清理资源"""
global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
stop_event.set()
# 等待所有任务完成
await asyncio.sleep(1.0) # 给处理中的帧一些时间完成
# 关闭队列
try:
while not frame_queue.empty():
await frame_queue.get()
while not processed_queue.empty():
await processed_queue.get()
while not invade_queue.empty():
await invade_queue.get()
while not cv_frame_queue.empty():
await cv_frame_queue.get()
except Exception as e:
logger.warning(f"清理队列时出错: {e}")
# 关闭线程池
if 'executor' in globals() and executor is not None:
executor.shutdown(wait=True, cancel_futures=True)
if 'upload_executor' in globals() and upload_executor is not None:
upload_executor.shutdown(wait=True, cancel_futures=True)
# 关闭所有推流容器
for container in stream_containers.values():
try:
if 'container' in container:
container['container'].close()
except Exception as e:
logger.warning(f"关闭推流容器时出错: {e}")
stream_containers.clear()
class MultiDetectionResults:
def __init__(self):
self.boxes = [] # 边界框
self.clss = [] # 类别ID
self.cls_names = [] # 类别名称
self.cls_en_names = [] # 英文类别名称
self.confs = [] # 置信度
self.track_ids = [] # 置信度
@dataclass
class DetectionResult:
bbox: List[int] # [x1, y1, x2, y2]
class_id: int
class_name: str
confidence: float
track_id: int
@dataclass
class DetectionResultList:
boxes: [] # [x1, y1, x2, y2]
clss: []
clss_name: []
confs: []
track_ids: []
def __iter__(self):
"""迭代返回每个检测结果的 (bbox, cls_id, cls_name, conf)"""
for i in range(len(self.boxes)):
yield (
self.boxes[i],
self.clss[i],
self.clss_name[i],
self.confs[i],
self.track_ids[i]
)
class MultiYOLODetector:
"""多模型并行检测器修复多GPU设备不匹配问题"""
def __init__(self, model_configs: List[Dict]):
self.models = []
self.class_maps = []
self.allowed_classes = []
self.model_cls = []
self.chinese_label = []
self.list_func_id = []
self.func_id = -1
self.list_class_names = []
self.list_para_invade_enable = []
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
self._setup_multi_gpu()
for config in model_configs:
model_path = config['path']
cls_map = config.get('cls_map', {})
allowed = config.get('allowed_classes', None)
tracking = config.get('tracking', True) # 是否启用跟踪
model_cls_index = config.get('cls_index', [])
model_chinese_labe = config.get('chinese_label', [])
model_list_func_id = config.get('list_func_id', [])
func_id = config.get('func_id', -1)
class_names = config.get('class_names', True)
para_invade_enable = config.get('para_invade_enable', False)
# 加载模型并确保在正确设备上
model = self._load_model(model_path, tracking)
self.models.append(model)
self.class_maps.append(cls_map)
self.allowed_classes.append(allowed)
self.model_cls.append(model_cls_index)
self.chinese_label.append(model_chinese_labe)
self.list_func_id.append(model_list_func_id)
self.func_id = func_id
self.list_class_names.append(class_names)
self.list_para_invade_enable.append(para_invade_enable)
def _setup_multi_gpu(self):
"""配置多GPU环境"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"检测到 {gpu_count} 个GPU将使用默认设备: {self.device}")
# 强制所有操作使用默认设备
torch.cuda.set_device(self.device)
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
try:
# 加载模型
model = YOLO(model_path)
# 移动到目标设备
model = model.to(self.device)
# 启用跟踪仅对YOLOv8有效
if hasattr(model, 'tracker'):
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
print(f"已为模型 {model_path} 启用跟踪")
else:
print(f"警告:模型 {model_path} 不支持跟踪功能")
# 对于多GPU仅在必要时使用DataParallel并确保所有参数在同一设备
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# 确保模型参数在默认设备上
model = model.to(self.device)
model = torch.nn.DataParallel(model, device_ids=[self.device])
# 禁用模型融合以避免多设备问题
model.fuse = lambda verbose=False: model # 替换融合方法
return model
except Exception as e:
print(f"模型加载失败: {e}")
raise
@staticmethod
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
"""预处理帧:调整大小、转换格式并标准化"""
# 调整大小
resized = cv2.resize(frame, input_size)
# 转换为RGB
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
# 转换为张量并标准化
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
# 添加批次维度并移动到目标设备
tensor = tensor.unsqueeze(0).to(device)
return tensor
@staticmethod
def scale_bbox(bbox: List[int], original_size: tuple, input_size: tuple) -> List[int]:
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
oh, ow = original_size
iw, ih = input_size
# 计算缩放比例
scale_w = ow / iw
scale_h = oh / ih
# 缩放边界框
x1, y1, x2, y2 = bbox
return [
int(x1 * scale_w),
int(y1 * scale_h),
int(x2 * scale_w),
int(y2 * scale_h)
]
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
"""异步调用多模型预测"""
loop = asyncio.get_running_loop()
original_size = (frame.shape[0], frame.shape[1]) # (height, width)
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
model = self.models[model_idx]
cls_map = self.class_maps[model_idx]
allowed = self.allowed_classes[model_idx]
model_cls_index = self.model_cls[model_idx]
model_chinese_labe = self.chinese_label[model_idx]
model_list_func_id = self.list_func_id[model_idx]
func_id = self.func_id
model_class_names = self.list_class_names[model_idx]
para_invade_enable = self.list_para_invade_enable[model_idx]
# 预处理帧并确保在正确设备上
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
timestart = time.time()
with torch.no_grad():
# 确保输入和模型在同一设备
if input_tensor.device != next(model.parameters()).device:
input_tensor = input_tensor.to(next(model.parameters()).device)
results = model(input_tensor, verbose=False)
timeend = time.time()
# print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s (设备: {self.device})")
detections = []
model_para = {
"cls_map": cls_map,
"model_chinese_labe": model_chinese_labe,
"model_cls_index": model_cls_index,
"model_list_func_id": model_list_func_id,
"func_id": func_id,
"model_class_names": model_class_names,
"para_invade_enable": para_invade_enable,
"results": results
}
# model_cls_index{}
detection_result_list = DetectionResultList([], [], [], [], [])
for result in results:
boxes = result.boxes.cpu().numpy() # 移动到CPU进行处理
tracks = result.boxes.id.cpu().numpy() if result.boxes.id is not None else None
for i, box in enumerate(boxes):
cls_id = int(box.cls[0])
# 获取原始模型处理DataParallel包装
original_model = model.module if isinstance(model, torch.nn.DataParallel) else model
cls_name = original_model.names[cls_id]
# 过滤不允许的类别
if allowed and cls_name not in allowed:
continue
# 模型输出的边界框是相对于输入尺寸的
x1, y1, x2, y2 = box.xyxy[0].astype(int)
# 缩放回原始图像尺寸
scaled_bbox = self.scale_bbox(
[x1, y1, x2, y2],
original_size,
self.input_size
)
conf = float(box.conf[0])
# 临时测试用,后面需要抽取出来参数
if conf < 0.45:
continue
en_name = cls_map.get(cls_name, "unknown")
# 添加跟踪ID如果存在
track_id = int(tracks[i]) if tracks is not None else None
detection_result_list.boxes.append(scaled_bbox)
detection_result_list.clss.append(cls_id)
detection_result_list.clss_name.append(cls_name)
detection_result_list.confs.append(conf)
detection_result_list.track_ids.append(-1)
detections.append(DetectionResult(
bbox=scaled_bbox,
class_id=cls_id,
class_name=cls_name,
confidence=conf, # 新增字段
track_id=-1
))
return detections, detection_result_list, model_para
# # 使用线程池并行处理
# futures = [
# loop.run_in_executor(
# executor,
# _predict,
# model_idx,
# frame.copy()
# )
# for model_idx in range(len(self.models))
# ]
#
# # 等待所有模型完成预测
# results = await asyncio.gather(*futures)
#
# # 合并所有模型的结果
# merged_results = []
# for res in results:
# merged_results.extend(res)
#
# return merged_results
# 并行执行
futures = [
loop.run_in_executor(
executor,
_predict,
model_idx,
frame.copy()
)
for model_idx in range(len(self.models))
]
results = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
# 合并检测结果
all_detections = []
detections_list = [] # 这里的格式跟results 的二次计算相同
# 合并所有模型的参数到一个字典(避免键冲突)
all_model_paras = []
for model_idx, (detections, detection_result_list, model_para) in enumerate(results):
all_detections.extend(detections)
# 给每个模型的参数加前缀,避免覆盖(如 model0_model_cls_index
# prefixed_para = {f"model{model_idx}_{k}": v for k, v in model_para.items()}
# all_model_paras.update(prefixed_para)
all_model_paras.append(model_para)
detections_list.append(detection_result_list)
return all_detections, detections_list, all_model_paras
class MultiYOLODetector_TrackId:
"""多模型并行检测器修复多GPU设备不匹配问题"""
def __init__(self, model_configs: List[Dict]):
self.models = []
self.class_maps = []
self.allowed_classes = []
self.model_cls = []
self.chinese_label = []
self.list_func_id = []
self.func_id = -1
self.list_class_names = []
self.list_para_invade_enable = []
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
self._setup_multi_gpu()
# 初始化ObjectCounter实例
self.object_counters = []
for config in model_configs:
model_path = config['path']
cls_map = config.get('cls_map', {})
allowed = config.get('allowed_classes', None)
tracking = config.get('tracking', True) # 是否启用跟踪
model_cls_index = config.get('cls_index', True)
model_chinese_labe = config.get('chinese_label', True)
model_list_func_id = config.get('list_func_id', -11)
func_id = config.get('func_id', True)
class_names = config.get('class_names', True)
para_invade_enable = config.get('para_invade_enable', False)
# 加载模型并确保在正确设备上
model = self._load_model(model_path, tracking)
self.models.append(model)
self.class_maps.append(cls_map)
self.allowed_classes.append(allowed)
self.model_cls.append(model_cls_index)
self.chinese_label.append(model_chinese_labe)
self.list_func_id.append(model_list_func_id)
self.func_id = func_id
self.list_class_names.append(class_names)
self.list_para_invade_enable.append(para_invade_enable)
# 为每个模型创建ObjectCounter实例
if hasattr(model, 'tracker'):
# 如果模型支持跟踪则创建带跟踪的ObjectCounter
self.object_counters.append(
solutions.ObjectCounter(
show=False,
region=None, # 可以根据需要设置检测区域
model=model_path,
classes=model_cls_index if allowed else None
)
)
else:
# 如果模型不支持跟踪则创建普通ObjectCounter
self.object_counters.append(
solutions.ObjectCounter(
show=False,
region=None,
model=model_path,
classes=model_cls_index if allowed else None
)
)
def _setup_multi_gpu(self):
"""配置多GPU环境"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"检测到 {gpu_count} 个GPU将使用默认设备: {self.device}")
# 强制所有操作使用默认设备
torch.cuda.set_device(self.device)
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
try:
# 加载模型
model = YOLO(model_path)
# 移动到目标设备
model = model.to(self.device)
# # 打印模型的输入尺寸(通过模型结构或配置)
# if hasattr(model, 'stride'):
# stride = model.stride # 通常是32
# input_width = int((self.input_size[0] / stride) * stride) # 对齐到最近的stride倍数
# input_height = int((self.input_size[1] / stride) * stride)
# print(f"模型实际输入尺寸(对齐后): {input_width}x{input_height}")
# else:
# print(f"模型输入尺寸(未对齐): {self.input_size}")
# 启用跟踪仅对YOLOv8有效
if hasattr(model, 'tracker') and tracking:
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
print(f"已为模型 {model_path} 启用跟踪")
else:
print(f"警告:模型 {model_path} 不支持跟踪功能")
# 对于多GPU仅在必要时使用DataParallel并确保所有参数在同一设备
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# 确保模型参数在默认设备上
model = model.to(self.device)
model = torch.nn.DataParallel(model, device_ids=[self.device])
# 禁用模型融合以避免多设备问题
model.fuse = lambda verbose=False: model # 替换融合方法
return model
except Exception as e:
print(f"模型加载失败: {e}")
raise
# # 修改 _load_model 方法
# def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
# model = YOLO(model_path)
# model = model.to(self.device)
#
# if tracking:
# # 显式设置跟踪器参数
# model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
# tracker_cfg = {
# "track_high_thresh": 0.3, # 降低跟踪阈值
# "track_low_thresh": 0.1,
# "new_track_thresh": 0.4,
# }
# model.tracker.update_params(tracker_cfg) # 更新参数
#
# return model
@staticmethod
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
"""预处理帧:调整大小、转换格式并标准化"""
# 调整大小
resized = cv2.resize(frame, (input_size[0], input_size[1])) # 明确指定 (width, height)
print(f"预处理后图像尺寸: {resized.shape}") # 应为 (input_size[1], input_size[0], 3)
# 转换为RGB
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
# 转换为张量并标准化
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
# 添加批次维度并移动到目标设备
tensor = tensor.unsqueeze(0).to(device)
return tensor
# @staticmethod
# def scale_bbox(bbox: List[int], original_size: tuple, input_size: tuple) -> List[int]:
# """将模型输出的边界框坐标缩放回原始图像尺寸"""
# oh, ow = original_size
# iw, ih = input_size
#
# # 计算缩放比例
# scale_w = ow / iw
# scale_h = oh / ih
#
# # 缩放边界框
# x1, y1, x2, y2 = bbox
# return [
# int(x1 * scale_w),
# int(y1 * scale_h),
# int(x2 * scale_w),
# int(y2 * scale_h)
# ]
@staticmethod
def scale_bbox(bbox, original_size, input_size):
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
ow, oh = original_size # (width, height)
iw, ih = input_size # (width, height)
# 计算缩放比例(基于输入尺寸和原始尺寸的比例)
scale_w = ow / iw
scale_h = oh / ih
x1, y1, x2, y2 = bbox
# 缩放并限制坐标范围
x1 = int(max(0, min(x1 * scale_w, ow - 1)))
y1 = int(max(0, min(y1 * scale_h, oh - 1)))
x2 = int(max(0, min(x2 * scale_w, ow - 1)))
y2 = int(max(0, min(y2 * scale_h, oh - 1)))
return [x1, y1, x2, y2]
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
"""异步调用多模型预测"""
loop = asyncio.get_running_loop()
# original_size = (frame.shape[0], frame.shape[1]) # (height, width)
original_size = (frame.shape[1], frame.shape[0]) # (widthheight )
print("predictpredict")
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
try:
model = self.models[model_idx]
object_counter = self.object_counters[model_idx]
cls_map = self.class_maps[model_idx]
allowed = self.allowed_classes[model_idx]
model_cls_index = self.model_cls[model_idx]
model_chinese_labe = self.chinese_label[model_idx]
model_list_func_id = self.list_func_id[model_idx]
func_id = self.func_id
model_class_names = self.list_class_names[model_idx]
para_invade_enable = self.list_para_invade_enable[model_idx]
# 预处理帧
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
timestart = time.time()
with torch.no_grad():
if input_tensor.device != next(model.parameters()).device:
input_tensor = input_tensor.to(next(model.parameters()).device)
# 使用ObjectCounter进行检测
results = object_counter(frame)
# 调试输出结果结构
print(f"Model {model_idx} results type: {type(results)}")
if hasattr(results, 'boxes'):
print(
f"Boxes type: {type(results.boxes)}, length: {len(results.boxes) if hasattr(results.boxes, '__len__') else 'N/A'}")
timeend = time.time()
print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s")
detections = []
model_para = {
"cls_map": cls_map,
"model_chinese_labe": model_chinese_labe,
"model_cls_index": model_cls_index,
"model_list_func_id": model_list_func_id,
"func_id": func_id,
"model_class_names": model_class_names,
"para_invade_enable": para_invade_enable,
"results": object_counter # 保存原始结果用于调试
}
results = object_counter
# 处理检测结果
detection_result_list = DetectionResultList([], [], [], [], [])
# 尝试不同的结果访问方式
boxes = []
tracks = results.track_ids
confs = results.confs
clss = results.clss
names = results.names
result_boxes = results.boxes
# yolo11 使用二次修改后的值本身就是list
# if isinstance(result_boxes, list):
# if len(result_boxes) == 0:
# return [], DetectionResultList([], [], [], []), {}
# boxes = result_boxes
# else:
# try:
# boxes = result_boxes.tolist() # 尝试转换为列表
# except AttributeError:
# # 如果既不是列表,也没有 tolist() 方法,根据需求处理
# # 例如raise TypeError("result_boxes 必须是列表或支持 tolist() 的对象")
# boxes = list(result_boxes) # 尝试用 list() 转换(如果对象可迭代)
# 在处理 boxes 之前添加调试代码
print(f"模型输出 boxes 原始值: {result_boxes}")
if isinstance(result_boxes, torch.Tensor):
print(f"Boxes 张量形状: {result_boxes.shape}")
if isinstance(result_boxes, list) and len(result_boxes) == 0:
return [], DetectionResultList([], [], [], [], []), {}
if isinstance(result_boxes, list) and len(result_boxes) > 0:
boxes = result_boxes
else:
boxes = result_boxes.tolist()
# 处理每个检测结果
for i, box in enumerate(boxes):
try:
# 确保box是可迭代的
if not hasattr(box, '__iter__'):
print(f"无效的边界框格式: {box}")
continue
# 转换为列表(如果还不是)
box = list(box) if not isinstance(box, list) else box
# 确保有4个坐标值
if len(box) < 4:
print(f"不完整的边界框坐标: {box}")
continue
x1, y1, x2, y2 = box[:4]
# cls_id = int(results.clss[i])
cls_id = int(results.clss[i])
ind = int(results.clss[i])
cls_name = names[ind]
# if 'Hat'== cls_name:
# print("1")
# else:
# cls_id=0
# print("2")
conf = float(confs[i]) if i < len(confs) else 0.0
# 待优化,将参数提取出来
if conf < 0.45:
continue
track_id = int(tracks[i]) if i < len(tracks) and tracks[i] is not None else None
# 过滤不允许的类别
if allowed and cls_name not in allowed:
continue
# 缩放回原始图像尺寸
scaled_bbox = self.scale_bbox(
[x1, y1, x2, y2],
original_size,
self.input_size
)
scaled_bbox = [int(x1), int(y1), int(x2), int(y2)]
detection_result_list.boxes.append([x1, y1, x2, y2])
detection_result_list.clss.append(cls_id)
detection_result_list.clss_name.append(cls_name)
detection_result_list.confs.append(conf)
detection_result_list.track_ids.append(track_id)
detections.append(DetectionResult(
bbox=scaled_bbox,
class_id=cls_id,
class_name=cls_name,
confidence=conf,
track_id=track_id
))
except Exception as e:
print(f"处理单个检测结果时出错: {str(e)}")
continue
return detections, detection_result_list, model_para
except Exception as e:
print(f"模型 {model_idx} 预测过程中发生错误: {str(e)}")
import traceback
traceback.print_exc()
return [], DetectionResultList([], [], [], [], []), {}
# 并行执行
futures = [
loop.run_in_executor(
executor,
_predict,
model_idx,
frame.copy()
)
for model_idx in range(len(self.models))
]
results111 = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
# 合并检测结果
all_detections = []
detections_list = [] # 这里的格式跟results 的二次计算相同
# 合并所有模型的参数到一个字典(避免键冲突)
all_model_paras = []
for model_idx, (detections, detection_result_list, model_para) in enumerate(results111):
all_detections.extend(detections)
# 给每个模型的参数加前缀,避免覆盖(如 model0_model_cls_index
# prefixed_para = {f"model{model_idx}_{k}": v for k, v in model_para.items()}
# all_model_paras.update(prefixed_para)
all_model_paras.append(model_para)
detections_list.append(detection_result_list)
return all_detections, detections_list, all_model_paras
#
# async def read_rtmp_frames(video_url: str):
# """异步读取RTMP流帧"""
# print(f"开始视频拉流: {video_url}")
# cap = cv2.VideoCapture(video_url, cv2.CAP_FFMPEG)
#
# cap.set(cv2.CAP_PROP_OPEN_TIMEOUT_MSEC, 60000)
# cap.set(cv2.CAP_PROP_READ_TIMEOUT_MSEC, 50000)
#
# if not cap.isOpened():
# raise RuntimeError(f"无法打开视频流: {video_url}")
#
# frame_count = 0
# retry_count = 0
# max_retries = 10
#
# while not stop_event.is_set() and retry_count < max_retries:
# try:
# ret, frame = cap.read()
# if not ret:
# retry_count += 1
# if retry_count >= max_retries:
# raise RuntimeError("连续读取帧失败,达到最大重试次数")
# await asyncio.sleep(0.1)
# continue
#
# retry_count = 0
# frame_count += 1
#
# # 跳过前5帧以稳定流
# if frame_count <= 5:
# continue
#
# # 放入帧队列
# if not frame_queue.full():
# await frame_queue.put((frame, time.time()))
# else:
# await asyncio.sleep(0.01) # 队列满时稍作等待
#
# except Exception as e:
# print(f"读取帧时出错: {e}")
# retry_count += 1
# await asyncio.sleep(0.5)
#
# cap.release()
# print("RTMP读取线程已停止")
async def read_video_frames(task_id, mqtt, mqtt_publish_topic, local_video_path: str, srt_path: str):
# 打开视频文件(流式读取,不加载到内存)
stop_event.clear() # 清除停止事件
srt_list = parse_srt_file(srt_path)
cap = cv2.VideoCapture(local_video_path)
if not cap.isOpened():
print("Error: Could not open video.")
return
prev_time = time.time()
frame_count = 0
fps = 0
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# # 临时使用,后续还是考虑使用文件整体传输消息好一点
# async with cache_lock:
# # 读取或修改共享变量
# video_process_status = shared_local_cache["video_process_status"]
# if video_process_status == 0:
# shared_local_cache["video_process_status"] = 1 # 将状态改为开始
# 临时处理,用作发送当前录像处理的状态,开始还是结束
async def publist_status(status_to_publish):
try:
message = {"task_id": task_id, "video_status": status_to_publish}
await mqtt.publish(mqtt_publish_topic, json.dumps(message, ensure_ascii=False))
except Exception as e:
logger.error(f"Failed to publish status: {e}")
raise
await publist_status(1)
# 控制读帧,避免队列积压,进而出现大量丢帧
cap = cv2.VideoCapture(local_video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频原始FPS
frame_interval = 1.0 / video_fps # 每帧应等待的时间
last_read_time = time.time()
while cap.isOpened():
ret, frame = cap.read()
if not ret: # 检查是否成功读取帧(视频结束或读取错误)
await publist_status(2)
stop_event.set()
# # 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
# async with cache_lock:
# # 读取或修改共享变量
# video_process_status = shared_local_cache["video_process_status"]
# if video_process_status == 0:
# shared_local_cache["video_process_status"] = 2 # 将状态改为结束
break
# 控制读取速度:确保每帧间隔不小于 frame_interval
current_time = time.time()
if current_time - last_read_time < frame_interval:
await asyncio.sleep(frame_interval - (current_time - last_read_time))
last_read_time = current_time
# # 保存当前帧为图片(或直接处理)
# frame_path = f"{output_folder}/frame_{frame_count:06d}.jpg"
# cv2.imwrite(frame_path, frame)
frame_count += 1
if frame_count % 10 == 0:
curr_time = time.time()
fps = 10 / (curr_time - prev_time)
prev_time = curr_time
print(f" ({orig_width}x{orig_height}, FPS: {fps:.2f})")
# 可选每处理N帧打印一次进度
if frame_count % 100 == 0:
print(f"Processed {frame_count} frames")
if not frame_queue.full():
if frame_count < len(srt_list) and len(srt_list) > 0:
dji_srt = srt_list[frame_count]
gimbal_pitch = dji_srt.gb_pitch
gimbal_roll = dji_srt.gb_roll
gb_yaw = dji_srt.gb_yaw
air_height = dji_srt.abs_alt
cam_longitude = dji_srt.longitude
cam_latitude = dji_srt.latitude
art_tit = Air_Attitude(gimbal_pitch, gimbal_roll, gb_yaw, air_height, cam_latitude, cam_longitude)
await frame_queue.put((frame, art_tit, time.time()))
else:
await asyncio.sleep(0.01) # 队列满时稍作等待
cap.release()
# # 测试时候,临时注释
if os.path.exists(srt_path):
os.remove(srt_path)
if os.path.exists(local_video_path):
os.remove(local_video_path)
async def read_rtmp_frames(
video_url: str,
device: MQTTDevice = None, # 可选参数,如果不需要 OSD 信息可设为 None
topic_camera_osd: str = None,
method_camera_osd: str = None,
topic_osd_info: str = None,
method_osd_info: str = None
):
"""
异步读取 RTMP 流帧(优化版)
:param video_url: RTMP 流地址
:param device: MQTTDevice 实例(可选,用于获取 OSD 信息)
:param topic_osd_info: OSD 信息对应的 MQTT 主题(可选)
:param method_osd_info: OSD 信息对应的 MQTT method可选
"""
max_retries = 5 # 最大重试次数
retry_delay = 2 # 重试间隔(秒)
frame_timeout = 1.0 # 获取帧的超时(秒)
for attempt in range(max_retries):
container = None
try:
print(f"尝试连接 RTMP 流 (尝试 {attempt + 1}/{max_retries}): {video_url}")
# 直接使用 av.open 尝试打开流,设置20s的超时时间防止视频流断开后进程挂起
# container = av.open(video_url)
container = av.open(video_url)
video_stream = next(s for s in container.streams if s.type == 'video')
print(f"成功连接到 RTMP 流: {video_url} ({video_stream.width}x{video_stream.height})")
# 创建 OpenCV 窗口用于调试(可选)
# cv2.namedWindow('RTMP Stream', cv2.WINDOW_NORMAL)
for frame in container.decode(video=0):
# 检查停止事件
if stop_event.is_set():
print("收到停止信号,终止 RTMP 读取")
break
try:
# 转换为 OpenCV 格式
img = frame.to_ndarray(format='bgr24')
# 获取 OSD 信息(如果需要)
osd_info = None
if device and topic_osd_info and method_osd_info:
osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
if osd_msg and hasattr(osd_msg, 'data'):
osd_info = Air_Attitude(
gimbal_pitch=osd_msg.data.gimbal_pitch,
gimbal_roll=osd_msg.data.gimbal_roll,
gimbal_yaw=osd_msg.data.gimbal_yaw,
height=osd_msg.data.height,
latitude=osd_msg.data.latitude,
longitude=osd_msg.data.longitude
)
# 显示帧用于调试(可选)
# cv2.imshow('RTMP Stream', img)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# stop_event.set()
# break
# 放入帧队列
if not frame_queue.full():
await frame_queue.put((img, osd_info, time.time()))
else:
await asyncio.sleep(0.001) # 队列满时短暂等待
except Exception as frame_error:
print(f"处理单帧时出错: {frame_error}")
continue
# 如果循环正常退出,说明流结束了
print("RTMP 流已结束")
break
except (av.AVError, IOError) as e:
print(f"RTMP 流错误 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
else:
raise RuntimeError(f"无法连接 RTMP 流 (尝试 {max_retries} 次后失败): {video_url}")
except Exception as e:
print(f"未知错误: {e}")
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
else:
raise
finally:
# 确保容器被关闭
if container:
container.close()
# cv2.destroyAllWindows()
# async def process_frames(detector: MultiYOLODetector):
async def process_frames(detector: MultiYOLODetector_TrackId):
"""协程处理帧队列"""
while not stop_event.is_set():
try:
frame, osd_info, timestamp = await asyncio.wait_for(
frame_queue.get(),
timeout=0.5 # 延长超时,适配处理耗时
)
try:
detections, detections_list, model_para = await detector.predict(frame)
predict_state = True
if not detections:
predict_state = False
logger.debug("未检测到任何目标")
continue
processed_data = {
'frame': frame,
'osd_info': osd_info,
'detections': detections,
'detections_list': detections_list,
'timestamp': timestamp,
'model_para': model_para,
'predict_state': predict_state # predict状态判定方便rtmp推流做状态判定
}
if not processed_queue.full():
await processed_queue.put(processed_data)
else:
logger.warning("处理队列已满,丢弃帧")
stats['dropped_frames'] += 1
except Exception as e:
logger.error(f"处理帧时出错: {e}", exc_info=True)
stats['dropped_frames'] += 1
await asyncio.sleep(0.1)
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"获取帧时发生意外错误: {e}", exc_info=True)
await asyncio.sleep(0.1)
from PIL import Image, ImageDraw, ImageFont
chinese_font = None
try:
chinese_font = ImageFont.truetype("config/SIMSUN.TTC", 20)
except:
chinese_font = ImageFont.load_default()
def put_chinese_text(img, text, position, font=chinese_font, color=(0, 255, 0)):
"""使用预加载的字体绘制中文"""
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_pil)
draw.text(position, text, font=font, fill=color)
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
# 全局变量存储推流容器
stream_containers: Dict[str, Any] = {}
from collections import defaultdict
import time
from typing import Dict, Any
class TrackIDEventFilter:
def __init__(self, max_inactive_time: float = 5.0):
"""
初始化 TrackID 过滤器
:param max_inactive_time: 最大不活跃时间(秒),超时未出现的 track_id 会被移除
"""
self.track_status = defaultdict(int) # 记录每个 track_id 的连续识别次数
self.last_active_time = defaultdict(float) # 记录每个 track_id 最后一次出现的时间
self.max_inactive_time = max_inactive_time # 超时清理阈值
def should_report(self, track_id: int) -> bool:
"""
判断是否应该上报事件
:param track_id: 跟踪ID
:return: True 表示应该上报False 表示过滤
"""
current_time = time.time()
# 1. 清理长时间不活跃的 track_id
inactive_tracks = [
tid for tid, last_time in self.last_active_time.items()
if (current_time - last_time) > self.max_inactive_time
]
for tid in inactive_tracks:
self.track_status.pop(tid, None)
self.last_active_time.pop(tid, None)
# 2. 更新最后活跃时间
self.last_active_time[track_id] = current_time
# 3. 过滤逻辑
if self.track_status[track_id] == 0:
# 首次出现,记录但不上报
self.track_status[track_id] = 1
return False
elif self.track_status[track_id] == 1:
# 连续出现第二次,上报并重置状态(需要再次连续出现才会触发)
self.track_status[track_id] = 0 # 重置为 0下次需重新连续出现两次
return True
else:
# 其他情况(如超时后首次出现),重新开始计数
self.track_status[track_id] = 1
return False
async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: float = None,
list_points: list[list[any]] = None, invade_state: bool = False):
global stream_containers, count_pic
print(f"推流地址: {output_url}")
# 修改推流参数
options = {
'preset': 'veryfast',
'tune': 'zerolatency',
'crf': '23',
'g': '50', # 关键帧间隔
'threads': '2', # 限制编码线程
}
codec_name = 'libx264'
# 初始化视频输出
output_video_path = None
video_writer = None
frame_width, frame_height = None, None
fps = input_fps or Config.TARGET_FPS # 优先使用输入帧率
# 如果需要保存处理后的视频
if output_url: # 或者根据某个配置参数决定是否保存
# 创建输出视频路径(可以根据需要修改路径和文件名)
output_video_path = r"D:\project\AI-PYTHON\Ai_tottle\save_pic\a12.mp4"
print(f"将处理后的视频保存到: {output_video_path}")
# 帧间隔控制
frame_interval = 1.0 / fps
last_frame_time = time.time()
count_p = 0
while not stop_event.is_set():
try:
# 第一层帧率控制:确保按目标帧率处理,未到时间则等待(而非直接跳过)
current_time = time.time()
time_diff = frame_interval - (current_time - last_frame_time)
if time_diff > 0:
await asyncio.sleep(time_diff) # 等待到间隔时间,而非跳过
last_frame_time = current_time
processed_data = await asyncio.wait_for(processed_queue.get(), timeout=10)
# 确保 processed_data 是字典
if not isinstance(processed_data, dict):
print(f"❌ 错误processed_data 不是字典,而是 {type(processed_data)}")
continue
frame = processed_data['frame']
predict_state = processed_data['predict_state']
if predict_state:
# 测试代码,用做测试推理结果,初始化视频写入器(如果尚未初始化)
if video_writer is None and output_video_path:
frame_width = frame.shape[1]
frame_height = frame.shape[0]
# 定义视频编码器和输出文件
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者使用 'avc1' 等其他编码
# fps = Config.TARGET_FPS # 使用配置中的目标帧率
video_writer = cv2.VideoWriter(
output_video_path,
fourcc,
fps,
(frame_width, frame_height)
)
print(f"视频写入器已初始化,分辨率: {frame_width}x{frame_height}, FPS: {fps}")
osd_info = processed_data['osd_info']
detections = processed_data['detections']
detections_list = processed_data['detections_list']
model_para = processed_data['model_para']
class_names = model_para[0]["model_class_names"]
chinese_label = model_para[0]["model_chinese_labe"]
cls_map = model_para[0]["cls_map"]
para_invade_enable = model_para[0]["para_invade_enable"]
img_height, img_width = frame.shape[:2]
results = []
invade_point = []
message_point = []
cls_count = 0
if osd_info:
gimbal_yaw = osd_info.gimbal_yaw
gimbal_pitch = osd_info.gimbal_pitch
gimbal_roll = osd_info.gimbal_roll
height = osd_info.height
cam_longitude = osd_info.longitude
cam_latitude = osd_info.latitude
# 当前list_points 虽然是二维数组,但是只存了一个,后续根据业务变化
for points in list_points:
# 批量返回图像的像素坐标
results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude,
cam_latitude,
img_width,
img_height, points)
# 绘制检测结果
frame_copy = frame.copy()
# # 初始化统计字典
class_stats = defaultdict(int)
reversed_dict = {value: 0 for value in cls_map.values()}
bg_color = (173, 216, 230) # 文本框底色使用淡蓝色
text_color = (0, 255, 0) # 绿色
for det in detections:
x1, y1, x2, y2 = map(int, det.bbox) # 确保坐标是整数
cls_id = det.class_id # 假设Detection对象有class_id属性
class_name = det.class_name
confidence = det.confidence
track_id = det.track_id
# 更新统计
class_stats[cls_id] += 1
# 如果开起侵限功能,就只显示侵限内的框
if invade_state:
point_x = (x1 + x2) / 2
point_y = (y1 + y2) / 2
is_invade = is_point_in_polygon(point_x, point_y, results)
if is_invade:
cls_count += 1
invade_point.append({
"u": point_x,
"v": point_y,
"class_name": class_name
})
model_list_func_id = model_para[0]["model_list_func_id"]
model_func_id = model_para[0]["func_id"]
en_name = model_para[0]["model_chinese_labe"][
model_para[0]["model_class_names"].index(class_name)]
message_point.append({
"confidence": confidence,
"cls_id": cls_id,
"type_name": en_name,
"box": [x1, y1, x2, y2]
})
label = f"{confidence:.2f}:{track_id}"
# 计算文本位置
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
text_width, text_height = text_size[0], text_size[1]
text_x = x1
text_y = y1 - 5
# 如果文本超出图像顶部,则放在框内部下方
if text_y < 0:
text_y = y2 + text_height + 5
temp_img = frame_copy.copy()
frame_copy = put_chinese_text(
temp_img,
label,
# "", # 注释掉汉字
(text_x, text_y - text_height),
)
else:
# 绘制边界框
cv2.rectangle(frame_copy, (x1, y1), (x2, y2), (0, 255, 255), 2)
# 准备标签文本
label = f"{chinese_label.get(cls_id, class_name)}: {confidence:.2f}:{track_id}"
# 计算文本位置
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
text_width, text_height = text_size[0], text_size[1]
text_x = x1
text_y = y1 - 5
# 如果文本超出图像顶部,则放在框内部下方
if text_y < 0:
text_y = y2 + text_height + 5
# 绘制文本背景
padding = 2
temp_img = frame_copy.copy()
frame_copy = put_chinese_text(
temp_img,
label,
# "", # 注释掉汉字
(text_x, text_y - text_height),
)
if invade_state:
for point in message_point:
cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
(point["box"][2], point["box"][3]),
(0, 255, 255), 2)
# 画红线
point_list=[]
if results:
for point in results:
point_list.append([point["u"], point["v"]])
cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255),
thickness=2)
# 在左上角显示统计结果
stats_text = []
for cls_id, count in class_stats.items():
cls_name = chinese_label.get(cls_id,
class_names[cls_id] if class_names and cls_id < len(
class_names) else str(
cls_id))
reversed_dict[cls_name] = count
for key, value in reversed_dict.items():
stats_text.append(f"{key}: {value}")
if stats_text:
# 计算统计文本的总高度
text_height = cv2.getTextSize("Test", cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0][1]
total_height = len(stats_text) * (text_height + 5) # 5是行间距
# 统计文本的起始位置(左上角)
start_x = 50 # 留出200像素宽度
start_y = 20 # 从顶部开始
# 绘制统计背景
# bg_color = (0, 0, 0)
frame_copy = cv2.rectangle(
frame_copy,
(start_x - 10, start_y - 10),
(200, start_y + total_height + 10),
bg_color,
-1
)
# 逐行绘制统计文本
for i, text in enumerate(stats_text):
y_pos = start_y + i * (text_height + 5)
temp_img = frame_copy.copy()
frame_copy = put_chinese_text(
temp_img,
text,
(start_x, y_pos),
# font_size=15,
color=text_color
)
new_data = {
'frame_copy': frame_copy,
'frame': frame,
"osd_info": osd_info,
'detections': detections,
'timestamp': processed_data.get('timestamp'),
"detections_list": detections_list,
"model_para": model_para
# 'model_para': processed_data.get('model_para', {}) # 确保 model_para 存在
}
# 临时代码 rtmp 和侵限逻辑要改
if invade_state:
# para_list 中使能了 para_invade_enable才做侵限判断
# if para_invade_enable:
if not invade_queue.full():
await invade_queue.put(new_data)
else:
if not cv_frame_queue.full():
print(f"cv_frame_queue.put(new_data)cv_frame_queue.put(new_data) ")
await cv_frame_queue.put(new_data)
count_p = count_p + 1
# cv2.imwrite(f"save_pic/rtmp/test-{count_p}.jpg", frame_copy)
# video_writer.write(frame_copy)
else:
frame_copy = frame
# 转换颜色空间
rgb_frame = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
# 初始化推流容器(如果尚未初始化)
if output_url and output_url not in stream_containers:
try:
container = av.open(output_url, mode='w', format='flv')
stream = container.add_stream(codec_name, rate=Config.TARGET_FPS)
# stream.time_base = f"1/{Config.TARGET_FPS}"
stream.width = frame.shape[1]
stream.height = frame.shape[0]
stream.pix_fmt = 'yuv420p'
stream.options = options
stream_containers[output_url] = {
'container': container,
'stream': stream,
'last_frame_time': time.time(),
'frame_count': 0
}
print(f"✅ 推流初始化成功: {output_url}")
except Exception as e:
print(f"❌ 推流初始化失败: {e}")
if 'container' in locals():
try:
container.close()
except:
pass
await asyncio.sleep(1.0)
continue
# 推流逻辑
if output_url and output_url in stream_containers:
try:
container_info = stream_containers[output_url]
stream = container_info['stream']
container = container_info['container']
if rgb_frame.dtype == np.uint8:
av_frame = av.VideoFrame.from_ndarray(rgb_frame, format='rgb24')
packets = stream.encode(av_frame) # 这是异步方法
print(f"📦 encode 生成 {len(packets)} 个 packet")
if packets:
for packet in packets:
container.mux(packet)
container_info['last_frame_time'] = time.time() # ← 只有真正 mux 后才更新
container_info['frame_count'] += 1
else:
# 编码器仍在初始化,不更新 last_frame_time
pass
# 每100帧打印一次状态
if container_info['frame_count'] % 100 == 0:
print(f" 已推送 {container_info['frame_count']} 帧到 {output_url}")
else:
print(f"⚠️ 无效帧格式: {rgb_frame.dtype}")
except Exception as e:
print(f"❌ 推流错误: {e}")
try:
container_info['container'].close()
except:
pass
stream_containers.pop(output_url, None)
await asyncio.sleep(1.0)
# if predict_state:
# new_data = {
# 'frame_copy': frame_copy,
# 'frame': frame,
# "osd_info": osd_info,
# 'detections': detections,
# 'timestamp': processed_data.get('timestamp'),
# "detections_list": detections_list,
# "model_para": model_para
# # 'model_para': processed_data.get('model_para', {}) # 确保 model_para 存在
# }
# if not cv_frame_queue.full():
# print(f"cv_frame_queue.put(new_data)cv_frame_queue.put(new_data) ")
# await cv_frame_queue.put(new_data)
# # para_list 中使能了 para_invade_enable才做侵限判断
# if para_invade_enable:
# if not invade_queue.full():
# await invade_queue.put(new_data)
except asyncio.TimeoutError:
# 检查现有推流状态
for url, info in list(stream_containers.items()):
if time.time() - info['last_frame_time'] > 5: # 5秒无帧则重启
print(f"⚠️ 检测到推流停滞,重启 {url}")
try:
info['container'].close()
except:
pass
stream_containers.pop(url)
continue
except Exception as e:
print(f"❌ 推流处理异常: {e}")
await asyncio.sleep(0.1)
# 基于射线法,判断设备是否在红线内
def is_point_in_polygon(point_x, point_y, polygon):
x, y = point_x, point_y
n = len(polygon)
inside = False
for i in range(n):
p1 = polygon[i]
p2 = polygon[(i + 1) % n]
x1, y1 = p1["u"], p1["v"]
x2, y2 = p2["u"], p2["v"]
# 检查点是否在多边形的顶点或边上
if (x == x1 and y == y1) or (x == x2 and y == y2):
return True # 点在顶点上
if (x - x1) * (y2 - y1) == (y - y1) * (x2 - x1): # 点在边上
if min(x1, x2) <= x <= max(x1, x2) and min(y1, y2) <= y <= max(y1, y2):
return True
# 射线法核心逻辑
if (y1 > y) != (y2 > y): # 确保点在边的垂直范围内
x_intersect = (x2 - x1) * (y - y1) / (y2 - y1) + x1
if x <= x_intersect:
inside = not inside
return inside
#
## 判断目标是否侵限
#async def cal_des_invade(task_id: str, mqtt, mqtt_publish_topic, list_points: list[list[any]]):
# loop = asyncio.get_running_loop()
# print(6)
# pic_count = 0
# pic_count_hongxian = 0
# while not stop_event.is_set():
# # while True:
# print(777777777)
# cv_frame = None
# # 检查队列长度,避免堆积
# if invade_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
# print(f"警告invade_queue 积压(当前长度={invade_queue.qsize()}),清空队列")
# while not invade_queue.empty():
# await invade_queue.get()
# print("队列已清空,等待新数据...")
# await asyncio.sleep(0.1)
# continue
# # 获取队列数据
# try:
# cv_frame = await asyncio.wait_for(invade_queue.get(), timeout=0.8)
# # 检查数据类型
# if not isinstance(cv_frame, dict):
# print(f"⚠️ 警告cv_frame 不是字典,而是 {type(cv_frame)}")
# continue
# except asyncio.TimeoutError:
# print(f"cal_des_invade TimeoutError")
# continue
#
# if cv_frame is None:
# continue
# print("cal_des_invade inside")
# frame_copy = cv_frame['frame_copy']
# frame = cv_frame['frame']
# air_alti = cv_frame['osd_info']
# detections = cv_frame['detections']
# detections_list = cv_frame['detections_list']
# model_para_list = cv_frame.get('model_para', {}) # 默认空字典
# if not isinstance(frame, np.ndarray) or frame.size == 0:
# print("⚠️ 警告frame不是有效的numpy数组")
# continue
#
# # 获取宽高
# img_height, img_width = frame.shape[:2]
# if not air_alti:
# continue
#
# # if latest_osd_info:
# # # message_json = json.loads(latest_osd_info)
# # # osd_info = parse_osd_message(json.dumps(message_json))
# # print(f"解析成功: {latest_osd_info.data.gimbal_yaw}, {latest_osd_info.data.longitude}")
# # osd_info = parse_osd_message(latest_osd_info)
# gimbal_yaw = air_alti.gimbal_yaw
# gimbal_pitch = air_alti.gimbal_pitch
# gimbal_roll = air_alti.gimbal_roll
# height = air_alti.height
# cam_longitude = air_alti.longitude
# cam_latitude = air_alti.latitude
#
# # img_width = 1440
# # img_height = 1080
# # 二维数组,特殊要求
# for points in list_points:
# # 批量返回图像的像素坐标
# results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude, cam_latitude,
# img_width,
# img_height, points)
#
# try:
# invade_point = []
# message_point = []
# cls_count = 0
# model_id = 101101
# for idx, model_para in enumerate(model_para_list):
# if not isinstance(detections_list, list) or idx >= len(detections_list):
# continue
# detections = detections_list[idx] # 正确获取当前模型的检测结果
# if detections is None or len(detections.boxes) < 1:
# continue
# # for det in detections:
# for bbox, class_id, class_name, confidence, track_id in detections:
# x1, y1, x2, y2 = map(int, bbox) # 确保坐标是整数
# cls_id = class_id # 假设Detection对象有class_id属性
# class_name = class_name
# confidence = confidence
# point_x = (x1 + x2) / 2
# point_y = (y1 + y2) / 2
#
# is_invade = is_point_in_polygon(point_x, point_y, results)
# if is_invade:
# cls_count += 1
# invade_point.append({
# "u": point_x,
# "v": point_y,
# "class_name": class_name
# })
# model_list_func_id = model_para["model_list_func_id"]
# model_func_id = model_para["func_id"]
# en_name = model_para["model_chinese_labe"][
# model_para["model_class_names"].index(class_name)]
# message_point.append({
# "confidence": confidence,
# "cls_id": cls_id,
# "type_name": en_name,
# "box": [x1, y1, x2, y2]
# })
#
# label = f"{confidence:.2f}:{track_id}"
#
# # 计算文本位置
# text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
# text_width, text_height = text_size[0], text_size[1]
# text_x = x1
# text_y = y1 - 5
#
# # 如果文本超出图像顶部,则放在框内部下方
# if text_y < 0:
# text_y = y2 + text_height + 5
#
# temp_img = frame_copy.copy()
# frame_copy = put_chinese_text(
# temp_img,
# label,
# # "", # 注释掉汉字
# (text_x, text_y - text_height),
# )
#
# point_list = [] # 整理红线集合
# # point_des_list=[]
#
# # 测试红线
# for point in results:
# point_list.append([point["u"], point["v"]])
# for point in message_point:
# cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]),
# (0, 255, 255), 2)
# cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255),
# thickness=2)
# pic_count_hongxian = pic_count_hongxian + 1
## cv2.imwrite(f"save_pic\invade-hongxian\hongxianongly-{pic_count_hongxian}.jpg", frame_copy)
#
# if len(invade_point) > 0:
# for point in results:
# point_list.append([point["u"], point["v"]])
#
# for point in message_point:
# cv2.rectangle(frame, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]),
# (0, 255, 255), 2)
#
# # 绘制红线
# cv2.polylines(frame, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255),
# thickness=2)
#
# print("hongxianhongxianhongxianhongxian")
# pic_count = pic_count + 1
## cv2.imwrite(f"save_pic\invade\hongxian-{pic_count}.jpg", frame)
#
# # await loop.run_in_executor(None, _show_frame, frame)
# # cv2.imshow("hongxian ", frame)
#
# drawn_frame = frame.copy() # 关键修复:深拷贝绘制后的帧
#
# # 图像编码
# def encode_frame():
# success, buffer = cv2.imencode(".jpg", drawn_frame)
# return buffer.tobytes() if success else None
#
# buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
# if not buffer_bytes:
# continue
#
# # 并行处理上传和MQTT发布
# async def upload_and_publish():
# # 上传到MinIO
# def upload_minio():
# return upload_frame_buff_from_buffer(buffer_bytes, None)
#
# minio_path, file_type = await loop.run_in_executor(
# upload_executor, upload_minio
# )
#
# # 构造消息
# message = {
# "task_id": task_id,
# "minio": {"minio_path": minio_path, "file_type": file_type},
# # "box_detail": invade_point
# "box_detail": [{
# "model_id": model_list_func_id,
# "cls_count": cls_count,
# "box_count": [message_point]
# }]
# }
# print(f"hongxianhongxianhongxianhongxian上传 {message}")
# message_json = json.dumps(message, ensure_ascii=False)
# await mqtt.publish(mqtt_publish_topic, message_json)
#
# # 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
# async with invade_cache_lock:
# # 读取或修改共享变量
# send_count = shared_local_cache["invade_send_count"]
# send_count = send_count + 1
# shared_local_cache["invade_send_count"] = send_count
# if send_count > 1:
# # 创建独立任务执行上传和发布
# shared_local_cache["invade_send_count"] = 0
# asyncio.create_task(upload_and_publish())
#
# # status_to_publish = None
# # async with invade_cache_lock:
# # video_process_status = shared_local_cache["video_process_status"]
# # if video_process_status > 0:
# # shared_local_cache["video_process_status"] = 0
# # status_to_publish = video_process_status
# #
# # if status_to_publish is not None:
# # await publist_status(status_to_publish)
# except Exception as e:
# print(f"cal_des_invade 错误: {e}")
# await asyncio.sleep(0.1)
async def cal_des_invade(task_id: str, mqtt, mqtt_publish_topic,
list_points: list[list[any]], model_count: int):
loop = asyncio.get_running_loop()
print(6)
pic_count = 0
pic_count_hongxian = 0
track_id_filters = []
for i in range(model_count):
track_id_filters.append(TrackIDEventFilter(max_inactive_time=5.0)) # 5秒不出现则清除
# 用于记录已上报的track_id及其上报时间
reported_track_ids = defaultdict(float)
# 上报间隔时间(秒)
report_interval = 2.0
while not stop_event.is_set():
print(777777777)
cv_frame = None
# 检查队列长度,避免堆积
if invade_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
print(f"警告invade_queue 积压(当前长度={invade_queue.qsize()}),清空队列")
while not invade_queue.empty():
await invade_queue.get()
print("队列已清空,等待新数据...")
await asyncio.sleep(0.1)
continue
# 获取队列数据
try:
cv_frame = await asyncio.wait_for(invade_queue.get(), timeout=0.8)
# 检查数据类型
if not isinstance(cv_frame, dict):
print(f"⚠️ 警告cv_frame 不是字典,而是 {type(cv_frame)}")
continue
except asyncio.TimeoutError:
print(f"cal_des_invade TimeoutError")
continue
if cv_frame is None:
continue
print("cal_des_invade inside")
frame_copy = cv_frame['frame_copy']
frame = cv_frame['frame']
air_alti = cv_frame['osd_info']
detections = cv_frame['detections']
detections_list = cv_frame['detections_list']
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
if not isinstance(frame, np.ndarray) or frame.size == 0:
print("⚠️ 警告frame不是有效的numpy数组")
continue
# 获取宽高
img_height, img_width = frame.shape[:2]
if not air_alti:
continue
gimbal_yaw = air_alti.gimbal_yaw
gimbal_pitch = air_alti.gimbal_pitch
gimbal_roll = air_alti.gimbal_roll
height = air_alti.height
cam_longitude = air_alti.longitude
cam_latitude = air_alti.latitude
# 二维数组,特殊要求
for points in list_points:
# 批量返回图像的像素坐标
results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude, cam_latitude,
img_width,
img_height, points)
try:
invade_point = []
message_point = []
cls_count = 0
model_id = 101101
current_time = time.time()
for idx, model_para in enumerate(model_para_list):
if not isinstance(detections_list, list) or idx >= len(detections_list):
continue
detections = detections_list[idx] # 正确获取当前模型的检测结果
if detections is None or len(detections.boxes) < 1:
continue
if idx >= len(track_id_filters):
track_id_filters.append(TrackIDEventFilter(max_inactive_time=5.0))
track_filter = track_id_filters[idx]
for bbox, class_id, class_name, confidence, track_id in detections:
# 跳过无效的track_id
if track_id is None:
continue
# 检查是否应该上报该track_id
should_report = True
# 如果这个track_id已经上报过检查是否超过上报间隔
if track_id in reported_track_ids:
last_report_time = reported_track_ids[track_id]
if current_time - last_report_time < report_interval:
print("基于track_id触发去重事件")
should_report = False
if should_report and track_filter.should_report(track_id):
should_report = True
if track_id<0: #适配MultiYOLODetector类该类不支持追踪默认track_id为-1
should_report = True
# 如果使用TrackIDEventFilter判断需要上报
if should_report:
x1, y1, x2, y2 = map(int, bbox) # 确保坐标是整数
point_x = (x1 + x2) / 2
point_y = (y1 + y2) / 2
is_invade = is_point_in_polygon(point_x, point_y, results)
if is_invade:
cls_count += 1
invade_point.append({
"u": point_x,
"v": point_y,
"class_name": class_name
})
model_list_func_id = model_para["model_list_func_id"]
model_func_id = model_para["func_id"]
en_name = model_para["model_chinese_labe"][
model_para["model_class_names"].index(class_name)]
message_point.append({
"confidence": confidence,
"cls_id": class_id,
"type_name": en_name,
"box": [x1, y1, x2, y2],
"track_id": track_id # 添加track_id到上报信息
})
# 记录上报时间
reported_track_ids[track_id] = current_time
label = f"{confidence:.2f}:{track_id}"
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
text_x = x1
text_y = y1 - 5
if text_y < 0:
text_y = y2 + text_size[1] + 5
temp_img = frame_copy.copy()
frame_copy = put_chinese_text(
temp_img,
label,
(text_x, text_y - text_size[1]),
)
point_list = [] # 整理红线集合
# 测试红线
for point in results:
point_list.append([point["u"], point["v"]])
for point in message_point:
cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
(point["box"][2], point["box"][3]), (0, 255, 255), 2)
cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)],
isClosed=True, color=(0, 0, 255), thickness=2)
pic_count_hongxian = pic_count_hongxian + 1
# cv2.imwrite(f"save_pic\invade-hongxian\hongxianongly-{pic_count_hongxian}.jpg", frame_copy)
if len(invade_point) > 0:
for point in results:
point_list.append([point["u"], point["v"]])
for point in message_point:
cv2.rectangle(frame, (point["box"][0], point["box"][1]),
(point["box"][2], point["box"][3]), (0, 255, 255), 2)
# 绘制红线
cv2.polylines(frame, [np.array(point_list, dtype=np.int32)],
isClosed=True, color=(0, 0, 255), thickness=2)
print("hongxianhongxianhongxianhongxian")
pic_count = pic_count + 1
# cv2.imwrite(f"save_pic\invade\hongxian-{pic_count}.jpg", frame)
drawn_frame = frame.copy() # 关键修复:深拷贝绘制后的帧
# 图像编码
def encode_frame():
success, buffer = cv2.imencode(".jpg", drawn_frame)
return buffer.tobytes() if success else None
buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
if not buffer_bytes:
continue
# 并行处理上传和MQTT发布
async def upload_and_publish():
# 上传到MinIO
def upload_minio():
return upload_frame_buff_from_buffer(buffer_bytes, None)
minio_path, file_type = await loop.run_in_executor(
upload_executor, upload_minio
)
# 构造消息
message = {
"task_id": task_id,
"minio": {"minio_path": minio_path, "file_type": file_type},
"box_detail": [{
"model_id": model_list_func_id,
"cls_count": cls_count,
"box_count": [message_point]
}]
}
print(f"hongxianhongxianhongxianhongxian上传 {message}")
message_json = json.dumps(message, ensure_ascii=False)
await mqtt.publish(mqtt_publish_topic, message_json)
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
async with invade_cache_lock:
# 读取或修改共享变量
send_count = shared_local_cache["invade_send_count"]
send_count = send_count + 1
shared_local_cache["invade_send_count"] = send_count
if send_count > 1:
# 创建独立任务执行上传和发布
shared_local_cache["invade_send_count"] = 0
asyncio.create_task(upload_and_publish())
except Exception as e:
print(f"cal_des_invade 错误: {e}")
import traceback
traceback.print_exc()
await asyncio.sleep(0.1)
# 全局共享变量
shared_local_cache = {
"send_count": 0,
"invade_send_count": 0,
"video_process_status": 0 # 0、1、2 分别表示录像识别的三种状态,未开始、开始、结束
}
from asyncio import Lock
cache_lock = Lock() # 用于保护共享变量的锁
invade_cache_lock = Lock() # 用于保护共享变量的锁
async def send_frame_to_s3_mq(task_id, mqtt, mqtt_topic):
global stats
count_pic = 0
loop = asyncio.get_running_loop()
local_func_cache = {
"func_100000": None,
"func_100004": None, # 存储缓存缓存人员track_id
"func_100006": None # 存储缓存缓存车辆track_id
}
para = {
"category": 3
}
while not stop_event.is_set():
try:
# 检查队列长度,避免堆积
if cv_frame_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
print(f"警告cv_frame_queue 积压(当前长度={cv_frame_queue.qsize()}),清空队列")
while not cv_frame_queue.empty():
await cv_frame_queue.get()
print("队列已清空,等待新数据...")
await asyncio.sleep(0.1)
continue
# 获取队列数据
try:
print(f"send_frame_to_s3_mq cv_frame_queue.get()")
cv_frame = await asyncio.wait_for(cv_frame_queue.get(), timeout=0.05)
# 检查数据类型
if not isinstance(cv_frame, dict):
print(f"⚠️ 警告cv_frame 不是字典,而是 {type(cv_frame)}")
continue
except asyncio.TimeoutError:
continue
# 准备数据
frame_copy = cv_frame['frame_copy']
frame = cv_frame['frame']
detections = cv_frame['detections']
detections_list = cv_frame['detections_list']
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
# 初始化默认值
frame11 = frame_copy # 默认使用原始帧
box_detail = [] # 默认空列表
for idx, model_para in enumerate(model_para_list):
if not isinstance(detections_list, list) or idx >= len(detections_list):
continue
detections = detections_list[idx] # 正确获取当前模型的检测结果
chinese_label = model_para.get("model_chinese_labe", {})
model_cls = model_para.get("model_cls_index", {})
list_func_id = model_para.get("model_list_func_id", -11)
func_id = model_para.get("func_id", [])
# 获取DRC消息同步操作放到线程池
local_drc_message = await loop.run_in_executor(executor, get_local_drc_message)
if detections is None or len(detections.boxes) < 1:
continue
try:
# 图像处理和结果计算
result = await loop.run_in_executor(
executor,
cal_tricker_results,
frame_copy, detections, None,
func_id, local_func_cache, para, model_cls, chinese_label, list_func_id
)
# 检查返回结果是否是元组
if isinstance(result, tuple) and len(result) == 2:
frame11, box_detail = result
else:
print(f"⚠️ 警告cal_tricker_results 返回了意外格式的结果: {type(result)}")
continue
except Exception as e:
print(f"处理帧时出错: {str(e)}")
import traceback
traceback.print_exc()
continue
count_pic = count_pic + 1
# 图像编码
def encode_frame():
success, buffer = cv2.imencode(".jpg", frame11)
return buffer.tobytes() if success else None
def encode_origin_frame():
success, buffer = cv2.imencode(".jpg", frame)
return buffer.tobytes() if success else None
buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
buffer_origin_bytes = await loop.run_in_executor(upload_executor, encode_origin_frame)
if not buffer_bytes:
continue
# 并行处理上传和MQTT发布
async def upload_and_publish():
# 上传到MinIO
def upload_minio():
minio_path, file_type = upload_frame_buff_from_buffer(buffer_bytes, None)
minio_origin_path, file_type = upload_frame_buff_from_buffer(buffer_origin_bytes, None)
return minio_path, minio_origin_path, file_type
minio_path, minio_origin_path, file_type = await loop.run_in_executor(
upload_executor, upload_minio
)
# 构造消息
message = {
"task_id": task_id,
"minio": {"minio_path": minio_path, "minio_origin_path": minio_origin_path,
"file_type": file_type},
"box_detail": box_detail,
"uav_location": local_drc_message
}
message_json = json.dumps(message, ensure_ascii=False)
await mqtt.publish(mqtt_topic, message_json)
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
async with cache_lock:
# 读取或修改共享变量
send_count = shared_local_cache["send_count"]
send_count = send_count + 1
shared_local_cache["send_count"] = send_count
# 伪代码,目前安全帽比较难识别到
if 100014 == model_para_list[0]["model_list_func_id"]:
if send_count > 2:
# 创建独立任务执行上传和发布
shared_local_cache["send_count"] = 0
asyncio.create_task(upload_and_publish())
else:
if send_count > 30:
# 创建独立任务执行上传和发布
shared_local_cache["send_count"] = 0
asyncio.create_task(upload_and_publish())
except Exception as e:
print(f"send_frame_to_s3_mq 错误: {e}")
import traceback
traceback.print_exc()
await asyncio.sleep(0.1)
# 更新性能统计
stats['processed'] += 1
if time.time() - stats['last_time'] >= 1.0:
stats['avg_fps'] = stats['processed'] / (time.time() - stats['last_time'])
print(f"处理速度: {stats['avg_fps']:.2f} FPS")
stats['processed'] = 0
stats['last_time'] = time.time()
async def start_rtmp_processing(video_url: str, task_id: str, model_configs: List[Dict],
mqtt_pub_ip: str, mqtt_pub_port: int, mqtt_pub_topic: str,
mqtt_sub_ip: str, mqtt_sub_port: int, mqtt_sub_topic: str,
output_rtmp_url: str,
invade_enable: bool, invade_file: str):
# global stop_event, frame_queue, processed_queue, executor, upload_executor
await initialize_resources() # 初始化资源
try:
list_points = [] # 二维数组,里面的一维数组就是面
if invade_file:
invade_file_path = downFile(invade_file)
if invade_file_path is None:
print(f"invade_file is None task_id:{task_id}")
return
with open(invade_file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 提取多边形坐标
features = data.get('features', [])
if not features:
print("没有找到有效的要素数据")
return
# 获取第一个多边形的坐标(假设只有一个多边形)
for polygon in features:
# polygon = features[0]
geometry = polygon.get('geometry', {})
coordinates = geometry.get('coordinates', [])
# 提取经纬度点忽略Z值
points = []
for key, coord in enumerate(coordinates[0]): # 取第一个环的坐标
# 只取经纬度,忽略第三个值(高度) 注意这里是反序不然后面convert_points_to_utm 会报错
points.append(Point(coord[1], coord[0], coord[2], key))
list_points.append(points)
mqtt = MQTTService(mqtt_pub_ip, port=mqtt_pub_port)
# 初始化检测器 MultiYOLODetector_TrackId
# detector = MultiYOLODetector(model_configs)
detector = MultiYOLODetector_TrackId(model_configs)
# mqtt_publish_topic = "thing/product/ai/events"
mqtt_publish_topic = mqtt_pub_topic
# device = MQTTDevice(
# ip="47.108.62.6", # MQTT服务器IP
# port=12503, # MQTT服务器端口
# topics=["test/topic"], # 订阅两个主题
# # topics=["thing/product/8UUXN6S00A0CK7/drc/up"], # 订阅两个主题
# queue_size=50 # 每个method的消息队列最大长度
# )
# topic_camera_osd = "test/topic"
# method_camera_osd = "drc_camera_osd_info_push"
#
# topic_osd_info = "test/topic"
# method_osd_info = "osd_info_push"
device = MQTTDevice(
ip=mqtt_sub_ip, # MQTT服务器IP
port=mqtt_sub_port, # MQTT服务器端口
topics=[mqtt_sub_topic], # 订阅两个主题
# topics=["thing/product/8UUXN6S00A0CK7/drc/up"], # 订阅两个主题
queue_size=50 # 每个method的消息队列最大长度
)
topic_camera_osd = mqtt_sub_topic
method_camera_osd = "drc_camera_osd_info_push" # 对应到消息当中的method
topic_osd_info = mqtt_sub_topic
method_osd_info = "osd_info_push" # 对应到消息当中的method
await device.start()
print(1)
await asyncio.sleep(10) # device等待10秒等待消息处理
print(2)
# 创建多个消费者任务
invade_state=False
if len(list_points) > 0 and invade_enable:
invade_state=True
tasks = [
asyncio.create_task(read_rtmp_frames(video_url, device, topic_camera_osd, method_camera_osd,
topic_osd_info, method_osd_info)),
asyncio.create_task(process_frames(detector)),
asyncio.create_task(write_results_to_rtmp(task_id, output_rtmp_url, None, list_points, invade_state)),
# asyncio.create_task(write_results_to_rtmp(output_rtmp_url)),
]
print(3)
# 侵限文件不为空,即输出当前事件
# 要么基于侵限做消息上传要么直接基于mqtt做消息上传
model_count = len(model_configs)
if len(list_points) > 0 and invade_enable:
print(4)
tasks.append(asyncio.create_task(
cal_des_invade(task_id, mqtt, mqtt_publish_topic, list_points, model_count)))
elif output_rtmp_url is None or len(output_rtmp_url) < 1:
for _ in range(2): # 2个上传消费者
tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_pub_topic)))
# # 创建多个上传任务并行处理
# if output_rtmp_url is None or len(output_rtmp_url) <1:
# for _ in range(2): # 2个上传消费者
# tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_topic)))
# 等待所有任务完成
await asyncio.gather(*tasks, return_exceptions=True)
finally:
await cleanup_resources() # 清理资源
async def start_video_processing(minio_path: str, task_id: str, model_configs: List[Dict],
mqtt_ip: str, mqtt_port: int, mqtt_topic: str, output_rtmp_url: str,
invade_enable: bool, invade_file: str):
# global stop_event, frame_queue, processed_queue, executor, upload_executor
await initialize_resources() # 初始化资源
print("开起录像识别")
try:
download_path = downBigFile(minio_path)
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut_output.mp4"
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut.mp4"
# download_path = r"E:\hami\aqm_0829-55-7_cut.mp4"
# # 近距离 1/3
# download_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4"
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
# 近距离 1/5
# download_path = r"C:\Users\14867\Downloads\DJI_20250912165859_0001_V.mp4"
# download_path = r"C:\Users\14867\Downloads\DJI_20250912202543_0001_V.mp4"
except Exception as e:
logger.error(f"Unexpected error downloading big file: {e}", exc_info=True)
return None
if not download_path:
logger.error(f"big file is none", exc_info=True)
return
print(f"download_path路径 {os.path.abspath(download_path)}")
dir_name = os.path.dirname(download_path)
srt_name = os.path.basename(download_path) + ".srt"
srt_path = os.path.join(dir_name, srt_name)
if os.path.exists(srt_path):
os.remove(srt_path)
command = [
"ffmpeg",
"-i", download_path,
"-map", "0:s:0", # 选择第一个字幕流
"-c:s", "srt", # 强制转换为 SRT 格式(可选)
srt_path
]
# # 近距离 1/3
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
# 近距离 1/5
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912165859_0001_V.mp4.srt"
try:
subprocess.run(command, check=True)
print(f"start_video_processing方法 字幕提取成功,保存至: {srt_path}")
except subprocess.CalledProcessError as e:
print(f"start_video_processing方法错误: FFmpeg 执行失败 - {e}")
return
except FileNotFoundError:
print("start_video_processing方法错误: 未安装 FFmpeg请先安装并添加到系统路径。")
return
print(f"srt_path路径 {os.path.abspath(srt_path)}")
try:
list_points = [] # 二维数组,里面的一维数组就是面
if invade_file:
invade_file_path = downFile(invade_file)
if invade_file_path is None:
print(f"invade_file is None task_id:{task_id}")
return
with open(invade_file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 提取多边形坐标
features = data.get('features', [])
if not features:
print("没有找到有效的要素数据")
return
# 获取第一个多边形的坐标(假设只有一个多边形)
for polygon in features:
# polygon = features[0]
geometry = polygon.get('geometry', {})
coordinates = geometry.get('coordinates', [])
# 提取经纬度点忽略Z值
points = []
for key, coord in enumerate(coordinates[0]): # 取第一个环的坐标
# 只取经纬度,忽略第三个值(高度) 注意这里是反序不然后面convert_points_to_utm 会报错
points.append(Point(coord[1], coord[0], coord[2], key))
list_points.append(points)
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
# 初始化检测器
# detector = MultiYOLODetector(model_configs)
detector = MultiYOLODetector_TrackId(model_configs)
mqtt_publish_topic = mqtt_topic
# 读取视频获取帧率
cap = cv2.VideoCapture(download_path)
if not cap.isOpened():
print("Error: Could not open video.")
return
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
invade_state = False
if len(list_points) > 0 and invade_enable:
invade_state = True
model_count = len(model_configs)
tasks = [
asyncio.create_task(read_video_frames(task_id, mqtt, mqtt_publish_topic, download_path, srt_path)),
asyncio.create_task(process_frames(detector)),
asyncio.create_task(write_results_to_rtmp(task_id, output_rtmp_url, fps, list_points, invade_state)),
]
# cal_des_invade(task_id, mqtt, mqtt_publish_topic, device, topic_camera_osd, method_camera_osd,
# topic_osd_info, method_osd_info, list_points)
# # 侵限文件不为空,即输出当前事件
if len(list_points) > 0 and invade_enable:
print("基于录像开起侵限识别")
tasks.append(asyncio.create_task(
cal_des_invade(task_id, mqtt, mqtt_publish_topic, list_points, model_count)))
# 创建多个上传任务并行处理
for _ in range(2): # 2个上传消费者
tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_topic)))
# 等待所有任务完成
await asyncio.gather(*tasks, return_exceptions=True)
finally:
await cleanup_resources() # 清理资源