28 lines
774 B
Python
28 lines
774 B
Python
|
|
import argparse
|
||
|
|
import torch
|
||
|
|
from ultralytics import YOLO
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("--dataset", required=True)
|
||
|
|
parser.add_argument("--weight", required=True)
|
||
|
|
parser.add_argument("--epochs", type=int, default=100)
|
||
|
|
parser.add_argument("--batch", type=int, default=8)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# 训练模型
|
||
|
|
model = YOLO("yolo11n.pt")
|
||
|
|
model.train(
|
||
|
|
data=f"{args.dataset}/data.yaml",
|
||
|
|
epochs=args.epochs,
|
||
|
|
batch=args.batch,
|
||
|
|
imgsz=640,
|
||
|
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
||
|
|
)
|
||
|
|
|
||
|
|
# 保存模型
|
||
|
|
torch.save(model.best.pt, args.weight)
|
||
|
|
print(f"Training finished. Model saved to {args.weight}")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|