145 lines
4.7 KiB
Python
145 lines
4.7 KiB
Python
|
|
import os
|
|||
|
|
import torch
|
|||
|
|
import cv2
|
|||
|
|
import argparse
|
|||
|
|
import numpy as np
|
|||
|
|
from pprint import pprint
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
|
|||
|
|
from middleware.minio_util import downFile
|
|||
|
|
from mmseg.apis import init_model, inference_model
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
|||
|
|
# 测试图像所在文件夹
|
|||
|
|
# IMAGE_FILE_PATH = r"D:\project\UAV_model/need_prediction"
|
|||
|
|
IMAGE_FILE_PATH = r"C:\Users\14867\Desktop\seg_tiff"
|
|||
|
|
# 模型训练结果的config配置文件路径
|
|||
|
|
CONFIG = r'unetformer_UAV_6000X4000.py'
|
|||
|
|
# 模型训练结果的权重文件路径
|
|||
|
|
CHECKPOINT = r'best_mIoU_iter_20000.pth'
|
|||
|
|
# 模型推理测试结果的保存路径,每个模型的推理结果都保存在`{save_dir}/{模型config同名文件夹}`下,如文末图片所示。
|
|||
|
|
SAVE_DIR = r"uv_predict_result"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def make_full_path(root_list, root_path):
|
|||
|
|
file_full_path_list = []
|
|||
|
|
for filename in root_list:
|
|||
|
|
file_full_path = os.path.join(root_path, filename)
|
|||
|
|
file_full_path_list.append(file_full_path)
|
|||
|
|
return file_full_path_list
|
|||
|
|
|
|||
|
|
|
|||
|
|
def read_filepath(root):
|
|||
|
|
from natsort import natsorted
|
|||
|
|
test_image_list = natsorted(os.listdir(root))
|
|||
|
|
test_image_full_path_list = make_full_path(test_image_list, root)
|
|||
|
|
return test_image_full_path_list
|
|||
|
|
|
|||
|
|
|
|||
|
|
from PIL import Image
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_colored_prediction(predictions, save_path):
|
|||
|
|
# color_map = [
|
|||
|
|
# [0, 0, 0],
|
|||
|
|
# [165, 42, 42],
|
|||
|
|
# [0, 255, 0],2
|
|||
|
|
# [255, 255, 0],
|
|||
|
|
# [0, 0, 255],
|
|||
|
|
# [128, 128, 128],
|
|||
|
|
# [0, 255, 255]
|
|||
|
|
# ]
|
|||
|
|
|
|||
|
|
color_map = [
|
|||
|
|
[255, 0, 0], # 类别0: 红色
|
|||
|
|
[0, 255, 0], # 类别1: 绿色
|
|||
|
|
[0, 0, 255], # 类别2: 蓝色
|
|||
|
|
[255, 255, 0], # 类别3: 黄色
|
|||
|
|
[255, 0, 255], # 类别4: 品红
|
|||
|
|
[0, 255, 255], # 类别5: 青色
|
|||
|
|
[128, 0, 128] # 类别6: 紫色
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# 类别 0:黑色 ;背景
|
|||
|
|
# 类别 1:棕色 ;荒地
|
|||
|
|
# 类别 2:绿色 ;林地
|
|||
|
|
# 类别 3:黄色 ;农田
|
|||
|
|
# 类别 4:蓝色 ;水域
|
|||
|
|
# 类别 5:灰色 ;道路
|
|||
|
|
# 类别 6:青色 ;建筑用地
|
|||
|
|
|
|||
|
|
# 创建一个空的 RGB 图像
|
|||
|
|
colored_image = np.zeros((predictions.shape[0], predictions.shape[1], 3), dtype=np.uint8)
|
|||
|
|
|
|||
|
|
# 将每个类别的颜色赋值到图像
|
|||
|
|
for class_id in range(len(color_map)):
|
|||
|
|
colored_image[predictions == class_id] = color_map[class_id]
|
|||
|
|
|
|||
|
|
# 转换为 PIL 图像并保存
|
|||
|
|
image = Image.fromarray(colored_image)
|
|||
|
|
image.save(save_path)
|
|||
|
|
|
|||
|
|
#
|
|||
|
|
# def predict(taskid, ts, img_url_list):
|
|||
|
|
# model_mmseg = init_model(CONFIG, CHECKPOINT, device=DEVICE)
|
|||
|
|
# result_img_list = []
|
|||
|
|
# for imgs in tqdm(img_url_list):
|
|||
|
|
# result = inference_model(model_mmseg, imgs)
|
|||
|
|
# pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)
|
|||
|
|
#
|
|||
|
|
# # save_path = os.path.join(SAVE_DIR, f"{os.path.basename(CONFIG).split('.')[0]}")
|
|||
|
|
# # if not os.path.exists(save_path):
|
|||
|
|
# # os.makedirs(save_path)
|
|||
|
|
# saves_path = os.path.join(SAVE_DIR, f"{os.path.basename(result.img_path).split('.')[0]}.png")
|
|||
|
|
# save_colored_prediction(pred_mask, saves_path)
|
|||
|
|
# result_img_list.append(saves_path)
|
|||
|
|
#
|
|||
|
|
# if len(result_img_list) > 0:
|
|||
|
|
# return result_img_list
|
|||
|
|
# else:
|
|||
|
|
# return None
|
|||
|
|
#
|
|||
|
|
# #
|
|||
|
|
# # pred_mask[pred_mask == 1] = 255
|
|||
|
|
# # save_path = os.path.join(args.save_dir, f"{os.path.basename(args.config).split('.')[0]}")
|
|||
|
|
# #
|
|||
|
|
# # if not os.path.exists(save_path):
|
|||
|
|
# # os.makedirs(save_path)
|
|||
|
|
# #
|
|||
|
|
# # cv2.imwrite(os.path.join(save_path, f"{os.path.basename(result.img_path).split('.')[0]}.png"), pred_mask,
|
|||
|
|
# # [cv2.IMWRITE_PNG_COMPRESSION, 0])
|
|||
|
|
#
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 做单张图片推理,不然几个图片不好做组合
|
|||
|
|
def predict_pic(taskid, ts, img_url):
|
|||
|
|
model_mmseg = init_model(CONFIG, CHECKPOINT, device=DEVICE)
|
|||
|
|
result_img_list = []
|
|||
|
|
|
|||
|
|
result = inference_model(model_mmseg, img_url)
|
|||
|
|
pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)
|
|||
|
|
saves_path = os.path.join(SAVE_DIR, f"{os.path.basename(result.img_path).split('.')[0]}.png")
|
|||
|
|
save_colored_prediction(pred_mask, saves_path)
|
|||
|
|
# result_img_list.append(saves_path)
|
|||
|
|
return saves_path
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
flight_task_id = "7a5c83e0-fe0d-47bf-a8e1-9bd663508783" # 任务id
|
|||
|
|
list_s3_url = [
|
|||
|
|
"test/patch_0011.png",
|
|||
|
|
"test/patch_0012.png"
|
|||
|
|
] # s3 img 地址集合
|
|||
|
|
local_img_url_list=[]
|
|||
|
|
for img_url in list_s3_url:
|
|||
|
|
pic = downFile(img_url)
|
|||
|
|
pic_path=os.path.abspath(pic)
|
|||
|
|
local_img_url_list.append(pic_path)
|
|||
|
|
list_func_id =[20000,20001] # 方法id集合
|
|||
|
|
predict(flight_task_id,12,local_img_url_list)
|