import argparse import os from change_detection import ChangeDetection from PIL import Image from znzh_x import PredictionVisualizer,ChangeVisualizer import numpy as np def process_and_replace_images(input_folder): """处理输入文件夹中的图像,并直接覆盖原文件""" # 支持的文件格式 extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'} # 处理每个文件 for filename in os.listdir(input_folder): if os.path.splitext(filename)[1].lower() in extensions: filepath = os.path.join(input_folder, filename) # 处理图像 img = Image.open(filepath) arr = np.array(img) if arr.dtype == np.float32 or arr.max() <= 1.0: arr = (arr * 255).clip(0, 255).astype('uint8') else: arr = arr.astype('uint8') result = Image.fromarray(arr) result.save(filepath) def parse_args(): parser = argparse.ArgumentParser(description='Change Detection Pipeline') parser.add_argument('--suffix', nargs='+', default=['.tif','.png','.jpg'], help='Supported image suffixes (default: .tif .png .jpg)') parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate (default: 0.0001)') # 让用户交互式选择模式 mode = input("请选择模式 (输入 1 或 2):\n1. 训练模式\n2. 测试模式\n").strip() while mode not in ['1', '2']: mode = input("输入无效!请重新输入 1 或 2: ").strip() # 基本参数 parser.add_argument('--mode', default='train' if mode == '1' else 'test', help='train or test mode (自动设置)') parser.add_argument('--gpu_id', default="3", type=str) parser.add_argument('--n_class', default=2, type=int) parser.add_argument('--img_size', default=256, type=int) #训练参数默认值 train_defaults = { 'hr1_train': r'/media/data0/HL/2025-7-6/train/image1',#训练集,第一期影像路径 'hr2_train': r'/media/data0/HL/2025-7-6/train/image2',#训练集,第二期影像路径 'lab_train': r'/media/data0/HL/2025-7-6/train/label',#训练集,o签路径 'hr1_val': r'/media/data0/HL/2025-7-6/val/image1',#验证集,第一期影像路径 'hr2_val': r'/media/data0/HL/2025-7-6/val/image2',#验证集,第二期影像路径 'lab_val': r'/media/data0/HL/2025-7-6/val/label',#验证集,标签路径 'model_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch'#最优权重保存路径 } # train_defaults = { # 'hr1_train': r'/media/data0/HL/LEVIR-CD/train/A',#训练集,第一期影像路径 # 'hr2_train': r'/media/data0/HL/LEVIR-CD/train/B',#训练集,第二期影像路径 # 'lab_train': r'/media/data0/HL/LEVIR-CD/train/OUT',#训练集,o签路径 # 'hr1_val': r'/media/data0/HL/LEVIR-CD/val/A',#验证集,第一期影像路径 # 'hr2_val': r'/media/data0/HL/LEVIR-CD/val/B',#验证集,第二期影像路径 # 'lab_val': r'/media/data0/HL/LEVIR-CD/val/OUT',#验证集,标签路径 # 'model_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch'#最优权重保存路径 # } # # 测试参数默认值 # test_defaults = { # 'model_path': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch/best_model.pth',#最优权重保存结果 # 'path_img1': r'/media/data0/HL/2025-7-6/val/image1',#测试集,第一期影像路径 # 'path_img2': r'/media/data0/HL/2025-7-6/val/image2',#测试集,第二期影像路径 # 'path_lab': r'/media/data0/HL/2025-7-6/val/label',#测试集,标签路径 # 'save_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save'#测试保存路径 # # } # 测试参数默认值 test_defaults = { 'model_path': r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save_epoch\best_model.pth',#最优权重保存结果 'path_img1': r'D:\project\2025-7-10_data+model\val\image1',#测试集,第一期影像路径 'path_img2': r'D:\project\2025-7-10_data+model\val\image2',#测试集,第二期影像路径 # 'path_img1': r'C:\Users\14867\Desktop\seg_tiff', # 测试集,第一期影像路径 # 'path_img2': r'C:\Users\14867\Desktop\seg_tiff', # 测试集,第二期影像路径 # 'path_lab': r'D:\project\2025-7-10_data+model\val\label',#测试集,标签路径 'save_dir': r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save'#测试保存路径 } # 根据模式添加参数 if mode == '1': for name, value in train_defaults.items(): parser.add_argument(f'--{name}', default=value, type=str) parser.add_argument('--num_epochs', default=100, type=int) parser.add_argument('--batchsize', default=2, type=int) parser.add_argument('--val_batchsize', default=16, type=int) parser.add_argument('--num_workers', default=24, type=int) else: for name, value in test_defaults.items(): parser.add_argument(f'--{name}', default=value, type=str) parser.add_argument('--batch_size', default=8, type=int) return parser.parse_args([]) # 空列表表示不解析命令行参数 if __name__ == '__main__': print("\n=== 变化检测系统 ===") args = parse_args() cd = ChangeDetection(args) if args.mode == 'train': print(f"\n开始训练...\n模型将保存至: {args.model_dir}") cd.train() else: print(f"\n开始测试...\n结果将保存至: {args.save_dir}") img1=r"D:\project\2025-7-10_data+model\val\image1\2.png" img2=r"D:\project\2025-7-10_data+model\val\image2\2.png" cd.load_model(args.model_path) save_dir= r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save' out_save_path=cd.predict_from_imgurl(img1,img2,save_dir) # input_folder =r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save' #测试保存路径 # process_and_replace_images(input_folder) # image1_dir = args.path_img1 # image2_dir = args.path_img2 # pred_mask_dir = args.path_lab # # pred_mask_dir = "" # save_dir = r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save_red'#变化标注回影像保存路径 # # visualizer = ChangeVisualizer(image1_dir, image2_dir, pred_mask_dir, save_dir) # visualizer.run() # print("测试完成,变化框回标成功")