2025-07-10 09:41:26 +08:00

205 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import rasterio
from rasterio.windows import Window
import albumentations as A
from albumentations.pytorch import ToTensorV2
import math
class RemoteSensingDataset(Dataset):
"""遥感图像数据集类"""
def __init__(self, image_dir, mask_dir, transform=None, tile_size=512, stride=None):
"""
Args:
image_dir (str): 图像目录路径
mask_dir (str): 掩码目录路径
transform: 数据增强和预处理
tile_size (int): 分块大小
stride (int): 滑动窗口步长如果为None则等于tile_size
"""
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.tile_size = tile_size
self.stride = stride if stride is not None else tile_size
self.images = [f for f in os.listdir(image_dir) if f.endswith(('.tif', '.png', '.jpg'))]
self.tiles_info = self._prepare_tiles()
def _prepare_tiles(self):
"""准备图像分块信息"""
tiles_info = []
for img_name in self.images:
img_path = os.path.join(self.image_dir, img_name)
mask_path = os.path.join(self.mask_dir, img_name)
# 检查掩码文件是否存在
if not os.path.exists(mask_path):
# 尝试不同的扩展名
base_name = os.path.splitext(img_name)[0]
for ext in ['.tif', '.png', '.jpg']:
alt_mask_path = os.path.join(self.mask_dir, base_name + ext)
if os.path.exists(alt_mask_path):
mask_path = alt_mask_path
break
else:
print(f"警告: 找不到图像 {img_name} 对应的掩码文件")
continue
# 获取图像尺寸
with rasterio.open(img_path) as src:
height, width = src.height, src.width
# 计算分块数量
num_tiles_h = math.ceil((height - self.tile_size) / self.stride) + 1
num_tiles_w = math.ceil((width - self.tile_size) / self.stride) + 1
# 生成分块信息
for i in range(num_tiles_h):
for j in range(num_tiles_w):
y = min(i * self.stride, height - self.tile_size)
x = min(j * self.stride, width - self.tile_size)
tiles_info.append({
'img_path': img_path,
'mask_path': mask_path,
'x': x,
'y': y,
'width': self.tile_size,
'height': self.tile_size
})
return tiles_info
def __len__(self):
return len(self.tiles_info)
def __getitem__(self, idx):
tile_info = self.tiles_info[idx]
# 读取图像分块
with rasterio.open(tile_info['img_path']) as src:
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
image = src.read(window=window)
# 转换为RGB格式 (C, H, W) -> (H, W, C)
image = np.transpose(image, (1, 2, 0))
# 处理单通道或四通道图像
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
elif image.shape[2] > 3:
image = image[:, :, :3] # 只使用RGB通道
# 读取掩码分块
with rasterio.open(tile_info['mask_path']) as src:
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
mask = src.read(1, window=window) # 假设掩码是单通道的
# 应用数据增强和预处理
if self.transform:
augmented = self.transform(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
return image, mask.long()
class TileDataset(Dataset):
"""用于大尺度图像预测的分块数据集"""
def __init__(self, image_path, tile_size=512, stride=None, transform=None):
"""
Args:
image_path (str): 图像文件路径
tile_size (int): 分块大小
stride (int): 滑动窗口步长如果为None则等于tile_size
transform: 数据预处理
"""
self.image_path = image_path
self.tile_size = tile_size
self.stride = stride if stride is not None else tile_size
self.transform = transform
# 获取图像尺寸
with rasterio.open(image_path) as src:
self.height, self.width = src.height, src.width
self.profile = src.profile.copy()
# 计算分块数量
self.num_tiles_h = math.ceil((self.height - self.tile_size) / self.stride) + 1
self.num_tiles_w = math.ceil((self.width - self.tile_size) / self.stride) + 1
# 生成分块信息
self.tiles_info = []
for i in range(self.num_tiles_h):
for j in range(self.num_tiles_w):
y = min(i * self.stride, self.height - self.tile_size)
x = min(j * self.stride, self.width - self.tile_size)
self.tiles_info.append({
'x': x,
'y': y,
'width': self.tile_size,
'height': self.tile_size
})
def __len__(self):
return len(self.tiles_info)
def __getitem__(self, idx):
tile_info = self.tiles_info[idx]
# 读取图像分块
with rasterio.open(self.image_path) as src:
window = Window(tile_info['x'], tile_info['y'], tile_info['width'], tile_info['height'])
image = src.read(window=window)
# 转换为RGB格式 (C, H, W) -> (H, W, C)
image = np.transpose(image, (1, 2, 0))
# 处理单通道或四通道图像
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
elif image.shape[2] > 3:
image = image[:, :, :3] # 只使用RGB通道
# 应用数据预处理
if self.transform:
augmented = self.transform(image=image)
image = augmented['image']
return image, tile_info
def get_dataloader(image_dir, mask_dir, batch_size=4, tile_size=512, stride=None,
transform=None, shuffle=True, num_workers=4):
"""获取数据加载器
Args:
image_dir (str): 图像目录路径
mask_dir (str): 掩码目录路径
batch_size (int): 批次大小
tile_size (int): 分块大小
stride (int): 滑动窗口步长
transform: 数据增强和预处理
shuffle (bool): 是否打乱数据
num_workers (int): 数据加载线程数
Returns:
DataLoader: PyTorch数据加载器
"""
dataset = RemoteSensingDataset(
image_dir=image_dir,
mask_dir=mask_dir,
transform=transform,
tile_size=tile_size,
stride=stride
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True
)