import os import numpy as np import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import rasterio from rasterio.windows import Window import albumentations as A from albumentations.pytorch import ToTensorV2 import math class RemoteSensingDataset(Dataset): """遥感图像数据集类""" def __init__(self, image_dir, mask_dir, transform=None, tile_size=512, stride=None): """ Args: image_dir (str): 图像目录路径 mask_dir (str): 掩码目录路径 transform: 数据增强和预处理 tile_size (int): 分块大小 stride (int): 滑动窗口步长,如果为None则等于tile_size """ self.image_dir = image_dir self.mask_dir = mask_dir self.transform = transform self.tile_size = tile_size self.stride = stride if stride is not None else tile_size self.images = [f for f in os.listdir(image_dir) if f.endswith(('.tif', '.png', '.jpg'))] self.tiles_info = self._prepare_tiles() def _prepare_tiles(self): """准备图像分块信息""" tiles_info = [] for img_name in self.images: img_path = os.path.join(self.image_dir, img_name) mask_path = os.path.join(self.mask_dir, img_name) # 检查掩码文件是否存在 if not os.path.exists(mask_path): # 尝试不同的扩展名 base_name = os.path.splitext(img_name)[0] for ext in ['.tif', '.png', '.jpg']: alt_mask_path = os.path.join(self.mask_dir, base_name + ext) if os.path.exists(alt_mask_path): mask_path = alt_mask_path break else: print(f"警告: 找不到图像 {img_name} 对应的掩码文件") continue # 获取图像尺寸 with rasterio.open(img_path) as src: height, width = src.height, src.width # 计算分块数量 num_tiles_h = math.ceil((height - self.tile_size) / self.stride) + 1 num_tiles_w = math.ceil((width - self.tile_size) / self.stride) + 1 # 生成分块信息 for i in range(num_tiles_h): for j in range(num_tiles_w): y = min(i * self.stride, height - self.tile_size) x = min(j * self.stride, width - self.tile_size) tiles_info.append({ 'img_path': img_path, 'mask_path': mask_path, 'x': x, 'y': y, 'width': self.tile_size, 'height': self.tile_size }) return tiles_info def __len__(self): return len(self.tiles_info) def __getitem__(self, idx): tile_info = self.tiles_info[idx] # 读取图像分块 with rasterio.open(tile_info['img_path']) as src: window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height']) image = src.read(window=window) # 转换为RGB格式 (C, H, W) -> (H, W, C) image = np.transpose(image, (1, 2, 0)) # 处理单通道或四通道图像 if image.shape[2] == 1: image = np.repeat(image, 3, axis=2) elif image.shape[2] > 3: image = image[:, :, :3] # 只使用RGB通道 # 读取掩码分块 with rasterio.open(tile_info['mask_path']) as src: window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height']) mask = src.read(1, window=window) # 假设掩码是单通道的 # 应用数据增强和预处理 if self.transform: augmented = self.transform(image=image, mask=mask) image = augmented['image'] mask = augmented['mask'] return image, mask.long() class TileDataset(Dataset): """用于大尺度图像预测的分块数据集""" def __init__(self, image_path, tile_size=512, stride=None, transform=None): """ Args: image_path (str): 图像文件路径 tile_size (int): 分块大小 stride (int): 滑动窗口步长,如果为None则等于tile_size transform: 数据预处理 """ self.image_path = image_path self.tile_size = tile_size self.stride = stride if stride is not None else tile_size self.transform = transform # 获取图像尺寸 with rasterio.open(image_path) as src: self.height, self.width = src.height, src.width self.profile = src.profile.copy() # 计算分块数量 self.num_tiles_h = math.ceil((self.height - self.tile_size) / self.stride) + 1 self.num_tiles_w = math.ceil((self.width - self.tile_size) / self.stride) + 1 # 生成分块信息 self.tiles_info = [] for i in range(self.num_tiles_h): for j in range(self.num_tiles_w): y = min(i * self.stride, self.height - self.tile_size) x = min(j * self.stride, self.width - self.tile_size) self.tiles_info.append({ 'x': x, 'y': y, 'width': self.tile_size, 'height': self.tile_size }) def __len__(self): return len(self.tiles_info) def __getitem__(self, idx): tile_info = self.tiles_info[idx] # 读取图像分块 with rasterio.open(self.image_path) as src: window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height']) image = src.read(window=window) # 转换为RGB格式 (C, H, W) -> (H, W, C) image = np.transpose(image, (1, 2, 0)) # 处理单通道或四通道图像 if image.shape[2] == 1: image = np.repeat(image, 3, axis=2) elif image.shape[2] > 3: image = image[:, :, :3] # 只使用RGB通道 # 应用数据预处理 if self.transform: augmented = self.transform(image=image) image = augmented['image'] return image, tile_info def get_dataloader(image_dir, mask_dir, batch_size=4, tile_size=512, stride=None, transform=None, shuffle=True, num_workers=4): """获取数据加载器 Args: image_dir (str): 图像目录路径 mask_dir (str): 掩码目录路径 batch_size (int): 批次大小 tile_size (int): 分块大小 stride (int): 滑动窗口步长 transform: 数据增强和预处理 shuffle (bool): 是否打乱数据 num_workers (int): 数据加载线程数 Returns: DataLoader: PyTorch数据加载器 """ dataset = RemoteSensingDataset( image_dir=image_dir, mask_dir=mask_dir, transform=transform, tile_size=tile_size, stride=stride ) return DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True )