This commit is contained in:
yooooger 2026-03-05 16:12:48 +08:00
parent 394e5aa071
commit c5a0996c88

View File

@ -2,7 +2,7 @@ import logging
import shutil import shutil
import zipfile import zipfile
from os.path import exists from os.path import exists
from pathlib import Path
import torch import torch
import gc import gc
import os import os
@ -385,7 +385,7 @@ class YOLOSegmentationInference:
return output_image return output_image
def save_results(self, result: InferenceResult, output_dir: str, save_mask: bool = False, def save_results(self, result: InferenceResult, output_dir: str, save_mask: bool = False,
save_label: bool = False, result_save: List[Dict] = None) -> None: save_label: bool = False, result_save: [] = None) -> None:
""" """
保存推理结果 保存推理结果
@ -459,7 +459,7 @@ class YOLOSegmentationInference:
approx = cv2.approxPolyDP(contour, epsilon, True) approx = cv2.approxPolyDP(contour, epsilon, True)
# 归一化坐标 # 归一化坐标
h, w = result.original_image.shape[:2] h, w = mask.shape
points = [] points = []
for point in approx: for point in approx:
x = point[0][0] / w x = point[0][0] / w
@ -508,127 +508,45 @@ class YOLOSegmentationInference:
plt.show() plt.show()
def process_single_image(self, image_path: str, output_dir: Optional[str] = None, def process_single_image(self, image_path: str, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5, conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = True, save_mask: bool = False, save_label: bool = False, show: bool = True,
result_save: Optional[List] = None): result_save: [] = None) -> InferenceResult:
""" """
处理单张图片使用YOLO分割模型 处理单张图片
Args: Args:
image_path: 输入图片路径 image_path: 图片路径
output_dir: 输出目录如果为None则不保存 output_dir: 输出目录如果为None则不保存
conf_threshold: 置信度阈值 conf_threshold: 置信度阈值
iou_threshold: NMS IoU阈值 iou_threshold: IOU阈值
save_mask: 是否保存单独的掩码图像PNG格式 save_mask: 是否保存单独的掩码文件
save_label: 是否保存YOLO格式标签文件txt含多边形或边界框 save_label: 是否保存YOLO格式的标签文件
show: 是否显示结果图像按任意键关闭 show: 是否显示结果
result_save: 可选的外部列表用于收集本次处理的信息如图像路径保存的文件等
Returns: Returns:
Ultralytics Results 对象包含原始图像边界框掩码等信息 推理结果
""" """
# 1. 确保模型已加载 # 执行推理
if not hasattr(self, 'model') or self.model is None: result = self.perform_inference(image_path, conf_threshold, iou_threshold)
raise ValueError("模型未加载请先初始化模型self.model = YOLO('...')")
# 2. 执行推理 # 绘制结果
results = self.model.predict( if result.masks is not None and len(result.masks) > 0:
source=image_path, self.draw_results(result, conf_threshold)
conf=conf_threshold,
iou=iou_threshold,
imgsz=640, # 可根据需要调整,或作为参数传入
save=False, # 手动保存,不依赖自动保存
save_txt=False, # 手动处理标签
retina_masks=True, # 获取与原图同分辨率的高精度掩码
)
# 单张图片,取第一个结果
result = results[0]
# 3. 准备输出目录 # 保存结果
if output_dir: if output_dir is not None:
output_path = Path(output_dir) self.save_results(result, output_dir, save_mask, save_label, result_save)
output_path.mkdir(parents=True, exist_ok=True)
base_name = Path(image_path).stem
else:
output_path = None
base_name = None
# 4. 绘制标注图像(使用 result.plot() # # 显示结果
annotated_img = result.plot() # 返回BGR numpy数组
# 5. 保存标注图像到输出目录(如果指定)
if output_path:
img_save_path = output_path / f"{base_name}_annotated.jpg"
cv2.imwrite(str(img_save_path), annotated_img)
else:
img_save_path = None
# 6. 保存单独的掩码文件(如果 save_mask 为 True 且有掩码)
mask_save_paths = []
if save_mask and result.masks is not None and output_path:
masks = result.masks.data.cpu().numpy() # (N, H, W) 二值掩码
for i, mask in enumerate(masks):
mask_img = (mask * 255).astype(np.uint8) # 转换为0-255图像
mask_file = output_path / f"{base_name}_mask_{i}.png"
cv2.imwrite(str(mask_file), mask_img)
mask_save_paths.append(str(mask_file))
# 7. 保存YOLO格式标签文件如果 save_label 为 True
label_save_path = None
if save_label and output_path:
# 优先保存分割多边形标签masks.xy
if result.masks is not None and hasattr(result.masks, 'xy') and result.masks.xy:
label_file = output_path / f"{base_name}.txt"
with open(label_file, 'w') as f:
for i, poly in enumerate(result.masks.xy):
cls_id = int(result.boxes.cls[i]) if result.boxes is not None else 0
# 归一化多边形点坐标(除以图像宽高)
h, w = result.orig_shape
normalized = poly / [w, h]
flat_coords = normalized.flatten().tolist()
line = f"{cls_id} " + " ".join([f"{x:.6f}" for x in flat_coords])
f.write(line + "\n")
label_save_path = str(label_file)
# 如果没有掩码但有边界框则保存检测标签YOLO检测格式
elif result.boxes is not None:
label_file = output_path / f"{base_name}.txt"
with open(label_file, 'w') as f:
boxes = result.boxes.xyxyn.cpu().numpy() # 归一化边界框 (xyxy)
cls_ids = result.boxes.cls.cpu().numpy().astype(int)
for box, cls_id in zip(boxes, cls_ids):
# 转换为 YOLO 格式 (x_center, y_center, width, height)
x1, y1, x2, y2 = box
x_center = (x1 + x2) / 2
y_center = (y1 + y2) / 2
width = x2 - x1
height = y2 - y1
line = f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
f.write(line + "\n")
label_save_path = str(label_file)
# # 8. 显示结果(如果 show 为 True
# if show: # if show:
# cv2.imshow("Segmentation Result", annotated_img) # self.show_results(result)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 9. 如果提供了 result_save 列表,将本次结果信息追加进去
if result_save is not None:
result_save.append({
'image_path': image_path,
'annotated_image_path': str(img_save_path) if img_save_path else None,
'mask_paths': mask_save_paths,
'label_path': label_save_path,
'num_objects': len(result.boxes) if result.boxes else 0,
})
# 10. 返回结果Ultralytics 的 Results 对象,包含所有推理信息)
return result return result
def process_single_image_share_dir(self, image_path, user_name, pwd, output_dir: Optional[str] = None, def process_single_image_share_dir(self, image_path, user_name, pwd, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5, conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = True, save_mask: bool = False, save_label: bool = False, show: bool = True,
result_save: List[Dict] = None) -> None: result_save: [] = None) -> None:
""" """
处理单张图片 处理单张图片
@ -667,7 +585,7 @@ class YOLOSegmentationInference:
def process_image_directory(self, input_dir: str, output_dir: Optional[str] = None, def process_image_directory(self, input_dir: str, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5, conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = False, save_mask: bool = False, save_label: bool = False, show: bool = False,
result_save: List[Dict] = None) -> List[ result_save: [] = None) -> List[
InferenceResult]: InferenceResult]:
""" """
处理目录中的所有图片 处理目录中的所有图片
@ -763,7 +681,7 @@ class YOLOSegmentationInference:
def process_image_directory_share_dir_circle(self, task_id, current_time, input_dir_list, user_name, pwd, output_dir: Optional[str] = None, def process_image_directory_share_dir_circle(self, task_id, current_time, input_dir_list, user_name, pwd, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5, conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = False, save_mask: bool = False, save_label: bool = False, show: bool = False,
result_save: List[Dict] = None) -> None: result_save: [] = None) -> None:
for input_dir in input_dir_list : for input_dir in input_dir_list :
self.process_image_directory_share_dir(task_id,current_time,input_dir,user_name,pwd,output_dir,conf_threshold,iou_threshold,save_mask,save_label,show,result_save) self.process_image_directory_share_dir(task_id,current_time,input_dir,user_name,pwd,output_dir,conf_threshold,iou_threshold,save_mask,save_label,show,result_save)
del_file_shutil(output_dir) del_file_shutil(output_dir)
@ -771,7 +689,7 @@ class YOLOSegmentationInference:
def process_image_directory_share_dir(self, task_id, current_time, input_dir, user_name, pwd, output_dir: Optional[str] = None, def process_image_directory_share_dir(self, task_id, current_time, input_dir, user_name, pwd, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5, conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = False, save_mask: bool = False, save_label: bool = False, show: bool = False,
result_save: List[Dict] = None) -> None: result_save: [] = None) -> None:
""" """
处理目录中的所有图片 处理目录中的所有图片