580 lines
22 KiB
Python
580 lines
22 KiB
Python
|
|
#coding=utf-8
|
|||
|
|
from os.path import join
|
|||
|
|
import torch
|
|||
|
|
from PIL import Image, ImageEnhance
|
|||
|
|
from torch.utils.data.dataset import Dataset
|
|||
|
|
|
|||
|
|
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
|
|||
|
|
import numpy as np
|
|||
|
|
import torchvision.transforms as transforms
|
|||
|
|
import os
|
|||
|
|
from argparse import Namespace
|
|||
|
|
import imageio
|
|||
|
|
import pandas as pd
|
|||
|
|
import rasterio
|
|||
|
|
from rasterio.windows import Window
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
Image.MAX_IMAGE_PIXELS = None # 禁用像素限制警告
|
|||
|
|
def is_image_file(filename):
|
|||
|
|
return any(filename.endswith(extension) for extension in ['.tif','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
|
|||
|
|
|
|||
|
|
def calMetric_iou(predict, label):
|
|||
|
|
tp = np.sum(np.logical_and(predict == 1, label == 1))
|
|||
|
|
fp = np.sum(predict==1)
|
|||
|
|
fn = np.sum(label == 1)
|
|||
|
|
return tp,fp+fn-tp
|
|||
|
|
|
|||
|
|
|
|||
|
|
def getDataList(img_path):
|
|||
|
|
dataline = open(img_path, 'r').readlines()
|
|||
|
|
datalist =[]
|
|||
|
|
for line in dataline:
|
|||
|
|
temp = line.strip('\n')
|
|||
|
|
datalist.append(temp)
|
|||
|
|
return datalist
|
|||
|
|
|
|||
|
|
|
|||
|
|
def make_one_hot(input, num_classes):
|
|||
|
|
"""Convert class index tensor to one hot encoding tensor.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
input: A tensor of shape [N, 1, *]
|
|||
|
|
num_classes: An int of number of class
|
|||
|
|
Returns:
|
|||
|
|
A tensor of shape [N, num_classes, *]
|
|||
|
|
"""
|
|||
|
|
shape = np.array(input.shape)
|
|||
|
|
shape[1] = num_classes
|
|||
|
|
shape = tuple(shape)
|
|||
|
|
result = torch.zeros(shape)
|
|||
|
|
result = result.scatter_(1, input.cpu(), 1)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_transform(convert=True, normalize=False):
|
|||
|
|
transform_list = []
|
|||
|
|
if convert:
|
|||
|
|
transform_list += [transforms.ToTensor()]
|
|||
|
|
if normalize:
|
|||
|
|
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
|||
|
|
(0.5, 0.5, 0.5))]
|
|||
|
|
return transforms.Compose(transform_list)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# class LoadDatasetFromFolder(Dataset):
|
|||
|
|
# def __init__(self, args, hr1_path, hr2_path, lab_path):
|
|||
|
|
# super(LoadDatasetFromFolder, self).__init__()
|
|||
|
|
# # 获取图片列表
|
|||
|
|
# datalist = [name for name in os.listdir(hr1_path) for item in args.suffix if
|
|||
|
|
# os.path.splitext(name)[1] == item]
|
|||
|
|
|
|||
|
|
# self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)]
|
|||
|
|
# self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)]
|
|||
|
|
# self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)]
|
|||
|
|
|
|||
|
|
# self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
|
|||
|
|
# self.label_transform = get_transform() # only convert to tensor
|
|||
|
|
|
|||
|
|
# def __getitem__(self, index):
|
|||
|
|
# hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB'))
|
|||
|
|
# # lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB'))
|
|||
|
|
# hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB'))
|
|||
|
|
|
|||
|
|
# label = self.label_transform(Image.open(self.lab_filenames[index]))
|
|||
|
|
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
|||
|
|
|
|||
|
|
# return hr1_img, hr2_img, label
|
|||
|
|
|
|||
|
|
# def __len__(self):
|
|||
|
|
# return len(self.hr1_filenames)
|
|||
|
|
#验证
|
|||
|
|
class LoadDatasetFromCoords(Dataset):
|
|||
|
|
def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir, num_classes=2):
|
|||
|
|
super(LoadDatasetFromCoords, self).__init__()
|
|||
|
|
self.coords = pd.read_csv(coords_csv)
|
|||
|
|
self.hr1_dir = hr1_dir
|
|||
|
|
self.hr2_dir = hr2_dir
|
|||
|
|
self.label_dir = label_dir
|
|||
|
|
self.num_classes = num_classes
|
|||
|
|
|
|||
|
|
self.transform = get_transform(convert=True, normalize=True)
|
|||
|
|
self.label_transform = get_transform()
|
|||
|
|
self.open_images = {}
|
|||
|
|
|
|||
|
|
def _load_image(self, path):
|
|||
|
|
if path not in self.open_images:
|
|||
|
|
self.open_images[path] = rasterio.open(path)
|
|||
|
|
return self.open_images[path]
|
|||
|
|
|
|||
|
|
def __len__(self):
|
|||
|
|
return len(self.coords)
|
|||
|
|
|
|||
|
|
def __getitem__(self, idx):
|
|||
|
|
row = self.coords.iloc[idx]
|
|||
|
|
fname = row['orig_name']
|
|||
|
|
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
|
|||
|
|
|
|||
|
|
# 加载影像
|
|||
|
|
img1 = self._load_image(os.path.join(self.hr1_dir, fname))
|
|||
|
|
img2 = self._load_image(os.path.join(self.hr2_dir, fname))
|
|||
|
|
label = self._load_image(os.path.join(self.label_dir, fname))
|
|||
|
|
|
|||
|
|
window = Window(x, y, w, h)
|
|||
|
|
|
|||
|
|
patch1 = img1.read(window=window) # [C, H, W]
|
|||
|
|
patch2 = img2.read(window=window)
|
|||
|
|
patch_label = label.read(1, window=window) # [H, W]
|
|||
|
|
|
|||
|
|
# 转为 PIL.Image 以使用 transform
|
|||
|
|
patch1 = Image.fromarray(np.transpose(patch1, (1, 2, 0)).astype(np.uint8))
|
|||
|
|
patch2 = Image.fromarray(np.transpose(patch2, (1, 2, 0)).astype(np.uint8))
|
|||
|
|
patch_label = Image.fromarray(patch_label.astype(np.uint8))
|
|||
|
|
|
|||
|
|
patch1 = self.transform(patch1)
|
|||
|
|
patch2 = self.transform(patch2)
|
|||
|
|
patch_label = self.label_transform(patch_label)
|
|||
|
|
|
|||
|
|
patch_label = make_one_hot(patch_label.unsqueeze(0).long(), self.num_classes).squeeze(0)
|
|||
|
|
|
|||
|
|
return patch1, patch2, patch_label
|
|||
|
|
|
|||
|
|
|
|||
|
|
#测试
|
|||
|
|
# class TestDatasetFromCoords(Dataset):
|
|||
|
|
# def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir=None, transform=None):
|
|||
|
|
# self.coords = pd.read_csv(coords_csv)
|
|||
|
|
# self.hr1_dir = hr1_dir
|
|||
|
|
# self.hr2_dir = hr2_dir
|
|||
|
|
# self.label_dir = label_dir
|
|||
|
|
# self.transform = transform
|
|||
|
|
# self.label_transform = transforms.ToTensor()
|
|||
|
|
|
|||
|
|
# def __len__(self):
|
|||
|
|
# return len(self.coords)
|
|||
|
|
|
|||
|
|
# def __getitem__(self, idx):
|
|||
|
|
# row = self.coords.iloc[idx]
|
|||
|
|
# x, y, h, w = row['x'], row['y'], row['h'], row['w']
|
|||
|
|
# fname = row['orig_name']
|
|||
|
|
|
|||
|
|
# img1 = Image.open(os.path.join(self.hr1_dir, fname)).convert('RGB')
|
|||
|
|
# img2 = Image.open(os.path.join(self.hr2_dir, fname)).convert('RGB')
|
|||
|
|
|
|||
|
|
# img1 = img1.crop((x, y, x + w, y + h))
|
|||
|
|
# img2 = img2.crop((x, y, x + w, y + h))
|
|||
|
|
|
|||
|
|
# if self.label_dir:
|
|||
|
|
# label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
|
|||
|
|
# label = label.crop((x, y, x + w, y + h))
|
|||
|
|
# label = self.label_transform(label)
|
|||
|
|
# else:
|
|||
|
|
# label = torch.zeros((1, h, w)) # dummy label
|
|||
|
|
|
|||
|
|
# if self.transform:
|
|||
|
|
# img1 = self.transform(img1)
|
|||
|
|
# img2 = self.transform(img2)
|
|||
|
|
|
|||
|
|
# patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
|
|||
|
|
# return img1, img2, label, patch_name, fname, x, y
|
|||
|
|
class TestDatasetFromCoords(Dataset):
|
|||
|
|
def __init__(self, coords_csv, hr1_dir, hr2_dir, label_dir=None, transform=None):
|
|||
|
|
self.coords = pd.read_csv(coords_csv)
|
|||
|
|
self.hr1_dir = hr1_dir
|
|||
|
|
self.hr2_dir = hr2_dir
|
|||
|
|
self.label_dir = label_dir
|
|||
|
|
self.transform = transform
|
|||
|
|
self.label_transform = transforms.ToTensor()
|
|||
|
|
|
|||
|
|
def __len__(self):
|
|||
|
|
return len(self.coords)
|
|||
|
|
|
|||
|
|
def __getitem__(self, idx):
|
|||
|
|
row = self.coords.iloc[idx]
|
|||
|
|
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
|
|||
|
|
fname = row['orig_name']
|
|||
|
|
|
|||
|
|
# 读取图像并裁剪 patch
|
|||
|
|
img1 = Image.open(os.path.join(self.hr1_dir, fname)).convert('RGB')
|
|||
|
|
img2 = Image.open(os.path.join(self.hr2_dir, fname)).convert('RGB')
|
|||
|
|
img1 = img1.crop((x, y, x + w, y + h))
|
|||
|
|
img2 = img2.crop((x, y, x + w, y + h))
|
|||
|
|
|
|||
|
|
# 标签
|
|||
|
|
if self.label_dir:
|
|||
|
|
label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
|
|||
|
|
label = label.crop((x, y, x + w, y + h))
|
|||
|
|
label = self.label_transform(label)
|
|||
|
|
else:
|
|||
|
|
label = torch.zeros((1, h, w)) # dummy label
|
|||
|
|
|
|||
|
|
# 图像 transform
|
|||
|
|
if self.transform:
|
|||
|
|
img1 = self.transform(img1)
|
|||
|
|
img2 = self.transform(img2)
|
|||
|
|
|
|||
|
|
# padding 到 256x256,并记录原始 h, w
|
|||
|
|
pad_h = 256 - h
|
|||
|
|
pad_w = 256 - w
|
|||
|
|
if pad_h > 0 or pad_w > 0:
|
|||
|
|
img1 = F.pad(img1, (0, pad_w, 0, pad_h)) # pad right & bottom
|
|||
|
|
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
|
|||
|
|
label = F.pad(label, (0, pad_w, 0, pad_h))
|
|||
|
|
|
|||
|
|
patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
|
|||
|
|
return img1, img2, label, patch_name, fname, x, y, h, w # 返回原始大小
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestDatasetFromImagePath(Dataset):
|
|||
|
|
def __init__(self, img1_path,img2_path, transform=None):
|
|||
|
|
# self.coords = pd.read_csv(coords_csv)
|
|||
|
|
# self.hr1_dir = hr1_dir
|
|||
|
|
# self.hr2_dir = hr2_dir
|
|||
|
|
# self.label_dir = label_dir
|
|||
|
|
self.img1_path = img1_path
|
|||
|
|
self.img2_path = img2_path
|
|||
|
|
self.transform = transform
|
|||
|
|
self.label_transform = transforms.ToTensor()
|
|||
|
|
|
|||
|
|
def __len__(self):
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def __getitem__(self, idx):
|
|||
|
|
# row = self.coords.iloc[idx]
|
|||
|
|
# x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
|
|||
|
|
# fname = row['orig_name']
|
|||
|
|
|
|||
|
|
# 读取图像并裁剪 patch
|
|||
|
|
img1 = Image.open(self.img1_path).convert('RGB')
|
|||
|
|
img2 = Image.open(self.img2_path).convert('RGB')
|
|||
|
|
img1 = img1.crop((x, y, x + w, y + h))
|
|||
|
|
img2 = img2.crop((x, y, x + w, y + h))
|
|||
|
|
|
|||
|
|
# 标签
|
|||
|
|
if self.label_dir:
|
|||
|
|
label = Image.open(os.path.join(self.label_dir, fname)).convert('L')
|
|||
|
|
label = label.crop((x, y, x + w, y + h))
|
|||
|
|
label = self.label_transform(label)
|
|||
|
|
else:
|
|||
|
|
label = torch.zeros((1, h, w)) # dummy label
|
|||
|
|
|
|||
|
|
# 图像 transform
|
|||
|
|
if self.transform:
|
|||
|
|
img1 = self.transform(img1)
|
|||
|
|
img2 = self.transform(img2)
|
|||
|
|
|
|||
|
|
# padding 到 256x256,并记录原始 h, w
|
|||
|
|
pad_h = 256 - h
|
|||
|
|
pad_w = 256 - w
|
|||
|
|
if pad_h > 0 or pad_w > 0:
|
|||
|
|
img1 = F.pad(img1, (0, pad_w, 0, pad_h)) # pad right & bottom
|
|||
|
|
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
|
|||
|
|
label = F.pad(label, (0, pad_w, 0, pad_h))
|
|||
|
|
|
|||
|
|
patch_name = f"{os.path.splitext(fname)[0]}_{x}_{y}"
|
|||
|
|
return img1, img2, label, patch_name, fname, x, y, h, w # 返回原始大小
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# class TestDatasetFromFolder(Dataset):
|
|||
|
|
# def __init__(self, args, Time1_dir, Time2_dir, Label_dir):
|
|||
|
|
# super(TestDatasetFromFolder, self).__init__()
|
|||
|
|
|
|||
|
|
# # 确保 args.suffix 存在,否则使用默认值
|
|||
|
|
# suffixes = getattr(args, 'suffix', ['.png', '.jpg'])
|
|||
|
|
|
|||
|
|
# # 筛选符合后缀的文件
|
|||
|
|
# datalist = [
|
|||
|
|
# name for name in os.listdir(Time1_dir)
|
|||
|
|
# if os.path.splitext(name)[1].lower() in suffixes
|
|||
|
|
# ]
|
|||
|
|
|
|||
|
|
# print(f"找到 {len(datalist)} 个文件: {datalist[:5]}...") # 调试输出
|
|||
|
|
|
|||
|
|
# self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)]
|
|||
|
|
# self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)]
|
|||
|
|
# self.image3_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
|
|||
|
|
|
|||
|
|
# self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
|
|||
|
|
# self.label_transform = get_transform()
|
|||
|
|
|
|||
|
|
# def __getitem__(self, index):
|
|||
|
|
# image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB'))
|
|||
|
|
# image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB'))
|
|||
|
|
|
|||
|
|
# label = self.label_transform(Image.open(self.image3_filenames[index]))
|
|||
|
|
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
|||
|
|
|
|||
|
|
# image_name = self.image1_filenames[index].split('/', -1)
|
|||
|
|
# image_name = image_name[len(image_name)-1]
|
|||
|
|
|
|||
|
|
# return image1, image2, label, image_name
|
|||
|
|
|
|||
|
|
# def __len__(self):
|
|||
|
|
# return len(self.image1_filenames)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class trainImageAug(object):
|
|||
|
|
def __init__(self, crop = True, augment = True, angle = 30):
|
|||
|
|
self.crop =crop
|
|||
|
|
self.augment = augment
|
|||
|
|
self.angle = angle
|
|||
|
|
|
|||
|
|
def __call__(self, image1, image2, mask):
|
|||
|
|
if self.crop:
|
|||
|
|
w = np.random.randint(0,256)
|
|||
|
|
h = np.random.randint(0,256)
|
|||
|
|
box = (w, h, w+256, h+256)
|
|||
|
|
image1 = image1.crop(box)
|
|||
|
|
image2 = image2.crop(box)
|
|||
|
|
mask = mask.crop(box)
|
|||
|
|
if self.augment:
|
|||
|
|
prop = np.random.uniform(0, 1)
|
|||
|
|
if prop < 0.15:
|
|||
|
|
image1 = image1.transpose(Image.FLIP_LEFT_RIGHT)
|
|||
|
|
image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
|
|||
|
|
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
|||
|
|
elif prop < 0.3:
|
|||
|
|
image1 = image1.transpose(Image.FLIP_TOP_BOTTOM)
|
|||
|
|
image2 = image2.transpose(Image.FLIP_TOP_BOTTOM)
|
|||
|
|
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
|
|||
|
|
elif prop < 0.5:
|
|||
|
|
image1 = image1.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
|||
|
|
image2 = image2.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
|||
|
|
mask = mask.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
|||
|
|
|
|||
|
|
return image1, image2, mask
|
|||
|
|
|
|||
|
|
def get_transform(convert=True, normalize=False):
|
|||
|
|
transform_list = []
|
|||
|
|
if convert:
|
|||
|
|
transform_list += [
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
]
|
|||
|
|
if normalize:
|
|||
|
|
transform_list += [
|
|||
|
|
# transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
|
|||
|
|
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
|
|||
|
|
return transforms.Compose(transform_list)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DynamicPatchDataset(Dataset):
|
|||
|
|
def __init__(self, coords_csv, img1_dir, img2_dir, label_dir,
|
|||
|
|
crop=True, augment=True, angle=30, num_classes=2):
|
|||
|
|
self.coords = pd.read_csv(coords_csv)
|
|||
|
|
self.img1_dir = img1_dir
|
|||
|
|
self.img2_dir = img2_dir
|
|||
|
|
self.label_dir = label_dir
|
|||
|
|
self.num_classes = num_classes
|
|||
|
|
|
|||
|
|
# 与原 DA_DatasetFromFolder 一致的设置
|
|||
|
|
self.data_augment = trainImageAug(crop=crop, augment=augment, angle=angle)
|
|||
|
|
self.img_transform = get_transform(convert=True, normalize=True)
|
|||
|
|
self.lab_transform = get_transform()
|
|||
|
|
|
|||
|
|
self.open_images = {} # 缓存图像
|
|||
|
|
|
|||
|
|
def _load_image(self, path):
|
|||
|
|
if path not in self.open_images:
|
|||
|
|
self.open_images[path] = rasterio.open(path)
|
|||
|
|
return self.open_images[path]
|
|||
|
|
|
|||
|
|
def __len__(self):
|
|||
|
|
return len(self.coords)
|
|||
|
|
|
|||
|
|
def __getitem__(self, idx):
|
|||
|
|
row = self.coords.iloc[idx]
|
|||
|
|
fname = row['orig_name']
|
|||
|
|
x, y, h, w = int(row['x']), int(row['y']), int(row['h']), int(row['w'])
|
|||
|
|
|
|||
|
|
# 打开大图并裁剪 patch
|
|||
|
|
img1 = self._load_image(os.path.join(self.img1_dir, fname))
|
|||
|
|
img2 = self._load_image(os.path.join(self.img2_dir, fname))
|
|||
|
|
label = self._load_image(os.path.join(self.label_dir, fname))
|
|||
|
|
|
|||
|
|
window = Window(x, y, w, h)
|
|||
|
|
|
|||
|
|
patch1 = img1.read(window=window) # [C, H, W]
|
|||
|
|
patch2 = img2.read(window=window)
|
|||
|
|
patch_label = label.read(1, window=window) # [H, W]
|
|||
|
|
|
|||
|
|
# 转换为 PIL Image 以便和原有增强方式兼容
|
|||
|
|
patch1 = Image.fromarray(np.transpose(patch1, (1, 2, 0)).astype(np.uint8)) # [H, W, C]
|
|||
|
|
patch2 = Image.fromarray(np.transpose(patch2, (1, 2, 0)).astype(np.uint8))
|
|||
|
|
patch_label = Image.fromarray(patch_label.astype(np.uint8))
|
|||
|
|
|
|||
|
|
# 数据增强(裁剪、旋转等)
|
|||
|
|
patch1, patch2, patch_label = self.data_augment(patch1, patch2, patch_label)
|
|||
|
|
|
|||
|
|
# 图像 transform(归一化、转 tensor)
|
|||
|
|
patch1 = self.img_transform(patch1)
|
|||
|
|
patch2 = self.img_transform(patch2)
|
|||
|
|
patch_label = self.lab_transform(patch_label)
|
|||
|
|
|
|||
|
|
# one-hot label
|
|||
|
|
patch_label = make_one_hot(patch_label.unsqueeze(0).long(), self.num_classes).squeeze(0)
|
|||
|
|
|
|||
|
|
return patch1, patch2, patch_label
|
|||
|
|
# class DA_DatasetFromFolder(Dataset):
|
|||
|
|
# def __init__(self, Image_dir1, Image_dir2, Label_dir, crop=True, augment = True, angle = 30):
|
|||
|
|
# super(DA_DatasetFromFolder, self).__init__()
|
|||
|
|
# # 获取图片列表
|
|||
|
|
# datalist = os.listdir(Image_dir1)
|
|||
|
|
# self.image_filenames1 = [join(Image_dir1, x) for x in datalist if is_image_file(x)]
|
|||
|
|
# self.image_filenames2 = [join(Image_dir2, x) for x in datalist if is_image_file(x)]
|
|||
|
|
# self.label_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
|
|||
|
|
# self.data_augment = trainImageAug(crop=crop, augment = augment, angle=angle)
|
|||
|
|
# self.img_transform = get_transform(convert=True, normalize=True)
|
|||
|
|
# self.lab_transform = get_transform()
|
|||
|
|
|
|||
|
|
# def __getitem__(self, index):
|
|||
|
|
# image1 = Image.open(self.image_filenames1[index]).convert('RGB')
|
|||
|
|
# image2 = Image.open(self.image_filenames2[index]).convert('RGB')
|
|||
|
|
# label = Image.open(self.label_filenames[index])
|
|||
|
|
# image1, image2, label = self.data_augment(image1, image2, label)
|
|||
|
|
# image1, image2 = self.img_transform(image1), self.img_transform(image2)
|
|||
|
|
# label = self.lab_transform(label)
|
|||
|
|
# label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
|||
|
|
# return image1, image2, label
|
|||
|
|
|
|||
|
|
# def __len__(self):
|
|||
|
|
# return len(self.image_filenames1)
|
|||
|
|
|
|||
|
|
class TestDatasetFromPic(Dataset):
|
|||
|
|
def __init__(self, img1_path, img2_path, label_path=None, transform=None,
|
|||
|
|
patch_size=256, stride=256, pad_to_patch=False):
|
|||
|
|
"""
|
|||
|
|
动态滑动窗口测试数据集(单张大图版本)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
img1_path (str): 第一期影像的完整路径
|
|||
|
|
img2_path (str): 第二期影像的完整路径
|
|||
|
|
label_path (str, optional): 标签路径(测试时可能不需要)
|
|||
|
|
transform (callable, optional): 图像变换(如归一化)
|
|||
|
|
patch_size (int): 裁剪的patch尺寸(默认256)
|
|||
|
|
stride (int): 滑动步长(默认256,无重叠)
|
|||
|
|
pad_to_patch (bool): 是否将图像边缘不足patch_size的部分填充到完整patch
|
|||
|
|
"""
|
|||
|
|
self.img1_path = img1_path
|
|||
|
|
self.img2_path = img2_path
|
|||
|
|
self.label_path = label_path
|
|||
|
|
self.transform = transform or transforms.Compose([
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|||
|
|
])
|
|||
|
|
self.patch_size = patch_size
|
|||
|
|
self.stride = stride
|
|||
|
|
self.pad_to_patch = pad_to_patch
|
|||
|
|
self.label_transform = transforms.ToTensor()
|
|||
|
|
|
|||
|
|
# 提取文件名(用于保存结果时命名)
|
|||
|
|
self.img1_name = os.path.basename(img1_path)
|
|||
|
|
self.img2_name = os.path.basename(img2_path)
|
|||
|
|
|
|||
|
|
# 动态生成坐标列表
|
|||
|
|
self.coords = self._generate_coords()
|
|||
|
|
|
|||
|
|
def _generate_coords(self):
|
|||
|
|
"""为单张图像生成滑动窗口坐标"""
|
|||
|
|
coords = []
|
|||
|
|
|
|||
|
|
# 获取图像尺寸
|
|||
|
|
with Image.open(self.img1_path) as img:
|
|||
|
|
width, height = img.size
|
|||
|
|
|
|||
|
|
# 计算需要裁剪的patch数量(考虑边缘填充)
|
|||
|
|
if self.pad_to_patch:
|
|||
|
|
pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
|
|||
|
|
pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
|
|||
|
|
n_patches_w = (width + pad_w) // self.stride
|
|||
|
|
n_patches_h = (height + pad_h) // self.stride
|
|||
|
|
else:
|
|||
|
|
n_patches_w = (width - self.patch_size) // self.stride + 1
|
|||
|
|
n_patches_h = (height - self.patch_size) // self.stride + 1
|
|||
|
|
|
|||
|
|
# 生成所有patch的坐标
|
|||
|
|
for i in range(n_patches_w):
|
|||
|
|
for j in range(n_patches_h):
|
|||
|
|
x = i * self.stride
|
|||
|
|
y = j * self.stride
|
|||
|
|
# 确保不超出原始图像边界(不填充时)
|
|||
|
|
if not self.pad_to_patch:
|
|||
|
|
if x + self.patch_size > width or y + self.patch_size > height:
|
|||
|
|
continue
|
|||
|
|
coords.append({
|
|||
|
|
'orig_name': self.img1_name, # 保持与原代码兼容
|
|||
|
|
'x': x,
|
|||
|
|
'y': y,
|
|||
|
|
'h': self.patch_size,
|
|||
|
|
'w': self.patch_size
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return coords
|
|||
|
|
|
|||
|
|
def __len__(self):
|
|||
|
|
return len(self.coords)
|
|||
|
|
|
|||
|
|
def __getitem__(self, idx):
|
|||
|
|
"""获取单个patch的数据"""
|
|||
|
|
row = self.coords[idx]
|
|||
|
|
x, y, h, w = row['x'], row['y'], row['h'], row['w']
|
|||
|
|
|
|||
|
|
# 加载图像并裁剪
|
|||
|
|
img1 = Image.open(self.img1_path).convert('RGB')
|
|||
|
|
img2 = Image.open(self.img2_path).convert('RGB')
|
|||
|
|
|
|||
|
|
# 处理标签(如果有)
|
|||
|
|
if self.label_path:
|
|||
|
|
label = Image.open(self.label_path).convert('L')
|
|||
|
|
label = label.crop((x, y, x + w, y + h))
|
|||
|
|
label = self.label_transform(label)
|
|||
|
|
else:
|
|||
|
|
label = torch.zeros((1, h, w)) # 虚拟标签
|
|||
|
|
|
|||
|
|
# 裁剪图像
|
|||
|
|
img1 = img1.crop((x, y, x + w, y + h))
|
|||
|
|
img2 = img2.crop((x, y, x + w, y + h))
|
|||
|
|
|
|||
|
|
# 应用变换
|
|||
|
|
if self.transform:
|
|||
|
|
img1 = self.transform(img1)
|
|||
|
|
img2 = self.transform(img2)
|
|||
|
|
|
|||
|
|
# 填充到统一尺寸(如果启用)
|
|||
|
|
if self.pad_to_patch:
|
|||
|
|
pad_h = self.patch_size - h
|
|||
|
|
pad_w = self.patch_size - w
|
|||
|
|
if pad_h > 0 or pad_w > 0:
|
|||
|
|
img1 = F.pad(img1, (0, pad_w, 0, pad_h))
|
|||
|
|
img2 = F.pad(img2, (0, pad_w, 0, pad_h))
|
|||
|
|
label = F.pad(label, (0, pad_w, 0, pad_h))
|
|||
|
|
|
|||
|
|
# 返回结果(保持与原代码兼容)
|
|||
|
|
patch_name = f"{os.path.splitext(self.img1_name)[0]}_{x}_{y}.png"
|
|||
|
|
return img1, img2, label, patch_name, self.img1_name, x, y, h, w
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 模拟 args(可以是 argparse.Namespace 或硬编码的 Namespace)
|
|||
|
|
args = Namespace()
|
|||
|
|
args.suffix = ['.png', '.jpg'] # 确保 suffix 存在
|
|||
|
|
|
|||
|
|
dataset = TestDatasetFromFolder(
|
|||
|
|
args,
|
|||
|
|
Time1_dir="path/to/Time1",
|
|||
|
|
Time2_dir="path/to/Time2",
|
|||
|
|
Label_dir="path/to/Label"
|
|||
|
|
)
|