257 lines
10 KiB
Python
257 lines
10 KiB
Python
#coding=utf-8
|
||
from os.path import join
|
||
import torch
|
||
from PIL import Image, ImageEnhance
|
||
from torch.utils.data.dataset import Dataset
|
||
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
|
||
import numpy as np
|
||
import torchvision.transforms as transforms
|
||
import os
|
||
from argparse import Namespace
|
||
import imageio
|
||
|
||
|
||
def is_image_file(filename):
|
||
return any(filename.endswith(extension) for extension in ['.tif','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
|
||
|
||
def calMetric_iou(predict, label):
|
||
tp = np.sum(np.logical_and(predict == 1, label == 1))
|
||
fp = np.sum(predict==1)
|
||
fn = np.sum(label == 1)
|
||
return tp,fp+fn-tp
|
||
|
||
|
||
def getDataList(img_path):
|
||
dataline = open(img_path, 'r').readlines()
|
||
datalist =[]
|
||
for line in dataline:
|
||
temp = line.strip('\n')
|
||
datalist.append(temp)
|
||
return datalist
|
||
|
||
|
||
def make_one_hot(input, num_classes):
|
||
"""Convert class index tensor to one hot encoding tensor.
|
||
|
||
Args:
|
||
input: A tensor of shape [N, 1, *]
|
||
num_classes: An int of number of class
|
||
Returns:
|
||
A tensor of shape [N, num_classes, *]
|
||
"""
|
||
shape = np.array(input.shape)
|
||
shape[1] = num_classes
|
||
shape = tuple(shape)
|
||
result = torch.zeros(shape)
|
||
result = result.scatter_(1, input.cpu(), 1)
|
||
return result
|
||
|
||
|
||
def get_transform(convert=True, normalize=False):
|
||
transform_list = []
|
||
if convert:
|
||
transform_list += [transforms.ToTensor()]
|
||
if normalize:
|
||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
||
(0.5, 0.5, 0.5))]
|
||
return transforms.Compose(transform_list)
|
||
|
||
|
||
|
||
class LoadDatasetFromFolder(Dataset):
|
||
def __init__(self, args, hr1_path, hr2_path, lab_path):
|
||
super().__init__()
|
||
|
||
# 安全获取suffix参数,带默认值
|
||
# suffixes = getattr(args, 'suffix', ['.tif','.png','.jpg'])
|
||
suffixes = ['.tif','.png','.jpg']
|
||
if not isinstance(suffixes, list):
|
||
suffixes = [suffixes]
|
||
|
||
# 调试打印
|
||
print(f"[DEBUG] Using suffixes: {suffixes}")
|
||
|
||
# 获取文件列表(兼容不同后缀格式)
|
||
datalist = []
|
||
for suffix in suffixes:
|
||
datalist.extend([
|
||
f for f in os.listdir(hr1_path)
|
||
if f.lower().endswith(suffix.lower())
|
||
])
|
||
|
||
# 去重
|
||
datalist = list(set(datalist))
|
||
print(f"[DEBUG] Found {len(datalist)} files")
|
||
self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)]
|
||
self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)]
|
||
self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)]
|
||
|
||
self.transform = get_transform(convert=True, normalize=True)
|
||
self.label_transform = get_transform()
|
||
|
||
def __getitem__(self, index):
|
||
hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB'))
|
||
# lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB'))
|
||
hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB'))
|
||
|
||
label = self.label_transform(Image.open(self.lab_filenames[index]))
|
||
label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
||
|
||
return hr1_img, hr2_img, label
|
||
|
||
def __len__(self):
|
||
return len(self.hr1_filenames)
|
||
|
||
|
||
class TestDatasetFromFolder(Dataset):
|
||
def __init__(self, suffixes, Time1_dir, Time2_dir, Label_dir):
|
||
super(TestDatasetFromFolder, self).__init__()
|
||
|
||
# 确保 args.suffix 存在,否则使用默认值
|
||
# suffixes = getattr(args, 'suffix', ['.png', '.jpg'])
|
||
|
||
# # 筛选符合后缀的文件
|
||
# datalist = [
|
||
# name for name in os.listdir(Time1_dir)
|
||
# if os.path.splitext(name)[1].lower() in suffixes
|
||
# ]
|
||
#
|
||
# print(f"找到 {len(datalist)} 个文件: {datalist[:5]}...") # 调试输出
|
||
|
||
# self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)]
|
||
# self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)]
|
||
# self.image3_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
|
||
|
||
self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
|
||
self.label_transform = get_transform()
|
||
|
||
def __getitem__(self, index):
|
||
image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB'))
|
||
image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB'))
|
||
|
||
label = self.label_transform(Image.open(self.image3_filenames[index]))
|
||
label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
||
|
||
image_name = self.image1_filenames[index].split('/', -1)
|
||
image_name = image_name[len(image_name)-1]
|
||
|
||
return image1, image2, label, image_name
|
||
|
||
def __len__(self):
|
||
return len(self.image1_filenames)
|
||
|
||
class TestDatasetFromFolderWithoutLabel(Dataset):
|
||
def __init__(self, suffixes, Time1_dir, Time2_dir):
|
||
super(TestDatasetFromFolderWithoutLabel, self).__init__()
|
||
suffixes = getattr(args, 'suffix', ['.png', '.jpg'])
|
||
datalist = [
|
||
name for name in os.listdir(Time1_dir)
|
||
if os.path.splitext(name)[1].lower() in suffixes
|
||
]
|
||
self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)]
|
||
self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)]
|
||
self.transform = get_transform(convert=True, normalize=True)
|
||
|
||
def __getitem__(self, index):
|
||
image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB'))
|
||
image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB'))
|
||
image_name = self.image1_filenames[index].split('/')[-1]
|
||
return image1, image2, image_name # 返回图像名称而非标签
|
||
|
||
def __len__(self):
|
||
return len(self.image1_filenames)
|
||
|
||
class trainImageAug(object):
|
||
def __init__(self, crop = True, augment = True, angle = 30):
|
||
self.crop =crop
|
||
self.augment = augment
|
||
self.angle = angle
|
||
|
||
def __call__(self, image1, image2, mask):
|
||
if self.crop:
|
||
w = np.random.randint(0,256)
|
||
h = np.random.randint(0,256)
|
||
box = (w, h, w+256, h+256)
|
||
image1 = image1.crop(box)
|
||
image2 = image2.crop(box)
|
||
mask = mask.crop(box)
|
||
if self.augment:
|
||
prop = np.random.uniform(0, 1)
|
||
if prop < 0.15:
|
||
image1 = image1.transpose(Image.FLIP_LEFT_RIGHT)
|
||
image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
|
||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||
elif prop < 0.3:
|
||
image1 = image1.transpose(Image.FLIP_TOP_BOTTOM)
|
||
image2 = image2.transpose(Image.FLIP_TOP_BOTTOM)
|
||
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
|
||
elif prop < 0.5:
|
||
image1 = image1.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
||
image2 = image2.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
||
mask = mask.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
|
||
|
||
return image1, image2, mask
|
||
|
||
def get_transform(convert=True, normalize=False):
|
||
transform_list = []
|
||
if convert:
|
||
transform_list += [
|
||
transforms.ToTensor(),
|
||
]
|
||
if normalize:
|
||
transform_list += [
|
||
# transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
|
||
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
|
||
return transforms.Compose(transform_list)
|
||
|
||
|
||
class DA_DatasetFromFolder(Dataset):
|
||
def __init__(self, Image_dir1, Image_dir2, Label_dir, crop=True, augment = True, angle = 30):
|
||
super(DA_DatasetFromFolder, self).__init__()
|
||
# 获取图片列表
|
||
datalist = os.listdir(Image_dir1)
|
||
self.image_filenames1 = [join(Image_dir1, x) for x in datalist if is_image_file(x)]
|
||
self.image_filenames2 = [join(Image_dir2, x) for x in datalist if is_image_file(x)]
|
||
self.label_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
|
||
self.data_augment = trainImageAug(crop=crop, augment = augment, angle=angle)
|
||
self.img_transform = get_transform(convert=True, normalize=True)
|
||
self.lab_transform = get_transform()
|
||
|
||
def __getitem__(self, index):
|
||
image1 = Image.open(self.image_filenames1[index]).convert('RGB')
|
||
image2 = Image.open(self.image_filenames2[index]).convert('RGB')
|
||
label = Image.open(self.label_filenames[index])
|
||
image1, image2, label = self.data_augment(image1, image2, label)
|
||
image1, image2 = self.img_transform(image1), self.img_transform(image2)
|
||
label = self.lab_transform(label)
|
||
label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
|
||
return image1, image2, label
|
||
|
||
def __len__(self):
|
||
return len(self.image_filenames1)
|
||
if __name__ == '__main__':
|
||
# 解析参数(确保包含suffix)
|
||
args = parse_args()
|
||
|
||
# 调试打印
|
||
print(f"[MAIN] Args namespace: {vars(args)}")
|
||
|
||
# 确保必要参数存在
|
||
if not hasattr(args, 'suffix'):
|
||
args.suffix = ['.tif','.png','.jpg'] # 硬编码保底
|
||
|
||
# 初始化数据集
|
||
try:
|
||
if args.mode == 'train':
|
||
dataset = LoadDatasetFromFolder(args,
|
||
args.hr1_train,
|
||
args.hr2_train,
|
||
args.lab_train)
|
||
else:
|
||
dataset = TestDatasetFromFolder(args,
|
||
args.path_img1,
|
||
args.path_img2,
|
||
args.path_lab)
|
||
except Exception as e:
|
||
print(f"初始化数据集失败: {str(e)}")
|
||
raise |