39 lines
660 B
Markdown
39 lines
660 B
Markdown
|
|
||
|
### 训练模型
|
||
|
|
||
|
```python
|
||
|
from uav_seg import UAVSegTrainer
|
||
|
|
||
|
|
||
|
trainer = UAVSegTrainer(
|
||
|
data_dir='path/to/dataset',
|
||
|
model_type='deeplabv3plus', # 可选: 'unetpp', 'deeplabv3plus'
|
||
|
num_classes=7,
|
||
|
batch_size=4,
|
||
|
epochs=100,
|
||
|
learning_rate=0.001
|
||
|
)
|
||
|
|
||
|
|
||
|
trainer.train()
|
||
|
```
|
||
|
|
||
|
### 预测
|
||
|
|
||
|
```python
|
||
|
from uav_seg import UAVSegPredictor
|
||
|
|
||
|
|
||
|
predictor = UAVSegPredictor(
|
||
|
model_path='path/to/saved/model',
|
||
|
model_type='deeplabv3plus',
|
||
|
num_classes=7
|
||
|
)
|
||
|
|
||
|
# 对单张图像进行预测
|
||
|
mask = predictor.predict('path/to/image.tif')
|
||
|
|
||
|
# 对文件夹中的所有图像进行预测
|
||
|
predictor.predict_folder('path/to/images', 'path/to/output')
|
||
|
```
|