import streamlit as st import os import cv2 import numpy as np import torch from PIL import Image import matplotlib.pyplot as plt import plotly.express as px import plotly.graph_objects as go from plotly.subplots import make_subplots import tempfile import shutil from change_detection import ChangeDetection from znzh_x import PredictionVisualizer import argparse import geopandas as gpd from io import BytesIO import base64 class VisualizationApp: def __init__(self): self.setup_page() self.model = None self.temp_dir = tempfile.mkdtemp() def setup_page(self): st.set_page_config( page_title="变化检测可视化系统", page_icon="🛰️", layout="wide", initial_sidebar_state="expanded" ) st.title("🛰️ 变化检测可视化系统") st.markdown("---") def load_model(self, model_path): """加载预训练模型""" try: # 创建模拟的args对象 args = argparse.Namespace( gpu_id="0", img_size=256, lr=0.0001, n_class=2, batch_size=1 ) self.model = ChangeDetection(args) self.model.load_model(model_path) return True except Exception as e: st.error(f"模型加载失败: {str(e)}") return False def preprocess_image(self, image, target_size=(256, 256)): """预处理图像""" if isinstance(image, Image.Image): image = np.array(image) # 调整大小 image = cv2.resize(image, target_size) # 归一化 image = image.astype(np.float32) / 255.0 # 转换为tensor image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) return image def predict_change(self, img1, img2): """执行变化检测预测""" if self.model is None: return None try: # 预处理图像 img1_tensor = self.preprocess_image(img1) img2_tensor = self.preprocess_image(img2) # 移动到设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') img1_tensor = img1_tensor.to(device) img2_tensor = img2_tensor.to(device) # 预测 self.model.model.eval() with torch.no_grad(): output = self.model.model(img1_tensor, img2_tensor) if isinstance(output, tuple): output = output[0] # 获取预测结果 pred = torch.argmax(output, 1).squeeze().cpu().numpy() return pred except Exception as e: st.error(f"预测失败: {str(e)}") return None def create_comparison_plot(self, img1, img2, pred_mask=None): """创建对比图""" if pred_mask is not None: fig = make_subplots( rows=2, cols=2, subplot_titles=["第一期影像", "第二期影像", "变化检测结果", "叠加显示"], specs=[[{"type": "image"}, {"type": "image"}], [{"type": "image"}, {"type": "image"}]] ) # 第一期影像 fig.add_trace( go.Image(z=img1), row=1, col=1 ) # 第二期影像 fig.add_trace( go.Image(z=img2), row=1, col=2 ) # 变化检测结果 change_mask_colored = np.zeros((*pred_mask.shape, 3), dtype=np.uint8) change_mask_colored[pred_mask > 0] = [255, 0, 0] # 红色表示变化 fig.add_trace( go.Image(z=change_mask_colored), row=2, col=1 ) # 叠加显示 overlay = img2.copy() if len(overlay.shape) == 3: overlay[pred_mask > 0] = [255, 0, 0] # 在第二期影像上标红 fig.add_trace( go.Image(z=overlay), row=2, col=2 ) else: fig = make_subplots( rows=1, cols=2, subplot_titles=["第一期影像", "第二期影像"], specs=[[{"type": "image"}, {"type": "image"}]] ) fig.add_trace(go.Image(z=img1), row=1, col=1) fig.add_trace(go.Image(z=img2), row=1, col=2) fig.update_layout( height=800, showlegend=False, title_text="影像对比分析" ) # 隐藏坐标轴 fig.update_xaxes(showticklabels=False) fig.update_yaxes(showticklabels=False) return fig def calculate_change_statistics(self, pred_mask): """计算变化统计信息""" total_pixels = pred_mask.size changed_pixels = np.sum(pred_mask > 0) unchanged_pixels = total_pixels - changed_pixels change_ratio = (changed_pixels / total_pixels) * 100 return { "总像素数": total_pixels, "变化像素数": changed_pixels, "未变化像素数": unchanged_pixels, "变化比例": f"{change_ratio:.2f}%" } def export_results(self, img1, img2, pred_mask, base_name="result"): """导出结果""" try: # 保存图像到临时目录 img1_path = os.path.join(self.temp_dir, f"{base_name}_img1.png") img2_path = os.path.join(self.temp_dir, f"{base_name}_img2.png") mask_path = os.path.join(self.temp_dir, f"{base_name}_mask.png") Image.fromarray(img1).save(img1_path) Image.fromarray(img2).save(img2_path) Image.fromarray((pred_mask * 255).astype(np.uint8)).save(mask_path) # 使用PredictionVisualizer生成可视化结果 visualizer = PredictionVisualizer(img1_path, mask_path) result_img, contours = visualizer.draw_contours_on_image() # 转换为可下载的格式 result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)) return result_pil, contours except Exception as e: st.error(f"导出失败: {str(e)}") return None, None def run(self): """运行主界面""" # 侧边栏配置 with st.sidebar: st.header("⚙️ 系统配置") # 模型加载 st.subheader("模型设置") model_path = st.text_input( "模型路径", value="/media/data0/HL/2025-7-6/best_model.pth", help="输入预训练模型的路径" ) if st.button("加载模型"): if os.path.exists(model_path): if self.load_model(model_path): st.success("✅ 模型加载成功!") else: st.error("❌ 模型加载失败!") else: st.error("❌ 模型文件不存在!") # 显示模型状态 if self.model is not None: st.success("🟢 模型已就绪") else: st.warning("🟡 请先加载模型") st.markdown("---") # 参数设置 st.subheader("检测参数") min_area = st.slider("最小变化区域(像素)", 10, 1000, 100) thickness = st.slider("轮廓线粗细", 1, 10, 2) # 主界面 col1, col2 = st.columns(2) with col1: st.subheader("📤 第一期影像") uploaded_img1 = st.file_uploader( "选择第一期影像", type=['png', 'jpg', 'jpeg', 'tif', 'tiff'], key="img1" ) with col2: st.subheader("📤 第二期影像") uploaded_img2 = st.file_uploader( "选择第二期影像", type=['png', 'jpg', 'jpeg', 'tif', 'tiff'], key="img2" ) # 处理上传的图像 if uploaded_img1 and uploaded_img2: # 读取图像 img1 = np.array(Image.open(uploaded_img1)) img2 = np.array(Image.open(uploaded_img2)) # 确保图像尺寸一致 if img1.shape != img2.shape: st.warning("⚠️ 两期影像尺寸不一致,将自动调整") target_shape = (min(img1.shape[0], img2.shape[0]), min(img1.shape[1], img2.shape[1])) img1 = cv2.resize(img1, (target_shape[1], target_shape[0])) img2 = cv2.resize(img2, (target_shape[1], target_shape[0])) # 显示原始图像对比 st.subheader("📊 影像对比") fig_original = self.create_comparison_plot(img1, img2) st.plotly_chart(fig_original, use_container_width=True) # 变化检测按钮 if st.button("🔍 执行变化检测", type="primary"): if self.model is None: st.error("❌ 请先加载模型!") else: with st.spinner("正在进行变化检测..."): pred_mask = self.predict_change(img1, img2) if pred_mask is not None: # 显示检测结果 st.subheader("🎯 变化检测结果") fig_result = self.create_comparison_plot(img1, img2, pred_mask) st.plotly_chart(fig_result, use_container_width=True) # 统计信息 st.subheader("📈 变化统计") stats = self.calculate_change_statistics(pred_mask) col1, col2, col3, col4 = st.columns(4) with col1: st.metric("总像素数", stats["总像素数"]) with col2: st.metric("变化像素数", stats["变化像素数"]) with col3: st.metric("未变化像素数", stats["未变化像素数"]) with col4: st.metric("变化比例", stats["变化比例"]) # 导出功能 st.subheader("💾 结果导出") col1, col2 = st.columns(2) with col1: if st.button("导出可视化结果"): result_img, contours = self.export_results(img1, img2, pred_mask) if result_img is not None: st.image(result_img, caption="变化检测可视化结果") # 提供下载链接 buf = BytesIO() result_img.save(buf, format='PNG') buf.seek(0) st.download_button( label="下载结果图像", data=buf.getvalue(), file_name="change_detection_result.png", mime="image/png" ) with col2: if st.button("生成地理坐标"): st.info("💡 地理坐标导出需要输入影像包含地理参考信息") # 这里可以添加shapefile导出功能 else: st.error("❌ 变化检测失败!") # 使用说明 with st.expander("📖 使用说明"): st.markdown(""" ### 使用步骤: 1. **加载模型**:在侧边栏输入模型路径并点击"加载模型" 2. **上传影像**:分别上传第一期和第二期影像 3. **执行检测**:点击"执行变化检测"按钮 4. **查看结果**:查看变化检测结果和统计信息 5. **导出结果**:下载可视化结果或地理坐标文件 ### 支持格式: - 图像格式:PNG, JPG, JPEG, TIF, TIFF - 模型格式:PyTorch (.pth) ### 注意事项: - 确保两期影像尺寸相近 - 模型路径必须正确 - 地理坐标导出需要影像包含地理参考信息 """) if __name__ == "__main__": app = VisualizationApp() app.run()