修改了识别逻辑

This commit is contained in:
yooooger 2026-03-05 15:30:21 +08:00
parent 3d6210adb8
commit b0da7002e3

View File

@ -2,7 +2,7 @@ import logging
import shutil
import zipfile
from os.path import exists
from pathlib import Path
import torch
import gc
import os
@ -385,7 +385,7 @@ class YOLOSegmentationInference:
return output_image
def save_results(self, result: InferenceResult, output_dir: str, save_mask: bool = False,
save_label: bool = False, result_save: [] = None) -> None:
save_label: bool = False, result_save: List[Dict] = None) -> None:
"""
保存推理结果
@ -508,45 +508,127 @@ class YOLOSegmentationInference:
plt.show()
def process_single_image(self, image_path: str, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = True,
result_save: [] = None) -> InferenceResult:
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = True,
result_save: Optional[List] = None):
"""
处理单张图片
处理单张图片使用YOLO分割模型
Args:
image_path: 图片路径
output_dir: 输出目录如果为None则不保存
image_path: 输入图片路径
output_dir: 输出目录如果为None则不保存
conf_threshold: 置信度阈值
iou_threshold: IOU阈值
save_mask: 是否保存单独的掩码文件
save_label: 是否保存YOLO格式的标签文件
show: 是否显示结果
iou_threshold: NMS IoU阈值
save_mask: 是否保存单独的掩码图像PNG格式
save_label: 是否保存YOLO格式标签文件txt含多边形或边界框
show: 是否显示结果图像按任意键关闭
result_save: 可选的外部列表用于收集本次处理的信息如图像路径保存的文件等
Returns:
推理结果
Ultralytics Results 对象包含原始图像边界框掩码等信息
"""
# 执行推理
result = self.perform_inference(image_path, conf_threshold, iou_threshold)
# 1. 确保模型已加载
if not hasattr(self, 'model') or self.model is None:
raise ValueError("模型未加载请先初始化模型self.model = YOLO('...')")
# 绘制结果
if result.masks is not None and len(result.masks) > 0:
self.draw_results(result, conf_threshold)
# 2. 执行推理
results = self.model.predict(
source=image_path,
conf=conf_threshold,
iou=iou_threshold,
imgsz=640, # 可根据需要调整,或作为参数传入
save=False, # 手动保存,不依赖自动保存
save_txt=False, # 手动处理标签
retina_masks=True, # 获取与原图同分辨率的高精度掩码
)
# 单张图片,取第一个结果
result = results[0]
# 保存结果
if output_dir is not None:
self.save_results(result, output_dir, save_mask, save_label, result_save)
# 3. 准备输出目录
if output_dir:
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
base_name = Path(image_path).stem
else:
output_path = None
base_name = None
# # 显示结果
# 4. 绘制标注图像(使用 result.plot()
annotated_img = result.plot() # 返回BGR numpy数组
# 5. 保存标注图像到输出目录(如果指定)
if output_path:
img_save_path = output_path / f"{base_name}_annotated.jpg"
cv2.imwrite(str(img_save_path), annotated_img)
else:
img_save_path = None
# 6. 保存单独的掩码文件(如果 save_mask 为 True 且有掩码)
mask_save_paths = []
if save_mask and result.masks is not None and output_path:
masks = result.masks.data.cpu().numpy() # (N, H, W) 二值掩码
for i, mask in enumerate(masks):
mask_img = (mask * 255).astype(np.uint8) # 转换为0-255图像
mask_file = output_path / f"{base_name}_mask_{i}.png"
cv2.imwrite(str(mask_file), mask_img)
mask_save_paths.append(str(mask_file))
# 7. 保存YOLO格式标签文件如果 save_label 为 True
label_save_path = None
if save_label and output_path:
# 优先保存分割多边形标签masks.xy
if result.masks is not None and hasattr(result.masks, 'xy') and result.masks.xy:
label_file = output_path / f"{base_name}.txt"
with open(label_file, 'w') as f:
for i, poly in enumerate(result.masks.xy):
cls_id = int(result.boxes.cls[i]) if result.boxes is not None else 0
# 归一化多边形点坐标(除以图像宽高)
h, w = result.orig_shape
normalized = poly / [w, h]
flat_coords = normalized.flatten().tolist()
line = f"{cls_id} " + " ".join([f"{x:.6f}" for x in flat_coords])
f.write(line + "\n")
label_save_path = str(label_file)
# 如果没有掩码但有边界框则保存检测标签YOLO检测格式
elif result.boxes is not None:
label_file = output_path / f"{base_name}.txt"
with open(label_file, 'w') as f:
boxes = result.boxes.xyxyn.cpu().numpy() # 归一化边界框 (xyxy)
cls_ids = result.boxes.cls.cpu().numpy().astype(int)
for box, cls_id in zip(boxes, cls_ids):
# 转换为 YOLO 格式 (x_center, y_center, width, height)
x1, y1, x2, y2 = box
x_center = (x1 + x2) / 2
y_center = (y1 + y2) / 2
width = x2 - x1
height = y2 - y1
line = f"{cls_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
f.write(line + "\n")
label_save_path = str(label_file)
# # 8. 显示结果(如果 show 为 True
# if show:
# self.show_results(result)
# cv2.imshow("Segmentation Result", annotated_img)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# 9. 如果提供了 result_save 列表,将本次结果信息追加进去
if result_save is not None:
result_save.append({
'image_path': image_path,
'annotated_image_path': str(img_save_path) if img_save_path else None,
'mask_paths': mask_save_paths,
'label_path': label_save_path,
'num_objects': len(result.boxes) if result.boxes else 0,
})
# 10. 返回结果Ultralytics 的 Results 对象,包含所有推理信息)
return result
def process_single_image_share_dir(self, image_path, user_name, pwd, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = True,
result_save: [] = None) -> None:
result_save: List[Dict] = None) -> None:
"""
处理单张图片
@ -585,7 +667,7 @@ class YOLOSegmentationInference:
def process_image_directory(self, input_dir: str, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = False,
result_save: [] = None) -> List[
result_save: List[Dict] = None) -> List[
InferenceResult]:
"""
处理目录中的所有图片
@ -681,7 +763,7 @@ class YOLOSegmentationInference:
def process_image_directory_share_dir_circle(self, task_id, current_time, input_dir_list, user_name, pwd, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = False,
result_save: [] = None) -> None:
result_save: List[Dict] = None) -> None:
for input_dir in input_dir_list :
self.process_image_directory_share_dir(task_id,current_time,input_dir,user_name,pwd,output_dir,conf_threshold,iou_threshold,save_mask,save_label,show,result_save)
del_file_shutil(output_dir)
@ -689,7 +771,7 @@ class YOLOSegmentationInference:
def process_image_directory_share_dir(self, task_id, current_time, input_dir, user_name, pwd, output_dir: Optional[str] = None,
conf_threshold: float = 0.25, iou_threshold: float = 0.5,
save_mask: bool = False, save_label: bool = False, show: bool = False,
result_save: [] = None) -> None:
result_save: List[Dict] = None) -> None:
"""
处理目录中的所有图片