ai_project_v1/CropLand_CD_module/visualization_app.py

357 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
import os
import cv2
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import tempfile
import shutil
from change_detection import ChangeDetection
from znzh_x import PredictionVisualizer
import argparse
import geopandas as gpd
from io import BytesIO
import base64
class VisualizationApp:
def __init__(self):
self.setup_page()
self.model = None
self.temp_dir = tempfile.mkdtemp()
def setup_page(self):
st.set_page_config(
page_title="变化检测可视化系统",
page_icon="🛰️",
layout="wide",
initial_sidebar_state="expanded"
)
st.title("🛰️ 变化检测可视化系统")
st.markdown("---")
def load_model(self, model_path):
"""加载预训练模型"""
try:
# 创建模拟的args对象
args = argparse.Namespace(
gpu_id="0",
img_size=256,
lr=0.0001,
n_class=2,
batch_size=1
)
self.model = ChangeDetection(args)
self.model.load_model(model_path)
return True
except Exception as e:
st.error(f"模型加载失败: {str(e)}")
return False
def preprocess_image(self, image, target_size=(256, 256)):
"""预处理图像"""
if isinstance(image, Image.Image):
image = np.array(image)
# 调整大小
image = cv2.resize(image, target_size)
# 归一化
image = image.astype(np.float32) / 255.0
# 转换为tensor
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)
return image
def predict_change(self, img1, img2):
"""执行变化检测预测"""
if self.model is None:
return None
try:
# 预处理图像
img1_tensor = self.preprocess_image(img1)
img2_tensor = self.preprocess_image(img2)
# 移动到设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img1_tensor = img1_tensor.to(device)
img2_tensor = img2_tensor.to(device)
# 预测
self.model.model.eval()
with torch.no_grad():
output = self.model.model(img1_tensor, img2_tensor)
if isinstance(output, tuple):
output = output[0]
# 获取预测结果
pred = torch.argmax(output, 1).squeeze().cpu().numpy()
return pred
except Exception as e:
st.error(f"预测失败: {str(e)}")
return None
def create_comparison_plot(self, img1, img2, pred_mask=None):
"""创建对比图"""
if pred_mask is not None:
fig = make_subplots(
rows=2, cols=2,
subplot_titles=["第一期影像", "第二期影像", "变化检测结果", "叠加显示"],
specs=[[{"type": "image"}, {"type": "image"}],
[{"type": "image"}, {"type": "image"}]]
)
# 第一期影像
fig.add_trace(
go.Image(z=img1),
row=1, col=1
)
# 第二期影像
fig.add_trace(
go.Image(z=img2),
row=1, col=2
)
# 变化检测结果
change_mask_colored = np.zeros((*pred_mask.shape, 3), dtype=np.uint8)
change_mask_colored[pred_mask > 0] = [255, 0, 0] # 红色表示变化
fig.add_trace(
go.Image(z=change_mask_colored),
row=2, col=1
)
# 叠加显示
overlay = img2.copy()
if len(overlay.shape) == 3:
overlay[pred_mask > 0] = [255, 0, 0] # 在第二期影像上标红
fig.add_trace(
go.Image(z=overlay),
row=2, col=2
)
else:
fig = make_subplots(
rows=1, cols=2,
subplot_titles=["第一期影像", "第二期影像"],
specs=[[{"type": "image"}, {"type": "image"}]]
)
fig.add_trace(go.Image(z=img1), row=1, col=1)
fig.add_trace(go.Image(z=img2), row=1, col=2)
fig.update_layout(
height=800,
showlegend=False,
title_text="影像对比分析"
)
# 隐藏坐标轴
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
return fig
def calculate_change_statistics(self, pred_mask):
"""计算变化统计信息"""
total_pixels = pred_mask.size
changed_pixels = np.sum(pred_mask > 0)
unchanged_pixels = total_pixels - changed_pixels
change_ratio = (changed_pixels / total_pixels) * 100
return {
"总像素数": total_pixels,
"变化像素数": changed_pixels,
"未变化像素数": unchanged_pixels,
"变化比例": f"{change_ratio:.2f}%"
}
def export_results(self, img1, img2, pred_mask, base_name="result"):
"""导出结果"""
try:
# 保存图像到临时目录
img1_path = os.path.join(self.temp_dir, f"{base_name}_img1.png")
img2_path = os.path.join(self.temp_dir, f"{base_name}_img2.png")
mask_path = os.path.join(self.temp_dir, f"{base_name}_mask.png")
Image.fromarray(img1).save(img1_path)
Image.fromarray(img2).save(img2_path)
Image.fromarray((pred_mask * 255).astype(np.uint8)).save(mask_path)
# 使用PredictionVisualizer生成可视化结果
visualizer = PredictionVisualizer(img1_path, mask_path)
result_img, contours = visualizer.draw_contours_on_image()
# 转换为可下载的格式
result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB))
return result_pil, contours
except Exception as e:
st.error(f"导出失败: {str(e)}")
return None, None
def run(self):
"""运行主界面"""
# 侧边栏配置
with st.sidebar:
st.header("⚙️ 系统配置")
# 模型加载
st.subheader("模型设置")
model_path = st.text_input(
"模型路径",
value="/media/data0/HL/2025-7-6/best_model.pth",
help="输入预训练模型的路径"
)
if st.button("加载模型"):
if os.path.exists(model_path):
if self.load_model(model_path):
st.success("✅ 模型加载成功!")
else:
st.error("❌ 模型加载失败!")
else:
st.error("❌ 模型文件不存在!")
# 显示模型状态
if self.model is not None:
st.success("🟢 模型已就绪")
else:
st.warning("🟡 请先加载模型")
st.markdown("---")
# 参数设置
st.subheader("检测参数")
min_area = st.slider("最小变化区域(像素)", 10, 1000, 100)
thickness = st.slider("轮廓线粗细", 1, 10, 2)
# 主界面
col1, col2 = st.columns(2)
with col1:
st.subheader("📤 第一期影像")
uploaded_img1 = st.file_uploader(
"选择第一期影像",
type=['png', 'jpg', 'jpeg', 'tif', 'tiff'],
key="img1"
)
with col2:
st.subheader("📤 第二期影像")
uploaded_img2 = st.file_uploader(
"选择第二期影像",
type=['png', 'jpg', 'jpeg', 'tif', 'tiff'],
key="img2"
)
# 处理上传的图像
if uploaded_img1 and uploaded_img2:
# 读取图像
img1 = np.array(Image.open(uploaded_img1))
img2 = np.array(Image.open(uploaded_img2))
# 确保图像尺寸一致
if img1.shape != img2.shape:
st.warning("⚠️ 两期影像尺寸不一致,将自动调整")
target_shape = (min(img1.shape[0], img2.shape[0]),
min(img1.shape[1], img2.shape[1]))
img1 = cv2.resize(img1, (target_shape[1], target_shape[0]))
img2 = cv2.resize(img2, (target_shape[1], target_shape[0]))
# 显示原始图像对比
st.subheader("📊 影像对比")
fig_original = self.create_comparison_plot(img1, img2)
st.plotly_chart(fig_original, use_container_width=True)
# 变化检测按钮
if st.button("🔍 执行变化检测", type="primary"):
if self.model is None:
st.error("❌ 请先加载模型!")
else:
with st.spinner("正在进行变化检测..."):
pred_mask = self.predict_change(img1, img2)
if pred_mask is not None:
# 显示检测结果
st.subheader("🎯 变化检测结果")
fig_result = self.create_comparison_plot(img1, img2, pred_mask)
st.plotly_chart(fig_result, use_container_width=True)
# 统计信息
st.subheader("📈 变化统计")
stats = self.calculate_change_statistics(pred_mask)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("总像素数", stats["总像素数"])
with col2:
st.metric("变化像素数", stats["变化像素数"])
with col3:
st.metric("未变化像素数", stats["未变化像素数"])
with col4:
st.metric("变化比例", stats["变化比例"])
# 导出功能
st.subheader("💾 结果导出")
col1, col2 = st.columns(2)
with col1:
if st.button("导出可视化结果"):
result_img, contours = self.export_results(img1, img2, pred_mask)
if result_img is not None:
st.image(result_img, caption="变化检测可视化结果")
# 提供下载链接
buf = BytesIO()
result_img.save(buf, format='PNG')
buf.seek(0)
st.download_button(
label="下载结果图像",
data=buf.getvalue(),
file_name="change_detection_result.png",
mime="image/png"
)
with col2:
if st.button("生成地理坐标"):
st.info("💡 地理坐标导出需要输入影像包含地理参考信息")
# 这里可以添加shapefile导出功能
else:
st.error("❌ 变化检测失败!")
# 使用说明
with st.expander("📖 使用说明"):
st.markdown("""
### 使用步骤:
1. **加载模型**:在侧边栏输入模型路径并点击"加载模型"
2. **上传影像**:分别上传第一期和第二期影像
3. **执行检测**:点击"执行变化检测"按钮
4. **查看结果**:查看变化检测结果和统计信息
5. **导出结果**:下载可视化结果或地理坐标文件
### 支持格式:
- 图像格式PNG, JPG, JPEG, TIF, TIFF
- 模型格式PyTorch (.pth)
### 注意事项:
- 确保两期影像尺寸相近
- 模型路径必须正确
- 地理坐标导出需要影像包含地理参考信息
""")
if __name__ == "__main__":
app = VisualizationApp()
app.run()