2025-07-10 09:41:26 +08:00
|
|
|
|
import os
|
2025-11-11 09:43:25 +08:00
|
|
|
|
from glob import glob
|
2025-10-09 09:29:18 +08:00
|
|
|
|
from aboutdataset.download_oss import download_and_save_images_from_oss
|
2025-11-11 09:43:25 +08:00
|
|
|
|
from train.let_txt_to_true import process_files_in_folder
|
|
|
|
|
|
from train.broken import broken_main
|
|
|
|
|
|
from ultralytics import YOLO
|
|
|
|
|
|
import torch
|
2025-09-26 18:03:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
# ------------------ 下载图片和标签 ------------------
|
|
|
|
|
|
def download_images_and_labels(
|
|
|
|
|
|
config_name, # OSS 配置文件名,用于读取连接信息
|
|
|
|
|
|
table_name, # OSS 表名,指定下载数据的表
|
|
|
|
|
|
column_name, # OSS 表中图片 URL 列名
|
|
|
|
|
|
search_condition, # 筛选条件,用于查询 OSS 数据
|
|
|
|
|
|
aim_path, # 本地保存数据集根目录
|
|
|
|
|
|
image_dir, # 本地保存图片的目录
|
|
|
|
|
|
label_dir # 本地保存标签 txt 的目录
|
|
|
|
|
|
):
|
|
|
|
|
|
os.makedirs(aim_path, exist_ok=True)
|
|
|
|
|
|
os.makedirs(image_dir, exist_ok=True)
|
|
|
|
|
|
os.makedirs(label_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
download_and_save_images_from_oss(
|
|
|
|
|
|
yaml_name=config_name,
|
|
|
|
|
|
where_clause=f"{column_name} = '{search_condition}'",
|
|
|
|
|
|
image_dir=image_dir,
|
|
|
|
|
|
label_dir=label_dir,
|
|
|
|
|
|
table_name=table_name,
|
|
|
|
|
|
)
|
2025-09-26 18:03:49 +08:00
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
return aim_path, image_dir, label_dir
|
2025-09-26 18:03:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
# ------------------ 标签修正与数据打乱 ------------------
|
|
|
|
|
|
def broken_and_convert_txt_to_yolo_format(
|
|
|
|
|
|
aim_path, # 数据集根目录
|
|
|
|
|
|
output_path, # 打乱并输出后的数据集目录
|
|
|
|
|
|
image_dir, # 图片目录
|
|
|
|
|
|
label_dir, # 标签目录
|
|
|
|
|
|
class_names # 数据集类别列表
|
|
|
|
|
|
):
|
|
|
|
|
|
process_files_in_folder(label_dir) # 修正标签为 YOLO 格式
|
|
|
|
|
|
broken_main(aim_path, output_path, class_names) # 打乱数据集并生成 dataset.yaml
|
|
|
|
|
|
yaml_path = os.path.join(output_path, 'dataset.yaml')
|
|
|
|
|
|
return output_path, yaml_path
|
2025-10-09 09:29:18 +08:00
|
|
|
|
|
2025-09-26 18:03:49 +08:00
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
# ------------------ 获取最新 pt 模型 ------------------
|
|
|
|
|
|
def get_latest_pt(project_dir, pt_path):
|
2025-10-09 09:29:18 +08:00
|
|
|
|
"""
|
2025-11-11 09:43:25 +08:00
|
|
|
|
检查指定训练输出目录是否有最新 .pt 模型文件。
|
|
|
|
|
|
若存在则返回最新文件路径,否则返回传入的 pt_path。
|
2025-10-09 09:29:18 +08:00
|
|
|
|
"""
|
2025-11-11 09:43:25 +08:00
|
|
|
|
if not os.path.exists(project_dir):
|
|
|
|
|
|
print(f"[INFO] 项目目录 {project_dir} 不存在,使用传入模型 {pt_path}")
|
|
|
|
|
|
return pt_path
|
|
|
|
|
|
|
|
|
|
|
|
pt_files = glob(os.path.join(project_dir, "*.pt"))
|
|
|
|
|
|
if not pt_files:
|
|
|
|
|
|
print(f"[INFO] 目录中无 pt 文件,使用传入模型 {pt_path}")
|
|
|
|
|
|
return pt_path
|
|
|
|
|
|
|
|
|
|
|
|
latest_pt = max(pt_files, key=os.path.getmtime)
|
|
|
|
|
|
print(f"[INFO] 检测到最新模型: {latest_pt}")
|
|
|
|
|
|
return latest_pt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------ 训练 ------------------
|
|
|
|
|
|
def train(
|
|
|
|
|
|
yaml_path, # YOLO 数据集配置文件路径
|
|
|
|
|
|
pt_path, # 用于训练的初始权重 .pt 文件路径
|
|
|
|
|
|
imgsz, # 输入图片分辨率
|
|
|
|
|
|
epochs, # 训练轮次
|
|
|
|
|
|
device, # GPU 设备索引列表,例如 [0] 或 [0,1]
|
|
|
|
|
|
hsv_v, # 图像亮度增强系数
|
|
|
|
|
|
cos_lr, # 是否使用余弦学习率
|
|
|
|
|
|
batch, # 批量大小
|
|
|
|
|
|
project_dir # 训练输出目录(模型权重、日志等)
|
|
|
|
|
|
):
|
2025-07-10 09:41:26 +08:00
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
2025-11-11 09:43:25 +08:00
|
|
|
|
print(f"[INFO] Using device: {device}")
|
|
|
|
|
|
|
|
|
|
|
|
pt_path = get_latest_pt(project_dir, pt_path) # 自动检测最新 pt 文件
|
2025-07-10 09:41:26 +08:00
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
model = YOLO(pt_path).to(device)
|
2025-07-10 09:41:26 +08:00
|
|
|
|
|
|
|
|
|
|
model.train(
|
|
|
|
|
|
data=yaml_path,
|
2025-11-11 09:43:25 +08:00
|
|
|
|
epochs=epochs,
|
|
|
|
|
|
imgsz=imgsz,
|
|
|
|
|
|
device=device,
|
|
|
|
|
|
hsv_v=hsv_v,
|
|
|
|
|
|
cos_lr=cos_lr,
|
|
|
|
|
|
batch=batch,
|
|
|
|
|
|
project=project_dir,
|
2025-07-10 09:41:26 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
# ------------------ 主流程 ------------------
|
|
|
|
|
|
def train_main(
|
|
|
|
|
|
# OSS 下载参数
|
|
|
|
|
|
config_name, # sql 配置文件名
|
|
|
|
|
|
table_name, # sql 表名
|
|
|
|
|
|
column_name, # sql 表中列名
|
|
|
|
|
|
search_condition, # sql 数据筛选条件
|
|
|
|
|
|
# 数据集路径
|
|
|
|
|
|
aim_path, # 本地数据集根目录,打乱后的
|
|
|
|
|
|
image_dir, # 本地图片保存目录
|
|
|
|
|
|
label_dir, # 本地标签保存目录
|
|
|
|
|
|
output_path, # 打乱并输出后的数据集目录
|
|
|
|
|
|
# YOLO 训练参数
|
|
|
|
|
|
pt_path, # 初始权重文件路径
|
|
|
|
|
|
imgsz, # 输入图片分辨率
|
|
|
|
|
|
epochs, # 训练轮次
|
|
|
|
|
|
device, # GPU 设备索引列表
|
|
|
|
|
|
hsv_v, # 图像亮度增强系数
|
|
|
|
|
|
cos_lr, # 是否使用余弦学习率
|
|
|
|
|
|
batch, # 批量大小
|
|
|
|
|
|
project_dir, # 训练输出目录
|
|
|
|
|
|
# 类别
|
|
|
|
|
|
class_names # 数据集类别列表
|
|
|
|
|
|
):
|
|
|
|
|
|
aim_path, image_dir, label_dir = download_images_and_labels(
|
|
|
|
|
|
config_name, table_name, column_name, search_condition,
|
|
|
|
|
|
aim_path, image_dir, label_dir
|
|
|
|
|
|
)
|
2025-07-10 09:41:26 +08:00
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
output_path, yaml_path = broken_and_convert_txt_to_yolo_format(
|
|
|
|
|
|
aim_path, output_path, image_dir, label_dir, class_names
|
|
|
|
|
|
)
|
2025-07-10 09:41:26 +08:00
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
train(
|
|
|
|
|
|
yaml_path=yaml_path,
|
|
|
|
|
|
pt_path=pt_path,
|
|
|
|
|
|
imgsz=imgsz,
|
|
|
|
|
|
epochs=epochs,
|
|
|
|
|
|
device=device,
|
|
|
|
|
|
hsv_v=hsv_v,
|
|
|
|
|
|
cos_lr=cos_lr,
|
|
|
|
|
|
batch=batch,
|
|
|
|
|
|
project_dir=project_dir
|
|
|
|
|
|
)
|
2025-07-10 09:41:26 +08:00
|
|
|
|
|
2025-09-02 10:23:21 +08:00
|
|
|
|
|
2025-11-11 09:43:25 +08:00
|
|
|
|
# ------------------ 执行 ------------------
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
train_main(
|
|
|
|
|
|
config_name="config",
|
|
|
|
|
|
table_name="aidataset",
|
|
|
|
|
|
column_name="image_url",
|
|
|
|
|
|
search_condition="your_search_id",
|
|
|
|
|
|
aim_path="./datasets/aidataset_dataset",
|
|
|
|
|
|
image_dir="./dataset/aidataset_dataset_images",
|
|
|
|
|
|
label_dir="./dataset/aidataset_dataset_labels",
|
|
|
|
|
|
output_path="./my_dataset",
|
|
|
|
|
|
pt_path="custom_model.pt",
|
|
|
|
|
|
imgsz=800,
|
|
|
|
|
|
epochs=500,
|
|
|
|
|
|
device=[0],
|
|
|
|
|
|
hsv_v=0.3,
|
|
|
|
|
|
cos_lr=True,
|
|
|
|
|
|
batch=8,
|
|
|
|
|
|
project_dir="./my_train_runs",
|
|
|
|
|
|
class_names=['person','car']
|
2025-07-10 09:41:26 +08:00
|
|
|
|
)
|