ai-train_platform/train_worker.py

28 lines
774 B
Python

import argparse
import torch
from ultralytics import YOLO
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", required=True)
parser.add_argument("--weight", required=True)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch", type=int, default=8)
args = parser.parse_args()
# 训练模型
model = YOLO("yolo11n.pt")
model.train(
data=f"{args.dataset}/data.yaml",
epochs=args.epochs,
batch=args.batch,
imgsz=640,
device="cuda" if torch.cuda.is_available() else "cpu",
)
# 保存模型
torch.save(model.best.pt, args.weight)
print(f"Training finished. Model saved to {args.weight}")
if __name__ == "__main__":
main()