from uav_seg import UAVSegPredictor if __name__ == '__main__': predictor = UAVSegPredictor( model_path='checkpoints/deeplabv3plus_best.pth', model_type='deeplabv3plus', num_classes=7 ) color_map = [ [255, 0, 0], # 类别0: 红色 [0, 255, 0], # 类别1: 绿色 [0, 0, 255], # 类别2: 蓝色 [255, 255, 0], # 类别3: 黄色 [255, 0, 255], # 类别4: 品红 [0, 255, 255], # 类别5: 青色 [128, 0, 128] # 类别6: 紫色 ] # 类别 0:黑色 ;背景 # 类别 1:棕色 ;荒地 # 类别 2:绿色 ;林地 # 类别 3:黄色 ;农田 # 类别 4:蓝色 ;水域 # 类别 5:灰色 ;道路 # 类别 6:青色 ;建筑用地 # # predictor.predict_folder('data/test', 'output_test', color_map) predictor.predict('pic/patch_0011.png', 'output_test/patch_0011.png', color_map)