205 lines
7.4 KiB
Python
Raw Permalink Normal View History

2025-07-10 09:41:26 +08:00
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import rasterio
from rasterio.windows import Window
import albumentations as A
from albumentations.pytorch import ToTensorV2
import math
class RemoteSensingDataset(Dataset):
"""遥感图像数据集类"""
def __init__(self, image_dir, mask_dir, transform=None, tile_size=512, stride=None):
"""
Args:
image_dir (str): 图像目录路径
mask_dir (str): 掩码目录路径
transform: 数据增强和预处理
tile_size (int): 分块大小
stride (int): 滑动窗口步长如果为None则等于tile_size
"""
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.tile_size = tile_size
self.stride = stride if stride is not None else tile_size
self.images = [f for f in os.listdir(image_dir) if f.endswith(('.tif', '.png', '.jpg'))]
self.tiles_info = self._prepare_tiles()
def _prepare_tiles(self):
"""准备图像分块信息"""
tiles_info = []
for img_name in self.images:
img_path = os.path.join(self.image_dir, img_name)
mask_path = os.path.join(self.mask_dir, img_name)
# 检查掩码文件是否存在
if not os.path.exists(mask_path):
# 尝试不同的扩展名
base_name = os.path.splitext(img_name)[0]
for ext in ['.tif', '.png', '.jpg']:
alt_mask_path = os.path.join(self.mask_dir, base_name + ext)
if os.path.exists(alt_mask_path):
mask_path = alt_mask_path
break
else:
print(f"警告: 找不到图像 {img_name} 对应的掩码文件")
continue
# 获取图像尺寸
with rasterio.open(img_path) as src:
height, width = src.height, src.width
# 计算分块数量
num_tiles_h = math.ceil((height - self.tile_size) / self.stride) + 1
num_tiles_w = math.ceil((width - self.tile_size) / self.stride) + 1
# 生成分块信息
for i in range(num_tiles_h):
for j in range(num_tiles_w):
y = min(i * self.stride, height - self.tile_size)
x = min(j * self.stride, width - self.tile_size)
tiles_info.append({
'img_path': img_path,
'mask_path': mask_path,
'x': x,
'y': y,
'width': self.tile_size,
'height': self.tile_size
})
return tiles_info
def __len__(self):
return len(self.tiles_info)
def __getitem__(self, idx):
tile_info = self.tiles_info[idx]
# 读取图像分块
with rasterio.open(tile_info['img_path']) as src:
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
image = src.read(window=window)
# 转换为RGB格式 (C, H, W) -> (H, W, C)
image = np.transpose(image, (1, 2, 0))
# 处理单通道或四通道图像
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
elif image.shape[2] > 3:
image = image[:, :, :3] # 只使用RGB通道
# 读取掩码分块
with rasterio.open(tile_info['mask_path']) as src:
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
mask = src.read(1, window=window) # 假设掩码是单通道的
# 应用数据增强和预处理
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
return image, mask.long()
class TileDataset(Dataset):
"""用于大尺度图像预测的分块数据集"""
def __init__(self, image_path, tile_size=512, stride=None, transform=None):
"""
Args:
image_path (str): 图像文件路径
tile_size (int): 分块大小
stride (int): 滑动窗口步长如果为None则等于tile_size
transform: 数据预处理
"""
self.image_path = image_path
self.tile_size = tile_size
self.stride = stride if stride is not None else tile_size
self.transform = transform
# 获取图像尺寸
with rasterio.open(image_path) as src:
self.height, self.width = src.height, src.width
self.profile = src.profile.copy()
# 计算分块数量
self.num_tiles_h = math.ceil((self.height - self.tile_size) / self.stride) + 1
self.num_tiles_w = math.ceil((self.width - self.tile_size) / self.stride) + 1
# 生成分块信息
self.tiles_info = []
for i in range(self.num_tiles_h):
for j in range(self.num_tiles_w):
y = min(i * self.stride, self.height - self.tile_size)
x = min(j * self.stride, self.width - self.tile_size)
self.tiles_info.append({
'x': x,
'y': y,
'width': self.tile_size,
'height': self.tile_size
})
def __len__(self):
return len(self.tiles_info)
def __getitem__(self, idx):
tile_info = self.tiles_info[idx]
# 读取图像分块
with rasterio.open(self.image_path) as src:
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
image = src.read(window=window)
# 转换为RGB格式 (C, H, W) -> (H, W, C)
image = np.transpose(image, (1, 2, 0))
# 处理单通道或四通道图像
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
elif image.shape[2] > 3:
image = image[:, :, :3] # 只使用RGB通道
# 应用数据预处理
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
return image, tile_info
def get_dataloader(image_dir, mask_dir, batch_size=4, tile_size=512, stride=None,
transform=None, shuffle=True, num_workers=4):
"""获取数据加载器
Args:
image_dir (str): 图像目录路径
mask_dir (str): 掩码目录路径
batch_size (int): 批次大小
tile_size (int): 分块大小
stride (int): 滑动窗口步长
transform: 数据增强和预处理
shuffle (bool): 是否打乱数据
num_workers (int): 数据加载线程数
Returns:
DataLoader: PyTorch数据加载器
"""
dataset = RemoteSensingDataset(
image_dir=image_dir,
mask_dir=mask_dir,
transform=transform,
tile_size=tile_size,
stride=stride
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True
)