import os from glob import glob from aboutdataset.download_oss import download_and_save_images_from_oss from train.let_txt_to_true import process_files_in_folder from train.broken import broken_main from ultralytics import YOLO import torch # ------------------ 下载图片和标签 ------------------ def download_images_and_labels( config_name, # OSS 配置文件名,用于读取连接信息 table_name, # OSS 表名,指定下载数据的表 column_name, # OSS 表中图片 URL 列名 search_condition, # 筛选条件,用于查询 OSS 数据 aim_path, # 本地保存数据集根目录 image_dir, # 本地保存图片的目录 label_dir # 本地保存标签 txt 的目录 ): os.makedirs(aim_path, exist_ok=True) os.makedirs(image_dir, exist_ok=True) os.makedirs(label_dir, exist_ok=True) download_and_save_images_from_oss( yaml_name=config_name, where_clause=f"{column_name} = '{search_condition}'", image_dir=image_dir, label_dir=label_dir, table_name=table_name, ) return aim_path, image_dir, label_dir # ------------------ 标签修正与数据打乱 ------------------ def broken_and_convert_txt_to_yolo_format( aim_path, # 数据集根目录 output_path, # 打乱并输出后的数据集目录 image_dir, # 图片目录 label_dir, # 标签目录 class_names # 数据集类别列表 ): process_files_in_folder(label_dir) # 修正标签为 YOLO 格式 broken_main(aim_path, output_path, class_names) # 打乱数据集并生成 dataset.yaml yaml_path = os.path.join(output_path, 'dataset.yaml') return output_path, yaml_path # ------------------ 获取最新 pt 模型 ------------------ def get_latest_pt(project_dir, pt_path): """ 检查指定训练输出目录是否有最新 .pt 模型文件。 若存在则返回最新文件路径,否则返回传入的 pt_path。 """ if not os.path.exists(project_dir): print(f"[INFO] 项目目录 {project_dir} 不存在,使用传入模型 {pt_path}") return pt_path pt_files = glob(os.path.join(project_dir, "*.pt")) if not pt_files: print(f"[INFO] 目录中无 pt 文件,使用传入模型 {pt_path}") return pt_path latest_pt = max(pt_files, key=os.path.getmtime) print(f"[INFO] 检测到最新模型: {latest_pt}") return latest_pt # ------------------ 训练 ------------------ def train( yaml_path, # YOLO 数据集配置文件路径 pt_path, # 用于训练的初始权重 .pt 文件路径 imgsz, # 输入图片分辨率 epochs, # 训练轮次 device, # GPU 设备索引列表,例如 [0] 或 [0,1] hsv_v, # 图像亮度增强系数 cos_lr, # 是否使用余弦学习率 batch, # 批量大小 project_dir # 训练输出目录(模型权重、日志等) ): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INFO] Using device: {device}") pt_path = get_latest_pt(project_dir, pt_path) # 自动检测最新 pt 文件 model = YOLO(pt_path).to(device) model.train( data=yaml_path, epochs=epochs, imgsz=imgsz, device=device, hsv_v=hsv_v, cos_lr=cos_lr, batch=batch, project=project_dir, ) # ------------------ 主流程 ------------------ def train_main( # OSS 下载参数 config_name, # sql 配置文件名 table_name, # sql 表名 column_name, # sql 表中列名 search_condition, # sql 数据筛选条件 # 数据集路径 aim_path, # 本地数据集根目录,打乱后的 image_dir, # 本地图片保存目录 label_dir, # 本地标签保存目录 output_path, # 打乱并输出后的数据集目录 # YOLO 训练参数 pt_path, # 初始权重文件路径 imgsz, # 输入图片分辨率 epochs, # 训练轮次 device, # GPU 设备索引列表 hsv_v, # 图像亮度增强系数 cos_lr, # 是否使用余弦学习率 batch, # 批量大小 project_dir, # 训练输出目录 # 类别 class_names # 数据集类别列表 ): aim_path, image_dir, label_dir = download_images_and_labels( config_name, table_name, column_name, search_condition, aim_path, image_dir, label_dir ) output_path, yaml_path = broken_and_convert_txt_to_yolo_format( aim_path, output_path, image_dir, label_dir, class_names ) train( yaml_path=yaml_path, pt_path=pt_path, imgsz=imgsz, epochs=epochs, device=device, hsv_v=hsv_v, cos_lr=cos_lr, batch=batch, project_dir=project_dir ) # ------------------ 执行 ------------------ if __name__ == "__main__": train_main( config_name="config", table_name="aidataset", column_name="image_url", search_condition="your_search_id", aim_path="./datasets/aidataset_dataset", image_dir="./dataset/aidataset_dataset_images", label_dir="./dataset/aidataset_dataset_labels", output_path="./my_dataset", pt_path="custom_model.pt", imgsz=800, epochs=500, device=[0], hsv_v=0.3, cos_lr=True, batch=8, project_dir="./my_train_runs", class_names=['person','car'] )