250 lines
8.9 KiB
Python
Raw Normal View History

2025-11-11 09:46:49 +08:00
import os
import re
import cv2
import numpy as np
import shutil
from concurrent.futures import ThreadPoolExecutor
# ------------------ 工具函数 ------------------
def clean_filename(name):
"""去掉空格/换行/制表符,小写化"""
name = name.strip()
name = re.sub(r'[\s\r\n\t]+', '', name)
return name.lower()
def num_to_coord(num, cols, cell_width, cell_height, offset=1):
n = num - 1 + offset
r = n // cols
c = n % cols
x1 = c * cell_width
y1 = r * cell_height
x2 = x1 + cell_width
y2 = y1 + cell_height
return x1, y1, x2, y2
def polygon_to_yolo(poly, img_width, img_height):
flat = [coord for point in poly for coord in point]
return [flat[i] / (img_width if i % 2 == 0 else img_height) for i in range(len(flat))]
def convex_hull_poly(points):
if not points:
return []
pts = np.array(points, dtype=np.int32)
hull = cv2.convexHull(pts)
return hull.reshape(-1, 2).tolist()
color_map = {
0: (0, 255, 255),
1: (255, 0, 255),
2: (0, 255, 0),
3: (255, 0, 0),
4: (0, 0, 255),
5: (255, 255, 0),
6: (128, 128, 0),
7: (128, 0, 128),
8: (0, 128, 128),
9: (128, 128, 128),
10: (0, 0, 128),
11: (0, 128, 0)
}
# ------------------ 匹配图片 ------------------
def find_matching_image(txt_path, input_root):
"""
强力匹配
- 去掉 _PartClass
- 去掉 .txt
- 如果有 .jpg TXT 名里也去掉
- 模糊匹配核心名和图片名
"""
txt_name = os.path.basename(txt_path).lower()
# 去掉 _partclass 和 .txt
base_name = re.sub(r'(_partclass)?\.txt$', '', txt_name)
# 再去掉可能残留的 .jpg
base_name = re.sub(r'\.jpg$', '', base_name)
for root, _, files in os.walk(input_root):
for f in files:
if f.lower().endswith((".jpg", ".jpeg", ".png")):
img_base = os.path.splitext(f)[0].lower()
if base_name == img_base:
return os.path.join(root, f)
return None
# ------------------ 处理函数 ------------------
def process_pixel_txt(img_path, txt_path, class_map, output_root):
image = cv2.imread(img_path)
if image is None:
return False
h, w = image.shape[:2]
vis_img = image.copy()
yolo_labels = []
unknown_labels = set()
with open(txt_path, "r", encoding="utf-8") as f:
for line in f:
parts = line.strip().split()
if len(parts) < 5:
continue
try:
x, y, w_box, h_box = map(int, parts[:4])
except:
continue
label = parts[4]
cls_id = class_map.get(label, -1)
if cls_id == -1:
unknown_labels.add(label)
continue
poly = [(x, y), (x+w_box, y), (x+w_box, y+h_box), (x, y+h_box)]
hull = convex_hull_poly(poly)
yolo_labels.append(f"{cls_id} " + " ".join(map(str, polygon_to_yolo(hull, w, h))))
cv2.polylines(vis_img, [np.array(hull, np.int32)], True,
color=color_map.get(cls_id,(255,255,255)), thickness=2)
if unknown_labels:
print(f"⚠️ 未知类别 {unknown_labels} 在文件: {txt_path}")
if not yolo_labels:
return False
base = os.path.splitext(os.path.basename(img_path))[0]
os.makedirs(os.path.join(output_root,"images"), exist_ok=True)
os.makedirs(os.path.join(output_root,"labels"), exist_ok=True)
os.makedirs(os.path.join(output_root,"visual"), exist_ok=True)
shutil.copy2(img_path, os.path.join(output_root,"images", os.path.basename(img_path)))
with open(os.path.join(output_root,"labels", base+".txt"), "w", encoding="utf-8") as f:
f.write("\n".join(yolo_labels))
cv2.imwrite(os.path.join(output_root,"visual", base+"-visual.jpg"), vis_img)
print(f"✅ 已处理像素点 TXT: {base}")
return True
def process_grid_txt(img_path, txt_path, class_map, output_root):
image = cv2.imread(img_path)
if image is None:
return False
h, w = image.shape[:2]
cell_width, cell_height = 108, 102
cols = max(1, w // cell_width)
vis_img = image.copy()
overlay = image.copy()
alpha = 0.5
yolo_labels = []
with open(txt_path,"r",encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
numbers = re.findall(r"(\d+)(?=-|$)", line.split()[-1])
numbers = [int(n) for n in numbers]
cname = None
for key in class_map.keys():
if line.startswith(key):
cname = key
break
if cname is None or not numbers:
continue
for num in numbers:
x1, y1, x2, y2 = num_to_coord(num, cols, cell_width, cell_height)
cv2.rectangle(overlay, (x1,y1), (x2,y2), color_map.get(class_map[cname],(128,128,128)),-1)
cv2.addWeighted(overlay, alpha, image, 1-alpha, 0, image)
points = []
for num in numbers:
x1, y1, x2, y2 = num_to_coord(num, cols, cell_width, cell_height)
points.extend([(x1,y1),(x2,y1),(x2,y2),(x1,y2)])
hull = convex_hull_poly(points)
cls_id = class_map[cname]
pts = np.array(hull, np.int32).reshape((-1,1,2))
cv2.polylines(vis_img, [pts], True, color_map.get(cls_id,(128,128,128)), 2)
yolo_labels.append(f"{cls_id} " + " ".join(map(str, polygon_to_yolo(hull, w, h))))
if not yolo_labels:
return False
base = os.path.splitext(os.path.basename(img_path))[0]
shutil.copy2(img_path, os.path.join(output_root,"images", os.path.basename(img_path)))
with open(os.path.join(output_root,"labels", base+".txt"), "w", encoding="utf-8") as f:
f.write("\n".join(yolo_labels))
cv2.imwrite(os.path.join(output_root,"visual", base+"-visual.jpg"), vis_img)
cv2.imwrite(os.path.join(output_root,"highlighted", base+"-highlighted.jpg"), image)
return True
# ------------------ 批量处理 ------------------
def batch_process_txt_first(input_root, class_map, output_root="output", max_workers=4):
os.makedirs(os.path.join(output_root,"images"), exist_ok=True)
os.makedirs(os.path.join(output_root,"labels"), exist_ok=True)
os.makedirs(os.path.join(output_root,"visual"), exist_ok=True)
os.makedirs(os.path.join(output_root,"highlighted"), exist_ok=True)
# 收集所有 TXT 文件
txt_files = []
for root, _, files in os.walk(input_root):
for file in files:
if file.lower().endswith(".txt"):
txt_files.append(os.path.join(root, file))
success_count, fail_count = 0, 0
log_lines = []
fail_logs = []
def process_single(txt_path):
nonlocal success_count, fail_count
img_path = find_matching_image(txt_path, input_root)
if img_path:
try:
if "_partclass" in txt_path.lower():
status = process_grid_txt(img_path, txt_path, class_map, output_root)
log_lines.append(f"{os.path.basename(txt_path)} -> Grid TXT processed with {os.path.basename(img_path)}")
else:
status = process_pixel_txt(img_path, txt_path, class_map, output_root)
log_lines.append(f"{os.path.basename(txt_path)} -> Pixel TXT processed with {os.path.basename(img_path)}")
if status:
success_count += 1
else:
fail_count += 1
fail_logs.append(f"{os.path.basename(txt_path)} -> Processed but no valid labels generated")
except Exception as e:
fail_count += 1
fail_logs.append(f"{os.path.basename(txt_path)} -> Processing error: {e}")
else:
fail_count += 1
fail_logs.append(f"{os.path.basename(txt_path)} -> No matching image found")
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(process_single, txt_files)
# 写入日志
log_file = os.path.join(output_root, "process_log.txt")
with open(log_file, "w", encoding="utf-8") as f:
f.write("\n".join(log_lines + ["\n失败文件:"] + fail_logs))
print(f"\n✅ 批量处理完成: 成功 {success_count}, 失败 {fail_count}")
if fail_logs:
print("⚠️ 失败文件及原因如下:")
for line in fail_logs:
print(line)
print(f"📄 处理日志已保存: {log_file}")
# ------------------ 主程序 ------------------
if __name__ == "__main__":
input_root = r"D:\work\develop\LF-where\01"
output_root = r"D:\work\develop\LF-where\out"
class_map = {
"裂缝": 0,
"横向裂缝": 1,
"纵向裂缝": 2,
"修补": 3,
"坑洞": 4,
"网裂": 5,
"破碎板":6,
}
batch_process_txt_first(input_root, class_map, output_root, max_workers=8)