ai-train_platform/download_train.py

800 lines
28 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# import asyncio
# import os.path
# import shutil
# import sys
# import threading
# import time
# from pathlib import Path
#
# from hachoir.parser.image.iptc import datasets
#
# from middleware.minio_util import downFullPathFile
# from middleware.query_model import ModelConfigDAO
# import yaml
# import multiprocessing
# import torch
# from ultralytics import YOLO
#
#
# async def download_train(task_id: str, bz_training_task_id: int, pt_name: str):
#
# DB_CONFIG = {
# "dbname": "smart_dev_123",
# "user": "postgres",
# "password": "root",
# "host": "8.137.54.85",
# "port": "5060"
# }
#
# # 创建DAO实例
# dao = ModelConfigDAO(DB_CONFIG)
# time_ns=time.time_ns()
# output_root=f"dataset-{time_ns}"
# if not os.path.exists(output_root):
# os.mkdir(output_root)
# list_labels = dao.get_labels(bz_training_task_id)
# list_datasets = dao.get_datasets(bz_training_task_id)
# label_yaml_list = dao.get_label_yaml(bz_training_task_id)
#
# # 定义数据结构(字典)
# uavid_config = {
#
# "path": "", # 替换为你的绝对路径
# "train": "images/train", # 训练集路径
# "val": "images/val", # 验证集路径
# "test": "images/test", # 测试集路径(可选)
# "names": {}
# }
# uavid_config["path"]=os.path.abspath(output_root)
# for i,item in enumerate(label_yaml_list):
# item.id_order=i
# uavid_config["names"][f"{i}"]=item.e_name
# # 生成 YAML 文件
# data_yaml="data.yaml"
# with open(data_yaml, "w", encoding="utf-8") as f:
# yaml.dump(
# uavid_config,
# f,
# default_flow_style=False, # 禁用紧凑格式(保持多行)
# allow_unicode=True, # 允许 Unicode 字符
# sort_keys=False # 保持键的顺序
# )
# file_name = os.path.basename(data_yaml)
# des_path = os.path.join(output_root, file_name)
# if os.path.exists(des_path):
# os.remove(des_path)
# shutil.move(data_yaml, output_root)
# print(f"output_rootoutput_rootoutput_rootoutput_root {os.path.abspath(output_root)}")
#
#
# print("YAML 文件已生成uavid_config.yaml")
#
# invalid_indices = []
#
# for index, pic in enumerate(list_datasets):
# if pic.resource_original_path: # 图像路径有效
# download_path = downFullPathFile(pic.resource_original_path)
# if download_path: # 下载成功
# pic.local_path = download_path
#
# pic.label_name = Path(download_path).stem # 截取图片名称,用作标签
# else:
# invalid_indices.append(index) #存储不符合条件的索引,准备删除
# else:
# invalid_indices.append(index)#存储不符合条件的索引,准备删除
#
#
# # 从后往前删除(避免删除时索引错乱),删除不符合条件的list_datasets
#
# for idx in sorted(invalid_indices, reverse=True):
# del list_datasets[idx]
#
# for data_pic in list_datasets: #整理完整的图像与标签集的对应关系
# for label in list_labels:
# if data_pic.id == label.id:
# for item in label_yaml_list:
# if label.label_ids==item.id:
# data_pic.label_content=data_pic.label_content+item.id_order+" "+label.annotation_data+ '\n'
#
#
# for data_pic in list_datasets:
# label_txt = f"{data_pic.label_name}.txt"
# with open(label_txt, 'w', encoding='utf-8') as f:
# f.write(data_pic.label_content)
# data_pic.label_txt_path=os.path.abspath(label_txt)
# # 移动文件,制作数据集
#
#
# dataset_dirs = {
# "images": Path(output_root) / "images",
# "labels": Path(output_root) / "labels"
# }
# for ds_dir in dataset_dirs.values():
# (ds_dir / "val").mkdir(parents=True, exist_ok=True)
# (ds_dir / "train").mkdir(parents=True, exist_ok=True)
# (ds_dir / "test").mkdir(parents=True, exist_ok=True)
#
# count_pic=0
# for data_pic in list_datasets:
# count_pic=count_pic+1
# if count_pic%10<8:
# images_train_path=dataset_dirs["images"]
# image_dir=os.path.join(images_train_path,"train")
# file_name=os.path.basename(data_pic.local_path)
# des_path=os.path.join(image_dir,file_name)
# if os.path.exists(des_path):
# os.remove(des_path)
# shutil.move(data_pic.local_path, image_dir)
#
#
# labels_train_path=dataset_dirs["labels"]
# label_dir=os.path.join(labels_train_path,"train")
#
# file_name=os.path.basename(data_pic.label_txt_path)
# des_path=os.path.join(label_dir,file_name)
# if os.path.exists(des_path):
# os.remove(des_path)
# shutil.move(data_pic.label_txt_path, label_dir)
# if count_pic%10==8:
# images_val_path=dataset_dirs["images"]
# image_dir=os.path.join(images_val_path,"val")
# file_name=os.path.basename(data_pic.local_path)
# des_path=os.path.join(image_dir,file_name)
# if os.path.exists(des_path):
# os.remove(des_path)
# shutil.move(data_pic.local_path, image_dir)
#
#
# labels_val_path=dataset_dirs["labels"]
# label_dir=os.path.join(labels_val_path,"val")
# file_name=os.path.basename(data_pic.label_txt_path)
# des_path=os.path.join(label_dir,file_name)
# if os.path.exists(des_path):
# os.remove(des_path)
# shutil.move(data_pic.label_txt_path, label_dir)
#
# if count_pic%10==9:
# images_test_path=dataset_dirs["images"]
# image_dir=os.path.join(images_test_path,"test")
# file_name=os.path.basename(data_pic.local_path)
# des_path=os.path.join(image_dir,file_name)
# if os.path.exists(des_path):
# os.remove(des_path)
# shutil.move(data_pic.local_path, image_dir)
#
#
# labels_test_path=dataset_dirs["labels"]
# label_dir=os.path.join(labels_test_path,"test")
# file_name=os.path.basename(data_pic.label_txt_path)
# des_path=os.path.join(label_dir,file_name)
# if os.path.exists(des_path):
# os.remove(des_path)
# shutil.move(data_pic.label_txt_path, label_dir)
#
#
#
# custom_config = {
# "epochs": 50, # 快速测试用
# "batch_size": 4,
# }
#
# # 启动后台训练
# pid = await run_background_training(
# dataset_dir=output_root,
# weight_name=pt_name,
# config_overrides=custom_config
# )
#
# print(f"pid--{pid}")
# dao.insert_train_pid(task_id,train_pid=pid)
#
#
# def train_model(dataset_dir,weight_name="best_segmentation_model.pt", config_overrides=None):
# """
# 训练模型并保存权重
# :param weight_name: 自定义权重文件名(如 "uavid_seg_v1.pt"
# :param config_overrides: 覆盖默认配置的字典(可选)
# """
# # 合并配置(允许通过参数覆盖默认配置)
#
# # 默认配置(可通过函数参数覆盖)
# DEFAULT_CONFIG = {
# "model": "pt/yolo11s-seg.pt",
# "pretrained": True,
# "data": os.path.join(dataset_dir, "data.yaml"), # 关键修改:指向 data.yaml
# "project": "UAVid_Segmentation",
# "name": "v1.5_official",
# "epochs": 1000,
# "batch_size": 8,
# "img_size": 640,
# "workers": 4,
# "optimizer": "SGD",
# "lr0": 0.01,
# "lrf": 0.01,
# "momentum": 0.9,
# "weight_decay": 0.0005,
# "augment": True,
# "hyp": {
# "mosaic": 0.5,
# "copy_paste": 0.2,
# "mixup": 0.15,
# },
# }
#
# config = DEFAULT_CONFIG.copy()
# if config_overrides:
# config.update(config_overrides)
#
# # 初始化模型
# model = YOLO(config["model"])
#
# # 开始训练
# results = model.train(
# data=config["data"],
# project=config["project"],
# name=config["name"],
# epochs=config["epochs"],
# batch=config["batch_size"],
# imgsz=config["img_size"],
# workers=config["workers"],
# optimizer=config["optimizer"],
# lr0=config["lr0"],
# lrf=config["lrf"],
# momentum=config["momentum"],
# weight_decay=config["weight_decay"],
# augment=config["augment"],
# device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
# )
#
# # 验证模型
# metrics = model.val()
# print(f"Validation mAP: {metrics.box_map:.2f} (box), {metrics.seg_map:.2f} (mask)")
#
# # 保存最佳模型(使用自定义名称)
# best_model = results.best_model
# torch.save(best_model, weight_name)
# print(f"Best model saved to: {weight_name}")
#
#
# # def run_background_training(output_root: str, weight_name="best_segmentation_model.pt", config_overrides=None):
# # """使用 spawn 上下文启动进程"""
# # ctx = multiprocessing.get_context('spawn')
# # process = ctx.Process(
# # target=train_model,
# # args=(output_root, weight_name, config_overrides),
# # daemon=False
# # )
# # process.start()
# # return process.pid
#
# import asyncio
#
# async def run_background_training(dataset_dir, weight_name, config_overrides=None):
# """异步启动训练进程"""
# process = await asyncio.create_subprocess_exec(
# sys.executable,
# "train_worker.py",
# "--dataset", dataset_dir,
# "--weight", weight_name,
# "--epochs", str(config_overrides.get("epochs", 50)),
# "--batch", str(config_overrides.get("batch", 4)),
# stdout=asyncio.subprocess.PIPE,
# stderr=asyncio.subprocess.PIPE,
# )
# return process.pid
import asyncio
import os.path
import shutil
import sys
import threading
import time
import subprocess
import json
from pathlib import Path
from middleware.minio_util import downFullPathFile
from middleware.query_model import ModelConfigDAO
import yaml
import torch
from ultralytics import YOLO
async def download_train(task_id: str, bz_training_task_id: int, pt_name: str):
"""
下载训练数据并启动训练
这个函数负责准备数据,然后使用线程+subprocess创建独立进程执行训练
"""
try:
current_pid = os.getpid()
print(f"Starting download and training for task {task_id} in process {current_pid}")
DB_CONFIG = {
"dbname": "smart_dev_123",
"user": "postgres",
"password": "root",
"host": "8.137.54.85",
"port": "5060"
}
# 创建DAO实例
dao = ModelConfigDAO(DB_CONFIG)
time_ns = time.time_ns()
output_root = f"dataset/dataset-{time_ns}"
try:
if not os.path.exists(output_root):
os.mkdir(output_root)
print(f"Created output directory: {output_root}")
except Exception as e:
print(f"Failed to create output directory: {e}")
raise
try:
# 获取标签和数据集信息
list_labels = dao.get_labels(bz_training_task_id)
list_datasets = dao.get_datasets(bz_training_task_id)
label_yaml_list = dao.get_label_yaml(bz_training_task_id)
print(
f"Retrieved {len(list_labels)} labels, {len(list_datasets)} datasets, {len(label_yaml_list)} label configs")
except Exception as e:
print(f"Failed to retrieve data from database: {e}")
raise
# 定义数据结构(字典)
uavid_config = {
"path": "", # 替换为你的绝对路径
"train": "images/train", # 训练集路径
"val": "images/val", # 验证集路径
"test": "images/test", # 测试集路径(可选)
"names": {}
}
try:
uavid_config["path"] = os.path.abspath(output_root)
for i, item in enumerate(label_yaml_list):
item.id_order = i
uavid_config["names"][f"{i}"] = item.e_name
# 生成 YAML 文件
data_yaml = "data.yaml"
with open(data_yaml, "w", encoding="utf-8") as f:
yaml.dump(
uavid_config,
f,
default_flow_style=False, # 禁用紧凑格式(保持多行)
allow_unicode=True, # 允许 Unicode 字符
sort_keys=False # 保持键的顺序
)
file_name = os.path.basename(data_yaml)
des_path = os.path.join(output_root, file_name)
if os.path.exists(des_path):
os.remove(des_path)
shutil.move(data_yaml, output_root)
print(f"Generated YAML config at: {os.path.abspath(output_root)}")
except Exception as e:
print(f"Failed to generate YAML config: {e}")
raise
# 下载数据集
invalid_indices = []
try:
for index, pic in enumerate(list_datasets):
if hasattr(pic, 'resource_original_path') and pic.resource_original_path: # 图像路径有效
try:
download_path = downFullPathFile(pic.resource_original_path)
if download_path: # 下载成功
pic.local_path = download_path
pic.label_name = Path(download_path).stem # 截取图片名称,用作标签
print(f"Downloaded file: {download_path}")
else:
invalid_indices.append(index) # 存储不符合条件的索引,准备删除
print(f"Failed to download file: {pic.resource_original_path}")
except Exception as e:
invalid_indices.append(index)
print(f"Error downloading file {pic.resource_original_path}: {e}")
else:
invalid_indices.append(index) # 存储不符合条件的索引,准备删除
except Exception as e:
print(f"Error processing datasets: {e}")
raise
# 从后往前删除(避免删除时索引错乱),删除不符合条件的list_datasets
try:
for idx in sorted(invalid_indices, reverse=True):
del list_datasets[idx]
print(f"Filtered datasets: {len(list_datasets)} valid items remaining")
except Exception as e:
print(f"Error filtering datasets: {e}")
raise
# 整理标签内容
try:
for data_pic in list_datasets: # 整理完整的图像与标签集的对应关系
for label in list_labels:
if hasattr(data_pic, 'id') and hasattr(label, 'id') and data_pic.id == label.id:
for item in label_yaml_list:
if hasattr(label, 'label_ids') and hasattr(item, 'id') and label.label_ids == item.id:
# 假设label有annotation_data属性
annotation = getattr(label, 'annotation_data', '')
current_content = getattr(data_pic, 'label_content', '')
data_pic.label_content = f"{current_content}{item.id_order} {annotation}\n"
except Exception as e:
print(f"Error organizing labels: {e}")
raise
# 创建标签文件
try:
for data_pic in list_datasets:
if hasattr(data_pic, 'label_name'):
label_txt = f"{data_pic.label_name}.txt"
with open(label_txt, 'w', encoding='utf-8') as f:
f.write(getattr(data_pic, 'label_content', ''))
data_pic.label_txt_path = os.path.abspath(label_txt)
print(f"Created label file: {label_txt}")
except Exception as e:
print(f"Error creating label files: {e}")
raise
# 移动文件,制作数据集
try:
dataset_dirs = {
"images": Path(output_root) / "images",
"labels": Path(output_root) / "labels"
}
for ds_dir in dataset_dirs.values():
(ds_dir / "val").mkdir(parents=True, exist_ok=True)
(ds_dir / "train").mkdir(parents=True, exist_ok=True)
(ds_dir / "test").mkdir(parents=True, exist_ok=True)
print("Created dataset directory structure")
except Exception as e:
print(f"Error creating dataset directories: {e}")
raise
# 分配数据集到训练、验证、测试集
try:
count_pic = 0
for data_pic in list_datasets:
count_pic += 1
# 80% 训练集, 10% 验证集, 10% 测试集
if count_pic % 10 < 8:
split = "train"
elif count_pic % 10 == 8:
split = "val"
else: # count_pic % 10 == 9
split = "test"
# 移动图像文件
if hasattr(data_pic, 'local_path') and os.path.exists(data_pic.local_path):
images_path = dataset_dirs["images"]
image_dir = os.path.join(images_path, split)
file_name = os.path.basename(data_pic.local_path)
des_path = os.path.join(image_dir, file_name)
if os.path.exists(des_path):
os.remove(des_path)
shutil.move(data_pic.local_path, image_dir)
# 移动标签文件
if hasattr(data_pic, 'label_txt_path') and os.path.exists(data_pic.label_txt_path):
labels_path = dataset_dirs["labels"]
label_dir = os.path.join(labels_path, split)
file_name = os.path.basename(data_pic.label_txt_path)
des_path = os.path.join(label_dir, file_name)
if os.path.exists(des_path):
os.remove(des_path)
shutil.move(data_pic.label_txt_path, label_dir)
print(f"Organized {count_pic} files into dataset splits")
except Exception as e:
print(f"Error organizing dataset splits: {e}")
raise
# 训练配置
custom_config = {
"epochs": 50, # 快速测试用
"batch_size": 4,
"workers": 0, # 禁用多进程数据加载
}
# 保存训练配置到文件
config_file = f"train_config_{task_id}.json"
with open(config_file, 'w', encoding='utf-8') as f:
json.dump({
'dataset_dir': output_root,
'pt_name': pt_name,
'config_overrides': custom_config,
'db_config': DB_CONFIG,
'task_id': task_id
}, f)
print(f"Training data preparation completed for task {task_id}")
# 在Windows上使用线程+subprocess创建训练进程
# 避免使用asyncio.create_subprocess_exec
loop = asyncio.get_event_loop()
training_pid = await loop.run_in_executor(
None, # 使用默认的线程池
start_training_process,
config_file
)
if training_pid:
print(f"pid--{training_pid}")
dao.insert_train_pid(task_id, train_pid=training_pid)
return training_pid
else:
raise Exception("Failed to start training process")
except Exception as e:
print(f"Training failed for task {task_id}: {e}", exc_info=True)
raise
def start_training_process(config_file: str) -> int:
"""
在独立线程中启动训练进程
使用subprocess.Popen创建训练进程
"""
try:
# 创建训练脚本内容
train_script = '''
import sys
import json
import os
import torch
from ultralytics import YOLO
class MockModelConfigDAO:
def __init__(self, db_config):
self.db_config = db_config
def insert_train_pid(self, task_id, train_pid):
print(f"Inserted training PID {train_pid} for task {task_id}")
def train_model(dataset_dir, weight_name="best_segmentation_model.pt", config_overrides=None):
"""
训练模型并保存权重
"""
try:
current_pid = os.getpid()
print(f"Starting model training in process {current_pid} with dataset: {dataset_dir}")
# 默认配置(可通过参数覆盖)
DEFAULT_CONFIG = {
"model": "pt/yolo11s-seg.pt",
"pretrained": True,
"data": os.path.join(dataset_dir, "data.yaml"),
"project": "UAVid_Segmentation",
"name": "v1.5_official",
"epochs": 1000,
"batch_size": 8,
"img_size": 640,
"workers": 0, # 禁用多进程数据加载
"optimizer": "SGD",
"lr0": 0.01,
"lrf": 0.01,
"momentum": 0.9,
"weight_decay": 0.0005,
"augment": True,
"hyp": {
"mosaic": 0.5,
"copy_paste": 0.2,
"mixup": 0.15,
},
}
config = DEFAULT_CONFIG.copy()
if config_overrides:
config.update(config_overrides)
print(f"Training config: {config}")
# 检查数据配置文件
data_path = config["data"]
if not os.path.exists(data_path):
raise FileNotFoundError(f"Data configuration file not found: {data_path}")
# 初始化模型
model = YOLO(config["model"])
print(f"Model initialized with: {config["model"]}")
# 开始训练
results = model.train(
data=config["data"],
project=config["project"],
name=config["name"],
epochs=config["epochs"],
batch=config["batch_size"],
imgsz=config["img_size"],
workers=config["workers"],
optimizer=config["optimizer"],
lr0=config["lr0"],
lrf=config["lrf"],
momentum=config["momentum"],
weight_decay=config["weight_decay"],
augment=config["augment"],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
print(f"Training completed successfully in process {current_pid}")
# 验证模型
metrics = model.val()
print(f"Validation mAP: {metrics.box:.2f} (box), {metrics.seg:.2f} (mask)")
# 保存最佳模型
try:
if hasattr(results, 'best') and results.best:
best_model_path = results.best
if os.path.exists(best_model_path):
import shutil
shutil.copy2(best_model_path, weight_name)
print(f"Best model saved to: {os.path.abspath(weight_name)}")
else:
torch.save(model.state_dict(), weight_name)
print(f"Best model path not found, saved state dict to: {weight_name}")
else:
torch.save(model.state_dict(), weight_name)
print(f"Saved model state dict to: {weight_name}")
except Exception as e:
print(f"Warning: Failed to save best model: {e}")
torch.save(model.state_dict(), weight_name)
print(f"Fallback: Saved model state dict to: {weight_name}")
return True
except Exception as e:
print(f"Model training failed in process {os.getpid()}: {e}", exc_info=True)
raise
def main():
if len(sys.argv) != 2:
print("Usage: python -c '<script>' <config_file>")
sys.exit(1)
config_file = sys.argv[1]
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 提取配置
dataset_dir = config['dataset_dir']
pt_name = config['pt_name']
config_overrides = config['config_overrides']
db_config = config['db_config']
task_id = config['task_id']
# 获取当前进程ID
pid = os.getpid()
print(f"Training process started for task {task_id} with PID {pid}")
# 记录PID到数据库
try:
from middleware.query_model import ModelConfigDAO
dao = ModelConfigDAO(db_config)
except ImportError:
dao = MockModelConfigDAO(db_config)
dao.insert_train_pid(task_id, train_pid=pid)
# 执行训练
success = train_model(dataset_dir, pt_name, config_overrides)
if success:
print(f"Training completed successfully for task {task_id}")
sys.exit(0)
else:
print(f"Training failed for task {task_id}")
sys.exit(1)
except Exception as e:
print(f"Training error: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()
'''
# 保存训练脚本
script_path = f"train_worker_{os.path.basename(config_file).split('_')[2].split('.')[0]}.py"
with open(script_path, 'w', encoding='utf-8') as f:
f.write(train_script)
# 使用subprocess.Popen创建训练进程
# 在Windows上使用shell=True可以解决一些路径问题
process = subprocess.Popen([
sys.executable, script_path, config_file
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=False)
print(f"Started training process with PID {process.pid}")
# 启动线程来处理输出
threading.Thread(target=handle_process_output, args=(process,), daemon=True).start()
return process.pid
except Exception as e:
print(f"Failed to start training process: {e}", exc_info=True)
return 0
def handle_process_output(process: subprocess.Popen):
"""
处理子进程的输出
"""
try:
# 分别读取stdout和stderr
def read_stream(stream, stream_name):
while True:
line = stream.readline()
if not line:
break
line = line.strip()
print(f"[{stream_name}] {line}")
# 启动线程读取stdout和stderr
stdout_thread = threading.Thread(target=read_stream, args=(process.stdout, 'STDOUT'))
stderr_thread = threading.Thread(target=read_stream, args=(process.stderr, 'STDERR'))
stdout_thread.start()
stderr_thread.start()
# 等待进程完成
stdout_thread.join()
stderr_thread.join()
# 获取返回码
return_code = process.wait()
print(f"Training process completed with return code: {return_code}")
except Exception as e:
print(f"Error handling process output: {e}", exc_info=True)
if __name__ == "__main__":
"""
直接运行时的入口点
用于测试
"""
import sys
if len(sys.argv) < 4:
print(f"Usage: {sys.argv[0]} <task_id> <train_task_id> <pt_name>")
sys.exit(1)
task_id = sys.argv[1]
train_task_id = int(sys.argv[2])
pt_name = sys.argv[3]
try:
# 创建事件循环
loop = asyncio.get_event_loop()
pid = loop.run_until_complete(download_train(task_id, train_task_id, pt_name))
print(f"Training started in process {pid}")
# 保持事件循环运行
try:
loop.run_forever()
except KeyboardInterrupt:
print("Received keyboard interrupt. Exiting...")
finally:
loop.close()
except Exception as e:
print(f"Training failed: {e}", exc_info=True)
sys.exit(1)