训练模型
from uav_seg import UAVSegTrainer
trainer = UAVSegTrainer(
data_dir='path/to/dataset',
model_type='deeplabv3plus', # 可选: 'unetpp', 'deeplabv3plus'
num_classes=7,
batch_size=4,
epochs=100,
learning_rate=0.001
)
trainer.train()
预测
from uav_seg import UAVSegPredictor
predictor = UAVSegPredictor(
model_path='path/to/saved/model',
model_type='deeplabv3plus',
num_classes=7
)
# 对单张图像进行预测
mask = predictor.predict('path/to/image.tif')
# 对文件夹中的所有图像进行预测
predictor.predict_folder('path/to/images', 'path/to/output')