ai-train_platform/train_worker_yolo.py

159 lines
5.0 KiB
Python
Raw Permalink Normal View History

import sys
import json
import os
import torch
from ultralytics import YOLO
class MockModelConfigDAO:
def __init__(self, db_config):
self.db_config = db_config
def insert_train_pid(self, task_id, train_pid):
print(f"Inserted training PID {train_pid} for task {task_id}")
def train_model(dataset_dir, weight_name="best_segmentation_model.pt", config_overrides=None):
"""
训练模型并保存权重
"""
try:
current_pid = os.getpid()
print(f"Starting model training in process {current_pid} with dataset: {dataset_dir}")
# 默认配置(可通过参数覆盖)
DEFAULT_CONFIG = {
"model": "pt/yolo11s-seg.pt",
"pretrained": True,
"data": os.path.join(dataset_dir, "data.yaml"),
"project": "UAVid_Segmentation",
"name": "v1.5_official",
"epochs": 10,
"batch_size": 8,
"img_size": 640,
"workers": 0, # 禁用多进程数据加载
"optimizer": "SGD",
"lr0": 0.01,
"lrf": 0.01,
"momentum": 0.9,
"weight_decay": 0.0005,
"augment": True,
"hyp": {
"mosaic": 0.5,
"copy_paste": 0.2,
"mixup": 0.15,
},
}
config = DEFAULT_CONFIG.copy()
if config_overrides:
config.update(config_overrides)
print(f"Training config: {config}")
# 检查数据配置文件
data_path = config["data"]
if not os.path.exists(data_path):
raise FileNotFoundError(f"Data configuration file not found: {data_path}")
# 初始化模型
model = YOLO(config["model"])
print(f"Model initialized with: {config["model"]}")
# 开始训练
results = model.train(
data=config["data"],
project=config["project"],
name=config["name"],
epochs=config["epochs"],
batch=config["batch_size"],
imgsz=config["img_size"],
workers=config["workers"],
optimizer=config["optimizer"],
lr0=config["lr0"],
lrf=config["lrf"],
momentum=config["momentum"],
weight_decay=config["weight_decay"],
augment=config["augment"],
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
print(f"Training completed successfully in process {current_pid}")
# 验证模型
metrics = model.val()
print(f"Validation mAP: {metrics.box:.2f} (box), {metrics.seg:.2f} (mask)")
# 保存最佳模型
try:
if hasattr(results, 'best') and results.best:
best_model_path = results.best
if os.path.exists(best_model_path):
import shutil
shutil.copy2(best_model_path, weight_name)
print(f"Best model saved to: {os.path.abspath(weight_name)}")
else:
torch.save(model.state_dict(), weight_name)
print(f"Best model path not found, saved state dict to: {weight_name}")
else:
torch.save(model.state_dict(), weight_name)
print(f"Saved model state dict to: {weight_name}")
except Exception as e:
print(f"Warning: Failed to save best model: {e}")
torch.save(model.state_dict(), weight_name)
print(f"Fallback: Saved model state dict to: {weight_name}")
return True
except Exception as e:
print(f"Model training failed in process {os.getpid()}: {e}")
raise
def main():
if len(sys.argv) != 2:
print("Usage: python -c '<script>' <config_file>")
sys.exit(1)
config_file = sys.argv[1]
try:
with open(config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 提取配置
dataset_dir = config['dataset_dir']
pt_name = config['pt_name']
config_overrides = config['config_overrides']
db_config = config['db_config']
task_id = config['task_id']
# 获取当前进程ID
pid = os.getpid()
print(f"Training process started for task {task_id} with PID {pid}")
# 记录PID到数据库
try:
from middleware.query_model import ModelConfigDAO
dao = ModelConfigDAO(db_config)
except ImportError:
dao = MockModelConfigDAO(db_config)
dao.insert_train_pid(task_id, train_pid=pid)
# 执行训练
success = train_model(dataset_dir, pt_name, config_overrides)
if success:
print(f"Training completed successfully for task {task_id}")
sys.exit(0)
else:
print(f"Training failed for task {task_id}")
sys.exit(1)
except Exception as e:
print(f"Training error: {e}", file=sys.stderr)
import traceback
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()