import os import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm import rasterio from rasterio.windows import Window from .models import get_model from .data import TileDataset, get_predict_transforms class UAVSegPredictor: """大尺度遥感图像语义分割预测器""" def __init__(self, model_path, model_type='deeplabv3plus', num_classes=6, tile_size=512, stride=None, device=None, overlap_weights=True): """ Args: model_path (str): 模型权重文件路径 model_type (str): 模型类型,可选 'deeplabv3plus' 或 'unetpp' num_classes (int): 类别数量 tile_size (int): 分块大小 stride (int): 滑动窗口步长,如果为None则等于tile_size的一半 device (str): 设备,如果为None则自动选择 overlap_weights (bool): 是否使用重叠区域加权融合 """ self.model_path = model_path self.model_type = model_type self.num_classes = num_classes self.tile_size = tile_size self.stride = stride if stride is not None else tile_size // 2 self.overlap_weights = overlap_weights # 设置设备 if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) print(f"使用设备: {self.device}") # 加载模型 self._load_model() # 数据预处理 self.transform = get_predict_transforms() def _load_model(self): """加载模型""" # 初始化模型 self.model = get_model(self.model_type, self.num_classes) # 加载权重 checkpoint = torch.load(self.model_path, map_location=self.device) # 检查模型类型 if 'model_type' in checkpoint and checkpoint['model_type'] != self.model_type: print(f"警告: 加载的模型类型 ({checkpoint['model_type']}) 与当前模型类型 ({self.model_type}) 不匹配") # 加载模型参数 if 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) else: self.model.load_state_dict(checkpoint) # 设置为评估模式 self.model.to(self.device) self.model.eval() print(f"模型已从 {self.model_path} 加载") def predict(self, image_path, output_path=None,color_map=None): """对单张图像进行预测 Args: image_path (str): 输入图像路径 output_path (str): 输出掩码路径,如果为None则不保存 Returns: numpy.ndarray: 预测的分割掩码 """ # 创建分块数据集 dataset = TileDataset( image_path=image_path, tile_size=self.tile_size, stride=self.stride, transform=self.transform ) # 获取图像信息 with rasterio.open(image_path) as src: height, width = src.height, src.width profile = src.profile.copy() # 初始化结果和权重图 result = np.zeros((self.num_classes, height, width), dtype=np.float32) weights = np.zeros((height, width), dtype=np.float32) # 生成权重图(如果使用重叠区域加权融合) if self.overlap_weights: weight_kernel = self._get_weight_kernel() # 预测每个分块 with torch.no_grad(): for i, (image, tile_info) in enumerate(tqdm(dataset, desc="预测分块")): # 获取分块位置信息 x, y = tile_info['x'], tile_info['y'] # 将图像移动到设备 image = image.unsqueeze(0).to(self.device) # 预测 pred = self.model(image) # 将预测结果转换为概率 pred = F.softmax(pred, dim=1) # 将预测结果移动到CPU并转换为NumPy数组 pred = pred.squeeze(0).cpu().numpy() # 更新结果和权重图 if self.overlap_weights: # 使用权重核 for c in range(self.num_classes): result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c] * weight_kernel weights[y:y+self.tile_size, x:x+self.tile_size] += weight_kernel else: # 简单累加 for c in range(self.num_classes): result[c, y:y+self.tile_size, x:x+self.tile_size] += pred[c] weights[y:y+self.tile_size, x:x+self.tile_size] += 1 # 归一化结果 for c in range(self.num_classes): result[c] = np.divide(result[c], weights, out=np.zeros_like(result[c]), where=weights > 0) # 获取最终分割掩码(类别索引) mask = np.argmax(result, axis=0).astype(np.uint8) # 保存结果 if output_path is not None: # 更新profile profile.update({ 'driver': 'GTiff', 'height': height, 'width': width, 'count': 1, 'dtype': 'uint8', 'nodata': None }) color_map = color_map # 生成彩色图像 colored_mask = np.zeros((height, width, 3), dtype=np.uint8) for i in range(self.num_classes): colored_mask[mask == i] = color_map[i] # 保存彩色图像 from PIL import Image colored_path = output_path.replace('.png', '_colored.png') Image.fromarray(colored_mask).save(colored_path) print(f"彩色预测图已保存到 {colored_path}") # 保存原始掩码 with rasterio.open(output_path, 'w', **profile) as dst: dst.write(mask, 1) print(f"预测结果已保存到 {output_path}") return mask def predict_folder(self, input_dir, output_dir,color_map=None): """对文件夹中的所有图像进行预测 Args: input_dir (str): 输入图像目录 output_dir (str): 输出掩码目录 """ # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 获取所有图像文件 image_files = [f for f in os.listdir(input_dir) if f.endswith(('.tif', '.png', '.jpg'))] # 预测每张图像 for image_file in tqdm(image_files, desc="预测图像"): # 构建路径 image_path = os.path.join(input_dir, image_file) output_path = os.path.join(output_dir, image_file) # 预测 self.predict(image_path, output_path,color_map) def _get_weight_kernel(self): """生成权重核,用于重叠区域加权融合""" # 创建二维高斯核 y, x = np.mgrid[0:self.tile_size, 0:self.tile_size] x = x.astype(np.float32) - self.tile_size / 2 y = y.astype(np.float32) - self.tile_size / 2 sigma = self.tile_size / 4 weight = np.exp(-(x**2 + y**2) / (2 * sigma**2)) return weight