210 lines
7.6 KiB
Python
210 lines
7.6 KiB
Python
import os
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from tqdm import tqdm
|
||
import rasterio
|
||
from rasterio.windows import Window
|
||
|
||
from .models import get_model
|
||
from .data import TileDataset, get_predict_transforms
|
||
|
||
class UAVSegPredictor:
|
||
"""大尺度遥感图像语义分割预测器"""
|
||
def __init__(self,
|
||
model_path,
|
||
model_type='deeplabv3plus',
|
||
num_classes=6,
|
||
tile_size=512,
|
||
stride=None,
|
||
device=None,
|
||
overlap_weights=True):
|
||
"""
|
||
Args:
|
||
model_path (str): 模型权重文件路径
|
||
model_type (str): 模型类型,可选 'deeplabv3plus' 或 'unetpp'
|
||
num_classes (int): 类别数量
|
||
tile_size (int): 分块大小
|
||
stride (int): 滑动窗口步长,如果为None则等于tile_size的一半
|
||
device (str): 设备,如果为None则自动选择
|
||
overlap_weights (bool): 是否使用重叠区域加权融合
|
||
"""
|
||
self.model_path = model_path
|
||
self.model_type = model_type
|
||
self.num_classes = num_classes
|
||
self.tile_size = tile_size
|
||
self.stride = stride if stride is not None else tile_size // 2
|
||
self.overlap_weights = overlap_weights
|
||
|
||
# 设置设备
|
||
if device is None:
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
else:
|
||
self.device = torch.device(device)
|
||
|
||
print(f"使用设备: {self.device}")
|
||
|
||
# 加载模型
|
||
self._load_model()
|
||
|
||
# 数据预处理
|
||
self.transform = get_predict_transforms()
|
||
|
||
def _load_model(self):
|
||
"""加载模型"""
|
||
# 初始化模型
|
||
self.model = get_model(self.model_type, self.num_classes)
|
||
|
||
# 加载权重
|
||
checkpoint = torch.load(self.model_path, map_location=self.device)
|
||
|
||
# 检查模型类型
|
||
if 'model_type' in checkpoint and checkpoint['model_type'] != self.model_type:
|
||
print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配")
|
||
|
||
# 加载模型参数
|
||
if 'model_state_dict' in checkpoint:
|
||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||
else:
|
||
self.model.load_state_dict(checkpoint)
|
||
|
||
# 设置为评估模式
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
print(f"模型已从 {self.model_path} 加载")
|
||
|
||
def predict(self, image_path, output_path=None,color_map=None):
|
||
"""对单张图像进行预测
|
||
|
||
Args:
|
||
image_path (str): 输入图像路径
|
||
output_path (str): 输出掩码路径,如果为None则不保存
|
||
|
||
Returns:
|
||
numpy.ndarray: 预测的分割掩码
|
||
"""
|
||
# 创建分块数据集
|
||
dataset = TileDataset(
|
||
image_path=image_path,
|
||
tile_size=self.tile_size,
|
||
stride=self.stride,
|
||
transform=self.transform
|
||
)
|
||
|
||
# 获取图像信息
|
||
with rasterio.open(image_path) as src:
|
||
height, width = src.height, src.width
|
||
profile = src.profile.copy()
|
||
|
||
# 初始化结果和权重图
|
||
result = np.zeros((self.num_classes, height, width), dtype=np.float32)
|
||
weights = np.zeros((height, width), dtype=np.float32)
|
||
|
||
# 生成权重图(如果使用重叠区域加权融合)
|
||
if self.overlap_weights:
|
||
weight_kernel = self._get_weight_kernel()
|
||
|
||
# 预测每个分块
|
||
with torch.no_grad():
|
||
for i, (image, tile_info) in enumerate(tqdm(dataset, desc="预测分块")):
|
||
# 获取分块位置信息
|
||
x, y = tile_info['x'], tile_info['y']
|
||
|
||
# 将图像移动到设备
|
||
image = image.unsqueeze(0).to(self.device)
|
||
|
||
# 预测
|
||
pred = self.model(image)
|
||
|
||
# 将预测结果转换为概率
|
||
pred = F.softmax(pred, dim=1)
|
||
|
||
# 将预测结果移动到CPU并转换为NumPy数组
|
||
pred = pred.squeeze(0).cpu().numpy()
|
||
|
||
# 更新结果和权重图
|
||
if self.overlap_weights:
|
||
# 使用权重核
|
||
for c in range(self.num_classes):
|
||
result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c] * weight_kernel
|
||
weights[y:y+self.tile_size, x:x+self.tile_size] += weight_kernel
|
||
else:
|
||
# 简单累加
|
||
for c in range(self.num_classes):
|
||
result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c]
|
||
weights[y:y+self.tile_size, x:x+self.tile_size] += 1
|
||
|
||
# 归一化结果
|
||
for c in range(self.num_classes):
|
||
result[c] = np.divide(result[c], weights, out=np.zeros_like(result[c]), where=weights > 0)
|
||
|
||
# 获取最终分割掩码(类别索引)
|
||
mask = np.argmax(result, axis=0).astype(np.uint8)
|
||
|
||
# 保存结果
|
||
if output_path is not None:
|
||
# 更新profile
|
||
profile.update({
|
||
'driver': 'GTiff',
|
||
'height': height,
|
||
'width': width,
|
||
'count': 1,
|
||
'dtype': 'uint8',
|
||
'nodata': None
|
||
})
|
||
|
||
|
||
color_map = color_map
|
||
|
||
# 生成彩色图像
|
||
colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
|
||
for i in range(self.num_classes):
|
||
colored_mask[mask == i] = color_map[i]
|
||
|
||
# 保存彩色图像
|
||
from PIL import Image
|
||
colored_path = output_path.replace('.png', '_colored.png')
|
||
Image.fromarray(colored_mask).save(colored_path)
|
||
print(f"彩色预测图已保存到 {colored_path}")
|
||
|
||
# 保存原始掩码
|
||
with rasterio.open(output_path, 'w', **profile) as dst:
|
||
dst.write(mask, 1)
|
||
|
||
print(f"预测结果已保存到 {output_path}")
|
||
|
||
return mask
|
||
|
||
def predict_folder(self, input_dir, output_dir,color_map=None):
|
||
"""对文件夹中的所有图像进行预测
|
||
|
||
Args:
|
||
input_dir (str): 输入图像目录
|
||
output_dir (str): 输出掩码目录
|
||
"""
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 获取所有图像文件
|
||
image_files = [f for f in os.listdir(input_dir) if f.endswith(('.tif', '.png', '.jpg'))]
|
||
|
||
# 预测每张图像
|
||
for image_file in tqdm(image_files, desc="预测图像"):
|
||
# 构建路径
|
||
image_path = os.path.join(input_dir, image_file)
|
||
output_path = os.path.join(output_dir, image_file)
|
||
|
||
# 预测
|
||
self.predict(image_path, output_path,color_map)
|
||
|
||
def _get_weight_kernel(self):
|
||
"""生成权重核,用于重叠区域加权融合"""
|
||
# 创建二维高斯核
|
||
y, x = np.mgrid[0:self.tile_size, 0:self.tile_size]
|
||
x = x.astype(np.float32) - self.tile_size / 2
|
||
y = y.astype(np.float32) - self.tile_size / 2
|
||
sigma = self.tile_size / 4
|
||
weight = np.exp(-(x**2 + y**2) / (2 * sigma**2))
|
||
|
||
return weight |