import os import shutil import random from tqdm import tqdm import yaml def split_img(img_path, label_path, split_list, output_path,class_names=[ 'people', 'car', 'truck', 'bicycle', 'tricycle', 'ship']): try: # 创建目标目录结构 for sub in ['images/train', 'images/val', 'images/test', 'labels/train', 'labels/val', 'labels/test']: os.makedirs(os.path.join(output_path, sub), exist_ok=True) except Exception as e: print(f'❌ 文件目录创建失败: {e}') return train, val, test = split_list all_imgs = [f for f in os.listdir(img_path) if f.endswith(('.jpg', '.png'))] all_img_paths = [os.path.join(img_path, f) for f in all_imgs] # 分配训练集 train_imgs = random.sample(all_img_paths, int(train * len(all_img_paths))) move_set(train_imgs, label_path, os.path.join(output_path, 'images/train'), os.path.join(output_path, 'labels/train')) for f in train_imgs: all_img_paths.remove(f) # 分配验证集 val_imgs = random.sample(all_img_paths, int(val / (val + test) * len(all_img_paths))) move_set(val_imgs, label_path, os.path.join(output_path, 'images/val'), os.path.join(output_path, 'labels/val')) for f in val_imgs: all_img_paths.remove(f) # 剩余分配给测试集 test_imgs = all_img_paths move_set(test_imgs, label_path, os.path.join(output_path, 'images/test'), os.path.join(output_path, 'labels/test')) # 生成 dataset.yaml generate_yaml(output_path, class_names) def move_set(img_list, label_root, dst_img_dir, dst_label_dir): for img_path in tqdm(img_list, desc=f'Copying to {os.path.basename(dst_img_dir)}', ncols=80): base = os.path.splitext(os.path.basename(img_path))[0] label_path = os.path.join(label_root, base + '.txt') shutil.copy(img_path, os.path.join(dst_img_dir, os.path.basename(img_path))) if os.path.exists(label_path): shutil.copy(label_path, os.path.join(dst_label_dir, base + '.txt')) def generate_yaml(dataset_root, class_names): yaml_content = { 'train': os.path.join('images/train'), 'val': os.path.join('images/val'), 'test': os.path.join('images/test'), 'nc': len(class_names), 'names': class_names } with open(os.path.join(dataset_root, 'dataset.yaml'), 'w') as f: yaml.dump(yaml_content, f, default_flow_style=False) print(f"✅ 已生成 YAML: {os.path.join(dataset_root, 'dataset.yaml')}") def broken_main(aim_path, output_path,class_names=[ 'people', 'car', 'truck', 'bicycle', 'tricycle', 'ship']): img_path = os.path.join(aim_path, 'images') label_path = os.path.join(aim_path, 'labels') split_ratio = [0.7, 0.2, 0.1] split_img(img_path, label_path, split_ratio, output_path,class_names) if __name__ == '__main__': broken_main( r"D:\Users\76118\Downloads\stanford_campus_dataset\filtered", r"D:\work\develop\AI\数据集\output", class_names=[ 'people', 'car', 'truck', 'bicycle', 'tricycle',] )