ai_project_v1/yolo/yolo_multi_model.py

167 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import threading
import queue
import time
from collections import defaultdict
import numpy as np
from ultralytics import YOLO
class MultiResults:
def __init__(self):
self.boxes = [] # 初始化 boxes 为空列表
self.clss = [] # 初始化 boxes 为空列表
self.cls_names = [] # 初始化 boxes 为空列表
self.cls_en_names = [] # 初始化 boxes 为空列表
self.confs = [] # 初始化 boxes 为空列表
# results = Results()
# print(results.boxes) # 输出: []
class YOLOModel:
def __init__(self, model_path,cls_map, allowed_classes=None):
self.model = YOLO(model_path)
self.allowed_classes = allowed_classes if allowed_classes else []
self.cls_map=cls_map
self.input_queue = queue.Queue(maxsize=1) # 每个模型有自己的输入队列
self.output_queue = queue.Queue(maxsize=1) # 每个模型有自己的输出队列
self.running = False
def predict(self, frame):
"""单帧预测(线程安全)"""
results = self.model(frame, verbose=False)
filtered_results = []
for result in results:
boxes = result.boxes.cpu().numpy()
for box in boxes:
cls_id = int(box.cls[0])
cls_name = self.model.names[cls_id]
if self.allowed_classes and cls_name not in self.allowed_classes:
continue
x1, y1, x2, y2 = box.xyxy[0].astype(int)
conf = box.conf[0]
en_name=self.cls_map[cls_name]
filtered_results.append({
'bbox': [x1, y1, x2, y2],
'class_id': cls_id,
'class_name': cls_name,
'en_name':en_name,
'confidence': float(conf)
})
return filtered_results
def worker(self):
"""模型推理工作线程"""
while self.running:
try:
# 从输入队列获取帧(带超时以避免永久阻塞)
frame, frame_id = self.input_queue.get(timeout=0.1)
starttime=time.time()
results = self.predict(frame)
endtime = time.time()
print(f"self.predict(frame) {endtime-starttime}")
# 将结果放入输出队列
self.output_queue.put((frame_id, results))
except queue.Empty:
continue
def start(self):
"""启动模型处理线程"""
self.running = True
threading.Thread(target=self.worker, daemon=True).start()
def stop(self):
"""停止模型处理线程"""
self.running = False
def multi_model_process_frame(frame, models, result_dict):
"""将帧分发给各模型线程,并收集结果"""
frame_id = id(frame)
# 将帧分发给所有模型线程
for model in models:
# 如果队列已满,先清空(处理积压)
while not model.input_queue.empty():
model.input_queue.get_nowait()
model.input_queue.put((frame, frame_id))
# 收集所有模型的结果
results = defaultdict(list)
remaining_models = len(models)
while remaining_models > 0:
for model in models:
try:
# 获取该模型的结果(带超时以避免永久阻塞)
result_frame_id, model_results = model.output_queue.get(timeout=0.1)
if result_frame_id == frame_id:
results[result_frame_id].extend(model_results)
remaining_models -= 1
except queue.Empty:
continue
# 将结果存入共享字典
if frame_id in results:
result_dict[frame_id] = results[frame_id]
print(1)
def generate_bright_color(class_id):
"""生成明亮且区分度高的颜色基于class_id确保同类同色"""
# 方法1使用HSV色彩空间推荐
hue = class_id * 137 % 360 # 137和360互质保证不同class_id有不同颜色
saturation = 240 # 高饱和度 (0-255)
value = 230 # 高亮度 (0-255)
# 将HSV转换为BGR
hsv = np.array([[[hue, saturation, value]]], dtype=np.uint8)
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return tuple(map(int, bgr[0][0]))
def draw_results(frame, results):
"""在帧上绘制检测结果"""
output_frame = frame.copy()
for result in results:
x1, y1, x2, y2 = result['bbox']
cls_name = result['class_name']
conf = result['confidence']
# # 随机颜色生成基于class_id确保同类同色
# color = (int(result['class_id'] * 57 % 200),
# int(result['class_id'] * 113 % 200),
# int(result['class_id'] * 197 % 200))
color=generate_bright_color(int(result['class_id']))
# 绘制边界框
cv2.rectangle(output_frame, (x1, y1), (x2, y2), color, 2)
# 绘制标签背景
label = f"{cls_name}: {conf:.2f}"
(label_width, label_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(output_frame, (x1, y1 - label_height - 5), (x1 + label_width, y1), color, cv2.FILLED)
# 绘制标签文本
cv2.putText(output_frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
return output_frame
# 将讲过做行列转换
def convert_result_to_multi_result(results):
multi_result=MultiResults()
for result in results:
multi_result.boxes.append(result["bbox"])
multi_result.clss.append(result["class_id"])
multi_result.cls_names.append(result["class_name"])
multi_result.cls_en_names.append(result["en_name"])
multi_result.confs.append(result["confidence"])
return multi_result