30 lines
1007 B
Python
30 lines
1007 B
Python
|
|
from uav_seg import UAVSegPredictor
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
predictor = UAVSegPredictor(
|
|||
|
|
model_path='checkpoints/deeplabv3plus_best.pth',
|
|||
|
|
model_type='deeplabv3plus',
|
|||
|
|
num_classes=7
|
|||
|
|
)
|
|||
|
|
color_map = [
|
|||
|
|
[255, 0, 0], # 类别0: 红色
|
|||
|
|
[0, 255, 0], # 类别1: 绿色
|
|||
|
|
[0, 0, 255], # 类别2: 蓝色
|
|||
|
|
[255, 255, 0], # 类别3: 黄色
|
|||
|
|
[255, 0, 255], # 类别4: 品红
|
|||
|
|
[0, 255, 255], # 类别5: 青色
|
|||
|
|
[128, 0, 128] # 类别6: 紫色
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 类别 0:黑色 ;背景
|
|||
|
|
# 类别 1:棕色 ;荒地
|
|||
|
|
# 类别 2:绿色 ;林地
|
|||
|
|
# 类别 3:黄色 ;农田
|
|||
|
|
# 类别 4:蓝色 ;水域
|
|||
|
|
# 类别 5:灰色 ;道路
|
|||
|
|
# 类别 6:青色 ;建筑用地
|
|||
|
|
#
|
|||
|
|
|
|||
|
|
# predictor.predict_folder('data/test', 'output_test', color_map)
|
|||
|
|
|
|||
|
|
predictor.predict('pic/patch_0011.png', 'output_test/patch_0011.png', color_map)
|