257 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#coding=utf-8
from os.path import join
import torch
from PIL import Image, ImageEnhance
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
import numpy as np
import torchvision.transforms as transforms
import os
from argparse import Namespace
import imageio
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['.tif','.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
def calMetric_iou(predict, label):
tp = np.sum(np.logical_and(predict == 1, label == 1))
fp = np.sum(predict==1)
fn = np.sum(label == 1)
return tp,fp+fn-tp
def getDataList(img_path):
dataline = open(img_path, 'r').readlines()
datalist =[]
for line in dataline:
temp = line.strip('\n')
datalist.append(temp)
return datalist
def make_one_hot(input, num_classes):
"""Convert class index tensor to one hot encoding tensor.
Args:
input: A tensor of shape [N, 1, *]
num_classes: An int of number of class
Returns:
A tensor of shape [N, num_classes, *]
"""
shape = np.array(input.shape)
shape[1] = num_classes
shape = tuple(shape)
result = torch.zeros(shape)
result = result.scatter_(1, input.cpu(), 1)
return result
def get_transform(convert=True, normalize=False):
transform_list = []
if convert:
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
class LoadDatasetFromFolder(Dataset):
def __init__(self, args, hr1_path, hr2_path, lab_path):
super().__init__()
# 安全获取suffix参数带默认值
# suffixes = getattr(args, 'suffix', ['.tif','.png','.jpg'])
suffixes = ['.tif','.png','.jpg']
if not isinstance(suffixes, list):
suffixes = [suffixes]
# 调试打印
print(f"[DEBUG] Using suffixes: {suffixes}")
# 获取文件列表(兼容不同后缀格式)
datalist = []
for suffix in suffixes:
datalist.extend([
f for f in os.listdir(hr1_path)
if f.lower().endswith(suffix.lower())
])
# 去重
datalist = list(set(datalist))
print(f"[DEBUG] Found {len(datalist)} files")
self.hr1_filenames = [join(hr1_path, x) for x in datalist if is_image_file(x)]
self.hr2_filenames = [join(hr2_path, x) for x in datalist if is_image_file(x)]
self.lab_filenames = [join(lab_path, x) for x in datalist if is_image_file(x)]
self.transform = get_transform(convert=True, normalize=True)
self.label_transform = get_transform()
def __getitem__(self, index):
hr1_img = self.transform(Image.open(self.hr1_filenames[index]).convert('RGB'))
# lr2_img = self.transform(Image.open(self.lr2_filenames[index]).convert('RGB'))
hr2_img = self.transform(Image.open(self.hr2_filenames[index]).convert('RGB'))
label = self.label_transform(Image.open(self.lab_filenames[index]))
label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
return hr1_img, hr2_img, label
def __len__(self):
return len(self.hr1_filenames)
class TestDatasetFromFolder(Dataset):
def __init__(self, suffixes, Time1_dir, Time2_dir, Label_dir):
super(TestDatasetFromFolder, self).__init__()
# 确保 args.suffix 存在,否则使用默认值
# suffixes = getattr(args, 'suffix', ['.png', '.jpg'])
# # 筛选符合后缀的文件
# datalist = [
# name for name in os.listdir(Time1_dir)
# if os.path.splitext(name)[1].lower() in suffixes
# ]
#
# print(f"找到 {len(datalist)} 个文件: {datalist[:5]}...") # 调试输出
# self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)]
# self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)]
# self.image3_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
self.transform = get_transform(convert=True, normalize=True) # convert to tensor and normalize to [-1,1]
self.label_transform = get_transform()
def __getitem__(self, index):
image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB'))
image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB'))
label = self.label_transform(Image.open(self.image3_filenames[index]))
label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
image_name = self.image1_filenames[index].split('/', -1)
image_name = image_name[len(image_name)-1]
return image1, image2, label, image_name
def __len__(self):
return len(self.image1_filenames)
class TestDatasetFromFolderWithoutLabel(Dataset):
def __init__(self, suffixes, Time1_dir, Time2_dir):
super(TestDatasetFromFolderWithoutLabel, self).__init__()
suffixes = getattr(args, 'suffix', ['.png', '.jpg'])
datalist = [
name for name in os.listdir(Time1_dir)
if os.path.splitext(name)[1].lower() in suffixes
]
self.image1_filenames = [join(Time1_dir, x) for x in datalist if is_image_file(x)]
self.image2_filenames = [join(Time2_dir, x) for x in datalist if is_image_file(x)]
self.transform = get_transform(convert=True, normalize=True)
def __getitem__(self, index):
image1 = self.transform(Image.open(self.image1_filenames[index]).convert('RGB'))
image2 = self.transform(Image.open(self.image2_filenames[index]).convert('RGB'))
image_name = self.image1_filenames[index].split('/')[-1]
return image1, image2, image_name # 返回图像名称而非标签
def __len__(self):
return len(self.image1_filenames)
class trainImageAug(object):
def __init__(self, crop = True, augment = True, angle = 30):
self.crop =crop
self.augment = augment
self.angle = angle
def __call__(self, image1, image2, mask):
if self.crop:
w = np.random.randint(0,256)
h = np.random.randint(0,256)
box = (w, h, w+256, h+256)
image1 = image1.crop(box)
image2 = image2.crop(box)
mask = mask.crop(box)
if self.augment:
prop = np.random.uniform(0, 1)
if prop < 0.15:
image1 = image1.transpose(Image.FLIP_LEFT_RIGHT)
image2 = image2.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
elif prop < 0.3:
image1 = image1.transpose(Image.FLIP_TOP_BOTTOM)
image2 = image2.transpose(Image.FLIP_TOP_BOTTOM)
mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
elif prop < 0.5:
image1 = image1.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
image2 = image2.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
mask = mask.rotate(transforms.RandomRotation.get_params([-self.angle, self.angle]))
return image1, image2, mask
def get_transform(convert=True, normalize=False):
transform_list = []
if convert:
transform_list += [
transforms.ToTensor(),
]
if normalize:
transform_list += [
# transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
class DA_DatasetFromFolder(Dataset):
def __init__(self, Image_dir1, Image_dir2, Label_dir, crop=True, augment = True, angle = 30):
super(DA_DatasetFromFolder, self).__init__()
# 获取图片列表
datalist = os.listdir(Image_dir1)
self.image_filenames1 = [join(Image_dir1, x) for x in datalist if is_image_file(x)]
self.image_filenames2 = [join(Image_dir2, x) for x in datalist if is_image_file(x)]
self.label_filenames = [join(Label_dir, x) for x in datalist if is_image_file(x)]
self.data_augment = trainImageAug(crop=crop, augment = augment, angle=angle)
self.img_transform = get_transform(convert=True, normalize=True)
self.lab_transform = get_transform()
def __getitem__(self, index):
image1 = Image.open(self.image_filenames1[index]).convert('RGB')
image2 = Image.open(self.image_filenames2[index]).convert('RGB')
label = Image.open(self.label_filenames[index])
image1, image2, label = self.data_augment(image1, image2, label)
image1, image2 = self.img_transform(image1), self.img_transform(image2)
label = self.lab_transform(label)
label = make_one_hot(label.unsqueeze(0).long(), 2).squeeze(0)
return image1, image2, label
def __len__(self):
return len(self.image_filenames1)
if __name__ == '__main__':
# 解析参数确保包含suffix
args = parse_args()
# 调试打印
print(f"[MAIN] Args namespace: {vars(args)}")
# 确保必要参数存在
if not hasattr(args, 'suffix'):
args.suffix = ['.tif','.png','.jpg'] # 硬编码保底
# 初始化数据集
try:
if args.mode == 'train':
dataset = LoadDatasetFromFolder(args,
args.hr1_train,
args.hr2_train,
args.lab_train)
else:
dataset = TestDatasetFromFolder(args,
args.path_img1,
args.path_img2,
args.path_lab)
except Exception as e:
print(f"初始化数据集失败: {str(e)}")
raise