139 lines
6.3 KiB
Python

import argparse
import os
# from change_detection import ChangeDetection
from PIL import Image
import numpy as np
from Ai_tottle.cropland_module.change_detection_module import ChangeDetectionModule
# from cropland_module.change_detection_module import ChangeDetectionModule
def process_and_replace_images(input_folder):
"""处理输入文件夹中的图像,并直接覆盖原文件"""
# 支持的文件格式
extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'}
# 处理每个文件
for filename in os.listdir(input_folder):
if os.path.splitext(filename)[1].lower() in extensions:
filepath = os.path.join(input_folder, filename)
# 处理图像
img = Image.open(filepath)
arr = np.array(img)
if arr.dtype == np.float32 or arr.max() <= 1.0:
arr = (arr * 255).clip(0, 255).astype('uint8')
else:
arr = arr.astype('uint8')
result = Image.fromarray(arr)
result.save(filepath)
def parse_args():
parser = argparse.ArgumentParser(description='Change Detection Pipeline')
parser.add_argument('--suffix', nargs='+', default=['.tif', '.png', '.jpg'],
help='Supported image suffixes (default: .tif .png .jpg)')
parser.add_argument('--lr', type=float, default=0.0001,
help='Learning rate (default: 0.0001)')
# 让用户交互式选择模式
mode = input("请选择模式 (输入 1 或 2):\n1. 训练模式\n2. 测试模式\n").strip()
while mode not in ['1', '2']:
mode = input("输入无效!请重新输入 1 或 2: ").strip()
# 基本参数
parser.add_argument('--mode', default='train' if mode == '1' else 'test',
help='train or test mode (自动设置)')
parser.add_argument('--gpu_id', default="0", type=str)
parser.add_argument('--n_class', default=2, type=int)
parser.add_argument('--img_size', default=256, type=int)
# 训练参数默认值
train_defaults = {
'hr1_train': '/media/data0/HL/6-3_CLCD/train_1/time1', # 训练集,第一期影像路径
'hr2_train': '/media/data0/HL/6-3_CLCD/train_1/time2', # 训练集,第二期影像路径
'lab_train': '/media/data0/HL/6-3_CLCD/train_1/label', # 训练集,标签路径
'hr1_val': '/media/data0/HL/6-3_CLCD/val_1/time1', # 验证集,第一期影像路径
'hr2_val': '/media/data0/HL/6-3_CLCD/val_1/time2', # 验证集,第二期影像路径
'lab_val': '/media/data0/HL/6-3_CLCD/val_1/label', # 验证集,标签路径
'model_dir': '/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/model/save_epoch' # 最优权重保存路径
}
# # 测试参数默认值
# test_defaults = {
# 'model_path': r'D:\code\612\model+best_epoch\best_epoch\crop\netCD_epoch_22.pth',#最优权重保存路径
# 'path_img1': r'D:\code\612\test_1\time1',#测试集,第一期影像路径
# 'path_img2': r'D:\code\612\test_1\time2',#测试集,第二期影像路径
# 'path_lab': r'D:\code\612\test_1\label',#测试集,标签路径
# 'save_dir': r'D:\code\612\model+best_epoch\CropLand-CD-main_3\res'#测试保存路径
# }
# # 测试参数默认值
# test_defaults = {
# 'model_path': r'D:\project\dlfg\model+best_epoch\CropLand-CD-main_3\CropLand-CD-main\epochs\CLCD\MSCANet\netCD_epoch_22.pth', # 最优权重保存路径
# 'path_img1': r'D:\project\dlfg\CLCD\val_1\time1', # 测试集,第一期影像路径
# 'path_img2': r'D:\project\dlfg\CLCD\val_1\time2', # 测试集,第二期影像路径
# 'path_lab': r'D:\project\dlfg\CLCD\val_1\label', # 测试集,标签路径
# 'save_dir': r'D:\project\dlfg\model+best_epoch\CropLand-CD-main_3\CropLand-CD-main\res' # 测试保存路径
# }
# 测试参数默认值
test_defaults = {
'model_path': r'D:\project\dlfg\model+best_epoch\best_epoch\building\netCD_epoch_34.pth', # 最优权重保存路径
'path_img1': r'D:\project\dlfg\LEVIR-CD\test\A1', # 测试集,第一期影像路径
'path_img2': r'D:\project\dlfg\LEVIR-CD\test\B1', # 测试集,第二期影像路径
'path_lab': r'D:\project\dlfg\LEVIR-CD\test\out', # 测试集,标签路径
'save_dir': r'D:\project\dlfg\model+best_epoch\CropLand-CD-main_3\CropLand-CD-main\res' # 测试保存路径
}
# 根据模式添加参数
if mode == '1':
for name, value in train_defaults.items():
parser.add_argument(f'--{name}', default=value, type=str)
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--batchsize', default=2, type=int)
parser.add_argument('--val_batchsize', default=16, type=int)
parser.add_argument('--num_workers', default=24, type=int)
else:
for name, value in test_defaults.items():
parser.add_argument(f'--{name}', default=value, type=str)
parser.add_argument('--batch_size', default=8, type=int)
return parser.parse_args([]) # 空列表表示不解析命令行参数
if __name__ == '__main__':
print("\n=== 变化检测系统 ===")
# args = parse_args()
gpu_id = "0"
img_size = 256
lr = 0.0001
model_path = r'D:\project\dlfg\model+best_epoch\best_epoch\building\netCD_epoch_34.pth'
# image1_path = r'D:\project\dlfg\LEVIR-CD\test\A1\00201.png'
# image1_path = r'C:\Users\14867\Desktop\0613-building\Axz\xz\270\00201.jpg'
# # 平移 5%-10%
# image1_path = r'C:\Users\14867\Desktop\0613-building\Axz\pingyi\00201.png'
# image2_path = r'D:\project\dlfg\LEVIR-CD\test\B1\00201.png'
image1_path = r'C:\Users\14867\Desktop\0613-building\Axz\pingyi\00205.png'
image2_path = r'C:\Users\14867\Desktop\0613-building\B\00205.png'
# image1_path=r'C:\Users\14867\Desktop\0613-building\B\00201.png'
# image2_path=r'C:\Users\14867\Desktop\0613-building\A\00201.png'
cd = ChangeDetectionModule(gpu_id, img_size, lr)
cd.load_model(model_path)
from PIL import Image
image1 = Image.open(image1_path).convert('RGB')
image2 = Image.open(image2_path).convert('RGB')
output, binary_result = cd.predict_without_label(image1, image2)
# 保存结果(可选)
result = Image.fromarray(binary_result)
result.save("output.png")