800 lines
28 KiB
Python
800 lines
28 KiB
Python
# import asyncio
|
||
# import os.path
|
||
# import shutil
|
||
# import sys
|
||
# import threading
|
||
# import time
|
||
# from pathlib import Path
|
||
#
|
||
# from hachoir.parser.image.iptc import datasets
|
||
#
|
||
# from middleware.minio_util import downFullPathFile
|
||
# from middleware.query_model import ModelConfigDAO
|
||
# import yaml
|
||
# import multiprocessing
|
||
# import torch
|
||
# from ultralytics import YOLO
|
||
#
|
||
#
|
||
# async def download_train(task_id: str, bz_training_task_id: int, pt_name: str):
|
||
#
|
||
# DB_CONFIG = {
|
||
# "dbname": "smart_dev_123",
|
||
# "user": "postgres",
|
||
# "password": "root",
|
||
# "host": "8.137.54.85",
|
||
# "port": "5060"
|
||
# }
|
||
#
|
||
# # 创建DAO实例
|
||
# dao = ModelConfigDAO(DB_CONFIG)
|
||
# time_ns=time.time_ns()
|
||
# output_root=f"dataset-{time_ns}"
|
||
# if not os.path.exists(output_root):
|
||
# os.mkdir(output_root)
|
||
# list_labels = dao.get_labels(bz_training_task_id)
|
||
# list_datasets = dao.get_datasets(bz_training_task_id)
|
||
# label_yaml_list = dao.get_label_yaml(bz_training_task_id)
|
||
#
|
||
# # 定义数据结构(字典)
|
||
# uavid_config = {
|
||
#
|
||
# "path": "", # 替换为你的绝对路径
|
||
# "train": "images/train", # 训练集路径
|
||
# "val": "images/val", # 验证集路径
|
||
# "test": "images/test", # 测试集路径(可选)
|
||
# "names": {}
|
||
# }
|
||
# uavid_config["path"]=os.path.abspath(output_root)
|
||
# for i,item in enumerate(label_yaml_list):
|
||
# item.id_order=i
|
||
# uavid_config["names"][f"{i}"]=item.e_name
|
||
# # 生成 YAML 文件
|
||
# data_yaml="data.yaml"
|
||
# with open(data_yaml, "w", encoding="utf-8") as f:
|
||
# yaml.dump(
|
||
# uavid_config,
|
||
# f,
|
||
# default_flow_style=False, # 禁用紧凑格式(保持多行)
|
||
# allow_unicode=True, # 允许 Unicode 字符
|
||
# sort_keys=False # 保持键的顺序
|
||
# )
|
||
# file_name = os.path.basename(data_yaml)
|
||
# des_path = os.path.join(output_root, file_name)
|
||
# if os.path.exists(des_path):
|
||
# os.remove(des_path)
|
||
# shutil.move(data_yaml, output_root)
|
||
# print(f"output_rootoutput_rootoutput_rootoutput_root {os.path.abspath(output_root)}")
|
||
#
|
||
#
|
||
# print("YAML 文件已生成:uavid_config.yaml")
|
||
#
|
||
# invalid_indices = []
|
||
#
|
||
# for index, pic in enumerate(list_datasets):
|
||
# if pic.resource_original_path: # 图像路径有效
|
||
# download_path = downFullPathFile(pic.resource_original_path)
|
||
# if download_path: # 下载成功
|
||
# pic.local_path = download_path
|
||
#
|
||
# pic.label_name = Path(download_path).stem # 截取图片名称,用作标签
|
||
# else:
|
||
# invalid_indices.append(index) #存储不符合条件的索引,准备删除
|
||
# else:
|
||
# invalid_indices.append(index)#存储不符合条件的索引,准备删除
|
||
#
|
||
#
|
||
# # 从后往前删除(避免删除时索引错乱),删除不符合条件的list_datasets
|
||
#
|
||
# for idx in sorted(invalid_indices, reverse=True):
|
||
# del list_datasets[idx]
|
||
#
|
||
# for data_pic in list_datasets: #整理完整的图像与标签集的对应关系
|
||
# for label in list_labels:
|
||
# if data_pic.id == label.id:
|
||
# for item in label_yaml_list:
|
||
# if label.label_ids==item.id:
|
||
# data_pic.label_content=data_pic.label_content+item.id_order+" "+label.annotation_data+ '\n'
|
||
#
|
||
#
|
||
# for data_pic in list_datasets:
|
||
# label_txt = f"{data_pic.label_name}.txt"
|
||
# with open(label_txt, 'w', encoding='utf-8') as f:
|
||
# f.write(data_pic.label_content)
|
||
# data_pic.label_txt_path=os.path.abspath(label_txt)
|
||
# # 移动文件,制作数据集
|
||
#
|
||
#
|
||
# dataset_dirs = {
|
||
# "images": Path(output_root) / "images",
|
||
# "labels": Path(output_root) / "labels"
|
||
# }
|
||
# for ds_dir in dataset_dirs.values():
|
||
# (ds_dir / "val").mkdir(parents=True, exist_ok=True)
|
||
# (ds_dir / "train").mkdir(parents=True, exist_ok=True)
|
||
# (ds_dir / "test").mkdir(parents=True, exist_ok=True)
|
||
#
|
||
# count_pic=0
|
||
# for data_pic in list_datasets:
|
||
# count_pic=count_pic+1
|
||
# if count_pic%10<8:
|
||
# images_train_path=dataset_dirs["images"]
|
||
# image_dir=os.path.join(images_train_path,"train")
|
||
# file_name=os.path.basename(data_pic.local_path)
|
||
# des_path=os.path.join(image_dir,file_name)
|
||
# if os.path.exists(des_path):
|
||
# os.remove(des_path)
|
||
# shutil.move(data_pic.local_path, image_dir)
|
||
#
|
||
#
|
||
# labels_train_path=dataset_dirs["labels"]
|
||
# label_dir=os.path.join(labels_train_path,"train")
|
||
#
|
||
# file_name=os.path.basename(data_pic.label_txt_path)
|
||
# des_path=os.path.join(label_dir,file_name)
|
||
# if os.path.exists(des_path):
|
||
# os.remove(des_path)
|
||
# shutil.move(data_pic.label_txt_path, label_dir)
|
||
# if count_pic%10==8:
|
||
# images_val_path=dataset_dirs["images"]
|
||
# image_dir=os.path.join(images_val_path,"val")
|
||
# file_name=os.path.basename(data_pic.local_path)
|
||
# des_path=os.path.join(image_dir,file_name)
|
||
# if os.path.exists(des_path):
|
||
# os.remove(des_path)
|
||
# shutil.move(data_pic.local_path, image_dir)
|
||
#
|
||
#
|
||
# labels_val_path=dataset_dirs["labels"]
|
||
# label_dir=os.path.join(labels_val_path,"val")
|
||
# file_name=os.path.basename(data_pic.label_txt_path)
|
||
# des_path=os.path.join(label_dir,file_name)
|
||
# if os.path.exists(des_path):
|
||
# os.remove(des_path)
|
||
# shutil.move(data_pic.label_txt_path, label_dir)
|
||
#
|
||
# if count_pic%10==9:
|
||
# images_test_path=dataset_dirs["images"]
|
||
# image_dir=os.path.join(images_test_path,"test")
|
||
# file_name=os.path.basename(data_pic.local_path)
|
||
# des_path=os.path.join(image_dir,file_name)
|
||
# if os.path.exists(des_path):
|
||
# os.remove(des_path)
|
||
# shutil.move(data_pic.local_path, image_dir)
|
||
#
|
||
#
|
||
# labels_test_path=dataset_dirs["labels"]
|
||
# label_dir=os.path.join(labels_test_path,"test")
|
||
# file_name=os.path.basename(data_pic.label_txt_path)
|
||
# des_path=os.path.join(label_dir,file_name)
|
||
# if os.path.exists(des_path):
|
||
# os.remove(des_path)
|
||
# shutil.move(data_pic.label_txt_path, label_dir)
|
||
#
|
||
#
|
||
#
|
||
# custom_config = {
|
||
# "epochs": 50, # 快速测试用
|
||
# "batch_size": 4,
|
||
# }
|
||
#
|
||
# # 启动后台训练
|
||
# pid = await run_background_training(
|
||
# dataset_dir=output_root,
|
||
# weight_name=pt_name,
|
||
# config_overrides=custom_config
|
||
# )
|
||
#
|
||
# print(f"pid--{pid}")
|
||
# dao.insert_train_pid(task_id,train_pid=pid)
|
||
#
|
||
#
|
||
# def train_model(dataset_dir,weight_name="best_segmentation_model.pt", config_overrides=None):
|
||
# """
|
||
# 训练模型并保存权重
|
||
# :param weight_name: 自定义权重文件名(如 "uavid_seg_v1.pt")
|
||
# :param config_overrides: 覆盖默认配置的字典(可选)
|
||
# """
|
||
# # 合并配置(允许通过参数覆盖默认配置)
|
||
#
|
||
# # 默认配置(可通过函数参数覆盖)
|
||
# DEFAULT_CONFIG = {
|
||
# "model": "pt/yolo11s-seg.pt",
|
||
# "pretrained": True,
|
||
# "data": os.path.join(dataset_dir, "data.yaml"), # 关键修改:指向 data.yaml
|
||
# "project": "UAVid_Segmentation",
|
||
# "name": "v1.5_official",
|
||
# "epochs": 1000,
|
||
# "batch_size": 8,
|
||
# "img_size": 640,
|
||
# "workers": 4,
|
||
# "optimizer": "SGD",
|
||
# "lr0": 0.01,
|
||
# "lrf": 0.01,
|
||
# "momentum": 0.9,
|
||
# "weight_decay": 0.0005,
|
||
# "augment": True,
|
||
# "hyp": {
|
||
# "mosaic": 0.5,
|
||
# "copy_paste": 0.2,
|
||
# "mixup": 0.15,
|
||
# },
|
||
# }
|
||
#
|
||
# config = DEFAULT_CONFIG.copy()
|
||
# if config_overrides:
|
||
# config.update(config_overrides)
|
||
#
|
||
# # 初始化模型
|
||
# model = YOLO(config["model"])
|
||
#
|
||
# # 开始训练
|
||
# results = model.train(
|
||
# data=config["data"],
|
||
# project=config["project"],
|
||
# name=config["name"],
|
||
# epochs=config["epochs"],
|
||
# batch=config["batch_size"],
|
||
# imgsz=config["img_size"],
|
||
# workers=config["workers"],
|
||
# optimizer=config["optimizer"],
|
||
# lr0=config["lr0"],
|
||
# lrf=config["lrf"],
|
||
# momentum=config["momentum"],
|
||
# weight_decay=config["weight_decay"],
|
||
# augment=config["augment"],
|
||
# device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||
# )
|
||
#
|
||
# # 验证模型
|
||
# metrics = model.val()
|
||
# print(f"Validation mAP: {metrics.box_map:.2f} (box), {metrics.seg_map:.2f} (mask)")
|
||
#
|
||
# # 保存最佳模型(使用自定义名称)
|
||
# best_model = results.best_model
|
||
# torch.save(best_model, weight_name)
|
||
# print(f"Best model saved to: {weight_name}")
|
||
#
|
||
#
|
||
# # def run_background_training(output_root: str, weight_name="best_segmentation_model.pt", config_overrides=None):
|
||
# # """使用 spawn 上下文启动进程"""
|
||
# # ctx = multiprocessing.get_context('spawn')
|
||
# # process = ctx.Process(
|
||
# # target=train_model,
|
||
# # args=(output_root, weight_name, config_overrides),
|
||
# # daemon=False
|
||
# # )
|
||
# # process.start()
|
||
# # return process.pid
|
||
#
|
||
# import asyncio
|
||
#
|
||
# async def run_background_training(dataset_dir, weight_name, config_overrides=None):
|
||
# """异步启动训练进程"""
|
||
# process = await asyncio.create_subprocess_exec(
|
||
# sys.executable,
|
||
# "train_worker.py",
|
||
# "--dataset", dataset_dir,
|
||
# "--weight", weight_name,
|
||
# "--epochs", str(config_overrides.get("epochs", 50)),
|
||
# "--batch", str(config_overrides.get("batch", 4)),
|
||
# stdout=asyncio.subprocess.PIPE,
|
||
# stderr=asyncio.subprocess.PIPE,
|
||
# )
|
||
# return process.pid
|
||
|
||
|
||
import asyncio
|
||
import os.path
|
||
import shutil
|
||
import sys
|
||
import threading
|
||
import time
|
||
import subprocess
|
||
import json
|
||
from pathlib import Path
|
||
|
||
from middleware.minio_util import downFullPathFile
|
||
from middleware.query_model import ModelConfigDAO
|
||
import yaml
|
||
import torch
|
||
from ultralytics import YOLO
|
||
|
||
|
||
async def download_train(task_id: str, bz_training_task_id: int, pt_name: str):
|
||
"""
|
||
下载训练数据并启动训练
|
||
这个函数负责准备数据,然后使用线程+subprocess创建独立进程执行训练
|
||
"""
|
||
try:
|
||
current_pid = os.getpid()
|
||
print(f"Starting download and training for task {task_id} in process {current_pid}")
|
||
|
||
DB_CONFIG = {
|
||
"dbname": "smart_dev_123",
|
||
"user": "postgres",
|
||
"password": "root",
|
||
"host": "8.137.54.85",
|
||
"port": "5060"
|
||
}
|
||
|
||
# 创建DAO实例
|
||
dao = ModelConfigDAO(DB_CONFIG)
|
||
time_ns = time.time_ns()
|
||
|
||
output_root = f"dataset/dataset-{time_ns}"
|
||
|
||
try:
|
||
if not os.path.exists(output_root):
|
||
os.mkdir(output_root)
|
||
print(f"Created output directory: {output_root}")
|
||
except Exception as e:
|
||
print(f"Failed to create output directory: {e}")
|
||
raise
|
||
|
||
try:
|
||
# 获取标签和数据集信息
|
||
list_labels = dao.get_labels(bz_training_task_id)
|
||
list_datasets = dao.get_datasets(bz_training_task_id)
|
||
label_yaml_list = dao.get_label_yaml(bz_training_task_id)
|
||
|
||
print(
|
||
f"Retrieved {len(list_labels)} labels, {len(list_datasets)} datasets, {len(label_yaml_list)} label configs")
|
||
except Exception as e:
|
||
print(f"Failed to retrieve data from database: {e}")
|
||
raise
|
||
|
||
# 定义数据结构(字典)
|
||
uavid_config = {
|
||
"path": "", # 替换为你的绝对路径
|
||
"train": "images/train", # 训练集路径
|
||
"val": "images/val", # 验证集路径
|
||
"test": "images/test", # 测试集路径(可选)
|
||
"names": {}
|
||
}
|
||
|
||
try:
|
||
uavid_config["path"] = os.path.abspath(output_root)
|
||
for i, item in enumerate(label_yaml_list):
|
||
item.id_order = i
|
||
uavid_config["names"][f"{i}"] = item.e_name
|
||
|
||
# 生成 YAML 文件
|
||
data_yaml = "data.yaml"
|
||
with open(data_yaml, "w", encoding="utf-8") as f:
|
||
yaml.dump(
|
||
uavid_config,
|
||
f,
|
||
default_flow_style=False, # 禁用紧凑格式(保持多行)
|
||
allow_unicode=True, # 允许 Unicode 字符
|
||
sort_keys=False # 保持键的顺序
|
||
)
|
||
|
||
file_name = os.path.basename(data_yaml)
|
||
des_path = os.path.join(output_root, file_name)
|
||
if os.path.exists(des_path):
|
||
os.remove(des_path)
|
||
shutil.move(data_yaml, output_root)
|
||
print(f"Generated YAML config at: {os.path.abspath(output_root)}")
|
||
except Exception as e:
|
||
print(f"Failed to generate YAML config: {e}")
|
||
raise
|
||
|
||
# 下载数据集
|
||
invalid_indices = []
|
||
|
||
try:
|
||
for index, pic in enumerate(list_datasets):
|
||
if hasattr(pic, 'resource_original_path') and pic.resource_original_path: # 图像路径有效
|
||
try:
|
||
download_path = downFullPathFile(pic.resource_original_path)
|
||
if download_path: # 下载成功
|
||
pic.local_path = download_path
|
||
pic.label_name = Path(download_path).stem # 截取图片名称,用作标签
|
||
print(f"Downloaded file: {download_path}")
|
||
else:
|
||
invalid_indices.append(index) # 存储不符合条件的索引,准备删除
|
||
print(f"Failed to download file: {pic.resource_original_path}")
|
||
except Exception as e:
|
||
invalid_indices.append(index)
|
||
print(f"Error downloading file {pic.resource_original_path}: {e}")
|
||
else:
|
||
invalid_indices.append(index) # 存储不符合条件的索引,准备删除
|
||
except Exception as e:
|
||
print(f"Error processing datasets: {e}")
|
||
raise
|
||
|
||
# 从后往前删除(避免删除时索引错乱),删除不符合条件的list_datasets
|
||
try:
|
||
for idx in sorted(invalid_indices, reverse=True):
|
||
del list_datasets[idx]
|
||
|
||
print(f"Filtered datasets: {len(list_datasets)} valid items remaining")
|
||
except Exception as e:
|
||
print(f"Error filtering datasets: {e}")
|
||
raise
|
||
|
||
# 整理标签内容
|
||
try:
|
||
for data_pic in list_datasets: # 整理完整的图像与标签集的对应关系
|
||
for label in list_labels:
|
||
if hasattr(data_pic, 'id') and hasattr(label, 'id') and data_pic.id == label.id:
|
||
for item in label_yaml_list:
|
||
if hasattr(label, 'label_ids') and hasattr(item, 'id') and label.label_ids == item.id:
|
||
# 假设label有annotation_data属性
|
||
annotation = getattr(label, 'annotation_data', '')
|
||
current_content = getattr(data_pic, 'label_content', '')
|
||
data_pic.label_content = f"{current_content}{item.id_order} {annotation}\n"
|
||
except Exception as e:
|
||
print(f"Error organizing labels: {e}")
|
||
raise
|
||
|
||
# 创建标签文件
|
||
try:
|
||
for data_pic in list_datasets:
|
||
if hasattr(data_pic, 'label_name'):
|
||
label_txt = f"{data_pic.label_name}.txt"
|
||
with open(label_txt, 'w', encoding='utf-8') as f:
|
||
f.write(getattr(data_pic, 'label_content', ''))
|
||
data_pic.label_txt_path = os.path.abspath(label_txt)
|
||
print(f"Created label file: {label_txt}")
|
||
except Exception as e:
|
||
print(f"Error creating label files: {e}")
|
||
raise
|
||
|
||
# 移动文件,制作数据集
|
||
try:
|
||
dataset_dirs = {
|
||
"images": Path(output_root) / "images",
|
||
"labels": Path(output_root) / "labels"
|
||
}
|
||
|
||
for ds_dir in dataset_dirs.values():
|
||
(ds_dir / "val").mkdir(parents=True, exist_ok=True)
|
||
(ds_dir / "train").mkdir(parents=True, exist_ok=True)
|
||
(ds_dir / "test").mkdir(parents=True, exist_ok=True)
|
||
|
||
print("Created dataset directory structure")
|
||
except Exception as e:
|
||
print(f"Error creating dataset directories: {e}")
|
||
raise
|
||
|
||
# 分配数据集到训练、验证、测试集
|
||
try:
|
||
count_pic = 0
|
||
for data_pic in list_datasets:
|
||
count_pic += 1
|
||
|
||
# 80% 训练集, 10% 验证集, 10% 测试集
|
||
if count_pic % 10 < 8:
|
||
split = "train"
|
||
elif count_pic % 10 == 8:
|
||
split = "val"
|
||
else: # count_pic % 10 == 9
|
||
split = "test"
|
||
|
||
# 移动图像文件
|
||
if hasattr(data_pic, 'local_path') and os.path.exists(data_pic.local_path):
|
||
images_path = dataset_dirs["images"]
|
||
image_dir = os.path.join(images_path, split)
|
||
file_name = os.path.basename(data_pic.local_path)
|
||
des_path = os.path.join(image_dir, file_name)
|
||
|
||
if os.path.exists(des_path):
|
||
os.remove(des_path)
|
||
shutil.move(data_pic.local_path, image_dir)
|
||
|
||
# 移动标签文件
|
||
if hasattr(data_pic, 'label_txt_path') and os.path.exists(data_pic.label_txt_path):
|
||
labels_path = dataset_dirs["labels"]
|
||
label_dir = os.path.join(labels_path, split)
|
||
file_name = os.path.basename(data_pic.label_txt_path)
|
||
des_path = os.path.join(label_dir, file_name)
|
||
|
||
if os.path.exists(des_path):
|
||
os.remove(des_path)
|
||
shutil.move(data_pic.label_txt_path, label_dir)
|
||
|
||
print(f"Organized {count_pic} files into dataset splits")
|
||
except Exception as e:
|
||
print(f"Error organizing dataset splits: {e}")
|
||
raise
|
||
|
||
# 训练配置
|
||
custom_config = {
|
||
"epochs": 50, # 快速测试用
|
||
"batch_size": 4,
|
||
"workers": 0, # 禁用多进程数据加载
|
||
}
|
||
|
||
# 保存训练配置到文件
|
||
config_file = f"train_config_{task_id}.json"
|
||
with open(config_file, 'w', encoding='utf-8') as f:
|
||
json.dump({
|
||
'dataset_dir': output_root,
|
||
'pt_name': pt_name,
|
||
'config_overrides': custom_config,
|
||
'db_config': DB_CONFIG,
|
||
'task_id': task_id
|
||
}, f)
|
||
|
||
print(f"Training data preparation completed for task {task_id}")
|
||
|
||
# 在Windows上使用线程+subprocess创建训练进程
|
||
# 避免使用asyncio.create_subprocess_exec
|
||
loop = asyncio.get_event_loop()
|
||
training_pid = await loop.run_in_executor(
|
||
None, # 使用默认的线程池
|
||
start_training_process,
|
||
config_file
|
||
)
|
||
|
||
if training_pid:
|
||
print(f"pid--{training_pid}")
|
||
dao.insert_train_pid(task_id, train_pid=training_pid)
|
||
return training_pid
|
||
else:
|
||
raise Exception("Failed to start training process")
|
||
|
||
except Exception as e:
|
||
print(f"Training failed for task {task_id}: {e}", exc_info=True)
|
||
raise
|
||
|
||
|
||
def start_training_process(config_file: str) -> int:
|
||
"""
|
||
在独立线程中启动训练进程
|
||
使用subprocess.Popen创建训练进程
|
||
"""
|
||
try:
|
||
# 创建训练脚本内容
|
||
train_script = '''
|
||
import sys
|
||
import json
|
||
import os
|
||
import torch
|
||
from ultralytics import YOLO
|
||
|
||
class MockModelConfigDAO:
|
||
def __init__(self, db_config):
|
||
self.db_config = db_config
|
||
|
||
def insert_train_pid(self, task_id, train_pid):
|
||
print(f"Inserted training PID {train_pid} for task {task_id}")
|
||
|
||
def train_model(dataset_dir, weight_name="best_segmentation_model.pt", config_overrides=None):
|
||
"""
|
||
训练模型并保存权重
|
||
"""
|
||
try:
|
||
current_pid = os.getpid()
|
||
print(f"Starting model training in process {current_pid} with dataset: {dataset_dir}")
|
||
|
||
# 默认配置(可通过参数覆盖)
|
||
DEFAULT_CONFIG = {
|
||
"model": "pt/yolo11s-seg.pt",
|
||
"pretrained": True,
|
||
"data": os.path.join(dataset_dir, "data.yaml"),
|
||
"project": "UAVid_Segmentation",
|
||
"name": "v1.5_official",
|
||
"epochs": 1000,
|
||
"batch_size": 8,
|
||
"img_size": 640,
|
||
"workers": 0, # 禁用多进程数据加载
|
||
"optimizer": "SGD",
|
||
"lr0": 0.01,
|
||
"lrf": 0.01,
|
||
"momentum": 0.9,
|
||
"weight_decay": 0.0005,
|
||
"augment": True,
|
||
"hyp": {
|
||
"mosaic": 0.5,
|
||
"copy_paste": 0.2,
|
||
"mixup": 0.15,
|
||
},
|
||
}
|
||
|
||
config = DEFAULT_CONFIG.copy()
|
||
if config_overrides:
|
||
config.update(config_overrides)
|
||
|
||
print(f"Training config: {config}")
|
||
|
||
# 检查数据配置文件
|
||
data_path = config["data"]
|
||
if not os.path.exists(data_path):
|
||
raise FileNotFoundError(f"Data configuration file not found: {data_path}")
|
||
|
||
# 初始化模型
|
||
model = YOLO(config["model"])
|
||
print(f"Model initialized with: {config["model"]}")
|
||
|
||
# 开始训练
|
||
results = model.train(
|
||
data=config["data"],
|
||
project=config["project"],
|
||
name=config["name"],
|
||
epochs=config["epochs"],
|
||
batch=config["batch_size"],
|
||
imgsz=config["img_size"],
|
||
workers=config["workers"],
|
||
optimizer=config["optimizer"],
|
||
lr0=config["lr0"],
|
||
lrf=config["lrf"],
|
||
momentum=config["momentum"],
|
||
weight_decay=config["weight_decay"],
|
||
augment=config["augment"],
|
||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||
)
|
||
print(f"Training completed successfully in process {current_pid}")
|
||
|
||
# 验证模型
|
||
metrics = model.val()
|
||
print(f"Validation mAP: {metrics.box:.2f} (box), {metrics.seg:.2f} (mask)")
|
||
|
||
# 保存最佳模型
|
||
try:
|
||
if hasattr(results, 'best') and results.best:
|
||
best_model_path = results.best
|
||
if os.path.exists(best_model_path):
|
||
import shutil
|
||
shutil.copy2(best_model_path, weight_name)
|
||
print(f"Best model saved to: {os.path.abspath(weight_name)}")
|
||
else:
|
||
torch.save(model.state_dict(), weight_name)
|
||
print(f"Best model path not found, saved state dict to: {weight_name}")
|
||
else:
|
||
torch.save(model.state_dict(), weight_name)
|
||
print(f"Saved model state dict to: {weight_name}")
|
||
except Exception as e:
|
||
print(f"Warning: Failed to save best model: {e}")
|
||
torch.save(model.state_dict(), weight_name)
|
||
print(f"Fallback: Saved model state dict to: {weight_name}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"Model training failed in process {os.getpid()}: {e}", exc_info=True)
|
||
raise
|
||
|
||
def main():
|
||
if len(sys.argv) != 2:
|
||
print("Usage: python -c '<script>' <config_file>")
|
||
sys.exit(1)
|
||
|
||
config_file = sys.argv[1]
|
||
|
||
try:
|
||
with open(config_file, 'r', encoding='utf-8') as f:
|
||
config = json.load(f)
|
||
|
||
# 提取配置
|
||
dataset_dir = config['dataset_dir']
|
||
pt_name = config['pt_name']
|
||
config_overrides = config['config_overrides']
|
||
db_config = config['db_config']
|
||
task_id = config['task_id']
|
||
|
||
# 获取当前进程ID
|
||
pid = os.getpid()
|
||
print(f"Training process started for task {task_id} with PID {pid}")
|
||
|
||
# 记录PID到数据库
|
||
try:
|
||
from middleware.query_model import ModelConfigDAO
|
||
dao = ModelConfigDAO(db_config)
|
||
except ImportError:
|
||
dao = MockModelConfigDAO(db_config)
|
||
|
||
dao.insert_train_pid(task_id, train_pid=pid)
|
||
|
||
# 执行训练
|
||
success = train_model(dataset_dir, pt_name, config_overrides)
|
||
|
||
if success:
|
||
print(f"Training completed successfully for task {task_id}")
|
||
sys.exit(0)
|
||
else:
|
||
print(f"Training failed for task {task_id}")
|
||
sys.exit(1)
|
||
|
||
except Exception as e:
|
||
print(f"Training error: {e}", file=sys.stderr)
|
||
import traceback
|
||
traceback.print_exc()
|
||
sys.exit(1)
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
'''
|
||
|
||
# 保存训练脚本
|
||
script_path = f"train_worker_{os.path.basename(config_file).split('_')[2].split('.')[0]}.py"
|
||
with open(script_path, 'w', encoding='utf-8') as f:
|
||
f.write(train_script)
|
||
|
||
# 使用subprocess.Popen创建训练进程
|
||
# 在Windows上使用shell=True可以解决一些路径问题
|
||
process = subprocess.Popen([
|
||
sys.executable, script_path, config_file
|
||
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=False)
|
||
|
||
print(f"Started training process with PID {process.pid}")
|
||
|
||
# 启动线程来处理输出
|
||
threading.Thread(target=handle_process_output, args=(process,), daemon=True).start()
|
||
|
||
return process.pid
|
||
|
||
except Exception as e:
|
||
print(f"Failed to start training process: {e}", exc_info=True)
|
||
return 0
|
||
|
||
|
||
def handle_process_output(process: subprocess.Popen):
|
||
"""
|
||
处理子进程的输出
|
||
"""
|
||
try:
|
||
# 分别读取stdout和stderr
|
||
def read_stream(stream, stream_name):
|
||
while True:
|
||
line = stream.readline()
|
||
if not line:
|
||
break
|
||
line = line.strip()
|
||
print(f"[{stream_name}] {line}")
|
||
|
||
# 启动线程读取stdout和stderr
|
||
stdout_thread = threading.Thread(target=read_stream, args=(process.stdout, 'STDOUT'))
|
||
stderr_thread = threading.Thread(target=read_stream, args=(process.stderr, 'STDERR'))
|
||
|
||
stdout_thread.start()
|
||
stderr_thread.start()
|
||
|
||
# 等待进程完成
|
||
stdout_thread.join()
|
||
stderr_thread.join()
|
||
|
||
# 获取返回码
|
||
return_code = process.wait()
|
||
print(f"Training process completed with return code: {return_code}")
|
||
|
||
except Exception as e:
|
||
print(f"Error handling process output: {e}", exc_info=True)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
"""
|
||
直接运行时的入口点
|
||
用于测试
|
||
"""
|
||
import sys
|
||
|
||
if len(sys.argv) < 4:
|
||
print(f"Usage: {sys.argv[0]} <task_id> <train_task_id> <pt_name>")
|
||
sys.exit(1)
|
||
|
||
task_id = sys.argv[1]
|
||
train_task_id = int(sys.argv[2])
|
||
pt_name = sys.argv[3]
|
||
|
||
try:
|
||
# 创建事件循环
|
||
loop = asyncio.get_event_loop()
|
||
pid = loop.run_until_complete(download_train(task_id, train_task_id, pt_name))
|
||
print(f"Training started in process {pid}")
|
||
|
||
# 保持事件循环运行
|
||
try:
|
||
loop.run_forever()
|
||
except KeyboardInterrupt:
|
||
print("Received keyboard interrupt. Exiting...")
|
||
finally:
|
||
loop.close()
|
||
|
||
except Exception as e:
|
||
print(f"Training failed: {e}", exc_info=True)
|
||
sys.exit(1)
|
||
|