import os import torch import cv2 import argparse import numpy as np from pprint import pprint from tqdm import tqdm from middleware.minio_util import downFile from mmseg.apis import init_model, inference_model """ """ DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') # 测试图像所在文件夹 # IMAGE_FILE_PATH = r"D:\project\UAV_model/need_prediction" IMAGE_FILE_PATH = r"C:\Users\14867\Desktop\seg_tiff" # 模型训练结果的config配置文件路径 CONFIG = r'unetformer_UAV_6000X4000.py' # 模型训练结果的权重文件路径 CHECKPOINT = r'best_mIoU_iter_20000.pth' # 模型推理测试结果的保存路径,每个模型的推理结果都保存在`{save_dir}/{模型config同名文件夹}`下,如文末图片所示。 SAVE_DIR = r"uv_predict_result" def make_full_path(root_list, root_path): file_full_path_list = [] for filename in root_list: file_full_path = os.path.join(root_path, filename) file_full_path_list.append(file_full_path) return file_full_path_list def read_filepath(root): from natsort import natsorted test_image_list = natsorted(os.listdir(root)) test_image_full_path_list = make_full_path(test_image_list, root) return test_image_full_path_list from PIL import Image def save_colored_prediction(predictions, save_path): # color_map = [ # [0, 0, 0], # [165, 42, 42], # [0, 255, 0],2 # [255, 255, 0], # [0, 0, 255], # [128, 128, 128], # [0, 255, 255] # ] color_map = [ [255, 0, 0], # 类别0: 红色 [0, 255, 0], # 类别1: 绿色 [0, 0, 255], # 类别2: 蓝色 [255, 255, 0], # 类别3: 黄色 [255, 0, 255], # 类别4: 品红 [0, 255, 255], # 类别5: 青色 [128, 0, 128] # 类别6: 紫色 ] # 类别 0:黑色 ;背景 # 类别 1:棕色 ;荒地 # 类别 2:绿色 ;林地 # 类别 3:黄色 ;农田 # 类别 4:蓝色 ;水域 # 类别 5:灰色 ;道路 # 类别 6:青色 ;建筑用地 # 创建一个空的 RGB 图像 colored_image = np.zeros((predictions.shape[0], predictions.shape[1], 3), dtype=np.uint8) # 将每个类别的颜色赋值到图像 for class_id in range(len(color_map)): colored_image[predictions == class_id] = color_map[class_id] # 转换为 PIL 图像并保存 image = Image.fromarray(colored_image) image.save(save_path) # # def predict(taskid, ts, img_url_list): # model_mmseg = init_model(CONFIG, CHECKPOINT, device=DEVICE) # result_img_list = [] # for imgs in tqdm(img_url_list): # result = inference_model(model_mmseg, imgs) # pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8) # # # save_path = os.path.join(SAVE_DIR, f"{os.path.basename(CONFIG).split('.')[0]}") # # if not os.path.exists(save_path): # # os.makedirs(save_path) # saves_path = os.path.join(SAVE_DIR, f"{os.path.basename(result.img_path).split('.')[0]}.png") # save_colored_prediction(pred_mask, saves_path) # result_img_list.append(saves_path) # # if len(result_img_list) > 0: # return result_img_list # else: # return None # # # # # pred_mask[pred_mask == 1] = 255 # # save_path = os.path.join(args.save_dir, f"{os.path.basename(args.config).split('.')[0]}") # # # # if not os.path.exists(save_path): # # os.makedirs(save_path) # # # # cv2.imwrite(os.path.join(save_path, f"{os.path.basename(result.img_path).split('.')[0]}.png"), pred_mask, # # [cv2.IMWRITE_PNG_COMPRESSION, 0]) # # 做单张图片推理,不然几个图片不好做组合 def predict_pic(taskid, ts, img_url): model_mmseg = init_model(CONFIG, CHECKPOINT, device=DEVICE) result_img_list = [] result = inference_model(model_mmseg, img_url) pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8) saves_path = os.path.join(SAVE_DIR, f"{os.path.basename(result.img_path).split('.')[0]}.png") save_colored_prediction(pred_mask, saves_path) # result_img_list.append(saves_path) return saves_path if __name__ == '__main__': flight_task_id = "7a5c83e0-fe0d-47bf-a8e1-9bd663508783" # 任务id list_s3_url = [ "test/patch_0011.png", "test/patch_0012.png" ] # s3 img 地址集合 local_img_url_list=[] for img_url in list_s3_url: pic = downFile(img_url) pic_path=os.path.abspath(pic) local_img_url_list.append(pic_path) list_func_id =[20000,20001] # 方法id集合 predict(flight_task_id,12,local_img_url_list)