修改为yolo官方提供的输出文件格式,保存标签文件时直接从yolo的临时输出目录复制txt文件,而不是重新生成标签文件。
This commit is contained in:
parent
f245d972c2
commit
540a20e15b
@ -91,6 +91,7 @@ class InferenceResult:
|
|||||||
self.scores = [] # 置信度列表
|
self.scores = [] # 置信度列表
|
||||||
self.class_names = [] # 类别名称列表
|
self.class_names = [] # 类别名称列表
|
||||||
self.inference_time = 0.0 # 推理时间(秒)
|
self.inference_time = 0.0 # 推理时间(秒)
|
||||||
|
self.temp_label_path = None # YOLO临时输出目录路径
|
||||||
|
|
||||||
|
|
||||||
class YOLOSegmentationInference:
|
class YOLOSegmentationInference:
|
||||||
@ -290,8 +291,12 @@ class YOLOSegmentationInference:
|
|||||||
print(f"正在处理图片: {os.path.basename(image_path)}")
|
print(f"正在处理图片: {os.path.basename(image_path)}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 使用YOLO模型进行推理
|
# 使用YOLO模型进行推理(启用save_txt自动生成标签)
|
||||||
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold)[0]
|
temp_output_dir = os.path.join(os.path.dirname(image_path), ".yolo_temp")
|
||||||
|
predictions = self.model(image_path, conf=conf_threshold, iou=iou_threshold,
|
||||||
|
save=False, save_txt=True, project=temp_output_dir,
|
||||||
|
name="labels", exist_ok=True)[0]
|
||||||
|
result.temp_label_path = temp_output_dir
|
||||||
|
|
||||||
result.inference_time = time.time() - start_time
|
result.inference_time = time.time() - start_time
|
||||||
|
|
||||||
@ -393,7 +398,7 @@ class YOLOSegmentationInference:
|
|||||||
result: 推理结果
|
result: 推理结果
|
||||||
output_dir: 输出目录
|
output_dir: 输出目录
|
||||||
save_mask: 是否保存单独的掩码文件
|
save_mask: 是否保存单独的掩码文件
|
||||||
save_label: 是否保存YOLO格式的标签文件
|
save_label: 是否保存YOLO格式的标签文件(使用YOLO原生格式)
|
||||||
"""
|
"""
|
||||||
if result.result_image is None:
|
if result.result_image is None:
|
||||||
return
|
return
|
||||||
@ -432,50 +437,26 @@ class YOLOSegmentationInference:
|
|||||||
|
|
||||||
print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}")
|
print(f"共保存 {len(result.masks)} 个掩码文件到: {mask_dir}")
|
||||||
|
|
||||||
# 保存YOLO格式的标签文件
|
# 保存YOLO格式的标签文件(直接从YOLO输出复制)
|
||||||
if save_label and result.masks is not None and len(result.masks) > 0 and len(result.boxes) > 0:
|
if save_label and result.masks is not None and len(result.masks) > 0:
|
||||||
# label_dir = os.path.join(output_dir, "labels")
|
|
||||||
label_dir = output_dir
|
label_dir = output_dir
|
||||||
os.makedirs(label_dir, exist_ok=True)
|
os.makedirs(label_dir, exist_ok=True)
|
||||||
|
|
||||||
label_path = os.path.join(label_dir, f"{base_name}.txt")
|
label_path = os.path.join(label_dir, f"{base_name}.txt")
|
||||||
|
|
||||||
with open(label_path, 'w') as f:
|
# 从YOLO临时输出目录复制txt文件
|
||||||
for i in range(len(result.masks)):
|
if hasattr(result, 'temp_label_path') and result.temp_label_path:
|
||||||
class_id = result.classes[i]
|
yolo_label_path = os.path.join(result.temp_label_path, "labels", f"{base_name}.txt")
|
||||||
score = result.scores[i]
|
if os.path.exists(yolo_label_path):
|
||||||
mask = result.masks[i]
|
shutil.copy(yolo_label_path, label_path)
|
||||||
|
|
||||||
# 获取掩码的多边形轮廓
|
|
||||||
contours, _ = cv2.findContours((mask > 0.5).astype(np.uint8),
|
|
||||||
cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
if contours:
|
|
||||||
# 取最大的轮廓
|
|
||||||
contour = max(contours, key=cv2.contourArea)
|
|
||||||
|
|
||||||
# 简化轮廓
|
|
||||||
epsilon = 0.001 * cv2.arcLength(contour, True)
|
|
||||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
|
||||||
|
|
||||||
# 归一化坐标
|
|
||||||
h, w = mask.shape
|
|
||||||
points = []
|
|
||||||
for point in approx:
|
|
||||||
x = max(0.0, min(1.0, point[0][0] / w))
|
|
||||||
y = max(0.0, min(1.0, point[0][1] / h))
|
|
||||||
points.extend([x, y])
|
|
||||||
|
|
||||||
# 写入标签文件
|
|
||||||
if len(points) >= 6: # 至少3个点
|
|
||||||
line = f"{class_id} {' '.join(map(lambda x: f'{x:.6f}', points))}\n"
|
|
||||||
f.write(line)
|
|
||||||
|
|
||||||
print(f"标签文件已保存: {label_path}")
|
print(f"标签文件已保存: {label_path}")
|
||||||
redir_obj["label"] = label_path
|
redir_obj["label"] = label_path
|
||||||
|
else:
|
||||||
|
print(f"警告: YOLO生成的标签文件不存在: {yolo_label_path}")
|
||||||
|
else:
|
||||||
|
print(f"警告: 未找到YOLO临时输出目录")
|
||||||
result_save.append(redir_obj)
|
result_save.append(redir_obj)
|
||||||
|
|
||||||
|
|
||||||
except PermissionError:
|
except PermissionError:
|
||||||
print(f"权限错误: 无法写入到目录 {output_dir}")
|
print(f"权限错误: 无法写入到目录 {output_dir}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user