import sys import json import os import torch from ultralytics import YOLO class MockModelConfigDAO: def __init__(self, db_config): self.db_config = db_config def insert_train_pid(self, task_id, train_pid): print(f"Inserted training PID {train_pid} for task {task_id}") def train_model(dataset_dir, weight_name="best_segmentation_model.pt", config_overrides=None): """ 训练模型并保存权重 """ try: current_pid = os.getpid() print(f"Starting model training in process {current_pid} with dataset: {dataset_dir}") # 默认配置(可通过参数覆盖) DEFAULT_CONFIG = { "model": "pt/yolo11s-seg.pt", "pretrained": True, "data": os.path.join(dataset_dir, "data.yaml"), "project": "UAVid_Segmentation", "name": "v1.5_official", "epochs": 10, "batch_size": 8, "img_size": 640, "workers": 0, # 禁用多进程数据加载 "optimizer": "SGD", "lr0": 0.01, "lrf": 0.01, "momentum": 0.9, "weight_decay": 0.0005, "augment": True, "hyp": { "mosaic": 0.5, "copy_paste": 0.2, "mixup": 0.15, }, } config = DEFAULT_CONFIG.copy() if config_overrides: config.update(config_overrides) print(f"Training config: {config}") # 检查数据配置文件 data_path = config["data"] if not os.path.exists(data_path): raise FileNotFoundError(f"Data configuration file not found: {data_path}") # 初始化模型 model = YOLO(config["model"]) print(f"Model initialized with: {config["model"]}") # 开始训练 results = model.train( data=config["data"], project=config["project"], name=config["name"], epochs=config["epochs"], batch=config["batch_size"], imgsz=config["img_size"], workers=config["workers"], optimizer=config["optimizer"], lr0=config["lr0"], lrf=config["lrf"], momentum=config["momentum"], weight_decay=config["weight_decay"], augment=config["augment"], device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) print(f"Training completed successfully in process {current_pid}") # 验证模型 metrics = model.val() print(f"Validation mAP: {metrics.box:.2f} (box), {metrics.seg:.2f} (mask)") # 保存最佳模型 try: if hasattr(results, 'best') and results.best: best_model_path = results.best if os.path.exists(best_model_path): import shutil shutil.copy2(best_model_path, weight_name) print(f"Best model saved to: {os.path.abspath(weight_name)}") else: torch.save(model.state_dict(), weight_name) print(f"Best model path not found, saved state dict to: {weight_name}") else: torch.save(model.state_dict(), weight_name) print(f"Saved model state dict to: {weight_name}") except Exception as e: print(f"Warning: Failed to save best model: {e}") torch.save(model.state_dict(), weight_name) print(f"Fallback: Saved model state dict to: {weight_name}") return True except Exception as e: print(f"Model training failed in process {os.getpid()}: {e}") raise def main(): if len(sys.argv) != 2: print("Usage: python -c '