205 lines
7.4 KiB
Python
205 lines
7.4 KiB
Python
|
import os
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
from torch.utils.data import Dataset, DataLoader
|
|||
|
from PIL import Image
|
|||
|
import rasterio
|
|||
|
from rasterio.windows import Window
|
|||
|
import albumentations as A
|
|||
|
from albumentations.pytorch import ToTensorV2
|
|||
|
import math
|
|||
|
|
|||
|
class RemoteSensingDataset(Dataset):
|
|||
|
"""遥感图像数据集类"""
|
|||
|
def __init__(self, image_dir, mask_dir, transform=None, tile_size=512, stride=None):
|
|||
|
"""
|
|||
|
Args:
|
|||
|
image_dir (str): 图像目录路径
|
|||
|
mask_dir (str): 掩码目录路径
|
|||
|
transform: 数据增强和预处理
|
|||
|
tile_size (int): 分块大小
|
|||
|
stride (int): 滑动窗口步长,如果为None则等于tile_size
|
|||
|
"""
|
|||
|
self.image_dir = image_dir
|
|||
|
self.mask_dir = mask_dir
|
|||
|
self.transform = transform
|
|||
|
self.tile_size = tile_size
|
|||
|
self.stride = stride if stride is not None else tile_size
|
|||
|
|
|||
|
self.images = [f for f in os.listdir(image_dir) if f.endswith(('.tif', '.png', '.jpg'))]
|
|||
|
self.tiles_info = self._prepare_tiles()
|
|||
|
|
|||
|
def _prepare_tiles(self):
|
|||
|
"""准备图像分块信息"""
|
|||
|
tiles_info = []
|
|||
|
|
|||
|
for img_name in self.images:
|
|||
|
img_path = os.path.join(self.image_dir, img_name)
|
|||
|
mask_path = os.path.join(self.mask_dir, img_name)
|
|||
|
|
|||
|
# 检查掩码文件是否存在
|
|||
|
if not os.path.exists(mask_path):
|
|||
|
# 尝试不同的扩展名
|
|||
|
base_name = os.path.splitext(img_name)[0]
|
|||
|
for ext in ['.tif', '.png', '.jpg']:
|
|||
|
alt_mask_path = os.path.join(self.mask_dir, base_name + ext)
|
|||
|
if os.path.exists(alt_mask_path):
|
|||
|
mask_path = alt_mask_path
|
|||
|
break
|
|||
|
else:
|
|||
|
print(f"警告: 找不到图像 {img_name} 对应的掩码文件")
|
|||
|
continue
|
|||
|
|
|||
|
# 获取图像尺寸
|
|||
|
with rasterio.open(img_path) as src:
|
|||
|
height, width = src.height, src.width
|
|||
|
|
|||
|
# 计算分块数量
|
|||
|
num_tiles_h = math.ceil((height - self.tile_size) / self.stride) + 1
|
|||
|
num_tiles_w = math.ceil((width - self.tile_size) / self.stride) + 1
|
|||
|
|
|||
|
# 生成分块信息
|
|||
|
for i in range(num_tiles_h):
|
|||
|
for j in range(num_tiles_w):
|
|||
|
y = min(i * self.stride, height - self.tile_size)
|
|||
|
x = min(j * self.stride, width - self.tile_size)
|
|||
|
|
|||
|
tiles_info.append({
|
|||
|
'img_path': img_path,
|
|||
|
'mask_path': mask_path,
|
|||
|
'x': x,
|
|||
|
'y': y,
|
|||
|
'width': self.tile_size,
|
|||
|
'height': self.tile_size
|
|||
|
})
|
|||
|
|
|||
|
return tiles_info
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.tiles_info)
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
tile_info = self.tiles_info[idx]
|
|||
|
|
|||
|
# 读取图像分块
|
|||
|
with rasterio.open(tile_info['img_path']) as src:
|
|||
|
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
|
|||
|
image = src.read(window=window)
|
|||
|
# 转换为RGB格式 (C, H, W) -> (H, W, C)
|
|||
|
image = np.transpose(image, (1, 2, 0))
|
|||
|
|
|||
|
# 处理单通道或四通道图像
|
|||
|
if image.shape[2] == 1:
|
|||
|
image = np.repeat(image, 3, axis=2)
|
|||
|
elif image.shape[2] > 3:
|
|||
|
image = image[:, :, :3] # 只使用RGB通道
|
|||
|
|
|||
|
# 读取掩码分块
|
|||
|
with rasterio.open(tile_info['mask_path']) as src:
|
|||
|
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
|
|||
|
mask = src.read(1, window=window) # 假设掩码是单通道的
|
|||
|
|
|||
|
# 应用数据增强和预处理
|
|||
|
if self.transform:
|
|||
|
augmented = self.transform(image=image, mask=mask)
|
|||
|
image = augmented['image']
|
|||
|
mask = augmented['mask']
|
|||
|
|
|||
|
return image, mask.long()
|
|||
|
|
|||
|
class TileDataset(Dataset):
|
|||
|
"""用于大尺度图像预测的分块数据集"""
|
|||
|
def __init__(self, image_path, tile_size=512, stride=None, transform=None):
|
|||
|
"""
|
|||
|
Args:
|
|||
|
image_path (str): 图像文件路径
|
|||
|
tile_size (int): 分块大小
|
|||
|
stride (int): 滑动窗口步长,如果为None则等于tile_size
|
|||
|
transform: 数据预处理
|
|||
|
"""
|
|||
|
self.image_path = image_path
|
|||
|
self.tile_size = tile_size
|
|||
|
self.stride = stride if stride is not None else tile_size
|
|||
|
self.transform = transform
|
|||
|
|
|||
|
# 获取图像尺寸
|
|||
|
with rasterio.open(image_path) as src:
|
|||
|
self.height, self.width = src.height, src.width
|
|||
|
self.profile = src.profile.copy()
|
|||
|
|
|||
|
# 计算分块数量
|
|||
|
self.num_tiles_h = math.ceil((self.height - self.tile_size) / self.stride) + 1
|
|||
|
self.num_tiles_w = math.ceil((self.width - self.tile_size) / self.stride) + 1
|
|||
|
|
|||
|
# 生成分块信息
|
|||
|
self.tiles_info = []
|
|||
|
for i in range(self.num_tiles_h):
|
|||
|
for j in range(self.num_tiles_w):
|
|||
|
y = min(i * self.stride, self.height - self.tile_size)
|
|||
|
x = min(j * self.stride, self.width - self.tile_size)
|
|||
|
|
|||
|
self.tiles_info.append({
|
|||
|
'x': x,
|
|||
|
'y': y,
|
|||
|
'width': self.tile_size,
|
|||
|
'height': self.tile_size
|
|||
|
})
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.tiles_info)
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
tile_info = self.tiles_info[idx]
|
|||
|
|
|||
|
# 读取图像分块
|
|||
|
with rasterio.open(self.image_path) as src:
|
|||
|
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
|
|||
|
image = src.read(window=window)
|
|||
|
# 转换为RGB格式 (C, H, W) -> (H, W, C)
|
|||
|
image = np.transpose(image, (1, 2, 0))
|
|||
|
|
|||
|
# 处理单通道或四通道图像
|
|||
|
if image.shape[2] == 1:
|
|||
|
image = np.repeat(image, 3, axis=2)
|
|||
|
elif image.shape[2] > 3:
|
|||
|
image = image[:, :, :3] # 只使用RGB通道
|
|||
|
|
|||
|
# 应用数据预处理
|
|||
|
if self.transform:
|
|||
|
augmented = self.transform(image=image)
|
|||
|
image = augmented['image']
|
|||
|
|
|||
|
return image, tile_info
|
|||
|
|
|||
|
def get_dataloader(image_dir, mask_dir, batch_size=4, tile_size=512, stride=None,
|
|||
|
transform=None, shuffle=True, num_workers=4):
|
|||
|
"""获取数据加载器
|
|||
|
|
|||
|
Args:
|
|||
|
image_dir (str): 图像目录路径
|
|||
|
mask_dir (str): 掩码目录路径
|
|||
|
batch_size (int): 批次大小
|
|||
|
tile_size (int): 分块大小
|
|||
|
stride (int): 滑动窗口步长
|
|||
|
transform: 数据增强和预处理
|
|||
|
shuffle (bool): 是否打乱数据
|
|||
|
num_workers (int): 数据加载线程数
|
|||
|
|
|||
|
Returns:
|
|||
|
DataLoader: PyTorch数据加载器
|
|||
|
"""
|
|||
|
dataset = RemoteSensingDataset(
|
|||
|
image_dir=image_dir,
|
|||
|
mask_dir=mask_dir,
|
|||
|
transform=transform,
|
|||
|
tile_size=tile_size,
|
|||
|
stride=stride
|
|||
|
)
|
|||
|
|
|||
|
return DataLoader(
|
|||
|
dataset,
|
|||
|
batch_size=batch_size,
|
|||
|
shuffle=shuffle,
|
|||
|
num_workers=num_workers,
|
|||
|
pin_memory=True
|
|||
|
)
|