ai_project_v1/uv_module/prediction.py

123 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import cv2
import argparse
import numpy as np
from pprint import pprint
from tqdm import tqdm
from mmseg.apis import init_model, inference_model
"""
"""
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# 测试图像所在文件夹
# IMAGE_FILE_PATH = r"D:\project\UAV_model/need_prediction"
IMAGE_FILE_PATH = r"C:\Users\14867\Desktop\seg_tiff"
# 模型训练结果的config配置文件路径
CONFIG = r'D:\project\UAV_model/tools/work_dirs/unetformer_UAV_6000X4000/unetformer_UAV_6000X4000.py'
# 模型训练结果的权重文件路径
CHECKPOINT = r'D:\project\UAV_model/tools/work_dirs/unetformer_UAV_6000X4000/best_mIoU_iter_20000.pth'
# 模型推理测试结果的保存路径,每个模型的推理结果都保存在`{save_dir}/{模型config同名文件夹}`下,如文末图片所示。
SAVE_DIR = r"D:\project\UAV_model/prediction_results"
def parse_args():
parser = argparse.ArgumentParser(description='Visualize CAM')
parser.add_argument('--img', default=IMAGE_FILE_PATH, help='Image file')
parser.add_argument('--config', default=CONFIG, help='Config file')
parser.add_argument('--checkpoint', default=CHECKPOINT, help='Checkpoint file')
parser.add_argument('--device', default=DEVICE, help='device')
parser.add_argument('--save_dir', default=SAVE_DIR, help='save_dir')
args = parser.parse_args()
return args
def make_full_path(root_list, root_path):
file_full_path_list = []
for filename in root_list:
file_full_path = os.path.join(root_path, filename)
file_full_path_list.append(file_full_path)
return file_full_path_list
def read_filepath(root):
from natsort import natsorted
test_image_list = natsorted(os.listdir(root))
test_image_full_path_list = make_full_path(test_image_list, root)
return test_image_full_path_list
from PIL import Image
def save_colored_prediction(predictions, save_path):
# color_map = [
# [0, 0, 0],
# [165, 42, 42],
# [0, 255, 0],2
# [255, 255, 0],
# [0, 0, 255],
# [128, 128, 128],
# [0, 255, 255]
# ]
color_map = [
[255, 0, 0], # 类别0: 红色
[0, 255, 0], # 类别1: 绿色
[0, 0, 255], # 类别2: 蓝色
[255, 255, 0], # 类别3: 黄色
[255, 0, 255], # 类别4: 品红
[0, 255, 255], # 类别5: 青色
[128, 0, 128] # 类别6: 紫色
]
# 类别 0黑色 ;背景
# 类别 1棕色 ;荒地
# 类别 2绿色 ;林地
# 类别 3黄色 ;农田
# 类别 4蓝色 ;水域
# 类别 5灰色 ;道路
# 类别 6青色 ;建筑用地
# 创建一个空的 RGB 图像
colored_image = np.zeros((predictions.shape[0], predictions.shape[1], 3), dtype=np.uint8)
# 将每个类别的颜色赋值到图像
for class_id in range(len(color_map)):
colored_image[predictions == class_id] = color_map[class_id]
# 转换为 PIL 图像并保存
image = Image.fromarray(colored_image)
image.save(save_path)
def main():
args = parse_args()
model_mmseg = init_model(args.config, args.checkpoint, device=args.device)
for imgs in tqdm(read_filepath(args.img)):
result = inference_model(model_mmseg, imgs)
pred_mask = result.pred_sem_seg.data.squeeze(0).detach().cpu().numpy().astype(np.uint8)
save_path = os.path.join(args.save_dir, f"{os.path.basename(args.config).split('.')[0]}")
if not os.path.exists(save_path):
os.makedirs(save_path)
saves_path=os.path.join(save_path, f"{os.path.basename(result.img_path).split('.')[0]}.png")
save_colored_prediction(pred_mask,saves_path)
#
# pred_mask[pred_mask == 1] = 255
# save_path = os.path.join(args.save_dir, f"{os.path.basename(args.config).split('.')[0]}")
#
# if not os.path.exists(save_path):
# os.makedirs(save_path)
#
# cv2.imwrite(os.path.join(save_path, f"{os.path.basename(result.img_path).split('.')[0]}.png"), pred_mask,
# [cv2.IMWRITE_PNG_COMPRESSION, 0])
if __name__ == '__main__':
main()