import ctypes import os import random import sys import time import cv2 import numpy as np import pycuda.autoinit # noqa: F401 import pycuda.driver as cuda import tensorrt as trt CONF_THRESH = 0.5 IOU_THRESHOLD = 0.4 POSE_NUM = 17 * 3 DET_NUM = 6 SEG_NUM = 32 OBB_NUM = 1 # 全局变量(你原来代码里用到了 categories,但没有定义,这里补上) categories = ["NoHat","Hat"] # 根据你的实际类别修改 def get_img_path_batches(batch_size, img_dir): ret = [] batch = [] for root, dirs, files in os.walk(img_dir): for name in files: if len(batch) == batch_size: ret.append(batch) batch = [] batch.append(os.path.join(root, name)) if len(batch) > 0: ret.append(batch) return ret def plot_one_box(x, img, color=None, label=None, line_thickness=None): tl = ( line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 ) color = color or [random.randint(0, 255) for _ in range(3)] c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA) if label: tf = max(tl - 1, 1) t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0] c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) cv2.putText( img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA, ) class YoLo11TRT(object): def __init__(self, engine_file_path): # 只在这里 push 一次 CUDA 上下文 self.ctx = cuda.Device(0).make_context() stream = cuda.Stream() TRT_LOGGER = trt.Logger(trt.Logger.INFO) runtime = trt.Runtime(TRT_LOGGER) with open(engine_file_path, "rb") as f: engine = runtime.deserialize_cuda_engine(f.read()) context = engine.create_execution_context() host_inputs = [] cuda_inputs = [] host_outputs = [] cuda_outputs = [] bindings = [] for binding in engine: print('binding:', binding, engine.get_binding_shape(binding)) self.batch_size = engine.get_binding_shape(binding)[0] size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size dtype = trt.nptype(engine.get_binding_dtype(binding)) host_mem = cuda.pagelocked_empty(size, dtype) cuda_mem = cuda.mem_alloc(host_mem.nbytes) bindings.append(int(cuda_mem)) if engine.binding_is_input(binding): self.input_w = engine.get_binding_shape(binding)[-1] self.input_h = engine.get_binding_shape(binding)[-2] host_inputs.append(host_mem) cuda_inputs.append(cuda_mem) else: host_outputs.append(host_mem) cuda_outputs.append(cuda_mem) self.stream = stream self.context = context self.engine = engine self.host_inputs = host_inputs self.cuda_inputs = cuda_inputs self.host_outputs = host_outputs self.cuda_outputs = cuda_outputs self.bindings = bindings self.det_output_length = host_outputs[0].shape[0] def infer(self, raw_image_generator,conf=None,cate=None): global CONF_THRESH,categories if conf is not None: CONF_THRESH=conf if categories is not None: categories=cate # 移除了 self.ctx.push() 和 self.ctx.pop(),因为已在 __init__ 和 destroy 中处理 print(f"infer 1") start = time.time() stream = self.stream context = self.context host_inputs = self.host_inputs cuda_inputs = self.cuda_inputs host_outputs = self.host_outputs cuda_outputs = self.cuda_outputs bindings = self.bindings print(f"infer 2") batch_image_raw = [] batch_origin_h = [] batch_origin_w = [] batch_input_image = np.empty(shape=[self.batch_size, 3, self.input_h, self.input_w]) print(f"infer 3") for i, image_raw in enumerate(raw_image_generator): input_image, image_raw, origin_h, origin_w = self.preprocess_image(image_raw) batch_image_raw.append(image_raw) batch_origin_h.append(origin_h) batch_origin_w.append(origin_w) np.copyto(batch_input_image[i], input_image) batch_input_image = np.ascontiguousarray(batch_input_image) print(f"infer 4") np.copyto(host_inputs[0], batch_input_image.ravel()) cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream) context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream) stream.synchronize() print(f"infer 5") end = time.time() output = host_outputs[0] result_box_list=[] result_scores_list=[] result_classid_list=[] print(f"infer 6") for i in range(self.batch_size): result_boxes, result_scores, result_classid = self.post_process( output[i * self.det_output_length: (i + 1) * self.det_output_length], batch_origin_h[i], batch_origin_w[i] ) print(f"infer 7") result_box_list.append(result_boxes.tolist()) result_scores_list.append(result_scores.tolist()) result_classid_list.append(result_classid.tolist()) print(f"infer 8") for j in range(len(result_boxes)): box = result_boxes[j] if int(result_classid[j])<=len(categories): plot_one_box( box, batch_image_raw[i], label="{}:{:.2f}".format(categories[int(result_classid[j])], result_scores[j]) ) print(f"infer 9") return batch_image_raw,result_box_list,result_scores_list,result_classid_list,end - start def destroy(self): self.ctx.pop() # 只在这里 pop 一次 def get_raw_image(self, image_path_batch): for img_path in image_path_batch: yield cv2.imread(img_path) def get_raw_image_zeros(self, image_path_batch=None): for _ in range(self.batch_size): yield np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8) def preprocess_image(self, raw_bgr_image): image_raw = raw_bgr_image h, w, c = image_raw.shape image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB) r_w = self.input_w / w r_h = self.input_h / h if r_h > r_w: tw = self.input_w th = int(r_w * h) tx1 = tx2 = 0 ty1 = int((self.input_h - th) / 2) ty2 = self.input_h - th - ty1 else: tw = int(r_h * w) th = self.input_h tx1 = int((self.input_w - tw) / 2) tx2 = self.input_w - tw - tx1 ty1 = ty2 = 0 image = cv2.resize(image, (tw, th)) image = cv2.copyMakeBorder(image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, None, (128, 128, 128)) image = image.astype(np.float32) image /= 255.0 image = np.transpose(image, [2, 0, 1]) image = np.expand_dims(image, axis=0) image = np.ascontiguousarray(image) return image, image_raw, h, w def xywh2xyxy(self, origin_h, origin_w, x): y = np.zeros_like(x) r_w = self.input_w / origin_w r_h = self.input_h / origin_h if r_h > r_w: y[:, 0] = x[:, 0] y[:, 2] = x[:, 2] y[:, 1] = x[:, 1] - (self.input_h - r_w * origin_h) / 2 y[:, 3] = x[:, 3] - (self.input_h - r_w * origin_h) / 2 y /= r_w else: y[:, 0] = x[:, 0] - (self.input_w - r_h * origin_w) / 2 y[:, 2] = x[:, 2] - (self.input_w - r_h * origin_w) / 2 y[:, 1] = x[:, 1] y[:, 3] = x[:, 3] y /= r_h return y def post_process(self, output, origin_h, origin_w): num_values_per_detection = DET_NUM + SEG_NUM + POSE_NUM + OBB_NUM num = int(output[0]) pred = np.reshape(output[1:], (-1, num_values_per_detection))[:num, :] boxes = self.non_max_suppression(pred, origin_h, origin_w, conf_thres=CONF_THRESH, nms_thres=IOU_THRESHOLD) result_boxes = boxes[:, :4] if len(boxes) else np.array([]) result_scores = boxes[:, 4] if len(boxes) else np.array([]) result_classid = boxes[:, 5] if len(boxes) else np.array([]) return result_boxes, result_scores, result_classid def bbox_iou(self, box1, box2, x1y1x2y2=True): if not x1y1x2y2: b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2 b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2 b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2 b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2 else: b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3] b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3] inter_rect_x1 = np.maximum(b1_x1, b2_x1) inter_rect_y1 = np.maximum(b1_y1, b2_y1) inter_rect_x2 = np.minimum(b1_x2, b2_x2) inter_rect_y2 = np.minimum(b1_y2, b2_y2) inter_area = (np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None)) b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1) b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1) iou = inter_area / (b1_area + b2_area - inter_area + 1e-16) return iou def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4): boxes = prediction[prediction[:, 4] >= conf_thres] boxes[:, :4] = self.xywh2xyxy(origin_h, origin_w, boxes[:, :4]) boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1) boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1) boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1) boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1) confs = boxes[:, 4] boxes = boxes[np.argsort(-confs)] keep_boxes = [] while boxes.shape[0]: large_overlap = self.bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres label_match = boxes[0, -1] == boxes[:, -1] invalid = large_overlap & label_match keep_boxes += [boxes[0]] boxes = boxes[~invalid] boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([]) return boxes # ====== 视频推理 + 显示 + FPS统计(主程序,不变,直接使用) ====== # if __name__ == "__main__": # # PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\myplugins.dll" # # engine_file_path = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\build.engine" # engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine" # PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll" # if len(sys.argv) > 1: # engine_file_path = sys.argv[1] # if len(sys.argv) > 2: # PLUGIN_LIBRARY = sys.argv[2] # # ctypes.CDLL(PLUGIN_LIBRARY) # # yolo11_wrapper = YoLo11TRT(engine_file_path) # # try: # video_path = r"E:\yolo-dataset\test_video\shimian_invade\DJI_20250910142510_0001_V.mp4" # cap = cv2.VideoCapture(video_path) # # if not cap.isOpened(): # print("[ERROR] 无法打开视频文件:", video_path) # exit(-1) # # frame_count = 0 # fps_smoothed = 0 # last_time = time.time() # frame_times = [] # # print("[INFO] 开始处理视频,按 'q' 键退出 ...") # cost_time=0 # count=0 # while True: # ret, frame = cap.read() # if not ret: # print("[INFO] 视频播放完毕或读取失败") # break # # frame_count += 1 # start_time = time.time() # # time_start=time.time_ns() # # image_path = r"C:\Users\14867\Downloads\1758382029741-frame.jpg" # # batch_image_raw, infer_time = yolo11_wrapper.infer([frame]) # time_end = time.time_ns() # single_frame_time = infer_time # frame_times.append(single_frame_time * 1000) # # cost_time=cost_time+(time_end-time_start)/1000000 # count=count+1 # if count>100: # print(f"single_frame_time: {cost_time/count} ms") # count=0 # cost_time=0 # # # # # detected_frame = batch_image_raw[0] # # current_time = time.time() # elapsed = current_time - last_time # # if elapsed >= 1.0: # fps = frame_count / elapsed # fps_smoothed = fps # frame_count = 0 # last_time = current_time # # fps_text = f"FPS: {fps_smoothed:.2f}" # time_text = f"Frame Time: {single_frame_time * 1000:.2f}ms" # cv2.putText(detected_frame, fps_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) # cv2.putText(detected_frame, time_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) # # cv2.imshow("YOLO11 Detection", detected_frame) # # if cv2.waitKey(1) & 0xFF == ord('q'): # print("[INFO] 用户退出") # break # # cap.release() # cv2.destroyAllWindows() # # finally: # yolo11_wrapper.destroy() # 单张图片测试代码 if __name__ == "__main__": # 设置引擎和插件路径 engine_file_path = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine" PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll" # PLUGIN_LIBRARY = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\myplugins.dll" # engine_file_path = r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build-b\Release\build.engine" if len(sys.argv) > 1: engine_file_path = sys.argv[1] if len(sys.argv) > 2: PLUGIN_LIBRARY = sys.argv[2] # 加载插件 ctypes.CDLL(PLUGIN_LIBRARY) # 初始化推理器 yolo11_wrapper = YoLo11TRT(engine_file_path) try: # 读取测试图片 img_path = r"C:\Users\14867\Downloads\1758382029741-frame.jpg" # 替换为你的测试图片路径 frame = cv2.imread(img_path) if frame is None: print("[ERROR] 无法读取图片文件:", img_path) exit(-1) # 执行推理 start_time = time.time() batch_image_raw,result_box_list,result_scores_list,result_classid_list, infer_time = yolo11_wrapper.infer([frame]) detected_frame = batch_image_raw[0] # 显示结果 cv2.putText(detected_frame, f"Infer Time: {infer_time * 1000:.2f}ms", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.imshow("YOLO11 Detection", detected_frame) # 保存结果 output_path = img_path.replace(".jpg", "_detected.jpg") cv2.imwrite(output_path, detected_frame) print(f"结果已保存到: {output_path}") cv2.waitKey(0) cv2.destroyAllWindows() finally: yolo11_wrapper.destroy()