159 lines
5.0 KiB
Python
159 lines
5.0 KiB
Python
|
|
|
||
|
|
import sys
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import torch
|
||
|
|
from ultralytics import YOLO
|
||
|
|
|
||
|
|
class MockModelConfigDAO:
|
||
|
|
def __init__(self, db_config):
|
||
|
|
self.db_config = db_config
|
||
|
|
|
||
|
|
def insert_train_pid(self, task_id, train_pid):
|
||
|
|
print(f"Inserted training PID {train_pid} for task {task_id}")
|
||
|
|
|
||
|
|
def train_model(dataset_dir, weight_name="best_segmentation_model.pt", config_overrides=None):
|
||
|
|
"""
|
||
|
|
训练模型并保存权重
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
current_pid = os.getpid()
|
||
|
|
print(f"Starting model training in process {current_pid} with dataset: {dataset_dir}")
|
||
|
|
|
||
|
|
# 默认配置(可通过参数覆盖)
|
||
|
|
DEFAULT_CONFIG = {
|
||
|
|
"model": "pt/yolo11s-seg.pt",
|
||
|
|
"pretrained": True,
|
||
|
|
"data": os.path.join(dataset_dir, "data.yaml"),
|
||
|
|
"project": "UAVid_Segmentation",
|
||
|
|
"name": "v1.5_official",
|
||
|
|
"epochs": 10,
|
||
|
|
"batch_size": 8,
|
||
|
|
"img_size": 640,
|
||
|
|
"workers": 0, # 禁用多进程数据加载
|
||
|
|
"optimizer": "SGD",
|
||
|
|
"lr0": 0.01,
|
||
|
|
"lrf": 0.01,
|
||
|
|
"momentum": 0.9,
|
||
|
|
"weight_decay": 0.0005,
|
||
|
|
"augment": True,
|
||
|
|
"hyp": {
|
||
|
|
"mosaic": 0.5,
|
||
|
|
"copy_paste": 0.2,
|
||
|
|
"mixup": 0.15,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
config = DEFAULT_CONFIG.copy()
|
||
|
|
if config_overrides:
|
||
|
|
config.update(config_overrides)
|
||
|
|
|
||
|
|
print(f"Training config: {config}")
|
||
|
|
|
||
|
|
# 检查数据配置文件
|
||
|
|
data_path = config["data"]
|
||
|
|
if not os.path.exists(data_path):
|
||
|
|
raise FileNotFoundError(f"Data configuration file not found: {data_path}")
|
||
|
|
|
||
|
|
# 初始化模型
|
||
|
|
model = YOLO(config["model"])
|
||
|
|
print(f"Model initialized with: {config["model"]}")
|
||
|
|
|
||
|
|
# 开始训练
|
||
|
|
results = model.train(
|
||
|
|
data=config["data"],
|
||
|
|
project=config["project"],
|
||
|
|
name=config["name"],
|
||
|
|
epochs=config["epochs"],
|
||
|
|
batch=config["batch_size"],
|
||
|
|
imgsz=config["img_size"],
|
||
|
|
workers=config["workers"],
|
||
|
|
optimizer=config["optimizer"],
|
||
|
|
lr0=config["lr0"],
|
||
|
|
lrf=config["lrf"],
|
||
|
|
momentum=config["momentum"],
|
||
|
|
weight_decay=config["weight_decay"],
|
||
|
|
augment=config["augment"],
|
||
|
|
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||
|
|
)
|
||
|
|
print(f"Training completed successfully in process {current_pid}")
|
||
|
|
|
||
|
|
# 验证模型
|
||
|
|
metrics = model.val()
|
||
|
|
print(f"Validation mAP: {metrics.box:.2f} (box), {metrics.seg:.2f} (mask)")
|
||
|
|
|
||
|
|
# 保存最佳模型
|
||
|
|
try:
|
||
|
|
if hasattr(results, 'best') and results.best:
|
||
|
|
best_model_path = results.best
|
||
|
|
if os.path.exists(best_model_path):
|
||
|
|
import shutil
|
||
|
|
shutil.copy2(best_model_path, weight_name)
|
||
|
|
print(f"Best model saved to: {os.path.abspath(weight_name)}")
|
||
|
|
else:
|
||
|
|
torch.save(model.state_dict(), weight_name)
|
||
|
|
print(f"Best model path not found, saved state dict to: {weight_name}")
|
||
|
|
else:
|
||
|
|
torch.save(model.state_dict(), weight_name)
|
||
|
|
print(f"Saved model state dict to: {weight_name}")
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Warning: Failed to save best model: {e}")
|
||
|
|
torch.save(model.state_dict(), weight_name)
|
||
|
|
print(f"Fallback: Saved model state dict to: {weight_name}")
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Model training failed in process {os.getpid()}: {e}")
|
||
|
|
raise
|
||
|
|
|
||
|
|
def main():
|
||
|
|
if len(sys.argv) != 2:
|
||
|
|
print("Usage: python -c '<script>' <config_file>")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
config_file = sys.argv[1]
|
||
|
|
|
||
|
|
try:
|
||
|
|
with open(config_file, 'r', encoding='utf-8') as f:
|
||
|
|
config = json.load(f)
|
||
|
|
|
||
|
|
# 提取配置
|
||
|
|
dataset_dir = config['dataset_dir']
|
||
|
|
pt_name = config['pt_name']
|
||
|
|
config_overrides = config['config_overrides']
|
||
|
|
db_config = config['db_config']
|
||
|
|
task_id = config['task_id']
|
||
|
|
|
||
|
|
# 获取当前进程ID
|
||
|
|
pid = os.getpid()
|
||
|
|
print(f"Training process started for task {task_id} with PID {pid}")
|
||
|
|
|
||
|
|
# 记录PID到数据库
|
||
|
|
try:
|
||
|
|
from middleware.query_model import ModelConfigDAO
|
||
|
|
dao = ModelConfigDAO(db_config)
|
||
|
|
except ImportError:
|
||
|
|
dao = MockModelConfigDAO(db_config)
|
||
|
|
|
||
|
|
dao.insert_train_pid(task_id, train_pid=pid)
|
||
|
|
|
||
|
|
# 执行训练
|
||
|
|
success = train_model(dataset_dir, pt_name, config_overrides)
|
||
|
|
|
||
|
|
if success:
|
||
|
|
print(f"Training completed successfully for task {task_id}")
|
||
|
|
sys.exit(0)
|
||
|
|
else:
|
||
|
|
print(f"Training failed for task {task_id}")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Training error: {e}", file=sys.stderr)
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|