942 lines
36 KiB
Python
942 lines
36 KiB
Python
import ctypes
|
||
import os
|
||
import shutil
|
||
import random
|
||
import sys
|
||
import threading
|
||
import time
|
||
import cv2
|
||
import numpy as np
|
||
import pycuda.autoinit # noqa: F401
|
||
import pycuda.driver as cuda
|
||
import tensorrt as trt
|
||
from typing import List, Dict, Tuple, Any
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import asyncio
|
||
from dataclasses import dataclass
|
||
import logging
|
||
|
||
from middleware.entity.detection import DetectionResult, DetectionResultList
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger("MultiYoloTRT")
|
||
|
||
# 置信度阈值 - 只有高于此值的检测结果才会显示和追踪
|
||
CONF_THRESH = 0.5
|
||
IOU_THRESHOLD = 0.4
|
||
POSE_NUM = 17 * 3
|
||
DET_NUM = 6
|
||
SEG_NUM = 32
|
||
OBB_NUM = 1
|
||
|
||
# 测试配置
|
||
MAX_TEST_FRAMES = 100 # 最大测试帧数
|
||
|
||
# 全局变量
|
||
categories = ["building"] # 根据你的实际类别修改
|
||
|
||
|
||
# 配置类
|
||
class ConfigTrt:
|
||
MAX_WORKERS = 4
|
||
CONF_THRESHOLD = 0.5
|
||
|
||
|
||
# 修复的追踪器类 - 防止幽灵框
|
||
class Tracker:
|
||
def __init__(self, max_age=10, min_hits=3, iou_threshold=0.5, conf_threshold=0.5):
|
||
self.tracks = []
|
||
self.max_age = max_age
|
||
self.min_hits = min_hits
|
||
self.iou_threshold = iou_threshold
|
||
self.conf_threshold = conf_threshold
|
||
self.frame_count = 0
|
||
self.next_id = 1
|
||
|
||
def update(self, detections):
|
||
"""
|
||
修复的追踪器更新方法,防止幽灵框
|
||
"""
|
||
self.frame_count += 1
|
||
|
||
# 过滤低置信度检测结果
|
||
filtered_detections = [det for det in detections if det[4] >= self.conf_threshold]
|
||
|
||
# 步骤1: 增加所有追踪器的年龄
|
||
for track in self.tracks:
|
||
track['age'] += 1
|
||
track['active'] = False # 标记为未激活
|
||
|
||
# 步骤2: 如果有检测结果,进行匹配
|
||
if len(filtered_detections) > 0:
|
||
# 匹配检测和追踪
|
||
matched_pairs = self.simple_matching(filtered_detections)
|
||
|
||
# 处理匹配的追踪器
|
||
matched_det_indices = set()
|
||
matched_track_indices = set()
|
||
|
||
for det_idx, track_idx in matched_pairs:
|
||
if det_idx < len(filtered_detections) and track_idx < len(self.tracks):
|
||
det = filtered_detections[det_idx]
|
||
track = self.tracks[track_idx]
|
||
|
||
# 更新追踪器
|
||
track['position'] = det[:4]
|
||
track['score'] = det[4]
|
||
track['class_id'] = det[5]
|
||
track['hits'] += 1
|
||
track['age'] = 0 # 重置年龄
|
||
track['active'] = True # 标记为激活
|
||
|
||
matched_det_indices.add(det_idx)
|
||
matched_track_indices.add(track_idx)
|
||
|
||
# 处理未匹配的检测 - 创建新追踪器
|
||
for i, det in enumerate(filtered_detections):
|
||
if i not in matched_det_indices:
|
||
# 检查是否与现有追踪器有足够大的重叠
|
||
is_new_object = True
|
||
for track in self.tracks:
|
||
iou = self.calculate_iou(det[:4], track['position'])
|
||
if iou > 0.1: # 如果有一定重叠,可能是同一个目标
|
||
is_new_object = False
|
||
break
|
||
|
||
if is_new_object:
|
||
self.tracks.append({
|
||
'track_id': self.next_id,
|
||
'position': det[:4],
|
||
'score': det[4],
|
||
'class_id': det[5],
|
||
'hits': 1,
|
||
'age': 0,
|
||
'active': True
|
||
})
|
||
self.next_id += 1
|
||
|
||
# 步骤3: 删除过期的追踪器
|
||
self.tracks = [track for track in self.tracks if track['age'] <= self.max_age]
|
||
|
||
# 步骤4: 返回结果 - 只返回激活的追踪器
|
||
result = []
|
||
for track in self.tracks:
|
||
# 只返回已确认且在当前帧激活的追踪器
|
||
if track['hits'] >= self.min_hits and track['active']:
|
||
result.append((*track['position'], track['score'], track['class_id'], track['track_id']))
|
||
|
||
return result
|
||
|
||
def simple_matching(self, detections):
|
||
"""
|
||
简单的IOU匹配算法
|
||
"""
|
||
matches = []
|
||
used_detections = set()
|
||
used_tracks = set()
|
||
|
||
# 为每个检测寻找最佳匹配的追踪器
|
||
for i, det in enumerate(detections):
|
||
best_iou = self.iou_threshold
|
||
best_track_idx = -1
|
||
|
||
for j, track in enumerate(self.tracks):
|
||
if j in used_tracks:
|
||
continue
|
||
|
||
iou = self.calculate_iou(det[:4], track['position'])
|
||
if iou > best_iou:
|
||
best_iou = iou
|
||
best_track_idx = j
|
||
|
||
if best_track_idx != -1:
|
||
matches.append((i, best_track_idx))
|
||
used_detections.add(i)
|
||
used_tracks.add(best_track_idx)
|
||
|
||
return matches
|
||
|
||
def calculate_iou(self, box1, box2):
|
||
"""计算两个框的IOU"""
|
||
x1 = max(box1[0], box2[0])
|
||
y1 = max(box1[1], box2[1])
|
||
x2 = min(box1[2], box2[2])
|
||
y2 = min(box1[3], box2[3])
|
||
|
||
if x2 < x1 or y2 < y1:
|
||
return 0.0
|
||
|
||
intersection = (x2 - x1) * (y2 - y1)
|
||
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
||
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
||
|
||
return intersection / (area1 + area2 - intersection + 1e-6)
|
||
|
||
|
||
def plot_one_box(x, img, color=None, label=None, line_thickness=None, track_id=None):
|
||
tl = (
|
||
line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
|
||
)
|
||
color = color or [random.randint(0, 255) for _ in range(3)]
|
||
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
||
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
||
|
||
# 添加追踪ID显示
|
||
if track_id is not None:
|
||
label = f"ID:{track_id} " + (label or "")
|
||
|
||
if label:
|
||
tf = max(tl - 1, 1)
|
||
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
||
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
||
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)
|
||
cv2.putText(
|
||
img,
|
||
label,
|
||
(c1[0], c1[1] - 2),
|
||
0,
|
||
tl / 3,
|
||
[225, 255, 255],
|
||
thickness=tf,
|
||
lineType=cv2.LINE_AA,
|
||
)
|
||
|
||
|
||
class YoLo11TRT(object):
|
||
def __init__(self, engine_file_path):
|
||
print(f"[INFO] 初始化YoLo11TRT,引擎文件: {engine_file_path}")
|
||
|
||
# 只在这里 push 一次 CUDA 上下文
|
||
self.ctx = cuda.Device(0).make_context()
|
||
stream = cuda.Stream()
|
||
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
|
||
runtime = trt.Runtime(TRT_LOGGER)
|
||
|
||
with open(engine_file_path, "rb") as f:
|
||
self.engine = runtime.deserialize_cuda_engine(f.read())
|
||
|
||
self.context = self.engine.create_execution_context()
|
||
self.stream = stream
|
||
|
||
# 分配内存
|
||
self._allocate_buffers()
|
||
|
||
# 初始化修复的追踪器
|
||
self.tracker = Tracker(max_age=10, min_hits=3, iou_threshold=0.5, conf_threshold=CONF_THRESH)
|
||
|
||
print("[INFO] YoLo11TRT初始化完成")
|
||
|
||
def _allocate_buffers(self):
|
||
"""分配输入输出缓冲区"""
|
||
self.host_inputs = []
|
||
self.cuda_inputs = []
|
||
self.host_outputs = []
|
||
self.cuda_outputs = []
|
||
self.bindings = []
|
||
|
||
for i in range(self.engine.num_bindings):
|
||
binding_name = self.engine.get_binding_name(i)
|
||
shape = self.engine.get_binding_shape(i)
|
||
dtype = trt.nptype(self.engine.get_binding_dtype(i))
|
||
|
||
# 计算大小
|
||
size = trt.volume(shape)
|
||
if size <= 0:
|
||
size = 1
|
||
|
||
# 分配主机内存
|
||
host_mem = cuda.pagelocked_empty(size, dtype)
|
||
# 分配设备内存
|
||
cuda_mem = cuda.mem_alloc(host_mem.nbytes)
|
||
|
||
self.bindings.append(int(cuda_mem))
|
||
|
||
if self.engine.binding_is_input(i):
|
||
self.input_w = self.engine.get_binding_shape(i)[-1]
|
||
self.input_h = self.engine.get_binding_shape(i)[-2]
|
||
self.batch_size = self.engine.get_binding_shape(i)[0]
|
||
self.host_inputs.append(host_mem)
|
||
self.cuda_inputs.append(cuda_mem)
|
||
print(f"[INFO] 输入绑定 {i}: {binding_name}, 形状: {shape}, 类型: {dtype}")
|
||
else:
|
||
self.host_outputs.append(host_mem)
|
||
self.cuda_outputs.append(cuda_mem)
|
||
print(f"[INFO] 输出绑定 {i}: {binding_name}, 形状: {shape}, 类型: {dtype}")
|
||
|
||
# 计算输出长度
|
||
if self.host_outputs:
|
||
self.det_output_length = self.host_outputs[0].shape[0]
|
||
else:
|
||
self.det_output_length = 0
|
||
|
||
print(f"[INFO] 缓冲区分配完成,batch_size: {self.batch_size}, det_output_length: {self.det_output_length}")
|
||
|
||
def infer(self, raw_image_generator):
|
||
"""执行推理 - 线程安全版本"""
|
||
# 确保在当前线程有CUDA上下文
|
||
self.ctx.push()
|
||
|
||
try:
|
||
start = time.time()
|
||
stream = self.stream
|
||
context = self.context
|
||
host_inputs = self.host_inputs
|
||
cuda_inputs = self.cuda_inputs
|
||
host_outputs = self.host_outputs
|
||
cuda_outputs = self.cuda_outputs
|
||
bindings = self.bindings
|
||
|
||
batch_image_raw = []
|
||
batch_origin_h = []
|
||
batch_origin_w = []
|
||
batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w])
|
||
for i, image_raw in enumerate(raw_image_generator):
|
||
input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw)
|
||
batch_image_raw.append(image_raw)
|
||
batch_origin_h.append(origin_h)
|
||
batch_origin_w.append(origin_w)
|
||
np.copyto(batch_input_image[i], input_image)
|
||
batch_input_image = np.ascontiguousarray(batch_input_image)
|
||
|
||
np.copyto(host_inputs[0], batch_input_image.ravel())
|
||
|
||
cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
|
||
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
|
||
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
|
||
stream.synchronize()
|
||
|
||
output = host_outputs[0]
|
||
frame_result = {
|
||
"box_list": [],
|
||
"score_list": [],
|
||
"class_id_list": [],
|
||
"track_id_list": [],
|
||
}
|
||
for i in range(self.batch_size):
|
||
result_boxes, result_scores, result_classid = self.post_process(
|
||
output[i * self.det_output_length: (i + 1) * self.det_output_length],
|
||
batch_origin_h[i], batch_origin_w[i]
|
||
)
|
||
# 准备检测结果用于追踪
|
||
detections = []
|
||
for j in range(len(result_boxes)):
|
||
box = result_boxes[j]
|
||
detections.append((float(box[0]),float(box[1]), float(box[2]), float(box[3]), float(result_scores[j]), result_classid[j]))
|
||
|
||
|
||
# 更新追踪器并获取带追踪ID的结果
|
||
tracked_detections = self.tracker.update(detections)
|
||
|
||
# 绘制追踪结果
|
||
for det in tracked_detections:
|
||
box = det[:4]
|
||
score = det[4]
|
||
class_id = det[5]
|
||
track_id = det[6] if len(det) > 6 else None
|
||
frame_result["box_list"].append(box)
|
||
frame_result["score_list"].append(score)
|
||
frame_result["class_id_list"].append(class_id)
|
||
if len(det) > 6:
|
||
frame_result["track_id_list"].append(track_id)
|
||
else:
|
||
frame_result["track_id_list"].append(-1) # 如果track_id 无效,就增加一个无效值
|
||
|
||
# plot_one_box(
|
||
# box,
|
||
# batch_image_raw[i],
|
||
# label="{}:{:.2f}".format(categories[int(class_id)], score),
|
||
# track_id=track_id
|
||
# )
|
||
end = time.time()
|
||
return frame_result, batch_image_raw, end - start
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 推理过程中出错: {e}")
|
||
raise
|
||
finally:
|
||
# 确保弹出上下文
|
||
self.ctx.pop()
|
||
|
||
def destroy(self):
|
||
print("[INFO] 开始清理CUDA资源...")
|
||
try:
|
||
# 确保清空上下文栈
|
||
self._clear_cuda_context_stack()
|
||
|
||
# 正确释放顺序:内存 → 流 → 上下文 → 引擎
|
||
self._release_all_resources()
|
||
self._pop_cuda_context()
|
||
print("[INFO] CUDA资源清理完成")
|
||
except Exception as e:
|
||
print(f"[ERROR] 资源清理异常: {e}")
|
||
|
||
def _clear_cuda_context_stack(self, max_attempts=10):
|
||
"""安全清空PyCUDA上下文栈"""
|
||
try:
|
||
attempts = 0
|
||
while attempts < max_attempts:
|
||
try:
|
||
current_ctx = cuda.Context.get_current()
|
||
if current_ctx:
|
||
current_ctx.pop()
|
||
print(f"[INFO] 弹出当前上下文 (尝试 {attempts + 1}/{max_attempts})")
|
||
else:
|
||
break
|
||
except cuda.LogicError:
|
||
break
|
||
attempts += 1
|
||
print("[INFO] CUDA上下文栈已清空")
|
||
except Exception as e:
|
||
print(f"[WARNING] 上下文栈清理异常: {e}")
|
||
|
||
def _release_all_resources(self):
|
||
"""严格资源释放顺序控制"""
|
||
# 1. 释放所有设备内存
|
||
if hasattr(self, 'cuda_outputs') and self.cuda_outputs:
|
||
for i, cuda_mem in enumerate(self.cuda_outputs):
|
||
try:
|
||
if cuda_mem:
|
||
cuda_mem.free()
|
||
print(f"[INFO] 输出内存 {i} 已释放")
|
||
except Exception as e:
|
||
print(f"[WARNING] 输出内存释放异常 {i}: {e}")
|
||
|
||
if hasattr(self, 'cuda_inputs') and self.cuda_inputs:
|
||
for i, cuda_mem in enumerate(self.cuda_inputs):
|
||
try:
|
||
if cuda_mem:
|
||
cuda_mem.free()
|
||
print(f"[INFO] 输入内存 {i} 已释放")
|
||
except Exception as e:
|
||
print(f"[WARNING] 输入内存释放异常 {i}: {e}")
|
||
|
||
# 2. 销毁所有CUDA资源
|
||
if hasattr(self, 'stream') and self.stream:
|
||
try:
|
||
self.stream.synchronize()
|
||
del self.stream
|
||
print("[INFO] CUDA流已销毁")
|
||
except Exception as e:
|
||
print(f"[WARNING] 流销毁异常: {e}")
|
||
|
||
# 3. 销毁执行上下文
|
||
if hasattr(self, 'context') and self.context:
|
||
try:
|
||
del self.context
|
||
print("[INFO] 执行上下文已销毁")
|
||
except Exception as e:
|
||
print(f"[WARNING] 上下文销毁异常: {e}")
|
||
|
||
# 4. 销毁引擎
|
||
if hasattr(self, 'engine') and self.engine:
|
||
try:
|
||
del self.engine
|
||
print("[INFO] TensorRT引擎已销毁")
|
||
except Exception as e:
|
||
print(f"[WARNING] 引擎销毁异常: {e}")
|
||
|
||
def _pop_cuda_context(self):
|
||
"""安全弹出CUDA上下文"""
|
||
if hasattr(self, 'ctx') and self.ctx:
|
||
try:
|
||
# 确保只弹出当前上下文
|
||
if self.ctx == cuda.Context.get_current():
|
||
self.ctx.pop()
|
||
self.ctx.detach()
|
||
print("[INFO] CUDA上下文已弹出")
|
||
except Exception as e:
|
||
print(f"[WARNING] 上下文弹出异常: {e}")
|
||
|
||
def preprocess_image(self, raw_bgr_image):
|
||
image_raw = raw_bgr_image
|
||
h, w, c = image_raw.shape
|
||
image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
|
||
r_w = self.input_w / w
|
||
r_h = self.input_h / h
|
||
if r_h > r_w:
|
||
tw = self.input_w
|
||
th = int(r_w * h)
|
||
tx1 = tx2 = 0
|
||
ty1 = int((self.input_h - th) / 2)
|
||
ty2 = self.input_h - th - ty1
|
||
else:
|
||
tw = int(r_h * w)
|
||
th = self.input_h
|
||
tx1 = int((self.input_w - tw) / 2)
|
||
tx2 = self.input_w - tw - tx1
|
||
ty1 = ty2 = 0
|
||
image = cv2.resize(image, (tw, th))
|
||
image = cv2.copyMakeBorder(image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (128, 128, 128))
|
||
image = image.astype(np.float32)
|
||
image /= 255.0
|
||
image = np.transpose(image, [2, 0, 1])
|
||
image = np.expand_dims(image, axis=0)
|
||
image = np.ascontiguousarray(image)
|
||
return image, image_raw, h, w
|
||
|
||
def xywh2xyxy(self, origin_h, origin_w, x):
|
||
y = np.zeros_like(x)
|
||
r_w = self.input_w / origin_w
|
||
r_h = self.input_h / origin_h
|
||
if r_h > r_w:
|
||
y[:, 0] = x[:, 0]
|
||
y[:, 2] = x[:, 2]
|
||
y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2
|
||
y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2
|
||
y /= r_w
|
||
else:
|
||
y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2
|
||
y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2
|
||
y[:, 1] = x[:, 1]
|
||
y[:, 3] = x[:, 3]
|
||
y /= r_h
|
||
return y
|
||
|
||
def post_process(self, output, origin_h, origin_w):
|
||
num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM
|
||
num = int(output[0])
|
||
if num <= 0:
|
||
return np.array([]), np.array([]), np.array([])
|
||
|
||
pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
|
||
boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
|
||
result_boxes = boxes[:, :4] if len(boxes) else np.array([])
|
||
result_scores = boxes[:, 4] if len(boxes) else np.array([])
|
||
result_classid = boxes[:, 5] if len(boxes) else np.array([])
|
||
return result_boxes, result_scores, result_classid
|
||
|
||
def bbox_iou(self, box1, box2, x1y1x2y2=True):
|
||
if not x1y1x2y2:
|
||
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
|
||
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
|
||
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
|
||
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
|
||
else:
|
||
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
|
||
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
|
||
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
|
||
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
|
||
inter_rect_x2 = np.minimum(b1_x2, b2_x2)
|
||
inter_rect_y2 = np.minimum(b1_y2, b2_y2)
|
||
inter_area = (np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) *
|
||
np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None))
|
||
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
|
||
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
|
||
iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
|
||
return iou
|
||
|
||
def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
|
||
# 首先过滤掉低置信度的检测结果
|
||
boxes = prediction[prediction[:, 4] >= conf_thres]
|
||
|
||
if len(boxes) == 0:
|
||
return np.array([])
|
||
|
||
boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
|
||
boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
|
||
boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1)
|
||
boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1)
|
||
boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1)
|
||
confs = boxes[:, 4]
|
||
boxes = boxes[np.argsort(-confs)]
|
||
keep_boxes = []
|
||
while boxes.shape[0]:
|
||
large_overlap = self.bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
|
||
label_match = boxes[0, -1] == boxes[:, -1]
|
||
invalid = large_overlap & label_match
|
||
keep_boxes.append(boxes[0])
|
||
boxes = boxes[~invalid]
|
||
boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
|
||
return boxes
|
||
|
||
|
||
class ThreadSafeYoLo11TRT(YoLo11TRT):
|
||
"""线程安全的YoLo11TRT包装器"""
|
||
|
||
def __init__(self, engine_file_path, plugin_library=None):
|
||
# 加载插件库
|
||
if plugin_library:
|
||
try:
|
||
ctypes.CDLL(plugin_library)
|
||
logger.info(f"插件库加载成功: {plugin_library}")
|
||
except Exception as e:
|
||
logger.warning(f"加载插件库失败: {e}")
|
||
|
||
# 使用锁确保线程安全
|
||
self._lock = threading.Lock()
|
||
super().__init__(engine_file_path)
|
||
|
||
def infer(self, raw_image_generator):
|
||
"""线程安全的推理方法"""
|
||
with self._lock:
|
||
return super().infer(raw_image_generator)
|
||
|
||
|
||
class MultiYoloTrtDetectorTrackId:
|
||
"""多模型并行检测器,使用线程安全的TensorRT推理类"""
|
||
|
||
def __init__(self, model_configs: List[Dict]):
|
||
self.models = []
|
||
self.class_maps = []
|
||
self.executor = ThreadPoolExecutor(max_workers=ConfigTrt.MAX_WORKERS)
|
||
self.allowed_classes = []
|
||
self.conf = ConfigTrt.CONF_THRESHOLD
|
||
self.model_cls = []
|
||
self.chinese_label = []
|
||
self.list_func_id = []
|
||
self.func_id = -1
|
||
self.list_class_names = []
|
||
self.list_para_invade_enable = []
|
||
self.model_configs = []
|
||
|
||
# 为每个配置创建独立的ThreadSafeYoLo11TRT实例
|
||
for config in model_configs:
|
||
try:
|
||
# model_path = config.get('path', '')
|
||
model_path = config.get('engine_path', '')
|
||
plugin_library = config.get('so_path', '')
|
||
cls_map = config.get('cls_map', {})
|
||
allowed = config.get('allowed_classes', None)
|
||
|
||
# model_path = r"/home/beidou/test0623/test_tensorrt/tensorrtx-master/yolo11/build/build.engine"
|
||
# plugin_library = r"/home/beidou/test0623/test_tensorrt/tensorrtx-master/yolo11/build/libmyplugins.so"
|
||
|
||
# 创建线程安全的模型实例
|
||
model = ThreadSafeYoLo11TRT(model_path, plugin_library)
|
||
self.models.append(model)
|
||
|
||
# 存储配置参数
|
||
self.class_maps.append(cls_map)
|
||
self.allowed_classes.append(allowed)
|
||
self.model_cls.append(config.get('cls_index', True))
|
||
self.chinese_label.append(config.get('chinese_label', {}))
|
||
self.list_func_id.append(config.get('list_func_id', -11))
|
||
self.func_id = config.get('func_id', True)
|
||
self.conf = config.get('config_conf', ConfigTrt.CONF_THRESHOLD)
|
||
self.list_class_names.append(config.get('class_names', ['unknown']))
|
||
self.list_para_invade_enable.append(config.get('para_invade_enable', False))
|
||
self.model_configs.append(config)
|
||
|
||
logger.info(f"成功加载模型: {model_path}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载模型失败 {config.get('path', 'unknown')}: {e}")
|
||
continue
|
||
|
||
if not self.models:
|
||
raise RuntimeError("没有可用的模型实例")
|
||
|
||
logger.info(f"多模型检测器初始化完成,共加载 {len(self.models)} 个模型")
|
||
|
||
async def predict(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
|
||
"""异步调用多模型预测"""
|
||
if not self.models:
|
||
return [], [], []
|
||
|
||
loop = asyncio.get_running_loop()
|
||
|
||
def _predict_single_model(model_idx: int, image: np.ndarray) -> Tuple[List[DetectionResult], Any, Any]:
|
||
"""单个模型的预测函数"""
|
||
try:
|
||
if model_idx >= len(self.models):
|
||
return [], DetectionResultList([], [], [], [], []), {}
|
||
|
||
model = self.models[model_idx]
|
||
config = self.model_configs[model_idx]
|
||
|
||
# 获取配置参数
|
||
cls_map = self.class_maps[model_idx]
|
||
allowed = self.allowed_classes[model_idx]
|
||
model_cls_index = self.model_cls[model_idx]
|
||
model_chinese_label = self.chinese_label[model_idx]
|
||
model_list_func_id = self.list_func_id[model_idx]
|
||
func_id = self.func_id
|
||
model_class_names = self.list_class_names[model_idx]
|
||
para_invade_enable = self.list_para_invade_enable[model_idx]
|
||
|
||
# 执行推理
|
||
frame_result, batch_image_raw, infer_time = model.infer([image])
|
||
|
||
# 处理检测结果
|
||
detections = []
|
||
detection_result_list = DetectionResultList([], [], [], [], [])
|
||
|
||
# 模型参数
|
||
model_para = {
|
||
"cls_map": cls_map,
|
||
"model_chinese_labe": model_chinese_label,
|
||
"model_cls_index": model_cls_index,
|
||
"model_list_func_id": model_list_func_id,
|
||
"func_id": func_id,
|
||
"model_class_names": model_class_names,
|
||
"para_invade_enable": para_invade_enable
|
||
}
|
||
|
||
# 解析检测结果
|
||
box_list = frame_result.get("box_list", [])
|
||
score_list = frame_result.get("score_list", [])
|
||
class_id_list = frame_result.get("class_id_list", [])
|
||
track_id_list = frame_result.get("track_id_list", [])
|
||
|
||
for i in range(len(box_list)):
|
||
try:
|
||
if i >= len(score_list) or i >= len(class_id_list):
|
||
continue
|
||
|
||
box = box_list[i]
|
||
score = score_list[i]
|
||
class_id = int(class_id_list[i]) # 修复:确保class_id是整数
|
||
track_id = track_id_list[i] if i < len(track_id_list) else -1
|
||
|
||
# 过滤低置信度检测结果
|
||
if score < self.conf:
|
||
continue
|
||
|
||
# 安全获取类别名称
|
||
class_name = "unknown"
|
||
if isinstance(model_class_names, list) and class_id < len(model_class_names):
|
||
class_name = model_class_names[class_id]
|
||
elif isinstance(model_class_names, dict):
|
||
class_name = model_class_names.get(class_id, "unknown")
|
||
|
||
# 创建DetectionResult对象
|
||
detection = DetectionResult(
|
||
bbox=box,
|
||
class_id=class_id,
|
||
class_name=class_name,
|
||
confidence=score,
|
||
track_id=track_id
|
||
)
|
||
detections.append(detection)
|
||
|
||
# 填充DetectionResultList
|
||
detection_result_list.boxes.append(box)
|
||
detection_result_list.clss.append(class_id)
|
||
detection_result_list.clss_name.append(str(class_id))
|
||
detection_result_list.confs.append(score)
|
||
detection_result_list.track_ids.append(track_id)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理检测结果时出错: {e}")
|
||
continue
|
||
|
||
logger.debug(f"模型 {model_idx} 检测到 {len(detections)} 个目标")
|
||
return detections, detection_result_list, model_para
|
||
|
||
except Exception as e:
|
||
logger.error(f"模型 {model_idx} 预测错误: {e}")
|
||
return [], DetectionResultList([], [], [], [], []), {}
|
||
|
||
# 创建并行任务
|
||
futures = [
|
||
loop.run_in_executor(
|
||
self.executor,
|
||
_predict_single_model,
|
||
model_idx,
|
||
frame.copy()
|
||
)
|
||
for model_idx in range(len(self.models))
|
||
]
|
||
|
||
# 等待所有任务完成
|
||
results = await asyncio.gather(*futures, return_exceptions=True)
|
||
|
||
# 合并结果
|
||
all_detections = []
|
||
all_detection_lists = []
|
||
all_model_paras = []
|
||
|
||
for result in results:
|
||
if isinstance(result, Exception):
|
||
logger.error(f"模型预测任务异常: {result}")
|
||
continue
|
||
|
||
detections, detection_list, model_para = result
|
||
all_detections.extend(detections)
|
||
all_detection_lists.append(detection_list)
|
||
all_model_paras.append(model_para)
|
||
|
||
logger.info(f"所有模型共检测到 {len(all_detections)} 个目标")
|
||
return all_detections, all_detection_lists, all_model_paras
|
||
|
||
def predict_sync(self, frame: np.ndarray) -> Tuple[List[DetectionResult], List, List]:
|
||
"""同步版本的多模型预测"""
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
try:
|
||
return loop.run_until_complete(self.predict(frame))
|
||
finally:
|
||
loop.close()
|
||
|
||
def destroy(self):
|
||
"""清理资源"""
|
||
for model in self.models:
|
||
try:
|
||
model.destroy()
|
||
logger.info("模型资源已清理")
|
||
except Exception as e:
|
||
logger.error(f"清理模型资源时出错: {e}")
|
||
|
||
self.executor.shutdown(wait=True)
|
||
logger.info("线程池已关闭")
|
||
|
||
|
||
# ====== 主程序 ======
|
||
def main():
|
||
for i in range(10):
|
||
# 示例配置
|
||
model_configs = [
|
||
{
|
||
'path': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\build.engine",
|
||
'plugin_library': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\myplugins.dll",
|
||
'cls_map': {0: 'building'},
|
||
'allowed_classes': None,
|
||
'cls_index': True,
|
||
'chinese_label': {0: '建筑物'},
|
||
'list_func_id': -11,
|
||
'func_id': True,
|
||
'class_names': ['building'],
|
||
'para_invade_enable': False,
|
||
'config_conf': 0.5
|
||
},
|
||
|
||
]
|
||
|
||
# 测试多模型检测器
|
||
try:
|
||
# 创建多模型检测器
|
||
multi_detector = MultiYoloTrtDetectorTrackId(model_configs)
|
||
|
||
# 视频文件路径
|
||
video_path = r"E:\yolo-dataset\test_video\shimian_invade\DJI_20250910142510_0001_V.mp4"
|
||
cap = cv2.VideoCapture(video_path)
|
||
|
||
if not cap.isOpened():
|
||
print(f"[ERROR] 无法打开视频文件: {video_path}")
|
||
return
|
||
|
||
# 性能统计
|
||
frame_count = 0
|
||
total_frames = 0
|
||
total_infer_time = 0
|
||
fps_smoothed = 0
|
||
last_time = time.time()
|
||
|
||
print(f"[INFO] 开始多模型并行检测测试,将在处理 {MAX_TEST_FRAMES} 帧后退出")
|
||
print(f"[INFO] 视频文件: {video_path}")
|
||
print(f"[INFO] 模型数量: {len(multi_detector.models)}")
|
||
|
||
while True:
|
||
ret, frame = cap.read()
|
||
if not ret:
|
||
print("[INFO] 视频播放完毕")
|
||
break
|
||
|
||
frame_count += 1
|
||
total_frames += 1
|
||
|
||
# 检查是否达到最大测试帧数
|
||
if total_frames >= MAX_TEST_FRAMES:
|
||
print(f"[INFO] 已达到最大测试帧数 {MAX_TEST_FRAMES},退出测试")
|
||
break
|
||
|
||
start_time = time.time()
|
||
|
||
# 执行多模型检测
|
||
detections, detection_lists, model_paras = multi_detector.predict_sync(frame)
|
||
|
||
infer_time = time.time() - start_time
|
||
total_infer_time += infer_time
|
||
|
||
# 计算FPS
|
||
current_time = time.time()
|
||
elapsed = current_time - last_time
|
||
if elapsed >= 1.0:
|
||
fps = frame_count / elapsed
|
||
fps_smoothed = fps
|
||
frame_count = 0
|
||
last_time = current_time
|
||
|
||
# 在图像上绘制检测结果
|
||
display_frame = frame.copy()
|
||
|
||
# 绘制所有检测结果
|
||
for i, detection in enumerate(detections):
|
||
box = detection.bbox
|
||
class_name = detection.class_name
|
||
confidence = detection.confidence
|
||
track_id = detection.track_id
|
||
|
||
# 绘制边界框
|
||
plot_one_box(
|
||
box,
|
||
display_frame,
|
||
label=f"{class_name}:{confidence:.2f}",
|
||
track_id=track_id
|
||
)
|
||
|
||
# 显示性能信息
|
||
fps_text = f"FPS: {fps_smoothed:.2f}"
|
||
time_text = f"Frame Time: {infer_time * 1000:.2f}ms"
|
||
frame_count_text = f"Frame: {total_frames}/{MAX_TEST_FRAMES}"
|
||
detection_text = f"Detections: {len(detections)}"
|
||
model_text = f"Models: {len(multi_detector.models)}"
|
||
|
||
cv2.putText(display_frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||
cv2.putText(display_frame, time_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||
cv2.putText(display_frame, frame_count_text, (10, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||
cv2.putText(display_frame, detection_text, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||
cv2.putText(display_frame, model_text, (10, 150), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||
|
||
# 显示结果
|
||
cv2.imshow("Multi-Model YOLO Detection - CUDA Context Test", display_frame)
|
||
|
||
# 每10帧输出一次进度
|
||
if total_frames % 10 == 0:
|
||
avg_infer_time = total_infer_time / total_frames
|
||
print(
|
||
f"[INFO] 已处理 {total_frames}/{MAX_TEST_FRAMES} 帧,当前FPS: {fps_smoothed:.2f},平均推理时间: {avg_infer_time * 1000:.2f}ms")
|
||
|
||
# 检查用户退出
|
||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||
print("[INFO] 用户请求退出")
|
||
break
|
||
|
||
# 计算总体性能
|
||
if total_frames > 0:
|
||
avg_fps = total_frames / total_infer_time
|
||
avg_infer_time = total_infer_time / total_frames
|
||
print(f"\n[INFO] 测试完成摘要:")
|
||
print(f"[INFO] - 总处理帧数: {total_frames}")
|
||
print(f"[INFO] - 总耗时: {total_infer_time:.2f}秒")
|
||
print(f"[INFO] - 平均FPS: {avg_fps:.2f}")
|
||
print(f"[INFO] - 平均推理时间: {avg_infer_time * 1000:.2f}ms")
|
||
print(f"[INFO] - 最大测试帧数限制: {MAX_TEST_FRAMES}")
|
||
print(f"[INFO] - 模型数量: {len(multi_detector.models)}")
|
||
|
||
cap.release()
|
||
cv2.destroyAllWindows()
|
||
|
||
# 清理资源
|
||
print("[INFO] 正在清理资源...")
|
||
multi_detector.destroy()
|
||
print("[INFO] 资源清理完成")
|
||
print("[INFO] CUDA上下文管理测试结束")
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] 多模型检测测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 确保清理资源
|
||
try:
|
||
if 'multi_detector' in locals():
|
||
multi_detector.destroy()
|
||
except:
|
||
pass
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|