210 lines
7.6 KiB
Python
Raw Permalink Normal View History

2025-07-10 09:41:26 +08:00
import os
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import rasterio
from rasterio.windows import Window
from .models import get_model
from .data import TileDataset, get_predict_transforms
class UAVSegPredictor:
"""大尺度遥感图像语义分割预测器"""
def __init__(self,
model_path,
model_type='deeplabv3plus',
num_classes=6,
tile_size=512,
stride=None,
device=None,
overlap_weights=True):
"""
Args:
model_path (str): 模型权重文件路径
model_type (str): 模型类型可选 'deeplabv3plus' 'unetpp'
num_classes (int): 类别数量
tile_size (int): 分块大小
stride (int): 滑动窗口步长如果为None则等于tile_size的一半
device (str): 设备如果为None则自动选择
overlap_weights (bool): 是否使用重叠区域加权融合
"""
self.model_path = model_path
self.model_type = model_type
self.num_classes = num_classes
self.tile_size = tile_size
self.stride = stride if stride is not None else tile_size // 2
self.overlap_weights = overlap_weights
# 设置设备
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = torch.device(device)
print(f"使用设备: {self.device}")
# 加载模型
self._load_model()
# 数据预处理
self.transform = get_predict_transforms()
def _load_model(self):
"""加载模型"""
# 初始化模型
self.model = get_model(self.model_type, self.num_classes)
# 加载权重
checkpoint = torch.load(self.model_path, map_location=self.device)
# 检查模型类型
if 'model_type' in checkpoint and checkpoint['model_type'] != self.model_type:
print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配")
# 加载模型参数
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
# 设置为评估模式
self.model.to(self.device)
self.model.eval()
print(f"模型已从 {self.model_path} 加载")
def predict(self, image_path, output_path=None,color_map=None):
"""对单张图像进行预测
Args:
image_path (str): 输入图像路径
output_path (str): 输出掩码路径如果为None则不保存
Returns:
numpy.ndarray: 预测的分割掩码
"""
# 创建分块数据集
dataset = TileDataset(
image_path=image_path,
tile_size=self.tile_size,
stride=self.stride,
transform=self.transform
)
# 获取图像信息
with rasterio.open(image_path) as src:
height, width = src.height, src.width
profile = src.profile.copy()
# 初始化结果和权重图
result = np.zeros((self.num_classes, height, width), dtype=np.float32)
weights = np.zeros((height, width), dtype=np.float32)
# 生成权重图(如果使用重叠区域加权融合)
if self.overlap_weights:
weight_kernel = self._get_weight_kernel()
# 预测每个分块
with torch.no_grad():
for i, (image, tile_info) in enumerate(tqdm(dataset, desc="预测分块")):
# 获取分块位置信息
x, y = tile_info['x'], tile_info['y']
# 将图像移动到设备
image = image.unsqueeze(0).to(self.device)
# 预测
pred = self.model(image)
# 将预测结果转换为概率
pred = F.softmax(pred, dim=1)
# 将预测结果移动到CPU并转换为NumPy数组
pred = pred.squeeze(0).cpu().numpy()
# 更新结果和权重图
if self.overlap_weights:
# 使用权重核
for c in range(self.num_classes):
result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c] * weight_kernel
weights[y:y+self.tile_size, x:x+self.tile_size] += weight_kernel
else:
# 简单累加
for c in range(self.num_classes):
result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c]
weights[y:y+self.tile_size, x:x+self.tile_size] += 1
# 归一化结果
for c in range(self.num_classes):
result[c] = np.divide(result[c], weights, out=np.zeros_like(result[c]), where=weights > 0)
# 获取最终分割掩码(类别索引)
mask = np.argmax(result, axis=0).astype(np.uint8)
# 保存结果
if output_path is not None:
# 更新profile
profile.update({
'driver': 'GTiff',
'height': height,
'width': width,
'count': 1,
'dtype': 'uint8',
'nodata': None
})
color_map = color_map
# 生成彩色图像
colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
for i in range(self.num_classes):
colored_mask[mask == i] = color_map[i]
# 保存彩色图像
from PIL import Image
colored_path = output_path.replace('.png', '_colored.png')
Image.fromarray(colored_mask).save(colored_path)
print(f"彩色预测图已保存到 {colored_path}")
# 保存原始掩码
with rasterio.open(output_path, 'w', **profile) as dst:
dst.write(mask, 1)
print(f"预测结果已保存到 {output_path}")
return mask
def predict_folder(self, input_dir, output_dir,color_map=None):
"""对文件夹中的所有图像进行预测
Args:
input_dir (str): 输入图像目录
output_dir (str): 输出掩码目录
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 获取所有图像文件
image_files = [f for f in os.listdir(input_dir) if f.endswith(('.tif', '.png', '.jpg'))]
# 预测每张图像
for image_file in tqdm(image_files, desc="预测图像"):
# 构建路径
image_path = os.path.join(input_dir, image_file)
output_path = os.path.join(output_dir, image_file)
# 预测
self.predict(image_path, output_path,color_map)
def _get_weight_kernel(self):
"""生成权重核,用于重叠区域加权融合"""
# 创建二维高斯核
y, x = np.mgrid[0:self.tile_size, 0:self.tile_size]
x = x.astype(np.float32) - self.tile_size / 2
y = y.astype(np.float32) - self.tile_size / 2
sigma = self.tile_size / 4
weight = np.exp(-(x**2 + y**2) / (2 * sigma**2))
return weight