357 lines
13 KiB
Python
357 lines
13 KiB
Python
import streamlit as st
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from PIL import Image
|
||
import matplotlib.pyplot as plt
|
||
import plotly.express as px
|
||
import plotly.graph_objects as go
|
||
from plotly.subplots import make_subplots
|
||
import tempfile
|
||
import shutil
|
||
from change_detection import ChangeDetection
|
||
from znzh_x import PredictionVisualizer
|
||
import argparse
|
||
import geopandas as gpd
|
||
from io import BytesIO
|
||
import base64
|
||
|
||
class VisualizationApp:
|
||
def __init__(self):
|
||
self.setup_page()
|
||
self.model = None
|
||
self.temp_dir = tempfile.mkdtemp()
|
||
|
||
def setup_page(self):
|
||
st.set_page_config(
|
||
page_title="变化检测可视化系统",
|
||
page_icon="🛰️",
|
||
layout="wide",
|
||
initial_sidebar_state="expanded"
|
||
)
|
||
|
||
st.title("🛰️ 变化检测可视化系统")
|
||
st.markdown("---")
|
||
|
||
def load_model(self, model_path):
|
||
"""加载预训练模型"""
|
||
try:
|
||
# 创建模拟的args对象
|
||
args = argparse.Namespace(
|
||
gpu_id="0",
|
||
img_size=256,
|
||
lr=0.0001,
|
||
n_class=2,
|
||
batch_size=1
|
||
)
|
||
|
||
self.model = ChangeDetection(args)
|
||
self.model.load_model(model_path)
|
||
return True
|
||
except Exception as e:
|
||
st.error(f"模型加载失败: {str(e)}")
|
||
return False
|
||
|
||
def preprocess_image(self, image, target_size=(256, 256)):
|
||
"""预处理图像"""
|
||
if isinstance(image, Image.Image):
|
||
image = np.array(image)
|
||
|
||
# 调整大小
|
||
image = cv2.resize(image, target_size)
|
||
|
||
# 归一化
|
||
image = image.astype(np.float32) / 255.0
|
||
|
||
# 转换为tensor
|
||
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
|
||
|
||
return image
|
||
|
||
def predict_change(self, img1, img2):
|
||
"""执行变化检测预测"""
|
||
if self.model is None:
|
||
return None
|
||
|
||
try:
|
||
# 预处理图像
|
||
img1_tensor = self.preprocess_image(img1)
|
||
img2_tensor = self.preprocess_image(img2)
|
||
|
||
# 移动到设备
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
img1_tensor = img1_tensor.to(device)
|
||
img2_tensor = img2_tensor.to(device)
|
||
|
||
# 预测
|
||
self.model.model.eval()
|
||
with torch.no_grad():
|
||
output = self.model.model(img1_tensor, img2_tensor)
|
||
if isinstance(output, tuple):
|
||
output = output[0]
|
||
|
||
# 获取预测结果
|
||
pred = torch.argmax(output, 1).squeeze().cpu().numpy()
|
||
|
||
return pred
|
||
except Exception as e:
|
||
st.error(f"预测失败: {str(e)}")
|
||
return None
|
||
|
||
def create_comparison_plot(self, img1, img2, pred_mask=None):
|
||
"""创建对比图"""
|
||
if pred_mask is not None:
|
||
fig = make_subplots(
|
||
rows=2, cols=2,
|
||
subplot_titles=["第一期影像", "第二期影像", "变化检测结果", "叠加显示"],
|
||
specs=[[{"type": "image"}, {"type": "image"}],
|
||
[{"type": "image"}, {"type": "image"}]]
|
||
)
|
||
|
||
# 第一期影像
|
||
fig.add_trace(
|
||
go.Image(z=img1),
|
||
row=1, col=1
|
||
)
|
||
|
||
# 第二期影像
|
||
fig.add_trace(
|
||
go.Image(z=img2),
|
||
row=1, col=2
|
||
)
|
||
|
||
# 变化检测结果
|
||
change_mask_colored = np.zeros((*pred_mask.shape, 3), dtype=np.uint8)
|
||
change_mask_colored[pred_mask > 0] = [255, 0, 0] # 红色表示变化
|
||
|
||
fig.add_trace(
|
||
go.Image(z=change_mask_colored),
|
||
row=2, col=1
|
||
)
|
||
|
||
# 叠加显示
|
||
overlay = img2.copy()
|
||
if len(overlay.shape) == 3:
|
||
overlay[pred_mask > 0] = [255, 0, 0] # 在第二期影像上标红
|
||
|
||
fig.add_trace(
|
||
go.Image(z=overlay),
|
||
row=2, col=2
|
||
)
|
||
else:
|
||
fig = make_subplots(
|
||
rows=1, cols=2,
|
||
subplot_titles=["第一期影像", "第二期影像"],
|
||
specs=[[{"type": "image"}, {"type": "image"}]]
|
||
)
|
||
|
||
fig.add_trace(go.Image(z=img1), row=1, col=1)
|
||
fig.add_trace(go.Image(z=img2), row=1, col=2)
|
||
|
||
fig.update_layout(
|
||
height=800,
|
||
showlegend=False,
|
||
title_text="影像对比分析"
|
||
)
|
||
|
||
# 隐藏坐标轴
|
||
fig.update_xaxes(showticklabels=False)
|
||
fig.update_yaxes(showticklabels=False)
|
||
|
||
return fig
|
||
|
||
def calculate_change_statistics(self, pred_mask):
|
||
"""计算变化统计信息"""
|
||
total_pixels = pred_mask.size
|
||
changed_pixels = np.sum(pred_mask > 0)
|
||
unchanged_pixels = total_pixels - changed_pixels
|
||
|
||
change_ratio = (changed_pixels / total_pixels) * 100
|
||
|
||
return {
|
||
"总像素数": total_pixels,
|
||
"变化像素数": changed_pixels,
|
||
"未变化像素数": unchanged_pixels,
|
||
"变化比例": f"{change_ratio:.2f}%"
|
||
}
|
||
|
||
def export_results(self, img1, img2, pred_mask, base_name="result"):
|
||
"""导出结果"""
|
||
try:
|
||
# 保存图像到临时目录
|
||
img1_path = os.path.join(self.temp_dir, f"{base_name}_img1.png")
|
||
img2_path = os.path.join(self.temp_dir, f"{base_name}_img2.png")
|
||
mask_path = os.path.join(self.temp_dir, f"{base_name}_mask.png")
|
||
|
||
Image.fromarray(img1).save(img1_path)
|
||
Image.fromarray(img2).save(img2_path)
|
||
Image.fromarray((pred_mask * 255).astype(np.uint8)).save(mask_path)
|
||
|
||
# 使用PredictionVisualizer生成可视化结果
|
||
visualizer = PredictionVisualizer(img1_path, mask_path)
|
||
result_img, contours = visualizer.draw_contours_on_image()
|
||
|
||
# 转换为可下载的格式
|
||
result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB))
|
||
|
||
return result_pil, contours
|
||
|
||
except Exception as e:
|
||
st.error(f"导出失败: {str(e)}")
|
||
return None, None
|
||
|
||
def run(self):
|
||
"""运行主界面"""
|
||
# 侧边栏配置
|
||
with st.sidebar:
|
||
st.header("⚙️ 系统配置")
|
||
|
||
# 模型加载
|
||
st.subheader("模型设置")
|
||
model_path = st.text_input(
|
||
"模型路径",
|
||
value="/media/data0/HL/2025-7-6/best_model.pth",
|
||
help="输入预训练模型的路径"
|
||
)
|
||
|
||
if st.button("加载模型"):
|
||
if os.path.exists(model_path):
|
||
if self.load_model(model_path):
|
||
st.success("✅ 模型加载成功!")
|
||
else:
|
||
st.error("❌ 模型加载失败!")
|
||
else:
|
||
st.error("❌ 模型文件不存在!")
|
||
|
||
# 显示模型状态
|
||
if self.model is not None:
|
||
st.success("🟢 模型已就绪")
|
||
else:
|
||
st.warning("🟡 请先加载模型")
|
||
|
||
st.markdown("---")
|
||
|
||
# 参数设置
|
||
st.subheader("检测参数")
|
||
min_area = st.slider("最小变化区域(像素)", 10, 1000, 100)
|
||
thickness = st.slider("轮廓线粗细", 1, 10, 2)
|
||
|
||
# 主界面
|
||
col1, col2 = st.columns(2)
|
||
|
||
with col1:
|
||
st.subheader("📤 第一期影像")
|
||
uploaded_img1 = st.file_uploader(
|
||
"选择第一期影像",
|
||
type=['png', 'jpg', 'jpeg', 'tif', 'tiff'],
|
||
key="img1"
|
||
)
|
||
|
||
with col2:
|
||
st.subheader("📤 第二期影像")
|
||
uploaded_img2 = st.file_uploader(
|
||
"选择第二期影像",
|
||
type=['png', 'jpg', 'jpeg', 'tif', 'tiff'],
|
||
key="img2"
|
||
)
|
||
|
||
# 处理上传的图像
|
||
if uploaded_img1 and uploaded_img2:
|
||
# 读取图像
|
||
img1 = np.array(Image.open(uploaded_img1))
|
||
img2 = np.array(Image.open(uploaded_img2))
|
||
|
||
# 确保图像尺寸一致
|
||
if img1.shape != img2.shape:
|
||
st.warning("⚠️ 两期影像尺寸不一致,将自动调整")
|
||
target_shape = (min(img1.shape[0], img2.shape[0]),
|
||
min(img1.shape[1], img2.shape[1]))
|
||
img1 = cv2.resize(img1, (target_shape[1], target_shape[0]))
|
||
img2 = cv2.resize(img2, (target_shape[1], target_shape[0]))
|
||
|
||
# 显示原始图像对比
|
||
st.subheader("📊 影像对比")
|
||
fig_original = self.create_comparison_plot(img1, img2)
|
||
st.plotly_chart(fig_original, use_container_width=True)
|
||
|
||
# 变化检测按钮
|
||
if st.button("🔍 执行变化检测", type="primary"):
|
||
if self.model is None:
|
||
st.error("❌ 请先加载模型!")
|
||
else:
|
||
with st.spinner("正在进行变化检测..."):
|
||
pred_mask = self.predict_change(img1, img2)
|
||
|
||
if pred_mask is not None:
|
||
# 显示检测结果
|
||
st.subheader("🎯 变化检测结果")
|
||
fig_result = self.create_comparison_plot(img1, img2, pred_mask)
|
||
st.plotly_chart(fig_result, use_container_width=True)
|
||
|
||
# 统计信息
|
||
st.subheader("📈 变化统计")
|
||
stats = self.calculate_change_statistics(pred_mask)
|
||
|
||
col1, col2, col3, col4 = st.columns(4)
|
||
with col1:
|
||
st.metric("总像素数", stats["总像素数"])
|
||
with col2:
|
||
st.metric("变化像素数", stats["变化像素数"])
|
||
with col3:
|
||
st.metric("未变化像素数", stats["未变化像素数"])
|
||
with col4:
|
||
st.metric("变化比例", stats["变化比例"])
|
||
|
||
# 导出功能
|
||
st.subheader("💾 结果导出")
|
||
col1, col2 = st.columns(2)
|
||
|
||
with col1:
|
||
if st.button("导出可视化结果"):
|
||
result_img, contours = self.export_results(img1, img2, pred_mask)
|
||
if result_img is not None:
|
||
st.image(result_img, caption="变化检测可视化结果")
|
||
|
||
# 提供下载链接
|
||
buf = BytesIO()
|
||
result_img.save(buf, format='PNG')
|
||
buf.seek(0)
|
||
|
||
st.download_button(
|
||
label="下载结果图像",
|
||
data=buf.getvalue(),
|
||
file_name="change_detection_result.png",
|
||
mime="image/png"
|
||
)
|
||
|
||
with col2:
|
||
if st.button("生成地理坐标"):
|
||
st.info("💡 地理坐标导出需要输入影像包含地理参考信息")
|
||
# 这里可以添加shapefile导出功能
|
||
else:
|
||
st.error("❌ 变化检测失败!")
|
||
|
||
# 使用说明
|
||
with st.expander("📖 使用说明"):
|
||
st.markdown("""
|
||
### 使用步骤:
|
||
1. **加载模型**:在侧边栏输入模型路径并点击"加载模型"
|
||
2. **上传影像**:分别上传第一期和第二期影像
|
||
3. **执行检测**:点击"执行变化检测"按钮
|
||
4. **查看结果**:查看变化检测结果和统计信息
|
||
5. **导出结果**:下载可视化结果或地理坐标文件
|
||
|
||
### 支持格式:
|
||
- 图像格式:PNG, JPG, JPEG, TIF, TIFF
|
||
- 模型格式:PyTorch (.pth)
|
||
|
||
### 注意事项:
|
||
- 确保两期影像尺寸相近
|
||
- 模型路径必须正确
|
||
- 地理坐标导出需要影像包含地理参考信息
|
||
""")
|
||
|
||
if __name__ == "__main__":
|
||
app = VisualizationApp()
|
||
app.run() |