167 lines
5.7 KiB
Python
167 lines
5.7 KiB
Python
import cv2
|
||
import threading
|
||
import queue
|
||
import time
|
||
from collections import defaultdict
|
||
import numpy as np
|
||
from ultralytics import YOLO
|
||
|
||
|
||
class MultiResults:
|
||
def __init__(self):
|
||
self.boxes = [] # 初始化 boxes 为空列表
|
||
self.clss = [] # 初始化 boxes 为空列表
|
||
self.cls_names = [] # 初始化 boxes 为空列表
|
||
self.cls_en_names = [] # 初始化 boxes 为空列表
|
||
self.confs = [] # 初始化 boxes 为空列表
|
||
|
||
|
||
# results = Results()
|
||
# print(results.boxes) # 输出: []
|
||
|
||
|
||
class YOLOModel:
|
||
def __init__(self, model_path,cls_map, allowed_classes=None):
|
||
self.model = YOLO(model_path)
|
||
self.allowed_classes = allowed_classes if allowed_classes else []
|
||
self.cls_map=cls_map
|
||
self.input_queue = queue.Queue(maxsize=1) # 每个模型有自己的输入队列
|
||
self.output_queue = queue.Queue(maxsize=1) # 每个模型有自己的输出队列
|
||
self.running = False
|
||
|
||
def predict(self, frame):
|
||
"""单帧预测(线程安全)"""
|
||
results = self.model(frame, verbose=False)
|
||
filtered_results = []
|
||
|
||
for result in results:
|
||
boxes = result.boxes.cpu().numpy()
|
||
for box in boxes:
|
||
cls_id = int(box.cls[0])
|
||
cls_name = self.model.names[cls_id]
|
||
|
||
if self.allowed_classes and cls_name not in self.allowed_classes:
|
||
continue
|
||
|
||
x1, y1, x2, y2 = box.xyxy[0].astype(int)
|
||
conf = box.conf[0]
|
||
en_name=self.cls_map[cls_name]
|
||
filtered_results.append({
|
||
'bbox': [x1, y1, x2, y2],
|
||
'class_id': cls_id,
|
||
'class_name': cls_name,
|
||
'en_name':en_name,
|
||
'confidence': float(conf)
|
||
})
|
||
|
||
return filtered_results
|
||
|
||
def worker(self):
|
||
"""模型推理工作线程"""
|
||
while self.running:
|
||
try:
|
||
# 从输入队列获取帧(带超时以避免永久阻塞)
|
||
frame, frame_id = self.input_queue.get(timeout=0.1)
|
||
starttime=time.time()
|
||
results = self.predict(frame)
|
||
endtime = time.time()
|
||
print(f"self.predict(frame) {endtime-starttime}")
|
||
# 将结果放入输出队列
|
||
self.output_queue.put((frame_id, results))
|
||
except queue.Empty:
|
||
continue
|
||
|
||
def start(self):
|
||
"""启动模型处理线程"""
|
||
self.running = True
|
||
threading.Thread(target=self.worker, daemon=True).start()
|
||
|
||
def stop(self):
|
||
"""停止模型处理线程"""
|
||
self.running = False
|
||
|
||
|
||
def multi_model_process_frame(frame, models, result_dict):
|
||
"""将帧分发给各模型线程,并收集结果"""
|
||
frame_id = id(frame)
|
||
|
||
# 将帧分发给所有模型线程
|
||
for model in models:
|
||
# 如果队列已满,先清空(处理积压)
|
||
while not model.input_queue.empty():
|
||
model.input_queue.get_nowait()
|
||
model.input_queue.put((frame, frame_id))
|
||
|
||
# 收集所有模型的结果
|
||
results = defaultdict(list)
|
||
remaining_models = len(models)
|
||
|
||
while remaining_models > 0:
|
||
for model in models:
|
||
try:
|
||
# 获取该模型的结果(带超时以避免永久阻塞)
|
||
result_frame_id, model_results = model.output_queue.get(timeout=0.1)
|
||
if result_frame_id == frame_id:
|
||
results[result_frame_id].extend(model_results)
|
||
remaining_models -= 1
|
||
except queue.Empty:
|
||
continue
|
||
|
||
# 将结果存入共享字典
|
||
if frame_id in results:
|
||
result_dict[frame_id] = results[frame_id]
|
||
|
||
print(1)
|
||
|
||
|
||
def generate_bright_color(class_id):
|
||
"""生成明亮且区分度高的颜色(基于class_id确保同类同色)"""
|
||
# 方法1:使用HSV色彩空间(推荐)
|
||
hue = class_id * 137 % 360 # 137和360互质,保证不同class_id有不同颜色
|
||
saturation = 240 # 高饱和度 (0-255)
|
||
value = 230 # 高亮度 (0-255)
|
||
|
||
# 将HSV转换为BGR
|
||
hsv = np.array([[[hue, saturation, value]]], dtype=np.uint8)
|
||
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
||
return tuple(map(int, bgr[0][0]))
|
||
|
||
def draw_results(frame, results):
|
||
"""在帧上绘制检测结果"""
|
||
output_frame = frame.copy()
|
||
for result in results:
|
||
x1, y1, x2, y2 = result['bbox']
|
||
cls_name = result['class_name']
|
||
conf = result['confidence']
|
||
|
||
# # 随机颜色生成(基于class_id确保同类同色)
|
||
# color = (int(result['class_id'] * 57 % 200),
|
||
# int(result['class_id'] * 113 % 200),
|
||
# int(result['class_id'] * 197 % 200))
|
||
|
||
color=generate_bright_color(int(result['class_id']))
|
||
|
||
# 绘制边界框
|
||
cv2.rectangle(output_frame, (x1, y1), (x2, y2), color, 2)
|
||
|
||
# 绘制标签背景
|
||
label = f"{cls_name}: {conf:.2f}"
|
||
(label_width, label_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||
cv2.rectangle(output_frame, (x1, y1 - label_height - 5), (x1 + label_width, y1), color, cv2.FILLED)
|
||
|
||
# 绘制标签文本
|
||
cv2.putText(output_frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
||
|
||
return output_frame
|
||
|
||
# 将讲过做行列转换
|
||
def convert_result_to_multi_result(results):
|
||
multi_result=MultiResults()
|
||
for result in results:
|
||
multi_result.boxes.append(result["bbox"])
|
||
multi_result.clss.append(result["class_id"])
|
||
multi_result.cls_names.append(result["class_name"])
|
||
multi_result.cls_en_names.append(result["en_name"])
|
||
multi_result.confs.append(result["confidence"])
|
||
return multi_result
|