580 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#coding=utf-8
from os.path import join
import torch
from PIL import Image, ImageEnhance
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
import numpy as np
import torchvision.transforms as transforms
import os
from argparse import Namespace
import imageio
import pandas as pd
import rasterio
from rasterio.windows import Window
import torch.nn.functional as F
Image.MAX_IMAGE_PIXELS = None # 禁用像素限制警告
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['.tif','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
def calMetric_iou(predict, label):
tp = np.sum(np.logical_and(predict == 1, label == 1))
fp = np.sum(predict==1)
fn = np.sum(label == 1)
return tp,fp+fn-tp
def getDataList(img_path):
dataline = open(img_path, 'r').readlines()
datalist =[]
for line in dataline:
temp = line.strip('\n')
datalist.append(temp)
return datalist
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [N, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, input.cpu(), 1)
return result
def get_transform(convert=True, normalize=False):
transform_list = []
if convert:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
# class LoadDatasetFromFolder(Dataset):
# def __init__(self, args, hr1_path, hr2_path, lab_path):
# super(LoadDatasetFromFolder, self).__init__()
# # 获取图片列表
# datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if
# os.path.splitext(name)[1] == item]
# self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)]
# self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)]
# self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)]
# self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
# self.label_transform = get_transform() # only convert to tensor
# def __getitem__(self, index):
# hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB'))
# # lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB'))
# hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB'))
# label = self.label_transform(Image.open(self.lab_filenames[index]))
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
# return hr1_img, hr2_img, label
# def __len__(self):
# return len(self.hr1_filenames)
#验证
class LoadDatasetFromCoords(Dataset):
def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir, num_classes=2):
super(LoadDatasetFromCoords, self).__init__()
self.coords = pd.read_csv(coords_csv)
self.hr1_dir = hr1_dir
self.hr2_dir = hr2_dir
self.label_dir = label_dir
self.num_classes = num_classes
self.transform = get_transform(convert=True, normalize=True)
self.label_transform = get_transform()
self.open_images = {}
def _load_image(self, path):
if path not in self.open_images:
self.open_images[path] = rasterio.open(path)
return self.open_images[path]
def __len__(self):
return len(self.coords)
def __getitem__(self, idx):
row = self.coords.iloc[idx]
fname = row['orig_name']
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
# 加载影像
img1 = self._load_image(os.path.join(self.hr1_dir, fname))
img2 = self._load_image(os.path.join(self.hr2_dir, fname))
label = self._load_image(os.path.join(self.label_dir, fname))
window = Window(x, y, w, h)
patch1 = img1.read(window=window) # [C, H, W]
patch2 = img2.read(window=window)
patch_label = label.read(1, window=window) # [H, W]
# 转为 PIL.Image 以使用 transform
patch1 = Image.fromarray(np.transpose(patch1, (1, 2, 0)).astype(np.uint8))
patch2 = Image.fromarray(np.transpose(patch2, (1, 2, 0)).astype(np.uint8))
patch_label = Image.fromarray(patch_label.astype(np.uint8))
patch1 = self.transform(patch1)
patch2 = self.transform(patch2)
patch_label = self.label_transform(patch_label)
patch_label = make_one_hot(patch_label.unsqueeze(0).long(), self.num_classes).squeeze(0)
return patch1, patch2, patch_label
#测试
# class TestDatasetFromCoords(Dataset):
# def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir=None, transform=None):
# self.coords = pd.read_csv(coords_csv)
# self.hr1_dir = hr1_dir
# self.hr2_dir = hr2_dir
# self.label_dir = label_dir
# self.transform = transform
# self.label_transform = transforms.ToTensor()
# def __len__(self):
# return len(self.coords)
# def __getitem__(self, idx):
# row = self.coords.iloc[idx]
# x, y, h, w = row['x'], row['y'], row['h'], row['w']
# fname = row['orig_name']
# img1 = Image.open(os.path.join(self.hr1_dir, fname)).convert('RGB')
# img2 = Image.open(os.path.join(self.hr2_dir, fname)).convert('RGB')
# img1 = img1.crop((x, y, x + w, y + h))
# img2 = img2.crop((x, y, x + w, y + h))
# if self.label_dir:
# label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
# label = label.crop((x, y, x + w, y + h))
# label = self.label_transform(label)
# else:
# label = torch.zeros((1, h, w)) # dummy label
# if self.transform:
# img1 = self.transform(img1)
# img2 = self.transform(img2)
# patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
# return img1, img2, label, patch_name, fname, x, y
class TestDatasetFromCoords(Dataset):
def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir=None, transform=None):
self.coords = pd.read_csv(coords_csv)
self.hr1_dir = hr1_dir
self.hr2_dir = hr2_dir
self.label_dir = label_dir
self.transform = transform
self.label_transform = transforms.ToTensor()
def __len__(self):
return len(self.coords)
def __getitem__(self, idx):
row = self.coords.iloc[idx]
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
fname = row['orig_name']
# 读取图像并裁剪 patch
img1 = Image.open(os.path.join(self.hr1_dir, fname)).convert('RGB')
img2 = Image.open(os.path.join(self.hr2_dir, fname)).convert('RGB')
img1 = img1.crop((x, y, x + w, y + h))
img2 = img2.crop((x, y, x + w, y + h))
# 标签
if self.label_dir:
label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
label = label.crop((x, y, x + w, y + h))
label = self.label_transform(label)
else:
label = torch.zeros((1, h, w)) # dummy label
# 图像 transform
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
# padding 到 256x256并记录原始 h, w
pad_h = 256 - h
pad_w = 256 - w
if pad_h > 0 or pad_w > 0:
img1 = F.pad(img1, (0, pad_w, 0, pad_h)) # pad right & bottom
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
label = F.pad(label, (0, pad_w, 0, pad_h))
patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
return img1, img2, label, patch_name, fname, x, y, h, w # 返回原始大小
class TestDatasetFromImagePath(Dataset):
def __init__(self, img1_path,img2_path, transform=None):
# self.coords = pd.read_csv(coords_csv)
# self.hr1_dir = hr1_dir
# self.hr2_dir = hr2_dir
# self.label_dir = label_dir
self.img1_path = img1_path
self.img2_path = img2_path
self.transform = transform
self.label_transform = transforms.ToTensor()
def __len__(self):
return None
def __getitem__(self, idx):
# row = self.coords.iloc[idx]
# x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
# fname = row['orig_name']
# 读取图像并裁剪 patch
img1 = Image.open(self.img1_path).convert('RGB')
img2 = Image.open(self.img2_path).convert('RGB')
img1 = img1.crop((x, y, x + w, y + h))
img2 = img2.crop((x, y, x + w, y + h))
# 标签
if self.label_dir:
label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
label = label.crop((x, y, x + w, y + h))
label = self.label_transform(label)
else:
label = torch.zeros((1, h, w)) # dummy label
# 图像 transform
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
# padding 到 256x256并记录原始 h, w
pad_h = 256 - h
pad_w = 256 - w
if pad_h > 0 or pad_w > 0:
img1 = F.pad(img1, (0, pad_w, 0, pad_h)) # pad right & bottom
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
label = F.pad(label, (0, pad_w, 0, pad_h))
patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
return img1, img2, label, patch_name, fname, x, y, h, w # 返回原始大小
# class TestDatasetFromFolder(Dataset):
# def __init__(self, args, Time1_dir, Time2_dir, Label_dir):
# super(TestDatasetFromFolder, self).__init__()
# # 确保 args.suffix 存在,否则使用默认值
# suffixes = getattr(args, 'suffix', ['.png', '.jpg'])
# # 筛选符合后缀的文件
# datalist = [
# name for name in os.listdir(Time1_dir)
# if os.path.splitext(name)[1].lower() in suffixes
# ]
# print(f"找到 {len(datalist)} 个文件: {datalist[:5]}...") # 调试输出
# self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)]
# self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)]
# self.image3_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
# self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
# self.label_transform = get_transform()
# def __getitem__(self, index):
# image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB'))
# image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB'))
# label = self.label_transform(Image.open(self.image3_filenames[index]))
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
# image_name = self.image1_filenames[index].split('/', -1)
# image_name = image_name[len(image_name)-1]
# return image1, image2, label, image_name
# def __len__(self):
# return len(self.image1_filenames)
class trainImageAug(object):
def __init__(self, crop = True, augment = True, angle = 30):
self.crop =crop
self.augment = augment
self.angle = angle
def __call__(self, image1, image2, mask):
if self.crop:
w = np.random.randint(0,256)
h = np.random.randint(0,256)
box = (w, h, w+256, h+256)
image1 = image1.crop(box)
image2 = image2.crop(box)
mask = mask.crop(box)
if self.augment:
prop = np.random.uniform(0, 1)
if prop < 0.15:
image1 = image1.transpose(Image.FLIP_LEFT_RIGHT)
image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
elif prop < 0.3:
image1 = image1.transpose(Image.FLIP_TOP_BOTTOM)
image2 = image2.transpose(Image.FLIP_TOP_BOTTOM)
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
elif prop < 0.5:
image1 = image1.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
image2 = image2.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
mask = mask.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
return image1, image2, mask
def get_transform(convert=True, normalize=False):
transform_list = []
if convert:
transform_list += [
transforms.ToTensor(),
]
if normalize:
transform_list += [
# transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
class DynamicPatchDataset(Dataset):
def __init__(self, coords_csv, img1_dir, img2_dir, label_dir,
crop=True, augment=True, angle=30, num_classes=2):
self.coords = pd.read_csv(coords_csv)
self.img1_dir = img1_dir
self.img2_dir = img2_dir
self.label_dir = label_dir
self.num_classes = num_classes
# 与原 DA_DatasetFromFolder 一致的设置
self.data_augment = trainImageAug(crop=crop, augment=augment, angle=angle)
self.img_transform = get_transform(convert=True, normalize=True)
self.lab_transform = get_transform()
self.open_images = {} # 缓存图像
def _load_image(self, path):
if path not in self.open_images:
self.open_images[path] = rasterio.open(path)
return self.open_images[path]
def __len__(self):
return len(self.coords)
def __getitem__(self, idx):
row = self.coords.iloc[idx]
fname = row['orig_name']
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
# 打开大图并裁剪 patch
img1 = self._load_image(os.path.join(self.img1_dir, fname))
img2 = self._load_image(os.path.join(self.img2_dir, fname))
label = self._load_image(os.path.join(self.label_dir, fname))
window = Window(x, y, w, h)
patch1 = img1.read(window=window) # [C, H, W]
patch2 = img2.read(window=window)
patch_label = label.read(1, window=window) # [H, W]
# 转换为 PIL Image 以便和原有增强方式兼容
patch1 = Image.fromarray(np.transpose(patch1, (1, 2, 0)).astype(np.uint8)) # [H, W, C]
patch2 = Image.fromarray(np.transpose(patch2, (1, 2, 0)).astype(np.uint8))
patch_label = Image.fromarray(patch_label.astype(np.uint8))
# 数据增强(裁剪、旋转等)
patch1, patch2, patch_label = self.data_augment(patch1, patch2, patch_label)
# 图像 transform归一化、转 tensor
patch1 = self.img_transform(patch1)
patch2 = self.img_transform(patch2)
patch_label = self.lab_transform(patch_label)
# one-hot label
patch_label = make_one_hot(patch_label.unsqueeze(0).long(), self.num_classes).squeeze(0)
return patch1, patch2, patch_label
# class DA_DatasetFromFolder(Dataset):
# def __init__(self, Image_dir1, Image_dir2, Label_dir, crop=True, augment = True, angle = 30):
# super(DA_DatasetFromFolder, self).__init__()
# # 获取图片列表
# datalist = os.listdir(Image_dir1)
# self.image_filenames1 = [join(Image_dir1, x) for x in datalist if is_image_file(x)]
# self.image_filenames2 = [join(Image_dir2, x) for x in datalist if is_image_file(x)]
# self.label_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
# self.data_augment = trainImageAug(crop=crop, augment = augment, angle=angle)
# self.img_transform = get_transform(convert=True, normalize=True)
# self.lab_transform = get_transform()
# def __getitem__(self, index):
# image1 = Image.open(self.image_filenames1[index]).convert('RGB')
# image2 = Image.open(self.image_filenames2[index]).convert('RGB')
# label = Image.open(self.label_filenames[index])
# image1, image2, label = self.data_augment(image1, image2, label)
# image1, image2 = self.img_transform(image1), self.img_transform(image2)
# label = self.lab_transform(label)
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
# return image1, image2, label
# def __len__(self):
# return len(self.image_filenames1)
class TestDatasetFromPic(Dataset):
def __init__(self, img1_path, img2_path, label_path=None, transform=None,
patch_size=256, stride=256, pad_to_patch=False):
"""
动态滑动窗口测试数据集(单张大图版本)
Args:
img1_path (str): 第一期影像的完整路径
img2_path (str): 第二期影像的完整路径
label_path (str, optional): 标签路径(测试时可能不需要)
transform (callable, optional): 图像变换(如归一化)
patch_size (int): 裁剪的patch尺寸默认256
stride (int): 滑动步长默认256无重叠
pad_to_patch (bool): 是否将图像边缘不足patch_size的部分填充到完整patch
"""
self.img1_path = img1_path
self.img2_path = img2_path
self.label_path = label_path
self.transform = transform or transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
self.patch_size = patch_size
self.stride = stride
self.pad_to_patch = pad_to_patch
self.label_transform = transforms.ToTensor()
# 提取文件名(用于保存结果时命名)
self.img1_name = os.path.basename(img1_path)
self.img2_name = os.path.basename(img2_path)
# 动态生成坐标列表
self.coords = self._generate_coords()
def _generate_coords(self):
"""为单张图像生成滑动窗口坐标"""
coords = []
# 获取图像尺寸
with Image.open(self.img1_path) as img:
width, height = img.size
# 计算需要裁剪的patch数量考虑边缘填充
if self.pad_to_patch:
pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
n_patches_w = (width + pad_w) // self.stride
n_patches_h = (height + pad_h) // self.stride
else:
n_patches_w = (width - self.patch_size) // self.stride + 1
n_patches_h = (height - self.patch_size) // self.stride + 1
# 生成所有patch的坐标
for i in range(n_patches_w):
for j in range(n_patches_h):
x = i * self.stride
y = j * self.stride
# 确保不超出原始图像边界(不填充时)
if not self.pad_to_patch:
if x + self.patch_size > width or y + self.patch_size > height:
continue
coords.append({
'orig_name': self.img1_name, # 保持与原代码兼容
'x': x,
'y': y,
'h': self.patch_size,
'w': self.patch_size
})
return coords
def __len__(self):
return len(self.coords)
def __getitem__(self, idx):
"""获取单个patch的数据"""
row = self.coords[idx]
x, y, h, w = row['x'], row['y'], row['h'], row['w']
# 加载图像并裁剪
img1 = Image.open(self.img1_path).convert('RGB')
img2 = Image.open(self.img2_path).convert('RGB')
# 处理标签(如果有)
if self.label_path:
label = Image.open(self.label_path).convert('L')
label = label.crop((x, y, x + w, y + h))
label = self.label_transform(label)
else:
label = torch.zeros((1, h, w)) # 虚拟标签
# 裁剪图像
img1 = img1.crop((x, y, x + w, y + h))
img2 = img2.crop((x, y, x + w, y + h))
# 应用变换
if self.transform:
img1 = self.transform(img1)
img2 = self.transform(img2)
# 填充到统一尺寸(如果启用)
if self.pad_to_patch:
pad_h = self.patch_size - h
pad_w = self.patch_size - w
if pad_h > 0 or pad_w > 0:
img1 = F.pad(img1, (0, pad_w, 0, pad_h))
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
label = F.pad(label, (0, pad_w, 0, pad_h))
# 返回结果(保持与原代码兼容)
patch_name = f"{os.path.splitext(self.img1_name)[0]}_{x}_{y}.png"
return img1, img2, label, patch_name, self.img1_name, x, y, h, w
if __name__ == "__main__":
# 模拟 args可以是 argparse.Namespace 或硬编码的 Namespace
args = Namespace()
args.suffix = ['.png', '.jpg'] # 确保 suffix 存在
dataset = TestDatasetFromFolder(
args,
Time1_dir="path/to/Time1",
Time2_dir="path/to/Time2",
Label_dir="path/to/Label"
)