222 lines
10 KiB
Python
222 lines
10 KiB
Python
# ChangeDetection.py
|
||
|
||
import os
|
||
import torch
|
||
import numpy as np
|
||
from tqdm import tqdm
|
||
import itertools
|
||
import torch.optim as optim
|
||
from torch.utils.data import DataLoader
|
||
|
||
from cropland_module.data_utils import get_transform, calMetric_iou, TestDatasetFromFolder
|
||
from cropland_module.loss.losses import cross_entropy
|
||
from cropland_module.model.network import CDNet
|
||
# from model.network import CDNet
|
||
# from data_utils import DA_DatasetFromFolder, LoadDatasetFromFolder, TestDatasetFromFolder, calMetric_iou, \
|
||
# TestDatasetFromFolderWithoutLabel, get_transform
|
||
# from loss.losses import cross_entropy
|
||
from PIL import Image
|
||
|
||
class ChangeDetectionModule:
|
||
|
||
|
||
|
||
|
||
def __init__(self, gpu_id,img_size,lr):
|
||
# 参数完整性检查
|
||
# required_args = ['gpu_id', 'img_size', 'lr']
|
||
# for arg in required_args:
|
||
# if not hasattr(args, arg):
|
||
# raise ValueError(f"Missing required argument: {arg}")
|
||
# self.args = args
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
self.best_iou = 0.0
|
||
|
||
# Initialize model
|
||
self.model = CDNet(img_size=img_size).to(self.device, dtype=torch.float)
|
||
if torch.cuda.device_count() > 1:
|
||
print(f"Using {torch.cuda.device_count()} GPUs!")
|
||
self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
|
||
|
||
# Loss and optimizer
|
||
self.criterion = cross_entropy().to(self.device, dtype=torch.float)
|
||
self.optimizer = optim.Adam(itertools.chain(self.model.parameters()),
|
||
lr=lr, betas=(0.9, 0.999))
|
||
|
||
# def train(self):
|
||
# print("Starting training...")
|
||
# train_set = DA_DatasetFromFolder(self.args.hr1_train, self.args.hr2_train, self.args.lab_train, crop=False)
|
||
# val_set = LoadDatasetFromFolder(self.args, self.args.hr1_val, self.args.hr2_val, self.args.lab_val)
|
||
# train_loader = DataLoader(train_set, num_workers=self.args.num_workers,
|
||
# batch_size=self.args.batchsize, shuffle=True)
|
||
# val_loader = DataLoader(val_set, num_workers=self.args.num_workers,
|
||
# batch_size=self.args.val_batchsize, shuffle=False)
|
||
#
|
||
# for epoch in range(1, self.args.num_epochs + 1):
|
||
# self.model.train()
|
||
# train_bar = tqdm(train_loader, desc=f"Epoch [{epoch}/{self.args.num_epochs}]")
|
||
# for hr_img1, hr_img2, label in train_bar:
|
||
# hr_img1 = hr_img1.to(self.device, dtype=torch.float)
|
||
# hr_img2 = hr_img2.to(self.device, dtype=torch.float)
|
||
# label = label.to(self.device, dtype=torch.float)
|
||
# label = torch.argmax(label, 1).unsqueeze(1).float()
|
||
#
|
||
# self.optimizer.zero_grad()
|
||
# out1, out2, out3 = self.model(hr_img1, hr_img2)
|
||
# cd_loss = (self.criterion(out1, label) +
|
||
# self.criterion(out2, label) +
|
||
# self.criterion(out3, label))
|
||
# cd_loss.backward()
|
||
# self.optimizer.step()
|
||
# train_bar.set_postfix(loss=cd_loss.item())
|
||
#
|
||
# # Evaluate only on best epoch (IoU improved)
|
||
# val_iou = self.validate(val_loader)
|
||
# print(f"Validation IoU: {val_iou:.4f}")
|
||
#
|
||
# if val_iou > self.best_iou:
|
||
# self.best_iou = val_iou
|
||
# self._save_best_model()
|
||
# print(f"New best model saved with IoU: {val_iou:.4f}")
|
||
|
||
# def validate(self, val_loader):
|
||
# self.model.eval()
|
||
# inter, union = 0, 0
|
||
# with torch.no_grad():
|
||
# for hr_img1, hr_img2, label in tqdm(val_loader, desc='Validating'):
|
||
# hr_img1 = hr_img1.to(self.device, dtype=torch.float)
|
||
# hr_img2 = hr_img2.to(self.device, dtype=torch.float)
|
||
# label = label.to(self.device, dtype=torch.float)
|
||
# label = torch.argmax(label, 1).unsqueeze(1).float()
|
||
#
|
||
# output, _, _ = self.model(hr_img1, hr_img2)
|
||
# pred = torch.argmax(output, 1).unsqueeze(1).float()
|
||
# gt = (label > 0).float()
|
||
# prob = (pred > 0).float()
|
||
#
|
||
# gt_np = np.squeeze(gt.cpu().detach().numpy())
|
||
# pred_np = np.squeeze(prob.cpu().detach().numpy())
|
||
# intr, unn = calMetric_iou(gt_np, pred_np)
|
||
# inter += intr
|
||
# union += unn
|
||
#
|
||
# iou = inter / union if union != 0 else 0
|
||
# return iou
|
||
|
||
# def _save_best_model(self):
|
||
# if not os.path.exists(self.args.model_dir):
|
||
# os.makedirs(self.args.model_dir)
|
||
# # Remove any existing .pth files
|
||
# for f in os.listdir(self.args.model_dir):
|
||
# if f.endswith('.pth'):
|
||
# os.remove(os.path.join(self.args.model_dir, f))
|
||
# save_path = os.path.join(self.args.model_dir, 'best_model.pth')
|
||
# torch.save(self.model.state_dict(), save_path)
|
||
# print(f"Best model saved to {save_path}")
|
||
|
||
def load_model(self, model_path):
|
||
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
||
self.model.eval()
|
||
print(f"Loaded model from {model_path}")
|
||
|
||
|
||
def predict(self,save_dir):
|
||
|
||
print("Starting prediction...")
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
test_set = TestDatasetFromFolder(self.args, self.args.path_img1,
|
||
self.args.path_img2, self.args.path_lab)
|
||
test_loader = DataLoader(test_set, num_workers=24,
|
||
batch_size=self.args.batch_size, shuffle=False)
|
||
|
||
self.model.eval()
|
||
inter, union = 0, 0
|
||
with torch.no_grad():
|
||
for image1, image2, label, image_names in tqdm(test_loader, desc='Testing'):
|
||
image1 = image1.to(self.device, dtype=torch.float)
|
||
image2 = image2.to(self.device, dtype=torch.float)
|
||
label = label.to(self.device, dtype=torch.float)
|
||
|
||
output, _, _ = self.model(image1, image2)
|
||
pred = torch.argmax(output, 1).unsqueeze(1)
|
||
label = torch.argmax(label, 1).unsqueeze(1)
|
||
|
||
for i in range(pred.size(0)):
|
||
gt_value = (label[i] > 0).float()
|
||
prob = (pred[i] > 0).float()
|
||
gt_np = np.squeeze(gt_value.cpu().detach().numpy())
|
||
prob_np = np.squeeze(prob.cpu().detach().numpy())
|
||
intr, unn = calMetric_iou(gt_np, prob_np)
|
||
inter += intr
|
||
union += unn
|
||
binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8') # 将预测值转换为0和255
|
||
|
||
result = Image.fromarray(binary_result) # 使用转换后的二值图像
|
||
result.save(os.path.join(self.args.save_dir, os.path.basename(image_names[i])))
|
||
|
||
iou = inter / union if union != 0 else 0
|
||
print(f"Test IoU: {iou:.4f}")
|
||
|
||
# def predict_without_label(self):
|
||
# print("Starting prediction...")
|
||
# if not os.path.exists(self.args.save_dir):
|
||
# os.makedirs(self.args.save_dir)
|
||
#
|
||
# # 使用无标签的数据集
|
||
# test_set = TestDatasetFromFolderWithoutLabel(self.args, self.args.path_img1, self.args.path_img2)
|
||
# test_loader = DataLoader(test_set, num_workers=24, batch_size=self.args.batch_size, shuffle=False)
|
||
#
|
||
# self.model.eval()
|
||
# with torch.no_grad():
|
||
# for image1, image2, image_names in tqdm(test_loader, desc='Testing'):
|
||
# image1 = image1.to(self.device, dtype=torch.float)
|
||
# image2 = image2.to(self.device, dtype=torch.float)
|
||
#
|
||
# output, _, _ = self.model(image1, image2)
|
||
# pred = torch.argmax(output, 1).unsqueeze(1)
|
||
# prob_np = pred.squeeze().cpu().detach().numpy()
|
||
#
|
||
# # 二值化并保存结果
|
||
# binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8')
|
||
# result = Image.fromarray(binary_result)
|
||
# result.save(os.path.join(self.args.save_dir, image_names))
|
||
|
||
def predict_without_label(self, img1, img2):
|
||
"""
|
||
直接对输入的 image1 和 image2 进行预测。
|
||
|
||
Args:
|
||
image1 (torch.Tensor 或 np.ndarray): 输入图像1(需与模型输入兼容)。
|
||
image2 (torch.Tensor 或 np.ndarray): 输入图像2(需与模型输入兼容)。
|
||
|
||
Returns:
|
||
torch.Tensor: 模型的输出(未经过 argmax 处理)。
|
||
torch.Tensor: 二值化后的预测结果(0 或 255)。
|
||
"""
|
||
print("Starting prediction...")
|
||
# if not os.path.exists(self.args.save_dir):
|
||
# os.makedirs(self.args.save_dir)
|
||
transform = get_transform(convert=True, normalize=True)
|
||
# image1 = transform(img1.convert('RGB'))
|
||
# image2 = transform(img2.convert('RGB'))
|
||
image1 = transform(img1.convert('RGB')).to(self.device) # 移动到设备
|
||
image2 = transform(img2.convert('RGB')).to(self.device) # 移动到设备
|
||
# # 确保输入是 Tensor,并移动到指定设备
|
||
# if isinstance(image1, np.ndarray):
|
||
# image1 = torch.from_numpy(image1).to(self.device, dtype=torch.float)
|
||
# if isinstance(image2, np.ndarray):
|
||
# image2 = torch.from_numpy(image2).to(self.device, dtype=torch.float)
|
||
|
||
# 模型推理
|
||
self.model.eval()
|
||
with torch.no_grad():
|
||
output, _, _ = self.model(image1.unsqueeze(0), image2.unsqueeze(0)) # 添加 batch 维度
|
||
pred = torch.argmax(output, 1).squeeze(0) # 移除 batch 维度
|
||
prob_np = pred.cpu().detach().numpy()
|
||
|
||
# 二值化
|
||
binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8')
|
||
|
||
return output, binary_result |