210 lines
7.6 KiB
Python
210 lines
7.6 KiB
Python
|
import os
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import torch.nn.functional as F
|
|||
|
from tqdm import tqdm
|
|||
|
import rasterio
|
|||
|
from rasterio.windows import Window
|
|||
|
|
|||
|
from .models import get_model
|
|||
|
from .data import TileDataset, get_predict_transforms
|
|||
|
|
|||
|
class UAVSegPredictor:
|
|||
|
"""大尺度遥感图像语义分割预测器"""
|
|||
|
def __init__(self,
|
|||
|
model_path,
|
|||
|
model_type='deeplabv3plus',
|
|||
|
num_classes=6,
|
|||
|
tile_size=512,
|
|||
|
stride=None,
|
|||
|
device=None,
|
|||
|
overlap_weights=True):
|
|||
|
"""
|
|||
|
Args:
|
|||
|
model_path (str): 模型权重文件路径
|
|||
|
model_type (str): 模型类型,可选 'deeplabv3plus' 或 'unetpp'
|
|||
|
num_classes (int): 类别数量
|
|||
|
tile_size (int): 分块大小
|
|||
|
stride (int): 滑动窗口步长,如果为None则等于tile_size的一半
|
|||
|
device (str): 设备,如果为None则自动选择
|
|||
|
overlap_weights (bool): 是否使用重叠区域加权融合
|
|||
|
"""
|
|||
|
self.model_path = model_path
|
|||
|
self.model_type = model_type
|
|||
|
self.num_classes = num_classes
|
|||
|
self.tile_size = tile_size
|
|||
|
self.stride = stride if stride is not None else tile_size // 2
|
|||
|
self.overlap_weights = overlap_weights
|
|||
|
|
|||
|
# 设置设备
|
|||
|
if device is None:
|
|||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
else:
|
|||
|
self.device = torch.device(device)
|
|||
|
|
|||
|
print(f"使用设备: {self.device}")
|
|||
|
|
|||
|
# 加载模型
|
|||
|
self._load_model()
|
|||
|
|
|||
|
# 数据预处理
|
|||
|
self.transform = get_predict_transforms()
|
|||
|
|
|||
|
def _load_model(self):
|
|||
|
"""加载模型"""
|
|||
|
# 初始化模型
|
|||
|
self.model = get_model(self.model_type, self.num_classes)
|
|||
|
|
|||
|
# 加载权重
|
|||
|
checkpoint = torch.load(self.model_path, map_location=self.device)
|
|||
|
|
|||
|
# 检查模型类型
|
|||
|
if 'model_type' in checkpoint and checkpoint['model_type'] != self.model_type:
|
|||
|
print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配")
|
|||
|
|
|||
|
# 加载模型参数
|
|||
|
if 'model_state_dict' in checkpoint:
|
|||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|||
|
else:
|
|||
|
self.model.load_state_dict(checkpoint)
|
|||
|
|
|||
|
# 设置为评估模式
|
|||
|
self.model.to(self.device)
|
|||
|
self.model.eval()
|
|||
|
|
|||
|
print(f"模型已从 {self.model_path} 加载")
|
|||
|
|
|||
|
def predict(self, image_path, output_path=None,color_map=None):
|
|||
|
"""对单张图像进行预测
|
|||
|
|
|||
|
Args:
|
|||
|
image_path (str): 输入图像路径
|
|||
|
output_path (str): 输出掩码路径,如果为None则不保存
|
|||
|
|
|||
|
Returns:
|
|||
|
numpy.ndarray: 预测的分割掩码
|
|||
|
"""
|
|||
|
# 创建分块数据集
|
|||
|
dataset = TileDataset(
|
|||
|
image_path=image_path,
|
|||
|
tile_size=self.tile_size,
|
|||
|
stride=self.stride,
|
|||
|
transform=self.transform
|
|||
|
)
|
|||
|
|
|||
|
# 获取图像信息
|
|||
|
with rasterio.open(image_path) as src:
|
|||
|
height, width = src.height, src.width
|
|||
|
profile = src.profile.copy()
|
|||
|
|
|||
|
# 初始化结果和权重图
|
|||
|
result = np.zeros((self.num_classes, height, width), dtype=np.float32)
|
|||
|
weights = np.zeros((height, width), dtype=np.float32)
|
|||
|
|
|||
|
# 生成权重图(如果使用重叠区域加权融合)
|
|||
|
if self.overlap_weights:
|
|||
|
weight_kernel = self._get_weight_kernel()
|
|||
|
|
|||
|
# 预测每个分块
|
|||
|
with torch.no_grad():
|
|||
|
for i, (image, tile_info) in enumerate(tqdm(dataset, desc="预测分块")):
|
|||
|
# 获取分块位置信息
|
|||
|
x, y = tile_info['x'], tile_info['y']
|
|||
|
|
|||
|
# 将图像移动到设备
|
|||
|
image = image.unsqueeze(0).to(self.device)
|
|||
|
|
|||
|
# 预测
|
|||
|
pred = self.model(image)
|
|||
|
|
|||
|
# 将预测结果转换为概率
|
|||
|
pred = F.softmax(pred, dim=1)
|
|||
|
|
|||
|
# 将预测结果移动到CPU并转换为NumPy数组
|
|||
|
pred = pred.squeeze(0).cpu().numpy()
|
|||
|
|
|||
|
# 更新结果和权重图
|
|||
|
if self.overlap_weights:
|
|||
|
# 使用权重核
|
|||
|
for c in range(self.num_classes):
|
|||
|
result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c] * weight_kernel
|
|||
|
weights[y:y+self.tile_size, x:x+self.tile_size] += weight_kernel
|
|||
|
else:
|
|||
|
# 简单累加
|
|||
|
for c in range(self.num_classes):
|
|||
|
result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c]
|
|||
|
weights[y:y+self.tile_size, x:x+self.tile_size] += 1
|
|||
|
|
|||
|
# 归一化结果
|
|||
|
for c in range(self.num_classes):
|
|||
|
result[c] = np.divide(result[c], weights, out=np.zeros_like(result[c]), where=weights > 0)
|
|||
|
|
|||
|
# 获取最终分割掩码(类别索引)
|
|||
|
mask = np.argmax(result, axis=0).astype(np.uint8)
|
|||
|
|
|||
|
# 保存结果
|
|||
|
if output_path is not None:
|
|||
|
# 更新profile
|
|||
|
profile.update({
|
|||
|
'driver': 'GTiff',
|
|||
|
'height': height,
|
|||
|
'width': width,
|
|||
|
'count': 1,
|
|||
|
'dtype': 'uint8',
|
|||
|
'nodata': None
|
|||
|
})
|
|||
|
|
|||
|
|
|||
|
color_map = color_map
|
|||
|
|
|||
|
# 生成彩色图像
|
|||
|
colored_mask = np.zeros((height, width, 3), dtype=np.uint8)
|
|||
|
for i in range(self.num_classes):
|
|||
|
colored_mask[mask == i] = color_map[i]
|
|||
|
|
|||
|
# 保存彩色图像
|
|||
|
from PIL import Image
|
|||
|
colored_path = output_path.replace('.png', '_colored.png')
|
|||
|
Image.fromarray(colored_mask).save(colored_path)
|
|||
|
print(f"彩色预测图已保存到 {colored_path}")
|
|||
|
|
|||
|
# 保存原始掩码
|
|||
|
with rasterio.open(output_path, 'w', **profile) as dst:
|
|||
|
dst.write(mask, 1)
|
|||
|
|
|||
|
print(f"预测结果已保存到 {output_path}")
|
|||
|
|
|||
|
return mask
|
|||
|
|
|||
|
def predict_folder(self, input_dir, output_dir,color_map=None):
|
|||
|
"""对文件夹中的所有图像进行预测
|
|||
|
|
|||
|
Args:
|
|||
|
input_dir (str): 输入图像目录
|
|||
|
output_dir (str): 输出掩码目录
|
|||
|
"""
|
|||
|
# 创建输出目录
|
|||
|
os.makedirs(output_dir, exist_ok=True)
|
|||
|
|
|||
|
# 获取所有图像文件
|
|||
|
image_files = [f for f in os.listdir(input_dir) if f.endswith(('.tif', '.png', '.jpg'))]
|
|||
|
|
|||
|
# 预测每张图像
|
|||
|
for image_file in tqdm(image_files, desc="预测图像"):
|
|||
|
# 构建路径
|
|||
|
image_path = os.path.join(input_dir, image_file)
|
|||
|
output_path = os.path.join(output_dir, image_file)
|
|||
|
|
|||
|
# 预测
|
|||
|
self.predict(image_path, output_path,color_map)
|
|||
|
|
|||
|
def _get_weight_kernel(self):
|
|||
|
"""生成权重核,用于重叠区域加权融合"""
|
|||
|
# 创建二维高斯核
|
|||
|
y, x = np.mgrid[0:self.tile_size, 0:self.tile_size]
|
|||
|
x = x.astype(np.float32) - self.tile_size / 2
|
|||
|
y = y.astype(np.float32) - self.tile_size / 2
|
|||
|
sigma = self.tile_size / 4
|
|||
|
weight = np.exp(-(x**2 + y**2) / (2 * sigma**2))
|
|||
|
|
|||
|
return weight
|