ai_project_v1/CropLand_CD_module/change_detection.py

364 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ChangeDetection.py
import os
import torch
import numpy as np
from tqdm import tqdm
import itertools
import torch.optim as optim
from torch.utils.data import DataLoader
from model.network import CDNet
from data_utils import DynamicPatchDataset, LoadDatasetFromCoords, TestDatasetFromCoords, calMetric_iou, get_transform, \
TestDatasetFromPic
from loss.losses import cross_entropy
from PIL import Image
from cood_csv import PatchIndexer, PredictionAggregator
import torchvision.transforms as transforms
class ChangeDetection:
def __init__(self, args):
# 参数完整性检查
required_args = ['gpu_id', 'img_size', 'lr']
for arg in required_args:
if not hasattr(args, arg):
raise ValueError(f"Missing required argument: {arg}")
self.args = args
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.best_iou = 0.0
# Initialize model
self.model = CDNet(img_size=self.args.img_size).to(self.device, dtype=torch.float)
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
# Loss and optimizer
self.criterion = cross_entropy().to(self.device, dtype=torch.float)
self.optimizer = optim.Adam(itertools.chain(self.model.parameters()),
lr=self.args.lr, betas=(0.9, 0.999))
def train(self):
print("Starting training...")
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# 自动推断 save_dir = 'val/', 自动保存为 'coords.csv'
# 1训练scv文件生成
indexer_train = PatchIndexer(
img_dir=self.args.hr1_train,
patch_size=256,
stride=256
)
indexer_train.index_all()
# 1验证scv文件生成
indexer_val = PatchIndexer(
img_dir=self.args.hr1_val,
patch_size=256,
stride=256
)
indexer_val.index_all()
# 调用相对路径
# 1获取上一级文件路径
parent_train_cood = os.path.dirname(self.args.hr1_train)
train_cood = os.path.join(parent_train_cood, "coords.csv")
parent_val_cood = os.path.dirname(self.args.hr1_val)
val_cood = os.path.join(parent_val_cood, "coords.csv")
train_set = DynamicPatchDataset(coords_csv=train_cood, img1_dir=self.args.hr1_train,
img2_dir=self.args.hr2_train, label_dir=self.args.lab_train, crop=True,
augment=True, angle=30, num_classes=2)
val_set = LoadDatasetFromCoords(coords_csv=val_cood, hr1_dir=self.args.hr1_val, hr2_dir=self.args.hr2_val,
label_dir=self.args.lab_val)
# train_set = DynamicPatchDataset(self.args.hr1_train, self.args.hr2_train, self.args.lab_train, crop=False)
# val_set = LoadDatasetFromFolder(self.args, self.args.hr1_val, self.args.hr2_val, self.args.lab_val)
train_loader = DataLoader(train_set, num_workers=self.args.num_workers,
batch_size=self.args.batchsize, shuffle=True)
val_loader = DataLoader(val_set, num_workers=self.args.num_workers,
batch_size=self.args.val_batchsize, shuffle=False)
for epoch in range(1, self.args.num_epochs + 1):
self.model.train()
train_bar = tqdm(train_loader, desc=f"Epoch [{epoch}/{self.args.num_epochs}]")
for hr_img1, hr_img2, label in train_bar:
hr_img1 = hr_img1.to(self.device, dtype=torch.float)
hr_img2 = hr_img2.to(self.device, dtype=torch.float)
label = label.to(self.device, dtype=torch.float)
label = torch.argmax(label, 1).unsqueeze(1).float()
self.optimizer.zero_grad()
out1, out2, out3 = self.model(hr_img1, hr_img2)
cd_loss = (self.criterion(out1, label) +
self.criterion(out2, label) +
self.criterion(out3, label))
cd_loss.backward()
self.optimizer.step()
train_bar.set_postfix(loss=cd_loss.item())
# Evaluate only on best epoch (IoU improved)
val_iou = self.validate(val_loader)
print(f"Validation IoU: {val_iou:.4f}")
if val_iou > self.best_iou:
self.best_iou = val_iou
self._save_best_model()
print(f"New best model saved with IoU: {val_iou:.4f}")
def validate(self, val_loader):
self.model.eval()
inter, union = 0, 0
with torch.no_grad():
for hr_img1, hr_img2, label in tqdm(val_loader, desc='Validating'):
hr_img1 = hr_img1.to(self.device, dtype=torch.float)
hr_img2 = hr_img2.to(self.device, dtype=torch.float)
label = label.to(self.device, dtype=torch.float)
label = torch.argmax(label, 1).unsqueeze(1).float()
output, _, _ = self.model(hr_img1, hr_img2)
pred = torch.argmax(output, 1).unsqueeze(1).float()
gt = (label > 0).float()
prob = (pred > 0).float()
gt_np = np.squeeze(gt.cpu().detach().numpy())
pred_np = np.squeeze(prob.cpu().detach().numpy())
intr, unn = calMetric_iou(gt_np, pred_np)
inter += intr
union += unn
iou = inter / union if union != 0 else 0
return iou
def _save_best_model(self):
if not os.path.exists(self.args.model_dir):
os.makedirs(self.args.model_dir)
# Remove any existing .pth files
for f in os.listdir(self.args.model_dir):
if f.endswith('.pth'):
os.remove(os.path.join(self.args.model_dir, f))
save_path = os.path.join(self.args.model_dir, 'best_model.pth')
torch.save(self.model.state_dict(), save_path)
print(f"Best model saved to {save_path}")
def load_model(self, model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.eval()
print(f"Loaded model from {model_path}")
def predict(self):
print(f"结果将保存到目录: {self.args.save_dir}")
print("Starting prediction...")
indexer_test = PatchIndexer(
img_dir=self.args.path_img1,
patch_size=256,
stride=256
)
indexer_test.index_all()
# 调用scv相对路径
parent_test_cood = os.path.dirname(self.args.path_img1)
test_cood = os.path.join(parent_test_cood, "coords.csv")
if not os.path.exists(self.args.save_dir):
os.makedirs(self.args.save_dir)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])
img1Url = r"D:\project\2025-7-10_data+model\val\image1\1.png"
img2Url = r"D:\project\2025-7-10_data+model\val\image2\1.png"
test_set = TestDatasetFromPic(img1Url, img2Url)
# test_set = TestDatasetFromCoords(coords_csv=test_cood, hr1_dir=self.args.path_img1, hr2_dir=self.args.path_img2,
# label_dir=None, transform=transform)
# test_set = TestDatasetFromFolder(self.args, self.args.path_img1,
# self.args.path_img2, self.args.path_lab)
test_loader = DataLoader(test_set, num_workers=24,
batch_size=self.args.batch_size, shuffle=False)
self.model.eval()
aggregator = PredictionAggregator(img_dir=self.args.path_img1)
# aggregator=PredictionAggregator()
test_bar = tqdm(test_loader)
inter, union = 0, 0
for img1, img2, label, patch_name, fname_list, x_list, y_list, h_list, w_list in test_bar:
img1 = img1.to(self.device)
img2 = img2.to(self.device)
label = label.to(self.device)
with torch.no_grad():
output = self.model(img1, img2)
if isinstance(output, tuple): # ← 检查是否是 tuple
output = output[0]
pred = torch.argmax(output, 1).unsqueeze(1).float()
# label = torch.argmax(label, 1).unsqueeze(1).float()
label = label.float()
for i in range(img1.size(0)):
# prob = pred[i].cpu().numpy()
# gt_value = label[i].cpu().numpy()
patch = pred[i][0].cpu().numpy().astype(np.uint8)
gt_value = label[i][0].cpu().numpy().astype(np.uint8)
# intr, unn = calMetric_iou(gt_value.squeeze(), prob.squeeze())
intr, unn = calMetric_iou(gt_value, patch)
inter += intr
union += unn
fname = fname_list[i]
x, y = x_list[i].item(), y_list[i].item()
# patch = pred[i][0].cpu().numpy().astype(np.uint8)
h, w = h_list[i].item(), w_list[i].item()
aggregator.add_patch(fname, x, y, patch, h, w)
# aggregator.add_patch(fname, x, y, patch)
test_bar.set_description('IoU: %.4f' % (inter / union if union > 0 else 0))
aggregator.save_all(self.args.save_dir)
iou = inter / union if union != 0 else 0
print(f"Test IoU: {iou:.4f}")
def predict_from_imgurl(self,img1Url,img2Url,save_dir):
print(f"结果将保存到目录: {self.args.save_dir}")
print("Starting prediction...")
indexer_test = PatchIndexer(
img_dir=self.args.path_img1,
patch_size=256,
stride=256
)
indexer_test.index_all()
# 调用scv相对路径
parent_test_cood = os.path.dirname(self.args.path_img1)
# test_cood = os.path.join(parent_test_cood, "coords.csv")
if not os.path.exists(self.args.save_dir):
os.makedirs(self.args.save_dir)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])
img1Url = r"D:\project\2025-7-10_data+model\val\image1\1.png"
img2Url = r"D:\project\2025-7-10_data+model\val\image2\1.png"
test_set = TestDatasetFromPic(img1Url, img2Url, None, transform)
# test_set = TestDatasetFromCoords(coords_csv=test_cood, hr1_dir=self.args.path_img1, hr2_dir=self.args.path_img2,
# label_dir=None, transform=transform)
# test_set = TestDatasetFromFolder(self.args, self.args.path_img1,
# self.args.path_img2, self.args.path_lab)
test_loader = DataLoader(test_set, num_workers=24,
batch_size=self.args.batch_size, shuffle=False)
self.model.eval()
aggregator = PredictionAggregator(img_dir=self.args.path_img1)
# aggregator=PredictionAggregator()
test_bar = tqdm(test_loader)
inter, union = 0, 0
for img1, img2, label, patch_name, fname_list, x_list, y_list, h_list, w_list in test_bar:
img1 = img1.to(self.device)
img2 = img2.to(self.device)
label = label.to(self.device)
with torch.no_grad():
output = self.model(img1, img2)
if isinstance(output, tuple): # ← 检查是否是 tuple
output = output[0]
pred = torch.argmax(output, 1).unsqueeze(1).float()
# label = torch.argmax(label, 1).unsqueeze(1).float()
label = label.float()
for i in range(img1.size(0)):
# prob = pred[i].cpu().numpy()
# gt_value = label[i].cpu().numpy()
patch = pred[i][0].cpu().numpy().astype(np.uint8)
gt_value = label[i][0].cpu().numpy().astype(np.uint8)
# intr, unn = calMetric_iou(gt_value.squeeze(), prob.squeeze())
intr, unn = calMetric_iou(gt_value, patch)
inter += intr
union += unn
fname = fname_list[i]
x, y = x_list[i].item(), y_list[i].item()
# patch = pred[i][0].cpu().numpy().astype(np.uint8)
h, w = h_list[i].item(), w_list[i].item()
aggregator.add_patch(fname, x, y, patch, h, w)
# aggregator.add_patch(fname, x, y, patch)
test_bar.set_description('IoU: %.4f' % (inter / union if union > 0 else 0))
out_save_path=aggregator.save_all(save_dir)
print(f"out_save_pathout_save_path {out_save_path}")
iou = inter / union if union != 0 else 0
print(f"Test IoU: {iou:.4f}")
return out_save_path
# for image1, image2, label, image_names in tqdm(test_loader, desc='Testing'):
# image1 = image1.to(self.device, dtype=torch.float)
# image2 = image2.to(self.device, dtype=torch.float)
# label = label.to(self.device, dtype=torch.float)
# output, _, _ = self.model(image1, image2)
# pred = torch.argmax(output, 1).unsqueeze(1)
# label = torch.argmax(label, 1).unsqueeze(1)
# for i in range(pred.size(0)):
# gt_value = (label[i] > 0).float()
# prob = (pred[i] > 0).float()
# gt_np = np.squeeze(gt_value.cpu().detach().numpy())
# prob_np = np.squeeze(prob.cpu().detach().numpy())
# intr, unn = calMetric_iou(gt_np, prob_np)
# inter += intr
# union += unn
# binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8') # 将预测值转换为0和255
# result = Image.fromarray(binary_result) # 使用转换后的二值图像
# result.save(os.path.join(self.args.save_dir, image_names[i]))
# iou = inter / union if union != 0 else 0
# print(f"Test IoU: {iou:.4f}")
def predict_without_label(self, img1, img2):
"""
直接对输入的 image1 和 image2 进行预测。
Args:
image1 (torch.Tensor 或 np.ndarray): 输入图像1需与模型输入兼容
image2 (torch.Tensor 或 np.ndarray): 输入图像2需与模型输入兼容
Returns:
torch.Tensor: 模型的输出(未经过 argmax 处理)。
torch.Tensor: 二值化后的预测结果0 或 255
"""
print("Starting prediction...")
# if not os.path.exists(self.args.save_dir):
# os.makedirs(self.args.save_dir)
transform = get_transform(convert=True, normalize=True)
# image1 = transform(img1.convert('RGB'))
# image2 = transform(img2.convert('RGB'))
image1 = transform(img1.convert('RGB')).to(self.device) # 移动到设备
image2 = transform(img2.convert('RGB')).to(self.device) # 移动到设备
# # 确保输入是 Tensor并移动到指定设备
# if isinstance(image1, np.ndarray):
# image1 = torch.from_numpy(image1).to(self.device, dtype=torch.float)
# if isinstance(image2, np.ndarray):
# image2 = torch.from_numpy(image2).to(self.device, dtype=torch.float)
# 模型推理
self.model.eval()
with torch.no_grad():
output, _, _ = self.model(image1.unsqueeze(0), image2.unsqueeze(0)) # 添加 batch 维度
pred = torch.argmax(output, 1).squeeze(0) # 移除 batch 维度
prob_np = pred.cpu().detach().numpy()
# 二值化
binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8')
return output, binary_result