ai_project_v1/yolo/detect/yolo11_det_pic_trt.py

410 lines
16 KiB
Python
Raw Normal View History

import ctypes
import os
import random
import sys
import time
import cv2
import numpy as np
import pycuda.autoinit # noqa: F401
import pycuda.driver as cuda
import tensorrt as trt
CONF_THRESH = 0.5
IOU_THRESHOLD = 0.4
POSE_NUM = 17 * 3
DET_NUM = 6
SEG_NUM = 32
OBB_NUM = 1
# 全局变量(你原来代码里用到了 categories但没有定义这里补上
categories = ["NoHat","Hat"] # 根据你的实际类别修改
def get_img_path_batches(batch_size, img_dir):
ret = []
batch = []
for root, dirs, files in os.walk(img_dir):
for name in files:
if len(batch) == batch_size:
ret.append(batch)
batch = []
batch.append(os.path.join(root, name))
if len(batch) > 0:
ret.append(batch)
return ret
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
tl = (
line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
)
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)
cv2.putText(
img,
label,
(c1[0], c1[1] - 2),
0,
tl / 3,
[225, 255, 255],
thickness=tf,
lineType=cv2.LINE_AA,
)
class YoLo11TRT(object):
def __init__(self, engine_file_path):
# 只在这里 push 一次 CUDA 上下文
self.ctx = cuda.Device(0).make_context()
stream = cuda.Stream()
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(TRT_LOGGER)
with open(engine_file_path, "rb") as f:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
host_inputs = []
cuda_inputs = []
host_outputs = []
cuda_outputs = []
bindings = []
for binding in engine:
print('binding:', binding, engine.get_binding_shape(binding))
self.batch_size = engine.get_binding_shape(binding)[0]
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
host_mem = cuda.pagelocked_empty(size, dtype)
cuda_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append(int(cuda_mem))
if engine.binding_is_input(binding):
self.input_w = engine.get_binding_shape(binding)[-1]
self.input_h = engine.get_binding_shape(binding)[-2]
host_inputs.append(host_mem)
cuda_inputs.append(cuda_mem)
else:
host_outputs.append(host_mem)
cuda_outputs.append(cuda_mem)
self.stream = stream
self.context = context
self.engine = engine
self.host_inputs = host_inputs
self.cuda_inputs = cuda_inputs
self.host_outputs = host_outputs
self.cuda_outputs = cuda_outputs
self.bindings = bindings
self.det_output_length = host_outputs[0].shape[0]
def infer(self, raw_image_generator,conf=None,cate=None):
global CONF_THRESH,categories
if conf is not None:
CONF_THRESH=conf
if categories is not None:
categories=cate
# 移除了 self.ctx.push() 和 self.ctx.pop(),因为已在 __init__ 和 destroy 中处理
print(f"infer 1")
start = time.time()
stream = self.stream
context = self.context
host_inputs = self.host_inputs
cuda_inputs = self.cuda_inputs
host_outputs = self.host_outputs
cuda_outputs = self.cuda_outputs
bindings = self.bindings
print(f"infer 2")
batch_image_raw = []
batch_origin_h = []
batch_origin_w = []
batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w])
print(f"infer 3")
for i, image_raw in enumerate(raw_image_generator):
input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw)
batch_image_raw.append(image_raw)
batch_origin_h.append(origin_h)
batch_origin_w.append(origin_w)
np.copyto(batch_input_image[i], input_image)
batch_input_image = np.ascontiguousarray(batch_input_image)
print(f"infer 4")
np.copyto(host_inputs[0], batch_input_image.ravel())
cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
stream.synchronize()
print(f"infer 5")
end = time.time()
output = host_outputs[0]
result_box_list=[]
result_scores_list=[]
result_classid_list=[]
print(f"infer 6")
for i in range(self.batch_size):
result_boxes, result_scores, result_classid = self.post_process(
output[i * self.det_output_length: (i + 1) * self.det_output_length],
batch_origin_h[i], batch_origin_w[i]
)
print(f"infer 7")
result_box_list.append(result_boxes.tolist())
result_scores_list.append(result_scores.tolist())
result_classid_list.append(result_classid.tolist())
print(f"infer 8")
for j in range(len(result_boxes)):
box = result_boxes[j]
if int(result_classid[j])<=len(categories):
plot_one_box(
box,
batch_image_raw[i],
label="{}:{:.2f}".format(categories[int(result_classid[j])], result_scores[j])
)
print(f"infer 9")
return batch_image_raw,result_box_list,result_scores_list,result_classid_list,end - start
def destroy(self):
self.ctx.pop() # 只在这里 pop 一次
def get_raw_image(self, image_path_batch):
for img_path in image_path_batch:
yield cv2.imread(img_path)
def get_raw_image_zeros(self, image_path_batch=None):
for _ in range(self.batch_size):
yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8)
def preprocess_image(self, raw_bgr_image):
image_raw = raw_bgr_image
h, w, c = image_raw.shape
image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
r_w = self.input_w / w
r_h = self.input_h / h
if r_h > r_w:
tw = self.input_w
th = int(r_w * h)
tx1 = tx2 = 0
ty1 = int((self.input_h - th) / 2)
ty2 = self.input_h - th - ty1
else:
tw = int(r_h * w)
th = self.input_h
tx1 = int((self.input_w - tw) / 2)
tx2 = self.input_w - tw - tx1
ty1 = ty2 = 0
image = cv2.resize(image, (tw, th))
image = cv2.copyMakeBorder(image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (128, 128, 128))
image = image.astype(np.float32)
image /= 255.0
image = np.transpose(image, [2, 0, 1])
image = np.expand_dims(image, axis=0)
image = np.ascontiguousarray(image)
return image, image_raw, h, w
def xywh2xyxy(self, origin_h, origin_w, x):
y = np.zeros_like(x)
r_w = self.input_w / origin_w
r_h = self.input_h / origin_h
if r_h > r_w:
y[:, 0] = x[:, 0]
y[:, 2] = x[:, 2]
y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2
y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2
y /= r_w
else:
y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2
y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2
y[:, 1] = x[:, 1]
y[:, 3] = x[:, 3]
y /= r_h
return y
def post_process(self, output, origin_h, origin_w):
num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM
num = int(output[0])
pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
result_boxes = boxes[:, :4] if len(boxes) else np.array([])
result_scores = boxes[:, 4] if len(boxes) else np.array([])
result_classid = boxes[:, 5] if len(boxes) else np.array([])
return result_boxes, result_scores, result_classid
def bbox_iou(self, box1, box2, x1y1x2y2=True):
if not x1y1x2y2:
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
else:
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
inter_rect_x2 = np.minimum(b1_x2, b2_x2)
inter_rect_y2 = np.minimum(b1_y2, b2_y2)
inter_area = (np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) *
np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None))
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
return iou
def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
boxes = prediction[prediction[:, 4] >= conf_thres]
boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1)
boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1)
boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1)
confs = boxes[:, 4]
boxes = boxes[np.argsort(-confs)]
keep_boxes = []
while boxes.shape[0]:
large_overlap = self.bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
label_match = boxes[0, -1] == boxes[:, -1]
invalid = large_overlap & label_match
keep_boxes += [boxes[0]]
boxes = boxes[~invalid]
boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
return boxes
# ====== 视频推理 + 显示 + FPS统计主程序不变直接使用 ======
# if __name__ == "__main__":
# # PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\myplugins.dll"
# # engine_file_path = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\build.engine"
# engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine"
# PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll"
# if len(sys.argv) > 1:
# engine_file_path = sys.argv[1]
# if len(sys.argv) > 2:
# PLUGIN_LIBRARY = sys.argv[2]
#
# ctypes.CDLL(PLUGIN_LIBRARY)
#
# yolo11_wrapper = YoLo11TRT(engine_file_path)
#
# try:
# video_path = r"E:\yolo-dataset\test_video\shimian_invade\DJI_20250910142510_0001_V.mp4"
# cap = cv2.VideoCapture(video_path)
#
# if not cap.isOpened():
# print("[ERROR] 无法打开视频文件:", video_path)
# exit(-1)
#
# frame_count = 0
# fps_smoothed = 0
# last_time = time.time()
# frame_times = []
#
# print("[INFO] 开始处理视频,按 'q' 键退出 ...")
# cost_time=0
# count=0
# while True:
# ret, frame = cap.read()
# if not ret:
# print("[INFO] 视频播放完毕或读取失败")
# break
#
# frame_count += 1
# start_time = time.time()
#
# time_start=time.time_ns()
#
# image_path = r"C:\Users\14867\Downloads\1758382029741-frame.jpg"
#
# batch_image_raw, infer_time = yolo11_wrapper.infer([frame])
# time_end = time.time_ns()
# single_frame_time = infer_time
# frame_times.append(single_frame_time * 1000)
#
# cost_time=cost_time+(time_end-time_start)/1000000
# count=count+1
# if count>100:
# print(f"single_frame_time: {cost_time/count} ms")
# count=0
# cost_time=0
#
#
#
#
# detected_frame = batch_image_raw[0]
#
# current_time = time.time()
# elapsed = current_time - last_time
#
# if elapsed >= 1.0:
# fps = frame_count / elapsed
# fps_smoothed = fps
# frame_count = 0
# last_time = current_time
#
# fps_text = f"FPS: {fps_smoothed:.2f}"
# time_text = f"Frame Time: {single_frame_time * 1000:.2f}ms"
# cv2.putText(detected_frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# cv2.putText(detected_frame, time_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
#
# cv2.imshow("YOLO11 Detection", detected_frame)
#
# if cv2.waitKey(1) & 0xFF == ord('q'):
# print("[INFO] 用户退出")
# break
#
# cap.release()
# cv2.destroyAllWindows()
#
# finally:
# yolo11_wrapper.destroy()
# 单张图片测试代码
if __name__ == "__main__":
# 设置引擎和插件路径
engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine"
PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll"
# PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\myplugins.dll"
# engine_file_path = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\build.engine"
if len(sys.argv) > 1:
engine_file_path = sys.argv[1]
if len(sys.argv) > 2:
PLUGIN_LIBRARY = sys.argv[2]
# 加载插件
ctypes.CDLL(PLUGIN_LIBRARY)
# 初始化推理器
yolo11_wrapper = YoLo11TRT(engine_file_path)
try:
# 读取测试图片
img_path = r"C:\Users\14867\Downloads\1758382029741-frame.jpg" # 替换为你的测试图片路径
frame = cv2.imread(img_path)
if frame is None:
print("[ERROR] 无法读取图片文件:", img_path)
exit(-1)
# 执行推理
start_time = time.time()
batch_image_raw,result_box_list,result_scores_list,result_classid_list, infer_time = yolo11_wrapper.infer([frame])
detected_frame = batch_image_raw[0]
# 显示结果
cv2.putText(detected_frame, f"Infer Time: {infer_time * 1000:.2f}ms",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.imshow("YOLO11 Detection", detected_frame)
# 保存结果
output_path = img_path.replace(".jpg", "_detected.jpg")
cv2.imwrite(output_path, detected_frame)
print(f"结果已保存到: {output_path}")
cv2.waitKey(0)
cv2.destroyAllWindows()
finally:
yolo11_wrapper.destroy()