Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 63f240ac3a | |||
|
|
dd931f6231 | ||
| 146872a4dd | |||
| 89181007c2 | |||
|
|
9a09c1e1cf | ||
| c5eeb87488 | |||
| ee8733a0ce | |||
| 0ce543572b | |||
| 929c670add | |||
| 1656f81fe3 | |||
| dfb89c70a3 | |||
| a2d3e2e24b | |||
|
|
0f44df8cec | ||
|
|
eb6ce0de46 | ||
|
|
8d4db9b6df | ||
| eedca6cd50 | |||
|
|
5c865a4418 | ||
| 0e952115c8 | |||
| fbcc505a88 | |||
| b899c4e9de |
47
.gitignore
vendored
Normal file
47
.gitignore
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
# 忽略所有 .log 文件
|
||||
*.log
|
||||
|
||||
# 忽略特定目录(如 node_modules/)
|
||||
node_modules/
|
||||
|
||||
# 忽略本地配置文件(但保留示例文件)
|
||||
config.local.json
|
||||
!config.example.json
|
||||
|
||||
|
||||
# 忽略编译输出目录
|
||||
dist/
|
||||
build/
|
||||
|
||||
test
|
||||
*test*
|
||||
|
||||
*.pyc
|
||||
__pycache__/
|
||||
|
||||
|
||||
misc.xml
|
||||
profiles_settings.xml
|
||||
Project_Default.xml
|
||||
*test*.py
|
||||
|
||||
|
||||
# 忽略 Python 缓存文件
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.mo
|
||||
*.so
|
||||
|
||||
|
||||
uvmodule.log
|
||||
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
*.iml
|
||||
.vscode/
|
||||
|
||||
|
||||
# Temp files
|
||||
*.tmp
|
||||
*.swp
|
||||
3
.idea/.gitignore
generated
vendored
3
.idea/.gitignore
generated
vendored
@ -1,3 +0,0 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
3
.idea/misc.xml
generated
3
.idea/misc.xml
generated
@ -1,4 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="yolo_tensorrt" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="yolo_tensorrt" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
100
ai_image.py
100
ai_image.py
@ -1,100 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import cv2
|
||||
import collections
|
||||
from ultralytics import YOLO
|
||||
from miniohelp import downFile, upload_file, parse_minio_url # 确保你有这些工具函数
|
||||
from minio import Minio
|
||||
|
||||
def process_images(yolo_model, image_list, class_filter, input_folder, output_folder, minio_info):
|
||||
# 初始化 MinIO 客户端# 用配置字典初始化 Minio 客户端对象
|
||||
# 清洗 endpoint,去掉 http:// 或 https:// 前缀
|
||||
endpoint = minio_info["MinIOEndpoint"].replace("http://", "").replace("https://", "")
|
||||
|
||||
# 初始化 MinIO 客户端
|
||||
minio = Minio(
|
||||
endpoint=endpoint,
|
||||
access_key=minio_info["MinIOAccessKey"],
|
||||
secret_key=minio_info["MinIOSecretKey"],
|
||||
secure=False
|
||||
)
|
||||
os.makedirs(input_folder, exist_ok=True)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
model = YOLO(yolo_model)
|
||||
class_ids_filter = [int(cls) for cls in class_filter.split(",")] if class_filter else None
|
||||
output_image_list = []
|
||||
|
||||
for item in image_list:
|
||||
img_id = item["id"]
|
||||
img_url = item["path"]
|
||||
|
||||
# 解析 MinIO 地址
|
||||
if img_url.startswith("http"):
|
||||
bucket_name, img_path = parse_minio_url(img_url)
|
||||
else:
|
||||
bucket_name, img_path = "default-bucket", img_url
|
||||
|
||||
try:
|
||||
# 下载原图到本地
|
||||
local_input_path = os.path.join(input_folder, os.path.basename(img_path))
|
||||
downFile(minio, img_path, bucket_name, local_input_path)
|
||||
|
||||
# 读取图像
|
||||
image = cv2.imread(local_input_path)
|
||||
if image is None:
|
||||
raise ValueError(f"无法读取图像: {local_input_path}")
|
||||
|
||||
# YOLO 检测
|
||||
results = model.predict(image,
|
||||
classes=class_ids_filter,
|
||||
conf=0.5,
|
||||
iou = 0.111,
|
||||
show_labels = False,)
|
||||
result = results[0]
|
||||
|
||||
# 统计类别数
|
||||
class_counts = collections.Counter(result.boxes.cls.cpu().numpy().astype(int)) if result.boxes is not None else {}
|
||||
filtered_class_counts = {k: v for k, v in class_counts.items() if k in class_ids_filter}
|
||||
|
||||
# 转换所有的 numpy.int64 为 Python 的 int 类型
|
||||
detected_classes = [int(cls) for cls in filtered_class_counts.keys()]
|
||||
detected_numbers = [int(num) for num in filtered_class_counts.values()]
|
||||
aim = bool(detected_classes)
|
||||
|
||||
# 保存标注图像
|
||||
annotated_image = result.plot(labels=False)
|
||||
filename_no_ext, ext = os.path.splitext(os.path.basename(img_path))
|
||||
output_filename = f"{filename_no_ext}_ai{ext}"
|
||||
local_output_path = os.path.join(output_folder, output_filename)
|
||||
cv2.imwrite(local_output_path, annotated_image)
|
||||
|
||||
# 上传标注图像到 MinIO
|
||||
minio_path = upload_file(minio, local_output_path, bucket_name, os.path.dirname(img_path))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[错误] 处理失败 - {img_path},错误: {str(e)}")
|
||||
detected_classes = []
|
||||
detected_numbers = []
|
||||
aim = False
|
||||
output_filename = ""
|
||||
minio_path = ""
|
||||
|
||||
output_image_list.append({
|
||||
"id": img_id,
|
||||
"minio_path":minio_path,
|
||||
"aim": aim,
|
||||
"class": detected_classes,
|
||||
"number": detected_numbers
|
||||
})
|
||||
|
||||
# 清理临时目录
|
||||
shutil.rmtree(input_folder, ignore_errors=True)
|
||||
shutil.rmtree(output_folder, ignore_errors=True)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Detection completed",
|
||||
"data": output_image_list
|
||||
}
|
||||
|
||||
505
ai_tottle_api.py
505
ai_tottle_api.py
@ -1,505 +0,0 @@
|
||||
from sanic import Sanic, json, Blueprint,response
|
||||
from sanic.exceptions import Unauthorized, NotFound
|
||||
from sanic.response import json as json_response
|
||||
from sanic_cors import CORS
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import uuid
|
||||
import os
|
||||
import asyncio
|
||||
from minio import Minio
|
||||
from ai_image import process_images # 你实现的图片处理函数
|
||||
from queue import Queue
|
||||
import gdal2tiles as gdal2tiles
|
||||
from map_find import map_process_images
|
||||
from yolo_train import auto_train
|
||||
from map_cut import process_tiling
|
||||
from cv_video_counter import start_video_session,switch_model_session,stop_video_session,stream_sessions
|
||||
import torch
|
||||
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
|
||||
|
||||
# 日志配置
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
###################################################################################验证中间件和管理件##############################################################################################
|
||||
|
||||
async def token_and_resource_check(request):
|
||||
# --- Token 验证 ---
|
||||
token = request.headers.get('X-API-Token')
|
||||
expected_token = request.app.config.get("VALID_TOKEN")
|
||||
if not token or token != expected_token:
|
||||
logger.warning(f"Unauthorized request with token: {token}")
|
||||
raise Unauthorized("Invalid token")
|
||||
|
||||
# --- GPU 使用率检查 ---
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
num_gpus = torch.cuda.device_count()
|
||||
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
|
||||
|
||||
for i in range(num_gpus):
|
||||
used = torch.cuda.memory_reserved(i)
|
||||
total = torch.cuda.max_memory_reserved(i)
|
||||
ratio = used / total if total else 0
|
||||
|
||||
logger.info(f"GPU {i} Usage: {ratio:.2%}")
|
||||
|
||||
if ratio > max_usage_ratio:
|
||||
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
|
||||
}, status=503)
|
||||
except Exception as e:
|
||||
logger.error(f"GPU check failed: {e}")
|
||||
|
||||
return None # 允许请求继续
|
||||
|
||||
##################################################################################################################################################################################################
|
||||
#创建Sanic应用
|
||||
app = Sanic("ai_Service_v2")
|
||||
CORS(app) # 允许跨域请求
|
||||
task_progress = {}
|
||||
|
||||
@app.middleware("request")
|
||||
async def global_middleware(request):
|
||||
result = await token_and_resource_check(request)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 配置Token和最大GPU使用率
|
||||
app.config.update({
|
||||
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
|
||||
"MAX_GPU_USAGE": 0.9
|
||||
})
|
||||
|
||||
######################################################################地图切割相关的API########################################################################################################
|
||||
#创建地图的蓝图
|
||||
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
|
||||
app.blueprint(map_tile_blueprint)
|
||||
|
||||
|
||||
@map_tile_blueprint.post("/tile")
|
||||
async def map_tile_api(request):
|
||||
try:
|
||||
# 1. 检查请求体
|
||||
if not request.json:
|
||||
return json_response(
|
||||
{"status": "error", "message": "Request body is required"},
|
||||
status=400
|
||||
)
|
||||
|
||||
# 2. 解析必要字段
|
||||
tile_data = request.json
|
||||
tif_url = tile_data.get("tif_url")
|
||||
prj_url = tile_data.get("prj_url")
|
||||
|
||||
if not tif_url or not prj_url:
|
||||
return json_response(
|
||||
{"status": "error", "message": "Both tif_url and prj_url are required"},
|
||||
status=400
|
||||
)
|
||||
|
||||
# 3. 处理业务逻辑(直接调用协程函数,不要用 asyncio.run)
|
||||
zoom_levels = tile_data.get("zoom_levels", "1-18")
|
||||
try:
|
||||
# 假设 process_tiling 是一个协程函数(async def)
|
||||
result = await process_tiling(tif_url, prj_url, zoom_levels)
|
||||
|
||||
# 如果 process_tiling 是普通函数,用 asyncio.to_thread 包装
|
||||
# result = await asyncio.to_thread(process_tiling, tif_url, prj_url, zoom_levels)
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
except Exception as processing_error:
|
||||
logger.error(f"Processing failed: {str(processing_error)}", exc_info=True)
|
||||
return json_response(
|
||||
{"status": "error", "message": f"Processing error: {str(processing_error)}"},
|
||||
status=500
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||
return json_response(
|
||||
{"status": "error", "message": str(e)}, # 直接返回字符串,不要用集合
|
||||
status=500
|
||||
)
|
||||
|
||||
|
||||
#语义识别
|
||||
@map_tile_blueprint.post("/uav")
|
||||
async def process_handler(request):
|
||||
"""
|
||||
接口:/map/uav
|
||||
输入 JSON:
|
||||
{
|
||||
"urls": [
|
||||
"http://example.com/img1.jpg",
|
||||
"http://example.com/img2.jpg"
|
||||
],
|
||||
"yaml_name": "112.44.103.230",
|
||||
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
|
||||
"bucket_directory": "2025/seg"
|
||||
"model_path": "deeplabv3plus_best.pth"
|
||||
}
|
||||
输出 JSON:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": [
|
||||
"http://minio.example.com/uav-results/2025/seg/result1.png",
|
||||
"http://minio.example.com/uav-results/2025/seg/result2.png"
|
||||
]
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
body = request.json
|
||||
urls = body.get("urls", [])
|
||||
yaml_name = body.get("yaml_name")
|
||||
bucket_name = body.get("bucket_name")
|
||||
bucket_directory = body.get("bucket_directory")
|
||||
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
|
||||
|
||||
# 校验参数
|
||||
if not urls or not isinstance(urls, list):
|
||||
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
|
||||
if not all([yaml_name, bucket_name, bucket_directory]):
|
||||
return json({"code": 400, "msg": "Missing required parameters"})
|
||||
|
||||
# 调用图像处理函数
|
||||
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
|
||||
return json(result)
|
||||
|
||||
except Exception as e:
|
||||
return json({"code": 500, "msg": f"Server error: {str(e)}"})
|
||||
|
||||
######################################################################yolo相关的API########################################################################################################
|
||||
#创建yolo的蓝图
|
||||
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
|
||||
app.blueprint(yolo_tile_blueprint)
|
||||
|
||||
# YOLO URL APT
|
||||
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
|
||||
|
||||
@yolo_tile_blueprint.post("/process_images")
|
||||
async def process_images(request):
|
||||
"""
|
||||
{
|
||||
"urls": [
|
||||
"http://example.com/image1.jpg",
|
||||
"http://example.com/image2.jpg",
|
||||
"http://example.com/image3.jpg"
|
||||
],
|
||||
"yaml_name": "your_minio_config",
|
||||
"bucket_name": "my-bucket",
|
||||
"bucket_directory": "2025/uav-results",
|
||||
"model_path": "deeplabv3plus_best.pth"
|
||||
}
|
||||
"""
|
||||
data = request.json
|
||||
urls = data.get("urls")
|
||||
yaml_name = data.get("yaml_name")
|
||||
bucket_name = data.get("bucket_name")
|
||||
bucket_directory = data.get("bucket_directory")
|
||||
uav_model_path = data.get("uav_model_path")
|
||||
|
||||
if not urls or not yaml_name or not bucket_name or not uav_model_path:
|
||||
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
|
||||
|
||||
# 启动后台任务
|
||||
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
|
||||
|
||||
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
|
||||
|
||||
@yolo_tile_blueprint.get("/task_status/<task_id>")
|
||||
async def task_status(request, task_id):
|
||||
progress = task_progress.get(task_id)
|
||||
if not progress:
|
||||
return response.json({"code": 404, "msg": "Task not found"}, status=404)
|
||||
return response.json({"code": 200, "msg": "Task status", "data": progress})
|
||||
|
||||
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
|
||||
|
||||
|
||||
try:
|
||||
task_progress[task_id]["status"] = "running"
|
||||
task_progress[task_id]["progress"] = 10 # 开始进度
|
||||
|
||||
# 下载、推理、上传阶段分别更新进度
|
||||
def progress_callback(stage, percent):
|
||||
task_progress[task_id]["status"] = stage
|
||||
task_progress[task_id]["progress"] = percent
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
map_process_images_with_progress,
|
||||
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
|
||||
)
|
||||
|
||||
task_progress[task_id]["status"] = "completed"
|
||||
task_progress[task_id]["progress"] = 100
|
||||
task_progress[task_id]["result"] = result
|
||||
|
||||
except Exception as e:
|
||||
task_progress[task_id]["status"] = "failed"
|
||||
task_progress[task_id]["progress"] = 100
|
||||
task_progress[task_id]["result"] = str(e)
|
||||
|
||||
# YOLO检测API
|
||||
@yolo_tile_blueprint.post("/picture")
|
||||
async def yolo_detect_api(request):
|
||||
try:
|
||||
|
||||
detect_data = request.json
|
||||
|
||||
# 解析必要字段
|
||||
image_list = detect_data.get("image_list")
|
||||
yolo_model = detect_data.get("yolo_model", "best.pt")
|
||||
class_filter = detect_data.get("class", None)
|
||||
minio_info = detect_data.get("minio", None)
|
||||
|
||||
if not image_list:
|
||||
return json_response({"status": "error", "message": "image_list is required"}, status=400)
|
||||
|
||||
if not minio_info:
|
||||
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
|
||||
|
||||
# 创建临时文件夹
|
||||
input_folder = f"./temp_input_{str(uuid.uuid4())}"
|
||||
output_folder = f"./temp_output_{str(uuid.uuid4())}"
|
||||
|
||||
# 执行图像处理
|
||||
result = await asyncio.to_thread(
|
||||
process_images,
|
||||
yolo_model=yolo_model,
|
||||
image_list=image_list,
|
||||
class_filter=class_filter,
|
||||
input_folder=input_folder,
|
||||
output_folder=output_folder,
|
||||
minio_info=minio_info
|
||||
)
|
||||
|
||||
# 返回处理结果
|
||||
return json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Internal server error: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
# YOLO自动训练
|
||||
@yolo_tile_blueprint.post("/train")
|
||||
async def yolo_train_api(request):
|
||||
"""
|
||||
自动训练模型
|
||||
输入 JSON:
|
||||
{
|
||||
"db_host": str,
|
||||
"db_database": str,
|
||||
"db_user": str,
|
||||
"db_password": str,
|
||||
"db_port": int,
|
||||
"model_id": int,
|
||||
"img_path": str,
|
||||
"label_path": str,
|
||||
"new_path": str,
|
||||
"split_list": List[float],
|
||||
"class_names": Optional[List[str]],
|
||||
"project_name": str
|
||||
}
|
||||
输出 JSON:
|
||||
{
|
||||
"base_metrics": Dict[str, float],
|
||||
"best_model_path": str,
|
||||
"final_metrics": Dict[str, float]
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
# 修改为直接访问 request.json 而不是调用它
|
||||
data = request.json
|
||||
|
||||
|
||||
if not data:
|
||||
return json_response({"status": "error", "message": "data is required"}, status=400)
|
||||
|
||||
# 执行图像处理
|
||||
result = await asyncio.to_thread(
|
||||
auto_train,
|
||||
data
|
||||
)
|
||||
# 返回处理结果
|
||||
return json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Internal server error: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
###########################################################################################视频流相关的API#######################################################################################################
|
||||
|
||||
#创建视频流的蓝图
|
||||
stream_tile_blueprint = Blueprint('stream', url_prefix='/stream_test/')
|
||||
app.blueprint(stream_tile_blueprint)
|
||||
|
||||
#
|
||||
# 任务管理器
|
||||
class StreamTaskManager:
|
||||
def __init__(self):
|
||||
self.active_tasks = {}
|
||||
self.task_status = {}
|
||||
self.task_timestamps = {}
|
||||
self.task_queue = Queue(maxsize=10)
|
||||
|
||||
def add_task(self, task_id: str, task_info: dict) -> None:
|
||||
if self.task_queue.full():
|
||||
oldest_task_id = self.task_queue.get()
|
||||
self.remove_task(oldest_task_id)
|
||||
stop_video_session(self.active_tasks[oldest_task_id]["session_id"])
|
||||
self.active_tasks[task_id] = task_info
|
||||
self.task_status[task_id] = "running"
|
||||
self.task_timestamps[task_id] = datetime.now()
|
||||
self.task_queue.put(task_id)
|
||||
logger.info(f"Task {task_id} started")
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
if task_id in self.active_tasks:
|
||||
del self.active_tasks[task_id]
|
||||
del self.task_status[task_id]
|
||||
del self.task_timestamps[task_id]
|
||||
logger.info(f"Task {task_id} removed")
|
||||
|
||||
def get_task_info(self, task_id: str) -> dict:
|
||||
if task_id not in self.active_tasks:
|
||||
raise NotFound("Task not found")
|
||||
return {
|
||||
"task_info": self.active_tasks[task_id],
|
||||
"status": self.task_status[task_id],
|
||||
"start_time": self.task_timestamps[task_id].isoformat()
|
||||
}
|
||||
|
||||
task_manager = StreamTaskManager()
|
||||
|
||||
# ---------- API Endpoints ----------
|
||||
|
||||
@stream_tile_blueprint.post("/start")
|
||||
async def api_start(request):
|
||||
"""
|
||||
启动视频流会话
|
||||
输入 JSON:
|
||||
{
|
||||
"video_path": str,
|
||||
"output_url": str,
|
||||
"model_path": str,
|
||||
"cls": List[int],
|
||||
"confidence": float,
|
||||
"cls2": Optional[List[int]]
|
||||
"push": bool
|
||||
}
|
||||
输出 JSON:
|
||||
{
|
||||
"session_id": str,
|
||||
"task_id": str,
|
||||
"message": "started"
|
||||
}
|
||||
"""
|
||||
data = request.json
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 启动视频处理会话,并传入 task_id
|
||||
session_id = start_video_session(
|
||||
video_path = data.get("video_path"),
|
||||
output_url = data.get("output_url"),
|
||||
model_path = data.get("model_path"),
|
||||
cls = data.get("cls"),
|
||||
confidence = data.get("confidence", 0.5),
|
||||
cls2 = data.get("cls2", []),
|
||||
push = data.get("push", False),
|
||||
)
|
||||
|
||||
# 注册到任务管理器
|
||||
task_manager.add_task(task_id, {
|
||||
"session_id": session_id,
|
||||
"video_path": data.get("video_path"),
|
||||
"output_url": data.get("output_url"),
|
||||
"model_path": data.get("model_path"),
|
||||
"class_filter": data.get("cls", []),
|
||||
"push": data.get("push", False),
|
||||
"start_time": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return json({"session_id": session_id, "task_id": task_id, "message": "started"})
|
||||
|
||||
|
||||
@stream_tile_blueprint.post("/stop")
|
||||
async def api_stop(request):
|
||||
"""
|
||||
停止指定会话
|
||||
输入 JSON: { "session_id": str }
|
||||
输出 JSON: { "session_id": str, "message": "stopped" }
|
||||
"""
|
||||
session_id = request.json.get("session_id")
|
||||
stop_video_session(session_id)
|
||||
|
||||
# 同步移除任务
|
||||
for tid, info in list(task_manager.active_tasks.items()):
|
||||
if info.get("session_id") == session_id:
|
||||
task_manager.remove_task(tid)
|
||||
break
|
||||
|
||||
return json({"session_id": session_id, "message": "stopped"})
|
||||
|
||||
|
||||
@stream_tile_blueprint.post("/switch_model")
|
||||
async def api_switch_model(request):
|
||||
"""
|
||||
切换会话模型
|
||||
输入 JSON: { "session_id": str, "new_model_path": str }
|
||||
输出 JSON: { "session_id": str, "new_model_path": str, "message": "model switched" }
|
||||
"""
|
||||
data = request.json
|
||||
session_id = data.get("session_id")
|
||||
new_model = data.get("new_model_path")
|
||||
switch_model_session(session_id, new_model)
|
||||
return json({"session_id": session_id, "new_model_path": new_model, "message": "model switched"})
|
||||
|
||||
|
||||
@stream_tile_blueprint.get("/sessions")
|
||||
async def api_list_sessions(request):
|
||||
"""
|
||||
列出所有当前会话
|
||||
输出 JSON: { "sessions": [{"session_id": str, "status": "running"}, ...] }
|
||||
"""
|
||||
sessions = [
|
||||
{"session_id": sid, "status": "running"}
|
||||
for sid in stream_sessions.keys()
|
||||
]
|
||||
return json({"sessions": sessions})
|
||||
|
||||
|
||||
# 统一的任务查询接口(含视频流)
|
||||
@stream_tile_blueprint.get("/tasks")
|
||||
async def api_list_tasks(request):
|
||||
"""
|
||||
列出所有任务(含状态、开始时间、详情)
|
||||
"""
|
||||
tasks = []
|
||||
for tid in task_manager.active_tasks:
|
||||
info = task_manager.get_task_info(tid)
|
||||
tasks.append({"task_id": tid, **info})
|
||||
return json({"tasks": tasks})
|
||||
|
||||
|
||||
##################################################################################################################################################################################################
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
||||
60
b3dm/b3dm_api.py
Normal file
60
b3dm/b3dm_api.py
Normal file
@ -0,0 +1,60 @@
|
||||
from sanic import Sanic, Request, json
|
||||
from sanic_cors import CORS
|
||||
import logging
|
||||
import time
|
||||
from earthwork_api import earthwork_bp
|
||||
from terrain_api import terrain_bp
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建Sanic应用
|
||||
app = Sanic("TerrainAnalysisAPI")
|
||||
# 显式注册蓝图
|
||||
app.blueprint(earthwork_bp)
|
||||
app.blueprint(terrain_bp)
|
||||
|
||||
CORS(app, automatic_options=True)
|
||||
|
||||
# 中间件:请求计时
|
||||
@app.middleware("request")
|
||||
async def add_start_time(request: Request):
|
||||
request.ctx.start_time = time.time()
|
||||
|
||||
@app.middleware("response")
|
||||
async def add_response_time(request: Request, response):
|
||||
if hasattr(request.ctx, "start_time"):
|
||||
process_time = (time.time() - request.ctx.start_time) * 1000
|
||||
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def health_check(request: Request):
|
||||
"""健康检查"""
|
||||
return json({
|
||||
"status": "healthy",
|
||||
"timestamp": time.time(),
|
||||
"service": "terrain-analysis-api",
|
||||
"version": "1.0.0"
|
||||
})
|
||||
|
||||
# 错误处理
|
||||
@app.exception(Exception)
|
||||
async def handle_exception(request: Request, exception):
|
||||
"""全局异常处理"""
|
||||
logger.error(f"未处理的异常: {exception}")
|
||||
return json({
|
||||
"error": "服务器内部错误",
|
||||
"message": str(exception) if app.debug else "请稍后重试",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动服务器
|
||||
app.run(
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
debug=True, # 生产环境设为False
|
||||
access_log=True,
|
||||
auto_reload=True
|
||||
)
|
||||
542
b3dm/data_3dtiles_manager.py
Normal file
542
b3dm/data_3dtiles_manager.py
Normal file
@ -0,0 +1,542 @@
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
from urllib.parse import urlparse
|
||||
import hashlib
|
||||
import time
|
||||
import re
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
|
||||
class MinIO3DTilesManager:
|
||||
def __init__(self, endpoint_url, access_key, secret_key, secure=False,
|
||||
mapping_file="minio_path_mapping.pkl"):
|
||||
"""
|
||||
初始化MinIO客户端
|
||||
|
||||
Args:
|
||||
endpoint_url: MinIO服务地址 (如: 222.212.85.86:9001)
|
||||
access_key: 访问密钥
|
||||
secret_key: 秘密密钥
|
||||
secure: 是否使用HTTPS
|
||||
mapping_file: 路径映射文件名
|
||||
"""
|
||||
if endpoint_url.startswith('http://'):
|
||||
endpoint_url = endpoint_url.replace('http://', '')
|
||||
elif endpoint_url.startswith('https://'):
|
||||
endpoint_url = endpoint_url.replace('https://', '')
|
||||
secure = True
|
||||
|
||||
self.endpoint_url = endpoint_url
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
self.minio_client = Minio(
|
||||
endpoint_url,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=secure
|
||||
)
|
||||
|
||||
# 获取脚本所在目录
|
||||
self.script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 映射文件路径
|
||||
self.mapping_file = os.path.join(self.script_dir, mapping_file)
|
||||
|
||||
# 加载现有的路径映射
|
||||
self.path_mapping = self.load_path_mapping()
|
||||
|
||||
def load_path_mapping(self):
|
||||
"""加载路径映射数据"""
|
||||
if os.path.exists(self.mapping_file):
|
||||
try:
|
||||
with open(self.mapping_file, 'rb') as f:
|
||||
mapping = pickle.load(f)
|
||||
return mapping
|
||||
except Exception as e:
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def save_path_mapping(self):
|
||||
"""保存路径映射数据"""
|
||||
try:
|
||||
with open(self.mapping_file, 'wb') as f:
|
||||
pickle.dump(self.path_mapping, f)
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def get_cache_key(self, tileset_url, save_dir=None):
|
||||
"""生成缓存键"""
|
||||
# 基于URL和保存目录生成缓存键
|
||||
cache_data = f"{tileset_url}|{save_dir}"
|
||||
return hashlib.md5(cache_data.encode()).hexdigest()
|
||||
|
||||
def get_cached_tileset_info(self, tileset_url, save_dir=None):
|
||||
"""获取缓存的tileset信息"""
|
||||
cache_key = self.get_cache_key(tileset_url, save_dir)
|
||||
|
||||
# 检查缓存映射中是否有这个tileset
|
||||
for file_id, info in self.path_mapping.items():
|
||||
if info.get('cache_key') == cache_key and info.get('is_tileset_root'):
|
||||
# 检查入口文件是否存在
|
||||
local_path = info.get('local_path')
|
||||
if local_path and os.path.exists(local_path):
|
||||
return local_path
|
||||
return None
|
||||
|
||||
def update_tileset_cache(self, tileset_url, save_dir, local_path):
|
||||
"""更新tileset缓存信息"""
|
||||
cache_key = self.get_cache_key(tileset_url, save_dir)
|
||||
|
||||
# 将tileset根文件标记为缓存
|
||||
entry_bucket, entry_path = self.parse_minio_url(tileset_url)
|
||||
file_id = f"{entry_bucket}/{entry_path}"
|
||||
|
||||
if file_id in self.path_mapping:
|
||||
self.path_mapping[file_id]['cache_key'] = cache_key
|
||||
self.path_mapping[file_id]['is_tileset_root'] = True
|
||||
self.path_mapping[file_id]['tileset_url'] = tileset_url
|
||||
self.path_mapping[file_id]['save_dir'] = save_dir
|
||||
self.path_mapping[file_id]['cache_time'] = datetime.now().isoformat()
|
||||
|
||||
def download_full_tileset(self, tileset_url, save_dir=None, region_filter=None, use_cache=True):
|
||||
"""
|
||||
下载完整的3D Tiles数据集,支持缓存功能
|
||||
|
||||
Args:
|
||||
tileset_url: MinIO上的tileset.json URL
|
||||
save_dir: 本地保存目录
|
||||
region_filter: 区域过滤器
|
||||
use_cache: 是否使用缓存
|
||||
|
||||
Returns:
|
||||
tuple: (success, result)
|
||||
- success: True/False
|
||||
- result: 如果success=True且use_cache=True,返回本地路径;否则返回True/False
|
||||
"""
|
||||
if save_dir is None:
|
||||
save_dir = os.path.join(self.script_dir, "data_3dtiles")
|
||||
|
||||
# 清理保存目录名称
|
||||
save_dir = self.clean_file_path(save_dir)
|
||||
|
||||
# 检查缓存:只需检查入口文件是否存在
|
||||
if use_cache:
|
||||
cached_path = self.get_cached_tileset_info(tileset_url, save_dir)
|
||||
if cached_path:
|
||||
# 入口文件存在,默认缓存完备
|
||||
return True, cached_path
|
||||
|
||||
# 解析URL
|
||||
entry_bucket, entry_path = self.parse_minio_url(tileset_url)
|
||||
if not entry_bucket or not entry_path:
|
||||
return False, "无法解析URL"
|
||||
|
||||
entry_dir = os.path.dirname(entry_path)
|
||||
|
||||
# 创建保存目录
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
visited = set()
|
||||
|
||||
# 下载入口文件
|
||||
entry_local_path = self.get_local_path(entry_bucket, entry_path, save_dir)
|
||||
|
||||
success, result = self.download_file(entry_bucket, entry_path, entry_local_path)
|
||||
if not success:
|
||||
return False, f"入口文件下载失败: {result}"
|
||||
|
||||
entry_id = f"{entry_bucket}/{entry_path}"
|
||||
visited.add(entry_id)
|
||||
|
||||
# 加载tileset数据
|
||||
tileset_data = self.load_json_from_minio(entry_bucket, entry_path)
|
||||
if not tileset_data or "root" not in tileset_data:
|
||||
return False, "无效的tileset.json文件"
|
||||
|
||||
# 遍历下载所有文件
|
||||
self.traverse_and_download_tileset(
|
||||
tileset_data["root"],
|
||||
entry_bucket,
|
||||
entry_dir,
|
||||
entry_bucket,
|
||||
entry_dir,
|
||||
save_dir,
|
||||
region_filter,
|
||||
None,
|
||||
visited
|
||||
)
|
||||
|
||||
# 更新缓存信息
|
||||
self.update_tileset_cache(tileset_url, save_dir, entry_local_path)
|
||||
|
||||
# 保存路径映射
|
||||
self.save_path_mapping()
|
||||
|
||||
if use_cache:
|
||||
return True, entry_local_path
|
||||
else:
|
||||
return True, True
|
||||
|
||||
def get_tileset_local_path(self, tileset_url, save_dir=None):
|
||||
"""
|
||||
获取已缓存的tileset本地路径
|
||||
|
||||
Args:
|
||||
tileset_url: tileset的URL
|
||||
save_dir: 保存目录
|
||||
|
||||
Returns:
|
||||
str: 本地路径,如果未缓存则返回None
|
||||
"""
|
||||
if save_dir is None:
|
||||
save_dir = os.path.join(self.script_dir, "data_3dtiles")
|
||||
|
||||
return self.get_cached_tileset_info(tileset_url, save_dir)
|
||||
|
||||
def clear_tileset_cache(self, tileset_url=None, save_dir=None):
|
||||
"""
|
||||
清除tileset缓存
|
||||
|
||||
Args:
|
||||
tileset_url: 指定要清除的tileset URL,如果为None则清除所有
|
||||
save_dir: 保存目录
|
||||
|
||||
Returns:
|
||||
bool: 成功/失败
|
||||
"""
|
||||
try:
|
||||
if tileset_url:
|
||||
# 清除指定tileset的缓存
|
||||
cache_key = self.get_cache_key(tileset_url, save_dir)
|
||||
|
||||
# 找出所有相关的缓存条目
|
||||
to_remove = []
|
||||
for file_id, info in self.path_mapping.items():
|
||||
if info.get('cache_key') == cache_key:
|
||||
to_remove.append(file_id)
|
||||
|
||||
# 删除这些条目
|
||||
for file_id in to_remove:
|
||||
del self.path_mapping[file_id]
|
||||
|
||||
print(f"已清除tileset缓存: {tileset_url}")
|
||||
else:
|
||||
# 清除所有缓存
|
||||
self.path_mapping = {}
|
||||
if os.path.exists(self.mapping_file):
|
||||
os.remove(self.mapping_file)
|
||||
print("已清除所有缓存")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
# 以下是原有的辅助方法
|
||||
def clean_filename(self, filename):
|
||||
"""清理文件名中的特殊字符"""
|
||||
if not filename:
|
||||
return ""
|
||||
cleaned = re.sub(r'[<>:"/\\|?*\x00-\x1F]', '_', filename)
|
||||
cleaned = re.sub(r'_+', '_', cleaned)
|
||||
cleaned = cleaned.strip(' _')
|
||||
return cleaned
|
||||
|
||||
def parse_minio_url(self, url):
|
||||
"""解析MinIO URL"""
|
||||
if url.startswith('http://') or url.startswith('https://'):
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path.lstrip('/')
|
||||
parts = path.split('/', 1)
|
||||
if len(parts) == 2:
|
||||
bucket, key = parts
|
||||
else:
|
||||
bucket = parts[0]
|
||||
key = ""
|
||||
return bucket, key
|
||||
else:
|
||||
parts = url.split('/', 1)
|
||||
if len(parts) == 2:
|
||||
bucket, key = parts
|
||||
else:
|
||||
bucket = parts[0]
|
||||
key = ""
|
||||
return bucket, key
|
||||
|
||||
def download_file(self, bucket_name, object_name, file_path):
|
||||
"""从MinIO下载文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
# 清理文件名
|
||||
clean_file_path = self.clean_file_path(file_path)
|
||||
|
||||
# 检查是否已下载
|
||||
file_id = f"{bucket_name}/{object_name}"
|
||||
if file_id in self.path_mapping:
|
||||
mapped_path = self.path_mapping[file_id]['local_path']
|
||||
if os.path.exists(mapped_path):
|
||||
return True, mapped_path
|
||||
|
||||
# 下载文件
|
||||
self.minio_client.fget_object(
|
||||
bucket_name,
|
||||
object_name,
|
||||
clean_file_path
|
||||
)
|
||||
|
||||
# 更新路径映射
|
||||
self.path_mapping[file_id] = {
|
||||
'local_path': clean_file_path,
|
||||
'bucket': bucket_name,
|
||||
'object': object_name,
|
||||
'download_time': datetime.now().isoformat(),
|
||||
'size': os.path.getsize(clean_file_path)
|
||||
}
|
||||
|
||||
return True, clean_file_path
|
||||
|
||||
except S3Error as e:
|
||||
return False, str(e)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
def clean_file_path(self, file_path):
|
||||
"""清理文件路径中的所有特殊字符"""
|
||||
dir_name = os.path.dirname(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
if dir_name:
|
||||
dir_parts = dir_name.split(os.sep)
|
||||
cleaned_parts = []
|
||||
for part in dir_parts:
|
||||
cleaned_part = self.clean_filename(part)
|
||||
if cleaned_part:
|
||||
cleaned_parts.append(cleaned_part)
|
||||
cleaned_dir = os.sep.join(cleaned_parts)
|
||||
else:
|
||||
cleaned_dir = ""
|
||||
|
||||
cleaned_file = self.clean_filename(file_name)
|
||||
|
||||
if cleaned_dir:
|
||||
cleaned_path = os.path.join(cleaned_dir, cleaned_file)
|
||||
else:
|
||||
cleaned_path = cleaned_file
|
||||
|
||||
return cleaned_path
|
||||
|
||||
def load_json_from_minio(self, bucket_name, object_name):
|
||||
"""从MinIO加载JSON文件"""
|
||||
try:
|
||||
self.minio_client.stat_object(bucket_name, object_name)
|
||||
|
||||
response = self.minio_client.get_object(bucket_name, object_name)
|
||||
content = response.read().decode('utf-8')
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
return json.loads(content)
|
||||
|
||||
except S3Error as e:
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def get_local_path(self, bucket_name, object_name, save_dir):
|
||||
"""生成保持目录结构的本地路径"""
|
||||
clean_bucket = self.clean_filename(bucket_name)
|
||||
|
||||
path_parts = object_name.split('/')
|
||||
cleaned_parts = []
|
||||
for part in path_parts:
|
||||
cleaned_part = self.clean_filename(part)
|
||||
if cleaned_part:
|
||||
cleaned_parts.append(cleaned_part)
|
||||
|
||||
if cleaned_parts:
|
||||
cleaned_relative = '/'.join(cleaned_parts)
|
||||
local_path = os.path.join(save_dir, clean_bucket, cleaned_relative)
|
||||
else:
|
||||
local_path = os.path.join(save_dir, clean_bucket)
|
||||
|
||||
|
||||
return os.path.normpath(local_path)
|
||||
|
||||
def traverse_and_download_tileset(self, tile_obj, current_bucket, current_dir,
|
||||
base_bucket, base_dir, save_dir,
|
||||
region_filter=None, parent_transform=None,
|
||||
visited=None):
|
||||
"""递归遍历并下载3D Tiles文件"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
current_transform = parent_transform
|
||||
if "transform" in tile_obj:
|
||||
tile_mat = tile_obj["transform"]
|
||||
if current_transform is None:
|
||||
current_transform = tile_mat
|
||||
else:
|
||||
mat1 = np.array(current_transform).reshape(4, 4)
|
||||
mat2 = np.array(tile_mat).reshape(4, 4)
|
||||
combined_mat = np.dot(mat1, mat2).flatten().tolist()
|
||||
current_transform = combined_mat
|
||||
|
||||
skip_current_tile = False
|
||||
if region_filter and "boundingVolume" in tile_obj:
|
||||
if not region_filter.check_tile_bounding_volume(tile_obj["boundingVolume"]):
|
||||
skip_current_tile = True
|
||||
|
||||
if not skip_current_tile and "content" in tile_obj and "uri" in tile_obj["content"]:
|
||||
tile_uri = tile_obj["content"]["uri"]
|
||||
|
||||
file_bucket = current_bucket
|
||||
file_path = ""
|
||||
|
||||
if tile_uri.startswith('http://') or tile_uri.startswith('https://'):
|
||||
parsed_bucket, parsed_path = self.parse_minio_url(tile_uri)
|
||||
if parsed_bucket:
|
||||
file_bucket = parsed_bucket
|
||||
file_path = parsed_path
|
||||
else:
|
||||
if current_dir:
|
||||
file_path = os.path.join(current_dir, tile_uri).replace('\\', '/')
|
||||
else:
|
||||
file_path = tile_uri
|
||||
|
||||
file_path = file_path.lstrip('/')
|
||||
|
||||
file_id = f"{file_bucket}/{file_path}"
|
||||
|
||||
if file_id not in visited:
|
||||
print(f"下载文件:{file_id}")
|
||||
visited.add(file_id)
|
||||
|
||||
local_path = self.get_local_path(file_bucket, file_path, save_dir)
|
||||
|
||||
self.download_file(file_bucket, file_path, local_path)
|
||||
|
||||
if file_path.lower().endswith('.json'):
|
||||
sub_tileset = self.load_json_from_minio(file_bucket, file_path)
|
||||
if sub_tileset and "root" in sub_tileset:
|
||||
sub_dir = os.path.dirname(file_path) if file_path else ""
|
||||
self.traverse_and_download_tileset(
|
||||
sub_tileset["root"],
|
||||
file_bucket,
|
||||
sub_dir,
|
||||
base_bucket,
|
||||
base_dir,
|
||||
save_dir,
|
||||
region_filter,
|
||||
current_transform,
|
||||
visited
|
||||
)
|
||||
|
||||
if "children" in tile_obj:
|
||||
for child_tile in tile_obj["children"]:
|
||||
self.traverse_and_download_tileset(
|
||||
child_tile,
|
||||
current_bucket,
|
||||
current_dir,
|
||||
base_bucket,
|
||||
base_dir,
|
||||
save_dir,
|
||||
region_filter,
|
||||
current_transform,
|
||||
visited
|
||||
)
|
||||
|
||||
def upload_file(self, bucket_name, object_name, file_path):
|
||||
"""上传文件到MinIO"""
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
return False, f"文件不存在: {file_path}"
|
||||
|
||||
file_size = os.path.getsize(file_path)
|
||||
self.minio_client.fput_object(bucket_name, object_name, file_path)
|
||||
|
||||
return True, f"{bucket_name}/{object_name}"
|
||||
|
||||
except S3Error as e:
|
||||
return False, f"MinIO上传错误: {e}"
|
||||
except Exception as e:
|
||||
return False, f"上传失败: {str(e)}"
|
||||
|
||||
def upload_directory(self, bucket_name, local_dir, remote_prefix=""):
|
||||
"""上传目录到MinIO"""
|
||||
if not os.path.exists(local_dir):
|
||||
return [], [f"目录不存在: {local_dir}"]
|
||||
|
||||
uploaded_files = []
|
||||
failed_files = []
|
||||
|
||||
for root, dirs, files in os.walk(local_dir):
|
||||
for file in files:
|
||||
local_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(local_path, local_dir)
|
||||
if remote_prefix:
|
||||
remote_path = os.path.join(remote_prefix, rel_path).replace('\\', '/')
|
||||
else:
|
||||
remote_path = rel_path.replace('\\', '/')
|
||||
|
||||
success, message = self.upload_file(bucket_name, remote_path, local_path)
|
||||
if success:
|
||||
uploaded_files.append(remote_path)
|
||||
else:
|
||||
failed_files.append((remote_path, message))
|
||||
|
||||
return uploaded_files, failed_files
|
||||
|
||||
def check_and_create_bucket(self, bucket_name):
|
||||
"""检查并创建bucket"""
|
||||
try:
|
||||
if not self.minio_client.bucket_exists(bucket_name):
|
||||
self.minio_client.make_bucket(bucket_name)
|
||||
return True, f"创建bucket: {bucket_name}"
|
||||
return True, f"bucket已存在: {bucket_name}"
|
||||
except S3Error as e:
|
||||
return False, f"创建bucket失败: {e}"
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
ENDPOINT_URL = "222.212.85.86:9000"
|
||||
ACCESS_KEY = "WuRenJi"
|
||||
SECRET_KEY = "WRJ@2024"
|
||||
|
||||
# 初始化管理器
|
||||
manager = MinIO3DTilesManager(
|
||||
endpoint_url=ENDPOINT_URL,
|
||||
access_key=ACCESS_KEY,
|
||||
secret_key=SECRET_KEY,
|
||||
secure=False
|
||||
)
|
||||
|
||||
# 使用缓存下载tileset
|
||||
tileset_url = "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/石棉0908/terra_b3dms/tileset.json"
|
||||
|
||||
# 第一次下载(会下载到本地)
|
||||
success, result = manager.download_full_tileset(tileset_url, use_cache=True)
|
||||
if success:
|
||||
print(f"下载成功,本地路径: {result}")
|
||||
|
||||
# 第二次下载相同URL(直接从缓存返回)
|
||||
success, result = manager.download_full_tileset(tileset_url, use_cache=True)
|
||||
if success:
|
||||
print(f"从缓存获取,本地路径: {result}")
|
||||
|
||||
# 强制重新下载(忽略缓存)
|
||||
success, result = manager.download_full_tileset(tileset_url, use_cache=False)
|
||||
if success:
|
||||
print("强制重新下载成功")
|
||||
|
||||
# 获取缓存的本地路径
|
||||
local_path = manager.get_tileset_local_path(tileset_url)
|
||||
if local_path:
|
||||
print(f"缓存的本地路径: {local_path}")
|
||||
1367
b3dm/data_3dtiles_to_dem.py
Normal file
1367
b3dm/data_3dtiles_to_dem.py
Normal file
File diff suppressed because it is too large
Load Diff
530
b3dm/earthwork_api.py
Normal file
530
b3dm/earthwork_api.py
Normal file
@ -0,0 +1,530 @@
|
||||
# pip install fastapi uvicorn pdal pyvista numpy
|
||||
from sanic import Blueprint, Request, json
|
||||
from sanic.response import text
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
from b3dm.earthwork_calculator_point_cloud import EarthworkCalculatorPointCloud
|
||||
# 导入计算模块
|
||||
from b3dm.earthwork_calculator_3d_tiles import EarthworkCalculator3dTiles, AlgorithmType, EarthworkResult3dTiles
|
||||
from b3dm.tileset_data_source import TilesetDataSource
|
||||
|
||||
earthwork_bp = Blueprint("earthwork", url_prefix="")
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化函数
|
||||
def init_app(url, type = "3dtiles"):
|
||||
"""初始化应用"""
|
||||
data_source = None
|
||||
calculator_3d_tiles = None
|
||||
calculator_point_cloud = None
|
||||
|
||||
try:
|
||||
# 初始化数据源
|
||||
data_source = TilesetDataSource(url)
|
||||
data_source.dowload_map_data(url)
|
||||
|
||||
if type == "3dtiles" :
|
||||
# 初始化计算器-3dTiles
|
||||
calculator_3d_tiles = EarthworkCalculator3dTiles(data_source)
|
||||
elif type == "pointcloud" :
|
||||
# 初始化计算器-点云
|
||||
calculator_point_cloud = EarthworkCalculatorPointCloud(data_source.tileset_path)
|
||||
else :
|
||||
logger.info(f"不支持的3d地图数据格式:{type}")
|
||||
raise
|
||||
|
||||
logger.info("土方量计算器初始化完成")
|
||||
return {
|
||||
"data_source":data_source,
|
||||
"calculator_3d_tiles":calculator_3d_tiles,
|
||||
"calculator_point_cloud":calculator_point_cloud
|
||||
}
|
||||
except ImportError as e:
|
||||
logger.error(f"依赖库缺失: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"初始化失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 土方量计算接口-3dTiles
|
||||
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles")
|
||||
async def calc_earthwork(request: Request):
|
||||
"""
|
||||
土方量计算接口
|
||||
|
||||
请求参数示例:
|
||||
{
|
||||
"polygonCoords": [
|
||||
[
|
||||
115.70440468338526,
|
||||
30.77363140345639
|
||||
],
|
||||
[
|
||||
115.70443054007985,
|
||||
30.773510462589584
|
||||
],
|
||||
[
|
||||
115.70459702429197,
|
||||
30.77360789911405
|
||||
]
|
||||
],
|
||||
"designElevation": 100,
|
||||
"algorithm": "tin",
|
||||
"resolution": 1,
|
||||
"crs": "EPSG:4326",
|
||||
"interpolationMethod": "linear",
|
||||
"url": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return _error_response("请求参数不能为空", 400)
|
||||
|
||||
# 2. 提取参数
|
||||
polygon_coords = data.get("polygonCoords")
|
||||
design_elevation = data.get("designElevation")
|
||||
url = data.get("url")
|
||||
|
||||
if not polygon_coords:
|
||||
return _error_response("多边形坐标不能为空", 400)
|
||||
if design_elevation is None:
|
||||
return _error_response("设计高程不能为空", 400)
|
||||
if url is None:
|
||||
return _error_response("地图不能为空", 400)
|
||||
|
||||
# 3. 可选参数
|
||||
algorithm = data.get("algorithm", "tin")
|
||||
resolution = data.get("resolution", 1.0)
|
||||
crs = data.get("crs", "EPSG:4326")
|
||||
interpolation_method = data.get("interpolationMethod", "linear")
|
||||
|
||||
# 4. 参数验证
|
||||
if len(polygon_coords) < 3:
|
||||
return _error_response("多边形至少需要3个点", 400)
|
||||
|
||||
# 检查多边形是否闭合,如不闭合则自动闭合
|
||||
if polygon_coords[0] != polygon_coords[-1]:
|
||||
polygon_coords.append(polygon_coords[0])
|
||||
|
||||
# 算法验证
|
||||
if algorithm not in ["grid", "tin", "prism"]:
|
||||
return _error_response("算法必须是 grid, tin 或 prism", 400)
|
||||
|
||||
# 分辨率验证
|
||||
if resolution <= 0 or resolution > 100:
|
||||
return _error_response("分辨率必须在0-100米之间", 400)
|
||||
|
||||
# 5. 确保计算器已初始化
|
||||
app_info = init_app(url)
|
||||
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
|
||||
|
||||
# 6. 执行计算
|
||||
algorithm_type = AlgorithmType(algorithm)
|
||||
|
||||
result = await calculator_3d_tiles.calculate(
|
||||
polygon_coords=polygon_coords,
|
||||
design_elevation=design_elevation,
|
||||
algorithm=algorithm_type,
|
||||
resolution=resolution,
|
||||
target_crs=crs,
|
||||
interpolation_method=interpolation_method
|
||||
)
|
||||
|
||||
# 7. 返回成功响应
|
||||
res_dict = result.to_dict()
|
||||
res_dict["calculation_details"] = None
|
||||
res_dict["elevation_statistics"] = None
|
||||
res_dict["volume_distribution"] = None
|
||||
return _success_response(res_dict)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"参数验证失败: {str(e)}")
|
||||
return _error_response(f"参数错误: {str(e)}", 400)
|
||||
except Exception as e:
|
||||
logger.error(f"计算失败: {str(e)}")
|
||||
return _error_response(f"服务器内部错误: {str(e)}", 500)
|
||||
|
||||
# 两期对比接口-3dTiles
|
||||
@earthwork_bp.post("/api/v1/calc/twoPhaseComparison")
|
||||
async def two_phase_comparison(request: Request):
|
||||
"""
|
||||
两期对比接口
|
||||
|
||||
请求参数示例:
|
||||
{
|
||||
"polygonCoords": [
|
||||
[
|
||||
115.70440468338526,
|
||||
30.77363140345639
|
||||
],
|
||||
[
|
||||
115.70443054007985,
|
||||
30.773510462589584
|
||||
],
|
||||
[
|
||||
115.70459702429197,
|
||||
30.77360789911405
|
||||
]
|
||||
],
|
||||
"designElevation": 100,
|
||||
"algorithm": "grid",
|
||||
"resolution": 1,
|
||||
"crs": "EPSG:4326",
|
||||
"interpolationMethod": "linear",
|
||||
"urlA": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json",
|
||||
"urlB": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return _error_response("请求参数不能为空", 400)
|
||||
|
||||
# 2. 提取参数
|
||||
polygon_coords = data.get("polygonCoords")
|
||||
design_elevation = data.get("designElevation", 1000)
|
||||
urlA = data.get("urlA")
|
||||
urlB = data.get("urlB")
|
||||
|
||||
if not polygon_coords:
|
||||
return _error_response("多边形坐标不能为空", 400)
|
||||
if design_elevation is None:
|
||||
return _error_response("设计高程不能为空", 400)
|
||||
if urlA is None or urlB is None :
|
||||
return _error_response("对比地图不能为空", 400)
|
||||
|
||||
# 3. 可选参数
|
||||
algorithm = data.get("algorithm", "tin")
|
||||
resolution = data.get("resolution", 1.0)
|
||||
crs = data.get("crs", "EPSG:4326")
|
||||
interpolation_method = data.get("interpolationMethod", "linear")
|
||||
|
||||
# 4. 参数验证
|
||||
if len(polygon_coords) < 3:
|
||||
return _error_response("多边形至少需要3个点", 400)
|
||||
|
||||
# 检查多边形是否闭合,如不闭合则自动闭合
|
||||
if polygon_coords[0] != polygon_coords[-1]:
|
||||
polygon_coords.append(polygon_coords[0])
|
||||
|
||||
# 算法验证
|
||||
if algorithm not in ["grid", "tin", "prism"]:
|
||||
return _error_response("算法必须是 grid, tin 或 prism", 400)
|
||||
|
||||
# 分辨率验证
|
||||
if resolution <= 0 or resolution > 100:
|
||||
return _error_response("分辨率必须在0-100米之间", 400)
|
||||
|
||||
# 5. 确保计算器已初始化
|
||||
app_info_a = init_app(urlA)
|
||||
if not app_info_a.get('data_source').tileset_path :
|
||||
return _error_response(f"下载地图失败:{urlA}", 400)
|
||||
calculator_3d_tiles_a = app_info_a.get("calculator_3d_tiles")
|
||||
app_info_b = init_app(urlB)
|
||||
if not app_info_b.get('data_source').tileset_path :
|
||||
return _error_response(f"下载地图失败:{urlB}", 400)
|
||||
calculator_3d_tiles_b = app_info_b.get("calculator_3d_tiles")
|
||||
|
||||
# 6. 执行计算
|
||||
algorithm_type = AlgorithmType.GRID
|
||||
result_a = await calculator_3d_tiles_a.calculate(
|
||||
polygon_coords=polygon_coords,
|
||||
design_elevation=design_elevation,
|
||||
algorithm=algorithm_type,
|
||||
resolution=resolution,
|
||||
target_crs=crs,
|
||||
interpolation_method=interpolation_method
|
||||
)
|
||||
result_b = await calculator_3d_tiles_b.calculate(
|
||||
polygon_coords=polygon_coords,
|
||||
design_elevation=design_elevation,
|
||||
algorithm=algorithm_type,
|
||||
resolution=resolution,
|
||||
target_crs=crs,
|
||||
interpolation_method=interpolation_method
|
||||
)
|
||||
|
||||
# 获取网格数据
|
||||
grids_a = result_a.calculation_details
|
||||
grids_b = result_b.calculation_details
|
||||
|
||||
# 比较网格数据
|
||||
comparison_result = calculator_3d_tiles_a.compare_grid_cells(grids_a, grids_b)
|
||||
|
||||
# 转换为字典
|
||||
result_dict = comparison_result.to_dict()
|
||||
|
||||
# 7. 返回成功响应
|
||||
return _success_response(result_dict)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"参数验证失败: {str(e)}")
|
||||
return _error_response(f"参数错误: {str(e)}", 400)
|
||||
except Exception as e:
|
||||
logger.error(f"计算失败: {str(e)}")
|
||||
return _error_response(f"服务器内部错误: {str(e)}", 500)
|
||||
|
||||
|
||||
# 验证接口
|
||||
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles/validate")
|
||||
async def validate_earthwork(request: Request):
|
||||
"""验证计算参数接口"""
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return _error_response("请求参数不能为空", 400)
|
||||
|
||||
# 2. 提取参数
|
||||
polygon_coords = data.get("polygonCoords")
|
||||
|
||||
if not polygon_coords:
|
||||
return _error_response("多边形坐标不能为空", 400)
|
||||
|
||||
url = data.get("url")
|
||||
if url is None:
|
||||
return _error_response("地图不能为空", 400)
|
||||
|
||||
# 3. 参数验证
|
||||
if len(polygon_coords) < 3:
|
||||
return _error_response("多边形至少需要3个点", 400)
|
||||
|
||||
# 检查多边形是否闭合,如不闭合则自动闭合
|
||||
if polygon_coords[0] != polygon_coords[-1]:
|
||||
polygon_coords.append(polygon_coords[0])
|
||||
|
||||
# 4. 确保计算器已初始化
|
||||
app_info = init_app(url)
|
||||
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
|
||||
|
||||
# 5. 执行验证
|
||||
validation_result = await calculator_3d_tiles.validate(polygon_coords)
|
||||
|
||||
# 6. 返回结果
|
||||
return _success_response(validation_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证失败: {str(e)}")
|
||||
return _error_response(f"验证失败: {str(e)}", 400)
|
||||
|
||||
# 获取算法列表接口
|
||||
@earthwork_bp.get("/api/v1/calc/earthwork3dTiles/algorithms")
|
||||
async def get_algorithms(request: Request):
|
||||
"""获取支持的算法列表接口"""
|
||||
try:
|
||||
algorithms = [
|
||||
{
|
||||
"id": "grid",
|
||||
"name": "格网法",
|
||||
"description": "将计算区域划分为规则格网,通过插值计算每个格网的高程变化,适合平坦或规则地形",
|
||||
"accuracy": "中等",
|
||||
"performance": "快速",
|
||||
"parameters": {
|
||||
"resolution": {
|
||||
"name": "格网分辨率",
|
||||
"description": "格网大小(米),影响计算精度和性能",
|
||||
"default": 1.0,
|
||||
"range": [0.1, 10.0]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "tin",
|
||||
"name": "三角网法",
|
||||
"description": "基于不规则三角网(TIN)构建地形表面,计算每个三角形的体积变化,适合复杂地形",
|
||||
"accuracy": "高",
|
||||
"performance": "中等",
|
||||
"parameters": {
|
||||
"resolution": {
|
||||
"name": "不适用",
|
||||
"description": "三角网法不使用固定的分辨率参数",
|
||||
"default": None
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "prism",
|
||||
"name": "三棱柱法",
|
||||
"description": "结合三角网和垂直棱柱的高精度算法,计算每个三棱柱的体积,精度最高",
|
||||
"accuracy": "最高",
|
||||
"performance": "较慢",
|
||||
"parameters": {
|
||||
"resolution": {
|
||||
"name": "棱柱宽度",
|
||||
"description": "棱柱宽度(米),影响计算精度",
|
||||
"default": 1.0,
|
||||
"range": [0.1, 5.0]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
return _success_response(algorithms)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取算法列表失败: {str(e)}")
|
||||
return _error_response(f"获取算法列表失败: {str(e)}", 500)
|
||||
|
||||
# 批量计算接口
|
||||
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles/batch")
|
||||
async def batch_calc_earthwork(request: Request):
|
||||
"""批量土方量计算接口"""
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return _error_response("请求参数不能为空", 400)
|
||||
|
||||
calculations = data.get("calculations", [])
|
||||
|
||||
if not calculations:
|
||||
return _error_response("计算任务列表不能为空", 400)
|
||||
|
||||
if len(calculations) > 100:
|
||||
return _error_response("批量计算数量超过限制(最多100个)", 400)
|
||||
|
||||
|
||||
|
||||
# 3. 执行批量计算
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
for i, calc_data in enumerate(calculations):
|
||||
try:
|
||||
# 提取参数
|
||||
polygon_coords = calc_data.get("polygonCoords")
|
||||
design_elevation = calc_data.get("designElevation")
|
||||
url = calc_data.get("url")
|
||||
|
||||
if not polygon_coords or design_elevation is None or url is None:
|
||||
errors.append({
|
||||
"index": i,
|
||||
"error": "缺少必要参数"
|
||||
})
|
||||
continue
|
||||
|
||||
# 参数验证
|
||||
if len(polygon_coords) < 3:
|
||||
errors.append({
|
||||
"index": i,
|
||||
"error": "多边形至少需要3个点"
|
||||
})
|
||||
continue
|
||||
|
||||
# 检查多边形是否闭合
|
||||
if polygon_coords[0] != polygon_coords[-1]:
|
||||
polygon_coords.append(polygon_coords[0])
|
||||
|
||||
# 可选参数
|
||||
algorithm = calc_data.get("algorithm", "tin")
|
||||
resolution = calc_data.get("resolution", 1.0)
|
||||
crs = calc_data.get("crs", "EPSG:4326")
|
||||
interpolation_method = calc_data.get("interpolationMethod", "linear")
|
||||
|
||||
# 2. 确保计算器已初始化
|
||||
app_info = init_app(url)
|
||||
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
|
||||
|
||||
# 执行计算
|
||||
algorithm_type = AlgorithmType(algorithm)
|
||||
|
||||
result = await calculator_3d_tiles.calculate(
|
||||
polygon_coords=polygon_coords,
|
||||
design_elevation=design_elevation,
|
||||
algorithm=algorithm_type,
|
||||
resolution=resolution,
|
||||
target_crs=crs,
|
||||
interpolation_method=interpolation_method
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
errors.append({
|
||||
"index": i,
|
||||
"error": str(e),
|
||||
"polygon": polygon_coords if 'polygon_coords' in locals() else None
|
||||
})
|
||||
continue
|
||||
|
||||
# 4. 返回结果
|
||||
batch_result = {
|
||||
"results": results,
|
||||
"errors": errors,
|
||||
"summary": {
|
||||
"total": len(calculations),
|
||||
"success": len(results),
|
||||
"failed": len(errors),
|
||||
"successRate": f"{(len(results)/len(calculations)*100):.1f}%" if calculations else "0%"
|
||||
}
|
||||
}
|
||||
|
||||
message = f"批量计算完成,成功 {len(results)} 个,失败 {len(errors)} 个"
|
||||
return _success_response(batch_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量计算失败: {str(e)}")
|
||||
return _error_response(f"批量计算失败: {str(e)}", 500)
|
||||
|
||||
# 核心接口:土方量计算-点云
|
||||
@earthwork_bp.post("/api/v1/calc/earthworkPointCloud")
|
||||
async def calc_earthwork_point_cloud(request: Request):
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({
|
||||
"code": 400,
|
||||
"msg": "请求参数不能为空",
|
||||
"data": None
|
||||
}, status=400)
|
||||
|
||||
polygon_coords = data.get("polygonCoords") # 计算区域多边形坐标
|
||||
design_elev = data.get("designElevation") # 设计高程
|
||||
crs = data.get("crs", "EPSG:4326") # 坐标系,默认WGS84
|
||||
url = data.get("url")
|
||||
if url is None:
|
||||
return _error_response("地图不能为空", 400)
|
||||
|
||||
# 2. 确保计算器已初始化
|
||||
app_info = init_app(url)
|
||||
calculator_point_cloud = app_info.get("calculator_point_cloud")
|
||||
|
||||
result = calculator_point_cloud.calculate_earthwork(polygon_coords=polygon_coords, design_elev=design_elev, crs=crs)
|
||||
|
||||
# 3. 处理结果
|
||||
if not result["success"]:
|
||||
return _error_response(result["error"], 400)
|
||||
|
||||
# 4. 格式化结果
|
||||
formatted_result = calculator_point_cloud.format_result(result)
|
||||
|
||||
# 5. 返回成功响应
|
||||
return _success_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"服务器错误: {str(e)}")
|
||||
return _error_response(f"服务器内部错误: {str(e)}", 500)
|
||||
|
||||
def _success_response(data: Dict[str, Any]) -> json:
|
||||
"""成功响应"""
|
||||
|
||||
return json({
|
||||
"code": 200,
|
||||
"msg": "计算成功",
|
||||
"data": data
|
||||
})
|
||||
|
||||
def _error_response(message: str, status_code: int = 400) -> json:
|
||||
"""错误响应"""
|
||||
return json({
|
||||
"code": status_code,
|
||||
"msg": message,
|
||||
"data": None
|
||||
}, status=status_code)
|
||||
1647
b3dm/earthwork_calculator_3d_tiles.py
Normal file
1647
b3dm/earthwork_calculator_3d_tiles.py
Normal file
File diff suppressed because it is too large
Load Diff
691
b3dm/earthwork_calculator_point_cloud.py
Normal file
691
b3dm/earthwork_calculator_point_cloud.py
Normal file
@ -0,0 +1,691 @@
|
||||
# earthwork_calculator.py
|
||||
import pdal
|
||||
import pyvista as pv
|
||||
import numpy as np
|
||||
import json
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import traceback
|
||||
|
||||
|
||||
class EarthworkAlgorithm(Enum):
|
||||
"""土方量计算算法枚举"""
|
||||
GRID = "grid"
|
||||
TIN = "tin"
|
||||
PRISM = "prism"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EarthworkResultPointCloud:
|
||||
"""土方量计算结果"""
|
||||
cut_volume: float # 挖方量 (m³)
|
||||
fill_volume: float # 填方量 (m³)
|
||||
net_volume: float # 净方量 (m³)
|
||||
area: float # 计算区域面积 (m²)
|
||||
avg_elevation: float # 平均高程
|
||||
min_elevation: float # 最低高程
|
||||
max_elevation: float # 最高高程
|
||||
points_count: int # 使用的点数
|
||||
triangle_count: int = 0 # 三角形数量
|
||||
grid_count: int = 0 # 网格数量(仅GRID算法使用)
|
||||
prism_count: int = 0 # 棱柱体数量(仅PRISM算法使用)
|
||||
bounding_box: Dict[str, List[float]] = field(default_factory=dict) # 边界框
|
||||
volume_accuracy: float = 0.95 # 计算精度
|
||||
algorithm: str = "TIN三角网法" # 使用的算法
|
||||
resolution: float = 1.0 # 计算分辨率
|
||||
algorithm_params: Dict[str, Any] = field(default_factory=dict) # 算法参数
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"volume": {
|
||||
"cut": round(self.cut_volume, 3),
|
||||
"fill": round(self.fill_volume, 3),
|
||||
"net": round(self.net_volume, 3),
|
||||
"unit": "m³"
|
||||
},
|
||||
"area": {
|
||||
"value": round(self.area, 3),
|
||||
"unit": "m²"
|
||||
},
|
||||
"elevation": {
|
||||
"average": round(self.avg_elevation, 3),
|
||||
"min": round(self.min_elevation, 3),
|
||||
"max": round(self.max_elevation, 3),
|
||||
"unit": "m"
|
||||
},
|
||||
"statistics": {
|
||||
"points_count": self.points_count,
|
||||
"accuracy": round(self.volume_accuracy, 3),
|
||||
"algorithm": self.algorithm
|
||||
},
|
||||
"bounding_box": self.bounding_box,
|
||||
"calculation_params": {
|
||||
"resolution": self.resolution,
|
||||
"accuracy": self.volume_accuracy,
|
||||
**self.algorithm_params
|
||||
}
|
||||
}
|
||||
|
||||
# 根据算法类型添加特定的统计信息
|
||||
if self.algorithm.startswith("GRID"):
|
||||
result["statistics"]["grid_count"] = self.grid_count
|
||||
elif self.algorithm.startswith("TIN"):
|
||||
result["statistics"]["triangle_count"] = self.triangle_count
|
||||
elif self.algorithm.startswith("PRISM"):
|
||||
result["statistics"]["prism_count"] = self.prism_count
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class EarthworkCalculatorPointCloud:
|
||||
"""土方量计算核心类(支持多种算法)"""
|
||||
|
||||
def __init__(self, point_cloud_path: str = "./data/your_point_cloud.laz"):
|
||||
"""
|
||||
初始化土方量计算器
|
||||
|
||||
Args:
|
||||
point_cloud_path: 点云数据文件路径
|
||||
"""
|
||||
self.point_cloud_path = point_cloud_path
|
||||
|
||||
def validate_inputs(self, polygon_coords: List[List[float]], design_elev: float) -> Tuple[bool, str]:
|
||||
"""验证输入参数"""
|
||||
if not polygon_coords or len(polygon_coords) < 3:
|
||||
return False, "多边形坐标至少需要3个点"
|
||||
|
||||
try:
|
||||
design_elev = float(design_elev)
|
||||
except (TypeError, ValueError):
|
||||
return False, "设计高程必须是有效数字"
|
||||
|
||||
return True, ""
|
||||
|
||||
def create_polygon_string(self, polygon_coords: List[List[float]]) -> str:
|
||||
"""创建PDAL多边形字符串"""
|
||||
coords_list = []
|
||||
for coord in polygon_coords:
|
||||
if len(coord) >= 2:
|
||||
coords_list.append(f"{coord[0]} {coord[1]}")
|
||||
|
||||
# 确保多边形闭合
|
||||
if coords_list and coords_list[0] != coords_list[-1]:
|
||||
coords_list.append(coords_list[0])
|
||||
|
||||
return "POLYGON((" + ", ".join(coords_list) + "))"
|
||||
|
||||
def calculate_bounding_box(self, points: np.ndarray) -> Dict[str, List[float]]:
|
||||
"""
|
||||
计算边界框
|
||||
|
||||
Args:
|
||||
points: 点云坐标数组
|
||||
|
||||
Returns:
|
||||
Dict: 边界框信息
|
||||
"""
|
||||
if len(points) == 0:
|
||||
return {"min": [0, 0, 0], "max": [0, 0, 0]}
|
||||
|
||||
min_vals = np.min(points, axis=0)
|
||||
max_vals = np.max(points, axis=0)
|
||||
|
||||
return {
|
||||
"min": [float(min_vals[0]), float(min_vals[1]), float(min_vals[2])],
|
||||
"max": [float(max_vals[0]), float(max_vals[1]), float(max_vals[2])]
|
||||
}
|
||||
|
||||
def clip_point_cloud(self, polygon_coords: List[List[float]], crs: str = "EPSG:4326") -> pv.PolyData:
|
||||
"""
|
||||
裁剪点云数据
|
||||
|
||||
Args:
|
||||
polygon_coords: 多边形坐标列表
|
||||
crs: 坐标系
|
||||
|
||||
Returns:
|
||||
pyvista.PolyData: 裁剪后的点云数据
|
||||
"""
|
||||
polygon_str = self.create_polygon_string(polygon_coords)
|
||||
|
||||
# PDAL管道配置
|
||||
pipeline_config = {
|
||||
"pipeline": [
|
||||
{
|
||||
"type": "readers.las",
|
||||
"filename": self.point_cloud_path,
|
||||
"spatialreference": crs
|
||||
},
|
||||
{
|
||||
"type": "filters.crop",
|
||||
"polygon": polygon_str
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 执行PDAL管道
|
||||
pipeline = pdal.Pipeline(json.dumps(pipeline_config))
|
||||
|
||||
try:
|
||||
pipeline.execute()
|
||||
|
||||
if len(pipeline.arrays) == 0:
|
||||
raise ValueError("多边形区域内没有找到点云数据")
|
||||
|
||||
# 获取裁剪后的点云数据
|
||||
points = pipeline.arrays[0]
|
||||
x = points["X"]
|
||||
y = points["Y"]
|
||||
z = points["Z"]
|
||||
|
||||
return pv.PolyData(np.column_stack((x, y, z)))
|
||||
|
||||
except RuntimeError as e:
|
||||
print(f"PDAL执行失败: {str(e)}")
|
||||
# 如果没有PDAL数据,生成模拟数据用于测试
|
||||
return self.generate_mock_point_cloud(polygon_coords)
|
||||
|
||||
def generate_mock_point_cloud(self, polygon_coords: List[List[float]]) -> pv.PolyData:
|
||||
"""
|
||||
生成模拟点云数据(仅用于测试)
|
||||
|
||||
Args:
|
||||
polygon_coords: 多边形坐标列表
|
||||
|
||||
Returns:
|
||||
pyvista.PolyData: 模拟点云数据
|
||||
"""
|
||||
print("使用模拟数据进行测试...")
|
||||
|
||||
n_points = 1000
|
||||
|
||||
# 获取多边形边界
|
||||
x_coords = [c[0] for c in polygon_coords]
|
||||
y_coords = [c[1] for c in polygon_coords]
|
||||
|
||||
x_min, x_max = min(x_coords), max(x_coords)
|
||||
y_min, y_max = min(y_coords), max(y_coords)
|
||||
|
||||
# 生成随机点
|
||||
x = np.random.uniform(x_min, x_max, n_points)
|
||||
y = np.random.uniform(y_min, y_max, n_points)
|
||||
z = np.random.uniform(100, 120, n_points) # 模拟高程在100-120米之间
|
||||
|
||||
return pv.PolyData(np.column_stack((x, y, z)))
|
||||
|
||||
def create_tin_mesh(self, point_cloud: pv.PolyData) -> pv.PolyData:
|
||||
"""
|
||||
创建三角网(TIN算法使用)
|
||||
|
||||
Args:
|
||||
point_cloud: 点云数据
|
||||
|
||||
Returns:
|
||||
pyvista.PolyData: 三角网格
|
||||
"""
|
||||
if len(point_cloud.points) < 3:
|
||||
raise ValueError("点云数据不足,无法构网")
|
||||
|
||||
try:
|
||||
return point_cloud.delaunay_2d()
|
||||
except Exception as e:
|
||||
raise ValueError(f"三角网构网失败: {str(e)}")
|
||||
|
||||
def calculate_volumes_by_tin(self, mesh: pv.PolyData, design_elev: float) -> Dict[str, Any]:
|
||||
"""
|
||||
TIN算法计算土方量
|
||||
|
||||
Args:
|
||||
mesh: 三角网格
|
||||
design_elev: 设计高程
|
||||
|
||||
Returns:
|
||||
Dict: 包含体积计算结果和额外信息
|
||||
"""
|
||||
points = mesh.points
|
||||
elev_diff = points[:, 2] - design_elev
|
||||
|
||||
cut_volume = 0.0
|
||||
fill_volume = 0.0
|
||||
triangle_count = 0
|
||||
|
||||
# 遍历所有三角形面片计算体积
|
||||
cells = mesh.cells.reshape(-1, 4)
|
||||
|
||||
if len(cells) == 0:
|
||||
raise ValueError("无法生成有效的三角网")
|
||||
|
||||
for cell in cells:
|
||||
if cell[0] == 3: # 三角形(VTK格式:3个顶点)
|
||||
triangle_count += 1
|
||||
vertex_indices = cell[1:]
|
||||
pts = points[vertex_indices]
|
||||
|
||||
# 计算三角形面积
|
||||
v1 = pts[1] - pts[0]
|
||||
v2 = pts[2] - pts[0]
|
||||
area = 0.5 * np.linalg.norm(np.cross(v1, v2))
|
||||
|
||||
# 计算平均高程差
|
||||
avg_diff = np.mean(elev_diff[vertex_indices])
|
||||
vol = area * avg_diff
|
||||
|
||||
if vol > 0:
|
||||
cut_volume += vol
|
||||
else:
|
||||
fill_volume += abs(vol)
|
||||
|
||||
return {
|
||||
"cut_volume": cut_volume,
|
||||
"fill_volume": fill_volume,
|
||||
"net_volume": cut_volume - fill_volume,
|
||||
"triangle_count": triangle_count
|
||||
}
|
||||
|
||||
def calculate_volumes_by_grid(self, point_cloud: pv.PolyData, design_elev: float,
|
||||
grid_size: float = 1.0) -> Dict[str, Any]:
|
||||
"""
|
||||
GRID算法计算土方量
|
||||
|
||||
Args:
|
||||
point_cloud: 点云数据
|
||||
design_elev: 设计高程
|
||||
grid_size: 网格尺寸
|
||||
|
||||
Returns:
|
||||
Dict: 包含体积计算结果和额外信息
|
||||
"""
|
||||
points = point_cloud.points
|
||||
|
||||
if len(points) == 0:
|
||||
raise ValueError("点云数据为空")
|
||||
|
||||
# 计算边界
|
||||
x_min, y_min = np.min(points[:, :2], axis=0)
|
||||
x_max, y_max = np.max(points[:, :2], axis=0)
|
||||
|
||||
# 创建网格
|
||||
x_edges = np.arange(x_min, x_max + grid_size, grid_size)
|
||||
y_edges = np.arange(y_min, y_max + grid_size, grid_size)
|
||||
|
||||
grid_count = (len(x_edges) - 1) * (len(y_edges) - 1)
|
||||
|
||||
cut_volume = 0.0
|
||||
fill_volume = 0.0
|
||||
|
||||
# 对每个网格计算土方量
|
||||
for i in range(len(x_edges) - 1):
|
||||
for j in range(len(y_edges) - 1):
|
||||
# 获取当前网格内的点
|
||||
mask = (points[:, 0] >= x_edges[i]) & (points[:, 0] < x_edges[i+1]) & \
|
||||
(points[:, 1] >= y_edges[j]) & (points[:, 1] < y_edges[j+1])
|
||||
|
||||
grid_points = points[mask]
|
||||
|
||||
if len(grid_points) > 0:
|
||||
# 计算网格内点的平均高程
|
||||
avg_elevation = np.mean(grid_points[:, 2])
|
||||
|
||||
# 计算高程差
|
||||
elev_diff = avg_elevation - design_elev
|
||||
|
||||
# 计算体积
|
||||
cell_area = grid_size * grid_size
|
||||
vol = cell_area * elev_diff
|
||||
|
||||
if vol > 0:
|
||||
cut_volume += vol
|
||||
else:
|
||||
fill_volume += abs(vol)
|
||||
|
||||
return {
|
||||
"cut_volume": cut_volume,
|
||||
"fill_volume": fill_volume,
|
||||
"net_volume": cut_volume - fill_volume,
|
||||
"grid_count": grid_count
|
||||
}
|
||||
|
||||
def calculate_volumes_by_prism(self, point_cloud: pv.PolyData, design_elev: float,
|
||||
influence_radius: float = 0.5) -> Dict[str, Any]:
|
||||
"""
|
||||
PRISM算法计算土方量
|
||||
|
||||
Args:
|
||||
point_cloud: 点云数据
|
||||
design_elev: 设计高程
|
||||
influence_radius: 影响半径
|
||||
|
||||
Returns:
|
||||
Dict: 包含体积计算结果和额外信息
|
||||
"""
|
||||
points = point_cloud.points
|
||||
|
||||
if len(points) == 0:
|
||||
raise ValueError("点云数据为空")
|
||||
|
||||
cut_volume = 0.0
|
||||
fill_volume = 0.0
|
||||
|
||||
# 每个点的影响面积
|
||||
influence_area = np.pi * influence_radius ** 2
|
||||
prism_count = len(points)
|
||||
|
||||
for point in points:
|
||||
# 计算高程差
|
||||
elev_diff = point[2] - design_elev
|
||||
|
||||
# 计算体积
|
||||
vol = influence_area * elev_diff
|
||||
|
||||
if vol > 0:
|
||||
cut_volume += vol
|
||||
else:
|
||||
fill_volume += abs(vol)
|
||||
|
||||
return {
|
||||
"cut_volume": cut_volume,
|
||||
"fill_volume": fill_volume,
|
||||
"net_volume": cut_volume - fill_volume,
|
||||
"prism_count": prism_count
|
||||
}
|
||||
|
||||
def calculate_statistics(self, point_cloud: pv.PolyData, mesh: pv.PolyData = None) -> Dict[str, float]:
|
||||
"""
|
||||
计算统计数据
|
||||
|
||||
Args:
|
||||
point_cloud: 点云数据
|
||||
mesh: 三角网格(仅TIN算法需要)
|
||||
|
||||
Returns:
|
||||
Dict: 统计结果
|
||||
"""
|
||||
elevations = point_cloud.points[:, 2]
|
||||
|
||||
stats = {
|
||||
"area": 0.0,
|
||||
"max_elevation": np.max(elevations) if len(elevations) > 0 else 0.0,
|
||||
"min_elevation": np.min(elevations) if len(elevations) > 0 else 0.0,
|
||||
"avg_elevation": np.mean(elevations) if len(elevations) > 0 else 0.0,
|
||||
"points_count": len(point_cloud.points)
|
||||
}
|
||||
|
||||
# 计算面积
|
||||
if mesh is not None:
|
||||
stats["area"] = mesh.area
|
||||
else:
|
||||
# 对于非TIN算法,使用多边形面积近似
|
||||
if len(point_cloud.points) > 0:
|
||||
# 使用点云的凸包面积
|
||||
try:
|
||||
hull = point_cloud.delaunay_2d()
|
||||
stats["area"] = hull.area
|
||||
except:
|
||||
# 如果无法计算凸包,使用边界框面积
|
||||
x_min, x_max = np.min(point_cloud.points[:, 0]), np.max(point_cloud.points[:, 0])
|
||||
y_min, y_max = np.min(point_cloud.points[:, 1]), np.max(point_cloud.points[:, 1])
|
||||
stats["area"] = (x_max - x_min) * (y_max - y_min)
|
||||
|
||||
return stats
|
||||
|
||||
def calculate_earthwork(self,
|
||||
polygon_coords: List[List[float]],
|
||||
design_elev: float,
|
||||
algorithm: str = EarthworkAlgorithm.TIN.value,
|
||||
algorithm_params: Optional[Dict[str, Any]] = None,
|
||||
crs: str = "EPSG:4326",
|
||||
volume_accuracy: Optional[float] = None,
|
||||
resolution: Optional[float] = None) -> EarthworkResultPointCloud:
|
||||
"""
|
||||
主计算方法:执行完整的土方量计算流程
|
||||
|
||||
Args:
|
||||
polygon_coords: 多边形坐标列表
|
||||
design_elev: 设计高程
|
||||
algorithm: 计算算法,可选值:'grid', 'tin', 'prism'
|
||||
algorithm_params: 算法特定参数
|
||||
crs: 坐标系
|
||||
volume_accuracy: 计算精度(0-1之间)
|
||||
resolution: 计算分辨率
|
||||
|
||||
Returns:
|
||||
EarthworkResultPointCloud: 计算结果
|
||||
"""
|
||||
try:
|
||||
# 1. 验证输入
|
||||
is_valid, message = self.validate_inputs(polygon_coords, design_elev)
|
||||
if not is_valid:
|
||||
raise ValueError(message)
|
||||
|
||||
design_elev = float(design_elev)
|
||||
|
||||
# 2. 验证算法参数
|
||||
if algorithm not in [a.value for a in EarthworkAlgorithm]:
|
||||
raise ValueError(f"不支持的算法: {algorithm}。支持的算法: {[a.value for a in EarthworkAlgorithm]}")
|
||||
|
||||
# 3. 设置默认参数
|
||||
if algorithm_params is None:
|
||||
algorithm_params = {}
|
||||
|
||||
# 4. 裁剪点云
|
||||
point_cloud = self.clip_point_cloud(polygon_coords, crs)
|
||||
|
||||
# 5. 根据算法选择计算方法
|
||||
algorithm_name = ""
|
||||
mesh = None
|
||||
|
||||
if algorithm == EarthworkAlgorithm.TIN.value:
|
||||
algorithm_name = "TIN三角网法"
|
||||
mesh = self.create_tin_mesh(point_cloud)
|
||||
volumes = self.calculate_volumes_by_tin(mesh, design_elev)
|
||||
algorithm_params = {
|
||||
"grid_size": algorithm_params.get("grid_size", 1.0)
|
||||
}
|
||||
|
||||
elif algorithm == EarthworkAlgorithm.GRID.value:
|
||||
algorithm_name = "GRID格网法"
|
||||
grid_size = algorithm_params.get("grid_size", 1.0)
|
||||
algorithm_name = f"GRID格网法(网格尺寸={grid_size}m)"
|
||||
volumes = self.calculate_volumes_by_grid(point_cloud, design_elev, grid_size)
|
||||
algorithm_params = {
|
||||
"grid_size": grid_size
|
||||
}
|
||||
|
||||
elif algorithm == EarthworkAlgorithm.PRISM.value:
|
||||
algorithm_name = "PRISM棱柱体法"
|
||||
influence_radius = algorithm_params.get("influence_radius", 0.5)
|
||||
algorithm_name = f"PRISM棱柱体法(影响半径={influence_radius}m)"
|
||||
volumes = self.calculate_volumes_by_prism(point_cloud, design_elev, influence_radius)
|
||||
algorithm_params = {
|
||||
"influence_radius": influence_radius
|
||||
}
|
||||
|
||||
# 6. 计算统计数据
|
||||
stats = self.calculate_statistics(point_cloud, mesh)
|
||||
|
||||
# 7. 计算边界框
|
||||
bounding_box = self.calculate_bounding_box(point_cloud.points)
|
||||
|
||||
# 8. 计算或使用传入的精度和分辨率
|
||||
if volume_accuracy is None:
|
||||
# 根据算法和点云密度自动估算精度
|
||||
volume_accuracy = self.estimate_accuracy(algorithm, point_cloud)
|
||||
|
||||
if resolution is None:
|
||||
# 根据点云密度自动估算分辨率
|
||||
resolution = self.estimate_resolution(point_cloud)
|
||||
|
||||
# 9. 创建EarthworkResultPointCloud对象
|
||||
result = EarthworkResultPointCloud(
|
||||
cut_volume=volumes["cut_volume"],
|
||||
fill_volume=volumes["fill_volume"],
|
||||
net_volume=volumes["net_volume"],
|
||||
area=stats["area"],
|
||||
avg_elevation=stats["avg_elevation"],
|
||||
min_elevation=stats["min_elevation"],
|
||||
max_elevation=stats["max_elevation"],
|
||||
points_count=stats["points_count"],
|
||||
triangle_count=volumes.get("triangle_count", 0),
|
||||
grid_count=volumes.get("grid_count", 0),
|
||||
prism_count=volumes.get("prism_count", 0),
|
||||
bounding_box=bounding_box,
|
||||
volume_accuracy=volume_accuracy,
|
||||
algorithm=algorithm_name,
|
||||
resolution=resolution,
|
||||
algorithm_params=algorithm_params
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算错误: {str(e)}")
|
||||
print(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def estimate_accuracy(self, algorithm: str, point_cloud: pv.PolyData) -> float:
|
||||
"""根据算法和点云密度估计计算精度"""
|
||||
point_density = len(point_cloud.points) / max(point_cloud.area, 0.1)
|
||||
|
||||
# 基础精度
|
||||
base_accuracy = {
|
||||
EarthworkAlgorithm.TIN.value: 0.95,
|
||||
EarthworkAlgorithm.GRID.value: 0.90,
|
||||
EarthworkAlgorithm.PRISM.value: 0.85
|
||||
}.get(algorithm, 0.90)
|
||||
|
||||
# 根据点云密度调整精度
|
||||
if point_density > 10: # 高密度点云
|
||||
accuracy_boost = min(0.05, point_density * 0.002)
|
||||
elif point_density < 1: # 低密度点云
|
||||
accuracy_boost = -0.05
|
||||
else:
|
||||
accuracy_boost = 0
|
||||
|
||||
estimated_accuracy = base_accuracy + accuracy_boost
|
||||
|
||||
# 确保精度在合理范围内
|
||||
return max(0.7, min(0.99, estimated_accuracy))
|
||||
|
||||
def estimate_resolution(self, point_cloud: pv.PolyData) -> float:
|
||||
"""根据点云密度估计分辨率"""
|
||||
if len(point_cloud.points) < 2:
|
||||
return 1.0
|
||||
|
||||
area = point_cloud.area
|
||||
if area > 0:
|
||||
point_count = len(point_cloud.points)
|
||||
avg_spacing = np.sqrt(area / point_count)
|
||||
return float(round(avg_spacing, 2))
|
||||
|
||||
return 1.0
|
||||
|
||||
def get_algorithm_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取支持的算法信息
|
||||
|
||||
Returns:
|
||||
Dict: 算法信息
|
||||
"""
|
||||
return {
|
||||
"supported_algorithms": [
|
||||
{
|
||||
"id": EarthworkAlgorithm.TIN.value,
|
||||
"name": "TIN三角网法",
|
||||
"description": "通过构建不规则三角网计算土方量,精度高,适合复杂地形",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "grid_size",
|
||||
"type": "float",
|
||||
"default": 1.0,
|
||||
"description": "网格尺寸(m),用于点云预处理",
|
||||
"required": False
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": EarthworkAlgorithm.GRID.value,
|
||||
"name": "GRID格网法",
|
||||
"description": "将区域划分为规则网格计算土方量,计算速度快,适合大规模区域",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "grid_size",
|
||||
"type": "float",
|
||||
"default": 1.0,
|
||||
"description": "网格尺寸(m)",
|
||||
"required": True,
|
||||
"min": 0.1,
|
||||
"max": 10.0
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": EarthworkAlgorithm.PRISM.value,
|
||||
"name": "PRISM棱柱体法",
|
||||
"description": "将每个点视为一个棱柱体计算土方量,计算简单快速",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "influence_radius",
|
||||
"type": "float",
|
||||
"default": 0.5,
|
||||
"description": "点的影响半径(m)",
|
||||
"required": True,
|
||||
"min": 0.1,
|
||||
"max": 5.0
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"default_algorithm": EarthworkAlgorithm.TIN.value
|
||||
}
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 创建计算器实例
|
||||
calculator = EarthworkCalculatorPointCloud("./data/sample_point_cloud.laz")
|
||||
|
||||
# 获取支持的算法信息
|
||||
algorithm_info = calculator.get_algorithm_info()
|
||||
print("支持的算法:")
|
||||
for algo in algorithm_info["supported_algorithms"]:
|
||||
print(f" {algo['id']}: {algo['name']} - {algo['description']}")
|
||||
|
||||
# 定义多边形区域
|
||||
polygon = [
|
||||
[116.3974, 39.9093],
|
||||
[116.4084, 39.9093],
|
||||
[116.4084, 39.9193],
|
||||
[116.3974, 39.9193]
|
||||
]
|
||||
|
||||
# 设计高程
|
||||
design_elevation = 100.0
|
||||
|
||||
# 测试不同算法
|
||||
algorithms = [
|
||||
(EarthworkAlgorithm.TIN.value, {}, "TIN算法"),
|
||||
(EarthworkAlgorithm.GRID.value, {"grid_size": 2.0}, "GRID算法(2米网格)"),
|
||||
(EarthworkAlgorithm.PRISM.value, {"influence_radius": 1.0}, "PRISM算法(1米影响半径)")
|
||||
]
|
||||
|
||||
for algo_id, params, description in algorithms:
|
||||
print(f"\n使用{description}计算:")
|
||||
try:
|
||||
result = calculator.calculate_earthwork(
|
||||
polygon_coords=polygon,
|
||||
design_elev=design_elevation,
|
||||
algorithm=algo_id,
|
||||
algorithm_params=params
|
||||
)
|
||||
|
||||
print(f" 挖方量: {result.cut_volume:.3f} m³")
|
||||
print(f" 填方量: {result.fill_volume:.3f} m³")
|
||||
print(f" 净方量: {result.net_volume:.3f} m³")
|
||||
print(f" 计算面积: {result.area:.3f} m²")
|
||||
print(f" 计算精度: {result.volume_accuracy:.3%}")
|
||||
print(f" 分辨率: {result.resolution:.2f} m")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 计算失败: {str(e)}")
|
||||
469
b3dm/glb_with_draco.py
Normal file
469
b3dm/glb_with_draco.py
Normal file
@ -0,0 +1,469 @@
|
||||
import json
|
||||
import struct
|
||||
import numpy as np
|
||||
import DracoPy
|
||||
|
||||
class DracoGLBParser:
|
||||
"""使用 DracoPy 解析包含 Draco 压缩的 GLB 文件"""
|
||||
|
||||
def __init__(self, glb_file_path):
|
||||
self.glb_file_path = glb_file_path
|
||||
self.json_data = None
|
||||
self.binary_data = None
|
||||
self.decoded_meshes = [] # 缓存解码后的网格数据
|
||||
|
||||
def parse_glb_structure(self):
|
||||
"""解析 GLB 文件结构"""
|
||||
with open(self.glb_file_path, 'rb') as f:
|
||||
# 读取 GLB 头部
|
||||
magic = f.read(4)
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
total_length = struct.unpack('<I', f.read(4))[0]
|
||||
|
||||
print("=" * 60)
|
||||
print(f"GLB 文件分析:")
|
||||
print(f" 文件类型: {magic.decode('utf-8')}")
|
||||
print(f" 版本: {version}")
|
||||
print(f" 总大小: {total_length:,} bytes")
|
||||
|
||||
# 读取 JSON chunk
|
||||
json_length = struct.unpack('<I', f.read(4))[0]
|
||||
json_type = f.read(4)
|
||||
|
||||
if json_type != b'JSON':
|
||||
raise ValueError(f"期望 JSON chunk,但得到: {json_type}")
|
||||
|
||||
self.json_data = json.loads(f.read(json_length).decode('utf-8'))
|
||||
|
||||
print(f" JSON 大小: {json_length:,} bytes")
|
||||
|
||||
# 读取 Binary chunk
|
||||
if f.tell() < total_length:
|
||||
bin_length = struct.unpack('<I', f.read(4))[0]
|
||||
bin_type = f.read(4)
|
||||
|
||||
if bin_type != b'BIN\x00':
|
||||
raise ValueError(f"期望 BIN chunk,但得到: {bin_type}")
|
||||
|
||||
self.binary_data = f.read(bin_length)
|
||||
print(f" 二进制数据大小: {bin_length:,} bytes")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return self
|
||||
|
||||
def analyze_structure(self):
|
||||
"""分析 GLB 结构"""
|
||||
if not self.json_data:
|
||||
self.parse_glb_structure()
|
||||
|
||||
print("\nGLB 结构分析:")
|
||||
print("-" * 40)
|
||||
|
||||
# 基本信息
|
||||
asset = self.json_data.get('asset', {})
|
||||
print(f"生成器: {asset.get('generator', '未知')}")
|
||||
print(f"glTF 版本: {asset.get('version', '未知')}")
|
||||
|
||||
# Draco 扩展
|
||||
extensions_used = self.json_data.get('extensionsUsed', [])
|
||||
extensions_required = self.json_data.get('extensionsRequired', [])
|
||||
|
||||
if 'KHR_draco_mesh_compression' in extensions_used:
|
||||
print("使用 Draco 压缩")
|
||||
if 'KHR_draco_mesh_compression' in extensions_required:
|
||||
print("Draco 压缩是必需的")
|
||||
|
||||
# 网格信息
|
||||
meshes = self.json_data.get('meshes', [])
|
||||
print(f"\n网格数量: {len(meshes)}")
|
||||
|
||||
for i, mesh in enumerate(meshes):
|
||||
print(f" 网格 {i}: {mesh.get('name', '未命名')}")
|
||||
primitives = mesh.get('primitives', [])
|
||||
print(f" 图元数量: {len(primitives)}")
|
||||
|
||||
for j, primitive in enumerate(primitives):
|
||||
print(f" 图元 {j}:")
|
||||
if 'extensions' in primitive:
|
||||
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
|
||||
if draco_info:
|
||||
print(f" 使用 Draco 压缩")
|
||||
print(f" 属性: {draco_info.get('attributes', {})}")
|
||||
|
||||
# 缓冲区信息
|
||||
buffers = self.json_data.get('buffers', [])
|
||||
buffer_views = self.json_data.get('bufferViews', [])
|
||||
accessors = self.json_data.get('accessors', [])
|
||||
|
||||
print(f"\n缓冲区: {len(buffers)}")
|
||||
print(f"BufferViews: {len(buffer_views)}")
|
||||
print(f"访问器: {len(accessors)}")
|
||||
|
||||
return self
|
||||
|
||||
def decode_draco_meshes(self):
|
||||
"""解码所有 Draco 压缩的网格"""
|
||||
if not self.json_data:
|
||||
self.parse_glb_structure()
|
||||
|
||||
meshes = []
|
||||
buffer_views = self.json_data.get('bufferViews', [])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("开始解码 Draco 压缩数据...")
|
||||
print("=" * 60)
|
||||
|
||||
for mesh_idx, mesh in enumerate(self.json_data.get('meshes', [])):
|
||||
mesh_name = mesh.get('name', f'mesh_{mesh_idx}')
|
||||
|
||||
for primitive_idx, primitive in enumerate(mesh.get('primitives', [])):
|
||||
if 'extensions' in primitive:
|
||||
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
|
||||
|
||||
if draco_info:
|
||||
print(f"\n解码: {mesh_name} - 图元 {primitive_idx}")
|
||||
|
||||
# 解码 Draco 数据
|
||||
mesh_data = self._decode_primitive(draco_info, buffer_views)
|
||||
|
||||
if mesh_data:
|
||||
meshes.append({
|
||||
'mesh_idx': mesh_idx,
|
||||
'primitive_idx': primitive_idx,
|
||||
'name': mesh_name,
|
||||
**mesh_data
|
||||
})
|
||||
|
||||
self.decoded_meshes = meshes # 缓存解码结果
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"解码完成!共解码 {len(meshes)} 个网格")
|
||||
print("=" * 60)
|
||||
|
||||
return meshes
|
||||
|
||||
def get_vertices(self, mesh_idx=0, primitive_idx=0):
|
||||
"""
|
||||
获取指定网格的顶点集合
|
||||
|
||||
参数:
|
||||
mesh_idx: 网格索引,默认0
|
||||
primitive_idx: 图元索引,默认0
|
||||
|
||||
返回:
|
||||
np.array: 顶点数组,形状为 (n, 3) 或 None
|
||||
"""
|
||||
# 如果还没有解码数据,先解码
|
||||
if not self.decoded_meshes:
|
||||
self.decode_draco_meshes()
|
||||
|
||||
# 查找指定网格
|
||||
for mesh in self.decoded_meshes:
|
||||
if mesh['mesh_idx'] == mesh_idx and mesh['primitive_idx'] == primitive_idx:
|
||||
return mesh.get('vertices')
|
||||
|
||||
print(f"未找到网格 {mesh_idx} 的图元 {primitive_idx}")
|
||||
return None
|
||||
|
||||
def get_all_vertices(self):
|
||||
"""
|
||||
获取所有网格的所有顶点,合并成一个数组
|
||||
|
||||
返回:
|
||||
np.array: 所有顶点的合并数组,形状为 (n, 3) 或 None
|
||||
"""
|
||||
# 如果还没有解码数据,先解码
|
||||
if not self.decoded_meshes:
|
||||
self.decode_draco_meshes()
|
||||
|
||||
if not self.decoded_meshes:
|
||||
print("没有解码的网格数据")
|
||||
return None
|
||||
|
||||
# 收集所有顶点
|
||||
all_vertices = []
|
||||
for mesh in self.decoded_meshes:
|
||||
if mesh.get('vertices') is not None:
|
||||
all_vertices.append(mesh['vertices'])
|
||||
|
||||
if not all_vertices:
|
||||
return None
|
||||
|
||||
# 合并所有顶点
|
||||
return np.vstack(all_vertices)
|
||||
|
||||
def get_vertices_by_mesh_name(self, mesh_name):
|
||||
"""
|
||||
根据网格名称获取顶点集合
|
||||
|
||||
参数:
|
||||
mesh_name: 网格名称
|
||||
|
||||
返回:
|
||||
list: 包含所有匹配网格的顶点数组列表
|
||||
"""
|
||||
# 如果还没有解码数据,先解码
|
||||
if not self.decoded_meshes:
|
||||
self.decode_draco_meshes()
|
||||
|
||||
vertices_list = []
|
||||
for mesh in self.decoded_meshes:
|
||||
if mesh['name'] == mesh_name and mesh.get('vertices') is not None:
|
||||
vertices_list.append(mesh['vertices'])
|
||||
|
||||
return vertices_list
|
||||
|
||||
def get_vertex_count(self):
|
||||
"""
|
||||
获取总顶点数
|
||||
|
||||
返回:
|
||||
int: 所有网格的总顶点数
|
||||
"""
|
||||
vertices = self.get_all_vertices()
|
||||
return len(vertices) if vertices is not None else 0
|
||||
|
||||
def _decode_primitive(self, draco_info, buffer_views):
|
||||
"""解码单个图元的 Draco 数据"""
|
||||
try:
|
||||
# 获取 bufferView 信息
|
||||
buffer_view_idx = draco_info['bufferView']
|
||||
attributes = draco_info['attributes']
|
||||
|
||||
buffer_view = buffer_views[buffer_view_idx]
|
||||
byte_offset = buffer_view.get('byteOffset', 0)
|
||||
byte_length = buffer_view['byteLength']
|
||||
|
||||
print(f" BufferView: {buffer_view_idx}")
|
||||
print(f" 属性映射: {attributes}")
|
||||
print(f" 数据位置: offset={byte_offset}, length={byte_length}")
|
||||
|
||||
# 提取 Draco 压缩数据
|
||||
draco_data = self.binary_data[byte_offset:byte_offset + byte_length]
|
||||
print(f" Draco 数据大小: {len(draco_data):,} bytes")
|
||||
|
||||
# 使用 DracoPy 解码
|
||||
print(" 正在使用 DracoPy 解码...")
|
||||
draco_decoder = DracoPy.decode(draco_data)
|
||||
|
||||
# 解析解码结果
|
||||
mesh_data = self._parse_draco_result(draco_decoder, attributes)
|
||||
|
||||
return mesh_data
|
||||
|
||||
except Exception as e:
|
||||
print(f" 解码失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def _parse_draco_result(self, draco_decoder, attributes):
|
||||
"""解析 DracoPy 解码结果"""
|
||||
result = {
|
||||
'vertices': None,
|
||||
'faces': None,
|
||||
'texcoords': None,
|
||||
'batch_ids': None,
|
||||
'normals': None,
|
||||
'colors': None
|
||||
}
|
||||
|
||||
# 获取顶点
|
||||
if hasattr(draco_decoder, 'points'):
|
||||
result['vertices'] = np.array(draco_decoder.points, dtype=np.float32)
|
||||
print(f" 顶点数量: {len(result['vertices'])}")
|
||||
|
||||
# 获取面/三角形
|
||||
if hasattr(draco_decoder, 'faces'):
|
||||
faces_data = draco_decoder.faces
|
||||
|
||||
# 确保是三角形(每面3个顶点)
|
||||
if len(faces_data) > 0:
|
||||
if isinstance(faces_data[0], list) or isinstance(faces_data[0], tuple):
|
||||
# 如果是列表的列表
|
||||
result['faces'] = np.array(faces_data, dtype=np.uint32)
|
||||
else:
|
||||
# 如果是扁平化的数组
|
||||
result['faces'] = np.array(faces_data, dtype=np.uint32).reshape(-1, 3)
|
||||
|
||||
print(f" 面数量: {len(result['faces']) if result['faces'] is not None else 0}")
|
||||
|
||||
# 获取属性数据
|
||||
if hasattr(draco_decoder, 'attributes'):
|
||||
attrs = draco_decoder.attributes
|
||||
|
||||
# 根据属性映射查找数据
|
||||
for gltf_attr_name, draco_attr_id in attributes.items():
|
||||
if draco_attr_id in attrs:
|
||||
attr_data = attrs[draco_attr_id]
|
||||
|
||||
if gltf_attr_name == 'POSITION':
|
||||
result['vertices'] = np.array(attr_data, dtype=np.float32)
|
||||
elif gltf_attr_name == 'TEXCOORD_0':
|
||||
result['texcoords'] = np.array(attr_data, dtype=np.float32)
|
||||
elif gltf_attr_name == '_BATCHID':
|
||||
result['batch_ids'] = np.array(attr_data, dtype=np.uint32)
|
||||
elif gltf_attr_name == 'NORMAL':
|
||||
result['normals'] = np.array(attr_data, dtype=np.float32)
|
||||
elif gltf_attr_name == 'COLOR_0':
|
||||
result['colors'] = np.array(attr_data, dtype=np.float32)
|
||||
|
||||
print(f" 已提取属性: {gltf_attr_name} (ID: {draco_attr_id})")
|
||||
|
||||
# 如果没有通过attributes获取到顶点,尝试其他方式
|
||||
if result['vertices'] is None and hasattr(draco_decoder, 'get_points'):
|
||||
try:
|
||||
result['vertices'] = np.array(draco_decoder.get_points(), dtype=np.float32)
|
||||
except:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
def save_decoded_meshes(self, meshes, output_format='obj'):
|
||||
"""保存解码后的网格"""
|
||||
import os
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(self.glb_file_path))[0]
|
||||
output_dir = f"{base_name}_decoded"
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
for mesh in meshes:
|
||||
filename = f"{output_dir}/{mesh['name']}_p{mesh['primitive_idx']}.{output_format}"
|
||||
|
||||
if output_format == 'obj':
|
||||
self._save_as_obj(mesh, filename)
|
||||
elif output_format == 'ply':
|
||||
self._save_as_ply(mesh, filename)
|
||||
elif output_format == 'npz':
|
||||
self._save_as_npz(mesh, filename)
|
||||
else:
|
||||
print(f"不支持的格式: {output_format}")
|
||||
continue
|
||||
|
||||
print(f"已保存: {filename}")
|
||||
|
||||
def _save_as_obj(self, mesh, filename):
|
||||
"""保存为 OBJ 格式"""
|
||||
with open(filename, 'w') as f:
|
||||
# 写入顶点
|
||||
if mesh['vertices'] is not None:
|
||||
for v in mesh['vertices']:
|
||||
f.write(f"v {v[0]} {v[1]} {v[2]}\n")
|
||||
|
||||
# 写入纹理坐标
|
||||
if mesh['texcoords'] is not None:
|
||||
for uv in mesh['texcoords']:
|
||||
f.write(f"vt {uv[0]} {uv[1]}\n")
|
||||
|
||||
# 写入法线
|
||||
if mesh['normals'] is not None:
|
||||
for n in mesh['normals']:
|
||||
f.write(f"vn {n[0]} {n[1]} {n[2]}\n")
|
||||
|
||||
# 写入面
|
||||
if mesh['faces'] is not None:
|
||||
for face in mesh['faces']:
|
||||
# OBJ 索引从1开始
|
||||
face_indices = [str(idx + 1) for idx in face]
|
||||
f.write(f"f {' '.join(face_indices)}\n")
|
||||
|
||||
def _save_as_ply(self, mesh, filename):
|
||||
"""保存为 PLY 格式"""
|
||||
from plyfile import PlyData, PlyElement
|
||||
import numpy as np
|
||||
|
||||
vertices = mesh['vertices']
|
||||
faces = mesh['faces']
|
||||
|
||||
if vertices is None:
|
||||
return
|
||||
|
||||
# 创建顶点数据
|
||||
vertex_data = np.zeros(len(vertices), dtype=[
|
||||
('x', 'f4'), ('y', 'f4'), ('z', 'f4')
|
||||
])
|
||||
|
||||
vertex_data['x'] = vertices[:, 0]
|
||||
vertex_data['y'] = vertices[:, 1]
|
||||
vertex_data['z'] = vertices[:, 2]
|
||||
|
||||
# 创建面数据
|
||||
if faces is not None:
|
||||
face_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (3,))])
|
||||
face_data['vertex_indices'] = faces
|
||||
|
||||
# 写入文件
|
||||
vertex_element = PlyElement.describe(vertex_data, 'vertex')
|
||||
|
||||
if faces is not None:
|
||||
face_element = PlyElement.describe(face_data, 'face')
|
||||
PlyData([vertex_element, face_element], text=False).write(filename)
|
||||
else:
|
||||
PlyData([vertex_element], text=False).write(filename)
|
||||
|
||||
def _save_as_npz(self, mesh, filename):
|
||||
"""保存为 NPZ 格式"""
|
||||
np.savez(
|
||||
filename,
|
||||
vertices=mesh['vertices'],
|
||||
faces=mesh['faces'],
|
||||
texcoords=mesh['texcoords'],
|
||||
batch_ids=mesh['batch_ids'],
|
||||
normals=mesh['normals'],
|
||||
colors=mesh['colors']
|
||||
)
|
||||
|
||||
# 使用示例
|
||||
def main():
|
||||
# 初始化解析器
|
||||
parser = DracoGLBParser(r"D:\devForBdzlWork\ai_project_v1\b3dm\test\temp_glb\temp_6e895637.glb")
|
||||
|
||||
# 解析 GLB 结构
|
||||
parser.parse_glb_structure()
|
||||
|
||||
# 分析结构
|
||||
parser.analyze_structure()
|
||||
|
||||
# 解码 Draco 网格
|
||||
meshes = parser.decode_draco_meshes()
|
||||
|
||||
# 使用新增的顶点获取方法
|
||||
print("\n" + "=" * 60)
|
||||
print("顶点获取方法演示:")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取第一个网格的第一个图元的顶点
|
||||
vertices = parser.get_vertices(mesh_idx=0, primitive_idx=0)
|
||||
if vertices is not None:
|
||||
print(f"1. 获取第一个网格顶点:")
|
||||
print(f" 形状: {vertices.shape}")
|
||||
print(f" 数据类型: {vertices.dtype}")
|
||||
print(f" 前5个顶点: \n{vertices[:5]}")
|
||||
|
||||
# 2. 获取所有顶点(合并)
|
||||
all_vertices = parser.get_all_vertices()
|
||||
if all_vertices is not None:
|
||||
print(f"\n2. 获取所有网格顶点(合并):")
|
||||
print(f" 总顶点数: {len(all_vertices)}")
|
||||
print(f" 形状: {all_vertices.shape}")
|
||||
|
||||
# 3. 获取总顶点数
|
||||
total_vertices = parser.get_vertex_count()
|
||||
print(f"\n3. 总顶点数: {total_vertices}")
|
||||
|
||||
# 4. 根据网格名称获取顶点
|
||||
if meshes:
|
||||
mesh_name = meshes[0]['name']
|
||||
vertices_list = parser.get_vertices_by_mesh_name(mesh_name)
|
||||
print(f"\n4. 根据名称 '{mesh_name}' 获取的顶点:")
|
||||
for i, verts in enumerate(vertices_list):
|
||||
print(f" 图元 {i}: {verts.shape if verts is not None else 'None'}")
|
||||
|
||||
# 保存解码后的网格
|
||||
parser.save_decoded_meshes(meshes, output_format='obj')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
380
b3dm/slope_aspect_img.py
Normal file
380
b3dm/slope_aspect_img.py
Normal file
@ -0,0 +1,380 @@
|
||||
# slope_aspect_img13
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.ndimage import sobel
|
||||
import matplotlib as mpl
|
||||
import rasterio
|
||||
from osgeo import gdal
|
||||
from matplotlib.patches import FancyArrowPatch
|
||||
from mpl_toolkits.mplot3d import proj3d
|
||||
import os
|
||||
|
||||
# 自定义3D箭头类
|
||||
class Arrow3D(FancyArrowPatch):
|
||||
def __init__(self, xs, ys, zs, *args, **kwargs):
|
||||
super().__init__((0,0), (0,0), *args, **kwargs)
|
||||
self._verts3d = xs, ys, zs
|
||||
|
||||
def do_3d_projection(self, renderer=None):
|
||||
xs3d, ys3d, zs3d = self._verts3d
|
||||
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
|
||||
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
|
||||
return min(zs)
|
||||
|
||||
def read_dem_rasterio(filepath):
|
||||
"""使用rasterio读取DEM数据"""
|
||||
try:
|
||||
with rasterio.open(filepath) as src:
|
||||
dem_data = src.read(1)
|
||||
transform = src.transform
|
||||
bounds = src.bounds
|
||||
crs = src.crs
|
||||
|
||||
print(f"DEM信息:")
|
||||
print(f" 尺寸: {dem_data.shape}")
|
||||
print(f" 范围: {bounds}")
|
||||
print(f" 坐标系: {crs}")
|
||||
print(f" 高程范围: {np.nanmin(dem_data):.2f} - {np.nanmax(dem_data):.2f}")
|
||||
|
||||
if np.isnan(dem_data).any():
|
||||
print(" 检测到NaN值")
|
||||
dem_data = np.nan_to_num(dem_data, nan=np.nanmean(dem_data))
|
||||
|
||||
rows, cols = dem_data.shape
|
||||
x = np.linspace(bounds.left, bounds.right, cols)
|
||||
y = np.linspace(bounds.bottom, bounds.top, rows)
|
||||
X, Y = np.meshgrid(x, y)
|
||||
|
||||
return X, Y, dem_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用rasterio读取失败: {e}")
|
||||
return None
|
||||
|
||||
def read_slope_aspect_by_dem(dem_path, overall_3d_output_path=None) :
|
||||
# ---------------------- 核心:解决中文乱码配置 ----------------------
|
||||
setup_chinese_font()
|
||||
|
||||
# 尝试读取DEM数据
|
||||
print("正在读取DEM文件...")
|
||||
X, Y, Z = read_dem_rasterio(dem_path)
|
||||
|
||||
if X is None:
|
||||
raise FileNotFoundError(f"无法读取DEM文件: {dem_path}")
|
||||
|
||||
# 检查数据有效性
|
||||
if Z.size == 0 or np.all(Z == 0):
|
||||
raise ValueError("DEM数据无效或全部为0")
|
||||
|
||||
# 重采样
|
||||
if Z.shape[0] > 200 or Z.shape[1] > 200:
|
||||
print(f"原始DEM尺寸较大 ({Z.shape}),进行重采样...")
|
||||
from scipy.ndimage import zoom
|
||||
scale_factor = min(200/Z.shape[0], 200/Z.shape[1])
|
||||
Z = zoom(Z, scale_factor, order=1)
|
||||
X = zoom(X, scale_factor, order=1)
|
||||
Y = zoom(Y, scale_factor, order=1)
|
||||
print(f"重采样后尺寸: {Z.shape}")
|
||||
|
||||
# ---------------------- 2. 计算坡度和坡向 ----------------------
|
||||
print("计算坡度和坡向...")
|
||||
|
||||
dx_pixel = np.abs(X[0,1] - X[0,0]) * 111000
|
||||
dy_pixel = np.abs(Y[1,0] - Y[0,0]) * 111000
|
||||
|
||||
dx = sobel(Z, axis=1) / (2 * dx_pixel)
|
||||
dy = sobel(Z, axis=0) / (2 * dy_pixel)
|
||||
|
||||
slope_rad = np.arctan(np.sqrt(dx**2 + dy**2))
|
||||
slope_deg = slope_rad * (180 / np.pi)
|
||||
|
||||
aspect_rad = np.arctan2(-dx, dy)
|
||||
aspect_deg = aspect_rad * (180 / np.pi)
|
||||
aspect_deg[aspect_deg < 0] += 360
|
||||
|
||||
print(f"坡度范围: {np.min(slope_deg):.2f}° - {np.max(slope_deg):.2f}°")
|
||||
print(f"坡向范围: {np.min(aspect_deg):.2f}° - {np.max(aspect_deg):.2f}°")
|
||||
|
||||
# ---------------------- 3. 计算整体坡向 ----------------------
|
||||
print("\n计算整体坡向...")
|
||||
|
||||
def calculate_overall_aspect(aspect_deg, slope_deg, method='weighted_mean'):
|
||||
"""计算整体坡向"""
|
||||
aspect_rad = np.deg2rad(aspect_deg)
|
||||
|
||||
if method == 'weighted_mean':
|
||||
# 坡度加权平均法
|
||||
u = np.sin(aspect_rad)
|
||||
v = np.cos(aspect_rad)
|
||||
|
||||
weights = slope_deg.flatten()
|
||||
weighted_u = np.nansum(u.flatten() * weights) / np.nansum(weights)
|
||||
weighted_v = np.nansum(v.flatten() * weights) / np.nansum(weights)
|
||||
|
||||
weighted_aspect_rad = np.arctan2(weighted_u, weighted_v)
|
||||
weighted_aspect_deg = np.rad2deg(weighted_aspect_rad)
|
||||
|
||||
if weighted_aspect_deg < 0:
|
||||
weighted_aspect_deg += 360
|
||||
|
||||
weighted_strength = np.sqrt(weighted_u**2 + weighted_v**2)
|
||||
return weighted_aspect_deg, weighted_strength, "坡度加权平均法"
|
||||
|
||||
elif method == 'vector_mean':
|
||||
# 向量平均法
|
||||
u = np.sin(aspect_rad)
|
||||
v = np.cos(aspect_rad)
|
||||
|
||||
mean_u = np.nanmean(u)
|
||||
mean_v = np.nanmean(v)
|
||||
|
||||
mean_aspect_rad = np.arctan2(mean_u, mean_v)
|
||||
mean_aspect_deg = np.rad2deg(mean_aspect_rad)
|
||||
|
||||
if mean_aspect_deg < 0:
|
||||
mean_aspect_deg += 360
|
||||
|
||||
vector_strength = np.sqrt(mean_u**2 + mean_v**2)
|
||||
return mean_aspect_deg, vector_strength, "向量平均法"
|
||||
|
||||
# 使用坡度加权平均法计算整体坡向
|
||||
overall_aspect, overall_strength, method_name = calculate_overall_aspect(aspect_deg, slope_deg, 'weighted_mean')
|
||||
|
||||
# 将整体坡向转换为方向描述
|
||||
if overall_aspect < 22.5 or overall_aspect >= 337.5:
|
||||
overall_direction = "北"
|
||||
elif 22.5 <= overall_aspect < 67.5:
|
||||
overall_direction = "东北"
|
||||
elif 67.5 <= overall_aspect < 112.5:
|
||||
overall_direction = "东"
|
||||
elif 112.5 <= overall_aspect < 157.5:
|
||||
overall_direction = "东南"
|
||||
elif 157.5 <= overall_aspect < 202.5:
|
||||
overall_direction = "南"
|
||||
elif 202.5 <= overall_aspect < 247.5:
|
||||
overall_direction = "西南"
|
||||
elif 247.5 <= overall_aspect < 292.5:
|
||||
overall_direction = "西"
|
||||
else:
|
||||
overall_direction = "西北"
|
||||
|
||||
print(f"整体坡向 ({method_name}):")
|
||||
print(f" 角度: {overall_aspect:.1f}°")
|
||||
print(f" 方向: {overall_direction}")
|
||||
# print(f" 一致性: {overall_strength:.3f}")
|
||||
|
||||
# ---------------------- 5. 3D可视化(俯视图,包含整体坡向和关键点坡向) ----------------------
|
||||
print("\n生成3D俯视图可视化...")
|
||||
fig3d, ax3d = plt.subplots(figsize=(16, 12), subplot_kw={"projection": "3d"})
|
||||
|
||||
# 绘制地形曲面 - 俯视图需要更清晰的地形表现
|
||||
norm = mpl.colors.Normalize(vmin=np.percentile(slope_deg, 5),
|
||||
vmax=np.percentile(slope_deg, 95))
|
||||
|
||||
plot_skip = max(2, Z.shape[0] // 60) # 增加采样密度,使俯视图更清晰
|
||||
X_plot = X[::plot_skip, ::plot_skip]
|
||||
Y_plot = Y[::plot_skip, ::plot_skip]
|
||||
Z_plot = Z[::plot_skip, ::plot_skip]
|
||||
slope_plot = slope_deg[::plot_skip, ::plot_skip]
|
||||
|
||||
surf = ax3d.plot_surface(
|
||||
X_plot, Y_plot, Z_plot,
|
||||
cmap="viridis_r",
|
||||
alpha=0.85, # 增加透明度,使箭头更明显
|
||||
linewidth=0.1, # 很细的网格线
|
||||
facecolors=plt.cm.viridis_r(norm(slope_plot)),
|
||||
zorder=1
|
||||
)
|
||||
|
||||
# ---------------------- 绘制整体坡向箭头(中心位置,红色粗箭头) ----------------------
|
||||
center_x = (X.min() + X.max()) / 2
|
||||
center_y = (Y.min() + Y.max()) / 2
|
||||
|
||||
# 整体坡向箭头长度
|
||||
arrow_length_overall = 0.15 * min(X.max()-X.min(), Y.max()-Y.min())
|
||||
|
||||
# 计算整体坡向箭头方向
|
||||
scale_factor = 1.5
|
||||
overall_aspect_rad = np.deg2rad(overall_aspect)
|
||||
dx_overall = np.sin(overall_aspect_rad) * arrow_length_overall * scale_factor
|
||||
dy_overall = np.cos(overall_aspect_rad) * arrow_length_overall * scale_factor
|
||||
|
||||
# 方案2:如果希望箭头在地形上方一定高度
|
||||
terrain_max_z = Z.max() # 地形最高点
|
||||
float_height = 0.2 * terrain_max_z # 在地形最高点上方20%的高度
|
||||
|
||||
# 绘制整体坡向箭头(红色,粗)
|
||||
arrow_overall = Arrow3D(
|
||||
[center_x, center_x + dx_overall],
|
||||
[center_y, center_y + dy_overall],
|
||||
[terrain_max_z + float_height, terrain_max_z + float_height], # 在地形上方
|
||||
mutation_scale=25, # 稍大一些的箭头
|
||||
lw=5, # 更粗的线
|
||||
arrowstyle='-|>',
|
||||
color='red',
|
||||
alpha=0.98, # 更高的透明度
|
||||
zorder=20 # 更高的绘制顺序
|
||||
)
|
||||
ax3d.add_artist(arrow_overall)
|
||||
|
||||
# 在整体坡向箭头起点添加标记
|
||||
ax3d.scatter(center_x, center_y, terrain_max_z + float_height,
|
||||
color='red', s=120, edgecolor='white', linewidth=2, zorder=21)
|
||||
|
||||
# ---------------------- 整体坡向面板放置在左上角 ----------------------
|
||||
# 计算左上角位置
|
||||
panel_x = X.min() + 0.02 * (X.max() - X.min()) # 左边留2%的边距
|
||||
panel_y = Y.max() - 0.02 * (Y.max() - Y.min()) # 上边留2%的边距
|
||||
panel_z = Z.max() + 0.5 * (Z.max() - Z.min()) # 进一步提高Z坐标,确保在视野内
|
||||
|
||||
# 整体坡向面板内容
|
||||
overall_info = (
|
||||
f"整体坡向分析\n"
|
||||
f"角度: {overall_aspect:.1f}°\n"
|
||||
f"方向: {overall_direction}"
|
||||
# f"一致性: {overall_strength:.3f}"
|
||||
)
|
||||
|
||||
# 绘制整体坡向面板(左上角)
|
||||
ax3d.text(panel_x, panel_y, panel_z,
|
||||
overall_info,
|
||||
fontsize=12, fontweight='bold',
|
||||
bbox=dict(facecolor='white', alpha=0.5, boxstyle="round,pad=0.5",
|
||||
edgecolor='red', linewidth=2.5),
|
||||
ha='left', va='top', zorder=30)
|
||||
|
||||
# ---------------------- 设置3D图参数 ----------------------
|
||||
ax3d.set_xlabel("经度 (X)", fontsize=12)
|
||||
ax3d.set_ylabel("纬度 (Y)", fontsize=12)
|
||||
ax3d.set_zlabel("高程 (m)", fontsize=12)
|
||||
|
||||
# 获取文件名用于标题
|
||||
filename = os.path.basename(dem_path)
|
||||
ax3d.set_title(f"DEM三维俯视图 - 坡向分析\n文件: {filename}",
|
||||
fontsize=14, fontweight='bold', pad=20)
|
||||
|
||||
# 添加颜色条
|
||||
cbar = plt.colorbar(
|
||||
mpl.cm.ScalarMappable(norm=norm, cmap='viridis_r'),
|
||||
ax=ax3d, shrink=0.6, aspect=25, pad=0.15
|
||||
)
|
||||
cbar.set_label("坡度 (°)", fontsize=12)
|
||||
|
||||
# ---------------------- 设置为俯视图(高仰角) ----------------------
|
||||
view_elev = 85 # 接近90度的仰角,俯视效果
|
||||
view_azim = overall_aspect + 180 # 从坡向的相反方向观看,可以看到坡面
|
||||
|
||||
# 确保方位角在0-360度范围内
|
||||
view_azim = view_azim % 360
|
||||
|
||||
print(f"\n设置3D俯视图视角:")
|
||||
print(f" 整体坡向: {overall_aspect:.1f}° ({overall_direction})")
|
||||
print(f" 视角方位角: {view_azim:.1f}°")
|
||||
print(f" 视角仰角: {view_elev:.1f}°")
|
||||
|
||||
# 应用俯视图视角设置
|
||||
ax3d.view_init(elev=view_elev, azim=view_azim)
|
||||
|
||||
# 调整相机距离,使视角更广
|
||||
ax3d.dist = 9.0 # 增加相机距离
|
||||
|
||||
# 设置坐标轴范围,确保所有元素都显示
|
||||
z_min, z_max = Z.min(), Z.max()
|
||||
z_padding = 0.6 * (z_max - z_min) # 适当增加Z轴范围
|
||||
ax3d.set_zlim(z_min - 0.1*z_padding, z_max + z_padding)
|
||||
|
||||
# 设置XY轴范围
|
||||
x_margin = 0.1 * (X.max() - X.min())
|
||||
y_margin = 0.1 * (Y.max() - Y.min())
|
||||
ax3d.set_xlim(X.min() - x_margin, X.max() + x_margin)
|
||||
ax3d.set_ylim(Y.min() - y_margin, Y.max() + y_margin)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
if not overall_3d_output_path :
|
||||
overall_3d_output_path = dem_path.replace('.tif', '_slopeAspect_3D_overlook.png')
|
||||
plt.savefig(overall_3d_output_path, dpi=250, bbox_inches='tight', facecolor='white')
|
||||
# plt.show()
|
||||
|
||||
os.remove(dem_path)
|
||||
print(f"\n3D俯视图已保存: {overall_3d_output_path}")
|
||||
print("分析完成!")
|
||||
print("="*60)
|
||||
|
||||
return overall_3d_output_path
|
||||
|
||||
# ---------------------- 跨平台中文字体配置 ----------------------
|
||||
def setup_chinese_font():
|
||||
"""设置中文字体支持,兼容Windows、Linux、macOS"""
|
||||
import platform
|
||||
|
||||
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
|
||||
|
||||
system = platform.system()
|
||||
|
||||
# 尝试添加系统中文字体路径
|
||||
font_paths = []
|
||||
|
||||
if system == 'Windows':
|
||||
# Windows 字体路径
|
||||
font_paths = [
|
||||
'C:/Windows/Fonts/simhei.ttf', # 黑体
|
||||
'C:/Windows/Fonts/simkai.ttf', # 楷体
|
||||
'C:/Windows/Fonts/simsun.ttc', # 宋体
|
||||
'C:/Windows/Fonts/microsoftyahei.ttf', # 微软雅黑
|
||||
]
|
||||
elif system == 'Darwin': # macOS
|
||||
# macOS 字体路径
|
||||
font_paths = [
|
||||
'/System/Library/Fonts/PingFang.ttc', # 苹方
|
||||
'/System/Library/Fonts/STHeiti Light.ttc', # 华文黑体
|
||||
'/System/Library/Fonts/STHeiti Medium.ttc',
|
||||
'/Library/Fonts/Arial Unicode.ttf', # Arial Unicode
|
||||
]
|
||||
elif system == 'Linux':
|
||||
# Linux 字体路径
|
||||
font_paths = [
|
||||
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑
|
||||
'/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑
|
||||
'/usr/share/fonts/truetype/arphic/uming.ttc', # AR PL UMing
|
||||
'/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # Noto
|
||||
]
|
||||
|
||||
# 尝试找到并添加第一个可用的中文字体
|
||||
font_added = False
|
||||
for font_path in font_paths:
|
||||
try:
|
||||
if os.path.exists(font_path):
|
||||
font_prop = mpl.font_manager.FontProperties(fname=font_path)
|
||||
font_name = font_prop.get_name()
|
||||
mpl.font_manager.fontManager.addfont(font_path)
|
||||
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
|
||||
font_added = True
|
||||
print(f"已添加中文字体: {font_name} ({font_path})")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"添加字体 {font_path} 失败: {e}")
|
||||
continue
|
||||
|
||||
# 如果以上方法都失败,使用通用备选方案
|
||||
if not font_added:
|
||||
if system == 'Windows':
|
||||
fallback_fonts = ['SimHei', 'Microsoft YaHei', 'KaiTi', 'FangSong']
|
||||
elif system == 'Darwin':
|
||||
fallback_fonts = ['PingFang SC', 'Hiragino Sans GB', 'Apple LiGothic Medium']
|
||||
else: # Linux and others
|
||||
fallback_fonts = ['DejaVu Sans', 'WenQuanYi Micro Hei',
|
||||
'Noto Sans CJK SC', 'Heiti TC', 'AR PL UMing CN']
|
||||
|
||||
current_fonts = plt.rcParams.get('font.sans-serif', [])
|
||||
plt.rcParams['font.sans-serif'] = fallback_fonts + current_fonts
|
||||
print(f"使用备选字体方案: {fallback_fonts[:2]}...")
|
||||
|
||||
# 设置字体族
|
||||
plt.rcParams['font.family'] = 'sans-serif'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dem_path = r'D:/devForBdzlWork/ai_project_v1/b3dm/o_dem_f1cb6f69_slopeAspect.tif'
|
||||
read_slope_aspect_by_dem(dem_path)
|
||||
1062
b3dm/slope_aspect_tif.py
Normal file
1062
b3dm/slope_aspect_tif.py
Normal file
File diff suppressed because it is too large
Load Diff
419
b3dm/terrain_api.py
Normal file
419
b3dm/terrain_api.py
Normal file
@ -0,0 +1,419 @@
|
||||
from sanic import Blueprint, Request, json
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import List, Optional, Dict, Any
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import threading
|
||||
import os
|
||||
from b3dm.terrain_calculator import TerrainCalculator
|
||||
|
||||
terrain_bp = Blueprint("terrain", url_prefix="")
|
||||
MINIO_SUB_PATH = "slopeAspectPng"
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 请求模型
|
||||
class NormalVector(BaseModel):
|
||||
"""法向量模型"""
|
||||
nx: float = Field(..., description="法向量X分量")
|
||||
ny: float = Field(..., description="法向量Y分量")
|
||||
nz: float = Field(..., description="法向量Z分量")
|
||||
|
||||
@field_validator('nx', 'ny', 'nz')
|
||||
def check_finite(cls, v):
|
||||
if not np.isfinite(v):
|
||||
raise ValueError(f"值必须是有限数字,得到: {v}")
|
||||
return v
|
||||
|
||||
def to_list(self):
|
||||
return [self.nx, self.ny, self.nz]
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
"""批量请求模型"""
|
||||
vectors: List[List[float]] = Field(..., description="法向量列表")
|
||||
|
||||
@field_validator('vectors')
|
||||
def validate_vectors(cls, v):
|
||||
if len(v) > 1000:
|
||||
raise ValueError("批量处理最多支持1000个向量")
|
||||
for vec in v:
|
||||
if len(vec) != 3:
|
||||
raise ValueError("每个向量必须是长度为3的列表")
|
||||
if not all(isinstance(x, (int, float)) for x in vec):
|
||||
raise ValueError("向量元素必须是数字")
|
||||
return v
|
||||
|
||||
class PointItem(BaseModel):
|
||||
"""单个点模型"""
|
||||
x: float = Field(..., description="x坐标")
|
||||
y: float = Field(..., description="y坐标")
|
||||
z: float = Field(..., description="z坐标")
|
||||
|
||||
class PointRequest(BaseModel):
|
||||
points: List[PointItem] = Field(..., description="点列表")
|
||||
url: str = Field(..., description="URL地址")
|
||||
|
||||
@field_validator('points')
|
||||
def validate_points_count(cls, v):
|
||||
if len(v) > 10:
|
||||
raise ValueError("批量处理最多支持10个点")
|
||||
return v
|
||||
|
||||
class PreloadRequest(BaseModel):
|
||||
url: str = Field(..., description="URL地址")
|
||||
|
||||
class AnalysisConfig(BaseModel):
|
||||
"""分析配置"""
|
||||
classify: bool = Field(default=True, description="是否进行分类")
|
||||
include_percent: bool = Field(default=True, description="是否包含坡度百分比")
|
||||
include_direction: bool = Field(default=True, description="是否包含方向描述")
|
||||
|
||||
# 中间件:请求计时
|
||||
@terrain_bp.middleware("request")
|
||||
async def add_start_time(request: Request):
|
||||
request.ctx.start_time = time.time()
|
||||
|
||||
@terrain_bp.middleware("response")
|
||||
async def add_response_time(request: Request, response):
|
||||
if hasattr(request.ctx, "start_time"):
|
||||
process_time = (time.time() - request.ctx.start_time) * 1000
|
||||
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/slope")
|
||||
async def calculate_slope(request: Request):
|
||||
"""计算坡度"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = NormalVector(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 计算坡度
|
||||
result = TerrainCalculator.calculate_slope(vector.to_list())
|
||||
|
||||
# 检查是否有错误
|
||||
if "error" in result and result["error"]:
|
||||
return json(result, status=400)
|
||||
|
||||
return json({
|
||||
"success": True,
|
||||
"data": result,
|
||||
"request": {
|
||||
"input_vector": vector.to_list(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡度计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/aspect")
|
||||
async def calculate_aspect1(request: Request):
|
||||
"""计算坡向"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = NormalVector(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 计算坡向
|
||||
result = TerrainCalculator.calculate_aspect(vector.to_list())
|
||||
|
||||
# 检查是否有错误
|
||||
if "error" in result and result["error"]:
|
||||
return json(result, status=400)
|
||||
|
||||
return json({
|
||||
"success": True,
|
||||
"data": result,
|
||||
"request": {
|
||||
"input_vector": vector.to_list(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/preload3dTiles")
|
||||
async def preload_3dtiles(request: Request):
|
||||
"""预加载3dtiles地图数据"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = PreloadRequest(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 创建并启动线程
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
thread1 = threading.Thread(target=TerrainCalculator.preload_3dtiles, args=(vector.url,))
|
||||
# 启动线程
|
||||
thread1.start()
|
||||
url_prefix = extract_and_rebuild_url(vector.url)
|
||||
return json({
|
||||
"success": True,
|
||||
"data": f"{script_dir}/data_3dtiles",
|
||||
"request": {
|
||||
"input_vector": vector.model_dump(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/slopeAspect")
|
||||
async def calculate_slopeAspect(request: Request):
|
||||
"""生成坡向坡度俯视图"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = PointRequest(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 生成坡向坡度俯视图
|
||||
region_coords = [(point.x, point.y, point.z) for point in vector.points]
|
||||
overall_3d_png_name = f"o_dem_{uuid.uuid4().hex[:8]}_slopeAspect.png"
|
||||
# 创建并启动线程
|
||||
thread1 = threading.Thread(target=TerrainCalculator.generate_slopeAspect_3d_overlook, args=(region_coords, vector.url, overall_3d_png_name, MINIO_SUB_PATH))
|
||||
# 启动线程
|
||||
thread1.start()
|
||||
url_prefix = extract_and_rebuild_url(vector.url)
|
||||
return json({
|
||||
"success": True,
|
||||
"data": f"{url_prefix}/{MINIO_SUB_PATH}/{overall_3d_png_name}",
|
||||
"request": {
|
||||
"input_vector": vector.model_dump(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/slopeAspectTif")
|
||||
async def calculate_slopeAspect_tif(request: Request):
|
||||
"""生成坡向坡度tif文件"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = PointRequest(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 生成坡向坡度俯视图
|
||||
region_coords = [(point.x, point.y, point.z) for point in vector.points]
|
||||
slope_aspect_tif_name = f"o_dem_{uuid.uuid4().hex[:8]}_slopeAspect.tif"
|
||||
# 创建并启动线程
|
||||
thread1 = threading.Thread(target=TerrainCalculator.generate_slopeAspect_tif, args=(region_coords, vector.url, slope_aspect_tif_name, MINIO_SUB_PATH))
|
||||
# 启动线程
|
||||
thread1.start()
|
||||
url_prefix = extract_and_rebuild_url(vector.url)
|
||||
return json({
|
||||
"success": True,
|
||||
"data": f"{url_prefix}/{MINIO_SUB_PATH}/{slope_aspect_tif_name}",
|
||||
"request": {
|
||||
"input_vector": vector.model_dump(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/both")
|
||||
async def calculate_both(request: Request):
|
||||
"""同时计算坡度和坡向"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = NormalVector(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 计算坡度和坡向
|
||||
result = TerrainCalculator.calculate_slope_aspect(vector.to_list())
|
||||
|
||||
# 检查是否有错误
|
||||
if "error" in result and result["error"]:
|
||||
return json(result, status=400)
|
||||
|
||||
return json({
|
||||
"success": True,
|
||||
"data": result,
|
||||
"request": {
|
||||
"input_vector": vector.to_list(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"综合计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/batch")
|
||||
async def batch_calculate(request: Request):
|
||||
"""批量计算"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
batch_request = BatchRequest(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 批量计算
|
||||
start_time = time.time()
|
||||
result = TerrainCalculator.batch_calculate(batch_request.vectors)
|
||||
process_time = (time.time() - start_time) * 1000
|
||||
|
||||
if "error" in result and result["error"]:
|
||||
return json(result, status=400)
|
||||
|
||||
return json({
|
||||
"success": True,
|
||||
"data": result,
|
||||
"performance": {
|
||||
"process_time_ms": process_time,
|
||||
"vectors_per_second": len(batch_request.vectors) / (process_time / 1000) if process_time > 0 else 0
|
||||
},
|
||||
"request": {
|
||||
"vector_count": len(batch_request.vectors),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.get("/api/v1/example")
|
||||
async def get_examples(request: Request):
|
||||
"""获取示例数据"""
|
||||
examples = {
|
||||
"flat": {
|
||||
"nx": 0.0,
|
||||
"ny": 0.0,
|
||||
"nz": 1.0,
|
||||
"expected_slope": 0.0,
|
||||
"description": "完全水平面"
|
||||
},
|
||||
"north_slope_30": {
|
||||
"nx": 0.0,
|
||||
"ny": -0.5,
|
||||
"nz": 0.8660254,
|
||||
"expected_slope": 30.0,
|
||||
"expected_aspect": 0.0,
|
||||
"description": "朝北30度斜坡"
|
||||
},
|
||||
"east_slope_45": {
|
||||
"nx": 0.7071068,
|
||||
"ny": 0.0,
|
||||
"nz": 0.7071068,
|
||||
"expected_slope": 45.0,
|
||||
"expected_aspect": 90.0,
|
||||
"description": "朝东45度斜坡"
|
||||
},
|
||||
"vertical": {
|
||||
"nx": 1.0,
|
||||
"ny": 0.0,
|
||||
"nz": 0.0,
|
||||
"expected_slope": 90.0,
|
||||
"description": "垂直面"
|
||||
}
|
||||
}
|
||||
|
||||
return json({
|
||||
"examples": examples,
|
||||
"count": len(examples)
|
||||
})
|
||||
|
||||
def extract_and_rebuild_url(url):
|
||||
"""提取URL的三部分并重建"""
|
||||
# 解析URL
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
|
||||
# 1. 提取协议部分 (http/https)
|
||||
scheme = parsed.scheme or "http" # 如果没有协议,默认用http
|
||||
|
||||
# 2. 提取IP端口/主机部分
|
||||
netloc = parsed.netloc
|
||||
|
||||
# 3. 提取第一个路径分段
|
||||
path = parsed.path.strip("/") # 去掉首尾的斜杠
|
||||
path_parts = path.split("/")
|
||||
|
||||
if path_parts and path_parts[0]:
|
||||
first_segment = path_parts[0]
|
||||
else:
|
||||
first_segment = ""
|
||||
|
||||
# 重建URL
|
||||
if first_segment:
|
||||
rebuilt_url = f"{scheme}://{netloc}/{first_segment}"
|
||||
else:
|
||||
rebuilt_url = f"{scheme}://{netloc}"
|
||||
|
||||
return rebuilt_url
|
||||
392
b3dm/terrain_calculator.py
Normal file
392
b3dm/terrain_calculator.py
Normal file
@ -0,0 +1,392 @@
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem
|
||||
import b3dm.slope_aspect_img as slope_aspect_img
|
||||
import b3dm.slope_aspect_tif as slope_aspect_tif
|
||||
from b3dm.tileset_data_source import TilesetDataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_data_source = None
|
||||
|
||||
|
||||
class TerrainCalculator:
|
||||
"""地形坡度和坡向计算器"""
|
||||
|
||||
def preload_3dtiles(url: str) :
|
||||
# 下载3dtiles地图数据
|
||||
_data_source = TilesetDataSource(url)
|
||||
_data_source.dowload_map_data(url)
|
||||
|
||||
if not _data_source.tileset_path :
|
||||
logger.info(f"下载地图数据失败: {url}")
|
||||
return "下载地图数据失败", None
|
||||
|
||||
def generate_slopeAspect_3d_overlook(region_coords, url, overall_3d_png_name, minio_sub_path) :
|
||||
# 下载3dtiles地图数据
|
||||
_data_source = TilesetDataSource(url)
|
||||
_data_source.dowload_map_data(url)
|
||||
|
||||
if not _data_source.tileset_path :
|
||||
logger.info(f"下载地图数据失败: {url},{region_coords}")
|
||||
return "下载地图数据失败", None
|
||||
|
||||
tileset_path = _data_source.tileset_path
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
|
||||
data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords)
|
||||
|
||||
if not os.path.exists(dem_path) :
|
||||
logger.info(f"生成坡度坡向俯视图失败: {url},{region_coords}")
|
||||
return "生成坡度坡向俯视图失败", None
|
||||
|
||||
overall_3d_png_path = os.path.join(script_dir, overall_3d_png_name)
|
||||
slope_aspect_img.read_slope_aspect_by_dem(dem_path, overall_3d_png_path)
|
||||
logger.info(f"生成成功: {url},{region_coords},{overall_3d_png_path}")
|
||||
|
||||
entry_bucket, _ = _data_source.parse_minio_url(url);
|
||||
success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{overall_3d_png_name}", overall_3d_png_path)
|
||||
if success :
|
||||
return "生成成功", minio_path
|
||||
else :
|
||||
return "生成失败", None
|
||||
|
||||
def generate_slopeAspect_tif(region_coords, url, slope_aspect_tif_name, minio_sub_path) :
|
||||
# 下载3dtiles地图数据
|
||||
_data_source = TilesetDataSource(url)
|
||||
_data_source.dowload_map_data(url)
|
||||
|
||||
if not _data_source.tileset_path :
|
||||
logger.info(f"下载地图数据失败: {url},{region_coords}")
|
||||
return "下载地图数据失败", None
|
||||
|
||||
tileset_path = _data_source.tileset_path
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
|
||||
data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords)
|
||||
|
||||
if not os.path.exists(dem_path) :
|
||||
logger.info(f"生成坡度坡向tif失败: {url},{region_coords}")
|
||||
return "生成坡度坡向tif失败", None
|
||||
|
||||
slope_aspect_tif_path = os.path.join(script_dir, slope_aspect_tif_name)
|
||||
slope_aspect_tif.create_slope_aspect(dem_path, 'combined', slope_aspect_tif_path)
|
||||
logger.info(f"生成成功: {url},{region_coords},{slope_aspect_tif_path}")
|
||||
|
||||
entry_bucket, _ = _data_source.parse_minio_url(url);
|
||||
success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{slope_aspect_tif_name}", slope_aspect_tif_path)
|
||||
if success :
|
||||
return "生成成功", minio_path
|
||||
else :
|
||||
return "生成失败", None
|
||||
|
||||
@staticmethod
|
||||
def validate_vector(vector: List[float]) -> bool:
|
||||
"""验证输入向量是否有效"""
|
||||
if len(vector) != 3:
|
||||
return False
|
||||
if not all(isinstance(v, (int, float)) for v in vector):
|
||||
return False
|
||||
norm = np.linalg.norm(vector)
|
||||
return norm > 1e-10 # 避免零向量
|
||||
|
||||
@staticmethod
|
||||
def normalize_vector(vector: List[float]) -> np.ndarray:
|
||||
"""向量归一化"""
|
||||
arr = np.array(vector, dtype=np.float64)
|
||||
norm = np.linalg.norm(arr)
|
||||
return arr / norm if norm > 0 else arr
|
||||
|
||||
@staticmethod
|
||||
def calculate_slope(normal_vector: List[float]) -> Dict[str, Any]:
|
||||
"""
|
||||
计算坡度
|
||||
|
||||
Args:
|
||||
normal_vector: 法向量 [nx, ny, nz]
|
||||
|
||||
Returns:
|
||||
dict: 包含坡度(度)和相关信息
|
||||
"""
|
||||
try:
|
||||
# 验证输入
|
||||
if not TerrainCalculator.validate_vector(normal_vector):
|
||||
return {
|
||||
"error": "无效的法向量,必须是长度为3的数值列表且不能为零向量",
|
||||
"slope_deg": None
|
||||
}
|
||||
|
||||
# 归一化
|
||||
n = TerrainCalculator.normalize_vector(normal_vector)
|
||||
|
||||
# 计算坡度(使用arccos法)
|
||||
nz_abs = abs(n[2])
|
||||
|
||||
# 处理数值误差
|
||||
if nz_abs > 1.0:
|
||||
nz_abs = 1.0
|
||||
elif nz_abs < 0.0:
|
||||
nz_abs = 0.0
|
||||
|
||||
# 计算坡度(弧度)
|
||||
if abs(nz_abs - 1.0) < 1e-10: # 完全水平
|
||||
slope_rad = 0.0
|
||||
elif abs(nz_abs) < 1e-10: # 完全垂直
|
||||
slope_rad = np.pi / 2
|
||||
else:
|
||||
slope_rad = np.arccos(nz_abs)
|
||||
|
||||
# 转换为度
|
||||
slope_deg = np.degrees(slope_rad)
|
||||
|
||||
# 计算坡度百分比
|
||||
slope_percent = np.tan(slope_rad) * 100 if slope_rad < np.pi/2 else float('inf')
|
||||
|
||||
return {
|
||||
"slope_deg": float(slope_deg),
|
||||
"slope_rad": float(slope_rad),
|
||||
"slope_percent": float(slope_percent),
|
||||
"normalized_vector": n.tolist(),
|
||||
"classification": TerrainCalculator.classify_slope(slope_deg)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡度计算错误: {e}")
|
||||
return {
|
||||
"error": f"计算失败: {str(e)}",
|
||||
"slope_deg": None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def calculate_aspect(normal_vector: List[float]) -> Dict[str, Any]:
|
||||
"""
|
||||
计算坡向
|
||||
|
||||
Args:
|
||||
normal_vector: 法向量 [nx, ny, nz]
|
||||
|
||||
Returns:
|
||||
dict: 包含坡向(度)和相关信息
|
||||
"""
|
||||
try:
|
||||
# 验证输入
|
||||
if not TerrainCalculator.validate_vector(normal_vector):
|
||||
return {
|
||||
"error": "无效的法向量,必须是长度为3的数值列表且不能为零向量",
|
||||
"aspect_deg": None
|
||||
}
|
||||
|
||||
# 归一化
|
||||
n = TerrainCalculator.normalize_vector(normal_vector)
|
||||
|
||||
# 检查是否为水平面
|
||||
nx, ny, nz = n
|
||||
horizontal_magnitude = np.sqrt(nx*nx + ny*ny)
|
||||
|
||||
if horizontal_magnitude < 1e-10: # 水平面,坡向无定义
|
||||
return {
|
||||
"aspect_deg": None,
|
||||
"aspect_rad": None,
|
||||
"is_flat": True,
|
||||
"message": "水平面,坡向无定义",
|
||||
"normalized_vector": n.tolist()
|
||||
}
|
||||
|
||||
# 计算原始坡向(四象限反正切)
|
||||
# 注意:arctan2(nx, ny) 不是 arctan2(ny, nx)
|
||||
raw_angle_rad = np.arctan2(nx, ny)
|
||||
|
||||
# 转换为坡向(下坡方向 = 法向量方向 + 180°)
|
||||
aspect_rad = raw_angle_rad + np.pi
|
||||
|
||||
# 转换为度
|
||||
aspect_deg = np.degrees(aspect_rad)
|
||||
|
||||
# 归一化到 [0, 360) 范围
|
||||
aspect_deg = aspect_deg % 360.0
|
||||
|
||||
# 转换为八方向
|
||||
direction = TerrainCalculator.aspect_to_direction(aspect_deg)
|
||||
|
||||
return {
|
||||
"aspect_deg": float(aspect_deg),
|
||||
"aspect_rad": float(aspect_rad % (2*np.pi)),
|
||||
"direction": direction,
|
||||
"is_flat": False,
|
||||
"normalized_vector": n.tolist(),
|
||||
"raw_angle_deg": float(np.degrees(raw_angle_rad))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算错误: {e}")
|
||||
return {
|
||||
"error": f"计算失败: {str(e)}",
|
||||
"aspect_deg": None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def calculate_slope_aspect(normal_vector: List[float]) -> Dict[str, Any]:
|
||||
"""
|
||||
同时计算坡度和坡向
|
||||
|
||||
Args:
|
||||
normal_vector: 法向量 [nx, ny, nz]
|
||||
|
||||
Returns:
|
||||
dict: 包含坡度和坡向的综合结果
|
||||
"""
|
||||
try:
|
||||
slope_result = TerrainCalculator.calculate_slope(normal_vector)
|
||||
aspect_result = TerrainCalculator.calculate_aspect(normal_vector)
|
||||
|
||||
result = {
|
||||
"slope": slope_result,
|
||||
"aspect": aspect_result,
|
||||
"input_vector": normal_vector,
|
||||
"calculation_time": None # 可在调用处添加时间戳
|
||||
}
|
||||
|
||||
# 如果有错误,合并错误信息
|
||||
errors = []
|
||||
if "error" in slope_result and slope_result["error"]:
|
||||
errors.append(f"坡度: {slope_result['error']}")
|
||||
if "error" in aspect_result and aspect_result["error"]:
|
||||
errors.append(f"坡向: {aspect_result['error']}")
|
||||
|
||||
if errors:
|
||||
result["errors"] = errors
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"综合计算错误: {e}")
|
||||
return {
|
||||
"error": f"综合计算失败: {str(e)}"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def classify_slope(slope_deg: float) -> Dict[str, Any]:
|
||||
"""坡度分类"""
|
||||
if slope_deg < 2:
|
||||
return {"category": "平坦", "level": 0, "description": "基本平坦"}
|
||||
elif slope_deg < 5:
|
||||
return {"category": "缓坡", "level": 1, "description": "适合农业"}
|
||||
elif slope_deg < 15:
|
||||
return {"category": "斜坡", "level": 2, "description": "适合建设"}
|
||||
elif slope_deg < 30:
|
||||
return {"category": "陡坡", "level": 3, "description": "需要工程措施"}
|
||||
elif slope_deg < 45:
|
||||
return {"category": "急陡坡", "level": 4, "description": "高风险区域"}
|
||||
else:
|
||||
return {"category": "峭壁", "level": 5, "description": "危险区域"}
|
||||
|
||||
@staticmethod
|
||||
def aspect_to_direction(aspect_deg: float) -> Dict[str, Any]:
|
||||
"""将坡向转换为八方向"""
|
||||
directions = ["北", "东北", "东", "东南", "南", "西南", "西", "西北"]
|
||||
|
||||
# 计算方向索引 (45°一个区间)
|
||||
index = int((aspect_deg + 22.5) % 360 / 45)
|
||||
|
||||
return {
|
||||
"chinese": directions[index],
|
||||
"english": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"][index],
|
||||
"degree_range": {
|
||||
"min": (index * 45 - 22.5) % 360,
|
||||
"max": (index * 45 + 22.5) % 360
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def batch_calculate(vectors: List[List[float]]) -> Dict[str, Any]:
|
||||
"""批量计算多个法向量"""
|
||||
try:
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
for i, vec in enumerate(vectors):
|
||||
try:
|
||||
result = TerrainCalculator.calculate_slope_aspect(vec)
|
||||
result["index"] = i
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
errors.append({
|
||||
"index": i,
|
||||
"vector": vec,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return {
|
||||
"total": len(vectors),
|
||||
"successful": len(results),
|
||||
"failed": len(errors),
|
||||
"results": results,
|
||||
"errors": errors,
|
||||
"statistics": TerrainCalculator.calculate_statistics(results)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量计算错误: {e}")
|
||||
return {"error": f"批量计算失败: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def calculate_statistics(results: List[Dict]) -> Dict[str, Any]:
|
||||
"""计算统计信息"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
slope_values = []
|
||||
aspect_values = []
|
||||
|
||||
for r in results:
|
||||
if "slope" in r and "slope_deg" in r["slope"] and r["slope"]["slope_deg"] is not None:
|
||||
slope_values.append(r["slope"]["slope_deg"])
|
||||
if "aspect" in r and "aspect_deg" in r["aspect"] and r["aspect"]["aspect_deg"] is not None:
|
||||
aspect_values.append(r["aspect"]["aspect_deg"])
|
||||
|
||||
if slope_values:
|
||||
slope_arr = np.array(slope_values)
|
||||
aspect_arr = np.array(aspect_values) if aspect_values else np.array([])
|
||||
|
||||
stats = {
|
||||
"slope": {
|
||||
"count": len(slope_values),
|
||||
"mean": float(np.mean(slope_arr)),
|
||||
"std": float(np.std(slope_arr)),
|
||||
"min": float(np.min(slope_arr)),
|
||||
"max": float(np.max(slope_arr)),
|
||||
"median": float(np.median(slope_arr))
|
||||
}
|
||||
}
|
||||
|
||||
if aspect_values:
|
||||
# 坡向统计需要循环统计
|
||||
stats["aspect"] = {
|
||||
"count": len(aspect_values),
|
||||
"mean_vector": TerrainCalculator.circular_mean(aspect_arr),
|
||||
"concentration": TerrainCalculator.circular_concentration(aspect_arr)
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def circular_mean(angles_deg: np.ndarray) -> float:
|
||||
"""计算循环数据的平均值(角度)"""
|
||||
angles_rad = np.radians(angles_deg)
|
||||
x = np.mean(np.cos(angles_rad))
|
||||
y = np.mean(np.sin(angles_rad))
|
||||
mean_rad = np.arctan2(y, x)
|
||||
return np.degrees(mean_rad) % 360
|
||||
|
||||
@staticmethod
|
||||
def circular_concentration(angles_deg: np.ndarray) -> float:
|
||||
"""计算角度数据的集中度 (0-1)"""
|
||||
angles_rad = np.radians(angles_deg)
|
||||
x = np.mean(np.cos(angles_rad))
|
||||
y = np.mean(np.sin(angles_rad))
|
||||
return np.sqrt(x*x + y*y)
|
||||
105
b3dm/tileset_data_source.py
Normal file
105
b3dm/tileset_data_source.py
Normal file
@ -0,0 +1,105 @@
|
||||
# tileset_data_source_py
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from b3dm.data_3dtiles_manager import MinIO3DTilesManager
|
||||
import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENDPOINT_URL = "222.212.85.86:9000"
|
||||
ACCESS_KEY = "WuRenJi"
|
||||
SECRET_KEY = "WRJ@2024"
|
||||
|
||||
class TilesetDataSource:
|
||||
"""使用py3dtiles库的数据源"""
|
||||
|
||||
def __init__(self, url: str, cache_size: int = 1000):
|
||||
self.url = url
|
||||
self.tileset_path = None
|
||||
self.tileset_dir = None
|
||||
self._crs = "EPSG:4979"
|
||||
|
||||
def parse_minio_url(self, url):
|
||||
manager = MinIO3DTilesManager(
|
||||
endpoint_url=ENDPOINT_URL,
|
||||
access_key=ACCESS_KEY,
|
||||
secret_key=SECRET_KEY,
|
||||
secure=False
|
||||
)
|
||||
return manager.parse_minio_url(url)
|
||||
|
||||
def upload_file(self, bucket_name, object_name, file_path):
|
||||
manager = MinIO3DTilesManager(
|
||||
endpoint_url=ENDPOINT_URL,
|
||||
access_key=ACCESS_KEY,
|
||||
secret_key=SECRET_KEY,
|
||||
secure=False
|
||||
)
|
||||
flag, path = manager.upload_file(bucket_name, object_name, file_path)
|
||||
if flag :
|
||||
os.remove(file_path)
|
||||
return flag, path
|
||||
|
||||
|
||||
def dowload_map_data(self, url: str) :
|
||||
# 下载3dtiles地图数据
|
||||
manager = MinIO3DTilesManager(
|
||||
endpoint_url=ENDPOINT_URL,
|
||||
access_key=ACCESS_KEY,
|
||||
secret_key=SECRET_KEY,
|
||||
secure=False
|
||||
)
|
||||
success, tileset_path = manager.download_full_tileset(
|
||||
tileset_url=url,
|
||||
save_dir=f"data_3dtiles",
|
||||
region_filter=None
|
||||
)
|
||||
if success :
|
||||
self.tileset_path = os.path.abspath(tileset_path)
|
||||
self.tileset_dir = os.path.dirname(tileset_path)
|
||||
|
||||
async def get_points_in_polygon(self, polygon_coords, z_range=None):
|
||||
"""获取多边形内的点数据"""
|
||||
points = data_3dtiles_to_dem.parse_tileset(self.tileset_path, polygon_coords)
|
||||
return np.array(points)
|
||||
|
||||
async def get_data_bounds(self) -> Dict[str, List[float]]:
|
||||
"""获取数据边界"""
|
||||
|
||||
bounds = {
|
||||
"min": [float('inf'), float('inf'), float('inf')],
|
||||
"max": [-float('inf'), -float('inf'), -float('inf')]
|
||||
}
|
||||
|
||||
if self._tileset and hasattr(self._tileset.root_tile, 'bounding_volume'):
|
||||
bv = self._tileset.root_tile.bounding_volume
|
||||
if hasattr(bv, 'get_corners'):
|
||||
corners = bv.get_corners()
|
||||
if corners is not None:
|
||||
bounds["min"] = corners.min(axis=0).tolist()
|
||||
bounds["max"] = corners.max(axis=0).tolist()
|
||||
|
||||
return bounds
|
||||
|
||||
def get_crs(self) -> str:
|
||||
return self._crs
|
||||
|
||||
async def main() :
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SCRIPT_PAR_DIR = os.path.dirname(SCRIPT_DIR)
|
||||
tileset_path = os.path.join(SCRIPT_PAR_DIR, "data/3dtiles/tileset.json")
|
||||
data_source_3d_tiles = TilesetDataSource(tileset_path)
|
||||
# tileSet = TileSet()
|
||||
# path = Path(tileset_path)
|
||||
# tileset_data = tileSet.from_file(path)
|
||||
print("====================================================")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
395
b3dm/tileset_to_ply.py
Normal file
395
b3dm/tileset_to_ply.py
Normal file
@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
3D Tiles Tileset to PLY Converter
|
||||
将整个3D Tiles tileset转换为单个PLY文件
|
||||
"""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import DracoPy
|
||||
except ImportError:
|
||||
print("警告: DracoPy库未安装,无法处理Draco压缩的数据")
|
||||
print("请运行: pip install DracoPy")
|
||||
DracoPy = None
|
||||
|
||||
class TilesetToPLYConverter:
|
||||
def __init__(self):
|
||||
self.all_vertices = []
|
||||
self.vertex_count = 0
|
||||
|
||||
def multiply_matrix_vector(self, matrix, vector):
|
||||
"""4x4矩阵与4D向量相乘"""
|
||||
# matrix是16个元素的列表,按列主序排列
|
||||
# 转换为4x4矩阵(行主序)
|
||||
m = [
|
||||
[matrix[0], matrix[4], matrix[8], matrix[12]],
|
||||
[matrix[1], matrix[5], matrix[9], matrix[13]],
|
||||
[matrix[2], matrix[6], matrix[10], matrix[14]],
|
||||
[matrix[3], matrix[7], matrix[11], matrix[15]]
|
||||
]
|
||||
|
||||
# 向量扩展为齐次坐标 [x, y, z, 1]
|
||||
v = [vector[0], vector[1], vector[2], 1.0]
|
||||
|
||||
# 矩阵乘法
|
||||
result = [
|
||||
m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2] + m[0][3]*v[3],
|
||||
m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2] + m[1][3]*v[3],
|
||||
m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2] + m[2][3]*v[3]
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def multiply_matrices(self, m1, m2):
|
||||
"""两个4x4矩阵相乘"""
|
||||
# 将16元素列表转换为4x4矩阵
|
||||
def list_to_matrix(lst):
|
||||
return [
|
||||
[lst[0], lst[4], lst[8], lst[12]],
|
||||
[lst[1], lst[5], lst[9], lst[13]],
|
||||
[lst[2], lst[6], lst[10], lst[14]],
|
||||
[lst[3], lst[7], lst[11], lst[15]]
|
||||
]
|
||||
|
||||
def matrix_to_list(mat):
|
||||
return [
|
||||
mat[0][0], mat[1][0], mat[2][0], mat[3][0],
|
||||
mat[0][1], mat[1][1], mat[2][1], mat[3][1],
|
||||
mat[0][2], mat[1][2], mat[2][2], mat[3][2],
|
||||
mat[0][3], mat[1][3], mat[2][3], mat[3][3]
|
||||
]
|
||||
|
||||
a = list_to_matrix(m1)
|
||||
b = list_to_matrix(m2)
|
||||
|
||||
result = [[0 for _ in range(4)] for _ in range(4)]
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
for k in range(4):
|
||||
result[i][j] += a[i][k] * b[k][j]
|
||||
|
||||
return matrix_to_list(result)
|
||||
|
||||
def apply_transform_to_vertices(self, vertices, transform_matrix):
|
||||
"""对顶点应用变换矩阵"""
|
||||
if not transform_matrix:
|
||||
return vertices
|
||||
|
||||
transformed_vertices = []
|
||||
for vertex in vertices:
|
||||
transformed = self.multiply_matrix_vector(transform_matrix, vertex)
|
||||
transformed_vertices.append(transformed)
|
||||
|
||||
return transformed_vertices
|
||||
|
||||
def parse_tileset_json(self, tileset_path, parent_transform=None):
|
||||
"""解析tileset.json文件,收集B3DM文件和变换矩阵"""
|
||||
try:
|
||||
with open(tileset_path, 'r', encoding='utf-8') as f:
|
||||
tileset_data = json.load(f)
|
||||
|
||||
b3dm_files = []
|
||||
|
||||
def process_node(node, base_path, accumulated_transform):
|
||||
# 获取当前节点的变换矩阵
|
||||
current_transform = node.get('transform')
|
||||
|
||||
# 计算累积变换矩阵
|
||||
if current_transform and accumulated_transform:
|
||||
# 矩阵相乘:accumulated_transform * current_transform
|
||||
final_transform = self.multiply_matrices(accumulated_transform, current_transform)
|
||||
elif current_transform:
|
||||
final_transform = current_transform
|
||||
else:
|
||||
final_transform = accumulated_transform
|
||||
|
||||
if 'content' in node and 'uri' in node['content']:
|
||||
uri = node['content']['uri']
|
||||
if uri.endswith('.b3dm'):
|
||||
full_path = os.path.join(base_path, uri)
|
||||
if os.path.exists(full_path):
|
||||
b3dm_files.append((full_path, final_transform))
|
||||
elif uri.endswith('.json'):
|
||||
# 递归处理子tileset
|
||||
sub_tileset_path = os.path.join(base_path, uri)
|
||||
if os.path.exists(sub_tileset_path):
|
||||
sub_files = self.parse_tileset_json(sub_tileset_path, final_transform)
|
||||
b3dm_files.extend(sub_files)
|
||||
|
||||
if 'children' in node:
|
||||
for child in node['children']:
|
||||
process_node(child, base_path, final_transform)
|
||||
|
||||
base_path = os.path.dirname(tileset_path)
|
||||
if 'root' in tileset_data:
|
||||
process_node(tileset_data['root'], base_path, parent_transform)
|
||||
|
||||
return b3dm_files
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析tileset.json时出错: {e}")
|
||||
return []
|
||||
|
||||
def parse_b3dm_file(self, file_path):
|
||||
"""解析B3DM文件"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
# 读取B3DM头部
|
||||
magic = f.read(4)
|
||||
if magic != b'b3dm':
|
||||
print(f"警告: {file_path} 不是有效的B3DM文件")
|
||||
return None
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
feature_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
feature_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
batch_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
batch_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
|
||||
# 跳过feature table和batch table
|
||||
f.seek(28 + feature_table_json_byte_length + feature_table_binary_byte_length +
|
||||
batch_table_json_byte_length + batch_table_binary_byte_length)
|
||||
|
||||
# 读取glTF数据
|
||||
gltf_data = f.read()
|
||||
return self.parse_gltf_data(gltf_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析B3DM文件 {file_path} 失败: {e}")
|
||||
return None
|
||||
|
||||
def parse_gltf_data(self, gltf_data):
|
||||
"""解析glTF数据"""
|
||||
try:
|
||||
# 检查是否为GLB格式
|
||||
if gltf_data[:4] == b'glTF':
|
||||
return self.parse_glb_data(gltf_data)
|
||||
else:
|
||||
# 尝试作为JSON解析
|
||||
gltf_json = json.loads(gltf_data.decode('utf-8'))
|
||||
return self.extract_vertices_from_gltf(gltf_json, None)
|
||||
except Exception as e:
|
||||
print(f"解析glTF数据失败: {e}")
|
||||
return None
|
||||
|
||||
def parse_glb_data(self, glb_data):
|
||||
"""解析GLB格式的glTF数据"""
|
||||
try:
|
||||
# GLB头部: magic(4) + version(4) + length(4)
|
||||
magic = glb_data[:4]
|
||||
if magic != b'glTF':
|
||||
return None
|
||||
|
||||
version = struct.unpack('<I', glb_data[4:8])[0]
|
||||
total_length = struct.unpack('<I', glb_data[8:12])[0]
|
||||
|
||||
offset = 12
|
||||
json_data = None
|
||||
binary_data = None
|
||||
|
||||
# 读取chunks
|
||||
while offset < len(glb_data):
|
||||
if offset + 8 > len(glb_data):
|
||||
break
|
||||
|
||||
chunk_length = struct.unpack('<I', glb_data[offset:offset+4])[0]
|
||||
chunk_type = glb_data[offset+4:offset+8]
|
||||
chunk_data = glb_data[offset+8:offset+8+chunk_length]
|
||||
|
||||
if chunk_type == b'JSON':
|
||||
json_data = json.loads(chunk_data.decode('utf-8'))
|
||||
elif chunk_type == b'BIN\x00':
|
||||
binary_data = chunk_data
|
||||
|
||||
offset += 8 + chunk_length
|
||||
# 对齐到4字节边界
|
||||
offset = (offset + 3) & ~3
|
||||
|
||||
if json_data:
|
||||
return self.extract_vertices_from_gltf(json_data, binary_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析GLB数据失败: {e}")
|
||||
return None
|
||||
|
||||
def extract_vertices_from_gltf(self, gltf_json, binary_data):
|
||||
"""从glTF JSON中提取顶点数据"""
|
||||
vertices = []
|
||||
|
||||
try:
|
||||
# 检查是否使用了Draco压缩
|
||||
if 'extensionsUsed' in gltf_json and 'KHR_draco_mesh_compression' in gltf_json['extensionsUsed']:
|
||||
if DracoPy is None:
|
||||
print("警告: 检测到Draco压缩但DracoPy未安装")
|
||||
return vertices
|
||||
return self.extract_draco_vertices(gltf_json, binary_data)
|
||||
|
||||
# 处理标准glTF格式
|
||||
if 'meshes' not in gltf_json:
|
||||
return vertices
|
||||
|
||||
for mesh in gltf_json['meshes']:
|
||||
for primitive in mesh['primitives']:
|
||||
if 'attributes' in primitive and 'POSITION' in primitive['attributes']:
|
||||
position_accessor_index = primitive['attributes']['POSITION']
|
||||
|
||||
if 'accessors' in gltf_json and position_accessor_index < len(gltf_json['accessors']):
|
||||
accessor = gltf_json['accessors'][position_accessor_index]
|
||||
buffer_view_index = accessor['bufferView']
|
||||
|
||||
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
|
||||
buffer_view = gltf_json['bufferViews'][buffer_view_index]
|
||||
buffer_index = buffer_view['buffer']
|
||||
byte_offset = buffer_view.get('byteOffset', 0) + accessor.get('byteOffset', 0)
|
||||
|
||||
if binary_data and buffer_index == 0:
|
||||
# 从二进制数据中读取顶点
|
||||
component_type = accessor['componentType']
|
||||
count = accessor['count']
|
||||
|
||||
if component_type == 5126: # FLOAT
|
||||
vertex_data = struct.unpack(f'<{count*3}f',
|
||||
binary_data[byte_offset:byte_offset+count*12])
|
||||
for i in range(0, len(vertex_data), 3):
|
||||
vertices.append([vertex_data[i], vertex_data[i+1], vertex_data[i+2]])
|
||||
|
||||
except Exception as e:
|
||||
print(f"提取顶点数据失败: {e}")
|
||||
|
||||
return vertices
|
||||
|
||||
def extract_draco_vertices(self, gltf_json, binary_data):
|
||||
"""提取Draco压缩的顶点数据"""
|
||||
vertices = []
|
||||
|
||||
try:
|
||||
if 'meshes' not in gltf_json:
|
||||
return vertices
|
||||
|
||||
for mesh in gltf_json['meshes']:
|
||||
for primitive in mesh['primitives']:
|
||||
if 'extensions' in primitive and 'KHR_draco_mesh_compression' in primitive['extensions']:
|
||||
draco_ext = primitive['extensions']['KHR_draco_mesh_compression']
|
||||
buffer_view_index = draco_ext['bufferView']
|
||||
|
||||
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
|
||||
buffer_view = gltf_json['bufferViews'][buffer_view_index]
|
||||
byte_offset = buffer_view.get('byteOffset', 0)
|
||||
byte_length = buffer_view['byteLength']
|
||||
|
||||
if binary_data:
|
||||
draco_data = binary_data[byte_offset:byte_offset+byte_length]
|
||||
|
||||
# 使用DracoPy解码
|
||||
mesh_data = DracoPy.decode(draco_data)
|
||||
if hasattr(mesh_data, 'points'):
|
||||
points = mesh_data.points
|
||||
for point in points:
|
||||
vertices.append([float(point[0]), float(point[1]), float(point[2])])
|
||||
|
||||
except Exception as e:
|
||||
print(f"解码Draco数据失败: {e}")
|
||||
|
||||
return vertices
|
||||
|
||||
def save_ply_file(self, output_path):
|
||||
"""保存PLY文件"""
|
||||
try:
|
||||
with open(output_path, 'w') as f:
|
||||
# 写入PLY头部
|
||||
f.write("ply\n")
|
||||
f.write("format ascii 1.0\n")
|
||||
f.write(f"element vertex {len(self.all_vertices)}\n")
|
||||
f.write("property float x\n")
|
||||
f.write("property float y\n")
|
||||
f.write("property float z\n")
|
||||
f.write("end_header\n")
|
||||
|
||||
# 写入顶点数据
|
||||
for vertex in self.all_vertices:
|
||||
f.write(f"{vertex[0]} {vertex[1]} {vertex[2]}\n")
|
||||
|
||||
print(f"PLY文件已保存: {output_path}")
|
||||
print(f"总顶点数: {len(self.all_vertices)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存PLY文件失败: {e}")
|
||||
|
||||
def convert_tileset_to_ply(self, tileset_path, output_path):
|
||||
"""将整个tileset转换为PLY文件"""
|
||||
print(f"开始处理tileset: {tileset_path}")
|
||||
|
||||
# 解析主tileset.json
|
||||
tileset_data = self.parse_tileset_json(tileset_path)
|
||||
if not tileset_data:
|
||||
print("无法解析tileset.json文件")
|
||||
return False
|
||||
|
||||
# 获取基础路径
|
||||
base_path = os.path.dirname(tileset_path)
|
||||
|
||||
# 提取所有b3dm文件和变换矩阵
|
||||
b3dm_files = self.parse_tileset_json(tileset_path)
|
||||
print(f"找到 {len(b3dm_files)} 个B3DM文件")
|
||||
|
||||
if not b3dm_files:
|
||||
print("未找到任何B3DM文件")
|
||||
return False
|
||||
|
||||
# 处理每个b3dm文件
|
||||
processed_count = 0
|
||||
for i, (b3dm_file, transform_matrix) in enumerate(b3dm_files):
|
||||
print(f"处理文件 {i+1}/{len(b3dm_files)}: {os.path.basename(b3dm_file)}")
|
||||
|
||||
vertices = self.parse_b3dm_file(b3dm_file)
|
||||
if vertices:
|
||||
# 应用变换矩阵
|
||||
if transform_matrix:
|
||||
vertices = self.apply_transform_to_vertices(vertices, transform_matrix)
|
||||
print(f" 应用了变换矩阵")
|
||||
|
||||
self.all_vertices.extend(vertices)
|
||||
processed_count += 1
|
||||
print(f" 提取到 {len(vertices)} 个顶点")
|
||||
|
||||
print(f"\n成功处理 {processed_count}/{len(b3dm_files)} 个文件")
|
||||
print(f"总计提取 {len(self.all_vertices)} 个顶点")
|
||||
|
||||
if self.all_vertices:
|
||||
self.save_ply_file(output_path)
|
||||
return True
|
||||
else:
|
||||
print("未提取到任何顶点数据")
|
||||
return False
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("用法: python tileset_to_ply.py <tileset.json路径> [输出PLY文件路径]")
|
||||
print("示例: python tileset_to_ply.py tileset.json output.ply")
|
||||
return
|
||||
|
||||
tileset_path = sys.argv[1]
|
||||
output_path = sys.argv[2] if len(sys.argv) > 2 else "merged_tileset.ply"
|
||||
|
||||
if not os.path.exists(tileset_path):
|
||||
print(f"错误: 文件不存在 {tileset_path}")
|
||||
return
|
||||
|
||||
converter = TilesetToPLYConverter()
|
||||
success = converter.convert_tileset_to_ply(tileset_path, output_path)
|
||||
|
||||
if success:
|
||||
print("\n转换完成!")
|
||||
else:
|
||||
print("\n转换失败!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
684
b3dm/volume_calculator.py
Normal file
684
b3dm/volume_calculator.py
Normal file
@ -0,0 +1,684 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
三维模型体积计算器
|
||||
基于指定地理范围内的三维模型数据,使用三角构网方法计算体积
|
||||
"""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from scipy.spatial import Delaunay
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
import DracoPy
|
||||
except ImportError:
|
||||
print("警告: DracoPy库未安装,无法处理Draco压缩的数据")
|
||||
print("请运行: pip install DracoPy")
|
||||
DracoPy = None
|
||||
|
||||
class VolumeCalculator:
|
||||
def __init__(self, location_file):
|
||||
self.location_bounds = self.load_location_bounds(location_file)
|
||||
self.all_vertices = []
|
||||
self.filtered_vertices = []
|
||||
|
||||
def load_location_bounds(self, location_file):
|
||||
"""加载地理范围边界"""
|
||||
try:
|
||||
with open(location_file, 'r', encoding='utf-8') as f:
|
||||
coords = json.load(f)
|
||||
|
||||
# 提取经纬度范围
|
||||
lons = [coord[0] for coord in coords]
|
||||
lats = [coord[1] for coord in coords]
|
||||
elevs = [coord[2] for coord in coords]
|
||||
|
||||
bounds = {
|
||||
'min_lon': min(lons),
|
||||
'max_lon': max(lons),
|
||||
'min_lat': min(lats),
|
||||
'max_lat': max(lats),
|
||||
'min_elev': min(elevs),
|
||||
'max_elev': max(elevs),
|
||||
'coords': coords
|
||||
}
|
||||
|
||||
print(f"地理范围边界:")
|
||||
print(f" 经度: {bounds['min_lon']:.8f} ~ {bounds['max_lon']:.8f}")
|
||||
print(f" 纬度: {bounds['min_lat']:.8f} ~ {bounds['max_lat']:.8f}")
|
||||
print(f" 高程: {bounds['min_elev']:.2f} ~ {bounds['max_elev']:.2f}")
|
||||
|
||||
return bounds
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载地理范围文件失败: {e}")
|
||||
return None
|
||||
|
||||
def wgs84_to_cartesian(self, lon, lat, elev):
|
||||
"""WGS84坐标转换为笛卡尔坐标(高精度算法)"""
|
||||
# WGS84椭球参数(EPSG:4326)
|
||||
a = 6378137.0 # 长半轴 (米)
|
||||
f = 1/298.257223563 # 扁率
|
||||
e2 = 2*f - f*f # 第一偏心率平方
|
||||
|
||||
# 角度转弧度
|
||||
lon_rad = math.radians(lon)
|
||||
lat_rad = math.radians(lat)
|
||||
|
||||
# 计算卯酉圈曲率半径
|
||||
sin_lat = math.sin(lat_rad)
|
||||
cos_lat = math.cos(lat_rad)
|
||||
sin_lon = math.sin(lon_rad)
|
||||
cos_lon = math.cos(lon_rad)
|
||||
|
||||
N = a / math.sqrt(1 - e2 * sin_lat * sin_lat)
|
||||
|
||||
# 高精度笛卡尔坐标计算
|
||||
x = (N + elev) * cos_lat * cos_lon
|
||||
y = (N + elev) * cos_lat * sin_lon
|
||||
z = (N * (1 - e2) + elev) * sin_lat
|
||||
|
||||
return [x, y, z]
|
||||
|
||||
def cartesian_to_wgs84(self, x, y, z):
|
||||
"""笛卡尔坐标转换为WGS84坐标(高精度迭代法)"""
|
||||
# WGS84椭球参数
|
||||
a = 6378137.0 # 长半轴
|
||||
f = 1/298.257223563 # 扁率
|
||||
e2 = 2*f - f*f # 第一偏心率平方
|
||||
ep2 = e2 / (1 - e2) # 第二偏心率平方
|
||||
|
||||
# 计算经度(精确值)
|
||||
lon = math.atan2(y, x)
|
||||
|
||||
# 计算纬度和高程(使用Bowring迭代法)
|
||||
p = math.sqrt(x*x + y*y)
|
||||
|
||||
if p == 0:
|
||||
# 极点情况
|
||||
lat = math.pi/2 if z > 0 else -math.pi/2
|
||||
elev = abs(z) - a * math.sqrt(1 - e2)
|
||||
return [math.degrees(lon), math.degrees(lat), elev]
|
||||
|
||||
# 初始估计
|
||||
theta = math.atan2(z, p * (1 - f))
|
||||
lat_prev = math.atan2(z + ep2 * a * (1 - f) * math.sin(theta)**3,
|
||||
p - e2 * a * math.cos(theta)**3)
|
||||
|
||||
# 迭代求解纬度
|
||||
max_iterations = 10
|
||||
tolerance = 1e-12
|
||||
|
||||
for i in range(max_iterations):
|
||||
N = a / math.sqrt(1 - e2 * math.sin(lat_prev)**2)
|
||||
elev = p / math.cos(lat_prev) - N
|
||||
|
||||
# 更新纬度估计
|
||||
lat_new = math.atan2(z + e2 * N * math.sin(lat_prev), p)
|
||||
|
||||
# 检查收敛性
|
||||
if abs(lat_new - lat_prev) < tolerance:
|
||||
break
|
||||
|
||||
lat_prev = lat_new
|
||||
|
||||
# 最终计算高程
|
||||
N = a / math.sqrt(1 - e2 * math.sin(lat_prev)**2)
|
||||
elev = p / math.cos(lat_prev) - N
|
||||
|
||||
return [math.degrees(lon), math.degrees(lat_prev), elev]
|
||||
|
||||
def is_point_in_bounds(self, vertex):
|
||||
"""检查点是否在指定的地理范围内"""
|
||||
if not self.location_bounds:
|
||||
return True
|
||||
|
||||
# 将笛卡尔坐标转换为WGS84
|
||||
try:
|
||||
lon, lat, elev = self.cartesian_to_wgs84(vertex[0], vertex[1], vertex[2])
|
||||
|
||||
# 检查是否在边界范围内
|
||||
return (self.location_bounds['min_lon'] <= lon <= self.location_bounds['max_lon'] and
|
||||
self.location_bounds['min_lat'] <= lat <= self.location_bounds['max_lat'])
|
||||
except:
|
||||
return False
|
||||
|
||||
def multiply_matrix_vector(self, matrix, vector):
|
||||
"""4x4矩阵与4D向量相乘"""
|
||||
m = [
|
||||
[matrix[0], matrix[4], matrix[8], matrix[12]],
|
||||
[matrix[1], matrix[5], matrix[9], matrix[13]],
|
||||
[matrix[2], matrix[6], matrix[10], matrix[14]],
|
||||
[matrix[3], matrix[7], matrix[11], matrix[15]]
|
||||
]
|
||||
|
||||
v = [vector[0], vector[1], vector[2], 1.0]
|
||||
|
||||
result = [
|
||||
m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2] + m[0][3]*v[3],
|
||||
m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2] + m[1][3]*v[3],
|
||||
m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2] + m[2][3]*v[3]
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def multiply_matrices(self, m1, m2):
|
||||
"""两个4x4矩阵相乘"""
|
||||
def list_to_matrix(lst):
|
||||
return [
|
||||
[lst[0], lst[4], lst[8], lst[12]],
|
||||
[lst[1], lst[5], lst[9], lst[13]],
|
||||
[lst[2], lst[6], lst[10], lst[14]],
|
||||
[lst[3], lst[7], lst[11], lst[15]]
|
||||
]
|
||||
|
||||
def matrix_to_list(mat):
|
||||
return [
|
||||
mat[0][0], mat[1][0], mat[2][0], mat[3][0],
|
||||
mat[0][1], mat[1][1], mat[2][1], mat[3][1],
|
||||
mat[0][2], mat[1][2], mat[2][2], mat[3][2],
|
||||
mat[0][3], mat[1][3], mat[2][3], mat[3][3]
|
||||
]
|
||||
|
||||
a = list_to_matrix(m1)
|
||||
b = list_to_matrix(m2)
|
||||
|
||||
result = [[0 for _ in range(4)] for _ in range(4)]
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
for k in range(4):
|
||||
result[i][j] += a[i][k] * b[k][j]
|
||||
|
||||
return matrix_to_list(result)
|
||||
|
||||
def apply_transform_to_vertices(self, vertices, transform_matrix):
|
||||
"""对顶点应用变换矩阵"""
|
||||
if not transform_matrix:
|
||||
return vertices
|
||||
|
||||
transformed_vertices = []
|
||||
for vertex in vertices:
|
||||
transformed = self.multiply_matrix_vector(transform_matrix, vertex)
|
||||
transformed_vertices.append(transformed)
|
||||
|
||||
return transformed_vertices
|
||||
|
||||
def parse_tileset_json(self, tileset_path, parent_transform=None):
|
||||
"""解析tileset.json文件,收集B3DM文件和变换矩阵"""
|
||||
try:
|
||||
with open(tileset_path, 'r', encoding='utf-8') as f:
|
||||
tileset_data = json.load(f)
|
||||
|
||||
b3dm_files = []
|
||||
|
||||
def process_node(node, base_path, accumulated_transform):
|
||||
current_transform = node.get('transform')
|
||||
|
||||
if current_transform and accumulated_transform:
|
||||
final_transform = self.multiply_matrices(accumulated_transform, current_transform)
|
||||
elif current_transform:
|
||||
final_transform = current_transform
|
||||
else:
|
||||
final_transform = accumulated_transform
|
||||
|
||||
if 'content' in node and 'uri' in node['content']:
|
||||
uri = node['content']['uri']
|
||||
if uri.endswith('.b3dm'):
|
||||
full_path = os.path.join(base_path, uri)
|
||||
if os.path.exists(full_path):
|
||||
b3dm_files.append((full_path, final_transform))
|
||||
elif uri.endswith('.json'):
|
||||
sub_tileset_path = os.path.join(base_path, uri)
|
||||
if os.path.exists(sub_tileset_path):
|
||||
sub_files = self.parse_tileset_json(sub_tileset_path, final_transform)
|
||||
b3dm_files.extend(sub_files)
|
||||
|
||||
if 'children' in node:
|
||||
for child in node['children']:
|
||||
process_node(child, base_path, final_transform)
|
||||
|
||||
base_path = os.path.dirname(tileset_path)
|
||||
if 'root' in tileset_data:
|
||||
process_node(tileset_data['root'], base_path, parent_transform)
|
||||
|
||||
return b3dm_files
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析tileset.json时出错: {e}")
|
||||
return []
|
||||
|
||||
def parse_b3dm_file(self, file_path):
|
||||
"""解析B3DM文件"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
magic = f.read(4)
|
||||
if magic != b'b3dm':
|
||||
return None
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
feature_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
feature_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
batch_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
batch_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
|
||||
f.seek(28 + feature_table_json_byte_length + feature_table_binary_byte_length +
|
||||
batch_table_json_byte_length + batch_table_binary_byte_length)
|
||||
|
||||
gltf_data = f.read()
|
||||
return self.parse_gltf_data(gltf_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析B3DM文件 {file_path} 失败: {e}")
|
||||
return None
|
||||
|
||||
def parse_gltf_data(self, gltf_data):
|
||||
"""解析glTF数据"""
|
||||
try:
|
||||
if gltf_data[:4] == b'glTF':
|
||||
return self.parse_glb_data(gltf_data)
|
||||
else:
|
||||
gltf_json = json.loads(gltf_data.decode('utf-8'))
|
||||
return self.extract_vertices_from_gltf(gltf_json, None)
|
||||
except Exception as e:
|
||||
print(f"解析glTF数据失败: {e}")
|
||||
return None
|
||||
|
||||
def parse_glb_data(self, glb_data):
|
||||
"""解析GLB格式的glTF数据"""
|
||||
try:
|
||||
magic = glb_data[:4]
|
||||
if magic != b'glTF':
|
||||
return None
|
||||
|
||||
version = struct.unpack('<I', glb_data[4:8])[0]
|
||||
total_length = struct.unpack('<I', glb_data[8:12])[0]
|
||||
|
||||
offset = 12
|
||||
json_data = None
|
||||
binary_data = None
|
||||
|
||||
while offset < len(glb_data):
|
||||
if offset + 8 > len(glb_data):
|
||||
break
|
||||
|
||||
chunk_length = struct.unpack('<I', glb_data[offset:offset+4])[0]
|
||||
chunk_type = glb_data[offset+4:offset+8]
|
||||
chunk_data = glb_data[offset+8:offset+8+chunk_length]
|
||||
|
||||
if chunk_type == b'JSON':
|
||||
json_data = json.loads(chunk_data.decode('utf-8'))
|
||||
elif chunk_type == b'BIN\x00':
|
||||
binary_data = chunk_data
|
||||
|
||||
offset += 8 + chunk_length
|
||||
offset = (offset + 3) & ~3
|
||||
|
||||
if json_data:
|
||||
return self.extract_vertices_from_gltf(json_data, binary_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析GLB数据失败: {e}")
|
||||
return None
|
||||
|
||||
def extract_vertices_from_gltf(self, gltf_json, binary_data):
|
||||
"""从glTF JSON中提取顶点数据"""
|
||||
vertices = []
|
||||
|
||||
try:
|
||||
if 'extensionsUsed' in gltf_json and 'KHR_draco_mesh_compression' in gltf_json['extensionsUsed']:
|
||||
if DracoPy is None:
|
||||
print("警告: 检测到Draco压缩但DracoPy未安装")
|
||||
return vertices
|
||||
return self.extract_draco_vertices(gltf_json, binary_data)
|
||||
|
||||
if 'meshes' not in gltf_json:
|
||||
return vertices
|
||||
|
||||
for mesh in gltf_json['meshes']:
|
||||
for primitive in mesh['primitives']:
|
||||
if 'attributes' in primitive and 'POSITION' in primitive['attributes']:
|
||||
position_accessor_index = primitive['attributes']['POSITION']
|
||||
|
||||
if 'accessors' in gltf_json and position_accessor_index < len(gltf_json['accessors']):
|
||||
accessor = gltf_json['accessors'][position_accessor_index]
|
||||
buffer_view_index = accessor['bufferView']
|
||||
|
||||
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
|
||||
buffer_view = gltf_json['bufferViews'][buffer_view_index]
|
||||
buffer_index = buffer_view['buffer']
|
||||
byte_offset = buffer_view.get('byteOffset', 0) + accessor.get('byteOffset', 0)
|
||||
|
||||
if binary_data and buffer_index == 0:
|
||||
component_type = accessor['componentType']
|
||||
count = accessor['count']
|
||||
|
||||
if component_type == 5126: # FLOAT
|
||||
vertex_data = struct.unpack(f'<{count*3}f',
|
||||
binary_data[byte_offset:byte_offset+count*12])
|
||||
for i in range(0, len(vertex_data), 3):
|
||||
vertices.append([vertex_data[i], vertex_data[i+1], vertex_data[i+2]])
|
||||
|
||||
except Exception as e:
|
||||
print(f"提取顶点数据失败: {e}")
|
||||
|
||||
return vertices
|
||||
|
||||
def extract_draco_vertices(self, gltf_json, binary_data):
|
||||
"""提取Draco压缩的顶点数据"""
|
||||
vertices = []
|
||||
|
||||
try:
|
||||
if 'meshes' not in gltf_json:
|
||||
return vertices
|
||||
|
||||
for mesh in gltf_json['meshes']:
|
||||
for primitive in mesh['primitives']:
|
||||
if 'extensions' in primitive and 'KHR_draco_mesh_compression' in primitive['extensions']:
|
||||
draco_ext = primitive['extensions']['KHR_draco_mesh_compression']
|
||||
buffer_view_index = draco_ext['bufferView']
|
||||
|
||||
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
|
||||
buffer_view = gltf_json['bufferViews'][buffer_view_index]
|
||||
byte_offset = buffer_view.get('byteOffset', 0)
|
||||
byte_length = buffer_view['byteLength']
|
||||
|
||||
if binary_data:
|
||||
draco_data = binary_data[byte_offset:byte_offset+byte_length]
|
||||
|
||||
mesh_data = DracoPy.decode(draco_data)
|
||||
if hasattr(mesh_data, 'points'):
|
||||
points = mesh_data.points
|
||||
for point in points:
|
||||
vertices.append([float(point[0]), float(point[1]), float(point[2])])
|
||||
|
||||
except Exception as e:
|
||||
print(f"解码Draco数据失败: {e}")
|
||||
|
||||
return vertices
|
||||
|
||||
def calculate_triangle_angles(self, p1, p2, p3):
|
||||
"""计算三角形的三个内角(度)"""
|
||||
# 计算三边长度
|
||||
a = np.linalg.norm(p2 - p3) # 边a对应角A(p1)
|
||||
b = np.linalg.norm(p1 - p3) # 边b对应角B(p2)
|
||||
c = np.linalg.norm(p1 - p2) # 边c对应角C(p3)
|
||||
|
||||
# 避免除零错误
|
||||
if a == 0 or b == 0 or c == 0:
|
||||
return [0, 0, 0]
|
||||
|
||||
# 使用余弦定理计算角度
|
||||
try:
|
||||
# 角A = arccos((b²+c²-a²)/(2bc))
|
||||
cos_A = (b*b + c*c - a*a) / (2*b*c)
|
||||
cos_B = (a*a + c*c - b*b) / (2*a*c)
|
||||
cos_C = (a*a + b*b - c*c) / (2*a*b)
|
||||
|
||||
# 限制余弦值范围,避免数值误差
|
||||
cos_A = np.clip(cos_A, -1.0, 1.0)
|
||||
cos_B = np.clip(cos_B, -1.0, 1.0)
|
||||
cos_C = np.clip(cos_C, -1.0, 1.0)
|
||||
|
||||
angle_A = np.arccos(cos_A) * 180 / np.pi
|
||||
angle_B = np.arccos(cos_B) * 180 / np.pi
|
||||
angle_C = np.arccos(cos_C) * 180 / np.pi
|
||||
|
||||
return [angle_A, angle_B, angle_C]
|
||||
except:
|
||||
return [0, 0, 0]
|
||||
|
||||
def is_valid_triangle(self, p1, p2, p3, min_angle=10.0, max_aspect_ratio=10.0):
|
||||
"""验证三角形质量,基于角度约束和长宽比"""
|
||||
angles = self.calculate_triangle_angles(p1, p2, p3)
|
||||
|
||||
# 检查最小角度约束(参考C#代码中的10度限制)
|
||||
if min(angles) < min_angle:
|
||||
return False
|
||||
|
||||
# 计算三边长度
|
||||
a = np.linalg.norm(p2 - p3)
|
||||
b = np.linalg.norm(p1 - p3)
|
||||
c = np.linalg.norm(p1 - p2)
|
||||
|
||||
# 检查长宽比约束
|
||||
max_side = max(a, b, c)
|
||||
min_side = min(a, b, c)
|
||||
if min_side > 0 and max_side / min_side > max_aspect_ratio:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def calculate_circumcenter_and_radius(self, p1, p2, p3):
|
||||
"""计算三角形外接圆圆心和半径(高精度算法)"""
|
||||
try:
|
||||
x1, y1 = p1[0], p1[1]
|
||||
x2, y2 = p2[0], p2[1]
|
||||
x3, y3 = p3[0], p3[1]
|
||||
|
||||
# 使用C#代码中的高精度外接圆计算公式
|
||||
d = 2 * (x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2))
|
||||
|
||||
if abs(d) < 1e-10: # 三点共线
|
||||
return None, float('inf')
|
||||
|
||||
ux = ((x1*x1 + y1*y1) * (y2 - y3) + (x2*x2 + y2*y2) * (y3 - y1) + (x3*x3 + y3*y3) * (y1 - y2)) / d
|
||||
uy = ((x1*x1 + y1*y1) * (x3 - x2) + (x2*x2 + y2*y2) * (x1 - y3) + (x3*x3 + y3*y3) * (x2 - x1)) / d
|
||||
|
||||
# 计算半径
|
||||
radius = np.sqrt((ux - x1)**2 + (uy - y1)**2)
|
||||
|
||||
return np.array([ux, uy]), radius
|
||||
except:
|
||||
return None, float('inf')
|
||||
|
||||
def calculate_volume_delaunay(self, vertices, base_elevation=None, min_angle=10.0, use_quality_filter=True):
|
||||
"""使用Delaunay三角剖分计算体积"""
|
||||
if len(vertices) < 4:
|
||||
print("顶点数量不足,无法进行三角剖分")
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
points = np.array(vertices)
|
||||
|
||||
if base_elevation is None:
|
||||
base_elevation = np.min(points[:, 2])
|
||||
|
||||
print(f"Delaunay方法使用基准面高程: {base_elevation:.2f} 米")
|
||||
print(f"质量控制参数: 最小角度={min_angle}°, 质量过滤={'开启' if use_quality_filter else '关闭'}")
|
||||
|
||||
adjusted_points = points.copy()
|
||||
adjusted_points[:, 2] = np.maximum(points[:, 2] - base_elevation, 0)
|
||||
|
||||
print("正在进行Delaunay三角剖分...")
|
||||
tri = Delaunay(adjusted_points)
|
||||
|
||||
valid_simplices = []
|
||||
total_simplices = len(tri.simplices)
|
||||
|
||||
if use_quality_filter:
|
||||
print("正在进行三角形质量检查...")
|
||||
for simplex in tqdm(tri.simplices, desc="质量检查", unit="个"):
|
||||
p0 = adjusted_points[simplex[0]]
|
||||
p1 = adjusted_points[simplex[1]]
|
||||
p2 = adjusted_points[simplex[2]]
|
||||
p3 = adjusted_points[simplex[3]]
|
||||
|
||||
faces = [(p0, p1, p2), (p0, p1, p3), (p0, p2, p3), (p1, p2, p3)]
|
||||
valid_faces = 0
|
||||
|
||||
for face in faces:
|
||||
if self.is_valid_triangle(face[0], face[1], face[2], min_angle):
|
||||
valid_faces += 1
|
||||
|
||||
if valid_faces >= 3:
|
||||
valid_simplices.append(simplex)
|
||||
|
||||
print(f"质量过滤: {len(valid_simplices)}/{total_simplices} 个四面体通过质量检查")
|
||||
else:
|
||||
valid_simplices = tri.simplices
|
||||
|
||||
total_volume = 0.0
|
||||
valid_volume_count = 0
|
||||
|
||||
print(f"计算 {len(valid_simplices)} 个四面体的体积...")
|
||||
|
||||
for simplex in tqdm(valid_simplices, desc="计算四面体体积", unit="个"):
|
||||
p0 = adjusted_points[simplex[0]]
|
||||
p1 = adjusted_points[simplex[1]]
|
||||
p2 = adjusted_points[simplex[2]]
|
||||
p3 = adjusted_points[simplex[3]]
|
||||
|
||||
v1 = p1 - p0
|
||||
v2 = p2 - p0
|
||||
v3 = p3 - p0
|
||||
|
||||
det = np.linalg.det(np.array([v1, v2, v3]))
|
||||
volume = abs(det) / 6.0
|
||||
|
||||
if volume > 1e-12:
|
||||
total_volume += volume
|
||||
valid_volume_count += 1
|
||||
|
||||
print(f"有效体积计算: {valid_volume_count}/{len(valid_simplices)} 个四面体")
|
||||
|
||||
return total_volume
|
||||
|
||||
except Exception as e:
|
||||
print(f"Delaunay三角剖分计算体积失败: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def load_and_filter_vertices(self, tileset_path):
|
||||
"""加载并过滤指定范围内的顶点"""
|
||||
print(f"开始处理tileset: {tileset_path}")
|
||||
|
||||
# 解析tileset获取所有B3DM文件
|
||||
b3dm_files = self.parse_tileset_json(tileset_path)
|
||||
print(f"找到 {len(b3dm_files)} 个B3DM文件")
|
||||
|
||||
if not b3dm_files:
|
||||
print("未找到任何B3DM文件")
|
||||
return False
|
||||
|
||||
# 处理每个B3DM文件
|
||||
processed_count = 0
|
||||
total_vertices = 0
|
||||
filtered_count = 0
|
||||
|
||||
for i, (b3dm_file, transform_matrix) in enumerate(tqdm(b3dm_files, desc="处理B3DM文件", unit="文件")):
|
||||
# print(f"处理文件 {i+1}/{len(b3dm_files)}: {os.path.basename(b3dm_file)}")
|
||||
|
||||
vertices = self.parse_b3dm_file(b3dm_file)
|
||||
if vertices:
|
||||
# 应用变换矩阵
|
||||
if transform_matrix:
|
||||
vertices = self.apply_transform_to_vertices(vertices, transform_matrix)
|
||||
|
||||
# 过滤范围内的顶点
|
||||
for vertex in vertices:
|
||||
total_vertices += 1
|
||||
if self.is_point_in_bounds(vertex):
|
||||
self.filtered_vertices.append(vertex)
|
||||
filtered_count += 1
|
||||
|
||||
self.all_vertices.extend(vertices)
|
||||
processed_count += 1
|
||||
# tqdm.write(f" {os.path.basename(b3dm_file)}: 提取到 {len(vertices)} 个顶点")
|
||||
|
||||
print(f"\n成功处理 {processed_count}/{len(b3dm_files)} 个文件")
|
||||
print(f"总计提取 {total_vertices} 个顶点")
|
||||
print(f"范围内顶点 {filtered_count} 个")
|
||||
|
||||
return len(self.filtered_vertices) > 0
|
||||
|
||||
|
||||
|
||||
def calculate_volume(self, tileset_path, base_elevation=None, min_angle=10.0, use_quality_filter=True):
|
||||
"""计算指定范围内三维模型的体积
|
||||
|
||||
Args:
|
||||
tileset_path: tileset.json文件路径
|
||||
base_elevation: 基准面高程
|
||||
min_angle: 最小角度约束(度)
|
||||
use_quality_filter: 是否启用质量过滤
|
||||
"""
|
||||
if not self.load_and_filter_vertices(tileset_path):
|
||||
print("未找到范围内的顶点数据")
|
||||
return 0.0
|
||||
|
||||
print(f"\n开始计算体积,使用Delaunay三角剖分方法")
|
||||
print(f"参与计算的顶点数: {len(self.filtered_vertices)}")
|
||||
|
||||
points = np.array(self.filtered_vertices)
|
||||
if base_elevation is None:
|
||||
base_elevation = np.min(points[:, 2])
|
||||
|
||||
print(f"统一基准面高程: {base_elevation:.2f} 米")
|
||||
|
||||
volume = self.calculate_volume_delaunay(self.filtered_vertices, base_elevation, min_angle, use_quality_filter)
|
||||
|
||||
print(f"\n计算结果:")
|
||||
print(f"体积: {volume:.6f} 立方米")
|
||||
print(f"体积: {volume/1000000:.6f} 立方千米")
|
||||
|
||||
return volume
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print("用法: python volume_calculator.py <lboundary.json路径> <tileset.json路径> [基准面高程] [最小角度] [质量过滤]")
|
||||
print("基准面高程: 基准面高程(米),默认使用最低点")
|
||||
print("最小角度: 三角形最小角度约束(度),默认10.0")
|
||||
print("质量过滤: 是否启用质量过滤(true/false),默认true")
|
||||
print("示例: python volume_calculator.py boundary.json tileset.json 100.0 15.0 true")
|
||||
return
|
||||
|
||||
location_file = sys.argv[1]
|
||||
tileset_path = sys.argv[2]
|
||||
|
||||
# 解析可选参数
|
||||
base_elevation = None
|
||||
min_angle = 10.0
|
||||
use_quality_filter = True
|
||||
|
||||
if len(sys.argv) > 3:
|
||||
try:
|
||||
base_elevation = float(sys.argv[3])
|
||||
except ValueError:
|
||||
print("错误:基准面高程必须是数字")
|
||||
return
|
||||
|
||||
if len(sys.argv) > 4:
|
||||
try:
|
||||
min_angle = float(sys.argv[4])
|
||||
except ValueError:
|
||||
print("错误:最小角度必须是数字")
|
||||
return
|
||||
|
||||
if len(sys.argv) > 5:
|
||||
use_quality_filter = sys.argv[5].lower() in ['true', '1', 'yes', 'on']
|
||||
|
||||
if not os.path.exists(location_file):
|
||||
print(f"错误: 位置文件不存在 {location_file}")
|
||||
return
|
||||
|
||||
if not os.path.exists(tileset_path):
|
||||
print(f"错误: tileset文件不存在 {tileset_path}")
|
||||
return
|
||||
|
||||
calculator = VolumeCalculator(location_file)
|
||||
volume = calculator.calculate_volume(tileset_path, base_elevation, min_angle, use_quality_filter)
|
||||
|
||||
if volume > 0:
|
||||
print("\n体积计算完成!")
|
||||
else:
|
||||
print("\n体积计算失败!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
47
check_grpc_server.py
Normal file
47
check_grpc_server.py
Normal file
@ -0,0 +1,47 @@
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import time
|
||||
|
||||
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
|
||||
|
||||
|
||||
class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer):
|
||||
def ProcessTask(self, request, context):
|
||||
print(f"Received task_id: {request.task_id}")
|
||||
|
||||
return check_grpc_pb2.TaskResponse(
|
||||
task_id=request.task_id,
|
||||
success=True,
|
||||
message="Task processed successfully"
|
||||
)
|
||||
|
||||
|
||||
class HealthCheckServicer(check_grpc_pb2_grpc.HealthCheckServicer):
|
||||
def Check(self, request, context):
|
||||
# 简单实现:总是返回SERVING状态
|
||||
# 实际应用中可以根据服务状态返回不同值
|
||||
return check_grpc_pb2.HealthCheckResponse(
|
||||
status=check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
|
||||
# 添加服务实现
|
||||
check_grpc_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
|
||||
check_grpc_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
|
||||
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
print("Server started, listening on port 50051...")
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(86400) # 保持运行
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
serve()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
Before Width: | Height: | Size: 4.2 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 5.3 MiB |
@ -1,8 +0,0 @@
|
||||
import os.path
|
||||
|
||||
from middleware.minio_util import downBigFile
|
||||
|
||||
|
||||
miniourl=r"media/22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_202509121519_001_22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_20250912152112_0001_V.mp4"
|
||||
file_path=downBigFile(miniourl)
|
||||
print(f"os.path.abspath(file_path) {os.path.abspath(file_path)}")
|
||||
@ -94,7 +94,7 @@ def func_100000(results, cls_id_list, type_name_list, func_id_10001, list_track_
|
||||
trickier_detail = {
|
||||
# "track_id": results.track_ids[i],
|
||||
"confidence": results.confs[i],
|
||||
"cls_id": i,
|
||||
"cls_id": ind,
|
||||
"type_name": type_name_list[ind],
|
||||
"box": boxes[i]
|
||||
}
|
||||
|
||||
58
grpc_proto/check_grpc/check_grpc.proto
Normal file
58
grpc_proto/check_grpc/check_grpc.proto
Normal file
@ -0,0 +1,58 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package task;
|
||||
|
||||
service TaskService {
|
||||
rpc ProcessTask (TaskRequest) returns (TaskResponse);
|
||||
}
|
||||
|
||||
// 添加健康检查服务
|
||||
service HealthCheck {
|
||||
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
|
||||
}
|
||||
|
||||
message HealthCheckRequest {
|
||||
string service = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
enum ServingStatus {
|
||||
UNKNOWN = 0;
|
||||
SERVING = 1;
|
||||
NOT_SERVING = 2;
|
||||
SERVICE_UNKNOWN = 3;
|
||||
}
|
||||
ServingStatus status = 1;
|
||||
}
|
||||
|
||||
message TaskRequest {
|
||||
string task_id = 1;
|
||||
string sn = 2;
|
||||
ContentBody content_body = 3;
|
||||
}
|
||||
|
||||
message ContentBody {
|
||||
string org_code = 1;
|
||||
repeated int32 func_id = 2;
|
||||
string source_url = 3;
|
||||
string push_url = 4;
|
||||
float confidence = 5;
|
||||
repeated ParaList para_list = 6;
|
||||
Invade invade = 7;
|
||||
}
|
||||
|
||||
message ParaList {
|
||||
int32 func_id = 1;
|
||||
bool para_invade_enable = 2;
|
||||
}
|
||||
|
||||
message Invade {
|
||||
string invade_file = 1;
|
||||
string camera_para_url = 2;
|
||||
}
|
||||
|
||||
message TaskResponse {
|
||||
string task_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
54
grpc_proto/check_grpc/check_grpc_pb2.py
Normal file
54
grpc_proto/check_grpc/check_grpc_pb2.py
Normal file
@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: grpc_proto/check_grpc/check_grpc.proto
|
||||
# Protobuf Python Version: 6.31.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
6,
|
||||
31,
|
||||
1,
|
||||
'',
|
||||
'grpc_proto/check_grpc/check_grpc.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&grpc_proto/check_grpc/check_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_proto.check_grpc.check_grpc_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=48
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=85
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=88
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=247
|
||||
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=168
|
||||
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=247
|
||||
_globals['_TASKREQUEST']._serialized_start=249
|
||||
_globals['_TASKREQUEST']._serialized_end=332
|
||||
_globals['_CONTENTBODY']._serialized_start=335
|
||||
_globals['_CONTENTBODY']._serialized_end=506
|
||||
_globals['_PARALIST']._serialized_start=508
|
||||
_globals['_PARALIST']._serialized_end=563
|
||||
_globals['_INVADE']._serialized_start=565
|
||||
_globals['_INVADE']._serialized_end=619
|
||||
_globals['_TASKRESPONSE']._serialized_start=621
|
||||
_globals['_TASKRESPONSE']._serialized_end=686
|
||||
_globals['_TASKSERVICE']._serialized_start=688
|
||||
_globals['_TASKSERVICE']._serialized_end=755
|
||||
_globals['_HEALTHCHECK']._serialized_start=757
|
||||
_globals['_HEALTHCHECK']._serialized_end=832
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
172
grpc_proto/check_grpc/check_grpc_pb2_grpc.py
Normal file
172
grpc_proto/check_grpc/check_grpc_pb2_grpc.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
from grpc_proto.check_grpc import check_grpc_pb2 as grpc__proto_dot_check__grpc_dot_check__grpc__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.76.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ ' but the generated code in grpc_proto/check_grpc/check_grpc_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class TaskServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.ProcessTask = channel.unary_unary(
|
||||
'/task.TaskService/ProcessTask',
|
||||
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString,
|
||||
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class TaskServiceServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def ProcessTask(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_TaskServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'ProcessTask': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ProcessTask,
|
||||
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.FromString,
|
||||
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'task.TaskService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('task.TaskService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class TaskService(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def ProcessTask(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/task.TaskService/ProcessTask',
|
||||
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString,
|
||||
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class HealthCheckStub(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Check = channel.unary_unary(
|
||||
'/task.HealthCheck/Check',
|
||||
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class HealthCheckServicer(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
def Check(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_HealthCheckServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Check': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Check,
|
||||
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'task.HealthCheck', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('task.HealthCheck', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class HealthCheck(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def Check(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/task.HealthCheck/Check',
|
||||
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString,
|
||||
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
81
grpc_proto/check_grpc_client.py
Normal file
81
grpc_proto/check_grpc_client.py
Normal file
@ -0,0 +1,81 @@
|
||||
import grpc
|
||||
import time
|
||||
|
||||
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
|
||||
|
||||
|
||||
def check_server_status(channel):
|
||||
try:
|
||||
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
|
||||
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
|
||||
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
except grpc.RpcError as e:
|
||||
print(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_grpc_request(max_retries=3, delay=5):
|
||||
channel = None
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
# 创建通道
|
||||
channel = grpc.insecure_channel('localhost:50051')
|
||||
|
||||
# 检查服务器状态
|
||||
if not check_server_status(channel):
|
||||
raise Exception("Server is not healthy")
|
||||
|
||||
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
|
||||
|
||||
# 创建请求消息
|
||||
request = check_grpc_pb2.TaskRequest(
|
||||
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
|
||||
sn="8UUXN6S00A0CK7",
|
||||
content_body=check_grpc_pb2.ContentBody(
|
||||
org_code="HMZHB",
|
||||
func_id=[101204],
|
||||
source_url="xxxxxxxxxx",
|
||||
push_url="",
|
||||
confidence=0.4,
|
||||
para_list=[
|
||||
check_grpc_pb2.ParaList(
|
||||
func_id=101204,
|
||||
para_invade_enable=True
|
||||
)
|
||||
],
|
||||
invade=check_grpc_pb2.Invade(
|
||||
invade_file="meta_data/高压线-0826.geojson",
|
||||
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 调用远程方法
|
||||
response = stub.ProcessTask(request)
|
||||
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
retries += 1
|
||||
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
finally:
|
||||
if channel:
|
||||
channel.close()
|
||||
|
||||
print("All retry attempts failed")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_grpc_request()
|
||||
284
md/grpc.md
Normal file
284
md/grpc.md
Normal file
@ -0,0 +1,284 @@
|
||||
# 前言
|
||||
|
||||
sanic 和 服务之间基于grpc 解绑,一个服务一个grpc
|
||||
|
||||
|
||||
|
||||
可以参考接口,grpc 最好留有健康校验
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# grpc demo
|
||||
|
||||
为了增强 gRPC 通讯的可靠性,我们可以添加以下功能:
|
||||
|
||||
1. 服务器健康检查服务
|
||||
2. 客户端连接前检查服务器状态
|
||||
3. 通讯过程中的错误处理和重试机制
|
||||
|
||||
## 1. 修改 protobuf 定义 (task.proto)
|
||||
|
||||
首先添加健康检查服务定义:
|
||||
|
||||
```protobuf
|
||||
syntax = "proto3";
|
||||
|
||||
package task;
|
||||
|
||||
service TaskService {
|
||||
rpc ProcessTask (TaskRequest) returns (TaskResponse);
|
||||
}
|
||||
|
||||
// 添加健康检查服务
|
||||
service HealthCheck {
|
||||
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
|
||||
}
|
||||
|
||||
message HealthCheckRequest {
|
||||
string service = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
enum ServingStatus {
|
||||
UNKNOWN = 0;
|
||||
SERVING = 1;
|
||||
NOT_SERVING = 2;
|
||||
SERVICE_UNKNOWN = 3;
|
||||
}
|
||||
ServingStatus status = 1;
|
||||
}
|
||||
|
||||
message TaskRequest {
|
||||
string task_id = 1;
|
||||
string sn = 2;
|
||||
ContentBody content_body = 3;
|
||||
}
|
||||
|
||||
message ContentBody {
|
||||
string org_code = 1;
|
||||
repeated int32 func_id = 2;
|
||||
string source_url = 3;
|
||||
string push_url = 4;
|
||||
float confidence = 5;
|
||||
repeated ParaList para_list = 6;
|
||||
Invade invade = 7;
|
||||
}
|
||||
|
||||
message ParaList {
|
||||
int32 func_id = 1;
|
||||
bool para_invade_enable = 2;
|
||||
}
|
||||
|
||||
message Invade {
|
||||
string invade_file = 1;
|
||||
string camera_para_url = 2;
|
||||
}
|
||||
|
||||
message TaskResponse {
|
||||
string task_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
```
|
||||
|
||||
## 2. 重新生成 Python 代码
|
||||
|
||||
```bash
|
||||
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. task.proto
|
||||
protoc -I=./proto --go_out=./gen proto/subfolder/task.proto
|
||||
```
|
||||
|
||||
## 3. 增强服务端实现 (server.py)
|
||||
|
||||
```python
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import time
|
||||
import task_pb2
|
||||
import task_pb2_grpc
|
||||
|
||||
class TaskServiceServicer(task_pb2_grpc.TaskServiceServicer):
|
||||
def ProcessTask(self, request, context):
|
||||
print(f"Received task_id: {request.task_id}")
|
||||
# ... 原有处理逻辑
|
||||
return task_pb2.TaskResponse(
|
||||
task_id=request.task_id,
|
||||
success=True,
|
||||
message="Task processed successfully"
|
||||
)
|
||||
|
||||
class HealthCheckServicer(task_pb2_grpc.HealthCheckServicer):
|
||||
def Check(self, request, context):
|
||||
# 简单实现:总是返回SERVING状态
|
||||
# 实际应用中可以根据服务状态返回不同值
|
||||
return task_pb2.HealthCheckResponse(
|
||||
status=task_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
)
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
|
||||
# 添加服务实现
|
||||
task_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
|
||||
task_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
|
||||
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
print("Server started, listening on port 50051...")
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(86400) # 保持运行
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
serve()
|
||||
```
|
||||
|
||||
## 4. 增强客户端实现 (client.py)
|
||||
|
||||
```python
|
||||
import grpc
|
||||
import time
|
||||
import task_pb2
|
||||
import task_pb2_grpc
|
||||
|
||||
def check_server_status(channel):
|
||||
try:
|
||||
health_stub = task_pb2_grpc.HealthCheckStub(channel)
|
||||
response = health_stub.Check(task_pb2.HealthCheckRequest(service="TaskService"))
|
||||
return response.status == task_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
except grpc.RpcError as e:
|
||||
print(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
def run_with_retry(max_retries=3, delay=5):
|
||||
channel = None
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
# 创建通道
|
||||
channel = grpc.insecure_channel('localhost:50051')
|
||||
|
||||
# 检查服务器状态
|
||||
if not check_server_status(channel):
|
||||
raise Exception("Server is not healthy")
|
||||
|
||||
stub = task_pb2_grpc.TaskServiceStub(channel)
|
||||
|
||||
# 创建请求消息
|
||||
request = task_pb2.TaskRequest(
|
||||
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
|
||||
sn="8UUXN6S00A0CK7",
|
||||
content_body=task_pb2.ContentBody(
|
||||
org_code="HMZHB",
|
||||
func_id=[101204],
|
||||
source_url="rtmp://222.212.85.86:1935/live/1581F8HGX253S00A05L8",
|
||||
push_url="",
|
||||
confidence=0.4,
|
||||
para_list=[
|
||||
task_pb2.ParaList(
|
||||
func_id=101204,
|
||||
para_invade_enable=True
|
||||
)
|
||||
],
|
||||
invade=task_pb2.Invade(
|
||||
invade_file="meta_data/高压线-0826.geojson",
|
||||
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 调用远程方法
|
||||
response = stub.ProcessTask(request)
|
||||
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
retries += 1
|
||||
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
finally:
|
||||
if channel:
|
||||
channel.close()
|
||||
|
||||
print("All retry attempts failed")
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_with_retry()
|
||||
```
|
||||
|
||||
## 5. 增强功能说明
|
||||
|
||||
1. 健康检查服务
|
||||
- 添加了标准的 gRPC 健康检查服务
|
||||
- 客户端可以在执行主要操作前检查服务状态
|
||||
2. 错误处理和重试机制
|
||||
- 客户端现在会捕获 `grpc.RpcError` 和其他异常
|
||||
- 实现了最大重试次数和重试间隔
|
||||
- 每次重试前都会检查服务器状态
|
||||
3. 资源管理
|
||||
- 确保在所有情况下都正确关闭通道
|
||||
- 使用上下文管理器处理通道生命周期
|
||||
4. 状态反馈
|
||||
- 提供更详细的错误信息
|
||||
- 记录重试尝试
|
||||
|
||||
## 6. 运行说明
|
||||
|
||||
1. 安装依赖:
|
||||
|
||||
```bash
|
||||
pip install grpcio grpcio-tools
|
||||
```
|
||||
|
||||
2. 生成 protobuf 代码:
|
||||
|
||||
```bash
|
||||
|
||||
# 确保在项目根目录下执行
|
||||
# 编译 proto 文件,输出到 a/b/ 目录
|
||||
python -m grpc_tools.protoc \
|
||||
-I. \ # 指定 proto 文件的根目录(当前目录 ".")
|
||||
--python_out=. \ # 生成 _pb2.py 到当前目录(或指定子目录)
|
||||
--grpc_python_out=. \ # 生成 _pb2_grpc.py 到当前目录(或指定子目录)
|
||||
a/b/task.proto # proto 文件路径(相对于 -I 指定的根目录)
|
||||
```
|
||||
|
||||
3. 启动服务器:
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
```
|
||||
|
||||
4. 运行客户端:
|
||||
|
||||
```bash
|
||||
python client.py
|
||||
```
|
||||
|
||||
## 7. 测试场景
|
||||
|
||||
1. 服务器未运行
|
||||
- 客户端会检测到连接失败并重试
|
||||
- 最终显示所有重试失败
|
||||
2. 服务器运行但健康检查失败
|
||||
- 可以修改 `HealthCheckServicer` 返回 `NOT_SERVING` 状态进行测试
|
||||
- 客户端会拒绝执行主要操作
|
||||
3. 网络中断
|
||||
- 客户端会捕获异常并尝试重试
|
||||
|
||||
这个增强版本提供了更健壮的 gRPC 通讯机制,适合生产环境使用。
|
||||
350
md/接口.md
Normal file
350
md/接口.md
Normal file
@ -0,0 +1,350 @@
|
||||
# 算法与后台解耦规则
|
||||
|
||||
# 1、方法
|
||||
|
||||
postgres 的ai_model_list 表,id字段声明为6位长度数字
|
||||
|
||||
1、第1位表示算法类别,1xxxxx 表明为目标识别、2xxxxx标明为语义分割、3xxxxx表示变化监测
|
||||
|
||||
2、最后两位表示二次计算,100001 表示做目标识别、100002表示做目标识别,且做人员计数
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 接口名:视频流识别
|
||||
|
||||
接收前端的视频流、模型、识别类型,算法做计算,并且将计算结果存储到minio,消息通过mqtt发送
|
||||
|
||||
## 1、请求
|
||||
|
||||
接口 /ai/stream/back_detect
|
||||
|
||||
方法 post
|
||||
|
||||
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
|
||||
|
||||
body
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "1234567890", #任务id
|
||||
"sn":"", #无人机sn
|
||||
"content_body": {
|
||||
"source_url": "rtmp://192.168.0.142:1935/live/123456", #无人机视频流url
|
||||
"confidence":0.4, #置信度
|
||||
"model_func_id":[100001,100002] #方法id
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2、响应
|
||||
|
||||
算法的响应分为两个部分
|
||||
|
||||
1、rest响应,表明收到请求
|
||||
|
||||
2、mqtt消息,持续输出计算结果
|
||||
|
||||
### 1、rest
|
||||
|
||||
```
|
||||
{
|
||||
"status": "success",
|
||||
"task_id": "1234567890",
|
||||
"message": "Detection started successfully"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 2、mqtt
|
||||
|
||||
ip 112.44.103.230 端口 1883
|
||||
|
||||
topic thing/product/ai/events
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "1234567890",
|
||||
"minio": {
|
||||
"minio_path": "ai_result/20250702/1751425303860-output-1751425303800959985.jpg",
|
||||
"file_type": "pic"
|
||||
},
|
||||
"box_detail": {
|
||||
"result_100001": {
|
||||
"func_id_10001": 100001,
|
||||
"type_name": "行人",
|
||||
"cls_count": 1,
|
||||
"box_count": [
|
||||
[
|
||||
{
|
||||
"track_id": 22099,
|
||||
"confidence": 0.34013107419013977,
|
||||
"cls_id": 0,
|
||||
"type_name": "行人",
|
||||
"box": [
|
||||
15.935794830322266,
|
||||
694.75390625,
|
||||
33.22901916503906,
|
||||
713.1658935546875
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
},
|
||||
"uav_location": {
|
||||
"data": {
|
||||
"attitude_head": 60,
|
||||
"gimbal_pitch": 60,
|
||||
"gimbal_roll": 60,
|
||||
"gimbal_yaw": 60,
|
||||
"height": 10,
|
||||
"latitude": 10,
|
||||
"longitude": 10,
|
||||
"speed_x": 10,
|
||||
"speed_y": 10,
|
||||
"speed_z": 10
|
||||
},
|
||||
"timestamp": 1751425301213249700
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
# 接口名:图片识别
|
||||
|
||||
接收前端的图片,算法做计算,并且将计算结果存储到minio,消息通过mqtt发送
|
||||
|
||||
## 1、请求
|
||||
|
||||
接口 /ai/pic/back_detect_pic
|
||||
|
||||
方法 post
|
||||
|
||||
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
|
||||
|
||||
body
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "0001111",
|
||||
"content_body": {
|
||||
"s3_id":1, #根据id适配,minio相关存储参数
|
||||
"s3_url":[
|
||||
"test/frame_0000.jpg","test/frame_0001.jpg","test/frame_0002.jpg" # minio文件地址
|
||||
],
|
||||
"confidence":0.4, #算法置信度
|
||||
"model_func_id":[10001,10002] #方法id
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2、响应
|
||||
|
||||
算法的响应分为两个部分
|
||||
|
||||
1、rest响应,表明收到请求
|
||||
|
||||
2、mqtt消息,持续输出计算结果
|
||||
|
||||
### 1、rest
|
||||
|
||||
```
|
||||
{
|
||||
"status": "success",
|
||||
"task_id": "0001111",
|
||||
"message": "Detection started successfully"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 2、mqtt
|
||||
|
||||
ip 112.44.103.230 端口 1883
|
||||
|
||||
topic thing/product/ai/events
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "0001111", #任务id
|
||||
"minio": {
|
||||
"minio_path": "ai_result/20250627/1751006943659-frame_0001.jpg", # minio 存储路径
|
||||
"file_type": "pic"
|
||||
},
|
||||
"box_detail": {
|
||||
"model_id": 10001,
|
||||
"box_count": [
|
||||
{
|
||||
"type": 3, # 类型
|
||||
"type_name": "车辆", #类型名称
|
||||
"count": 71 #数量
|
||||
},
|
||||
{
|
||||
"type": 0,
|
||||
"type_name": "车辆",
|
||||
"count": 7
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
# 接口名:地类分割
|
||||
|
||||
|
||||
|
||||
接收前端的图片,算法做计算,并且将计算结果存储到minio,消息通过mqtt发送
|
||||
|
||||
## 1、请求
|
||||
|
||||
接口 /ai/pic/back_detect_pic
|
||||
|
||||
方法 post
|
||||
|
||||
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
|
||||
|
||||
body
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
|
||||
"content_body": {
|
||||
"s3_id":1,#根据id适配,minio相关存储参数
|
||||
"s3_url":[
|
||||
"test/patch_0011.png", # minio文件地址
|
||||
"test/patch_0012.png"
|
||||
],
|
||||
"model_func_id":[20000,20001] #方法id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
```
|
||||
|
||||
## 2、响应
|
||||
|
||||
算法的响应分为两个部分
|
||||
|
||||
1、rest响应,表明收到请求
|
||||
|
||||
2、mqtt消息,持续输出计算结果
|
||||
|
||||
### 1、rest
|
||||
|
||||
```
|
||||
{
|
||||
"status": "success",
|
||||
"task_id": "0001111",
|
||||
"message": "Detection started successfully"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 2、mqtt
|
||||
|
||||
ip 112.44.103.230 端口 1883
|
||||
|
||||
topic thing/product/ai/events
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
|
||||
"minio": [
|
||||
{
|
||||
"minio_path_before": "ai_result/20250710/1752128232469-patch_0011.png", # 需要分割的图片
|
||||
"minio_path_after": "ai_result/20250710/1752128234222-patch_0011.png", #分割之后的图片
|
||||
"minio_path_boundary": "ai_result/20250710/1752128234264-patch_0011.pngfinal_vis.png", # 分割的边界图片
|
||||
"minio_path_json": "ai_result/20250710/1752128234326-patch_0011.pnginstance_results.json", #分割生成的json文件
|
||||
"file_type": "pic"
|
||||
},
|
||||
{
|
||||
"minio_path_before": "ai_result/20250710/1752128240382-patch_0012.png",
|
||||
"minio_path_after": "ai_result/20250710/1752128241553-patch_0012.png",
|
||||
"minio_path_boundary": "ai_result/20250710/1752128241587-patch_0012.pngfinal_vis.png",
|
||||
"minio_path_json": "ai_result/20250710/1752128241631-patch_0012.pnginstance_results.json",
|
||||
"file_type": "pic"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 接口名:地类变化监测
|
||||
|
||||
|
||||
|
||||
接收前端的图片,对一期、二期的图像做变化监测,并且将计算结果存储到minio,消息通过mqtt发送
|
||||
|
||||
## 1、请求
|
||||
|
||||
接口 /ai/pic/back_detect_pic
|
||||
|
||||
方法 post
|
||||
|
||||
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
|
||||
|
||||
body
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
|
||||
"content_body": {
|
||||
"s3_id":1,
|
||||
"s3_url":{
|
||||
"early":"/test/1-00205.png", # 一期图像minio文件地址
|
||||
"later":"/test/2-00205.png" # 二期图像minio文件地址
|
||||
},
|
||||
"model_func_id":[30000,30001]
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## 2、响应
|
||||
|
||||
算法的响应分为两个部分
|
||||
|
||||
1、rest响应,表明收到请求
|
||||
|
||||
2、mqtt消息,持续输出计算结果
|
||||
|
||||
### 1、rest
|
||||
|
||||
```
|
||||
{
|
||||
"status": "success",
|
||||
"task_id": "0001111",
|
||||
"message": "Detection started successfully"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 2、mqtt
|
||||
|
||||
ip 112.44.103.230 端口 1883
|
||||
|
||||
topic thing/product/ai/events
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
|
||||
"minio": {
|
||||
"minio_path_1": "ai_result/20250627/1751007686483-1-00205.png", # 一期影像,minio地址
|
||||
"minio_path_2": "ai_result/20250627/1751007686541-2-00205.png", # 二期影像,minio地址
|
||||
"minio_path_result": "ai_result/20250627/1751007686.458642-result-2-00205.png", #识别结果,minio地址
|
||||
"file_type": "pic"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@ -165,13 +165,15 @@ def pic_detect_func(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
||||
try:
|
||||
frame_copy = frame.copy()
|
||||
results = counter(frame)
|
||||
|
||||
func_id=model_func_id_list[0]
|
||||
annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names,
|
||||
model_func_id_list,
|
||||
func_id,
|
||||
local_func_cache, para, cls, chinese_label,
|
||||
model_func_id_list[0])
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理帧错误: {e}")
|
||||
print(f"处理帧错误1: {e}")
|
||||
error_count += 1
|
||||
if error_count >= 5:
|
||||
print(f"连续处理错误达到5次 ,正在停止处理...")
|
||||
|
||||
BIN
pt/build-wall.pt
Normal file
BIN
pt/build-wall.pt
Normal file
Binary file not shown.
File diff suppressed because it is too large
Load Diff
48
yolo_api.py
48
yolo_api.py
@ -18,6 +18,7 @@ from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from CropLand_CD_module.detection import corpland_detection_func
|
||||
from cropland_module.detection import detection_func
|
||||
from grpc_proto.check_grpc_client import check_grpc_request
|
||||
from middleware.AsyncioMqttClient import AsyncMQTTClient, ConnectionContext, active_connections
|
||||
from middleware.TaskManager import TaskManager, task_manager
|
||||
from middleware.minio_util import downFile
|
||||
@ -37,6 +38,9 @@ from cv_back_video import startBackAIVideo
|
||||
from sanic_cors import CORS
|
||||
|
||||
from sanic import Sanic, Request
|
||||
# 引入其他模块
|
||||
from b3dm.earthwork_api import earthwork_bp
|
||||
from b3dm.terrain_api import terrain_bp
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
@ -63,7 +67,14 @@ DB_CONFIG = {
|
||||
"host": "8.137.54.85",
|
||||
"port": "5060"
|
||||
}
|
||||
|
||||
# 什邡数据库
|
||||
# DB_CONFIG = {
|
||||
# "dbname": "postgres",
|
||||
# "user": "postgres",
|
||||
# "password": "root",
|
||||
# "host": "222.213.91.11",
|
||||
# "port": "5061"
|
||||
# }
|
||||
|
||||
# 配置类
|
||||
class Config:
|
||||
@ -312,6 +323,9 @@ async def heartbeat_monitor(task_manager: TaskManager):
|
||||
|
||||
app = Sanic("YoloStreamService1")
|
||||
CORS(app)
|
||||
# 显式注册蓝图
|
||||
app.blueprint(earthwork_bp)
|
||||
app.blueprint(terrain_bp)
|
||||
|
||||
|
||||
# 启动心跳监测
|
||||
@ -657,11 +671,11 @@ async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio
|
||||
{
|
||||
|
||||
'path': config.model_path,
|
||||
'engine_path': config.engine_path,
|
||||
'so_path': config.so_path,
|
||||
# 'engine_path': config.engine_path,
|
||||
# 'so_path': config.so_path,
|
||||
# # 测试代码
|
||||
# 'engine_path': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\build.engine",
|
||||
# 'so_path': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\myplugins.dll",
|
||||
'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine",
|
||||
'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll",
|
||||
# 工地安全帽
|
||||
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
|
||||
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
|
||||
@ -1588,6 +1602,28 @@ async def stop_task_heart(request):
|
||||
"message": str(e)}, status=500)
|
||||
|
||||
|
||||
|
||||
@app.post("/ai/func/check_grpc")
|
||||
async def check_grpc(request):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
check_grpc_request()
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": "task_id",
|
||||
"message": "Detection started successfully"
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
return json_response({"status": "error", "message": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(request: Request, ws):
|
||||
"""WebSocket端点,处理前端连接和消息,为每个连接创建独立的MQTT客户端"""
|
||||
@ -1628,6 +1664,8 @@ async def websocket_endpoint(request: Request, ws):
|
||||
camera_para_url = "meta_data/camera_para/xyzj_camera_para.txt"
|
||||
if model2 == "M4D":
|
||||
camera_para_url = "meta_data/camera_para/xyzj_camera_para.txt"
|
||||
elif model2 == "M3TD":
|
||||
camera_para_url = "meta_data/camera_para/hami_camera_para .txt"
|
||||
elif model2 == "M4TD":
|
||||
camera_para_url = "meta_data/camera_para/hami_camera_para .txt"
|
||||
camera_file_path = downFile(camera_para_url)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user