ai_project_v1/yolo/detect/multi_yolo_trt_detect_track.py

942 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ctypes
import os
import shutil
import random
import sys
import threading
import time
import cv2
import numpy as np
import pycuda.autoinit # noqa: F401
import pycuda.driver as cuda
import tensorrt as trt
from typing import List, Dict, Tuple, Any
from concurrent.futures import ThreadPoolExecutor
import asyncio
from dataclasses import dataclass
import logging
from middleware.entity.detection import DetectionResult, DetectionResultList
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("MultiYoloTRT")
# 置信度阈值 - 只有高于此值的检测结果才会显示和追踪
CONF_THRESH = 0.5
IOU_THRESHOLD = 0.4
POSE_NUM = 17 * 3
DET_NUM = 6
SEG_NUM = 32
OBB_NUM = 1
# 测试配置
MAX_TEST_FRAMES = 100 # 最大测试帧数
# 全局变量
categories = ["building"] # 根据你的实际类别修改
# 配置类
class ConfigTrt:
MAX_WORKERS = 4
CONF_THRESHOLD = 0.5
# 修复的追踪器类 - 防止幽灵框
class Tracker:
def __init__(self, max_age=10, min_hits=3, iou_threshold=0.5, conf_threshold=0.5):
self.tracks = []
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.conf_threshold = conf_threshold
self.frame_count = 0
self.next_id = 1
def update(self, detections):
"""
修复的追踪器更新方法,防止幽灵框
"""
self.frame_count += 1
# 过滤低置信度检测结果
filtered_detections = [det for det in detections if det[4] >= self.conf_threshold]
# 步骤1: 增加所有追踪器的年龄
for track in self.tracks:
track['age'] += 1
track['active'] = False # 标记为未激活
# 步骤2: 如果有检测结果,进行匹配
if len(filtered_detections) > 0:
# 匹配检测和追踪
matched_pairs = self.simple_matching(filtered_detections)
# 处理匹配的追踪器
matched_det_indices = set()
matched_track_indices = set()
for det_idx, track_idx in matched_pairs:
if det_idx < len(filtered_detections) and track_idx < len(self.tracks):
det = filtered_detections[det_idx]
track = self.tracks[track_idx]
# 更新追踪器
track['position'] = det[:4]
track['score'] = det[4]
track['class_id'] = det[5]
track['hits'] += 1
track['age'] = 0 # 重置年龄
track['active'] = True # 标记为激活
matched_det_indices.add(det_idx)
matched_track_indices.add(track_idx)
# 处理未匹配的检测 - 创建新追踪器
for i, det in enumerate(filtered_detections):
if i not in matched_det_indices:
# 检查是否与现有追踪器有足够大的重叠
is_new_object = True
for track in self.tracks:
iou = self.calculate_iou(det[:4], track['position'])
if iou > 0.1: # 如果有一定重叠,可能是同一个目标
is_new_object = False
break
if is_new_object:
self.tracks.append({
'track_id': self.next_id,
'position': det[:4],
'score': det[4],
'class_id': det[5],
'hits': 1,
'age': 0,
'active': True
})
self.next_id += 1
# 步骤3: 删除过期的追踪器
self.tracks = [track for track in self.tracks if track['age'] <= self.max_age]
# 步骤4: 返回结果 - 只返回激活的追踪器
result = []
for track in self.tracks:
# 只返回已确认且在当前帧激活的追踪器
if track['hits'] >= self.min_hits and track['active']:
result.append((*track['position'], track['score'], track['class_id'], track['track_id']))
return result
def simple_matching(self, detections):
"""
简单的IOU匹配算法
"""
matches = []
used_detections = set()
used_tracks = set()
# 为每个检测寻找最佳匹配的追踪器
for i, det in enumerate(detections):
best_iou = self.iou_threshold
best_track_idx = -1
for j, track in enumerate(self.tracks):
if j in used_tracks:
continue
iou = self.calculate_iou(det[:4], track['position'])
if iou > best_iou:
best_iou = iou
best_track_idx = j
if best_track_idx != -1:
matches.append((i, best_track_idx))
used_detections.add(i)
used_tracks.add(best_track_idx)
return matches
def calculate_iou(self, box1, box2):
"""计算两个框的IOU"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
if x2 < x1 or y2 < y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
return intersection / (area1 + area2 - intersection + 1e-6)
def plot_one_box(x, img, color=None, label=None, line_thickness=None, track_id=None):
tl = (
line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
)
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
# 添加追踪ID显示
if track_id is not None:
label = f"ID:{track_id} " + (label or "")
if label:
tf = max(tl - 1, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)
cv2.putText(
img,
label,
(c1[0], c1[1] - 2),
0,
tl / 3,
[225, 255, 255],
thickness=tf,
lineType=cv2.LINE_AA,
)
class YoLo11TRT(object):
def __init__(self, engine_file_path):
print(f"[INFO] 初始化YoLo11TRT引擎文件: {engine_file_path}")
# 只在这里 push 一次 CUDA 上下文
self.ctx = cuda.Device(0).make_context()
stream = cuda.Stream()
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(TRT_LOGGER)
with open(engine_file_path, "rb") as f:
self.engine = runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()
self.stream = stream
# 分配内存
self._allocate_buffers()
# 初始化修复的追踪器
self.tracker = Tracker(max_age=10, min_hits=3, iou_threshold=0.5, conf_threshold=CONF_THRESH)
print("[INFO] YoLo11TRT初始化完成")
def _allocate_buffers(self):
"""分配输入输出缓冲区"""
self.host_inputs = []
self.cuda_inputs = []
self.host_outputs = []
self.cuda_outputs = []
self.bindings = []
for i in range(self.engine.num_bindings):
binding_name = self.engine.get_binding_name(i)
shape = self.engine.get_binding_shape(i)
dtype = trt.nptype(self.engine.get_binding_dtype(i))
# 计算大小
size = trt.volume(shape)
if size <= 0:
size = 1
# 分配主机内存
host_mem = cuda.pagelocked_empty(size, dtype)
# 分配设备内存
cuda_mem = cuda.mem_alloc(host_mem.nbytes)
self.bindings.append(int(cuda_mem))
if self.engine.binding_is_input(i):
self.input_w = self.engine.get_binding_shape(i)[-1]
self.input_h = self.engine.get_binding_shape(i)[-2]
self.batch_size = self.engine.get_binding_shape(i)[0]
self.host_inputs.append(host_mem)
self.cuda_inputs.append(cuda_mem)
print(f"[INFO] 输入绑定 {i}: {binding_name}, 形状: {shape}, 类型: {dtype}")
else:
self.host_outputs.append(host_mem)
self.cuda_outputs.append(cuda_mem)
print(f"[INFO] 输出绑定 {i}: {binding_name}, 形状: {shape}, 类型: {dtype}")
# 计算输出长度
if self.host_outputs:
self.det_output_length = self.host_outputs[0].shape[0]
else:
self.det_output_length = 0
print(f"[INFO] 缓冲区分配完成batch_size: {self.batch_size}, det_output_length: {self.det_output_length}")
def infer(self, raw_image_generator):
"""执行推理 - 线程安全版本"""
# 确保在当前线程有CUDA上下文
self.ctx.push()
try:
start = time.time()
stream = self.stream
context = self.context
host_inputs = self.host_inputs
cuda_inputs = self.cuda_inputs
host_outputs = self.host_outputs
cuda_outputs = self.cuda_outputs
bindings = self.bindings
batch_image_raw = []
batch_origin_h = []
batch_origin_w = []
batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w])
for i, image_raw in enumerate(raw_image_generator):
input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw)
batch_image_raw.append(image_raw)
batch_origin_h.append(origin_h)
batch_origin_w.append(origin_w)
np.copyto(batch_input_image[i], input_image)
batch_input_image = np.ascontiguousarray(batch_input_image)
np.copyto(host_inputs[0], batch_input_image.ravel())
cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
stream.synchronize()
output = host_outputs[0]
frame_result = {
"box_list": [],
"score_list": [],
"class_id_list": [],
"track_id_list": [],
}
for i in range(self.batch_size):
result_boxes, result_scores, result_classid = self.post_process(
output[i * self.det_output_length: (i + 1) * self.det_output_length],
batch_origin_h[i], batch_origin_w[i]
)
# 准备检测结果用于追踪
detections = []
for j in range(len(result_boxes)):
box = result_boxes[j]
detections.append((float(box[0]),float(box[1]), float(box[2]), float(box[3]), float(result_scores[j]), result_classid[j]))
# 更新追踪器并获取带追踪ID的结果
tracked_detections = self.tracker.update(detections)
# 绘制追踪结果
for det in tracked_detections:
box = det[:4]
score = det[4]
class_id = det[5]
track_id = det[6] if len(det) > 6 else None
frame_result["box_list"].append(box)
frame_result["score_list"].append(score)
frame_result["class_id_list"].append(class_id)
if len(det) > 6:
frame_result["track_id_list"].append(track_id)
else:
frame_result["track_id_list"].append(-1) # 如果track_id 无效,就增加一个无效值
# plot_one_box(
# box,
# batch_image_raw[i],
# label="{}:{:.2f}".format(categories[int(class_id)], score),
# track_id=track_id
# )
end = time.time()
return frame_result, batch_image_raw, end - start
except Exception as e:
print(f"[ERROR] 推理过程中出错: {e}")
raise
finally:
# 确保弹出上下文
self.ctx.pop()
def destroy(self):
print("[INFO] 开始清理CUDA资源...")
try:
# 确保清空上下文栈
self._clear_cuda_context_stack()
# 正确释放顺序:内存 → 流 → 上下文 → 引擎
self._release_all_resources()
self._pop_cuda_context()
print("[INFO] CUDA资源清理完成")
except Exception as e:
print(f"[ERROR] 资源清理异常: {e}")
def _clear_cuda_context_stack(self, max_attempts=10):
"""安全清空PyCUDA上下文栈"""
try:
attempts = 0
while attempts < max_attempts:
try:
current_ctx = cuda.Context.get_current()
if current_ctx:
current_ctx.pop()
print(f"[INFO] 弹出当前上下文 (尝试 {attempts + 1}/{max_attempts})")
else:
break
except cuda.LogicError:
break
attempts += 1
print("[INFO] CUDA上下文栈已清空")
except Exception as e:
print(f"[WARNING] 上下文栈清理异常: {e}")
def _release_all_resources(self):
"""严格资源释放顺序控制"""
# 1. 释放所有设备内存
if hasattr(self, 'cuda_outputs') and self.cuda_outputs:
for i, cuda_mem in enumerate(self.cuda_outputs):
try:
if cuda_mem:
cuda_mem.free()
print(f"[INFO] 输出内存 {i} 已释放")
except Exception as e:
print(f"[WARNING] 输出内存释放异常 {i}: {e}")
if hasattr(self, 'cuda_inputs') and self.cuda_inputs:
for i, cuda_mem in enumerate(self.cuda_inputs):
try:
if cuda_mem:
cuda_mem.free()
print(f"[INFO] 输入内存 {i} 已释放")
except Exception as e:
print(f"[WARNING] 输入内存释放异常 {i}: {e}")
# 2. 销毁所有CUDA资源
if hasattr(self, 'stream') and self.stream:
try:
self.stream.synchronize()
del self.stream
print("[INFO] CUDA流已销毁")
except Exception as e:
print(f"[WARNING] 流销毁异常: {e}")
# 3. 销毁执行上下文
if hasattr(self, 'context') and self.context:
try:
del self.context
print("[INFO] 执行上下文已销毁")
except Exception as e:
print(f"[WARNING] 上下文销毁异常: {e}")
# 4. 销毁引擎
if hasattr(self, 'engine') and self.engine:
try:
del self.engine
print("[INFO] TensorRT引擎已销毁")
except Exception as e:
print(f"[WARNING] 引擎销毁异常: {e}")
def _pop_cuda_context(self):
"""安全弹出CUDA上下文"""
if hasattr(self, 'ctx') and self.ctx:
try:
# 确保只弹出当前上下文
if self.ctx == cuda.Context.get_current():
self.ctx.pop()
self.ctx.detach()
print("[INFO] CUDA上下文已弹出")
except Exception as e:
print(f"[WARNING] 上下文弹出异常: {e}")
def preprocess_image(self, raw_bgr_image):
image_raw = raw_bgr_image
h, w, c = image_raw.shape
image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
r_w = self.input_w / w
r_h = self.input_h / h
if r_h > r_w:
tw = self.input_w
th = int(r_w * h)
tx1 = tx2 = 0
ty1 = int((self.input_h - th) / 2)
ty2 = self.input_h - th - ty1
else:
tw = int(r_h * w)
th = self.input_h
tx1 = int((self.input_w - tw) / 2)
tx2 = self.input_w - tw - tx1
ty1 = ty2 = 0
image = cv2.resize(image, (tw, th))
image = cv2.copyMakeBorder(image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (128, 128, 128))
image = image.astype(np.float32)
image /= 255.0
image = np.transpose(image, [2, 0, 1])
image = np.expand_dims(image, axis=0)
image = np.ascontiguousarray(image)
return image, image_raw, h, w
def xywh2xyxy(self, origin_h, origin_w, x):
y = np.zeros_like(x)
r_w = self.input_w / origin_w
r_h = self.input_h / origin_h
if r_h > r_w:
y[:, 0] = x[:, 0]
y[:, 2] = x[:, 2]
y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2
y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2
y /= r_w
else:
y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2
y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2
y[:, 1] = x[:, 1]
y[:, 3] = x[:, 3]
y /= r_h
return y
def post_process(self, output, origin_h, origin_w):
num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM
num = int(output[0])
if num <= 0:
return np.array([]), np.array([]), np.array([])
pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
result_boxes = boxes[:, :4] if len(boxes) else np.array([])
result_scores = boxes[:, 4] if len(boxes) else np.array([])
result_classid = boxes[:, 5] if len(boxes) else np.array([])
return result_boxes, result_scores, result_classid
def bbox_iou(self, box1, box2, x1y1x2y2=True):
if not x1y1x2y2:
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
else:
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
inter_rect_x2 = np.minimum(b1_x2, b2_x2)
inter_rect_y2 = np.minimum(b1_y2, b2_y2)
inter_area = (np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) *
np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None))
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
return iou
def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
# 首先过滤掉低置信度的检测结果
boxes = prediction[prediction[:, 4] >= conf_thres]
if len(boxes) == 0:
return np.array([])
boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1)
boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1)
boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1)
confs = boxes[:, 4]
boxes = boxes[np.argsort(-confs)]
keep_boxes = []
while boxes.shape[0]:
large_overlap = self.bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
label_match = boxes[0, -1] == boxes[:, -1]
invalid = large_overlap & label_match
keep_boxes.append(boxes[0])
boxes = boxes[~invalid]
boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
return boxes
class ThreadSafeYoLo11TRT(YoLo11TRT):
"""线程安全的YoLo11TRT包装器"""
def __init__(self, engine_file_path, plugin_library=None):
# 加载插件库
if plugin_library:
try:
ctypes.CDLL(plugin_library)
logger.info(f"插件库加载成功: {plugin_library}")
except Exception as e:
logger.warning(f"加载插件库失败: {e}")
# 使用锁确保线程安全
self._lock = threading.Lock()
super().__init__(engine_file_path)
def infer(self, raw_image_generator):
"""线程安全的推理方法"""
with self._lock:
return super().infer(raw_image_generator)
class MultiYoloTrtDetectorTrackId:
"""多模型并行检测器使用线程安全的TensorRT推理类"""
def __init__(self, model_configs: List[Dict]):
self.models = []
self.class_maps = []
self.executor = ThreadPoolExecutor(max_workers=ConfigTrt.MAX_WORKERS)
self.allowed_classes = []
self.conf = ConfigTrt.CONF_THRESHOLD
self.model_cls = []
self.chinese_label = []
self.list_func_id = []
self.func_id = -1
self.list_class_names = []
self.list_para_invade_enable = []
self.model_configs = []
# 为每个配置创建独立的ThreadSafeYoLo11TRT实例
for config in model_configs:
try:
# model_path = config.get('path', '')
model_path = config.get('engine_path', '')
plugin_library = config.get('so_path', '')
cls_map = config.get('cls_map', {})
allowed = config.get('allowed_classes', None)
# model_path = r"/home/beidou/test0623/test_tensorrt/tensorrtx-master/yolo11/build/build.engine"
# plugin_library = r"/home/beidou/test0623/test_tensorrt/tensorrtx-master/yolo11/build/libmyplugins.so"
# 创建线程安全的模型实例
model = ThreadSafeYoLo11TRT(model_path, plugin_library)
self.models.append(model)
# 存储配置参数
self.class_maps.append(cls_map)
self.allowed_classes.append(allowed)
self.model_cls.append(config.get('cls_index', True))
self.chinese_label.append(config.get('chinese_label', {}))
self.list_func_id.append(config.get('list_func_id', -11))
self.func_id = config.get('func_id', True)
self.conf = config.get('config_conf', ConfigTrt.CONF_THRESHOLD)
self.list_class_names.append(config.get('class_names', ['unknown']))
self.list_para_invade_enable.append(config.get('para_invade_enable', False))
self.model_configs.append(config)
logger.info(f"成功加载模型: {model_path}")
except Exception as e:
logger.error(f"加载模型失败 {config.get('path', 'unknown')}: {e}")
continue
if not self.models:
raise RuntimeError("没有可用的模型实例")
logger.info(f"多模型检测器初始化完成,共加载 {len(self.models)} 个模型")
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
"""异步调用多模型预测"""
if not self.models:
return [], [], []
loop = asyncio.get_running_loop()
def _predict_single_model(model_idx: int, image: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
"""单个模型的预测函数"""
try:
if model_idx >= len(self.models):
return [], DetectionResultList([], [], [], [], []), {}
model = self.models[model_idx]
config = self.model_configs[model_idx]
# 获取配置参数
cls_map = self.class_maps[model_idx]
allowed = self.allowed_classes[model_idx]
model_cls_index = self.model_cls[model_idx]
model_chinese_label = self.chinese_label[model_idx]
model_list_func_id = self.list_func_id[model_idx]
func_id = self.func_id
model_class_names = self.list_class_names[model_idx]
para_invade_enable = self.list_para_invade_enable[model_idx]
# 执行推理
frame_result, batch_image_raw, infer_time = model.infer([image])
# 处理检测结果
detections = []
detection_result_list = DetectionResultList([], [], [], [], [])
# 模型参数
model_para = {
"cls_map": cls_map,
"model_chinese_labe": model_chinese_label,
"model_cls_index": model_cls_index,
"model_list_func_id": model_list_func_id,
"func_id": func_id,
"model_class_names": model_class_names,
"para_invade_enable": para_invade_enable
}
# 解析检测结果
box_list = frame_result.get("box_list", [])
score_list = frame_result.get("score_list", [])
class_id_list = frame_result.get("class_id_list", [])
track_id_list = frame_result.get("track_id_list", [])
for i in range(len(box_list)):
try:
if i >= len(score_list) or i >= len(class_id_list):
continue
box = box_list[i]
score = score_list[i]
class_id = int(class_id_list[i]) # 修复确保class_id是整数
track_id = track_id_list[i] if i < len(track_id_list) else -1
# 过滤低置信度检测结果
if score < self.conf:
continue
# 安全获取类别名称
class_name = "unknown"
if isinstance(model_class_names, list) and class_id < len(model_class_names):
class_name = model_class_names[class_id]
elif isinstance(model_class_names, dict):
class_name = model_class_names.get(class_id, "unknown")
# 创建DetectionResult对象
detection = DetectionResult(
bbox=box,
class_id=class_id,
class_name=class_name,
confidence=score,
track_id=track_id
)
detections.append(detection)
# 填充DetectionResultList
detection_result_list.boxes.append(box)
detection_result_list.clss.append(class_id)
detection_result_list.clss_name.append(str(class_id))
detection_result_list.confs.append(score)
detection_result_list.track_ids.append(track_id)
except Exception as e:
logger.error(f"处理检测结果时出错: {e}")
continue
logger.debug(f"模型 {model_idx} 检测到 {len(detections)} 个目标")
return detections, detection_result_list, model_para
except Exception as e:
logger.error(f"模型 {model_idx} 预测错误: {e}")
return [], DetectionResultList([], [], [], [], []), {}
# 创建并行任务
futures = [
loop.run_in_executor(
self.executor,
_predict_single_model,
model_idx,
frame.copy()
)
for model_idx in range(len(self.models))
]
# 等待所有任务完成
results = await asyncio.gather(*futures, return_exceptions=True)
# 合并结果
all_detections = []
all_detection_lists = []
all_model_paras = []
for result in results:
if isinstance(result, Exception):
logger.error(f"模型预测任务异常: {result}")
continue
detections, detection_list, model_para = result
all_detections.extend(detections)
all_detection_lists.append(detection_list)
all_model_paras.append(model_para)
logger.info(f"所有模型共检测到 {len(all_detections)} 个目标")
return all_detections, all_detection_lists, all_model_paras
def predict_sync(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
"""同步版本的多模型预测"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.predict(frame))
finally:
loop.close()
def destroy(self):
"""清理资源"""
for model in self.models:
try:
model.destroy()
logger.info("模型资源已清理")
except Exception as e:
logger.error(f"清理模型资源时出错: {e}")
self.executor.shutdown(wait=True)
logger.info("线程池已关闭")
# ====== 主程序 ======
def main():
for i in range(10):
# 示例配置
model_configs = [
{
'path': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\build.engine",
'plugin_library': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\myplugins.dll",
'cls_map': {0: 'building'},
'allowed_classes': None,
'cls_index': True,
'chinese_label': {0: '建筑物'},
'list_func_id': -11,
'func_id': True,
'class_names': ['building'],
'para_invade_enable': False,
'config_conf': 0.5
},
]
# 测试多模型检测器
try:
# 创建多模型检测器
multi_detector = MultiYoloTrtDetectorTrackId(model_configs)
# 视频文件路径
video_path = r"E:\yolo-dataset\test_video\shimian_invade\DJI_20250910142510_0001_V.mp4"
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"[ERROR] 无法打开视频文件: {video_path}")
return
# 性能统计
frame_count = 0
total_frames = 0
total_infer_time = 0
fps_smoothed = 0
last_time = time.time()
print(f"[INFO] 开始多模型并行检测测试,将在处理 {MAX_TEST_FRAMES} 帧后退出")
print(f"[INFO] 视频文件: {video_path}")
print(f"[INFO] 模型数量: {len(multi_detector.models)}")
while True:
ret, frame = cap.read()
if not ret:
print("[INFO] 视频播放完毕")
break
frame_count += 1
total_frames += 1
# 检查是否达到最大测试帧数
if total_frames >= MAX_TEST_FRAMES:
print(f"[INFO] 已达到最大测试帧数 {MAX_TEST_FRAMES},退出测试")
break
start_time = time.time()
# 执行多模型检测
detections, detection_lists, model_paras = multi_detector.predict_sync(frame)
infer_time = time.time() - start_time
total_infer_time += infer_time
# 计算FPS
current_time = time.time()
elapsed = current_time - last_time
if elapsed >= 1.0:
fps = frame_count / elapsed
fps_smoothed = fps
frame_count = 0
last_time = current_time
# 在图像上绘制检测结果
display_frame = frame.copy()
# 绘制所有检测结果
for i, detection in enumerate(detections):
box = detection.bbox
class_name = detection.class_name
confidence = detection.confidence
track_id = detection.track_id
# 绘制边界框
plot_one_box(
box,
display_frame,
label=f"{class_name}:{confidence:.2f}",
track_id=track_id
)
# 显示性能信息
fps_text = f"FPS: {fps_smoothed:.2f}"
time_text = f"Frame Time: {infer_time * 1000:.2f}ms"
frame_count_text = f"Frame: {total_frames}/{MAX_TEST_FRAMES}"
detection_text = f"Detections: {len(detections)}"
model_text = f"Models: {len(multi_detector.models)}"
cv2.putText(display_frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(display_frame, time_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(display_frame, frame_count_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(display_frame, detection_text, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(display_frame, model_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# 显示结果
cv2.imshow("Multi-Model YOLO Detection - CUDA Context Test", display_frame)
# 每10帧输出一次进度
if total_frames % 10 == 0:
avg_infer_time = total_infer_time / total_frames
print(
f"[INFO] 已处理 {total_frames}/{MAX_TEST_FRAMES}当前FPS: {fps_smoothed:.2f},平均推理时间: {avg_infer_time * 1000:.2f}ms")
# 检查用户退出
if cv2.waitKey(1) & 0xFF == ord('q'):
print("[INFO] 用户请求退出")
break
# 计算总体性能
if total_frames > 0:
avg_fps = total_frames / total_infer_time
avg_infer_time = total_infer_time / total_frames
print(f"\n[INFO] 测试完成摘要:")
print(f"[INFO] - 总处理帧数: {total_frames}")
print(f"[INFO] - 总耗时: {total_infer_time:.2f}")
print(f"[INFO] - 平均FPS: {avg_fps:.2f}")
print(f"[INFO] - 平均推理时间: {avg_infer_time * 1000:.2f}ms")
print(f"[INFO] - 最大测试帧数限制: {MAX_TEST_FRAMES}")
print(f"[INFO] - 模型数量: {len(multi_detector.models)}")
cap.release()
cv2.destroyAllWindows()
# 清理资源
print("[INFO] 正在清理资源...")
multi_detector.destroy()
print("[INFO] 资源清理完成")
print("[INFO] CUDA上下文管理测试结束")
except Exception as e:
print(f"[ERROR] 多模型检测测试失败: {e}")
import traceback
traceback.print_exc()
# 确保清理资源
try:
if 'multi_detector' in locals():
multi_detector.destroy()
except:
pass
if __name__ == "__main__":
main()