#coding=utf-8 from os.path import join import torch from PIL import Image, ImageEnhance from torch.utils.data.dataset import Dataset from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize import numpy as np import torchvision.transforms as transforms import os from argparse import Namespace import imageio import pandas as pd import rasterio from rasterio.windows import Window import torch.nn.functional as F Image.MAX_IMAGE_PIXELS = None # 禁用像素限制警告 def is_image_file(filename): return any(filename.endswith(extension) for extension in ['.tif','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']) def calMetric_iou(predict, label): tp = np.sum(np.logical_and(predict == 1, label == 1)) fp = np.sum(predict==1) fn = np.sum(label == 1) return tp,fp+fn-tp def getDataList(img_path): dataline = open(img_path, 'r').readlines() datalist =[] for line in dataline: temp = line.strip('\n') datalist.append(temp) return datalist def make_one_hot(input, num_classes): """Convert class index tensor to one hot encoding tensor. Args: input: A tensor of shape [N, 1, *] num_classes: An int of number of class Returns: A tensor of shape [N, num_classes, *] """ shape = np.array(input.shape) shape[1] = num_classes shape = tuple(shape) result = torch.zeros(shape) result = result.scatter_(1, input.cpu(), 1) return result def get_transform(convert=True, normalize=False): transform_list = [] if convert: transform_list += [transforms.ToTensor()] if normalize: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) # class LoadDatasetFromFolder(Dataset): # def __init__(self, args, hr1_path, hr2_path, lab_path): # super(LoadDatasetFromFolder, self).__init__() # # 获取图片列表 # datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if # os.path.splitext(name)[1] == item] # self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)] # self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)] # self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)] # self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1] # self.label_transform = get_transform() # only convert to tensor # def __getitem__(self, index): # hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB')) # # lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB')) # hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB')) # label = self.label_transform(Image.open(self.lab_filenames[index])) # label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0) # return hr1_img, hr2_img, label # def __len__(self): # return len(self.hr1_filenames) #验证 class LoadDatasetFromCoords(Dataset): def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir, num_classes=2): super(LoadDatasetFromCoords, self).__init__() self.coords = pd.read_csv(coords_csv) self.hr1_dir = hr1_dir self.hr2_dir = hr2_dir self.label_dir = label_dir self.num_classes = num_classes self.transform = get_transform(convert=True, normalize=True) self.label_transform = get_transform() self.open_images = {} def _load_image(self, path): if path not in self.open_images: self.open_images[path] = rasterio.open(path) return self.open_images[path] def __len__(self): return len(self.coords) def __getitem__(self, idx): row = self.coords.iloc[idx] fname = row['orig_name'] x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w']) # 加载影像 img1 = self._load_image(os.path.join(self.hr1_dir, fname)) img2 = self._load_image(os.path.join(self.hr2_dir, fname)) label = self._load_image(os.path.join(self.label_dir, fname)) window = Window(x, y, w, h) patch1 = img1.read(window=window) # [C, H, W] patch2 = img2.read(window=window) patch_label = label.read(1, window=window) # [H, W] # 转为 PIL.Image 以使用 transform patch1 = Image.fromarray(np.transpose(patch1, (1, 2, 0)).astype(np.uint8)) patch2 = Image.fromarray(np.transpose(patch2, (1, 2, 0)).astype(np.uint8)) patch_label = Image.fromarray(patch_label.astype(np.uint8)) patch1 = self.transform(patch1) patch2 = self.transform(patch2) patch_label = self.label_transform(patch_label) patch_label = make_one_hot(patch_label.unsqueeze(0).long(), self.num_classes).squeeze(0) return patch1, patch2, patch_label #测试 # class TestDatasetFromCoords(Dataset): # def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir=None, transform=None): # self.coords = pd.read_csv(coords_csv) # self.hr1_dir = hr1_dir # self.hr2_dir = hr2_dir # self.label_dir = label_dir # self.transform = transform # self.label_transform = transforms.ToTensor() # def __len__(self): # return len(self.coords) # def __getitem__(self, idx): # row = self.coords.iloc[idx] # x, y, h, w = row['x'], row['y'], row['h'], row['w'] # fname = row['orig_name'] # img1 = Image.open(os.path.join(self.hr1_dir, fname)).convert('RGB') # img2 = Image.open(os.path.join(self.hr2_dir, fname)).convert('RGB') # img1 = img1.crop((x, y, x + w, y + h)) # img2 = img2.crop((x, y, x + w, y + h)) # if self.label_dir: # label = Image.open(os.path.join(self.label_dir, fname)).convert('L') # label = label.crop((x, y, x + w, y + h)) # label = self.label_transform(label) # else: # label = torch.zeros((1, h, w)) # dummy label # if self.transform: # img1 = self.transform(img1) # img2 = self.transform(img2) # patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}" # return img1, img2, label, patch_name, fname, x, y class TestDatasetFromCoords(Dataset): def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir=None, transform=None): self.coords = pd.read_csv(coords_csv) self.hr1_dir = hr1_dir self.hr2_dir = hr2_dir self.label_dir = label_dir self.transform = transform self.label_transform = transforms.ToTensor() def __len__(self): return len(self.coords) def __getitem__(self, idx): row = self.coords.iloc[idx] x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w']) fname = row['orig_name'] # 读取图像并裁剪 patch img1 = Image.open(os.path.join(self.hr1_dir, fname)).convert('RGB') img2 = Image.open(os.path.join(self.hr2_dir, fname)).convert('RGB') img1 = img1.crop((x, y, x + w, y + h)) img2 = img2.crop((x, y, x + w, y + h)) # 标签 if self.label_dir: label = Image.open(os.path.join(self.label_dir, fname)).convert('L') label = label.crop((x, y, x + w, y + h)) label = self.label_transform(label) else: label = torch.zeros((1, h, w)) # dummy label # 图像 transform if self.transform: img1 = self.transform(img1) img2 = self.transform(img2) # padding 到 256x256,并记录原始 h, w pad_h = 256 - h pad_w = 256 - w if pad_h > 0 or pad_w > 0: img1 = F.pad(img1, (0, pad_w, 0, pad_h)) # pad right & bottom img2 = F.pad(img2, (0, pad_w, 0, pad_h)) label = F.pad(label, (0, pad_w, 0, pad_h)) patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}" return img1, img2, label, patch_name, fname, x, y, h, w # 返回原始大小 class TestDatasetFromImagePath(Dataset): def __init__(self, img1_path,img2_path, transform=None): # self.coords = pd.read_csv(coords_csv) # self.hr1_dir = hr1_dir # self.hr2_dir = hr2_dir # self.label_dir = label_dir self.img1_path = img1_path self.img2_path = img2_path self.transform = transform self.label_transform = transforms.ToTensor() def __len__(self): return None def __getitem__(self, idx): # row = self.coords.iloc[idx] # x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w']) # fname = row['orig_name'] # 读取图像并裁剪 patch img1 = Image.open(self.img1_path).convert('RGB') img2 = Image.open(self.img2_path).convert('RGB') img1 = img1.crop((x, y, x + w, y + h)) img2 = img2.crop((x, y, x + w, y + h)) # 标签 if self.label_dir: label = Image.open(os.path.join(self.label_dir, fname)).convert('L') label = label.crop((x, y, x + w, y + h)) label = self.label_transform(label) else: label = torch.zeros((1, h, w)) # dummy label # 图像 transform if self.transform: img1 = self.transform(img1) img2 = self.transform(img2) # padding 到 256x256,并记录原始 h, w pad_h = 256 - h pad_w = 256 - w if pad_h > 0 or pad_w > 0: img1 = F.pad(img1, (0, pad_w, 0, pad_h)) # pad right & bottom img2 = F.pad(img2, (0, pad_w, 0, pad_h)) label = F.pad(label, (0, pad_w, 0, pad_h)) patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}" return img1, img2, label, patch_name, fname, x, y, h, w # 返回原始大小 # class TestDatasetFromFolder(Dataset): # def __init__(self, args, Time1_dir, Time2_dir, Label_dir): # super(TestDatasetFromFolder, self).__init__() # # 确保 args.suffix 存在,否则使用默认值 # suffixes = getattr(args, 'suffix', ['.png', '.jpg']) # # 筛选符合后缀的文件 # datalist = [ # name for name in os.listdir(Time1_dir) # if os.path.splitext(name)[1].lower() in suffixes # ] # print(f"找到 {len(datalist)} 个文件: {datalist[:5]}...") # 调试输出 # self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)] # self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)] # self.image3_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)] # self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1] # self.label_transform = get_transform() # def __getitem__(self, index): # image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB')) # image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB')) # label = self.label_transform(Image.open(self.image3_filenames[index])) # label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0) # image_name = self.image1_filenames[index].split('/', -1) # image_name = image_name[len(image_name)-1] # return image1, image2, label, image_name # def __len__(self): # return len(self.image1_filenames) class trainImageAug(object): def __init__(self, crop = True, augment = True, angle = 30): self.crop =crop self.augment = augment self.angle = angle def __call__(self, image1, image2, mask): if self.crop: w = np.random.randint(0,256) h = np.random.randint(0,256) box = (w, h, w+256, h+256) image1 = image1.crop(box) image2 = image2.crop(box) mask = mask.crop(box) if self.augment: prop = np.random.uniform(0, 1) if prop < 0.15: image1 = image1.transpose(Image.FLIP_LEFT_RIGHT) image2 = image2.transpose(Image.FLIP_LEFT_RIGHT) mask = mask.transpose(Image.FLIP_LEFT_RIGHT) elif prop < 0.3: image1 = image1.transpose(Image.FLIP_TOP_BOTTOM) image2 = image2.transpose(Image.FLIP_TOP_BOTTOM) mask = mask.transpose(Image.FLIP_TOP_BOTTOM) elif prop < 0.5: image1 = image1.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle])) image2 = image2.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle])) mask = mask.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle])) return image1, image2, mask def get_transform(convert=True, normalize=False): transform_list = [] if convert: transform_list += [ transforms.ToTensor(), ] if normalize: transform_list += [ # transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))] return transforms.Compose(transform_list) class DynamicPatchDataset(Dataset): def __init__(self, coords_csv, img1_dir, img2_dir, label_dir, crop=True, augment=True, angle=30, num_classes=2): self.coords = pd.read_csv(coords_csv) self.img1_dir = img1_dir self.img2_dir = img2_dir self.label_dir = label_dir self.num_classes = num_classes # 与原 DA_DatasetFromFolder 一致的设置 self.data_augment = trainImageAug(crop=crop, augment=augment, angle=angle) self.img_transform = get_transform(convert=True, normalize=True) self.lab_transform = get_transform() self.open_images = {} # 缓存图像 def _load_image(self, path): if path not in self.open_images: self.open_images[path] = rasterio.open(path) return self.open_images[path] def __len__(self): return len(self.coords) def __getitem__(self, idx): row = self.coords.iloc[idx] fname = row['orig_name'] x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w']) # 打开大图并裁剪 patch img1 = self._load_image(os.path.join(self.img1_dir, fname)) img2 = self._load_image(os.path.join(self.img2_dir, fname)) label = self._load_image(os.path.join(self.label_dir, fname)) window = Window(x, y, w, h) patch1 = img1.read(window=window) # [C, H, W] patch2 = img2.read(window=window) patch_label = label.read(1, window=window) # [H, W] # 转换为 PIL Image 以便和原有增强方式兼容 patch1 = Image.fromarray(np.transpose(patch1, (1, 2, 0)).astype(np.uint8)) # [H, W, C] patch2 = Image.fromarray(np.transpose(patch2, (1, 2, 0)).astype(np.uint8)) patch_label = Image.fromarray(patch_label.astype(np.uint8)) # 数据增强(裁剪、旋转等) patch1, patch2, patch_label = self.data_augment(patch1, patch2, patch_label) # 图像 transform(归一化、转 tensor) patch1 = self.img_transform(patch1) patch2 = self.img_transform(patch2) patch_label = self.lab_transform(patch_label) # one-hot label patch_label = make_one_hot(patch_label.unsqueeze(0).long(), self.num_classes).squeeze(0) return patch1, patch2, patch_label # class DA_DatasetFromFolder(Dataset): # def __init__(self, Image_dir1, Image_dir2, Label_dir, crop=True, augment = True, angle = 30): # super(DA_DatasetFromFolder, self).__init__() # # 获取图片列表 # datalist = os.listdir(Image_dir1) # self.image_filenames1 = [join(Image_dir1, x) for x in datalist if is_image_file(x)] # self.image_filenames2 = [join(Image_dir2, x) for x in datalist if is_image_file(x)] # self.label_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)] # self.data_augment = trainImageAug(crop=crop, augment = augment, angle=angle) # self.img_transform = get_transform(convert=True, normalize=True) # self.lab_transform = get_transform() # def __getitem__(self, index): # image1 = Image.open(self.image_filenames1[index]).convert('RGB') # image2 = Image.open(self.image_filenames2[index]).convert('RGB') # label = Image.open(self.label_filenames[index]) # image1, image2, label = self.data_augment(image1, image2, label) # image1, image2 = self.img_transform(image1), self.img_transform(image2) # label = self.lab_transform(label) # label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0) # return image1, image2, label # def __len__(self): # return len(self.image_filenames1) class TestDatasetFromPic(Dataset): def __init__(self, img1_path, img2_path, label_path=None, transform=None, patch_size=256, stride=256, pad_to_patch=False): """ 动态滑动窗口测试数据集(单张大图版本) Args: img1_path (str): 第一期影像的完整路径 img2_path (str): 第二期影像的完整路径 label_path (str, optional): 标签路径(测试时可能不需要) transform (callable, optional): 图像变换(如归一化) patch_size (int): 裁剪的patch尺寸(默认256) stride (int): 滑动步长(默认256,无重叠) pad_to_patch (bool): 是否将图像边缘不足patch_size的部分填充到完整patch """ self.img1_path = img1_path self.img2_path = img2_path self.label_path = label_path self.transform = transform or transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) self.patch_size = patch_size self.stride = stride self.pad_to_patch = pad_to_patch self.label_transform = transforms.ToTensor() # 提取文件名(用于保存结果时命名) self.img1_name = os.path.basename(img1_path) self.img2_name = os.path.basename(img2_path) # 动态生成坐标列表 self.coords = self._generate_coords() def _generate_coords(self): """为单张图像生成滑动窗口坐标""" coords = [] # 获取图像尺寸 with Image.open(self.img1_path) as img: width, height = img.size # 计算需要裁剪的patch数量(考虑边缘填充) if self.pad_to_patch: pad_w = (self.patch_size - width % self.patch_size) % self.patch_size pad_h = (self.patch_size - height % self.patch_size) % self.patch_size n_patches_w = (width + pad_w) // self.stride n_patches_h = (height + pad_h) // self.stride else: n_patches_w = (width - self.patch_size) // self.stride + 1 n_patches_h = (height - self.patch_size) // self.stride + 1 # 生成所有patch的坐标 for i in range(n_patches_w): for j in range(n_patches_h): x = i * self.stride y = j * self.stride # 确保不超出原始图像边界(不填充时) if not self.pad_to_patch: if x + self.patch_size > width or y + self.patch_size > height: continue coords.append({ 'orig_name': self.img1_name, # 保持与原代码兼容 'x': x, 'y': y, 'h': self.patch_size, 'w': self.patch_size }) return coords def __len__(self): return len(self.coords) def __getitem__(self, idx): """获取单个patch的数据""" row = self.coords[idx] x, y, h, w = row['x'], row['y'], row['h'], row['w'] # 加载图像并裁剪 img1 = Image.open(self.img1_path).convert('RGB') img2 = Image.open(self.img2_path).convert('RGB') # 处理标签(如果有) if self.label_path: label = Image.open(self.label_path).convert('L') label = label.crop((x, y, x + w, y + h)) label = self.label_transform(label) else: label = torch.zeros((1, h, w)) # 虚拟标签 # 裁剪图像 img1 = img1.crop((x, y, x + w, y + h)) img2 = img2.crop((x, y, x + w, y + h)) # 应用变换 if self.transform: img1 = self.transform(img1) img2 = self.transform(img2) # 填充到统一尺寸(如果启用) if self.pad_to_patch: pad_h = self.patch_size - h pad_w = self.patch_size - w if pad_h > 0 or pad_w > 0: img1 = F.pad(img1, (0, pad_w, 0, pad_h)) img2 = F.pad(img2, (0, pad_w, 0, pad_h)) label = F.pad(label, (0, pad_w, 0, pad_h)) # 返回结果(保持与原代码兼容) patch_name = f"{os.path.splitext(self.img1_name)[0]}_{x}_{y}.png" return img1, img2, label, patch_name, self.img1_name, x, y, h, w if __name__ == "__main__": # 模拟 args(可以是 argparse.Namespace 或硬编码的 Namespace) args = Namespace() args.suffix = ['.png', '.jpg'] # 确保 suffix 存在 dataset = TestDatasetFromFolder( args, Time1_dir="path/to/Time1", Time2_dir="path/to/Time2", Label_dir="path/to/Label" )