580 lines
22 KiB
Python
580 lines
22 KiB
Python
#coding=utf-8
|
||
from os.path import join
|
||
import torch
|
||
from PIL import Image, ImageEnhance
|
||
from torch.utils.data.dataset import Dataset
|
||
|
||
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
|
||
import numpy as np
|
||
import torchvision.transforms as transforms
|
||
import os
|
||
from argparse import Namespace
|
||
import imageio
|
||
import pandas as pd
|
||
import rasterio
|
||
from rasterio.windows import Window
|
||
import torch.nn.functional as F
|
||
Image.MAX_IMAGE_PIXELS = None # 禁用像素限制警告
|
||
def is_image_file(filename):
|
||
return any(filename.endswith(extension) for extension in ['.tif','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
|
||
|
||
def calMetric_iou(predict, label):
|
||
tp = np.sum(np.logical_and(predict == 1, label == 1))
|
||
fp = np.sum(predict==1)
|
||
fn = np.sum(label == 1)
|
||
return tp,fp+fn-tp
|
||
|
||
|
||
def getDataList(img_path):
|
||
dataline = open(img_path, 'r').readlines()
|
||
datalist =[]
|
||
for line in dataline:
|
||
temp = line.strip('\n')
|
||
datalist.append(temp)
|
||
return datalist
|
||
|
||
|
||
def make_one_hot(input, num_classes):
|
||
"""Convert class index tensor to one hot encoding tensor.
|
||
|
||
Args:
|
||
input: A tensor of shape [N, 1, *]
|
||
num_classes: An int of number of class
|
||
Returns:
|
||
A tensor of shape [N, num_classes, *]
|
||
"""
|
||
shape = np.array(input.shape)
|
||
shape[1] = num_classes
|
||
shape = tuple(shape)
|
||
result = torch.zeros(shape)
|
||
result = result.scatter_(1, input.cpu(), 1)
|
||
return result
|
||
|
||
|
||
def get_transform(convert=True, normalize=False):
|
||
transform_list = []
|
||
if convert:
|
||
transform_list += [transforms.ToTensor()]
|
||
if normalize:
|
||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
||
(0.5, 0.5, 0.5))]
|
||
return transforms.Compose(transform_list)
|
||
|
||
|
||
|
||
# class LoadDatasetFromFolder(Dataset):
|
||
# def __init__(self, args, hr1_path, hr2_path, lab_path):
|
||
# super(LoadDatasetFromFolder, self).__init__()
|
||
# # 获取图片列表
|
||
# datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if
|
||
# os.path.splitext(name)[1] == item]
|
||
|
||
# self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)]
|
||
# self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)]
|
||
# self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)]
|
||
|
||
# self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
|
||
# self.label_transform = get_transform() # only convert to tensor
|
||
|
||
# def __getitem__(self, index):
|
||
# hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB'))
|
||
# # lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB'))
|
||
# hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB'))
|
||
|
||
# label = self.label_transform(Image.open(self.lab_filenames[index]))
|
||
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
||
|
||
# return hr1_img, hr2_img, label
|
||
|
||
# def __len__(self):
|
||
# return len(self.hr1_filenames)
|
||
#验证
|
||
class LoadDatasetFromCoords(Dataset):
|
||
def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir, num_classes=2):
|
||
super(LoadDatasetFromCoords, self).__init__()
|
||
self.coords = pd.read_csv(coords_csv)
|
||
self.hr1_dir = hr1_dir
|
||
self.hr2_dir = hr2_dir
|
||
self.label_dir = label_dir
|
||
self.num_classes = num_classes
|
||
|
||
self.transform = get_transform(convert=True, normalize=True)
|
||
self.label_transform = get_transform()
|
||
self.open_images = {}
|
||
|
||
def _load_image(self, path):
|
||
if path not in self.open_images:
|
||
self.open_images[path] = rasterio.open(path)
|
||
return self.open_images[path]
|
||
|
||
def __len__(self):
|
||
return len(self.coords)
|
||
|
||
def __getitem__(self, idx):
|
||
row = self.coords.iloc[idx]
|
||
fname = row['orig_name']
|
||
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
|
||
|
||
# 加载影像
|
||
img1 = self._load_image(os.path.join(self.hr1_dir, fname))
|
||
img2 = self._load_image(os.path.join(self.hr2_dir, fname))
|
||
label = self._load_image(os.path.join(self.label_dir, fname))
|
||
|
||
window = Window(x, y, w, h)
|
||
|
||
patch1 = img1.read(window=window) # [C, H, W]
|
||
patch2 = img2.read(window=window)
|
||
patch_label = label.read(1, window=window) # [H, W]
|
||
|
||
# 转为 PIL.Image 以使用 transform
|
||
patch1 = Image.fromarray(np.transpose(patch1, (1, 2, 0)).astype(np.uint8))
|
||
patch2 = Image.fromarray(np.transpose(patch2, (1, 2, 0)).astype(np.uint8))
|
||
patch_label = Image.fromarray(patch_label.astype(np.uint8))
|
||
|
||
patch1 = self.transform(patch1)
|
||
patch2 = self.transform(patch2)
|
||
patch_label = self.label_transform(patch_label)
|
||
|
||
patch_label = make_one_hot(patch_label.unsqueeze(0).long(), self.num_classes).squeeze(0)
|
||
|
||
return patch1, patch2, patch_label
|
||
|
||
|
||
#测试
|
||
# class TestDatasetFromCoords(Dataset):
|
||
# def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir=None, transform=None):
|
||
# self.coords = pd.read_csv(coords_csv)
|
||
# self.hr1_dir = hr1_dir
|
||
# self.hr2_dir = hr2_dir
|
||
# self.label_dir = label_dir
|
||
# self.transform = transform
|
||
# self.label_transform = transforms.ToTensor()
|
||
|
||
# def __len__(self):
|
||
# return len(self.coords)
|
||
|
||
# def __getitem__(self, idx):
|
||
# row = self.coords.iloc[idx]
|
||
# x, y, h, w = row['x'], row['y'], row['h'], row['w']
|
||
# fname = row['orig_name']
|
||
|
||
# img1 = Image.open(os.path.join(self.hr1_dir, fname)).convert('RGB')
|
||
# img2 = Image.open(os.path.join(self.hr2_dir, fname)).convert('RGB')
|
||
|
||
# img1 = img1.crop((x, y, x + w, y + h))
|
||
# img2 = img2.crop((x, y, x + w, y + h))
|
||
|
||
# if self.label_dir:
|
||
# label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
|
||
# label = label.crop((x, y, x + w, y + h))
|
||
# label = self.label_transform(label)
|
||
# else:
|
||
# label = torch.zeros((1, h, w)) # dummy label
|
||
|
||
# if self.transform:
|
||
# img1 = self.transform(img1)
|
||
# img2 = self.transform(img2)
|
||
|
||
# patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
|
||
# return img1, img2, label, patch_name, fname, x, y
|
||
class TestDatasetFromCoords(Dataset):
|
||
def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir=None, transform=None):
|
||
self.coords = pd.read_csv(coords_csv)
|
||
self.hr1_dir = hr1_dir
|
||
self.hr2_dir = hr2_dir
|
||
self.label_dir = label_dir
|
||
self.transform = transform
|
||
self.label_transform = transforms.ToTensor()
|
||
|
||
def __len__(self):
|
||
return len(self.coords)
|
||
|
||
def __getitem__(self, idx):
|
||
row = self.coords.iloc[idx]
|
||
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
|
||
fname = row['orig_name']
|
||
|
||
# 读取图像并裁剪 patch
|
||
img1 = Image.open(os.path.join(self.hr1_dir, fname)).convert('RGB')
|
||
img2 = Image.open(os.path.join(self.hr2_dir, fname)).convert('RGB')
|
||
img1 = img1.crop((x, y, x + w, y + h))
|
||
img2 = img2.crop((x, y, x + w, y + h))
|
||
|
||
# 标签
|
||
if self.label_dir:
|
||
label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
|
||
label = label.crop((x, y, x + w, y + h))
|
||
label = self.label_transform(label)
|
||
else:
|
||
label = torch.zeros((1, h, w)) # dummy label
|
||
|
||
# 图像 transform
|
||
if self.transform:
|
||
img1 = self.transform(img1)
|
||
img2 = self.transform(img2)
|
||
|
||
# padding 到 256x256,并记录原始 h, w
|
||
pad_h = 256 - h
|
||
pad_w = 256 - w
|
||
if pad_h > 0 or pad_w > 0:
|
||
img1 = F.pad(img1, (0, pad_w, 0, pad_h)) # pad right & bottom
|
||
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
|
||
label = F.pad(label, (0, pad_w, 0, pad_h))
|
||
|
||
patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
|
||
return img1, img2, label, patch_name, fname, x, y, h, w # 返回原始大小
|
||
|
||
|
||
|
||
|
||
class TestDatasetFromImagePath(Dataset):
|
||
def __init__(self, img1_path,img2_path, transform=None):
|
||
# self.coords = pd.read_csv(coords_csv)
|
||
# self.hr1_dir = hr1_dir
|
||
# self.hr2_dir = hr2_dir
|
||
# self.label_dir = label_dir
|
||
self.img1_path = img1_path
|
||
self.img2_path = img2_path
|
||
self.transform = transform
|
||
self.label_transform = transforms.ToTensor()
|
||
|
||
def __len__(self):
|
||
return None
|
||
|
||
def __getitem__(self, idx):
|
||
# row = self.coords.iloc[idx]
|
||
# x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
|
||
# fname = row['orig_name']
|
||
|
||
# 读取图像并裁剪 patch
|
||
img1 = Image.open(self.img1_path).convert('RGB')
|
||
img2 = Image.open(self.img2_path).convert('RGB')
|
||
img1 = img1.crop((x, y, x + w, y + h))
|
||
img2 = img2.crop((x, y, x + w, y + h))
|
||
|
||
# 标签
|
||
if self.label_dir:
|
||
label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
|
||
label = label.crop((x, y, x + w, y + h))
|
||
label = self.label_transform(label)
|
||
else:
|
||
label = torch.zeros((1, h, w)) # dummy label
|
||
|
||
# 图像 transform
|
||
if self.transform:
|
||
img1 = self.transform(img1)
|
||
img2 = self.transform(img2)
|
||
|
||
# padding 到 256x256,并记录原始 h, w
|
||
pad_h = 256 - h
|
||
pad_w = 256 - w
|
||
if pad_h > 0 or pad_w > 0:
|
||
img1 = F.pad(img1, (0, pad_w, 0, pad_h)) # pad right & bottom
|
||
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
|
||
label = F.pad(label, (0, pad_w, 0, pad_h))
|
||
|
||
patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
|
||
return img1, img2, label, patch_name, fname, x, y, h, w # 返回原始大小
|
||
|
||
|
||
|
||
|
||
|
||
# class TestDatasetFromFolder(Dataset):
|
||
# def __init__(self, args, Time1_dir, Time2_dir, Label_dir):
|
||
# super(TestDatasetFromFolder, self).__init__()
|
||
|
||
# # 确保 args.suffix 存在,否则使用默认值
|
||
# suffixes = getattr(args, 'suffix', ['.png', '.jpg'])
|
||
|
||
# # 筛选符合后缀的文件
|
||
# datalist = [
|
||
# name for name in os.listdir(Time1_dir)
|
||
# if os.path.splitext(name)[1].lower() in suffixes
|
||
# ]
|
||
|
||
# print(f"找到 {len(datalist)} 个文件: {datalist[:5]}...") # 调试输出
|
||
|
||
# self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)]
|
||
# self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)]
|
||
# self.image3_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
|
||
|
||
# self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
|
||
# self.label_transform = get_transform()
|
||
|
||
# def __getitem__(self, index):
|
||
# image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB'))
|
||
# image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB'))
|
||
|
||
# label = self.label_transform(Image.open(self.image3_filenames[index]))
|
||
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
||
|
||
# image_name = self.image1_filenames[index].split('/', -1)
|
||
# image_name = image_name[len(image_name)-1]
|
||
|
||
# return image1, image2, label, image_name
|
||
|
||
# def __len__(self):
|
||
# return len(self.image1_filenames)
|
||
|
||
|
||
class trainImageAug(object):
|
||
def __init__(self, crop = True, augment = True, angle = 30):
|
||
self.crop =crop
|
||
self.augment = augment
|
||
self.angle = angle
|
||
|
||
def __call__(self, image1, image2, mask):
|
||
if self.crop:
|
||
w = np.random.randint(0,256)
|
||
h = np.random.randint(0,256)
|
||
box = (w, h, w+256, h+256)
|
||
image1 = image1.crop(box)
|
||
image2 = image2.crop(box)
|
||
mask = mask.crop(box)
|
||
if self.augment:
|
||
prop = np.random.uniform(0, 1)
|
||
if prop < 0.15:
|
||
image1 = image1.transpose(Image.FLIP_LEFT_RIGHT)
|
||
image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
|
||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||
elif prop < 0.3:
|
||
image1 = image1.transpose(Image.FLIP_TOP_BOTTOM)
|
||
image2 = image2.transpose(Image.FLIP_TOP_BOTTOM)
|
||
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
|
||
elif prop < 0.5:
|
||
image1 = image1.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
||
image2 = image2.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
||
mask = mask.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
||
|
||
return image1, image2, mask
|
||
|
||
def get_transform(convert=True, normalize=False):
|
||
transform_list = []
|
||
if convert:
|
||
transform_list += [
|
||
transforms.ToTensor(),
|
||
]
|
||
if normalize:
|
||
transform_list += [
|
||
# transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
|
||
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
|
||
return transforms.Compose(transform_list)
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
class DynamicPatchDataset(Dataset):
|
||
def __init__(self, coords_csv, img1_dir, img2_dir, label_dir,
|
||
crop=True, augment=True, angle=30, num_classes=2):
|
||
self.coords = pd.read_csv(coords_csv)
|
||
self.img1_dir = img1_dir
|
||
self.img2_dir = img2_dir
|
||
self.label_dir = label_dir
|
||
self.num_classes = num_classes
|
||
|
||
# 与原 DA_DatasetFromFolder 一致的设置
|
||
self.data_augment = trainImageAug(crop=crop, augment=augment, angle=angle)
|
||
self.img_transform = get_transform(convert=True, normalize=True)
|
||
self.lab_transform = get_transform()
|
||
|
||
self.open_images = {} # 缓存图像
|
||
|
||
def _load_image(self, path):
|
||
if path not in self.open_images:
|
||
self.open_images[path] = rasterio.open(path)
|
||
return self.open_images[path]
|
||
|
||
def __len__(self):
|
||
return len(self.coords)
|
||
|
||
def __getitem__(self, idx):
|
||
row = self.coords.iloc[idx]
|
||
fname = row['orig_name']
|
||
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
|
||
|
||
# 打开大图并裁剪 patch
|
||
img1 = self._load_image(os.path.join(self.img1_dir, fname))
|
||
img2 = self._load_image(os.path.join(self.img2_dir, fname))
|
||
label = self._load_image(os.path.join(self.label_dir, fname))
|
||
|
||
window = Window(x, y, w, h)
|
||
|
||
patch1 = img1.read(window=window) # [C, H, W]
|
||
patch2 = img2.read(window=window)
|
||
patch_label = label.read(1, window=window) # [H, W]
|
||
|
||
# 转换为 PIL Image 以便和原有增强方式兼容
|
||
patch1 = Image.fromarray(np.transpose(patch1, (1, 2, 0)).astype(np.uint8)) # [H, W, C]
|
||
patch2 = Image.fromarray(np.transpose(patch2, (1, 2, 0)).astype(np.uint8))
|
||
patch_label = Image.fromarray(patch_label.astype(np.uint8))
|
||
|
||
# 数据增强(裁剪、旋转等)
|
||
patch1, patch2, patch_label = self.data_augment(patch1, patch2, patch_label)
|
||
|
||
# 图像 transform(归一化、转 tensor)
|
||
patch1 = self.img_transform(patch1)
|
||
patch2 = self.img_transform(patch2)
|
||
patch_label = self.lab_transform(patch_label)
|
||
|
||
# one-hot label
|
||
patch_label = make_one_hot(patch_label.unsqueeze(0).long(), self.num_classes).squeeze(0)
|
||
|
||
return patch1, patch2, patch_label
|
||
# class DA_DatasetFromFolder(Dataset):
|
||
# def __init__(self, Image_dir1, Image_dir2, Label_dir, crop=True, augment = True, angle = 30):
|
||
# super(DA_DatasetFromFolder, self).__init__()
|
||
# # 获取图片列表
|
||
# datalist = os.listdir(Image_dir1)
|
||
# self.image_filenames1 = [join(Image_dir1, x) for x in datalist if is_image_file(x)]
|
||
# self.image_filenames2 = [join(Image_dir2, x) for x in datalist if is_image_file(x)]
|
||
# self.label_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
|
||
# self.data_augment = trainImageAug(crop=crop, augment = augment, angle=angle)
|
||
# self.img_transform = get_transform(convert=True, normalize=True)
|
||
# self.lab_transform = get_transform()
|
||
|
||
# def __getitem__(self, index):
|
||
# image1 = Image.open(self.image_filenames1[index]).convert('RGB')
|
||
# image2 = Image.open(self.image_filenames2[index]).convert('RGB')
|
||
# label = Image.open(self.label_filenames[index])
|
||
# image1, image2, label = self.data_augment(image1, image2, label)
|
||
# image1, image2 = self.img_transform(image1), self.img_transform(image2)
|
||
# label = self.lab_transform(label)
|
||
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
||
# return image1, image2, label
|
||
|
||
# def __len__(self):
|
||
# return len(self.image_filenames1)
|
||
|
||
class TestDatasetFromPic(Dataset):
|
||
def __init__(self, img1_path, img2_path, label_path=None, transform=None,
|
||
patch_size=256, stride=256, pad_to_patch=False):
|
||
"""
|
||
动态滑动窗口测试数据集(单张大图版本)
|
||
|
||
Args:
|
||
img1_path (str): 第一期影像的完整路径
|
||
img2_path (str): 第二期影像的完整路径
|
||
label_path (str, optional): 标签路径(测试时可能不需要)
|
||
transform (callable, optional): 图像变换(如归一化)
|
||
patch_size (int): 裁剪的patch尺寸(默认256)
|
||
stride (int): 滑动步长(默认256,无重叠)
|
||
pad_to_patch (bool): 是否将图像边缘不足patch_size的部分填充到完整patch
|
||
"""
|
||
self.img1_path = img1_path
|
||
self.img2_path = img2_path
|
||
self.label_path = label_path
|
||
self.transform = transform or transforms.Compose([
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||
])
|
||
self.patch_size = patch_size
|
||
self.stride = stride
|
||
self.pad_to_patch = pad_to_patch
|
||
self.label_transform = transforms.ToTensor()
|
||
|
||
# 提取文件名(用于保存结果时命名)
|
||
self.img1_name = os.path.basename(img1_path)
|
||
self.img2_name = os.path.basename(img2_path)
|
||
|
||
# 动态生成坐标列表
|
||
self.coords = self._generate_coords()
|
||
|
||
def _generate_coords(self):
|
||
"""为单张图像生成滑动窗口坐标"""
|
||
coords = []
|
||
|
||
# 获取图像尺寸
|
||
with Image.open(self.img1_path) as img:
|
||
width, height = img.size
|
||
|
||
# 计算需要裁剪的patch数量(考虑边缘填充)
|
||
if self.pad_to_patch:
|
||
pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
|
||
pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
|
||
n_patches_w = (width + pad_w) // self.stride
|
||
n_patches_h = (height + pad_h) // self.stride
|
||
else:
|
||
n_patches_w = (width - self.patch_size) // self.stride + 1
|
||
n_patches_h = (height - self.patch_size) // self.stride + 1
|
||
|
||
# 生成所有patch的坐标
|
||
for i in range(n_patches_w):
|
||
for j in range(n_patches_h):
|
||
x = i * self.stride
|
||
y = j * self.stride
|
||
# 确保不超出原始图像边界(不填充时)
|
||
if not self.pad_to_patch:
|
||
if x + self.patch_size > width or y + self.patch_size > height:
|
||
continue
|
||
coords.append({
|
||
'orig_name': self.img1_name, # 保持与原代码兼容
|
||
'x': x,
|
||
'y': y,
|
||
'h': self.patch_size,
|
||
'w': self.patch_size
|
||
})
|
||
|
||
return coords
|
||
|
||
def __len__(self):
|
||
return len(self.coords)
|
||
|
||
def __getitem__(self, idx):
|
||
"""获取单个patch的数据"""
|
||
row = self.coords[idx]
|
||
x, y, h, w = row['x'], row['y'], row['h'], row['w']
|
||
|
||
# 加载图像并裁剪
|
||
img1 = Image.open(self.img1_path).convert('RGB')
|
||
img2 = Image.open(self.img2_path).convert('RGB')
|
||
|
||
# 处理标签(如果有)
|
||
if self.label_path:
|
||
label = Image.open(self.label_path).convert('L')
|
||
label = label.crop((x, y, x + w, y + h))
|
||
label = self.label_transform(label)
|
||
else:
|
||
label = torch.zeros((1, h, w)) # 虚拟标签
|
||
|
||
# 裁剪图像
|
||
img1 = img1.crop((x, y, x + w, y + h))
|
||
img2 = img2.crop((x, y, x + w, y + h))
|
||
|
||
# 应用变换
|
||
if self.transform:
|
||
img1 = self.transform(img1)
|
||
img2 = self.transform(img2)
|
||
|
||
# 填充到统一尺寸(如果启用)
|
||
if self.pad_to_patch:
|
||
pad_h = self.patch_size - h
|
||
pad_w = self.patch_size - w
|
||
if pad_h > 0 or pad_w > 0:
|
||
img1 = F.pad(img1, (0, pad_w, 0, pad_h))
|
||
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
|
||
label = F.pad(label, (0, pad_w, 0, pad_h))
|
||
|
||
# 返回结果(保持与原代码兼容)
|
||
patch_name = f"{os.path.splitext(self.img1_name)[0]}_{x}_{y}.png"
|
||
return img1, img2, label, patch_name, self.img1_name, x, y, h, w
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 模拟 args(可以是 argparse.Namespace 或硬编码的 Namespace)
|
||
args = Namespace()
|
||
args.suffix = ['.png', '.jpg'] # 确保 suffix 存在
|
||
|
||
dataset = TestDatasetFromFolder(
|
||
args,
|
||
Time1_dir="path/to/Time1",
|
||
Time2_dir="path/to/Time2",
|
||
Label_dir="path/to/Label"
|
||
)
|