import argparse import torch from ultralytics import YOLO def main(): parser = argparse.ArgumentParser() parser.add_argument("--dataset", required=True) parser.add_argument("--weight", required=True) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--batch", type=int, default=8) args = parser.parse_args() # 训练模型 model = YOLO("yolo11n.pt") model.train( data=f"{args.dataset}/data.yaml", epochs=args.epochs, batch=args.batch, imgsz=640, device="cuda" if torch.cuda.is_available() else "cpu", ) # 保存模型 torch.save(model.best.pt, args.weight) print(f"Training finished. Model saved to {args.weight}") if __name__ == "__main__": main()