139 lines
6.3 KiB
Python
139 lines
6.3 KiB
Python
import argparse
|
|
import os
|
|
# from change_detection import ChangeDetection
|
|
from PIL import Image
|
|
import numpy as np
|
|
|
|
from Ai_tottle.cropland_module.change_detection_module import ChangeDetectionModule
|
|
|
|
|
|
# from cropland_module.change_detection_module import ChangeDetectionModule
|
|
|
|
|
|
def process_and_replace_images(input_folder):
|
|
"""处理输入文件夹中的图像,并直接覆盖原文件"""
|
|
# 支持的文件格式
|
|
extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'}
|
|
|
|
# 处理每个文件
|
|
for filename in os.listdir(input_folder):
|
|
if os.path.splitext(filename)[1].lower() in extensions:
|
|
filepath = os.path.join(input_folder, filename)
|
|
|
|
# 处理图像
|
|
img = Image.open(filepath)
|
|
arr = np.array(img)
|
|
if arr.dtype == np.float32 or arr.max() <= 1.0:
|
|
arr = (arr * 255).clip(0, 255).astype('uint8')
|
|
else:
|
|
arr = arr.astype('uint8')
|
|
result = Image.fromarray(arr)
|
|
result.save(filepath)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='Change Detection Pipeline')
|
|
parser.add_argument('--suffix', nargs='+', default=['.tif', '.png', '.jpg'],
|
|
help='Supported image suffixes (default: .tif .png .jpg)')
|
|
parser.add_argument('--lr', type=float, default=0.0001,
|
|
help='Learning rate (default: 0.0001)')
|
|
# 让用户交互式选择模式
|
|
mode = input("请选择模式 (输入 1 或 2):\n1. 训练模式\n2. 测试模式\n").strip()
|
|
while mode not in ['1', '2']:
|
|
mode = input("输入无效!请重新输入 1 或 2: ").strip()
|
|
|
|
# 基本参数
|
|
parser.add_argument('--mode', default='train' if mode == '1' else 'test',
|
|
help='train or test mode (自动设置)')
|
|
parser.add_argument('--gpu_id', default="0", type=str)
|
|
parser.add_argument('--n_class', default=2, type=int)
|
|
parser.add_argument('--img_size', default=256, type=int)
|
|
|
|
# 训练参数默认值
|
|
train_defaults = {
|
|
'hr1_train': '/media/data0/HL/6-3_CLCD/train_1/time1', # 训练集,第一期影像路径
|
|
'hr2_train': '/media/data0/HL/6-3_CLCD/train_1/time2', # 训练集,第二期影像路径
|
|
'lab_train': '/media/data0/HL/6-3_CLCD/train_1/label', # 训练集,标签路径
|
|
'hr1_val': '/media/data0/HL/6-3_CLCD/val_1/time1', # 验证集,第一期影像路径
|
|
'hr2_val': '/media/data0/HL/6-3_CLCD/val_1/time2', # 验证集,第二期影像路径
|
|
'lab_val': '/media/data0/HL/6-3_CLCD/val_1/label', # 验证集,标签路径
|
|
'model_dir': '/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/model/save_epoch' # 最优权重保存路径
|
|
}
|
|
|
|
# # 测试参数默认值
|
|
# test_defaults = {
|
|
# 'model_path': r'D:\code\612\model+best_epoch\best_epoch\crop\netCD_epoch_22.pth',#最优权重保存路径
|
|
# 'path_img1': r'D:\code\612\test_1\time1',#测试集,第一期影像路径
|
|
# 'path_img2': r'D:\code\612\test_1\time2',#测试集,第二期影像路径
|
|
# 'path_lab': r'D:\code\612\test_1\label',#测试集,标签路径
|
|
# 'save_dir': r'D:\code\612\model+best_epoch\CropLand-CD-main_3\res'#测试保存路径
|
|
# }
|
|
|
|
# # 测试参数默认值
|
|
# test_defaults = {
|
|
# 'model_path': r'D:\project\dlfg\model+best_epoch\CropLand-CD-main_3\CropLand-CD-main\epochs\CLCD\MSCANet\netCD_epoch_22.pth', # 最优权重保存路径
|
|
# 'path_img1': r'D:\project\dlfg\CLCD\val_1\time1', # 测试集,第一期影像路径
|
|
# 'path_img2': r'D:\project\dlfg\CLCD\val_1\time2', # 测试集,第二期影像路径
|
|
# 'path_lab': r'D:\project\dlfg\CLCD\val_1\label', # 测试集,标签路径
|
|
# 'save_dir': r'D:\project\dlfg\model+best_epoch\CropLand-CD-main_3\CropLand-CD-main\res' # 测试保存路径
|
|
# }
|
|
|
|
# 测试参数默认值
|
|
test_defaults = {
|
|
'model_path': r'D:\project\dlfg\model+best_epoch\best_epoch\building\netCD_epoch_34.pth', # 最优权重保存路径
|
|
'path_img1': r'D:\project\dlfg\LEVIR-CD\test\A1', # 测试集,第一期影像路径
|
|
'path_img2': r'D:\project\dlfg\LEVIR-CD\test\B1', # 测试集,第二期影像路径
|
|
'path_lab': r'D:\project\dlfg\LEVIR-CD\test\out', # 测试集,标签路径
|
|
'save_dir': r'D:\project\dlfg\model+best_epoch\CropLand-CD-main_3\CropLand-CD-main\res' # 测试保存路径
|
|
}
|
|
|
|
# 根据模式添加参数
|
|
if mode == '1':
|
|
for name, value in train_defaults.items():
|
|
parser.add_argument(f'--{name}', default=value, type=str)
|
|
parser.add_argument('--num_epochs', default=100, type=int)
|
|
parser.add_argument('--batchsize', default=2, type=int)
|
|
parser.add_argument('--val_batchsize', default=16, type=int)
|
|
parser.add_argument('--num_workers', default=24, type=int)
|
|
|
|
else:
|
|
for name, value in test_defaults.items():
|
|
parser.add_argument(f'--{name}', default=value, type=str)
|
|
parser.add_argument('--batch_size', default=8, type=int)
|
|
|
|
return parser.parse_args([]) # 空列表表示不解析命令行参数
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print("\n=== 变化检测系统 ===")
|
|
# args = parse_args()
|
|
gpu_id = "0"
|
|
img_size = 256
|
|
lr = 0.0001
|
|
model_path = r'D:\project\dlfg\model+best_epoch\best_epoch\building\netCD_epoch_34.pth'
|
|
|
|
# image1_path = r'D:\project\dlfg\LEVIR-CD\test\A1\00201.png'
|
|
# image1_path = r'C:\Users\14867\Desktop\0613-building\Axz\xz\270\00201.jpg'
|
|
|
|
# # 平移 5%-10%
|
|
# image1_path = r'C:\Users\14867\Desktop\0613-building\Axz\pingyi\00201.png'
|
|
# image2_path = r'D:\project\dlfg\LEVIR-CD\test\B1\00201.png'
|
|
image1_path = r'C:\Users\14867\Desktop\0613-building\Axz\pingyi\00205.png'
|
|
image2_path = r'C:\Users\14867\Desktop\0613-building\B\00205.png'
|
|
# image1_path=r'C:\Users\14867\Desktop\0613-building\B\00201.png'
|
|
# image2_path=r'C:\Users\14867\Desktop\0613-building\A\00201.png'
|
|
|
|
cd = ChangeDetectionModule(gpu_id, img_size, lr)
|
|
cd.load_model(model_path)
|
|
|
|
from PIL import Image
|
|
|
|
image1 = Image.open(image1_path).convert('RGB')
|
|
image2 = Image.open(image2_path).convert('RGB')
|
|
|
|
output, binary_result = cd.predict_without_label(image1, image2)
|
|
|
|
# 保存结果(可选)
|
|
result = Image.fromarray(binary_result)
|
|
result.save("output.png")
|