140 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import os
from change_detection import ChangeDetection
from PIL import Image
from znzh_x import PredictionVisualizer,ChangeVisualizer
import numpy as np
def process_and_replace_images(input_folder):
"""处理输入文件夹中的图像,并直接覆盖原文件"""
# 支持的文件格式
extensions = {'.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp'}
# 处理每个文件
for filename in os.listdir(input_folder):
if os.path.splitext(filename)[1].lower() in extensions:
filepath = os.path.join(input_folder, filename)
# 处理图像
img = Image.open(filepath)
arr = np.array(img)
if arr.dtype == np.float32 or arr.max() <= 1.0:
arr = (arr * 255).clip(0, 255).astype('uint8')
else:
arr = arr.astype('uint8')
result = Image.fromarray(arr)
result.save(filepath)
def parse_args():
parser = argparse.ArgumentParser(description='Change Detection Pipeline')
parser.add_argument('--suffix', nargs='+', default=['.tif','.png','.jpg'],
help='Supported image suffixes (default: .tif .png .jpg)')
parser.add_argument('--lr', type=float, default=0.0001,
help='Learning rate (default: 0.0001)')
# 让用户交互式选择模式
mode = input("请选择模式 (输入 1 或 2):\n1. 训练模式\n2. 测试模式\n").strip()
while mode not in ['1', '2']:
mode = input("输入无效!请重新输入 1 或 2: ").strip()
# 基本参数
parser.add_argument('--mode', default='train' if mode == '1' else 'test',
help='train or test mode (自动设置)')
parser.add_argument('--gpu_id', default="3", type=str)
parser.add_argument('--n_class', default=2, type=int)
parser.add_argument('--img_size', default=256, type=int)
#训练参数默认值
train_defaults = {
'hr1_train': r'/media/data0/HL/2025-7-6/train/image1',#训练集,第一期影像路径
'hr2_train': r'/media/data0/HL/2025-7-6/train/image2',#训练集,第二期影像路径
'lab_train': r'/media/data0/HL/2025-7-6/train/label',#训练集o签路径
'hr1_val': r'/media/data0/HL/2025-7-6/val/image1',#验证集,第一期影像路径
'hr2_val': r'/media/data0/HL/2025-7-6/val/image2',#验证集,第二期影像路径
'lab_val': r'/media/data0/HL/2025-7-6/val/label',#验证集,标签路径
'model_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch'#最优权重保存路径
}
# train_defaults = {
# 'hr1_train': r'/media/data0/HL/LEVIR-CD/train/A',#训练集,第一期影像路径
# 'hr2_train': r'/media/data0/HL/LEVIR-CD/train/B',#训练集,第二期影像路径
# 'lab_train': r'/media/data0/HL/LEVIR-CD/train/OUT',#训练集o签路径
# 'hr1_val': r'/media/data0/HL/LEVIR-CD/val/A',#验证集,第一期影像路径
# 'hr2_val': r'/media/data0/HL/LEVIR-CD/val/B',#验证集,第二期影像路径
# 'lab_val': r'/media/data0/HL/LEVIR-CD/val/OUT',#验证集,标签路径
# 'model_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch'#最优权重保存路径
# }
# # 测试参数默认值
# test_defaults = {
# 'model_path': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save_epoch/best_model.pth',#最优权重保存结果
# 'path_img1': r'/media/data0/HL/2025-7-6/val/image1',#测试集,第一期影像路径
# 'path_img2': r'/media/data0/HL/2025-7-6/val/image2',#测试集,第二期影像路径
# 'path_lab': r'/media/data0/HL/2025-7-6/val/label',#测试集,标签路径
# 'save_dir': r'/media/data0/HL/CropLand-CD-main_3/CropLand-CD-main/save'#测试保存路径
#
# }
# 测试参数默认值
test_defaults = {
'model_path': r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save_epoch\best_model.pth',#最优权重保存结果
'path_img1': r'D:\project\2025-7-10_data+model\val\image1',#测试集,第一期影像路径
'path_img2': r'D:\project\2025-7-10_data+model\val\image2',#测试集,第二期影像路径
# 'path_img1': r'C:\Users\14867\Desktop\seg_tiff', # 测试集,第一期影像路径
# 'path_img2': r'C:\Users\14867\Desktop\seg_tiff', # 测试集,第二期影像路径
# 'path_lab': r'D:\project\2025-7-10_data+model\val\label',#测试集,标签路径
'save_dir': r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save'#测试保存路径
}
# 根据模式添加参数
if mode == '1':
for name, value in train_defaults.items():
parser.add_argument(f'--{name}', default=value, type=str)
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--batchsize', default=2, type=int)
parser.add_argument('--val_batchsize', default=16, type=int)
parser.add_argument('--num_workers', default=24, type=int)
else:
for name, value in test_defaults.items():
parser.add_argument(f'--{name}', default=value, type=str)
parser.add_argument('--batch_size', default=8, type=int)
return parser.parse_args([]) # 空列表表示不解析命令行参数
if __name__ == '__main__':
print("\n=== 变化检测系统 ===")
args = parse_args()
cd = ChangeDetection(args)
if args.mode == 'train':
print(f"\n开始训练...\n模型将保存至: {args.model_dir}")
cd.train()
else:
print(f"\n开始测试...\n结果将保存至: {args.save_dir}")
img1=r"D:\project\2025-7-10_data+model\val\image1\2.png"
img2=r"D:\project\2025-7-10_data+model\val\image2\2.png"
cd.load_model(args.model_path)
save_dir= r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save'
out_save_path=cd.predict_from_imgurl(img1,img2,save_dir)
# input_folder =r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save' #测试保存路径
# process_and_replace_images(input_folder)
# image1_dir = args.path_img1
# image2_dir = args.path_img2
# pred_mask_dir = args.path_lab
# # pred_mask_dir = ""
# save_dir = r'D:\project\2025-7-10_data+model\CropLand-CD-main_3\CropLand-CD-main\save_red'#变化标注回影像保存路径
#
# visualizer = ChangeVisualizer(image1_dir, image2_dir, pred_mask_dir, save_dir)
# visualizer.run()
# print("测试完成,变化框回标成功")