ai_project_v1/uv_module/uv_prediction.py

145 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import cv2
import argparse
import numpy as np
from pprint import pprint
from tqdm import tqdm
from middleware.minio_util import downFile
from mmseg.apis import init_model, inference_model
"""
"""
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# 测试图像所在文件夹
# IMAGE_FILE_PATH = r"D:\project\UAV_model/need_prediction"
IMAGE_FILE_PATH = r"C:\Users\14867\Desktop\seg_tiff"
# 模型训练结果的config配置文件路径
CONFIG = r'unetformer_UAV_6000X4000.py'
# 模型训练结果的权重文件路径
CHECKPOINT = r'best_mIoU_iter_20000.pth'
# 模型推理测试结果的保存路径,每个模型的推理结果都保存在`{save_dir}/{模型config同名文件夹}`下,如文末图片所示。
SAVE_DIR = r"uv_predict_result"
def make_full_path(root_list, root_path):
file_full_path_list = []
for filename in root_list:
file_full_path = os.path.join(root_path, filename)
file_full_path_list.append(file_full_path)
return file_full_path_list
def read_filepath(root):
from natsort import natsorted
test_image_list = natsorted(os.listdir(root))
test_image_full_path_list = make_full_path(test_image_list, root)
return test_image_full_path_list
from PIL import Image
def save_colored_prediction(predictions, save_path):
# color_map = [
# [0, 0, 0],
# [165, 42, 42],
# [0, 255, 0],2
# [255, 255, 0],
# [0, 0, 255],
# [128, 128, 128],
# [0, 255, 255]
# ]
color_map = [
[255, 0, 0], # 类别0: 红色
[0, 255, 0], # 类别1: 绿色
[0, 0, 255], # 类别2: 蓝色
[255, 255, 0], # 类别3: 黄色
[255, 0, 255], # 类别4: 品红
[0, 255, 255], # 类别5: 青色
[128, 0, 128] # 类别6: 紫色
]
# 类别 0黑色 ;背景
# 类别 1棕色 ;荒地
# 类别 2绿色 ;林地
# 类别 3黄色 ;农田
# 类别 4蓝色 ;水域
# 类别 5灰色 ;道路
# 类别 6青色 ;建筑用地
# 创建一个空的 RGB 图像
colored_image = np.zeros((predictions.shape[0], predictions.shape[1], 3), dtype=np.uint8)
# 将每个类别的颜色赋值到图像
for class_id in range(len(color_map)):
colored_image[predictions == class_id] = color_map[class_id]
# 转换为 PIL 图像并保存
image = Image.fromarray(colored_image)
image.save(save_path)
#
# def predict(taskid, ts, img_url_list):
# model_mmseg = init_model(CONFIG, CHECKPOINT, device=DEVICE)
# result_img_list = []
# for imgs in tqdm(img_url_list):
# result = inference_model(model_mmseg, imgs)
# pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)
#
# # save_path = os.path.join(SAVE_DIR, f"{os.path.basename(CONFIG).split('.')[0]}")
# # if not os.path.exists(save_path):
# # os.makedirs(save_path)
# saves_path = os.path.join(SAVE_DIR, f"{os.path.basename(result.img_path).split('.')[0]}.png")
# save_colored_prediction(pred_mask, saves_path)
# result_img_list.append(saves_path)
#
# if len(result_img_list) > 0:
# return result_img_list
# else:
# return None
#
# #
# # pred_mask[pred_mask == 1] = 255
# # save_path = os.path.join(args.save_dir, f"{os.path.basename(args.config).split('.')[0]}")
# #
# # if not os.path.exists(save_path):
# # os.makedirs(save_path)
# #
# # cv2.imwrite(os.path.join(save_path, f"{os.path.basename(result.img_path).split('.')[0]}.png"), pred_mask,
# # [cv2.IMWRITE_PNG_COMPRESSION, 0])
#
# 做单张图片推理,不然几个图片不好做组合
def predict_pic(taskid, ts, img_url):
model_mmseg = init_model(CONFIG, CHECKPOINT, device=DEVICE)
result_img_list = []
result = inference_model(model_mmseg, img_url)
pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)
saves_path = os.path.join(SAVE_DIR, f"{os.path.basename(result.img_path).split('.')[0]}.png")
save_colored_prediction(pred_mask, saves_path)
# result_img_list.append(saves_path)
return saves_path
if __name__ == '__main__':
flight_task_id = "7a5c83e0-fe0d-47bf-a8e1-9bd663508783" # 任务id
list_s3_url = [
"test/patch_0011.png",
"test/patch_0012.png"
] # s3 img 地址集合
local_img_url_list=[]
for img_url in list_s3_url:
pic = downFile(img_url)
pic_path=os.path.abspath(pic)
local_img_url_list.append(pic_path)
list_func_id =[20000,20001] # 方法id集合
predict(flight_task_id,12,local_img_url_list)