ai-train_platform/test/predict_seg.py

289 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
import time
from typing import List, Optional
from ultralytics import YOLO
import matplotlib.pyplot as plt
class InferenceResult:
"""存储推理结果的类"""
def __init__(self, image_path: str):
self.image_path = image_path
self.original_image = None # 原始图片 (RGB格式)
self.result_image = None # 绘制结果后的图片
self.masks = [] # 分割掩码列表
self.boxes = [] # 边界框列表
self.classes = [] # 类别ID列表
self.scores = [] # 置信度列表
self.class_names = [] # 类别名称列表
self.inference_time = 0.0 # 推理时间(秒)
class YOLOSegmentationFolderDemo:
"""YOLO分割模型文件夹处理Demo"""
def __init__(self, model_path: str, device: Optional[str] = None):
"""
初始化推理工具
Args:
model_path: 模型文件路径
device: 运行设备 ('cpu', 'cuda', 或 None)
"""
self.model_path = model_path
self.device = device
self.model = None
self.class_names = []
# 定义颜色映射(用于不同类别)
self.colors = [
(255, 0, 0), # 红色
(0, 255, 0), # 绿色
(0, 0, 255), # 蓝色
(255, 255, 0), # 黄色
(255, 0, 255), # 品红色
(0, 255, 255), # 青色
(128, 0, 0), # 深红色
(0, 128, 0), # 深绿色
(0, 0, 128), # 深蓝色
(128, 128, 0), # 深黄色
]
def load_model(self) -> bool:
"""加载YOLO分割模型"""
try:
print(f"正在加载模型: {self.model_path}")
self.model = YOLO(self.model_path)
# 设置设备
if self.device is not None:
self.model.to(self.device)
# 获取类别名称
self.class_names = list(self.model.names.values())
print(f"模型加载成功,包含 {len(self.class_names)} 个类别")
return True
except Exception as e:
print(f"加载模型失败: {e}")
return False
def preprocess_image(self, image_path: str) -> Optional[np.ndarray]:
"""图片预处理"""
try:
# 读取图片
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法读取图片: {image_path}")
# 转换为RGB格式
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image_rgb
except Exception as e:
print(f"图片预处理失败: {e}")
return None
def perform_inference(self, image_path: str, conf_threshold: float = 0.25,
iou_threshold: float = 0.5) -> InferenceResult:
"""执行推理"""
result = InferenceResult(image_path)
try:
if self.model is None:
raise ValueError("模型未加载请先调用load_model()")
# 读取图片
original_image = self.preprocess_image(image_path)
if original_image is None:
return result
result.original_image = original_image
# 执行推理
print(f"正在处理图片: {os.path.basename(image_path)}")
start_time = time.time()
# 使用YOLO模型进行推理
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
result.inference_time = time.time() - start_time
# 处理结果
if predictions.masks is not None:
# 处理掩码
result.masks = predictions.masks.data.cpu().numpy()
# 处理边界框
result.boxes = predictions.boxes.data.cpu().numpy()
# 处理类别和置信度
result.classes = predictions.boxes.cls.cpu().numpy().astype(int)
result.scores = predictions.boxes.conf.cpu().numpy()
result.class_names = [self.model.names[i] for i in result.classes]
print(f"检测到 {len(result.masks)} 个对象,推理时间: {result.inference_time:.3f}")
return result
except Exception as e:
print(f"推理失败: {e}")
return result
def draw_results(self, result: InferenceResult, conf_threshold: float = 0.25) -> Optional[np.ndarray]:
"""绘制推理结果"""
if result.original_image is None or (result.masks is None or len(result.masks) == 0):
return result.original_image
# 创建副本
output_image = result.original_image.copy()
height, width = output_image.shape[:2]
# 绘制掩码和边界框
for i in range(len(result.masks)):
if result.scores[i] < conf_threshold:
continue
class_id = result.classes[i]
class_name = result.class_names[i]
score = result.scores[i]
# 获取颜色
color = self.colors[class_id % len(self.colors)]
# 绘制掩码
mask = result.masks[i]
mask = cv2.resize(mask, (width, height))
mask = (mask > 0.5).astype(np.uint8) * 255
# 创建掩码彩色图层
mask_colored = np.zeros_like(output_image)
mask_colored[mask > 0] = color
# 混合掩码和原图
output_image = cv2.addWeighted(output_image, 0.7, mask_colored, 0.3, 0)
# 绘制边界框
if len(result.boxes) > i:
box = result.boxes[i]
x1, y1, x2, y2 = box[:4].astype(int)
cv2.rectangle(output_image, (x1, y1), (x2, y2), color, 2)
# 添加标签
label = f"{class_name}: {score:.2f}"
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(output_image, (x1, y1 - text_height),
(x1 + text_width, y1), color, -1)
cv2.putText(output_image, label, (x1, y1),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
result.result_image = output_image
return output_image
def save_result(self, result: InferenceResult, output_dir: str) -> bool:
"""保存单张图片的推理结果"""
try:
if result.result_image is None:
return False
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 生成结果文件名
base_name = os.path.splitext(os.path.basename(result.image_path))[0]
result_path = os.path.join(output_dir, f"{base_name}_result.jpg")
# 转换为BGR格式并保存
result_image_bgr = cv2.cvtColor(result.result_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(result_path, result_image_bgr)
print(f"结果图片已保存: {result_path}")
return True
except Exception as e:
print(f"保存结果失败: {e}")
return False
def process_folder(self, input_dir: str, output_dir: str,
conf_threshold: float = 0.25, iou_threshold: float = 0.5) -> List[str]:
"""
处理文件夹中的所有图片
Args:
input_dir: 输入图片文件夹路径
output_dir: 输出结果文件夹路径
conf_threshold: 置信度阈值
iou_threshold: IOU阈值
Returns:
成功处理的图片路径列表
"""
if not os.path.isdir(input_dir):
print(f"错误: {input_dir} 不是有效的目录")
return []
# 支持的图片扩展名
image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
processed_files = []
# 遍历文件夹
for root, _, files in os.walk(input_dir):
for file in files:
if not file.lower().endswith(image_extensions):
continue
image_path = os.path.join(root, file)
try:
# 执行推理
result = self.perform_inference(image_path, conf_threshold, iou_threshold)
# 绘制结果
if result.masks is not None and len(result.masks) > 0:
self.draw_results(result, conf_threshold)
# 保存结果
if result.result_image is not None:
# 创建相对于输入文件夹的输出子目录结构
relative_path = os.path.relpath(root, input_dir)
folder_output_dir = os.path.join(output_dir, relative_path)
if self.save_result(result, folder_output_dir):
processed_files.append(image_path)
except Exception as e:
print(f"处理图片 {image_path} 时出错: {e}")
continue
return processed_files
def main():
# 配置参数
input_dir = r"C:\Users\14867\xwechat_files\wxid_lqq76m3bwzja21_94e1\msg\file\2025-10\12121" # 输入图片文件夹
output_dir = r"D:\project\ai-train_platform\test\out" # 输出结果文件夹
model_path = r"C:\Users\14867\xwechat_files\wxid_lqq76m3bwzja21_94e1\msg\file\2025-10\road_crack(1).pt" # 替换为您的模型路径
conf_threshold = 0.25 # 置信度阈值
iou_threshold = 0.5 # IOU阈值
# 初始化Demo
demo = YOLOSegmentationFolderDemo(model_path=model_path)
# 加载模型
if not demo.load_model():
return
# 处理文件夹中的所有图片
processed_files = demo.process_folder(
input_dir=input_dir,
output_dir=output_dir,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold
)
print(f"\n处理完成! 共处理了 {len(processed_files)} 张图片.")
print(f"结果已保存到: {os.path.abspath(output_dir)}")
if __name__ == "__main__":
main()