2025-11-11 09:43:25 +08:00

167 lines
5.5 KiB
Python

import os
from glob import glob
from aboutdataset.download_oss import download_and_save_images_from_oss
from train.let_txt_to_true import process_files_in_folder
from train.broken import broken_main
from ultralytics import YOLO
import torch
# ------------------ 下载图片和标签 ------------------
def download_images_and_labels(
config_name, # OSS 配置文件名,用于读取连接信息
table_name, # OSS 表名,指定下载数据的表
column_name, # OSS 表中图片 URL 列名
search_condition, # 筛选条件,用于查询 OSS 数据
aim_path, # 本地保存数据集根目录
image_dir, # 本地保存图片的目录
label_dir # 本地保存标签 txt 的目录
):
os.makedirs(aim_path, exist_ok=True)
os.makedirs(image_dir, exist_ok=True)
os.makedirs(label_dir, exist_ok=True)
download_and_save_images_from_oss(
yaml_name=config_name,
where_clause=f"{column_name} = '{search_condition}'",
image_dir=image_dir,
label_dir=label_dir,
table_name=table_name,
)
return aim_path, image_dir, label_dir
# ------------------ 标签修正与数据打乱 ------------------
def broken_and_convert_txt_to_yolo_format(
aim_path, # 数据集根目录
output_path, # 打乱并输出后的数据集目录
image_dir, # 图片目录
label_dir, # 标签目录
class_names # 数据集类别列表
):
process_files_in_folder(label_dir) # 修正标签为 YOLO 格式
broken_main(aim_path, output_path, class_names) # 打乱数据集并生成 dataset.yaml
yaml_path = os.path.join(output_path, 'dataset.yaml')
return output_path, yaml_path
# ------------------ 获取最新 pt 模型 ------------------
def get_latest_pt(project_dir, pt_path):
"""
检查指定训练输出目录是否有最新 .pt 模型文件。
若存在则返回最新文件路径,否则返回传入的 pt_path。
"""
if not os.path.exists(project_dir):
print(f"[INFO] 项目目录 {project_dir} 不存在,使用传入模型 {pt_path}")
return pt_path
pt_files = glob(os.path.join(project_dir, "*.pt"))
if not pt_files:
print(f"[INFO] 目录中无 pt 文件,使用传入模型 {pt_path}")
return pt_path
latest_pt = max(pt_files, key=os.path.getmtime)
print(f"[INFO] 检测到最新模型: {latest_pt}")
return latest_pt
# ------------------ 训练 ------------------
def train(
yaml_path, # YOLO 数据集配置文件路径
pt_path, # 用于训练的初始权重 .pt 文件路径
imgsz, # 输入图片分辨率
epochs, # 训练轮次
device, # GPU 设备索引列表,例如 [0] 或 [0,1]
hsv_v, # 图像亮度增强系数
cos_lr, # 是否使用余弦学习率
batch, # 批量大小
project_dir # 训练输出目录(模型权重、日志等)
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")
pt_path = get_latest_pt(project_dir, pt_path) # 自动检测最新 pt 文件
model = YOLO(pt_path).to(device)
model.train(
data=yaml_path,
epochs=epochs,
imgsz=imgsz,
device=device,
hsv_v=hsv_v,
cos_lr=cos_lr,
batch=batch,
project=project_dir,
)
# ------------------ 主流程 ------------------
def train_main(
# OSS 下载参数
config_name, # sql 配置文件名
table_name, # sql 表名
column_name, # sql 表中列名
search_condition, # sql 数据筛选条件
# 数据集路径
aim_path, # 本地数据集根目录,打乱后的
image_dir, # 本地图片保存目录
label_dir, # 本地标签保存目录
output_path, # 打乱并输出后的数据集目录
# YOLO 训练参数
pt_path, # 初始权重文件路径
imgsz, # 输入图片分辨率
epochs, # 训练轮次
device, # GPU 设备索引列表
hsv_v, # 图像亮度增强系数
cos_lr, # 是否使用余弦学习率
batch, # 批量大小
project_dir, # 训练输出目录
# 类别
class_names # 数据集类别列表
):
aim_path, image_dir, label_dir = download_images_and_labels(
config_name, table_name, column_name, search_condition,
aim_path, image_dir, label_dir
)
output_path, yaml_path = broken_and_convert_txt_to_yolo_format(
aim_path, output_path, image_dir, label_dir, class_names
)
train(
yaml_path=yaml_path,
pt_path=pt_path,
imgsz=imgsz,
epochs=epochs,
device=device,
hsv_v=hsv_v,
cos_lr=cos_lr,
batch=batch,
project_dir=project_dir
)
# ------------------ 执行 ------------------
if __name__ == "__main__":
train_main(
config_name="config",
table_name="aidataset",
column_name="image_url",
search_condition="your_search_id",
aim_path="./datasets/aidataset_dataset",
image_dir="./dataset/aidataset_dataset_images",
label_dir="./dataset/aidataset_dataset_labels",
output_path="./my_dataset",
pt_path="custom_model.pt",
imgsz=800,
epochs=500,
device=[0],
hsv_v=0.3,
cos_lr=True,
batch=8,
project_dir="./my_train_runs",
class_names=['person','car']
)