21 lines
871 B
Python
21 lines
871 B
Python
|
from uav_seg import UAVSegPredictor
|
||
|
|
||
|
def test_uav_seg(model_path, input_path, output_path):
|
||
|
predictor = UAVSegPredictor(
|
||
|
model_path= model_path,
|
||
|
model_type='deeplabv3plus',
|
||
|
num_classes=7
|
||
|
)
|
||
|
color_map = [
|
||
|
[255, 0, 0], # 类别0: 红色;背景
|
||
|
[0, 255, 0], # 类别1: 绿色;荒地
|
||
|
[0, 0, 255], # 类别2: 蓝色;林地
|
||
|
[255, 255, 0], # 类别3: 黄色;农田
|
||
|
[255, 0, 255], # 类别4: 品红;水
|
||
|
[0, 255, 255], # 类别5: 青色;道路
|
||
|
[128, 0, 128] # 类别6: 紫色;建筑用地
|
||
|
]
|
||
|
predictor.predict_folder(input_path, output_path,color_map)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test_uav_seg('checkpoints/deeplabv3plus_best.pth', r'D:\work\AI_Python\Ai_tottle\map\test_demo', './output')
|