205 lines
7.4 KiB
Python
205 lines
7.4 KiB
Python
import os
|
||
import numpy as np
|
||
import torch
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from PIL import Image
|
||
import rasterio
|
||
from rasterio.windows import Window
|
||
import albumentations as A
|
||
from albumentations.pytorch import ToTensorV2
|
||
import math
|
||
|
||
class RemoteSensingDataset(Dataset):
|
||
"""遥感图像数据集类"""
|
||
def __init__(self, image_dir, mask_dir, transform=None, tile_size=512, stride=None):
|
||
"""
|
||
Args:
|
||
image_dir (str): 图像目录路径
|
||
mask_dir (str): 掩码目录路径
|
||
transform: 数据增强和预处理
|
||
tile_size (int): 分块大小
|
||
stride (int): 滑动窗口步长,如果为None则等于tile_size
|
||
"""
|
||
self.image_dir = image_dir
|
||
self.mask_dir = mask_dir
|
||
self.transform = transform
|
||
self.tile_size = tile_size
|
||
self.stride = stride if stride is not None else tile_size
|
||
|
||
self.images = [f for f in os.listdir(image_dir) if f.endswith(('.tif', '.png', '.jpg'))]
|
||
self.tiles_info = self._prepare_tiles()
|
||
|
||
def _prepare_tiles(self):
|
||
"""准备图像分块信息"""
|
||
tiles_info = []
|
||
|
||
for img_name in self.images:
|
||
img_path = os.path.join(self.image_dir, img_name)
|
||
mask_path = os.path.join(self.mask_dir, img_name)
|
||
|
||
# 检查掩码文件是否存在
|
||
if not os.path.exists(mask_path):
|
||
# 尝试不同的扩展名
|
||
base_name = os.path.splitext(img_name)[0]
|
||
for ext in ['.tif', '.png', '.jpg']:
|
||
alt_mask_path = os.path.join(self.mask_dir, base_name + ext)
|
||
if os.path.exists(alt_mask_path):
|
||
mask_path = alt_mask_path
|
||
break
|
||
else:
|
||
print(f"警告: 找不到图像 {img_name} 对应的掩码文件")
|
||
continue
|
||
|
||
# 获取图像尺寸
|
||
with rasterio.open(img_path) as src:
|
||
height, width = src.height, src.width
|
||
|
||
# 计算分块数量
|
||
num_tiles_h = math.ceil((height - self.tile_size) / self.stride) + 1
|
||
num_tiles_w = math.ceil((width - self.tile_size) / self.stride) + 1
|
||
|
||
# 生成分块信息
|
||
for i in range(num_tiles_h):
|
||
for j in range(num_tiles_w):
|
||
y = min(i * self.stride, height - self.tile_size)
|
||
x = min(j * self.stride, width - self.tile_size)
|
||
|
||
tiles_info.append({
|
||
'img_path': img_path,
|
||
'mask_path': mask_path,
|
||
'x': x,
|
||
'y': y,
|
||
'width': self.tile_size,
|
||
'height': self.tile_size
|
||
})
|
||
|
||
return tiles_info
|
||
|
||
def __len__(self):
|
||
return len(self.tiles_info)
|
||
|
||
def __getitem__(self, idx):
|
||
tile_info = self.tiles_info[idx]
|
||
|
||
# 读取图像分块
|
||
with rasterio.open(tile_info['img_path']) as src:
|
||
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
|
||
image = src.read(window=window)
|
||
# 转换为RGB格式 (C, H, W) -> (H, W, C)
|
||
image = np.transpose(image, (1, 2, 0))
|
||
|
||
# 处理单通道或四通道图像
|
||
if image.shape[2] == 1:
|
||
image = np.repeat(image, 3, axis=2)
|
||
elif image.shape[2] > 3:
|
||
image = image[:, :, :3] # 只使用RGB通道
|
||
|
||
# 读取掩码分块
|
||
with rasterio.open(tile_info['mask_path']) as src:
|
||
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
|
||
mask = src.read(1, window=window) # 假设掩码是单通道的
|
||
|
||
# 应用数据增强和预处理
|
||
if self.transform:
|
||
augmented = self.transform(image=image, mask=mask)
|
||
image = augmented['image']
|
||
mask = augmented['mask']
|
||
|
||
return image, mask.long()
|
||
|
||
class TileDataset(Dataset):
|
||
"""用于大尺度图像预测的分块数据集"""
|
||
def __init__(self, image_path, tile_size=512, stride=None, transform=None):
|
||
"""
|
||
Args:
|
||
image_path (str): 图像文件路径
|
||
tile_size (int): 分块大小
|
||
stride (int): 滑动窗口步长,如果为None则等于tile_size
|
||
transform: 数据预处理
|
||
"""
|
||
self.image_path = image_path
|
||
self.tile_size = tile_size
|
||
self.stride = stride if stride is not None else tile_size
|
||
self.transform = transform
|
||
|
||
# 获取图像尺寸
|
||
with rasterio.open(image_path) as src:
|
||
self.height, self.width = src.height, src.width
|
||
self.profile = src.profile.copy()
|
||
|
||
# 计算分块数量
|
||
self.num_tiles_h = math.ceil((self.height - self.tile_size) / self.stride) + 1
|
||
self.num_tiles_w = math.ceil((self.width - self.tile_size) / self.stride) + 1
|
||
|
||
# 生成分块信息
|
||
self.tiles_info = []
|
||
for i in range(self.num_tiles_h):
|
||
for j in range(self.num_tiles_w):
|
||
y = min(i * self.stride, self.height - self.tile_size)
|
||
x = min(j * self.stride, self.width - self.tile_size)
|
||
|
||
self.tiles_info.append({
|
||
'x': x,
|
||
'y': y,
|
||
'width': self.tile_size,
|
||
'height': self.tile_size
|
||
})
|
||
|
||
def __len__(self):
|
||
return len(self.tiles_info)
|
||
|
||
def __getitem__(self, idx):
|
||
tile_info = self.tiles_info[idx]
|
||
|
||
# 读取图像分块
|
||
with rasterio.open(self.image_path) as src:
|
||
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
|
||
image = src.read(window=window)
|
||
# 转换为RGB格式 (C, H, W) -> (H, W, C)
|
||
image = np.transpose(image, (1, 2, 0))
|
||
|
||
# 处理单通道或四通道图像
|
||
if image.shape[2] == 1:
|
||
image = np.repeat(image, 3, axis=2)
|
||
elif image.shape[2] > 3:
|
||
image = image[:, :, :3] # 只使用RGB通道
|
||
|
||
# 应用数据预处理
|
||
if self.transform:
|
||
augmented = self.transform(image=image)
|
||
image = augmented['image']
|
||
|
||
return image, tile_info
|
||
|
||
def get_dataloader(image_dir, mask_dir, batch_size=4, tile_size=512, stride=None,
|
||
transform=None, shuffle=True, num_workers=4):
|
||
"""获取数据加载器
|
||
|
||
Args:
|
||
image_dir (str): 图像目录路径
|
||
mask_dir (str): 掩码目录路径
|
||
batch_size (int): 批次大小
|
||
tile_size (int): 分块大小
|
||
stride (int): 滑动窗口步长
|
||
transform: 数据增强和预处理
|
||
shuffle (bool): 是否打乱数据
|
||
num_workers (int): 数据加载线程数
|
||
|
||
Returns:
|
||
DataLoader: PyTorch数据加载器
|
||
"""
|
||
dataset = RemoteSensingDataset(
|
||
image_dir=image_dir,
|
||
mask_dir=mask_dir,
|
||
transform=transform,
|
||
tile_size=tile_size,
|
||
stride=stride
|
||
)
|
||
|
||
return DataLoader(
|
||
dataset,
|
||
batch_size=batch_size,
|
||
shuffle=shuffle,
|
||
num_workers=num_workers,
|
||
pin_memory=True
|
||
) |