已经测试完成
This commit is contained in:
parent
540a20e15b
commit
173e94954d
@ -2,7 +2,6 @@ import logging
|
|||||||
import shutil
|
import shutil
|
||||||
import zipfile
|
import zipfile
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
@ -14,9 +13,7 @@ import glob
|
|||||||
from typing import List, Tuple, Optional, Dict
|
from typing import List, Tuple, Optional, Dict
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
|
|
||||||
from middleware.recognition_task import RecognitionTaskDAO, RecognitionTask
|
from middleware.recognition_task import RecognitionTaskDAO, RecognitionTask
|
||||||
from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder
|
from middleware.minio_util import upload_file, downFile, check_zip_size, upload_folder
|
||||||
from util.yolo2pix_new import *
|
from util.yolo2pix_new import *
|
||||||
@ -91,7 +88,6 @@ class InferenceResult:
|
|||||||
self.scores = [] # 置信度列表
|
self.scores = [] # 置信度列表
|
||||||
self.class_names = [] # 类别名称列表
|
self.class_names = [] # 类别名称列表
|
||||||
self.inference_time = 0.0 # 推理时间(秒)
|
self.inference_time = 0.0 # 推理时间(秒)
|
||||||
self.temp_label_path = None # YOLO临时输出目录路径
|
|
||||||
|
|
||||||
|
|
||||||
class YOLOSegmentationInference:
|
class YOLOSegmentationInference:
|
||||||
@ -227,7 +223,7 @@ class YOLOSegmentationInference:
|
|||||||
|
|
||||||
# 使用YOLO模型进行推理
|
# 使用YOLO模型进行推理
|
||||||
# predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
|
# predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
|
||||||
predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold)[0]
|
predictions = self.model(original_image_rgb, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0]
|
||||||
|
|
||||||
result.inference_time = time.time() - start_time
|
result.inference_time = time.time() - start_time
|
||||||
|
|
||||||
@ -291,12 +287,8 @@ class YOLOSegmentationInference:
|
|||||||
print(f"正在处理图片: {os.path.basename(image_path)}")
|
print(f"正在处理图片: {os.path.basename(image_path)}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 使用YOLO模型进行推理(启用save_txt自动生成标签)
|
# 使用YOLO模型进行推理
|
||||||
temp_output_dir = os.path.join(os.path.dirname(image_path), ".yolo_temp")
|
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold, imgsz=640)[0]
|
||||||
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold,
|
|
||||||
save=False, save_txt=True, project=temp_output_dir,
|
|
||||||
name="labels", exist_ok=True)[0]
|
|
||||||
result.temp_label_path = temp_output_dir
|
|
||||||
|
|
||||||
result.inference_time = time.time() - start_time
|
result.inference_time = time.time() - start_time
|
||||||
|
|
||||||
@ -398,7 +390,7 @@ class YOLOSegmentationInference:
|
|||||||
result: 推理结果
|
result: 推理结果
|
||||||
output_dir: 输出目录
|
output_dir: 输出目录
|
||||||
save_mask: 是否保存单独的掩码文件
|
save_mask: 是否保存单独的掩码文件
|
||||||
save_label: 是否保存YOLO格式的标签文件(使用YOLO原生格式)
|
save_label: 是否保存YOLO格式的标签文件
|
||||||
"""
|
"""
|
||||||
if result.result_image is None:
|
if result.result_image is None:
|
||||||
return
|
return
|
||||||
@ -437,26 +429,50 @@ class YOLOSegmentationInference:
|
|||||||
|
|
||||||
print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}")
|
print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}")
|
||||||
|
|
||||||
# 保存YOLO格式的标签文件(直接从YOLO输出复制)
|
# 保存YOLO格式的标签文件
|
||||||
if save_label and result.masks is not None and len(result.masks) > 0:
|
if save_label and result.masks is not None and len(result.masks) > 0 and len(result.boxes) > 0:
|
||||||
|
# label_dir = os.path.join(output_dir, "labels")
|
||||||
label_dir = output_dir
|
label_dir = output_dir
|
||||||
os.makedirs(label_dir, exist_ok=True)
|
os.makedirs(label_dir, exist_ok=True)
|
||||||
|
|
||||||
label_path = os.path.join(label_dir, f"{base_name}.txt")
|
label_path = os.path.join(label_dir, f"{base_name}.txt")
|
||||||
|
|
||||||
# 从YOLO临时输出目录复制txt文件
|
with open(label_path, 'w') as f:
|
||||||
if hasattr(result, 'temp_label_path') and result.temp_label_path:
|
for i in range(len(result.masks)):
|
||||||
yolo_label_path = os.path.join(result.temp_label_path, "labels", f"{base_name}.txt")
|
class_id = result.classes[i]
|
||||||
if os.path.exists(yolo_label_path):
|
score = result.scores[i]
|
||||||
shutil.copy(yolo_label_path, label_path)
|
mask = result.masks[i]
|
||||||
|
|
||||||
|
# 获取掩码的多边形轮廓
|
||||||
|
contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8),
|
||||||
|
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
if contours:
|
||||||
|
# 取最大的轮廓
|
||||||
|
contour = max(contours, key=cv2.contourArea)
|
||||||
|
|
||||||
|
# 简化轮廓
|
||||||
|
epsilon = 0.001 * cv2.arcLength(contour, True)
|
||||||
|
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||||
|
|
||||||
|
# 归一化坐标
|
||||||
|
h, w = mask.shape
|
||||||
|
points = []
|
||||||
|
for point in approx:
|
||||||
|
x = point[0][0] / w
|
||||||
|
y = point[0][1] / h
|
||||||
|
points.extend([x, y])
|
||||||
|
|
||||||
|
# 写入标签文件
|
||||||
|
if len(points) >= 6: # 至少3个点
|
||||||
|
line = f"{class_id} {' '.join(map(lambda x: f'{x:.6f}', points))} {score:.6f}\n"
|
||||||
|
f.write(line)
|
||||||
|
|
||||||
print(f"标签文件已保存: {label_path}")
|
print(f"标签文件已保存: {label_path}")
|
||||||
redir_obj["label"] = label_path
|
redir_obj["label"] = label_path
|
||||||
else:
|
|
||||||
print(f"警告: YOLO生成的标签文件不存在: {yolo_label_path}")
|
|
||||||
else:
|
|
||||||
print(f"警告: 未找到YOLO临时输出目录")
|
|
||||||
result_save.append(redir_obj)
|
result_save.append(redir_obj)
|
||||||
|
|
||||||
|
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
print(f"权限错误: 无法写入到目录 {output_dir}")
|
print(f"权限错误: 无法写入到目录 {output_dir}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1009,7 +1025,3 @@ def predict_images_share_dir(task_id, pt_name, zip_url, user_name, pwd, output_d
|
|||||||
print(f"错误: 输入 {zip_url} 不是有效的文件或目录")
|
print(f"错误: 输入 {zip_url} 不是有效的文件或目录")
|
||||||
|
|
||||||
return standardized_path(f"{target_path}_识别/{task_id}/{current_time}/{task_id}.zip"), "success"
|
return standardized_path(f"{target_path}_识别/{task_id}/{current_time}/{task_id}.zip"), "success"
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
predict_images
|
|
||||||
#predict_images_share_dir(1, "road_crack", "smb://192.168.1.100/share/ai_train_platform/train.zip", "admin", "admin.123", "predictions", 0.25, True)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user