2025-07-10 09:41:26 +08:00

210 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import rasterio
from rasterio.windows import Window
from .models import get_model
from .data import TileDataset, get_predict_transforms
class UAVSegPredictor:
"""大尺度遥感图像语义分割预测器"""
def __init__(self,
model_path,
model_type='deeplabv3plus',
num_classes=6,
tile_size=512,
stride=None,
device=None,
overlap_weights=True):
"""
Args:
model_path (str): 模型权重文件路径
model_type (str): 模型类型,可选 'deeplabv3plus''unetpp'
num_classes (int): 类别数量
tile_size (int): 分块大小
stride (int): 滑动窗口步长如果为None则等于tile_size的一半
device (str): 设备如果为None则自动选择
overlap_weights (bool): 是否使用重叠区域加权融合
"""
self.model_path = model_path
self.model_type = model_type
self.num_classes = num_classes
self.tile_size = tile_size
self.stride = stride if stride is not None else tile_size // 2
self.overlap_weights = overlap_weights
# 设置设备
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"使用设备: {self.device}")
# 加载模型
self._load_model()
# 数据预处理
self.transform = get_predict_transforms()
def _load_model(self):
"""加载模型"""
# 初始化模型
self.model = get_model(self.model_type, self.num_classes)
# 加载权重
checkpoint = torch.load(self.model_path, map_location=self.device)
# 检查模型类型
if 'model_type' in checkpoint and checkpoint['model_type'] != self.model_type:
print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配")
# 加载模型参数
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
# 设置为评估模式
self.model.to(self.device)
self.model.eval()
print(f"模型已从 {self.model_path} 加载")
def predict(self, image_path, output_path=None,color_map=None):
"""对单张图像进行预测
Args:
image_path (str): 输入图像路径
output_path (str): 输出掩码路径如果为None则不保存
Returns:
numpy.ndarray: 预测的分割掩码
"""
# 创建分块数据集
dataset = TileDataset(
image_path=image_path,
tile_size=self.tile_size,
stride=self.stride,
transform=self.transform
)
# 获取图像信息
with rasterio.open(image_path) as src:
height, width = src.height, src.width
profile = src.profile.copy()
# 初始化结果和权重图
result = np.zeros((self.num_classes, height, width), dtype=np.float32)
weights = np.zeros((height, width), dtype=np.float32)
# 生成权重图(如果使用重叠区域加权融合)
if self.overlap_weights:
weight_kernel = self._get_weight_kernel()
# 预测每个分块
with torch.no_grad():
for i, (image, tile_info) in enumerate(tqdm(dataset, desc="预测分块")):
# 获取分块位置信息
x, y = tile_info['x'], tile_info['y']
# 将图像移动到设备
image = image.unsqueeze(0).to(self.device)
# 预测
pred = self.model(image)
# 将预测结果转换为概率
pred = F.softmax(pred, dim=1)
# 将预测结果移动到CPU并转换为NumPy数组
pred = pred.squeeze(0).cpu().numpy()
# 更新结果和权重图
if self.overlap_weights:
# 使用权重核
for c in range(self.num_classes):
result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c] * weight_kernel
weights[y:y+self.tile_size, x:x+self.tile_size] += weight_kernel
else:
# 简单累加
for c in range(self.num_classes):
result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c]
weights[y:y+self.tile_size, x:x+self.tile_size] += 1
# 归一化结果
for c in range(self.num_classes):
result[c] = np.divide(result[c], weights, out=np.zeros_like(result[c]), where=weights > 0)
# 获取最终分割掩码(类别索引)
mask = np.argmax(result, axis=0).astype(np.uint8)
# 保存结果
if output_path is not None:
# 更新profile
profile.update({
'driver': 'GTiff',
'height': height,
'width': width,
'count': 1,
'dtype': 'uint8',
'nodata': None
})
color_map = color_map
# 生成彩色图像
colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
for i in range(self.num_classes):
colored_mask[mask == i] = color_map[i]
# 保存彩色图像
from PIL import Image
colored_path = output_path.replace('.png', '_colored.png')
Image.fromarray(colored_mask).save(colored_path)
print(f"彩色预测图已保存到 {colored_path}")
# 保存原始掩码
with rasterio.open(output_path, 'w', **profile) as dst:
dst.write(mask, 1)
print(f"预测结果已保存到 {output_path}")
return mask
def predict_folder(self, input_dir, output_dir,color_map=None):
"""对文件夹中的所有图像进行预测
Args:
input_dir (str): 输入图像目录
output_dir (str): 输出掩码目录
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 获取所有图像文件
image_files = [f for f in os.listdir(input_dir) if f.endswith(('.tif', '.png', '.jpg'))]
# 预测每张图像
for image_file in tqdm(image_files, desc="预测图像"):
# 构建路径
image_path = os.path.join(input_dir, image_file)
output_path = os.path.join(output_dir, image_file)
# 预测
self.predict(image_path, output_path,color_map)
def _get_weight_kernel(self):
"""生成权重核,用于重叠区域加权融合"""
# 创建二维高斯核
y, x = np.mgrid[0:self.tile_size, 0:self.tile_size]
x = x.astype(np.float32) - self.tile_size / 2
y = y.astype(np.float32) - self.tile_size / 2
sigma = self.tile_size / 4
weight = np.exp(-(x**2 + y**2) / (2 * sigma**2))
return weight