ai_project_v1/yolo/detect/yolo11_det_pic_trt.py

410 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ctypes
import os
import random
import sys
import time
import cv2
import numpy as np
import pycuda.autoinit # noqa: F401
import pycuda.driver as cuda
import tensorrt as trt
CONF_THRESH = 0.5
IOU_THRESHOLD = 0.4
POSE_NUM = 17 * 3
DET_NUM = 6
SEG_NUM = 32
OBB_NUM = 1
# 全局变量(你原来代码里用到了 categories但没有定义这里补上
categories = ["NoHat","Hat"] # 根据你的实际类别修改
def get_img_path_batches(batch_size, img_dir):
ret = []
batch = []
for root, dirs, files in os.walk(img_dir):
for name in files:
if len(batch) == batch_size:
ret.append(batch)
batch = []
batch.append(os.path.join(root, name))
if len(batch) > 0:
ret.append(batch)
return ret
def plot_one_box(x, img, color=None, label=None, line_thickness=None):
tl = (
line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1
)
color = color or [random.randint(0, 255) for _ in range(3)]
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
if label:
tf = max(tl - 1, 1)
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)
cv2.putText(
img,
label,
(c1[0], c1[1] - 2),
0,
tl / 3,
[225, 255, 255],
thickness=tf,
lineType=cv2.LINE_AA,
)
class YoLo11TRT(object):
def __init__(self, engine_file_path):
# 只在这里 push 一次 CUDA 上下文
self.ctx = cuda.Device(0).make_context()
stream = cuda.Stream()
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
runtime = trt.Runtime(TRT_LOGGER)
with open(engine_file_path, "rb") as f:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
host_inputs = []
cuda_inputs = []
host_outputs = []
cuda_outputs = []
bindings = []
for binding in engine:
print('binding:', binding, engine.get_binding_shape(binding))
self.batch_size = engine.get_binding_shape(binding)[0]
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
dtype = trt.nptype(engine.get_binding_dtype(binding))
host_mem = cuda.pagelocked_empty(size, dtype)
cuda_mem = cuda.mem_alloc(host_mem.nbytes)
bindings.append(int(cuda_mem))
if engine.binding_is_input(binding):
self.input_w = engine.get_binding_shape(binding)[-1]
self.input_h = engine.get_binding_shape(binding)[-2]
host_inputs.append(host_mem)
cuda_inputs.append(cuda_mem)
else:
host_outputs.append(host_mem)
cuda_outputs.append(cuda_mem)
self.stream = stream
self.context = context
self.engine = engine
self.host_inputs = host_inputs
self.cuda_inputs = cuda_inputs
self.host_outputs = host_outputs
self.cuda_outputs = cuda_outputs
self.bindings = bindings
self.det_output_length = host_outputs[0].shape[0]
def infer(self, raw_image_generator,conf=None,cate=None):
global CONF_THRESH,categories
if conf is not None:
CONF_THRESH=conf
if categories is not None:
categories=cate
# 移除了 self.ctx.push() 和 self.ctx.pop(),因为已在 __init__ 和 destroy 中处理
print(f"infer 1")
start = time.time()
stream = self.stream
context = self.context
host_inputs = self.host_inputs
cuda_inputs = self.cuda_inputs
host_outputs = self.host_outputs
cuda_outputs = self.cuda_outputs
bindings = self.bindings
print(f"infer 2")
batch_image_raw = []
batch_origin_h = []
batch_origin_w = []
batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w])
print(f"infer 3")
for i, image_raw in enumerate(raw_image_generator):
input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw)
batch_image_raw.append(image_raw)
batch_origin_h.append(origin_h)
batch_origin_w.append(origin_w)
np.copyto(batch_input_image[i], input_image)
batch_input_image = np.ascontiguousarray(batch_input_image)
print(f"infer 4")
np.copyto(host_inputs[0], batch_input_image.ravel())
cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
stream.synchronize()
print(f"infer 5")
end = time.time()
output = host_outputs[0]
result_box_list=[]
result_scores_list=[]
result_classid_list=[]
print(f"infer 6")
for i in range(self.batch_size):
result_boxes, result_scores, result_classid = self.post_process(
output[i * self.det_output_length: (i + 1) * self.det_output_length],
batch_origin_h[i], batch_origin_w[i]
)
print(f"infer 7")
result_box_list.append(result_boxes.tolist())
result_scores_list.append(result_scores.tolist())
result_classid_list.append(result_classid.tolist())
print(f"infer 8")
for j in range(len(result_boxes)):
box = result_boxes[j]
if int(result_classid[j])<=len(categories):
plot_one_box(
box,
batch_image_raw[i],
label="{}:{:.2f}".format(categories[int(result_classid[j])], result_scores[j])
)
print(f"infer 9")
return batch_image_raw,result_box_list,result_scores_list,result_classid_list,end - start
def destroy(self):
self.ctx.pop() # 只在这里 pop 一次
def get_raw_image(self, image_path_batch):
for img_path in image_path_batch:
yield cv2.imread(img_path)
def get_raw_image_zeros(self, image_path_batch=None):
for _ in range(self.batch_size):
yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8)
def preprocess_image(self, raw_bgr_image):
image_raw = raw_bgr_image
h, w, c = image_raw.shape
image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
r_w = self.input_w / w
r_h = self.input_h / h
if r_h > r_w:
tw = self.input_w
th = int(r_w * h)
tx1 = tx2 = 0
ty1 = int((self.input_h - th) / 2)
ty2 = self.input_h - th - ty1
else:
tw = int(r_h * w)
th = self.input_h
tx1 = int((self.input_w - tw) / 2)
tx2 = self.input_w - tw - tx1
ty1 = ty2 = 0
image = cv2.resize(image, (tw, th))
image = cv2.copyMakeBorder(image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (128, 128, 128))
image = image.astype(np.float32)
image /= 255.0
image = np.transpose(image, [2, 0, 1])
image = np.expand_dims(image, axis=0)
image = np.ascontiguousarray(image)
return image, image_raw, h, w
def xywh2xyxy(self, origin_h, origin_w, x):
y = np.zeros_like(x)
r_w = self.input_w / origin_w
r_h = self.input_h / origin_h
if r_h > r_w:
y[:, 0] = x[:, 0]
y[:, 2] = x[:, 2]
y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2
y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2
y /= r_w
else:
y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2
y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2
y[:, 1] = x[:, 1]
y[:, 3] = x[:, 3]
y /= r_h
return y
def post_process(self, output, origin_h, origin_w):
num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM
num = int(output[0])
pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :]
boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD)
result_boxes = boxes[:, :4] if len(boxes) else np.array([])
result_scores = boxes[:, 4] if len(boxes) else np.array([])
result_classid = boxes[:, 5] if len(boxes) else np.array([])
return result_boxes, result_scores, result_classid
def bbox_iou(self, box1, box2, x1y1x2y2=True):
if not x1y1x2y2:
b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
else:
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
inter_rect_x2 = np.minimum(b1_x2, b2_x2)
inter_rect_y2 = np.minimum(b1_y2, b2_y2)
inter_area = (np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) *
np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None))
b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)
iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)
return iou
def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
boxes = prediction[prediction[:, 4] >= conf_thres]
boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4])
boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1)
boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1)
boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1)
confs = boxes[:, 4]
boxes = boxes[np.argsort(-confs)]
keep_boxes = []
while boxes.shape[0]:
large_overlap = self.bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
label_match = boxes[0, -1] == boxes[:, -1]
invalid = large_overlap & label_match
keep_boxes += [boxes[0]]
boxes = boxes[~invalid]
boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
return boxes
# ====== 视频推理 + 显示 + FPS统计主程序不变直接使用 ======
# if __name__ == "__main__":
# # PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\myplugins.dll"
# # engine_file_path = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\build.engine"
# engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine"
# PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll"
# if len(sys.argv) > 1:
# engine_file_path = sys.argv[1]
# if len(sys.argv) > 2:
# PLUGIN_LIBRARY = sys.argv[2]
#
# ctypes.CDLL(PLUGIN_LIBRARY)
#
# yolo11_wrapper = YoLo11TRT(engine_file_path)
#
# try:
# video_path = r"E:\yolo-dataset\test_video\shimian_invade\DJI_20250910142510_0001_V.mp4"
# cap = cv2.VideoCapture(video_path)
#
# if not cap.isOpened():
# print("[ERROR] 无法打开视频文件:", video_path)
# exit(-1)
#
# frame_count = 0
# fps_smoothed = 0
# last_time = time.time()
# frame_times = []
#
# print("[INFO] 开始处理视频,按 'q' 键退出 ...")
# cost_time=0
# count=0
# while True:
# ret, frame = cap.read()
# if not ret:
# print("[INFO] 视频播放完毕或读取失败")
# break
#
# frame_count += 1
# start_time = time.time()
#
# time_start=time.time_ns()
#
# image_path = r"C:\Users\14867\Downloads\1758382029741-frame.jpg"
#
# batch_image_raw, infer_time = yolo11_wrapper.infer([frame])
# time_end = time.time_ns()
# single_frame_time = infer_time
# frame_times.append(single_frame_time * 1000)
#
# cost_time=cost_time+(time_end-time_start)/1000000
# count=count+1
# if count>100:
# print(f"single_frame_time: {cost_time/count} ms")
# count=0
# cost_time=0
#
#
#
#
# detected_frame = batch_image_raw[0]
#
# current_time = time.time()
# elapsed = current_time - last_time
#
# if elapsed >= 1.0:
# fps = frame_count / elapsed
# fps_smoothed = fps
# frame_count = 0
# last_time = current_time
#
# fps_text = f"FPS: {fps_smoothed:.2f}"
# time_text = f"Frame Time: {single_frame_time * 1000:.2f}ms"
# cv2.putText(detected_frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# cv2.putText(detected_frame, time_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
#
# cv2.imshow("YOLO11 Detection", detected_frame)
#
# if cv2.waitKey(1) & 0xFF == ord('q'):
# print("[INFO] 用户退出")
# break
#
# cap.release()
# cv2.destroyAllWindows()
#
# finally:
# yolo11_wrapper.destroy()
# 单张图片测试代码
if __name__ == "__main__":
# 设置引擎和插件路径
engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine"
PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll"
# PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\myplugins.dll"
# engine_file_path = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\build.engine"
if len(sys.argv) > 1:
engine_file_path = sys.argv[1]
if len(sys.argv) > 2:
PLUGIN_LIBRARY = sys.argv[2]
# 加载插件
ctypes.CDLL(PLUGIN_LIBRARY)
# 初始化推理器
yolo11_wrapper = YoLo11TRT(engine_file_path)
try:
# 读取测试图片
img_path = r"C:\Users\14867\Downloads\1758382029741-frame.jpg" # 替换为你的测试图片路径
frame = cv2.imread(img_path)
if frame is None:
print("[ERROR] 无法读取图片文件:", img_path)
exit(-1)
# 执行推理
start_time = time.time()
batch_image_raw,result_box_list,result_scores_list,result_classid_list, infer_time = yolo11_wrapper.infer([frame])
detected_frame = batch_image_raw[0]
# 显示结果
cv2.putText(detected_frame, f"Infer Time: {infer_time * 1000:.2f}ms",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.imshow("YOLO11 Detection", detected_frame)
# 保存结果
output_path = img_path.replace(".jpg", "_detected.jpg")
cv2.imwrite(output_path, detected_frame)
print(f"结果已保存到: {output_path}")
cv2.waitKey(0)
cv2.destroyAllWindows()
finally:
yolo11_wrapper.destroy()