diff --git a/predict/predict_yolo11seg.py b/predict/predict_yolo11seg.py index 48dfe12..ff83b3b 100644 --- a/predict/predict_yolo11seg.py +++ b/predict/predict_yolo11seg.py @@ -2,7 +2,7 @@ import logging import shutil import zipfile from os.path import exists - +from pathlib import Path import torch import gc import os @@ -385,7 +385,7 @@ class YOLOSegmentationInference: return output_image def save_results(self, result: InferenceResult, output_dir: str, save_mask: bool = False, - save_label: bool = False, result_save: [] = None) -> None: + save_label: bool = False, result_save: List[Dict] = None) -> None: """ 保存推理结果 @@ -508,45 +508,127 @@ class YOLOSegmentationInference: plt.show() def process_single_image(self, image_path: str, output_dir: Optional[str] = None, - conf_threshold: float = 0.25, iou_threshold: float = 0.5, - save_mask: bool = False, save_label: bool = False, show: bool = True, - result_save: [] = None) -> InferenceResult: + conf_threshold: float = 0.25, iou_threshold: float = 0.5, + save_mask: bool = False, save_label: bool = False, show: bool = True, + result_save: Optional[List] = None): """ - 处理单张图片 - + 处理单张图片(使用YOLO分割模型) + Args: - image_path: 图片路径 - output_dir: 输出目录,如果为None则不保存 + image_path: 输入图片路径 + output_dir: 输出目录(如果为None则不保存) conf_threshold: 置信度阈值 - iou_threshold: IOU阈值 - save_mask: 是否保存单独的掩码文件 - save_label: 是否保存YOLO格式的标签文件 - show: 是否显示结果 + iou_threshold: NMS IoU阈值 + save_mask: 是否保存单独的掩码图像(PNG格式) + save_label: 是否保存YOLO格式标签文件(txt,含多边形或边界框) + show: 是否显示结果图像(按任意键关闭) + result_save: 可选的外部列表,用于收集本次处理的信息(如图像路径、保存的文件等) Returns: - 推理结果 + Ultralytics 的 Results 对象(包含原始图像、边界框、掩码等信息) """ - # 执行推理 - result = self.perform_inference(image_path, conf_threshold, iou_threshold) + # 1. 确保模型已加载 + if not hasattr(self, 'model') or self.model is None: + raise ValueError("模型未加载,请先初始化模型(self.model = YOLO('...'))") - # 绘制结果 - if result.masks is not None and len(result.masks) > 0: - self.draw_results(result, conf_threshold) + # 2. 执行推理 + results = self.model.predict( + source=image_path, + conf=conf_threshold, + iou=iou_threshold, + imgsz=640, # 可根据需要调整,或作为参数传入 + save=False, # 手动保存,不依赖自动保存 + save_txt=False, # 手动处理标签 + retina_masks=True, # 获取与原图同分辨率的高精度掩码 + ) + # 单张图片,取第一个结果 + result = results[0] - # 保存结果 - if output_dir is not None: - self.save_results(result, output_dir, save_mask, save_label, result_save) + # 3. 准备输出目录 + if output_dir: + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + base_name = Path(image_path).stem + else: + output_path = None + base_name = None - # # 显示结果 + # 4. 绘制标注图像(使用 result.plot()) + annotated_img = result.plot() # 返回BGR numpy数组 + + # 5. 保存标注图像到输出目录(如果指定) + if output_path: + img_save_path = output_path / f"{base_name}_annotated.jpg" + cv2.imwrite(str(img_save_path), annotated_img) + else: + img_save_path = None + + # 6. 保存单独的掩码文件(如果 save_mask 为 True 且有掩码) + mask_save_paths = [] + if save_mask and result.masks is not None and output_path: + masks = result.masks.data.cpu().numpy() # (N, H, W) 二值掩码 + for i, mask in enumerate(masks): + mask_img = (mask * 255).astype(np.uint8) # 转换为0-255图像 + mask_file = output_path / f"{base_name}_mask_{i}.png" + cv2.imwrite(str(mask_file), mask_img) + mask_save_paths.append(str(mask_file)) + + # 7. 保存YOLO格式标签文件(如果 save_label 为 True) + label_save_path = None + if save_label and output_path: + # 优先保存分割多边形标签(masks.xy) + if result.masks is not None and hasattr(result.masks, 'xy') and result.masks.xy: + label_file = output_path / f"{base_name}.txt" + with open(label_file, 'w') as f: + for i, poly in enumerate(result.masks.xy): + cls_id = int(result.boxes.cls[i]) if result.boxes is not None else 0 + # 归一化多边形点坐标(除以图像宽高) + h, w = result.orig_shape + normalized = poly / [w, h] + flat_coords = normalized.flatten().tolist() + line = f"{cls_id} " + " ".join([f"{x:.6f}" for x in flat_coords]) + f.write(line + "\n") + label_save_path = str(label_file) + # 如果没有掩码但有边界框,则保存检测标签(YOLO检测格式) + elif result.boxes is not None: + label_file = output_path / f"{base_name}.txt" + with open(label_file, 'w') as f: + boxes = result.boxes.xyxyn.cpu().numpy() # 归一化边界框 (xyxy) + cls_ids = result.boxes.cls.cpu().numpy().astype(int) + for box, cls_id in zip(boxes, cls_ids): + # 转换为 YOLO 格式 (x_center, y_center, width, height) + x1, y1, x2, y2 = box + x_center = (x1 + x2) / 2 + y_center = (y1 + y2) / 2 + width = x2 - x1 + height = y2 - y1 + line = f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}" + f.write(line + "\n") + label_save_path = str(label_file) + + # # 8. 显示结果(如果 show 为 True) # if show: - # self.show_results(result) + # cv2.imshow("Segmentation Result", annotated_img) + # cv2.waitKey(0) + # cv2.destroyAllWindows() + # 9. 如果提供了 result_save 列表,将本次结果信息追加进去 + if result_save is not None: + result_save.append({ + 'image_path': image_path, + 'annotated_image_path': str(img_save_path) if img_save_path else None, + 'mask_paths': mask_save_paths, + 'label_path': label_save_path, + 'num_objects': len(result.boxes) if result.boxes else 0, + }) + + # 10. 返回结果(Ultralytics 的 Results 对象,包含所有推理信息) return result def process_single_image_share_dir(self, image_path, user_name, pwd, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, save_mask: bool = False, save_label: bool = False, show: bool = True, - result_save: [] = None) -> None: + result_save: List[Dict] = None) -> None: """ 处理单张图片 @@ -585,7 +667,7 @@ class YOLOSegmentationInference: def process_image_directory(self, input_dir: str, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, save_mask: bool = False, save_label: bool = False, show: bool = False, - result_save: [] = None) -> List[ + result_save: List[Dict] = None) -> List[ InferenceResult]: """ 处理目录中的所有图片 @@ -681,7 +763,7 @@ class YOLOSegmentationInference: def process_image_directory_share_dir_circle(self, task_id, current_time, input_dir_list, user_name, pwd, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, save_mask: bool = False, save_label: bool = False, show: bool = False, - result_save: [] = None) -> None: + result_save: List[Dict] = None) -> None: for input_dir in input_dir_list : self.process_image_directory_share_dir(task_id,current_time,input_dir,user_name,pwd,output_dir,conf_threshold,iou_threshold,save_mask,save_label,show,result_save) del_file_shutil(output_dir) @@ -689,7 +771,7 @@ class YOLOSegmentationInference: def process_image_directory_share_dir(self, task_id, current_time, input_dir, user_name, pwd, output_dir: Optional[str] = None, conf_threshold: float = 0.25, iou_threshold: float = 0.5, save_mask: bool = False, save_label: bool = False, show: bool = False, - result_save: [] = None) -> None: + result_save: List[Dict] = None) -> None: """ 处理目录中的所有图片