ai_project_v1/CropLand_CD_module/change_detection_module.py

327 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ChangeDetection.py
import os
import torch
import numpy as np
from tqdm import tqdm
import itertools
import torch.optim as optim
from torch.utils.data import DataLoader
from CropLand_CD_module.cood_csv import PatchIndexer, PredictionAggregator
from CropLand_CD_module.data_utils import DynamicPatchDataset, LoadDatasetFromCoords, calMetric_iou, \
TestDatasetFromPic
from CropLand_CD_module.loss.losses import cross_entropy
from CropLand_CD_module.model.network import CDNet
# from model.network import CDNet
# from data_utils import DynamicPatchDataset, LoadDatasetFromCoords, TestDatasetFromCoords, calMetric_iou, get_transform, \
# TestDatasetFromPic
# from loss.losses import cross_entropy
from PIL import Image
# from cood_csv import PatchIndexer, PredictionAggregator
import torchvision.transforms as transforms
class ChangeDetectionModule:
def __init__(self, gpu_id,img_size,lr,batch_size=2):
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.best_iou = 0.0
self.batch_size=batch_size
# Initialize model
self.model = CDNet(img_size=img_size).to(self.device, dtype=torch.float)
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
# Loss and optimizer
self.criterion = cross_entropy().to(self.device, dtype=torch.float)
self.optimizer = optim.Adam(itertools.chain(self.model.parameters()),
lr=lr, betas=(0.9, 0.999))
def train(self):
print("Starting training...")
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# 自动推断 save_dir = 'val/', 自动保存为 'coords.csv'
# 1训练scv文件生成
indexer_train = PatchIndexer(
img_dir=self.args.hr1_train,
patch_size=256,
stride=256
)
indexer_train.index_all()
# 1验证scv文件生成
indexer_val = PatchIndexer(
img_dir=self.args.hr1_val,
patch_size=256,
stride=256
)
indexer_val.index_all()
# 调用相对路径
# 1获取上一级文件路径
parent_train_cood = os.path.dirname(self.args.hr1_train)
train_cood = os.path.join(parent_train_cood, "coords.csv")
parent_val_cood = os.path.dirname(self.args.hr1_val)
val_cood = os.path.join(parent_val_cood, "coords.csv")
train_set = DynamicPatchDataset(coords_csv=train_cood, img1_dir=self.args.hr1_train,
img2_dir=self.args.hr2_train, label_dir=self.args.lab_train, crop=True,
augment=True, angle=30, num_classes=2)
val_set = LoadDatasetFromCoords(coords_csv=val_cood, hr1_dir=self.args.hr1_val, hr2_dir=self.args.hr2_val,
label_dir=self.args.lab_val)
# train_set = DynamicPatchDataset(self.args.hr1_train, self.args.hr2_train, self.args.lab_train, crop=False)
# val_set = LoadDatasetFromFolder(self.args, self.args.hr1_val, self.args.hr2_val, self.args.lab_val)
train_loader = DataLoader(train_set, num_workers=self.args.num_workers,
batch_size=self.args.batchsize, shuffle=True)
val_loader = DataLoader(val_set, num_workers=self.args.num_workers,
batch_size=self.args.val_batchsize, shuffle=False)
for epoch in range(1, self.args.num_epochs + 1):
self.model.train()
train_bar = tqdm(train_loader, desc=f"Epoch [{epoch}/{self.args.num_epochs}]")
for hr_img1, hr_img2, label in train_bar:
hr_img1 = hr_img1.to(self.device, dtype=torch.float)
hr_img2 = hr_img2.to(self.device, dtype=torch.float)
label = label.to(self.device, dtype=torch.float)
label = torch.argmax(label, 1).unsqueeze(1).float()
self.optimizer.zero_grad()
out1, out2, out3 = self.model(hr_img1, hr_img2)
cd_loss = (self.criterion(out1, label) +
self.criterion(out2, label) +
self.criterion(out3, label))
cd_loss.backward()
self.optimizer.step()
train_bar.set_postfix(loss=cd_loss.item())
# Evaluate only on best epoch (IoU improved)
val_iou = self.validate(val_loader)
print(f"Validation IoU: {val_iou:.4f}")
if val_iou > self.best_iou:
self.best_iou = val_iou
self._save_best_model()
print(f"New best model saved with IoU: {val_iou:.4f}")
def validate(self, val_loader):
self.model.eval()
inter, union = 0, 0
with torch.no_grad():
for hr_img1, hr_img2, label in tqdm(val_loader, desc='Validating'):
hr_img1 = hr_img1.to(self.device, dtype=torch.float)
hr_img2 = hr_img2.to(self.device, dtype=torch.float)
label = label.to(self.device, dtype=torch.float)
label = torch.argmax(label, 1).unsqueeze(1).float()
output, _, _ = self.model(hr_img1, hr_img2)
pred = torch.argmax(output, 1).unsqueeze(1).float()
gt = (label > 0).float()
prob = (pred > 0).float()
gt_np = np.squeeze(gt.cpu().detach().numpy())
pred_np = np.squeeze(prob.cpu().detach().numpy())
intr, unn = calMetric_iou(gt_np, pred_np)
inter += intr
union += unn
iou = inter / union if union != 0 else 0
return iou
def _save_best_model(self):
if not os.path.exists(self.args.model_dir):
os.makedirs(self.args.model_dir)
# Remove any existing .pth files
for f in os.listdir(self.args.model_dir):
if f.endswith('.pth'):
os.remove(os.path.join(self.args.model_dir, f))
save_path = os.path.join(self.args.model_dir, 'best_model.pth')
torch.save(self.model.state_dict(), save_path)
print(f"Best model saved to {save_path}")
def load_model(self, model_path):
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.eval()
print(f"Loaded model from {model_path}")
def predict(self):
print(f"结果将保存到目录: {self.args.save_dir}")
print("Starting prediction...")
indexer_test = PatchIndexer(
img_dir=self.args.path_img1,
patch_size=256,
stride=256
)
indexer_test.index_all()
# 调用scv相对路径
parent_test_cood = os.path.dirname(self.args.path_img1)
test_cood = os.path.join(parent_test_cood, "coords.csv")
if not os.path.exists(self.args.save_dir):
os.makedirs(self.args.save_dir)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])
img1Url = r"D:\project\2025-7-10_data+model\val\image1\1.png"
img2Url = r"D:\project\2025-7-10_data+model\val\image2\1.png"
test_set = TestDatasetFromPic(img1Url, img2Url)
# test_set = TestDatasetFromCoords(coords_csv=test_cood, hr1_dir=self.args.path_img1, hr2_dir=self.args.path_img2,
# label_dir=None, transform=transform)
# test_set = TestDatasetFromFolder(self.args, self.args.path_img1,
# self.args.path_img2, self.args.path_lab)
test_loader = DataLoader(test_set, num_workers=24,
batch_size=self.args.batch_size, shuffle=False)
self.model.eval()
aggregator = PredictionAggregator(img_dir=self.args.path_img1)
# aggregator=PredictionAggregator()
test_bar = tqdm(test_loader)
inter, union = 0, 0
for img1, img2, label, patch_name, fname_list, x_list, y_list, h_list, w_list in test_bar:
img1 = img1.to(self.device)
img2 = img2.to(self.device)
label = label.to(self.device)
with torch.no_grad():
output = self.model(img1, img2)
if isinstance(output, tuple): # ← 检查是否是 tuple
output = output[0]
pred = torch.argmax(output, 1).unsqueeze(1).float()
# label = torch.argmax(label, 1).unsqueeze(1).float()
label = label.float()
for i in range(img1.size(0)):
# prob = pred[i].cpu().numpy()
# gt_value = label[i].cpu().numpy()
patch = pred[i][0].cpu().numpy().astype(np.uint8)
gt_value = label[i][0].cpu().numpy().astype(np.uint8)
# intr, unn = calMetric_iou(gt_value.squeeze(), prob.squeeze())
intr, unn = calMetric_iou(gt_value, patch)
inter += intr
union += unn
fname = fname_list[i]
x, y = x_list[i].item(), y_list[i].item()
# patch = pred[i][0].cpu().numpy().astype(np.uint8)
h, w = h_list[i].item(), w_list[i].item()
aggregator.add_patch(fname, x, y, patch, h, w)
# aggregator.add_patch(fname, x, y, patch)
test_bar.set_description('IoU: %.4f' % (inter / union if union > 0 else 0))
aggregator.save_all(self.args.save_dir)
iou = inter / union if union != 0 else 0
print(f"Test IoU: {iou:.4f}")
def predict_from_imgurl(self,img1Url,img2Url,save_dir):
print(f"结果将保存到目录: {save_dir}")
print("Starting prediction...")
# indexer_test = PatchIndexer(
# img_dir=self.args.path_img1,
# patch_size=256,
# stride=256
# )
# indexer_test.index_all()
# # 调用scv相对路径
# parent_test_cood = os.path.dirname(self.args.path_img1)
# # test_cood = os.path.join(parent_test_cood, "coords.csv")
# if not os.path.exists(self.args.save_dir):
# os.makedirs(self.args.save_dir)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])
# img1Url = r"D:\project\2025-7-10_data+model\val\image1\1.png"
# img2Url = r"D:\project\2025-7-10_data+model\val\image2\1.png"
test_set = TestDatasetFromPic(img1Url, img2Url, None, transform)
# test_set = TestDatasetFromCoords(coords_csv=test_cood, hr1_dir=self.args.path_img1, hr2_dir=self.args.path_img2,
# label_dir=None, transform=transform)
# test_set = TestDatasetFromFolder(self.args, self.args.path_img1,
# self.args.path_img2, self.args.path_lab)
# DataLoader(num_workers=24) 尝试启动多进程加载数据。 Sanic 的 worker 进程默认是守护进程daemon=True导致无法再创建子进程。后续考虑将服务独立
test_loader = DataLoader(test_set, num_workers=0,
batch_size=self.batch_size, shuffle=False)
self.model.eval()
img1_dir=os.path.dirname(img1Url)
aggregator = PredictionAggregator(img_dir=img1_dir)
# aggregator=PredictionAggregator()
test_bar = tqdm(test_loader)
inter, union = 0, 0
for img1, img2, label, patch_name, fname_list, x_list, y_list, h_list, w_list in test_bar:
img1 = img1.to(self.device)
img2 = img2.to(self.device)
label = label.to(self.device)
with torch.no_grad():
output = self.model(img1, img2)
if isinstance(output, tuple): # ← 检查是否是 tuple
output = output[0]
pred = torch.argmax(output, 1).unsqueeze(1).float()
# label = torch.argmax(label, 1).unsqueeze(1).float()
label = label.float()
for i in range(img1.size(0)):
# prob = pred[i].cpu().numpy()
# gt_value = label[i].cpu().numpy()
patch = pred[i][0].cpu().numpy().astype(np.uint8)
gt_value = label[i][0].cpu().numpy().astype(np.uint8)
# intr, unn = calMetric_iou(gt_value.squeeze(), prob.squeeze())
intr, unn = calMetric_iou(gt_value, patch)
inter += intr
union += unn
fname = fname_list[i]
x, y = x_list[i].item(), y_list[i].item()
# patch = pred[i][0].cpu().numpy().astype(np.uint8)
h, w = h_list[i].item(), w_list[i].item()
aggregator.add_patch(fname, x, y, patch, h, w)
# aggregator.add_patch(fname, x, y, patch)
test_bar.set_description('IoU: %.4f' % (inter / union if union > 0 else 0))
out_save_path=aggregator.save_all(save_dir)
print(f"out_save_pathout_save_path {out_save_path}")
iou = inter / union if union != 0 else 0
print(f"Test IoU: {iou:.4f}")
return out_save_path
# for image1, image2, label, image_names in tqdm(test_loader, desc='Testing'):
# image1 = image1.to(self.device, dtype=torch.float)
# image2 = image2.to(self.device, dtype=torch.float)
# label = label.to(self.device, dtype=torch.float)
# output, _, _ = self.model(image1, image2)
# pred = torch.argmax(output, 1).unsqueeze(1)
# label = torch.argmax(label, 1).unsqueeze(1)
# for i in range(pred.size(0)):
# gt_value = (label[i] > 0).float()
# prob = (pred[i] > 0).float()
# gt_np = np.squeeze(gt_value.cpu().detach().numpy())
# prob_np = np.squeeze(prob.cpu().detach().numpy())
# intr, unn = calMetric_iou(gt_np, prob_np)
# inter += intr
# union += unn
# binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8') # 将预测值转换为0和255
# result = Image.fromarray(binary_result) # 使用转换后的二值图像
# result.save(os.path.join(self.args.save_dir, image_names[i]))
# iou = inter / union if union != 0 else 0
# print(f"Test IoU: {iou:.4f}"