3166 lines
136 KiB
Python
3166 lines
136 KiB
Python
import logging
|
||
import math
|
||
import os.path
|
||
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from datetime import datetime
|
||
from time import sleep
|
||
from typing import List, Dict, Any, Optional, Tuple
|
||
from dataclasses import dataclass
|
||
import json
|
||
import time
|
||
|
||
from ultralytics import YOLO
|
||
import torch
|
||
|
||
# 假设其他导入模块正确存在
|
||
from cv_back_video import cal_tricker_results, get_local_drc_message, read_drc_mqtt, async_read_drc_mqtt
|
||
from middleware.MQTTService import MQTTService, MQTTDevice, empty_osd_callback
|
||
from middleware.TaskManager import task_manager
|
||
from middleware.entity.air_attitude import Air_Attitude
|
||
from middleware.entity.camera_para import read_camera_params, Camera_Para
|
||
from middleware.entity.detection import DetectionResult, DetectionResultList
|
||
from middleware.entity.timestamp_queue import TimestampedQueue
|
||
from middleware.minio_util import upload_file_from_buffer, upload_frame_buff_from_buffer, downFile, \
|
||
downBigFile, upload_video_buff_from_buffer
|
||
import av
|
||
import asyncio
|
||
import cv2
|
||
import numpy as np
|
||
from typing import Dict, Any
|
||
from ultralytics import solutions
|
||
from middleware.read_srt import parse_srt_file
|
||
from touying.ImageReproject_python.cal_func import cal_canv_location_by_osd, red_line_reproject
|
||
from touying.ImageReproject_python.img_types import Point
|
||
from yolo.detect.multi_yolo_trt_detect_track import MultiYoloTrtDetectorTrackId
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# 全局配置
|
||
class Config:
|
||
MAX_WORKERS = 8 # 图像识别线程池大小
|
||
FRAME_QUEUE_SIZE = 5 # 帧队列大小
|
||
PROCESSED_QUEUE_SIZE = 5 # 处理后帧队列大小
|
||
RETRY_COUNT = 3 # 模型加载重试次数
|
||
READ_RTMP_WORKERS = 2 # 读取专用线程池大小
|
||
PROCESS_FRAME_WORKERS = 2 # 推理专用线程池大小
|
||
UPLOAD_WORKERS = 2 # 上传专用线程池大小
|
||
WRITE_FRAME_WORKERS = 2 # 视频推流专用线程池大小
|
||
INVADE_WORKERS = 2 # 侵限专用线程池大小
|
||
EVENT_VIDEO_WORKERS = 2 # 事件录像专用线程池大小
|
||
MAX_QUEUE_WARN_SIZE = 15 # 队列警告阈值(PROCESSED_QUEUE_SIZE的一半)
|
||
MODEL_INPUT_SIZE = (640, 640) # YOLO模型输入尺寸(确保能被32整除)
|
||
DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" # 默认使用第一个GPU
|
||
TARGET_FPS = 25 # 目标推流帧率
|
||
MAX_FRAME_DROP = 5 # 最大允许丢帧数(避免队列积压)
|
||
PERF_LOG_INTERVAL = 5.0 # 性能日志输出间隔(秒)
|
||
|
||
|
||
# 全局性能统计变量
|
||
perf_stats = {
|
||
'read_rtmp_frames': {'count': 0, 'total_time': 0.0},
|
||
'process_frames': {'count': 0, 'total_time': 0.0},
|
||
'write_results_to_rtmp': {'count': 0, 'total_time': 0.0},
|
||
'cal_des_invade': {'count': 0, 'total_time': 0.0},
|
||
'send_frame_to_s3_mq': {'count': 0, 'total_time': 0.0},
|
||
'last_log_time': time.time()
|
||
}
|
||
|
||
|
||
def log_perf(func_name, start_time):
|
||
"""记录函数执行时间和调用次数"""
|
||
end_time = time.time()
|
||
duration = end_time - start_time
|
||
|
||
# 更新统计
|
||
perf_stats[func_name]['count'] += 1
|
||
perf_stats[func_name]['total_time'] += duration
|
||
|
||
# 定期输出性能统计
|
||
current_time = time.time()
|
||
if current_time - perf_stats['last_log_time'] >= Config.PERF_LOG_INTERVAL:
|
||
print("\n=== 性能统计 ===")
|
||
for name, stats in perf_stats.items():
|
||
if name not in ['last_log_time']:
|
||
avg_time = stats['total_time'] / stats['count'] if stats['count'] > 0 else 0
|
||
print(f"{name}: 调用次数={stats['count']}, 平均耗时={avg_time:.4f}s")
|
||
print("===============\n")
|
||
perf_stats['last_log_time'] = current_time
|
||
|
||
|
||
# # 创建专用线程池
|
||
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
|
||
# frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
|
||
# processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
# invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
# cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
||
# stop_event = asyncio.Event()
|
||
|
||
|
||
# 模型缓存
|
||
model_cache = {}
|
||
|
||
# 性能统计
|
||
stats = {
|
||
'processed': 0,
|
||
'last_time': time.time(),
|
||
'avg_fps': 0.0,
|
||
'dropped_frames': 0
|
||
}
|
||
|
||
|
||
# 修改 cleanup_resources 函数,确保正确取消所有协程
|
||
async def cleanup_resources(upload_executor, executor, frame_queue,
|
||
processed_queue, invade_queue, cv_frame_queue,
|
||
stop_event):
|
||
"""清理资源"""
|
||
# global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
|
||
print("cleanup_resources 1")
|
||
stop_event.set() # 首先设置取消标志
|
||
print("cleanup_resources 2")
|
||
|
||
# 等待所有队列清空(添加超时防止死锁)
|
||
await asyncio.wait_for(asyncio.gather(
|
||
_empty_queue(frame_queue),
|
||
_empty_queue(processed_queue),
|
||
_empty_queue(invade_queue),
|
||
_empty_queue(cv_frame_queue)
|
||
), timeout=5.0)
|
||
|
||
print("cleanup_resources 8")
|
||
# 关闭线程池
|
||
if 'executor' in globals() and executor is not None:
|
||
print("cleanup_resources 9")
|
||
executor.shutdown(wait=True, cancel_futures=True)
|
||
if 'upload_executor' in globals() and upload_executor is not None:
|
||
print("cleanup_resources 10")
|
||
upload_executor.shutdown(wait=True, cancel_futures=True)
|
||
print("cleanup_resources 11")
|
||
|
||
print("cleanup_resources 14")
|
||
|
||
|
||
async def _empty_queue(queue):
|
||
"""安全清空队列"""
|
||
try:
|
||
while not queue.empty():
|
||
await queue.get()
|
||
except Exception as e:
|
||
logger.warning(f"清空队列时出错: {e}")
|
||
|
||
|
||
class MultiDetectionResults:
|
||
def __init__(self):
|
||
self.boxes = [] # 边界框
|
||
self.clss = [] # 类别ID
|
||
self.cls_names = [] # 类别名称
|
||
self.cls_en_names = [] # 英文类别名称
|
||
self.confs = [] # 置信度
|
||
self.track_ids = [] # 置信度
|
||
|
||
|
||
class MultiYOLODetector:
|
||
"""多模型并行检测器,修复多GPU设备不匹配问题"""
|
||
|
||
def __init__(self, model_configs: List[Dict]):
|
||
self.models = []
|
||
self.class_maps = []
|
||
self.executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
||
self.allowed_classes = []
|
||
|
||
self.model_cls = []
|
||
self.chinese_label = []
|
||
self.list_func_id = []
|
||
self.conf = 0.4
|
||
self.func_id = -1
|
||
self.list_class_names = []
|
||
self.list_para_invade_enable = []
|
||
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
|
||
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
|
||
self._setup_multi_gpu()
|
||
|
||
for config in model_configs:
|
||
model_path = config['path']
|
||
cls_map = config.get('cls_map', {})
|
||
allowed = config.get('allowed_classes', None)
|
||
tracking = config.get('tracking', True) # 是否启用跟踪
|
||
|
||
model_cls_index = config.get('cls_index', [])
|
||
model_chinese_labe = config.get('chinese_label', [])
|
||
model_list_func_id = config.get('list_func_id', [])
|
||
func_id = config.get('func_id', -1)
|
||
class_names = config.get('class_names', True)
|
||
para_invade_enable = config.get('para_invade_enable', False)
|
||
config_conf = config.get('config_conf', 0.4)
|
||
|
||
# 加载模型并确保在正确设备上
|
||
model = self._load_model(model_path, tracking)
|
||
self.models.append(model)
|
||
self.class_maps.append(cls_map)
|
||
self.allowed_classes.append(allowed)
|
||
self.model_cls.append(model_cls_index)
|
||
self.chinese_label.append(model_chinese_labe)
|
||
self.list_func_id.append(model_list_func_id)
|
||
self.func_id = func_id
|
||
self.conf = config_conf
|
||
self.list_class_names.append(class_names)
|
||
self.list_para_invade_enable.append(para_invade_enable)
|
||
|
||
def _setup_multi_gpu(self):
|
||
"""配置多GPU环境"""
|
||
if torch.cuda.is_available():
|
||
gpu_count = torch.cuda.device_count()
|
||
print(f"检测到 {gpu_count} 个GPU,将使用默认设备: {self.device}")
|
||
# 强制所有操作使用默认设备
|
||
torch.cuda.set_device(self.device)
|
||
|
||
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
|
||
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
|
||
try:
|
||
# 加载模型
|
||
model = YOLO(model_path)
|
||
# model.tracker = None # 手动禁用跟踪,关闭模型相关日志输出
|
||
|
||
# 移动到目标设备
|
||
model = model.to(self.device)
|
||
|
||
# 启用跟踪(仅对YOLOv8有效)
|
||
if hasattr(model, 'tracker'):
|
||
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
|
||
print(f"已为模型 {model_path} 启用跟踪")
|
||
else:
|
||
print(f"警告:模型 {model_path} 不支持跟踪功能")
|
||
|
||
# 对于多GPU,仅在必要时使用DataParallel,并确保所有参数在同一设备
|
||
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
||
# 确保模型参数在默认设备上
|
||
model = model.to(self.device)
|
||
model = torch.nn.DataParallel(model, device_ids=[self.device])
|
||
|
||
# 禁用模型融合以避免多设备问题
|
||
model.fuse = lambda verbose=False: model # 替换融合方法
|
||
|
||
return model
|
||
except Exception as e:
|
||
print(f"模型加载失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
|
||
"""预处理帧:调整大小、转换格式并标准化"""
|
||
# 调整大小
|
||
resized = cv2.resize(frame, input_size)
|
||
# 转换为RGB
|
||
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
||
# 转换为张量并标准化
|
||
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
|
||
# 添加批次维度并移动到目标设备
|
||
tensor = tensor.unsqueeze(0).to(device)
|
||
return tensor
|
||
|
||
@staticmethod
|
||
def scale_bbox(bbox: List[int], original_size: tuple, input_size: tuple) -> List[int]:
|
||
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
|
||
oh, ow = original_size
|
||
iw, ih = input_size
|
||
|
||
# 计算缩放比例
|
||
scale_w = ow / iw
|
||
scale_h = oh / ih
|
||
|
||
# 缩放边界框
|
||
x1, y1, x2, y2 = bbox
|
||
return [
|
||
int(x1 * scale_w),
|
||
int(y1 * scale_h),
|
||
int(x2 * scale_w),
|
||
int(y2 * scale_h)
|
||
]
|
||
|
||
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
|
||
"""异步调用多模型预测"""
|
||
loop = asyncio.get_running_loop()
|
||
original_size = (frame.shape[0], frame.shape[1]) # (height, width)
|
||
|
||
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
|
||
model = self.models[model_idx]
|
||
cls_map = self.class_maps[model_idx]
|
||
allowed = self.allowed_classes[model_idx]
|
||
|
||
model_cls_index = self.model_cls[model_idx]
|
||
model_chinese_labe = self.chinese_label[model_idx]
|
||
model_list_func_id = self.list_func_id[model_idx]
|
||
func_id = self.func_id
|
||
model_class_names = self.list_class_names[model_idx]
|
||
para_invade_enable = self.list_para_invade_enable[model_idx]
|
||
|
||
# 预处理帧并确保在正确设备上
|
||
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
|
||
|
||
timestart = time.time()
|
||
with torch.no_grad():
|
||
# 确保输入和模型在同一设备
|
||
if input_tensor.device != next(model.parameters()).device:
|
||
input_tensor = input_tensor.to(next(model.parameters()).device)
|
||
results = model(input_tensor, verbose=False)
|
||
timeend = time.time()
|
||
# print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s (设备: {self.device})")
|
||
|
||
detections = []
|
||
model_para = {
|
||
"cls_map": cls_map,
|
||
"model_chinese_labe": model_chinese_labe,
|
||
"model_cls_index": model_cls_index,
|
||
"model_list_func_id": model_list_func_id,
|
||
"func_id": func_id,
|
||
"model_class_names": model_class_names,
|
||
"para_invade_enable": para_invade_enable,
|
||
"results": results
|
||
}
|
||
# model_cls_index{}
|
||
detection_result_list = DetectionResultList([], [], [], [], [])
|
||
for result in results:
|
||
boxes = result.boxes.cpu().numpy() # 移动到CPU进行处理
|
||
tracks = result.boxes.id.cpu().numpy() if result.boxes.id is not None else None
|
||
|
||
for i, box in enumerate(boxes):
|
||
cls_id = int(box.cls[0])
|
||
# 获取原始模型(处理DataParallel包装)
|
||
original_model = model.module if isinstance(model, torch.nn.DataParallel) else model
|
||
cls_name = original_model.names[cls_id]
|
||
|
||
# 过滤不允许的类别
|
||
if allowed and cls_name not in allowed:
|
||
continue
|
||
|
||
# 模型输出的边界框是相对于输入尺寸的
|
||
x1, y1, x2, y2 = box.xyxy[0].astype(int)
|
||
# 缩放回原始图像尺寸
|
||
scaled_bbox = self.scale_bbox(
|
||
[x1, y1, x2, y2],
|
||
original_size,
|
||
self.input_size
|
||
)
|
||
|
||
conf = float(box.conf[0])
|
||
# 临时测试用,后面需要抽取出来参数
|
||
|
||
if conf < self.conf:
|
||
continue
|
||
|
||
en_name = cls_map.get(cls_name, "unknown")
|
||
# 添加跟踪ID(如果存在)
|
||
track_id = int(tracks[i]) if tracks is not None else None
|
||
|
||
detection_result_list.boxes.append(scaled_bbox)
|
||
detection_result_list.clss.append(cls_id)
|
||
detection_result_list.clss_name.append(cls_name)
|
||
detection_result_list.confs.append(conf)
|
||
detection_result_list.track_ids.append(-1)
|
||
|
||
detections.append(DetectionResult(
|
||
bbox=scaled_bbox,
|
||
class_id=cls_id,
|
||
class_name=cls_name,
|
||
confidence=conf, # 新增字段
|
||
track_id=-1
|
||
))
|
||
|
||
return detections, detection_result_list, model_para
|
||
|
||
# 并行执行
|
||
futures = [
|
||
loop.run_in_executor(
|
||
self.executor,
|
||
_predict,
|
||
model_idx,
|
||
frame.copy()
|
||
)
|
||
for model_idx in range(len(self.models))
|
||
]
|
||
results = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
|
||
|
||
# 合并检测结果
|
||
all_detections = []
|
||
detections_list = [] # 这里的格式,跟results 的二次计算相同
|
||
# 合并所有模型的参数到一个字典(避免键冲突)
|
||
all_model_paras = []
|
||
for model_idx, (detections, detection_result_list, model_para) in enumerate(results):
|
||
all_detections.extend(detections)
|
||
# 给每个模型的参数加前缀,避免覆盖(如 model0_model_cls_index)
|
||
# prefixed_para = {f"model{model_idx}_{k}": v for k, v in model_para.items()}
|
||
# all_model_paras.update(prefixed_para)
|
||
all_model_paras.append(model_para)
|
||
detections_list.append(detection_result_list)
|
||
|
||
return all_detections, detections_list, all_model_paras
|
||
|
||
|
||
class MultiYOLODetector_TrackId:
|
||
"""多模型并行检测器,修复多GPU设备不匹配问题"""
|
||
|
||
def __init__(self, model_configs: List[Dict]):
|
||
|
||
self.models = []
|
||
self.class_maps = []
|
||
self.executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
||
self.allowed_classes = []
|
||
self.conf = 0.4
|
||
self.model_cls = []
|
||
self.chinese_label = []
|
||
self.list_func_id = []
|
||
self.func_id = -1
|
||
self.list_class_names = []
|
||
self.list_para_invade_enable = []
|
||
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
|
||
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
|
||
self._setup_multi_gpu()
|
||
|
||
# 初始化ObjectCounter实例
|
||
self.object_counters = []
|
||
|
||
for config in model_configs:
|
||
model_path = config['path']
|
||
cls_map = config.get('cls_map', {})
|
||
allowed = config.get('allowed_classes', None)
|
||
tracking = config.get('tracking', True) # 是否启用跟踪
|
||
|
||
model_cls_index = config.get('cls_index', True)
|
||
model_chinese_labe = config.get('chinese_label', {})
|
||
model_list_func_id = config.get('list_func_id', -11)
|
||
func_id = config.get('func_id', True)
|
||
class_names = config.get('class_names', True)
|
||
para_invade_enable = config.get('para_invade_enable', False)
|
||
config_conf = config.get('config_conf', 0.4)
|
||
|
||
# 加载模型并确保在正确设备上
|
||
model = self._load_model(model_path, tracking)
|
||
self.models.append(model)
|
||
self.class_maps.append(cls_map)
|
||
self.allowed_classes.append(allowed)
|
||
self.model_cls.append(model_cls_index)
|
||
self.chinese_label.append(model_chinese_labe)
|
||
self.list_func_id.append(model_list_func_id)
|
||
self.func_id = func_id
|
||
self.conf = config_conf
|
||
self.list_class_names.append(class_names)
|
||
self.list_para_invade_enable.append(para_invade_enable)
|
||
|
||
# 为每个模型创建ObjectCounter实例
|
||
if hasattr(model, 'tracker'):
|
||
# 如果模型支持跟踪,则创建带跟踪的ObjectCounter
|
||
self.object_counters.append(
|
||
solutions.ObjectCounter(
|
||
show=False,
|
||
region=None, # 可以根据需要设置检测区域
|
||
model=model_path,
|
||
classes=model_cls_index if allowed else None,
|
||
verbose=False # 新增:禁用性能日志打印
|
||
|
||
)
|
||
)
|
||
else:
|
||
# 如果模型不支持跟踪,则创建普通ObjectCounter
|
||
self.object_counters.append(
|
||
solutions.ObjectCounter(
|
||
show=False,
|
||
region=None,
|
||
model=model_path,
|
||
classes=model_cls_index if allowed else None,
|
||
verbose=False # 新增:禁用性能日志打印
|
||
|
||
)
|
||
)
|
||
|
||
def _setup_multi_gpu(self):
|
||
"""配置多GPU环境"""
|
||
if torch.cuda.is_available():
|
||
gpu_count = torch.cuda.device_count()
|
||
print(f"检测到 {gpu_count} 个GPU,将使用默认设备: {self.device}")
|
||
# 强制所有操作使用默认设备
|
||
torch.cuda.set_device(self.device)
|
||
|
||
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
|
||
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
|
||
try:
|
||
# 加载模型
|
||
model = YOLO(model_path)
|
||
# model.tracker = None # 手动禁用跟踪,关闭模型相关日志输出
|
||
|
||
# 移动到目标设备
|
||
model = model.to(self.device)
|
||
|
||
# 启用跟踪(仅对YOLOv8有效)
|
||
if hasattr(model, 'tracker') and tracking:
|
||
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
|
||
print(f"已为模型 {model_path} 启用跟踪")
|
||
else:
|
||
print(f"警告:模型 {model_path} 不支持跟踪功能")
|
||
|
||
# 对于多GPU,仅在必要时使用DataParallel,并确保所有参数在同一设备
|
||
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
||
# 确保模型参数在默认设备上
|
||
model = model.to(self.device)
|
||
model = torch.nn.DataParallel(model, device_ids=[self.device])
|
||
|
||
# 禁用模型融合以避免多设备问题
|
||
model.fuse = lambda verbose=False: model # 替换融合方法
|
||
|
||
return model
|
||
except Exception as e:
|
||
print(f"模型加载失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
|
||
"""预处理帧:调整大小、转换格式并标准化"""
|
||
# 获取原始帧尺寸 (width, height)
|
||
original_h, original_w = frame.shape[:2]
|
||
|
||
# 计算缩放比例,保持宽高比
|
||
scale = min(input_size[0] / original_w, input_size[1] / original_h)
|
||
new_w = int(original_w * scale)
|
||
new_h = int(original_h * scale)
|
||
|
||
# 调整大小并保持宽高比
|
||
resized = cv2.resize(frame, (new_w, new_h))
|
||
|
||
# 创建画布并保持模型输入尺寸
|
||
canvas = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
|
||
canvas[:new_h, :new_w] = resized
|
||
|
||
# 转换为RGB并标准化
|
||
rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
|
||
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
|
||
return tensor.unsqueeze(0).to(device)
|
||
|
||
@staticmethod
|
||
def scale_bbox(bbox, original_size, input_size):
|
||
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
|
||
ow, oh = original_size # (width, height)
|
||
iw, ih = input_size # (width, height)
|
||
|
||
# 计算缩放比例(基于输入尺寸和原始尺寸的比例)
|
||
scale_w = ow / iw
|
||
scale_h = oh / ih
|
||
|
||
x1, y1, x2, y2 = bbox
|
||
# 缩放并限制坐标范围
|
||
x1 = int(max(0, min(x1 * scale_w, ow - 1)))
|
||
y1 = int(max(0, min(y1 * scale_h, oh - 1)))
|
||
x2 = int(max(0, min(x2 * scale_w, ow - 1)))
|
||
y2 = int(max(0, min(y2 * scale_h, oh - 1)))
|
||
|
||
return [x1, y1, x2, y2]
|
||
|
||
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
|
||
"""异步调用多模型预测"""
|
||
loop = asyncio.get_running_loop()
|
||
# original_size = (frame.shape[0], frame.shape[1]) # (height, width)
|
||
original_size = (frame.shape[1], frame.shape[0]) # (width,height )
|
||
|
||
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
|
||
try:
|
||
model = self.models[model_idx]
|
||
object_counter = self.object_counters[model_idx]
|
||
cls_map = self.class_maps[model_idx]
|
||
allowed = self.allowed_classes[model_idx]
|
||
|
||
model_cls_index = self.model_cls[model_idx]
|
||
model_chinese_labe = self.chinese_label[model_idx]
|
||
model_list_func_id = self.list_func_id[model_idx]
|
||
func_id = self.func_id
|
||
model_class_names = self.list_class_names[model_idx]
|
||
para_invade_enable = self.list_para_invade_enable[model_idx]
|
||
|
||
# 预处理帧
|
||
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
|
||
timestart = time.time()
|
||
with torch.no_grad():
|
||
if input_tensor.device != next(model.parameters()).device:
|
||
input_tensor = input_tensor.to(next(model.parameters()).device)
|
||
|
||
# 使用ObjectCounter进行检测
|
||
results = object_counter(frame)
|
||
# 调试输出结果结构
|
||
# print(f"Model {model_idx} results type: {type(results)}")
|
||
if hasattr(results, 'boxes'):
|
||
print(
|
||
f"Boxes type: {type(results.boxes)}, length: {len(results.boxes) if hasattr(results.boxes, '__len__') else 'N/A'}")
|
||
|
||
timeend = time.time()
|
||
# print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s")
|
||
detections = []
|
||
model_para = {
|
||
"cls_map": cls_map,
|
||
"model_chinese_labe": model_chinese_labe,
|
||
"model_cls_index": model_cls_index,
|
||
"model_list_func_id": model_list_func_id,
|
||
"func_id": func_id,
|
||
"model_class_names": model_class_names,
|
||
"para_invade_enable": para_invade_enable,
|
||
"results": object_counter # 保存原始结果用于调试
|
||
}
|
||
results = object_counter
|
||
# 处理检测结果
|
||
detection_result_list = DetectionResultList([], [], [], [], [])
|
||
# 尝试不同的结果访问方式
|
||
boxes = []
|
||
tracks = results.track_ids
|
||
confs = results.confs
|
||
clss = results.clss
|
||
names = results.names
|
||
result_boxes = results.boxes
|
||
|
||
if isinstance(result_boxes, list) and len(result_boxes) == 0:
|
||
return [], DetectionResultList([], [], [], [], []), {}
|
||
if isinstance(result_boxes, list) and len(result_boxes) > 0:
|
||
boxes = result_boxes
|
||
else:
|
||
boxes = result_boxes.tolist()
|
||
|
||
filter_conf = False
|
||
|
||
# 处理每个检测结果
|
||
for i, box in enumerate(boxes):
|
||
try:
|
||
# 确保box是可迭代的
|
||
if not hasattr(box, '__iter__'):
|
||
print(f"无效的边界框格式: {box}")
|
||
continue
|
||
|
||
# 转换为列表(如果还不是)
|
||
box = list(box) if not isinstance(box, list) else box
|
||
|
||
# 确保有4个坐标值
|
||
if len(box) < 4:
|
||
print(f"不完整的边界框坐标: {box}")
|
||
continue
|
||
|
||
x1, y1, x2, y2 = box[:4]
|
||
# cls_id = int(results.clss[i])
|
||
cls_id = int(results.clss[i])
|
||
ind = int(results.clss[i])
|
||
cls_name = names[ind]
|
||
|
||
conf = float(confs[i]) if i < len(confs) else 0.0
|
||
# 待优化,将参数提取出来
|
||
if conf < self.conf:
|
||
# print(f"conf self.conf{conf} {self.conf}")
|
||
# filter_conf=True
|
||
continue
|
||
track_id = int(tracks[i]) if i < len(tracks) and tracks[i] is not None else None
|
||
|
||
# 过滤不允许的类别
|
||
if allowed and cls_name not in allowed:
|
||
continue
|
||
|
||
scaled_bbox = [int(x1), int(y1), int(x2), int(y2)]
|
||
detection_result_list.boxes.append([x1, y1, x2, y2])
|
||
detection_result_list.clss.append(cls_id)
|
||
detection_result_list.clss_name.append(cls_name)
|
||
detection_result_list.confs.append(conf)
|
||
detection_result_list.track_ids.append(track_id)
|
||
|
||
detections.append(DetectionResult(
|
||
bbox=scaled_bbox,
|
||
class_id=cls_id,
|
||
class_name=cls_name,
|
||
confidence=conf,
|
||
track_id=track_id
|
||
))
|
||
except Exception as e:
|
||
print(f"处理单个检测结果时出错: {str(e)}")
|
||
continue
|
||
# # 基于conf 筛选目标
|
||
# if not filter_conf:
|
||
# return detections, detection_result_list, model_para
|
||
# else:
|
||
# return [], DetectionResultList([], [], [], [], []), {}
|
||
return detections, detection_result_list, model_para
|
||
|
||
|
||
except Exception as e:
|
||
print(f"模型 {model_idx} 预测过程中发生错误: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return [], DetectionResultList([], [], [], [], []), {}
|
||
|
||
# 并行执行
|
||
futures = [
|
||
loop.run_in_executor(
|
||
self.executor,
|
||
_predict,
|
||
model_idx,
|
||
frame.copy()
|
||
)
|
||
for model_idx in range(len(self.models))
|
||
]
|
||
results111 = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
|
||
|
||
# 合并检测结果
|
||
all_detections = []
|
||
detections_list = [] # 这里的格式,跟results 的二次计算相同
|
||
# 合并所有模型的参数到一个字典(避免键冲突)
|
||
all_model_paras = []
|
||
for model_idx, (detections, detection_result_list, model_para) in enumerate(results111):
|
||
all_detections.extend(detections)
|
||
all_model_paras.append(model_para)
|
||
detections_list.append(detection_result_list)
|
||
|
||
return all_detections, detections_list, all_model_paras
|
||
|
||
|
||
async def read_video_frames(task_id, mqtt, mqtt_publish_topic,
|
||
local_video_path: str, srt_path: str, frame_queue: asyncio.Queue,
|
||
timestamp_frame_queue: TimestampedQueue,
|
||
cancel_flag: asyncio.Event = None, ):
|
||
srt_list = parse_srt_file(srt_path)
|
||
|
||
cap = cv2.VideoCapture(local_video_path)
|
||
if not cap.isOpened():
|
||
print("Error: Could not open video.")
|
||
return
|
||
prev_time = time.time()
|
||
frame_count = 0
|
||
fps = 0
|
||
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
||
# # 临时使用,后续还是考虑使用文件整体传输消息好一点
|
||
# async with cache_lock:
|
||
# # 读取或修改共享变量
|
||
# video_process_status = shared_local_cache["video_process_status"]
|
||
# if video_process_status == 0:
|
||
# shared_local_cache["video_process_status"] = 1 # 将状态改为开始
|
||
|
||
# 临时处理,用作发送当前录像处理的状态,开始还是结束
|
||
async def publist_status(status_to_publish):
|
||
try:
|
||
message = {"task_id": task_id, "video_status": status_to_publish}
|
||
await mqtt.publish(mqtt_publish_topic, json.dumps(message, ensure_ascii=False))
|
||
except Exception as e:
|
||
logger.error(f"Failed to publish status: {e}")
|
||
raise
|
||
|
||
await publist_status(1)
|
||
|
||
# 控制读帧,避免队列积压,进而出现大量丢帧
|
||
cap = cv2.VideoCapture(local_video_path)
|
||
video_fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频原始FPS
|
||
frame_interval = 1.0 / video_fps # 每帧应等待的时间
|
||
last_read_time = time.time()
|
||
|
||
while not cancel_flag.is_set() and cap.isOpened():
|
||
ret, frame = cap.read()
|
||
if not ret: # 检查是否成功读取帧(视频结束或读取错误)
|
||
await publist_status(2)
|
||
# 录像的任务停止,由录像本身发起
|
||
await task_manager.remove_task(task_id)
|
||
# stop_event.set()
|
||
# # 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
|
||
# async with cache_lock:
|
||
# # 读取或修改共享变量
|
||
# video_process_status = shared_local_cache["video_process_status"]
|
||
# if video_process_status == 0:
|
||
# shared_local_cache["video_process_status"] = 2 # 将状态改为结束
|
||
break
|
||
|
||
# 控制读取速度:确保每帧间隔不小于 frame_interval
|
||
current_time = time.time()
|
||
if current_time - last_read_time < frame_interval:
|
||
await asyncio.sleep(frame_interval - (current_time - last_read_time))
|
||
last_read_time = current_time
|
||
|
||
# # 保存当前帧为图片(或直接处理)
|
||
# frame_path = f"{output_folder}/frame_{frame_count:06d}.jpg"
|
||
# cv2.imwrite(frame_path, frame)
|
||
frame_count += 1
|
||
# if frame_count % 10 == 0:
|
||
# curr_time = time.time()
|
||
# fps = 10 / (curr_time - prev_time)
|
||
# prev_time = curr_time
|
||
# print(f" ({orig_width}x{orig_height}, FPS: {fps:.2f})")
|
||
# # 可选:每处理N帧打印一次进度
|
||
# if frame_count % 100 == 0:
|
||
# print(f"Processed {frame_count} frames")
|
||
|
||
if not frame_queue.full():
|
||
|
||
if frame_count < len(srt_list) and len(srt_list) > 0:
|
||
dji_srt = srt_list[frame_count]
|
||
gimbal_pitch = dji_srt.gb_pitch
|
||
gimbal_roll = dji_srt.gb_roll
|
||
gb_yaw = dji_srt.gb_yaw
|
||
air_height = dji_srt.abs_alt
|
||
cam_longitude = dji_srt.longitude
|
||
cam_latitude = dji_srt.latitude
|
||
art_tit = Air_Attitude(gimbal_pitch, gimbal_roll, gb_yaw, air_height, cam_latitude, cam_longitude)
|
||
time_ns = time.time_ns()
|
||
await frame_queue.put((frame, art_tit, time_ns))
|
||
timestamp_frame_queue.append({
|
||
"timestamp": time_ns,
|
||
"frame": frame
|
||
})
|
||
else:
|
||
await asyncio.sleep(0.01) # 队列满时稍作等待
|
||
|
||
cap.release()
|
||
# # # # # # 测试时候,临时注释
|
||
if os.path.exists(srt_path):
|
||
os.remove(srt_path)
|
||
if os.path.exists(local_video_path):
|
||
os.remove(local_video_path)
|
||
|
||
|
||
async def read_rtmp_frames(
|
||
loop,
|
||
read_rtmp_frames_executor: ThreadPoolExecutor,
|
||
video_url: str,
|
||
device: Optional[MQTTDevice] = None,
|
||
topic_camera_osd: Optional[str] = None,
|
||
method_camera_osd: Optional[str] = None,
|
||
topic_osd_info: Optional[str] = None,
|
||
method_osd_info: Optional[str] = None,
|
||
cancel_flag: Optional[asyncio.Event] = None,
|
||
frame_queue: asyncio.Queue = None,
|
||
timestamp_frame_queue: TimestampedQueue = None
|
||
):
|
||
"""
|
||
异步读取 RTMP 流帧(优化版:移除帧率控制,优化线程池)
|
||
"""
|
||
max_retries = 20
|
||
retry_delay = 2
|
||
pic_count = 0
|
||
attempt = 0
|
||
time_start = time.time_ns() # 添加开始时间统计
|
||
frame_count = 0 # 统计总帧数
|
||
|
||
if cancel_flag is None:
|
||
cancel_flag = asyncio.Event()
|
||
|
||
# loop = asyncio.get_running_loop()
|
||
|
||
# 打印初始统计信息
|
||
print(f"开始读取RTMP流: {video_url}")
|
||
|
||
while not cancel_flag.is_set() and attempt < max_retries:
|
||
attempt += 1
|
||
if cancel_flag.is_set():
|
||
logger.info("收到停止信号,终止 RTMP 读取")
|
||
break
|
||
|
||
container = None
|
||
try:
|
||
logger.info(f"尝试连接 RTMP 流 (尝试 {attempt}/{max_retries}): {video_url}")
|
||
# 1. 关键优化:将同步的 av.open 和流初始化放到线程池
|
||
container = await loop.run_in_executor(read_rtmp_frames_executor, av.open, video_url)
|
||
video_stream = await loop.run_in_executor(read_rtmp_frames_executor, next,
|
||
(s for s in container.streams if s.type == 'video'))
|
||
logger.info(f"成功连接到 RTMP 流: {video_url} ({video_stream.width}x{video_stream.height})")
|
||
|
||
# 2. 提前获取一次OSD消息(验证MQTT是否正常)
|
||
if device and topic_osd_info and method_osd_info:
|
||
osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
|
||
if osd_msg:
|
||
logger.info(f"初始OSD消息获取成功: 高度={osd_msg.data.height}")
|
||
else:
|
||
logger.warning("初始OSD消息为空,可能MQTT尚未收到消息")
|
||
|
||
# 3. 关键优化:将同步的帧迭代放到线程池,通过生成器异步获取
|
||
async def async_frame_generator():
|
||
"""异步帧生成器:在后台线程迭代同步帧,通过yield返回给事件循环"""
|
||
|
||
def sync_frame_iter():
|
||
try:
|
||
for frame in container.decode(video=0):
|
||
# 线程内检查取消标志(需定期检查,避免线程无法退出)
|
||
if cancel_flag.is_set():
|
||
logger.info("后台线程检测到取消信号,停止帧迭代")
|
||
break
|
||
|
||
# 确保是3通道RGB
|
||
if len(frame.planes) == 1: # 如果是灰度图
|
||
gray = frame.to_ndarray(format='gray')
|
||
# 转换为3通道BGR(不修改尺寸)
|
||
bgr = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
||
yield bgr
|
||
else:
|
||
# 保持原始尺寸和色彩空间,只转换格式
|
||
bgr = frame.to_ndarray(format='bgr24')
|
||
yield bgr
|
||
except Exception as e:
|
||
logger.error(f"同步帧迭代出错: {e}")
|
||
finally:
|
||
if container:
|
||
container.close()
|
||
logger.info("RTMP容器已关闭")
|
||
|
||
# 将同步迭代器包装为异步生成器
|
||
gen = sync_frame_iter()
|
||
while not cancel_flag.is_set():
|
||
try:
|
||
# 每次获取一帧都通过线程池执行,避免长时间阻塞
|
||
frame = await loop.run_in_executor(read_rtmp_frames_executor, next, gen, None)
|
||
if frame is None: # 迭代结束
|
||
break
|
||
yield frame
|
||
except StopIteration:
|
||
logger.info("RTMP流帧迭代结束")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"异步获取帧出错: {e}")
|
||
break
|
||
|
||
# 4. 异步迭代帧(不阻塞事件循环)
|
||
async for frame in async_frame_generator():
|
||
if cancel_flag.is_set():
|
||
logger.info("检测到取消信号,停止读取帧")
|
||
break
|
||
|
||
try:
|
||
# 5. 帧转换也放到线程池(av.Frame.to_ndarray是CPU密集操作)
|
||
img = frame.copy() # 确保不修改原始帧
|
||
osd_info = None
|
||
|
||
# 6. 此时事件循环未被阻塞,MQTT消息已缓存,get_latest_message可即时获取
|
||
if device and topic_osd_info and method_osd_info:
|
||
osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
|
||
if osd_msg and hasattr(osd_msg, 'data'):
|
||
osd_info = Air_Attitude(
|
||
gimbal_pitch=osd_msg.data.gimbal_pitch,
|
||
gimbal_roll=osd_msg.data.gimbal_roll,
|
||
gimbal_yaw=osd_msg.data.gimbal_yaw,
|
||
height=osd_msg.data.height,
|
||
latitude=osd_msg.data.latitude,
|
||
longitude=osd_msg.data.longitude
|
||
)
|
||
|
||
# 7. 异步放入帧队列(避免队列满时阻塞)
|
||
if not frame_queue.full():
|
||
pic_count += 1
|
||
frame_count += 1 # 增加总帧数统计
|
||
time_ns = time.time_ns()
|
||
|
||
# 定期输出统计信息(每1000帧)
|
||
if time_ns - time_start > 1000000000:
|
||
print(f"readFrames {pic_count}")
|
||
pic_count = 0
|
||
time_start = time_ns
|
||
if img is not None and osd_info is not None:
|
||
await frame_queue.put((img, osd_info, time_ns))
|
||
timestamp_frame_queue.append({
|
||
"timestamp": time_ns,
|
||
"frame": img
|
||
})
|
||
logger.debug(
|
||
f"已放入帧队列,累计帧数: {pic_count},队列剩余空间: {frame_queue.maxsize - frame_queue.qsize()}")
|
||
else:
|
||
logger.warning("帧队列已满,等待1ms后重试")
|
||
await asyncio.sleep(0.001)
|
||
|
||
except Exception as frame_error:
|
||
logger.error(f"处理单帧时出错: {frame_error}", exc_info=True)
|
||
continue
|
||
|
||
except (av.AVError, IOError) as e:
|
||
logger.error(f"RTMP 流错误 (尝试 {attempt}/{max_retries}): {e}")
|
||
if attempt < max_retries:
|
||
await asyncio.sleep(retry_delay)
|
||
else:
|
||
raise RuntimeError(f"无法连接 RTMP 流 (尝试 {max_retries} 次后失败): {video_url}")
|
||
except asyncio.CancelledError:
|
||
logger.info("read_rtmp_frames 收到取消信号")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"未知错误: {e}", exc_info=True)
|
||
if attempt < max_retries:
|
||
await asyncio.sleep(retry_delay)
|
||
finally:
|
||
# 双重保险:确保容器关闭
|
||
if container and not container.closed:
|
||
await loop.run_in_executor(None, container.close)
|
||
logger.info("RTMP容器在finally中关闭")
|
||
|
||
# 最终统计信息
|
||
if frame_count > 0:
|
||
total_time = (time.time_ns() - time_start) / 1e9
|
||
avg_fps = frame_count / total_time if total_time > 0 else 0
|
||
print(f"RTMP流读取完成,总帧数: {frame_count}, 总时间: {total_time:.2f}秒, 平均FPS: {avg_fps:.2f}")
|
||
else:
|
||
print("RTMP流读取失败,未获取到任何帧")
|
||
|
||
logger.info(f"RTMP 流已结束或被取消,累计处理帧数: {pic_count}")
|
||
|
||
|
||
# async def process_frames(detector: MultiYOLODetector):
|
||
# async def process_frames(detector: MultiYOLODetector_TrackId, cancel_flag: asyncio.Event,
|
||
# frame_queue: asyncio.Queue, processed_queue: asyncio.Queue):
|
||
async def process_frames(detector: MultiYoloTrtDetectorTrackId, cancel_flag: asyncio.Event,
|
||
frame_queue: asyncio.Queue, processed_queue: asyncio.Queue):
|
||
"""协程处理帧队列"""
|
||
start_time = time.time()
|
||
time_start = time.time_ns()
|
||
pic_count = 0
|
||
while not cancel_flag.is_set():
|
||
frame_start = time.time()
|
||
try:
|
||
frame, osd_info, timestamp = await asyncio.wait_for(
|
||
frame_queue.get(),
|
||
timeout=0.5 # 延长超时,适配处理耗时
|
||
)
|
||
|
||
try:
|
||
time_pr_start = time.time_ns()
|
||
detections, detections_list, model_para = await detector.predict(frame)
|
||
time_pr_end = time.time_ns()
|
||
print(f"time_pr_starttime_pr_start {(time_pr_end - time_pr_start) / 1000000}")
|
||
predict_state = True
|
||
if detections:
|
||
print("检测到任何目标")
|
||
if not detections:
|
||
predict_state = False
|
||
logger.debug("未检测到任何目标")
|
||
|
||
# continue
|
||
|
||
# # # # 显示帧用于调试(可选)
|
||
# cv2.imshow('process_frames', frame)
|
||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
# stop_event.set()
|
||
# break
|
||
|
||
processed_data = {
|
||
'frame': frame,
|
||
'osd_info': osd_info,
|
||
'detections': detections,
|
||
'detections_list': detections_list,
|
||
'timestamp': timestamp,
|
||
'model_para': model_para,
|
||
'predict_state': predict_state # predict状态判定,方便rtmp推流做状态判定
|
||
}
|
||
|
||
if not processed_queue.full():
|
||
time_end = time.time_ns()
|
||
pic_count = pic_count + 1
|
||
if time_end - time_start > 1000000000:
|
||
print(f"processframes {pic_count}")
|
||
pic_count = 0
|
||
time_start = time_end
|
||
|
||
await processed_queue.put(processed_data)
|
||
else:
|
||
logger.warning("处理队列已满,丢弃帧")
|
||
stats['dropped_frames'] += 1
|
||
if 'frame' in locals():
|
||
frame_time = time.time() - frame_start
|
||
# print(f"处理帧耗时: {frame_time:.4f}s")
|
||
except Exception as e:
|
||
logger.error(f"处理帧时出错: {e}", exc_info=True)
|
||
stats['dropped_frames'] += 1
|
||
await asyncio.sleep(0.1)
|
||
|
||
except asyncio.TimeoutError:
|
||
continue
|
||
except asyncio.CancelledError:
|
||
print("process_frames 收到取消信号")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取帧时发生意外错误: {e}", exc_info=True)
|
||
await asyncio.sleep(0.1)
|
||
# finally:
|
||
# log_perf('process_frames', start_time)
|
||
print("process_frames读取线程已停止")
|
||
|
||
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
|
||
chinese_font = None
|
||
try:
|
||
chinese_font = ImageFont.truetype("config/SIMSUN.TTC", 60)
|
||
except:
|
||
chinese_font = ImageFont.load_default()
|
||
|
||
|
||
def put_chinese_text(img, text, position, font=chinese_font, color=(0, 255, 0)):
|
||
"""使用预加载的字体绘制中文"""
|
||
|
||
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(img_pil)
|
||
draw.text(position, text, font=font, fill=color)
|
||
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
||
|
||
|
||
# # 全局变量存储推流容器
|
||
# stream_containers: Dict[str, Any] = {}
|
||
|
||
from collections import defaultdict
|
||
import time
|
||
from typing import Dict, Any
|
||
|
||
|
||
class TrackIDEventFilter:
|
||
def __init__(self, max_inactive_time: float = 5.0):
|
||
"""
|
||
初始化 TrackID 过滤器
|
||
:param max_inactive_time: 最大不活跃时间(秒),超时未出现的 track_id 会被移除
|
||
"""
|
||
self.track_status = defaultdict(int) # 记录每个 track_id 的连续识别次数
|
||
self.last_active_time = defaultdict(float) # 记录每个 track_id 最后一次出现的时间
|
||
self.max_inactive_time = max_inactive_time # 超时清理阈值
|
||
|
||
def should_report(self, track_id: int) -> bool:
|
||
"""
|
||
判断是否应该上报事件
|
||
:param track_id: 跟踪ID
|
||
:return: True 表示应该上报,False 表示过滤
|
||
"""
|
||
current_time = time.time()
|
||
|
||
# for tid, last_time in self.last_active_time.items():
|
||
# print(f"tidtidtidtid {tid} {current_time} {last_time} {current_time-last_time} {self.max_inactive_time}")
|
||
|
||
# 1. 清理长时间不活跃的 track_id
|
||
inactive_tracks = [
|
||
tid for tid, last_time in self.last_active_time.items()
|
||
if (current_time - last_time) > self.max_inactive_time
|
||
]
|
||
for tid in inactive_tracks:
|
||
print(f"inactive_tracksinactive_tracksinactive_tracks{tid}")
|
||
self.track_status.pop(tid, None)
|
||
self.last_active_time.pop(tid, None)
|
||
|
||
# 2. 更新最后活跃时间
|
||
self.last_active_time[track_id] = current_time
|
||
|
||
show_count = self.track_status[track_id]
|
||
print(f"show_countshow_countshow_count {track_id} {show_count}")
|
||
# 3. 过滤逻辑
|
||
if show_count == 0:
|
||
# 首次出现
|
||
self.track_status[track_id] = 1
|
||
print(f"track_status[track_id] = 1 {self.track_status[track_id]} : {track_id}")
|
||
return True
|
||
elif show_count > 0:
|
||
# 连续出现第二次,上报并重置状态(需要再次连续出现才会触发)
|
||
self.track_status[track_id] = show_count + 1 # 重置为 0,下次需重新连续出现两次
|
||
print(f"track_status[track_id] =show_count + 1 {self.track_status[track_id]} : {track_id}")
|
||
return False
|
||
|
||
|
||
async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: float = None,
|
||
list_points: list[list[any]] = None, camera_para: Camera_Para = None,
|
||
invade_state: bool = False, cancel_flag: asyncio.Event = None,
|
||
processed_queue: asyncio.Queue = None, invade_queue: asyncio.Queue = None,
|
||
cv_frame_queue: asyncio.Queue = None, stream_containers: Dict[str, Any] = None):
|
||
# global stream_containers, count_pic
|
||
start_time = time.time()
|
||
time_start = time.time_ns()
|
||
pic_count = 0
|
||
# 修改推流参数
|
||
options = {
|
||
'preset': 'veryfast',
|
||
'tune': 'zerolatency',
|
||
'crf': '23',
|
||
'g': '50', # 关键帧间隔
|
||
'threads': '2', # 限制编码线程
|
||
}
|
||
codec_name = 'libx264'
|
||
|
||
max_retries = 3
|
||
retry_delay = 2.0
|
||
|
||
# 初始化视频输出
|
||
output_video_path = None
|
||
video_writer = None
|
||
frame_width, frame_height = None, None
|
||
fps = input_fps or Config.TARGET_FPS
|
||
last_frame_time = time.time() - 1
|
||
frame_interval = 1.0 / fps
|
||
try:
|
||
while not cancel_flag.is_set():
|
||
frame_start = time.time()
|
||
try:
|
||
# 第一层帧率控制
|
||
# current_time = time.time()
|
||
#
|
||
# time_diff = frame_interval - (current_time - last_frame_time)
|
||
# if time_diff > 0:
|
||
# await asyncio.sleep(time_diff)
|
||
# last_frame_time = current_time
|
||
|
||
processed_data = await asyncio.wait_for(
|
||
processed_queue.get(),
|
||
timeout=1
|
||
)
|
||
|
||
# 确保 processed_data 是字典
|
||
if not isinstance(processed_data, dict):
|
||
print(f"❌ 错误:processed_data 不是字典,而是 {type(processed_data)}")
|
||
continue
|
||
|
||
frame = processed_data['frame']
|
||
# 绘制检测结果
|
||
frame_copy = frame.copy()
|
||
predict_state = processed_data['predict_state']
|
||
osd_info = processed_data['osd_info']
|
||
img_height, img_width = frame.shape[:2]
|
||
|
||
results = []
|
||
results_list = []
|
||
# 启用侵限且拿到了飞机的姿态信息,再绘制红线
|
||
if invade_state and osd_info:
|
||
gimbal_yaw = osd_info.gimbal_yaw
|
||
gimbal_pitch = osd_info.gimbal_pitch
|
||
gimbal_roll = osd_info.gimbal_roll
|
||
height = osd_info.height
|
||
# print(f"heightheightheight {height}")
|
||
cam_longitude = osd_info.longitude
|
||
cam_latitude = osd_info.latitude
|
||
# 当前list_points 虽然是二维数组,但是只存了一个,后续根据业务变化
|
||
|
||
for points in list_points:
|
||
# 批量返回图像的像素坐标
|
||
point_list = []
|
||
results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude,
|
||
cam_latitude,
|
||
img_width,
|
||
img_height, points, camera_para)
|
||
if results:
|
||
results_list.append(results) # 支持两个区域,高压侵限、营业线侵限
|
||
for point in results:
|
||
point_list.append([point["u"], point["v"]])
|
||
cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int64)], isClosed=True,
|
||
color=(0, 0, 255),
|
||
thickness=2)
|
||
# print(f"predict_statepredict_state {predict_state}")
|
||
# 模型输出了推理结果
|
||
if predict_state:
|
||
# 测试代码,用做测试推理结果,初始化视频写入器(如果尚未初始化)
|
||
if video_writer is None and output_video_path:
|
||
frame_width = frame.shape[1]
|
||
frame_height = frame.shape[0]
|
||
# 定义视频编码器和输出文件
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者使用 'avc1' 等其他编码
|
||
# fps = Config.TARGET_FPS # 使用配置中的目标帧率
|
||
video_writer = cv2.VideoWriter(
|
||
output_video_path,
|
||
fourcc,
|
||
fps,
|
||
(frame_width, frame_height)
|
||
)
|
||
print(f"视频写入器已初始化,分辨率: {frame_width}x{frame_height}, FPS: {fps}")
|
||
|
||
detections = processed_data['detections']
|
||
detections_list = processed_data['detections_list']
|
||
model_para = processed_data['model_para']
|
||
|
||
class_names = model_para[0]["model_class_names"]
|
||
chinese_label = model_para[0]["model_chinese_labe"]
|
||
cls_map = model_para[0]["cls_map"]
|
||
para_invade_enable = model_para[0]["para_invade_enable"]
|
||
model_list_func_id = model_para[0]["model_list_func_id"]
|
||
model_func_id = model_para[0]["func_id"]
|
||
invade_point = []
|
||
message_point = []
|
||
target_point = [] # 存储满足条件的图像坐标,方便后续经纬度转换
|
||
cls_count = 0
|
||
|
||
# # 初始化统计字典
|
||
class_stats = defaultdict(int)
|
||
reversed_dict = {value: 0 for value in cls_map.values()}
|
||
bg_color = (173, 216, 230) # 文本框底色使用淡蓝色
|
||
text_color = (0, 255, 0) # 绿色
|
||
|
||
for det in detections:
|
||
x1, y1, x2, y2 = map(int, det.bbox) # 确保坐标是整数
|
||
cls_id = det.class_id # 假设Detection对象有class_id属性
|
||
class_name = det.class_name
|
||
confidence = det.confidence
|
||
track_id = det.track_id
|
||
new_track_id = track_id * 100 + cls_id # 类型小于100或者为负数
|
||
# 更新统计
|
||
class_stats[cls_id] += 1
|
||
# 如果开起侵限功能,就只显示侵限内的框
|
||
point_x = (x1 + x2) / 2
|
||
point_y = (y1 + y2) / 2
|
||
print(f"class_name--{class_name}")
|
||
print(f"model_class_names: {model_para[0]['model_class_names']}")
|
||
|
||
if class_name not in model_para[0]["model_class_names"]:
|
||
continue
|
||
|
||
en_name = model_para[0]["model_chinese_labe"][
|
||
model_para[0]["model_class_names"].index(class_name)]
|
||
if invade_state:
|
||
# 同时适配多个区域的侵限判断
|
||
is_invade = is_point_in_polygonlist(point_x, point_y, results_list)
|
||
# is_invade = is_point_in_polygon(point_x, point_y, results)
|
||
# print(f"is_invadeis_invadeis_invade {is_invade} {len(results)}")
|
||
if is_invade:
|
||
cls_count += 1
|
||
invade_point.append({
|
||
"u": point_x,
|
||
"v": point_y,
|
||
"class_name": class_name
|
||
})
|
||
target_point.append({
|
||
"u": point_x,
|
||
"v": point_y,
|
||
"cls_id": cls_id,
|
||
"track_id": track_id,
|
||
"new_track_id": new_track_id
|
||
}) # 对于侵限,只存储侵限目标
|
||
# model_list_func_id = model_para[0]["model_list_func_id"]
|
||
# model_func_id = model_para[0]["func_id"]
|
||
|
||
message_point.append({
|
||
"confidence": float(confidence),
|
||
"cls_id": cls_id,
|
||
"type_name": en_name,
|
||
"track_id": track_id,
|
||
"box": [x1, y1, x2, y2]
|
||
})
|
||
label = f"{en_name}:{confidence:.2f}:{track_id}"
|
||
# 计算文本位置
|
||
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=4)[
|
||
0]
|
||
text_width, text_height = text_size[0], text_size[1]
|
||
text_x = x1
|
||
text_y = y1 - 5
|
||
|
||
# 如果文本超出图像顶部,则放在框内部下方
|
||
if text_y < 0:
|
||
text_y = y2 + text_height + 5
|
||
temp_img = frame_copy.copy()
|
||
frame_copy = put_chinese_text(
|
||
temp_img,
|
||
# label, # 置信度、类别、用作测试
|
||
"", # 注释掉汉字
|
||
(text_x, text_y - text_height),
|
||
)
|
||
else:
|
||
cls_count += 1
|
||
# 绘制边界框
|
||
cv2.rectangle(frame_copy, (x1, y1), (x2, y2), (0, 255, 255), 2)
|
||
message_point.append({
|
||
"confidence": float(confidence),
|
||
"cls_id": cls_id,
|
||
"type_name": en_name,
|
||
"track_id": track_id,
|
||
"box": [x1, y1, x2, y2]
|
||
})
|
||
target_point.append({
|
||
"u": point_x,
|
||
"v": point_y,
|
||
"cls_id": cls_id,
|
||
"track_id": track_id,
|
||
"new_track_id": new_track_id
|
||
}) # 对于侵限,只存储侵限目标
|
||
# 准备标签文本
|
||
# label = f"{chinese_label.get(cls_id, class_name)}: {confidence:.2f}:{track_id}"
|
||
label = f"{confidence:.2f}:{track_id}"
|
||
# 计算文本位置
|
||
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=8, thickness=8)[0]
|
||
text_width, text_height = text_size[0], text_size[1]
|
||
text_x = x1
|
||
text_y = y1 - 5
|
||
# 如果文本超出图像顶部,则放在框内部下方
|
||
if text_y < 0:
|
||
text_y = y2 + text_height + 5
|
||
|
||
# 绘制文本背景
|
||
padding = 2
|
||
temp_img = frame_copy.copy()
|
||
frame_copy = put_chinese_text(
|
||
temp_img,
|
||
# label, # 置信度、类别、用作测试
|
||
"", # 注释掉汉字
|
||
(text_x, text_y - text_height),
|
||
)
|
||
|
||
if invade_state:
|
||
for point in message_point:
|
||
cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
|
||
(point["box"][2], point["box"][3]),
|
||
(0, 255, 255), 2)
|
||
# 画红线
|
||
# 在左上角显示统计结果
|
||
stats_text = []
|
||
for cls_id, count in class_stats.items():
|
||
cls_name = chinese_label.get(cls_id,
|
||
class_names[cls_id] if class_names and cls_id < len(
|
||
class_names) else str(
|
||
cls_id))
|
||
reversed_dict[cls_name] = count
|
||
for key, value in reversed_dict.items():
|
||
stats_text.append(f"{key}: {value}")
|
||
|
||
if stats_text:
|
||
# 计算统计文本的总高度
|
||
text_height = cv2.getTextSize("Test", cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, thickness=1)[0][
|
||
1]
|
||
total_height = len(stats_text) * (text_height + 18) # 5是行间距
|
||
|
||
# 统计文本的起始位置(左上角)
|
||
start_x = 50 # 留出200像素宽度
|
||
start_y = 20 # 从顶部开始
|
||
|
||
# # 绘制统计背景
|
||
# # bg_color = (0, 0, 0)
|
||
# frame_copy = cv2.rectangle(
|
||
# frame_copy,
|
||
# (start_x - 10, start_y - 10),
|
||
# (400, start_y + total_height + 20),
|
||
# bg_color,
|
||
# -1
|
||
# )
|
||
|
||
# # 逐行绘制统计文本
|
||
# for i, text in enumerate(stats_text):
|
||
# y_pos = start_y + i * (text_height + 30)
|
||
# temp_img = frame_copy.copy()
|
||
# frame_copy = put_chinese_text(
|
||
# temp_img,
|
||
# text,
|
||
# (start_x, y_pos),
|
||
# color=text_color
|
||
# )
|
||
|
||
new_data = {
|
||
'frame_copy': frame_copy,
|
||
'frame': frame,
|
||
"osd_info": osd_info,
|
||
'detections': detections,
|
||
"message_point": message_point,
|
||
"cls_count": cls_count,
|
||
"target_point": target_point,
|
||
"model_list_func_id": model_list_func_id,
|
||
"model_func_id": model_func_id,
|
||
# 提取第一个模型的func_id 字段,因为现在已经要做整合,无法在区分各个提取第一个模型的func_id
|
||
'timestamp': processed_data.get('timestamp'),
|
||
"detections_list": detections_list,
|
||
"model_para": model_para
|
||
# 'model_para': processed_data.get('model_para', {}) # 确保 model_para 存在
|
||
}
|
||
# 临时代码 rtmp 和侵限逻辑要改
|
||
if invade_state:
|
||
# para_list 中使能了 para_invade_enable,才做侵限判断
|
||
if para_invade_enable:
|
||
if not invade_queue.full():
|
||
await invade_queue.put(new_data)
|
||
else:
|
||
if not cv_frame_queue.full():
|
||
await cv_frame_queue.put(new_data)
|
||
|
||
time_end = time.time_ns()
|
||
pic_count = pic_count + 1
|
||
if time_end - time_start > 1000000000:
|
||
print(f"writeFrames {pic_count}")
|
||
pic_count = 0
|
||
time_start = time_end
|
||
|
||
#
|
||
# count_p = count_p + 1
|
||
# cv2.imwrite(f"save_pic/rtmp/test-{count_p}.jpg", frame_copy)
|
||
# video_writer.write(frame_copy)
|
||
# else:
|
||
# frame_copy = frame
|
||
|
||
# 转换颜色空间
|
||
rgb_frame = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
|
||
|
||
# # # # # 显示帧用于调试(可选)
|
||
# cv2.imshow(f"write_results_to_rtmp-{task_id}", frame_copy)
|
||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
# stop_event.set()
|
||
# break
|
||
|
||
# 初始化推流容器(如果尚未初始化)
|
||
if output_url and output_url not in stream_containers:
|
||
try:
|
||
# container = av.open(output_url, mode='w', format='flv')
|
||
container = av.open(
|
||
output_url,
|
||
mode='w',
|
||
format='flv',
|
||
options={
|
||
# 添加RTMP连接超时(5秒)和数据超时(10秒)
|
||
'rtmp_connect_timeout': '5000000', # 单位:微秒(5秒)
|
||
'rtmp_timeout': '10000000', # 单位:微秒(10秒)
|
||
'stimeout': '5000000' # 底层socket超时
|
||
}
|
||
)
|
||
stream = container.add_stream(codec_name, rate=Config.TARGET_FPS)
|
||
# stream.time_base = f"1/{Config.TARGET_FPS}"
|
||
|
||
stream.width = frame.shape[1]
|
||
|
||
stream.height = frame.shape[0]
|
||
stream.pix_fmt = 'yuv420p'
|
||
stream.options = options
|
||
|
||
stream_containers[output_url] = {
|
||
'container': container,
|
||
'stream': stream,
|
||
'last_frame_time': time.time(),
|
||
'frame_count': 0,
|
||
'retry_count': 0
|
||
}
|
||
print(f"✅ 推流初始化成功: {output_url}")
|
||
except Exception as e:
|
||
print(f"❌ 推流初始化失败: {e}")
|
||
if 'container' in locals():
|
||
try:
|
||
container.close()
|
||
except:
|
||
pass
|
||
await asyncio.sleep(1.0)
|
||
continue
|
||
|
||
# 推流逻辑
|
||
if output_url and output_url in stream_containers:
|
||
try:
|
||
container_info = stream_containers[output_url]
|
||
stream = container_info['stream']
|
||
container = container_info['container']
|
||
|
||
if rgb_frame.dtype == np.uint8:
|
||
av_frame = av.VideoFrame.from_ndarray(rgb_frame, format='rgb24')
|
||
packets = stream.encode(av_frame) # 这是异步方法
|
||
# print(f"📦 encode 生成 {len(packets)} 个 packet")
|
||
|
||
if packets:
|
||
for packet in packets:
|
||
try:
|
||
container.mux(packet)
|
||
container_info['last_frame_time'] = time.time()
|
||
container_info['frame_count'] += 1
|
||
except Exception as e:
|
||
logger.warning(f"推流数据包错误: {e}")
|
||
container_info['retry_count'] += 1
|
||
if container_info['retry_count'] > max_retries:
|
||
raise
|
||
else:
|
||
# 编码器仍在初始化,不更新 last_frame_time
|
||
pass
|
||
|
||
# 每100帧打印一次状态
|
||
# if container_info['frame_count'] % 100 == 0:
|
||
# print(f"ℹ️ 已推送 {container_info['frame_count']} 帧到 {output_url}")
|
||
if 'frame' in locals():
|
||
frame_time = time.time() - frame_start
|
||
print(f"推流帧耗时: {frame_time:.4f}s")
|
||
else:
|
||
print(f"⚠️ 无效帧格式: {rgb_frame.dtype}")
|
||
except Exception as e:
|
||
logger.error(f"❌ 推流错误: {e}")
|
||
# 尝试重新初始化推流
|
||
if output_url in stream_containers:
|
||
try:
|
||
stream_containers[output_url]['container'].close()
|
||
except:
|
||
pass
|
||
del stream_containers[output_url]
|
||
await asyncio.sleep(retry_delay)
|
||
continue
|
||
|
||
|
||
except asyncio.TimeoutError:
|
||
# if stop_event.is_set():
|
||
# break
|
||
continue
|
||
except asyncio.CancelledError:
|
||
print("write_results_to_rtmp 收到取消信号")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"推流处理异常: {e}", exc_info=True)
|
||
await asyncio.sleep(0.1)
|
||
|
||
finally:
|
||
# log_perf('write_results_to_rtmp', start_time)
|
||
|
||
# 清理推流容器
|
||
for url, info in list(stream_containers.items()):
|
||
try:
|
||
if 'container' in info:
|
||
info['container'].close()
|
||
except Exception as e:
|
||
logger.warning(f"关闭推流容器时出错: {e}")
|
||
stream_containers.clear()
|
||
|
||
# 清理视频写入器
|
||
if video_writer is not None:
|
||
video_writer.release()
|
||
|
||
print("write_results_to_rtmp读取线程已停止")
|
||
|
||
|
||
# 基于射线法,判断设备是否在红线内
|
||
def is_point_in_polygonlist(point_x, point_y, polygon_list):
|
||
inside = False
|
||
for polygon in polygon_list:
|
||
x, y = point_x, point_y
|
||
n = len(polygon)
|
||
inside = False
|
||
|
||
for i in range(n):
|
||
p1 = polygon[i]
|
||
p2 = polygon[(i + 1) % n]
|
||
x1, y1 = p1["u"], p1["v"]
|
||
x2, y2 = p2["u"], p2["v"]
|
||
|
||
# 检查点是否在多边形的顶点或边上
|
||
if (x == x1 and y == y1) or (x == x2 and y == y2):
|
||
return True # 点在顶点上
|
||
if (x - x1) * (y2 - y1) == (y - y1) * (x2 - x1): # 点在边上
|
||
if min(x1, x2) <= x <= max(x1, x2) and min(y1, y2) <= y <= max(y1, y2):
|
||
return True
|
||
|
||
# 射线法核心逻辑
|
||
if (y1 > y) != (y2 > y): # 确保点在边的垂直范围内
|
||
x_intersect = (x2 - x1) * (y - y1) / (y2 - y1) + x1
|
||
if x <= x_intersect:
|
||
inside = not inside
|
||
return inside
|
||
|
||
|
||
# 基于射线法,判断设备是否在红线内
|
||
def is_point_in_polygon(point_x, point_y, polygon):
|
||
x, y = point_x, point_y
|
||
n = len(polygon)
|
||
inside = False
|
||
|
||
for i in range(n):
|
||
p1 = polygon[i]
|
||
p2 = polygon[(i + 1) % n]
|
||
x1, y1 = p1["u"], p1["v"]
|
||
x2, y2 = p2["u"], p2["v"]
|
||
|
||
# 检查点是否在多边形的顶点或边上
|
||
if (x == x1 and y == y1) or (x == x2 and y == y2):
|
||
return True # 点在顶点上
|
||
if (x - x1) * (y2 - y1) == (y - y1) * (x2 - x1): # 点在边上
|
||
if min(x1, x2) <= x <= max(x1, x2) and min(y1, y2) <= y <= max(y1, y2):
|
||
return True
|
||
|
||
# 射线法核心逻辑
|
||
if (y1 > y) != (y2 > y): # 确保点在边的垂直范围内
|
||
x_intersect = (x2 - x1) * (y - y1) / (y2 - y1) + x1
|
||
if x <= x_intersect:
|
||
inside = not inside
|
||
|
||
return inside
|
||
|
||
|
||
# 基于两个点的经纬度,计算两个点之间的距离
|
||
def haversine(lon1, lat1, lon2, lat2):
|
||
"""计算两个经纬度点之间的直线距离(单位:米)"""
|
||
lon1, lat1, lon2, lat2 = map(math.radians, [lon1, lat1, lon2, lat2])
|
||
dlon = lon2 - lon1
|
||
dlat = lat2 - lat1
|
||
a = math.sin(dlat / 2) ** 2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon / 2) ** 2
|
||
c = 2 * math.asin(math.sqrt(a))
|
||
r = 6371000 # 地球平均半径(米)
|
||
return c * r
|
||
|
||
|
||
async def cal_des_invade(loop,invade_executor, task_id: str, mqtt, mqtt_publish_topic,
|
||
list_points: list[list[any]], camera_para: Camera_Para, model_count: int,
|
||
cancel_flag: asyncio.Event = None, invade_queue: asyncio.Queue = None,
|
||
event_queue: asyncio.Queue = None,
|
||
device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1):
|
||
# loop = asyncio.get_running_loop()
|
||
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
|
||
pic_count_hongxian = 0
|
||
track_filter = TrackIDEventFilter(max_inactive_time=8.0)
|
||
# 用于记录已上报的track_id及其上报时间
|
||
reported_track_ids = defaultdict(float)
|
||
# 上报间隔时间(秒)
|
||
report_interval = 8
|
||
target_location_back = [] # 本地缓存,用作位置重复计算
|
||
current_time_second = int(time.time())
|
||
while not cancel_flag.is_set():
|
||
# 检查队列长度,避免堆积
|
||
if invade_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
|
||
print(f"警告:invade_queue 积压(当前长度={invade_queue.qsize()}),清空队列")
|
||
while not invade_queue.empty():
|
||
await invade_queue.get()
|
||
print("队列已清空,等待新数据...")
|
||
await asyncio.sleep(0.1)
|
||
continue
|
||
# 获取队列数据
|
||
try:
|
||
cv_frame = await asyncio.wait_for(invade_queue.get(), timeout=0.8)
|
||
# 检查数据类型
|
||
if not isinstance(cv_frame, dict):
|
||
print(f"⚠️ 警告:cv_frame 不是字典,而是 {type(cv_frame)}")
|
||
continue
|
||
except asyncio.TimeoutError:
|
||
print(f"cal_des_invade TimeoutError")
|
||
continue
|
||
except asyncio.CancelledError:
|
||
print("cal_des_invade 收到取消信号")
|
||
raise
|
||
if cv_frame is None:
|
||
continue
|
||
|
||
if repeat_time > 0: # 基于时间的重复计算,是否使能
|
||
read_cv_time_second = int(time.time())
|
||
|
||
if read_cv_time_second - current_time_second < repeat_time:
|
||
continue
|
||
else:
|
||
current_time_second = read_cv_time_second
|
||
|
||
# print("cal_des_invade inside")
|
||
frame_copy = cv_frame['frame_copy']
|
||
frame = cv_frame['frame']
|
||
target_point = cv_frame['target_point']
|
||
message_point = cv_frame['message_point']
|
||
cls_count = cv_frame['cls_count']
|
||
model_list_func_id = cv_frame['model_list_func_id']
|
||
model_func_id = cv_frame['model_func_id']
|
||
|
||
air_alti = cv_frame['osd_info']
|
||
detections = cv_frame['detections']
|
||
detections_list = cv_frame['detections_list']
|
||
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
|
||
timestamp = cv_frame.get('timestamp', time.time_ns()) # 默认空字典
|
||
if not isinstance(frame, np.ndarray) or frame.size == 0:
|
||
print("⚠️ 警告:frame不是有效的numpy数组")
|
||
continue
|
||
|
||
# 获取宽高
|
||
img_height, img_width = frame.shape[:2]
|
||
if not air_alti:
|
||
continue
|
||
|
||
gimbal_yaw = air_alti.gimbal_yaw
|
||
gimbal_pitch = air_alti.gimbal_pitch
|
||
gimbal_roll = air_alti.gimbal_roll
|
||
height = air_alti.height
|
||
cam_longitude = air_alti.longitude
|
||
cam_latitude = air_alti.latitude
|
||
|
||
try:
|
||
current_time = time.time()
|
||
h = device_height
|
||
us = []
|
||
vs = []
|
||
heights = []
|
||
|
||
should_report = True
|
||
for item in target_point:
|
||
# # 跳过无效的track_id
|
||
# 检查是否应该上报该track_id
|
||
u = item["u"]
|
||
v = item["v"]
|
||
cls_id = item["cls_id"]
|
||
track_id = item["track_id"]
|
||
new_track_id = item["new_track_id"]
|
||
|
||
# # 如果这个track_id已经上报过,检查是否超过上报间隔
|
||
# # if new_track_id in reported_track_ids:
|
||
# if track_id in reported_track_ids:
|
||
# last_report_time = reported_track_ids[track_id]
|
||
# if current_time - last_report_time < report_interval:
|
||
# print(f"基于track_id,触发去重事件:{track_id}")
|
||
# should_report = False
|
||
print(f"方法 cal_des_invade")
|
||
if not track_filter.should_report(track_id):
|
||
should_report = False
|
||
break
|
||
|
||
# if track_id < 0: # 适配MultiYOLODetector类,该类不支持追踪,默认track_id为-1
|
||
# should_report = True
|
||
|
||
# 如果使用TrackIDEventFilter判断需要上报
|
||
if should_report:
|
||
us.append(u)
|
||
vs.append(v)
|
||
heights.append(h)
|
||
|
||
if should_report:
|
||
print("进行侵限计算")
|
||
location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
|
||
height, cam_longitude, cam_latitude, img_width, img_height,
|
||
heights)
|
||
if not location_results:
|
||
continue
|
||
# for point in location_results:
|
||
# target_location.append({point[1], point[2]})
|
||
# point_list = [] # 整理红线集合
|
||
repeat_state = False
|
||
show_des = 0
|
||
str_loca = ""
|
||
des_location_result=[]
|
||
if repeat_dis > 0: # ai_model_list repeat_dis 字段大于零,才启用去重
|
||
if len(target_location_back) > 0: # 当前逻辑并不严谨,只是比较了第一个位置信息
|
||
des1_back = target_location_back[0]
|
||
des1_back_longitude = des1_back[0]
|
||
des1_back_latitude = des1_back[1]
|
||
des1_back_height = des1_back[2]
|
||
|
||
des1 = location_results[0]
|
||
des1_longitude = des1[0]
|
||
des1_latitude = des1[1]
|
||
des_location_result.append({"longitude": des1_longitude,
|
||
"latitude": des1_latitude})
|
||
|
||
des1_height = des1[2]
|
||
str_loca = f"{des1_back_longitude}:{des1_back_latitude}---{des1_longitude}{des1_latitude}"
|
||
des = haversine(des1_back_longitude, des1_back_latitude, des1_longitude, des1_latitude)
|
||
show_des = des
|
||
if des < repeat_dis:
|
||
print(f"触发基于坐标判断重复,坐标距离:{des}")
|
||
repeat_state = True
|
||
else:
|
||
target_location_back = location_results # 未触发去重逻辑,即更新本地位置缓存
|
||
else:
|
||
target_location_back = location_results # 基于坐标位置去重判断,如果失败就缓存本次位置信息
|
||
# 测试红线
|
||
# for point in results:
|
||
# point_list.append([point["u"], point["v"]])
|
||
# for point in message_point:
|
||
# cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
|
||
# (point["box"][2], point["box"][3]), (0, 255, 255), 2)
|
||
# cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)],
|
||
# isClosed=True, color=(0, 0, 255), thickness=2)
|
||
pic_count_hongxian = pic_count_hongxian + 1
|
||
# cv2.imwrite(f"save_pic\invade-hongxian\hongxianongly-{pic_count_hongxian}.jpg", frame_copy)
|
||
|
||
# if len(invade_point) > 0:
|
||
# print("hongxianhongxianhongxianhongxian")
|
||
# pic_count = pic_count + 1
|
||
# # cv2.imwrite(f"save_pic\invade\hongxian-{pic_count}.jpg", frame)
|
||
# drawn_frame = frame_copy # 关键修复:深拷贝绘制后的帧
|
||
# 图像编码
|
||
|
||
if not repeat_state: # 未触发去重逻辑,即执行图像上传逻辑
|
||
print(f"未发现重复,现在上传{str_loca}======{show_des}")
|
||
|
||
def encode_origin_frame():
|
||
success, buffer = cv2.imencode(".jpg", frame)
|
||
return buffer.tobytes() if success else None
|
||
|
||
def encode_frame():
|
||
success, buffer = cv2.imencode(".jpg", frame_copy)
|
||
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
|
||
# cv2.imwrite(f"save_pic\qm\hongxian-encode-{pic_count_hongxian}.jpg", frame)
|
||
#
|
||
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
|
||
# cv2.imwrite(f"save_pic\gdcl\hongxian-encode-{pic_count_hongxian}.jpg", frame_copy)
|
||
|
||
return buffer.tobytes() if success else None
|
||
|
||
buffer_bytes = await loop.run_in_executor(invade_executor, encode_frame)
|
||
buffer_origin_bytes = await loop.run_in_executor(invade_executor, encode_origin_frame)
|
||
if not buffer_bytes:
|
||
continue
|
||
|
||
# 并行处理上传和MQTT发布
|
||
async def upload_and_publish():
|
||
# 上传到MinIO
|
||
def upload_minio():
|
||
minio_path, file_type = upload_frame_buff_from_buffer(buffer_bytes, None)
|
||
minio_origin_path, file_type = upload_frame_buff_from_buffer(buffer_origin_bytes, None)
|
||
return minio_path, minio_origin_path, file_type
|
||
# return upload_frame_buff_from_buffer(buffer_bytes, None)
|
||
|
||
minio_path, minio_origin_path, file_type = await loop.run_in_executor(
|
||
invade_executor, upload_minio
|
||
)
|
||
print(f"minio_pathminio_pathminio_pathminio_path {minio_path}")
|
||
# 构造消息
|
||
message = {
|
||
"task_id": task_id,
|
||
"time": str(datetime.now()),
|
||
"detection_id": timestamp,
|
||
"minio": {"minio_path": minio_path,
|
||
"minio_origin_path": minio_origin_path,
|
||
"file_type": file_type},
|
||
"box_detail": [{
|
||
"model_id": model_func_id,
|
||
"cls_count": cls_count,
|
||
"box_count": [message_point], # 特殊处理
|
||
"location_results": location_results # 增加经纬度信息
|
||
}],
|
||
"osd_location": {
|
||
"longitude": cam_longitude,
|
||
"latitude": cam_latitude
|
||
},
|
||
"des_location":des_location_result
|
||
|
||
}
|
||
|
||
if not event_queue.full():
|
||
await event_queue.put({
|
||
"timestamp": timestamp # 存储事件触发的时刻,用作视频制作
|
||
})
|
||
else:
|
||
logger.warning("event_queue 帧队列已满,等待1ms后重试")
|
||
await asyncio.sleep(0.001)
|
||
|
||
print(f"hongxianhongxianhongxianhongxian上传 {message}:{event_queue.qsize()} ")
|
||
message_json = json.dumps(message, ensure_ascii=False)
|
||
await mqtt.publish(mqtt_publish_topic, message_json)
|
||
|
||
asyncio.create_task(upload_and_publish())
|
||
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
|
||
|
||
# async with invade_cache_lock:
|
||
# # 读取或修改共享变量
|
||
# send_count = shared_local_cache["invade_send_count"]
|
||
# send_count = send_count + 1
|
||
# shared_local_cache["invade_send_count"] = send_count
|
||
# if send_count > 1:
|
||
# # 创建独立任务执行上传和发布
|
||
# shared_local_cache["invade_send_count"] = 0
|
||
# asyncio.create_task(upload_and_publish())
|
||
|
||
except Exception as e:
|
||
print(f"cal_des_invade 错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
await asyncio.sleep(0.1)
|
||
|
||
print("cal_des_invade读取线程已停止")
|
||
|
||
|
||
# 全局共享变量
|
||
shared_local_cache = {
|
||
"send_count": 0,
|
||
"invade_send_count": 0,
|
||
"video_process_status": 0 # 0、1、2 分别表示录像识别的三种状态,未开始、开始、结束
|
||
}
|
||
from asyncio import Lock, AbstractEventLoop
|
||
|
||
cache_lock = Lock() # 用于保护共享变量的锁
|
||
invade_cache_lock = Lock() # 用于保护共享变量的锁
|
||
|
||
|
||
async def send_frame_to_s3_mq(loop,upload_executor,task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event,
|
||
cv_frame_queue: asyncio.Queue,
|
||
event_queue: asyncio.Queue = None,
|
||
device_height: float = float(200), repeat_dis: float = -1, repeat_time: float = -1):
|
||
global stats
|
||
start_time = time.time()
|
||
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
||
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
|
||
track_filter = TrackIDEventFilter(max_inactive_time=10)
|
||
reported_track_ids = defaultdict(float)
|
||
count_pic = 0
|
||
report_interval = 8
|
||
# loop = asyncio.get_running_loop()
|
||
local_func_cache = {
|
||
"func_100000": None,
|
||
"func_100004": None, # 存储缓存,缓存人员track_id
|
||
"func_100006": None # 存储缓存,缓存车辆track_id
|
||
}
|
||
|
||
para = {
|
||
"category": 3
|
||
}
|
||
target_location_back = [] # 本地缓存,用作位置重复计算
|
||
current_time_second = int(time.time())
|
||
repeat_time_count = 0
|
||
while not cancel_flag.is_set():
|
||
upload_start = time.time()
|
||
try:
|
||
# 检查队列长度,避免堆积
|
||
if cv_frame_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
|
||
print(f"警告:cv_frame_queue 积压(当前长度={cv_frame_queue.qsize()}),清空队列")
|
||
while not cv_frame_queue.empty():
|
||
await cv_frame_queue.get()
|
||
print("队列已清空,等待新数据...")
|
||
await asyncio.sleep(0.1)
|
||
continue
|
||
|
||
# 获取队列数据
|
||
try:
|
||
cv_frame = await asyncio.wait_for(cv_frame_queue.get(), timeout=0.05)
|
||
# 检查数据类型
|
||
if not isinstance(cv_frame, dict):
|
||
print(f"⚠️ 警告:cv_frame 不是字典,而是 {type(cv_frame)}")
|
||
continue
|
||
except asyncio.TimeoutError:
|
||
continue
|
||
if repeat_time > 0: # 基于时间的重复计算,是否使能
|
||
read_cv_time_second = int(time.time())
|
||
# print(f"read_cv_time_secondread_cv_time_second {read_cv_time_second}")
|
||
if read_cv_time_second - current_time_second < repeat_time and repeat_time_count > 0:
|
||
# print("触发事件去重")
|
||
continue
|
||
else:
|
||
print("没有触发事件去重")
|
||
repeat_time_count += 1 # 防止丢失第一个目标
|
||
current_time_second = read_cv_time_second
|
||
# 准备数据
|
||
frame_copy = cv_frame['frame_copy']
|
||
frame = cv_frame['frame']
|
||
target_point = cv_frame['target_point']
|
||
detections = cv_frame['detections']
|
||
detections_list = cv_frame['detections_list']
|
||
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
|
||
timestamp = cv_frame.get('timestamp', time.time_ns()) # 默认空字典
|
||
air_alti = cv_frame['osd_info']
|
||
|
||
img_height, img_width = frame.shape[:2]
|
||
gimbal_yaw = air_alti.gimbal_yaw
|
||
gimbal_pitch = air_alti.gimbal_pitch
|
||
gimbal_roll = air_alti.gimbal_roll
|
||
height = air_alti.height
|
||
cam_longitude = air_alti.longitude
|
||
cam_latitude = air_alti.latitude
|
||
# 初始化默认值
|
||
frame11 = frame_copy # 默认使用原始帧
|
||
box_detail = [] # 默认空列表
|
||
|
||
current_time = time.time()
|
||
h = device_height
|
||
us = []
|
||
vs = []
|
||
heights = []
|
||
should_report = True
|
||
print(f"target_pointtarget_point {len(target_point)}")
|
||
count_item = 0
|
||
des_location_result=[]
|
||
for item in target_point:
|
||
# # 跳过无效的track_id
|
||
# 检查是否应该上报该track_id
|
||
u = item["u"]
|
||
v = item["v"]
|
||
cls_id = item["cls_id"]
|
||
track_id = item["track_id"]
|
||
new_track_id = item["new_track_id"]
|
||
|
||
# should_report = True
|
||
|
||
# # 如果这个track_id已经上报过,检查是否超过上报间隔
|
||
# if new_track_id in reported_track_ids:
|
||
# last_report_time = reported_track_ids[new_track_id]
|
||
# if current_time - last_report_time < report_interval:
|
||
# print(f"基于track_id,触发去重事件:{new_track_id}")
|
||
# should_report = False
|
||
# print(f"new_track_idnew_track_id {new_track_id} {track_id}")
|
||
# if should_report and track_filter.should_report(new_track_id):
|
||
# should_report = True
|
||
print(f"方法 send_frame——to {count_item}")
|
||
count_item += 1
|
||
if not track_filter.should_report(track_id):
|
||
should_report = False
|
||
break
|
||
# if track_id < 0: # 适配MultiYOLODetector类,该类不支持追踪,默认track_id为-1
|
||
# should_report = True
|
||
|
||
# 如果使用TrackIDEventFilter判断需要上报
|
||
if should_report:
|
||
us.append(u)
|
||
vs.append(v)
|
||
heights.append(h)
|
||
# location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
|
||
# height, cam_longitude, cam_latitude, img_width, img_height,
|
||
# heights)
|
||
#
|
||
# if not location_results:
|
||
# print("location_results is null")
|
||
# continue
|
||
if len(us) > 0:
|
||
print("进行侵限计算")
|
||
location_results = cal_canv_location_by_osd(us, vs, gimbal_pitch, gimbal_yaw, gimbal_roll,
|
||
height, cam_longitude, cam_latitude, img_width, img_height,
|
||
heights)
|
||
if not location_results:
|
||
continue
|
||
|
||
if location_results:
|
||
des1 = location_results[0]
|
||
des1_longitude = des1[0]
|
||
des1_latitude = des1[1]
|
||
des1_height = des1[2]
|
||
des_location_result.append({"longitude": des1_longitude,
|
||
"latitude": des1_latitude})
|
||
|
||
repeat_state = False
|
||
show_des = 0
|
||
str_loca = ""
|
||
if repeat_dis > 0: # ai_model_list repeat_dis 字段大于零,才启用去重
|
||
if len(target_location_back) > 0: # 当前逻辑并不严谨,只是比较了第一个位置信息
|
||
des1_back = target_location_back[0]
|
||
des1_back_longitude = des1_back[0]
|
||
des1_back_latitude = des1_back[1]
|
||
des1_back_height = des1_back[2]
|
||
|
||
des1 = location_results[0]
|
||
des1_longitude = des1[0]
|
||
des1_latitude = des1[1]
|
||
des1_height = des1[2]
|
||
|
||
|
||
str_loca = f"{des1_back_longitude}:{des1_back_latitude}---{des1_longitude}{des1_latitude}"
|
||
des = haversine(des1_back_longitude, des1_back_latitude, des1_longitude, des1_latitude)
|
||
show_des = des
|
||
if des < repeat_dis:
|
||
print(f"触发基于坐标判断重复,坐标距离:{des}")
|
||
repeat_state = True
|
||
else:
|
||
target_location_back = location_results # 未触发去重逻辑,即更新本地位置缓存
|
||
else:
|
||
target_location_back = location_results # 基于坐标位置去重判断,如果失败就缓存本次位置信息
|
||
if not repeat_state: # 未触发去重逻辑,即执行图像上传逻辑
|
||
for idx, model_para in enumerate(model_para_list):
|
||
if not isinstance(detections_list, list) or idx >= len(detections_list):
|
||
continue
|
||
detections = detections_list[idx] # 正确获取当前模型的检测结果
|
||
chinese_label = model_para.get("model_chinese_labe", {})
|
||
model_cls = model_para.get("model_cls_index", {})
|
||
list_func_id = model_para.get("model_list_func_id", -11)
|
||
func_id = model_para.get("func_id", [])
|
||
|
||
# 获取DRC消息(同步操作,放到线程池)
|
||
local_drc_message = await loop.run_in_executor(upload_executor, get_local_drc_message)
|
||
|
||
if detections is None or len(detections.boxes) < 1:
|
||
continue
|
||
|
||
try:
|
||
# 图像处理和结果计算
|
||
frame11, box_detail1 = await loop.run_in_executor(
|
||
upload_executor,
|
||
cal_tricker_results,
|
||
frame_copy, detections, None,
|
||
func_id, local_func_cache, para, model_cls, chinese_label, list_func_id
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"处理帧时出错: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
continue
|
||
|
||
count_pic = count_pic + 1
|
||
|
||
# 图像编码
|
||
def encode_frame():
|
||
success, buffer = cv2.imencode(".jpg", frame11)
|
||
return buffer.tobytes() if success else None
|
||
|
||
def encode_origin_frame():
|
||
success, buffer = cv2.imencode(".jpg", frame)
|
||
return buffer.tobytes() if success else None
|
||
|
||
buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
|
||
buffer_origin_bytes = await loop.run_in_executor(upload_executor, encode_origin_frame)
|
||
if not buffer_bytes:
|
||
continue
|
||
|
||
# 并行处理上传和MQTT发布
|
||
async def upload_and_publish():
|
||
# 上传到MinIO
|
||
def upload_minio():
|
||
minio_path, file_type = upload_frame_buff_from_buffer(buffer_bytes, None)
|
||
minio_origin_path, file_type = upload_frame_buff_from_buffer(buffer_origin_bytes, None)
|
||
return minio_path, minio_origin_path, file_type
|
||
|
||
minio_path, minio_origin_path, file_type = await loop.run_in_executor(
|
||
upload_executor, upload_minio
|
||
)
|
||
|
||
# 构造消息
|
||
message = {
|
||
"task_id": task_id,
|
||
"time": str(datetime.now()),
|
||
"detection_id": timestamp,
|
||
"minio": {"minio_path": minio_path, "minio_origin_path": minio_origin_path,
|
||
"file_type": file_type},
|
||
"box_detail": box_detail1,
|
||
"uav_location": local_drc_message,
|
||
"osd_location": {
|
||
"longitude": cam_longitude,
|
||
"latitude": cam_latitude
|
||
},
|
||
"des_location":des_location_result
|
||
}
|
||
await event_queue.put({
|
||
"timestamp": timestamp # 存储事件触发的时刻,用作视频制作
|
||
})
|
||
message_json = json.dumps(message, ensure_ascii=False)
|
||
await mqtt.publish(mqtt_topic, message_json)
|
||
|
||
asyncio.create_task(upload_and_publish())
|
||
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
|
||
# async with cache_lock:
|
||
# # 读取或修改共享变量
|
||
# send_count = shared_local_cache["send_count"]
|
||
# send_count = send_count + 1
|
||
# shared_local_cache["send_count"] = send_count
|
||
#
|
||
# # 伪代码,目前安全帽比较难识别到
|
||
# if 100014 == model_para_list[0]["model_list_func_id"]:
|
||
# if send_count > 2:
|
||
# # 创建独立任务执行上传和发布
|
||
# shared_local_cache["send_count"] = 0
|
||
# asyncio.create_task(upload_and_publish())
|
||
# else:
|
||
# if send_count > 30:
|
||
# # 创建独立任务执行上传和发布
|
||
# shared_local_cache["send_count"] = 0
|
||
# asyncio.create_task(upload_and_publish())
|
||
if 'frame' in locals():
|
||
upload_time = time.time() - upload_start
|
||
print(f"上传耗时: {upload_time:.4f}s")
|
||
except Exception as e:
|
||
print(f"send_frame_to_s3_mq 错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
await asyncio.sleep(0.1)
|
||
except asyncio.CancelledError:
|
||
print("send_frame_to_s3 收到取消信号")
|
||
raise
|
||
# finally:
|
||
# log_perf('send_frame_to_s3_mq', start_time)
|
||
|
||
# 更新性能统计
|
||
stats['processed'] += 1
|
||
if time.time() - stats['last_time'] >= 1.0:
|
||
stats['avg_fps'] = stats['processed'] / (time.time() - stats['last_time'])
|
||
print(f"处理速度: {stats['avg_fps']:.2f} FPS")
|
||
stats['processed'] = 0
|
||
stats['last_time'] = time.time()
|
||
print("send_frame_to_s3线程已停止")
|
||
|
||
|
||
#
|
||
# def frames_to_video_bytes(frames, fps=25, format="mp4"):
|
||
# """
|
||
# 将多个帧转换为视频字节数组(内存处理,不存储到本地)
|
||
#
|
||
# Args:
|
||
# frames (list): 包含多个帧的列表,每个帧是 numpy 数组(BGR 或 RGB 格式)
|
||
# fps (int): 视频帧率
|
||
# format (str): 视频格式(如 "mp4"、"avi")
|
||
#
|
||
# Returns:
|
||
# bytes: 视频的字节数组
|
||
# """
|
||
# if len(frames) == 0:
|
||
# raise ValueError("帧列表不能为空")
|
||
#
|
||
# # 确保所有帧是 uint8 类型
|
||
# frames = [np.array(frame, dtype=np.uint8) for frame in frames]
|
||
#
|
||
# # 使用 BytesIO 存储视频数据
|
||
# output = BytesIO()
|
||
#
|
||
# # 使用 imageio 写入内存
|
||
# with imageio.get_writer(output, format=format, fps=fps) as writer:
|
||
# for frame in frames:
|
||
# # 如果帧是 BGR 格式(OpenCV 默认),转换为 RGB
|
||
# if frame.ndim == 3 and frame.shape[2] == 3:
|
||
# frame = frame[..., ::-1] # BGR → RGB
|
||
# writer.append_data(frame)
|
||
#
|
||
# # 获取字节数据
|
||
# video_bytes = output.getvalue()
|
||
# output.close()
|
||
#
|
||
# return video_bytes
|
||
|
||
#
|
||
# def frames_to_video_bytes(frames, fps=25, format="mp4"):
|
||
# """
|
||
# 将多个帧转换为视频字节数组(内存处理,不存储到本地)
|
||
# Args:
|
||
# frames (list): 包含多个帧的列表,每个帧是 numpy 数组(BGR 或 RGB 格式)
|
||
# fps (int): 视频帧率
|
||
# format (str): 视频格式(如 "mp4"、"avi")
|
||
# Returns:
|
||
# bytes: 视频的字节数组 或 None(如果失败)
|
||
# """
|
||
# if len(frames) == 0:
|
||
# logger.warning("frames_to_video_bytes: 输入帧列表为空")
|
||
# return None
|
||
#
|
||
# try:
|
||
# # 确保所有帧是 uint8 类型
|
||
# frames = [np.array(frame, dtype=np.uint8) for frame in frames]
|
||
#
|
||
# output = BytesIO()
|
||
# with imageio.get_writer(output, format=format, fps=fps) as writer:
|
||
# for frame in frames:
|
||
# if frame.ndim == 3 and frame.shape[2] == 3:
|
||
# frame = frame[..., ::-1] # BGR → RGB
|
||
# writer.append_data(frame)
|
||
# return output.getvalue()
|
||
# except Exception as e:
|
||
# logger.error(f"frames_to_video_bytes 生成视频失败: {e}")
|
||
# return None
|
||
|
||
|
||
import subprocess
|
||
import numpy as np
|
||
import tempfile
|
||
import os
|
||
import time
|
||
import threading
|
||
import traceback
|
||
|
||
|
||
def frames_to_video_bytes(frames, fps=25, format="flv"):
|
||
"""
|
||
使用临时文件暂存帧数据,生成视频字节后自动删除临时文件
|
||
特点:降低管道传输压力,适配大尺寸帧和旧版FFMPEG
|
||
"""
|
||
|
||
# 日志工具:带时间戳和详细级别
|
||
def log_info(msg):
|
||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] INFO: {msg}")
|
||
|
||
def log_warning(msg):
|
||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] WARNING: {msg}")
|
||
|
||
def log_error(msg):
|
||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] ERROR: {msg}")
|
||
|
||
def log_debug(msg):
|
||
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] DEBUG: {msg}")
|
||
|
||
# 强制切换为FLV格式(兼容性最佳)
|
||
if format != "flv":
|
||
log_warning(f"自动切换为FLV格式(兼容性最佳)")
|
||
# format = "flv"
|
||
|
||
# 输入帧合法性检查
|
||
if not isinstance(frames, list) or len(frames) == 0:
|
||
log_error("输入帧列表为空或非列表类型,返回None")
|
||
return None
|
||
|
||
# 记录开始时间
|
||
start_time = time.time()
|
||
|
||
# 创建临时文件(自动删除)
|
||
temp_frame_file = None
|
||
temp_video_file = None
|
||
try:
|
||
# 创建存储帧数据的临时文件
|
||
temp_frame_fd, temp_frame_path = tempfile.mkstemp(suffix=".raw", text=False)
|
||
os.close(temp_frame_fd) # 关闭文件描述符,使用路径操作
|
||
temp_frame_file = open(temp_frame_path, 'wb')
|
||
log_debug(f"创建帧数据临时文件: {temp_frame_path}")
|
||
|
||
# 创建存储视频的临时文件
|
||
temp_video_fd, temp_video_path = tempfile.mkstemp(suffix=f".{format}", text=False)
|
||
os.close(temp_video_fd)
|
||
log_debug(f"创建视频临时文件: {temp_video_path}")
|
||
|
||
# --------------------------
|
||
# 步骤1:帧预处理并写入临时文件
|
||
# --------------------------
|
||
try:
|
||
log_info("开始帧预处理并写入临时文件")
|
||
# 处理第一帧:确认基础信息
|
||
first_frame = np.asarray(frames[0], dtype=np.uint8)
|
||
if first_frame.ndim != 3 or first_frame.shape[2] != 3:
|
||
log_error("第一帧必须是3通道BGR/RGB图像")
|
||
return None
|
||
|
||
frame_h, frame_w = first_frame.shape[:2]
|
||
is_4k = frame_w >= 3840 and frame_h >= 2160
|
||
# 4K帧特殊优化
|
||
if is_4k:
|
||
single_frame_mb = (frame_w * frame_h * 3) / (1024 * 1024)
|
||
total_data_mb = single_frame_mb * len(frames)
|
||
log_info(f"检测到4K帧:{frame_w}x{frame_h},单帧{single_frame_mb:.1f}MB,总数据{total_data_mb:.1f}MB")
|
||
fps = min(fps, 12)
|
||
log_info(f"4K适配:帧率调整为{fps}fps")
|
||
|
||
# 统一所有帧格式并写入临时文件
|
||
try:
|
||
from cv2 import resize
|
||
except ImportError:
|
||
log_error("未安装opencv-python,请执行 pip install opencv-python")
|
||
return None
|
||
|
||
for idx, frame in enumerate(frames):
|
||
try:
|
||
# 1. 转为uint8类型
|
||
frame_uint8 = np.asarray(frame, dtype=np.uint8)
|
||
# 2. 统一尺寸
|
||
if frame_uint8.shape[:2] != (frame_h, frame_w):
|
||
log_debug(f"帧{idx + 1}尺寸不匹配,缩放至{frame_w}x{frame_h}")
|
||
frame_uint8 = resize(frame_uint8, (frame_w, frame_h), interpolation=1)
|
||
# 3. BGR转RGB
|
||
frame_rgb = frame_uint8[..., ::-1]
|
||
# 4. 写入临时文件
|
||
temp_frame_file.write(frame_rgb.tobytes())
|
||
log_debug(f"帧{idx + 1}已写入临时文件")
|
||
except Exception as e:
|
||
log_error(f"处理帧{idx + 1}失败:{str(e)}")
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
# 确保所有数据写入磁盘
|
||
temp_frame_file.flush()
|
||
os.fsync(temp_frame_file.fileno())
|
||
temp_frame_file.close()
|
||
log_info(f"所有帧已写入临时文件,共{len(frames)}帧")
|
||
except Exception as e:
|
||
log_error(f"帧预处理阶段异常:{str(e)}")
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
# --------------------------
|
||
# 步骤2:构造FFMPEG命令(使用临时文件)
|
||
# --------------------------
|
||
try:
|
||
log_info("构造FFMPEG命令(使用临时文件输入)")
|
||
ffmpeg_cmd = [
|
||
# "ffmpeg",
|
||
"/usr/bin/ffmpeg",
|
||
"-hide_banner",
|
||
"-loglevel", "error",
|
||
"-y",
|
||
# 输入配置(从临时文件读取)
|
||
"-f", "rawvideo",
|
||
"-vcodec", "rawvideo",
|
||
"-s", f"{frame_w}x{frame_h}",
|
||
"-pix_fmt", "rgb24",
|
||
"-r", str(fps),
|
||
"-i", temp_frame_path, # 从临时文件读取帧数据
|
||
# 输出配置
|
||
"-vcodec", "libx264",
|
||
"-pix_fmt", "yuv420p",
|
||
"-r", str(fps),
|
||
"-preset", "ultrafast",
|
||
"-crf", "30",
|
||
# 输出到临时文件
|
||
"-f", format,
|
||
temp_video_path
|
||
]
|
||
|
||
log_info(f"FFMPEG命令:{' '.join(ffmpeg_cmd)}")
|
||
except Exception as e:
|
||
log_error(f"构造FFMPEG命令失败:{str(e)}")
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
# --------------------------
|
||
# 步骤3:执行FFMPEG处理
|
||
# --------------------------
|
||
try:
|
||
log_info("启动FFMPEG进程(使用临时文件)")
|
||
# 启动FFMPEG进程
|
||
process = subprocess.Popen(
|
||
ffmpeg_cmd,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.PIPE,
|
||
text=False
|
||
)
|
||
|
||
# 捕获错误输出
|
||
ffmpeg_error_log = []
|
||
|
||
def read_stderr():
|
||
while True:
|
||
err_line = process.stderr.readline()
|
||
if not err_line:
|
||
break
|
||
err_str = err_line.decode("utf-8", errors="replace").strip()
|
||
if err_str:
|
||
ffmpeg_error_log.append(err_str)
|
||
log_error(f"FFMPEG错误:{err_str}")
|
||
|
||
stderr_thread = threading.Thread(target=read_stderr, daemon=True)
|
||
stderr_thread.start()
|
||
|
||
# 等待处理完成
|
||
log_info("等待FFMPEG处理完成(超时120秒)")
|
||
process.wait(timeout=120)
|
||
stderr_thread.join()
|
||
|
||
# 检查处理结果
|
||
if process.returncode != 0:
|
||
log_error(f"FFMPEG处理失败(退出码:{process.returncode})")
|
||
log_error(f"FFMPEG完整错误:\n{chr(10).join(ffmpeg_error_log)}")
|
||
return None
|
||
|
||
# 读取视频文件内容
|
||
log_info("读取视频临时文件内容")
|
||
with open(temp_video_path, 'rb') as f:
|
||
video_bytes = f.read()
|
||
|
||
log_info(
|
||
f"视频生成成功!大小:{len(video_bytes) / 1024 / 1024:.2f}MB,总耗时:{time.time() - start_time:.2f}秒")
|
||
return video_bytes
|
||
|
||
except subprocess.TimeoutExpired:
|
||
log_error("FFMPEG处理超时(超过120秒)")
|
||
process.kill()
|
||
return None
|
||
except Exception as e:
|
||
log_error(f"FFMPEG处理异常:{str(e)}")
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
finally:
|
||
# 确保临时文件被删除
|
||
try:
|
||
if temp_frame_file and not temp_frame_file.closed:
|
||
temp_frame_file.close()
|
||
if temp_frame_path and os.path.exists(temp_frame_path):
|
||
os.remove(temp_frame_path)
|
||
log_debug(f"已删除帧临时文件: {temp_frame_path}")
|
||
except Exception as e:
|
||
log_warning(f"删除帧临时文件失败:{str(e)}")
|
||
|
||
try:
|
||
if temp_video_path and os.path.exists(temp_video_path):
|
||
os.remove(temp_video_path)
|
||
log_debug(f"已删除视频临时文件: {temp_video_path}")
|
||
except Exception as e:
|
||
log_warning(f"删除视频临时文件失败:{str(e)}")
|
||
|
||
log_info("视频生成流程结束")
|
||
|
||
|
||
async def cut_evnt_video_publish(task_id, mqtt, mqtt_topic, cancel_flag: asyncio.Event,
|
||
event_queue: asyncio.Queue = None,
|
||
timestamp_frame_queue: TimestampedQueue = None):
|
||
loop = asyncio.get_running_loop()
|
||
upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
|
||
|
||
while not cancel_flag.is_set():
|
||
try:
|
||
event = await event_queue.get()
|
||
print(f"[cut循环] 成功获取事件: {event}")
|
||
|
||
if not isinstance(event, dict) or "timestamp" not in event:
|
||
print(f"[cut循环] 无效事件,跳过")
|
||
continue
|
||
|
||
timestamp = event.get("timestamp")
|
||
if timestamp is None:
|
||
print("⚠️ 警告:事件中缺少 timestamp")
|
||
continue
|
||
|
||
matched_items = timestamp_frame_queue.query_by_timestamp(timestamp)
|
||
print(f"[cut循环] 匹配到 {len(matched_items)} 个帧")
|
||
if len(matched_items) == 0:
|
||
continue
|
||
|
||
frames = [item["frame"] for item in matched_items]
|
||
if not frames: # 显式检查帧列表是否为空
|
||
print("⚠️ 警告:匹配到的帧列表为空")
|
||
continue
|
||
|
||
video_bytes = frames_to_video_bytes(frames, fps=150, format="mp4")
|
||
if video_bytes is None: # 检查返回值
|
||
print("⚠️ 警告:视频生成失败,返回None")
|
||
continue
|
||
|
||
async def upload_and_publish():
|
||
def upload_minio():
|
||
return upload_video_buff_from_buffer(video_bytes, video_format="mp4")
|
||
|
||
try:
|
||
minio_path, file_type = await loop.run_in_executor(upload_executor, upload_minio)
|
||
# message = {"detection_id": timestamp, "minio_path": minio_path}
|
||
message = {
|
||
"task_id": task_id,
|
||
"time": str(datetime.now()),
|
||
"detection_id": timestamp,
|
||
"minio": {"minio_path": minio_path,
|
||
"minio_origin_path": minio_path,
|
||
"file_type": file_type},
|
||
"box_detail": [{
|
||
|
||
}],
|
||
"osd_location": {
|
||
},
|
||
"des_location": []
|
||
}
|
||
print(f"成功上传视频: {message}")
|
||
await mqtt.publish(mqtt_topic, json.dumps(message, ensure_ascii=False))
|
||
except Exception as e:
|
||
print(f"上传失败: {e}")
|
||
|
||
asyncio.create_task(upload_and_publish())
|
||
|
||
except Exception as e:
|
||
print(f"[cut循环] 处理事件出错: {e}")
|
||
finally:
|
||
if event is not None:
|
||
event_queue.task_done()
|
||
print(f"[cut循环] 确认消费,剩余未完成任务: {event_queue._unfinished_tasks}")
|
||
|
||
if event_queue.qsize() > 0 and event_queue._unfinished_tasks == 0:
|
||
try:
|
||
event_queue.put_nowait(None)
|
||
dummy = await event_queue.get()
|
||
event_queue.task_done()
|
||
except:
|
||
pass
|
||
|
||
|
||
async def start_rtmp_processing(video_url: str, task_id: str, model_configs: List[Dict],
|
||
mqtt_pub_ip: str, mqtt_pub_port: int, mqtt_pub_topic: str,
|
||
mqtt_sub_ip: str, mqtt_sub_port: int, mqtt_sub_topic: str,
|
||
output_rtmp_url: str,
|
||
invade_enable: bool, invade_file: str, camera_para_url: str,
|
||
device_height: float, repeat_dis: float, repeat_time: float):
|
||
# 初始化资源
|
||
# await initialize_resources()
|
||
logger.info(f"拉流地址{video_url}")
|
||
logger.info(f"推流地址{output_rtmp_url}")
|
||
cancel_flag = asyncio.Event()
|
||
# 初始化局部变量(避免全局污染)
|
||
frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
|
||
processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
event_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE) # 存储事件,作为截取视频的标准
|
||
print(f"[start_rtmp] 创建event_queue实例:{id(event_queue)}") # 打印内存地址
|
||
|
||
loop = asyncio.get_running_loop() # 声明事件循环管理
|
||
|
||
stream_containers: Dict[str, Any] = {}
|
||
# 创建取消标志
|
||
cancel_flag = asyncio.Event()
|
||
timestamp_frame_queue = TimestampedQueue(maxlen=1000) # 限定长度的队列,先进先出存储连续帧,事件触发即截取前后视频
|
||
|
||
try:
|
||
# 加载侵限区域数据
|
||
list_points = []
|
||
# 加载现场摄像头相关数据
|
||
camera_para = None
|
||
if invade_file and camera_para_url:
|
||
camera_file_path = downFile(camera_para_url)
|
||
camera_para = read_camera_params(camera_file_path)
|
||
|
||
invade_file_path = downFile(invade_file)
|
||
if invade_file_path:
|
||
with open(invade_file_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
for feature in data.get('features', []):
|
||
geometry = feature.get('geometry', {})
|
||
coordinates = geometry.get('coordinates', [])
|
||
points = [Point(coord[1], coord[0], coord[2], i)
|
||
for i, coord in enumerate(coordinates[0])]
|
||
list_points.append(points)
|
||
# 初始化MQTT和设备
|
||
mqtt = MQTTService(mqtt_pub_ip, port=mqtt_pub_port)
|
||
# detector = MultiYOLODetector_TrackId(model_configs)
|
||
detector = MultiYoloTrtDetectorTrackId(model_configs)
|
||
|
||
print(f"mqtt_sub_ipmqtt_sub_ipmqtt_sub_ip {mqtt_sub_ip} {mqtt_sub_port} {mqtt_sub_topic}")
|
||
|
||
# 创建MQTT设备任务
|
||
async def start_mqtt_device():
|
||
device = MQTTDevice(
|
||
ip=mqtt_sub_ip,
|
||
port=mqtt_sub_port,
|
||
topics=[mqtt_sub_topic],
|
||
queue_size=50
|
||
)
|
||
# 对osd_info_push 相关信息进行回调,否则消息读取非常慢
|
||
device.register_callback(
|
||
topic=mqtt_sub_topic,
|
||
method="osd_info_push",
|
||
callback=empty_osd_callback # 直接引用已定义的函数
|
||
)
|
||
await device.start()
|
||
return device
|
||
|
||
# device_task = asyncio.create_task(start_mqtt_device(), name="start_mqtt_device")
|
||
# await asyncio.sleep(2) # 等待设备初始化
|
||
#
|
||
# # 确保设备已启动
|
||
# device = await device_task
|
||
|
||
# 然后在 start_rtmp_processing 中直接等待:
|
||
device = await start_mqtt_device() # 直到连接和订阅完成才继续
|
||
|
||
if device is None:
|
||
raise RuntimeError("Failed to start MQTT device")
|
||
|
||
# 创建任务列表
|
||
tasks = []
|
||
read_task = None
|
||
process_task = None
|
||
write_task = None
|
||
invade_task = None
|
||
upload_tasks = []
|
||
model_count = len(model_configs)
|
||
try:
|
||
# RTMP读取任务
|
||
read_rtmp_frames_executor = ThreadPoolExecutor(max_workers=Config.READ_RTMP_WORKERS)
|
||
read_task = asyncio.create_task(
|
||
read_rtmp_frames(
|
||
loop,
|
||
read_rtmp_frames_executor,
|
||
video_url,
|
||
device,
|
||
mqtt_sub_topic,
|
||
"drc_camera_osd_info_push",
|
||
mqtt_sub_topic,
|
||
"osd_info_push",
|
||
cancel_flag,
|
||
frame_queue,
|
||
timestamp_frame_queue
|
||
),
|
||
name="read_rtmp_frames"
|
||
)
|
||
tasks.append(read_task)
|
||
#
|
||
# # 处理任务
|
||
process_frame_executor = ThreadPoolExecutor(max_workers=Config.PROCESS_FRAME_WORKERS)
|
||
process_task = asyncio.create_task(
|
||
process_frames(detector, cancel_flag, frame_queue, processed_queue),
|
||
name="process_frames"
|
||
)
|
||
tasks.append(process_task)
|
||
#
|
||
# # 推流任务
|
||
invade_state = bool(list_points) and invade_enable
|
||
write_frame_executor = ThreadPoolExecutor(max_workers=Config.WRITE_FRAME_WORKERS)
|
||
write_task = asyncio.create_task(
|
||
write_results_to_rtmp(
|
||
|
||
task_id,
|
||
output_rtmp_url,
|
||
None,
|
||
list_points,
|
||
camera_para,
|
||
invade_state,
|
||
cancel_flag,
|
||
processed_queue,
|
||
invade_queue,
|
||
cv_frame_queue,
|
||
stream_containers
|
||
),
|
||
name="write_results_to_rtmp"
|
||
)
|
||
tasks.append(write_task)
|
||
|
||
# # # 侵限检测任务
|
||
if invade_enable and list_points:
|
||
invade_executor = ThreadPoolExecutor(max_workers=Config.INVADE_WORKERS)
|
||
invade_task = asyncio.create_task(
|
||
cal_des_invade(
|
||
loop,
|
||
invade_executor,
|
||
task_id,
|
||
mqtt,
|
||
mqtt_pub_topic,
|
||
list_points,
|
||
camera_para,
|
||
model_count,
|
||
cancel_flag,
|
||
invade_queue,
|
||
event_queue,
|
||
device_height,
|
||
repeat_dis, repeat_time
|
||
),
|
||
name="cal_des_invade"
|
||
)
|
||
tasks.append(invade_task)
|
||
|
||
# # S3上传任务
|
||
upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
|
||
for _ in range(1):
|
||
upload_task = asyncio.create_task(
|
||
send_frame_to_s3_mq(loop, upload_executor, task_id, mqtt, mqtt_pub_topic,
|
||
cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis,
|
||
repeat_time),
|
||
name=f"send_frame_to_s3_mq_{_}"
|
||
)
|
||
upload_tasks.append(upload_task)
|
||
tasks.append(upload_task)
|
||
#
|
||
# # # 截取事件,并将frame存储为video,然后执行上传
|
||
# event_video_executor = ThreadPoolExecutor(max_workers=Config.EVENT_VIDEO_WORKERS)
|
||
# upload_video = asyncio.create_task(cut_evnt_video_publish(task_id,mqtt, mqtt_pub_topic, cancel_flag,
|
||
# event_queue, timestamp_frame_queue),
|
||
# name="cut_evnt_video_publish")
|
||
# tasks.append(upload_video)
|
||
|
||
# 注册任务到TaskManager
|
||
device_list = [mqtt]
|
||
task_info = {
|
||
"video_url": video_url,
|
||
"output_rtmp_url": output_rtmp_url,
|
||
"task_id": task_id,
|
||
"cancel_flag": cancel_flag,
|
||
"device_list": device_list
|
||
}
|
||
|
||
# 使用asyncio.shield保护主任务,确保任务不会提前停止
|
||
main_task = asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
|
||
|
||
# 注册任务
|
||
await task_manager.add_task(
|
||
task_id,
|
||
task_info,
|
||
main_task,
|
||
tasks
|
||
)
|
||
print("start_rtmp 1")
|
||
await task_manager.update_heartbeat(task_id)
|
||
print("start_rtmp 2")
|
||
# 等待任务完成或取消
|
||
try:
|
||
print("start_rtmp 3")
|
||
await main_task
|
||
except asyncio.CancelledError:
|
||
logger.info(f"Task {task_id} was cancelled")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Task {task_id} failed: {e}")
|
||
raise
|
||
finally:
|
||
# 确保所有任务都被取消
|
||
cancel_flag.set()
|
||
|
||
sleep(3)
|
||
print("start_rtmp 4")
|
||
for task in tasks:
|
||
print("start_rtmp 5")
|
||
task.cancel()
|
||
for upload_task in upload_tasks:
|
||
print("start_rtmp 6")
|
||
upload_task.cancel()
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in main task loop: {e}")
|
||
raise
|
||
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error in start_rtmp_processing: {e}")
|
||
raise
|
||
finally:
|
||
# 确保所有任务被取消
|
||
print("start_rtmp 7")
|
||
cancel_flag.set()
|
||
print("start_rtmp 8")
|
||
# 释放cuda内存
|
||
detector.destroy()
|
||
|
||
# 停止所有任务
|
||
if read_task and not read_task.done():
|
||
print("start_rtmp 9")
|
||
read_task.cancel()
|
||
if process_task and not process_task.done():
|
||
print("start_rtmp 10")
|
||
process_task.cancel()
|
||
if write_task and not write_task.done():
|
||
print("start_rtmp 11")
|
||
write_task.cancel()
|
||
if invade_task and not invade_task.done():
|
||
print("start_rtmp 12")
|
||
invade_task.cancel()
|
||
for task in upload_tasks:
|
||
if task and not task.done():
|
||
task.cancel()
|
||
print("start_rtmp 13")
|
||
# 停止设备
|
||
if 'device' in locals():
|
||
await device.stop()
|
||
print("start_rtmp 14")
|
||
# 清理资源
|
||
await cleanup_resources()
|
||
await task_manager.remove_task(task_id)
|
||
logger.info(f"Task {task_id} resources cleaned up")
|
||
print(f"start_rtmp_processing线程已停止")
|
||
|
||
|
||
async def start_video_processing(minio_path: str, task_id: str, model_configs: List[Dict],
|
||
mqtt_ip: str, mqtt_port: int, mqtt_topic: str, output_rtmp_url: str,
|
||
invade_enable: bool, invade_file: str, camera_para_url: str, device_height: float,
|
||
repeat_dis: float, repeat_time: float):
|
||
# global stop_event, frame_queue, processed_queue, executor, upload_executor
|
||
# await initialize_resources() # 初始化资源
|
||
print("开起录像识别")
|
||
cancel_flag = asyncio.Event()
|
||
# 初始化局部变量(避免全局污染)
|
||
frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
|
||
processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
event_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE) # 存储事件,作为截取视频的标准
|
||
stream_containers: Dict[str, Any] = {}
|
||
timestamp_frame_queue = TimestampedQueue(maxlen=500) # 限定长度的队列,先进先出存储连续帧,事件触发即截取前后视频
|
||
try:
|
||
download_path = downBigFile(minio_path)
|
||
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut_output.mp4"
|
||
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut.mp4"
|
||
# download_path = r"E:\hami\aqm_0829-55-7_cut.mp4"
|
||
# # 近距离 1/3
|
||
# download_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4"
|
||
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
|
||
|
||
# 近距离 1/5
|
||
# # download_path = r"C:\Users\14867\Downloads\DJI_20250912165859_0001_V.mp4"
|
||
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
|
||
# download_path = r"C:\Users\14867\Downloads\DJI_20250918175331_0001_V.mp4"
|
||
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
|
||
# download_path = r"C:\Users\14867\Downloads\DJI_20250917114158_0001_V.mp4"
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error downloading big file: {e}", exc_info=True)
|
||
return None
|
||
if not download_path:
|
||
logger.error(f"big file is none", exc_info=True)
|
||
return
|
||
|
||
print(f"download_path路径 {os.path.abspath(download_path)}")
|
||
|
||
dir_name = os.path.dirname(download_path)
|
||
srt_name = os.path.basename(download_path) + ".srt"
|
||
srt_path = os.path.join(dir_name, srt_name)
|
||
if os.path.exists(srt_path):
|
||
os.remove(srt_path)
|
||
command = [
|
||
"ffmpeg",
|
||
"-i", download_path,
|
||
"-map", "0:s:0", # 选择第一个字幕流
|
||
"-c:s", "srt", # 强制转换为 SRT 格式(可选)
|
||
srt_path
|
||
]
|
||
# # 近距离 1/3
|
||
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
|
||
|
||
# # # # 近距离 1/5
|
||
# if task_id == "7eecadd6-001f-488c-bed9-1086079c3450-1": # 测试工地车辆
|
||
# srt_path = r"C:\Users\14867\Downloads\DJI_20250918175331_0001_V.mp4.srt"
|
||
# if task_id == "2a5d7a80-109a-4cd6-aa95-0e8c9aab6b3f-1": # 测试安全帽
|
||
# srt_path = r"C:\Users\14867\Downloads\DJI_20250917114158_0001_V.mp4.srt"
|
||
try:
|
||
subprocess.run(command, check=True)
|
||
print(f"start_video_processing方法 字幕提取成功,保存至: {srt_path}")
|
||
except subprocess.CalledProcessError as e:
|
||
print(f"start_video_processing方法错误: FFmpeg 执行失败 - {e}")
|
||
return
|
||
except FileNotFoundError:
|
||
print("start_video_processing方法错误: 未安装 FFmpeg,请先安装并添加到系统路径。")
|
||
return
|
||
print(f"srt_path路径 {os.path.abspath(srt_path)}")
|
||
try:
|
||
list_points = [] # 二维数组,里面的一维数组就是面
|
||
camera_para = None
|
||
if invade_file and camera_para_url:
|
||
camera_file_path = downFile(camera_para_url)
|
||
camera_para = read_camera_params(camera_file_path)
|
||
|
||
invade_file_path = downFile(invade_file)
|
||
if invade_file_path is None:
|
||
print(f"invade_file is None task_id:{task_id}")
|
||
return
|
||
|
||
with open(invade_file_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
# 提取多边形坐标
|
||
features = data.get('features', [])
|
||
if not features:
|
||
print("没有找到有效的要素数据")
|
||
return
|
||
|
||
# 获取第一个多边形的坐标(假设只有一个多边形)
|
||
|
||
for polygon in features:
|
||
|
||
# polygon = features[0]
|
||
geometry = polygon.get('geometry', {})
|
||
coordinates = geometry.get('coordinates', [])
|
||
|
||
# 提取经纬度点(忽略Z值)
|
||
points = []
|
||
for key, coord in enumerate(coordinates[0]): # 取第一个环的坐标
|
||
# 只取经纬度,忽略第三个值(高度) 注意这里是反序,不然后面convert_points_to_utm 会报错
|
||
points.append(Point(coord[1], coord[0], coord[2], key))
|
||
list_points.append(points)
|
||
|
||
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
|
||
|
||
# 初始化检测器
|
||
# detector = MultiYOLODetector(model_configs)
|
||
detector = MultiYOLODetector_TrackId(model_configs)
|
||
|
||
mqtt_publish_topic = mqtt_topic
|
||
|
||
# 读取视频获取帧率
|
||
cap = cv2.VideoCapture(download_path)
|
||
if not cap.isOpened():
|
||
print("Error: Could not open video.")
|
||
return
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
cap.release()
|
||
invade_state = False
|
||
if len(list_points) > 0 and invade_enable:
|
||
invade_state = True
|
||
model_count = len(model_configs)
|
||
# 创建任务列表
|
||
tasks = []
|
||
read_task = None
|
||
process_task = None
|
||
write_task = None
|
||
invade_task = None
|
||
upload_tasks = []
|
||
# RTMP读取任务
|
||
read_task = asyncio.create_task(
|
||
read_video_frames(task_id, mqtt, mqtt_publish_topic, download_path, srt_path, frame_queue,
|
||
timestamp_frame_queue, cancel_flag),
|
||
name="read_rtmp_frames")
|
||
tasks.append(read_task)
|
||
|
||
# 处理任务
|
||
process_task = asyncio.create_task(
|
||
process_frames(detector, cancel_flag, frame_queue, processed_queue),
|
||
name="process_frames"
|
||
)
|
||
tasks.append(process_task)
|
||
|
||
# 推流任务
|
||
invade_state = bool(list_points) and invade_enable
|
||
write_task = asyncio.create_task(
|
||
write_results_to_rtmp(
|
||
task_id,
|
||
output_rtmp_url,
|
||
fps,
|
||
list_points,
|
||
camera_para,
|
||
invade_state,
|
||
cancel_flag,
|
||
processed_queue,
|
||
invade_queue,
|
||
cv_frame_queue,
|
||
stream_containers
|
||
),
|
||
name="write_results_to_rtmp"
|
||
)
|
||
tasks.append(write_task)
|
||
|
||
# # # 侵限文件不为空,即输出当前事件
|
||
# if len(list_points) > 0 and invade_enable:
|
||
# print("基于录像开起侵限识别")
|
||
# tasks.append(asyncio.create_task(
|
||
# cal_des_invade(task_id, mqtt, mqtt_publish_topic, list_points, cancel_flag)))
|
||
# 侵限检测任务
|
||
if invade_enable and list_points:
|
||
invade_task = asyncio.create_task(
|
||
cal_des_invade(
|
||
task_id,
|
||
mqtt,
|
||
mqtt_publish_topic,
|
||
list_points,
|
||
camera_para,
|
||
model_count,
|
||
cancel_flag,
|
||
invade_queue,
|
||
event_queue,
|
||
device_height,
|
||
repeat_dis, repeat_time
|
||
),
|
||
name="cal_des_invade"
|
||
)
|
||
tasks.append(invade_task)
|
||
|
||
# S3上传任务
|
||
for _ in range(2):
|
||
upload_task = asyncio.create_task(
|
||
send_frame_to_s3_mq(task_id, mqtt, mqtt_topic,
|
||
cancel_flag, cv_frame_queue, event_queue, device_height, repeat_dis, repeat_time),
|
||
name=f"send_frame_to_s3_mq_{_}"
|
||
)
|
||
upload_tasks.append(upload_task)
|
||
tasks.append(upload_task)
|
||
|
||
# # # 截取事件,并将frame存储为video,然后执行上传
|
||
upload_video = asyncio.create_task(cut_evnt_video_publish(task_id, mqtt, mqtt_topic, cancel_flag,
|
||
event_queue, timestamp_frame_queue),
|
||
name="cut_evnt_video_publish")
|
||
tasks.append(upload_video)
|
||
|
||
# 创建多个上传任务并行处理
|
||
# for _ in range(2): # 2个上传消费者
|
||
# tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_topic, cancel_flag)))
|
||
task_info = {
|
||
"output_rtmp_url": output_rtmp_url,
|
||
"task_id": task_id,
|
||
"cancel_flag": cancel_flag,
|
||
}
|
||
# 使用asyncio.shield保护主任务
|
||
main_task = asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
|
||
|
||
# 注册任务
|
||
await task_manager.add_task(
|
||
task_id,
|
||
task_info,
|
||
main_task,
|
||
tasks
|
||
)
|
||
print("start_rtmp 1")
|
||
await task_manager.update_heartbeat(task_id)
|
||
print("start_rtmp 2")
|
||
# 等待任务完成或取消
|
||
try:
|
||
print("start_rtmp 3")
|
||
await main_task
|
||
except asyncio.CancelledError:
|
||
logger.info(f"Task {task_id} was cancelled")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Task {task_id} failed: {e}")
|
||
raise
|
||
finally:
|
||
# 确保所有任务都被取消
|
||
cancel_flag.set()
|
||
sleep(3)
|
||
# 释放cuda内存
|
||
detector.destroy()
|
||
print("start_rtmp 4")
|
||
for task in tasks:
|
||
print("start_rtmp 5")
|
||
task.cancel()
|
||
for upload_task in upload_tasks:
|
||
print("start_rtmp 6")
|
||
upload_task.cancel()
|
||
|
||
finally:
|
||
# 确保所有任务被取消
|
||
print("start_rtmp 7")
|
||
cancel_flag.set()
|
||
print("start_rtmp 8")
|
||
# 停止所有任务
|
||
if read_task and not read_task.done():
|
||
print("start_rtmp 9")
|
||
read_task.cancel()
|
||
if process_task and not process_task.done():
|
||
print("start_rtmp 10")
|
||
process_task.cancel()
|
||
if write_task and not write_task.done():
|
||
print("start_rtmp 11")
|
||
write_task.cancel()
|
||
if invade_task and not invade_task.done():
|
||
print("start_rtmp 12")
|
||
invade_task.cancel()
|
||
for task in upload_tasks:
|
||
if task and not task.done():
|
||
task.cancel()
|
||
print("start_rtmp 13")
|
||
# # 停止设备
|
||
# if 'device' in locals():
|
||
# await device.stop()
|
||
# print("start_rtmp 14")
|
||
# 清理资源
|
||
await cleanup_resources()
|
||
await task_manager.remove_task(task_id)
|
||
logger.info(f"Task {task_id} resources cleaned up")
|
||
print(f"start_rtmp_processing线程已停止")
|