289 lines
10 KiB
Python
289 lines
10 KiB
Python
import os
|
||
import cv2
|
||
import numpy as np
|
||
import time
|
||
from typing import List, Optional
|
||
from ultralytics import YOLO
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
class InferenceResult:
|
||
"""存储推理结果的类"""
|
||
|
||
def __init__(self, image_path: str):
|
||
self.image_path = image_path
|
||
self.original_image = None # 原始图片 (RGB格式)
|
||
self.result_image = None # 绘制结果后的图片
|
||
self.masks = [] # 分割掩码列表
|
||
self.boxes = [] # 边界框列表
|
||
self.classes = [] # 类别ID列表
|
||
self.scores = [] # 置信度列表
|
||
self.class_names = [] # 类别名称列表
|
||
self.inference_time = 0.0 # 推理时间(秒)
|
||
|
||
|
||
class YOLOSegmentationFolderDemo:
|
||
"""YOLO分割模型文件夹处理Demo"""
|
||
|
||
def __init__(self, model_path: str, device: Optional[str] = None):
|
||
"""
|
||
初始化推理工具
|
||
|
||
Args:
|
||
model_path: 模型文件路径
|
||
device: 运行设备 ('cpu', 'cuda', 或 None)
|
||
"""
|
||
self.model_path = model_path
|
||
self.device = device
|
||
self.model = None
|
||
self.class_names = []
|
||
|
||
# 定义颜色映射(用于不同类别)
|
||
self.colors = [
|
||
(255, 0, 0), # 红色
|
||
(0, 255, 0), # 绿色
|
||
(0, 0, 255), # 蓝色
|
||
(255, 255, 0), # 黄色
|
||
(255, 0, 255), # 品红色
|
||
(0, 255, 255), # 青色
|
||
(128, 0, 0), # 深红色
|
||
(0, 128, 0), # 深绿色
|
||
(0, 0, 128), # 深蓝色
|
||
(128, 128, 0), # 深黄色
|
||
]
|
||
|
||
def load_model(self) -> bool:
|
||
"""加载YOLO分割模型"""
|
||
try:
|
||
print(f"正在加载模型: {self.model_path}")
|
||
self.model = YOLO(self.model_path)
|
||
|
||
# 设置设备
|
||
if self.device is not None:
|
||
self.model.to(self.device)
|
||
|
||
# 获取类别名称
|
||
self.class_names = list(self.model.names.values())
|
||
print(f"模型加载成功,包含 {len(self.class_names)} 个类别")
|
||
return True
|
||
except Exception as e:
|
||
print(f"加载模型失败: {e}")
|
||
return False
|
||
|
||
def preprocess_image(self, image_path: str) -> Optional[np.ndarray]:
|
||
"""图片预处理"""
|
||
try:
|
||
# 读取图片
|
||
image = cv2.imread(image_path)
|
||
if image is None:
|
||
raise ValueError(f"无法读取图片: {image_path}")
|
||
|
||
# 转换为RGB格式
|
||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
return image_rgb
|
||
except Exception as e:
|
||
print(f"图片预处理失败: {e}")
|
||
return None
|
||
|
||
def perform_inference(self, image_path: str, conf_threshold: float = 0.25,
|
||
iou_threshold: float = 0.5) -> InferenceResult:
|
||
"""执行推理"""
|
||
result = InferenceResult(image_path)
|
||
|
||
try:
|
||
if self.model is None:
|
||
raise ValueError("模型未加载,请先调用load_model()")
|
||
|
||
# 读取图片
|
||
original_image = self.preprocess_image(image_path)
|
||
if original_image is None:
|
||
return result
|
||
|
||
result.original_image = original_image
|
||
|
||
# 执行推理
|
||
print(f"正在处理图片: {os.path.basename(image_path)}")
|
||
start_time = time.time()
|
||
|
||
# 使用YOLO模型进行推理
|
||
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
|
||
result.inference_time = time.time() - start_time
|
||
|
||
# 处理结果
|
||
if predictions.masks is not None:
|
||
# 处理掩码
|
||
result.masks = predictions.masks.data.cpu().numpy()
|
||
|
||
# 处理边界框
|
||
result.boxes = predictions.boxes.data.cpu().numpy()
|
||
|
||
# 处理类别和置信度
|
||
result.classes = predictions.boxes.cls.cpu().numpy().astype(int)
|
||
result.scores = predictions.boxes.conf.cpu().numpy()
|
||
result.class_names = [self.model.names[i] for i in result.classes]
|
||
|
||
print(f"检测到 {len(result.masks)} 个对象,推理时间: {result.inference_time:.3f} 秒")
|
||
|
||
return result
|
||
except Exception as e:
|
||
print(f"推理失败: {e}")
|
||
return result
|
||
|
||
def draw_results(self, result: InferenceResult, conf_threshold: float = 0.25) -> Optional[np.ndarray]:
|
||
"""绘制推理结果"""
|
||
if result.original_image is None or (result.masks is None or len(result.masks) == 0):
|
||
return result.original_image
|
||
|
||
# 创建副本
|
||
output_image = result.original_image.copy()
|
||
height, width = output_image.shape[:2]
|
||
|
||
# 绘制掩码和边界框
|
||
for i in range(len(result.masks)):
|
||
if result.scores[i] < conf_threshold:
|
||
continue
|
||
|
||
class_id = result.classes[i]
|
||
class_name = result.class_names[i]
|
||
score = result.scores[i]
|
||
|
||
# 获取颜色
|
||
color = self.colors[class_id % len(self.colors)]
|
||
|
||
# 绘制掩码
|
||
mask = result.masks[i]
|
||
mask = cv2.resize(mask, (width, height))
|
||
mask = (mask > 0.5).astype(np.uint8) * 255
|
||
|
||
# 创建掩码彩色图层
|
||
mask_colored = np.zeros_like(output_image)
|
||
mask_colored[mask > 0] = color
|
||
|
||
# 混合掩码和原图
|
||
output_image = cv2.addWeighted(output_image, 0.7, mask_colored, 0.3, 0)
|
||
|
||
# 绘制边界框
|
||
if len(result.boxes) > i:
|
||
box = result.boxes[i]
|
||
x1, y1, x2, y2 = box[:4].astype(int)
|
||
cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
|
||
|
||
# 添加标签
|
||
label = f"{class_name}: {score:.2f}"
|
||
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||
cv2.rectangle(output_image, (x1, y1 - text_height),
|
||
(x1 + text_width, y1), color, -1)
|
||
cv2.putText(output_image, label, (x1, y1),
|
||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||
|
||
result.result_image = output_image
|
||
return output_image
|
||
|
||
def save_result(self, result: InferenceResult, output_dir: str) -> bool:
|
||
"""保存单张图片的推理结果"""
|
||
try:
|
||
if result.result_image is None:
|
||
return False
|
||
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 生成结果文件名
|
||
base_name = os.path.splitext(os.path.basename(result.image_path))[0]
|
||
result_path = os.path.join(output_dir, f"{base_name}_result.jpg")
|
||
|
||
# 转换为BGR格式并保存
|
||
result_image_bgr = cv2.cvtColor(result.result_image, cv2.COLOR_RGB2BGR)
|
||
cv2.imwrite(result_path, result_image_bgr)
|
||
print(f"结果图片已保存: {result_path}")
|
||
return True
|
||
except Exception as e:
|
||
print(f"保存结果失败: {e}")
|
||
return False
|
||
|
||
def process_folder(self, input_dir: str, output_dir: str,
|
||
conf_threshold: float = 0.25, iou_threshold: float = 0.5) -> List[str]:
|
||
"""
|
||
处理文件夹中的所有图片
|
||
|
||
Args:
|
||
input_dir: 输入图片文件夹路径
|
||
output_dir: 输出结果文件夹路径
|
||
conf_threshold: 置信度阈值
|
||
iou_threshold: IOU阈值
|
||
|
||
Returns:
|
||
成功处理的图片路径列表
|
||
"""
|
||
if not os.path.isdir(input_dir):
|
||
print(f"错误: {input_dir} 不是有效的目录")
|
||
return []
|
||
|
||
# 支持的图片扩展名
|
||
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
|
||
processed_files = []
|
||
|
||
# 遍历文件夹
|
||
for root, _, files in os.walk(input_dir):
|
||
for file in files:
|
||
if not file.lower().endswith(image_extensions):
|
||
continue
|
||
|
||
image_path = os.path.join(root, file)
|
||
try:
|
||
# 执行推理
|
||
result = self.perform_inference(image_path, conf_threshold, iou_threshold)
|
||
|
||
# 绘制结果
|
||
if result.masks is not None and len(result.masks) > 0:
|
||
self.draw_results(result, conf_threshold)
|
||
|
||
# 保存结果
|
||
if result.result_image is not None:
|
||
# 创建相对于输入文件夹的输出子目录结构
|
||
relative_path = os.path.relpath(root, input_dir)
|
||
folder_output_dir = os.path.join(output_dir, relative_path)
|
||
|
||
if self.save_result(result, folder_output_dir):
|
||
processed_files.append(image_path)
|
||
|
||
except Exception as e:
|
||
print(f"处理图片 {image_path} 时出错: {e}")
|
||
continue
|
||
|
||
return processed_files
|
||
|
||
|
||
def main():
|
||
# 配置参数
|
||
input_dir = r"C:\Users\14867\xwechat_files\wxid_lqq76m3bwzja21_94e1\msg\file\2025-10\12121" # 输入图片文件夹
|
||
output_dir = r"D:\project\ai-train_platform\test\out" # 输出结果文件夹
|
||
|
||
model_path = r"C:\Users\14867\xwechat_files\wxid_lqq76m3bwzja21_94e1\msg\file\2025-10\road_crack(1).pt" # 替换为您的模型路径
|
||
|
||
|
||
conf_threshold = 0.25 # 置信度阈值
|
||
iou_threshold = 0.5 # IOU阈值
|
||
|
||
# 初始化Demo
|
||
demo = YOLOSegmentationFolderDemo(model_path=model_path)
|
||
|
||
# 加载模型
|
||
if not demo.load_model():
|
||
return
|
||
|
||
# 处理文件夹中的所有图片
|
||
processed_files = demo.process_folder(
|
||
input_dir=input_dir,
|
||
output_dir=output_dir,
|
||
conf_threshold=conf_threshold,
|
||
iou_threshold=iou_threshold
|
||
)
|
||
|
||
print(f"\n处理完成! 共处理了 {len(processed_files)} 张图片.")
|
||
print(f"结果已保存到: {os.path.abspath(output_dir)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|