2421 lines
102 KiB
Plaintext
2421 lines
102 KiB
Plaintext
import logging
|
||
import os.path
|
||
import subprocess
|
||
from asyncio import timeout
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from typing import List, Dict, Any, Optional, Tuple
|
||
from dataclasses import dataclass
|
||
import json
|
||
import time
|
||
|
||
from scipy.optimize import brent
|
||
from sympy import false
|
||
from sympy.codegen.ast import continue_
|
||
from ultralytics import YOLO
|
||
import torch
|
||
|
||
# 假设其他导入模块正确存在
|
||
from cv_back_video import cal_tricker_results, get_local_drc_message, read_drc_mqtt, async_read_drc_mqtt
|
||
from middleware.MQTTService import MQTTService, MQTTDevice
|
||
from middleware.entity.air_attitude import Air_Attitude
|
||
from middleware.entity.up_drc_camera_osd_info_push import parse_camera_osd_info
|
||
from middleware.entity.up_osd_info_push import parse_osd_message
|
||
from middleware.minio_util import upload_file_from_buffer, upload_frame_buff_from_buffer, downFile, \
|
||
downBigFile
|
||
import av
|
||
import asyncio
|
||
import cv2
|
||
import numpy as np
|
||
from typing import Dict, Any
|
||
from ultralytics import solutions
|
||
from middleware.read_srt import parse_srt_file
|
||
from touying.ImageReproject_python.cal_func import cal_canv_location_by_osd, red_line_reproject
|
||
from touying.ImageReproject_python.img_types import Point
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# 全局配置
|
||
class Config:
|
||
MAX_WORKERS = 8 # 图像识别线程池大小
|
||
FRAME_QUEUE_SIZE = 120 # 帧队列大小
|
||
PROCESSED_QUEUE_SIZE = 60 # 处理后帧队列大小
|
||
RETRY_COUNT = 3 # 模型加载重试次数
|
||
UPLOAD_WORKERS = 2 # 上传专用线程池大小
|
||
MAX_QUEUE_WARN_SIZE = 15 # 队列警告阈值(PROCESSED_QUEUE_SIZE的一半)
|
||
MODEL_INPUT_SIZE = (640, 640) # YOLO模型输入尺寸(确保能被32整除)
|
||
DEFAULT_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" # 默认使用第一个GPU
|
||
TARGET_FPS = 25 # 目标推流帧率
|
||
MAX_FRAME_DROP = 5 # 最大允许丢帧数(避免队列积压)
|
||
|
||
|
||
# # 创建专用线程池
|
||
# upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
|
||
# frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
|
||
# processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
# invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
# cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
# executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
||
# stop_event = asyncio.Event()
|
||
|
||
|
||
# 模型缓存
|
||
model_cache = {}
|
||
|
||
# 性能统计
|
||
stats = {
|
||
'processed': 0,
|
||
'last_time': time.time(),
|
||
'avg_fps': 0.0,
|
||
'dropped_frames': 0
|
||
}
|
||
|
||
|
||
# 改为在每次调用时创建和销毁
|
||
async def initialize_resources():
|
||
"""初始化资源(线程池、队列等)"""
|
||
global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
|
||
upload_executor = ThreadPoolExecutor(max_workers=Config.UPLOAD_WORKERS)
|
||
executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
||
frame_queue = asyncio.Queue(maxsize=Config.FRAME_QUEUE_SIZE)
|
||
processed_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
invade_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
cv_frame_queue = asyncio.Queue(maxsize=Config.PROCESSED_QUEUE_SIZE)
|
||
stop_event = asyncio.Event()
|
||
|
||
|
||
# async def cleanup_resources():
|
||
# """清理资源"""
|
||
# global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
|
||
# stop_event.set()
|
||
# if 'executor' in globals() and executor is not None:
|
||
# executor.shutdown(wait=True)
|
||
# if 'upload_executor' in globals() and upload_executor is not None:
|
||
# upload_executor.shutdown(wait=True)
|
||
|
||
|
||
async def cleanup_resources():
|
||
"""清理资源"""
|
||
global upload_executor, executor, frame_queue, processed_queue, invade_queue, cv_frame_queue, stop_event
|
||
|
||
stop_event.set()
|
||
|
||
# 等待所有任务完成
|
||
await asyncio.sleep(1.0) # 给处理中的帧一些时间完成
|
||
|
||
# 关闭队列
|
||
try:
|
||
while not frame_queue.empty():
|
||
await frame_queue.get()
|
||
while not processed_queue.empty():
|
||
await processed_queue.get()
|
||
while not invade_queue.empty():
|
||
await invade_queue.get()
|
||
while not cv_frame_queue.empty():
|
||
await cv_frame_queue.get()
|
||
except Exception as e:
|
||
logger.warning(f"清理队列时出错: {e}")
|
||
|
||
# 关闭线程池
|
||
if 'executor' in globals() and executor is not None:
|
||
executor.shutdown(wait=True, cancel_futures=True)
|
||
if 'upload_executor' in globals() and upload_executor is not None:
|
||
upload_executor.shutdown(wait=True, cancel_futures=True)
|
||
|
||
# 关闭所有推流容器
|
||
for container in stream_containers.values():
|
||
try:
|
||
if 'container' in container:
|
||
container['container'].close()
|
||
except Exception as e:
|
||
logger.warning(f"关闭推流容器时出错: {e}")
|
||
|
||
stream_containers.clear()
|
||
|
||
|
||
class MultiDetectionResults:
|
||
def __init__(self):
|
||
self.boxes = [] # 边界框
|
||
self.clss = [] # 类别ID
|
||
self.cls_names = [] # 类别名称
|
||
self.cls_en_names = [] # 英文类别名称
|
||
self.confs = [] # 置信度
|
||
self.track_ids = [] # 置信度
|
||
|
||
|
||
@dataclass
|
||
class DetectionResult:
|
||
bbox: List[int] # [x1, y1, x2, y2]
|
||
class_id: int
|
||
class_name: str
|
||
confidence: float
|
||
track_id: int
|
||
|
||
|
||
@dataclass
|
||
class DetectionResultList:
|
||
boxes: [] # [x1, y1, x2, y2]
|
||
clss: []
|
||
clss_name: []
|
||
confs: []
|
||
track_ids: []
|
||
|
||
def __iter__(self):
|
||
"""迭代返回每个检测结果的 (bbox, cls_id, cls_name, conf)"""
|
||
for i in range(len(self.boxes)):
|
||
yield (
|
||
self.boxes[i],
|
||
self.clss[i],
|
||
self.clss_name[i],
|
||
self.confs[i],
|
||
self.track_ids[i]
|
||
)
|
||
|
||
|
||
class MultiYOLODetector:
|
||
"""多模型并行检测器,修复多GPU设备不匹配问题"""
|
||
|
||
def __init__(self, model_configs: List[Dict]):
|
||
self.models = []
|
||
self.class_maps = []
|
||
|
||
self.allowed_classes = []
|
||
|
||
self.model_cls = []
|
||
self.chinese_label = []
|
||
self.list_func_id = []
|
||
self.func_id = -1
|
||
self.list_class_names = []
|
||
self.list_para_invade_enable = []
|
||
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
|
||
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
|
||
self._setup_multi_gpu()
|
||
|
||
for config in model_configs:
|
||
model_path = config['path']
|
||
cls_map = config.get('cls_map', {})
|
||
allowed = config.get('allowed_classes', None)
|
||
tracking = config.get('tracking', True) # 是否启用跟踪
|
||
|
||
model_cls_index = config.get('cls_index', [])
|
||
model_chinese_labe = config.get('chinese_label', [])
|
||
model_list_func_id = config.get('list_func_id', [])
|
||
func_id = config.get('func_id', -1)
|
||
class_names = config.get('class_names', True)
|
||
para_invade_enable = config.get('para_invade_enable', False)
|
||
|
||
# 加载模型并确保在正确设备上
|
||
model = self._load_model(model_path, tracking)
|
||
self.models.append(model)
|
||
self.class_maps.append(cls_map)
|
||
self.allowed_classes.append(allowed)
|
||
self.model_cls.append(model_cls_index)
|
||
self.chinese_label.append(model_chinese_labe)
|
||
self.list_func_id.append(model_list_func_id)
|
||
self.func_id = func_id
|
||
self.list_class_names.append(class_names)
|
||
self.list_para_invade_enable.append(para_invade_enable)
|
||
|
||
def _setup_multi_gpu(self):
|
||
"""配置多GPU环境"""
|
||
if torch.cuda.is_available():
|
||
gpu_count = torch.cuda.device_count()
|
||
print(f"检测到 {gpu_count} 个GPU,将使用默认设备: {self.device}")
|
||
# 强制所有操作使用默认设备
|
||
torch.cuda.set_device(self.device)
|
||
|
||
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
|
||
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
|
||
try:
|
||
# 加载模型
|
||
model = YOLO(model_path)
|
||
|
||
# 移动到目标设备
|
||
model = model.to(self.device)
|
||
|
||
# 启用跟踪(仅对YOLOv8有效)
|
||
if hasattr(model, 'tracker'):
|
||
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
|
||
print(f"已为模型 {model_path} 启用跟踪")
|
||
else:
|
||
print(f"警告:模型 {model_path} 不支持跟踪功能")
|
||
|
||
# 对于多GPU,仅在必要时使用DataParallel,并确保所有参数在同一设备
|
||
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
||
# 确保模型参数在默认设备上
|
||
model = model.to(self.device)
|
||
model = torch.nn.DataParallel(model, device_ids=[self.device])
|
||
|
||
# 禁用模型融合以避免多设备问题
|
||
model.fuse = lambda verbose=False: model # 替换融合方法
|
||
|
||
return model
|
||
except Exception as e:
|
||
print(f"模型加载失败: {e}")
|
||
raise
|
||
|
||
@staticmethod
|
||
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
|
||
"""预处理帧:调整大小、转换格式并标准化"""
|
||
# 调整大小
|
||
resized = cv2.resize(frame, input_size)
|
||
# 转换为RGB
|
||
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
||
# 转换为张量并标准化
|
||
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
|
||
# 添加批次维度并移动到目标设备
|
||
tensor = tensor.unsqueeze(0).to(device)
|
||
return tensor
|
||
|
||
@staticmethod
|
||
def scale_bbox(bbox: List[int], original_size: tuple, input_size: tuple) -> List[int]:
|
||
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
|
||
oh, ow = original_size
|
||
iw, ih = input_size
|
||
|
||
# 计算缩放比例
|
||
scale_w = ow / iw
|
||
scale_h = oh / ih
|
||
|
||
# 缩放边界框
|
||
x1, y1, x2, y2 = bbox
|
||
return [
|
||
int(x1 * scale_w),
|
||
int(y1 * scale_h),
|
||
int(x2 * scale_w),
|
||
int(y2 * scale_h)
|
||
]
|
||
|
||
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
|
||
"""异步调用多模型预测"""
|
||
loop = asyncio.get_running_loop()
|
||
original_size = (frame.shape[0], frame.shape[1]) # (height, width)
|
||
|
||
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
|
||
model = self.models[model_idx]
|
||
cls_map = self.class_maps[model_idx]
|
||
allowed = self.allowed_classes[model_idx]
|
||
|
||
model_cls_index = self.model_cls[model_idx]
|
||
model_chinese_labe = self.chinese_label[model_idx]
|
||
model_list_func_id = self.list_func_id[model_idx]
|
||
func_id = self.func_id
|
||
model_class_names = self.list_class_names[model_idx]
|
||
para_invade_enable = self.list_para_invade_enable[model_idx]
|
||
|
||
# 预处理帧并确保在正确设备上
|
||
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
|
||
|
||
timestart = time.time()
|
||
with torch.no_grad():
|
||
# 确保输入和模型在同一设备
|
||
if input_tensor.device != next(model.parameters()).device:
|
||
input_tensor = input_tensor.to(next(model.parameters()).device)
|
||
results = model(input_tensor, verbose=False)
|
||
timeend = time.time()
|
||
# print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s (设备: {self.device})")
|
||
|
||
detections = []
|
||
model_para = {
|
||
"cls_map": cls_map,
|
||
"model_chinese_labe": model_chinese_labe,
|
||
"model_cls_index": model_cls_index,
|
||
"model_list_func_id": model_list_func_id,
|
||
"func_id": func_id,
|
||
"model_class_names": model_class_names,
|
||
"para_invade_enable": para_invade_enable,
|
||
"results": results
|
||
}
|
||
# model_cls_index{}
|
||
detection_result_list = DetectionResultList([], [], [], [], [])
|
||
for result in results:
|
||
boxes = result.boxes.cpu().numpy() # 移动到CPU进行处理
|
||
tracks = result.boxes.id.cpu().numpy() if result.boxes.id is not None else None
|
||
|
||
for i, box in enumerate(boxes):
|
||
cls_id = int(box.cls[0])
|
||
# 获取原始模型(处理DataParallel包装)
|
||
original_model = model.module if isinstance(model, torch.nn.DataParallel) else model
|
||
cls_name = original_model.names[cls_id]
|
||
|
||
# 过滤不允许的类别
|
||
if allowed and cls_name not in allowed:
|
||
continue
|
||
|
||
# 模型输出的边界框是相对于输入尺寸的
|
||
x1, y1, x2, y2 = box.xyxy[0].astype(int)
|
||
# 缩放回原始图像尺寸
|
||
scaled_bbox = self.scale_bbox(
|
||
[x1, y1, x2, y2],
|
||
original_size,
|
||
self.input_size
|
||
)
|
||
|
||
conf = float(box.conf[0])
|
||
# 临时测试用,后面需要抽取出来参数
|
||
|
||
if conf < 0.45:
|
||
continue
|
||
|
||
en_name = cls_map.get(cls_name, "unknown")
|
||
# 添加跟踪ID(如果存在)
|
||
track_id = int(tracks[i]) if tracks is not None else None
|
||
|
||
detection_result_list.boxes.append(scaled_bbox)
|
||
detection_result_list.clss.append(cls_id)
|
||
detection_result_list.clss_name.append(cls_name)
|
||
detection_result_list.confs.append(conf)
|
||
detection_result_list.track_ids.append(-1)
|
||
|
||
detections.append(DetectionResult(
|
||
bbox=scaled_bbox,
|
||
class_id=cls_id,
|
||
class_name=cls_name,
|
||
confidence=conf, # 新增字段
|
||
track_id=-1
|
||
))
|
||
|
||
return detections, detection_result_list, model_para
|
||
|
||
# # 使用线程池并行处理
|
||
# futures = [
|
||
# loop.run_in_executor(
|
||
# executor,
|
||
# _predict,
|
||
# model_idx,
|
||
# frame.copy()
|
||
# )
|
||
# for model_idx in range(len(self.models))
|
||
# ]
|
||
#
|
||
# # 等待所有模型完成预测
|
||
# results = await asyncio.gather(*futures)
|
||
#
|
||
# # 合并所有模型的结果
|
||
# merged_results = []
|
||
# for res in results:
|
||
# merged_results.extend(res)
|
||
#
|
||
# return merged_results
|
||
|
||
# 并行执行
|
||
futures = [
|
||
loop.run_in_executor(
|
||
executor,
|
||
_predict,
|
||
model_idx,
|
||
frame.copy()
|
||
)
|
||
for model_idx in range(len(self.models))
|
||
]
|
||
results = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
|
||
|
||
# 合并检测结果
|
||
all_detections = []
|
||
detections_list = [] # 这里的格式,跟results 的二次计算相同
|
||
# 合并所有模型的参数到一个字典(避免键冲突)
|
||
all_model_paras = []
|
||
for model_idx, (detections, detection_result_list, model_para) in enumerate(results):
|
||
all_detections.extend(detections)
|
||
# 给每个模型的参数加前缀,避免覆盖(如 model0_model_cls_index)
|
||
# prefixed_para = {f"model{model_idx}_{k}": v for k, v in model_para.items()}
|
||
# all_model_paras.update(prefixed_para)
|
||
all_model_paras.append(model_para)
|
||
detections_list.append(detection_result_list)
|
||
|
||
return all_detections, detections_list, all_model_paras
|
||
|
||
|
||
class MultiYOLODetector_TrackId:
|
||
"""多模型并行检测器,修复多GPU设备不匹配问题"""
|
||
|
||
def __init__(self, model_configs: List[Dict]):
|
||
self.models = []
|
||
self.class_maps = []
|
||
|
||
self.allowed_classes = []
|
||
|
||
self.model_cls = []
|
||
self.chinese_label = []
|
||
self.list_func_id = []
|
||
self.func_id = -1
|
||
self.list_class_names = []
|
||
self.list_para_invade_enable = []
|
||
self.input_size = Config.MODEL_INPUT_SIZE # 统一输入尺寸
|
||
self.device = Config.DEFAULT_DEVICE # 强制使用默认设备
|
||
self._setup_multi_gpu()
|
||
|
||
# 初始化ObjectCounter实例
|
||
self.object_counters = []
|
||
|
||
for config in model_configs:
|
||
model_path = config['path']
|
||
cls_map = config.get('cls_map', {})
|
||
allowed = config.get('allowed_classes', None)
|
||
tracking = config.get('tracking', True) # 是否启用跟踪
|
||
|
||
model_cls_index = config.get('cls_index', True)
|
||
model_chinese_labe = config.get('chinese_label', True)
|
||
model_list_func_id = config.get('list_func_id', -11)
|
||
func_id = config.get('func_id', True)
|
||
class_names = config.get('class_names', True)
|
||
para_invade_enable = config.get('para_invade_enable', False)
|
||
|
||
# 加载模型并确保在正确设备上
|
||
model = self._load_model(model_path, tracking)
|
||
self.models.append(model)
|
||
self.class_maps.append(cls_map)
|
||
self.allowed_classes.append(allowed)
|
||
self.model_cls.append(model_cls_index)
|
||
self.chinese_label.append(model_chinese_labe)
|
||
self.list_func_id.append(model_list_func_id)
|
||
self.func_id = func_id
|
||
self.list_class_names.append(class_names)
|
||
self.list_para_invade_enable.append(para_invade_enable)
|
||
|
||
# 为每个模型创建ObjectCounter实例
|
||
if hasattr(model, 'tracker'):
|
||
# 如果模型支持跟踪,则创建带跟踪的ObjectCounter
|
||
self.object_counters.append(
|
||
solutions.ObjectCounter(
|
||
show=False,
|
||
region=None, # 可以根据需要设置检测区域
|
||
model=model_path,
|
||
classes=model_cls_index if allowed else None
|
||
)
|
||
)
|
||
else:
|
||
# 如果模型不支持跟踪,则创建普通ObjectCounter
|
||
self.object_counters.append(
|
||
solutions.ObjectCounter(
|
||
show=False,
|
||
region=None,
|
||
model=model_path,
|
||
classes=model_cls_index if allowed else None
|
||
)
|
||
)
|
||
|
||
def _setup_multi_gpu(self):
|
||
"""配置多GPU环境"""
|
||
if torch.cuda.is_available():
|
||
gpu_count = torch.cuda.device_count()
|
||
print(f"检测到 {gpu_count} 个GPU,将使用默认设备: {self.device}")
|
||
# 强制所有操作使用默认设备
|
||
torch.cuda.set_device(self.device)
|
||
|
||
def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
|
||
"""加载模型并确保在正确设备上,禁用模型融合以避免设备冲突"""
|
||
try:
|
||
# 加载模型
|
||
model = YOLO(model_path)
|
||
|
||
# 移动到目标设备
|
||
model = model.to(self.device)
|
||
|
||
# # 打印模型的输入尺寸(通过模型结构或配置)
|
||
# if hasattr(model, 'stride'):
|
||
# stride = model.stride # 通常是32
|
||
# input_width = int((self.input_size[0] / stride) * stride) # 对齐到最近的stride倍数
|
||
# input_height = int((self.input_size[1] / stride) * stride)
|
||
# print(f"模型实际输入尺寸(对齐后): {input_width}x{input_height}")
|
||
# else:
|
||
# print(f"模型输入尺寸(未对齐): {self.input_size}")
|
||
|
||
# 启用跟踪(仅对YOLOv8有效)
|
||
if hasattr(model, 'tracker') and tracking:
|
||
model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
|
||
print(f"已为模型 {model_path} 启用跟踪")
|
||
else:
|
||
print(f"警告:模型 {model_path} 不支持跟踪功能")
|
||
|
||
# 对于多GPU,仅在必要时使用DataParallel,并确保所有参数在同一设备
|
||
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
||
# 确保模型参数在默认设备上
|
||
model = model.to(self.device)
|
||
model = torch.nn.DataParallel(model, device_ids=[self.device])
|
||
|
||
# 禁用模型融合以避免多设备问题
|
||
model.fuse = lambda verbose=False: model # 替换融合方法
|
||
|
||
return model
|
||
except Exception as e:
|
||
print(f"模型加载失败: {e}")
|
||
raise
|
||
|
||
# # 修改 _load_model 方法
|
||
# def _load_model(self, model_path: str, tracking: bool) -> torch.nn.Module:
|
||
# model = YOLO(model_path)
|
||
# model = model.to(self.device)
|
||
#
|
||
# if tracking:
|
||
# # 显式设置跟踪器参数
|
||
# model.tracker = "botsort.yaml" # 或 "bytetrack.yaml"
|
||
# tracker_cfg = {
|
||
# "track_high_thresh": 0.3, # 降低跟踪阈值
|
||
# "track_low_thresh": 0.1,
|
||
# "new_track_thresh": 0.4,
|
||
# }
|
||
# model.tracker.update_params(tracker_cfg) # 更新参数
|
||
#
|
||
# return model
|
||
|
||
@staticmethod
|
||
def preprocess_frame(frame: np.ndarray, input_size: tuple, device: str) -> torch.Tensor:
|
||
"""预处理帧:调整大小、转换格式并标准化"""
|
||
# 调整大小
|
||
resized = cv2.resize(frame, (input_size[0], input_size[1])) # 明确指定 (width, height)
|
||
print(f"预处理后图像尺寸: {resized.shape}") # 应为 (input_size[1], input_size[0], 3)
|
||
# 转换为RGB
|
||
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
||
# 转换为张量并标准化
|
||
tensor = torch.from_numpy(rgb).permute(2, 0, 1).float() / 255.0
|
||
# 添加批次维度并移动到目标设备
|
||
tensor = tensor.unsqueeze(0).to(device)
|
||
return tensor
|
||
|
||
# @staticmethod
|
||
# def scale_bbox(bbox: List[int], original_size: tuple, input_size: tuple) -> List[int]:
|
||
# """将模型输出的边界框坐标缩放回原始图像尺寸"""
|
||
# oh, ow = original_size
|
||
# iw, ih = input_size
|
||
#
|
||
# # 计算缩放比例
|
||
# scale_w = ow / iw
|
||
# scale_h = oh / ih
|
||
#
|
||
# # 缩放边界框
|
||
# x1, y1, x2, y2 = bbox
|
||
# return [
|
||
# int(x1 * scale_w),
|
||
# int(y1 * scale_h),
|
||
# int(x2 * scale_w),
|
||
# int(y2 * scale_h)
|
||
# ]
|
||
|
||
@staticmethod
|
||
def scale_bbox(bbox, original_size, input_size):
|
||
"""将模型输出的边界框坐标缩放回原始图像尺寸"""
|
||
ow, oh = original_size # (width, height)
|
||
iw, ih = input_size # (width, height)
|
||
|
||
# 计算缩放比例(基于输入尺寸和原始尺寸的比例)
|
||
scale_w = ow / iw
|
||
scale_h = oh / ih
|
||
|
||
x1, y1, x2, y2 = bbox
|
||
# 缩放并限制坐标范围
|
||
x1 = int(max(0, min(x1 * scale_w, ow - 1)))
|
||
y1 = int(max(0, min(y1 * scale_h, oh - 1)))
|
||
x2 = int(max(0, min(x2 * scale_w, ow - 1)))
|
||
y2 = int(max(0, min(y2 * scale_h, oh - 1)))
|
||
|
||
return [x1, y1, x2, y2]
|
||
|
||
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
|
||
"""异步调用多模型预测"""
|
||
loop = asyncio.get_running_loop()
|
||
# original_size = (frame.shape[0], frame.shape[1]) # (height, width)
|
||
original_size = (frame.shape[1], frame.shape[0]) # (width,height )
|
||
print("predictpredict")
|
||
|
||
def _predict(model_idx: int, frame: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
|
||
try:
|
||
model = self.models[model_idx]
|
||
object_counter = self.object_counters[model_idx]
|
||
cls_map = self.class_maps[model_idx]
|
||
allowed = self.allowed_classes[model_idx]
|
||
|
||
model_cls_index = self.model_cls[model_idx]
|
||
model_chinese_labe = self.chinese_label[model_idx]
|
||
model_list_func_id = self.list_func_id[model_idx]
|
||
func_id = self.func_id
|
||
model_class_names = self.list_class_names[model_idx]
|
||
para_invade_enable = self.list_para_invade_enable[model_idx]
|
||
|
||
# 预处理帧
|
||
input_tensor = self.preprocess_frame(frame, self.input_size, self.device)
|
||
|
||
timestart = time.time()
|
||
with torch.no_grad():
|
||
if input_tensor.device != next(model.parameters()).device:
|
||
input_tensor = input_tensor.to(next(model.parameters()).device)
|
||
|
||
# 使用ObjectCounter进行检测
|
||
results = object_counter(frame)
|
||
|
||
# 调试输出结果结构
|
||
print(f"Model {model_idx} results type: {type(results)}")
|
||
if hasattr(results, 'boxes'):
|
||
print(
|
||
f"Boxes type: {type(results.boxes)}, length: {len(results.boxes) if hasattr(results.boxes, '__len__') else 'N/A'}")
|
||
|
||
timeend = time.time()
|
||
print(f"模型 {model_idx} 推理耗时: {timeend - timestart:.4f}s")
|
||
|
||
detections = []
|
||
model_para = {
|
||
"cls_map": cls_map,
|
||
"model_chinese_labe": model_chinese_labe,
|
||
"model_cls_index": model_cls_index,
|
||
"model_list_func_id": model_list_func_id,
|
||
"func_id": func_id,
|
||
"model_class_names": model_class_names,
|
||
"para_invade_enable": para_invade_enable,
|
||
"results": object_counter # 保存原始结果用于调试
|
||
}
|
||
results = object_counter
|
||
# 处理检测结果
|
||
detection_result_list = DetectionResultList([], [], [], [], [])
|
||
|
||
# 尝试不同的结果访问方式
|
||
boxes = []
|
||
tracks = results.track_ids
|
||
confs = results.confs
|
||
clss = results.clss
|
||
names = results.names
|
||
result_boxes = results.boxes
|
||
# yolo11 使用二次修改后的值,本身就是list
|
||
# if isinstance(result_boxes, list):
|
||
# if len(result_boxes) == 0:
|
||
# return [], DetectionResultList([], [], [], []), {}
|
||
# boxes = result_boxes
|
||
# else:
|
||
# try:
|
||
# boxes = result_boxes.tolist() # 尝试转换为列表
|
||
# except AttributeError:
|
||
# # 如果既不是列表,也没有 tolist() 方法,根据需求处理
|
||
# # 例如:raise TypeError("result_boxes 必须是列表或支持 tolist() 的对象")
|
||
# boxes = list(result_boxes) # 尝试用 list() 转换(如果对象可迭代)
|
||
|
||
# 在处理 boxes 之前添加调试代码
|
||
print(f"模型输出 boxes 原始值: {result_boxes}")
|
||
if isinstance(result_boxes, torch.Tensor):
|
||
print(f"Boxes 张量形状: {result_boxes.shape}")
|
||
|
||
if isinstance(result_boxes, list) and len(result_boxes) == 0:
|
||
return [], DetectionResultList([], [], [], [], []), {}
|
||
if isinstance(result_boxes, list) and len(result_boxes) > 0:
|
||
boxes = result_boxes
|
||
else:
|
||
boxes = result_boxes.tolist()
|
||
# 处理每个检测结果
|
||
for i, box in enumerate(boxes):
|
||
try:
|
||
# 确保box是可迭代的
|
||
if not hasattr(box, '__iter__'):
|
||
print(f"无效的边界框格式: {box}")
|
||
continue
|
||
|
||
# 转换为列表(如果还不是)
|
||
box = list(box) if not isinstance(box, list) else box
|
||
|
||
# 确保有4个坐标值
|
||
if len(box) < 4:
|
||
print(f"不完整的边界框坐标: {box}")
|
||
continue
|
||
|
||
x1, y1, x2, y2 = box[:4]
|
||
# cls_id = int(results.clss[i])
|
||
cls_id = int(results.clss[i])
|
||
ind = int(results.clss[i])
|
||
cls_name = names[ind]
|
||
# if 'Hat'== cls_name:
|
||
# print("1")
|
||
# else:
|
||
# cls_id=0
|
||
# print("2")
|
||
conf = float(confs[i]) if i < len(confs) else 0.0
|
||
# 待优化,将参数提取出来
|
||
if conf < 0.45:
|
||
continue
|
||
|
||
track_id = int(tracks[i]) if i < len(tracks) and tracks[i] is not None else None
|
||
|
||
# 过滤不允许的类别
|
||
if allowed and cls_name not in allowed:
|
||
continue
|
||
|
||
# 缩放回原始图像尺寸
|
||
scaled_bbox = self.scale_bbox(
|
||
[x1, y1, x2, y2],
|
||
original_size,
|
||
self.input_size
|
||
)
|
||
scaled_bbox = [int(x1), int(y1), int(x2), int(y2)]
|
||
detection_result_list.boxes.append([x1, y1, x2, y2])
|
||
detection_result_list.clss.append(cls_id)
|
||
detection_result_list.clss_name.append(cls_name)
|
||
detection_result_list.confs.append(conf)
|
||
detection_result_list.track_ids.append(track_id)
|
||
|
||
detections.append(DetectionResult(
|
||
bbox=scaled_bbox,
|
||
class_id=cls_id,
|
||
class_name=cls_name,
|
||
confidence=conf,
|
||
track_id=track_id
|
||
))
|
||
except Exception as e:
|
||
print(f"处理单个检测结果时出错: {str(e)}")
|
||
continue
|
||
|
||
return detections, detection_result_list, model_para
|
||
|
||
except Exception as e:
|
||
print(f"模型 {model_idx} 预测过程中发生错误: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return [], DetectionResultList([], [], [], [], []), {}
|
||
|
||
# 并行执行
|
||
futures = [
|
||
loop.run_in_executor(
|
||
executor,
|
||
_predict,
|
||
model_idx,
|
||
frame.copy()
|
||
)
|
||
for model_idx in range(len(self.models))
|
||
]
|
||
results111 = await asyncio.gather(*futures) # List[Tuple[List[DetectionResult], Dict]]
|
||
|
||
# 合并检测结果
|
||
all_detections = []
|
||
detections_list = [] # 这里的格式,跟results 的二次计算相同
|
||
# 合并所有模型的参数到一个字典(避免键冲突)
|
||
all_model_paras = []
|
||
for model_idx, (detections, detection_result_list, model_para) in enumerate(results111):
|
||
all_detections.extend(detections)
|
||
# 给每个模型的参数加前缀,避免覆盖(如 model0_model_cls_index)
|
||
# prefixed_para = {f"model{model_idx}_{k}": v for k, v in model_para.items()}
|
||
# all_model_paras.update(prefixed_para)
|
||
all_model_paras.append(model_para)
|
||
detections_list.append(detection_result_list)
|
||
|
||
return all_detections, detections_list, all_model_paras
|
||
|
||
|
||
#
|
||
# async def read_rtmp_frames(video_url: str):
|
||
# """异步读取RTMP流帧"""
|
||
# print(f"开始视频拉流: {video_url}")
|
||
# cap = cv2.VideoCapture(video_url, cv2.CAP_FFMPEG)
|
||
#
|
||
# cap.set(cv2.CAP_PROP_OPEN_TIMEOUT_MSEC, 60000)
|
||
# cap.set(cv2.CAP_PROP_READ_TIMEOUT_MSEC, 50000)
|
||
#
|
||
# if not cap.isOpened():
|
||
# raise RuntimeError(f"无法打开视频流: {video_url}")
|
||
#
|
||
# frame_count = 0
|
||
# retry_count = 0
|
||
# max_retries = 10
|
||
#
|
||
# while not stop_event.is_set() and retry_count < max_retries:
|
||
# try:
|
||
# ret, frame = cap.read()
|
||
# if not ret:
|
||
# retry_count += 1
|
||
# if retry_count >= max_retries:
|
||
# raise RuntimeError("连续读取帧失败,达到最大重试次数")
|
||
# await asyncio.sleep(0.1)
|
||
# continue
|
||
#
|
||
# retry_count = 0
|
||
# frame_count += 1
|
||
#
|
||
# # 跳过前5帧以稳定流
|
||
# if frame_count <= 5:
|
||
# continue
|
||
#
|
||
# # 放入帧队列
|
||
# if not frame_queue.full():
|
||
# await frame_queue.put((frame, time.time()))
|
||
# else:
|
||
# await asyncio.sleep(0.01) # 队列满时稍作等待
|
||
#
|
||
# except Exception as e:
|
||
# print(f"读取帧时出错: {e}")
|
||
# retry_count += 1
|
||
# await asyncio.sleep(0.5)
|
||
#
|
||
# cap.release()
|
||
# print("RTMP读取线程已停止")
|
||
|
||
|
||
async def read_video_frames(task_id, mqtt, mqtt_publish_topic, local_video_path: str, srt_path: str):
|
||
# 打开视频文件(流式读取,不加载到内存)
|
||
stop_event.clear() # 清除停止事件
|
||
|
||
srt_list = parse_srt_file(srt_path)
|
||
|
||
cap = cv2.VideoCapture(local_video_path)
|
||
if not cap.isOpened():
|
||
print("Error: Could not open video.")
|
||
return
|
||
prev_time = time.time()
|
||
frame_count = 0
|
||
fps = 0
|
||
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||
|
||
# # 临时使用,后续还是考虑使用文件整体传输消息好一点
|
||
# async with cache_lock:
|
||
# # 读取或修改共享变量
|
||
# video_process_status = shared_local_cache["video_process_status"]
|
||
# if video_process_status == 0:
|
||
# shared_local_cache["video_process_status"] = 1 # 将状态改为开始
|
||
|
||
# 临时处理,用作发送当前录像处理的状态,开始还是结束
|
||
async def publist_status(status_to_publish):
|
||
try:
|
||
message = {"task_id": task_id, "video_status": status_to_publish}
|
||
await mqtt.publish(mqtt_publish_topic, json.dumps(message, ensure_ascii=False))
|
||
except Exception as e:
|
||
logger.error(f"Failed to publish status: {e}")
|
||
raise
|
||
|
||
await publist_status(1)
|
||
|
||
# 控制读帧,避免队列积压,进而出现大量丢帧
|
||
cap = cv2.VideoCapture(local_video_path)
|
||
video_fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频原始FPS
|
||
frame_interval = 1.0 / video_fps # 每帧应等待的时间
|
||
last_read_time = time.time()
|
||
|
||
while cap.isOpened():
|
||
ret, frame = cap.read()
|
||
if not ret: # 检查是否成功读取帧(视频结束或读取错误)
|
||
await publist_status(2)
|
||
stop_event.set()
|
||
# # 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
|
||
# async with cache_lock:
|
||
# # 读取或修改共享变量
|
||
# video_process_status = shared_local_cache["video_process_status"]
|
||
# if video_process_status == 0:
|
||
# shared_local_cache["video_process_status"] = 2 # 将状态改为结束
|
||
break
|
||
|
||
# 控制读取速度:确保每帧间隔不小于 frame_interval
|
||
current_time = time.time()
|
||
if current_time - last_read_time < frame_interval:
|
||
await asyncio.sleep(frame_interval - (current_time - last_read_time))
|
||
last_read_time = current_time
|
||
|
||
# # 保存当前帧为图片(或直接处理)
|
||
# frame_path = f"{output_folder}/frame_{frame_count:06d}.jpg"
|
||
# cv2.imwrite(frame_path, frame)
|
||
frame_count += 1
|
||
if frame_count % 10 == 0:
|
||
curr_time = time.time()
|
||
fps = 10 / (curr_time - prev_time)
|
||
prev_time = curr_time
|
||
print(f" ({orig_width}x{orig_height}, FPS: {fps:.2f})")
|
||
# 可选:每处理N帧打印一次进度
|
||
if frame_count % 100 == 0:
|
||
print(f"Processed {frame_count} frames")
|
||
|
||
if not frame_queue.full():
|
||
|
||
if frame_count < len(srt_list) and len(srt_list) > 0:
|
||
dji_srt = srt_list[frame_count]
|
||
gimbal_pitch = dji_srt.gb_pitch
|
||
gimbal_roll = dji_srt.gb_roll
|
||
gb_yaw = dji_srt.gb_yaw
|
||
air_height = dji_srt.abs_alt
|
||
cam_longitude = dji_srt.longitude
|
||
cam_latitude = dji_srt.latitude
|
||
art_tit = Air_Attitude(gimbal_pitch, gimbal_roll, gb_yaw, air_height, cam_latitude, cam_longitude)
|
||
await frame_queue.put((frame, art_tit, time.time()))
|
||
|
||
else:
|
||
await asyncio.sleep(0.01) # 队列满时稍作等待
|
||
|
||
cap.release()
|
||
# # 测试时候,临时注释
|
||
if os.path.exists(srt_path):
|
||
os.remove(srt_path)
|
||
if os.path.exists(local_video_path):
|
||
os.remove(local_video_path)
|
||
|
||
|
||
async def read_rtmp_frames(
|
||
video_url: str,
|
||
device: MQTTDevice = None, # 可选参数,如果不需要 OSD 信息可设为 None
|
||
topic_camera_osd: str = None,
|
||
method_camera_osd: str = None,
|
||
topic_osd_info: str = None,
|
||
method_osd_info: str = None
|
||
):
|
||
"""
|
||
异步读取 RTMP 流帧(优化版)
|
||
:param video_url: RTMP 流地址
|
||
:param device: MQTTDevice 实例(可选,用于获取 OSD 信息)
|
||
:param topic_osd_info: OSD 信息对应的 MQTT 主题(可选)
|
||
:param method_osd_info: OSD 信息对应的 MQTT method(可选)
|
||
"""
|
||
max_retries = 5 # 最大重试次数
|
||
retry_delay = 2 # 重试间隔(秒)
|
||
frame_timeout = 1.0 # 获取帧的超时(秒)
|
||
|
||
for attempt in range(max_retries):
|
||
container = None
|
||
try:
|
||
print(f"尝试连接 RTMP 流 (尝试 {attempt + 1}/{max_retries}): {video_url}")
|
||
|
||
# 直接使用 av.open 尝试打开流,设置20s的超时时间,防止视频流断开后,进程挂起
|
||
# container = av.open(video_url)
|
||
container = av.open(video_url)
|
||
|
||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||
|
||
print(f"成功连接到 RTMP 流: {video_url} ({video_stream.width}x{video_stream.height})")
|
||
|
||
# 创建 OpenCV 窗口用于调试(可选)
|
||
# cv2.namedWindow('RTMP Stream', cv2.WINDOW_NORMAL)
|
||
|
||
for frame in container.decode(video=0):
|
||
# 检查停止事件
|
||
if stop_event.is_set():
|
||
print("收到停止信号,终止 RTMP 读取")
|
||
break
|
||
|
||
try:
|
||
# 转换为 OpenCV 格式
|
||
img = frame.to_ndarray(format='bgr24')
|
||
|
||
# 获取 OSD 信息(如果需要)
|
||
osd_info = None
|
||
if device and topic_osd_info and method_osd_info:
|
||
osd_msg = device.get_latest_message(topic=topic_osd_info, method=method_osd_info)
|
||
if osd_msg and hasattr(osd_msg, 'data'):
|
||
osd_info = Air_Attitude(
|
||
gimbal_pitch=osd_msg.data.gimbal_pitch,
|
||
gimbal_roll=osd_msg.data.gimbal_roll,
|
||
gimbal_yaw=osd_msg.data.gimbal_yaw,
|
||
height=osd_msg.data.height,
|
||
latitude=osd_msg.data.latitude,
|
||
longitude=osd_msg.data.longitude
|
||
)
|
||
|
||
# 显示帧用于调试(可选)
|
||
# cv2.imshow('RTMP Stream', img)
|
||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
# stop_event.set()
|
||
# break
|
||
|
||
# 放入帧队列
|
||
if not frame_queue.full():
|
||
await frame_queue.put((img, osd_info, time.time()))
|
||
else:
|
||
await asyncio.sleep(0.001) # 队列满时短暂等待
|
||
|
||
except Exception as frame_error:
|
||
print(f"处理单帧时出错: {frame_error}")
|
||
continue
|
||
|
||
# 如果循环正常退出,说明流结束了
|
||
print("RTMP 流已结束")
|
||
break
|
||
|
||
except (av.AVError, IOError) as e:
|
||
print(f"RTMP 流错误 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||
if attempt < max_retries - 1:
|
||
await asyncio.sleep(retry_delay)
|
||
else:
|
||
raise RuntimeError(f"无法连接 RTMP 流 (尝试 {max_retries} 次后失败): {video_url}")
|
||
|
||
except Exception as e:
|
||
print(f"未知错误: {e}")
|
||
if attempt < max_retries - 1:
|
||
await asyncio.sleep(retry_delay)
|
||
else:
|
||
raise
|
||
|
||
finally:
|
||
# 确保容器被关闭
|
||
if container:
|
||
container.close()
|
||
# cv2.destroyAllWindows()
|
||
|
||
|
||
# async def process_frames(detector: MultiYOLODetector):
|
||
async def process_frames(detector: MultiYOLODetector_TrackId):
|
||
"""协程处理帧队列"""
|
||
while not stop_event.is_set():
|
||
try:
|
||
frame, osd_info, timestamp = await asyncio.wait_for(
|
||
frame_queue.get(),
|
||
timeout=0.5 # 延长超时,适配处理耗时
|
||
)
|
||
|
||
try:
|
||
detections, detections_list, model_para = await detector.predict(frame)
|
||
predict_state = True
|
||
if not detections:
|
||
predict_state = False
|
||
logger.debug("未检测到任何目标")
|
||
|
||
continue
|
||
|
||
processed_data = {
|
||
'frame': frame,
|
||
'osd_info': osd_info,
|
||
'detections': detections,
|
||
'detections_list': detections_list,
|
||
'timestamp': timestamp,
|
||
'model_para': model_para,
|
||
'predict_state': predict_state # predict状态判定,方便rtmp推流做状态判定
|
||
}
|
||
|
||
if not processed_queue.full():
|
||
await processed_queue.put(processed_data)
|
||
else:
|
||
logger.warning("处理队列已满,丢弃帧")
|
||
stats['dropped_frames'] += 1
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理帧时出错: {e}", exc_info=True)
|
||
stats['dropped_frames'] += 1
|
||
await asyncio.sleep(0.1)
|
||
|
||
except asyncio.TimeoutError:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"获取帧时发生意外错误: {e}", exc_info=True)
|
||
await asyncio.sleep(0.1)
|
||
|
||
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
|
||
chinese_font = None
|
||
try:
|
||
chinese_font = ImageFont.truetype("config/SIMSUN.TTC", 20)
|
||
except:
|
||
chinese_font = ImageFont.load_default()
|
||
|
||
|
||
def put_chinese_text(img, text, position, font=chinese_font, color=(0, 255, 0)):
|
||
"""使用预加载的字体绘制中文"""
|
||
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||
draw = ImageDraw.Draw(img_pil)
|
||
draw.text(position, text, font=font, fill=color)
|
||
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
||
|
||
|
||
# 全局变量存储推流容器
|
||
stream_containers: Dict[str, Any] = {}
|
||
|
||
from collections import defaultdict
|
||
import time
|
||
from typing import Dict, Any
|
||
|
||
|
||
class TrackIDEventFilter:
|
||
def __init__(self, max_inactive_time: float = 5.0):
|
||
"""
|
||
初始化 TrackID 过滤器
|
||
:param max_inactive_time: 最大不活跃时间(秒),超时未出现的 track_id 会被移除
|
||
"""
|
||
self.track_status = defaultdict(int) # 记录每个 track_id 的连续识别次数
|
||
self.last_active_time = defaultdict(float) # 记录每个 track_id 最后一次出现的时间
|
||
self.max_inactive_time = max_inactive_time # 超时清理阈值
|
||
|
||
def should_report(self, track_id: int) -> bool:
|
||
"""
|
||
判断是否应该上报事件
|
||
:param track_id: 跟踪ID
|
||
:return: True 表示应该上报,False 表示过滤
|
||
"""
|
||
current_time = time.time()
|
||
|
||
# 1. 清理长时间不活跃的 track_id
|
||
inactive_tracks = [
|
||
tid for tid, last_time in self.last_active_time.items()
|
||
if (current_time - last_time) > self.max_inactive_time
|
||
]
|
||
for tid in inactive_tracks:
|
||
self.track_status.pop(tid, None)
|
||
self.last_active_time.pop(tid, None)
|
||
|
||
# 2. 更新最后活跃时间
|
||
self.last_active_time[track_id] = current_time
|
||
|
||
# 3. 过滤逻辑
|
||
if self.track_status[track_id] == 0:
|
||
# 首次出现,记录但不上报
|
||
self.track_status[track_id] = 1
|
||
return False
|
||
elif self.track_status[track_id] == 1:
|
||
# 连续出现第二次,上报并重置状态(需要再次连续出现才会触发)
|
||
self.track_status[track_id] = 0 # 重置为 0,下次需重新连续出现两次
|
||
return True
|
||
else:
|
||
# 其他情况(如超时后首次出现),重新开始计数
|
||
self.track_status[track_id] = 1
|
||
return False
|
||
|
||
|
||
async def write_results_to_rtmp(task_id: str, output_url: str = None, input_fps: float = None,
|
||
list_points: list[list[any]] = None, invade_state: bool = False):
|
||
global stream_containers, count_pic
|
||
print(f"推流地址: {output_url}")
|
||
|
||
# 修改推流参数
|
||
options = {
|
||
'preset': 'veryfast',
|
||
'tune': 'zerolatency',
|
||
'crf': '23',
|
||
'g': '50', # 关键帧间隔
|
||
'threads': '2', # 限制编码线程
|
||
}
|
||
codec_name = 'libx264'
|
||
|
||
# 初始化视频输出
|
||
output_video_path = None
|
||
video_writer = None
|
||
frame_width, frame_height = None, None
|
||
fps = input_fps or Config.TARGET_FPS # 优先使用输入帧率
|
||
|
||
# 如果需要保存处理后的视频
|
||
if output_url: # 或者根据某个配置参数决定是否保存
|
||
# 创建输出视频路径(可以根据需要修改路径和文件名)
|
||
output_video_path = r"D:\project\AI-PYTHON\Ai_tottle\save_pic\a12.mp4"
|
||
print(f"将处理后的视频保存到: {output_video_path}")
|
||
|
||
# 帧间隔控制
|
||
frame_interval = 1.0 / fps
|
||
last_frame_time = time.time()
|
||
count_p = 0
|
||
while not stop_event.is_set():
|
||
try:
|
||
# 第一层帧率控制:确保按目标帧率处理,未到时间则等待(而非直接跳过)
|
||
current_time = time.time()
|
||
time_diff = frame_interval - (current_time - last_frame_time)
|
||
if time_diff > 0:
|
||
await asyncio.sleep(time_diff) # 等待到间隔时间,而非跳过
|
||
last_frame_time = current_time
|
||
|
||
processed_data = await asyncio.wait_for(processed_queue.get(), timeout=10)
|
||
|
||
# 确保 processed_data 是字典
|
||
if not isinstance(processed_data, dict):
|
||
print(f"❌ 错误:processed_data 不是字典,而是 {type(processed_data)}")
|
||
continue
|
||
|
||
frame = processed_data['frame']
|
||
predict_state = processed_data['predict_state']
|
||
if predict_state:
|
||
# 测试代码,用做测试推理结果,初始化视频写入器(如果尚未初始化)
|
||
if video_writer is None and output_video_path:
|
||
frame_width = frame.shape[1]
|
||
frame_height = frame.shape[0]
|
||
# 定义视频编码器和输出文件
|
||
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者使用 'avc1' 等其他编码
|
||
# fps = Config.TARGET_FPS # 使用配置中的目标帧率
|
||
video_writer = cv2.VideoWriter(
|
||
output_video_path,
|
||
fourcc,
|
||
fps,
|
||
(frame_width, frame_height)
|
||
)
|
||
print(f"视频写入器已初始化,分辨率: {frame_width}x{frame_height}, FPS: {fps}")
|
||
|
||
osd_info = processed_data['osd_info']
|
||
detections = processed_data['detections']
|
||
detections_list = processed_data['detections_list']
|
||
model_para = processed_data['model_para']
|
||
|
||
class_names = model_para[0]["model_class_names"]
|
||
chinese_label = model_para[0]["model_chinese_labe"]
|
||
cls_map = model_para[0]["cls_map"]
|
||
para_invade_enable = model_para[0]["para_invade_enable"]
|
||
|
||
img_height, img_width = frame.shape[:2]
|
||
results = []
|
||
invade_point = []
|
||
message_point = []
|
||
cls_count = 0
|
||
if osd_info:
|
||
gimbal_yaw = osd_info.gimbal_yaw
|
||
gimbal_pitch = osd_info.gimbal_pitch
|
||
gimbal_roll = osd_info.gimbal_roll
|
||
height = osd_info.height
|
||
cam_longitude = osd_info.longitude
|
||
cam_latitude = osd_info.latitude
|
||
# 当前list_points 虽然是二维数组,但是只存了一个,后续根据业务变化
|
||
for points in list_points:
|
||
# 批量返回图像的像素坐标
|
||
results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude,
|
||
cam_latitude,
|
||
img_width,
|
||
img_height, points)
|
||
# 绘制检测结果
|
||
frame_copy = frame.copy()
|
||
# # 初始化统计字典
|
||
class_stats = defaultdict(int)
|
||
reversed_dict = {value: 0 for value in cls_map.values()}
|
||
bg_color = (173, 216, 230) # 文本框底色使用淡蓝色
|
||
text_color = (0, 255, 0) # 绿色
|
||
|
||
for det in detections:
|
||
|
||
x1, y1, x2, y2 = map(int, det.bbox) # 确保坐标是整数
|
||
cls_id = det.class_id # 假设Detection对象有class_id属性
|
||
class_name = det.class_name
|
||
confidence = det.confidence
|
||
track_id = det.track_id
|
||
# 更新统计
|
||
class_stats[cls_id] += 1
|
||
# 如果开起侵限功能,就只显示侵限内的框
|
||
if invade_state:
|
||
point_x = (x1 + x2) / 2
|
||
point_y = (y1 + y2) / 2
|
||
is_invade = is_point_in_polygon(point_x, point_y, results)
|
||
if is_invade:
|
||
cls_count += 1
|
||
invade_point.append({
|
||
"u": point_x,
|
||
"v": point_y,
|
||
"class_name": class_name
|
||
})
|
||
model_list_func_id = model_para[0]["model_list_func_id"]
|
||
model_func_id = model_para[0]["func_id"]
|
||
en_name = model_para[0]["model_chinese_labe"][
|
||
model_para[0]["model_class_names"].index(class_name)]
|
||
message_point.append({
|
||
"confidence": confidence,
|
||
"cls_id": cls_id,
|
||
"type_name": en_name,
|
||
"box": [x1, y1, x2, y2]
|
||
})
|
||
|
||
label = f"{confidence:.2f}:{track_id}"
|
||
|
||
# 计算文本位置
|
||
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
|
||
text_width, text_height = text_size[0], text_size[1]
|
||
text_x = x1
|
||
text_y = y1 - 5
|
||
|
||
# 如果文本超出图像顶部,则放在框内部下方
|
||
if text_y < 0:
|
||
text_y = y2 + text_height + 5
|
||
|
||
temp_img = frame_copy.copy()
|
||
frame_copy = put_chinese_text(
|
||
temp_img,
|
||
label,
|
||
# "", # 注释掉汉字
|
||
(text_x, text_y - text_height),
|
||
)
|
||
else:
|
||
# 绘制边界框
|
||
cv2.rectangle(frame_copy, (x1, y1), (x2, y2), (0, 255, 255), 2)
|
||
|
||
# 准备标签文本
|
||
label = f"{chinese_label.get(cls_id, class_name)}: {confidence:.2f}:{track_id}"
|
||
|
||
# 计算文本位置
|
||
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
|
||
text_width, text_height = text_size[0], text_size[1]
|
||
text_x = x1
|
||
text_y = y1 - 5
|
||
|
||
# 如果文本超出图像顶部,则放在框内部下方
|
||
if text_y < 0:
|
||
text_y = y2 + text_height + 5
|
||
|
||
# 绘制文本背景
|
||
padding = 2
|
||
temp_img = frame_copy.copy()
|
||
frame_copy = put_chinese_text(
|
||
temp_img,
|
||
label,
|
||
# "", # 注释掉汉字
|
||
(text_x, text_y - text_height),
|
||
)
|
||
|
||
if invade_state:
|
||
for point in message_point:
|
||
cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
|
||
(point["box"][2], point["box"][3]),
|
||
(0, 255, 255), 2)
|
||
# 画红线
|
||
point_list=[]
|
||
if results:
|
||
for point in results:
|
||
point_list.append([point["u"], point["v"]])
|
||
cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255),
|
||
thickness=2)
|
||
# 在左上角显示统计结果
|
||
stats_text = []
|
||
for cls_id, count in class_stats.items():
|
||
cls_name = chinese_label.get(cls_id,
|
||
class_names[cls_id] if class_names and cls_id < len(
|
||
class_names) else str(
|
||
cls_id))
|
||
reversed_dict[cls_name] = count
|
||
|
||
for key, value in reversed_dict.items():
|
||
stats_text.append(f"{key}: {value}")
|
||
|
||
if stats_text:
|
||
# 计算统计文本的总高度
|
||
text_height = cv2.getTextSize("Test", cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0][1]
|
||
total_height = len(stats_text) * (text_height + 5) # 5是行间距
|
||
|
||
# 统计文本的起始位置(左上角)
|
||
start_x = 50 # 留出200像素宽度
|
||
start_y = 20 # 从顶部开始
|
||
|
||
# 绘制统计背景
|
||
# bg_color = (0, 0, 0)
|
||
frame_copy = cv2.rectangle(
|
||
frame_copy,
|
||
(start_x - 10, start_y - 10),
|
||
(200, start_y + total_height + 10),
|
||
bg_color,
|
||
-1
|
||
)
|
||
|
||
# 逐行绘制统计文本
|
||
for i, text in enumerate(stats_text):
|
||
y_pos = start_y + i * (text_height + 5)
|
||
temp_img = frame_copy.copy()
|
||
frame_copy = put_chinese_text(
|
||
temp_img,
|
||
text,
|
||
(start_x, y_pos),
|
||
# font_size=15,
|
||
color=text_color
|
||
)
|
||
|
||
new_data = {
|
||
'frame_copy': frame_copy,
|
||
'frame': frame,
|
||
"osd_info": osd_info,
|
||
'detections': detections,
|
||
'timestamp': processed_data.get('timestamp'),
|
||
"detections_list": detections_list,
|
||
"model_para": model_para
|
||
# 'model_para': processed_data.get('model_para', {}) # 确保 model_para 存在
|
||
}
|
||
# 临时代码 rtmp 和侵限逻辑要改
|
||
if invade_state:
|
||
# para_list 中使能了 para_invade_enable,才做侵限判断
|
||
# if para_invade_enable:
|
||
if not invade_queue.full():
|
||
await invade_queue.put(new_data)
|
||
else:
|
||
if not cv_frame_queue.full():
|
||
print(f"cv_frame_queue.put(new_data)cv_frame_queue.put(new_data) ")
|
||
await cv_frame_queue.put(new_data)
|
||
|
||
count_p = count_p + 1
|
||
# cv2.imwrite(f"save_pic/rtmp/test-{count_p}.jpg", frame_copy)
|
||
# video_writer.write(frame_copy)
|
||
else:
|
||
frame_copy = frame
|
||
# 转换颜色空间
|
||
rgb_frame = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
|
||
|
||
# 初始化推流容器(如果尚未初始化)
|
||
if output_url and output_url not in stream_containers:
|
||
try:
|
||
container = av.open(output_url, mode='w', format='flv')
|
||
stream = container.add_stream(codec_name, rate=Config.TARGET_FPS)
|
||
# stream.time_base = f"1/{Config.TARGET_FPS}"
|
||
|
||
stream.width = frame.shape[1]
|
||
|
||
stream.height = frame.shape[0]
|
||
stream.pix_fmt = 'yuv420p'
|
||
stream.options = options
|
||
|
||
stream_containers[output_url] = {
|
||
'container': container,
|
||
'stream': stream,
|
||
'last_frame_time': time.time(),
|
||
'frame_count': 0
|
||
}
|
||
print(f"✅ 推流初始化成功: {output_url}")
|
||
except Exception as e:
|
||
print(f"❌ 推流初始化失败: {e}")
|
||
if 'container' in locals():
|
||
try:
|
||
container.close()
|
||
except:
|
||
pass
|
||
await asyncio.sleep(1.0)
|
||
continue
|
||
|
||
# 推流逻辑
|
||
if output_url and output_url in stream_containers:
|
||
try:
|
||
container_info = stream_containers[output_url]
|
||
stream = container_info['stream']
|
||
container = container_info['container']
|
||
|
||
if rgb_frame.dtype == np.uint8:
|
||
av_frame = av.VideoFrame.from_ndarray(rgb_frame, format='rgb24')
|
||
packets = stream.encode(av_frame) # 这是异步方法
|
||
print(f"📦 encode 生成 {len(packets)} 个 packet")
|
||
|
||
if packets:
|
||
for packet in packets:
|
||
container.mux(packet)
|
||
container_info['last_frame_time'] = time.time() # ← 只有真正 mux 后才更新
|
||
container_info['frame_count'] += 1
|
||
else:
|
||
# 编码器仍在初始化,不更新 last_frame_time
|
||
pass
|
||
|
||
# 每100帧打印一次状态
|
||
if container_info['frame_count'] % 100 == 0:
|
||
print(f"ℹ️ 已推送 {container_info['frame_count']} 帧到 {output_url}")
|
||
else:
|
||
print(f"⚠️ 无效帧格式: {rgb_frame.dtype}")
|
||
except Exception as e:
|
||
print(f"❌ 推流错误: {e}")
|
||
try:
|
||
container_info['container'].close()
|
||
except:
|
||
pass
|
||
stream_containers.pop(output_url, None)
|
||
await asyncio.sleep(1.0)
|
||
|
||
# if predict_state:
|
||
# new_data = {
|
||
# 'frame_copy': frame_copy,
|
||
# 'frame': frame,
|
||
# "osd_info": osd_info,
|
||
# 'detections': detections,
|
||
# 'timestamp': processed_data.get('timestamp'),
|
||
# "detections_list": detections_list,
|
||
# "model_para": model_para
|
||
# # 'model_para': processed_data.get('model_para', {}) # 确保 model_para 存在
|
||
# }
|
||
# if not cv_frame_queue.full():
|
||
# print(f"cv_frame_queue.put(new_data)cv_frame_queue.put(new_data) ")
|
||
# await cv_frame_queue.put(new_data)
|
||
# # para_list 中使能了 para_invade_enable,才做侵限判断
|
||
# if para_invade_enable:
|
||
# if not invade_queue.full():
|
||
# await invade_queue.put(new_data)
|
||
except asyncio.TimeoutError:
|
||
# 检查现有推流状态
|
||
for url, info in list(stream_containers.items()):
|
||
if time.time() - info['last_frame_time'] > 5: # 5秒无帧则重启
|
||
print(f"⚠️ 检测到推流停滞,重启 {url}")
|
||
try:
|
||
info['container'].close()
|
||
except:
|
||
pass
|
||
stream_containers.pop(url)
|
||
continue
|
||
|
||
except Exception as e:
|
||
print(f"❌ 推流处理异常: {e}")
|
||
await asyncio.sleep(0.1)
|
||
|
||
|
||
# 基于射线法,判断设备是否在红线内
|
||
|
||
def is_point_in_polygon(point_x, point_y, polygon):
|
||
x, y = point_x, point_y
|
||
n = len(polygon)
|
||
inside = False
|
||
|
||
for i in range(n):
|
||
p1 = polygon[i]
|
||
p2 = polygon[(i + 1) % n]
|
||
x1, y1 = p1["u"], p1["v"]
|
||
x2, y2 = p2["u"], p2["v"]
|
||
|
||
# 检查点是否在多边形的顶点或边上
|
||
if (x == x1 and y == y1) or (x == x2 and y == y2):
|
||
return True # 点在顶点上
|
||
if (x - x1) * (y2 - y1) == (y - y1) * (x2 - x1): # 点在边上
|
||
if min(x1, x2) <= x <= max(x1, x2) and min(y1, y2) <= y <= max(y1, y2):
|
||
return True
|
||
|
||
# 射线法核心逻辑
|
||
if (y1 > y) != (y2 > y): # 确保点在边的垂直范围内
|
||
x_intersect = (x2 - x1) * (y - y1) / (y2 - y1) + x1
|
||
if x <= x_intersect:
|
||
inside = not inside
|
||
|
||
return inside
|
||
|
||
#
|
||
## 判断目标是否侵限
|
||
#async def cal_des_invade(task_id: str, mqtt, mqtt_publish_topic, list_points: list[list[any]]):
|
||
# loop = asyncio.get_running_loop()
|
||
# print(6)
|
||
# pic_count = 0
|
||
# pic_count_hongxian = 0
|
||
# while not stop_event.is_set():
|
||
# # while True:
|
||
# print(777777777)
|
||
# cv_frame = None
|
||
# # 检查队列长度,避免堆积
|
||
# if invade_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
|
||
# print(f"警告:invade_queue 积压(当前长度={invade_queue.qsize()}),清空队列")
|
||
# while not invade_queue.empty():
|
||
# await invade_queue.get()
|
||
# print("队列已清空,等待新数据...")
|
||
# await asyncio.sleep(0.1)
|
||
# continue
|
||
# # 获取队列数据
|
||
# try:
|
||
# cv_frame = await asyncio.wait_for(invade_queue.get(), timeout=0.8)
|
||
# # 检查数据类型
|
||
# if not isinstance(cv_frame, dict):
|
||
# print(f"⚠️ 警告:cv_frame 不是字典,而是 {type(cv_frame)}")
|
||
# continue
|
||
# except asyncio.TimeoutError:
|
||
# print(f"cal_des_invade TimeoutError")
|
||
# continue
|
||
#
|
||
# if cv_frame is None:
|
||
# continue
|
||
# print("cal_des_invade inside")
|
||
# frame_copy = cv_frame['frame_copy']
|
||
# frame = cv_frame['frame']
|
||
# air_alti = cv_frame['osd_info']
|
||
# detections = cv_frame['detections']
|
||
# detections_list = cv_frame['detections_list']
|
||
# model_para_list = cv_frame.get('model_para', {}) # 默认空字典
|
||
# if not isinstance(frame, np.ndarray) or frame.size == 0:
|
||
# print("⚠️ 警告:frame不是有效的numpy数组")
|
||
# continue
|
||
#
|
||
# # 获取宽高
|
||
# img_height, img_width = frame.shape[:2]
|
||
# if not air_alti:
|
||
# continue
|
||
#
|
||
# # if latest_osd_info:
|
||
# # # message_json = json.loads(latest_osd_info)
|
||
# # # osd_info = parse_osd_message(json.dumps(message_json))
|
||
# # print(f"解析成功: {latest_osd_info.data.gimbal_yaw}, {latest_osd_info.data.longitude}")
|
||
# # osd_info = parse_osd_message(latest_osd_info)
|
||
# gimbal_yaw = air_alti.gimbal_yaw
|
||
# gimbal_pitch = air_alti.gimbal_pitch
|
||
# gimbal_roll = air_alti.gimbal_roll
|
||
# height = air_alti.height
|
||
# cam_longitude = air_alti.longitude
|
||
# cam_latitude = air_alti.latitude
|
||
#
|
||
# # img_width = 1440
|
||
# # img_height = 1080
|
||
# # 二维数组,特殊要求
|
||
# for points in list_points:
|
||
# # 批量返回图像的像素坐标
|
||
# results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude, cam_latitude,
|
||
# img_width,
|
||
# img_height, points)
|
||
#
|
||
# try:
|
||
# invade_point = []
|
||
# message_point = []
|
||
# cls_count = 0
|
||
# model_id = 101101
|
||
# for idx, model_para in enumerate(model_para_list):
|
||
# if not isinstance(detections_list, list) or idx >= len(detections_list):
|
||
# continue
|
||
# detections = detections_list[idx] # 正确获取当前模型的检测结果
|
||
# if detections is None or len(detections.boxes) < 1:
|
||
# continue
|
||
# # for det in detections:
|
||
# for bbox, class_id, class_name, confidence, track_id in detections:
|
||
# x1, y1, x2, y2 = map(int, bbox) # 确保坐标是整数
|
||
# cls_id = class_id # 假设Detection对象有class_id属性
|
||
# class_name = class_name
|
||
# confidence = confidence
|
||
# point_x = (x1 + x2) / 2
|
||
# point_y = (y1 + y2) / 2
|
||
#
|
||
# is_invade = is_point_in_polygon(point_x, point_y, results)
|
||
# if is_invade:
|
||
# cls_count += 1
|
||
# invade_point.append({
|
||
# "u": point_x,
|
||
# "v": point_y,
|
||
# "class_name": class_name
|
||
# })
|
||
# model_list_func_id = model_para["model_list_func_id"]
|
||
# model_func_id = model_para["func_id"]
|
||
# en_name = model_para["model_chinese_labe"][
|
||
# model_para["model_class_names"].index(class_name)]
|
||
# message_point.append({
|
||
# "confidence": confidence,
|
||
# "cls_id": cls_id,
|
||
# "type_name": en_name,
|
||
# "box": [x1, y1, x2, y2]
|
||
# })
|
||
#
|
||
# label = f"{confidence:.2f}:{track_id}"
|
||
#
|
||
# # 计算文本位置
|
||
# text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
|
||
# text_width, text_height = text_size[0], text_size[1]
|
||
# text_x = x1
|
||
# text_y = y1 - 5
|
||
#
|
||
# # 如果文本超出图像顶部,则放在框内部下方
|
||
# if text_y < 0:
|
||
# text_y = y2 + text_height + 5
|
||
#
|
||
# temp_img = frame_copy.copy()
|
||
# frame_copy = put_chinese_text(
|
||
# temp_img,
|
||
# label,
|
||
# # "", # 注释掉汉字
|
||
# (text_x, text_y - text_height),
|
||
# )
|
||
#
|
||
# point_list = [] # 整理红线集合
|
||
# # point_des_list=[]
|
||
#
|
||
# # 测试红线
|
||
# for point in results:
|
||
# point_list.append([point["u"], point["v"]])
|
||
# for point in message_point:
|
||
# cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]),
|
||
# (0, 255, 255), 2)
|
||
# cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255),
|
||
# thickness=2)
|
||
# pic_count_hongxian = pic_count_hongxian + 1
|
||
## cv2.imwrite(f"save_pic\invade-hongxian\hongxianongly-{pic_count_hongxian}.jpg", frame_copy)
|
||
#
|
||
# if len(invade_point) > 0:
|
||
# for point in results:
|
||
# point_list.append([point["u"], point["v"]])
|
||
#
|
||
# for point in message_point:
|
||
# cv2.rectangle(frame, (point["box"][0], point["box"][1]), (point["box"][2], point["box"][3]),
|
||
# (0, 255, 255), 2)
|
||
#
|
||
# # 绘制红线
|
||
# cv2.polylines(frame, [np.array(point_list, dtype=np.int32)], isClosed=True, color=(0, 0, 255),
|
||
# thickness=2)
|
||
#
|
||
# print("hongxianhongxianhongxianhongxian")
|
||
# pic_count = pic_count + 1
|
||
## cv2.imwrite(f"save_pic\invade\hongxian-{pic_count}.jpg", frame)
|
||
#
|
||
# # await loop.run_in_executor(None, _show_frame, frame)
|
||
# # cv2.imshow("hongxian ", frame)
|
||
#
|
||
# drawn_frame = frame.copy() # 关键修复:深拷贝绘制后的帧
|
||
#
|
||
# # 图像编码
|
||
# def encode_frame():
|
||
# success, buffer = cv2.imencode(".jpg", drawn_frame)
|
||
# return buffer.tobytes() if success else None
|
||
#
|
||
# buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
|
||
# if not buffer_bytes:
|
||
# continue
|
||
#
|
||
# # 并行处理上传和MQTT发布
|
||
# async def upload_and_publish():
|
||
# # 上传到MinIO
|
||
# def upload_minio():
|
||
# return upload_frame_buff_from_buffer(buffer_bytes, None)
|
||
#
|
||
# minio_path, file_type = await loop.run_in_executor(
|
||
# upload_executor, upload_minio
|
||
# )
|
||
#
|
||
# # 构造消息
|
||
# message = {
|
||
# "task_id": task_id,
|
||
# "minio": {"minio_path": minio_path, "file_type": file_type},
|
||
# # "box_detail": invade_point
|
||
# "box_detail": [{
|
||
# "model_id": model_list_func_id,
|
||
# "cls_count": cls_count,
|
||
# "box_count": [message_point]
|
||
# }]
|
||
# }
|
||
# print(f"hongxianhongxianhongxianhongxian上传 {message}")
|
||
# message_json = json.dumps(message, ensure_ascii=False)
|
||
# await mqtt.publish(mqtt_publish_topic, message_json)
|
||
#
|
||
# # 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
|
||
# async with invade_cache_lock:
|
||
# # 读取或修改共享变量
|
||
# send_count = shared_local_cache["invade_send_count"]
|
||
# send_count = send_count + 1
|
||
# shared_local_cache["invade_send_count"] = send_count
|
||
# if send_count > 1:
|
||
# # 创建独立任务执行上传和发布
|
||
# shared_local_cache["invade_send_count"] = 0
|
||
# asyncio.create_task(upload_and_publish())
|
||
#
|
||
# # status_to_publish = None
|
||
# # async with invade_cache_lock:
|
||
# # video_process_status = shared_local_cache["video_process_status"]
|
||
# # if video_process_status > 0:
|
||
# # shared_local_cache["video_process_status"] = 0
|
||
# # status_to_publish = video_process_status
|
||
# #
|
||
# # if status_to_publish is not None:
|
||
# # await publist_status(status_to_publish)
|
||
# except Exception as e:
|
||
# print(f"cal_des_invade 错误: {e}")
|
||
# await asyncio.sleep(0.1)
|
||
|
||
|
||
|
||
async def cal_des_invade(task_id: str, mqtt, mqtt_publish_topic,
|
||
list_points: list[list[any]], model_count: int):
|
||
loop = asyncio.get_running_loop()
|
||
print(6)
|
||
pic_count = 0
|
||
pic_count_hongxian = 0
|
||
track_id_filters = []
|
||
for i in range(model_count):
|
||
track_id_filters.append(TrackIDEventFilter(max_inactive_time=5.0)) # 5秒不出现则清除
|
||
|
||
# 用于记录已上报的track_id及其上报时间
|
||
reported_track_ids = defaultdict(float)
|
||
# 上报间隔时间(秒)
|
||
report_interval = 2.0
|
||
|
||
while not stop_event.is_set():
|
||
print(777777777)
|
||
cv_frame = None
|
||
# 检查队列长度,避免堆积
|
||
if invade_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
|
||
print(f"警告:invade_queue 积压(当前长度={invade_queue.qsize()}),清空队列")
|
||
while not invade_queue.empty():
|
||
await invade_queue.get()
|
||
print("队列已清空,等待新数据...")
|
||
await asyncio.sleep(0.1)
|
||
continue
|
||
# 获取队列数据
|
||
try:
|
||
cv_frame = await asyncio.wait_for(invade_queue.get(), timeout=0.8)
|
||
# 检查数据类型
|
||
if not isinstance(cv_frame, dict):
|
||
print(f"⚠️ 警告:cv_frame 不是字典,而是 {type(cv_frame)}")
|
||
continue
|
||
except asyncio.TimeoutError:
|
||
print(f"cal_des_invade TimeoutError")
|
||
continue
|
||
|
||
if cv_frame is None:
|
||
continue
|
||
print("cal_des_invade inside")
|
||
frame_copy = cv_frame['frame_copy']
|
||
frame = cv_frame['frame']
|
||
air_alti = cv_frame['osd_info']
|
||
detections = cv_frame['detections']
|
||
detections_list = cv_frame['detections_list']
|
||
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
|
||
if not isinstance(frame, np.ndarray) or frame.size == 0:
|
||
print("⚠️ 警告:frame不是有效的numpy数组")
|
||
continue
|
||
|
||
# 获取宽高
|
||
img_height, img_width = frame.shape[:2]
|
||
if not air_alti:
|
||
continue
|
||
|
||
gimbal_yaw = air_alti.gimbal_yaw
|
||
gimbal_pitch = air_alti.gimbal_pitch
|
||
gimbal_roll = air_alti.gimbal_roll
|
||
height = air_alti.height
|
||
cam_longitude = air_alti.longitude
|
||
cam_latitude = air_alti.latitude
|
||
|
||
# 二维数组,特殊要求
|
||
for points in list_points:
|
||
# 批量返回图像的像素坐标
|
||
results = red_line_reproject(gimbal_yaw, gimbal_pitch, gimbal_roll, height, cam_longitude, cam_latitude,
|
||
img_width,
|
||
img_height, points)
|
||
|
||
try:
|
||
invade_point = []
|
||
message_point = []
|
||
cls_count = 0
|
||
model_id = 101101
|
||
current_time = time.time()
|
||
|
||
for idx, model_para in enumerate(model_para_list):
|
||
if not isinstance(detections_list, list) or idx >= len(detections_list):
|
||
continue
|
||
detections = detections_list[idx] # 正确获取当前模型的检测结果
|
||
if detections is None or len(detections.boxes) < 1:
|
||
continue
|
||
|
||
if idx >= len(track_id_filters):
|
||
track_id_filters.append(TrackIDEventFilter(max_inactive_time=5.0))
|
||
track_filter = track_id_filters[idx]
|
||
|
||
for bbox, class_id, class_name, confidence, track_id in detections:
|
||
# 跳过无效的track_id
|
||
if track_id is None:
|
||
continue
|
||
|
||
# 检查是否应该上报该track_id
|
||
should_report = True
|
||
|
||
# 如果这个track_id已经上报过,检查是否超过上报间隔
|
||
if track_id in reported_track_ids:
|
||
last_report_time = reported_track_ids[track_id]
|
||
if current_time - last_report_time < report_interval:
|
||
print("基于track_id,触发去重事件")
|
||
should_report = False
|
||
|
||
if should_report and track_filter.should_report(track_id):
|
||
should_report = True
|
||
|
||
if track_id<0: #适配MultiYOLODetector类,该类不支持追踪,默认track_id为-1
|
||
should_report = True
|
||
|
||
# 如果使用TrackIDEventFilter判断需要上报
|
||
if should_report:
|
||
x1, y1, x2, y2 = map(int, bbox) # 确保坐标是整数
|
||
point_x = (x1 + x2) / 2
|
||
point_y = (y1 + y2) / 2
|
||
|
||
is_invade = is_point_in_polygon(point_x, point_y, results)
|
||
if is_invade:
|
||
cls_count += 1
|
||
invade_point.append({
|
||
"u": point_x,
|
||
"v": point_y,
|
||
"class_name": class_name
|
||
})
|
||
model_list_func_id = model_para["model_list_func_id"]
|
||
model_func_id = model_para["func_id"]
|
||
en_name = model_para["model_chinese_labe"][
|
||
model_para["model_class_names"].index(class_name)]
|
||
message_point.append({
|
||
"confidence": confidence,
|
||
"cls_id": class_id,
|
||
"type_name": en_name,
|
||
"box": [x1, y1, x2, y2],
|
||
"track_id": track_id # 添加track_id到上报信息
|
||
})
|
||
|
||
# 记录上报时间
|
||
reported_track_ids[track_id] = current_time
|
||
|
||
label = f"{confidence:.2f}:{track_id}"
|
||
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
|
||
text_x = x1
|
||
text_y = y1 - 5
|
||
if text_y < 0:
|
||
text_y = y2 + text_size[1] + 5
|
||
|
||
temp_img = frame_copy.copy()
|
||
frame_copy = put_chinese_text(
|
||
temp_img,
|
||
label,
|
||
(text_x, text_y - text_size[1]),
|
||
)
|
||
|
||
point_list = [] # 整理红线集合
|
||
|
||
# 测试红线
|
||
for point in results:
|
||
point_list.append([point["u"], point["v"]])
|
||
for point in message_point:
|
||
cv2.rectangle(frame_copy, (point["box"][0], point["box"][1]),
|
||
(point["box"][2], point["box"][3]), (0, 255, 255), 2)
|
||
cv2.polylines(frame_copy, [np.array(point_list, dtype=np.int32)],
|
||
isClosed=True, color=(0, 0, 255), thickness=2)
|
||
pic_count_hongxian = pic_count_hongxian + 1
|
||
# cv2.imwrite(f"save_pic\invade-hongxian\hongxianongly-{pic_count_hongxian}.jpg", frame_copy)
|
||
|
||
if len(invade_point) > 0:
|
||
for point in results:
|
||
point_list.append([point["u"], point["v"]])
|
||
|
||
for point in message_point:
|
||
cv2.rectangle(frame, (point["box"][0], point["box"][1]),
|
||
(point["box"][2], point["box"][3]), (0, 255, 255), 2)
|
||
|
||
# 绘制红线
|
||
cv2.polylines(frame, [np.array(point_list, dtype=np.int32)],
|
||
isClosed=True, color=(0, 0, 255), thickness=2)
|
||
|
||
print("hongxianhongxianhongxianhongxian")
|
||
pic_count = pic_count + 1
|
||
# cv2.imwrite(f"save_pic\invade\hongxian-{pic_count}.jpg", frame)
|
||
|
||
drawn_frame = frame.copy() # 关键修复:深拷贝绘制后的帧
|
||
|
||
# 图像编码
|
||
def encode_frame():
|
||
success, buffer = cv2.imencode(".jpg", drawn_frame)
|
||
return buffer.tobytes() if success else None
|
||
|
||
buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
|
||
if not buffer_bytes:
|
||
continue
|
||
|
||
# 并行处理上传和MQTT发布
|
||
async def upload_and_publish():
|
||
# 上传到MinIO
|
||
def upload_minio():
|
||
return upload_frame_buff_from_buffer(buffer_bytes, None)
|
||
|
||
minio_path, file_type = await loop.run_in_executor(
|
||
upload_executor, upload_minio
|
||
)
|
||
|
||
# 构造消息
|
||
message = {
|
||
"task_id": task_id,
|
||
"minio": {"minio_path": minio_path, "file_type": file_type},
|
||
"box_detail": [{
|
||
"model_id": model_list_func_id,
|
||
"cls_count": cls_count,
|
||
"box_count": [message_point]
|
||
}]
|
||
}
|
||
print(f"hongxianhongxianhongxianhongxian上传 {message}")
|
||
message_json = json.dumps(message, ensure_ascii=False)
|
||
await mqtt.publish(mqtt_publish_topic, message_json)
|
||
|
||
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
|
||
async with invade_cache_lock:
|
||
# 读取或修改共享变量
|
||
send_count = shared_local_cache["invade_send_count"]
|
||
send_count = send_count + 1
|
||
shared_local_cache["invade_send_count"] = send_count
|
||
if send_count > 1:
|
||
# 创建独立任务执行上传和发布
|
||
shared_local_cache["invade_send_count"] = 0
|
||
asyncio.create_task(upload_and_publish())
|
||
|
||
except Exception as e:
|
||
print(f"cal_des_invade 错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
await asyncio.sleep(0.1)
|
||
|
||
|
||
|
||
|
||
# 全局共享变量
|
||
shared_local_cache = {
|
||
"send_count": 0,
|
||
"invade_send_count": 0,
|
||
"video_process_status": 0 # 0、1、2 分别表示录像识别的三种状态,未开始、开始、结束
|
||
}
|
||
from asyncio import Lock
|
||
|
||
cache_lock = Lock() # 用于保护共享变量的锁
|
||
invade_cache_lock = Lock() # 用于保护共享变量的锁
|
||
|
||
|
||
async def send_frame_to_s3_mq(task_id, mqtt, mqtt_topic):
|
||
global stats
|
||
count_pic = 0
|
||
loop = asyncio.get_running_loop()
|
||
local_func_cache = {
|
||
"func_100000": None,
|
||
"func_100004": None, # 存储缓存,缓存人员track_id
|
||
"func_100006": None # 存储缓存,缓存车辆track_id
|
||
}
|
||
|
||
para = {
|
||
"category": 3
|
||
}
|
||
|
||
while not stop_event.is_set():
|
||
try:
|
||
# 检查队列长度,避免堆积
|
||
if cv_frame_queue.qsize() > Config.PROCESSED_QUEUE_SIZE // 2:
|
||
print(f"警告:cv_frame_queue 积压(当前长度={cv_frame_queue.qsize()}),清空队列")
|
||
while not cv_frame_queue.empty():
|
||
await cv_frame_queue.get()
|
||
print("队列已清空,等待新数据...")
|
||
await asyncio.sleep(0.1)
|
||
continue
|
||
|
||
# 获取队列数据
|
||
try:
|
||
print(f"send_frame_to_s3_mq cv_frame_queue.get()")
|
||
cv_frame = await asyncio.wait_for(cv_frame_queue.get(), timeout=0.05)
|
||
# 检查数据类型
|
||
if not isinstance(cv_frame, dict):
|
||
print(f"⚠️ 警告:cv_frame 不是字典,而是 {type(cv_frame)}")
|
||
continue
|
||
except asyncio.TimeoutError:
|
||
continue
|
||
|
||
# 准备数据
|
||
frame_copy = cv_frame['frame_copy']
|
||
frame = cv_frame['frame']
|
||
detections = cv_frame['detections']
|
||
detections_list = cv_frame['detections_list']
|
||
model_para_list = cv_frame.get('model_para', {}) # 默认空字典
|
||
|
||
# 初始化默认值
|
||
frame11 = frame_copy # 默认使用原始帧
|
||
box_detail = [] # 默认空列表
|
||
|
||
for idx, model_para in enumerate(model_para_list):
|
||
if not isinstance(detections_list, list) or idx >= len(detections_list):
|
||
continue
|
||
detections = detections_list[idx] # 正确获取当前模型的检测结果
|
||
chinese_label = model_para.get("model_chinese_labe", {})
|
||
model_cls = model_para.get("model_cls_index", {})
|
||
list_func_id = model_para.get("model_list_func_id", -11)
|
||
func_id = model_para.get("func_id", [])
|
||
|
||
# 获取DRC消息(同步操作,放到线程池)
|
||
local_drc_message = await loop.run_in_executor(executor, get_local_drc_message)
|
||
|
||
if detections is None or len(detections.boxes) < 1:
|
||
continue
|
||
|
||
try:
|
||
# 图像处理和结果计算
|
||
result = await loop.run_in_executor(
|
||
executor,
|
||
cal_tricker_results,
|
||
frame_copy, detections, None,
|
||
func_id, local_func_cache, para, model_cls, chinese_label, list_func_id
|
||
)
|
||
|
||
# 检查返回结果是否是元组
|
||
if isinstance(result, tuple) and len(result) == 2:
|
||
frame11, box_detail = result
|
||
else:
|
||
print(f"⚠️ 警告:cal_tricker_results 返回了意外格式的结果: {type(result)}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
print(f"处理帧时出错: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
continue
|
||
|
||
count_pic = count_pic + 1
|
||
|
||
# 图像编码
|
||
def encode_frame():
|
||
success, buffer = cv2.imencode(".jpg", frame11)
|
||
return buffer.tobytes() if success else None
|
||
|
||
def encode_origin_frame():
|
||
success, buffer = cv2.imencode(".jpg", frame)
|
||
return buffer.tobytes() if success else None
|
||
|
||
buffer_bytes = await loop.run_in_executor(upload_executor, encode_frame)
|
||
buffer_origin_bytes = await loop.run_in_executor(upload_executor, encode_origin_frame)
|
||
if not buffer_bytes:
|
||
continue
|
||
|
||
# 并行处理上传和MQTT发布
|
||
async def upload_and_publish():
|
||
# 上传到MinIO
|
||
def upload_minio():
|
||
minio_path, file_type = upload_frame_buff_from_buffer(buffer_bytes, None)
|
||
minio_origin_path, file_type = upload_frame_buff_from_buffer(buffer_origin_bytes, None)
|
||
return minio_path, minio_origin_path, file_type
|
||
|
||
minio_path, minio_origin_path, file_type = await loop.run_in_executor(
|
||
upload_executor, upload_minio
|
||
)
|
||
|
||
# 构造消息
|
||
message = {
|
||
"task_id": task_id,
|
||
"minio": {"minio_path": minio_path, "minio_origin_path": minio_origin_path,
|
||
"file_type": file_type},
|
||
"box_detail": box_detail,
|
||
"uav_location": local_drc_message
|
||
}
|
||
message_json = json.dumps(message, ensure_ascii=False)
|
||
await mqtt.publish(mqtt_topic, message_json)
|
||
|
||
# 使用共享变量时加锁,进而进行跳帧,不然上报太频繁
|
||
async with cache_lock:
|
||
# 读取或修改共享变量
|
||
send_count = shared_local_cache["send_count"]
|
||
send_count = send_count + 1
|
||
shared_local_cache["send_count"] = send_count
|
||
|
||
# 伪代码,目前安全帽比较难识别到
|
||
if 100014 == model_para_list[0]["model_list_func_id"]:
|
||
if send_count > 2:
|
||
# 创建独立任务执行上传和发布
|
||
shared_local_cache["send_count"] = 0
|
||
asyncio.create_task(upload_and_publish())
|
||
else:
|
||
if send_count > 30:
|
||
# 创建独立任务执行上传和发布
|
||
shared_local_cache["send_count"] = 0
|
||
asyncio.create_task(upload_and_publish())
|
||
|
||
except Exception as e:
|
||
print(f"send_frame_to_s3_mq 错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
await asyncio.sleep(0.1)
|
||
|
||
# 更新性能统计
|
||
stats['processed'] += 1
|
||
if time.time() - stats['last_time'] >= 1.0:
|
||
stats['avg_fps'] = stats['processed'] / (time.time() - stats['last_time'])
|
||
print(f"处理速度: {stats['avg_fps']:.2f} FPS")
|
||
stats['processed'] = 0
|
||
stats['last_time'] = time.time()
|
||
|
||
|
||
async def start_rtmp_processing(video_url: str, task_id: str, model_configs: List[Dict],
|
||
mqtt_pub_ip: str, mqtt_pub_port: int, mqtt_pub_topic: str,
|
||
mqtt_sub_ip: str, mqtt_sub_port: int, mqtt_sub_topic: str,
|
||
output_rtmp_url: str,
|
||
invade_enable: bool, invade_file: str):
|
||
# global stop_event, frame_queue, processed_queue, executor, upload_executor
|
||
await initialize_resources() # 初始化资源
|
||
|
||
try:
|
||
list_points = [] # 二维数组,里面的一维数组就是面
|
||
|
||
if invade_file:
|
||
invade_file_path = downFile(invade_file)
|
||
if invade_file_path is None:
|
||
print(f"invade_file is None task_id:{task_id}")
|
||
return
|
||
|
||
with open(invade_file_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
# 提取多边形坐标
|
||
features = data.get('features', [])
|
||
if not features:
|
||
print("没有找到有效的要素数据")
|
||
return
|
||
|
||
# 获取第一个多边形的坐标(假设只有一个多边形)
|
||
|
||
for polygon in features:
|
||
|
||
# polygon = features[0]
|
||
geometry = polygon.get('geometry', {})
|
||
coordinates = geometry.get('coordinates', [])
|
||
|
||
# 提取经纬度点(忽略Z值)
|
||
points = []
|
||
for key, coord in enumerate(coordinates[0]): # 取第一个环的坐标
|
||
# 只取经纬度,忽略第三个值(高度) 注意这里是反序,不然后面convert_points_to_utm 会报错
|
||
points.append(Point(coord[1], coord[0], coord[2], key))
|
||
list_points.append(points)
|
||
|
||
mqtt = MQTTService(mqtt_pub_ip, port=mqtt_pub_port)
|
||
|
||
# 初始化检测器 MultiYOLODetector_TrackId
|
||
# detector = MultiYOLODetector(model_configs)
|
||
detector = MultiYOLODetector_TrackId(model_configs)
|
||
|
||
# mqtt_publish_topic = "thing/product/ai/events"
|
||
mqtt_publish_topic = mqtt_pub_topic
|
||
# device = MQTTDevice(
|
||
# ip="47.108.62.6", # MQTT服务器IP
|
||
# port=12503, # MQTT服务器端口
|
||
# topics=["test/topic"], # 订阅两个主题
|
||
# # topics=["thing/product/8UUXN6S00A0CK7/drc/up"], # 订阅两个主题
|
||
# queue_size=50 # 每个method的消息队列最大长度
|
||
# )
|
||
# topic_camera_osd = "test/topic"
|
||
# method_camera_osd = "drc_camera_osd_info_push"
|
||
#
|
||
# topic_osd_info = "test/topic"
|
||
# method_osd_info = "osd_info_push"
|
||
device = MQTTDevice(
|
||
ip=mqtt_sub_ip, # MQTT服务器IP
|
||
port=mqtt_sub_port, # MQTT服务器端口
|
||
topics=[mqtt_sub_topic], # 订阅两个主题
|
||
# topics=["thing/product/8UUXN6S00A0CK7/drc/up"], # 订阅两个主题
|
||
queue_size=50 # 每个method的消息队列最大长度
|
||
)
|
||
topic_camera_osd = mqtt_sub_topic
|
||
method_camera_osd = "drc_camera_osd_info_push" # 对应到消息当中的method
|
||
|
||
topic_osd_info = mqtt_sub_topic
|
||
method_osd_info = "osd_info_push" # 对应到消息当中的method
|
||
await device.start()
|
||
print(1)
|
||
await asyncio.sleep(10) # device等待10秒,等待消息处理
|
||
print(2)
|
||
# 创建多个消费者任务
|
||
invade_state=False
|
||
if len(list_points) > 0 and invade_enable:
|
||
invade_state=True
|
||
tasks = [
|
||
asyncio.create_task(read_rtmp_frames(video_url, device, topic_camera_osd, method_camera_osd,
|
||
topic_osd_info, method_osd_info)),
|
||
asyncio.create_task(process_frames(detector)),
|
||
asyncio.create_task(write_results_to_rtmp(task_id, output_rtmp_url, None, list_points, invade_state)),
|
||
# asyncio.create_task(write_results_to_rtmp(output_rtmp_url)),
|
||
]
|
||
print(3)
|
||
# 侵限文件不为空,即输出当前事件
|
||
# 要么基于侵限做消息上传,要么直接基于mqtt做消息上传
|
||
model_count = len(model_configs)
|
||
if len(list_points) > 0 and invade_enable:
|
||
print(4)
|
||
tasks.append(asyncio.create_task(
|
||
cal_des_invade(task_id, mqtt, mqtt_publish_topic, list_points, model_count)))
|
||
elif output_rtmp_url is None or len(output_rtmp_url) < 1:
|
||
for _ in range(2): # 2个上传消费者
|
||
tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_pub_topic)))
|
||
# # 创建多个上传任务并行处理
|
||
# if output_rtmp_url is None or len(output_rtmp_url) <1:
|
||
# for _ in range(2): # 2个上传消费者
|
||
# tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_topic)))
|
||
|
||
# 等待所有任务完成
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
finally:
|
||
await cleanup_resources() # 清理资源
|
||
|
||
|
||
async def start_video_processing(minio_path: str, task_id: str, model_configs: List[Dict],
|
||
mqtt_ip: str, mqtt_port: int, mqtt_topic: str, output_rtmp_url: str,
|
||
invade_enable: bool, invade_file: str):
|
||
# global stop_event, frame_queue, processed_queue, executor, upload_executor
|
||
await initialize_resources() # 初始化资源
|
||
print("开起录像识别")
|
||
try:
|
||
download_path = downBigFile(minio_path)
|
||
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut_output.mp4"
|
||
# download_path = r"E:\hami\gdaq\aqm_0829-55-7_cut.mp4"
|
||
# download_path = r"E:\hami\aqm_0829-55-7_cut.mp4"
|
||
# # 近距离 1/3
|
||
# download_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4"
|
||
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
|
||
|
||
# 近距离 1/5
|
||
# download_path = r"C:\Users\14867\Downloads\DJI_20250912165859_0001_V.mp4"
|
||
|
||
# download_path = r"C:\Users\14867\Downloads\DJI_20250912202543_0001_V.mp4"
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error downloading big file: {e}", exc_info=True)
|
||
return None
|
||
if not download_path:
|
||
logger.error(f"big file is none", exc_info=True)
|
||
return
|
||
|
||
print(f"download_path路径 {os.path.abspath(download_path)}")
|
||
|
||
dir_name = os.path.dirname(download_path)
|
||
srt_name = os.path.basename(download_path) + ".srt"
|
||
srt_path = os.path.join(dir_name, srt_name)
|
||
if os.path.exists(srt_path):
|
||
os.remove(srt_path)
|
||
command = [
|
||
"ffmpeg",
|
||
"-i", download_path,
|
||
"-map", "0:s:0", # 选择第一个字幕流
|
||
"-c:s", "srt", # 强制转换为 SRT 格式(可选)
|
||
srt_path
|
||
]
|
||
# # 近距离 1/3
|
||
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912140036_0001_V.mp4.srt"
|
||
|
||
# 近距离 1/5
|
||
# srt_path = r"C:\Users\14867\Downloads\DJI_20250912165859_0001_V.mp4.srt"
|
||
try:
|
||
subprocess.run(command, check=True)
|
||
print(f"start_video_processing方法 字幕提取成功,保存至: {srt_path}")
|
||
except subprocess.CalledProcessError as e:
|
||
print(f"start_video_processing方法错误: FFmpeg 执行失败 - {e}")
|
||
return
|
||
except FileNotFoundError:
|
||
print("start_video_processing方法错误: 未安装 FFmpeg,请先安装并添加到系统路径。")
|
||
return
|
||
print(f"srt_path路径 {os.path.abspath(srt_path)}")
|
||
try:
|
||
list_points = [] # 二维数组,里面的一维数组就是面
|
||
|
||
if invade_file:
|
||
invade_file_path = downFile(invade_file)
|
||
if invade_file_path is None:
|
||
print(f"invade_file is None task_id:{task_id}")
|
||
return
|
||
|
||
with open(invade_file_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
# 提取多边形坐标
|
||
features = data.get('features', [])
|
||
if not features:
|
||
print("没有找到有效的要素数据")
|
||
return
|
||
|
||
# 获取第一个多边形的坐标(假设只有一个多边形)
|
||
|
||
for polygon in features:
|
||
|
||
# polygon = features[0]
|
||
geometry = polygon.get('geometry', {})
|
||
coordinates = geometry.get('coordinates', [])
|
||
|
||
# 提取经纬度点(忽略Z值)
|
||
points = []
|
||
for key, coord in enumerate(coordinates[0]): # 取第一个环的坐标
|
||
# 只取经纬度,忽略第三个值(高度) 注意这里是反序,不然后面convert_points_to_utm 会报错
|
||
points.append(Point(coord[1], coord[0], coord[2], key))
|
||
list_points.append(points)
|
||
|
||
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
|
||
|
||
# 初始化检测器
|
||
# detector = MultiYOLODetector(model_configs)
|
||
detector = MultiYOLODetector_TrackId(model_configs)
|
||
|
||
mqtt_publish_topic = mqtt_topic
|
||
|
||
# 读取视频获取帧率
|
||
cap = cv2.VideoCapture(download_path)
|
||
if not cap.isOpened():
|
||
print("Error: Could not open video.")
|
||
return
|
||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||
cap.release()
|
||
invade_state = False
|
||
if len(list_points) > 0 and invade_enable:
|
||
invade_state = True
|
||
model_count = len(model_configs)
|
||
tasks = [
|
||
asyncio.create_task(read_video_frames(task_id, mqtt, mqtt_publish_topic, download_path, srt_path)),
|
||
asyncio.create_task(process_frames(detector)),
|
||
asyncio.create_task(write_results_to_rtmp(task_id, output_rtmp_url, fps, list_points, invade_state)),
|
||
]
|
||
# cal_des_invade(task_id, mqtt, mqtt_publish_topic, device, topic_camera_osd, method_camera_osd,
|
||
# topic_osd_info, method_osd_info, list_points)
|
||
# # 侵限文件不为空,即输出当前事件
|
||
if len(list_points) > 0 and invade_enable:
|
||
print("基于录像开起侵限识别")
|
||
tasks.append(asyncio.create_task(
|
||
cal_des_invade(task_id, mqtt, mqtt_publish_topic, list_points, model_count)))
|
||
|
||
# 创建多个上传任务并行处理
|
||
for _ in range(2): # 2个上传消费者
|
||
tasks.append(asyncio.create_task(send_frame_to_s3_mq(task_id, mqtt, mqtt_topic)))
|
||
|
||
# 等待所有任务完成
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
finally:
|
||
await cleanup_resources() # 清理资源
|