140 lines
6.6 KiB
Python
Raw Normal View History

import argparse
import os
from change_detection import ChangeDetection
from PIL import Image
from znzh_x import PredictionVisualizer,ChangeVisualizer
import numpy as np
def process_and_replace_images(input_folder):
"""处理输入文件夹中的图像,并直接覆盖原文件"""
# 支持的文件格式
extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'}
# 处理每个文件
for filename in os.listdir(input_folder):
if os.path.splitext(filename)[1].lower() in extensions:
filepath = os.path.join(input_folder, filename)
# 处理图像
img = Image.open(filepath)
arr = np.array(img)
if arr.dtype == np.float32 or arr.max() <= 1.0:
arr = (arr * 255).clip(0, 255).astype('uint8')
else:
arr = arr.astype('uint8')
result = Image.fromarray(arr)
result.save(filepath)
def parse_args():
parser = argparse.ArgumentParser(description='Change Detection Pipeline')
parser.add_argument('--suffix', nargs='+', default=['.tif','.png','.jpg'],
help='Supported image suffixes (default: .tif .png .jpg)')
parser.add_argument('--lr', type=float, default=0.0001,
help='Learning rate (default: 0.0001)')
# 让用户交互式选择模式
mode = input("请选择模式 (输入 1 或 2):\n1. 训练模式\n2. 测试模式\n").strip()
while mode not in ['1', '2']:
mode = input("输入无效!请重新输入 1 或 2: ").strip()
# 基本参数
parser.add_argument('--mode', default='train' if mode == '1' else 'test',
help='train or test mode (自动设置)')
parser.add_argument('--gpu_id', default="3", type=str)
parser.add_argument('--n_class', default=2, type=int)
parser.add_argument('--img_size', default=256, type=int)
#训练参数默认值
train_defaults = {
'hr1_train': r'/media/data0/HL/2025-7-6/train/image1',#训练集,第一期影像路径
'hr2_train': r'/media/data0/HL/2025-7-6/train/image2',#训练集,第二期影像路径
'lab_train': r'/media/data0/HL/2025-7-6/train/label',#训练集o签路径
'hr1_val': r'/media/data0/HL/2025-7-6/val/image1',#验证集,第一期影像路径
'hr2_val': r'/media/data0/HL/2025-7-6/val/image2',#验证集,第二期影像路径
'lab_val': r'/media/data0/HL/2025-7-6/val/label',#验证集,标签路径
'model_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch'#最优权重保存路径
}
# train_defaults = {
# 'hr1_train': r'/media/data0/HL/LEVIR-CD/train/A',#训练集,第一期影像路径
# 'hr2_train': r'/media/data0/HL/LEVIR-CD/train/B',#训练集,第二期影像路径
# 'lab_train': r'/media/data0/HL/LEVIR-CD/train/OUT',#训练集o签路径
# 'hr1_val': r'/media/data0/HL/LEVIR-CD/val/A',#验证集,第一期影像路径
# 'hr2_val': r'/media/data0/HL/LEVIR-CD/val/B',#验证集,第二期影像路径
# 'lab_val': r'/media/data0/HL/LEVIR-CD/val/OUT',#验证集,标签路径
# 'model_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch'#最优权重保存路径
# }
# # 测试参数默认值
# test_defaults = {
# 'model_path': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch/best_model.pth',#最优权重保存结果
# 'path_img1': r'/media/data0/HL/2025-7-6/val/image1',#测试集,第一期影像路径
# 'path_img2': r'/media/data0/HL/2025-7-6/val/image2',#测试集,第二期影像路径
# 'path_lab': r'/media/data0/HL/2025-7-6/val/label',#测试集,标签路径
# 'save_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save'#测试保存路径
#
# }
# 测试参数默认值
test_defaults = {
'model_path': r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save_epoch\best_model.pth',#最优权重保存结果
'path_img1': r'D:\project\2025-7-10_data+model\val\image1',#测试集,第一期影像路径
'path_img2': r'D:\project\2025-7-10_data+model\val\image2',#测试集,第二期影像路径
# 'path_img1': r'C:\Users\14867\Desktop\seg_tiff', # 测试集,第一期影像路径
# 'path_img2': r'C:\Users\14867\Desktop\seg_tiff', # 测试集,第二期影像路径
# 'path_lab': r'D:\project\2025-7-10_data+model\val\label',#测试集,标签路径
'save_dir': r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save'#测试保存路径
}
# 根据模式添加参数
if mode == '1':
for name, value in train_defaults.items():
parser.add_argument(f'--{name}', default=value, type=str)
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--batchsize', default=2, type=int)
parser.add_argument('--val_batchsize', default=16, type=int)
parser.add_argument('--num_workers', default=24, type=int)
else:
for name, value in test_defaults.items():
parser.add_argument(f'--{name}', default=value, type=str)
parser.add_argument('--batch_size', default=8, type=int)
return parser.parse_args([]) # 空列表表示不解析命令行参数
if __name__ == '__main__':
print("\n=== 变化检测系统 ===")
args = parse_args()
cd = ChangeDetection(args)
if args.mode == 'train':
print(f"\n开始训练...\n模型将保存至: {args.model_dir}")
cd.train()
else:
print(f"\n开始测试...\n结果将保存至: {args.save_dir}")
img1=r"D:\project\2025-7-10_data+model\val\image1\2.png"
img2=r"D:\project\2025-7-10_data+model\val\image2\2.png"
cd.load_model(args.model_path)
save_dir= r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save'
out_save_path=cd.predict_from_imgurl(img1,img2,save_dir)
# input_folder =r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save' #测试保存路径
# process_and_replace_images(input_folder)
# image1_dir = args.path_img1
# image2_dir = args.path_img2
# pred_mask_dir = args.path_lab
# # pred_mask_dir = ""
# save_dir = r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save_red'#变化标注回影像保存路径
#
# visualizer = ChangeVisualizer(image1_dir, image2_dir, pred_mask_dir, save_dir)
# visualizer.run()
# print("测试完成,变化框回标成功")