# ChangeDetection.py import os import torch import numpy as np from tqdm import tqdm import itertools import torch.optim as optim from torch.utils.data import DataLoader from cropland_module.data_utils import get_transform, calMetric_iou, TestDatasetFromFolder from cropland_module.loss.losses import cross_entropy from cropland_module.model.network import CDNet # from model.network import CDNet # from data_utils import DA_DatasetFromFolder, LoadDatasetFromFolder, TestDatasetFromFolder, calMetric_iou, \ # TestDatasetFromFolderWithoutLabel, get_transform # from loss.losses import cross_entropy from PIL import Image class ChangeDetectionModule: def __init__(self, gpu_id,img_size,lr): # 参数完整性检查 # required_args = ['gpu_id', 'img_size', 'lr'] # for arg in required_args: # if not hasattr(args, arg): # raise ValueError(f"Missing required argument: {arg}") # self.args = args os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.best_iou = 0.0 # Initialize model self.model = CDNet(img_size=img_size).to(self.device, dtype=torch.float) if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs!") self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count())) # Loss and optimizer self.criterion = cross_entropy().to(self.device, dtype=torch.float) self.optimizer = optim.Adam(itertools.chain(self.model.parameters()), lr=lr, betas=(0.9, 0.999)) # def train(self): # print("Starting training...") # train_set = DA_DatasetFromFolder(self.args.hr1_train, self.args.hr2_train, self.args.lab_train, crop=False) # val_set = LoadDatasetFromFolder(self.args, self.args.hr1_val, self.args.hr2_val, self.args.lab_val) # train_loader = DataLoader(train_set, num_workers=self.args.num_workers, # batch_size=self.args.batchsize, shuffle=True) # val_loader = DataLoader(val_set, num_workers=self.args.num_workers, # batch_size=self.args.val_batchsize, shuffle=False) # # for epoch in range(1, self.args.num_epochs + 1): # self.model.train() # train_bar = tqdm(train_loader, desc=f"Epoch [{epoch}/{self.args.num_epochs}]") # for hr_img1, hr_img2, label in train_bar: # hr_img1 = hr_img1.to(self.device, dtype=torch.float) # hr_img2 = hr_img2.to(self.device, dtype=torch.float) # label = label.to(self.device, dtype=torch.float) # label = torch.argmax(label, 1).unsqueeze(1).float() # # self.optimizer.zero_grad() # out1, out2, out3 = self.model(hr_img1, hr_img2) # cd_loss = (self.criterion(out1, label) + # self.criterion(out2, label) + # self.criterion(out3, label)) # cd_loss.backward() # self.optimizer.step() # train_bar.set_postfix(loss=cd_loss.item()) # # # Evaluate only on best epoch (IoU improved) # val_iou = self.validate(val_loader) # print(f"Validation IoU: {val_iou:.4f}") # # if val_iou > self.best_iou: # self.best_iou = val_iou # self._save_best_model() # print(f"New best model saved with IoU: {val_iou:.4f}") # def validate(self, val_loader): # self.model.eval() # inter, union = 0, 0 # with torch.no_grad(): # for hr_img1, hr_img2, label in tqdm(val_loader, desc='Validating'): # hr_img1 = hr_img1.to(self.device, dtype=torch.float) # hr_img2 = hr_img2.to(self.device, dtype=torch.float) # label = label.to(self.device, dtype=torch.float) # label = torch.argmax(label, 1).unsqueeze(1).float() # # output, _, _ = self.model(hr_img1, hr_img2) # pred = torch.argmax(output, 1).unsqueeze(1).float() # gt = (label > 0).float() # prob = (pred > 0).float() # # gt_np = np.squeeze(gt.cpu().detach().numpy()) # pred_np = np.squeeze(prob.cpu().detach().numpy()) # intr, unn = calMetric_iou(gt_np, pred_np) # inter += intr # union += unn # # iou = inter / union if union != 0 else 0 # return iou # def _save_best_model(self): # if not os.path.exists(self.args.model_dir): # os.makedirs(self.args.model_dir) # # Remove any existing .pth files # for f in os.listdir(self.args.model_dir): # if f.endswith('.pth'): # os.remove(os.path.join(self.args.model_dir, f)) # save_path = os.path.join(self.args.model_dir, 'best_model.pth') # torch.save(self.model.state_dict(), save_path) # print(f"Best model saved to {save_path}") def load_model(self, model_path): self.model.load_state_dict(torch.load(model_path, map_location=self.device)) self.model.eval() print(f"Loaded model from {model_path}") def predict(self,save_dir): print("Starting prediction...") if not os.path.exists(save_dir): os.makedirs(save_dir) test_set = TestDatasetFromFolder(self.args, self.args.path_img1, self.args.path_img2, self.args.path_lab) test_loader = DataLoader(test_set, num_workers=24, batch_size=self.args.batch_size, shuffle=False) self.model.eval() inter, union = 0, 0 with torch.no_grad(): for image1, image2, label, image_names in tqdm(test_loader, desc='Testing'): image1 = image1.to(self.device, dtype=torch.float) image2 = image2.to(self.device, dtype=torch.float) label = label.to(self.device, dtype=torch.float) output, _, _ = self.model(image1, image2) pred = torch.argmax(output, 1).unsqueeze(1) label = torch.argmax(label, 1).unsqueeze(1) for i in range(pred.size(0)): gt_value = (label[i] > 0).float() prob = (pred[i] > 0).float() gt_np = np.squeeze(gt_value.cpu().detach().numpy()) prob_np = np.squeeze(prob.cpu().detach().numpy()) intr, unn = calMetric_iou(gt_np, prob_np) inter += intr union += unn binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8') # 将预测值转换为0和255 result = Image.fromarray(binary_result) # 使用转换后的二值图像 result.save(os.path.join(self.args.save_dir, os.path.basename(image_names[i]))) iou = inter / union if union != 0 else 0 print(f"Test IoU: {iou:.4f}") # def predict_without_label(self): # print("Starting prediction...") # if not os.path.exists(self.args.save_dir): # os.makedirs(self.args.save_dir) # # # 使用无标签的数据集 # test_set = TestDatasetFromFolderWithoutLabel(self.args, self.args.path_img1, self.args.path_img2) # test_loader = DataLoader(test_set, num_workers=24, batch_size=self.args.batch_size, shuffle=False) # # self.model.eval() # with torch.no_grad(): # for image1, image2, image_names in tqdm(test_loader, desc='Testing'): # image1 = image1.to(self.device, dtype=torch.float) # image2 = image2.to(self.device, dtype=torch.float) # # output, _, _ = self.model(image1, image2) # pred = torch.argmax(output, 1).unsqueeze(1) # prob_np = pred.squeeze().cpu().detach().numpy() # # # 二值化并保存结果 # binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8') # result = Image.fromarray(binary_result) # result.save(os.path.join(self.args.save_dir, image_names)) def predict_without_label(self, img1, img2): """ 直接对输入的 image1 和 image2 进行预测。 Args: image1 (torch.Tensor 或 np.ndarray): 输入图像1(需与模型输入兼容)。 image2 (torch.Tensor 或 np.ndarray): 输入图像2(需与模型输入兼容)。 Returns: torch.Tensor: 模型的输出(未经过 argmax 处理)。 torch.Tensor: 二值化后的预测结果(0 或 255)。 """ print("Starting prediction...") # if not os.path.exists(self.args.save_dir): # os.makedirs(self.args.save_dir) transform = get_transform(convert=True, normalize=True) # image1 = transform(img1.convert('RGB')) # image2 = transform(img2.convert('RGB')) image1 = transform(img1.convert('RGB')).to(self.device) # 移动到设备 image2 = transform(img2.convert('RGB')).to(self.device) # 移动到设备 # # 确保输入是 Tensor,并移动到指定设备 # if isinstance(image1, np.ndarray): # image1 = torch.from_numpy(image1).to(self.device, dtype=torch.float) # if isinstance(image2, np.ndarray): # image2 = torch.from_numpy(image2).to(self.device, dtype=torch.float) # 模型推理 self.model.eval() with torch.no_grad(): output, _, _ = self.model(image1.unsqueeze(0), image2.unsqueeze(0)) # 添加 batch 维度 pred = torch.argmax(output, 1).squeeze(0) # 移除 batch 维度 prob_np = pred.cpu().detach().numpy() # 二值化 binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8') return output, binary_result