已经测试完成

This commit is contained in:
yooooger 2026-03-10 17:35:36 +08:00
parent 540a20e15b
commit 173e94954d

View File

@ -2,7 +2,6 @@ import logging
import shutil import shutil
import zipfile import zipfile
from os.path import exists from os.path import exists
import torch import torch
import gc import gc
import os import os
@ -14,9 +13,7 @@ import glob
from typing import List, Tuple, Optional, Dict from typing import List, Tuple, Optional, Dict
from ultralytics import YOLO from ultralytics import YOLO
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ultralytics import YOLO from ultralytics import YOLO
from middleware.recognition_task import RecognitionTaskDAO, RecognitionTask from middleware.recognition_task import RecognitionTaskDAO, RecognitionTask
from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder
from util.yolo2pix_new import * from util.yolo2pix_new import *
@ -91,7 +88,6 @@ class InferenceResult:
self.scores = [] # 置信度列表 self.scores = [] # 置信度列表
self.class_names = [] # 类别名称列表 self.class_names = [] # 类别名称列表
self.inference_time = 0.0 # 推理时间(秒) self.inference_time = 0.0 # 推理时间(秒)
self.temp_label_path = None # YOLO临时输出目录路径
class YOLOSegmentationInference: class YOLOSegmentationInference:
@ -227,7 +223,7 @@ class YOLOSegmentationInference:
# 使用YOLO模型进行推理 # 使用YOLO模型进行推理
# predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0] # predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold)[0] predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0]
result.inference_time = time.time() - start_time result.inference_time = time.time() - start_time
@ -291,12 +287,8 @@ class YOLOSegmentationInference:
print(f"正在处理图片: {os.path.basename(image_path)}") print(f"正在处理图片: {os.path.basename(image_path)}")
start_time = time.time() start_time = time.time()
# 使用YOLO模型进行推理启用save_txt自动生成标签 # 使用YOLO模型进行推理
temp_output_dir = os.path.join(os.path.dirname(image_path), ".yolo_temp") predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0]
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold,
save=False, save_txt=True, project=temp_output_dir,
name="labels", exist_ok=True)[0]
result.temp_label_path = temp_output_dir
result.inference_time = time.time() - start_time result.inference_time = time.time() - start_time
@ -398,7 +390,7 @@ class YOLOSegmentationInference:
result: 推理结果 result: 推理结果
output_dir: 输出目录 output_dir: 输出目录
save_mask: 是否保存单独的掩码文件 save_mask: 是否保存单独的掩码文件
save_label: 是否保存YOLO格式的标签文件使用YOLO原生格式 save_label: 是否保存YOLO格式的标签文件
""" """
if result.result_image is None: if result.result_image is None:
return return
@ -437,26 +429,50 @@ class YOLOSegmentationInference:
print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}") print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}")
# 保存YOLO格式的标签文件直接从YOLO输出复制 # 保存YOLO格式的标签文件
if save_label and result.masks is not None and len(result.masks) > 0: if save_label and result.masks is not None and len(result.masks) > 0 and len(result.boxes) > 0:
# label_dir = os.path.join(output_dir, "labels")
label_dir = output_dir label_dir = output_dir
os.makedirs(label_dir, exist_ok=True) os.makedirs(label_dir, exist_ok=True)
label_path = os.path.join(label_dir, f"{base_name}.txt") label_path = os.path.join(label_dir, f"{base_name}.txt")
# 从YOLO临时输出目录复制txt文件 with open(label_path, 'w') as f:
if hasattr(result, 'temp_label_path') and result.temp_label_path: for i in range(len(result.masks)):
yolo_label_path = os.path.join(result.temp_label_path, "labels", f"{base_name}.txt") class_id = result.classes[i]
if os.path.exists(yolo_label_path): score = result.scores[i]
shutil.copy(yolo_label_path, label_path) mask = result.masks[i]
# 获取掩码的多边形轮廓
contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8),
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
# 取最大的轮廓
contour = max(contours, key=cv2.contourArea)
# 简化轮廓
epsilon = 0.001 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
# 归一化坐标
h, w = mask.shape
points = []
for point in approx:
x = point[0][0] / w
y = point[0][1] / h
points.extend([x, y])
# 写入标签文件
if len(points) >= 6: # 至少3个点
line = f"{class_id} {' '.join(map(lambda x: f'{x:.6f}', points))} {score:.6f}\n"
f.write(line)
print(f"标签文件已保存: {label_path}") print(f"标签文件已保存: {label_path}")
redir_obj["label"] = label_path redir_obj["label"] = label_path
else:
print(f"警告: YOLO生成的标签文件不存在: {yolo_label_path}")
else:
print(f"警告: 未找到YOLO临时输出目录")
result_save.append(redir_obj) result_save.append(redir_obj)
except PermissionError: except PermissionError:
print(f"权限错误: 无法写入到目录 {output_dir}") print(f"权限错误: 无法写入到目录 {output_dir}")
except Exception as e: except Exception as e:
@ -1009,7 +1025,3 @@ def predict_images_share_dir(task_id, pt_name, zip_url, user_name, pwd, output_d
print(f"错误: 输入 {zip_url} 不是有效的文件或目录") print(f"错误: 输入 {zip_url} 不是有效的文件或目录")
return standardized_path(f"{target_path}_识别/{task_id}/{current_time}/{task_id}.zip"), "success" return standardized_path(f"{target_path}_识别/{task_id}/{current_time}/{task_id}.zip"), "success"
if __name__ == '__main__':
predict_images
#predict_images_share_dir(1, "road_crack", "smb://192.168.1.100/share/ai_train_platform/train.zip", "admin", "admin.123", "predictions", 0.25, True)