ai_project_v1/yolo/cv_multi_model_back_video.py

3166 lines
136 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import math
import os.path
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from time import sleep
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import json
import time
from ultralytics import YOLO
import torch
# 假设其他导入模块正确存在
from cv_back_video import cal_tricker_results, get_local_drc_message, read_drc_mqtt, async_read_drc_mqtt
from middleware.MQTTService import MQTTService, MQTTDevice, empty_osd_callback
from middleware.TaskManager import task_manager
from middleware.entity.air_attitude import Air_Attitude
from middleware.entity.camera_para import read_camera_params, Camera_Para
from middleware.entity.detection import DetectionResult, DetectionResultList
from middleware.entity.timestamp_queue import TimestampedQueue
from middleware.minio_util import upload_file_from_buffer, upload_frame_buff_from_buffer, downFile, \
downBigFile, upload_video_buff_from_buffer
import av
import asyncio
import cv2
import numpy as np
from typing import Dict, Any
from ultralytics import solutions
from middleware.read_srt import parse_srt_file
from touying.ImageReproject_python.cal_func import cal_canv_location_by_osd, red_line_reproject
from touying.ImageReproject_python.img_types import Point
from yolo.detect.multi_yolo_trt_detect_track import MultiYoloTrtDetectorTrackId
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 全局配置
class Config:
MAX_WORKERS = 8 # 图像识别线程池大小
FRAME_QUEUE_SIZE = 5 # 帧队列大小
PROCESSED_QUEUE_SIZE = 5 # 处理后帧队列大小
RETRY_COUNT = 3 # 模型加载重试次数
READ_RTMP_WORKERS = 2 # 读取专用线程池大小
PROCESS_FRAME_WORKERS = 2 # 推理专用线程池大小
UPLOAD_WORKERS = 2 # 上传专用线程池大小
WRITE_FRAME_WORKERS = 2 # 视频推流专用线程池大小
INVADE_WORKERS = 2 # 侵限专用线程池大小
EVENT_VIDEO_WORKERS = 2 # 事件录像专用线程池大小
MAX_QUEUE_WARN_SIZE = 15 # 队列警告阈值(PROCESSED_QUEUE_SIZE的一半)
MODEL_INPUT_SIZE = (640, 640) # YOLO模型输入尺寸确保能被32整除
DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" # 默认使用第一个GPU
TARGET_FPS = 25 # 目标推流帧率
MAX_FRAME_DROP = 5 # 最大允许丢帧数(避免队列积压)
PERF_LOG_INTERVAL = 5.0 # 性能日志输出间隔(秒)
# 全局性能统计变量
perf_stats = {
'read_rtmp_frames': {'count': 0, 'total_time': 0.0},
'process_frames': {'count': 0, 'total_time': 0.0},
'write_results_to_rtmp': {'count': 0, 'total_time': 0.0},
'cal_des_invade': {'count': 0, 'total_time': 0.0},
'send_frame_to_s3_mq': {'count': 0, 'total_time': 0.0},
'last_log_time': time.time()
}
def log_perf(func_name, start_time):
"""记录函数执行时间和调用次数"""
end_time = time.time()
duration = end_time - start_time
# 更新统计
perf_stats[func_name]['count'] += 1
perf_stats[func_name]['total_time'] += duration
# 定期输出性能统计
current_time = time.time()
if current_time - perf_stats['last_log_time'] >= Config.PERF_LOG_INTERVAL:
print("\n=== 性能统计 ===")
for name, stats in perf_stats.items():
if name not in ['last_log_time']:
avg_time = stats['total_time'] / stats['count'] if stats['count'] > 0 else 0
print(f"{name}: 调用次数={stats['count']}, 平均耗时={avg_time:.4f}s")
print("===============\n")
perf_stats['last_log_time'] = current_time
# # 创建专用线程池
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
# frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
# processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
# stop_event = asyncio.Event()
# 模型缓存
model_cache = {}
# 性能统计
stats = {
'processed': 0,
'last_time': time.time(),
'avg_fps': 0.0,
'dropped_frames': 0
}
# 修改 cleanup_resources 函数,确保正确取消所有协程
async def cleanup_resources(upload_executor, executor, frame_queue,
processed_queue, invade_queue, cv_frame_queue,
stop_event):
"""清理资源"""
# global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
print("cleanup_resources 1")
stop_event.set() # 首先设置取消标志
print("cleanup_resources 2")
# 等待所有队列清空(添加超时防止死锁)
await asyncio.wait_for(asyncio.gather(
_empty_queue(frame_queue),
_empty_queue(processed_queue),
_empty_queue(invade_queue),
_empty_queue(cv_frame_queue)
), timeout=5.0)
print("cleanup_resources 8")
# 关闭线程池
if 'executor' in globals() and executor is not None:
print("cleanup_resources 9")
executor.shutdown(wait=True, cancel_futures=True)
if 'upload_executor' in globals() and upload_executor is not None:
print("cleanup_resources 10")
upload_executor.shutdown(wait=True, cancel_futures=True)
print("cleanup_resources 11")
print("cleanup_resources 14")
async def _empty_queue(queue):
"""安全清空队列"""
try:
while not queue.empty():
await queue.get()
except Exception as e:
logger.warning(f"清空队列时出错: {e}")
class MultiDetectionResults:
def __init__(self):
self.boxes = [] # 边界框
self.clss = [] # 类别ID
self.cls_names = [] # 类别名称
self.cls_en_names = [] # 英文类别名称
self.confs = [] # 置信度
self.track_ids = [] # 置信度
class MultiYOLODetector:
"""多模型并行检测器修复多GPU设备不匹配问题"""
def __init__(self, model_configs: List[Dict]):
self.models = []
self.class_maps = []
self.executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
self.allowed_classes = []
self.model_cls = []
self.chinese_label = []
self.list_func_id = []
self.conf = 0.4
self.func_id = -1
self.list_class_names = []
self.list_para_invade_enable = []
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
self._setup_multi_gpu()
for config in model_configs:
model_path = config['path']
cls_map = config.get('cls_map', {})
allowed = config.get('allowed_classes', None)
tracking = config.get('tracking', True) # 是否启用跟踪
model_cls_index = config.get('cls_index', [])
model_chinese_labe = config.get('chinese_label', [])
model_list_func_id = config.get('list_func_id', [])
func_id = config.get('func_id', -1)
class_names = config.get('class_names', True)
para_invade_enable = config.get('para_invade_enable', False)
config_conf = config.get('config_conf', 0.4)
# 加载模型并确保在正确设备上
model = self._load_model(model_path, tracking)
self.models.append(model)
self.class_maps.append(cls_map)
self.allowed_classes.append(allowed)
self.model_cls.append(model_cls_index)
self.chinese_label.append(model_chinese_labe)
self.list_func_id.append(model_list_func_id)
self.func_id = func_id
self.conf = config_conf
self.list_class_names.append(class_names)
self.list_para_invade_enable.append(para_invade_enable)
def _setup_multi_gpu(self):
"""配置多GPU环境"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"检测到 {gpu_count} 个GPU将使用默认设备: {self.device}")
# 强制所有操作使用默认设备
torch.cuda.set_device(self.device)
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
try:
# 加载模型
model = YOLO(model_path)
# model.tracker = None # 手动禁用跟踪,关闭模型相关日志输出
# 移动到目标设备
model = model.to(self.device)
# 启用跟踪仅对YOLOv8有效
if hasattr(model, 'tracker'):
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
print(f"已为模型 {model_path} 启用跟踪")
else:
print(f"警告:模型 {model_path} 不支持跟踪功能")
# 对于多GPU仅在必要时使用DataParallel并确保所有参数在同一设备
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# 确保模型参数在默认设备上
model = model.to(self.device)
model = torch.nn.DataParallel(model, device_ids=[self.device])
# 禁用模型融合以避免多设备问题
model.fuse = lambda verbose=False: model # 替换融合方法
return model
except Exception as e:
print(f"模型加载失败: {e}")
raise
@staticmethod
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
"""预处理帧:调整大小、转换格式并标准化"""
# 调整大小
resized = cv2.resize(frame, input_size)
# 转换为RGB
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
# 转换为张量并标准化
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
# 添加批次维度并移动到目标设备
tensor = tensor.unsqueeze(0).to(device)
return tensor
@staticmethod
def scale_bbox(bbox: List[int], original_size: tuple, input_size: tuple) -> List[int]:
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
oh, ow = original_size
iw, ih = input_size
# 计算缩放比例
scale_w = ow / iw
scale_h = oh / ih
# 缩放边界框
x1, y1, x2, y2 = bbox
return [
int(x1 * scale_w),
int(y1 * scale_h),
int(x2 * scale_w),
int(y2 * scale_h)
]
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
"""异步调用多模型预测"""
loop = asyncio.get_running_loop()
original_size = (frame.shape[0], frame.shape[1]) # (height, width)
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
model = self.models[model_idx]
cls_map = self.class_maps[model_idx]
allowed = self.allowed_classes[model_idx]
model_cls_index = self.model_cls[model_idx]
model_chinese_labe = self.chinese_label[model_idx]
model_list_func_id = self.list_func_id[model_idx]
func_id = self.func_id
model_class_names = self.list_class_names[model_idx]
para_invade_enable = self.list_para_invade_enable[model_idx]
# 预处理帧并确保在正确设备上
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
timestart = time.time()
with torch.no_grad():
# 确保输入和模型在同一设备
if input_tensor.device != next(model.parameters()).device:
input_tensor = input_tensor.to(next(model.parameters()).device)
results = model(input_tensor, verbose=False)
timeend = time.time()
# print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s (设备: {self.device})")
detections = []
model_para = {
"cls_map": cls_map,
"model_chinese_labe": model_chinese_labe,
"model_cls_index": model_cls_index,
"model_list_func_id": model_list_func_id,
"func_id": func_id,
"model_class_names": model_class_names,
"para_invade_enable": para_invade_enable,
"results": results
}
# model_cls_index{}
detection_result_list = DetectionResultList([], [], [], [], [])
for result in results:
boxes = result.boxes.cpu().numpy() # 移动到CPU进行处理
tracks = result.boxes.id.cpu().numpy() if result.boxes.id is not None else None
for i, box in enumerate(boxes):
cls_id = int(box.cls[0])
# 获取原始模型处理DataParallel包装
original_model = model.module if isinstance(model, torch.nn.DataParallel) else model
cls_name = original_model.names[cls_id]
# 过滤不允许的类别
if allowed and cls_name not in allowed:
continue
# 模型输出的边界框是相对于输入尺寸的
x1, y1, x2, y2 = box.xyxy[0].astype(int)
# 缩放回原始图像尺寸
scaled_bbox = self.scale_bbox(
[x1, y1, x2, y2],
original_size,
self.input_size
)
conf = float(box.conf[0])
# 临时测试用,后面需要抽取出来参数
if conf < self.conf:
continue
en_name = cls_map.get(cls_name, "unknown")
# 添加跟踪ID如果存在
track_id = int(tracks[i]) if tracks is not None else None
detection_result_list.boxes.append(scaled_bbox)
detection_result_list.clss.append(cls_id)
detection_result_list.clss_name.append(cls_name)
detection_result_list.confs.append(conf)
detection_result_list.track_ids.append(-1)
detections.append(DetectionResult(
bbox=scaled_bbox,
class_id=cls_id,
class_name=cls_name,
confidence=conf, # 新增字段
track_id=-1
))
return detections, detection_result_list, model_para
# 并行执行
futures = [
loop.run_in_executor(
self.executor,
_predict,
model_idx,
frame.copy()
)
for model_idx in range(len(self.models))
]
results = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
# 合并检测结果
all_detections = []
detections_list = [] # 这里的格式跟results 的二次计算相同
# 合并所有模型的参数到一个字典(避免键冲突)
all_model_paras = []
for model_idx, (detections, detection_result_list, model_para) in enumerate(results):
all_detections.extend(detections)
# 给每个模型的参数加前缀,避免覆盖(如 model0_model_cls_index
# prefixed_para = {f"model{model_idx}_{k}": v for k, v in model_para.items()}
# all_model_paras.update(prefixed_para)
all_model_paras.append(model_para)
detections_list.append(detection_result_list)
return all_detections, detections_list, all_model_paras
class MultiYOLODetector_TrackId:
"""多模型并行检测器修复多GPU设备不匹配问题"""
def __init__(self, model_configs: List[Dict]):
self.models = []
self.class_maps = []
self.executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
self.allowed_classes = []
self.conf = 0.4
self.model_cls = []
self.chinese_label = []
self.list_func_id = []
self.func_id = -1
self.list_class_names = []
self.list_para_invade_enable = []
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
self._setup_multi_gpu()
# 初始化ObjectCounter实例
self.object_counters = []
for config in model_configs:
model_path = config['path']
cls_map = config.get('cls_map', {})
allowed = config.get('allowed_classes', None)
tracking = config.get('tracking', True) # 是否启用跟踪
model_cls_index = config.get('cls_index', True)
model_chinese_labe = config.get('chinese_label', {})
model_list_func_id = config.get('list_func_id', -11)
func_id = config.get('func_id', True)
class_names = config.get('class_names', True)
para_invade_enable = config.get('para_invade_enable', False)
config_conf = config.get('config_conf', 0.4)
# 加载模型并确保在正确设备上
model = self._load_model(model_path, tracking)
self.models.append(model)
self.class_maps.append(cls_map)
self.allowed_classes.append(allowed)
self.model_cls.append(model_cls_index)
self.chinese_label.append(model_chinese_labe)
self.list_func_id.append(model_list_func_id)
self.func_id = func_id
self.conf = config_conf
self.list_class_names.append(class_names)
self.list_para_invade_enable.append(para_invade_enable)
# 为每个模型创建ObjectCounter实例
if hasattr(model, 'tracker'):
# 如果模型支持跟踪则创建带跟踪的ObjectCounter
self.object_counters.append(
solutions.ObjectCounter(
show=False,
region=None, # 可以根据需要设置检测区域
model=model_path,
classes=model_cls_index if allowed else None,
verbose=False # 新增:禁用性能日志打印
)
)
else:
# 如果模型不支持跟踪则创建普通ObjectCounter
self.object_counters.append(
solutions.ObjectCounter(
show=False,
region=None,
model=model_path,
classes=model_cls_index if allowed else None,
verbose=False # 新增:禁用性能日志打印
)
)
def _setup_multi_gpu(self):
"""配置多GPU环境"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"检测到 {gpu_count} 个GPU将使用默认设备: {self.device}")
# 强制所有操作使用默认设备
torch.cuda.set_device(self.device)
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
try:
# 加载模型
model = YOLO(model_path)
# model.tracker = None # 手动禁用跟踪,关闭模型相关日志输出
# 移动到目标设备
model = model.to(self.device)
# 启用跟踪仅对YOLOv8有效
if hasattr(model, 'tracker') and tracking:
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
print(f"已为模型 {model_path} 启用跟踪")
else:
print(f"警告:模型 {model_path} 不支持跟踪功能")
# 对于多GPU仅在必要时使用DataParallel并确保所有参数在同一设备
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# 确保模型参数在默认设备上
model = model.to(self.device)
model = torch.nn.DataParallel(model, device_ids=[self.device])
# 禁用模型融合以避免多设备问题
model.fuse = lambda verbose=False: model # 替换融合方法
return model
except Exception as e:
print(f"模型加载失败: {e}")
raise
@staticmethod
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
"""预处理帧:调整大小、转换格式并标准化"""
# 获取原始帧尺寸 (width, height)
original_h, original_w = frame.shape[:2]
# 计算缩放比例,保持宽高比
scale = min(input_size[0] / original_w, input_size[1] / original_h)
new_w = int(original_w * scale)
new_h = int(original_h * scale)
# 调整大小并保持宽高比
resized = cv2.resize(frame, (new_w, new_h))
# 创建画布并保持模型输入尺寸
canvas = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
canvas[:new_h, :new_w] = resized
# 转换为RGB并标准化
rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
return tensor.unsqueeze(0).to(device)
@staticmethod
def scale_bbox(bbox, original_size, input_size):
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
ow, oh = original_size # (width, height)
iw, ih = input_size # (width, height)
# 计算缩放比例(基于输入尺寸和原始尺寸的比例)
scale_w = ow / iw
scale_h = oh / ih
x1, y1, x2, y2 = bbox
# 缩放并限制坐标范围
x1 = int(max(0, min(x1 * scale_w, ow - 1)))
y1 = int(max(0, min(y1 * scale_h, oh - 1)))
x2 = int(max(0, min(x2 * scale_w, ow - 1)))
y2 = int(max(0, min(y2 * scale_h, oh - 1)))
return [x1, y1, x2, y2]
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
"""异步调用多模型预测"""
loop = asyncio.get_running_loop()
# original_size = (frame.shape[0], frame.shape[1]) # (height, width)
original_size = (frame.shape[1], frame.shape[0]) # (widthheight )
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
try:
model = self.models[model_idx]
object_counter = self.object_counters[model_idx]
cls_map = self.class_maps[model_idx]
allowed = self.allowed_classes[model_idx]
model_cls_index = self.model_cls[model_idx]
model_chinese_labe = self.chinese_label[model_idx]
model_list_func_id = self.list_func_id[model_idx]
func_id = self.func_id
model_class_names = self.list_class_names[model_idx]
para_invade_enable = self.list_para_invade_enable[model_idx]
# 预处理帧
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
timestart = time.time()
with torch.no_grad():
if input_tensor.device != next(model.parameters()).device:
input_tensor = input_tensor.to(next(model.parameters()).device)
# 使用ObjectCounter进行检测
results = object_counter(frame)
# 调试输出结果结构
# print(f"Model {model_idx} results type: {type(results)}")
if hasattr(results, 'boxes'):
print(
f"Boxes type: {type(results.boxes)}, length: {len(results.boxes) if hasattr(results.boxes, '__len__') else 'N/A'}")
timeend = time.time()
# print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s")
detections = []
model_para = {
"cls_map": cls_map,
"model_chinese_labe": model_chinese_labe,
"model_cls_index": model_cls_index,
"model_list_func_id": model_list_func_id,
"func_id": func_id,
"model_class_names": model_class_names,
"para_invade_enable": para_invade_enable,
"results": object_counter # 保存原始结果用于调试
}
results = object_counter
# 处理检测结果
detection_result_list = DetectionResultList([], [], [], [], [])
# 尝试不同的结果访问方式
boxes = []
tracks = results.track_ids
confs = results.confs
clss = results.clss
names = results.names
result_boxes = results.boxes
if isinstance(result_boxes, list) and len(result_boxes) == 0:
return [], DetectionResultList([], [], [], [], []), {}
if isinstance(result_boxes, list) and len(result_boxes) > 0:
boxes = result_boxes
else:
boxes = result_boxes.tolist()
filter_conf = False
# 处理每个检测结果
for i, box in enumerate(boxes):
try:
# 确保box是可迭代的
if not hasattr(box, '__iter__'):
print(f"无效的边界框格式: {box}")
continue
# 转换为列表(如果还不是)
box = list(box) if not isinstance(box, list) else box
# 确保有4个坐标值
if len(box) < 4:
print(f"不完整的边界框坐标: {box}")
continue
x1, y1, x2, y2 = box[:4]
# cls_id = int(results.clss[i])
cls_id = int(results.clss[i])
ind = int(results.clss[i])
cls_name = names[ind]
conf = float(confs[i]) if i < len(confs) else 0.0
# 待优化,将参数提取出来
if conf < self.conf:
# print(f"conf self.conf{conf} {self.conf}")
# filter_conf=True
continue
track_id = int(tracks[i]) if i < len(tracks) and tracks[i] is not None else None
# 过滤不允许的类别
if allowed and cls_name not in allowed:
continue
scaled_bbox = [int(x1), int(y1), int(x2), int(y2)]
detection_result_list.boxes.append([x1, y1, x2, y2])
detection_result_list.clss.append(cls_id)
detection_result_list.clss_name.append(cls_name)
detection_result_list.confs.append(conf)
detection_result_list.track_ids.append(track_id)
detections.append(DetectionResult(
bbox=scaled_bbox,
class_id=cls_id,
class_name=cls_name,
confidence=conf,
track_id=track_id
))
except Exception as e:
print(f"处理单个检测结果时出错: {str(e)}")
continue
# # 基于conf 筛选目标
# if not filter_conf:
# return detections, detection_result_list, model_para
# else:
# return [], DetectionResultList([], [], [], [], []), {}
return detections, detection_result_list, model_para
except Exception as e:
print(f"模型 {model_idx} 预测过程中发生错误: {str(e)}")
import traceback
traceback.print_exc()
return [], DetectionResultList([], [], [], [], []), {}
# 并行执行
futures = [
loop.run_in_executor(
self.executor,
_predict,
model_idx,
frame.copy()
)
for model_idx in range(len(self.models))
]
results111 = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
# 合并检测结果
all_detections = []
detections_list = [] # 这里的格式跟results 的二次计算相同
# 合并所有模型的参数到一个字典(避免键冲突)
all_model_paras = []
for model_idx, (detections, detection_result_list, model_para) in enumerate(results111):
all_detections.extend(detections)
all_model_paras.append(model_para)
detections_list.append(detection_result_list)
return all_detections, detections_list, all_model_paras
async def read_video_frames(task_id, mqtt, mqtt_publish_topic,
local_video_path: str, srt_path: str, frame_queue: asyncio.Queue,
timestamp_frame_queue: TimestampedQueue,
cancel_flag: asyncio.Event = None, ):
srt_list = parse_srt_file(srt_path)
cap = cv2.VideoCapture(local_video_path)
if not cap.isOpened():
print("Error: Could not open video.")
return
prev_time = time.time()
frame_count = 0
fps = 0
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# # 临时使用,后续还是考虑使用文件整体传输消息好一点
# async with cache_lock:
# # 读取或修改共享变量
# video_process_status = shared_local_cache["video_process_status"]
# if video_process_status == 0:
# shared_local_cache["video_process_status"] = 1 # 将状态改为开始
# 临时处理,用作发送当前录像处理的状态,开始还是结束
async def publist_status(status_to_publish):
try:
message = {"task_id": task_id, "video_status": status_to_publish}
await mqtt.publish(mqtt_publish_topic, json.dumps(message, ensure_ascii=False))
except Exception as e:
logger.error(f"Failed to publish status: {e}")
raise
await publist_status(1)
# 控制读帧,避免队列积压,进而出现大量丢帧
cap = cv2.VideoCapture(local_video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频原始FPS
frame_interval = 1.0 / video_fps # 每帧应等待的时间
last_read_time = time.time()
while not cancel_flag.is_set() and cap.isOpened():
ret, frame = cap.read()
if not ret: # 检查是否成功读取帧(视频结束或读取错误)
await publist_status(2)
# 录像的任务停止,由录像本身发起
await task_manager.remove_task(task_id)
# stop_event.set()
# # 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
# async with cache_lock:
# # 读取或修改共享变量
# video_process_status = shared_local_cache["video_process_status"]
# if video_process_status == 0:
# shared_local_cache["video_process_status"] = 2 # 将状态改为结束
break
# 控制读取速度:确保每帧间隔不小于 frame_interval
current_time = time.time()
if current_time - last_read_time < frame_interval:
await asyncio.sleep(frame_interval - (current_time - last_read_time))
last_read_time = current_time
# # 保存当前帧为图片(或直接处理)
# frame_path = f"{output_folder}/frame_{frame_count:06d}.jpg"
# cv2.imwrite(frame_path, frame)
frame_count += 1
# if frame_count % 10 == 0:
# curr_time = time.time()
# fps = 10 / (curr_time - prev_time)
# prev_time = curr_time
# print(f" ({orig_width}x{orig_height}, FPS: {fps:.2f})")
# # 可选每处理N帧打印一次进度
# if frame_count % 100 == 0:
# print(f"Processed {frame_count} frames")
if not frame_queue.full():
if frame_count < len(srt_list) and len(srt_list) > 0:
dji_srt = srt_list[frame_count]
gimbal_pitch = dji_srt.gb_pitch
gimbal_roll = dji_srt.gb_roll
gb_yaw = dji_srt.gb_yaw
air_height = dji_srt.abs_alt
cam_longitude = dji_srt.longitude
cam_latitude = dji_srt.latitude
art_tit = Air_Attitude(gimbal_pitch, gimbal_roll, gb_yaw, air_height, cam_latitude, cam_longitude)
time_ns = time.time_ns()
await frame_queue.put((frame, art_tit, time_ns))
timestamp_frame_queue.append({
"timestamp": time_ns,
"frame": frame
})
else:
await asyncio.sleep(0.01) # 队列满时稍作等待
cap.release()
# # # # # # 测试时候,临时注释
if os.path.exists(srt_path):
os.remove(srt_path)
if os.path.exists(local_video_path):
os.remove(local_video_path)
async def read_rtmp_frames(
loop,
read_rtmp_frames_executor: ThreadPoolExecutor,
video_url: str,
device: Optional[MQTTDevice] = None,
topic_camera_osd: Optional[str] = None,
method_camera_osd: Optional[str] = None,
topic_osd_info: Optional[str] = None,
method_osd_info: Optional[str] = None,
cancel_flag: Optional[asyncio.Event] = None,
frame_queue: asyncio.Queue = None,
timestamp_frame_queue: TimestampedQueue = None
):
"""
异步读取 RTMP 流帧(优化版:移除帧率控制,优化线程池)
"""
max_retries = 20
retry_delay = 2
pic_count = 0
attempt = 0
time_start = time.time_ns() # 添加开始时间统计
frame_count = 0 # 统计总帧数
if cancel_flag is None:
cancel_flag = asyncio.Event()
# loop = asyncio.get_running_loop()
# 打印初始统计信息
print(f"开始读取RTMP流: {video_url}")
while not cancel_flag.is_set() and attempt < max_retries:
attempt += 1
if cancel_flag.is_set():
logger.info("收到停止信号,终止 RTMP 读取")
break
container = None
try:
logger.info(f"尝试连接 RTMP 流 (尝试 {attempt}/{max_retries}): {video_url}")
# 1. 关键优化:将同步的 av.open 和流初始化放到线程池
container = await loop.run_in_executor(read_rtmp_frames_executor, av.open, video_url)
video_stream = await loop.run_in_executor(read_rtmp_frames_executor, next,
(s for s in container.streams if s.type == 'video'))
logger.info(f"成功连接到 RTMP 流: {video_url} ({video_stream.width}x{video_stream.height})")
# 2. 提前获取一次OSD消息验证MQTT是否正常
if device and topic_osd_info and method_osd_info:
osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
if osd_msg:
logger.info(f"初始OSD消息获取成功: 高度={osd_msg.data.height}")
else:
logger.warning("初始OSD消息为空可能MQTT尚未收到消息")
# 3. 关键优化:将同步的帧迭代放到线程池,通过生成器异步获取
async def async_frame_generator():
"""异步帧生成器在后台线程迭代同步帧通过yield返回给事件循环"""
def sync_frame_iter():
try:
for frame in container.decode(video=0):
# 线程内检查取消标志(需定期检查,避免线程无法退出)
if cancel_flag.is_set():
logger.info("后台线程检测到取消信号,停止帧迭代")
break
# 确保是3通道RGB
if len(frame.planes) == 1: # 如果是灰度图
gray = frame.to_ndarray(format='gray')
# 转换为3通道BGR不修改尺寸
bgr = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
yield bgr
else:
# 保持原始尺寸和色彩空间,只转换格式
bgr = frame.to_ndarray(format='bgr24')
yield bgr
except Exception as e:
logger.error(f"同步帧迭代出错: {e}")
finally:
if container:
container.close()
logger.info("RTMP容器已关闭")
# 将同步迭代器包装为异步生成器
gen = sync_frame_iter()
while not cancel_flag.is_set():
try:
# 每次获取一帧都通过线程池执行,避免长时间阻塞
frame = await loop.run_in_executor(read_rtmp_frames_executor, next, gen, None)
if frame is None: # 迭代结束
break
yield frame
except StopIteration:
logger.info("RTMP流帧迭代结束")
break
except Exception as e:
logger.error(f"异步获取帧出错: {e}")
break
# 4. 异步迭代帧(不阻塞事件循环)
async for frame in async_frame_generator():
if cancel_flag.is_set():
logger.info("检测到取消信号,停止读取帧")
break
try:
# 5. 帧转换也放到线程池av.Frame.to_ndarray是CPU密集操作
img = frame.copy() # 确保不修改原始帧
osd_info = None
# 6. 此时事件循环未被阻塞MQTT消息已缓存get_latest_message可即时获取
if device and topic_osd_info and method_osd_info:
osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
if osd_msg and hasattr(osd_msg, 'data'):
osd_info = Air_Attitude(
gimbal_pitch=osd_msg.data.gimbal_pitch,
gimbal_roll=osd_msg.data.gimbal_roll,
gimbal_yaw=osd_msg.data.gimbal_yaw,
height=osd_msg.data.height,
latitude=osd_msg.data.latitude,
longitude=osd_msg.data.longitude
)
# 7. 异步放入帧队列(避免队列满时阻塞)
if not frame_queue.full():
pic_count += 1
frame_count += 1 # 增加总帧数统计
time_ns = time.time_ns()
# 定期输出统计信息每1000帧
if time_ns - time_start > 1000000000:
print(f"readFrames {pic_count}")
pic_count = 0
time_start = time_ns
if img is not None and osd_info is not None:
await frame_queue.put((img, osd_info, time_ns))
timestamp_frame_queue.append({
"timestamp": time_ns,
"frame": img
})
logger.debug(
f"已放入帧队列,累计帧数: {pic_count},队列剩余空间: {frame_queue.maxsize - frame_queue.qsize()}")
else:
logger.warning("帧队列已满等待1ms后重试")
await asyncio.sleep(0.001)
except Exception as frame_error:
logger.error(f"处理单帧时出错: {frame_error}", exc_info=True)
continue
except (av.AVError, IOError) as e:
logger.error(f"RTMP 流错误 (尝试 {attempt}/{max_retries}): {e}")
if attempt < max_retries:
await asyncio.sleep(retry_delay)
else:
raise RuntimeError(f"无法连接 RTMP 流 (尝试 {max_retries} 次后失败): {video_url}")
except asyncio.CancelledError:
logger.info("read_rtmp_frames 收到取消信号")
raise
except Exception as e:
logger.error(f"未知错误: {e}", exc_info=True)
if attempt < max_retries:
await asyncio.sleep(retry_delay)
finally:
# 双重保险:确保容器关闭
if container and not container.closed:
await loop.run_in_executor(None, container.close)
logger.info("RTMP容器在finally中关闭")
# 最终统计信息
if frame_count > 0:
total_time = (time.time_ns() - time_start) / 1e9
avg_fps = frame_count / total_time if total_time > 0 else 0
print(f"RTMP流读取完成总帧数: {frame_count}, 总时间: {total_time:.2f}秒, 平均FPS: {avg_fps:.2f}")
else:
print("RTMP流读取失败未获取到任何帧")
logger.info(f"RTMP 流已结束或被取消,累计处理帧数: {pic_count}")
# async def process_frames(detector: MultiYOLODetector):
# async def process_frames(detector: MultiYOLODetector_TrackId, cancel_flag: asyncio.Event,
# frame_queue: asyncio.Queue, processed_queue: asyncio.Queue):
async def process_frames(detector: MultiYoloTrtDetectorTrackId, cancel_flag: asyncio.Event,
frame_queue: asyncio.Queue, processed_queue: asyncio.Queue):
"""协程处理帧队列"""
start_time = time.time()
time_start = time.time_ns()
pic_count = 0
while not cancel_flag.is_set():
frame_start = time.time()
try:
frame, osd_info, timestamp = await asyncio.wait_for(
frame_queue.get(),
timeout=0.5 # 延长超时,适配处理耗时
)
try:
time_pr_start = time.time_ns()
detections, detections_list, model_para = await detector.predict(frame)
time_pr_end = time.time_ns()
print(f"time_pr_starttime_pr_start {(time_pr_end - time_pr_start) / 1000000}")
predict_state = True
if detections:
print("检测到任何目标")
if not detections:
predict_state = False
logger.debug("未检测到任何目标")
# continue
# # # # 显示帧用于调试(可选)
# cv2.imshow('process_frames', frame)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# stop_event.set()
# break
processed_data = {
'frame': frame,
'osd_info': osd_info,
'detections': detections,
'detections_list': detections_list,
'timestamp': timestamp,
'model_para': model_para,
'predict_state': predict_state # predict状态判定方便rtmp推流做状态判定
}
if not processed_queue.full():
time_end = time.time_ns()
pic_count = pic_count + 1
if time_end - time_start > 1000000000:
print(f"processframes {pic_count}")
pic_count = 0
time_start = time_end
await processed_queue.put(processed_data)
else:
logger.warning("处理队列已满,丢弃帧")
stats['dropped_frames'] += 1
if 'frame' in locals():
frame_time = time.time() - frame_start
# print(f"处理帧耗时: {frame_time:.4f}s")
except Exception as e:
logger.error(f"处理帧时出错: {e}", exc_info=True)
stats['dropped_frames'] += 1
await asyncio.sleep(0.1)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
print("process_frames 收到取消信号")
raise
except Exception as e:
logger.error(f"获取帧时发生意外错误: {e}", exc_info=True)
await asyncio.sleep(0.1)
# finally:
# log_perf('process_frames', start_time)
print("process_frames读取线程已停止")
from PIL import Image, ImageDraw, ImageFont
chinese_font = None
try:
chinese_font = ImageFont.truetype("config/SIMSUN.TTC", 60)
except:
chinese_font = ImageFont.load_default()
def put_chinese_text(img, text, position, font=chinese_font, color=(0, 255, 0)):
"""使用预加载的字体绘制中文"""
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_pil)
draw.text(position, text, font=font, fill=color)
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
# # 全局变量存储推流容器
# stream_containers: Dict[str, Any] = {}
from collections import defaultdict
import time
from typing import Dict, Any
class TrackIDEventFilter:
def __init__(self, max_inactive_time: float = 5.0):
"""
初始化 TrackID 过滤器
:param max_inactive_time: 最大不活跃时间(秒),超时未出现的 track_id 会被移除
"""
self.track_status = defaultdict(int) # 记录每个 track_id 的连续识别次数
self.last_active_time = defaultdict(float) # 记录每个 track_id 最后一次出现的时间
self.max_inactive_time = max_inactive_time # 超时清理阈值
def should_report(self, track_id: int) -> bool:
"""
判断是否应该上报事件
:param track_id: 跟踪ID
:return: True 表示应该上报False 表示过滤
"""
current_time = time.time()
# for tid, last_time in self.last_active_time.items():
# print(f"tidtidtidtid {tid} {current_time} {last_time} {current_time-last_time} {self.max_inactive_time}")
# 1. 清理长时间不活跃的 track_id
inactive_tracks = [
tid for tid, last_time in self.last_active_time.items()
if (current_time - last_time) > self.max_inactive_time
]
for tid in inactive_tracks:
print(f"inactive_tracksinactive_tracksinactive_tracks{tid}")
self.track_status.pop(tid, None)
self.last_active_time.pop(tid, None)
# 2. 更新最后活跃时间
self.last_active_time[track_id] = current_time
show_count = self.track_status[track_id]
print(f"show_countshow_countshow_count {track_id} {show_count}")
# 3. 过滤逻辑
if show_count == 0:
# 首次出现
self.track_status[track_id] = 1
print(f"track_status[track_id] = 1 {self.track_status[track_id]} : {track_id}")
return True
elif show_count > 0:
# 连续出现第二次,上报并重置状态(需要再次连续出现才会触发)
self.track_status[track_id] = show_count + 1 # 重置为 0下次需重新连续出现两次
print(f"track_status[track_id] =show_count + 1 {self.track_status[track_id]} : {track_id}")
return False
async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: float = None,
list_points: list[list[any]] = None, camera_para: Camera_Para = None,
invade_state: bool = False, cancel_flag: asyncio.Event = None,
processed_queue: asyncio.Queue = None, invade_queue: asyncio.Queue = None,
cv_frame_queue: asyncio.Queue = None, stream_containers: Dict[str, Any] = None):
# global stream_containers, count_pic
start_time = time.time()
time_start = time.time_ns()
pic_count = 0
# 修改推流参数
options = {
'preset': 'veryfast',
'tune': 'zerolatency',
'crf': '23',
'g': '50', # 关键帧间隔
'threads': '2', # 限制编码线程
}
codec_name = 'libx264'
max_retries = 3
retry_delay = 2.0
# 初始化视频输出
output_video_path = None
video_writer = None
frame_width, frame_height = None, None
fps = input_fps or Config.TARGET_FPS
last_frame_time = time.time() - 1
frame_interval = 1.0 / fps
try:
while not cancel_flag.is_set():
frame_start = time.time()
try:
# 第一层帧率控制
# current_time = time.time()
#
# time_diff = frame_interval - (current_time - last_frame_time)
# if time_diff > 0:
# await asyncio.sleep(time_diff)
# last_frame_time = current_time
processed_data = await asyncio.wait_for(
processed_queue.get(),
timeout=1
)
# 确保 processed_data 是字典
if not isinstance(processed_data, dict):
print(f"❌ 错误processed_data 不是字典,而是 {type(processed_data)}")
continue
frame = processed_data['frame']
# 绘制检测结果
frame_copy = frame.copy()
predict_state = processed_data['predict_state']
osd_info = processed_data['osd_info']
img_height, img_width = frame.shape[:2]
results = []
results_list = []
# 启用侵限且拿到了飞机的姿态信息,再绘制红线
if invade_state and osd_info:
gimbal_yaw = osd_info.gimbal_yaw
gimbal_pitch = osd_info.gimbal_pitch
gimbal_roll = osd_info.gimbal_roll
height = osd_info.height
# print(f"heightheightheight {height}")
cam_longitude = osd_info.longitude
cam_latitude = osd_info.latitude
# 当前list_points 虽然是二维数组,但是只存了一个,后续根据业务变化
for points in list_points:
# 批量返回图像的像素坐标
point_list = []
results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude,
cam_latitude,
img_width,
img_height, points, camera_para)
if results:
results_list.append(results) # 支持两个区域,高压侵限、营业线侵限
for point in results:
point_list.append([point["u"], point["v"]])
cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int64)], isClosed=True,
color=(0, 0, 255),
thickness=2)
# print(f"predict_statepredict_state {predict_state}")
# 模型输出了推理结果
if predict_state:
# 测试代码,用做测试推理结果,初始化视频写入器(如果尚未初始化)
if video_writer is None and output_video_path:
frame_width = frame.shape[1]
frame_height = frame.shape[0]
# 定义视频编码器和输出文件
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者使用 'avc1' 等其他编码
# fps = Config.TARGET_FPS # 使用配置中的目标帧率
video_writer = cv2.VideoWriter(
output_video_path,
fourcc,
fps,
(frame_width, frame_height)
)
print(f"视频写入器已初始化,分辨率: {frame_width}x{frame_height}, FPS: {fps}")
detections = processed_data['detections']
detections_list = processed_data['detections_list']
model_para = processed_data['model_para']
class_names = model_para[0]["model_class_names"]
chinese_label = model_para[0]["model_chinese_labe"]
cls_map = model_para[0]["cls_map"]
para_invade_enable = model_para[0]["para_invade_enable"]
model_list_func_id = model_para[0]["model_list_func_id"]
model_func_id = model_para[0]["func_id"]
invade_point = []
message_point = []
target_point = [] # 存储满足条件的图像坐标,方便后续经纬度转换
cls_count = 0
# # 初始化统计字典
class_stats = defaultdict(int)
reversed_dict = {value: 0 for value in cls_map.values()}
bg_color = (173, 216, 230) # 文本框底色使用淡蓝色
text_color = (0, 255, 0) # 绿色
for det in detections:
x1, y1, x2, y2 = map(int, det.bbox) # 确保坐标是整数
cls_id = det.class_id # 假设Detection对象有class_id属性
class_name = det.class_name
confidence = det.confidence
track_id = det.track_id
new_track_id = track_id * 100 + cls_id # 类型小于100或者为负数
# 更新统计
class_stats[cls_id] += 1
# 如果开起侵限功能,就只显示侵限内的框
point_x = (x1 + x2) / 2
point_y = (y1 + y2) / 2
print(f"class_name--{class_name}")
print(f"model_class_names: {model_para[0]['model_class_names']}")
if class_name not in model_para[0]["model_class_names"]:
continue
en_name = model_para[0]["model_chinese_labe"][
model_para[0]["model_class_names"].index(class_name)]
if invade_state:
# 同时适配多个区域的侵限判断
is_invade = is_point_in_polygonlist(point_x, point_y, results_list)
# is_invade = is_point_in_polygon(point_x, point_y, results)
# print(f"is_invadeis_invadeis_invade {is_invade} {len(results)}")
if is_invade:
cls_count += 1
invade_point.append({
"u": point_x,
"v": point_y,
"class_name": class_name
})
target_point.append({
"u": point_x,
"v": point_y,
"cls_id": cls_id,
"track_id": track_id,
"new_track_id": new_track_id
}) # 对于侵限,只存储侵限目标
# model_list_func_id = model_para[0]["model_list_func_id"]
# model_func_id = model_para[0]["func_id"]
message_point.append({
"confidence": float(confidence),
"cls_id": cls_id,
"type_name": en_name,
"track_id": track_id,
"box": [x1, y1, x2, y2]
})
label = f"{en_name}:{confidence:.2f}:{track_id}"
# 计算文本位置
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[
0]
text_width, text_height = text_size[0], text_size[1]
text_x = x1
text_y = y1 - 5
# 如果文本超出图像顶部,则放在框内部下方
if text_y < 0:
text_y = y2 + text_height + 5
temp_img = frame_copy.copy()
frame_copy = put_chinese_text(
temp_img,
# label, # 置信度、类别、用作测试
"", # 注释掉汉字
(text_x, text_y - text_height),
)
else:
cls_count += 1
# 绘制边界框
cv2.rectangle(frame_copy, (x1, y1), (x2, y2), (0, 255, 255), 2)
message_point.append({
"confidence": float(confidence),
"cls_id": cls_id,
"type_name": en_name,
"track_id": track_id,
"box": [x1, y1, x2, y2]
})
target_point.append({
"u": point_x,
"v": point_y,
"cls_id": cls_id,
"track_id": track_id,
"new_track_id": new_track_id
}) # 对于侵限,只存储侵限目标
# 准备标签文本
# label = f"{chinese_label.get(cls_id, class_name)}: {confidence:.2f}:{track_id}"
label = f"{confidence:.2f}:{track_id}"
# 计算文本位置
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=8)[0]
text_width, text_height = text_size[0], text_size[1]
text_x = x1
text_y = y1 - 5
# 如果文本超出图像顶部,则放在框内部下方
if text_y < 0:
text_y = y2 + text_height + 5
# 绘制文本背景
padding = 2
temp_img = frame_copy.copy()
frame_copy = put_chinese_text(
temp_img,
# label, # 置信度、类别、用作测试
"", # 注释掉汉字
(text_x, text_y - text_height),
)
if invade_state:
for point in message_point:
cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
(point["box"][2], point["box"][3]),
(0, 255, 255), 2)
# 画红线
# 在左上角显示统计结果
stats_text = []
for cls_id, count in class_stats.items():
cls_name = chinese_label.get(cls_id,
class_names[cls_id] if class_names and cls_id < len(
class_names) else str(
cls_id))
reversed_dict[cls_name] = count
for key, value in reversed_dict.items():
stats_text.append(f"{key}: {value}")
if stats_text:
# 计算统计文本的总高度
text_height = cv2.getTextSize("Test", cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1)[0][
1]
total_height = len(stats_text) * (text_height + 18) # 5是行间距
# 统计文本的起始位置(左上角)
start_x = 50 # 留出200像素宽度
start_y = 20 # 从顶部开始
# # 绘制统计背景
# # bg_color = (0, 0, 0)
# frame_copy = cv2.rectangle(
# frame_copy,
# (start_x - 10, start_y - 10),
# (400, start_y + total_height + 20),
# bg_color,
# -1
# )
# # 逐行绘制统计文本
# for i, text in enumerate(stats_text):
# y_pos = start_y + i * (text_height + 30)
# temp_img = frame_copy.copy()
# frame_copy = put_chinese_text(
# temp_img,
# text,
# (start_x, y_pos),
# color=text_color
# )
new_data = {
'frame_copy': frame_copy,
'frame': frame,
"osd_info": osd_info,
'detections': detections,
"message_point": message_point,
"cls_count": cls_count,
"target_point": target_point,
"model_list_func_id": model_list_func_id,
"model_func_id": model_func_id,
# 提取第一个模型的func_id 字段因为现在已经要做整合无法在区分各个提取第一个模型的func_id
'timestamp': processed_data.get('timestamp'),
"detections_list": detections_list,
"model_para": model_para
# 'model_para': processed_data.get('model_para', {}) # 确保 model_para 存在
}
# 临时代码 rtmp 和侵限逻辑要改
if invade_state:
# para_list 中使能了 para_invade_enable才做侵限判断
if para_invade_enable:
if not invade_queue.full():
await invade_queue.put(new_data)
else:
if not cv_frame_queue.full():
await cv_frame_queue.put(new_data)
time_end = time.time_ns()
pic_count = pic_count + 1
if time_end - time_start > 1000000000:
print(f"writeFrames {pic_count}")
pic_count = 0
time_start = time_end
#
# count_p = count_p + 1
# cv2.imwrite(f"save_pic/rtmp/test-{count_p}.jpg", frame_copy)
# video_writer.write(frame_copy)
# else:
# frame_copy = frame
# 转换颜色空间
rgb_frame = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
# # # # # 显示帧用于调试(可选)
# cv2.imshow(f"write_results_to_rtmp-{task_id}", frame_copy)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# stop_event.set()
# break
# 初始化推流容器(如果尚未初始化)
if output_url and output_url not in stream_containers:
try:
# container = av.open(output_url, mode='w', format='flv')
container = av.open(
output_url,
mode='w',
format='flv',
options={
# 添加RTMP连接超时5秒和数据超时10秒
'rtmp_connect_timeout': '5000000', # 单位微秒5秒
'rtmp_timeout': '10000000', # 单位微秒10秒
'stimeout': '5000000' # 底层socket超时
}
)
stream = container.add_stream(codec_name, rate=Config.TARGET_FPS)
# stream.time_base = f"1/{Config.TARGET_FPS}"
stream.width = frame.shape[1]
stream.height = frame.shape[0]
stream.pix_fmt = 'yuv420p'
stream.options = options
stream_containers[output_url] = {
'container': container,
'stream': stream,
'last_frame_time': time.time(),
'frame_count': 0,
'retry_count': 0
}
print(f"✅ 推流初始化成功: {output_url}")
except Exception as e:
print(f"❌ 推流初始化失败: {e}")
if 'container' in locals():
try:
container.close()
except:
pass
await asyncio.sleep(1.0)
continue
# 推流逻辑
if output_url and output_url in stream_containers:
try:
container_info = stream_containers[output_url]
stream = container_info['stream']
container = container_info['container']
if rgb_frame.dtype == np.uint8:
av_frame = av.VideoFrame.from_ndarray(rgb_frame, format='rgb24')
packets = stream.encode(av_frame) # 这是异步方法
# print(f"📦 encode 生成 {len(packets)} 个 packet")
if packets:
for packet in packets:
try:
container.mux(packet)
container_info['last_frame_time'] = time.time()
container_info['frame_count'] += 1
except Exception as e:
logger.warning(f"推流数据包错误: {e}")
container_info['retry_count'] += 1
if container_info['retry_count'] > max_retries:
raise
else:
# 编码器仍在初始化,不更新 last_frame_time
pass
# 每100帧打印一次状态
# if container_info['frame_count'] % 100 == 0:
# print(f" 已推送 {container_info['frame_count']} 帧到 {output_url}")
if 'frame' in locals():
frame_time = time.time() - frame_start
print(f"推流帧耗时: {frame_time:.4f}s")
else:
print(f"⚠️ 无效帧格式: {rgb_frame.dtype}")
except Exception as e:
logger.error(f"❌ 推流错误: {e}")
# 尝试重新初始化推流
if output_url in stream_containers:
try:
stream_containers[output_url]['container'].close()
except:
pass
del stream_containers[output_url]
await asyncio.sleep(retry_delay)
continue
except asyncio.TimeoutError:
# if stop_event.is_set():
# break
continue
except asyncio.CancelledError:
print("write_results_to_rtmp 收到取消信号")
raise
except Exception as e:
logger.error(f"推流处理异常: {e}", exc_info=True)
await asyncio.sleep(0.1)
finally:
# log_perf('write_results_to_rtmp', start_time)
# 清理推流容器
for url, info in list(stream_containers.items()):
try:
if 'container' in info:
info['container'].close()
except Exception as e:
logger.warning(f"关闭推流容器时出错: {e}")
stream_containers.clear()
# 清理视频写入器
if video_writer is not None:
video_writer.release()
print("write_results_to_rtmp读取线程已停止")
# 基于射线法,判断设备是否在红线内
def is_point_in_polygonlist(point_x, point_y, polygon_list):
inside = False
for polygon in polygon_list:
x, y = point_x, point_y
n = len(polygon)
inside = False
for i in range(n):
p1 = polygon[i]
p2 = polygon[(i + 1) % n]
x1, y1 = p1["u"], p1["v"]
x2, y2 = p2["u"], p2["v"]
# 检查点是否在多边形的顶点或边上
if (x == x1 and y == y1) or (x == x2 and y == y2):
return True # 点在顶点上
if (x - x1) * (y2 - y1) == (y - y1) * (x2 - x1): # 点在边上
if min(x1, x2) <= x <= max(x1, x2) and min(y1, y2) <= y <= max(y1, y2):
return True
# 射线法核心逻辑
if (y1 > y) != (y2 > y): # 确保点在边的垂直范围内
x_intersect = (x2 - x1) * (y - y1) / (y2 - y1) + x1
if x <= x_intersect:
inside = not inside
return inside
# 基于射线法,判断设备是否在红线内
def is_point_in_polygon(point_x, point_y, polygon):
x, y = point_x, point_y
n = len(polygon)
inside = False
for i in range(n):
p1 = polygon[i]
p2 = polygon[(i + 1) % n]
x1, y1 = p1["u"], p1["v"]
x2, y2 = p2["u"], p2["v"]
# 检查点是否在多边形的顶点或边上
if (x == x1 and y == y1) or (x == x2 and y == y2):
return True # 点在顶点上
if (x - x1) * (y2 - y1) == (y - y1) * (x2 - x1): # 点在边上
if min(x1, x2) <= x <= max(x1, x2) and min(y1, y2) <= y <= max(y1, y2):
return True
# 射线法核心逻辑
if (y1 > y) != (y2 > y): # 确保点在边的垂直范围内
x_intersect = (x2 - x1) * (y - y1) / (y2 - y1) + x1
if x <= x_intersect:
inside = not inside
return inside
# 基于两个点的经纬度,计算两个点之间的距离
def haversine(lon1, lat1, lon2, lat2):
"""计算两个经纬度点之间的直线距离(单位:米)"""
lon1, lat1, lon2, lat2 = map(math.radians, [lon1, lat1, lon2, lat2])
dlon = lon2 - lon1
dlat = lat2 - lat1
a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
c = 2 * math.asin(math.sqrt(a))
r = 6371000 # 地球平均半径(米)
return c * r
async def cal_des_invade(loop,invade_executor, task_id: str, mqtt, mqtt_publish_topic,
list_points: list[list[any]], camera_para: Camera_Para, model_count: int,
cancel_flag: asyncio.Event = None, invade_queue: asyncio.Queue = None,
event_queue: asyncio.Queue = None,
device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1):
# loop = asyncio.get_running_loop()
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
pic_count_hongxian = 0
track_filter = TrackIDEventFilter(max_inactive_time=8.0)
# 用于记录已上报的track_id及其上报时间
reported_track_ids = defaultdict(float)
# 上报间隔时间(秒)
report_interval = 8
target_location_back = [] # 本地缓存,用作位置重复计算
current_time_second = int(time.time())
while not cancel_flag.is_set():
# 检查队列长度,避免堆积
if invade_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
print(f"警告invade_queue 积压(当前长度={invade_queue.qsize()}),清空队列")
while not invade_queue.empty():
await invade_queue.get()
print("队列已清空,等待新数据...")
await asyncio.sleep(0.1)
continue
# 获取队列数据
try:
cv_frame = await asyncio.wait_for(invade_queue.get(), timeout=0.8)
# 检查数据类型
if not isinstance(cv_frame, dict):
print(f"⚠️ 警告cv_frame 不是字典,而是 {type(cv_frame)}")
continue
except asyncio.TimeoutError:
print(f"cal_des_invade TimeoutError")
continue
except asyncio.CancelledError:
print("cal_des_invade 收到取消信号")
raise
if cv_frame is None:
continue
if repeat_time > 0: # 基于时间的重复计算,是否使能
read_cv_time_second = int(time.time())
if read_cv_time_second - current_time_second < repeat_time:
continue
else:
current_time_second = read_cv_time_second
# print("cal_des_invade inside")
frame_copy = cv_frame['frame_copy']
frame = cv_frame['frame']
target_point = cv_frame['target_point']
message_point = cv_frame['message_point']
cls_count = cv_frame['cls_count']
model_list_func_id = cv_frame['model_list_func_id']
model_func_id = cv_frame['model_func_id']
air_alti = cv_frame['osd_info']
detections = cv_frame['detections']
detections_list = cv_frame['detections_list']
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
timestamp = cv_frame.get('timestamp', time.time_ns()) # 默认空字典
if not isinstance(frame, np.ndarray) or frame.size == 0:
print("⚠️ 警告frame不是有效的numpy数组")
continue
# 获取宽高
img_height, img_width = frame.shape[:2]
if not air_alti:
continue
gimbal_yaw = air_alti.gimbal_yaw
gimbal_pitch = air_alti.gimbal_pitch
gimbal_roll = air_alti.gimbal_roll
height = air_alti.height
cam_longitude = air_alti.longitude
cam_latitude = air_alti.latitude
try:
current_time = time.time()
h = device_height
us = []
vs = []
heights = []
should_report = True
for item in target_point:
# # 跳过无效的track_id
# 检查是否应该上报该track_id
u = item["u"]
v = item["v"]
cls_id = item["cls_id"]
track_id = item["track_id"]
new_track_id = item["new_track_id"]
# # 如果这个track_id已经上报过检查是否超过上报间隔
# # if new_track_id in reported_track_ids:
# if track_id in reported_track_ids:
# last_report_time = reported_track_ids[track_id]
# if current_time - last_report_time < report_interval:
# print(f"基于track_id触发去重事件{track_id}")
# should_report = False
print(f"方法 cal_des_invade")
if not track_filter.should_report(track_id):
should_report = False
break
# if track_id < 0: # 适配MultiYOLODetector类该类不支持追踪默认track_id为-1
# should_report = True
# 如果使用TrackIDEventFilter判断需要上报
if should_report:
us.append(u)
vs.append(v)
heights.append(h)
if should_report:
print("进行侵限计算")
location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
height, cam_longitude, cam_latitude, img_width, img_height,
heights)
if not location_results:
continue
# for point in location_results:
# target_location.append({point[1], point[2]})
# point_list = [] # 整理红线集合
repeat_state = False
show_des = 0
str_loca = ""
des_location_result=[]
if repeat_dis > 0: # ai_model_list repeat_dis 字段大于零,才启用去重
if len(target_location_back) > 0: # 当前逻辑并不严谨,只是比较了第一个位置信息
des1_back = target_location_back[0]
des1_back_longitude = des1_back[0]
des1_back_latitude = des1_back[1]
des1_back_height = des1_back[2]
des1 = location_results[0]
des1_longitude = des1[0]
des1_latitude = des1[1]
des_location_result.append({"longitude": des1_longitude,
"latitude": des1_latitude})
des1_height = des1[2]
str_loca = f"{des1_back_longitude}:{des1_back_latitude}---{des1_longitude}{des1_latitude}"
des = haversine(des1_back_longitude, des1_back_latitude, des1_longitude, des1_latitude)
show_des = des
if des < repeat_dis:
print(f"触发基于坐标判断重复,坐标距离:{des}")
repeat_state = True
else:
target_location_back = location_results # 未触发去重逻辑,即更新本地位置缓存
else:
target_location_back = location_results # 基于坐标位置去重判断,如果失败就缓存本次位置信息
# 测试红线
# for point in results:
# point_list.append([point["u"], point["v"]])
# for point in message_point:
# cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
# (point["box"][2], point["box"][3]), (0, 255, 255), 2)
# cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)],
# isClosed=True, color=(0, 0, 255), thickness=2)
pic_count_hongxian = pic_count_hongxian + 1
# cv2.imwrite(f"save_pic\invade-hongxian\hongxianongly-{pic_count_hongxian}.jpg", frame_copy)
# if len(invade_point) > 0:
# print("hongxianhongxianhongxianhongxian")
# pic_count = pic_count + 1
# # cv2.imwrite(f"save_pic\invade\hongxian-{pic_count}.jpg", frame)
# drawn_frame = frame_copy # 关键修复:深拷贝绘制后的帧
# 图像编码
if not repeat_state: # 未触发去重逻辑,即执行图像上传逻辑
print(f"未发现重复,现在上传{str_loca}======{show_des}")
def encode_origin_frame():
success, buffer = cv2.imencode(".jpg", frame)
return buffer.tobytes() if success else None
def encode_frame():
success, buffer = cv2.imencode(".jpg", frame_copy)
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
# cv2.imwrite(f"save_pic\qm\hongxian-encode-{pic_count_hongxian}.jpg", frame)
#
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
# cv2.imwrite(f"save_pic\gdcl\hongxian-encode-{pic_count_hongxian}.jpg", frame_copy)
return buffer.tobytes() if success else None
buffer_bytes = await loop.run_in_executor(invade_executor, encode_frame)
buffer_origin_bytes = await loop.run_in_executor(invade_executor, encode_origin_frame)
if not buffer_bytes:
continue
# 并行处理上传和MQTT发布
async def upload_and_publish():
# 上传到MinIO
def upload_minio():
minio_path, file_type = upload_frame_buff_from_buffer(buffer_bytes, None)
minio_origin_path, file_type = upload_frame_buff_from_buffer(buffer_origin_bytes, None)
return minio_path, minio_origin_path, file_type
# return upload_frame_buff_from_buffer(buffer_bytes, None)
minio_path, minio_origin_path, file_type = await loop.run_in_executor(
invade_executor, upload_minio
)
print(f"minio_pathminio_pathminio_pathminio_path {minio_path}")
# 构造消息
message = {
"task_id": task_id,
"time": str(datetime.now()),
"detection_id": timestamp,
"minio": {"minio_path": minio_path,
"minio_origin_path": minio_origin_path,
"file_type": file_type},
"box_detail": [{
"model_id": model_func_id,
"cls_count": cls_count,
"box_count": [message_point], # 特殊处理
"location_results": location_results # 增加经纬度信息
}],
"osd_location": {
"longitude": cam_longitude,
"latitude": cam_latitude
},
"des_location":des_location_result
}
if not event_queue.full():
await event_queue.put({
"timestamp": timestamp # 存储事件触发的时刻,用作视频制作
})
else:
logger.warning("event_queue 帧队列已满等待1ms后重试")
await asyncio.sleep(0.001)
print(f"hongxianhongxianhongxianhongxian上传 {message}:{event_queue.qsize()} ")
message_json = json.dumps(message, ensure_ascii=False)
await mqtt.publish(mqtt_publish_topic, message_json)
asyncio.create_task(upload_and_publish())
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
# async with invade_cache_lock:
# # 读取或修改共享变量
# send_count = shared_local_cache["invade_send_count"]
# send_count = send_count + 1
# shared_local_cache["invade_send_count"] = send_count
# if send_count > 1:
# # 创建独立任务执行上传和发布
# shared_local_cache["invade_send_count"] = 0
# asyncio.create_task(upload_and_publish())
except Exception as e:
print(f"cal_des_invade 错误: {e}")
import traceback
traceback.print_exc()
await asyncio.sleep(0.1)
print("cal_des_invade读取线程已停止")
# 全局共享变量
shared_local_cache = {
"send_count": 0,
"invade_send_count": 0,
"video_process_status": 0 # 0、1、2 分别表示录像识别的三种状态,未开始、开始、结束
}
from asyncio import Lock, AbstractEventLoop
cache_lock = Lock() # 用于保护共享变量的锁
invade_cache_lock = Lock() # 用于保护共享变量的锁
async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event,
cv_frame_queue: asyncio.Queue,
event_queue: asyncio.Queue = None,
device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1):
global stats
start_time = time.time()
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
track_filter = TrackIDEventFilter(max_inactive_time=10)
reported_track_ids = defaultdict(float)
count_pic = 0
report_interval = 8
# loop = asyncio.get_running_loop()
local_func_cache = {
"func_100000": None,
"func_100004": None, # 存储缓存缓存人员track_id
"func_100006": None # 存储缓存缓存车辆track_id
}
para = {
"category": 3
}
target_location_back = [] # 本地缓存,用作位置重复计算
current_time_second = int(time.time())
repeat_time_count = 0
while not cancel_flag.is_set():
upload_start = time.time()
try:
# 检查队列长度,避免堆积
if cv_frame_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
print(f"警告cv_frame_queue 积压(当前长度={cv_frame_queue.qsize()}),清空队列")
while not cv_frame_queue.empty():
await cv_frame_queue.get()
print("队列已清空,等待新数据...")
await asyncio.sleep(0.1)
continue
# 获取队列数据
try:
cv_frame = await asyncio.wait_for(cv_frame_queue.get(), timeout=0.05)
# 检查数据类型
if not isinstance(cv_frame, dict):
print(f"⚠️ 警告cv_frame 不是字典,而是 {type(cv_frame)}")
continue
except asyncio.TimeoutError:
continue
if repeat_time > 0: # 基于时间的重复计算,是否使能
read_cv_time_second = int(time.time())
# print(f"read_cv_time_secondread_cv_time_second {read_cv_time_second}")
if read_cv_time_second - current_time_second < repeat_time and repeat_time_count > 0:
# print("触发事件去重")
continue
else:
print("没有触发事件去重")
repeat_time_count += 1 # 防止丢失第一个目标
current_time_second = read_cv_time_second
# 准备数据
frame_copy = cv_frame['frame_copy']
frame = cv_frame['frame']
target_point = cv_frame['target_point']
detections = cv_frame['detections']
detections_list = cv_frame['detections_list']
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
timestamp = cv_frame.get('timestamp', time.time_ns()) # 默认空字典
air_alti = cv_frame['osd_info']
img_height, img_width = frame.shape[:2]
gimbal_yaw = air_alti.gimbal_yaw
gimbal_pitch = air_alti.gimbal_pitch
gimbal_roll = air_alti.gimbal_roll
height = air_alti.height
cam_longitude = air_alti.longitude
cam_latitude = air_alti.latitude
# 初始化默认值
frame11 = frame_copy # 默认使用原始帧
box_detail = [] # 默认空列表
current_time = time.time()
h = device_height
us = []
vs = []
heights = []
should_report = True
print(f"target_pointtarget_point {len(target_point)}")
count_item = 0
des_location_result=[]
for item in target_point:
# # 跳过无效的track_id
# 检查是否应该上报该track_id
u = item["u"]
v = item["v"]
cls_id = item["cls_id"]
track_id = item["track_id"]
new_track_id = item["new_track_id"]
# should_report = True
# # 如果这个track_id已经上报过检查是否超过上报间隔
# if new_track_id in reported_track_ids:
# last_report_time = reported_track_ids[new_track_id]
# if current_time - last_report_time < report_interval:
# print(f"基于track_id触发去重事件{new_track_id}")
# should_report = False
# print(f"new_track_idnew_track_id {new_track_id} {track_id}")
# if should_report and track_filter.should_report(new_track_id):
# should_report = True
print(f"方法 send_frame——to {count_item}")
count_item += 1
if not track_filter.should_report(track_id):
should_report = False
break
# if track_id < 0: # 适配MultiYOLODetector类该类不支持追踪默认track_id为-1
# should_report = True
# 如果使用TrackIDEventFilter判断需要上报
if should_report:
us.append(u)
vs.append(v)
heights.append(h)
# location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
# height, cam_longitude, cam_latitude, img_width, img_height,
# heights)
#
# if not location_results:
# print("location_results is null")
# continue
if len(us) > 0:
print("进行侵限计算")
location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
height, cam_longitude, cam_latitude, img_width, img_height,
heights)
if not location_results:
continue
if location_results:
des1 = location_results[0]
des1_longitude = des1[0]
des1_latitude = des1[1]
des1_height = des1[2]
des_location_result.append({"longitude": des1_longitude,
"latitude": des1_latitude})
repeat_state = False
show_des = 0
str_loca = ""
if repeat_dis > 0: # ai_model_list repeat_dis 字段大于零,才启用去重
if len(target_location_back) > 0: # 当前逻辑并不严谨,只是比较了第一个位置信息
des1_back = target_location_back[0]
des1_back_longitude = des1_back[0]
des1_back_latitude = des1_back[1]
des1_back_height = des1_back[2]
des1 = location_results[0]
des1_longitude = des1[0]
des1_latitude = des1[1]
des1_height = des1[2]
str_loca = f"{des1_back_longitude}:{des1_back_latitude}---{des1_longitude}{des1_latitude}"
des = haversine(des1_back_longitude, des1_back_latitude, des1_longitude, des1_latitude)
show_des = des
if des < repeat_dis:
print(f"触发基于坐标判断重复,坐标距离:{des}")
repeat_state = True
else:
target_location_back = location_results # 未触发去重逻辑,即更新本地位置缓存
else:
target_location_back = location_results # 基于坐标位置去重判断,如果失败就缓存本次位置信息
if not repeat_state: # 未触发去重逻辑,即执行图像上传逻辑
for idx, model_para in enumerate(model_para_list):
if not isinstance(detections_list, list) or idx >= len(detections_list):
continue
detections = detections_list[idx] # 正确获取当前模型的检测结果
chinese_label = model_para.get("model_chinese_labe", {})
model_cls = model_para.get("model_cls_index", {})
list_func_id = model_para.get("model_list_func_id", -11)
func_id = model_para.get("func_id", [])
# 获取DRC消息同步操作放到线程池
local_drc_message = await loop.run_in_executor(upload_executor, get_local_drc_message)
if detections is None or len(detections.boxes) < 1:
continue
try:
# 图像处理和结果计算
frame11, box_detail1 = await loop.run_in_executor(
upload_executor,
cal_tricker_results,
frame_copy, detections, None,
func_id, local_func_cache, para, model_cls, chinese_label, list_func_id
)
except Exception as e:
print(f"处理帧时出错: {str(e)}")
import traceback
traceback.print_exc()
continue
count_pic = count_pic + 1
# 图像编码
def encode_frame():
success, buffer = cv2.imencode(".jpg", frame11)
return buffer.tobytes() if success else None
def encode_origin_frame():
success, buffer = cv2.imencode(".jpg", frame)
return buffer.tobytes() if success else None
buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
buffer_origin_bytes = await loop.run_in_executor(upload_executor, encode_origin_frame)
if not buffer_bytes:
continue
# 并行处理上传和MQTT发布
async def upload_and_publish():
# 上传到MinIO
def upload_minio():
minio_path, file_type = upload_frame_buff_from_buffer(buffer_bytes, None)
minio_origin_path, file_type = upload_frame_buff_from_buffer(buffer_origin_bytes, None)
return minio_path, minio_origin_path, file_type
minio_path, minio_origin_path, file_type = await loop.run_in_executor(
upload_executor, upload_minio
)
# 构造消息
message = {
"task_id": task_id,
"time": str(datetime.now()),
"detection_id": timestamp,
"minio": {"minio_path": minio_path, "minio_origin_path": minio_origin_path,
"file_type": file_type},
"box_detail": box_detail1,
"uav_location": local_drc_message,
"osd_location": {
"longitude": cam_longitude,
"latitude": cam_latitude
},
"des_location":des_location_result
}
await event_queue.put({
"timestamp": timestamp # 存储事件触发的时刻,用作视频制作
})
message_json = json.dumps(message, ensure_ascii=False)
await mqtt.publish(mqtt_topic, message_json)
asyncio.create_task(upload_and_publish())
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
# async with cache_lock:
# # 读取或修改共享变量
# send_count = shared_local_cache["send_count"]
# send_count = send_count + 1
# shared_local_cache["send_count"] = send_count
#
# # 伪代码,目前安全帽比较难识别到
# if 100014 == model_para_list[0]["model_list_func_id"]:
# if send_count > 2:
# # 创建独立任务执行上传和发布
# shared_local_cache["send_count"] = 0
# asyncio.create_task(upload_and_publish())
# else:
# if send_count > 30:
# # 创建独立任务执行上传和发布
# shared_local_cache["send_count"] = 0
# asyncio.create_task(upload_and_publish())
if 'frame' in locals():
upload_time = time.time() - upload_start
print(f"上传耗时: {upload_time:.4f}s")
except Exception as e:
print(f"send_frame_to_s3_mq 错误: {e}")
import traceback
traceback.print_exc()
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print("send_frame_to_s3 收到取消信号")
raise
# finally:
# log_perf('send_frame_to_s3_mq', start_time)
# 更新性能统计
stats['processed'] += 1
if time.time() - stats['last_time'] >= 1.0:
stats['avg_fps'] = stats['processed'] / (time.time() - stats['last_time'])
print(f"处理速度: {stats['avg_fps']:.2f} FPS")
stats['processed'] = 0
stats['last_time'] = time.time()
print("send_frame_to_s3线程已停止")
#
# def frames_to_video_bytes(frames, fps=25, format="mp4"):
# """
# 将多个帧转换为视频字节数组(内存处理,不存储到本地)
#
# Args:
# frames (list): 包含多个帧的列表,每个帧是 numpy 数组BGR 或 RGB 格式)
# fps (int): 视频帧率
# format (str): 视频格式(如 "mp4"、"avi"
#
# Returns:
# bytes: 视频的字节数组
# """
# if len(frames) == 0:
# raise ValueError("帧列表不能为空")
#
# # 确保所有帧是 uint8 类型
# frames = [np.array(frame, dtype=np.uint8) for frame in frames]
#
# # 使用 BytesIO 存储视频数据
# output = BytesIO()
#
# # 使用 imageio 写入内存
# with imageio.get_writer(output, format=format, fps=fps) as writer:
# for frame in frames:
# # 如果帧是 BGR 格式OpenCV 默认),转换为 RGB
# if frame.ndim == 3 and frame.shape[2] == 3:
# frame = frame[..., ::-1] # BGR → RGB
# writer.append_data(frame)
#
# # 获取字节数据
# video_bytes = output.getvalue()
# output.close()
#
# return video_bytes
#
# def frames_to_video_bytes(frames, fps=25, format="mp4"):
# """
# 将多个帧转换为视频字节数组(内存处理,不存储到本地)
# Args:
# frames (list): 包含多个帧的列表,每个帧是 numpy 数组BGR 或 RGB 格式)
# fps (int): 视频帧率
# format (str): 视频格式(如 "mp4"、"avi"
# Returns:
# bytes: 视频的字节数组 或 None如果失败
# """
# if len(frames) == 0:
# logger.warning("frames_to_video_bytes: 输入帧列表为空")
# return None
#
# try:
# # 确保所有帧是 uint8 类型
# frames = [np.array(frame, dtype=np.uint8) for frame in frames]
#
# output = BytesIO()
# with imageio.get_writer(output, format=format, fps=fps) as writer:
# for frame in frames:
# if frame.ndim == 3 and frame.shape[2] == 3:
# frame = frame[..., ::-1] # BGR → RGB
# writer.append_data(frame)
# return output.getvalue()
# except Exception as e:
# logger.error(f"frames_to_video_bytes 生成视频失败: {e}")
# return None
import subprocess
import numpy as np
import tempfile
import os
import time
import threading
import traceback
def frames_to_video_bytes(frames, fps=25, format="flv"):
"""
使用临时文件暂存帧数据,生成视频字节后自动删除临时文件
特点降低管道传输压力适配大尺寸帧和旧版FFMPEG
"""
# 日志工具:带时间戳和详细级别
def log_info(msg):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] INFO: {msg}")
def log_warning(msg):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] WARNING: {msg}")
def log_error(msg):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ERROR: {msg}")
def log_debug(msg):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: {msg}")
# 强制切换为FLV格式兼容性最佳
if format != "flv":
log_warning(f"自动切换为FLV格式兼容性最佳")
# format = "flv"
# 输入帧合法性检查
if not isinstance(frames, list) or len(frames) == 0:
log_error("输入帧列表为空或非列表类型返回None")
return None
# 记录开始时间
start_time = time.time()
# 创建临时文件(自动删除)
temp_frame_file = None
temp_video_file = None
try:
# 创建存储帧数据的临时文件
temp_frame_fd, temp_frame_path = tempfile.mkstemp(suffix=".raw", text=False)
os.close(temp_frame_fd) # 关闭文件描述符,使用路径操作
temp_frame_file = open(temp_frame_path, 'wb')
log_debug(f"创建帧数据临时文件: {temp_frame_path}")
# 创建存储视频的临时文件
temp_video_fd, temp_video_path = tempfile.mkstemp(suffix=f".{format}", text=False)
os.close(temp_video_fd)
log_debug(f"创建视频临时文件: {temp_video_path}")
# --------------------------
# 步骤1帧预处理并写入临时文件
# --------------------------
try:
log_info("开始帧预处理并写入临时文件")
# 处理第一帧:确认基础信息
first_frame = np.asarray(frames[0], dtype=np.uint8)
if first_frame.ndim != 3 or first_frame.shape[2] != 3:
log_error("第一帧必须是3通道BGR/RGB图像")
return None
frame_h, frame_w = first_frame.shape[:2]
is_4k = frame_w >= 3840 and frame_h >= 2160
# 4K帧特殊优化
if is_4k:
single_frame_mb = (frame_w * frame_h * 3) / (1024 * 1024)
total_data_mb = single_frame_mb * len(frames)
log_info(f"检测到4K帧{frame_w}x{frame_h},单帧{single_frame_mb:.1f}MB总数据{total_data_mb:.1f}MB")
fps = min(fps, 12)
log_info(f"4K适配帧率调整为{fps}fps")
# 统一所有帧格式并写入临时文件
try:
from cv2 import resize
except ImportError:
log_error("未安装opencv-python请执行 pip install opencv-python")
return None
for idx, frame in enumerate(frames):
try:
# 1. 转为uint8类型
frame_uint8 = np.asarray(frame, dtype=np.uint8)
# 2. 统一尺寸
if frame_uint8.shape[:2] != (frame_h, frame_w):
log_debug(f"{idx + 1}尺寸不匹配,缩放至{frame_w}x{frame_h}")
frame_uint8 = resize(frame_uint8, (frame_w, frame_h), interpolation=1)
# 3. BGR转RGB
frame_rgb = frame_uint8[..., ::-1]
# 4. 写入临时文件
temp_frame_file.write(frame_rgb.tobytes())
log_debug(f"{idx + 1}已写入临时文件")
except Exception as e:
log_error(f"处理帧{idx + 1}失败:{str(e)}")
traceback.print_exc()
return None
# 确保所有数据写入磁盘
temp_frame_file.flush()
os.fsync(temp_frame_file.fileno())
temp_frame_file.close()
log_info(f"所有帧已写入临时文件,共{len(frames)}")
except Exception as e:
log_error(f"帧预处理阶段异常:{str(e)}")
traceback.print_exc()
return None
# --------------------------
# 步骤2构造FFMPEG命令使用临时文件
# --------------------------
try:
log_info("构造FFMPEG命令使用临时文件输入")
ffmpeg_cmd = [
# "ffmpeg",
"/usr/bin/ffmpeg",
"-hide_banner",
"-loglevel", "error",
"-y",
# 输入配置(从临时文件读取)
"-f", "rawvideo",
"-vcodec", "rawvideo",
"-s", f"{frame_w}x{frame_h}",
"-pix_fmt", "rgb24",
"-r", str(fps),
"-i", temp_frame_path, # 从临时文件读取帧数据
# 输出配置
"-vcodec", "libx264",
"-pix_fmt", "yuv420p",
"-r", str(fps),
"-preset", "ultrafast",
"-crf", "30",
# 输出到临时文件
"-f", format,
temp_video_path
]
log_info(f"FFMPEG命令{' '.join(ffmpeg_cmd)}")
except Exception as e:
log_error(f"构造FFMPEG命令失败{str(e)}")
traceback.print_exc()
return None
# --------------------------
# 步骤3执行FFMPEG处理
# --------------------------
try:
log_info("启动FFMPEG进程使用临时文件")
# 启动FFMPEG进程
process = subprocess.Popen(
ffmpeg_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=False
)
# 捕获错误输出
ffmpeg_error_log = []
def read_stderr():
while True:
err_line = process.stderr.readline()
if not err_line:
break
err_str = err_line.decode("utf-8", errors="replace").strip()
if err_str:
ffmpeg_error_log.append(err_str)
log_error(f"FFMPEG错误{err_str}")
stderr_thread = threading.Thread(target=read_stderr, daemon=True)
stderr_thread.start()
# 等待处理完成
log_info("等待FFMPEG处理完成超时120秒")
process.wait(timeout=120)
stderr_thread.join()
# 检查处理结果
if process.returncode != 0:
log_error(f"FFMPEG处理失败退出码{process.returncode}")
log_error(f"FFMPEG完整错误\n{chr(10).join(ffmpeg_error_log)}")
return None
# 读取视频文件内容
log_info("读取视频临时文件内容")
with open(temp_video_path, 'rb') as f:
video_bytes = f.read()
log_info(
f"视频生成成功!大小:{len(video_bytes) / 1024 / 1024:.2f}MB总耗时{time.time() - start_time:.2f}")
return video_bytes
except subprocess.TimeoutExpired:
log_error("FFMPEG处理超时超过120秒")
process.kill()
return None
except Exception as e:
log_error(f"FFMPEG处理异常{str(e)}")
traceback.print_exc()
return None
finally:
# 确保临时文件被删除
try:
if temp_frame_file and not temp_frame_file.closed:
temp_frame_file.close()
if temp_frame_path and os.path.exists(temp_frame_path):
os.remove(temp_frame_path)
log_debug(f"已删除帧临时文件: {temp_frame_path}")
except Exception as e:
log_warning(f"删除帧临时文件失败:{str(e)}")
try:
if temp_video_path and os.path.exists(temp_video_path):
os.remove(temp_video_path)
log_debug(f"已删除视频临时文件: {temp_video_path}")
except Exception as e:
log_warning(f"删除视频临时文件失败:{str(e)}")
log_info("视频生成流程结束")
async def cut_evnt_video_publish(task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event,
event_queue: asyncio.Queue = None,
timestamp_frame_queue: TimestampedQueue = None):
loop = asyncio.get_running_loop()
upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
while not cancel_flag.is_set():
try:
event = await event_queue.get()
print(f"[cut循环] 成功获取事件: {event}")
if not isinstance(event, dict) or "timestamp" not in event:
print(f"[cut循环] 无效事件,跳过")
continue
timestamp = event.get("timestamp")
if timestamp is None:
print("⚠️ 警告:事件中缺少 timestamp")
continue
matched_items = timestamp_frame_queue.query_by_timestamp(timestamp)
print(f"[cut循环] 匹配到 {len(matched_items)} 个帧")
if len(matched_items) == 0:
continue
frames = [item["frame"] for item in matched_items]
if not frames: # 显式检查帧列表是否为空
print("⚠️ 警告:匹配到的帧列表为空")
continue
video_bytes = frames_to_video_bytes(frames, fps=150, format="mp4")
if video_bytes is None: # 检查返回值
print("⚠️ 警告视频生成失败返回None")
continue
async def upload_and_publish():
def upload_minio():
return upload_video_buff_from_buffer(video_bytes, video_format="mp4")
try:
minio_path, file_type = await loop.run_in_executor(upload_executor, upload_minio)
# message = {"detection_id": timestamp, "minio_path": minio_path}
message = {
"task_id": task_id,
"time": str(datetime.now()),
"detection_id": timestamp,
"minio": {"minio_path": minio_path,
"minio_origin_path": minio_path,
"file_type": file_type},
"box_detail": [{
}],
"osd_location": {
},
"des_location": []
}
print(f"成功上传视频: {message}")
await mqtt.publish(mqtt_topic, json.dumps(message, ensure_ascii=False))
except Exception as e:
print(f"上传失败: {e}")
asyncio.create_task(upload_and_publish())
except Exception as e:
print(f"[cut循环] 处理事件出错: {e}")
finally:
if event is not None:
event_queue.task_done()
print(f"[cut循环] 确认消费,剩余未完成任务: {event_queue._unfinished_tasks}")
if event_queue.qsize() > 0 and event_queue._unfinished_tasks == 0:
try:
event_queue.put_nowait(None)
dummy = await event_queue.get()
event_queue.task_done()
except:
pass
async def start_rtmp_processing(video_url: str, task_id: str, model_configs: List[Dict],
mqtt_pub_ip: str, mqtt_pub_port: int, mqtt_pub_topic: str,
mqtt_sub_ip: str, mqtt_sub_port: int, mqtt_sub_topic: str,
output_rtmp_url: str,
invade_enable: bool, invade_file: str, camera_para_url: str,
device_height: float, repeat_dis: float, repeat_time: float):
# 初始化资源
# await initialize_resources()
logger.info(f"拉流地址{video_url}")
logger.info(f"推流地址{output_rtmp_url}")
cancel_flag = asyncio.Event()
# 初始化局部变量(避免全局污染)
frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
event_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE) # 存储事件,作为截取视频的标准
print(f"[start_rtmp] 创建event_queue实例{id(event_queue)}") # 打印内存地址
loop = asyncio.get_running_loop() # 声明事件循环管理
stream_containers: Dict[str, Any] = {}
# 创建取消标志
cancel_flag = asyncio.Event()
timestamp_frame_queue = TimestampedQueue(maxlen=1000) # 限定长度的队列,先进先出存储连续帧,事件触发即截取前后视频
try:
# 加载侵限区域数据
list_points = []
# 加载现场摄像头相关数据
camera_para = None
if invade_file and camera_para_url:
camera_file_path = downFile(camera_para_url)
camera_para = read_camera_params(camera_file_path)
invade_file_path = downFile(invade_file)
if invade_file_path:
with open(invade_file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for feature in data.get('features', []):
geometry = feature.get('geometry', {})
coordinates = geometry.get('coordinates', [])
points = [Point(coord[1], coord[0], coord[2], i)
for i, coord in enumerate(coordinates[0])]
list_points.append(points)
# 初始化MQTT和设备
mqtt = MQTTService(mqtt_pub_ip, port=mqtt_pub_port)
# detector = MultiYOLODetector_TrackId(model_configs)
detector = MultiYoloTrtDetectorTrackId(model_configs)
print(f"mqtt_sub_ipmqtt_sub_ipmqtt_sub_ip {mqtt_sub_ip} {mqtt_sub_port} {mqtt_sub_topic}")
# 创建MQTT设备任务
async def start_mqtt_device():
device = MQTTDevice(
ip=mqtt_sub_ip,
port=mqtt_sub_port,
topics=[mqtt_sub_topic],
queue_size=50
)
# 对osd_info_push 相关信息进行回调,否则消息读取非常慢
device.register_callback(
topic=mqtt_sub_topic,
method="osd_info_push",
callback=empty_osd_callback # 直接引用已定义的函数
)
await device.start()
return device
# device_task = asyncio.create_task(start_mqtt_device(), name="start_mqtt_device")
# await asyncio.sleep(2) # 等待设备初始化
#
# # 确保设备已启动
# device = await device_task
# 然后在 start_rtmp_processing 中直接等待:
device = await start_mqtt_device() # 直到连接和订阅完成才继续
if device is None:
raise RuntimeError("Failed to start MQTT device")
# 创建任务列表
tasks = []
read_task = None
process_task = None
write_task = None
invade_task = None
upload_tasks = []
model_count = len(model_configs)
try:
# RTMP读取任务
read_rtmp_frames_executor = ThreadPoolExecutor(max_workers=Config.READ_RTMP_WORKERS)
read_task = asyncio.create_task(
read_rtmp_frames(
loop,
read_rtmp_frames_executor,
video_url,
device,
mqtt_sub_topic,
"drc_camera_osd_info_push",
mqtt_sub_topic,
"osd_info_push",
cancel_flag,
frame_queue,
timestamp_frame_queue
),
name="read_rtmp_frames"
)
tasks.append(read_task)
#
# # 处理任务
process_frame_executor = ThreadPoolExecutor(max_workers=Config.PROCESS_FRAME_WORKERS)
process_task = asyncio.create_task(
process_frames(detector, cancel_flag, frame_queue, processed_queue),
name="process_frames"
)
tasks.append(process_task)
#
# # 推流任务
invade_state = bool(list_points) and invade_enable
write_frame_executor = ThreadPoolExecutor(max_workers=Config.WRITE_FRAME_WORKERS)
write_task = asyncio.create_task(
write_results_to_rtmp(
task_id,
output_rtmp_url,
None,
list_points,
camera_para,
invade_state,
cancel_flag,
processed_queue,
invade_queue,
cv_frame_queue,
stream_containers
),
name="write_results_to_rtmp"
)
tasks.append(write_task)
# # # 侵限检测任务
if invade_enable and list_points:
invade_executor = ThreadPoolExecutor(max_workers=Config.INVADE_WORKERS)
invade_task = asyncio.create_task(
cal_des_invade(
loop,
invade_executor,
task_id,
mqtt,
mqtt_pub_topic,
list_points,
camera_para,
model_count,
cancel_flag,
invade_queue,
event_queue,
device_height,
repeat_dis, repeat_time
),
name="cal_des_invade"
)
tasks.append(invade_task)
# # S3上传任务
upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
for _ in range(1):
upload_task = asyncio.create_task(
send_frame_to_s3_mq(loop, upload_executor, task_id, mqtt, mqtt_pub_topic,
cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis,
repeat_time),
name=f"send_frame_to_s3_mq_{_}"
)
upload_tasks.append(upload_task)
tasks.append(upload_task)
#
# # # 截取事件并将frame存储为video然后执行上传
# event_video_executor = ThreadPoolExecutor(max_workers=Config.EVENT_VIDEO_WORKERS)
# upload_video = asyncio.create_task(cut_evnt_video_publish(task_id,mqtt, mqtt_pub_topic, cancel_flag,
# event_queue, timestamp_frame_queue),
# name="cut_evnt_video_publish")
# tasks.append(upload_video)
# 注册任务到TaskManager
device_list = [mqtt]
task_info = {
"video_url": video_url,
"output_rtmp_url": output_rtmp_url,
"task_id": task_id,
"cancel_flag": cancel_flag,
"device_list": device_list
}
# 使用asyncio.shield保护主任务,确保任务不会提前停止
main_task = asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
# 注册任务
await task_manager.add_task(
task_id,
task_info,
main_task,
tasks
)
print("start_rtmp 1")
await task_manager.update_heartbeat(task_id)
print("start_rtmp 2")
# 等待任务完成或取消
try:
print("start_rtmp 3")
await main_task
except asyncio.CancelledError:
logger.info(f"Task {task_id} was cancelled")
raise
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
raise
finally:
# 确保所有任务都被取消
cancel_flag.set()
sleep(3)
print("start_rtmp 4")
for task in tasks:
print("start_rtmp 5")
task.cancel()
for upload_task in upload_tasks:
print("start_rtmp 6")
upload_task.cancel()
except Exception as e:
logger.error(f"Error in main task loop: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in start_rtmp_processing: {e}")
raise
finally:
# 确保所有任务被取消
print("start_rtmp 7")
cancel_flag.set()
print("start_rtmp 8")
# 释放cuda内存
detector.destroy()
# 停止所有任务
if read_task and not read_task.done():
print("start_rtmp 9")
read_task.cancel()
if process_task and not process_task.done():
print("start_rtmp 10")
process_task.cancel()
if write_task and not write_task.done():
print("start_rtmp 11")
write_task.cancel()
if invade_task and not invade_task.done():
print("start_rtmp 12")
invade_task.cancel()
for task in upload_tasks:
if task and not task.done():
task.cancel()
print("start_rtmp 13")
# 停止设备
if 'device' in locals():
await device.stop()
print("start_rtmp 14")
# 清理资源
await cleanup_resources()
await task_manager.remove_task(task_id)
logger.info(f"Task {task_id} resources cleaned up")
print(f"start_rtmp_processing线程已停止")
async def start_video_processing(minio_path: str, task_id: str, model_configs: List[Dict],
mqtt_ip: str, mqtt_port: int, mqtt_topic: str, output_rtmp_url: str,
invade_enable: bool, invade_file: str, camera_para_url: str, device_height: float,
repeat_dis: float, repeat_time: float):
# global stop_event, frame_queue, processed_queue, executor, upload_executor
# await initialize_resources() # 初始化资源
print("开起录像识别")
cancel_flag = asyncio.Event()
# 初始化局部变量(避免全局污染)
frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
event_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE) # 存储事件,作为截取视频的标准
stream_containers: Dict[str, Any] = {}
timestamp_frame_queue = TimestampedQueue(maxlen=500) # 限定长度的队列,先进先出存储连续帧,事件触发即截取前后视频
try:
download_path = downBigFile(minio_path)
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut_output.mp4"
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut.mp4"
# download_path = r"E:\hami\aqm_0829-55-7_cut.mp4"
# # 近距离 1/3
# download_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4"
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
# 近距离 1/5
# # download_path = r"C:\Users\14867\Downloads\DJI_20250912165859_0001_V.mp4"
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
# download_path = r"C:\Users\14867\Downloads\DJI_20250918175331_0001_V.mp4"
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
# download_path = r"C:\Users\14867\Downloads\DJI_20250917114158_0001_V.mp4"
except Exception as e:
logger.error(f"Unexpected error downloading big file: {e}", exc_info=True)
return None
if not download_path:
logger.error(f"big file is none", exc_info=True)
return
print(f"download_path路径 {os.path.abspath(download_path)}")
dir_name = os.path.dirname(download_path)
srt_name = os.path.basename(download_path) + ".srt"
srt_path = os.path.join(dir_name, srt_name)
if os.path.exists(srt_path):
os.remove(srt_path)
command = [
"ffmpeg",
"-i", download_path,
"-map", "0:s:0", # 选择第一个字幕流
"-c:s", "srt", # 强制转换为 SRT 格式(可选)
srt_path
]
# # 近距离 1/3
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
# # # # 近距离 1/5
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
# srt_path = r"C:\Users\14867\Downloads\DJI_20250918175331_0001_V.mp4.srt"
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
# srt_path = r"C:\Users\14867\Downloads\DJI_20250917114158_0001_V.mp4.srt"
try:
subprocess.run(command, check=True)
print(f"start_video_processing方法 字幕提取成功,保存至: {srt_path}")
except subprocess.CalledProcessError as e:
print(f"start_video_processing方法错误: FFmpeg 执行失败 - {e}")
return
except FileNotFoundError:
print("start_video_processing方法错误: 未安装 FFmpeg请先安装并添加到系统路径。")
return
print(f"srt_path路径 {os.path.abspath(srt_path)}")
try:
list_points = [] # 二维数组,里面的一维数组就是面
camera_para = None
if invade_file and camera_para_url:
camera_file_path = downFile(camera_para_url)
camera_para = read_camera_params(camera_file_path)
invade_file_path = downFile(invade_file)
if invade_file_path is None:
print(f"invade_file is None task_id:{task_id}")
return
with open(invade_file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 提取多边形坐标
features = data.get('features', [])
if not features:
print("没有找到有效的要素数据")
return
# 获取第一个多边形的坐标(假设只有一个多边形)
for polygon in features:
# polygon = features[0]
geometry = polygon.get('geometry', {})
coordinates = geometry.get('coordinates', [])
# 提取经纬度点忽略Z值
points = []
for key, coord in enumerate(coordinates[0]): # 取第一个环的坐标
# 只取经纬度,忽略第三个值(高度) 注意这里是反序不然后面convert_points_to_utm 会报错
points.append(Point(coord[1], coord[0], coord[2], key))
list_points.append(points)
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
# 初始化检测器
# detector = MultiYOLODetector(model_configs)
detector = MultiYOLODetector_TrackId(model_configs)
mqtt_publish_topic = mqtt_topic
# 读取视频获取帧率
cap = cv2.VideoCapture(download_path)
if not cap.isOpened():
print("Error: Could not open video.")
return
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
invade_state = False
if len(list_points) > 0 and invade_enable:
invade_state = True
model_count = len(model_configs)
# 创建任务列表
tasks = []
read_task = None
process_task = None
write_task = None
invade_task = None
upload_tasks = []
# RTMP读取任务
read_task = asyncio.create_task(
read_video_frames(task_id, mqtt, mqtt_publish_topic, download_path, srt_path, frame_queue,
timestamp_frame_queue, cancel_flag),
name="read_rtmp_frames")
tasks.append(read_task)
# 处理任务
process_task = asyncio.create_task(
process_frames(detector, cancel_flag, frame_queue, processed_queue),
name="process_frames"
)
tasks.append(process_task)
# 推流任务
invade_state = bool(list_points) and invade_enable
write_task = asyncio.create_task(
write_results_to_rtmp(
task_id,
output_rtmp_url,
fps,
list_points,
camera_para,
invade_state,
cancel_flag,
processed_queue,
invade_queue,
cv_frame_queue,
stream_containers
),
name="write_results_to_rtmp"
)
tasks.append(write_task)
# # # 侵限文件不为空,即输出当前事件
# if len(list_points) > 0 and invade_enable:
# print("基于录像开起侵限识别")
# tasks.append(asyncio.create_task(
# cal_des_invade(task_id, mqtt, mqtt_publish_topic, list_points, cancel_flag)))
# 侵限检测任务
if invade_enable and list_points:
invade_task = asyncio.create_task(
cal_des_invade(
task_id,
mqtt,
mqtt_publish_topic,
list_points,
camera_para,
model_count,
cancel_flag,
invade_queue,
event_queue,
device_height,
repeat_dis, repeat_time
),
name="cal_des_invade"
)
tasks.append(invade_task)
# S3上传任务
for _ in range(2):
upload_task = asyncio.create_task(
send_frame_to_s3_mq(task_id, mqtt, mqtt_topic,
cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis, repeat_time),
name=f"send_frame_to_s3_mq_{_}"
)
upload_tasks.append(upload_task)
tasks.append(upload_task)
# # # 截取事件并将frame存储为video然后执行上传
upload_video = asyncio.create_task(cut_evnt_video_publish(task_id, mqtt, mqtt_topic, cancel_flag,
event_queue, timestamp_frame_queue),
name="cut_evnt_video_publish")
tasks.append(upload_video)
# 创建多个上传任务并行处理
# for _ in range(2): # 2个上传消费者
# tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_topic, cancel_flag)))
task_info = {
"output_rtmp_url": output_rtmp_url,
"task_id": task_id,
"cancel_flag": cancel_flag,
}
# 使用asyncio.shield保护主任务
main_task = asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
# 注册任务
await task_manager.add_task(
task_id,
task_info,
main_task,
tasks
)
print("start_rtmp 1")
await task_manager.update_heartbeat(task_id)
print("start_rtmp 2")
# 等待任务完成或取消
try:
print("start_rtmp 3")
await main_task
except asyncio.CancelledError:
logger.info(f"Task {task_id} was cancelled")
raise
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
raise
finally:
# 确保所有任务都被取消
cancel_flag.set()
sleep(3)
# 释放cuda内存
detector.destroy()
print("start_rtmp 4")
for task in tasks:
print("start_rtmp 5")
task.cancel()
for upload_task in upload_tasks:
print("start_rtmp 6")
upload_task.cancel()
finally:
# 确保所有任务被取消
print("start_rtmp 7")
cancel_flag.set()
print("start_rtmp 8")
# 停止所有任务
if read_task and not read_task.done():
print("start_rtmp 9")
read_task.cancel()
if process_task and not process_task.done():
print("start_rtmp 10")
process_task.cancel()
if write_task and not write_task.done():
print("start_rtmp 11")
write_task.cancel()
if invade_task and not invade_task.done():
print("start_rtmp 12")
invade_task.cancel()
for task in upload_tasks:
if task and not task.done():
task.cancel()
print("start_rtmp 13")
# # 停止设备
# if 'device' in locals():
# await device.stop()
# print("start_rtmp 14")
# 清理资源
await cleanup_resources()
await task_manager.remove_task(task_id)
logger.info(f"Task {task_id} resources cleaned up")
print(f"start_rtmp_processing线程已停止")