88 lines
3.1 KiB
Python
Raw Permalink Normal View History

2025-11-11 09:43:25 +08:00
import os
import shutil
import random
from tqdm import tqdm
import yaml
def split_img(img_path, label_path, split_list, output_path,class_names=[
'people',
'car',
'truck',
'bicycle',
'tricycle',
'ship']):
try:
# 创建目标目录结构
for sub in ['images/train', 'images/val', 'images/test',
'labels/train', 'labels/val', 'labels/test']:
os.makedirs(os.path.join(output_path, sub), exist_ok=True)
except Exception as e:
print(f'❌ 文件目录创建失败: {e}')
return
train, val, test = split_list
all_imgs = [f for f in os.listdir(img_path) if f.endswith(('.jpg', '.png'))]
all_img_paths = [os.path.join(img_path, f) for f in all_imgs]
# 分配训练集
train_imgs = random.sample(all_img_paths, int(train * len(all_img_paths)))
move_set(train_imgs, label_path, os.path.join(output_path, 'images/train'), os.path.join(output_path, 'labels/train'))
for f in train_imgs: all_img_paths.remove(f)
# 分配验证集
val_imgs = random.sample(all_img_paths, int(val / (val + test) * len(all_img_paths)))
move_set(val_imgs, label_path, os.path.join(output_path, 'images/val'), os.path.join(output_path, 'labels/val'))
for f in val_imgs: all_img_paths.remove(f)
# 剩余分配给测试集
test_imgs = all_img_paths
move_set(test_imgs, label_path, os.path.join(output_path, 'images/test'), os.path.join(output_path, 'labels/test'))
# 生成 dataset.yaml
generate_yaml(output_path, class_names)
def move_set(img_list, label_root, dst_img_dir, dst_label_dir):
for img_path in tqdm(img_list, desc=f'Copying to {os.path.basename(dst_img_dir)}', ncols=80):
base = os.path.splitext(os.path.basename(img_path))[0]
label_path = os.path.join(label_root, base + '.txt')
shutil.copy(img_path, os.path.join(dst_img_dir, os.path.basename(img_path)))
if os.path.exists(label_path):
shutil.copy(label_path, os.path.join(dst_label_dir, base + '.txt'))
def generate_yaml(dataset_root, class_names):
yaml_content = {
'train': os.path.join('images/train'),
'val': os.path.join('images/val'),
'test': os.path.join('images/test'),
'nc': len(class_names),
'names': class_names
}
with open(os.path.join(dataset_root, 'dataset.yaml'), 'w') as f:
yaml.dump(yaml_content, f, default_flow_style=False)
print(f"✅ 已生成 YAML: {os.path.join(dataset_root, 'dataset.yaml')}")
def broken_main(aim_path, output_path,class_names=[
'people',
'car',
'truck',
'bicycle',
'tricycle',
'ship']):
img_path = os.path.join(aim_path, 'images')
label_path = os.path.join(aim_path, 'labels')
split_ratio = [0.7, 0.2, 0.1]
split_img(img_path, label_path, split_ratio, output_path,class_names)
if __name__ == '__main__':
broken_main(
r"D:\Users\76118\Downloads\stanford_campus_dataset\filtered",
r"D:\work\develop\AI\数据集\output",
class_names=[
'people',
'car',
'truck',
'bicycle',
'tricycle',]
)