ai_project_v1/yolo/cv_multi_model_back_video.py

4512 lines
199 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import math
import os.path
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from time import sleep
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
import json
import time
from ultralytics import YOLO
import torch
# 假设其他导入模块正确存在
from cv_back_video import cal_tricker_results, get_local_drc_message, read_drc_mqtt, async_read_drc_mqtt
from middleware.MQTTService import MQTTService, MQTTDevice, empty_osd_callback
from middleware.TaskManager import task_manager
from middleware.entity.air_attitude import Air_Attitude
from middleware.entity.camera_para import read_camera_params, Camera_Para
from middleware.entity.detection import DetectionResult, DetectionResultList
from middleware.entity.timestamp_queue import TimestampedQueue
from middleware.minio_util import upload_file_from_buffer, upload_frame_buff_from_buffer, downFile, \
downBigFile, upload_video_buff_from_buffer
import av
import asyncio
import cv2
import numpy as np
from typing import Dict, Any
from ultralytics import solutions
from middleware.read_srt import parse_srt_file
from touying.ImageReproject_python.cal_func import cal_canv_location_by_osd, red_line_reproject
from touying.ImageReproject_python.img_types import Point
from yolo.detect.multi_yolo_trt_detect_track import MultiYoloTrtDetectorTrackId
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 全局配置
class Config:
MAX_WORKERS = 8 # 图像识别线程池大小
FRAME_QUEUE_SIZE = 5 # 帧队列大小
PROCESSED_QUEUE_SIZE = 5 # 处理后帧队列大小
RETRY_COUNT = 3 # 模型加载重试次数
READ_RTMP_WORKERS = 2 # 读取专用线程池大小
PROCESS_FRAME_WORKERS = 2 # 推理专用线程池大小
UPLOAD_WORKERS = 2 # 上传专用线程池大小
WRITE_FRAME_WORKERS = 2 # 视频推流专用线程池大小
INVADE_WORKERS = 2 # 侵限专用线程池大小
EVENT_VIDEO_WORKERS = 2 # 事件录像专用线程池大小
MAX_QUEUE_WARN_SIZE = 15 # 队列警告阈值(PROCESSED_QUEUE_SIZE的一半)
MODEL_INPUT_SIZE = (640, 640) # YOLO模型输入尺寸确保能被32整除
DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" # 默认使用第一个GPU
TARGET_FPS = 25 # 目标推流帧率
MAX_FRAME_DROP = 5 # 最大允许丢帧数(避免队列积压)
PERF_LOG_INTERVAL = 5.0 # 性能日志输出间隔(秒)
# 全局性能统计变量
perf_stats = {
'read_rtmp_frames': {'count': 0, 'total_time': 0.0},
'process_frames': {'count': 0, 'total_time': 0.0},
'write_results_to_rtmp': {'count': 0, 'total_time': 0.0},
'cal_des_invade': {'count': 0, 'total_time': 0.0},
'send_frame_to_s3_mq': {'count': 0, 'total_time': 0.0},
'last_log_time': time.time()
}
def log_perf(func_name, start_time):
"""记录函数执行时间和调用次数"""
end_time = time.time()
duration = end_time - start_time
# 更新统计
perf_stats[func_name]['count'] += 1
perf_stats[func_name]['total_time'] += duration
# 定期输出性能统计
current_time = time.time()
if current_time - perf_stats['last_log_time'] >= Config.PERF_LOG_INTERVAL:
print("\n=== 性能统计 ===")
for name, stats in perf_stats.items():
if name not in ['last_log_time']:
avg_time = stats['total_time'] / stats['count'] if stats['count'] > 0 else 0
print(f"{name}: 调用次数={stats['count']}, 平均耗时={avg_time:.4f}s")
print("===============\n")
perf_stats['last_log_time'] = current_time
# # 创建专用线程池
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
# frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
# processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
# stop_event = asyncio.Event()
# 模型缓存
model_cache = {}
# 性能统计
stats = {
'processed': 0,
'last_time': time.time(),
'avg_fps': 0.0,
'dropped_frames': 0
}
# 修改 cleanup_resources 函数,确保正确取消所有协程
async def cleanup_resources(upload_executor, executor, frame_queue,
processed_queue, invade_queue, cv_frame_queue,
stop_event):
"""清理资源"""
# global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
print("cleanup_resources 1")
stop_event.set() # 首先设置取消标志
print("cleanup_resources 2")
# 等待所有队列清空(添加超时防止死锁)
await asyncio.wait_for(asyncio.gather(
_empty_queue(frame_queue),
_empty_queue(processed_queue),
_empty_queue(invade_queue),
_empty_queue(cv_frame_queue)
), timeout=5.0)
print("cleanup_resources 8")
# 关闭线程池
if 'executor' in globals() and executor is not None:
print("cleanup_resources 9")
executor.shutdown(wait=True, cancel_futures=True)
if 'upload_executor' in globals() and upload_executor is not None:
print("cleanup_resources 10")
upload_executor.shutdown(wait=True, cancel_futures=True)
print("cleanup_resources 11")
print("cleanup_resources 14")
async def _empty_queue(queue):
"""安全清空队列"""
try:
while not queue.empty():
await queue.get()
except Exception as e:
logger.warning(f"清空队列时出错: {e}")
class MultiDetectionResults:
def __init__(self):
self.boxes = [] # 边界框
self.clss = [] # 类别ID
self.cls_names = [] # 类别名称
self.cls_en_names = [] # 英文类别名称
self.confs = [] # 置信度
self.track_ids = [] # 置信度
class MultiYOLODetector:
"""多模型并行检测器修复多GPU设备不匹配问题"""
def __init__(self, model_configs: List[Dict]):
self.models = []
self.class_maps = []
self.executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
self.allowed_classes = []
self.model_cls = []
self.chinese_label = []
self.list_func_id = []
self.conf = 0.4
self.func_id = -1
self.list_class_names = []
self.list_para_invade_enable = []
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
self._setup_multi_gpu()
for config in model_configs:
model_path = config['path']
cls_map = config.get('cls_map', {})
allowed = config.get('allowed_classes', None)
tracking = config.get('tracking', True) # 是否启用跟踪
model_cls_index = config.get('cls_index', [])
model_chinese_labe = config.get('chinese_label', [])
model_list_func_id = config.get('list_func_id', [])
func_id = config.get('func_id', -1)
class_names = config.get('class_names', True)
para_invade_enable = config.get('para_invade_enable', False)
config_conf = config.get('config_conf', 0.4)
# 加载模型并确保在正确设备上
model = self._load_model(model_path, tracking)
self.models.append(model)
self.class_maps.append(cls_map)
self.allowed_classes.append(allowed)
self.model_cls.append(model_cls_index)
self.chinese_label.append(model_chinese_labe)
self.list_func_id.append(model_list_func_id)
self.func_id = func_id
self.conf = config_conf
self.list_class_names.append(class_names)
self.list_para_invade_enable.append(para_invade_enable)
def _setup_multi_gpu(self):
"""配置多GPU环境"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"检测到 {gpu_count} 个GPU将使用默认设备: {self.device}")
# 强制所有操作使用默认设备
torch.cuda.set_device(self.device)
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
try:
# 加载模型
model = YOLO(model_path)
# model.tracker = None # 手动禁用跟踪,关闭模型相关日志输出
# 移动到目标设备
model = model.to(self.device)
# 启用跟踪仅对YOLOv8有效
if hasattr(model, 'tracker'):
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
print(f"已为模型 {model_path} 启用跟踪")
else:
print(f"警告:模型 {model_path} 不支持跟踪功能")
# 对于多GPU仅在必要时使用DataParallel并确保所有参数在同一设备
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# 确保模型参数在默认设备上
model = model.to(self.device)
model = torch.nn.DataParallel(model, device_ids=[self.device])
# 禁用模型融合以避免多设备问题
model.fuse = lambda verbose=False: model # 替换融合方法
return model
except Exception as e:
print(f"模型加载失败: {e}")
raise
@staticmethod
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
"""预处理帧:调整大小、转换格式并标准化"""
# 调整大小
resized = cv2.resize(frame, input_size)
# 转换为RGB
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
# 转换为张量并标准化
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
# 添加批次维度并移动到目标设备
tensor = tensor.unsqueeze(0).to(device)
return tensor
@staticmethod
def scale_bbox(bbox: List[int], original_size: tuple, input_size: tuple) -> List[int]:
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
oh, ow = original_size
iw, ih = input_size
# 计算缩放比例
scale_w = ow / iw
scale_h = oh / ih
# 缩放边界框
x1, y1, x2, y2 = bbox
return [
int(x1 * scale_w),
int(y1 * scale_h),
int(x2 * scale_w),
int(y2 * scale_h)
]
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
"""异步调用多模型预测"""
loop = asyncio.get_running_loop()
original_size = (frame.shape[0], frame.shape[1]) # (height, width)
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
model = self.models[model_idx]
cls_map = self.class_maps[model_idx]
allowed = self.allowed_classes[model_idx]
model_cls_index = self.model_cls[model_idx]
model_chinese_labe = self.chinese_label[model_idx]
model_list_func_id = self.list_func_id[model_idx]
func_id = self.func_id
model_class_names = self.list_class_names[model_idx]
para_invade_enable = self.list_para_invade_enable[model_idx]
# 预处理帧并确保在正确设备上
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
timestart = time.time()
with torch.no_grad():
# 确保输入和模型在同一设备
if input_tensor.device != next(model.parameters()).device:
input_tensor = input_tensor.to(next(model.parameters()).device)
results = model(input_tensor, verbose=False)
timeend = time.time()
# print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s (设备: {self.device})")
detections = []
model_para = {
"cls_map": cls_map,
"model_chinese_labe": model_chinese_labe,
"model_cls_index": model_cls_index,
"model_list_func_id": model_list_func_id,
"func_id": func_id,
"model_class_names": model_class_names,
"para_invade_enable": para_invade_enable,
"results": results
}
# model_cls_index{}
detection_result_list = DetectionResultList([], [], [], [], [])
for result in results:
boxes = result.boxes.cpu().numpy() # 移动到CPU进行处理
tracks = result.boxes.id.cpu().numpy() if result.boxes.id is not None else None
for i, box in enumerate(boxes):
cls_id = int(box.cls[0])
# 获取原始模型处理DataParallel包装
original_model = model.module if isinstance(model, torch.nn.DataParallel) else model
cls_name = original_model.names[cls_id]
# 过滤不允许的类别
if allowed and cls_name not in allowed:
continue
# 模型输出的边界框是相对于输入尺寸的
x1, y1, x2, y2 = box.xyxy[0].astype(int)
# 缩放回原始图像尺寸
scaled_bbox = self.scale_bbox(
[x1, y1, x2, y2],
original_size,
self.input_size
)
conf = float(box.conf[0])
# 临时测试用,后面需要抽取出来参数
if conf < self.conf:
continue
en_name = cls_map.get(cls_name, "unknown")
# 添加跟踪ID如果存在
track_id = int(tracks[i]) if tracks is not None else None
detection_result_list.boxes.append(scaled_bbox)
detection_result_list.clss.append(cls_id)
detection_result_list.clss_name.append(cls_name)
detection_result_list.confs.append(conf)
detection_result_list.track_ids.append(-1)
detections.append(DetectionResult(
bbox=scaled_bbox,
class_id=cls_id,
class_name=cls_name,
confidence=conf, # 新增字段
track_id=-1
))
return detections, detection_result_list, model_para
# 并行执行
futures = [
loop.run_in_executor(
self.executor,
_predict,
model_idx,
frame.copy()
)
for model_idx in range(len(self.models))
]
results = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
# 合并检测结果
all_detections = []
detections_list = [] # 这里的格式跟results 的二次计算相同
# 合并所有模型的参数到一个字典(避免键冲突)
all_model_paras = []
for model_idx, (detections, detection_result_list, model_para) in enumerate(results):
all_detections.extend(detections)
# 给每个模型的参数加前缀,避免覆盖(如 model0_model_cls_index
# prefixed_para = {f"model{model_idx}_{k}": v for k, v in model_para.items()}
# all_model_paras.update(prefixed_para)
all_model_paras.append(model_para)
detections_list.append(detection_result_list)
return all_detections, detections_list, all_model_paras
class MultiYOLODetector_TrackId:
"""多模型并行检测器修复多GPU设备不匹配问题"""
def __init__(self, model_configs: List[Dict]):
self.models = []
self.class_maps = []
self.executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
self.allowed_classes = []
self.conf = 0.4
self.model_cls = []
self.chinese_label = []
self.list_func_id = []
self.func_id = -1
self.list_class_names = []
self.list_para_invade_enable = []
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
self._setup_multi_gpu()
# 初始化ObjectCounter实例
self.object_counters = []
for config in model_configs:
model_path = config['path']
cls_map = config.get('cls_map', {})
allowed = config.get('allowed_classes', None)
tracking = config.get('tracking', True) # 是否启用跟踪
model_cls_index = config.get('cls_index', True)
model_chinese_labe = config.get('chinese_label', {})
model_list_func_id = config.get('list_func_id', -11)
func_id = config.get('func_id', True)
class_names = config.get('class_names', True)
para_invade_enable = config.get('para_invade_enable', False)
config_conf = config.get('config_conf', 0.4)
# 加载模型并确保在正确设备上
model = self._load_model(model_path, tracking)
self.models.append(model)
self.class_maps.append(cls_map)
self.allowed_classes.append(allowed)
self.model_cls.append(model_cls_index)
self.chinese_label.append(model_chinese_labe)
self.list_func_id.append(model_list_func_id)
self.func_id = func_id
self.conf = config_conf
self.list_class_names.append(class_names)
self.list_para_invade_enable.append(para_invade_enable)
# 为每个模型创建ObjectCounter实例
if hasattr(model, 'tracker'):
# 如果模型支持跟踪则创建带跟踪的ObjectCounter
self.object_counters.append(
solutions.ObjectCounter(
show=False,
region=None, # 可以根据需要设置检测区域
model=model_path,
classes=model_cls_index if allowed else None,
verbose=False # 新增:禁用性能日志打印
)
)
else:
# 如果模型不支持跟踪则创建普通ObjectCounter
self.object_counters.append(
solutions.ObjectCounter(
show=False,
region=None,
model=model_path,
classes=model_cls_index if allowed else None,
verbose=False # 新增:禁用性能日志打印
)
)
def _setup_multi_gpu(self):
"""配置多GPU环境"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
print(f"检测到 {gpu_count} 个GPU将使用默认设备: {self.device}")
# 强制所有操作使用默认设备
torch.cuda.set_device(self.device)
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
try:
# 加载模型
model = YOLO(model_path)
# model.tracker = None # 手动禁用跟踪,关闭模型相关日志输出
# 移动到目标设备
model = model.to(self.device)
# 启用跟踪仅对YOLOv8有效
if hasattr(model, 'tracker') and tracking:
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
print(f"已为模型 {model_path} 启用跟踪")
else:
print(f"警告:模型 {model_path} 不支持跟踪功能")
# 对于多GPU仅在必要时使用DataParallel并确保所有参数在同一设备
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
# 确保模型参数在默认设备上
model = model.to(self.device)
model = torch.nn.DataParallel(model, device_ids=[self.device])
# 禁用模型融合以避免多设备问题
model.fuse = lambda verbose=False: model # 替换融合方法
return model
except Exception as e:
print(f"模型加载失败: {e}")
raise
@staticmethod
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
"""预处理帧:调整大小、转换格式并标准化"""
# 获取原始帧尺寸 (width, height)
original_h, original_w = frame.shape[:2]
# 计算缩放比例,保持宽高比
scale = min(input_size[0] / original_w, input_size[1] / original_h)
new_w = int(original_w * scale)
new_h = int(original_h * scale)
# 调整大小并保持宽高比
resized = cv2.resize(frame, (new_w, new_h))
# 创建画布并保持模型输入尺寸
canvas = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
canvas[:new_h, :new_w] = resized
# 转换为RGB并标准化
rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
return tensor.unsqueeze(0).to(device)
@staticmethod
def scale_bbox(bbox, original_size, input_size):
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
ow, oh = original_size # (width, height)
iw, ih = input_size # (width, height)
# 计算缩放比例(基于输入尺寸和原始尺寸的比例)
scale_w = ow / iw
scale_h = oh / ih
x1, y1, x2, y2 = bbox
# 缩放并限制坐标范围
x1 = int(max(0, min(x1 * scale_w, ow - 1)))
y1 = int(max(0, min(y1 * scale_h, oh - 1)))
x2 = int(max(0, min(x2 * scale_w, ow - 1)))
y2 = int(max(0, min(y2 * scale_h, oh - 1)))
return [x1, y1, x2, y2]
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
"""异步调用多模型预测"""
loop = asyncio.get_running_loop()
# original_size = (frame.shape[0], frame.shape[1]) # (height, width)
original_size = (frame.shape[1], frame.shape[0]) # (widthheight )
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
try:
model = self.models[model_idx]
object_counter = self.object_counters[model_idx]
cls_map = self.class_maps[model_idx]
allowed = self.allowed_classes[model_idx]
model_cls_index = self.model_cls[model_idx]
model_chinese_labe = self.chinese_label[model_idx]
model_list_func_id = self.list_func_id[model_idx]
func_id = self.func_id
model_class_names = self.list_class_names[model_idx]
para_invade_enable = self.list_para_invade_enable[model_idx]
# 预处理帧
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
timestart = time.time()
with torch.no_grad():
if input_tensor.device != next(model.parameters()).device:
input_tensor = input_tensor.to(next(model.parameters()).device)
# 使用ObjectCounter进行检测
results = object_counter(frame)
# 调试输出结果结构
# print(f"Model {model_idx} results type: {type(results)}")
if hasattr(results, 'boxes'):
print(
f"Boxes type: {type(results.boxes)}, length: {len(results.boxes) if hasattr(results.boxes, '__len__') else 'N/A'}")
timeend = time.time()
# print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s")
detections = []
model_para = {
"cls_map": cls_map,
"model_chinese_labe": model_chinese_labe,
"model_cls_index": model_cls_index,
"model_list_func_id": model_list_func_id,
"func_id": func_id,
"model_class_names": model_class_names,
"para_invade_enable": para_invade_enable,
"results": object_counter # 保存原始结果用于调试
}
results = object_counter
# 处理检测结果
detection_result_list = DetectionResultList([], [], [], [], [])
# 尝试不同的结果访问方式
boxes = []
tracks = results.track_ids
confs = results.confs
clss = results.clss
names = results.names
result_boxes = results.boxes
if isinstance(result_boxes, list) and len(result_boxes) == 0:
return [], DetectionResultList([], [], [], [], []), {}
if isinstance(result_boxes, list) and len(result_boxes) > 0:
boxes = result_boxes
else:
boxes = result_boxes.tolist()
filter_conf = False
# 处理每个检测结果
for i, box in enumerate(boxes):
try:
# 确保box是可迭代的
if not hasattr(box, '__iter__'):
print(f"无效的边界框格式: {box}")
continue
# 转换为列表(如果还不是)
box = list(box) if not isinstance(box, list) else box
# 确保有4个坐标值
if len(box) < 4:
print(f"不完整的边界框坐标: {box}")
continue
x1, y1, x2, y2 = box[:4]
# cls_id = int(results.clss[i])
cls_id = int(results.clss[i])
ind = int(results.clss[i])
cls_name = names[ind]
conf = float(confs[i]) if i < len(confs) else 0.0
# 待优化,将参数提取出来
if conf < self.conf:
# print(f"conf self.conf{conf} {self.conf}")
# filter_conf=True
continue
track_id = int(tracks[i]) if i < len(tracks) and tracks[i] is not None else None
# 过滤不允许的类别
if allowed and cls_name not in allowed:
continue
scaled_bbox = [int(x1), int(y1), int(x2), int(y2)]
detection_result_list.boxes.append([x1, y1, x2, y2])
detection_result_list.clss.append(cls_id)
detection_result_list.clss_name.append(cls_name)
detection_result_list.confs.append(conf)
detection_result_list.track_ids.append(track_id)
detections.append(DetectionResult(
bbox=scaled_bbox,
class_id=cls_id,
class_name=cls_name,
confidence=conf,
track_id=track_id
))
except Exception as e:
print(f"处理单个检测结果时出错: {str(e)}")
continue
# # 基于conf 筛选目标
# if not filter_conf:
# return detections, detection_result_list, model_para
# else:
# return [], DetectionResultList([], [], [], [], []), {}
return detections, detection_result_list, model_para
except Exception as e:
print(f"模型 {model_idx} 预测过程中发生错误: {str(e)}")
import traceback
traceback.print_exc()
return [], DetectionResultList([], [], [], [], []), {}
# 并行执行
futures = [
loop.run_in_executor(
self.executor,
_predict,
model_idx,
frame.copy()
)
for model_idx in range(len(self.models))
]
results111 = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
# 合并检测结果
all_detections = []
detections_list = [] # 这里的格式跟results 的二次计算相同
# 合并所有模型的参数到一个字典(避免键冲突)
all_model_paras = []
for model_idx, (detections, detection_result_list, model_para) in enumerate(results111):
all_detections.extend(detections)
all_model_paras.append(model_para)
detections_list.append(detection_result_list)
return all_detections, detections_list, all_model_paras
async def read_video_frames(task_id, mqtt, mqtt_publish_topic,
local_video_path: str, srt_path: str, frame_queue: asyncio.Queue,
timestamp_frame_queue: TimestampedQueue,
cancel_flag: asyncio.Event = None, ):
srt_list = parse_srt_file(srt_path)
cap = cv2.VideoCapture(local_video_path)
if not cap.isOpened():
print("Error: Could not open video.")
return
prev_time = time.time()
frame_count = 0
fps = 0
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# # 临时使用,后续还是考虑使用文件整体传输消息好一点
# async with cache_lock:
# # 读取或修改共享变量
# video_process_status = shared_local_cache["video_process_status"]
# if video_process_status == 0:
# shared_local_cache["video_process_status"] = 1 # 将状态改为开始
# 临时处理,用作发送当前录像处理的状态,开始还是结束
async def publist_status(status_to_publish):
try:
message = {"task_id": task_id, "video_status": status_to_publish}
await mqtt.publish(mqtt_publish_topic, json.dumps(message, ensure_ascii=False))
except Exception as e:
logger.error(f"Failed to publish status: {e}")
raise
await publist_status(1)
# 控制读帧,避免队列积压,进而出现大量丢帧
cap = cv2.VideoCapture(local_video_path)
video_fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频原始FPS
frame_interval = 1.0 / video_fps # 每帧应等待的时间
last_read_time = time.time()
while not cancel_flag.is_set() and cap.isOpened():
ret, frame = cap.read()
if not ret: # 检查是否成功读取帧(视频结束或读取错误)
await publist_status(2)
# 录像的任务停止,由录像本身发起
await task_manager.remove_task(task_id)
# stop_event.set()
# # 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
# async with cache_lock:
# # 读取或修改共享变量
# video_process_status = shared_local_cache["video_process_status"]
# if video_process_status == 0:
# shared_local_cache["video_process_status"] = 2 # 将状态改为结束
break
# 控制读取速度:确保每帧间隔不小于 frame_interval
current_time = time.time()
if current_time - last_read_time < frame_interval:
await asyncio.sleep(frame_interval - (current_time - last_read_time))
last_read_time = current_time
# # 保存当前帧为图片(或直接处理)
# frame_path = f"{output_folder}/frame_{frame_count:06d}.jpg"
# cv2.imwrite(frame_path, frame)
frame_count += 1
# if frame_count % 10 == 0:
# curr_time = time.time()
# fps = 10 / (curr_time - prev_time)
# prev_time = curr_time
# print(f" ({orig_width}x{orig_height}, FPS: {fps:.2f})")
# # 可选每处理N帧打印一次进度
# if frame_count % 100 == 0:
# print(f"Processed {frame_count} frames")
if not frame_queue.full():
if frame_count < len(srt_list) and len(srt_list) > 0:
dji_srt = srt_list[frame_count]
gimbal_pitch = dji_srt.gb_pitch
gimbal_roll = dji_srt.gb_roll
gb_yaw = dji_srt.gb_yaw
air_height = dji_srt.abs_alt
cam_longitude = dji_srt.longitude
cam_latitude = dji_srt.latitude
art_tit = Air_Attitude(gimbal_pitch, gimbal_roll, gb_yaw, air_height, cam_latitude, cam_longitude)
time_ns = time.time_ns()
await frame_queue.put((frame, art_tit, time_ns))
timestamp_frame_queue.append({
"timestamp": time_ns,
"frame": frame
})
else:
await asyncio.sleep(0.01) # 队列满时稍作等待
cap.release()
# # # # # # 测试时候,临时注释
if os.path.exists(srt_path):
os.remove(srt_path)
if os.path.exists(local_video_path):
os.remove(local_video_path)
#
# async def read_rtmp_frames(
# loop,
# read_rtmp_frames_executor: ThreadPoolExecutor,
# video_url: str,
# device: Optional[MQTTDevice] = None,
# topic_camera_osd: Optional[str] = None,
# method_camera_osd: Optional[str] = None,
# topic_osd_info: Optional[str] = None,
# method_osd_info: Optional[str] = None,
# cancel_flag: Optional[asyncio.Event] = None,
# frame_queue: asyncio.Queue = None,
# timestamp_frame_queue: TimestampedQueue = None
# ):
# """
# 异步读取 RTMP 流帧(优化版:移除帧率控制,优化线程池)
# """
# max_retries = 20
# retry_delay = 2
# pic_count = 0
# attempt = 0
# time_start = time.time_ns() # 添加开始时间统计
# frame_count = 0 # 统计总帧数
#
# if cancel_flag is None:
# cancel_flag = asyncio.Event()
#
# # loop = asyncio.get_running_loop()
#
# # 打印初始统计信息
# print(f"开始读取RTMP流: {video_url}")
#
# while not cancel_flag.is_set() and attempt < max_retries:
# attempt += 1
# if cancel_flag.is_set():
# logger.info("收到停止信号,终止 RTMP 读取")
# break
#
# container = None
# try:
# logger.info(f"尝试连接 RTMP 流 (尝试 {attempt}/{max_retries}): {video_url}")
# # 1. 关键优化:将同步的 av.open 和流初始化放到线程池
# container = await loop.run_in_executor(read_rtmp_frames_executor, av.open, video_url)
# video_stream = await loop.run_in_executor(read_rtmp_frames_executor, next,
# (s for s in container.streams if s.type == 'video'))
# logger.info(f"成功连接到 RTMP 流: {video_url} ({video_stream.width}x{video_stream.height})")
#
# # 2. 提前获取一次OSD消息验证MQTT是否正常
# if device and topic_osd_info and method_osd_info:
# osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
# if osd_msg:
# logger.info(f"初始OSD消息获取成功: 高度={osd_msg.data.height}")
# else:
# logger.warning("初始OSD消息为空可能MQTT尚未收到消息")
#
# # 3. 关键优化:将同步的帧迭代放到线程池,通过生成器异步获取
# async def async_frame_generator():
# """异步帧生成器在后台线程迭代同步帧通过yield返回给事件循环"""
#
# def sync_frame_iter():
# try:
# for frame in container.decode(video=0):
# # 线程内检查取消标志(需定期检查,避免线程无法退出)
# if cancel_flag.is_set():
# logger.info("后台线程检测到取消信号,停止帧迭代")
# break
#
# # 确保是3通道RGB
# if len(frame.planes) == 1: # 如果是灰度图
# gray = frame.to_ndarray(format='gray')
# # 转换为3通道BGR不修改尺寸
# bgr = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
# yield bgr
# else:
# # 保持原始尺寸和色彩空间,只转换格式
# bgr = frame.to_ndarray(format='bgr24')
# yield bgr
# except Exception as e:
# logger.error(f"同步帧迭代出错: {e}")
# finally:
# if container:
# container.close()
# logger.info("RTMP容器已关闭")
#
# # 将同步迭代器包装为异步生成器
# gen = sync_frame_iter()
# while not cancel_flag.is_set():
# try:
# # 每次获取一帧都通过线程池执行,避免长时间阻塞
# frame = await loop.run_in_executor(read_rtmp_frames_executor, next, gen, None)
# if frame is None: # 迭代结束
# break
# yield frame
# except StopIteration:
# logger.info("RTMP流帧迭代结束")
# break
# except Exception as e:
# logger.error(f"异步获取帧出错: {e}")
# break
#
# # 4. 异步迭代帧(不阻塞事件循环)
# async for frame in async_frame_generator():
# if cancel_flag.is_set():
# logger.info("检测到取消信号,停止读取帧")
# break
#
# try:
# # 5. 帧转换也放到线程池av.Frame.to_ndarray是CPU密集操作
# img = frame.copy() # 确保不修改原始帧
# osd_info = None
#
# # 6. 此时事件循环未被阻塞MQTT消息已缓存get_latest_message可即时获取
# if device and topic_osd_info and method_osd_info:
# osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
# if osd_msg and hasattr(osd_msg, 'data'):
# osd_info = Air_Attitude(
# gimbal_pitch=osd_msg.data.gimbal_pitch,
# gimbal_roll=osd_msg.data.gimbal_roll,
# gimbal_yaw=osd_msg.data.gimbal_yaw,
# height=osd_msg.data.height,
# latitude=osd_msg.data.latitude,
# longitude=osd_msg.data.longitude
# )
#
# # 7. 异步放入帧队列(避免队列满时阻塞)
# if not frame_queue.full():
# pic_count += 1
# frame_count += 1 # 增加总帧数统计
# time_ns = time.time_ns()
#
# # 定期输出统计信息每1000帧
# if time_ns - time_start > 1000000000:
# print(f"readFrames {pic_count}")
# pic_count = 0
# time_start = time_ns
# if img is not None and osd_info is not None:
# await frame_queue.put((img, osd_info, time_ns))
# timestamp_frame_queue.append({
# "timestamp": time_ns,
# "frame": img
# })
# logger.debug(
# f"已放入帧队列,累计帧数: {pic_count},队列剩余空间: {frame_queue.maxsize - frame_queue.qsize()}")
# else:
# logger.warning("帧队列已满等待1ms后重试")
# await asyncio.sleep(0.001)
#
# except Exception as frame_error:
# logger.error(f"处理单帧时出错: {frame_error}", exc_info=True)
# continue
#
# except (av.AVError, IOError) as e:
# logger.error(f"RTMP 流错误 (尝试 {attempt}/{max_retries}): {e}")
# if attempt < max_retries:
# await asyncio.sleep(retry_delay)
# else:
# raise RuntimeError(f"无法连接 RTMP 流 (尝试 {max_retries} 次后失败): {video_url}")
# except asyncio.CancelledError:
# logger.info("read_rtmp_frames 收到取消信号")
# raise
# except Exception as e:
# logger.error(f"未知错误: {e}", exc_info=True)
# if attempt < max_retries:
# await asyncio.sleep(retry_delay)
# finally:
# # 双重保险:确保容器关闭
# if container and not container.closed:
# await loop.run_in_executor(None, container.close)
# logger.info("RTMP容器在finally中关闭")
#
# # 最终统计信息
# if frame_count > 0:
# total_time = (time.time_ns() - time_start) / 1e9
# avg_fps = frame_count / total_time if total_time > 0 else 0
# print(f"RTMP流读取完成总帧数: {frame_count}, 总时间: {total_time:.2f}秒, 平均FPS: {avg_fps:.2f}")
# else:
# print("RTMP流读取失败未获取到任何帧")
#
# logger.info(f"RTMP 流已结束或被取消,累计处理帧数: {pic_count}")
# ------------------------------- 下述方法使用ffmpeg 拉流可以解决cv2拉流的一些问题主要是虚拟环境ffmpeg不匹配的问题。但是ffmpeg拉流慢3s左右
# import cv2
# import json
# import asyncio
# from typing import Optional
# from concurrent.futures import ThreadPoolExecutor
#
#
# async def read_rtmp_frames(
# loop,
# read_rtmp_frames_executor: ThreadPoolExecutor,
# video_url: str,
# device: Optional[MQTTDevice] = None,
# topic_camera_osd: Optional[str] = None,
# method_camera_osd: Optional[str] = None,
# topic_osd_info: Optional[str] = None,
# method_osd_info: Optional[str] = None,
# cancel_flag: Optional[asyncio.Event] = None,
# frame_queue: asyncio.Queue = None,
# timestamp_frame_queue=None
# ):
# """
# 基于 FFmpeg 读取 RTMP 流帧(优化版:高性能读取,处理损坏帧)
# """
# max_retries = 20
# retry_delay = 2
# pic_count = 0
# attempt = 0
# time_start = time.time_ns()
# frame_count = 0
#
# if cancel_flag is None:
# cancel_flag = asyncio.Event()
#
# print(f"开始读取RTMP流: {video_url}")
#
# while not cancel_flag.is_set() and attempt < max_retries:
# attempt += 1
# if cancel_flag.is_set():
# logger.info("收到停止信号,终止 RTMP 读取")
# break
#
# ffmpeg_process = None
# width, height = None, None
# frame_size = None
#
# try:
# logger.info(f"尝试连接 RTMP 流 (尝试 {attempt}/{max_retries}): {video_url}")
#
# # 1. 探测视频分辨率
# width, height = await detect_video_resolution(loop, read_rtmp_frames_executor, video_url)
#
# if width is None or height is None:
# logger.warning("使用默认分辨率 1920x1080")
# width, height = 1920, 1080
#
# frame_size = width * height * 3
# logger.info(f"视频分辨率: {width}x{height}, 帧大小: {frame_size} bytes")
#
# # 2. 启动 FFmpeg 进程(优化参数提高性能)
# ffmpeg_cmd = [
# 'ffmpeg',
# '-hide_banner',
# '-loglevel', 'warning', # 改为warning可以看到更多错误信息
# '-fflags', '+nobuffer+genpts',
# '-err_detect', 'ignore_err',
# '-max_delay', '0',
# '-flags', 'low_delay',
# '-i', video_url,
# '-an', # 无音频
# '-c:v', 'rawvideo', # 关键:输出原始视频帧,而不是复制编码
# '-pix_fmt', 'bgr24', # OpenCV使用BGR格式
# '-f', 'rawvideo', # 关键:输出原始视频格式
# '-flush_packets', '1',
# '-'
# ]
# ffmpeg_process = await loop.run_in_executor(
# read_rtmp_frames_executor,
# lambda: subprocess.Popen(
# ffmpeg_cmd,
# stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
# bufsize=frame_size # 设置合适的缓冲区大小
# )
# )
#
# logger.info(f"成功启动 FFmpeg 进程连接 RTMP 流: {video_url}")
#
# # 3. 初始化帧读取状态
# frame_sequence = 0
# last_timestamp = time_start
# consecutive_corrupted_frames = 0 # 连续损坏帧计数
# max_consecutive_corrupted = 10 # 最大连续损坏帧数
#
# while not cancel_flag.is_set():
# try:
# # 直接读取完整帧(高性能方式)
# raw_frame = await loop.run_in_executor(
# read_rtmp_frames_executor,
# lambda: ffmpeg_process.stdout.read(frame_size)
# )
#
# if not raw_frame:
# logger.warning("读取到空帧数据,流可能已结束")
# break
#
# current_time_ns = time.time_ns()
# frame_sequence += 1
# frame_count += 1
#
# # 处理帧数据(无论是否完整)
# img = None
# is_corrupted = False
# print(f"读取 read_rtmp_frames 判断")
# try:
# if len(raw_frame) == frame_size:
# # 完整帧处理
# frame = np.frombuffer(raw_frame, dtype=np.uint8).reshape((height, width, 3))
# img = frame.copy()
# consecutive_corrupted_frames = 0 # 重置连续损坏计数
# else:
# # 损坏帧处理
# logger.warning(f"帧数据损坏: {len(raw_frame)}/{frame_size} bytes, 序列: {frame_sequence}")
# is_corrupted = True
# consecutive_corrupted_frames += 1
#
# # 创建替代帧
# if consecutive_corrupted_frames <= max_consecutive_corrupted:
# # 尝试部分恢复
# img = np.zeros((height, width, 3), dtype=np.uint8)
# valid_data = min(len(raw_frame), frame_size)
# if valid_data > 0:
# # 尽可能填充有效数据
# temp_frame = np.frombuffer(raw_frame[:valid_data], dtype=np.uint8)
# img.flat[:len(temp_frame)] = temp_frame
# else:
# # 连续损坏过多,创建空白帧
# img = np.zeros((height, width, 3), dtype=np.uint8)
# logger.error(f"连续损坏帧过多 ({consecutive_corrupted_frames}),创建空白帧")
#
# except Exception as frame_error:
# logger.error(f"帧数据处理错误: {frame_error}")
# # 创建空白帧作为后备
# img = np.zeros((height, width, 3), dtype=np.uint8)
# is_corrupted = True
# consecutive_corrupted_frames += 1
# print(f"读取 read_rtmp_frames 判断1")
# # 获取OSD信息
# osd_info = None
# if device and topic_osd_info and method_osd_info:
# try:
# osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
# if osd_msg and hasattr(osd_msg, 'data'):
# osd_info = Air_Attitude(
# gimbal_pitch=osd_msg.data.gimbal_pitch,
# gimbal_roll=osd_msg.data.gimbal_roll,
# gimbal_yaw=osd_msg.data.gimbal_yaw,
# height=osd_msg.data.height,
# latitude=osd_msg.data.latitude,
# longitude=osd_msg.data.longitude
# )
# except Exception as osd_error:
# logger.warning(f"获取OSD信息失败: {osd_error}")
# print(f"读取 read_rtmp_frames 判断2")
# # 放入帧队列
# if img is not None and not frame_queue.full():
# # 确保时间戳递增
# if current_time_ns <= last_timestamp:
# current_time_ns = last_timestamp + 1
# last_timestamp = current_time_ns
#
# # 统计信息
# pic_count += 1
# if current_time_ns - time_start > 1000000000: # 1秒
# elapsed_seconds = (current_time_ns - time_start) / 1e9
# fps = pic_count / elapsed_seconds if elapsed_seconds > 0 else 0
# corrupted_rate = (consecutive_corrupted_frames / pic_count * 100) if pic_count > 0 else 0
# print(
# f"readFrames 序列:{frame_sequence} 帧数:{pic_count} FPS:{fps:.2f} 损坏率:{corrupted_rate:.1f}%")
# pic_count = 0
# time_start = current_time_ns
#
# print(f"读取 read_rtmp_frames 实时流")
#
# # 准备帧数据
# frame_data = {
# "sequence": frame_sequence,
# "frame": img,
# "osd_info": osd_info,
# "timestamp": current_time_ns,
# "is_corrupted": is_corrupted
# }
# time_ns = time.time_ns()
# if img is not None and osd_info is not None:
# await frame_queue.put((img, osd_info, time_ns))
# timestamp_frame_queue.append({
# "timestamp": time_ns,
# "frame": img
# })
#
# if frame_sequence % 100 == 0: # 每100帧输出一次日志
# logger.debug(f"已处理帧 序列:{frame_sequence} 累计:{frame_count}")
#
# elif frame_queue.full():
# logger.warning("帧队列已满,跳过此帧")
# await asyncio.sleep(0.001) # 短暂等待
#
# # 检查是否需要重新探测分辨率(仅在连续损坏时)
# if consecutive_corrupted_frames > 5:
# logger.warning("连续帧损坏,尝试重新探测分辨率")
# try:
# new_width, new_height = await detect_video_resolution(loop, read_rtmp_frames_executor,
# video_url)
# if new_width and new_height and (new_width != width or new_height != height):
# logger.info(f"分辨率变化: {width}x{height} -> {new_width}x{new_height}")
# width, height = new_width, new_height
# frame_size = width * height * 3
# consecutive_corrupted_frames = 0 # 重置计数
# except Exception as probe_error:
# logger.warning(f"重新探测分辨率失败: {probe_error}")
#
# except Exception as e:
# logger.error(f"读取帧数据时出错: {e}", exc_info=True)
# # 检查 FFmpeg 进程状态
# if ffmpeg_process and ffmpeg_process.poll() is not None:
# try:
# stderr_output = ffmpeg_process.stderr.read().decode('utf-8', errors='ignore')
# if stderr_output:
# logger.error(f"FFmpeg 进程错误: {stderr_output}")
# except:
# pass
# logger.error("FFmpeg 进程已退出")
# break
# # 短暂等待后继续
# await asyncio.sleep(0.01)
# continue
#
# except (subprocess.SubprocessError, IOError) as e:
# logger.error(f"RTMP 流错误 (尝试 {attempt}/{max_retries}): {e}")
# if attempt < max_retries:
# await asyncio.sleep(retry_delay)
# else:
# raise RuntimeError(f"无法连接 RTMP 流 (尝试 {max_retries} 次后失败): {video_url}")
# except asyncio.CancelledError:
# logger.info("read_rtmp_frames 收到取消信号")
# raise
# except Exception as e:
# logger.error(f"未知错误: {e}", exc_info=True)
# if attempt < max_retries:
# await asyncio.sleep(retry_delay)
# finally:
# if ffmpeg_process:
# try:
# ffmpeg_process.terminate()
# try:
# await asyncio.wait_for(
# loop.run_in_executor(read_rtmp_frames_executor, ffmpeg_process.wait),
# timeout=2.0
# )
# except asyncio.TimeoutError:
# ffmpeg_process.kill()
# await loop.run_in_executor(read_rtmp_frames_executor, ffmpeg_process.wait)
# except Exception as e:
# logger.warning(f"关闭FFmpeg进程时出错: {e}")
# logger.info("FFmpeg 进程已关闭")
#
# if frame_count > 0:
# total_time = (time.time_ns() - time_start) / 1e9
# avg_fps = frame_count / total_time if total_time > 0 else 0
# print(f"RTMP流读取完成总帧数: {frame_count}, 总时间: {total_time:.2f}秒, 平均FPS: {avg_fps:.2f}")
# else:
# print("RTMP流读取失败未获取到任何帧")
#
# logger.info(f"RTMP 流已结束或被取消,累计处理帧数: {frame_count}")
#
# async def detect_video_resolution(loop, executor, video_url):
# """
# 探测视频流的分辨率(修复版)
# """
# try:
# logger.info(f"开始探测视频分辨率: {video_url}")
#
# # 方法1: 使用简单的ffprobe命令更可靠
# simple_cmd = [
# 'ffprobe',
# '-v', 'error',
# '-select_streams', 'v:0',
# '-show_entries', 'stream=width,height',
# '-of', 'csv=p=0:s=x',
# video_url
# ]
#
# def run_simple_probe():
# try:
# result = subprocess.run(simple_cmd, capture_output=True, text=True, timeout=15)
# logger.info(f"ffprobe返回码: {result.returncode}, 输出: {result.stdout.strip()}")
#
# if result.returncode == 0 and result.stdout.strip():
# dimensions = result.stdout.strip().split('x')
# if len(dimensions) == 2:
# width = int(dimensions[0])
# height = int(dimensions[1])
# if width > 0 and height > 0:
# return width, height
# except Exception as e:
# logger.warning(f"简单分辨率探测失败: {e}")
# return None
#
# dimensions = await loop.run_in_executor(executor, run_simple_probe)
# if dimensions:
# width, height = dimensions
# logger.info(f"探测到视频分辨率: {width}x{height}")
# return width, height
#
# # 方法2: 使用详细的ffprobe命令
# detailed_cmd = [
# 'ffprobe',
# '-v', 'quiet',
# '-print_format', 'json',
# '-show_streams',
# video_url
# ]
#
# def run_detailed_probe():
# try:
# result = subprocess.run(detailed_cmd, capture_output=True, text=True, timeout=15)
# if result.returncode == 0:
# data = json.loads(result.stdout)
# if 'streams' in data:
# for stream in data['streams']:
# if stream.get('codec_type') == 'video':
# width = stream.get('width')
# height = stream.get('height')
# if width and height:
# return int(width), int(height)
# except Exception as e:
# logger.warning(f"详细分辨率探测失败: {e}")
# return None
#
# dimensions = await loop.run_in_executor(executor, run_detailed_probe)
# if dimensions:
# width, height = dimensions
# logger.info(f"探测到视频分辨率: {width}x{height}")
# return width, height
#
# # 方法3: 尝试使用ffmpeg快速探测
# quick_cmd = [
# 'ffmpeg',
# '-i', video_url,
# '-t', '1', # 只读取1秒
# '-f', 'null',
# '-'
# ]
#
# def run_quick_probe():
# try:
# result = subprocess.run(quick_cmd, capture_output=True, text=True, timeout=10)
# # 从stderr中解析分辨率信息
# if result.stderr:
# import re
# # 尝试从输出中解析分辨率
# resolution_match = re.search(r'(\d+)x(\d+)', result.stderr)
# if resolution_match:
# width = int(resolution_match.group(1))
# height = int(resolution_match.group(2))
# if width > 0 and height > 0:
# return width, height
# except Exception as e:
# logger.warning(f"快速分辨率探测失败: {e}")
# return None
#
# dimensions = await loop.run_in_executor(executor, run_quick_probe)
# if dimensions:
# width, height = dimensions
# logger.info(f"探测到视频分辨率: {width}x{height}")
# return width, height
#
# logger.warning("所有分辨率探测方法都失败,使用默认值 1920x1080")
# return 1920, 1080
#
# except Exception as e:
# logger.error(f"分辨率探测异常: {e}")
# logger.warning("使用默认分辨率 1920x1080")
# return 1920, 1080
import cv2
import asyncio
from typing import Optional
from concurrent.futures import ThreadPoolExecutor
# 使用cv2 拉流避免了ffmpeg 拉流的rtmp延时3s的问题
# async def read_rtmp_frames(
# loop,
# read_rtmp_frames_executor: ThreadPoolExecutor,
# video_url: str,
# device: Optional[MQTTDevice] = None,
# topic_camera_osd: Optional[str] = None,
# method_camera_osd: Optional[str] = None,
# topic_osd_info: Optional[str] = None,
# method_osd_info: Optional[str] = None,
# cancel_flag: Optional[asyncio.Event] = None,
# frame_queue: asyncio.Queue = None,
# timestamp_frame_queue=None
# ):
# """
# 基于 OpenCV+FFmpeg 读取 RTMP 流帧(优化版:高性能读取,处理损坏帧)
# ✅ 核心修改替换原FFmpeg子进程读流为 cv2.VideoCapture 读流,保留所有原有业务逻辑
# ✅ 核心优化:自适应分辨率、低延迟无残影、断流自动重连、超时保护
# """
# max_retries = 20
# retry_delay = 2
# pic_count = 0
# attempt = 0
# time_start = time.time_ns()
# frame_count = 0
#
# if cancel_flag is None:
# cancel_flag = asyncio.Event()
# if timestamp_frame_queue is None:
# timestamp_frame_queue = []
#
# print(f"开始读取RTMP流: {video_url}")
#
# while not cancel_flag.is_set() and attempt < max_retries:
# attempt += 1
# if cancel_flag.is_set():
# logger.info("收到停止信号,终止 RTMP 读取")
# break
#
# cap = None
# width, height = None, None
# stream_fps = None
#
# try:
# logger.info(f"尝试连接 RTMP 流 (尝试 {attempt}/{max_retries}): {video_url}")
#
# # ✅ 核心替换使用cv2.VideoCapture + CAP_FFMPEG 打开RTMP流最优参数配置
# # 切换到线程池执行opencv操作避免阻塞协程
# cap = await loop.run_in_executor(
# read_rtmp_frames_executor,
# lambda: cv2.VideoCapture(video_url, cv2.CAP_FFMPEG)
# )
# # 设置核心参数 - 重中之重,缺一不可
# await loop.run_in_executor(read_rtmp_frames_executor, lambda: cap.set(cv2.CAP_PROP_OPEN_TIMEOUT_MSEC, 60000))
# await loop.run_in_executor(read_rtmp_frames_executor, lambda: cap.set(cv2.CAP_PROP_READ_TIMEOUT_MSEC, 50000))
# await loop.run_in_executor(read_rtmp_frames_executor, lambda: cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)) # 无缓存,低延迟无残影
# await loop.run_in_executor(read_rtmp_frames_executor, lambda: cap.set(cv2.CAP_PROP_FPS, 25))
#
# # 校验流是否成功打开
# is_opened = await loop.run_in_executor(read_rtmp_frames_executor, lambda: cap.isOpened())
# if not is_opened:
# logger.warning(f"尝试 {attempt} 次打开RTMP流失败准备重试")
# await loop.run_in_executor(read_rtmp_frames_executor, lambda: cap.release() if cap else None)
# await asyncio.sleep(retry_delay)
# continue
#
# # ✅ 自适应获取流的【真实分辨率】,无需手动探测,精准无误差
# width = await loop.run_in_executor(read_rtmp_frames_executor, lambda: int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)))
# height = await loop.run_in_executor(read_rtmp_frames_executor, lambda: int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
# stream_fps = await loop.run_in_executor(read_rtmp_frames_executor, lambda: cap.get(cv2.CAP_PROP_FPS))
#
# # 兜底分辨率,防止异常
# if width is None or height is None or width == 0 or height == 0:
# logger.warning("使用默认分辨率 1920x1080")
# width, height = 1920, 1080
#
# logger.info(f"视频分辨率: {width}x{height}, 流帧率: {stream_fps:.1f} FPS")
# logger.info(f"成功启动 OpenCV+FFmpeg 连接 RTMP 流: {video_url}")
#
# # 初始化帧读取状态 (保留原逻辑不变)
# frame_sequence = 0
# last_timestamp = time_start
# consecutive_corrupted_frames = 0 # 连续损坏帧计数
# max_consecutive_corrupted = 10 # 最大连续损坏帧数
#
# while not cancel_flag.is_set():
# try:
# # ✅ 核心替换使用cv2.read()读取帧,线程池执行避免阻塞协程
# ret, frame = await loop.run_in_executor(
# read_rtmp_frames_executor,
# lambda: cap.read()
# )
#
# current_time_ns = time.time_ns()
# frame_sequence += 1
# frame_count += 1
#
# # 处理帧数据(保留原逻辑完全不变,兼容原有的损坏帧处理)
# img = None
# is_corrupted = False
# print(f"读取 read_rtmp_frames 判断")
# try:
# if ret and frame is not None and frame.shape == (height, width, 3):
# # 完整有效帧处理
# img = frame.copy()
# consecutive_corrupted_frames = 0 # 重置连续损坏计数
# else:
# # 损坏帧/空帧处理
# logger.warning(f"帧数据损坏/空帧, 序列: {frame_sequence}")
# is_corrupted = True
# consecutive_corrupted_frames += 1
#
# # 创建替代帧,保留原逻辑
# if consecutive_corrupted_frames <= max_consecutive_corrupted:
# img = np.zeros((height, width, 3), dtype=np.uint8)
# else:
# img = np.zeros((height, width, 3), dtype=np.uint8)
# logger.error(f"连续损坏帧过多 ({consecutive_corrupted_frames}),创建空白帧")
#
# except Exception as frame_error:
# logger.error(f"帧数据处理错误: {frame_error}")
# # 创建空白帧作为后备,保留原逻辑
# img = np.zeros((height, width, 3), dtype=np.uint8)
# is_corrupted = True
# consecutive_corrupted_frames += 1
# print(f"读取 read_rtmp_frames 判断1")
#
# # 获取OSD信息 - 保留原逻辑完全不变
# osd_info = None
# if device and topic_osd_info and method_osd_info:
# try:
# osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
# if osd_msg and hasattr(osd_msg, 'data'):
# osd_info = Air_Attitude(
# gimbal_pitch=osd_msg.data.gimbal_pitch,
# gimbal_roll=osd_msg.data.gimbal_roll,
# gimbal_yaw=osd_msg.data.gimbal_yaw,
# height=osd_msg.data.height,
# latitude=osd_msg.data.latitude,
# longitude=osd_msg.data.longitude
# )
# except Exception as osd_error:
# logger.warning(f"获取OSD信息失败: {osd_error}")
# print(f"读取 read_rtmp_frames 判断2")
#
# # 放入帧队列 - 保留原逻辑完全不变
# if img is not None and not frame_queue.full():
# # 确保时间戳递增
# if current_time_ns <= last_timestamp:
# current_time_ns = last_timestamp + 1
# last_timestamp = current_time_ns
#
# # 统计信息 - 保留原逻辑完全不变
# pic_count += 1
# if current_time_ns - time_start > 1000000000: # 1秒
# elapsed_seconds = (current_time_ns - time_start) / 1e9
# fps = pic_count / elapsed_seconds if elapsed_seconds > 0 else 0
# corrupted_rate = (consecutive_corrupted_frames / pic_count * 100) if pic_count > 0 else 0
# print(
# f"readFrames 序列:{frame_sequence} 帧数:{pic_count} FPS:{fps:.2f} 损坏率:{corrupted_rate:.1f}%")
# pic_count = 0
# time_start = current_time_ns
#
# print(f"读取 read_rtmp_frames 实时流")
#
# # 准备帧数据 - 保留原逻辑完全不变
# frame_data = {
# "sequence": frame_sequence,
# "frame": img,
# "osd_info": osd_info,
# "timestamp": current_time_ns,
# "is_corrupted": is_corrupted
# }
# time_ns = time.time_ns()
# if img is not None and osd_info is not None:
# await frame_queue.put((img, osd_info, time_ns))
# timestamp_frame_queue.append({
# "timestamp": time_ns,
# "frame": img
# })
#
# if frame_sequence % 100 == 0: # 每100帧输出一次日志
# logger.debug(f"已处理帧 序列:{frame_sequence} 累计:{frame_count}")
#
# elif frame_queue.full():
# logger.warning("帧队列已满,跳过此帧")
# await asyncio.sleep(0.001) # 短暂等待
#
# # 连续帧损坏触发重连前置判断 - 保留原逻辑
# if consecutive_corrupted_frames > max_consecutive_corrupted:
# logger.warning(f"连续{consecutive_corrupted_frames}帧损坏,触发流重连逻辑")
# break
#
# except Exception as e:
# logger.error(f"读取帧数据时出错: {e}", exc_info=True)
# break
#
# except Exception as e:
# logger.error(f"RTMP 流错误 (尝试 {attempt}/{max_retries}): {e}")
# if attempt < max_retries:
# await asyncio.sleep(retry_delay)
# else:
# raise RuntimeError(f"无法连接 RTMP 流 (尝试 {max_retries} 次后失败): {video_url}")
# finally:
# # ✅ 释放opencv的VideoCapture资源替代原FFmpeg进程的关闭逻辑
# if cap is not None:
# await loop.run_in_executor(
# read_rtmp_frames_executor,
# lambda: cap.release()
# )
# logger.info("OpenCV VideoCapture 资源已释放")
#
# if frame_count > 0:
# total_time = (time.time_ns() - time_start) / 1e9
# avg_fps = frame_count / total_time if total_time > 0 else 0
# print(f"RTMP流读取完成总帧数: {frame_count}, 总时间: {total_time:.2f}秒, 平均FPS: {avg_fps:.2f}")
# else:
# print("RTMP流读取失败未获取到任何帧")
#
# logger.info(f"RTMP 流已结束或被取消,累计处理帧数: {frame_count}")
# SEI修复核心配置
os.environ["OPENCV_FFMPEG_LOG_LEVEL"] = "ERROR" # 屏蔽SEI错误日志
MAX_RETRIES = 5
RETRY_DELAY = 2
BUFFER_SIZE = 1 # 最小缓冲区减少SEI积压
TARGET_FPS = 25
FOURCC = cv2.VideoWriter_fourcc(*'H264')
MAX_CORRUPTED = 30
#
# async def read_rtmp_frames(
# loop,
# read_rtmp_frames_executor: ThreadPoolExecutor,
# video_url: str,
# device: Optional[MQTTDevice] = None,
# topic_camera_osd: Optional[str] = None,
# method_camera_osd: Optional[str] = None,
# topic_osd_info: Optional[str] = None,
# method_osd_info: Optional[str] = None,
# cancel_flag: Optional[asyncio.Event] = None,
# frame_queue: asyncio.Queue = None,
# timestamp_frame_queue: TimestampedQueue = None
# ):
# """优化版RTMP流读取集成SEI修复+完整逻辑)
# 核心修复:
# 1. 屏蔽FFmpeg SEI截断日志
# 2. 精简OpenCV参数仅保留Python支持的核心配置
# 3. 增强帧格式校验和异常处理
# 4. 修复事件循环嵌套运行的致命错误
# 5. 优化重连机制和SEI帧跳过逻辑
# """
# print(f"开始读取RTMP流: {video_url}")
#
# # ✅ 关键修复1设置FFmpeg全局参数屏蔽SEI帧截断日志
# os.environ["OPENCV_FFMPEG_LOG_LEVEL"] = "ERROR" # 只输出致命错误屏蔽SEI相关警告
#
# def ensure_cv8uc3(frame):
# """确保帧格式为CV_8UC3增强版修复"""
# if frame is None:
# return None
# if frame.dtype != np.uint8:
# frame = frame.astype(np.uint8)
# if len(frame.shape) == 2: # 灰度图转彩色
# frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
# elif frame.shape[2] == 4: # RGBA 转 BGR
# frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
# return frame
#
# # ✅ 关键修复2将init_capture改为同步函数核心解决事件循环冲突
# def init_capture(attempt: int = 1):
# """初始化捕获器同步版本带SEI修复参数+重连逻辑)"""
# print(f"第 {attempt}/{MAX_RETRIES} 次尝试初始化RTMP捕获器")
#
# # 指定FFmpeg后端创建捕获器
# cap = cv2.VideoCapture(video_url, cv2.CAP_FFMPEG)
# if not cap.isOpened():
# raise RuntimeError(f"无法打开RTMP流 (第{attempt}次尝试)")
#
# # 仅保留Python版OpenCV支持的核心参数
# cap.set(cv2.CAP_PROP_BUFFERSIZE, BUFFER_SIZE) # 缓冲区设为1减少SEI帧积压
# cap.set(cv2.CAP_PROP_FOURCC, FOURCC) # 指定H264解码器减少SEI解析开销
# cap.set(cv2.CAP_PROP_FPS, TARGET_FPS) # 同步流帧率,避免冗余处理
#
# # 获取分辨率(增加异常处理)
# try:
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# if width <= 0 or height <= 0:
# width, height = 1280, 720 # 默认分辨率兜底
# except Exception as e:
# print(f"获取分辨率失败,使用默认值: {e}")
# width, height = 1280, 720
#
# print(f"RTMP捕获器初始化成功分辨率: {width}x{height}")
# return cap, (width, height)
#
# # 初始化捕获器(支持重连)
# cap = None
# width, height = 1280, 720
# for attempt in range(1, MAX_RETRIES + 1):
# try:
# # ✅ 关键修复3直接在线程池执行同步的init_capture不再嵌套事件循环
# cap, (width, height) = await loop.run_in_executor(
# read_rtmp_frames_executor,
# init_capture, # 直接传函数不再用lambda嵌套loop.run_until_complete
# attempt # 传递attempt参数
# )
# break
# except RuntimeError as e:
# print(f"初始化失败: {e}")
# if attempt >= MAX_RETRIES:
# raise RuntimeError(f"所有{MAX_RETRIES}次初始化尝试均失败")
# await asyncio.sleep(RETRY_DELAY)
#
# try:
# last_valid_frame = np.zeros((height, width, 3), dtype=np.uint8)
# consecutive_corrupted_frames = 0
# frame_count = 0
# time_start = time.time_ns()
#
# while not cancel_flag.is_set():
# try:
# # 读取帧(使用线程池避免阻塞事件循环)
# ret, frame = await loop.run_in_executor(
# read_rtmp_frames_executor,
# cap.read # 直接传递方法,更简洁
# )
#
# # 增强SEI帧/损坏帧处理逻辑
# current_frame = last_valid_frame.copy()
# if ret and frame is not None and frame.size > 0:
# # 正常帧:格式转换 + 更新兜底帧
# processed_frame = ensure_cv8uc3(frame)
# if processed_frame is not None:
# current_frame = processed_frame
# last_valid_frame = current_frame.copy()
# consecutive_corrupted_frames = 0
# else:
# consecutive_corrupted_frames += 1
# else:
# # SEI帧/损坏帧:使用兜底帧 + 计数
# consecutive_corrupted_frames += 1
# if consecutive_corrupted_frames % 15 == 0:
# print(f"跳过SEI/损坏帧 (连续: {consecutive_corrupted_frames})")
#
# # 获取OSD信息保留原逻辑
# osd_info = None
# if device and topic_osd_info and method_osd_info:
# try:
# osd_msg = await loop.run_in_executor(
# read_rtmp_frames_executor,
# device.get_latest_message,
# topic_osd_info,
# method_osd_info
# )
# if osd_msg and hasattr(osd_msg, 'data'):
# osd_info = Air_Attitude(
# gimbal_pitch=osd_msg.data.gimbal_pitch,
# gimbal_roll=osd_msg.data.gimbal_roll,
# gimbal_yaw=osd_msg.data.gimbal_yaw,
# height=osd_msg.data.height,
# latitude=osd_msg.data.latitude,
# longitude=osd_msg.data.longitude
# )
# except Exception as e:
# print(f"获取OSD信息失败: {str(e)}")
#
# # 时间戳和队列处理(优化超时逻辑)
# timestamp = time.time_ns()
# if frame_queue is not None and timestamp_frame_queue is not None:
# try:
# # 非阻塞放入队列,避免长时间等待
# if not frame_queue.full():
# await asyncio.wait_for(
# frame_queue.put((current_frame, osd_info, timestamp)),
# timeout=0.01
# )
# timestamp_frame_queue.append({
# "timestamp": timestamp,
# "frame": current_frame
# })
# frame_count += 1
#
# # 每秒打印一次帧数统计
# if timestamp - time_start > 1000000000:
# fps = frame_count / ((timestamp - time_start) / 1000000000)
# print(f"读取帧数: {frame_count} | 实时FPS: {fps:.2f}")
# frame_count = 0
# time_start = timestamp
# else:
# # 队列满时短暂休眠避免CPU占用过高
# await asyncio.sleep(0.001)
# except asyncio.TimeoutError:
# print("帧队列已满,跳过此帧")
#
# # 连续损坏帧触发重连逻辑(增强版)
# if consecutive_corrupted_frames > MAX_CORRUPTED:
# print(f"连续{MAX_CORRUPTED}帧异常,尝试重新初始化捕获器")
# # 释放旧捕获器
# if cap and cap.isOpened():
# await loop.run_in_executor(
# read_rtmp_frames_executor,
# cap.release
# )
# # 重新初始化
# reconnected = False
# for attempt in range(1, MAX_RETRIES + 1):
# try:
# cap, (width, height) = await loop.run_in_executor(
# read_rtmp_frames_executor,
# init_capture,
# attempt
# )
# consecutive_corrupted_frames = 0
# last_valid_frame = np.zeros((height, width, 3), dtype=np.uint8)
# reconnected = True
# break
# except RuntimeError as e:
# print(f"重连失败 (第{attempt}次): {e}")
# if attempt >= MAX_RETRIES:
# raise RuntimeError("重连次数超限,停止尝试")
# await asyncio.sleep(RETRY_DELAY)
# if not reconnected:
# break
#
# except Exception as e:
# print(f"帧处理错误: {str(e)}")
# await asyncio.sleep(0.1)
#
# except asyncio.CancelledError:
# print("读取任务已取消")
# except Exception as e:
# print(f"流读取异常: {str(e)}")
# # 异常恢复机制(增强版)
# if not cancel_flag.is_set():
# await asyncio.sleep(RETRY_DELAY)
# # 重新初始化捕获器
# reconnected = False
# for attempt in range(1, MAX_RETRIES + 1):
# try:
# cap, (width, height) = await loop.run_in_executor(
# read_rtmp_frames_executor,
# init_capture,
# attempt
# )
# reconnected = True
# break
# except RuntimeError as e:
# print(f"异常恢复重连失败 (第{attempt}次): {e}")
# if attempt >= MAX_RETRIES:
# raise
# await asyncio.sleep(RETRY_DELAY)
# if reconnected:
# # 恢复后继续运行(递归调用自身,保持逻辑完整)
# await read_rtmp_frames(
# loop=loop,
# read_rtmp_frames_executor=read_rtmp_frames_executor,
# video_url=video_url,
# device=device,
# topic_camera_osd=topic_camera_osd,
# method_camera_osd=method_camera_osd,
# topic_osd_info=topic_osd_info,
# method_osd_info=method_osd_info,
# cancel_flag=cancel_flag,
# frame_queue=frame_queue,
# timestamp_frame_queue=timestamp_frame_queue
# )
# finally:
# # 确保资源释放
# if cap:
# await loop.run_in_executor(
# read_rtmp_frames_executor,
# lambda: cap.release() if cap.isOpened() else None
# )
# print("RTMP流读取已停止")
def init_capture_with_sei_fix(video_url: str, attempt: int = 1):
"""
修复SEI错误的VideoCapture初始化
"""
print(f"\n===== 第 {attempt}/{MAX_RETRIES} 次尝试连接 =====")
# ✅ 关键设置FFmpeg全局参数屏蔽SEI帧截断日志
os.environ["OPENCV_FFMPEG_LOG_LEVEL"] = "ERROR" # 只输出FFmpeg致命错误屏蔽警告包括SEI截断
# 初始化VideoCapture指定FFmpeg后端
cap = cv2.VideoCapture(video_url, cv2.CAP_FFMPEG)
if not cap.isOpened():
raise RuntimeError(f"无法打开RTMP流 (第{attempt}次尝试)")
# 设置核心参数
cap.set(cv2.CAP_PROP_BUFFERSIZE, BUFFER_SIZE) # 小缓冲区,实时推帧
cap.set(cv2.CAP_PROP_FOURCC, FOURCC) # 指定H264解码器
cap.set(cv2.CAP_PROP_FPS, TARGET_FPS) # 同步流帧率
# 获取流分辨率
try:
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
except:
width, height = 1440, 1080 # 默认分辨率
print(f"拉流成功:分辨率 {width}x{height}")
return cap, (width, height)
def ensure_cv8uc3(frame):
"""确保帧格式为8位3通道BGR"""
if frame is None or frame.size == 0:
return None
if frame.dtype != np.uint8:
frame = frame.astype(np.uint8)
if len(frame.shape) == 2: # 灰度图转彩色
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
elif frame.shape[2] == 4: # RGBA转BGR
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
return frame
async def read_rtmp_frames(
loop,
executor: ThreadPoolExecutor,
video_url: str,
device: Optional[MQTTDevice] = None,
topic_camera_osd: Optional[str] = None,
method_camera_osd: Optional[str] = None,
topic_osd_info: Optional[str] = None,
method_osd_info: Optional[str] = None,
cancel_flag: Optional[asyncio.Event] = None,
frame_queue: asyncio.Queue = None,
timestamp_frame_queue: TimestampedQueue = None
):
"""异步RTMP读取修复SEI错误"""
print(f"开始测试 RTMP 拉流(验证 SEI 丢包问题): {video_url}")
print("测试说明1. 控制台无 SEI 截断警告即为正常 2. 观察 imshow 窗口画面是否稳定 3. 按 'q' 退出")
if cancel_flag is None:
cancel_flag = asyncio.Event()
# 创建预览队列
preview_task = None
# 初始化捕获器
cap = None
width, height = 1280, 720
for attempt in range(1, MAX_RETRIES + 1):
try:
cap, (width, height) = await loop.run_in_executor(
executor,
init_capture_with_sei_fix,
video_url,
attempt
)
break
except RuntimeError as e:
print(f"{attempt} 次尝试失败:{e}{RETRY_DELAY} 秒后重试")
if attempt >= MAX_RETRIES:
raise RuntimeError(f"所有{MAX_RETRIES}次初始化尝试均失败")
await asyncio.sleep(RETRY_DELAY)
try:
last_valid_frame = np.zeros((height, width, 3), dtype=np.uint8)
consecutive_corrupted_frames = 0
frame_count = 0
fps_count = 0
time_start = time.time_ns()
fps_start = time.time()
print(f"拉流成功:分辨率 {width}x{height},开始预览(按 'q' 退出)")
while not cancel_flag.is_set():
try:
# 异步读取帧
ret, frame = await loop.run_in_executor(executor, cap.read)
# 帧有效性判断 + 格式转换
current_frame = last_valid_frame.copy()
if ret and frame is not None and frame.size > 0:
frame = ensure_cv8uc3(frame)
current_frame = frame.copy()
last_valid_frame = current_frame
consecutive_corrupted_frames = 0
frame_count += 1
fps_count += 1
else:
consecutive_corrupted_frames += 1
if consecutive_corrupted_frames % 15 == 0:
print(f"跳过 SEI 帧/临时解码异常,连续异常帧:{consecutive_corrupted_frames}(画面稳定)")
# # 放入预览队列
# if enable_preview and preview_queue is not None:
# try:
# # 清空旧帧,只保留最新的
# while not preview_queue.empty():
# try:
# preview_queue.get_nowait()
# except asyncio.QueueEmpty:
# break
#
# if not preview_queue.full():
# await asyncio.wait_for(preview_queue.put(current_frame), timeout=0.001)
# except (asyncio.QueueFull, asyncio.TimeoutError):
# pass
# OSD信息
osd_info = None
if device and topic_osd_info and method_osd_info:
try:
osd_msg = await loop.run_in_executor(
executor, device.get_latest_message, topic_osd_info, method_osd_info
)
if osd_msg and hasattr(osd_msg, 'data'):
osd_info = Air_Attitude(
gimbal_pitch=osd_msg.data.gimbal_pitch,
gimbal_roll=osd_msg.data.gimbal_roll,
gimbal_yaw=osd_msg.data.gimbal_yaw,
height=osd_msg.data.height,
latitude=osd_msg.data.latitude,
longitude=osd_msg.data.longitude
)
except Exception as e:
if consecutive_corrupted_frames < 5: # 减少日志输出
print(f"获取OSD信息失败: {str(e)}")
# 帧数统计
timestamp = time.time_ns()
if frame_queue is not None and timestamp_frame_queue is not None and ret:
try:
if not frame_queue.full():
await asyncio.wait_for(
frame_queue.put((current_frame, osd_info, timestamp)),
timeout=0.01
)
timestamp_frame_queue.append({"timestamp": timestamp, "frame": current_frame})
except (asyncio.QueueFull, asyncio.TimeoutError):
pass
# 实时FPS统计
fps_elapsed = time.time() - fps_start
if fps_elapsed >= 1.0:
fps = fps_count / fps_elapsed
timestamp_elapsed = (timestamp - time_start) / 1000000000
if timestamp_elapsed > 0:
timestamp_fps = frame_count / timestamp_elapsed
print(f"实时 FPS{fps:.2f} | 累计FPS{timestamp_fps:.2f} | 分辨率:{width}x{height}")
fps_count = 0
fps_start = time.time()
# 连续损坏帧重连
if consecutive_corrupted_frames > MAX_CORRUPTED:
print(f"连续 {MAX_CORRUPTED} 帧异常,停止预览,尝试重连")
if cap and cap.isOpened():
await loop.run_in_executor(executor, cap.release)
reconnected = False
for retry_attempt in range(1, MAX_RETRIES + 1):
try:
cap, (width, height) = await loop.run_in_executor(
executor, init_capture_with_sei_fix, video_url, retry_attempt
)
consecutive_corrupted_frames = 0
last_valid_frame = np.zeros((height, width, 3), dtype=np.uint8)
reconnected = True
print(f"重连成功 (第{retry_attempt}次)")
break
except RuntimeError as e:
print(f"重连失败 (第{retry_attempt}次): {e}")
if retry_attempt >= MAX_RETRIES:
print("重连次数超限,退出")
cancel_flag.set()
await asyncio.sleep(RETRY_DELAY)
if not reconnected:
break
# 小幅延迟避免CPU占用过高
await asyncio.sleep(0.001)
except asyncio.CancelledError:
print("读取任务被取消")
break
except Exception as e:
print(f"帧处理错误: {str(e)}")
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print("读取任务已取消")
except Exception as e:
print(f"流读取异常: {str(e)}")
finally:
# 资源释放
if cap:
await loop.run_in_executor(executor, lambda: cap.release() if cap.isOpened() else None)
print("释放VideoCapture资源")
if preview_task is not None:
preview_task.cancel()
try:
await preview_task
except asyncio.CancelledError:
pass
except Exception:
pass
await loop.run_in_executor(executor, cv2.destroyAllWindows)
print("RTMP流读取已停止")
#
# async def read_rtmp_frames_skip_sei(
# loop,
# read_rtmp_frames_executor: ThreadPoolExecutor,
# video_url: str,
# device: Optional[MQTTDevice] = None,
# topic_camera_osd: Optional[str] = None,
# method_camera_osd: Optional[str] = None,
# topic_osd_info: Optional[str] = None,
# method_osd_info: Optional[str] = None,
# cancel_flag: Optional[asyncio.Event] = None,
# frame_queue: asyncio.Queue = None,
# timestamp_frame_queue=None
# ):
# """
# 优化版(兼容 Python OpenCV + 屏蔽 SEI 帧错误)
# ✅ 核心修复1. 移除 Python 不支持的 CAP_PROP_OUTPUT_FORMAT 2. 屏蔽 FFmpeg SEI 截断日志
# ✅ 核心实现:配置解码器参数忽略 SEI 帧,保证实时拉流无性能损耗,保留异步逻辑和队列功能
# ✅ 解决问题:[h264 @ xxx] SEI type 245 truncated 错误,同时保证拉流实时性和有效性
# """
# max_retries = 20
# retry_delay = 3
# pic_count = 0
# attempt = 0
# time_start = cv2.getTickCount()
# frame_count = 0
#
# # ✅ 关键优化:设置 FFmpeg 全局参数,屏蔽 SEI 帧截断日志Python OpenCV 间接控制 FFmpeg
# # 禁止 FFmpeg 输出冗余警告,仅保留致命错误,彻底隐藏 SEI type 245 truncated 信息
# os.environ["OPENCV_FFMPEG_LOG_LEVEL"] = "ERROR"
#
# # 核心拉流参数(沿用同步测试代码的有效配置)
# BUFFER_SIZE = 1 # 减少缓存,不积压 SEI 帧
# TARGET_FPS = 30 # 适配 RTMP 流帧率
# FOURCC = cv2.VideoWriter_fourcc(*'H264') # 兼容 YUV420P 像素格式,减少 SEI 解析冗余
#
# # 初始化默认参数
# if cancel_flag is None:
# cancel_flag = asyncio.Event()
# if timestamp_frame_queue is None:
# timestamp_frame_queue = []
# if frame_queue is None:
# frame_queue = asyncio.Queue(maxsize=5) # 小队列减少缓存,保证实时性
#
# print(f"开始读取 RTMP 流OpenCV 优化版,跳过 SEI 帧错误): {video_url}")
#
# while not cancel_flag.is_set() and attempt < max_retries:
# attempt += 1
# if cancel_flag.is_set():
# logger.info("收到停止信号,终止 RTMP 读取")
# break
#
# cap = None
# consecutive_corrupted_frames = 0
# max_consecutive_corrupted = 30
# last_valid_frame = None
#
# try:
# logger.info(f"尝试连接 RTMP 流 (尝试 {attempt}/{max_retries}): {video_url}")
#
# # ✅ 步骤1初始化 VideoCapture指定 FFmpeg 后端,修复 Python 兼容问题
# def init_capture_with_sei_skip():
# # 1. 指定 FFmpeg 后端,启用高级参数配置
# cap = cv2.VideoCapture(video_url, cv2.CAP_FFMPEG)
# if not cap.isOpened():
# return None
#
# # ✅ 步骤2配置解码器参数仅保留 Python 支持的核心参数,移除 CAP_PROP_OUTPUT_FORMAT
# # 2.1 关键配置:缓冲区大小设为 1实时推帧不缓存 SEI 帧
# cap.set(cv2.CAP_PROP_BUFFERSIZE, BUFFER_SIZE)
#
# # 2.2 配置像素格式,确保与硬件兼容(对应 PIX_FMT_YUV420P
# cap.set(cv2.CAP_PROP_FOURCC, FOURCC)
#
# # 2.3 配置帧率,与流帧率同步,避免冗余处理
# cap.set(cv2.CAP_PROP_FPS, TARGET_FPS)
#
# # ✅ 步骤3可选启用硬件加速保留原逻辑增加异常兼容
# try:
# # 启用 CUDA 硬件加速(对应 cv::CAP_PROP_HW_ACCELERATION
# cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_CUDA)
# logger.info("已启用 CUDA 硬件加速,降低 SEI 帧解析负担")
# except (AttributeError, cv2.error):
# try:
# # 备用:启用 VA-API 硬件加速Intel 平台)
# cap.set(cv2.CAP_PROP_HW_ACCELERATION, cv2.VIDEO_ACCELERATION_VAAPI)
# logger.info("已启用 VA-API 硬件加速,降低 SEI 帧解析负担")
# except (AttributeError, cv2.error):
# logger.info("未检测到硬件加速模块,使用软件解码(仍可跳过 SEI 帧错误)")
#
# # ✅ 步骤4替代 CAP_PROP_OUTPUT_FORMAT定义帧格式转换函数确保 CV_8UC3 格式)
# global ensure_cv8uc3 # 全局声明,方便后续帧处理调用
# def ensure_cv8uc3(frame):
# if frame.dtype != np.uint8:
# frame = frame.astype(np.uint8)
# if len(frame.shape) == 2: # 灰度图转彩色
# frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR)
# elif frame.shape[2] == 4: # RGBA 转 BGR
# frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2BGR)
# return frame
#
# # ✅ 步骤5获取流分辨率初始化兜底帧增加容错防止获取失败
# try:
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# except:
# width, height = 1280, 720 # 默认分辨率,提升鲁棒性
# nonlocal last_valid_frame
# last_valid_frame = np.zeros((height, width, 3), dtype=np.uint8)
#
# return cap
#
# # 在线程池中初始化捕获器,避免阻塞事件循环
# cap = await loop.run_in_executor(
# read_rtmp_frames_executor,
# init_capture_with_sei_skip
# )
#
# if cap is None or not cap.isOpened():
# logger.warning(f"尝试 {attempt} 次OpenCV 捕获器初始化失败,准备重试")
# await asyncio.sleep(retry_delay)
# continue
#
# # 获取有效分辨率,打印启动日志
# try:
# width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
# height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# except:
# width, height = 1280, 720
# logger.info(f"OpenCV 捕获器已启动(跳过 SEI 帧),分辨率:{width}x{height}")
# logger.info("解码器已配置:忽略 SEI 帧,仅处理核心视频帧,无截断错误日志")
#
# # ✅ 步骤6读取视频流跳过 SEI 帧影响(整合格式转换,保留异步逻辑)
# while not cancel_flag.is_set():
# # 读取一帧(解码器自动跳过 SEI 帧,仅返回核心解码帧)
# def read_frame_core():
# ret, frame = cap.read()
# return ret, frame
#
# ret, frame = await loop.run_in_executor(
# read_rtmp_frames_executor,
# read_frame_core
# )
#
# current_time = cv2.getTickCount()
# frame_count += 1
# img = None
# is_corrupted = False
#
# # ✅ 流程判断:帧有效性检查 + 格式转换(替代 CAP_PROP_OUTPUT_FORMAT
# if ret and frame is not None and frame.size > 0:
# # 非 SEI 帧,解码成功,转换为 CV_8UC3 格式
# frame = ensure_cv8uc3(frame)
# img = frame.copy()
# last_valid_frame = img.copy()
# consecutive_corrupted_frames = 0
# is_corrupted = False
# else:
# # SEI 帧或解码异常,使用兜底帧(避免解码器状态异常)
# img = last_valid_frame.copy()
# is_corrupted = True
# consecutive_corrupted_frames += 1
# if consecutive_corrupted_frames % 15 == 0:
# logger.warning(f"跳过 SEI 帧/解码异常,连续损坏帧: {consecutive_corrupted_frames}(解码器状态稳定)")
#
# # 强制兜底,保证帧非空
# if img is None or img.size == 0:
# img = np.zeros((height, width, 3), dtype=np.uint8)
#
# # 获取 OSD 信息(保留原逻辑,实时性优先)
# osd_info = None
# if device and topic_osd_info and method_osd_info:
# try:
# osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
# if osd_msg and hasattr(osd_msg, 'data'):
# osd_info = Air_Attitude(
# gimbal_pitch=osd_msg.data.gimbal_pitch,
# gimbal_roll=osd_msg.data.gimbal_roll,
# gimbal_yaw=osd_msg.data.gimbal_yaw,
# height=osd_msg.data.height,
# latitude=osd_msg.data.latitude,
# longitude=osd_msg.data.longitude
# )
# except Exception as osd_error:
# pass
#
# # 实时放入队列,满队列丢弃旧帧(保证实时性,保留原逻辑)
# if img is not None:
# try:
# if frame_queue.full():
# frame_queue.get_nowait()
# current_time_ns = time.time_ns()
# await frame_queue.put((img, osd_info, current_time_ns), timeout=0.001)
# except asyncio.TimeoutError:
# pass
#
# # 统计实时 FPS验证性能无额外 CPU 开销)
# pic_count += 1
# tick_count = cv2.getTickCount() - time_start
# elapsed_seconds = tick_count / cv2.getTickFrequency()
# if elapsed_seconds >= 1.0:
# fps = pic_count / elapsed_seconds
# print(f"实时统计FPS:{fps:.2f} 分辨率:{img.shape[1]}x{img.shape[0]} 无 SEI 帧截断错误日志")
# pic_count = 0
# time_start = cv2.getTickCount()
#
# # 连续损坏帧触发重连(解码器状态异常时重试)
# if consecutive_corrupted_frames > max_consecutive_corrupted:
# logger.warning(f"连续{consecutive_corrupted_frames}帧异常,触发捕获器重连")
# break
#
# except Exception as e:
# logger.error(f"RTMP 流错误 (尝试 {attempt}/{max_retries}): {e}")
# if attempt < max_retries:
# await asyncio.sleep(retry_delay)
# else:
# raise RuntimeError(f"无法连接 RTMP 流 (尝试 {max_retries} 次后失败): {video_url}")
# finally:
# # 安全释放捕获器资源
# if cap is not None and hasattr(cap, 'isOpened') and cap.isOpened():
# cap.release()
# logger.info("OpenCV 捕获器资源已释放")
#
# # 统计最终结果
# total_elapsed = (cv2.getTickCount() - time_start) / cv2.getTickFrequency()
# if frame_count > 0 and frame_count > max_retries:
# avg_fps = frame_count / total_elapsed if total_elapsed > 0 else 0
# print(f"\nRTMP 流读取完成,总有效帧数: {frame_count}, 平均 FPS: {avg_fps:.2f},无 SEI 帧截断错误日志")
# else:
# print("\nRTMP 流读取失败,未获取到有效帧")
#
# logger.info(f"RTMP 流已结束,累计处理帧数: {frame_count}")
# async def process_frames(detector: MultiYOLODetector):
# async def process_frames(detector: MultiYOLODetector_TrackId, cancel_flag: asyncio.Event,
# frame_queue: asyncio.Queue, processed_queue: asyncio.Queue):
async def process_frames(detector: MultiYoloTrtDetectorTrackId, cancel_flag: asyncio.Event,
frame_queue: asyncio.Queue, processed_queue: asyncio.Queue):
"""协程处理帧队列"""
start_time = time.time()
time_start = time.time_ns()
pic_count = 0
while not cancel_flag.is_set():
frame_start = time.time()
try:
frame, osd_info, timestamp = await asyncio.wait_for(
frame_queue.get(),
timeout=0.5 # 延长超时,适配处理耗时
)
try:
time_pr_start = time.time_ns()
detections, detections_list, model_para = await detector.predict(frame)
time_pr_end = time.time_ns()
print(f"time_pr_starttime_pr_start {(time_pr_end - time_pr_start) / 1000000}")
predict_state = True
if detections:
print("检测到任何目标")
if not detections:
predict_state = False
logger.debug("未检测到任何目标")
# continue
# # # # 显示帧用于调试(可选)
# cv2.imshow('process_frames', frame)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# stop_event.set()
# break
processed_data = {
'frame': frame,
'osd_info': osd_info,
'detections': detections,
'detections_list': detections_list,
'timestamp': timestamp,
'model_para': model_para,
'predict_state': predict_state # predict状态判定方便rtmp推流做状态判定
}
if not processed_queue.full():
time_end = time.time_ns()
pic_count = pic_count + 1
if time_end - time_start > 1000000000:
print(f"processframes {pic_count}")
pic_count = 0
time_start = time_end
await processed_queue.put(processed_data)
else:
logger.warning("处理队列已满,丢弃帧")
stats['dropped_frames'] += 1
if 'frame' in locals():
frame_time = time.time() - frame_start
# print(f"处理帧耗时: {frame_time:.4f}s")
except Exception as e:
logger.error(f"处理帧时出错: {e}", exc_info=True)
stats['dropped_frames'] += 1
await asyncio.sleep(0.1)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
print("process_frames 收到取消信号")
raise
except Exception as e:
logger.error(f"获取帧时发生意外错误: {e}", exc_info=True)
await asyncio.sleep(0.1)
# finally:
# log_perf('process_frames', start_time)
print("process_frames读取线程已停止")
from PIL import Image, ImageDraw, ImageFont
chinese_font = None
try:
chinese_font = ImageFont.truetype("config/SIMSUN.TTC", 20)
except:
chinese_font = ImageFont.load_default()
def put_chinese_text(img, text, position, font=chinese_font, color=(0, 255, 0)):
"""使用预加载的字体绘制中文"""
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img_pil)
draw.text(position, text, font=font, fill=color)
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
# # 全局变量存储推流容器
# stream_containers: Dict[str, Any] = {}
from collections import defaultdict
import time
from typing import Dict, Any
class TrackIDEventFilter:
def __init__(self, max_inactive_time: float = 5.0):
"""
初始化 TrackID 过滤器
:param max_inactive_time: 最大不活跃时间(秒),超时未出现的 track_id 会被移除
"""
self.track_status = defaultdict(int) # 记录每个 track_id 的连续识别次数
self.last_active_time = defaultdict(float) # 记录每个 track_id 最后一次出现的时间
self.max_inactive_time = max_inactive_time # 超时清理阈值
def should_report(self, track_id: int) -> bool:
"""
判断是否应该上报事件
:param track_id: 跟踪ID
:return: True 表示应该上报False 表示过滤
"""
current_time = time.time()
# for tid, last_time in self.last_active_time.items():
# print(f"tidtidtidtid {tid} {current_time} {last_time} {current_time-last_time} {self.max_inactive_time}")
# 1. 清理长时间不活跃的 track_id
inactive_tracks = [
tid for tid, last_time in self.last_active_time.items()
if (current_time - last_time) > self.max_inactive_time
]
for tid in inactive_tracks:
print(f"inactive_tracksinactive_tracksinactive_tracks{tid}")
self.track_status.pop(tid, None)
self.last_active_time.pop(tid, None)
# 2. 更新最后活跃时间
self.last_active_time[track_id] = current_time
show_count = self.track_status[track_id]
print(f"show_countshow_countshow_count {track_id} {show_count}")
# 3. 过滤逻辑
if show_count == 0:
# 首次出现
self.track_status[track_id] = 1
print(f"track_status[track_id] = 1 {self.track_status[track_id]} : {track_id}")
return True
elif show_count > 0:
# 连续出现第二次,上报并重置状态(需要再次连续出现才会触发)
self.track_status[track_id] = show_count + 1 # 重置为 0下次需重新连续出现两次
print(f"track_status[track_id] =show_count + 1 {self.track_status[track_id]} : {track_id}")
return False
async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: float = None,
list_points: list[list[any]] = None, camera_para: Camera_Para = None,
invade_state: bool = False, cancel_flag: asyncio.Event = None,
processed_queue: asyncio.Queue = None, invade_queue: asyncio.Queue = None,
cv_frame_queue: asyncio.Queue = None, stream_containers: Dict[str, Any] = None):
# global stream_containers, count_pic
start_time = time.time()
time_start = time.time_ns()
pic_count = 0
# 修改推流参数
options = {
'preset': 'veryfast',
'tune': 'zerolatency',
'crf': '23',
'g': '50', # 关键帧间隔
'threads': '2', # 限制编码线程
}
codec_name = 'libx264'
max_retries = 3
retry_delay = 2.0
# 初始化视频输出
output_video_path = None
video_writer = None
frame_width, frame_height = None, None
fps = input_fps or Config.TARGET_FPS
last_frame_time = time.time() - 1
frame_interval = 1.0 / fps
try:
while not cancel_flag.is_set():
frame_start = time.time()
try:
# 第一层帧率控制
# current_time = time.time()
#
# time_diff = frame_interval - (current_time - last_frame_time)
# if time_diff > 0:
# await asyncio.sleep(time_diff)
# last_frame_time = current_time
processed_data = await asyncio.wait_for(
processed_queue.get(),
timeout=1
)
# 确保 processed_data 是字典
if not isinstance(processed_data, dict):
print(f"❌ 错误processed_data 不是字典,而是 {type(processed_data)}")
continue
frame = processed_data['frame']
# 绘制检测结果
frame_copy = frame.copy()
predict_state = processed_data['predict_state']
osd_info = processed_data['osd_info']
img_height, img_width = frame.shape[:2]
results = []
results_list = []
# 启用侵限且拿到了飞机的姿态信息,再绘制红线
if invade_state and osd_info:
gimbal_yaw = osd_info.gimbal_yaw
gimbal_pitch = osd_info.gimbal_pitch
gimbal_roll = osd_info.gimbal_roll
height = osd_info.height
# print(f"heightheightheight {height}")
cam_longitude = osd_info.longitude
cam_latitude = osd_info.latitude
# 当前list_points 虽然是二维数组,但是只存了一个,后续根据业务变化
for points in list_points:
# 批量返回图像的像素坐标
point_list = []
results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude,
cam_latitude,
img_width,
img_height, points, camera_para)
if results:
results_list.append(results) # 支持两个区域,高压侵限、营业线侵限
for point in results:
point_list.append([point["u"], point["v"]])
cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int64)], isClosed=True,
color=(0, 0, 255),
thickness=2)
# print(f"predict_statepredict_state {predict_state}")
# 模型输出了推理结果
if predict_state:
# 测试代码,用做测试推理结果,初始化视频写入器(如果尚未初始化)
if video_writer is None and output_video_path:
frame_width = frame.shape[1]
frame_height = frame.shape[0]
# 定义视频编码器和输出文件
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者使用 'avc1' 等其他编码
# fps = Config.TARGET_FPS # 使用配置中的目标帧率
video_writer = cv2.VideoWriter(
output_video_path,
fourcc,
fps,
(frame_width, frame_height)
)
print(f"视频写入器已初始化,分辨率: {frame_width}x{frame_height}, FPS: {fps}")
detections = processed_data['detections']
detections_list = processed_data['detections_list']
model_para = processed_data['model_para']
class_names = model_para[0]["model_class_names"]
chinese_label = model_para[0]["model_chinese_labe"]
cls_map = model_para[0]["cls_map"]
para_invade_enable = model_para[0]["para_invade_enable"]
model_list_func_id = model_para[0]["model_list_func_id"]
model_func_id = model_para[0]["func_id"]
invade_point = []
message_point = []
target_point = [] # 存储满足条件的图像坐标,方便后续经纬度转换
cls_count = 0
# # 初始化统计字典
class_stats = defaultdict(int)
reversed_dict = {value: 0 for value in cls_map.values()}
bg_color = (173, 216, 230) # 文本框底色使用淡蓝色
text_color = (0, 255, 0) # 绿色
for det in detections:
x1, y1, x2, y2 = map(int, det.bbox) # 确保坐标是整数
cls_id = det.class_id # 假设Detection对象有class_id属性
class_name = det.class_name
confidence = det.confidence
track_id = det.track_id
new_track_id = track_id * 100 + cls_id # 类型小于100或者为负数
# 更新统计
class_stats[cls_id] += 1
# 如果开起侵限功能,就只显示侵限内的框
point_x = (x1 + x2) / 2
point_y = (y1 + y2) / 2
print(f"class_name--{class_name}")
print(f"model_class_names: {model_para[0]['model_class_names']}")
if class_name not in model_para[0]["model_class_names"]:
continue
en_name = model_para[0]["model_chinese_labe"][
model_para[0]["model_class_names"].index(class_name)]
if invade_state:
# 同时适配多个区域的侵限判断
is_invade = is_point_in_polygonlist(point_x, point_y, results_list)
# is_invade = is_point_in_polygon(point_x, point_y, results)
# print(f"is_invadeis_invadeis_invade {is_invade} {len(results)}")
if is_invade:
cls_count += 1
invade_point.append({
"u": point_x,
"v": point_y,
"class_name": class_name
})
target_point.append({
"u": point_x,
"v": point_y,
"cls_id": cls_id,
"track_id": track_id,
"new_track_id": new_track_id
}) # 对于侵限,只存储侵限目标
# model_list_func_id = model_para[0]["model_list_func_id"]
# model_func_id = model_para[0]["func_id"]
message_point.append({
"confidence": float(confidence),
"cls_id": cls_id,
"type_name": en_name,
"track_id": track_id,
"box": [x1, y1, x2, y2]
})
label = f"{en_name}:{confidence:.2f}:{track_id}"
label_name = f"{en_name}"
# 计算文本位置
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[
0]
text_width, text_height = text_size[0], text_size[1]
text_x = x1
text_y = y1 - 5
# 如果文本超出图像顶部,则放在框内部下方
if text_y < 0:
text_y = y2 + text_height + 5
temp_img = frame_copy.copy()
frame_copy = put_chinese_text(
temp_img,
# label, # 置信度、类别、用作测试
# "", # 注释掉汉字
label_name, # 仅显示汉字
(text_x, text_y- 40),
)
else:
cls_count += 1
# 绘制边界框
cv2.rectangle(frame_copy, (x1, y1), (x2, y2), (0, 255, 255), 2)
message_point.append({
"confidence": float(confidence),
"cls_id": cls_id,
"type_name": en_name,
"track_id": track_id,
"box": [x1, y1, x2, y2]
})
target_point.append({
"u": point_x,
"v": point_y,
"cls_id": cls_id,
"track_id": track_id,
"new_track_id": new_track_id
}) # 对于侵限,只存储侵限目标
# 准备标签文本
# label = f"{chinese_label.get(cls_id, class_name)}: {confidence:.2f}:{track_id}"
label = f"{confidence:.2f}:{track_id}"
label_name = f"{en_name}"
# 计算文本位置
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=8)[0]
text_width, text_height = text_size[0], text_size[1]
text_x = x1
text_y = y1 - 5
# 如果文本超出图像顶部,则放在框内部下方
if text_y < 0:
text_y = y2 + text_height + 5
# 绘制文本背景
padding = 2
temp_img = frame_copy.copy()
frame_copy = put_chinese_text(
temp_img,
# label, # 置信度、类别、用作测试
# "", # 注释掉汉字
label_name, # 仅显示汉字
(text_x,text_y- 40),
)
if invade_state:
for point in message_point:
cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
(point["box"][2], point["box"][3]),
(0, 255, 255), 2)
# 画红线
# 在左上角显示统计结果
stats_text = []
for cls_id, count in class_stats.items():
cls_name = chinese_label.get(cls_id,
class_names[cls_id] if class_names and cls_id < len(
class_names) else str(
cls_id))
reversed_dict[cls_name] = count
for key, value in reversed_dict.items():
stats_text.append(f"{key}: {value}")
if stats_text:
# 计算统计文本的总高度
text_height = cv2.getTextSize("Test", cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1)[0][
1]
total_height = len(stats_text) * (text_height + 18) # 5是行间距
# 统计文本的起始位置(左上角)
start_x = 50 # 留出200像素宽度
start_y = 20 # 从顶部开始
# # 绘制统计背景
# # bg_color = (0, 0, 0)
# frame_copy = cv2.rectangle(
# frame_copy,
# (start_x - 10, start_y - 10),
# (400, start_y + total_height + 20),
# bg_color,
# -1
# )
# # 逐行绘制统计文本
# for i, text in enumerate(stats_text):
# y_pos = start_y + i * (text_height + 30)
# temp_img = frame_copy.copy()
# frame_copy = put_chinese_text(
# temp_img,
# text,
# (start_x, y_pos),
# color=text_color
# )
new_data = {
'frame_copy': frame_copy,
'frame': frame,
"osd_info": osd_info,
'detections': detections,
"message_point": message_point,
"cls_count": cls_count,
"target_point": target_point,
"model_list_func_id": model_list_func_id,
"model_func_id": model_func_id,
# 提取第一个模型的func_id 字段因为现在已经要做整合无法在区分各个提取第一个模型的func_id
'timestamp': processed_data.get('timestamp'),
"detections_list": detections_list,
"model_para": model_para
# 'model_para': processed_data.get('model_para', {}) # 确保 model_para 存在
}
# 临时代码 rtmp 和侵限逻辑要改
if invade_state:
# para_list 中使能了 para_invade_enable才做侵限判断
if para_invade_enable:
if not invade_queue.full():
await invade_queue.put(new_data)
else:
if not cv_frame_queue.full():
await cv_frame_queue.put(new_data)
time_end = time.time_ns()
pic_count = pic_count + 1
if time_end - time_start > 1000000000:
print(f"writeFrames {pic_count}")
pic_count = 0
time_start = time_end
#
# count_p = count_p + 1
# cv2.imwrite(f"save_pic/rtmp/test-{count_p}.jpg", frame_copy)
# video_writer.write(frame_copy)
# else:
# frame_copy = frame
# 转换颜色空间
rgb_frame = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
# # # # # 显示帧用于调试(可选)
# cv2.imshow(f"write_results_to_rtmp-{task_id}", frame_copy)
# if cv2.waitKey(1) & 0xFF == ord('q'):
# stop_event.set()
# break
# 初始化推流容器(如果尚未初始化)
if output_url and output_url not in stream_containers:
try:
# container = av.open(output_url, mode='w', format='flv')
container = av.open(
output_url,
mode='w',
format='flv',
options={
# 添加RTMP连接超时5秒和数据超时10秒
'rtmp_connect_timeout': '5000000', # 单位微秒5秒
'rtmp_timeout': '10000000', # 单位微秒10秒
'stimeout': '5000000' # 底层socket超时
}
)
stream = container.add_stream(codec_name, rate=Config.TARGET_FPS)
# stream.time_base = f"1/{Config.TARGET_FPS}"
stream.width = frame.shape[1]
stream.height = frame.shape[0]
stream.pix_fmt = 'yuv420p'
stream.options = options
stream_containers[output_url] = {
'container': container,
'stream': stream,
'last_frame_time': time.time(),
'frame_count': 0,
'retry_count': 0
}
print(f"✅ 推流初始化成功: {output_url}")
except Exception as e:
print(f"❌ 推流初始化失败: {e}")
if 'container' in locals():
try:
container.close()
except:
pass
await asyncio.sleep(1.0)
continue
# 推流逻辑
if output_url and output_url in stream_containers:
try:
container_info = stream_containers[output_url]
stream = container_info['stream']
container = container_info['container']
if rgb_frame.dtype == np.uint8:
av_frame = av.VideoFrame.from_ndarray(rgb_frame, format='rgb24')
packets = stream.encode(av_frame) # 这是异步方法
# print(f"📦 encode 生成 {len(packets)} 个 packet")
if packets:
for packet in packets:
try:
container.mux(packet)
container_info['last_frame_time'] = time.time()
container_info['frame_count'] += 1
except Exception as e:
logger.warning(f"推流数据包错误: {e}")
container_info['retry_count'] += 1
if container_info['retry_count'] > max_retries:
raise
else:
# 编码器仍在初始化,不更新 last_frame_time
pass
# 每100帧打印一次状态
# if container_info['frame_count'] % 100 == 0:
# print(f" 已推送 {container_info['frame_count']} 帧到 {output_url}")
if 'frame' in locals():
frame_time = time.time() - frame_start
print(f"推流帧耗时: {frame_time:.4f}s")
else:
print(f"⚠️ 无效帧格式: {rgb_frame.dtype}")
except Exception as e:
logger.error(f"❌ 推流错误: {e}")
# 尝试重新初始化推流
if output_url in stream_containers:
try:
stream_containers[output_url]['container'].close()
except:
pass
del stream_containers[output_url]
await asyncio.sleep(retry_delay)
continue
except asyncio.TimeoutError:
# if stop_event.is_set():
# break
continue
except asyncio.CancelledError:
print("write_results_to_rtmp 收到取消信号")
raise
except Exception as e:
logger.error(f"推流处理异常: {e}", exc_info=True)
await asyncio.sleep(0.1)
finally:
# log_perf('write_results_to_rtmp', start_time)
# 清理推流容器
for url, info in list(stream_containers.items()):
try:
if 'container' in info:
info['container'].close()
except Exception as e:
logger.warning(f"关闭推流容器时出错: {e}")
stream_containers.clear()
# 清理视频写入器
if video_writer is not None:
video_writer.release()
print("write_results_to_rtmp读取线程已停止")
# 基于射线法,判断设备是否在红线内
def is_point_in_polygonlist(point_x, point_y, polygon_list):
inside = False
for polygon in polygon_list:
x, y = point_x, point_y
n = len(polygon)
inside = False
for i in range(n):
p1 = polygon[i]
p2 = polygon[(i + 1) % n]
x1, y1 = p1["u"], p1["v"]
x2, y2 = p2["u"], p2["v"]
# 检查点是否在多边形的顶点或边上
if (x == x1 and y == y1) or (x == x2 and y == y2):
return True # 点在顶点上
if (x - x1) * (y2 - y1) == (y - y1) * (x2 - x1): # 点在边上
if min(x1, x2) <= x <= max(x1, x2) and min(y1, y2) <= y <= max(y1, y2):
return True
# 射线法核心逻辑
if (y1 > y) != (y2 > y): # 确保点在边的垂直范围内
x_intersect = (x2 - x1) * (y - y1) / (y2 - y1) + x1
if x <= x_intersect:
inside = not inside
return inside
# 基于射线法,判断设备是否在红线内
def is_point_in_polygon(point_x, point_y, polygon):
x, y = point_x, point_y
n = len(polygon)
inside = False
for i in range(n):
p1 = polygon[i]
p2 = polygon[(i + 1) % n]
x1, y1 = p1["u"], p1["v"]
x2, y2 = p2["u"], p2["v"]
# 检查点是否在多边形的顶点或边上
if (x == x1 and y == y1) or (x == x2 and y == y2):
return True # 点在顶点上
if (x - x1) * (y2 - y1) == (y - y1) * (x2 - x1): # 点在边上
if min(x1, x2) <= x <= max(x1, x2) and min(y1, y2) <= y <= max(y1, y2):
return True
# 射线法核心逻辑
if (y1 > y) != (y2 > y): # 确保点在边的垂直范围内
x_intersect = (x2 - x1) * (y - y1) / (y2 - y1) + x1
if x <= x_intersect:
inside = not inside
return inside
# 基于两个点的经纬度,计算两个点之间的距离
def haversine(lon1, lat1, lon2, lat2):
"""计算两个经纬度点之间的直线距离(单位:米)"""
lon1, lat1, lon2, lat2 = map(math.radians, [lon1, lat1, lon2, lat2])
dlon = lon2 - lon1
dlat = lat2 - lat1
a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
c = 2 * math.asin(math.sqrt(a))
r = 6371000 # 地球平均半径(米)
return c * r
async def cal_des_invade(loop,invade_executor, task_id: str, mqtt, mqtt_publish_topic,
list_points: list[list[any]], camera_para: Camera_Para, model_count: int,
cancel_flag: asyncio.Event = None, invade_queue: asyncio.Queue = None,
event_queue: asyncio.Queue = None,
device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1):
# loop = asyncio.get_running_loop()
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
pic_count_hongxian = 0
track_filter = TrackIDEventFilter(max_inactive_time=8.0)
# 用于记录已上报的track_id及其上报时间
reported_track_ids = defaultdict(float)
# 上报间隔时间(秒)
report_interval = 8
target_location_back = [] # 本地缓存,用作位置重复计算
current_time_second = int(time.time())
while not cancel_flag.is_set():
# 检查队列长度,避免堆积
if invade_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
print(f"警告invade_queue 积压(当前长度={invade_queue.qsize()}),清空队列")
while not invade_queue.empty():
await invade_queue.get()
print("队列已清空,等待新数据...")
await asyncio.sleep(0.1)
continue
# 获取队列数据
try:
cv_frame = await asyncio.wait_for(invade_queue.get(), timeout=0.8)
# 检查数据类型
if not isinstance(cv_frame, dict):
print(f"⚠️ 警告cv_frame 不是字典,而是 {type(cv_frame)}")
continue
except asyncio.TimeoutError:
print(f"cal_des_invade TimeoutError")
continue
except asyncio.CancelledError:
print("cal_des_invade 收到取消信号")
raise
if cv_frame is None:
continue
if repeat_time > 0: # 基于时间的重复计算,是否使能
read_cv_time_second = int(time.time())
if read_cv_time_second - current_time_second < repeat_time:
continue
else:
current_time_second = read_cv_time_second
# print("cal_des_invade inside")
frame_copy = cv_frame['frame_copy']
frame = cv_frame['frame']
target_point = cv_frame['target_point']
message_point = cv_frame['message_point']
cls_count = cv_frame['cls_count']
model_list_func_id = cv_frame['model_list_func_id']
model_func_id = cv_frame['model_func_id']
air_alti = cv_frame['osd_info']
detections = cv_frame['detections']
detections_list = cv_frame['detections_list']
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
timestamp = cv_frame.get('timestamp', time.time_ns()) # 默认空字典
if not isinstance(frame, np.ndarray) or frame.size == 0:
print("⚠️ 警告frame不是有效的numpy数组")
continue
# 获取宽高
img_height, img_width = frame.shape[:2]
if not air_alti:
continue
gimbal_yaw = air_alti.gimbal_yaw
gimbal_pitch = air_alti.gimbal_pitch
gimbal_roll = air_alti.gimbal_roll
height = air_alti.height
cam_longitude = air_alti.longitude
cam_latitude = air_alti.latitude
try:
current_time = time.time()
h = device_height
us = []
vs = []
heights = []
should_report = True
for item in target_point:
# # 跳过无效的track_id
# 检查是否应该上报该track_id
u = item["u"]
v = item["v"]
cls_id = item["cls_id"]
track_id = item["track_id"]
new_track_id = item["new_track_id"]
# # 如果这个track_id已经上报过检查是否超过上报间隔
# # if new_track_id in reported_track_ids:
# if track_id in reported_track_ids:
# last_report_time = reported_track_ids[track_id]
# if current_time - last_report_time < report_interval:
# print(f"基于track_id触发去重事件{track_id}")
# should_report = False
print(f"方法 cal_des_invade")
if not track_filter.should_report(track_id):
should_report = False
break
# if track_id < 0: # 适配MultiYOLODetector类该类不支持追踪默认track_id为-1
# should_report = True
# 如果使用TrackIDEventFilter判断需要上报
if should_report:
us.append(u)
vs.append(v)
heights.append(h)
if should_report:
print("进行侵限计算")
location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
height, cam_longitude, cam_latitude, img_width, img_height,
heights)
if not location_results:
continue
# for point in location_results:
# target_location.append({point[1], point[2]})
# point_list = [] # 整理红线集合
repeat_state = False
show_des = 0
str_loca = ""
des_location_result=[]
if repeat_dis > 0: # ai_model_list repeat_dis 字段大于零,才启用去重
if len(target_location_back) > 0: # 当前逻辑并不严谨,只是比较了第一个位置信息
des1_back = target_location_back[0]
des1_back_longitude = des1_back[0]
des1_back_latitude = des1_back[1]
des1_back_height = des1_back[2]
des1 = location_results[0]
des1_longitude = des1[0]
des1_latitude = des1[1]
des_location_result.append({"longitude": des1_longitude,
"latitude": des1_latitude})
des1_height = des1[2]
str_loca = f"{des1_back_longitude}:{des1_back_latitude}---{des1_longitude}{des1_latitude}"
des = haversine(des1_back_longitude, des1_back_latitude, des1_longitude, des1_latitude)
show_des = des
if des < repeat_dis:
print(f"触发基于坐标判断重复,坐标距离:{des}")
repeat_state = True
else:
target_location_back = location_results # 未触发去重逻辑,即更新本地位置缓存
else:
target_location_back = location_results # 基于坐标位置去重判断,如果失败就缓存本次位置信息
# 测试红线
# for point in results:
# point_list.append([point["u"], point["v"]])
# for point in message_point:
# cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
# (point["box"][2], point["box"][3]), (0, 255, 255), 2)
# cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)],
# isClosed=True, color=(0, 0, 255), thickness=2)
pic_count_hongxian = pic_count_hongxian + 1
# cv2.imwrite(f"save_pic\invade-hongxian\hongxianongly-{pic_count_hongxian}.jpg", frame_copy)
# if len(invade_point) > 0:
# print("hongxianhongxianhongxianhongxian")
# pic_count = pic_count + 1
# # cv2.imwrite(f"save_pic\invade\hongxian-{pic_count}.jpg", frame)
# drawn_frame = frame_copy # 关键修复:深拷贝绘制后的帧
# 图像编码
if not repeat_state: # 未触发去重逻辑,即执行图像上传逻辑
print(f"未发现重复,现在上传{str_loca}======{show_des}")
def encode_origin_frame():
success, buffer = cv2.imencode(".jpg", frame)
return buffer.tobytes() if success else None
def encode_frame():
success, buffer = cv2.imencode(".jpg", frame_copy)
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
# cv2.imwrite(f"save_pic\qm\hongxian-encode-{pic_count_hongxian}.jpg", frame)
#
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
# cv2.imwrite(f"save_pic\gdcl\hongxian-encode-{pic_count_hongxian}.jpg", frame_copy)
return buffer.tobytes() if success else None
buffer_bytes = await loop.run_in_executor(invade_executor, encode_frame)
buffer_origin_bytes = await loop.run_in_executor(invade_executor, encode_origin_frame)
if not buffer_bytes:
continue
# 并行处理上传和MQTT发布
async def upload_and_publish():
# 上传到MinIO
def upload_minio():
minio_path, file_type = upload_frame_buff_from_buffer(buffer_bytes, None)
minio_origin_path, file_type = upload_frame_buff_from_buffer(buffer_origin_bytes, None)
return minio_path, minio_origin_path, file_type
# return upload_frame_buff_from_buffer(buffer_bytes, None)
minio_path, minio_origin_path, file_type = await loop.run_in_executor(
invade_executor, upload_minio
)
print(f"minio_pathminio_pathminio_pathminio_path {minio_path}")
# 构造消息
message = {
"task_id": task_id,
"time": str(datetime.now()),
"detection_id": timestamp,
"minio": {"minio_path": minio_path,
"minio_origin_path": minio_origin_path,
"file_type": file_type},
"box_detail": [{
"model_id": model_func_id,
"cls_count": cls_count,
"box_count": [message_point], # 特殊处理
"location_results": location_results # 增加经纬度信息
}],
"osd_location": {
"longitude": cam_longitude,
"latitude": cam_latitude
},
"des_location":des_location_result
}
if not event_queue.full():
await event_queue.put({
"timestamp": timestamp # 存储事件触发的时刻,用作视频制作
})
else:
logger.warning("event_queue 帧队列已满等待1ms后重试")
await asyncio.sleep(0.001)
print(f"hongxianhongxianhongxianhongxian上传 {message}:{event_queue.qsize()} ")
message_json = json.dumps(message, ensure_ascii=False)
await mqtt.publish(mqtt_publish_topic, message_json)
asyncio.create_task(upload_and_publish())
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
# async with invade_cache_lock:
# # 读取或修改共享变量
# send_count = shared_local_cache["invade_send_count"]
# send_count = send_count + 1
# shared_local_cache["invade_send_count"] = send_count
# if send_count > 1:
# # 创建独立任务执行上传和发布
# shared_local_cache["invade_send_count"] = 0
# asyncio.create_task(upload_and_publish())
except Exception as e:
print(f"cal_des_invade 错误: {e}")
import traceback
traceback.print_exc()
await asyncio.sleep(0.1)
print("cal_des_invade读取线程已停止")
# 全局共享变量
shared_local_cache = {
"send_count": 0,
"invade_send_count": 0,
"video_process_status": 0 # 0、1、2 分别表示录像识别的三种状态,未开始、开始、结束
}
from asyncio import Lock, AbstractEventLoop
cache_lock = Lock() # 用于保护共享变量的锁
invade_cache_lock = Lock() # 用于保护共享变量的锁
async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event,
cv_frame_queue: asyncio.Queue,
event_queue: asyncio.Queue = None,
device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1):
global stats
start_time = time.time()
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
track_filter = TrackIDEventFilter(max_inactive_time=10)
reported_track_ids = defaultdict(float)
count_pic = 0
report_interval = 8
# loop = asyncio.get_running_loop()
local_func_cache = {
"func_100000": None,
"func_100004": None, # 存储缓存缓存人员track_id
"func_100006": None # 存储缓存缓存车辆track_id
}
para = {
"category": 3
}
target_location_back = [] # 本地缓存,用作位置重复计算
current_time_second = int(time.time())
repeat_time_count = 0
while not cancel_flag.is_set():
upload_start = time.time()
try:
# 检查队列长度,避免堆积
if cv_frame_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
print(f"警告cv_frame_queue 积压(当前长度={cv_frame_queue.qsize()}),清空队列")
while not cv_frame_queue.empty():
await cv_frame_queue.get()
print("队列已清空,等待新数据...")
await asyncio.sleep(0.1)
continue
# 获取队列数据
try:
cv_frame = await asyncio.wait_for(cv_frame_queue.get(), timeout=0.05)
# 检查数据类型
if not isinstance(cv_frame, dict):
print(f"⚠️ 警告cv_frame 不是字典,而是 {type(cv_frame)}")
continue
except asyncio.TimeoutError:
continue
if repeat_time > 0: # 基于时间的重复计算,是否使能
read_cv_time_second = int(time.time())
# print(f"read_cv_time_secondread_cv_time_second {read_cv_time_second}")
if read_cv_time_second - current_time_second < repeat_time and repeat_time_count > 0:
# print("触发事件去重")
continue
else:
print("没有触发事件去重")
repeat_time_count += 1 # 防止丢失第一个目标
current_time_second = read_cv_time_second
# 准备数据
frame_copy = cv_frame['frame_copy']
frame = cv_frame['frame']
target_point = cv_frame['target_point']
detections = cv_frame['detections']
detections_list = cv_frame['detections_list']
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
timestamp = cv_frame.get('timestamp', time.time_ns()) # 默认空字典
air_alti = cv_frame['osd_info']
img_height, img_width = frame.shape[:2]
gimbal_yaw = air_alti.gimbal_yaw
gimbal_pitch = air_alti.gimbal_pitch
gimbal_roll = air_alti.gimbal_roll
height = air_alti.height
cam_longitude = air_alti.longitude
cam_latitude = air_alti.latitude
# 初始化默认值
frame11 = frame_copy # 默认使用原始帧
box_detail = [] # 默认空列表
current_time = time.time()
h = device_height
us = []
vs = []
heights = []
should_report = True
print(f"target_pointtarget_point {len(target_point)}")
count_item = 0
des_location_result=[]
for item in target_point:
# # 跳过无效的track_id
# 检查是否应该上报该track_id
u = item["u"]
v = item["v"]
cls_id = item["cls_id"]
track_id = item["track_id"]
new_track_id = item["new_track_id"]
# should_report = True
# # 如果这个track_id已经上报过检查是否超过上报间隔
# if new_track_id in reported_track_ids:
# last_report_time = reported_track_ids[new_track_id]
# if current_time - last_report_time < report_interval:
# print(f"基于track_id触发去重事件{new_track_id}")
# should_report = False
# print(f"new_track_idnew_track_id {new_track_id} {track_id}")
# if should_report and track_filter.should_report(new_track_id):
# should_report = True
print(f"方法 send_frame——to {count_item}")
count_item += 1
if not track_filter.should_report(track_id):
should_report = False
break
# if track_id < 0: # 适配MultiYOLODetector类该类不支持追踪默认track_id为-1
# should_report = True
# 如果使用TrackIDEventFilter判断需要上报
if should_report:
us.append(u)
vs.append(v)
heights.append(h)
# location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
# height, cam_longitude, cam_latitude, img_width, img_height,
# heights)
#
# if not location_results:
# print("location_results is null")
# continue
if len(us) > 0:
print("进行侵限计算")
location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
height, cam_longitude, cam_latitude, img_width, img_height,
heights)
if not location_results:
continue
if location_results:
des1 = location_results[0]
des1_longitude = des1[0]
des1_latitude = des1[1]
des1_height = des1[2]
des_location_result.append({"longitude": des1_longitude,
"latitude": des1_latitude})
repeat_state = False
show_des = 0
str_loca = ""
if repeat_dis > 0: # ai_model_list repeat_dis 字段大于零,才启用去重
if len(target_location_back) > 0: # 当前逻辑并不严谨,只是比较了第一个位置信息
des1_back = target_location_back[0]
des1_back_longitude = des1_back[0]
des1_back_latitude = des1_back[1]
des1_back_height = des1_back[2]
des1 = location_results[0]
des1_longitude = des1[0]
des1_latitude = des1[1]
des1_height = des1[2]
str_loca = f"{des1_back_longitude}:{des1_back_latitude}---{des1_longitude}{des1_latitude}"
des = haversine(des1_back_longitude, des1_back_latitude, des1_longitude, des1_latitude)
show_des = des
if des < repeat_dis:
print(f"触发基于坐标判断重复,坐标距离:{des}")
repeat_state = True
else:
target_location_back = location_results # 未触发去重逻辑,即更新本地位置缓存
else:
target_location_back = location_results # 基于坐标位置去重判断,如果失败就缓存本次位置信息
if not repeat_state: # 未触发去重逻辑,即执行图像上传逻辑
for idx, model_para in enumerate(model_para_list):
if not isinstance(detections_list, list) or idx >= len(detections_list):
continue
detections = detections_list[idx] # 正确获取当前模型的检测结果
chinese_label = model_para.get("model_chinese_labe", {})
model_cls = model_para.get("model_cls_index", {})
list_func_id = model_para.get("model_list_func_id", -11)
func_id = model_para.get("func_id", [])
# 获取DRC消息同步操作放到线程池
local_drc_message = await loop.run_in_executor(upload_executor, get_local_drc_message)
if detections is None or len(detections.boxes) < 1:
continue
try:
# 图像处理和结果计算
frame11, box_detail1 = await loop.run_in_executor(
upload_executor,
cal_tricker_results,
frame_copy, detections, None,
func_id, local_func_cache, para, model_cls, chinese_label, list_func_id
)
except Exception as e:
print(f"处理帧时出错: {str(e)}")
import traceback
traceback.print_exc()
continue
count_pic = count_pic + 1
# 图像编码
def encode_frame():
success, buffer = cv2.imencode(".jpg", frame11)
return buffer.tobytes() if success else None
def encode_origin_frame():
success, buffer = cv2.imencode(".jpg", frame)
return buffer.tobytes() if success else None
buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
buffer_origin_bytes = await loop.run_in_executor(upload_executor, encode_origin_frame)
if not buffer_bytes:
continue
# 并行处理上传和MQTT发布
async def upload_and_publish():
# 上传到MinIO
def upload_minio():
minio_path, file_type = upload_frame_buff_from_buffer(buffer_bytes, None)
minio_origin_path, file_type = upload_frame_buff_from_buffer(buffer_origin_bytes, None)
return minio_path, minio_origin_path, file_type
minio_path, minio_origin_path, file_type = await loop.run_in_executor(
upload_executor, upload_minio
)
# 构造消息
message = {
"task_id": task_id,
"time": str(datetime.now()),
"detection_id": timestamp,
"minio": {"minio_path": minio_path, "minio_origin_path": minio_origin_path,
"file_type": file_type},
"box_detail": box_detail1,
"uav_location": local_drc_message,
"osd_location": {
"longitude": cam_longitude,
"latitude": cam_latitude
},
"des_location":des_location_result
}
await event_queue.put({
"timestamp": timestamp # 存储事件触发的时刻,用作视频制作
})
message_json = json.dumps(message, ensure_ascii=False)
await mqtt.publish(mqtt_topic, message_json)
asyncio.create_task(upload_and_publish())
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
# async with cache_lock:
# # 读取或修改共享变量
# send_count = shared_local_cache["send_count"]
# send_count = send_count + 1
# shared_local_cache["send_count"] = send_count
#
# # 伪代码,目前安全帽比较难识别到
# if 100014 == model_para_list[0]["model_list_func_id"]:
# if send_count > 2:
# # 创建独立任务执行上传和发布
# shared_local_cache["send_count"] = 0
# asyncio.create_task(upload_and_publish())
# else:
# if send_count > 30:
# # 创建独立任务执行上传和发布
# shared_local_cache["send_count"] = 0
# asyncio.create_task(upload_and_publish())
if 'frame' in locals():
upload_time = time.time() - upload_start
print(f"上传耗时: {upload_time:.4f}s")
except Exception as e:
print(f"send_frame_to_s3_mq 错误: {e}")
import traceback
traceback.print_exc()
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print("send_frame_to_s3 收到取消信号")
raise
# finally:
# log_perf('send_frame_to_s3_mq', start_time)
# 更新性能统计
stats['processed'] += 1
if time.time() - stats['last_time'] >= 1.0:
stats['avg_fps'] = stats['processed'] / (time.time() - stats['last_time'])
print(f"处理速度: {stats['avg_fps']:.2f} FPS")
stats['processed'] = 0
stats['last_time'] = time.time()
print("send_frame_to_s3线程已停止")
#
# def frames_to_video_bytes(frames, fps=25, format="mp4"):
# """
# 将多个帧转换为视频字节数组(内存处理,不存储到本地)
#
# Args:
# frames (list): 包含多个帧的列表,每个帧是 numpy 数组BGR 或 RGB 格式)
# fps (int): 视频帧率
# format (str): 视频格式(如 "mp4"、"avi"
#
# Returns:
# bytes: 视频的字节数组
# """
# if len(frames) == 0:
# raise ValueError("帧列表不能为空")
#
# # 确保所有帧是 uint8 类型
# frames = [np.array(frame, dtype=np.uint8) for frame in frames]
#
# # 使用 BytesIO 存储视频数据
# output = BytesIO()
#
# # 使用 imageio 写入内存
# with imageio.get_writer(output, format=format, fps=fps) as writer:
# for frame in frames:
# # 如果帧是 BGR 格式OpenCV 默认),转换为 RGB
# if frame.ndim == 3 and frame.shape[2] == 3:
# frame = frame[..., ::-1] # BGR → RGB
# writer.append_data(frame)
#
# # 获取字节数据
# video_bytes = output.getvalue()
# output.close()
#
# return video_bytes
#
# def frames_to_video_bytes(frames, fps=25, format="mp4"):
# """
# 将多个帧转换为视频字节数组(内存处理,不存储到本地)
# Args:
# frames (list): 包含多个帧的列表,每个帧是 numpy 数组BGR 或 RGB 格式)
# fps (int): 视频帧率
# format (str): 视频格式(如 "mp4"、"avi"
# Returns:
# bytes: 视频的字节数组 或 None如果失败
# """
# if len(frames) == 0:
# logger.warning("frames_to_video_bytes: 输入帧列表为空")
# return None
#
# try:
# # 确保所有帧是 uint8 类型
# frames = [np.array(frame, dtype=np.uint8) for frame in frames]
#
# output = BytesIO()
# with imageio.get_writer(output, format=format, fps=fps) as writer:
# for frame in frames:
# if frame.ndim == 3 and frame.shape[2] == 3:
# frame = frame[..., ::-1] # BGR → RGB
# writer.append_data(frame)
# return output.getvalue()
# except Exception as e:
# logger.error(f"frames_to_video_bytes 生成视频失败: {e}")
# return None
import subprocess
import numpy as np
import tempfile
import os
import time
import threading
import traceback
def frames_to_video_bytes(frames, fps=25, format="flv"):
"""
使用临时文件暂存帧数据,生成视频字节后自动删除临时文件
特点降低管道传输压力适配大尺寸帧和旧版FFMPEG
"""
# 日志工具:带时间戳和详细级别
def log_info(msg):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] INFO: {msg}")
def log_warning(msg):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] WARNING: {msg}")
def log_error(msg):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ERROR: {msg}")
def log_debug(msg):
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: {msg}")
# 强制切换为FLV格式兼容性最佳
if format != "flv":
log_warning(f"自动切换为FLV格式兼容性最佳")
# format = "flv"
# 输入帧合法性检查
if not isinstance(frames, list) or len(frames) == 0:
log_error("输入帧列表为空或非列表类型返回None")
return None
# 记录开始时间
start_time = time.time()
# 创建临时文件(自动删除)
temp_frame_file = None
temp_video_file = None
try:
# 创建存储帧数据的临时文件
temp_frame_fd, temp_frame_path = tempfile.mkstemp(suffix=".raw", text=False)
os.close(temp_frame_fd) # 关闭文件描述符,使用路径操作
temp_frame_file = open(temp_frame_path, 'wb')
log_debug(f"创建帧数据临时文件: {temp_frame_path}")
# 创建存储视频的临时文件
temp_video_fd, temp_video_path = tempfile.mkstemp(suffix=f".{format}", text=False)
os.close(temp_video_fd)
log_debug(f"创建视频临时文件: {temp_video_path}")
# --------------------------
# 步骤1帧预处理并写入临时文件
# --------------------------
try:
log_info("开始帧预处理并写入临时文件")
# 处理第一帧:确认基础信息
first_frame = np.asarray(frames[0], dtype=np.uint8)
if first_frame.ndim != 3 or first_frame.shape[2] != 3:
log_error("第一帧必须是3通道BGR/RGB图像")
return None
frame_h, frame_w = first_frame.shape[:2]
is_4k = frame_w >= 3840 and frame_h >= 2160
# 4K帧特殊优化
if is_4k:
single_frame_mb = (frame_w * frame_h * 3) / (1024 * 1024)
total_data_mb = single_frame_mb * len(frames)
log_info(f"检测到4K帧{frame_w}x{frame_h},单帧{single_frame_mb:.1f}MB总数据{total_data_mb:.1f}MB")
fps = min(fps, 12)
log_info(f"4K适配帧率调整为{fps}fps")
# 统一所有帧格式并写入临时文件
try:
from cv2 import resize
except ImportError:
log_error("未安装opencv-python请执行 pip install opencv-python")
return None
for idx, frame in enumerate(frames):
try:
# 1. 转为uint8类型
frame_uint8 = np.asarray(frame, dtype=np.uint8)
# 2. 统一尺寸
if frame_uint8.shape[:2] != (frame_h, frame_w):
log_debug(f"{idx + 1}尺寸不匹配,缩放至{frame_w}x{frame_h}")
frame_uint8 = resize(frame_uint8, (frame_w, frame_h), interpolation=1)
# 3. BGR转RGB
frame_rgb = frame_uint8[..., ::-1]
# 4. 写入临时文件
temp_frame_file.write(frame_rgb.tobytes())
log_debug(f"{idx + 1}已写入临时文件")
except Exception as e:
log_error(f"处理帧{idx + 1}失败:{str(e)}")
traceback.print_exc()
return None
# 确保所有数据写入磁盘
temp_frame_file.flush()
os.fsync(temp_frame_file.fileno())
temp_frame_file.close()
log_info(f"所有帧已写入临时文件,共{len(frames)}")
except Exception as e:
log_error(f"帧预处理阶段异常:{str(e)}")
traceback.print_exc()
return None
# --------------------------
# 步骤2构造FFMPEG命令使用临时文件
# --------------------------
try:
log_info("构造FFMPEG命令使用临时文件输入")
ffmpeg_cmd = [
# "ffmpeg",
"/usr/bin/ffmpeg",
"-hide_banner",
"-loglevel", "error",
"-y",
# 输入配置(从临时文件读取)
"-f", "rawvideo",
"-vcodec", "rawvideo",
"-s", f"{frame_w}x{frame_h}",
"-pix_fmt", "rgb24",
"-r", str(fps),
"-i", temp_frame_path, # 从临时文件读取帧数据
# 输出配置
"-vcodec", "libx264",
"-pix_fmt", "yuv420p",
"-r", str(fps),
"-preset", "ultrafast",
"-crf", "30",
# 输出到临时文件
"-f", format,
temp_video_path
]
log_info(f"FFMPEG命令{' '.join(ffmpeg_cmd)}")
except Exception as e:
log_error(f"构造FFMPEG命令失败{str(e)}")
traceback.print_exc()
return None
# --------------------------
# 步骤3执行FFMPEG处理
# --------------------------
try:
log_info("启动FFMPEG进程使用临时文件")
# 启动FFMPEG进程
process = subprocess.Popen(
ffmpeg_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=False
)
# 捕获错误输出
ffmpeg_error_log = []
def read_stderr():
while True:
err_line = process.stderr.readline()
if not err_line:
break
err_str = err_line.decode("utf-8", errors="replace").strip()
if err_str:
ffmpeg_error_log.append(err_str)
log_error(f"FFMPEG错误{err_str}")
stderr_thread = threading.Thread(target=read_stderr, daemon=True)
stderr_thread.start()
# 等待处理完成
log_info("等待FFMPEG处理完成超时120秒")
process.wait(timeout=120)
stderr_thread.join()
# 检查处理结果
if process.returncode != 0:
log_error(f"FFMPEG处理失败退出码{process.returncode}")
log_error(f"FFMPEG完整错误\n{chr(10).join(ffmpeg_error_log)}")
return None
# 读取视频文件内容
log_info("读取视频临时文件内容")
with open(temp_video_path, 'rb') as f:
video_bytes = f.read()
log_info(
f"视频生成成功!大小:{len(video_bytes) / 1024 / 1024:.2f}MB总耗时{time.time() - start_time:.2f}")
return video_bytes
except subprocess.TimeoutExpired:
log_error("FFMPEG处理超时超过120秒")
process.kill()
return None
except Exception as e:
log_error(f"FFMPEG处理异常{str(e)}")
traceback.print_exc()
return None
finally:
# 确保临时文件被删除
try:
if temp_frame_file and not temp_frame_file.closed:
temp_frame_file.close()
if temp_frame_path and os.path.exists(temp_frame_path):
os.remove(temp_frame_path)
log_debug(f"已删除帧临时文件: {temp_frame_path}")
except Exception as e:
log_warning(f"删除帧临时文件失败:{str(e)}")
try:
if temp_video_path and os.path.exists(temp_video_path):
os.remove(temp_video_path)
log_debug(f"已删除视频临时文件: {temp_video_path}")
except Exception as e:
log_warning(f"删除视频临时文件失败:{str(e)}")
log_info("视频生成流程结束")
async def cut_evnt_video_publish(task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event,
event_queue: asyncio.Queue = None,
timestamp_frame_queue: TimestampedQueue = None):
loop = asyncio.get_running_loop()
upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
while not cancel_flag.is_set():
try:
event = await event_queue.get()
print(f"[cut循环] 成功获取事件: {event}")
if not isinstance(event, dict) or "timestamp" not in event:
print(f"[cut循环] 无效事件,跳过")
continue
timestamp = event.get("timestamp")
if timestamp is None:
print("⚠️ 警告:事件中缺少 timestamp")
continue
matched_items = timestamp_frame_queue.query_by_timestamp(timestamp)
print(f"[cut循环] 匹配到 {len(matched_items)} 个帧")
if len(matched_items) == 0:
continue
frames = [item["frame"] for item in matched_items]
if not frames: # 显式检查帧列表是否为空
print("⚠️ 警告:匹配到的帧列表为空")
continue
video_bytes = frames_to_video_bytes(frames, fps=150, format="mp4")
if video_bytes is None: # 检查返回值
print("⚠️ 警告视频生成失败返回None")
continue
async def upload_and_publish():
def upload_minio():
return upload_video_buff_from_buffer(video_bytes, video_format="mp4")
try:
minio_path, file_type = await loop.run_in_executor(upload_executor, upload_minio)
# message = {"detection_id": timestamp, "minio_path": minio_path}
message = {
"task_id": task_id,
"time": str(datetime.now()),
"detection_id": timestamp,
"minio": {"minio_path": minio_path,
"minio_origin_path": minio_path,
"file_type": file_type},
"box_detail": [{
}],
"osd_location": {
},
"des_location": []
}
print(f"成功上传视频: {message}")
await mqtt.publish(mqtt_topic, json.dumps(message, ensure_ascii=False))
except Exception as e:
print(f"上传失败: {e}")
asyncio.create_task(upload_and_publish())
except Exception as e:
print(f"[cut循环] 处理事件出错: {e}")
finally:
if event is not None:
event_queue.task_done()
print(f"[cut循环] 确认消费,剩余未完成任务: {event_queue._unfinished_tasks}")
if event_queue.qsize() > 0 and event_queue._unfinished_tasks == 0:
try:
event_queue.put_nowait(None)
dummy = await event_queue.get()
event_queue.task_done()
except:
pass
async def start_rtmp_processing(video_url: str, task_id: str, model_configs: List[Dict],
mqtt_pub_ip: str, mqtt_pub_port: int, mqtt_pub_topic: str,
mqtt_sub_ip: str, mqtt_sub_port: int, mqtt_sub_topic: str,
output_rtmp_url: str,
invade_enable: bool, invade_file: str, camera_para_url: str,
device_height: float, repeat_dis: float, repeat_time: float):
# 初始化资源
# await initialize_resources()
logger.info(f"拉流地址{video_url}")
logger.info(f"推流地址{output_rtmp_url}")
cancel_flag = asyncio.Event()
# 初始化局部变量(避免全局污染)
frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
event_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE) # 存储事件,作为截取视频的标准
print(f"[start_rtmp] 创建event_queue实例{id(event_queue)}") # 打印内存地址
loop = asyncio.get_running_loop() # 声明事件循环管理
stream_containers: Dict[str, Any] = {}
# 创建取消标志
cancel_flag = asyncio.Event()
timestamp_frame_queue = TimestampedQueue(maxlen=1000) # 限定长度的队列,先进先出存储连续帧,事件触发即截取前后视频
try:
# 加载侵限区域数据
list_points = []
# 加载现场摄像头相关数据
camera_para = None
if invade_file and camera_para_url:
camera_file_path = downFile(camera_para_url)
camera_para = read_camera_params(camera_file_path)
invade_file_path = downFile(invade_file)
if invade_file_path:
with open(invade_file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for feature in data.get('features', []):
geometry = feature.get('geometry', {})
coordinates = geometry.get('coordinates', [])
points = [Point(coord[1], coord[0], coord[2], i)
for i, coord in enumerate(coordinates[0])]
list_points.append(points)
# 初始化MQTT和设备
mqtt = MQTTService(mqtt_pub_ip, port=mqtt_pub_port)
# detector = MultiYOLODetector_TrackId(model_configs)
detector = MultiYoloTrtDetectorTrackId(model_configs)
print(f"mqtt_sub_ipmqtt_sub_ipmqtt_sub_ip {mqtt_sub_ip} {mqtt_sub_port} {mqtt_sub_topic}")
# 创建MQTT设备任务
async def start_mqtt_device():
device = MQTTDevice(
ip=mqtt_sub_ip,
port=mqtt_sub_port,
topics=[mqtt_sub_topic],
queue_size=50
)
# 对osd_info_push 相关信息进行回调,否则消息读取非常慢
device.register_callback(
topic=mqtt_sub_topic,
method="osd_info_push",
callback=empty_osd_callback # 直接引用已定义的函数
)
await device.start()
return device
# device_task = asyncio.create_task(start_mqtt_device(), name="start_mqtt_device")
# await asyncio.sleep(2) # 等待设备初始化
#
# # 确保设备已启动
# device = await device_task
# 然后在 start_rtmp_processing 中直接等待:
device = await start_mqtt_device() # 直到连接和订阅完成才继续
if device is None:
raise RuntimeError("Failed to start MQTT device")
# 创建任务列表
tasks = []
read_task = None
process_task = None
write_task = None
invade_task = None
upload_tasks = []
model_count = len(model_configs)
try:
# RTMP读取任务
read_rtmp_frames_executor = ThreadPoolExecutor(max_workers=Config.READ_RTMP_WORKERS)
read_task = asyncio.create_task(
read_rtmp_frames(
loop,
read_rtmp_frames_executor,
video_url,
device,
mqtt_sub_topic,
"drc_camera_osd_info_push",
mqtt_sub_topic,
"osd_info_push",
cancel_flag,
frame_queue,
timestamp_frame_queue
),
name="read_rtmp_frames"
)
tasks.append(read_task)
#
# # 处理任务
process_frame_executor = ThreadPoolExecutor(max_workers=Config.PROCESS_FRAME_WORKERS)
process_task = asyncio.create_task(
process_frames(detector, cancel_flag, frame_queue, processed_queue),
name="process_frames"
)
tasks.append(process_task)
#
# # 推流任务
invade_state = bool(list_points) and invade_enable
write_frame_executor = ThreadPoolExecutor(max_workers=Config.WRITE_FRAME_WORKERS)
write_task = asyncio.create_task(
write_results_to_rtmp(
task_id,
output_rtmp_url,
None,
list_points,
camera_para,
invade_state,
cancel_flag,
processed_queue,
invade_queue,
cv_frame_queue,
stream_containers
),
name="write_results_to_rtmp"
)
tasks.append(write_task)
# # # 侵限检测任务
if invade_enable and list_points:
invade_executor = ThreadPoolExecutor(max_workers=Config.INVADE_WORKERS)
invade_task = asyncio.create_task(
cal_des_invade(
loop,
invade_executor,
task_id,
mqtt,
mqtt_pub_topic,
list_points,
camera_para,
model_count,
cancel_flag,
invade_queue,
event_queue,
device_height,
repeat_dis, repeat_time
),
name="cal_des_invade"
)
tasks.append(invade_task)
# # S3上传任务
upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
for _ in range(1):
upload_task = asyncio.create_task(
send_frame_to_s3_mq(loop, upload_executor, task_id, mqtt, mqtt_pub_topic,
cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis,
repeat_time),
name=f"send_frame_to_s3_mq_{_}"
)
upload_tasks.append(upload_task)
tasks.append(upload_task)
#
# # # 截取事件并将frame存储为video然后执行上传
# event_video_executor = ThreadPoolExecutor(max_workers=Config.EVENT_VIDEO_WORKERS)
# upload_video = asyncio.create_task(cut_evnt_video_publish(task_id,mqtt, mqtt_pub_topic, cancel_flag,
# event_queue, timestamp_frame_queue),
# name="cut_evnt_video_publish")
# tasks.append(upload_video)
# 注册任务到TaskManager
device_list = [mqtt]
task_info = {
"video_url": video_url,
"output_rtmp_url": output_rtmp_url,
"task_id": task_id,
"cancel_flag": cancel_flag,
"device_list": device_list
}
# 使用asyncio.shield保护主任务,确保任务不会提前停止
main_task = asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
# 注册任务
await task_manager.add_task(
task_id,
task_info,
main_task,
tasks
)
print("start_rtmp 1")
await task_manager.update_heartbeat(task_id)
print("start_rtmp 2")
# 等待任务完成或取消
try:
print("start_rtmp 3")
await main_task
except asyncio.CancelledError:
logger.info(f"Task {task_id} was cancelled")
raise
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
raise
finally:
# 确保所有任务都被取消
cancel_flag.set()
sleep(3)
print("start_rtmp 4")
for task in tasks:
print("start_rtmp 5")
task.cancel()
for upload_task in upload_tasks:
print("start_rtmp 6")
upload_task.cancel()
except Exception as e:
logger.error(f"Error in main task loop: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in start_rtmp_processing: {e}")
raise
finally:
# 确保所有任务被取消
print("start_rtmp 7")
cancel_flag.set()
print("start_rtmp 8")
# 释放cuda内存
detector.destroy()
# 停止所有任务
if read_task and not read_task.done():
print("start_rtmp 9")
read_task.cancel()
if process_task and not process_task.done():
print("start_rtmp 10")
process_task.cancel()
if write_task and not write_task.done():
print("start_rtmp 11")
write_task.cancel()
if invade_task and not invade_task.done():
print("start_rtmp 12")
invade_task.cancel()
for task in upload_tasks:
if task and not task.done():
task.cancel()
print("start_rtmp 13")
# 停止设备
if 'device' in locals():
await device.stop()
print("start_rtmp 14")
# 清理资源
await cleanup_resources()
await task_manager.remove_task(task_id)
logger.info(f"Task {task_id} resources cleaned up")
print(f"start_rtmp_processing线程已停止")
async def start_video_processing(minio_path: str, task_id: str, model_configs: List[Dict],
mqtt_ip: str, mqtt_port: int, mqtt_topic: str, output_rtmp_url: str,
invade_enable: bool, invade_file: str, camera_para_url: str, device_height: float,
repeat_dis: float, repeat_time: float):
# global stop_event, frame_queue, processed_queue, executor, upload_executor
# await initialize_resources() # 初始化资源
print("开起录像识别")
cancel_flag = asyncio.Event()
# 初始化局部变量(避免全局污染)
frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
event_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE) # 存储事件,作为截取视频的标准
stream_containers: Dict[str, Any] = {}
timestamp_frame_queue = TimestampedQueue(maxlen=500) # 限定长度的队列,先进先出存储连续帧,事件触发即截取前后视频
try:
download_path = downBigFile(minio_path)
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut_output.mp4"
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut.mp4"
# download_path = r"E:\hami\aqm_0829-55-7_cut.mp4"
# # 近距离 1/3
# download_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4"
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
# 近距离 1/5
# # download_path = r"C:\Users\14867\Downloads\DJI_20250912165859_0001_V.mp4"
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
# download_path = r"C:\Users\14867\Downloads\DJI_20250918175331_0001_V.mp4"
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
# download_path = r"C:\Users\14867\Downloads\DJI_20250917114158_0001_V.mp4"
except Exception as e:
logger.error(f"Unexpected error downloading big file: {e}", exc_info=True)
return None
if not download_path:
logger.error(f"big file is none", exc_info=True)
return
print(f"download_path路径 {os.path.abspath(download_path)}")
dir_name = os.path.dirname(download_path)
srt_name = os.path.basename(download_path) + ".srt"
srt_path = os.path.join(dir_name, srt_name)
if os.path.exists(srt_path):
os.remove(srt_path)
command = [
"ffmpeg",
"-i", download_path,
"-map", "0:s:0", # 选择第一个字幕流
"-c:s", "srt", # 强制转换为 SRT 格式(可选)
srt_path
]
# # 近距离 1/3
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
# # # # 近距离 1/5
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
# srt_path = r"C:\Users\14867\Downloads\DJI_20250918175331_0001_V.mp4.srt"
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
# srt_path = r"C:\Users\14867\Downloads\DJI_20250917114158_0001_V.mp4.srt"
try:
subprocess.run(command, check=True)
print(f"start_video_processing方法 字幕提取成功,保存至: {srt_path}")
except subprocess.CalledProcessError as e:
print(f"start_video_processing方法错误: FFmpeg 执行失败 - {e}")
return
except FileNotFoundError:
print("start_video_processing方法错误: 未安装 FFmpeg请先安装并添加到系统路径。")
return
print(f"srt_path路径 {os.path.abspath(srt_path)}")
try:
list_points = [] # 二维数组,里面的一维数组就是面
camera_para = None
if invade_file and camera_para_url:
camera_file_path = downFile(camera_para_url)
camera_para = read_camera_params(camera_file_path)
invade_file_path = downFile(invade_file)
if invade_file_path is None:
print(f"invade_file is None task_id:{task_id}")
return
with open(invade_file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 提取多边形坐标
features = data.get('features', [])
if not features:
print("没有找到有效的要素数据")
return
# 获取第一个多边形的坐标(假设只有一个多边形)
for polygon in features:
# polygon = features[0]
geometry = polygon.get('geometry', {})
coordinates = geometry.get('coordinates', [])
# 提取经纬度点忽略Z值
points = []
for key, coord in enumerate(coordinates[0]): # 取第一个环的坐标
# 只取经纬度,忽略第三个值(高度) 注意这里是反序不然后面convert_points_to_utm 会报错
points.append(Point(coord[1], coord[0], coord[2], key))
list_points.append(points)
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
# 初始化检测器
# detector = MultiYOLODetector(model_configs)
detector = MultiYOLODetector_TrackId(model_configs)
mqtt_publish_topic = mqtt_topic
# 读取视频获取帧率
cap = cv2.VideoCapture(download_path)
if not cap.isOpened():
print("Error: Could not open video.")
return
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
invade_state = False
if len(list_points) > 0 and invade_enable:
invade_state = True
model_count = len(model_configs)
# 创建任务列表
tasks = []
read_task = None
process_task = None
write_task = None
invade_task = None
upload_tasks = []
# RTMP读取任务
read_task = asyncio.create_task(
read_video_frames(task_id, mqtt, mqtt_publish_topic, download_path, srt_path, frame_queue,
timestamp_frame_queue, cancel_flag),
name="read_rtmp_frames")
tasks.append(read_task)
# 处理任务
process_task = asyncio.create_task(
process_frames(detector, cancel_flag, frame_queue, processed_queue),
name="process_frames"
)
tasks.append(process_task)
# 推流任务
invade_state = bool(list_points) and invade_enable
write_task = asyncio.create_task(
write_results_to_rtmp(
task_id,
output_rtmp_url,
fps,
list_points,
camera_para,
invade_state,
cancel_flag,
processed_queue,
invade_queue,
cv_frame_queue,
stream_containers
),
name="write_results_to_rtmp"
)
tasks.append(write_task)
# # # 侵限文件不为空,即输出当前事件
# if len(list_points) > 0 and invade_enable:
# print("基于录像开起侵限识别")
# tasks.append(asyncio.create_task(
# cal_des_invade(task_id, mqtt, mqtt_publish_topic, list_points, cancel_flag)))
# 侵限检测任务
if invade_enable and list_points:
invade_task = asyncio.create_task(
cal_des_invade(
task_id,
mqtt,
mqtt_publish_topic,
list_points,
camera_para,
model_count,
cancel_flag,
invade_queue,
event_queue,
device_height,
repeat_dis, repeat_time
),
name="cal_des_invade"
)
tasks.append(invade_task)
# S3上传任务
for _ in range(2):
upload_task = asyncio.create_task(
send_frame_to_s3_mq(task_id, mqtt, mqtt_topic,
cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis, repeat_time),
name=f"send_frame_to_s3_mq_{_}"
)
upload_tasks.append(upload_task)
tasks.append(upload_task)
# # # 截取事件并将frame存储为video然后执行上传
upload_video = asyncio.create_task(cut_evnt_video_publish(task_id, mqtt, mqtt_topic, cancel_flag,
event_queue, timestamp_frame_queue),
name="cut_evnt_video_publish")
tasks.append(upload_video)
# 创建多个上传任务并行处理
# for _ in range(2): # 2个上传消费者
# tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_topic, cancel_flag)))
task_info = {
"output_rtmp_url": output_rtmp_url,
"task_id": task_id,
"cancel_flag": cancel_flag,
}
# 使用asyncio.shield保护主任务
main_task = asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
# 注册任务
await task_manager.add_task(
task_id,
task_info,
main_task,
tasks
)
print("start_rtmp 1")
await task_manager.update_heartbeat(task_id)
print("start_rtmp 2")
# 等待任务完成或取消
try:
print("start_rtmp 3")
await main_task
except asyncio.CancelledError:
logger.info(f"Task {task_id} was cancelled")
raise
except Exception as e:
logger.error(f"Task {task_id} failed: {e}")
raise
finally:
# 确保所有任务都被取消
cancel_flag.set()
sleep(3)
# 释放cuda内存
detector.destroy()
print("start_rtmp 4")
for task in tasks:
print("start_rtmp 5")
task.cancel()
for upload_task in upload_tasks:
print("start_rtmp 6")
upload_task.cancel()
finally:
# 确保所有任务被取消
print("start_rtmp 7")
cancel_flag.set()
print("start_rtmp 8")
# 停止所有任务
if read_task and not read_task.done():
print("start_rtmp 9")
read_task.cancel()
if process_task and not process_task.done():
print("start_rtmp 10")
process_task.cancel()
if write_task and not write_task.done():
print("start_rtmp 11")
write_task.cancel()
if invade_task and not invade_task.done():
print("start_rtmp 12")
invade_task.cancel()
for task in upload_tasks:
if task and not task.done():
task.cancel()
print("start_rtmp 13")
# # 停止设备
# if 'device' in locals():
# await device.stop()
# print("start_rtmp 14")
# 清理资源
await cleanup_resources()
await task_manager.remove_task(task_id)
logger.info(f"Task {task_id} resources cleaned up")
print(f"start_rtmp_processing线程已停止")