88 lines
3.1 KiB
Python
88 lines
3.1 KiB
Python
import os
|
|
import shutil
|
|
import random
|
|
from tqdm import tqdm
|
|
import yaml
|
|
|
|
def split_img(img_path, label_path, split_list, output_path,class_names=[
|
|
'people',
|
|
'car',
|
|
'truck',
|
|
'bicycle',
|
|
'tricycle',
|
|
'ship']):
|
|
try:
|
|
# 创建目标目录结构
|
|
for sub in ['images/train', 'images/val', 'images/test',
|
|
'labels/train', 'labels/val', 'labels/test']:
|
|
os.makedirs(os.path.join(output_path, sub), exist_ok=True)
|
|
except Exception as e:
|
|
print(f'❌ 文件目录创建失败: {e}')
|
|
return
|
|
|
|
train, val, test = split_list
|
|
all_imgs = [f for f in os.listdir(img_path) if f.endswith(('.jpg', '.png'))]
|
|
all_img_paths = [os.path.join(img_path, f) for f in all_imgs]
|
|
|
|
# 分配训练集
|
|
train_imgs = random.sample(all_img_paths, int(train * len(all_img_paths)))
|
|
move_set(train_imgs, label_path, os.path.join(output_path, 'images/train'), os.path.join(output_path, 'labels/train'))
|
|
for f in train_imgs: all_img_paths.remove(f)
|
|
|
|
# 分配验证集
|
|
val_imgs = random.sample(all_img_paths, int(val / (val + test) * len(all_img_paths)))
|
|
move_set(val_imgs, label_path, os.path.join(output_path, 'images/val'), os.path.join(output_path, 'labels/val'))
|
|
for f in val_imgs: all_img_paths.remove(f)
|
|
|
|
# 剩余分配给测试集
|
|
test_imgs = all_img_paths
|
|
move_set(test_imgs, label_path, os.path.join(output_path, 'images/test'), os.path.join(output_path, 'labels/test'))
|
|
|
|
# 生成 dataset.yaml
|
|
generate_yaml(output_path, class_names)
|
|
|
|
def move_set(img_list, label_root, dst_img_dir, dst_label_dir):
|
|
for img_path in tqdm(img_list, desc=f'Copying to {os.path.basename(dst_img_dir)}', ncols=80):
|
|
base = os.path.splitext(os.path.basename(img_path))[0]
|
|
label_path = os.path.join(label_root, base + '.txt')
|
|
|
|
shutil.copy(img_path, os.path.join(dst_img_dir, os.path.basename(img_path)))
|
|
if os.path.exists(label_path):
|
|
shutil.copy(label_path, os.path.join(dst_label_dir, base + '.txt'))
|
|
|
|
def generate_yaml(dataset_root, class_names):
|
|
yaml_content = {
|
|
'train': os.path.join('images/train'),
|
|
'val': os.path.join('images/val'),
|
|
'test': os.path.join('images/test'),
|
|
'nc': len(class_names),
|
|
'names': class_names
|
|
}
|
|
with open(os.path.join(dataset_root, 'dataset.yaml'), 'w') as f:
|
|
yaml.dump(yaml_content, f, default_flow_style=False)
|
|
print(f"✅ 已生成 YAML: {os.path.join(dataset_root, 'dataset.yaml')}")
|
|
|
|
def broken_main(aim_path, output_path,class_names=[
|
|
'people',
|
|
'car',
|
|
'truck',
|
|
'bicycle',
|
|
'tricycle',
|
|
'ship']):
|
|
img_path = os.path.join(aim_path, 'images')
|
|
label_path = os.path.join(aim_path, 'labels')
|
|
split_ratio = [0.7, 0.2, 0.1]
|
|
split_img(img_path, label_path, split_ratio, output_path,class_names)
|
|
|
|
if __name__ == '__main__':
|
|
broken_main(
|
|
r"D:\Users\76118\Downloads\stanford_campus_dataset\filtered",
|
|
r"D:\work\develop\AI\数据集\output",
|
|
class_names=[
|
|
'people',
|
|
'car',
|
|
'truck',
|
|
'bicycle',
|
|
'tricycle',]
|
|
)
|