已经测试完成
This commit is contained in:
parent
540a20e15b
commit
173e94954d
@ -2,7 +2,6 @@ import logging
|
||||
import shutil
|
||||
import zipfile
|
||||
from os.path import exists
|
||||
|
||||
import torch
|
||||
import gc
|
||||
import os
|
||||
@ -14,9 +13,7 @@ import glob
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
from ultralytics import YOLO
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
from middleware.recognition_task import RecognitionTaskDAO, RecognitionTask
|
||||
from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder
|
||||
from util.yolo2pix_new import *
|
||||
@ -91,7 +88,6 @@ class InferenceResult:
|
||||
self.scores = [] # 置信度列表
|
||||
self.class_names = [] # 类别名称列表
|
||||
self.inference_time = 0.0 # 推理时间(秒)
|
||||
self.temp_label_path = None # YOLO临时输出目录路径
|
||||
|
||||
|
||||
class YOLOSegmentationInference:
|
||||
@ -227,7 +223,7 @@ class YOLOSegmentationInference:
|
||||
|
||||
# 使用YOLO模型进行推理
|
||||
# predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
|
||||
predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold)[0]
|
||||
predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0]
|
||||
|
||||
result.inference_time = time.time() - start_time
|
||||
|
||||
@ -291,12 +287,8 @@ class YOLOSegmentationInference:
|
||||
print(f"正在处理图片: {os.path.basename(image_path)}")
|
||||
start_time = time.time()
|
||||
|
||||
# 使用YOLO模型进行推理(启用save_txt自动生成标签)
|
||||
temp_output_dir = os.path.join(os.path.dirname(image_path), ".yolo_temp")
|
||||
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold,
|
||||
save=False, save_txt=True, project=temp_output_dir,
|
||||
name="labels", exist_ok=True)[0]
|
||||
result.temp_label_path = temp_output_dir
|
||||
# 使用YOLO模型进行推理
|
||||
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0]
|
||||
|
||||
result.inference_time = time.time() - start_time
|
||||
|
||||
@ -398,7 +390,7 @@ class YOLOSegmentationInference:
|
||||
result: 推理结果
|
||||
output_dir: 输出目录
|
||||
save_mask: 是否保存单独的掩码文件
|
||||
save_label: 是否保存YOLO格式的标签文件(使用YOLO原生格式)
|
||||
save_label: 是否保存YOLO格式的标签文件
|
||||
"""
|
||||
if result.result_image is None:
|
||||
return
|
||||
@ -437,26 +429,50 @@ class YOLOSegmentationInference:
|
||||
|
||||
print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}")
|
||||
|
||||
# 保存YOLO格式的标签文件(直接从YOLO输出复制)
|
||||
if save_label and result.masks is not None and len(result.masks) > 0:
|
||||
# 保存YOLO格式的标签文件
|
||||
if save_label and result.masks is not None and len(result.masks) > 0 and len(result.boxes) > 0:
|
||||
# label_dir = os.path.join(output_dir, "labels")
|
||||
label_dir = output_dir
|
||||
os.makedirs(label_dir, exist_ok=True)
|
||||
|
||||
label_path = os.path.join(label_dir, f"{base_name}.txt")
|
||||
|
||||
# 从YOLO临时输出目录复制txt文件
|
||||
if hasattr(result, 'temp_label_path') and result.temp_label_path:
|
||||
yolo_label_path = os.path.join(result.temp_label_path, "labels", f"{base_name}.txt")
|
||||
if os.path.exists(yolo_label_path):
|
||||
shutil.copy(yolo_label_path, label_path)
|
||||
print(f"标签文件已保存: {label_path}")
|
||||
redir_obj["label"] = label_path
|
||||
else:
|
||||
print(f"警告: YOLO生成的标签文件不存在: {yolo_label_path}")
|
||||
else:
|
||||
print(f"警告: 未找到YOLO临时输出目录")
|
||||
|
||||
with open(label_path, 'w') as f:
|
||||
for i in range(len(result.masks)):
|
||||
class_id = result.classes[i]
|
||||
score = result.scores[i]
|
||||
mask = result.masks[i]
|
||||
|
||||
# 获取掩码的多边形轮廓
|
||||
contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8),
|
||||
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
if contours:
|
||||
# 取最大的轮廓
|
||||
contour = max(contours, key=cv2.contourArea)
|
||||
|
||||
# 简化轮廓
|
||||
epsilon = 0.001 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
|
||||
# 归一化坐标
|
||||
h, w = mask.shape
|
||||
points = []
|
||||
for point in approx:
|
||||
x = point[0][0] / w
|
||||
y = point[0][1] / h
|
||||
points.extend([x, y])
|
||||
|
||||
# 写入标签文件
|
||||
if len(points) >= 6: # 至少3个点
|
||||
line = f"{class_id} {' '.join(map(lambda x: f'{x:.6f}', points))} {score:.6f}\n"
|
||||
f.write(line)
|
||||
|
||||
print(f"标签文件已保存: {label_path}")
|
||||
redir_obj["label"] = label_path
|
||||
result_save.append(redir_obj)
|
||||
|
||||
|
||||
|
||||
except PermissionError:
|
||||
print(f"权限错误: 无法写入到目录 {output_dir}")
|
||||
except Exception as e:
|
||||
@ -1008,8 +1024,4 @@ def predict_images_share_dir(task_id, pt_name, zip_url, user_name, pwd, output_d
|
||||
else :
|
||||
print(f"错误: 输入 {zip_url} 不是有效的文件或目录")
|
||||
|
||||
return standardized_path(f"{target_path}_识别/{task_id}/{current_time}/{task_id}.zip"), "success"
|
||||
|
||||
if __name__ == '__main__':
|
||||
predict_images
|
||||
#predict_images_share_dir(1, "road_crack", "smb://192.168.1.100/share/ai_train_platform/train.zip", "admin", "admin.123", "predictions", 0.25, True)
|
||||
return standardized_path(f"{target_path}_识别/{task_id}/{current_time}/{task_id}.zip"), "success"
|
||||
Loading…
x
Reference in New Issue
Block a user