ai_project_v1/cropland_module/change_detection_module.py

222 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ChangeDetection.py
import os
import torch
import numpy as np
from tqdm import tqdm
import itertools
import torch.optim as optim
from torch.utils.data import DataLoader
from cropland_module.data_utils import get_transform, calMetric_iou, TestDatasetFromFolder
from cropland_module.loss.losses import cross_entropy
from cropland_module.model.network import CDNet
# from model.network import CDNet
# from data_utils import DA_DatasetFromFolder, LoadDatasetFromFolder, TestDatasetFromFolder, calMetric_iou, \
# TestDatasetFromFolderWithoutLabel, get_transform
# from loss.losses import cross_entropy
from PIL import Image
class ChangeDetectionModule:
def __init__(self, gpu_id,img_size,lr):
# 参数完整性检查
# required_args = ['gpu_id', 'img_size', 'lr']
# for arg in required_args:
# if not hasattr(args, arg):
# raise ValueError(f"Missing required argument: {arg}")
# self.args = args
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.best_iou = 0.0
# Initialize model
self.model = CDNet(img_size=img_size).to(self.device, dtype=torch.float)
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
# Loss and optimizer
self.criterion = cross_entropy().to(self.device, dtype=torch.float)
self.optimizer = optim.Adam(itertools.chain(self.model.parameters()),
lr=lr, betas=(0.9, 0.999))
# def train(self):
# print("Starting training...")
# train_set = DA_DatasetFromFolder(self.args.hr1_train, self.args.hr2_train, self.args.lab_train, crop=False)
# val_set = LoadDatasetFromFolder(self.args, self.args.hr1_val, self.args.hr2_val, self.args.lab_val)
# train_loader = DataLoader(train_set, num_workers=self.args.num_workers,
# batch_size=self.args.batchsize, shuffle=True)
# val_loader = DataLoader(val_set, num_workers=self.args.num_workers,
# batch_size=self.args.val_batchsize, shuffle=False)
#
# for epoch in range(1, self.args.num_epochs + 1):
# self.model.train()
# train_bar = tqdm(train_loader, desc=f"Epoch [{epoch}/{self.args.num_epochs}]")
# for hr_img1, hr_img2, label in train_bar:
# hr_img1 = hr_img1.to(self.device, dtype=torch.float)
# hr_img2 = hr_img2.to(self.device, dtype=torch.float)
# label = label.to(self.device, dtype=torch.float)
# label = torch.argmax(label, 1).unsqueeze(1).float()
#
# self.optimizer.zero_grad()
# out1, out2, out3 = self.model(hr_img1, hr_img2)
# cd_loss = (self.criterion(out1, label) +
# self.criterion(out2, label) +
# self.criterion(out3, label))
# cd_loss.backward()
# self.optimizer.step()
# train_bar.set_postfix(loss=cd_loss.item())
#
# # Evaluate only on best epoch (IoU improved)
# val_iou = self.validate(val_loader)
# print(f"Validation IoU: {val_iou:.4f}")
#
# if val_iou > self.best_iou:
# self.best_iou = val_iou
# self._save_best_model()
# print(f"New best model saved with IoU: {val_iou:.4f}")
# def validate(self, val_loader):
# self.model.eval()
# inter, union = 0, 0
# with torch.no_grad():
# for hr_img1, hr_img2, label in tqdm(val_loader, desc='Validating'):
# hr_img1 = hr_img1.to(self.device, dtype=torch.float)
# hr_img2 = hr_img2.to(self.device, dtype=torch.float)
# label = label.to(self.device, dtype=torch.float)
# label = torch.argmax(label, 1).unsqueeze(1).float()
#
# output, _, _ = self.model(hr_img1, hr_img2)
# pred = torch.argmax(output, 1).unsqueeze(1).float()
# gt = (label > 0).float()
# prob = (pred > 0).float()
#
# gt_np = np.squeeze(gt.cpu().detach().numpy())
# pred_np = np.squeeze(prob.cpu().detach().numpy())
# intr, unn = calMetric_iou(gt_np, pred_np)
# inter += intr
# union += unn
#
# iou = inter / union if union != 0 else 0
# return iou
# def _save_best_model(self):
# if not os.path.exists(self.args.model_dir):
# os.makedirs(self.args.model_dir)
# # Remove any existing .pth files
# for f in os.listdir(self.args.model_dir):
# if f.endswith('.pth'):
# os.remove(os.path.join(self.args.model_dir, f))
# save_path = os.path.join(self.args.model_dir, 'best_model.pth')
# torch.save(self.model.state_dict(), save_path)
# print(f"Best model saved to {save_path}")
def load_model(self, model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.eval()
print(f"Loaded model from {model_path}")
def predict(self,save_dir):
print("Starting prediction...")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
test_set = TestDatasetFromFolder(self.args, self.args.path_img1,
self.args.path_img2, self.args.path_lab)
test_loader = DataLoader(test_set, num_workers=24,
batch_size=self.args.batch_size, shuffle=False)
self.model.eval()
inter, union = 0, 0
with torch.no_grad():
for image1, image2, label, image_names in tqdm(test_loader, desc='Testing'):
image1 = image1.to(self.device, dtype=torch.float)
image2 = image2.to(self.device, dtype=torch.float)
label = label.to(self.device, dtype=torch.float)
output, _, _ = self.model(image1, image2)
pred = torch.argmax(output, 1).unsqueeze(1)
label = torch.argmax(label, 1).unsqueeze(1)
for i in range(pred.size(0)):
gt_value = (label[i] > 0).float()
prob = (pred[i] > 0).float()
gt_np = np.squeeze(gt_value.cpu().detach().numpy())
prob_np = np.squeeze(prob.cpu().detach().numpy())
intr, unn = calMetric_iou(gt_np, prob_np)
inter += intr
union += unn
binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8') # 将预测值转换为0和255
result = Image.fromarray(binary_result) # 使用转换后的二值图像
result.save(os.path.join(self.args.save_dir, os.path.basename(image_names[i])))
iou = inter / union if union != 0 else 0
print(f"Test IoU: {iou:.4f}")
# def predict_without_label(self):
# print("Starting prediction...")
# if not os.path.exists(self.args.save_dir):
# os.makedirs(self.args.save_dir)
#
# # 使用无标签的数据集
# test_set = TestDatasetFromFolderWithoutLabel(self.args, self.args.path_img1, self.args.path_img2)
# test_loader = DataLoader(test_set, num_workers=24, batch_size=self.args.batch_size, shuffle=False)
#
# self.model.eval()
# with torch.no_grad():
# for image1, image2, image_names in tqdm(test_loader, desc='Testing'):
# image1 = image1.to(self.device, dtype=torch.float)
# image2 = image2.to(self.device, dtype=torch.float)
#
# output, _, _ = self.model(image1, image2)
# pred = torch.argmax(output, 1).unsqueeze(1)
# prob_np = pred.squeeze().cpu().detach().numpy()
#
# # 二值化并保存结果
# binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8')
# result = Image.fromarray(binary_result)
# result.save(os.path.join(self.args.save_dir, image_names))
def predict_without_label(self, img1, img2):
"""
直接对输入的 image1 和 image2 进行预测。
Args:
image1 (torch.Tensor 或 np.ndarray): 输入图像1需与模型输入兼容
image2 (torch.Tensor 或 np.ndarray): 输入图像2需与模型输入兼容
Returns:
torch.Tensor: 模型的输出(未经过 argmax 处理)。
torch.Tensor: 二值化后的预测结果0 或 255
"""
print("Starting prediction...")
# if not os.path.exists(self.args.save_dir):
# os.makedirs(self.args.save_dir)
transform = get_transform(convert=True, normalize=True)
# image1 = transform(img1.convert('RGB'))
# image2 = transform(img2.convert('RGB'))
image1 = transform(img1.convert('RGB')).to(self.device) # 移动到设备
image2 = transform(img2.convert('RGB')).to(self.device) # 移动到设备
# # 确保输入是 Tensor并移动到指定设备
# if isinstance(image1, np.ndarray):
# image1 = torch.from_numpy(image1).to(self.device, dtype=torch.float)
# if isinstance(image2, np.ndarray):
# image2 = torch.from_numpy(image2).to(self.device, dtype=torch.float)
# 模型推理
self.model.eval()
with torch.no_grad():
output, _, _ = self.model(image1.unsqueeze(0), image2.unsqueeze(0)) # 添加 batch 维度
pred = torch.argmax(output, 1).squeeze(0) # 移除 batch 维度
prob_np = pred.cpu().detach().numpy()
# 二值化
binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8')
return output, binary_result