Compare commits

...

20 Commits

Author SHA1 Message Date
63f240ac3a 针对缺少图传模块,引起的进程挂起问题,优化读流时长参数 2026-02-26 14:45:05 +08:00
liyubo
dd931f6231 两期对比接口,土方量计算接口返回信息调整 2026-02-26 11:00:33 +08:00
146872a4dd 优化拉流丢包的问题,增加宽松策略 2026-02-05 02:26:38 +08:00
89181007c2 优化拉流丢包的问题,增加宽松策略 2026-01-29 05:59:39 +08:00
liyubo
9a09c1e1cf 坡度坡向tif生成 2026-01-29 11:51:20 +08:00
c5eeb87488 增加忽略项目 2026-01-27 14:13:37 +08:00
ee8733a0ce Merge branch 'develop' of http://222.212.85.86:8222/bdzl2/ai_project_v1 into develop 2026-01-27 14:13:01 +08:00
0ce543572b 增加忽略项目 2026-01-27 11:54:08 +08:00
929c670add 增加忽略项目 2026-01-27 11:51:05 +08:00
1656f81fe3 增加忽略项目 2026-01-27 11:46:36 +08:00
dfb89c70a3 增加墙面裂缝魔心,修改目标识别接口bug 2026-01-27 11:44:18 +08:00
a2d3e2e24b 增加墙面裂缝魔心,修改目标识别接口bug 2026-01-27 11:18:05 +08:00
liyubo
0f44df8cec 坡度坡向合并tif生成接口 2026-01-19 10:42:21 +08:00
liyubo
eb6ce0de46 3dtiles地图数据预加载接口,terrain3d_analyzer_color移动到slope_aspect_img 2026-01-19 09:36:22 +08:00
liyubo
8d4db9b6df 坡度坡向api集成到yolo api 2026-01-14 16:15:21 +08:00
eedca6cd50 Merge branch 'develop' of http://222.212.85.86:8222/bdzl2/ai_project_v1 into develop 2026-01-14 07:34:21 +08:00
liyubo
5c865a4418 坡度坡向土方量计算 2026-01-14 11:37:35 +08:00
0e952115c8 ffmpeg 拉流更换为cv2,解决了rtmp拉流延迟3s的问题 2026-01-14 06:53:44 +08:00
fbcc505a88 修改类别字段 2026-01-13 03:04:09 +08:00
b899c4e9de 上传grpc 测试demo 2026-01-05 16:29:39 +08:00
38 changed files with 11380 additions and 7074 deletions

47
.gitignore vendored Normal file
View File

@ -0,0 +1,47 @@
# 忽略所有 .log 文件
*.log
# 忽略特定目录(如 node_modules/
node_modules/
# 忽略本地配置文件(但保留示例文件)
config.local.json
!config.example.json
# 忽略编译输出目录
dist/
build/
test
*test*
*.pyc
__pycache__/
misc.xml
profiles_settings.xml
Project_Default.xml
*test*.py
# 忽略 Python 缓存文件
*.pyo
*.pyd
*.mo
*.so
uvmodule.log
# IDE
.idea/
*.iml
.vscode/
# Temp files
*.tmp
*.swp

3
.idea/.gitignore generated vendored
View File

@ -1,3 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml

3
.idea/misc.xml generated
View File

@ -1,4 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="Black">
<option name="sdkName" value="yolo_tensorrt" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="yolo_tensorrt" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="yolo_tensorrt" project-jdk-type="Python SDK" />
</project> </project>

View File

@ -1,100 +0,0 @@
import os
import shutil
import cv2
import collections
from ultralytics import YOLO
from miniohelp import downFile, upload_file, parse_minio_url # 确保你有这些工具函数
from minio import Minio
def process_images(yolo_model, image_list, class_filter, input_folder, output_folder, minio_info):
# 初始化 MinIO 客户端# 用配置字典初始化 Minio 客户端对象
# 清洗 endpoint去掉 http:// 或 https:// 前缀
endpoint = minio_info["MinIOEndpoint"].replace("http://", "").replace("https://", "")
# 初始化 MinIO 客户端
minio = Minio(
endpoint=endpoint,
access_key=minio_info["MinIOAccessKey"],
secret_key=minio_info["MinIOSecretKey"],
secure=False
)
os.makedirs(input_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)
model = YOLO(yolo_model)
class_ids_filter = [int(cls) for cls in class_filter.split(",")] if class_filter else None
output_image_list = []
for item in image_list:
img_id = item["id"]
img_url = item["path"]
# 解析 MinIO 地址
if img_url.startswith("http"):
bucket_name, img_path = parse_minio_url(img_url)
else:
bucket_name, img_path = "default-bucket", img_url
try:
# 下载原图到本地
local_input_path = os.path.join(input_folder, os.path.basename(img_path))
downFile(minio, img_path, bucket_name, local_input_path)
# 读取图像
image = cv2.imread(local_input_path)
if image is None:
raise ValueError(f"无法读取图像: {local_input_path}")
# YOLO 检测
results = model.predict(image,
classes=class_ids_filter,
conf=0.5,
iou = 0.111,
show_labels = False,)
result = results[0]
# 统计类别数
class_counts = collections.Counter(result.boxes.cls.cpu().numpy().astype(int)) if result.boxes is not None else {}
filtered_class_counts = {k: v for k, v in class_counts.items() if k in class_ids_filter}
# 转换所有的 numpy.int64 为 Python 的 int 类型
detected_classes = [int(cls) for cls in filtered_class_counts.keys()]
detected_numbers = [int(num) for num in filtered_class_counts.values()]
aim = bool(detected_classes)
# 保存标注图像
annotated_image = result.plot(labels=False)
filename_no_ext, ext = os.path.splitext(os.path.basename(img_path))
output_filename = f"{filename_no_ext}_ai{ext}"
local_output_path = os.path.join(output_folder, output_filename)
cv2.imwrite(local_output_path, annotated_image)
# 上传标注图像到 MinIO
minio_path = upload_file(minio, local_output_path, bucket_name, os.path.dirname(img_path))
except Exception as e:
print(f"[错误] 处理失败 - {img_path},错误: {str(e)}")
detected_classes = []
detected_numbers = []
aim = False
output_filename = ""
minio_path = ""
output_image_list.append({
"id": img_id,
"minio_path":minio_path,
"aim": aim,
"class": detected_classes,
"number": detected_numbers
})
# 清理临时目录
shutil.rmtree(input_folder, ignore_errors=True)
shutil.rmtree(output_folder, ignore_errors=True)
return {
"status": "success",
"message": "Detection completed",
"data": output_image_list
}

View File

@ -1,505 +0,0 @@
from sanic import Sanic, json, Blueprint,response
from sanic.exceptions import Unauthorized, NotFound
from sanic.response import json as json_response
from sanic_cors import CORS
from datetime import datetime
import logging
import uuid
import os
import asyncio
from minio import Minio
from ai_image import process_images # 你实现的图片处理函数
from queue import Queue
import gdal2tiles as gdal2tiles
from map_find import map_process_images
from yolo_train import auto_train
from map_cut import process_tiling
from cv_video_counter import start_video_session,switch_model_session,stop_video_session,stream_sessions
import torch
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
###################################################################################验证中间件和管理件##############################################################################################
async def token_and_resource_check(request):
# --- Token 验证 ---
token = request.headers.get('X-API-Token')
expected_token = request.app.config.get("VALID_TOKEN")
if not token or token != expected_token:
logger.warning(f"Unauthorized request with token: {token}")
raise Unauthorized("Invalid token")
# --- GPU 使用率检查 ---
try:
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
for i in range(num_gpus):
used = torch.cuda.memory_reserved(i)
total = torch.cuda.max_memory_reserved(i)
ratio = used / total if total else 0
logger.info(f"GPU {i} Usage: {ratio:.2%}")
if ratio > max_usage_ratio:
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
return json_response({
"status": "error",
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
}, status=503)
except Exception as e:
logger.error(f"GPU check failed: {e}")
return None # 允许请求继续
##################################################################################################################################################################################################
#创建Sanic应用
app = Sanic("ai_Service_v2")
CORS(app) # 允许跨域请求
task_progress = {}
@app.middleware("request")
async def global_middleware(request):
result = await token_and_resource_check(request)
if result:
return result
# 配置Token和最大GPU使用率
app.config.update({
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
"MAX_GPU_USAGE": 0.9
})
######################################################################地图切割相关的API########################################################################################################
#创建地图的蓝图
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
app.blueprint(map_tile_blueprint)
@map_tile_blueprint.post("/tile")
async def map_tile_api(request):
try:
# 1. 检查请求体
if not request.json:
return json_response(
{"status": "error", "message": "Request body is required"},
status=400
)
# 2. 解析必要字段
tile_data = request.json
tif_url = tile_data.get("tif_url")
prj_url = tile_data.get("prj_url")
if not tif_url or not prj_url:
return json_response(
{"status": "error", "message": "Both tif_url and prj_url are required"},
status=400
)
# 3. 处理业务逻辑(直接调用协程函数,不要用 asyncio.run
zoom_levels = tile_data.get("zoom_levels", "1-18")
try:
# 假设 process_tiling 是一个协程函数async def
result = await process_tiling(tif_url, prj_url, zoom_levels)
# 如果 process_tiling 是普通函数,用 asyncio.to_thread 包装
# result = await asyncio.to_thread(process_tiling, tif_url, prj_url, zoom_levels)
return json_response({
"status": "success",
"data": result
})
except Exception as processing_error:
logger.error(f"Processing failed: {str(processing_error)}", exc_info=True)
return json_response(
{"status": "error", "message": f"Processing error: {str(processing_error)}"},
status=500
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response(
{"status": "error", "message": str(e)}, # 直接返回字符串,不要用集合
status=500
)
#语义识别
@map_tile_blueprint.post("/uav")
async def process_handler(request):
"""
接口/map/uav
输入 JSON:
{
"urls": [
"http://example.com/img1.jpg",
"http://example.com/img2.jpg"
],
"yaml_name": "112.44.103.230",
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
"bucket_directory": "2025/seg"
"model_path": "deeplabv3plus_best.pth"
}
输出 JSON:
{
"code": 200,
"msg": "success",
"data": [
"http://minio.example.com/uav-results/2025/seg/result1.png",
"http://minio.example.com/uav-results/2025/seg/result2.png"
]
}
"""
try:
body = request.json
urls = body.get("urls", [])
yaml_name = body.get("yaml_name")
bucket_name = body.get("bucket_name")
bucket_directory = body.get("bucket_directory")
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
# 校验参数
if not urls or not isinstance(urls, list):
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
if not all([yaml_name, bucket_name, bucket_directory]):
return json({"code": 400, "msg": "Missing required parameters"})
# 调用图像处理函数
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
return json(result)
except Exception as e:
return json({"code": 500, "msg": f"Server error: {str(e)}"})
######################################################################yolo相关的API########################################################################################################
#创建yolo的蓝图
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
app.blueprint(yolo_tile_blueprint)
# YOLO URL APT
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
@yolo_tile_blueprint.post("/process_images")
async def process_images(request):
"""
{
"urls": [
"http://example.com/image1.jpg",
"http://example.com/image2.jpg",
"http://example.com/image3.jpg"
],
"yaml_name": "your_minio_config",
"bucket_name": "my-bucket",
"bucket_directory": "2025/uav-results",
"model_path": "deeplabv3plus_best.pth"
}
"""
data = request.json
urls = data.get("urls")
yaml_name = data.get("yaml_name")
bucket_name = data.get("bucket_name")
bucket_directory = data.get("bucket_directory")
uav_model_path = data.get("uav_model_path")
if not urls or not yaml_name or not bucket_name or not uav_model_path:
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
task_id = str(uuid.uuid4())
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
# 启动后台任务
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
@yolo_tile_blueprint.get("/task_status/<task_id>")
async def task_status(request, task_id):
progress = task_progress.get(task_id)
if not progress:
return response.json({"code": 404, "msg": "Task not found"}, status=404)
return response.json({"code": 200, "msg": "Task status", "data": progress})
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
try:
task_progress[task_id]["status"] = "running"
task_progress[task_id]["progress"] = 10 # 开始进度
# 下载、推理、上传阶段分别更新进度
def progress_callback(stage, percent):
task_progress[task_id]["status"] = stage
task_progress[task_id]["progress"] = percent
result = await asyncio.to_thread(
map_process_images_with_progress,
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
)
task_progress[task_id]["status"] = "completed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = result
except Exception as e:
task_progress[task_id]["status"] = "failed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = str(e)
# YOLO检测API
@yolo_tile_blueprint.post("/picture")
async def yolo_detect_api(request):
try:
detect_data = request.json
# 解析必要字段
image_list = detect_data.get("image_list")
yolo_model = detect_data.get("yolo_model", "best.pt")
class_filter = detect_data.get("class", None)
minio_info = detect_data.get("minio", None)
if not image_list:
return json_response({"status": "error", "message": "image_list is required"}, status=400)
if not minio_info:
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
# 创建临时文件夹
input_folder = f"./temp_input_{str(uuid.uuid4())}"
output_folder = f"./temp_output_{str(uuid.uuid4())}"
# 执行图像处理
result = await asyncio.to_thread(
process_images,
yolo_model=yolo_model,
image_list=image_list,
class_filter=class_filter,
input_folder=input_folder,
output_folder=output_folder,
minio_info=minio_info
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
# YOLO自动训练
@yolo_tile_blueprint.post("/train")
async def yolo_train_api(request):
"""
自动训练模型
输入 JSON:
{
"db_host": str,
"db_database": str,
"db_user": str,
"db_password": str,
"db_port": int,
"model_id": int,
"img_path": str,
"label_path": str,
"new_path": str,
"split_list": List[float],
"class_names": Optional[List[str]],
"project_name": str
}
输出 JSON:
{
"base_metrics": Dict[str, float],
"best_model_path": str,
"final_metrics": Dict[str, float]
}
"""
try:
# 修改为直接访问 request.json 而不是调用它
data = request.json
if not data:
return json_response({"status": "error", "message": "data is required"}, status=400)
# 执行图像处理
result = await asyncio.to_thread(
auto_train,
data
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
###########################################################################################视频流相关的API#######################################################################################################
#创建视频流的蓝图
stream_tile_blueprint = Blueprint('stream', url_prefix='/stream_test/')
app.blueprint(stream_tile_blueprint)
#
# 任务管理器
class StreamTaskManager:
def __init__(self):
self.active_tasks = {}
self.task_status = {}
self.task_timestamps = {}
self.task_queue = Queue(maxsize=10)
def add_task(self, task_id: str, task_info: dict) -> None:
if self.task_queue.full():
oldest_task_id = self.task_queue.get()
self.remove_task(oldest_task_id)
stop_video_session(self.active_tasks[oldest_task_id]["session_id"])
self.active_tasks[task_id] = task_info
self.task_status[task_id] = "running"
self.task_timestamps[task_id] = datetime.now()
self.task_queue.put(task_id)
logger.info(f"Task {task_id} started")
def remove_task(self, task_id: str) -> None:
if task_id in self.active_tasks:
del self.active_tasks[task_id]
del self.task_status[task_id]
del self.task_timestamps[task_id]
logger.info(f"Task {task_id} removed")
def get_task_info(self, task_id: str) -> dict:
if task_id not in self.active_tasks:
raise NotFound("Task not found")
return {
"task_info": self.active_tasks[task_id],
"status": self.task_status[task_id],
"start_time": self.task_timestamps[task_id].isoformat()
}
task_manager = StreamTaskManager()
# ---------- API Endpoints ----------
@stream_tile_blueprint.post("/start")
async def api_start(request):
"""
启动视频流会话
输入 JSON:
{
"video_path": str,
"output_url": str,
"model_path": str,
"cls": List[int],
"confidence": float,
"cls2": Optional[List[int]]
"push": bool
}
输出 JSON:
{
"session_id": str,
"task_id": str,
"message": "started"
}
"""
data = request.json
task_id = str(uuid.uuid4())
# 启动视频处理会话,并传入 task_id
session_id = start_video_session(
video_path = data.get("video_path"),
output_url = data.get("output_url"),
model_path = data.get("model_path"),
cls = data.get("cls"),
confidence = data.get("confidence", 0.5),
cls2 = data.get("cls2", []),
push = data.get("push", False),
)
# 注册到任务管理器
task_manager.add_task(task_id, {
"session_id": session_id,
"video_path": data.get("video_path"),
"output_url": data.get("output_url"),
"model_path": data.get("model_path"),
"class_filter": data.get("cls", []),
"push": data.get("push", False),
"start_time": datetime.now().isoformat()
})
return json({"session_id": session_id, "task_id": task_id, "message": "started"})
@stream_tile_blueprint.post("/stop")
async def api_stop(request):
"""
停止指定会话
输入 JSON: { "session_id": str }
输出 JSON: { "session_id": str, "message": "stopped" }
"""
session_id = request.json.get("session_id")
stop_video_session(session_id)
# 同步移除任务
for tid, info in list(task_manager.active_tasks.items()):
if info.get("session_id") == session_id:
task_manager.remove_task(tid)
break
return json({"session_id": session_id, "message": "stopped"})
@stream_tile_blueprint.post("/switch_model")
async def api_switch_model(request):
"""
切换会话模型
输入 JSON: { "session_id": str, "new_model_path": str }
输出 JSON: { "session_id": str, "new_model_path": str, "message": "model switched" }
"""
data = request.json
session_id = data.get("session_id")
new_model = data.get("new_model_path")
switch_model_session(session_id, new_model)
return json({"session_id": session_id, "new_model_path": new_model, "message": "model switched"})
@stream_tile_blueprint.get("/sessions")
async def api_list_sessions(request):
"""
列出所有当前会话
输出 JSON: { "sessions": [{"session_id": str, "status": "running"}, ...] }
"""
sessions = [
{"session_id": sid, "status": "running"}
for sid in stream_sessions.keys()
]
return json({"sessions": sessions})
# 统一的任务查询接口(含视频流)
@stream_tile_blueprint.get("/tasks")
async def api_list_tasks(request):
"""
列出所有任务含状态开始时间详情
"""
tasks = []
for tid in task_manager.active_tasks:
info = task_manager.get_task_info(tid)
tasks.append({"task_id": tid, **info})
return json({"tasks": tasks})
##################################################################################################################################################################################################
if __name__ == '__main__':
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)

60
b3dm/b3dm_api.py Normal file
View File

@ -0,0 +1,60 @@
from sanic import Sanic, Request, json
from sanic_cors import CORS
import logging
import time
from earthwork_api import earthwork_bp
from terrain_api import terrain_bp
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建Sanic应用
app = Sanic("TerrainAnalysisAPI")
# 显式注册蓝图
app.blueprint(earthwork_bp)
app.blueprint(terrain_bp)
CORS(app, automatic_options=True)
# 中间件:请求计时
@app.middleware("request")
async def add_start_time(request: Request):
request.ctx.start_time = time.time()
@app.middleware("response")
async def add_response_time(request: Request, response):
if hasattr(request.ctx, "start_time"):
process_time = (time.time() - request.ctx.start_time) * 1000
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
@app.get("/api/v1/health")
async def health_check(request: Request):
"""健康检查"""
return json({
"status": "healthy",
"timestamp": time.time(),
"service": "terrain-analysis-api",
"version": "1.0.0"
})
# 错误处理
@app.exception(Exception)
async def handle_exception(request: Request, exception):
"""全局异常处理"""
logger.error(f"未处理的异常: {exception}")
return json({
"error": "服务器内部错误",
"message": str(exception) if app.debug else "请稍后重试",
"success": False
}, status=500)
if __name__ == "__main__":
# 启动服务器
app.run(
host="0.0.0.0",
port=8000,
debug=True, # 生产环境设为False
access_log=True,
auto_reload=True
)

View File

@ -0,0 +1,542 @@
from minio import Minio
from minio.error import S3Error
import json
import os
import numpy as np
from urllib.parse import urlparse
import hashlib
import time
import re
import pickle
from datetime import datetime
class MinIO3DTilesManager:
def __init__(self, endpoint_url, access_key, secret_key, secure=False,
mapping_file="minio_path_mapping.pkl"):
"""
初始化MinIO客户端
Args:
endpoint_url: MinIO服务地址 (: 222.212.85.86:9001)
access_key: 访问密钥
secret_key: 秘密密钥
secure: 是否使用HTTPS
mapping_file: 路径映射文件名
"""
if endpoint_url.startswith('http://'):
endpoint_url = endpoint_url.replace('http://', '')
elif endpoint_url.startswith('https://'):
endpoint_url = endpoint_url.replace('https://', '')
secure = True
self.endpoint_url = endpoint_url
self.access_key = access_key
self.secret_key = secret_key
self.minio_client = Minio(
endpoint_url,
access_key=access_key,
secret_key=secret_key,
secure=secure
)
# 获取脚本所在目录
self.script_dir = os.path.dirname(os.path.abspath(__file__))
# 映射文件路径
self.mapping_file = os.path.join(self.script_dir, mapping_file)
# 加载现有的路径映射
self.path_mapping = self.load_path_mapping()
def load_path_mapping(self):
"""加载路径映射数据"""
if os.path.exists(self.mapping_file):
try:
with open(self.mapping_file, 'rb') as f:
mapping = pickle.load(f)
return mapping
except Exception as e:
return {}
else:
return {}
def save_path_mapping(self):
"""保存路径映射数据"""
try:
with open(self.mapping_file, 'wb') as f:
pickle.dump(self.path_mapping, f)
return True
except Exception as e:
return False
def get_cache_key(self, tileset_url, save_dir=None):
"""生成缓存键"""
# 基于URL和保存目录生成缓存键
cache_data = f"{tileset_url}|{save_dir}"
return hashlib.md5(cache_data.encode()).hexdigest()
def get_cached_tileset_info(self, tileset_url, save_dir=None):
"""获取缓存的tileset信息"""
cache_key = self.get_cache_key(tileset_url, save_dir)
# 检查缓存映射中是否有这个tileset
for file_id, info in self.path_mapping.items():
if info.get('cache_key') == cache_key and info.get('is_tileset_root'):
# 检查入口文件是否存在
local_path = info.get('local_path')
if local_path and os.path.exists(local_path):
return local_path
return None
def update_tileset_cache(self, tileset_url, save_dir, local_path):
"""更新tileset缓存信息"""
cache_key = self.get_cache_key(tileset_url, save_dir)
# 将tileset根文件标记为缓存
entry_bucket, entry_path = self.parse_minio_url(tileset_url)
file_id = f"{entry_bucket}/{entry_path}"
if file_id in self.path_mapping:
self.path_mapping[file_id]['cache_key'] = cache_key
self.path_mapping[file_id]['is_tileset_root'] = True
self.path_mapping[file_id]['tileset_url'] = tileset_url
self.path_mapping[file_id]['save_dir'] = save_dir
self.path_mapping[file_id]['cache_time'] = datetime.now().isoformat()
def download_full_tileset(self, tileset_url, save_dir=None, region_filter=None, use_cache=True):
"""
下载完整的3D Tiles数据集支持缓存功能
Args:
tileset_url: MinIO上的tileset.json URL
save_dir: 本地保存目录
region_filter: 区域过滤器
use_cache: 是否使用缓存
Returns:
tuple: (success, result)
- success: True/False
- result: 如果success=True且use_cache=True返回本地路径否则返回True/False
"""
if save_dir is None:
save_dir = os.path.join(self.script_dir, "data_3dtiles")
# 清理保存目录名称
save_dir = self.clean_file_path(save_dir)
# 检查缓存:只需检查入口文件是否存在
if use_cache:
cached_path = self.get_cached_tileset_info(tileset_url, save_dir)
if cached_path:
# 入口文件存在,默认缓存完备
return True, cached_path
# 解析URL
entry_bucket, entry_path = self.parse_minio_url(tileset_url)
if not entry_bucket or not entry_path:
return False, "无法解析URL"
entry_dir = os.path.dirname(entry_path)
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
visited = set()
# 下载入口文件
entry_local_path = self.get_local_path(entry_bucket, entry_path, save_dir)
success, result = self.download_file(entry_bucket, entry_path, entry_local_path)
if not success:
return False, f"入口文件下载失败: {result}"
entry_id = f"{entry_bucket}/{entry_path}"
visited.add(entry_id)
# 加载tileset数据
tileset_data = self.load_json_from_minio(entry_bucket, entry_path)
if not tileset_data or "root" not in tileset_data:
return False, "无效的tileset.json文件"
# 遍历下载所有文件
self.traverse_and_download_tileset(
tileset_data["root"],
entry_bucket,
entry_dir,
entry_bucket,
entry_dir,
save_dir,
region_filter,
None,
visited
)
# 更新缓存信息
self.update_tileset_cache(tileset_url, save_dir, entry_local_path)
# 保存路径映射
self.save_path_mapping()
if use_cache:
return True, entry_local_path
else:
return True, True
def get_tileset_local_path(self, tileset_url, save_dir=None):
"""
获取已缓存的tileset本地路径
Args:
tileset_url: tileset的URL
save_dir: 保存目录
Returns:
str: 本地路径如果未缓存则返回None
"""
if save_dir is None:
save_dir = os.path.join(self.script_dir, "data_3dtiles")
return self.get_cached_tileset_info(tileset_url, save_dir)
def clear_tileset_cache(self, tileset_url=None, save_dir=None):
"""
清除tileset缓存
Args:
tileset_url: 指定要清除的tileset URL如果为None则清除所有
save_dir: 保存目录
Returns:
bool: 成功/失败
"""
try:
if tileset_url:
# 清除指定tileset的缓存
cache_key = self.get_cache_key(tileset_url, save_dir)
# 找出所有相关的缓存条目
to_remove = []
for file_id, info in self.path_mapping.items():
if info.get('cache_key') == cache_key:
to_remove.append(file_id)
# 删除这些条目
for file_id in to_remove:
del self.path_mapping[file_id]
print(f"已清除tileset缓存: {tileset_url}")
else:
# 清除所有缓存
self.path_mapping = {}
if os.path.exists(self.mapping_file):
os.remove(self.mapping_file)
print("已清除所有缓存")
return True
except Exception as e:
return False
# 以下是原有的辅助方法
def clean_filename(self, filename):
"""清理文件名中的特殊字符"""
if not filename:
return ""
cleaned = re.sub(r'[<>:"/\\|?*\x00-\x1F]', '_', filename)
cleaned = re.sub(r'_+', '_', cleaned)
cleaned = cleaned.strip(' _')
return cleaned
def parse_minio_url(self, url):
"""解析MinIO URL"""
if url.startswith('http://') or url.startswith('https://'):
parsed = urlparse(url)
path = parsed.path.lstrip('/')
parts = path.split('/', 1)
if len(parts) == 2:
bucket, key = parts
else:
bucket = parts[0]
key = ""
return bucket, key
else:
parts = url.split('/', 1)
if len(parts) == 2:
bucket, key = parts
else:
bucket = parts[0]
key = ""
return bucket, key
def download_file(self, bucket_name, object_name, file_path):
"""从MinIO下载文件"""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# 清理文件名
clean_file_path = self.clean_file_path(file_path)
# 检查是否已下载
file_id = f"{bucket_name}/{object_name}"
if file_id in self.path_mapping:
mapped_path = self.path_mapping[file_id]['local_path']
if os.path.exists(mapped_path):
return True, mapped_path
# 下载文件
self.minio_client.fget_object(
bucket_name,
object_name,
clean_file_path
)
# 更新路径映射
self.path_mapping[file_id] = {
'local_path': clean_file_path,
'bucket': bucket_name,
'object': object_name,
'download_time': datetime.now().isoformat(),
'size': os.path.getsize(clean_file_path)
}
return True, clean_file_path
except S3Error as e:
return False, str(e)
except Exception as e:
return False, str(e)
def clean_file_path(self, file_path):
"""清理文件路径中的所有特殊字符"""
dir_name = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
if dir_name:
dir_parts = dir_name.split(os.sep)
cleaned_parts = []
for part in dir_parts:
cleaned_part = self.clean_filename(part)
if cleaned_part:
cleaned_parts.append(cleaned_part)
cleaned_dir = os.sep.join(cleaned_parts)
else:
cleaned_dir = ""
cleaned_file = self.clean_filename(file_name)
if cleaned_dir:
cleaned_path = os.path.join(cleaned_dir, cleaned_file)
else:
cleaned_path = cleaned_file
return cleaned_path
def load_json_from_minio(self, bucket_name, object_name):
"""从MinIO加载JSON文件"""
try:
self.minio_client.stat_object(bucket_name, object_name)
response = self.minio_client.get_object(bucket_name, object_name)
content = response.read().decode('utf-8')
response.close()
response.release_conn()
return json.loads(content)
except S3Error as e:
return None
except Exception as e:
return None
def get_local_path(self, bucket_name, object_name, save_dir):
"""生成保持目录结构的本地路径"""
clean_bucket = self.clean_filename(bucket_name)
path_parts = object_name.split('/')
cleaned_parts = []
for part in path_parts:
cleaned_part = self.clean_filename(part)
if cleaned_part:
cleaned_parts.append(cleaned_part)
if cleaned_parts:
cleaned_relative = '/'.join(cleaned_parts)
local_path = os.path.join(save_dir, clean_bucket, cleaned_relative)
else:
local_path = os.path.join(save_dir, clean_bucket)
return os.path.normpath(local_path)
def traverse_and_download_tileset(self, tile_obj, current_bucket, current_dir,
base_bucket, base_dir, save_dir,
region_filter=None, parent_transform=None,
visited=None):
"""递归遍历并下载3D Tiles文件"""
if visited is None:
visited = set()
current_transform = parent_transform
if "transform" in tile_obj:
tile_mat = tile_obj["transform"]
if current_transform is None:
current_transform = tile_mat
else:
mat1 = np.array(current_transform).reshape(4, 4)
mat2 = np.array(tile_mat).reshape(4, 4)
combined_mat = np.dot(mat1, mat2).flatten().tolist()
current_transform = combined_mat
skip_current_tile = False
if region_filter and "boundingVolume" in tile_obj:
if not region_filter.check_tile_bounding_volume(tile_obj["boundingVolume"]):
skip_current_tile = True
if not skip_current_tile and "content" in tile_obj and "uri" in tile_obj["content"]:
tile_uri = tile_obj["content"]["uri"]
file_bucket = current_bucket
file_path = ""
if tile_uri.startswith('http://') or tile_uri.startswith('https://'):
parsed_bucket, parsed_path = self.parse_minio_url(tile_uri)
if parsed_bucket:
file_bucket = parsed_bucket
file_path = parsed_path
else:
if current_dir:
file_path = os.path.join(current_dir, tile_uri).replace('\\', '/')
else:
file_path = tile_uri
file_path = file_path.lstrip('/')
file_id = f"{file_bucket}/{file_path}"
if file_id not in visited:
print(f"下载文件:{file_id}")
visited.add(file_id)
local_path = self.get_local_path(file_bucket, file_path, save_dir)
self.download_file(file_bucket, file_path, local_path)
if file_path.lower().endswith('.json'):
sub_tileset = self.load_json_from_minio(file_bucket, file_path)
if sub_tileset and "root" in sub_tileset:
sub_dir = os.path.dirname(file_path) if file_path else ""
self.traverse_and_download_tileset(
sub_tileset["root"],
file_bucket,
sub_dir,
base_bucket,
base_dir,
save_dir,
region_filter,
current_transform,
visited
)
if "children" in tile_obj:
for child_tile in tile_obj["children"]:
self.traverse_and_download_tileset(
child_tile,
current_bucket,
current_dir,
base_bucket,
base_dir,
save_dir,
region_filter,
current_transform,
visited
)
def upload_file(self, bucket_name, object_name, file_path):
"""上传文件到MinIO"""
try:
if not os.path.exists(file_path):
return False, f"文件不存在: {file_path}"
file_size = os.path.getsize(file_path)
self.minio_client.fput_object(bucket_name, object_name, file_path)
return True, f"{bucket_name}/{object_name}"
except S3Error as e:
return False, f"MinIO上传错误: {e}"
except Exception as e:
return False, f"上传失败: {str(e)}"
def upload_directory(self, bucket_name, local_dir, remote_prefix=""):
"""上传目录到MinIO"""
if not os.path.exists(local_dir):
return [], [f"目录不存在: {local_dir}"]
uploaded_files = []
failed_files = []
for root, dirs, files in os.walk(local_dir):
for file in files:
local_path = os.path.join(root, file)
rel_path = os.path.relpath(local_path, local_dir)
if remote_prefix:
remote_path = os.path.join(remote_prefix, rel_path).replace('\\', '/')
else:
remote_path = rel_path.replace('\\', '/')
success, message = self.upload_file(bucket_name, remote_path, local_path)
if success:
uploaded_files.append(remote_path)
else:
failed_files.append((remote_path, message))
return uploaded_files, failed_files
def check_and_create_bucket(self, bucket_name):
"""检查并创建bucket"""
try:
if not self.minio_client.bucket_exists(bucket_name):
self.minio_client.make_bucket(bucket_name)
return True, f"创建bucket: {bucket_name}"
return True, f"bucket已存在: {bucket_name}"
except S3Error as e:
return False, f"创建bucket失败: {e}"
# 使用示例
if __name__ == "__main__":
# 配置参数
ENDPOINT_URL = "222.212.85.86:9000"
ACCESS_KEY = "WuRenJi"
SECRET_KEY = "WRJ@2024"
# 初始化管理器
manager = MinIO3DTilesManager(
endpoint_url=ENDPOINT_URL,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
# 使用缓存下载tileset
tileset_url = "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/石棉0908/terra_b3dms/tileset.json"
# 第一次下载(会下载到本地)
success, result = manager.download_full_tileset(tileset_url, use_cache=True)
if success:
print(f"下载成功,本地路径: {result}")
# 第二次下载相同URL直接从缓存返回
success, result = manager.download_full_tileset(tileset_url, use_cache=True)
if success:
print(f"从缓存获取,本地路径: {result}")
# 强制重新下载(忽略缓存)
success, result = manager.download_full_tileset(tileset_url, use_cache=False)
if success:
print("强制重新下载成功")
# 获取缓存的本地路径
local_path = manager.get_tileset_local_path(tileset_url)
if local_path:
print(f"缓存的本地路径: {local_path}")

1367
b3dm/data_3dtiles_to_dem.py Normal file

File diff suppressed because it is too large Load Diff

530
b3dm/earthwork_api.py Normal file
View File

@ -0,0 +1,530 @@
# pip install fastapi uvicorn pdal pyvista numpy
from sanic import Blueprint, Request, json
from sanic.response import text
from typing import Dict, Any
import logging
from b3dm.earthwork_calculator_point_cloud import EarthworkCalculatorPointCloud
# 导入计算模块
from b3dm.earthwork_calculator_3d_tiles import EarthworkCalculator3dTiles, AlgorithmType, EarthworkResult3dTiles
from b3dm.tileset_data_source import TilesetDataSource
earthwork_bp = Blueprint("earthwork", url_prefix="")
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 初始化函数
def init_app(url, type = "3dtiles"):
"""初始化应用"""
data_source = None
calculator_3d_tiles = None
calculator_point_cloud = None
try:
# 初始化数据源
data_source = TilesetDataSource(url)
data_source.dowload_map_data(url)
if type == "3dtiles" :
# 初始化计算器-3dTiles
calculator_3d_tiles = EarthworkCalculator3dTiles(data_source)
elif type == "pointcloud" :
# 初始化计算器-点云
calculator_point_cloud = EarthworkCalculatorPointCloud(data_source.tileset_path)
else :
logger.info(f"不支持的3d地图数据格式:{type}")
raise
logger.info("土方量计算器初始化完成")
return {
"data_source":data_source,
"calculator_3d_tiles":calculator_3d_tiles,
"calculator_point_cloud":calculator_point_cloud
}
except ImportError as e:
logger.error(f"依赖库缺失: {str(e)}")
raise
except Exception as e:
logger.error(f"初始化失败: {str(e)}")
raise
# 土方量计算接口-3dTiles
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles")
async def calc_earthwork(request: Request):
"""
土方量计算接口
请求参数示例:
{
"polygonCoords": [
[
115.70440468338526,
30.77363140345639
],
[
115.70443054007985,
30.773510462589584
],
[
115.70459702429197,
30.77360789911405
]
],
"designElevation": 100,
"algorithm": "tin",
"resolution": 1,
"crs": "EPSG:4326",
"interpolationMethod": "linear",
"url": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json"
}
"""
try:
# 1. 接收前端传参
data = request.json
if not data:
return _error_response("请求参数不能为空", 400)
# 2. 提取参数
polygon_coords = data.get("polygonCoords")
design_elevation = data.get("designElevation")
url = data.get("url")
if not polygon_coords:
return _error_response("多边形坐标不能为空", 400)
if design_elevation is None:
return _error_response("设计高程不能为空", 400)
if url is None:
return _error_response("地图不能为空", 400)
# 3. 可选参数
algorithm = data.get("algorithm", "tin")
resolution = data.get("resolution", 1.0)
crs = data.get("crs", "EPSG:4326")
interpolation_method = data.get("interpolationMethod", "linear")
# 4. 参数验证
if len(polygon_coords) < 3:
return _error_response("多边形至少需要3个点", 400)
# 检查多边形是否闭合,如不闭合则自动闭合
if polygon_coords[0] != polygon_coords[-1]:
polygon_coords.append(polygon_coords[0])
# 算法验证
if algorithm not in ["grid", "tin", "prism"]:
return _error_response("算法必须是 grid, tin 或 prism", 400)
# 分辨率验证
if resolution <= 0 or resolution > 100:
return _error_response("分辨率必须在0-100米之间", 400)
# 5. 确保计算器已初始化
app_info = init_app(url)
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
# 6. 执行计算
algorithm_type = AlgorithmType(algorithm)
result = await calculator_3d_tiles.calculate(
polygon_coords=polygon_coords,
design_elevation=design_elevation,
algorithm=algorithm_type,
resolution=resolution,
target_crs=crs,
interpolation_method=interpolation_method
)
# 7. 返回成功响应
res_dict = result.to_dict()
res_dict["calculation_details"] = None
res_dict["elevation_statistics"] = None
res_dict["volume_distribution"] = None
return _success_response(res_dict)
except ValueError as e:
logger.warning(f"参数验证失败: {str(e)}")
return _error_response(f"参数错误: {str(e)}", 400)
except Exception as e:
logger.error(f"计算失败: {str(e)}")
return _error_response(f"服务器内部错误: {str(e)}", 500)
# 两期对比接口-3dTiles
@earthwork_bp.post("/api/v1/calc/twoPhaseComparison")
async def two_phase_comparison(request: Request):
"""
两期对比接口
请求参数示例:
{
"polygonCoords": [
[
115.70440468338526,
30.77363140345639
],
[
115.70443054007985,
30.773510462589584
],
[
115.70459702429197,
30.77360789911405
]
],
"designElevation": 100,
"algorithm": "grid",
"resolution": 1,
"crs": "EPSG:4326",
"interpolationMethod": "linear",
"urlA": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json",
"urlB": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json"
}
"""
try:
# 1. 接收前端传参
data = request.json
if not data:
return _error_response("请求参数不能为空", 400)
# 2. 提取参数
polygon_coords = data.get("polygonCoords")
design_elevation = data.get("designElevation", 1000)
urlA = data.get("urlA")
urlB = data.get("urlB")
if not polygon_coords:
return _error_response("多边形坐标不能为空", 400)
if design_elevation is None:
return _error_response("设计高程不能为空", 400)
if urlA is None or urlB is None :
return _error_response("对比地图不能为空", 400)
# 3. 可选参数
algorithm = data.get("algorithm", "tin")
resolution = data.get("resolution", 1.0)
crs = data.get("crs", "EPSG:4326")
interpolation_method = data.get("interpolationMethod", "linear")
# 4. 参数验证
if len(polygon_coords) < 3:
return _error_response("多边形至少需要3个点", 400)
# 检查多边形是否闭合,如不闭合则自动闭合
if polygon_coords[0] != polygon_coords[-1]:
polygon_coords.append(polygon_coords[0])
# 算法验证
if algorithm not in ["grid", "tin", "prism"]:
return _error_response("算法必须是 grid, tin 或 prism", 400)
# 分辨率验证
if resolution <= 0 or resolution > 100:
return _error_response("分辨率必须在0-100米之间", 400)
# 5. 确保计算器已初始化
app_info_a = init_app(urlA)
if not app_info_a.get('data_source').tileset_path :
return _error_response(f"下载地图失败:{urlA}", 400)
calculator_3d_tiles_a = app_info_a.get("calculator_3d_tiles")
app_info_b = init_app(urlB)
if not app_info_b.get('data_source').tileset_path :
return _error_response(f"下载地图失败:{urlB}", 400)
calculator_3d_tiles_b = app_info_b.get("calculator_3d_tiles")
# 6. 执行计算
algorithm_type = AlgorithmType.GRID
result_a = await calculator_3d_tiles_a.calculate(
polygon_coords=polygon_coords,
design_elevation=design_elevation,
algorithm=algorithm_type,
resolution=resolution,
target_crs=crs,
interpolation_method=interpolation_method
)
result_b = await calculator_3d_tiles_b.calculate(
polygon_coords=polygon_coords,
design_elevation=design_elevation,
algorithm=algorithm_type,
resolution=resolution,
target_crs=crs,
interpolation_method=interpolation_method
)
# 获取网格数据
grids_a = result_a.calculation_details
grids_b = result_b.calculation_details
# 比较网格数据
comparison_result = calculator_3d_tiles_a.compare_grid_cells(grids_a, grids_b)
# 转换为字典
result_dict = comparison_result.to_dict()
# 7. 返回成功响应
return _success_response(result_dict)
except ValueError as e:
logger.warning(f"参数验证失败: {str(e)}")
return _error_response(f"参数错误: {str(e)}", 400)
except Exception as e:
logger.error(f"计算失败: {str(e)}")
return _error_response(f"服务器内部错误: {str(e)}", 500)
# 验证接口
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles/validate")
async def validate_earthwork(request: Request):
"""验证计算参数接口"""
try:
# 1. 接收前端传参
data = request.json
if not data:
return _error_response("请求参数不能为空", 400)
# 2. 提取参数
polygon_coords = data.get("polygonCoords")
if not polygon_coords:
return _error_response("多边形坐标不能为空", 400)
url = data.get("url")
if url is None:
return _error_response("地图不能为空", 400)
# 3. 参数验证
if len(polygon_coords) < 3:
return _error_response("多边形至少需要3个点", 400)
# 检查多边形是否闭合,如不闭合则自动闭合
if polygon_coords[0] != polygon_coords[-1]:
polygon_coords.append(polygon_coords[0])
# 4. 确保计算器已初始化
app_info = init_app(url)
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
# 5. 执行验证
validation_result = await calculator_3d_tiles.validate(polygon_coords)
# 6. 返回结果
return _success_response(validation_result)
except Exception as e:
logger.error(f"验证失败: {str(e)}")
return _error_response(f"验证失败: {str(e)}", 400)
# 获取算法列表接口
@earthwork_bp.get("/api/v1/calc/earthwork3dTiles/algorithms")
async def get_algorithms(request: Request):
"""获取支持的算法列表接口"""
try:
algorithms = [
{
"id": "grid",
"name": "格网法",
"description": "将计算区域划分为规则格网,通过插值计算每个格网的高程变化,适合平坦或规则地形",
"accuracy": "中等",
"performance": "快速",
"parameters": {
"resolution": {
"name": "格网分辨率",
"description": "格网大小(米),影响计算精度和性能",
"default": 1.0,
"range": [0.1, 10.0]
}
}
},
{
"id": "tin",
"name": "三角网法",
"description": "基于不规则三角网TIN构建地形表面计算每个三角形的体积变化适合复杂地形",
"accuracy": "",
"performance": "中等",
"parameters": {
"resolution": {
"name": "不适用",
"description": "三角网法不使用固定的分辨率参数",
"default": None
}
}
},
{
"id": "prism",
"name": "三棱柱法",
"description": "结合三角网和垂直棱柱的高精度算法,计算每个三棱柱的体积,精度最高",
"accuracy": "最高",
"performance": "较慢",
"parameters": {
"resolution": {
"name": "棱柱宽度",
"description": "棱柱宽度(米),影响计算精度",
"default": 1.0,
"range": [0.1, 5.0]
}
}
}
]
return _success_response(algorithms)
except Exception as e:
logger.error(f"获取算法列表失败: {str(e)}")
return _error_response(f"获取算法列表失败: {str(e)}", 500)
# 批量计算接口
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles/batch")
async def batch_calc_earthwork(request: Request):
"""批量土方量计算接口"""
try:
# 1. 接收前端传参
data = request.json
if not data:
return _error_response("请求参数不能为空", 400)
calculations = data.get("calculations", [])
if not calculations:
return _error_response("计算任务列表不能为空", 400)
if len(calculations) > 100:
return _error_response("批量计算数量超过限制最多100个", 400)
# 3. 执行批量计算
results = []
errors = []
for i, calc_data in enumerate(calculations):
try:
# 提取参数
polygon_coords = calc_data.get("polygonCoords")
design_elevation = calc_data.get("designElevation")
url = calc_data.get("url")
if not polygon_coords or design_elevation is None or url is None:
errors.append({
"index": i,
"error": "缺少必要参数"
})
continue
# 参数验证
if len(polygon_coords) < 3:
errors.append({
"index": i,
"error": "多边形至少需要3个点"
})
continue
# 检查多边形是否闭合
if polygon_coords[0] != polygon_coords[-1]:
polygon_coords.append(polygon_coords[0])
# 可选参数
algorithm = calc_data.get("algorithm", "tin")
resolution = calc_data.get("resolution", 1.0)
crs = calc_data.get("crs", "EPSG:4326")
interpolation_method = calc_data.get("interpolationMethod", "linear")
# 2. 确保计算器已初始化
app_info = init_app(url)
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
# 执行计算
algorithm_type = AlgorithmType(algorithm)
result = await calculator_3d_tiles.calculate(
polygon_coords=polygon_coords,
design_elevation=design_elevation,
algorithm=algorithm_type,
resolution=resolution,
target_crs=crs,
interpolation_method=interpolation_method
)
results.append(result)
except Exception as e:
errors.append({
"index": i,
"error": str(e),
"polygon": polygon_coords if 'polygon_coords' in locals() else None
})
continue
# 4. 返回结果
batch_result = {
"results": results,
"errors": errors,
"summary": {
"total": len(calculations),
"success": len(results),
"failed": len(errors),
"successRate": f"{(len(results)/len(calculations)*100):.1f}%" if calculations else "0%"
}
}
message = f"批量计算完成,成功 {len(results)} 个,失败 {len(errors)}"
return _success_response(batch_result)
except Exception as e:
logger.error(f"批量计算失败: {str(e)}")
return _error_response(f"批量计算失败: {str(e)}", 500)
# 核心接口:土方量计算-点云
@earthwork_bp.post("/api/v1/calc/earthworkPointCloud")
async def calc_earthwork_point_cloud(request: Request):
try:
# 1. 接收前端传参
data = request.json
if not data:
return json({
"code": 400,
"msg": "请求参数不能为空",
"data": None
}, status=400)
polygon_coords = data.get("polygonCoords") # 计算区域多边形坐标
design_elev = data.get("designElevation") # 设计高程
crs = data.get("crs", "EPSG:4326") # 坐标系默认WGS84
url = data.get("url")
if url is None:
return _error_response("地图不能为空", 400)
# 2. 确保计算器已初始化
app_info = init_app(url)
calculator_point_cloud = app_info.get("calculator_point_cloud")
result = calculator_point_cloud.calculate_earthwork(polygon_coords=polygon_coords, design_elev=design_elev, crs=crs)
# 3. 处理结果
if not result["success"]:
return _error_response(result["error"], 400)
# 4. 格式化结果
formatted_result = calculator_point_cloud.format_result(result)
# 5. 返回成功响应
return _success_response(formatted_result)
except Exception as e:
logger.error(f"服务器错误: {str(e)}")
return _error_response(f"服务器内部错误: {str(e)}", 500)
def _success_response(data: Dict[str, Any]) -> json:
"""成功响应"""
return json({
"code": 200,
"msg": "计算成功",
"data": data
})
def _error_response(message: str, status_code: int = 400) -> json:
"""错误响应"""
return json({
"code": status_code,
"msg": message,
"data": None
}, status=status_code)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,691 @@
# earthwork_calculator.py
import pdal
import pyvista as pv
import numpy as np
import json
from typing import List, Tuple, Dict, Any, Optional
from dataclasses import dataclass, field
from enum import Enum
import traceback
class EarthworkAlgorithm(Enum):
"""土方量计算算法枚举"""
GRID = "grid"
TIN = "tin"
PRISM = "prism"
@dataclass
class EarthworkResultPointCloud:
"""土方量计算结果"""
cut_volume: float # 挖方量 (m³)
fill_volume: float # 填方量 (m³)
net_volume: float # 净方量 (m³)
area: float # 计算区域面积 (m²)
avg_elevation: float # 平均高程
min_elevation: float # 最低高程
max_elevation: float # 最高高程
points_count: int # 使用的点数
triangle_count: int = 0 # 三角形数量
grid_count: int = 0 # 网格数量仅GRID算法使用
prism_count: int = 0 # 棱柱体数量仅PRISM算法使用
bounding_box: Dict[str, List[float]] = field(default_factory=dict) # 边界框
volume_accuracy: float = 0.95 # 计算精度
algorithm: str = "TIN三角网法" # 使用的算法
resolution: float = 1.0 # 计算分辨率
algorithm_params: Dict[str, Any] = field(default_factory=dict) # 算法参数
def to_dict(self) -> Dict:
"""转换为字典"""
result = {
"volume": {
"cut": round(self.cut_volume, 3),
"fill": round(self.fill_volume, 3),
"net": round(self.net_volume, 3),
"unit": ""
},
"area": {
"value": round(self.area, 3),
"unit": ""
},
"elevation": {
"average": round(self.avg_elevation, 3),
"min": round(self.min_elevation, 3),
"max": round(self.max_elevation, 3),
"unit": "m"
},
"statistics": {
"points_count": self.points_count,
"accuracy": round(self.volume_accuracy, 3),
"algorithm": self.algorithm
},
"bounding_box": self.bounding_box,
"calculation_params": {
"resolution": self.resolution,
"accuracy": self.volume_accuracy,
**self.algorithm_params
}
}
# 根据算法类型添加特定的统计信息
if self.algorithm.startswith("GRID"):
result["statistics"]["grid_count"] = self.grid_count
elif self.algorithm.startswith("TIN"):
result["statistics"]["triangle_count"] = self.triangle_count
elif self.algorithm.startswith("PRISM"):
result["statistics"]["prism_count"] = self.prism_count
return result
class EarthworkCalculatorPointCloud:
"""土方量计算核心类(支持多种算法)"""
def __init__(self, point_cloud_path: str = "./data/your_point_cloud.laz"):
"""
初始化土方量计算器
Args:
point_cloud_path: 点云数据文件路径
"""
self.point_cloud_path = point_cloud_path
def validate_inputs(self, polygon_coords: List[List[float]], design_elev: float) -> Tuple[bool, str]:
"""验证输入参数"""
if not polygon_coords or len(polygon_coords) < 3:
return False, "多边形坐标至少需要3个点"
try:
design_elev = float(design_elev)
except (TypeError, ValueError):
return False, "设计高程必须是有效数字"
return True, ""
def create_polygon_string(self, polygon_coords: List[List[float]]) -> str:
"""创建PDAL多边形字符串"""
coords_list = []
for coord in polygon_coords:
if len(coord) >= 2:
coords_list.append(f"{coord[0]} {coord[1]}")
# 确保多边形闭合
if coords_list and coords_list[0] != coords_list[-1]:
coords_list.append(coords_list[0])
return "POLYGON((" + ", ".join(coords_list) + "))"
def calculate_bounding_box(self, points: np.ndarray) -> Dict[str, List[float]]:
"""
计算边界框
Args:
points: 点云坐标数组
Returns:
Dict: 边界框信息
"""
if len(points) == 0:
return {"min": [0, 0, 0], "max": [0, 0, 0]}
min_vals = np.min(points, axis=0)
max_vals = np.max(points, axis=0)
return {
"min": [float(min_vals[0]), float(min_vals[1]), float(min_vals[2])],
"max": [float(max_vals[0]), float(max_vals[1]), float(max_vals[2])]
}
def clip_point_cloud(self, polygon_coords: List[List[float]], crs: str = "EPSG:4326") -> pv.PolyData:
"""
裁剪点云数据
Args:
polygon_coords: 多边形坐标列表
crs: 坐标系
Returns:
pyvista.PolyData: 裁剪后的点云数据
"""
polygon_str = self.create_polygon_string(polygon_coords)
# PDAL管道配置
pipeline_config = {
"pipeline": [
{
"type": "readers.las",
"filename": self.point_cloud_path,
"spatialreference": crs
},
{
"type": "filters.crop",
"polygon": polygon_str
}
]
}
# 执行PDAL管道
pipeline = pdal.Pipeline(json.dumps(pipeline_config))
try:
pipeline.execute()
if len(pipeline.arrays) == 0:
raise ValueError("多边形区域内没有找到点云数据")
# 获取裁剪后的点云数据
points = pipeline.arrays[0]
x = points["X"]
y = points["Y"]
z = points["Z"]
return pv.PolyData(np.column_stack((x, y, z)))
except RuntimeError as e:
print(f"PDAL执行失败: {str(e)}")
# 如果没有PDAL数据生成模拟数据用于测试
return self.generate_mock_point_cloud(polygon_coords)
def generate_mock_point_cloud(self, polygon_coords: List[List[float]]) -> pv.PolyData:
"""
生成模拟点云数据仅用于测试
Args:
polygon_coords: 多边形坐标列表
Returns:
pyvista.PolyData: 模拟点云数据
"""
print("使用模拟数据进行测试...")
n_points = 1000
# 获取多边形边界
x_coords = [c[0] for c in polygon_coords]
y_coords = [c[1] for c in polygon_coords]
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
# 生成随机点
x = np.random.uniform(x_min, x_max, n_points)
y = np.random.uniform(y_min, y_max, n_points)
z = np.random.uniform(100, 120, n_points) # 模拟高程在100-120米之间
return pv.PolyData(np.column_stack((x, y, z)))
def create_tin_mesh(self, point_cloud: pv.PolyData) -> pv.PolyData:
"""
创建三角网TIN算法使用
Args:
point_cloud: 点云数据
Returns:
pyvista.PolyData: 三角网格
"""
if len(point_cloud.points) < 3:
raise ValueError("点云数据不足,无法构网")
try:
return point_cloud.delaunay_2d()
except Exception as e:
raise ValueError(f"三角网构网失败: {str(e)}")
def calculate_volumes_by_tin(self, mesh: pv.PolyData, design_elev: float) -> Dict[str, Any]:
"""
TIN算法计算土方量
Args:
mesh: 三角网格
design_elev: 设计高程
Returns:
Dict: 包含体积计算结果和额外信息
"""
points = mesh.points
elev_diff = points[:, 2] - design_elev
cut_volume = 0.0
fill_volume = 0.0
triangle_count = 0
# 遍历所有三角形面片计算体积
cells = mesh.cells.reshape(-1, 4)
if len(cells) == 0:
raise ValueError("无法生成有效的三角网")
for cell in cells:
if cell[0] == 3: # 三角形VTK格式3个顶点
triangle_count += 1
vertex_indices = cell[1:]
pts = points[vertex_indices]
# 计算三角形面积
v1 = pts[1] - pts[0]
v2 = pts[2] - pts[0]
area = 0.5 * np.linalg.norm(np.cross(v1, v2))
# 计算平均高程差
avg_diff = np.mean(elev_diff[vertex_indices])
vol = area * avg_diff
if vol > 0:
cut_volume += vol
else:
fill_volume += abs(vol)
return {
"cut_volume": cut_volume,
"fill_volume": fill_volume,
"net_volume": cut_volume - fill_volume,
"triangle_count": triangle_count
}
def calculate_volumes_by_grid(self, point_cloud: pv.PolyData, design_elev: float,
grid_size: float = 1.0) -> Dict[str, Any]:
"""
GRID算法计算土方量
Args:
point_cloud: 点云数据
design_elev: 设计高程
grid_size: 网格尺寸
Returns:
Dict: 包含体积计算结果和额外信息
"""
points = point_cloud.points
if len(points) == 0:
raise ValueError("点云数据为空")
# 计算边界
x_min, y_min = np.min(points[:, :2], axis=0)
x_max, y_max = np.max(points[:, :2], axis=0)
# 创建网格
x_edges = np.arange(x_min, x_max + grid_size, grid_size)
y_edges = np.arange(y_min, y_max + grid_size, grid_size)
grid_count = (len(x_edges) - 1) * (len(y_edges) - 1)
cut_volume = 0.0
fill_volume = 0.0
# 对每个网格计算土方量
for i in range(len(x_edges) - 1):
for j in range(len(y_edges) - 1):
# 获取当前网格内的点
mask = (points[:, 0] >= x_edges[i]) & (points[:, 0] < x_edges[i+1]) & \
(points[:, 1] >= y_edges[j]) & (points[:, 1] < y_edges[j+1])
grid_points = points[mask]
if len(grid_points) > 0:
# 计算网格内点的平均高程
avg_elevation = np.mean(grid_points[:, 2])
# 计算高程差
elev_diff = avg_elevation - design_elev
# 计算体积
cell_area = grid_size * grid_size
vol = cell_area * elev_diff
if vol > 0:
cut_volume += vol
else:
fill_volume += abs(vol)
return {
"cut_volume": cut_volume,
"fill_volume": fill_volume,
"net_volume": cut_volume - fill_volume,
"grid_count": grid_count
}
def calculate_volumes_by_prism(self, point_cloud: pv.PolyData, design_elev: float,
influence_radius: float = 0.5) -> Dict[str, Any]:
"""
PRISM算法计算土方量
Args:
point_cloud: 点云数据
design_elev: 设计高程
influence_radius: 影响半径
Returns:
Dict: 包含体积计算结果和额外信息
"""
points = point_cloud.points
if len(points) == 0:
raise ValueError("点云数据为空")
cut_volume = 0.0
fill_volume = 0.0
# 每个点的影响面积
influence_area = np.pi * influence_radius ** 2
prism_count = len(points)
for point in points:
# 计算高程差
elev_diff = point[2] - design_elev
# 计算体积
vol = influence_area * elev_diff
if vol > 0:
cut_volume += vol
else:
fill_volume += abs(vol)
return {
"cut_volume": cut_volume,
"fill_volume": fill_volume,
"net_volume": cut_volume - fill_volume,
"prism_count": prism_count
}
def calculate_statistics(self, point_cloud: pv.PolyData, mesh: pv.PolyData = None) -> Dict[str, float]:
"""
计算统计数据
Args:
point_cloud: 点云数据
mesh: 三角网格仅TIN算法需要
Returns:
Dict: 统计结果
"""
elevations = point_cloud.points[:, 2]
stats = {
"area": 0.0,
"max_elevation": np.max(elevations) if len(elevations) > 0 else 0.0,
"min_elevation": np.min(elevations) if len(elevations) > 0 else 0.0,
"avg_elevation": np.mean(elevations) if len(elevations) > 0 else 0.0,
"points_count": len(point_cloud.points)
}
# 计算面积
if mesh is not None:
stats["area"] = mesh.area
else:
# 对于非TIN算法使用多边形面积近似
if len(point_cloud.points) > 0:
# 使用点云的凸包面积
try:
hull = point_cloud.delaunay_2d()
stats["area"] = hull.area
except:
# 如果无法计算凸包,使用边界框面积
x_min, x_max = np.min(point_cloud.points[:, 0]), np.max(point_cloud.points[:, 0])
y_min, y_max = np.min(point_cloud.points[:, 1]), np.max(point_cloud.points[:, 1])
stats["area"] = (x_max - x_min) * (y_max - y_min)
return stats
def calculate_earthwork(self,
polygon_coords: List[List[float]],
design_elev: float,
algorithm: str = EarthworkAlgorithm.TIN.value,
algorithm_params: Optional[Dict[str, Any]] = None,
crs: str = "EPSG:4326",
volume_accuracy: Optional[float] = None,
resolution: Optional[float] = None) -> EarthworkResultPointCloud:
"""
主计算方法执行完整的土方量计算流程
Args:
polygon_coords: 多边形坐标列表
design_elev: 设计高程
algorithm: 计算算法可选值'grid', 'tin', 'prism'
algorithm_params: 算法特定参数
crs: 坐标系
volume_accuracy: 计算精度0-1之间
resolution: 计算分辨率
Returns:
EarthworkResultPointCloud: 计算结果
"""
try:
# 1. 验证输入
is_valid, message = self.validate_inputs(polygon_coords, design_elev)
if not is_valid:
raise ValueError(message)
design_elev = float(design_elev)
# 2. 验证算法参数
if algorithm not in [a.value for a in EarthworkAlgorithm]:
raise ValueError(f"不支持的算法: {algorithm}。支持的算法: {[a.value for a in EarthworkAlgorithm]}")
# 3. 设置默认参数
if algorithm_params is None:
algorithm_params = {}
# 4. 裁剪点云
point_cloud = self.clip_point_cloud(polygon_coords, crs)
# 5. 根据算法选择计算方法
algorithm_name = ""
mesh = None
if algorithm == EarthworkAlgorithm.TIN.value:
algorithm_name = "TIN三角网法"
mesh = self.create_tin_mesh(point_cloud)
volumes = self.calculate_volumes_by_tin(mesh, design_elev)
algorithm_params = {
"grid_size": algorithm_params.get("grid_size", 1.0)
}
elif algorithm == EarthworkAlgorithm.GRID.value:
algorithm_name = "GRID格网法"
grid_size = algorithm_params.get("grid_size", 1.0)
algorithm_name = f"GRID格网法(网格尺寸={grid_size}m)"
volumes = self.calculate_volumes_by_grid(point_cloud, design_elev, grid_size)
algorithm_params = {
"grid_size": grid_size
}
elif algorithm == EarthworkAlgorithm.PRISM.value:
algorithm_name = "PRISM棱柱体法"
influence_radius = algorithm_params.get("influence_radius", 0.5)
algorithm_name = f"PRISM棱柱体法(影响半径={influence_radius}m)"
volumes = self.calculate_volumes_by_prism(point_cloud, design_elev, influence_radius)
algorithm_params = {
"influence_radius": influence_radius
}
# 6. 计算统计数据
stats = self.calculate_statistics(point_cloud, mesh)
# 7. 计算边界框
bounding_box = self.calculate_bounding_box(point_cloud.points)
# 8. 计算或使用传入的精度和分辨率
if volume_accuracy is None:
# 根据算法和点云密度自动估算精度
volume_accuracy = self.estimate_accuracy(algorithm, point_cloud)
if resolution is None:
# 根据点云密度自动估算分辨率
resolution = self.estimate_resolution(point_cloud)
# 9. 创建EarthworkResultPointCloud对象
result = EarthworkResultPointCloud(
cut_volume=volumes["cut_volume"],
fill_volume=volumes["fill_volume"],
net_volume=volumes["net_volume"],
area=stats["area"],
avg_elevation=stats["avg_elevation"],
min_elevation=stats["min_elevation"],
max_elevation=stats["max_elevation"],
points_count=stats["points_count"],
triangle_count=volumes.get("triangle_count", 0),
grid_count=volumes.get("grid_count", 0),
prism_count=volumes.get("prism_count", 0),
bounding_box=bounding_box,
volume_accuracy=volume_accuracy,
algorithm=algorithm_name,
resolution=resolution,
algorithm_params=algorithm_params
)
return result
except Exception as e:
print(f"计算错误: {str(e)}")
print(traceback.format_exc())
raise
def estimate_accuracy(self, algorithm: str, point_cloud: pv.PolyData) -> float:
"""根据算法和点云密度估计计算精度"""
point_density = len(point_cloud.points) / max(point_cloud.area, 0.1)
# 基础精度
base_accuracy = {
EarthworkAlgorithm.TIN.value: 0.95,
EarthworkAlgorithm.GRID.value: 0.90,
EarthworkAlgorithm.PRISM.value: 0.85
}.get(algorithm, 0.90)
# 根据点云密度调整精度
if point_density > 10: # 高密度点云
accuracy_boost = min(0.05, point_density * 0.002)
elif point_density < 1: # 低密度点云
accuracy_boost = -0.05
else:
accuracy_boost = 0
estimated_accuracy = base_accuracy + accuracy_boost
# 确保精度在合理范围内
return max(0.7, min(0.99, estimated_accuracy))
def estimate_resolution(self, point_cloud: pv.PolyData) -> float:
"""根据点云密度估计分辨率"""
if len(point_cloud.points) < 2:
return 1.0
area = point_cloud.area
if area > 0:
point_count = len(point_cloud.points)
avg_spacing = np.sqrt(area / point_count)
return float(round(avg_spacing, 2))
return 1.0
def get_algorithm_info(self) -> Dict[str, Any]:
"""
获取支持的算法信息
Returns:
Dict: 算法信息
"""
return {
"supported_algorithms": [
{
"id": EarthworkAlgorithm.TIN.value,
"name": "TIN三角网法",
"description": "通过构建不规则三角网计算土方量,精度高,适合复杂地形",
"parameters": [
{
"name": "grid_size",
"type": "float",
"default": 1.0,
"description": "网格尺寸(m),用于点云预处理",
"required": False
}
]
},
{
"id": EarthworkAlgorithm.GRID.value,
"name": "GRID格网法",
"description": "将区域划分为规则网格计算土方量,计算速度快,适合大规模区域",
"parameters": [
{
"name": "grid_size",
"type": "float",
"default": 1.0,
"description": "网格尺寸(m)",
"required": True,
"min": 0.1,
"max": 10.0
}
]
},
{
"id": EarthworkAlgorithm.PRISM.value,
"name": "PRISM棱柱体法",
"description": "将每个点视为一个棱柱体计算土方量,计算简单快速",
"parameters": [
{
"name": "influence_radius",
"type": "float",
"default": 0.5,
"description": "点的影响半径(m)",
"required": True,
"min": 0.1,
"max": 5.0
}
]
}
],
"default_algorithm": EarthworkAlgorithm.TIN.value
}
# 使用示例
if __name__ == "__main__":
# 创建计算器实例
calculator = EarthworkCalculatorPointCloud("./data/sample_point_cloud.laz")
# 获取支持的算法信息
algorithm_info = calculator.get_algorithm_info()
print("支持的算法:")
for algo in algorithm_info["supported_algorithms"]:
print(f" {algo['id']}: {algo['name']} - {algo['description']}")
# 定义多边形区域
polygon = [
[116.3974, 39.9093],
[116.4084, 39.9093],
[116.4084, 39.9193],
[116.3974, 39.9193]
]
# 设计高程
design_elevation = 100.0
# 测试不同算法
algorithms = [
(EarthworkAlgorithm.TIN.value, {}, "TIN算法"),
(EarthworkAlgorithm.GRID.value, {"grid_size": 2.0}, "GRID算法(2米网格)"),
(EarthworkAlgorithm.PRISM.value, {"influence_radius": 1.0}, "PRISM算法(1米影响半径)")
]
for algo_id, params, description in algorithms:
print(f"\n使用{description}计算:")
try:
result = calculator.calculate_earthwork(
polygon_coords=polygon,
design_elev=design_elevation,
algorithm=algo_id,
algorithm_params=params
)
print(f" 挖方量: {result.cut_volume:.3f}")
print(f" 填方量: {result.fill_volume:.3f}")
print(f" 净方量: {result.net_volume:.3f}")
print(f" 计算面积: {result.area:.3f}")
print(f" 计算精度: {result.volume_accuracy:.3%}")
print(f" 分辨率: {result.resolution:.2f} m")
except Exception as e:
print(f" 计算失败: {str(e)}")

469
b3dm/glb_with_draco.py Normal file
View File

@ -0,0 +1,469 @@
import json
import struct
import numpy as np
import DracoPy
class DracoGLBParser:
"""使用 DracoPy 解析包含 Draco 压缩的 GLB 文件"""
def __init__(self, glb_file_path):
self.glb_file_path = glb_file_path
self.json_data = None
self.binary_data = None
self.decoded_meshes = [] # 缓存解码后的网格数据
def parse_glb_structure(self):
"""解析 GLB 文件结构"""
with open(self.glb_file_path, 'rb') as f:
# 读取 GLB 头部
magic = f.read(4)
version = struct.unpack('<I', f.read(4))[0]
total_length = struct.unpack('<I', f.read(4))[0]
print("=" * 60)
print(f"GLB 文件分析:")
print(f" 文件类型: {magic.decode('utf-8')}")
print(f" 版本: {version}")
print(f" 总大小: {total_length:,} bytes")
# 读取 JSON chunk
json_length = struct.unpack('<I', f.read(4))[0]
json_type = f.read(4)
if json_type != b'JSON':
raise ValueError(f"期望 JSON chunk但得到: {json_type}")
self.json_data = json.loads(f.read(json_length).decode('utf-8'))
print(f" JSON 大小: {json_length:,} bytes")
# 读取 Binary chunk
if f.tell() < total_length:
bin_length = struct.unpack('<I', f.read(4))[0]
bin_type = f.read(4)
if bin_type != b'BIN\x00':
raise ValueError(f"期望 BIN chunk但得到: {bin_type}")
self.binary_data = f.read(bin_length)
print(f" 二进制数据大小: {bin_length:,} bytes")
print("=" * 60)
return self
def analyze_structure(self):
"""分析 GLB 结构"""
if not self.json_data:
self.parse_glb_structure()
print("\nGLB 结构分析:")
print("-" * 40)
# 基本信息
asset = self.json_data.get('asset', {})
print(f"生成器: {asset.get('generator', '未知')}")
print(f"glTF 版本: {asset.get('version', '未知')}")
# Draco 扩展
extensions_used = self.json_data.get('extensionsUsed', [])
extensions_required = self.json_data.get('extensionsRequired', [])
if 'KHR_draco_mesh_compression' in extensions_used:
print("使用 Draco 压缩")
if 'KHR_draco_mesh_compression' in extensions_required:
print("Draco 压缩是必需的")
# 网格信息
meshes = self.json_data.get('meshes', [])
print(f"\n网格数量: {len(meshes)}")
for i, mesh in enumerate(meshes):
print(f" 网格 {i}: {mesh.get('name', '未命名')}")
primitives = mesh.get('primitives', [])
print(f" 图元数量: {len(primitives)}")
for j, primitive in enumerate(primitives):
print(f" 图元 {j}:")
if 'extensions' in primitive:
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
if draco_info:
print(f" 使用 Draco 压缩")
print(f" 属性: {draco_info.get('attributes', {})}")
# 缓冲区信息
buffers = self.json_data.get('buffers', [])
buffer_views = self.json_data.get('bufferViews', [])
accessors = self.json_data.get('accessors', [])
print(f"\n缓冲区: {len(buffers)}")
print(f"BufferViews: {len(buffer_views)}")
print(f"访问器: {len(accessors)}")
return self
def decode_draco_meshes(self):
"""解码所有 Draco 压缩的网格"""
if not self.json_data:
self.parse_glb_structure()
meshes = []
buffer_views = self.json_data.get('bufferViews', [])
print("\n" + "=" * 60)
print("开始解码 Draco 压缩数据...")
print("=" * 60)
for mesh_idx, mesh in enumerate(self.json_data.get('meshes', [])):
mesh_name = mesh.get('name', f'mesh_{mesh_idx}')
for primitive_idx, primitive in enumerate(mesh.get('primitives', [])):
if 'extensions' in primitive:
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
if draco_info:
print(f"\n解码: {mesh_name} - 图元 {primitive_idx}")
# 解码 Draco 数据
mesh_data = self._decode_primitive(draco_info, buffer_views)
if mesh_data:
meshes.append({
'mesh_idx': mesh_idx,
'primitive_idx': primitive_idx,
'name': mesh_name,
**mesh_data
})
self.decoded_meshes = meshes # 缓存解码结果
print("\n" + "=" * 60)
print(f"解码完成!共解码 {len(meshes)} 个网格")
print("=" * 60)
return meshes
def get_vertices(self, mesh_idx=0, primitive_idx=0):
"""
获取指定网格的顶点集合
参数:
mesh_idx: 网格索引默认0
primitive_idx: 图元索引默认0
返回:
np.array: 顶点数组形状为 (n, 3) None
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
# 查找指定网格
for mesh in self.decoded_meshes:
if mesh['mesh_idx'] == mesh_idx and mesh['primitive_idx'] == primitive_idx:
return mesh.get('vertices')
print(f"未找到网格 {mesh_idx} 的图元 {primitive_idx}")
return None
def get_all_vertices(self):
"""
获取所有网格的所有顶点合并成一个数组
返回:
np.array: 所有顶点的合并数组形状为 (n, 3) None
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
if not self.decoded_meshes:
print("没有解码的网格数据")
return None
# 收集所有顶点
all_vertices = []
for mesh in self.decoded_meshes:
if mesh.get('vertices') is not None:
all_vertices.append(mesh['vertices'])
if not all_vertices:
return None
# 合并所有顶点
return np.vstack(all_vertices)
def get_vertices_by_mesh_name(self, mesh_name):
"""
根据网格名称获取顶点集合
参数:
mesh_name: 网格名称
返回:
list: 包含所有匹配网格的顶点数组列表
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
vertices_list = []
for mesh in self.decoded_meshes:
if mesh['name'] == mesh_name and mesh.get('vertices') is not None:
vertices_list.append(mesh['vertices'])
return vertices_list
def get_vertex_count(self):
"""
获取总顶点数
返回:
int: 所有网格的总顶点数
"""
vertices = self.get_all_vertices()
return len(vertices) if vertices is not None else 0
def _decode_primitive(self, draco_info, buffer_views):
"""解码单个图元的 Draco 数据"""
try:
# 获取 bufferView 信息
buffer_view_idx = draco_info['bufferView']
attributes = draco_info['attributes']
buffer_view = buffer_views[buffer_view_idx]
byte_offset = buffer_view.get('byteOffset', 0)
byte_length = buffer_view['byteLength']
print(f" BufferView: {buffer_view_idx}")
print(f" 属性映射: {attributes}")
print(f" 数据位置: offset={byte_offset}, length={byte_length}")
# 提取 Draco 压缩数据
draco_data = self.binary_data[byte_offset:byte_offset + byte_length]
print(f" Draco 数据大小: {len(draco_data):,} bytes")
# 使用 DracoPy 解码
print(" 正在使用 DracoPy 解码...")
draco_decoder = DracoPy.decode(draco_data)
# 解析解码结果
mesh_data = self._parse_draco_result(draco_decoder, attributes)
return mesh_data
except Exception as e:
print(f" 解码失败: {e}")
import traceback
traceback.print_exc()
return None
def _parse_draco_result(self, draco_decoder, attributes):
"""解析 DracoPy 解码结果"""
result = {
'vertices': None,
'faces': None,
'texcoords': None,
'batch_ids': None,
'normals': None,
'colors': None
}
# 获取顶点
if hasattr(draco_decoder, 'points'):
result['vertices'] = np.array(draco_decoder.points, dtype=np.float32)
print(f" 顶点数量: {len(result['vertices'])}")
# 获取面/三角形
if hasattr(draco_decoder, 'faces'):
faces_data = draco_decoder.faces
# 确保是三角形每面3个顶点
if len(faces_data) > 0:
if isinstance(faces_data[0], list) or isinstance(faces_data[0], tuple):
# 如果是列表的列表
result['faces'] = np.array(faces_data, dtype=np.uint32)
else:
# 如果是扁平化的数组
result['faces'] = np.array(faces_data, dtype=np.uint32).reshape(-1, 3)
print(f" 面数量: {len(result['faces']) if result['faces'] is not None else 0}")
# 获取属性数据
if hasattr(draco_decoder, 'attributes'):
attrs = draco_decoder.attributes
# 根据属性映射查找数据
for gltf_attr_name, draco_attr_id in attributes.items():
if draco_attr_id in attrs:
attr_data = attrs[draco_attr_id]
if gltf_attr_name == 'POSITION':
result['vertices'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == 'TEXCOORD_0':
result['texcoords'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == '_BATCHID':
result['batch_ids'] = np.array(attr_data, dtype=np.uint32)
elif gltf_attr_name == 'NORMAL':
result['normals'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == 'COLOR_0':
result['colors'] = np.array(attr_data, dtype=np.float32)
print(f" 已提取属性: {gltf_attr_name} (ID: {draco_attr_id})")
# 如果没有通过attributes获取到顶点尝试其他方式
if result['vertices'] is None and hasattr(draco_decoder, 'get_points'):
try:
result['vertices'] = np.array(draco_decoder.get_points(), dtype=np.float32)
except:
pass
return result
def save_decoded_meshes(self, meshes, output_format='obj'):
"""保存解码后的网格"""
import os
base_name = os.path.splitext(os.path.basename(self.glb_file_path))[0]
output_dir = f"{base_name}_decoded"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for mesh in meshes:
filename = f"{output_dir}/{mesh['name']}_p{mesh['primitive_idx']}.{output_format}"
if output_format == 'obj':
self._save_as_obj(mesh, filename)
elif output_format == 'ply':
self._save_as_ply(mesh, filename)
elif output_format == 'npz':
self._save_as_npz(mesh, filename)
else:
print(f"不支持的格式: {output_format}")
continue
print(f"已保存: {filename}")
def _save_as_obj(self, mesh, filename):
"""保存为 OBJ 格式"""
with open(filename, 'w') as f:
# 写入顶点
if mesh['vertices'] is not None:
for v in mesh['vertices']:
f.write(f"v {v[0]} {v[1]} {v[2]}\n")
# 写入纹理坐标
if mesh['texcoords'] is not None:
for uv in mesh['texcoords']:
f.write(f"vt {uv[0]} {uv[1]}\n")
# 写入法线
if mesh['normals'] is not None:
for n in mesh['normals']:
f.write(f"vn {n[0]} {n[1]} {n[2]}\n")
# 写入面
if mesh['faces'] is not None:
for face in mesh['faces']:
# OBJ 索引从1开始
face_indices = [str(idx + 1) for idx in face]
f.write(f"f {' '.join(face_indices)}\n")
def _save_as_ply(self, mesh, filename):
"""保存为 PLY 格式"""
from plyfile import PlyData, PlyElement
import numpy as np
vertices = mesh['vertices']
faces = mesh['faces']
if vertices is None:
return
# 创建顶点数据
vertex_data = np.zeros(len(vertices), dtype=[
('x', 'f4'), ('y', 'f4'), ('z', 'f4')
])
vertex_data['x'] = vertices[:, 0]
vertex_data['y'] = vertices[:, 1]
vertex_data['z'] = vertices[:, 2]
# 创建面数据
if faces is not None:
face_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (3,))])
face_data['vertex_indices'] = faces
# 写入文件
vertex_element = PlyElement.describe(vertex_data, 'vertex')
if faces is not None:
face_element = PlyElement.describe(face_data, 'face')
PlyData([vertex_element, face_element], text=False).write(filename)
else:
PlyData([vertex_element], text=False).write(filename)
def _save_as_npz(self, mesh, filename):
"""保存为 NPZ 格式"""
np.savez(
filename,
vertices=mesh['vertices'],
faces=mesh['faces'],
texcoords=mesh['texcoords'],
batch_ids=mesh['batch_ids'],
normals=mesh['normals'],
colors=mesh['colors']
)
# 使用示例
def main():
# 初始化解析器
parser = DracoGLBParser(r"D:\devForBdzlWork\ai_project_v1\b3dm\test\temp_glb\temp_6e895637.glb")
# 解析 GLB 结构
parser.parse_glb_structure()
# 分析结构
parser.analyze_structure()
# 解码 Draco 网格
meshes = parser.decode_draco_meshes()
# 使用新增的顶点获取方法
print("\n" + "=" * 60)
print("顶点获取方法演示:")
print("=" * 60)
# 1. 获取第一个网格的第一个图元的顶点
vertices = parser.get_vertices(mesh_idx=0, primitive_idx=0)
if vertices is not None:
print(f"1. 获取第一个网格顶点:")
print(f" 形状: {vertices.shape}")
print(f" 数据类型: {vertices.dtype}")
print(f" 前5个顶点: \n{vertices[:5]}")
# 2. 获取所有顶点(合并)
all_vertices = parser.get_all_vertices()
if all_vertices is not None:
print(f"\n2. 获取所有网格顶点(合并):")
print(f" 总顶点数: {len(all_vertices)}")
print(f" 形状: {all_vertices.shape}")
# 3. 获取总顶点数
total_vertices = parser.get_vertex_count()
print(f"\n3. 总顶点数: {total_vertices}")
# 4. 根据网格名称获取顶点
if meshes:
mesh_name = meshes[0]['name']
vertices_list = parser.get_vertices_by_mesh_name(mesh_name)
print(f"\n4. 根据名称 '{mesh_name}' 获取的顶点:")
for i, verts in enumerate(vertices_list):
print(f" 图元 {i}: {verts.shape if verts is not None else 'None'}")
# 保存解码后的网格
parser.save_decoded_meshes(meshes, output_format='obj')
if __name__ == "__main__":
main()

380
b3dm/slope_aspect_img.py Normal file
View File

@ -0,0 +1,380 @@
# slope_aspect_img13
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import sobel
import matplotlib as mpl
import rasterio
from osgeo import gdal
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import os
# 自定义3D箭头类
class Arrow3D(FancyArrowPatch):
def __init__(self, xs, ys, zs, *args, **kwargs):
super().__init__((0,0), (0,0), *args, **kwargs)
self._verts3d = xs, ys, zs
def do_3d_projection(self, renderer=None):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
return min(zs)
def read_dem_rasterio(filepath):
"""使用rasterio读取DEM数据"""
try:
with rasterio.open(filepath) as src:
dem_data = src.read(1)
transform = src.transform
bounds = src.bounds
crs = src.crs
print(f"DEM信息:")
print(f" 尺寸: {dem_data.shape}")
print(f" 范围: {bounds}")
print(f" 坐标系: {crs}")
print(f" 高程范围: {np.nanmin(dem_data):.2f} - {np.nanmax(dem_data):.2f}")
if np.isnan(dem_data).any():
print(" 检测到NaN值")
dem_data = np.nan_to_num(dem_data, nan=np.nanmean(dem_data))
rows, cols = dem_data.shape
x = np.linspace(bounds.left, bounds.right, cols)
y = np.linspace(bounds.bottom, bounds.top, rows)
X, Y = np.meshgrid(x, y)
return X, Y, dem_data
except Exception as e:
print(f"使用rasterio读取失败: {e}")
return None
def read_slope_aspect_by_dem(dem_path, overall_3d_output_path=None) :
# ---------------------- 核心:解决中文乱码配置 ----------------------
setup_chinese_font()
# 尝试读取DEM数据
print("正在读取DEM文件...")
X, Y, Z = read_dem_rasterio(dem_path)
if X is None:
raise FileNotFoundError(f"无法读取DEM文件: {dem_path}")
# 检查数据有效性
if Z.size == 0 or np.all(Z == 0):
raise ValueError("DEM数据无效或全部为0")
# 重采样
if Z.shape[0] > 200 or Z.shape[1] > 200:
print(f"原始DEM尺寸较大 ({Z.shape}),进行重采样...")
from scipy.ndimage import zoom
scale_factor = min(200/Z.shape[0], 200/Z.shape[1])
Z = zoom(Z, scale_factor, order=1)
X = zoom(X, scale_factor, order=1)
Y = zoom(Y, scale_factor, order=1)
print(f"重采样后尺寸: {Z.shape}")
# ---------------------- 2. 计算坡度和坡向 ----------------------
print("计算坡度和坡向...")
dx_pixel = np.abs(X[0,1] - X[0,0]) * 111000
dy_pixel = np.abs(Y[1,0] - Y[0,0]) * 111000
dx = sobel(Z, axis=1) / (2 * dx_pixel)
dy = sobel(Z, axis=0) / (2 * dy_pixel)
slope_rad = np.arctan(np.sqrt(dx**2 + dy**2))
slope_deg = slope_rad * (180 / np.pi)
aspect_rad = np.arctan2(-dx, dy)
aspect_deg = aspect_rad * (180 / np.pi)
aspect_deg[aspect_deg < 0] += 360
print(f"坡度范围: {np.min(slope_deg):.2f}° - {np.max(slope_deg):.2f}°")
print(f"坡向范围: {np.min(aspect_deg):.2f}° - {np.max(aspect_deg):.2f}°")
# ---------------------- 3. 计算整体坡向 ----------------------
print("\n计算整体坡向...")
def calculate_overall_aspect(aspect_deg, slope_deg, method='weighted_mean'):
"""计算整体坡向"""
aspect_rad = np.deg2rad(aspect_deg)
if method == 'weighted_mean':
# 坡度加权平均法
u = np.sin(aspect_rad)
v = np.cos(aspect_rad)
weights = slope_deg.flatten()
weighted_u = np.nansum(u.flatten() * weights) / np.nansum(weights)
weighted_v = np.nansum(v.flatten() * weights) / np.nansum(weights)
weighted_aspect_rad = np.arctan2(weighted_u, weighted_v)
weighted_aspect_deg = np.rad2deg(weighted_aspect_rad)
if weighted_aspect_deg < 0:
weighted_aspect_deg += 360
weighted_strength = np.sqrt(weighted_u**2 + weighted_v**2)
return weighted_aspect_deg, weighted_strength, "坡度加权平均法"
elif method == 'vector_mean':
# 向量平均法
u = np.sin(aspect_rad)
v = np.cos(aspect_rad)
mean_u = np.nanmean(u)
mean_v = np.nanmean(v)
mean_aspect_rad = np.arctan2(mean_u, mean_v)
mean_aspect_deg = np.rad2deg(mean_aspect_rad)
if mean_aspect_deg < 0:
mean_aspect_deg += 360
vector_strength = np.sqrt(mean_u**2 + mean_v**2)
return mean_aspect_deg, vector_strength, "向量平均法"
# 使用坡度加权平均法计算整体坡向
overall_aspect, overall_strength, method_name = calculate_overall_aspect(aspect_deg, slope_deg, 'weighted_mean')
# 将整体坡向转换为方向描述
if overall_aspect < 22.5 or overall_aspect >= 337.5:
overall_direction = ""
elif 22.5 <= overall_aspect < 67.5:
overall_direction = "东北"
elif 67.5 <= overall_aspect < 112.5:
overall_direction = ""
elif 112.5 <= overall_aspect < 157.5:
overall_direction = "东南"
elif 157.5 <= overall_aspect < 202.5:
overall_direction = ""
elif 202.5 <= overall_aspect < 247.5:
overall_direction = "西南"
elif 247.5 <= overall_aspect < 292.5:
overall_direction = "西"
else:
overall_direction = "西北"
print(f"整体坡向 ({method_name}):")
print(f" 角度: {overall_aspect:.1f}°")
print(f" 方向: {overall_direction}")
# print(f" 一致性: {overall_strength:.3f}")
# ---------------------- 5. 3D可视化俯视图包含整体坡向和关键点坡向 ----------------------
print("\n生成3D俯视图可视化...")
fig3d, ax3d = plt.subplots(figsize=(16, 12), subplot_kw={"projection": "3d"})
# 绘制地形曲面 - 俯视图需要更清晰的地形表现
norm = mpl.colors.Normalize(vmin=np.percentile(slope_deg, 5),
vmax=np.percentile(slope_deg, 95))
plot_skip = max(2, Z.shape[0] // 60) # 增加采样密度,使俯视图更清晰
X_plot = X[::plot_skip, ::plot_skip]
Y_plot = Y[::plot_skip, ::plot_skip]
Z_plot = Z[::plot_skip, ::plot_skip]
slope_plot = slope_deg[::plot_skip, ::plot_skip]
surf = ax3d.plot_surface(
X_plot, Y_plot, Z_plot,
cmap="viridis_r",
alpha=0.85, # 增加透明度,使箭头更明显
linewidth=0.1, # 很细的网格线
facecolors=plt.cm.viridis_r(norm(slope_plot)),
zorder=1
)
# ---------------------- 绘制整体坡向箭头(中心位置,红色粗箭头) ----------------------
center_x = (X.min() + X.max()) / 2
center_y = (Y.min() + Y.max()) / 2
# 整体坡向箭头长度
arrow_length_overall = 0.15 * min(X.max()-X.min(), Y.max()-Y.min())
# 计算整体坡向箭头方向
scale_factor = 1.5
overall_aspect_rad = np.deg2rad(overall_aspect)
dx_overall = np.sin(overall_aspect_rad) * arrow_length_overall * scale_factor
dy_overall = np.cos(overall_aspect_rad) * arrow_length_overall * scale_factor
# 方案2如果希望箭头在地形上方一定高度
terrain_max_z = Z.max() # 地形最高点
float_height = 0.2 * terrain_max_z # 在地形最高点上方20%的高度
# 绘制整体坡向箭头(红色,粗)
arrow_overall = Arrow3D(
[center_x, center_x + dx_overall],
[center_y, center_y + dy_overall],
[terrain_max_z + float_height, terrain_max_z + float_height], # 在地形上方
mutation_scale=25, # 稍大一些的箭头
lw=5, # 更粗的线
arrowstyle='-|>',
color='red',
alpha=0.98, # 更高的透明度
zorder=20 # 更高的绘制顺序
)
ax3d.add_artist(arrow_overall)
# 在整体坡向箭头起点添加标记
ax3d.scatter(center_x, center_y, terrain_max_z + float_height,
color='red', s=120, edgecolor='white', linewidth=2, zorder=21)
# ---------------------- 整体坡向面板放置在左上角 ----------------------
# 计算左上角位置
panel_x = X.min() + 0.02 * (X.max() - X.min()) # 左边留2%的边距
panel_y = Y.max() - 0.02 * (Y.max() - Y.min()) # 上边留2%的边距
panel_z = Z.max() + 0.5 * (Z.max() - Z.min()) # 进一步提高Z坐标确保在视野内
# 整体坡向面板内容
overall_info = (
f"整体坡向分析\n"
f"角度: {overall_aspect:.1f}°\n"
f"方向: {overall_direction}"
# f"一致性: {overall_strength:.3f}"
)
# 绘制整体坡向面板(左上角)
ax3d.text(panel_x, panel_y, panel_z,
overall_info,
fontsize=12, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.5, boxstyle="round,pad=0.5",
edgecolor='red', linewidth=2.5),
ha='left', va='top', zorder=30)
# ---------------------- 设置3D图参数 ----------------------
ax3d.set_xlabel("经度 (X)", fontsize=12)
ax3d.set_ylabel("纬度 (Y)", fontsize=12)
ax3d.set_zlabel("高程 (m)", fontsize=12)
# 获取文件名用于标题
filename = os.path.basename(dem_path)
ax3d.set_title(f"DEM三维俯视图 - 坡向分析\n文件: {filename}",
fontsize=14, fontweight='bold', pad=20)
# 添加颜色条
cbar = plt.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap='viridis_r'),
ax=ax3d, shrink=0.6, aspect=25, pad=0.15
)
cbar.set_label("坡度 (°)", fontsize=12)
# ---------------------- 设置为俯视图(高仰角) ----------------------
view_elev = 85 # 接近90度的仰角俯视效果
view_azim = overall_aspect + 180 # 从坡向的相反方向观看,可以看到坡面
# 确保方位角在0-360度范围内
view_azim = view_azim % 360
print(f"\n设置3D俯视图视角:")
print(f" 整体坡向: {overall_aspect:.1f}° ({overall_direction})")
print(f" 视角方位角: {view_azim:.1f}°")
print(f" 视角仰角: {view_elev:.1f}°")
# 应用俯视图视角设置
ax3d.view_init(elev=view_elev, azim=view_azim)
# 调整相机距离,使视角更广
ax3d.dist = 9.0 # 增加相机距离
# 设置坐标轴范围,确保所有元素都显示
z_min, z_max = Z.min(), Z.max()
z_padding = 0.6 * (z_max - z_min) # 适当增加Z轴范围
ax3d.set_zlim(z_min - 0.1*z_padding, z_max + z_padding)
# 设置XY轴范围
x_margin = 0.1 * (X.max() - X.min())
y_margin = 0.1 * (Y.max() - Y.min())
ax3d.set_xlim(X.min() - x_margin, X.max() + x_margin)
ax3d.set_ylim(Y.min() - y_margin, Y.max() + y_margin)
plt.tight_layout()
# 保存图片
if not overall_3d_output_path :
overall_3d_output_path = dem_path.replace('.tif', '_slopeAspect_3D_overlook.png')
plt.savefig(overall_3d_output_path, dpi=250, bbox_inches='tight', facecolor='white')
# plt.show()
os.remove(dem_path)
print(f"\n3D俯视图已保存: {overall_3d_output_path}")
print("分析完成!")
print("="*60)
return overall_3d_output_path
# ---------------------- 跨平台中文字体配置 ----------------------
def setup_chinese_font():
"""设置中文字体支持兼容Windows、Linux、macOS"""
import platform
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
system = platform.system()
# 尝试添加系统中文字体路径
font_paths = []
if system == 'Windows':
# Windows 字体路径
font_paths = [
'C:/Windows/Fonts/simhei.ttf', # 黑体
'C:/Windows/Fonts/simkai.ttf', # 楷体
'C:/Windows/Fonts/simsun.ttc', # 宋体
'C:/Windows/Fonts/microsoftyahei.ttf', # 微软雅黑
]
elif system == 'Darwin': # macOS
# macOS 字体路径
font_paths = [
'/System/Library/Fonts/PingFang.ttc', # 苹方
'/System/Library/Fonts/STHeiti Light.ttc', # 华文黑体
'/System/Library/Fonts/STHeiti Medium.ttc',
'/Library/Fonts/Arial Unicode.ttf', # Arial Unicode
]
elif system == 'Linux':
# Linux 字体路径
font_paths = [
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑
'/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑
'/usr/share/fonts/truetype/arphic/uming.ttc', # AR PL UMing
'/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # Noto
]
# 尝试找到并添加第一个可用的中文字体
font_added = False
for font_path in font_paths:
try:
if os.path.exists(font_path):
font_prop = mpl.font_manager.FontProperties(fname=font_path)
font_name = font_prop.get_name()
mpl.font_manager.fontManager.addfont(font_path)
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
font_added = True
print(f"已添加中文字体: {font_name} ({font_path})")
break
except Exception as e:
print(f"添加字体 {font_path} 失败: {e}")
continue
# 如果以上方法都失败,使用通用备选方案
if not font_added:
if system == 'Windows':
fallback_fonts = ['SimHei', 'Microsoft YaHei', 'KaiTi', 'FangSong']
elif system == 'Darwin':
fallback_fonts = ['PingFang SC', 'Hiragino Sans GB', 'Apple LiGothic Medium']
else: # Linux and others
fallback_fonts = ['DejaVu Sans', 'WenQuanYi Micro Hei',
'Noto Sans CJK SC', 'Heiti TC', 'AR PL UMing CN']
current_fonts = plt.rcParams.get('font.sans-serif', [])
plt.rcParams['font.sans-serif'] = fallback_fonts + current_fonts
print(f"使用备选字体方案: {fallback_fonts[:2]}...")
# 设置字体族
plt.rcParams['font.family'] = 'sans-serif'
if __name__ == "__main__":
dem_path = r'D:/devForBdzlWork/ai_project_v1/b3dm/o_dem_f1cb6f69_slopeAspect.tif'
read_slope_aspect_by_dem(dem_path)

1062
b3dm/slope_aspect_tif.py Normal file

File diff suppressed because it is too large Load Diff

419
b3dm/terrain_api.py Normal file
View File

@ -0,0 +1,419 @@
from sanic import Blueprint, Request, json
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional, Dict, Any
import numpy as np
import logging
import time
import uuid
import urllib.parse
import threading
import os
from b3dm.terrain_calculator import TerrainCalculator
terrain_bp = Blueprint("terrain", url_prefix="")
MINIO_SUB_PATH = "slopeAspectPng"
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 请求模型
class NormalVector(BaseModel):
"""法向量模型"""
nx: float = Field(..., description="法向量X分量")
ny: float = Field(..., description="法向量Y分量")
nz: float = Field(..., description="法向量Z分量")
@field_validator('nx', 'ny', 'nz')
def check_finite(cls, v):
if not np.isfinite(v):
raise ValueError(f"值必须是有限数字,得到: {v}")
return v
def to_list(self):
return [self.nx, self.ny, self.nz]
class BatchRequest(BaseModel):
"""批量请求模型"""
vectors: List[List[float]] = Field(..., description="法向量列表")
@field_validator('vectors')
def validate_vectors(cls, v):
if len(v) > 1000:
raise ValueError("批量处理最多支持1000个向量")
for vec in v:
if len(vec) != 3:
raise ValueError("每个向量必须是长度为3的列表")
if not all(isinstance(x, (int, float)) for x in vec):
raise ValueError("向量元素必须是数字")
return v
class PointItem(BaseModel):
"""单个点模型"""
x: float = Field(..., description="x坐标")
y: float = Field(..., description="y坐标")
z: float = Field(..., description="z坐标")
class PointRequest(BaseModel):
points: List[PointItem] = Field(..., description="点列表")
url: str = Field(..., description="URL地址")
@field_validator('points')
def validate_points_count(cls, v):
if len(v) > 10:
raise ValueError("批量处理最多支持10个点")
return v
class PreloadRequest(BaseModel):
url: str = Field(..., description="URL地址")
class AnalysisConfig(BaseModel):
"""分析配置"""
classify: bool = Field(default=True, description="是否进行分类")
include_percent: bool = Field(default=True, description="是否包含坡度百分比")
include_direction: bool = Field(default=True, description="是否包含方向描述")
# 中间件:请求计时
@terrain_bp.middleware("request")
async def add_start_time(request: Request):
request.ctx.start_time = time.time()
@terrain_bp.middleware("response")
async def add_response_time(request: Request, response):
if hasattr(request.ctx, "start_time"):
process_time = (time.time() - request.ctx.start_time) * 1000
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
@terrain_bp.post("/api/v1/calculate/slope")
async def calculate_slope(request: Request):
"""计算坡度"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = NormalVector(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 计算坡度
result = TerrainCalculator.calculate_slope(vector.to_list())
# 检查是否有错误
if "error" in result and result["error"]:
return json(result, status=400)
return json({
"success": True,
"data": result,
"request": {
"input_vector": vector.to_list(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡度计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/aspect")
async def calculate_aspect1(request: Request):
"""计算坡向"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = NormalVector(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 计算坡向
result = TerrainCalculator.calculate_aspect(vector.to_list())
# 检查是否有错误
if "error" in result and result["error"]:
return json(result, status=400)
return json({
"success": True,
"data": result,
"request": {
"input_vector": vector.to_list(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡向计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/preload3dTiles")
async def preload_3dtiles(request: Request):
"""预加载3dtiles地图数据"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = PreloadRequest(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 创建并启动线程
script_dir = os.path.dirname(os.path.abspath(__file__))
thread1 = threading.Thread(target=TerrainCalculator.preload_3dtiles, args=(vector.url,))
# 启动线程
thread1.start()
url_prefix = extract_and_rebuild_url(vector.url)
return json({
"success": True,
"data": f"{script_dir}/data_3dtiles",
"request": {
"input_vector": vector.model_dump(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡向计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/slopeAspect")
async def calculate_slopeAspect(request: Request):
"""生成坡向坡度俯视图"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = PointRequest(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 生成坡向坡度俯视图
region_coords = [(point.x, point.y, point.z) for point in vector.points]
overall_3d_png_name = f"o_dem_{uuid.uuid4().hex[:8]}_slopeAspect.png"
# 创建并启动线程
thread1 = threading.Thread(target=TerrainCalculator.generate_slopeAspect_3d_overlook, args=(region_coords, vector.url, overall_3d_png_name, MINIO_SUB_PATH))
# 启动线程
thread1.start()
url_prefix = extract_and_rebuild_url(vector.url)
return json({
"success": True,
"data": f"{url_prefix}/{MINIO_SUB_PATH}/{overall_3d_png_name}",
"request": {
"input_vector": vector.model_dump(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡向计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/slopeAspectTif")
async def calculate_slopeAspect_tif(request: Request):
"""生成坡向坡度tif文件"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = PointRequest(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 生成坡向坡度俯视图
region_coords = [(point.x, point.y, point.z) for point in vector.points]
slope_aspect_tif_name = f"o_dem_{uuid.uuid4().hex[:8]}_slopeAspect.tif"
# 创建并启动线程
thread1 = threading.Thread(target=TerrainCalculator.generate_slopeAspect_tif, args=(region_coords, vector.url, slope_aspect_tif_name, MINIO_SUB_PATH))
# 启动线程
thread1.start()
url_prefix = extract_and_rebuild_url(vector.url)
return json({
"success": True,
"data": f"{url_prefix}/{MINIO_SUB_PATH}/{slope_aspect_tif_name}",
"request": {
"input_vector": vector.model_dump(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡向计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/both")
async def calculate_both(request: Request):
"""同时计算坡度和坡向"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = NormalVector(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 计算坡度和坡向
result = TerrainCalculator.calculate_slope_aspect(vector.to_list())
# 检查是否有错误
if "error" in result and result["error"]:
return json(result, status=400)
return json({
"success": True,
"data": result,
"request": {
"input_vector": vector.to_list(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"综合计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/batch")
async def batch_calculate(request: Request):
"""批量计算"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
batch_request = BatchRequest(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 批量计算
start_time = time.time()
result = TerrainCalculator.batch_calculate(batch_request.vectors)
process_time = (time.time() - start_time) * 1000
if "error" in result and result["error"]:
return json(result, status=400)
return json({
"success": True,
"data": result,
"performance": {
"process_time_ms": process_time,
"vectors_per_second": len(batch_request.vectors) / (process_time / 1000) if process_time > 0 else 0
},
"request": {
"vector_count": len(batch_request.vectors),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"批量计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.get("/api/v1/example")
async def get_examples(request: Request):
"""获取示例数据"""
examples = {
"flat": {
"nx": 0.0,
"ny": 0.0,
"nz": 1.0,
"expected_slope": 0.0,
"description": "完全水平面"
},
"north_slope_30": {
"nx": 0.0,
"ny": -0.5,
"nz": 0.8660254,
"expected_slope": 30.0,
"expected_aspect": 0.0,
"description": "朝北30度斜坡"
},
"east_slope_45": {
"nx": 0.7071068,
"ny": 0.0,
"nz": 0.7071068,
"expected_slope": 45.0,
"expected_aspect": 90.0,
"description": "朝东45度斜坡"
},
"vertical": {
"nx": 1.0,
"ny": 0.0,
"nz": 0.0,
"expected_slope": 90.0,
"description": "垂直面"
}
}
return json({
"examples": examples,
"count": len(examples)
})
def extract_and_rebuild_url(url):
"""提取URL的三部分并重建"""
# 解析URL
parsed = urllib.parse.urlparse(url)
# 1. 提取协议部分 (http/https)
scheme = parsed.scheme or "http" # 如果没有协议默认用http
# 2. 提取IP端口/主机部分
netloc = parsed.netloc
# 3. 提取第一个路径分段
path = parsed.path.strip("/") # 去掉首尾的斜杠
path_parts = path.split("/")
if path_parts and path_parts[0]:
first_segment = path_parts[0]
else:
first_segment = ""
# 重建URL
if first_segment:
rebuilt_url = f"{scheme}://{netloc}/{first_segment}"
else:
rebuilt_url = f"{scheme}://{netloc}"
return rebuilt_url

392
b3dm/terrain_calculator.py Normal file
View File

@ -0,0 +1,392 @@
import numpy as np
from typing import List, Tuple, Dict, Any
import logging
import os
import uuid
import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem
import b3dm.slope_aspect_img as slope_aspect_img
import b3dm.slope_aspect_tif as slope_aspect_tif
from b3dm.tileset_data_source import TilesetDataSource
logger = logging.getLogger(__name__)
_data_source = None
class TerrainCalculator:
"""地形坡度和坡向计算器"""
def preload_3dtiles(url: str) :
# 下载3dtiles地图数据
_data_source = TilesetDataSource(url)
_data_source.dowload_map_data(url)
if not _data_source.tileset_path :
logger.info(f"下载地图数据失败: {url}")
return "下载地图数据失败", None
def generate_slopeAspect_3d_overlook(region_coords, url, overall_3d_png_name, minio_sub_path) :
# 下载3dtiles地图数据
_data_source = TilesetDataSource(url)
_data_source.dowload_map_data(url)
if not _data_source.tileset_path :
logger.info(f"下载地图数据失败: {url},{region_coords}")
return "下载地图数据失败", None
tileset_path = _data_source.tileset_path
script_dir = os.path.dirname(os.path.abspath(__file__))
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords)
if not os.path.exists(dem_path) :
logger.info(f"生成坡度坡向俯视图失败: {url},{region_coords}")
return "生成坡度坡向俯视图失败", None
overall_3d_png_path = os.path.join(script_dir, overall_3d_png_name)
slope_aspect_img.read_slope_aspect_by_dem(dem_path, overall_3d_png_path)
logger.info(f"生成成功: {url},{region_coords},{overall_3d_png_path}")
entry_bucket, _ = _data_source.parse_minio_url(url);
success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{overall_3d_png_name}", overall_3d_png_path)
if success :
return "生成成功", minio_path
else :
return "生成失败", None
def generate_slopeAspect_tif(region_coords, url, slope_aspect_tif_name, minio_sub_path) :
# 下载3dtiles地图数据
_data_source = TilesetDataSource(url)
_data_source.dowload_map_data(url)
if not _data_source.tileset_path :
logger.info(f"下载地图数据失败: {url},{region_coords}")
return "下载地图数据失败", None
tileset_path = _data_source.tileset_path
script_dir = os.path.dirname(os.path.abspath(__file__))
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords)
if not os.path.exists(dem_path) :
logger.info(f"生成坡度坡向tif失败: {url},{region_coords}")
return "生成坡度坡向tif失败", None
slope_aspect_tif_path = os.path.join(script_dir, slope_aspect_tif_name)
slope_aspect_tif.create_slope_aspect(dem_path, 'combined', slope_aspect_tif_path)
logger.info(f"生成成功: {url},{region_coords},{slope_aspect_tif_path}")
entry_bucket, _ = _data_source.parse_minio_url(url);
success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{slope_aspect_tif_name}", slope_aspect_tif_path)
if success :
return "生成成功", minio_path
else :
return "生成失败", None
@staticmethod
def validate_vector(vector: List[float]) -> bool:
"""验证输入向量是否有效"""
if len(vector) != 3:
return False
if not all(isinstance(v, (int, float)) for v in vector):
return False
norm = np.linalg.norm(vector)
return norm > 1e-10 # 避免零向量
@staticmethod
def normalize_vector(vector: List[float]) -> np.ndarray:
"""向量归一化"""
arr = np.array(vector, dtype=np.float64)
norm = np.linalg.norm(arr)
return arr / norm if norm > 0 else arr
@staticmethod
def calculate_slope(normal_vector: List[float]) -> Dict[str, Any]:
"""
计算坡度
Args:
normal_vector: 法向量 [nx, ny, nz]
Returns:
dict: 包含坡度()和相关信息
"""
try:
# 验证输入
if not TerrainCalculator.validate_vector(normal_vector):
return {
"error": "无效的法向量,必须是长度为3的数值列表且不能为零向量",
"slope_deg": None
}
# 归一化
n = TerrainCalculator.normalize_vector(normal_vector)
# 计算坡度使用arccos法
nz_abs = abs(n[2])
# 处理数值误差
if nz_abs > 1.0:
nz_abs = 1.0
elif nz_abs < 0.0:
nz_abs = 0.0
# 计算坡度(弧度)
if abs(nz_abs - 1.0) < 1e-10: # 完全水平
slope_rad = 0.0
elif abs(nz_abs) < 1e-10: # 完全垂直
slope_rad = np.pi / 2
else:
slope_rad = np.arccos(nz_abs)
# 转换为度
slope_deg = np.degrees(slope_rad)
# 计算坡度百分比
slope_percent = np.tan(slope_rad) * 100 if slope_rad < np.pi/2 else float('inf')
return {
"slope_deg": float(slope_deg),
"slope_rad": float(slope_rad),
"slope_percent": float(slope_percent),
"normalized_vector": n.tolist(),
"classification": TerrainCalculator.classify_slope(slope_deg)
}
except Exception as e:
logger.error(f"坡度计算错误: {e}")
return {
"error": f"计算失败: {str(e)}",
"slope_deg": None
}
@staticmethod
def calculate_aspect(normal_vector: List[float]) -> Dict[str, Any]:
"""
计算坡向
Args:
normal_vector: 法向量 [nx, ny, nz]
Returns:
dict: 包含坡向()和相关信息
"""
try:
# 验证输入
if not TerrainCalculator.validate_vector(normal_vector):
return {
"error": "无效的法向量,必须是长度为3的数值列表且不能为零向量",
"aspect_deg": None
}
# 归一化
n = TerrainCalculator.normalize_vector(normal_vector)
# 检查是否为水平面
nx, ny, nz = n
horizontal_magnitude = np.sqrt(nx*nx + ny*ny)
if horizontal_magnitude < 1e-10: # 水平面,坡向无定义
return {
"aspect_deg": None,
"aspect_rad": None,
"is_flat": True,
"message": "水平面,坡向无定义",
"normalized_vector": n.tolist()
}
# 计算原始坡向(四象限反正切)
# 注意arctan2(nx, ny) 不是 arctan2(ny, nx)
raw_angle_rad = np.arctan2(nx, ny)
# 转换为坡向(下坡方向 = 法向量方向 + 180°
aspect_rad = raw_angle_rad + np.pi
# 转换为度
aspect_deg = np.degrees(aspect_rad)
# 归一化到 [0, 360) 范围
aspect_deg = aspect_deg % 360.0
# 转换为八方向
direction = TerrainCalculator.aspect_to_direction(aspect_deg)
return {
"aspect_deg": float(aspect_deg),
"aspect_rad": float(aspect_rad % (2*np.pi)),
"direction": direction,
"is_flat": False,
"normalized_vector": n.tolist(),
"raw_angle_deg": float(np.degrees(raw_angle_rad))
}
except Exception as e:
logger.error(f"坡向计算错误: {e}")
return {
"error": f"计算失败: {str(e)}",
"aspect_deg": None
}
@staticmethod
def calculate_slope_aspect(normal_vector: List[float]) -> Dict[str, Any]:
"""
同时计算坡度和坡向
Args:
normal_vector: 法向量 [nx, ny, nz]
Returns:
dict: 包含坡度和坡向的综合结果
"""
try:
slope_result = TerrainCalculator.calculate_slope(normal_vector)
aspect_result = TerrainCalculator.calculate_aspect(normal_vector)
result = {
"slope": slope_result,
"aspect": aspect_result,
"input_vector": normal_vector,
"calculation_time": None # 可在调用处添加时间戳
}
# 如果有错误,合并错误信息
errors = []
if "error" in slope_result and slope_result["error"]:
errors.append(f"坡度: {slope_result['error']}")
if "error" in aspect_result and aspect_result["error"]:
errors.append(f"坡向: {aspect_result['error']}")
if errors:
result["errors"] = errors
return result
except Exception as e:
logger.error(f"综合计算错误: {e}")
return {
"error": f"综合计算失败: {str(e)}"
}
@staticmethod
def classify_slope(slope_deg: float) -> Dict[str, Any]:
"""坡度分类"""
if slope_deg < 2:
return {"category": "平坦", "level": 0, "description": "基本平坦"}
elif slope_deg < 5:
return {"category": "缓坡", "level": 1, "description": "适合农业"}
elif slope_deg < 15:
return {"category": "斜坡", "level": 2, "description": "适合建设"}
elif slope_deg < 30:
return {"category": "陡坡", "level": 3, "description": "需要工程措施"}
elif slope_deg < 45:
return {"category": "急陡坡", "level": 4, "description": "高风险区域"}
else:
return {"category": "峭壁", "level": 5, "description": "危险区域"}
@staticmethod
def aspect_to_direction(aspect_deg: float) -> Dict[str, Any]:
"""将坡向转换为八方向"""
directions = ["", "东北", "", "东南", "", "西南", "西", "西北"]
# 计算方向索引 (45°一个区间)
index = int((aspect_deg + 22.5) % 360 / 45)
return {
"chinese": directions[index],
"english": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"][index],
"degree_range": {
"min": (index * 45 - 22.5) % 360,
"max": (index * 45 + 22.5) % 360
}
}
@staticmethod
def batch_calculate(vectors: List[List[float]]) -> Dict[str, Any]:
"""批量计算多个法向量"""
try:
results = []
errors = []
for i, vec in enumerate(vectors):
try:
result = TerrainCalculator.calculate_slope_aspect(vec)
result["index"] = i
results.append(result)
except Exception as e:
errors.append({
"index": i,
"vector": vec,
"error": str(e)
})
return {
"total": len(vectors),
"successful": len(results),
"failed": len(errors),
"results": results,
"errors": errors,
"statistics": TerrainCalculator.calculate_statistics(results)
}
except Exception as e:
logger.error(f"批量计算错误: {e}")
return {"error": f"批量计算失败: {str(e)}"}
@staticmethod
def calculate_statistics(results: List[Dict]) -> Dict[str, Any]:
"""计算统计信息"""
if not results:
return {}
slope_values = []
aspect_values = []
for r in results:
if "slope" in r and "slope_deg" in r["slope"] and r["slope"]["slope_deg"] is not None:
slope_values.append(r["slope"]["slope_deg"])
if "aspect" in r and "aspect_deg" in r["aspect"] and r["aspect"]["aspect_deg"] is not None:
aspect_values.append(r["aspect"]["aspect_deg"])
if slope_values:
slope_arr = np.array(slope_values)
aspect_arr = np.array(aspect_values) if aspect_values else np.array([])
stats = {
"slope": {
"count": len(slope_values),
"mean": float(np.mean(slope_arr)),
"std": float(np.std(slope_arr)),
"min": float(np.min(slope_arr)),
"max": float(np.max(slope_arr)),
"median": float(np.median(slope_arr))
}
}
if aspect_values:
# 坡向统计需要循环统计
stats["aspect"] = {
"count": len(aspect_values),
"mean_vector": TerrainCalculator.circular_mean(aspect_arr),
"concentration": TerrainCalculator.circular_concentration(aspect_arr)
}
return stats
return {}
@staticmethod
def circular_mean(angles_deg: np.ndarray) -> float:
"""计算循环数据的平均值(角度)"""
angles_rad = np.radians(angles_deg)
x = np.mean(np.cos(angles_rad))
y = np.mean(np.sin(angles_rad))
mean_rad = np.arctan2(y, x)
return np.degrees(mean_rad) % 360
@staticmethod
def circular_concentration(angles_deg: np.ndarray) -> float:
"""计算角度数据的集中度 (0-1)"""
angles_rad = np.radians(angles_deg)
x = np.mean(np.cos(angles_rad))
y = np.mean(np.sin(angles_rad))
return np.sqrt(x*x + y*y)

105
b3dm/tileset_data_source.py Normal file
View File

@ -0,0 +1,105 @@
# tileset_data_source_py
import numpy as np
from typing import List, Dict, Optional, Tuple
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import os
from pathlib import Path
from b3dm.data_3dtiles_manager import MinIO3DTilesManager
import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem
logger = logging.getLogger(__name__)
ENDPOINT_URL = "222.212.85.86:9000"
ACCESS_KEY = "WuRenJi"
SECRET_KEY = "WRJ@2024"
class TilesetDataSource:
"""使用py3dtiles库的数据源"""
def __init__(self, url: str, cache_size: int = 1000):
self.url = url
self.tileset_path = None
self.tileset_dir = None
self._crs = "EPSG:4979"
def parse_minio_url(self, url):
manager = MinIO3DTilesManager(
endpoint_url=ENDPOINT_URL,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
return manager.parse_minio_url(url)
def upload_file(self, bucket_name, object_name, file_path):
manager = MinIO3DTilesManager(
endpoint_url=ENDPOINT_URL,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
flag, path = manager.upload_file(bucket_name, object_name, file_path)
if flag :
os.remove(file_path)
return flag, path
def dowload_map_data(self, url: str) :
# 下载3dtiles地图数据
manager = MinIO3DTilesManager(
endpoint_url=ENDPOINT_URL,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
success, tileset_path = manager.download_full_tileset(
tileset_url=url,
save_dir=f"data_3dtiles",
region_filter=None
)
if success :
self.tileset_path = os.path.abspath(tileset_path)
self.tileset_dir = os.path.dirname(tileset_path)
async def get_points_in_polygon(self, polygon_coords, z_range=None):
"""获取多边形内的点数据"""
points = data_3dtiles_to_dem.parse_tileset(self.tileset_path, polygon_coords)
return np.array(points)
async def get_data_bounds(self) -> Dict[str, List[float]]:
"""获取数据边界"""
bounds = {
"min": [float('inf'), float('inf'), float('inf')],
"max": [-float('inf'), -float('inf'), -float('inf')]
}
if self._tileset and hasattr(self._tileset.root_tile, 'bounding_volume'):
bv = self._tileset.root_tile.bounding_volume
if hasattr(bv, 'get_corners'):
corners = bv.get_corners()
if corners is not None:
bounds["min"] = corners.min(axis=0).tolist()
bounds["max"] = corners.max(axis=0).tolist()
return bounds
def get_crs(self) -> str:
return self._crs
async def main() :
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SCRIPT_PAR_DIR = os.path.dirname(SCRIPT_DIR)
tileset_path = os.path.join(SCRIPT_PAR_DIR, "data/3dtiles/tileset.json")
data_source_3d_tiles = TilesetDataSource(tileset_path)
# tileSet = TileSet()
# path = Path(tileset_path)
# tileset_data = tileSet.from_file(path)
print("====================================================")
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main())

395
b3dm/tileset_to_ply.py Normal file
View File

@ -0,0 +1,395 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
3D Tiles Tileset to PLY Converter
将整个3D Tiles tileset转换为单个PLY文件
"""
import json
import struct
import os
import sys
from pathlib import Path
import numpy as np
try:
import DracoPy
except ImportError:
print("警告: DracoPy库未安装无法处理Draco压缩的数据")
print("请运行: pip install DracoPy")
DracoPy = None
class TilesetToPLYConverter:
def __init__(self):
self.all_vertices = []
self.vertex_count = 0
def multiply_matrix_vector(self, matrix, vector):
"""4x4矩阵与4D向量相乘"""
# matrix是16个元素的列表按列主序排列
# 转换为4x4矩阵行主序
m = [
[matrix[0], matrix[4], matrix[8], matrix[12]],
[matrix[1], matrix[5], matrix[9], matrix[13]],
[matrix[2], matrix[6], matrix[10], matrix[14]],
[matrix[3], matrix[7], matrix[11], matrix[15]]
]
# 向量扩展为齐次坐标 [x, y, z, 1]
v = [vector[0], vector[1], vector[2], 1.0]
# 矩阵乘法
result = [
m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2] + m[0][3]*v[3],
m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2] + m[1][3]*v[3],
m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2] + m[2][3]*v[3]
]
return result
def multiply_matrices(self, m1, m2):
"""两个4x4矩阵相乘"""
# 将16元素列表转换为4x4矩阵
def list_to_matrix(lst):
return [
[lst[0], lst[4], lst[8], lst[12]],
[lst[1], lst[5], lst[9], lst[13]],
[lst[2], lst[6], lst[10], lst[14]],
[lst[3], lst[7], lst[11], lst[15]]
]
def matrix_to_list(mat):
return [
mat[0][0], mat[1][0], mat[2][0], mat[3][0],
mat[0][1], mat[1][1], mat[2][1], mat[3][1],
mat[0][2], mat[1][2], mat[2][2], mat[3][2],
mat[0][3], mat[1][3], mat[2][3], mat[3][3]
]
a = list_to_matrix(m1)
b = list_to_matrix(m2)
result = [[0 for _ in range(4)] for _ in range(4)]
for i in range(4):
for j in range(4):
for k in range(4):
result[i][j] += a[i][k] * b[k][j]
return matrix_to_list(result)
def apply_transform_to_vertices(self, vertices, transform_matrix):
"""对顶点应用变换矩阵"""
if not transform_matrix:
return vertices
transformed_vertices = []
for vertex in vertices:
transformed = self.multiply_matrix_vector(transform_matrix, vertex)
transformed_vertices.append(transformed)
return transformed_vertices
def parse_tileset_json(self, tileset_path, parent_transform=None):
"""解析tileset.json文件收集B3DM文件和变换矩阵"""
try:
with open(tileset_path, 'r', encoding='utf-8') as f:
tileset_data = json.load(f)
b3dm_files = []
def process_node(node, base_path, accumulated_transform):
# 获取当前节点的变换矩阵
current_transform = node.get('transform')
# 计算累积变换矩阵
if current_transform and accumulated_transform:
# 矩阵相乘accumulated_transform * current_transform
final_transform = self.multiply_matrices(accumulated_transform, current_transform)
elif current_transform:
final_transform = current_transform
else:
final_transform = accumulated_transform
if 'content' in node and 'uri' in node['content']:
uri = node['content']['uri']
if uri.endswith('.b3dm'):
full_path = os.path.join(base_path, uri)
if os.path.exists(full_path):
b3dm_files.append((full_path, final_transform))
elif uri.endswith('.json'):
# 递归处理子tileset
sub_tileset_path = os.path.join(base_path, uri)
if os.path.exists(sub_tileset_path):
sub_files = self.parse_tileset_json(sub_tileset_path, final_transform)
b3dm_files.extend(sub_files)
if 'children' in node:
for child in node['children']:
process_node(child, base_path, final_transform)
base_path = os.path.dirname(tileset_path)
if 'root' in tileset_data:
process_node(tileset_data['root'], base_path, parent_transform)
return b3dm_files
except Exception as e:
print(f"解析tileset.json时出错: {e}")
return []
def parse_b3dm_file(self, file_path):
"""解析B3DM文件"""
try:
with open(file_path, 'rb') as f:
# 读取B3DM头部
magic = f.read(4)
if magic != b'b3dm':
print(f"警告: {file_path} 不是有效的B3DM文件")
return None
version = struct.unpack('<I', f.read(4))[0]
byte_length = struct.unpack('<I', f.read(4))[0]
feature_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
feature_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
batch_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
batch_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
# 跳过feature table和batch table
f.seek(28 + feature_table_json_byte_length + feature_table_binary_byte_length +
batch_table_json_byte_length + batch_table_binary_byte_length)
# 读取glTF数据
gltf_data = f.read()
return self.parse_gltf_data(gltf_data)
except Exception as e:
print(f"解析B3DM文件 {file_path} 失败: {e}")
return None
def parse_gltf_data(self, gltf_data):
"""解析glTF数据"""
try:
# 检查是否为GLB格式
if gltf_data[:4] == b'glTF':
return self.parse_glb_data(gltf_data)
else:
# 尝试作为JSON解析
gltf_json = json.loads(gltf_data.decode('utf-8'))
return self.extract_vertices_from_gltf(gltf_json, None)
except Exception as e:
print(f"解析glTF数据失败: {e}")
return None
def parse_glb_data(self, glb_data):
"""解析GLB格式的glTF数据"""
try:
# GLB头部: magic(4) + version(4) + length(4)
magic = glb_data[:4]
if magic != b'glTF':
return None
version = struct.unpack('<I', glb_data[4:8])[0]
total_length = struct.unpack('<I', glb_data[8:12])[0]
offset = 12
json_data = None
binary_data = None
# 读取chunks
while offset < len(glb_data):
if offset + 8 > len(glb_data):
break
chunk_length = struct.unpack('<I', glb_data[offset:offset+4])[0]
chunk_type = glb_data[offset+4:offset+8]
chunk_data = glb_data[offset+8:offset+8+chunk_length]
if chunk_type == b'JSON':
json_data = json.loads(chunk_data.decode('utf-8'))
elif chunk_type == b'BIN\x00':
binary_data = chunk_data
offset += 8 + chunk_length
# 对齐到4字节边界
offset = (offset + 3) & ~3
if json_data:
return self.extract_vertices_from_gltf(json_data, binary_data)
except Exception as e:
print(f"解析GLB数据失败: {e}")
return None
def extract_vertices_from_gltf(self, gltf_json, binary_data):
"""从glTF JSON中提取顶点数据"""
vertices = []
try:
# 检查是否使用了Draco压缩
if 'extensionsUsed' in gltf_json and 'KHR_draco_mesh_compression' in gltf_json['extensionsUsed']:
if DracoPy is None:
print("警告: 检测到Draco压缩但DracoPy未安装")
return vertices
return self.extract_draco_vertices(gltf_json, binary_data)
# 处理标准glTF格式
if 'meshes' not in gltf_json:
return vertices
for mesh in gltf_json['meshes']:
for primitive in mesh['primitives']:
if 'attributes' in primitive and 'POSITION' in primitive['attributes']:
position_accessor_index = primitive['attributes']['POSITION']
if 'accessors' in gltf_json and position_accessor_index < len(gltf_json['accessors']):
accessor = gltf_json['accessors'][position_accessor_index]
buffer_view_index = accessor['bufferView']
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
buffer_view = gltf_json['bufferViews'][buffer_view_index]
buffer_index = buffer_view['buffer']
byte_offset = buffer_view.get('byteOffset', 0) + accessor.get('byteOffset', 0)
if binary_data and buffer_index == 0:
# 从二进制数据中读取顶点
component_type = accessor['componentType']
count = accessor['count']
if component_type == 5126: # FLOAT
vertex_data = struct.unpack(f'<{count*3}f',
binary_data[byte_offset:byte_offset+count*12])
for i in range(0, len(vertex_data), 3):
vertices.append([vertex_data[i], vertex_data[i+1], vertex_data[i+2]])
except Exception as e:
print(f"提取顶点数据失败: {e}")
return vertices
def extract_draco_vertices(self, gltf_json, binary_data):
"""提取Draco压缩的顶点数据"""
vertices = []
try:
if 'meshes' not in gltf_json:
return vertices
for mesh in gltf_json['meshes']:
for primitive in mesh['primitives']:
if 'extensions' in primitive and 'KHR_draco_mesh_compression' in primitive['extensions']:
draco_ext = primitive['extensions']['KHR_draco_mesh_compression']
buffer_view_index = draco_ext['bufferView']
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
buffer_view = gltf_json['bufferViews'][buffer_view_index]
byte_offset = buffer_view.get('byteOffset', 0)
byte_length = buffer_view['byteLength']
if binary_data:
draco_data = binary_data[byte_offset:byte_offset+byte_length]
# 使用DracoPy解码
mesh_data = DracoPy.decode(draco_data)
if hasattr(mesh_data, 'points'):
points = mesh_data.points
for point in points:
vertices.append([float(point[0]), float(point[1]), float(point[2])])
except Exception as e:
print(f"解码Draco数据失败: {e}")
return vertices
def save_ply_file(self, output_path):
"""保存PLY文件"""
try:
with open(output_path, 'w') as f:
# 写入PLY头部
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(self.all_vertices)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("end_header\n")
# 写入顶点数据
for vertex in self.all_vertices:
f.write(f"{vertex[0]} {vertex[1]} {vertex[2]}\n")
print(f"PLY文件已保存: {output_path}")
print(f"总顶点数: {len(self.all_vertices)}")
except Exception as e:
print(f"保存PLY文件失败: {e}")
def convert_tileset_to_ply(self, tileset_path, output_path):
"""将整个tileset转换为PLY文件"""
print(f"开始处理tileset: {tileset_path}")
# 解析主tileset.json
tileset_data = self.parse_tileset_json(tileset_path)
if not tileset_data:
print("无法解析tileset.json文件")
return False
# 获取基础路径
base_path = os.path.dirname(tileset_path)
# 提取所有b3dm文件和变换矩阵
b3dm_files = self.parse_tileset_json(tileset_path)
print(f"找到 {len(b3dm_files)} 个B3DM文件")
if not b3dm_files:
print("未找到任何B3DM文件")
return False
# 处理每个b3dm文件
processed_count = 0
for i, (b3dm_file, transform_matrix) in enumerate(b3dm_files):
print(f"处理文件 {i+1}/{len(b3dm_files)}: {os.path.basename(b3dm_file)}")
vertices = self.parse_b3dm_file(b3dm_file)
if vertices:
# 应用变换矩阵
if transform_matrix:
vertices = self.apply_transform_to_vertices(vertices, transform_matrix)
print(f" 应用了变换矩阵")
self.all_vertices.extend(vertices)
processed_count += 1
print(f" 提取到 {len(vertices)} 个顶点")
print(f"\n成功处理 {processed_count}/{len(b3dm_files)} 个文件")
print(f"总计提取 {len(self.all_vertices)} 个顶点")
if self.all_vertices:
self.save_ply_file(output_path)
return True
else:
print("未提取到任何顶点数据")
return False
def main():
if len(sys.argv) < 2:
print("用法: python tileset_to_ply.py <tileset.json路径> [输出PLY文件路径]")
print("示例: python tileset_to_ply.py tileset.json output.ply")
return
tileset_path = sys.argv[1]
output_path = sys.argv[2] if len(sys.argv) > 2 else "merged_tileset.ply"
if not os.path.exists(tileset_path):
print(f"错误: 文件不存在 {tileset_path}")
return
converter = TilesetToPLYConverter()
success = converter.convert_tileset_to_ply(tileset_path, output_path)
if success:
print("\n转换完成!")
else:
print("\n转换失败!")
if __name__ == "__main__":
main()

684
b3dm/volume_calculator.py Normal file
View File

@ -0,0 +1,684 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
三维模型体积计算器
基于指定地理范围内的三维模型数据使用三角构网方法计算体积
"""
import json
import struct
import os
import sys
from pathlib import Path
import numpy as np
from scipy.spatial import Delaunay
import math
from tqdm import tqdm
try:
import DracoPy
except ImportError:
print("警告: DracoPy库未安装无法处理Draco压缩的数据")
print("请运行: pip install DracoPy")
DracoPy = None
class VolumeCalculator:
def __init__(self, location_file):
self.location_bounds = self.load_location_bounds(location_file)
self.all_vertices = []
self.filtered_vertices = []
def load_location_bounds(self, location_file):
"""加载地理范围边界"""
try:
with open(location_file, 'r', encoding='utf-8') as f:
coords = json.load(f)
# 提取经纬度范围
lons = [coord[0] for coord in coords]
lats = [coord[1] for coord in coords]
elevs = [coord[2] for coord in coords]
bounds = {
'min_lon': min(lons),
'max_lon': max(lons),
'min_lat': min(lats),
'max_lat': max(lats),
'min_elev': min(elevs),
'max_elev': max(elevs),
'coords': coords
}
print(f"地理范围边界:")
print(f" 经度: {bounds['min_lon']:.8f} ~ {bounds['max_lon']:.8f}")
print(f" 纬度: {bounds['min_lat']:.8f} ~ {bounds['max_lat']:.8f}")
print(f" 高程: {bounds['min_elev']:.2f} ~ {bounds['max_elev']:.2f}")
return bounds
except Exception as e:
print(f"加载地理范围文件失败: {e}")
return None
def wgs84_to_cartesian(self, lon, lat, elev):
"""WGS84坐标转换为笛卡尔坐标高精度算法"""
# WGS84椭球参数EPSG:4326
a = 6378137.0 # 长半轴 (米)
f = 1/298.257223563 # 扁率
e2 = 2*f - f*f # 第一偏心率平方
# 角度转弧度
lon_rad = math.radians(lon)
lat_rad = math.radians(lat)
# 计算卯酉圈曲率半径
sin_lat = math.sin(lat_rad)
cos_lat = math.cos(lat_rad)
sin_lon = math.sin(lon_rad)
cos_lon = math.cos(lon_rad)
N = a / math.sqrt(1 - e2 * sin_lat * sin_lat)
# 高精度笛卡尔坐标计算
x = (N + elev) * cos_lat * cos_lon
y = (N + elev) * cos_lat * sin_lon
z = (N * (1 - e2) + elev) * sin_lat
return [x, y, z]
def cartesian_to_wgs84(self, x, y, z):
"""笛卡尔坐标转换为WGS84坐标高精度迭代法"""
# WGS84椭球参数
a = 6378137.0 # 长半轴
f = 1/298.257223563 # 扁率
e2 = 2*f - f*f # 第一偏心率平方
ep2 = e2 / (1 - e2) # 第二偏心率平方
# 计算经度(精确值)
lon = math.atan2(y, x)
# 计算纬度和高程使用Bowring迭代法
p = math.sqrt(x*x + y*y)
if p == 0:
# 极点情况
lat = math.pi/2 if z > 0 else -math.pi/2
elev = abs(z) - a * math.sqrt(1 - e2)
return [math.degrees(lon), math.degrees(lat), elev]
# 初始估计
theta = math.atan2(z, p * (1 - f))
lat_prev = math.atan2(z + ep2 * a * (1 - f) * math.sin(theta)**3,
p - e2 * a * math.cos(theta)**3)
# 迭代求解纬度
max_iterations = 10
tolerance = 1e-12
for i in range(max_iterations):
N = a / math.sqrt(1 - e2 * math.sin(lat_prev)**2)
elev = p / math.cos(lat_prev) - N
# 更新纬度估计
lat_new = math.atan2(z + e2 * N * math.sin(lat_prev), p)
# 检查收敛性
if abs(lat_new - lat_prev) < tolerance:
break
lat_prev = lat_new
# 最终计算高程
N = a / math.sqrt(1 - e2 * math.sin(lat_prev)**2)
elev = p / math.cos(lat_prev) - N
return [math.degrees(lon), math.degrees(lat_prev), elev]
def is_point_in_bounds(self, vertex):
"""检查点是否在指定的地理范围内"""
if not self.location_bounds:
return True
# 将笛卡尔坐标转换为WGS84
try:
lon, lat, elev = self.cartesian_to_wgs84(vertex[0], vertex[1], vertex[2])
# 检查是否在边界范围内
return (self.location_bounds['min_lon'] <= lon <= self.location_bounds['max_lon'] and
self.location_bounds['min_lat'] <= lat <= self.location_bounds['max_lat'])
except:
return False
def multiply_matrix_vector(self, matrix, vector):
"""4x4矩阵与4D向量相乘"""
m = [
[matrix[0], matrix[4], matrix[8], matrix[12]],
[matrix[1], matrix[5], matrix[9], matrix[13]],
[matrix[2], matrix[6], matrix[10], matrix[14]],
[matrix[3], matrix[7], matrix[11], matrix[15]]
]
v = [vector[0], vector[1], vector[2], 1.0]
result = [
m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2] + m[0][3]*v[3],
m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2] + m[1][3]*v[3],
m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2] + m[2][3]*v[3]
]
return result
def multiply_matrices(self, m1, m2):
"""两个4x4矩阵相乘"""
def list_to_matrix(lst):
return [
[lst[0], lst[4], lst[8], lst[12]],
[lst[1], lst[5], lst[9], lst[13]],
[lst[2], lst[6], lst[10], lst[14]],
[lst[3], lst[7], lst[11], lst[15]]
]
def matrix_to_list(mat):
return [
mat[0][0], mat[1][0], mat[2][0], mat[3][0],
mat[0][1], mat[1][1], mat[2][1], mat[3][1],
mat[0][2], mat[1][2], mat[2][2], mat[3][2],
mat[0][3], mat[1][3], mat[2][3], mat[3][3]
]
a = list_to_matrix(m1)
b = list_to_matrix(m2)
result = [[0 for _ in range(4)] for _ in range(4)]
for i in range(4):
for j in range(4):
for k in range(4):
result[i][j] += a[i][k] * b[k][j]
return matrix_to_list(result)
def apply_transform_to_vertices(self, vertices, transform_matrix):
"""对顶点应用变换矩阵"""
if not transform_matrix:
return vertices
transformed_vertices = []
for vertex in vertices:
transformed = self.multiply_matrix_vector(transform_matrix, vertex)
transformed_vertices.append(transformed)
return transformed_vertices
def parse_tileset_json(self, tileset_path, parent_transform=None):
"""解析tileset.json文件收集B3DM文件和变换矩阵"""
try:
with open(tileset_path, 'r', encoding='utf-8') as f:
tileset_data = json.load(f)
b3dm_files = []
def process_node(node, base_path, accumulated_transform):
current_transform = node.get('transform')
if current_transform and accumulated_transform:
final_transform = self.multiply_matrices(accumulated_transform, current_transform)
elif current_transform:
final_transform = current_transform
else:
final_transform = accumulated_transform
if 'content' in node and 'uri' in node['content']:
uri = node['content']['uri']
if uri.endswith('.b3dm'):
full_path = os.path.join(base_path, uri)
if os.path.exists(full_path):
b3dm_files.append((full_path, final_transform))
elif uri.endswith('.json'):
sub_tileset_path = os.path.join(base_path, uri)
if os.path.exists(sub_tileset_path):
sub_files = self.parse_tileset_json(sub_tileset_path, final_transform)
b3dm_files.extend(sub_files)
if 'children' in node:
for child in node['children']:
process_node(child, base_path, final_transform)
base_path = os.path.dirname(tileset_path)
if 'root' in tileset_data:
process_node(tileset_data['root'], base_path, parent_transform)
return b3dm_files
except Exception as e:
print(f"解析tileset.json时出错: {e}")
return []
def parse_b3dm_file(self, file_path):
"""解析B3DM文件"""
try:
with open(file_path, 'rb') as f:
magic = f.read(4)
if magic != b'b3dm':
return None
version = struct.unpack('<I', f.read(4))[0]
byte_length = struct.unpack('<I', f.read(4))[0]
feature_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
feature_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
batch_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
batch_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
f.seek(28 + feature_table_json_byte_length + feature_table_binary_byte_length +
batch_table_json_byte_length + batch_table_binary_byte_length)
gltf_data = f.read()
return self.parse_gltf_data(gltf_data)
except Exception as e:
print(f"解析B3DM文件 {file_path} 失败: {e}")
return None
def parse_gltf_data(self, gltf_data):
"""解析glTF数据"""
try:
if gltf_data[:4] == b'glTF':
return self.parse_glb_data(gltf_data)
else:
gltf_json = json.loads(gltf_data.decode('utf-8'))
return self.extract_vertices_from_gltf(gltf_json, None)
except Exception as e:
print(f"解析glTF数据失败: {e}")
return None
def parse_glb_data(self, glb_data):
"""解析GLB格式的glTF数据"""
try:
magic = glb_data[:4]
if magic != b'glTF':
return None
version = struct.unpack('<I', glb_data[4:8])[0]
total_length = struct.unpack('<I', glb_data[8:12])[0]
offset = 12
json_data = None
binary_data = None
while offset < len(glb_data):
if offset + 8 > len(glb_data):
break
chunk_length = struct.unpack('<I', glb_data[offset:offset+4])[0]
chunk_type = glb_data[offset+4:offset+8]
chunk_data = glb_data[offset+8:offset+8+chunk_length]
if chunk_type == b'JSON':
json_data = json.loads(chunk_data.decode('utf-8'))
elif chunk_type == b'BIN\x00':
binary_data = chunk_data
offset += 8 + chunk_length
offset = (offset + 3) & ~3
if json_data:
return self.extract_vertices_from_gltf(json_data, binary_data)
except Exception as e:
print(f"解析GLB数据失败: {e}")
return None
def extract_vertices_from_gltf(self, gltf_json, binary_data):
"""从glTF JSON中提取顶点数据"""
vertices = []
try:
if 'extensionsUsed' in gltf_json and 'KHR_draco_mesh_compression' in gltf_json['extensionsUsed']:
if DracoPy is None:
print("警告: 检测到Draco压缩但DracoPy未安装")
return vertices
return self.extract_draco_vertices(gltf_json, binary_data)
if 'meshes' not in gltf_json:
return vertices
for mesh in gltf_json['meshes']:
for primitive in mesh['primitives']:
if 'attributes' in primitive and 'POSITION' in primitive['attributes']:
position_accessor_index = primitive['attributes']['POSITION']
if 'accessors' in gltf_json and position_accessor_index < len(gltf_json['accessors']):
accessor = gltf_json['accessors'][position_accessor_index]
buffer_view_index = accessor['bufferView']
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
buffer_view = gltf_json['bufferViews'][buffer_view_index]
buffer_index = buffer_view['buffer']
byte_offset = buffer_view.get('byteOffset', 0) + accessor.get('byteOffset', 0)
if binary_data and buffer_index == 0:
component_type = accessor['componentType']
count = accessor['count']
if component_type == 5126: # FLOAT
vertex_data = struct.unpack(f'<{count*3}f',
binary_data[byte_offset:byte_offset+count*12])
for i in range(0, len(vertex_data), 3):
vertices.append([vertex_data[i], vertex_data[i+1], vertex_data[i+2]])
except Exception as e:
print(f"提取顶点数据失败: {e}")
return vertices
def extract_draco_vertices(self, gltf_json, binary_data):
"""提取Draco压缩的顶点数据"""
vertices = []
try:
if 'meshes' not in gltf_json:
return vertices
for mesh in gltf_json['meshes']:
for primitive in mesh['primitives']:
if 'extensions' in primitive and 'KHR_draco_mesh_compression' in primitive['extensions']:
draco_ext = primitive['extensions']['KHR_draco_mesh_compression']
buffer_view_index = draco_ext['bufferView']
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
buffer_view = gltf_json['bufferViews'][buffer_view_index]
byte_offset = buffer_view.get('byteOffset', 0)
byte_length = buffer_view['byteLength']
if binary_data:
draco_data = binary_data[byte_offset:byte_offset+byte_length]
mesh_data = DracoPy.decode(draco_data)
if hasattr(mesh_data, 'points'):
points = mesh_data.points
for point in points:
vertices.append([float(point[0]), float(point[1]), float(point[2])])
except Exception as e:
print(f"解码Draco数据失败: {e}")
return vertices
def calculate_triangle_angles(self, p1, p2, p3):
"""计算三角形的三个内角(度)"""
# 计算三边长度
a = np.linalg.norm(p2 - p3) # 边a对应角A(p1)
b = np.linalg.norm(p1 - p3) # 边b对应角B(p2)
c = np.linalg.norm(p1 - p2) # 边c对应角C(p3)
# 避免除零错误
if a == 0 or b == 0 or c == 0:
return [0, 0, 0]
# 使用余弦定理计算角度
try:
# 角A = arccos((b²+c²-a²)/(2bc))
cos_A = (b*b + c*c - a*a) / (2*b*c)
cos_B = (a*a + c*c - b*b) / (2*a*c)
cos_C = (a*a + b*b - c*c) / (2*a*b)
# 限制余弦值范围,避免数值误差
cos_A = np.clip(cos_A, -1.0, 1.0)
cos_B = np.clip(cos_B, -1.0, 1.0)
cos_C = np.clip(cos_C, -1.0, 1.0)
angle_A = np.arccos(cos_A) * 180 / np.pi
angle_B = np.arccos(cos_B) * 180 / np.pi
angle_C = np.arccos(cos_C) * 180 / np.pi
return [angle_A, angle_B, angle_C]
except:
return [0, 0, 0]
def is_valid_triangle(self, p1, p2, p3, min_angle=10.0, max_aspect_ratio=10.0):
"""验证三角形质量,基于角度约束和长宽比"""
angles = self.calculate_triangle_angles(p1, p2, p3)
# 检查最小角度约束参考C#代码中的10度限制
if min(angles) < min_angle:
return False
# 计算三边长度
a = np.linalg.norm(p2 - p3)
b = np.linalg.norm(p1 - p3)
c = np.linalg.norm(p1 - p2)
# 检查长宽比约束
max_side = max(a, b, c)
min_side = min(a, b, c)
if min_side > 0 and max_side / min_side > max_aspect_ratio:
return False
return True
def calculate_circumcenter_and_radius(self, p1, p2, p3):
"""计算三角形外接圆圆心和半径(高精度算法)"""
try:
x1, y1 = p1[0], p1[1]
x2, y2 = p2[0], p2[1]
x3, y3 = p3[0], p3[1]
# 使用C#代码中的高精度外接圆计算公式
d = 2 * (x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2))
if abs(d) < 1e-10: # 三点共线
return None, float('inf')
ux = ((x1*x1 + y1*y1) * (y2 - y3) + (x2*x2 + y2*y2) * (y3 - y1) + (x3*x3 + y3*y3) * (y1 - y2)) / d
uy = ((x1*x1 + y1*y1) * (x3 - x2) + (x2*x2 + y2*y2) * (x1 - y3) + (x3*x3 + y3*y3) * (x2 - x1)) / d
# 计算半径
radius = np.sqrt((ux - x1)**2 + (uy - y1)**2)
return np.array([ux, uy]), radius
except:
return None, float('inf')
def calculate_volume_delaunay(self, vertices, base_elevation=None, min_angle=10.0, use_quality_filter=True):
"""使用Delaunay三角剖分计算体积"""
if len(vertices) < 4:
print("顶点数量不足,无法进行三角剖分")
return 0.0
try:
points = np.array(vertices)
if base_elevation is None:
base_elevation = np.min(points[:, 2])
print(f"Delaunay方法使用基准面高程: {base_elevation:.2f}")
print(f"质量控制参数: 最小角度={min_angle}°, 质量过滤={'开启' if use_quality_filter else '关闭'}")
adjusted_points = points.copy()
adjusted_points[:, 2] = np.maximum(points[:, 2] - base_elevation, 0)
print("正在进行Delaunay三角剖分...")
tri = Delaunay(adjusted_points)
valid_simplices = []
total_simplices = len(tri.simplices)
if use_quality_filter:
print("正在进行三角形质量检查...")
for simplex in tqdm(tri.simplices, desc="质量检查", unit=""):
p0 = adjusted_points[simplex[0]]
p1 = adjusted_points[simplex[1]]
p2 = adjusted_points[simplex[2]]
p3 = adjusted_points[simplex[3]]
faces = [(p0, p1, p2), (p0, p1, p3), (p0, p2, p3), (p1, p2, p3)]
valid_faces = 0
for face in faces:
if self.is_valid_triangle(face[0], face[1], face[2], min_angle):
valid_faces += 1
if valid_faces >= 3:
valid_simplices.append(simplex)
print(f"质量过滤: {len(valid_simplices)}/{total_simplices} 个四面体通过质量检查")
else:
valid_simplices = tri.simplices
total_volume = 0.0
valid_volume_count = 0
print(f"计算 {len(valid_simplices)} 个四面体的体积...")
for simplex in tqdm(valid_simplices, desc="计算四面体体积", unit=""):
p0 = adjusted_points[simplex[0]]
p1 = adjusted_points[simplex[1]]
p2 = adjusted_points[simplex[2]]
p3 = adjusted_points[simplex[3]]
v1 = p1 - p0
v2 = p2 - p0
v3 = p3 - p0
det = np.linalg.det(np.array([v1, v2, v3]))
volume = abs(det) / 6.0
if volume > 1e-12:
total_volume += volume
valid_volume_count += 1
print(f"有效体积计算: {valid_volume_count}/{len(valid_simplices)} 个四面体")
return total_volume
except Exception as e:
print(f"Delaunay三角剖分计算体积失败: {e}")
return 0.0
def load_and_filter_vertices(self, tileset_path):
"""加载并过滤指定范围内的顶点"""
print(f"开始处理tileset: {tileset_path}")
# 解析tileset获取所有B3DM文件
b3dm_files = self.parse_tileset_json(tileset_path)
print(f"找到 {len(b3dm_files)} 个B3DM文件")
if not b3dm_files:
print("未找到任何B3DM文件")
return False
# 处理每个B3DM文件
processed_count = 0
total_vertices = 0
filtered_count = 0
for i, (b3dm_file, transform_matrix) in enumerate(tqdm(b3dm_files, desc="处理B3DM文件", unit="文件")):
# print(f"处理文件 {i+1}/{len(b3dm_files)}: {os.path.basename(b3dm_file)}")
vertices = self.parse_b3dm_file(b3dm_file)
if vertices:
# 应用变换矩阵
if transform_matrix:
vertices = self.apply_transform_to_vertices(vertices, transform_matrix)
# 过滤范围内的顶点
for vertex in vertices:
total_vertices += 1
if self.is_point_in_bounds(vertex):
self.filtered_vertices.append(vertex)
filtered_count += 1
self.all_vertices.extend(vertices)
processed_count += 1
# tqdm.write(f" {os.path.basename(b3dm_file)}: 提取到 {len(vertices)} 个顶点")
print(f"\n成功处理 {processed_count}/{len(b3dm_files)} 个文件")
print(f"总计提取 {total_vertices} 个顶点")
print(f"范围内顶点 {filtered_count}")
return len(self.filtered_vertices) > 0
def calculate_volume(self, tileset_path, base_elevation=None, min_angle=10.0, use_quality_filter=True):
"""计算指定范围内三维模型的体积
Args:
tileset_path: tileset.json文件路径
base_elevation: 基准面高程
min_angle: 最小角度约束
use_quality_filter: 是否启用质量过滤
"""
if not self.load_and_filter_vertices(tileset_path):
print("未找到范围内的顶点数据")
return 0.0
print(f"\n开始计算体积使用Delaunay三角剖分方法")
print(f"参与计算的顶点数: {len(self.filtered_vertices)}")
points = np.array(self.filtered_vertices)
if base_elevation is None:
base_elevation = np.min(points[:, 2])
print(f"统一基准面高程: {base_elevation:.2f}")
volume = self.calculate_volume_delaunay(self.filtered_vertices, base_elevation, min_angle, use_quality_filter)
print(f"\n计算结果:")
print(f"体积: {volume:.6f} 立方米")
print(f"体积: {volume/1000000:.6f} 立方千米")
return volume
def main():
if len(sys.argv) < 3:
print("用法: python volume_calculator.py <lboundary.json路径> <tileset.json路径> [基准面高程] [最小角度] [质量过滤]")
print("基准面高程: 基准面高程(米),默认使用最低点")
print("最小角度: 三角形最小角度约束(度)默认10.0")
print("质量过滤: 是否启用质量过滤(true/false)默认true")
print("示例: python volume_calculator.py boundary.json tileset.json 100.0 15.0 true")
return
location_file = sys.argv[1]
tileset_path = sys.argv[2]
# 解析可选参数
base_elevation = None
min_angle = 10.0
use_quality_filter = True
if len(sys.argv) > 3:
try:
base_elevation = float(sys.argv[3])
except ValueError:
print("错误:基准面高程必须是数字")
return
if len(sys.argv) > 4:
try:
min_angle = float(sys.argv[4])
except ValueError:
print("错误:最小角度必须是数字")
return
if len(sys.argv) > 5:
use_quality_filter = sys.argv[5].lower() in ['true', '1', 'yes', 'on']
if not os.path.exists(location_file):
print(f"错误: 位置文件不存在 {location_file}")
return
if not os.path.exists(tileset_path):
print(f"错误: tileset文件不存在 {tileset_path}")
return
calculator = VolumeCalculator(location_file)
volume = calculator.calculate_volume(tileset_path, base_elevation, min_angle, use_quality_filter)
if volume > 0:
print("\n体积计算完成!")
else:
print("\n体积计算失败!")
if __name__ == "__main__":
main()

47
check_grpc_server.py Normal file
View File

@ -0,0 +1,47 @@
from concurrent import futures
import grpc
import time
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer):
def ProcessTask(self, request, context):
print(f"Received task_id: {request.task_id}")
return check_grpc_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message="Task processed successfully"
)
class HealthCheckServicer(check_grpc_pb2_grpc.HealthCheckServicer):
def Check(self, request, context):
# 简单实现总是返回SERVING状态
# 实际应用中可以根据服务状态返回不同值
return check_grpc_pb2.HealthCheckResponse(
status=check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
# 添加服务实现
check_grpc_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
check_grpc_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
server.add_insecure_port('[::]:50051')
server.start()
print("Server started, listening on port 50051...")
try:
while True:
time.sleep(86400) # 保持运行
except KeyboardInterrupt:
server.stop(0)
if __name__ == '__main__':
serve()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.3 MiB

View File

@ -1,8 +0,0 @@
import os.path
from middleware.minio_util import downBigFile
miniourl=r"media/22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_202509121519_001_22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_20250912152112_0001_V.mp4"
file_path=downBigFile(miniourl)
print(f"os.path.abspath(file_path) {os.path.abspath(file_path)}")

View File

@ -94,7 +94,7 @@ def func_100000(results, cls_id_list, type_name_list, func_id_10001, list_track_
trickier_detail = { trickier_detail = {
# "track_id": results.track_ids[i], # "track_id": results.track_ids[i],
"confidence": results.confs[i], "confidence": results.confs[i],
"cls_id": i, "cls_id": ind,
"type_name": type_name_list[ind], "type_name": type_name_list[ind],
"box": boxes[i] "box": boxes[i]
} }

View File

@ -0,0 +1,58 @@
syntax = "proto3";
package task;
service TaskService {
rpc ProcessTask (TaskRequest) returns (TaskResponse);
}
//
service HealthCheck {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
message TaskRequest {
string task_id = 1;
string sn = 2;
ContentBody content_body = 3;
}
message ContentBody {
string org_code = 1;
repeated int32 func_id = 2;
string source_url = 3;
string push_url = 4;
float confidence = 5;
repeated ParaList para_list = 6;
Invade invade = 7;
}
message ParaList {
int32 func_id = 1;
bool para_invade_enable = 2;
}
message Invade {
string invade_file = 1;
string camera_para_url = 2;
}
message TaskResponse {
string task_id = 1;
bool success = 2;
string message = 3;
}

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: grpc_proto/check_grpc/check_grpc.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
6,
31,
1,
'',
'grpc_proto/check_grpc/check_grpc.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&grpc_proto/check_grpc/check_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_proto.check_grpc.check_grpc_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_HEALTHCHECKREQUEST']._serialized_start=48
_globals['_HEALTHCHECKREQUEST']._serialized_end=85
_globals['_HEALTHCHECKRESPONSE']._serialized_start=88
_globals['_HEALTHCHECKRESPONSE']._serialized_end=247
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=168
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=247
_globals['_TASKREQUEST']._serialized_start=249
_globals['_TASKREQUEST']._serialized_end=332
_globals['_CONTENTBODY']._serialized_start=335
_globals['_CONTENTBODY']._serialized_end=506
_globals['_PARALIST']._serialized_start=508
_globals['_PARALIST']._serialized_end=563
_globals['_INVADE']._serialized_start=565
_globals['_INVADE']._serialized_end=619
_globals['_TASKRESPONSE']._serialized_start=621
_globals['_TASKRESPONSE']._serialized_end=686
_globals['_TASKSERVICE']._serialized_start=688
_globals['_TASKSERVICE']._serialized_end=755
_globals['_HEALTHCHECK']._serialized_start=757
_globals['_HEALTHCHECK']._serialized_end=832
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,172 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from grpc_proto.check_grpc import check_grpc_pb2 as grpc__proto_dot_check__grpc_dot_check__grpc__pb2
GRPC_GENERATED_VERSION = '1.76.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ ' but the generated code in grpc_proto/check_grpc/check_grpc_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class TaskServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ProcessTask = channel.unary_unary(
'/task.TaskService/ProcessTask',
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString,
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString,
_registered_method=True)
class TaskServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def ProcessTask(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TaskServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'ProcessTask': grpc.unary_unary_rpc_method_handler(
servicer.ProcessTask,
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.FromString,
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'task.TaskService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('task.TaskService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class TaskService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ProcessTask(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/task.TaskService/ProcessTask',
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString,
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
class HealthCheckStub(object):
"""添加健康检查服务
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Check = channel.unary_unary(
'/task.HealthCheck/Check',
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString,
_registered_method=True)
class HealthCheckServicer(object):
"""添加健康检查服务
"""
def Check(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_HealthCheckServicer_to_server(servicer, server):
rpc_method_handlers = {
'Check': grpc.unary_unary_rpc_method_handler(
servicer.Check,
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.FromString,
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'task.HealthCheck', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('task.HealthCheck', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class HealthCheck(object):
"""添加健康检查服务
"""
@staticmethod
def Check(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/task.HealthCheck/Check',
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString,
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@ -0,0 +1,81 @@
import grpc
import time
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
def check_server_status(channel):
try:
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
def check_grpc_request(max_retries=3, delay=5):
channel = None
retries = 0
while retries < max_retries:
try:
# 创建通道
channel = grpc.insecure_channel('localhost:50051')
# 检查服务器状态
if not check_server_status(channel):
raise Exception("Server is not healthy")
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
# 创建请求消息
request = check_grpc_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
sn="8UUXN6S00A0CK7",
content_body=check_grpc_pb2.ContentBody(
org_code="HMZHB",
func_id=[101204],
source_url="xxxxxxxxxx",
push_url="",
confidence=0.4,
para_list=[
check_grpc_pb2.ParaList(
func_id=101204,
para_invade_enable=True
)
],
invade=check_grpc_pb2.Invade(
invade_file="meta_data/高压线-0826.geojson",
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
)
)
)
# 调用远程方法
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
finally:
if channel:
channel.close()
print("All retry attempts failed")
return False
if __name__ == '__main__':
check_grpc_request()

284
md/grpc.md Normal file
View File

@ -0,0 +1,284 @@
# 前言
sanic 和 服务之间基于grpc 解绑一个服务一个grpc
可以参考接口grpc 最好留有健康校验
# grpc demo
为了增强 gRPC 通讯的可靠性,我们可以添加以下功能:
1. 服务器健康检查服务
2. 客户端连接前检查服务器状态
3. 通讯过程中的错误处理和重试机制
## 1. 修改 protobuf 定义 (task.proto)
首先添加健康检查服务定义:
```protobuf
syntax = "proto3";
package task;
service TaskService {
rpc ProcessTask (TaskRequest) returns (TaskResponse);
}
// 添加健康检查服务
service HealthCheck {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
message TaskRequest {
string task_id = 1;
string sn = 2;
ContentBody content_body = 3;
}
message ContentBody {
string org_code = 1;
repeated int32 func_id = 2;
string source_url = 3;
string push_url = 4;
float confidence = 5;
repeated ParaList para_list = 6;
Invade invade = 7;
}
message ParaList {
int32 func_id = 1;
bool para_invade_enable = 2;
}
message Invade {
string invade_file = 1;
string camera_para_url = 2;
}
message TaskResponse {
string task_id = 1;
bool success = 2;
string message = 3;
}
```
## 2. 重新生成 Python 代码
```bash
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. task.proto
protoc -I=./proto --go_out=./gen proto/subfolder/task.proto
```
## 3. 增强服务端实现 (server.py)
```python
from concurrent import futures
import grpc
import time
import task_pb2
import task_pb2_grpc
class TaskServiceServicer(task_pb2_grpc.TaskServiceServicer):
def ProcessTask(self, request, context):
print(f"Received task_id: {request.task_id}")
# ... 原有处理逻辑
return task_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message="Task processed successfully"
)
class HealthCheckServicer(task_pb2_grpc.HealthCheckServicer):
def Check(self, request, context):
# 简单实现总是返回SERVING状态
# 实际应用中可以根据服务状态返回不同值
return task_pb2.HealthCheckResponse(
status=task_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
# 添加服务实现
task_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
task_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
server.add_insecure_port('[::]:50051')
server.start()
print("Server started, listening on port 50051...")
try:
while True:
time.sleep(86400) # 保持运行
except KeyboardInterrupt:
server.stop(0)
if __name__ == '__main__':
serve()
```
## 4. 增强客户端实现 (client.py)
```python
import grpc
import time
import task_pb2
import task_pb2_grpc
def check_server_status(channel):
try:
health_stub = task_pb2_grpc.HealthCheckStub(channel)
response = health_stub.Check(task_pb2.HealthCheckRequest(service="TaskService"))
return response.status == task_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
def run_with_retry(max_retries=3, delay=5):
channel = None
retries = 0
while retries < max_retries:
try:
# 创建通道
channel = grpc.insecure_channel('localhost:50051')
# 检查服务器状态
if not check_server_status(channel):
raise Exception("Server is not healthy")
stub = task_pb2_grpc.TaskServiceStub(channel)
# 创建请求消息
request = task_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
sn="8UUXN6S00A0CK7",
content_body=task_pb2.ContentBody(
org_code="HMZHB",
func_id=[101204],
source_url="rtmp://222.212.85.86:1935/live/1581F8HGX253S00A05L8",
push_url="",
confidence=0.4,
para_list=[
task_pb2.ParaList(
func_id=101204,
para_invade_enable=True
)
],
invade=task_pb2.Invade(
invade_file="meta_data/高压线-0826.geojson",
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
)
)
)
# 调用远程方法
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
finally:
if channel:
channel.close()
print("All retry attempts failed")
return False
if __name__ == '__main__':
run_with_retry()
```
## 5. 增强功能说明
1. 健康检查服务
- 添加了标准的 gRPC 健康检查服务
- 客户端可以在执行主要操作前检查服务状态
2. 错误处理和重试机制
- 客户端现在会捕获 `grpc.RpcError` 和其他异常
- 实现了最大重试次数和重试间隔
- 每次重试前都会检查服务器状态
3. 资源管理
- 确保在所有情况下都正确关闭通道
- 使用上下文管理器处理通道生命周期
4. 状态反馈
- 提供更详细的错误信息
- 记录重试尝试
## 6. 运行说明
1. 安装依赖:
```bash
pip install grpcio grpcio-tools
```
2. 生成 protobuf 代码:
```bash
# 确保在项目根目录下执行
# 编译 proto 文件,输出到 a/b/ 目录
python -m grpc_tools.protoc \
-I. \ # 指定 proto 文件的根目录(当前目录 "."
--python_out=. \ # 生成 _pb2.py 到当前目录(或指定子目录)
--grpc_python_out=. \ # 生成 _pb2_grpc.py 到当前目录(或指定子目录)
a/b/task.proto # proto 文件路径(相对于 -I 指定的根目录)
```
3. 启动服务器:
```bash
python server.py
```
4. 运行客户端:
```bash
python client.py
```
## 7. 测试场景
1. 服务器未运行
- 客户端会检测到连接失败并重试
- 最终显示所有重试失败
2. 服务器运行但健康检查失败
- 可以修改 `HealthCheckServicer` 返回 `NOT_SERVING` 状态进行测试
- 客户端会拒绝执行主要操作
3. 网络中断
- 客户端会捕获异常并尝试重试
这个增强版本提供了更健壮的 gRPC 通讯机制,适合生产环境使用。

350
md/接口.md Normal file
View File

@ -0,0 +1,350 @@
# 算法与后台解耦规则
# 1、方法
postgres 的ai_model_list 表id字段声明为6位长度数字
1、第1位表示算法类别1xxxxx 表明为目标识别、2xxxxx标明为语义分割、3xxxxx表示变化监测
2、最后两位表示二次计算100001 表示做目标识别、100002表示做目标识别且做人员计数
# 接口名:视频流识别
接收前端的视频流、模型、识别类型算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/stream/back_detect
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "1234567890", #任务id
"sn":"", #无人机sn
"content_body": {
"source_url": "rtmp://192.168.0.142:1935/live/123456", #无人机视频流url
"confidence":0.4, #置信度
"model_func_id":[100001,100002] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "1234567890",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "1234567890",
"minio": {
"minio_path": "ai_result/20250702/1751425303860-output-1751425303800959985.jpg",
"file_type": "pic"
},
"box_detail": {
"result_100001": {
"func_id_10001": 100001,
"type_name": "行人",
"cls_count": 1,
"box_count": [
[
{
"track_id": 22099,
"confidence": 0.34013107419013977,
"cls_id": 0,
"type_name": "行人",
"box": [
15.935794830322266,
694.75390625,
33.22901916503906,
713.1658935546875
]
}
]
]
}
},
"uav_location": {
"data": {
"attitude_head": 60,
"gimbal_pitch": 60,
"gimbal_roll": 60,
"gimbal_yaw": 60,
"height": 10,
"latitude": 10,
"longitude": 10,
"speed_x": 10,
"speed_y": 10,
"speed_z": 10
},
"timestamp": 1751425301213249700
}
}
```
# 接口名:图片识别
接收前端的图片算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "0001111",
"content_body": {
"s3_id":1, #根据id适配minio相关存储参数
"s3_url":[
"test/frame_0000.jpg","test/frame_0001.jpg","test/frame_0002.jpg" # minio文件地址
],
"confidence":0.4, #算法置信度
"model_func_id":[10001,10002] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "0001111", #任务id
"minio": {
"minio_path": "ai_result/20250627/1751006943659-frame_0001.jpg", # minio 存储路径
"file_type": "pic"
},
"box_detail": {
"model_id": 10001,
"box_count": [
{
"type": 3, # 类型
"type_name": "车辆", #类型名称
"count": 71 #数量
},
{
"type": 0,
"type_name": "车辆",
"count": 7
}
]
}
}
```
# 接口名:地类分割
接收前端的图片算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
"content_body": {
"s3_id":1,#根据id适配minio相关存储参数
"s3_url":[
"test/patch_0011.png", # minio文件地址
"test/patch_0012.png"
],
"model_func_id":[20000,20001] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
"minio": [
{
"minio_path_before": "ai_result/20250710/1752128232469-patch_0011.png", # 需要分割的图片
"minio_path_after": "ai_result/20250710/1752128234222-patch_0011.png", #分割之后的图片
"minio_path_boundary": "ai_result/20250710/1752128234264-patch_0011.pngfinal_vis.png", # 分割的边界图片
"minio_path_json": "ai_result/20250710/1752128234326-patch_0011.pnginstance_results.json", #分割生成的json文件
"file_type": "pic"
},
{
"minio_path_before": "ai_result/20250710/1752128240382-patch_0012.png",
"minio_path_after": "ai_result/20250710/1752128241553-patch_0012.png",
"minio_path_boundary": "ai_result/20250710/1752128241587-patch_0012.pngfinal_vis.png",
"minio_path_json": "ai_result/20250710/1752128241631-patch_0012.pnginstance_results.json",
"file_type": "pic"
}
]
}
```
# 接口名:地类变化监测
接收前端的图片对一期、二期的图像做变化监测并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
"content_body": {
"s3_id":1,
"s3_url":{
"early":"/test/1-00205.png", # 一期图像minio文件地址
"later":"/test/2-00205.png" # 二期图像minio文件地址
},
"model_func_id":[30000,30001]
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
"minio": {
"minio_path_1": "ai_result/20250627/1751007686483-1-00205.png", # 一期影像minio地址
"minio_path_2": "ai_result/20250627/1751007686541-2-00205.png", # 二期影像minio地址
"minio_path_result": "ai_result/20250627/1751007686.458642-result-2-00205.png", #识别结果minio地址
"file_type": "pic"
}
}
```

View File

@ -165,13 +165,15 @@ def pic_detect_func(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
try: try:
frame_copy = frame.copy() frame_copy = frame.copy()
results = counter(frame) results = counter(frame)
func_id=model_func_id_list[0]
annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names, annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names,
model_func_id_list, func_id,
local_func_cache, para, cls, chinese_label, local_func_cache, para, cls, chinese_label,
model_func_id_list[0]) model_func_id_list[0])
except Exception as e: except Exception as e:
print(f"处理帧错误: {e}") print(f"处理帧错误1: {e}")
error_count += 1 error_count += 1
if error_count >= 5: if error_count >= 5:
print(f"连续处理错误达到5次 ,正在停止处理...") print(f"连续处理错误达到5次 ,正在停止处理...")

BIN
pt/build-wall.pt Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -18,6 +18,7 @@ from websockets.exceptions import ConnectionClosed
from CropLand_CD_module.detection import corpland_detection_func from CropLand_CD_module.detection import corpland_detection_func
from cropland_module.detection import detection_func from cropland_module.detection import detection_func
from grpc_proto.check_grpc_client import check_grpc_request
from middleware.AsyncioMqttClient import AsyncMQTTClient, ConnectionContext, active_connections from middleware.AsyncioMqttClient import AsyncMQTTClient, ConnectionContext, active_connections
from middleware.TaskManager import TaskManager, task_manager from middleware.TaskManager import TaskManager, task_manager
from middleware.minio_util import downFile from middleware.minio_util import downFile
@ -37,6 +38,9 @@ from cv_back_video import startBackAIVideo
from sanic_cors import CORS from sanic_cors import CORS
from sanic import Sanic, Request from sanic import Sanic, Request
# 引入其他模块
from b3dm.earthwork_api import earthwork_bp
from b3dm.terrain_api import terrain_bp
# 配置日志 # 配置日志
logging.basicConfig( logging.basicConfig(
@ -63,7 +67,14 @@ DB_CONFIG = {
"host": "8.137.54.85", "host": "8.137.54.85",
"port": "5060" "port": "5060"
} }
# 什邡数据库
# DB_CONFIG = {
# "dbname": "postgres",
# "user": "postgres",
# "password": "root",
# "host": "222.213.91.11",
# "port": "5061"
# }
# 配置类 # 配置类
class Config: class Config:
@ -312,6 +323,9 @@ async def heartbeat_monitor(task_manager: TaskManager):
app = Sanic("YoloStreamService1") app = Sanic("YoloStreamService1")
CORS(app) CORS(app)
# 显式注册蓝图
app.blueprint(earthwork_bp)
app.blueprint(terrain_bp)
# 启动心跳监测 # 启动心跳监测
@ -657,11 +671,11 @@ async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio
{ {
'path': config.model_path, 'path': config.model_path,
'engine_path': config.engine_path, # 'engine_path': config.engine_path,
'so_path': config.so_path, # 'so_path': config.so_path,
# # 测试代码 # # 测试代码
# 'engine_path': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\build.engine", 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine",
# 'so_path': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\myplugins.dll", 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll",
# 工地安全帽 # 工地安全帽
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine", # 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll", # 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
@ -1588,6 +1602,28 @@ async def stop_task_heart(request):
"message": str(e)}, status=500) "message": str(e)}, status=500)
@app.post("/ai/func/check_grpc")
async def check_grpc(request):
try:
verify_token(request)
check_grpc_request()
return json_response({
"status": "success",
"task_id": "task_id",
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(request: Request, ws): async def websocket_endpoint(request: Request, ws):
"""WebSocket端点处理前端连接和消息为每个连接创建独立的MQTT客户端""" """WebSocket端点处理前端连接和消息为每个连接创建独立的MQTT客户端"""
@ -1628,6 +1664,8 @@ async def websocket_endpoint(request: Request, ws):
camera_para_url = "meta_data/camera_para/xyzj_camera_para.txt" camera_para_url = "meta_data/camera_para/xyzj_camera_para.txt"
if model2 == "M4D": if model2 == "M4D":
camera_para_url = "meta_data/camera_para/xyzj_camera_para.txt" camera_para_url = "meta_data/camera_para/xyzj_camera_para.txt"
elif model2 == "M3TD":
camera_para_url = "meta_data/camera_para/hami_camera_para .txt"
elif model2 == "M4TD": elif model2 == "M4TD":
camera_para_url = "meta_data/camera_para/hami_camera_para .txt" camera_para_url = "meta_data/camera_para/hami_camera_para .txt"
camera_file_path = downFile(camera_para_url) camera_file_path = downFile(camera_para_url)