diff --git a/predict/predict_yolo11seg.py b/predict/predict_yolo11seg.py index 99746c7..b4c5247 100644 --- a/predict/predict_yolo11seg.py +++ b/predict/predict_yolo11seg.py @@ -2,7 +2,6 @@ import logging import shutil import zipfile from os.path import exists - import torch import gc import os @@ -14,9 +13,7 @@ import glob from typing import List, Tuple, Optional, Dict from ultralytics import YOLO import matplotlib.pyplot as plt - from ultralytics import YOLO - from middleware.recognition_task import RecognitionTaskDAO, RecognitionTask from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder from util.yolo2pix_new import * @@ -91,7 +88,6 @@ class InferenceResult: self.scores = [] # 置信度列表 self.class_names = [] # 类别名称列表 self.inference_time = 0.0 # 推理时间(秒) - self.temp_label_path = None # YOLO临时输出目录路径 class YOLOSegmentationInference: @@ -227,7 +223,7 @@ class YOLOSegmentationInference: # 使用YOLO模型进行推理 # predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0] - predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold)[0] + predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0] result.inference_time = time.time() - start_time @@ -291,12 +287,8 @@ class YOLOSegmentationInference: print(f"正在处理图片: {os.path.basename(image_path)}") start_time = time.time() - # 使用YOLO模型进行推理(启用save_txt自动生成标签) - temp_output_dir = os.path.join(os.path.dirname(image_path), ".yolo_temp") - predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold, - save=False, save_txt=True, project=temp_output_dir, - name="labels", exist_ok=True)[0] - result.temp_label_path = temp_output_dir + # 使用YOLO模型进行推理 + predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0] result.inference_time = time.time() - start_time @@ -398,7 +390,7 @@ class YOLOSegmentationInference: result: 推理结果 output_dir: 输出目录 save_mask: 是否保存单独的掩码文件 - save_label: 是否保存YOLO格式的标签文件(使用YOLO原生格式) + save_label: 是否保存YOLO格式的标签文件 """ if result.result_image is None: return @@ -437,26 +429,50 @@ class YOLOSegmentationInference: print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}") - # 保存YOLO格式的标签文件(直接从YOLO输出复制) - if save_label and result.masks is not None and len(result.masks) > 0: + # 保存YOLO格式的标签文件 + if save_label and result.masks is not None and len(result.masks) > 0 and len(result.boxes) > 0: + # label_dir = os.path.join(output_dir, "labels") label_dir = output_dir os.makedirs(label_dir, exist_ok=True) label_path = os.path.join(label_dir, f"{base_name}.txt") - - # 从YOLO临时输出目录复制txt文件 - if hasattr(result, 'temp_label_path') and result.temp_label_path: - yolo_label_path = os.path.join(result.temp_label_path, "labels", f"{base_name}.txt") - if os.path.exists(yolo_label_path): - shutil.copy(yolo_label_path, label_path) - print(f"标签文件已保存: {label_path}") - redir_obj["label"] = label_path - else: - print(f"警告: YOLO生成的标签文件不存在: {yolo_label_path}") - else: - print(f"警告: 未找到YOLO临时输出目录") + + with open(label_path, 'w') as f: + for i in range(len(result.masks)): + class_id = result.classes[i] + score = result.scores[i] + mask = result.masks[i] + + # 获取掩码的多边形轮廓 + contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8), + cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + if contours: + # 取最大的轮廓 + contour = max(contours, key=cv2.contourArea) + + # 简化轮廓 + epsilon = 0.001 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + + # 归一化坐标 + h, w = mask.shape + points = [] + for point in approx: + x = point[0][0] / w + y = point[0][1] / h + points.extend([x, y]) + + # 写入标签文件 + if len(points) >= 6: # 至少3个点 + line = f"{class_id} {' '.join(map(lambda x: f'{x:.6f}', points))} {score:.6f}\n" + f.write(line) + + print(f"标签文件已保存: {label_path}") + redir_obj["label"] = label_path result_save.append(redir_obj) - + + except PermissionError: print(f"权限错误: 无法写入到目录 {output_dir}") except Exception as e: @@ -1008,8 +1024,4 @@ def predict_images_share_dir(task_id, pt_name, zip_url, user_name, pwd, output_d else : print(f"错误: 输入 {zip_url} 不是有效的文件或目录") - return standardized_path(f"{target_path}_识别/{task_id}/{current_time}/{task_id}.zip"), "success" - -if __name__ == '__main__': - predict_images - #predict_images_share_dir(1, "road_crack", "smb://192.168.1.100/share/ai_train_platform/train.zip", "admin", "admin.123", "predictions", 0.25, True) \ No newline at end of file + return standardized_path(f"{target_path}_识别/{task_id}/{current_time}/{task_id}.zip"), "success" \ No newline at end of file