327 lines
14 KiB
Python
327 lines
14 KiB
Python
|
|
# ChangeDetection.py
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import torch
|
|||
|
|
import numpy as np
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
import itertools
|
|||
|
|
import torch.optim as optim
|
|||
|
|
from torch.utils.data import DataLoader
|
|||
|
|
|
|||
|
|
from CropLand_CD_module.cood_csv import PatchIndexer, PredictionAggregator
|
|||
|
|
from CropLand_CD_module.data_utils import DynamicPatchDataset, LoadDatasetFromCoords, calMetric_iou, \
|
|||
|
|
TestDatasetFromPic
|
|||
|
|
from CropLand_CD_module.loss.losses import cross_entropy
|
|||
|
|
from CropLand_CD_module.model.network import CDNet
|
|||
|
|
# from model.network import CDNet
|
|||
|
|
# from data_utils import DynamicPatchDataset, LoadDatasetFromCoords, TestDatasetFromCoords, calMetric_iou, get_transform, \
|
|||
|
|
# TestDatasetFromPic
|
|||
|
|
# from loss.losses import cross_entropy
|
|||
|
|
from PIL import Image
|
|||
|
|
# from cood_csv import PatchIndexer, PredictionAggregator
|
|||
|
|
import torchvision.transforms as transforms
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChangeDetectionModule:
|
|||
|
|
|
|||
|
|
def __init__(self, gpu_id,img_size,lr,batch_size=2):
|
|||
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
|
|||
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
|
self.best_iou = 0.0
|
|||
|
|
self.batch_size=batch_size
|
|||
|
|
# Initialize model
|
|||
|
|
self.model = CDNet(img_size=img_size).to(self.device, dtype=torch.float)
|
|||
|
|
if torch.cuda.device_count() > 1:
|
|||
|
|
print(f"Using {torch.cuda.device_count()} GPUs!")
|
|||
|
|
self.model = torch.nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
|
|||
|
|
|
|||
|
|
# Loss and optimizer
|
|||
|
|
self.criterion = cross_entropy().to(self.device, dtype=torch.float)
|
|||
|
|
self.optimizer = optim.Adam(itertools.chain(self.model.parameters()),
|
|||
|
|
lr=lr, betas=(0.9, 0.999))
|
|||
|
|
|
|||
|
|
def train(self):
|
|||
|
|
print("Starting training...")
|
|||
|
|
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
|
|||
|
|
|
|||
|
|
# 自动推断 save_dir = 'val/', 自动保存为 'coords.csv'
|
|||
|
|
# 1训练scv文件生成
|
|||
|
|
indexer_train = PatchIndexer(
|
|||
|
|
img_dir=self.args.hr1_train,
|
|||
|
|
patch_size=256,
|
|||
|
|
stride=256
|
|||
|
|
)
|
|||
|
|
indexer_train.index_all()
|
|||
|
|
# 1验证scv文件生成
|
|||
|
|
indexer_val = PatchIndexer(
|
|||
|
|
img_dir=self.args.hr1_val,
|
|||
|
|
patch_size=256,
|
|||
|
|
stride=256
|
|||
|
|
)
|
|||
|
|
indexer_val.index_all()
|
|||
|
|
|
|||
|
|
# 调用相对路径
|
|||
|
|
# 1获取上一级文件路径
|
|||
|
|
parent_train_cood = os.path.dirname(self.args.hr1_train)
|
|||
|
|
train_cood = os.path.join(parent_train_cood, "coords.csv")
|
|||
|
|
parent_val_cood = os.path.dirname(self.args.hr1_val)
|
|||
|
|
val_cood = os.path.join(parent_val_cood, "coords.csv")
|
|||
|
|
|
|||
|
|
train_set = DynamicPatchDataset(coords_csv=train_cood, img1_dir=self.args.hr1_train,
|
|||
|
|
img2_dir=self.args.hr2_train, label_dir=self.args.lab_train, crop=True,
|
|||
|
|
augment=True, angle=30, num_classes=2)
|
|||
|
|
val_set = LoadDatasetFromCoords(coords_csv=val_cood, hr1_dir=self.args.hr1_val, hr2_dir=self.args.hr2_val,
|
|||
|
|
label_dir=self.args.lab_val)
|
|||
|
|
|
|||
|
|
# train_set = DynamicPatchDataset(self.args.hr1_train, self.args.hr2_train, self.args.lab_train, crop=False)
|
|||
|
|
# val_set = LoadDatasetFromFolder(self.args, self.args.hr1_val, self.args.hr2_val, self.args.lab_val)
|
|||
|
|
train_loader = DataLoader(train_set, num_workers=self.args.num_workers,
|
|||
|
|
batch_size=self.args.batchsize, shuffle=True)
|
|||
|
|
val_loader = DataLoader(val_set, num_workers=self.args.num_workers,
|
|||
|
|
batch_size=self.args.val_batchsize, shuffle=False)
|
|||
|
|
|
|||
|
|
for epoch in range(1, self.args.num_epochs + 1):
|
|||
|
|
self.model.train()
|
|||
|
|
train_bar = tqdm(train_loader, desc=f"Epoch [{epoch}/{self.args.num_epochs}]")
|
|||
|
|
for hr_img1, hr_img2, label in train_bar:
|
|||
|
|
hr_img1 = hr_img1.to(self.device, dtype=torch.float)
|
|||
|
|
hr_img2 = hr_img2.to(self.device, dtype=torch.float)
|
|||
|
|
label = label.to(self.device, dtype=torch.float)
|
|||
|
|
label = torch.argmax(label, 1).unsqueeze(1).float()
|
|||
|
|
|
|||
|
|
self.optimizer.zero_grad()
|
|||
|
|
out1, out2, out3 = self.model(hr_img1, hr_img2)
|
|||
|
|
cd_loss = (self.criterion(out1, label) +
|
|||
|
|
self.criterion(out2, label) +
|
|||
|
|
self.criterion(out3, label))
|
|||
|
|
cd_loss.backward()
|
|||
|
|
self.optimizer.step()
|
|||
|
|
train_bar.set_postfix(loss=cd_loss.item())
|
|||
|
|
|
|||
|
|
# Evaluate only on best epoch (IoU improved)
|
|||
|
|
val_iou = self.validate(val_loader)
|
|||
|
|
print(f"Validation IoU: {val_iou:.4f}")
|
|||
|
|
|
|||
|
|
if val_iou > self.best_iou:
|
|||
|
|
self.best_iou = val_iou
|
|||
|
|
self._save_best_model()
|
|||
|
|
print(f"New best model saved with IoU: {val_iou:.4f}")
|
|||
|
|
|
|||
|
|
def validate(self, val_loader):
|
|||
|
|
self.model.eval()
|
|||
|
|
inter, union = 0, 0
|
|||
|
|
with torch.no_grad():
|
|||
|
|
for hr_img1, hr_img2, label in tqdm(val_loader, desc='Validating'):
|
|||
|
|
hr_img1 = hr_img1.to(self.device, dtype=torch.float)
|
|||
|
|
hr_img2 = hr_img2.to(self.device, dtype=torch.float)
|
|||
|
|
label = label.to(self.device, dtype=torch.float)
|
|||
|
|
label = torch.argmax(label, 1).unsqueeze(1).float()
|
|||
|
|
|
|||
|
|
output, _, _ = self.model(hr_img1, hr_img2)
|
|||
|
|
pred = torch.argmax(output, 1).unsqueeze(1).float()
|
|||
|
|
gt = (label > 0).float()
|
|||
|
|
prob = (pred > 0).float()
|
|||
|
|
|
|||
|
|
gt_np = np.squeeze(gt.cpu().detach().numpy())
|
|||
|
|
pred_np = np.squeeze(prob.cpu().detach().numpy())
|
|||
|
|
intr, unn = calMetric_iou(gt_np, pred_np)
|
|||
|
|
inter += intr
|
|||
|
|
union += unn
|
|||
|
|
|
|||
|
|
iou = inter / union if union != 0 else 0
|
|||
|
|
return iou
|
|||
|
|
|
|||
|
|
def _save_best_model(self):
|
|||
|
|
if not os.path.exists(self.args.model_dir):
|
|||
|
|
os.makedirs(self.args.model_dir)
|
|||
|
|
# Remove any existing .pth files
|
|||
|
|
for f in os.listdir(self.args.model_dir):
|
|||
|
|
if f.endswith('.pth'):
|
|||
|
|
os.remove(os.path.join(self.args.model_dir, f))
|
|||
|
|
save_path = os.path.join(self.args.model_dir, 'best_model.pth')
|
|||
|
|
torch.save(self.model.state_dict(), save_path)
|
|||
|
|
print(f"Best model saved to {save_path}")
|
|||
|
|
|
|||
|
|
def load_model(self, model_path):
|
|||
|
|
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
|||
|
|
self.model.eval()
|
|||
|
|
print(f"Loaded model from {model_path}")
|
|||
|
|
|
|||
|
|
def predict(self):
|
|||
|
|
print(f"结果将保存到目录: {self.args.save_dir}")
|
|||
|
|
print("Starting prediction...")
|
|||
|
|
indexer_test = PatchIndexer(
|
|||
|
|
img_dir=self.args.path_img1,
|
|||
|
|
patch_size=256,
|
|||
|
|
stride=256
|
|||
|
|
)
|
|||
|
|
indexer_test.index_all()
|
|||
|
|
# 调用scv相对路径
|
|||
|
|
parent_test_cood = os.path.dirname(self.args.path_img1)
|
|||
|
|
test_cood = os.path.join(parent_test_cood, "coords.csv")
|
|||
|
|
|
|||
|
|
if not os.path.exists(self.args.save_dir):
|
|||
|
|
os.makedirs(self.args.save_dir)
|
|||
|
|
|
|||
|
|
transform = transforms.Compose([
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
img1Url = r"D:\project\2025-7-10_data+model\val\image1\1.png"
|
|||
|
|
img2Url = r"D:\project\2025-7-10_data+model\val\image2\1.png"
|
|||
|
|
|
|||
|
|
test_set = TestDatasetFromPic(img1Url, img2Url)
|
|||
|
|
|
|||
|
|
# test_set = TestDatasetFromCoords(coords_csv=test_cood, hr1_dir=self.args.path_img1, hr2_dir=self.args.path_img2,
|
|||
|
|
# label_dir=None, transform=transform)
|
|||
|
|
# test_set = TestDatasetFromFolder(self.args, self.args.path_img1,
|
|||
|
|
# self.args.path_img2, self.args.path_lab)
|
|||
|
|
test_loader = DataLoader(test_set, num_workers=24,
|
|||
|
|
batch_size=self.args.batch_size, shuffle=False)
|
|||
|
|
|
|||
|
|
self.model.eval()
|
|||
|
|
aggregator = PredictionAggregator(img_dir=self.args.path_img1)
|
|||
|
|
# aggregator=PredictionAggregator()
|
|||
|
|
test_bar = tqdm(test_loader)
|
|||
|
|
inter, union = 0, 0
|
|||
|
|
|
|||
|
|
for img1, img2, label, patch_name, fname_list, x_list, y_list, h_list, w_list in test_bar:
|
|||
|
|
img1 = img1.to(self.device)
|
|||
|
|
img2 = img2.to(self.device)
|
|||
|
|
label = label.to(self.device)
|
|||
|
|
|
|||
|
|
with torch.no_grad():
|
|||
|
|
output = self.model(img1, img2)
|
|||
|
|
if isinstance(output, tuple): # ← 检查是否是 tuple
|
|||
|
|
output = output[0]
|
|||
|
|
|
|||
|
|
pred = torch.argmax(output, 1).unsqueeze(1).float()
|
|||
|
|
# label = torch.argmax(label, 1).unsqueeze(1).float()
|
|||
|
|
label = label.float()
|
|||
|
|
|
|||
|
|
for i in range(img1.size(0)):
|
|||
|
|
# prob = pred[i].cpu().numpy()
|
|||
|
|
# gt_value = label[i].cpu().numpy()
|
|||
|
|
patch = pred[i][0].cpu().numpy().astype(np.uint8)
|
|||
|
|
gt_value = label[i][0].cpu().numpy().astype(np.uint8)
|
|||
|
|
|
|||
|
|
# intr, unn = calMetric_iou(gt_value.squeeze(), prob.squeeze())
|
|||
|
|
intr, unn = calMetric_iou(gt_value, patch)
|
|||
|
|
inter += intr
|
|||
|
|
union += unn
|
|||
|
|
fname = fname_list[i]
|
|||
|
|
x, y = x_list[i].item(), y_list[i].item()
|
|||
|
|
# patch = pred[i][0].cpu().numpy().astype(np.uint8)
|
|||
|
|
h, w = h_list[i].item(), w_list[i].item()
|
|||
|
|
aggregator.add_patch(fname, x, y, patch, h, w)
|
|||
|
|
# aggregator.add_patch(fname, x, y, patch)
|
|||
|
|
|
|||
|
|
test_bar.set_description('IoU: %.4f' % (inter / union if union > 0 else 0))
|
|||
|
|
aggregator.save_all(self.args.save_dir)
|
|||
|
|
iou = inter / union if union != 0 else 0
|
|||
|
|
print(f"Test IoU: {iou:.4f}")
|
|||
|
|
|
|||
|
|
def predict_from_imgurl(self,img1Url,img2Url,save_dir):
|
|||
|
|
print(f"结果将保存到目录: {save_dir}")
|
|||
|
|
print("Starting prediction...")
|
|||
|
|
# indexer_test = PatchIndexer(
|
|||
|
|
# img_dir=self.args.path_img1,
|
|||
|
|
# patch_size=256,
|
|||
|
|
# stride=256
|
|||
|
|
# )
|
|||
|
|
# indexer_test.index_all()
|
|||
|
|
# # 调用scv相对路径
|
|||
|
|
# parent_test_cood = os.path.dirname(self.args.path_img1)
|
|||
|
|
# # test_cood = os.path.join(parent_test_cood, "coords.csv")
|
|||
|
|
|
|||
|
|
# if not os.path.exists(self.args.save_dir):
|
|||
|
|
# os.makedirs(self.args.save_dir)
|
|||
|
|
|
|||
|
|
transform = transforms.Compose([
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# img1Url = r"D:\project\2025-7-10_data+model\val\image1\1.png"
|
|||
|
|
# img2Url = r"D:\project\2025-7-10_data+model\val\image2\1.png"
|
|||
|
|
|
|||
|
|
test_set = TestDatasetFromPic(img1Url, img2Url, None, transform)
|
|||
|
|
|
|||
|
|
# test_set = TestDatasetFromCoords(coords_csv=test_cood, hr1_dir=self.args.path_img1, hr2_dir=self.args.path_img2,
|
|||
|
|
# label_dir=None, transform=transform)
|
|||
|
|
# test_set = TestDatasetFromFolder(self.args, self.args.path_img1,
|
|||
|
|
# self.args.path_img2, self.args.path_lab)
|
|||
|
|
# DataLoader(num_workers=24) 尝试启动多进程加载数据。 Sanic 的 worker 进程默认是守护进程(daemon=True),导致无法再创建子进程。后续考虑将服务独立
|
|||
|
|
test_loader = DataLoader(test_set, num_workers=0,
|
|||
|
|
batch_size=self.batch_size, shuffle=False)
|
|||
|
|
|
|||
|
|
self.model.eval()
|
|||
|
|
img1_dir=os.path.dirname(img1Url)
|
|||
|
|
aggregator = PredictionAggregator(img_dir=img1_dir)
|
|||
|
|
# aggregator=PredictionAggregator()
|
|||
|
|
test_bar = tqdm(test_loader)
|
|||
|
|
inter, union = 0, 0
|
|||
|
|
|
|||
|
|
for img1, img2, label, patch_name, fname_list, x_list, y_list, h_list, w_list in test_bar:
|
|||
|
|
img1 = img1.to(self.device)
|
|||
|
|
img2 = img2.to(self.device)
|
|||
|
|
label = label.to(self.device)
|
|||
|
|
|
|||
|
|
with torch.no_grad():
|
|||
|
|
output = self.model(img1, img2)
|
|||
|
|
if isinstance(output, tuple): # ← 检查是否是 tuple
|
|||
|
|
output = output[0]
|
|||
|
|
|
|||
|
|
pred = torch.argmax(output, 1).unsqueeze(1).float()
|
|||
|
|
# label = torch.argmax(label, 1).unsqueeze(1).float()
|
|||
|
|
label = label.float()
|
|||
|
|
|
|||
|
|
for i in range(img1.size(0)):
|
|||
|
|
# prob = pred[i].cpu().numpy()
|
|||
|
|
# gt_value = label[i].cpu().numpy()
|
|||
|
|
patch = pred[i][0].cpu().numpy().astype(np.uint8)
|
|||
|
|
gt_value = label[i][0].cpu().numpy().astype(np.uint8)
|
|||
|
|
|
|||
|
|
# intr, unn = calMetric_iou(gt_value.squeeze(), prob.squeeze())
|
|||
|
|
intr, unn = calMetric_iou(gt_value, patch)
|
|||
|
|
inter += intr
|
|||
|
|
union += unn
|
|||
|
|
fname = fname_list[i]
|
|||
|
|
x, y = x_list[i].item(), y_list[i].item()
|
|||
|
|
# patch = pred[i][0].cpu().numpy().astype(np.uint8)
|
|||
|
|
h, w = h_list[i].item(), w_list[i].item()
|
|||
|
|
aggregator.add_patch(fname, x, y, patch, h, w)
|
|||
|
|
# aggregator.add_patch(fname, x, y, patch)
|
|||
|
|
|
|||
|
|
test_bar.set_description('IoU: %.4f' % (inter / union if union > 0 else 0))
|
|||
|
|
out_save_path=aggregator.save_all(save_dir)
|
|||
|
|
print(f"out_save_pathout_save_path {out_save_path}")
|
|||
|
|
iou = inter / union if union != 0 else 0
|
|||
|
|
print(f"Test IoU: {iou:.4f}")
|
|||
|
|
return out_save_path
|
|||
|
|
|
|||
|
|
# for image1, image2, label, image_names in tqdm(test_loader, desc='Testing'):
|
|||
|
|
# image1 = image1.to(self.device, dtype=torch.float)
|
|||
|
|
# image2 = image2.to(self.device, dtype=torch.float)
|
|||
|
|
# label = label.to(self.device, dtype=torch.float)
|
|||
|
|
|
|||
|
|
# output, _, _ = self.model(image1, image2)
|
|||
|
|
# pred = torch.argmax(output, 1).unsqueeze(1)
|
|||
|
|
# label = torch.argmax(label, 1).unsqueeze(1)
|
|||
|
|
|
|||
|
|
# for i in range(pred.size(0)):
|
|||
|
|
# gt_value = (label[i] > 0).float()
|
|||
|
|
# prob = (pred[i] > 0).float()
|
|||
|
|
# gt_np = np.squeeze(gt_value.cpu().detach().numpy())
|
|||
|
|
# prob_np = np.squeeze(prob.cpu().detach().numpy())
|
|||
|
|
# intr, unn = calMetric_iou(gt_np, prob_np)
|
|||
|
|
# inter += intr
|
|||
|
|
# union += unn
|
|||
|
|
# binary_result = np.where(prob_np > 0.5, 255, 0).astype('uint8') # 将预测值转换为0和255
|
|||
|
|
|
|||
|
|
# result = Image.fromarray(binary_result) # 使用转换后的二值图像
|
|||
|
|
# result.save(os.path.join(self.args.save_dir, image_names[i]))
|
|||
|
|
|
|||
|
|
# iou = inter / union if union != 0 else 0
|
|||
|
|
# print(f"Test IoU: {iou:.4f}"
|