Compare commits

..

28 Commits

Author SHA1 Message Date
liyubo
3a7bdbc3a4 视频流-识别-多任务并发 拆分成小模块 2026-05-12 09:36:53 +08:00
795a028d0e 增加数量超限报警功能 2026-04-29 19:03:48 +08:00
fc534a096e 增加超限施工功能 2026-04-25 21:37:18 +08:00
9720a07683 增减数量超限报警功能 2026-04-23 17:05:29 +08:00
0b283f6b8c 增减数量超限报警功能 2026-04-15 16:12:24 +08:00
be99472837 修复bug,过滤不允许的类别 2026-04-08 10:40:13 +08:00
fad554632e 增加类别计数 2026-03-24 17:44:25 +08:00
c17df2e460 增加sam3,集成并通讯成功 2026-03-05 14:51:08 +08:00
63f240ac3a 针对缺少图传模块,引起的进程挂起问题,优化读流时长参数 2026-02-26 14:45:05 +08:00
liyubo
dd931f6231 两期对比接口,土方量计算接口返回信息调整 2026-02-26 11:00:33 +08:00
146872a4dd 优化拉流丢包的问题,增加宽松策略 2026-02-05 02:26:38 +08:00
89181007c2 优化拉流丢包的问题,增加宽松策略 2026-01-29 05:59:39 +08:00
liyubo
9a09c1e1cf 坡度坡向tif生成 2026-01-29 11:51:20 +08:00
c5eeb87488 增加忽略项目 2026-01-27 14:13:37 +08:00
ee8733a0ce Merge branch 'develop' of http://222.212.85.86:8222/bdzl2/ai_project_v1 into develop 2026-01-27 14:13:01 +08:00
0ce543572b 增加忽略项目 2026-01-27 11:54:08 +08:00
929c670add 增加忽略项目 2026-01-27 11:51:05 +08:00
1656f81fe3 增加忽略项目 2026-01-27 11:46:36 +08:00
dfb89c70a3 增加墙面裂缝魔心,修改目标识别接口bug 2026-01-27 11:44:18 +08:00
a2d3e2e24b 增加墙面裂缝魔心,修改目标识别接口bug 2026-01-27 11:18:05 +08:00
liyubo
0f44df8cec 坡度坡向合并tif生成接口 2026-01-19 10:42:21 +08:00
liyubo
eb6ce0de46 3dtiles地图数据预加载接口,terrain3d_analyzer_color移动到slope_aspect_img 2026-01-19 09:36:22 +08:00
liyubo
8d4db9b6df 坡度坡向api集成到yolo api 2026-01-14 16:15:21 +08:00
eedca6cd50 Merge branch 'develop' of http://222.212.85.86:8222/bdzl2/ai_project_v1 into develop 2026-01-14 07:34:21 +08:00
liyubo
5c865a4418 坡度坡向土方量计算 2026-01-14 11:37:35 +08:00
0e952115c8 ffmpeg 拉流更换为cv2,解决了rtmp拉流延迟3s的问题 2026-01-14 06:53:44 +08:00
fbcc505a88 修改类别字段 2026-01-13 03:04:09 +08:00
b899c4e9de 上传grpc 测试demo 2026-01-05 16:29:39 +08:00
559 changed files with 432038 additions and 7239 deletions

47
.gitignore vendored Normal file
View File

@ -0,0 +1,47 @@
# 忽略所有 .log 文件
*.log
# 忽略特定目录(如 node_modules/
node_modules/
# 忽略本地配置文件(但保留示例文件)
config.local.json
!config.example.json
# 忽略编译输出目录
dist/
build/
test
*test*
*.pyc
__pycache__/
misc.xml
profiles_settings.xml
Project_Default.xml
*test*.py
# 忽略 Python 缓存文件
*.pyo
*.pyd
*.mo
*.so
uvmodule.log
# IDE
.idea/
*.iml
.vscode/
# Temp files
*.tmp
*.swp

3
.idea/.gitignore generated vendored
View File

@ -1,3 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml

3
.idea/misc.xml generated
View File

@ -1,4 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="Black">
<option name="sdkName" value="yolo_tensorrt" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="yolo_tensorrt" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="yolo_tensorrt" project-jdk-type="Python SDK" />
</project> </project>

View File

@ -1,100 +0,0 @@
import os
import shutil
import cv2
import collections
from ultralytics import YOLO
from miniohelp import downFile, upload_file, parse_minio_url # 确保你有这些工具函数
from minio import Minio
def process_images(yolo_model, image_list, class_filter, input_folder, output_folder, minio_info):
# 初始化 MinIO 客户端# 用配置字典初始化 Minio 客户端对象
# 清洗 endpoint去掉 http:// 或 https:// 前缀
endpoint = minio_info["MinIOEndpoint"].replace("http://", "").replace("https://", "")
# 初始化 MinIO 客户端
minio = Minio(
endpoint=endpoint,
access_key=minio_info["MinIOAccessKey"],
secret_key=minio_info["MinIOSecretKey"],
secure=False
)
os.makedirs(input_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)
model = YOLO(yolo_model)
class_ids_filter = [int(cls) for cls in class_filter.split(",")] if class_filter else None
output_image_list = []
for item in image_list:
img_id = item["id"]
img_url = item["path"]
# 解析 MinIO 地址
if img_url.startswith("http"):
bucket_name, img_path = parse_minio_url(img_url)
else:
bucket_name, img_path = "default-bucket", img_url
try:
# 下载原图到本地
local_input_path = os.path.join(input_folder, os.path.basename(img_path))
downFile(minio, img_path, bucket_name, local_input_path)
# 读取图像
image = cv2.imread(local_input_path)
if image is None:
raise ValueError(f"无法读取图像: {local_input_path}")
# YOLO 检测
results = model.predict(image,
classes=class_ids_filter,
conf=0.5,
iou = 0.111,
show_labels = False,)
result = results[0]
# 统计类别数
class_counts = collections.Counter(result.boxes.cls.cpu().numpy().astype(int)) if result.boxes is not None else {}
filtered_class_counts = {k: v for k, v in class_counts.items() if k in class_ids_filter}
# 转换所有的 numpy.int64 为 Python 的 int 类型
detected_classes = [int(cls) for cls in filtered_class_counts.keys()]
detected_numbers = [int(num) for num in filtered_class_counts.values()]
aim = bool(detected_classes)
# 保存标注图像
annotated_image = result.plot(labels=False)
filename_no_ext, ext = os.path.splitext(os.path.basename(img_path))
output_filename = f"{filename_no_ext}_ai{ext}"
local_output_path = os.path.join(output_folder, output_filename)
cv2.imwrite(local_output_path, annotated_image)
# 上传标注图像到 MinIO
minio_path = upload_file(minio, local_output_path, bucket_name, os.path.dirname(img_path))
except Exception as e:
print(f"[错误] 处理失败 - {img_path},错误: {str(e)}")
detected_classes = []
detected_numbers = []
aim = False
output_filename = ""
minio_path = ""
output_image_list.append({
"id": img_id,
"minio_path":minio_path,
"aim": aim,
"class": detected_classes,
"number": detected_numbers
})
# 清理临时目录
shutil.rmtree(input_folder, ignore_errors=True)
shutil.rmtree(output_folder, ignore_errors=True)
return {
"status": "success",
"message": "Detection completed",
"data": output_image_list
}

View File

@ -1,505 +0,0 @@
from sanic import Sanic, json, Blueprint,response
from sanic.exceptions import Unauthorized, NotFound
from sanic.response import json as json_response
from sanic_cors import CORS
from datetime import datetime
import logging
import uuid
import os
import asyncio
from minio import Minio
from ai_image import process_images # 你实现的图片处理函数
from queue import Queue
import gdal2tiles as gdal2tiles
from map_find import map_process_images
from yolo_train import auto_train
from map_cut import process_tiling
from cv_video_counter import start_video_session,switch_model_session,stop_video_session,stream_sessions
import torch
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
###################################################################################验证中间件和管理件##############################################################################################
async def token_and_resource_check(request):
# --- Token 验证 ---
token = request.headers.get('X-API-Token')
expected_token = request.app.config.get("VALID_TOKEN")
if not token or token != expected_token:
logger.warning(f"Unauthorized request with token: {token}")
raise Unauthorized("Invalid token")
# --- GPU 使用率检查 ---
try:
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
for i in range(num_gpus):
used = torch.cuda.memory_reserved(i)
total = torch.cuda.max_memory_reserved(i)
ratio = used / total if total else 0
logger.info(f"GPU {i} Usage: {ratio:.2%}")
if ratio > max_usage_ratio:
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
return json_response({
"status": "error",
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
}, status=503)
except Exception as e:
logger.error(f"GPU check failed: {e}")
return None # 允许请求继续
##################################################################################################################################################################################################
#创建Sanic应用
app = Sanic("ai_Service_v2")
CORS(app) # 允许跨域请求
task_progress = {}
@app.middleware("request")
async def global_middleware(request):
result = await token_and_resource_check(request)
if result:
return result
# 配置Token和最大GPU使用率
app.config.update({
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
"MAX_GPU_USAGE": 0.9
})
######################################################################地图切割相关的API########################################################################################################
#创建地图的蓝图
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
app.blueprint(map_tile_blueprint)
@map_tile_blueprint.post("/tile")
async def map_tile_api(request):
try:
# 1. 检查请求体
if not request.json:
return json_response(
{"status": "error", "message": "Request body is required"},
status=400
)
# 2. 解析必要字段
tile_data = request.json
tif_url = tile_data.get("tif_url")
prj_url = tile_data.get("prj_url")
if not tif_url or not prj_url:
return json_response(
{"status": "error", "message": "Both tif_url and prj_url are required"},
status=400
)
# 3. 处理业务逻辑(直接调用协程函数,不要用 asyncio.run
zoom_levels = tile_data.get("zoom_levels", "1-18")
try:
# 假设 process_tiling 是一个协程函数async def
result = await process_tiling(tif_url, prj_url, zoom_levels)
# 如果 process_tiling 是普通函数,用 asyncio.to_thread 包装
# result = await asyncio.to_thread(process_tiling, tif_url, prj_url, zoom_levels)
return json_response({
"status": "success",
"data": result
})
except Exception as processing_error:
logger.error(f"Processing failed: {str(processing_error)}", exc_info=True)
return json_response(
{"status": "error", "message": f"Processing error: {str(processing_error)}"},
status=500
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response(
{"status": "error", "message": str(e)}, # 直接返回字符串,不要用集合
status=500
)
#语义识别
@map_tile_blueprint.post("/uav")
async def process_handler(request):
"""
接口/map/uav
输入 JSON:
{
"urls": [
"http://example.com/img1.jpg",
"http://example.com/img2.jpg"
],
"yaml_name": "112.44.103.230",
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
"bucket_directory": "2025/seg"
"model_path": "deeplabv3plus_best.pth"
}
输出 JSON:
{
"code": 200,
"msg": "success",
"data": [
"http://minio.example.com/uav-results/2025/seg/result1.png",
"http://minio.example.com/uav-results/2025/seg/result2.png"
]
}
"""
try:
body = request.json
urls = body.get("urls", [])
yaml_name = body.get("yaml_name")
bucket_name = body.get("bucket_name")
bucket_directory = body.get("bucket_directory")
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
# 校验参数
if not urls or not isinstance(urls, list):
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
if not all([yaml_name, bucket_name, bucket_directory]):
return json({"code": 400, "msg": "Missing required parameters"})
# 调用图像处理函数
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
return json(result)
except Exception as e:
return json({"code": 500, "msg": f"Server error: {str(e)}"})
######################################################################yolo相关的API########################################################################################################
#创建yolo的蓝图
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
app.blueprint(yolo_tile_blueprint)
# YOLO URL APT
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
@yolo_tile_blueprint.post("/process_images")
async def process_images(request):
"""
{
"urls": [
"http://example.com/image1.jpg",
"http://example.com/image2.jpg",
"http://example.com/image3.jpg"
],
"yaml_name": "your_minio_config",
"bucket_name": "my-bucket",
"bucket_directory": "2025/uav-results",
"model_path": "deeplabv3plus_best.pth"
}
"""
data = request.json
urls = data.get("urls")
yaml_name = data.get("yaml_name")
bucket_name = data.get("bucket_name")
bucket_directory = data.get("bucket_directory")
uav_model_path = data.get("uav_model_path")
if not urls or not yaml_name or not bucket_name or not uav_model_path:
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
task_id = str(uuid.uuid4())
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
# 启动后台任务
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
@yolo_tile_blueprint.get("/task_status/<task_id>")
async def task_status(request, task_id):
progress = task_progress.get(task_id)
if not progress:
return response.json({"code": 404, "msg": "Task not found"}, status=404)
return response.json({"code": 200, "msg": "Task status", "data": progress})
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
try:
task_progress[task_id]["status"] = "running"
task_progress[task_id]["progress"] = 10 # 开始进度
# 下载、推理、上传阶段分别更新进度
def progress_callback(stage, percent):
task_progress[task_id]["status"] = stage
task_progress[task_id]["progress"] = percent
result = await asyncio.to_thread(
map_process_images_with_progress,
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
)
task_progress[task_id]["status"] = "completed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = result
except Exception as e:
task_progress[task_id]["status"] = "failed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = str(e)
# YOLO检测API
@yolo_tile_blueprint.post("/picture")
async def yolo_detect_api(request):
try:
detect_data = request.json
# 解析必要字段
image_list = detect_data.get("image_list")
yolo_model = detect_data.get("yolo_model", "best.pt")
class_filter = detect_data.get("class", None)
minio_info = detect_data.get("minio", None)
if not image_list:
return json_response({"status": "error", "message": "image_list is required"}, status=400)
if not minio_info:
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
# 创建临时文件夹
input_folder = f"./temp_input_{str(uuid.uuid4())}"
output_folder = f"./temp_output_{str(uuid.uuid4())}"
# 执行图像处理
result = await asyncio.to_thread(
process_images,
yolo_model=yolo_model,
image_list=image_list,
class_filter=class_filter,
input_folder=input_folder,
output_folder=output_folder,
minio_info=minio_info
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
# YOLO自动训练
@yolo_tile_blueprint.post("/train")
async def yolo_train_api(request):
"""
自动训练模型
输入 JSON:
{
"db_host": str,
"db_database": str,
"db_user": str,
"db_password": str,
"db_port": int,
"model_id": int,
"img_path": str,
"label_path": str,
"new_path": str,
"split_list": List[float],
"class_names": Optional[List[str]],
"project_name": str
}
输出 JSON:
{
"base_metrics": Dict[str, float],
"best_model_path": str,
"final_metrics": Dict[str, float]
}
"""
try:
# 修改为直接访问 request.json 而不是调用它
data = request.json
if not data:
return json_response({"status": "error", "message": "data is required"}, status=400)
# 执行图像处理
result = await asyncio.to_thread(
auto_train,
data
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
###########################################################################################视频流相关的API#######################################################################################################
#创建视频流的蓝图
stream_tile_blueprint = Blueprint('stream', url_prefix='/stream_test/')
app.blueprint(stream_tile_blueprint)
#
# 任务管理器
class StreamTaskManager:
def __init__(self):
self.active_tasks = {}
self.task_status = {}
self.task_timestamps = {}
self.task_queue = Queue(maxsize=10)
def add_task(self, task_id: str, task_info: dict) -> None:
if self.task_queue.full():
oldest_task_id = self.task_queue.get()
self.remove_task(oldest_task_id)
stop_video_session(self.active_tasks[oldest_task_id]["session_id"])
self.active_tasks[task_id] = task_info
self.task_status[task_id] = "running"
self.task_timestamps[task_id] = datetime.now()
self.task_queue.put(task_id)
logger.info(f"Task {task_id} started")
def remove_task(self, task_id: str) -> None:
if task_id in self.active_tasks:
del self.active_tasks[task_id]
del self.task_status[task_id]
del self.task_timestamps[task_id]
logger.info(f"Task {task_id} removed")
def get_task_info(self, task_id: str) -> dict:
if task_id not in self.active_tasks:
raise NotFound("Task not found")
return {
"task_info": self.active_tasks[task_id],
"status": self.task_status[task_id],
"start_time": self.task_timestamps[task_id].isoformat()
}
task_manager = StreamTaskManager()
# ---------- API Endpoints ----------
@stream_tile_blueprint.post("/start")
async def api_start(request):
"""
启动视频流会话
输入 JSON:
{
"video_path": str,
"output_url": str,
"model_path": str,
"cls": List[int],
"confidence": float,
"cls2": Optional[List[int]]
"push": bool
}
输出 JSON:
{
"session_id": str,
"task_id": str,
"message": "started"
}
"""
data = request.json
task_id = str(uuid.uuid4())
# 启动视频处理会话,并传入 task_id
session_id = start_video_session(
video_path = data.get("video_path"),
output_url = data.get("output_url"),
model_path = data.get("model_path"),
cls = data.get("cls"),
confidence = data.get("confidence", 0.5),
cls2 = data.get("cls2", []),
push = data.get("push", False),
)
# 注册到任务管理器
task_manager.add_task(task_id, {
"session_id": session_id,
"video_path": data.get("video_path"),
"output_url": data.get("output_url"),
"model_path": data.get("model_path"),
"class_filter": data.get("cls", []),
"push": data.get("push", False),
"start_time": datetime.now().isoformat()
})
return json({"session_id": session_id, "task_id": task_id, "message": "started"})
@stream_tile_blueprint.post("/stop")
async def api_stop(request):
"""
停止指定会话
输入 JSON: { "session_id": str }
输出 JSON: { "session_id": str, "message": "stopped" }
"""
session_id = request.json.get("session_id")
stop_video_session(session_id)
# 同步移除任务
for tid, info in list(task_manager.active_tasks.items()):
if info.get("session_id") == session_id:
task_manager.remove_task(tid)
break
return json({"session_id": session_id, "message": "stopped"})
@stream_tile_blueprint.post("/switch_model")
async def api_switch_model(request):
"""
切换会话模型
输入 JSON: { "session_id": str, "new_model_path": str }
输出 JSON: { "session_id": str, "new_model_path": str, "message": "model switched" }
"""
data = request.json
session_id = data.get("session_id")
new_model = data.get("new_model_path")
switch_model_session(session_id, new_model)
return json({"session_id": session_id, "new_model_path": new_model, "message": "model switched"})
@stream_tile_blueprint.get("/sessions")
async def api_list_sessions(request):
"""
列出所有当前会话
输出 JSON: { "sessions": [{"session_id": str, "status": "running"}, ...] }
"""
sessions = [
{"session_id": sid, "status": "running"}
for sid in stream_sessions.keys()
]
return json({"sessions": sessions})
# 统一的任务查询接口(含视频流)
@stream_tile_blueprint.get("/tasks")
async def api_list_tasks(request):
"""
列出所有任务含状态开始时间详情
"""
tasks = []
for tid in task_manager.active_tasks:
info = task_manager.get_task_info(tid)
tasks.append({"task_id": tid, **info})
return json({"tasks": tasks})
##################################################################################################################################################################################################
if __name__ == '__main__':
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)

60
b3dm/b3dm_api.py Normal file
View File

@ -0,0 +1,60 @@
from sanic import Sanic, Request, json
from sanic_cors import CORS
import logging
import time
from earthwork_api import earthwork_bp
from terrain_api import terrain_bp
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建Sanic应用
app = Sanic("TerrainAnalysisAPI")
# 显式注册蓝图
app.blueprint(earthwork_bp)
app.blueprint(terrain_bp)
CORS(app, automatic_options=True)
# 中间件:请求计时
@app.middleware("request")
async def add_start_time(request: Request):
request.ctx.start_time = time.time()
@app.middleware("response")
async def add_response_time(request: Request, response):
if hasattr(request.ctx, "start_time"):
process_time = (time.time() - request.ctx.start_time) * 1000
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
@app.get("/api/v1/health")
async def health_check(request: Request):
"""健康检查"""
return json({
"status": "healthy",
"timestamp": time.time(),
"service": "terrain-analysis-api",
"version": "1.0.0"
})
# 错误处理
@app.exception(Exception)
async def handle_exception(request: Request, exception):
"""全局异常处理"""
logger.error(f"未处理的异常: {exception}")
return json({
"error": "服务器内部错误",
"message": str(exception) if app.debug else "请稍后重试",
"success": False
}, status=500)
if __name__ == "__main__":
# 启动服务器
app.run(
host="0.0.0.0",
port=8000,
debug=True, # 生产环境设为False
access_log=True,
auto_reload=True
)

View File

@ -0,0 +1,542 @@
from minio import Minio
from minio.error import S3Error
import json
import os
import numpy as np
from urllib.parse import urlparse
import hashlib
import time
import re
import pickle
from datetime import datetime
class MinIO3DTilesManager:
def __init__(self, endpoint_url, access_key, secret_key, secure=False,
mapping_file="minio_path_mapping.pkl"):
"""
初始化MinIO客户端
Args:
endpoint_url: MinIO服务地址 (: 222.212.85.86:9001)
access_key: 访问密钥
secret_key: 秘密密钥
secure: 是否使用HTTPS
mapping_file: 路径映射文件名
"""
if endpoint_url.startswith('http://'):
endpoint_url = endpoint_url.replace('http://', '')
elif endpoint_url.startswith('https://'):
endpoint_url = endpoint_url.replace('https://', '')
secure = True
self.endpoint_url = endpoint_url
self.access_key = access_key
self.secret_key = secret_key
self.minio_client = Minio(
endpoint_url,
access_key=access_key,
secret_key=secret_key,
secure=secure
)
# 获取脚本所在目录
self.script_dir = os.path.dirname(os.path.abspath(__file__))
# 映射文件路径
self.mapping_file = os.path.join(self.script_dir, mapping_file)
# 加载现有的路径映射
self.path_mapping = self.load_path_mapping()
def load_path_mapping(self):
"""加载路径映射数据"""
if os.path.exists(self.mapping_file):
try:
with open(self.mapping_file, 'rb') as f:
mapping = pickle.load(f)
return mapping
except Exception as e:
return {}
else:
return {}
def save_path_mapping(self):
"""保存路径映射数据"""
try:
with open(self.mapping_file, 'wb') as f:
pickle.dump(self.path_mapping, f)
return True
except Exception as e:
return False
def get_cache_key(self, tileset_url, save_dir=None):
"""生成缓存键"""
# 基于URL和保存目录生成缓存键
cache_data = f"{tileset_url}|{save_dir}"
return hashlib.md5(cache_data.encode()).hexdigest()
def get_cached_tileset_info(self, tileset_url, save_dir=None):
"""获取缓存的tileset信息"""
cache_key = self.get_cache_key(tileset_url, save_dir)
# 检查缓存映射中是否有这个tileset
for file_id, info in self.path_mapping.items():
if info.get('cache_key') == cache_key and info.get('is_tileset_root'):
# 检查入口文件是否存在
local_path = info.get('local_path')
if local_path and os.path.exists(local_path):
return local_path
return None
def update_tileset_cache(self, tileset_url, save_dir, local_path):
"""更新tileset缓存信息"""
cache_key = self.get_cache_key(tileset_url, save_dir)
# 将tileset根文件标记为缓存
entry_bucket, entry_path = self.parse_minio_url(tileset_url)
file_id = f"{entry_bucket}/{entry_path}"
if file_id in self.path_mapping:
self.path_mapping[file_id]['cache_key'] = cache_key
self.path_mapping[file_id]['is_tileset_root'] = True
self.path_mapping[file_id]['tileset_url'] = tileset_url
self.path_mapping[file_id]['save_dir'] = save_dir
self.path_mapping[file_id]['cache_time'] = datetime.now().isoformat()
def download_full_tileset(self, tileset_url, save_dir=None, region_filter=None, use_cache=True):
"""
下载完整的3D Tiles数据集支持缓存功能
Args:
tileset_url: MinIO上的tileset.json URL
save_dir: 本地保存目录
region_filter: 区域过滤器
use_cache: 是否使用缓存
Returns:
tuple: (success, result)
- success: True/False
- result: 如果success=True且use_cache=True返回本地路径否则返回True/False
"""
if save_dir is None:
save_dir = os.path.join(self.script_dir, "data_3dtiles")
# 清理保存目录名称
save_dir = self.clean_file_path(save_dir)
# 检查缓存:只需检查入口文件是否存在
if use_cache:
cached_path = self.get_cached_tileset_info(tileset_url, save_dir)
if cached_path:
# 入口文件存在,默认缓存完备
return True, cached_path
# 解析URL
entry_bucket, entry_path = self.parse_minio_url(tileset_url)
if not entry_bucket or not entry_path:
return False, "无法解析URL"
entry_dir = os.path.dirname(entry_path)
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
visited = set()
# 下载入口文件
entry_local_path = self.get_local_path(entry_bucket, entry_path, save_dir)
success, result = self.download_file(entry_bucket, entry_path, entry_local_path)
if not success:
return False, f"入口文件下载失败: {result}"
entry_id = f"{entry_bucket}/{entry_path}"
visited.add(entry_id)
# 加载tileset数据
tileset_data = self.load_json_from_minio(entry_bucket, entry_path)
if not tileset_data or "root" not in tileset_data:
return False, "无效的tileset.json文件"
# 遍历下载所有文件
self.traverse_and_download_tileset(
tileset_data["root"],
entry_bucket,
entry_dir,
entry_bucket,
entry_dir,
save_dir,
region_filter,
None,
visited
)
# 更新缓存信息
self.update_tileset_cache(tileset_url, save_dir, entry_local_path)
# 保存路径映射
self.save_path_mapping()
if use_cache:
return True, entry_local_path
else:
return True, True
def get_tileset_local_path(self, tileset_url, save_dir=None):
"""
获取已缓存的tileset本地路径
Args:
tileset_url: tileset的URL
save_dir: 保存目录
Returns:
str: 本地路径如果未缓存则返回None
"""
if save_dir is None:
save_dir = os.path.join(self.script_dir, "data_3dtiles")
return self.get_cached_tileset_info(tileset_url, save_dir)
def clear_tileset_cache(self, tileset_url=None, save_dir=None):
"""
清除tileset缓存
Args:
tileset_url: 指定要清除的tileset URL如果为None则清除所有
save_dir: 保存目录
Returns:
bool: 成功/失败
"""
try:
if tileset_url:
# 清除指定tileset的缓存
cache_key = self.get_cache_key(tileset_url, save_dir)
# 找出所有相关的缓存条目
to_remove = []
for file_id, info in self.path_mapping.items():
if info.get('cache_key') == cache_key:
to_remove.append(file_id)
# 删除这些条目
for file_id in to_remove:
del self.path_mapping[file_id]
print(f"已清除tileset缓存: {tileset_url}")
else:
# 清除所有缓存
self.path_mapping = {}
if os.path.exists(self.mapping_file):
os.remove(self.mapping_file)
print("已清除所有缓存")
return True
except Exception as e:
return False
# 以下是原有的辅助方法
def clean_filename(self, filename):
"""清理文件名中的特殊字符"""
if not filename:
return ""
cleaned = re.sub(r'[<>:"/\\|?*\x00-\x1F]', '_', filename)
cleaned = re.sub(r'_+', '_', cleaned)
cleaned = cleaned.strip(' _')
return cleaned
def parse_minio_url(self, url):
"""解析MinIO URL"""
if url.startswith('http://') or url.startswith('https://'):
parsed = urlparse(url)
path = parsed.path.lstrip('/')
parts = path.split('/', 1)
if len(parts) == 2:
bucket, key = parts
else:
bucket = parts[0]
key = ""
return bucket, key
else:
parts = url.split('/', 1)
if len(parts) == 2:
bucket, key = parts
else:
bucket = parts[0]
key = ""
return bucket, key
def download_file(self, bucket_name, object_name, file_path):
"""从MinIO下载文件"""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
# 清理文件名
clean_file_path = self.clean_file_path(file_path)
# 检查是否已下载
file_id = f"{bucket_name}/{object_name}"
if file_id in self.path_mapping:
mapped_path = self.path_mapping[file_id]['local_path']
if os.path.exists(mapped_path):
return True, mapped_path
# 下载文件
self.minio_client.fget_object(
bucket_name,
object_name,
clean_file_path
)
# 更新路径映射
self.path_mapping[file_id] = {
'local_path': clean_file_path,
'bucket': bucket_name,
'object': object_name,
'download_time': datetime.now().isoformat(),
'size': os.path.getsize(clean_file_path)
}
return True, clean_file_path
except S3Error as e:
return False, str(e)
except Exception as e:
return False, str(e)
def clean_file_path(self, file_path):
"""清理文件路径中的所有特殊字符"""
dir_name = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
if dir_name:
dir_parts = dir_name.split(os.sep)
cleaned_parts = []
for part in dir_parts:
cleaned_part = self.clean_filename(part)
if cleaned_part:
cleaned_parts.append(cleaned_part)
cleaned_dir = os.sep.join(cleaned_parts)
else:
cleaned_dir = ""
cleaned_file = self.clean_filename(file_name)
if cleaned_dir:
cleaned_path = os.path.join(cleaned_dir, cleaned_file)
else:
cleaned_path = cleaned_file
return cleaned_path
def load_json_from_minio(self, bucket_name, object_name):
"""从MinIO加载JSON文件"""
try:
self.minio_client.stat_object(bucket_name, object_name)
response = self.minio_client.get_object(bucket_name, object_name)
content = response.read().decode('utf-8')
response.close()
response.release_conn()
return json.loads(content)
except S3Error as e:
return None
except Exception as e:
return None
def get_local_path(self, bucket_name, object_name, save_dir):
"""生成保持目录结构的本地路径"""
clean_bucket = self.clean_filename(bucket_name)
path_parts = object_name.split('/')
cleaned_parts = []
for part in path_parts:
cleaned_part = self.clean_filename(part)
if cleaned_part:
cleaned_parts.append(cleaned_part)
if cleaned_parts:
cleaned_relative = '/'.join(cleaned_parts)
local_path = os.path.join(save_dir, clean_bucket, cleaned_relative)
else:
local_path = os.path.join(save_dir, clean_bucket)
return os.path.normpath(local_path)
def traverse_and_download_tileset(self, tile_obj, current_bucket, current_dir,
base_bucket, base_dir, save_dir,
region_filter=None, parent_transform=None,
visited=None):
"""递归遍历并下载3D Tiles文件"""
if visited is None:
visited = set()
current_transform = parent_transform
if "transform" in tile_obj:
tile_mat = tile_obj["transform"]
if current_transform is None:
current_transform = tile_mat
else:
mat1 = np.array(current_transform).reshape(4, 4)
mat2 = np.array(tile_mat).reshape(4, 4)
combined_mat = np.dot(mat1, mat2).flatten().tolist()
current_transform = combined_mat
skip_current_tile = False
if region_filter and "boundingVolume" in tile_obj:
if not region_filter.check_tile_bounding_volume(tile_obj["boundingVolume"]):
skip_current_tile = True
if not skip_current_tile and "content" in tile_obj and "uri" in tile_obj["content"]:
tile_uri = tile_obj["content"]["uri"]
file_bucket = current_bucket
file_path = ""
if tile_uri.startswith('http://') or tile_uri.startswith('https://'):
parsed_bucket, parsed_path = self.parse_minio_url(tile_uri)
if parsed_bucket:
file_bucket = parsed_bucket
file_path = parsed_path
else:
if current_dir:
file_path = os.path.join(current_dir, tile_uri).replace('\\', '/')
else:
file_path = tile_uri
file_path = file_path.lstrip('/')
file_id = f"{file_bucket}/{file_path}"
if file_id not in visited:
print(f"下载文件:{file_id}")
visited.add(file_id)
local_path = self.get_local_path(file_bucket, file_path, save_dir)
self.download_file(file_bucket, file_path, local_path)
if file_path.lower().endswith('.json'):
sub_tileset = self.load_json_from_minio(file_bucket, file_path)
if sub_tileset and "root" in sub_tileset:
sub_dir = os.path.dirname(file_path) if file_path else ""
self.traverse_and_download_tileset(
sub_tileset["root"],
file_bucket,
sub_dir,
base_bucket,
base_dir,
save_dir,
region_filter,
current_transform,
visited
)
if "children" in tile_obj:
for child_tile in tile_obj["children"]:
self.traverse_and_download_tileset(
child_tile,
current_bucket,
current_dir,
base_bucket,
base_dir,
save_dir,
region_filter,
current_transform,
visited
)
def upload_file(self, bucket_name, object_name, file_path):
"""上传文件到MinIO"""
try:
if not os.path.exists(file_path):
return False, f"文件不存在: {file_path}"
file_size = os.path.getsize(file_path)
self.minio_client.fput_object(bucket_name, object_name, file_path)
return True, f"{bucket_name}/{object_name}"
except S3Error as e:
return False, f"MinIO上传错误: {e}"
except Exception as e:
return False, f"上传失败: {str(e)}"
def upload_directory(self, bucket_name, local_dir, remote_prefix=""):
"""上传目录到MinIO"""
if not os.path.exists(local_dir):
return [], [f"目录不存在: {local_dir}"]
uploaded_files = []
failed_files = []
for root, dirs, files in os.walk(local_dir):
for file in files:
local_path = os.path.join(root, file)
rel_path = os.path.relpath(local_path, local_dir)
if remote_prefix:
remote_path = os.path.join(remote_prefix, rel_path).replace('\\', '/')
else:
remote_path = rel_path.replace('\\', '/')
success, message = self.upload_file(bucket_name, remote_path, local_path)
if success:
uploaded_files.append(remote_path)
else:
failed_files.append((remote_path, message))
return uploaded_files, failed_files
def check_and_create_bucket(self, bucket_name):
"""检查并创建bucket"""
try:
if not self.minio_client.bucket_exists(bucket_name):
self.minio_client.make_bucket(bucket_name)
return True, f"创建bucket: {bucket_name}"
return True, f"bucket已存在: {bucket_name}"
except S3Error as e:
return False, f"创建bucket失败: {e}"
# 使用示例
if __name__ == "__main__":
# 配置参数
ENDPOINT_URL = "222.212.85.86:9000"
ACCESS_KEY = "WuRenJi"
SECRET_KEY = "WRJ@2024"
# 初始化管理器
manager = MinIO3DTilesManager(
endpoint_url=ENDPOINT_URL,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
# 使用缓存下载tileset
tileset_url = "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/石棉0908/terra_b3dms/tileset.json"
# 第一次下载(会下载到本地)
success, result = manager.download_full_tileset(tileset_url, use_cache=True)
if success:
print(f"下载成功,本地路径: {result}")
# 第二次下载相同URL直接从缓存返回
success, result = manager.download_full_tileset(tileset_url, use_cache=True)
if success:
print(f"从缓存获取,本地路径: {result}")
# 强制重新下载(忽略缓存)
success, result = manager.download_full_tileset(tileset_url, use_cache=False)
if success:
print("强制重新下载成功")
# 获取缓存的本地路径
local_path = manager.get_tileset_local_path(tileset_url)
if local_path:
print(f"缓存的本地路径: {local_path}")

1367
b3dm/data_3dtiles_to_dem.py Normal file

File diff suppressed because it is too large Load Diff

530
b3dm/earthwork_api.py Normal file
View File

@ -0,0 +1,530 @@
# pip install fastapi uvicorn pdal pyvista numpy
from sanic import Blueprint, Request, json
from sanic.response import text
from typing import Dict, Any
import logging
from b3dm.earthwork_calculator_point_cloud import EarthworkCalculatorPointCloud
# 导入计算模块
from b3dm.earthwork_calculator_3d_tiles import EarthworkCalculator3dTiles, AlgorithmType, EarthworkResult3dTiles
from b3dm.tileset_data_source import TilesetDataSource
earthwork_bp = Blueprint("earthwork", url_prefix="")
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 初始化函数
def init_app(url, type = "3dtiles"):
"""初始化应用"""
data_source = None
calculator_3d_tiles = None
calculator_point_cloud = None
try:
# 初始化数据源
data_source = TilesetDataSource(url)
data_source.dowload_map_data(url)
if type == "3dtiles" :
# 初始化计算器-3dTiles
calculator_3d_tiles = EarthworkCalculator3dTiles(data_source)
elif type == "pointcloud" :
# 初始化计算器-点云
calculator_point_cloud = EarthworkCalculatorPointCloud(data_source.tileset_path)
else :
logger.info(f"不支持的3d地图数据格式:{type}")
raise
logger.info("土方量计算器初始化完成")
return {
"data_source":data_source,
"calculator_3d_tiles":calculator_3d_tiles,
"calculator_point_cloud":calculator_point_cloud
}
except ImportError as e:
logger.error(f"依赖库缺失: {str(e)}")
raise
except Exception as e:
logger.error(f"初始化失败: {str(e)}")
raise
# 土方量计算接口-3dTiles
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles")
async def calc_earthwork(request: Request):
"""
土方量计算接口
请求参数示例:
{
"polygonCoords": [
[
115.70440468338526,
30.77363140345639
],
[
115.70443054007985,
30.773510462589584
],
[
115.70459702429197,
30.77360789911405
]
],
"designElevation": 100,
"algorithm": "tin",
"resolution": 1,
"crs": "EPSG:4326",
"interpolationMethod": "linear",
"url": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json"
}
"""
try:
# 1. 接收前端传参
data = request.json
if not data:
return _error_response("请求参数不能为空", 400)
# 2. 提取参数
polygon_coords = data.get("polygonCoords")
design_elevation = data.get("designElevation")
url = data.get("url")
if not polygon_coords:
return _error_response("多边形坐标不能为空", 400)
if design_elevation is None:
return _error_response("设计高程不能为空", 400)
if url is None:
return _error_response("地图不能为空", 400)
# 3. 可选参数
algorithm = data.get("algorithm", "tin")
resolution = data.get("resolution", 1.0)
crs = data.get("crs", "EPSG:4326")
interpolation_method = data.get("interpolationMethod", "linear")
# 4. 参数验证
if len(polygon_coords) < 3:
return _error_response("多边形至少需要3个点", 400)
# 检查多边形是否闭合,如不闭合则自动闭合
if polygon_coords[0] != polygon_coords[-1]:
polygon_coords.append(polygon_coords[0])
# 算法验证
if algorithm not in ["grid", "tin", "prism"]:
return _error_response("算法必须是 grid, tin 或 prism", 400)
# 分辨率验证
if resolution <= 0 or resolution > 100:
return _error_response("分辨率必须在0-100米之间", 400)
# 5. 确保计算器已初始化
app_info = init_app(url)
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
# 6. 执行计算
algorithm_type = AlgorithmType(algorithm)
result = await calculator_3d_tiles.calculate(
polygon_coords=polygon_coords,
design_elevation=design_elevation,
algorithm=algorithm_type,
resolution=resolution,
target_crs=crs,
interpolation_method=interpolation_method
)
# 7. 返回成功响应
res_dict = result.to_dict()
res_dict["calculation_details"] = None
res_dict["elevation_statistics"] = None
res_dict["volume_distribution"] = None
return _success_response(res_dict)
except ValueError as e:
logger.warning(f"参数验证失败: {str(e)}")
return _error_response(f"参数错误: {str(e)}", 400)
except Exception as e:
logger.error(f"计算失败: {str(e)}")
return _error_response(f"服务器内部错误: {str(e)}", 500)
# 两期对比接口-3dTiles
@earthwork_bp.post("/api/v1/calc/twoPhaseComparison")
async def two_phase_comparison(request: Request):
"""
两期对比接口
请求参数示例:
{
"polygonCoords": [
[
115.70440468338526,
30.77363140345639
],
[
115.70443054007985,
30.773510462589584
],
[
115.70459702429197,
30.77360789911405
]
],
"designElevation": 100,
"algorithm": "grid",
"resolution": 1,
"crs": "EPSG:4326",
"interpolationMethod": "linear",
"urlA": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json",
"urlB": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json"
}
"""
try:
# 1. 接收前端传参
data = request.json
if not data:
return _error_response("请求参数不能为空", 400)
# 2. 提取参数
polygon_coords = data.get("polygonCoords")
design_elevation = data.get("designElevation", 1000)
urlA = data.get("urlA")
urlB = data.get("urlB")
if not polygon_coords:
return _error_response("多边形坐标不能为空", 400)
if design_elevation is None:
return _error_response("设计高程不能为空", 400)
if urlA is None or urlB is None :
return _error_response("对比地图不能为空", 400)
# 3. 可选参数
algorithm = data.get("algorithm", "tin")
resolution = data.get("resolution", 1.0)
crs = data.get("crs", "EPSG:4326")
interpolation_method = data.get("interpolationMethod", "linear")
# 4. 参数验证
if len(polygon_coords) < 3:
return _error_response("多边形至少需要3个点", 400)
# 检查多边形是否闭合,如不闭合则自动闭合
if polygon_coords[0] != polygon_coords[-1]:
polygon_coords.append(polygon_coords[0])
# 算法验证
if algorithm not in ["grid", "tin", "prism"]:
return _error_response("算法必须是 grid, tin 或 prism", 400)
# 分辨率验证
if resolution <= 0 or resolution > 100:
return _error_response("分辨率必须在0-100米之间", 400)
# 5. 确保计算器已初始化
app_info_a = init_app(urlA)
if not app_info_a.get('data_source').tileset_path :
return _error_response(f"下载地图失败:{urlA}", 400)
calculator_3d_tiles_a = app_info_a.get("calculator_3d_tiles")
app_info_b = init_app(urlB)
if not app_info_b.get('data_source').tileset_path :
return _error_response(f"下载地图失败:{urlB}", 400)
calculator_3d_tiles_b = app_info_b.get("calculator_3d_tiles")
# 6. 执行计算
algorithm_type = AlgorithmType.GRID
result_a = await calculator_3d_tiles_a.calculate(
polygon_coords=polygon_coords,
design_elevation=design_elevation,
algorithm=algorithm_type,
resolution=resolution,
target_crs=crs,
interpolation_method=interpolation_method
)
result_b = await calculator_3d_tiles_b.calculate(
polygon_coords=polygon_coords,
design_elevation=design_elevation,
algorithm=algorithm_type,
resolution=resolution,
target_crs=crs,
interpolation_method=interpolation_method
)
# 获取网格数据
grids_a = result_a.calculation_details
grids_b = result_b.calculation_details
# 比较网格数据
comparison_result = calculator_3d_tiles_a.compare_grid_cells(grids_a, grids_b)
# 转换为字典
result_dict = comparison_result.to_dict()
# 7. 返回成功响应
return _success_response(result_dict)
except ValueError as e:
logger.warning(f"参数验证失败: {str(e)}")
return _error_response(f"参数错误: {str(e)}", 400)
except Exception as e:
logger.error(f"计算失败: {str(e)}")
return _error_response(f"服务器内部错误: {str(e)}", 500)
# 验证接口
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles/validate")
async def validate_earthwork(request: Request):
"""验证计算参数接口"""
try:
# 1. 接收前端传参
data = request.json
if not data:
return _error_response("请求参数不能为空", 400)
# 2. 提取参数
polygon_coords = data.get("polygonCoords")
if not polygon_coords:
return _error_response("多边形坐标不能为空", 400)
url = data.get("url")
if url is None:
return _error_response("地图不能为空", 400)
# 3. 参数验证
if len(polygon_coords) < 3:
return _error_response("多边形至少需要3个点", 400)
# 检查多边形是否闭合,如不闭合则自动闭合
if polygon_coords[0] != polygon_coords[-1]:
polygon_coords.append(polygon_coords[0])
# 4. 确保计算器已初始化
app_info = init_app(url)
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
# 5. 执行验证
validation_result = await calculator_3d_tiles.validate(polygon_coords)
# 6. 返回结果
return _success_response(validation_result)
except Exception as e:
logger.error(f"验证失败: {str(e)}")
return _error_response(f"验证失败: {str(e)}", 400)
# 获取算法列表接口
@earthwork_bp.get("/api/v1/calc/earthwork3dTiles/algorithms")
async def get_algorithms(request: Request):
"""获取支持的算法列表接口"""
try:
algorithms = [
{
"id": "grid",
"name": "格网法",
"description": "将计算区域划分为规则格网,通过插值计算每个格网的高程变化,适合平坦或规则地形",
"accuracy": "中等",
"performance": "快速",
"parameters": {
"resolution": {
"name": "格网分辨率",
"description": "格网大小(米),影响计算精度和性能",
"default": 1.0,
"range": [0.1, 10.0]
}
}
},
{
"id": "tin",
"name": "三角网法",
"description": "基于不规则三角网TIN构建地形表面计算每个三角形的体积变化适合复杂地形",
"accuracy": "",
"performance": "中等",
"parameters": {
"resolution": {
"name": "不适用",
"description": "三角网法不使用固定的分辨率参数",
"default": None
}
}
},
{
"id": "prism",
"name": "三棱柱法",
"description": "结合三角网和垂直棱柱的高精度算法,计算每个三棱柱的体积,精度最高",
"accuracy": "最高",
"performance": "较慢",
"parameters": {
"resolution": {
"name": "棱柱宽度",
"description": "棱柱宽度(米),影响计算精度",
"default": 1.0,
"range": [0.1, 5.0]
}
}
}
]
return _success_response(algorithms)
except Exception as e:
logger.error(f"获取算法列表失败: {str(e)}")
return _error_response(f"获取算法列表失败: {str(e)}", 500)
# 批量计算接口
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles/batch")
async def batch_calc_earthwork(request: Request):
"""批量土方量计算接口"""
try:
# 1. 接收前端传参
data = request.json
if not data:
return _error_response("请求参数不能为空", 400)
calculations = data.get("calculations", [])
if not calculations:
return _error_response("计算任务列表不能为空", 400)
if len(calculations) > 100:
return _error_response("批量计算数量超过限制最多100个", 400)
# 3. 执行批量计算
results = []
errors = []
for i, calc_data in enumerate(calculations):
try:
# 提取参数
polygon_coords = calc_data.get("polygonCoords")
design_elevation = calc_data.get("designElevation")
url = calc_data.get("url")
if not polygon_coords or design_elevation is None or url is None:
errors.append({
"index": i,
"error": "缺少必要参数"
})
continue
# 参数验证
if len(polygon_coords) < 3:
errors.append({
"index": i,
"error": "多边形至少需要3个点"
})
continue
# 检查多边形是否闭合
if polygon_coords[0] != polygon_coords[-1]:
polygon_coords.append(polygon_coords[0])
# 可选参数
algorithm = calc_data.get("algorithm", "tin")
resolution = calc_data.get("resolution", 1.0)
crs = calc_data.get("crs", "EPSG:4326")
interpolation_method = calc_data.get("interpolationMethod", "linear")
# 2. 确保计算器已初始化
app_info = init_app(url)
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
# 执行计算
algorithm_type = AlgorithmType(algorithm)
result = await calculator_3d_tiles.calculate(
polygon_coords=polygon_coords,
design_elevation=design_elevation,
algorithm=algorithm_type,
resolution=resolution,
target_crs=crs,
interpolation_method=interpolation_method
)
results.append(result)
except Exception as e:
errors.append({
"index": i,
"error": str(e),
"polygon": polygon_coords if 'polygon_coords' in locals() else None
})
continue
# 4. 返回结果
batch_result = {
"results": results,
"errors": errors,
"summary": {
"total": len(calculations),
"success": len(results),
"failed": len(errors),
"successRate": f"{(len(results)/len(calculations)*100):.1f}%" if calculations else "0%"
}
}
message = f"批量计算完成,成功 {len(results)} 个,失败 {len(errors)}"
return _success_response(batch_result)
except Exception as e:
logger.error(f"批量计算失败: {str(e)}")
return _error_response(f"批量计算失败: {str(e)}", 500)
# 核心接口:土方量计算-点云
@earthwork_bp.post("/api/v1/calc/earthworkPointCloud")
async def calc_earthwork_point_cloud(request: Request):
try:
# 1. 接收前端传参
data = request.json
if not data:
return json({
"code": 400,
"msg": "请求参数不能为空",
"data": None
}, status=400)
polygon_coords = data.get("polygonCoords") # 计算区域多边形坐标
design_elev = data.get("designElevation") # 设计高程
crs = data.get("crs", "EPSG:4326") # 坐标系默认WGS84
url = data.get("url")
if url is None:
return _error_response("地图不能为空", 400)
# 2. 确保计算器已初始化
app_info = init_app(url)
calculator_point_cloud = app_info.get("calculator_point_cloud")
result = calculator_point_cloud.calculate_earthwork(polygon_coords=polygon_coords, design_elev=design_elev, crs=crs)
# 3. 处理结果
if not result["success"]:
return _error_response(result["error"], 400)
# 4. 格式化结果
formatted_result = calculator_point_cloud.format_result(result)
# 5. 返回成功响应
return _success_response(formatted_result)
except Exception as e:
logger.error(f"服务器错误: {str(e)}")
return _error_response(f"服务器内部错误: {str(e)}", 500)
def _success_response(data: Dict[str, Any]) -> json:
"""成功响应"""
return json({
"code": 200,
"msg": "计算成功",
"data": data
})
def _error_response(message: str, status_code: int = 400) -> json:
"""错误响应"""
return json({
"code": status_code,
"msg": message,
"data": None
}, status=status_code)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,691 @@
# earthwork_calculator.py
import pdal
import pyvista as pv
import numpy as np
import json
from typing import List, Tuple, Dict, Any, Optional
from dataclasses import dataclass, field
from enum import Enum
import traceback
class EarthworkAlgorithm(Enum):
"""土方量计算算法枚举"""
GRID = "grid"
TIN = "tin"
PRISM = "prism"
@dataclass
class EarthworkResultPointCloud:
"""土方量计算结果"""
cut_volume: float # 挖方量 (m³)
fill_volume: float # 填方量 (m³)
net_volume: float # 净方量 (m³)
area: float # 计算区域面积 (m²)
avg_elevation: float # 平均高程
min_elevation: float # 最低高程
max_elevation: float # 最高高程
points_count: int # 使用的点数
triangle_count: int = 0 # 三角形数量
grid_count: int = 0 # 网格数量仅GRID算法使用
prism_count: int = 0 # 棱柱体数量仅PRISM算法使用
bounding_box: Dict[str, List[float]] = field(default_factory=dict) # 边界框
volume_accuracy: float = 0.95 # 计算精度
algorithm: str = "TIN三角网法" # 使用的算法
resolution: float = 1.0 # 计算分辨率
algorithm_params: Dict[str, Any] = field(default_factory=dict) # 算法参数
def to_dict(self) -> Dict:
"""转换为字典"""
result = {
"volume": {
"cut": round(self.cut_volume, 3),
"fill": round(self.fill_volume, 3),
"net": round(self.net_volume, 3),
"unit": ""
},
"area": {
"value": round(self.area, 3),
"unit": ""
},
"elevation": {
"average": round(self.avg_elevation, 3),
"min": round(self.min_elevation, 3),
"max": round(self.max_elevation, 3),
"unit": "m"
},
"statistics": {
"points_count": self.points_count,
"accuracy": round(self.volume_accuracy, 3),
"algorithm": self.algorithm
},
"bounding_box": self.bounding_box,
"calculation_params": {
"resolution": self.resolution,
"accuracy": self.volume_accuracy,
**self.algorithm_params
}
}
# 根据算法类型添加特定的统计信息
if self.algorithm.startswith("GRID"):
result["statistics"]["grid_count"] = self.grid_count
elif self.algorithm.startswith("TIN"):
result["statistics"]["triangle_count"] = self.triangle_count
elif self.algorithm.startswith("PRISM"):
result["statistics"]["prism_count"] = self.prism_count
return result
class EarthworkCalculatorPointCloud:
"""土方量计算核心类(支持多种算法)"""
def __init__(self, point_cloud_path: str = "./data/your_point_cloud.laz"):
"""
初始化土方量计算器
Args:
point_cloud_path: 点云数据文件路径
"""
self.point_cloud_path = point_cloud_path
def validate_inputs(self, polygon_coords: List[List[float]], design_elev: float) -> Tuple[bool, str]:
"""验证输入参数"""
if not polygon_coords or len(polygon_coords) < 3:
return False, "多边形坐标至少需要3个点"
try:
design_elev = float(design_elev)
except (TypeError, ValueError):
return False, "设计高程必须是有效数字"
return True, ""
def create_polygon_string(self, polygon_coords: List[List[float]]) -> str:
"""创建PDAL多边形字符串"""
coords_list = []
for coord in polygon_coords:
if len(coord) >= 2:
coords_list.append(f"{coord[0]} {coord[1]}")
# 确保多边形闭合
if coords_list and coords_list[0] != coords_list[-1]:
coords_list.append(coords_list[0])
return "POLYGON((" + ", ".join(coords_list) + "))"
def calculate_bounding_box(self, points: np.ndarray) -> Dict[str, List[float]]:
"""
计算边界框
Args:
points: 点云坐标数组
Returns:
Dict: 边界框信息
"""
if len(points) == 0:
return {"min": [0, 0, 0], "max": [0, 0, 0]}
min_vals = np.min(points, axis=0)
max_vals = np.max(points, axis=0)
return {
"min": [float(min_vals[0]), float(min_vals[1]), float(min_vals[2])],
"max": [float(max_vals[0]), float(max_vals[1]), float(max_vals[2])]
}
def clip_point_cloud(self, polygon_coords: List[List[float]], crs: str = "EPSG:4326") -> pv.PolyData:
"""
裁剪点云数据
Args:
polygon_coords: 多边形坐标列表
crs: 坐标系
Returns:
pyvista.PolyData: 裁剪后的点云数据
"""
polygon_str = self.create_polygon_string(polygon_coords)
# PDAL管道配置
pipeline_config = {
"pipeline": [
{
"type": "readers.las",
"filename": self.point_cloud_path,
"spatialreference": crs
},
{
"type": "filters.crop",
"polygon": polygon_str
}
]
}
# 执行PDAL管道
pipeline = pdal.Pipeline(json.dumps(pipeline_config))
try:
pipeline.execute()
if len(pipeline.arrays) == 0:
raise ValueError("多边形区域内没有找到点云数据")
# 获取裁剪后的点云数据
points = pipeline.arrays[0]
x = points["X"]
y = points["Y"]
z = points["Z"]
return pv.PolyData(np.column_stack((x, y, z)))
except RuntimeError as e:
print(f"PDAL执行失败: {str(e)}")
# 如果没有PDAL数据生成模拟数据用于测试
return self.generate_mock_point_cloud(polygon_coords)
def generate_mock_point_cloud(self, polygon_coords: List[List[float]]) -> pv.PolyData:
"""
生成模拟点云数据仅用于测试
Args:
polygon_coords: 多边形坐标列表
Returns:
pyvista.PolyData: 模拟点云数据
"""
print("使用模拟数据进行测试...")
n_points = 1000
# 获取多边形边界
x_coords = [c[0] for c in polygon_coords]
y_coords = [c[1] for c in polygon_coords]
x_min, x_max = min(x_coords), max(x_coords)
y_min, y_max = min(y_coords), max(y_coords)
# 生成随机点
x = np.random.uniform(x_min, x_max, n_points)
y = np.random.uniform(y_min, y_max, n_points)
z = np.random.uniform(100, 120, n_points) # 模拟高程在100-120米之间
return pv.PolyData(np.column_stack((x, y, z)))
def create_tin_mesh(self, point_cloud: pv.PolyData) -> pv.PolyData:
"""
创建三角网TIN算法使用
Args:
point_cloud: 点云数据
Returns:
pyvista.PolyData: 三角网格
"""
if len(point_cloud.points) < 3:
raise ValueError("点云数据不足,无法构网")
try:
return point_cloud.delaunay_2d()
except Exception as e:
raise ValueError(f"三角网构网失败: {str(e)}")
def calculate_volumes_by_tin(self, mesh: pv.PolyData, design_elev: float) -> Dict[str, Any]:
"""
TIN算法计算土方量
Args:
mesh: 三角网格
design_elev: 设计高程
Returns:
Dict: 包含体积计算结果和额外信息
"""
points = mesh.points
elev_diff = points[:, 2] - design_elev
cut_volume = 0.0
fill_volume = 0.0
triangle_count = 0
# 遍历所有三角形面片计算体积
cells = mesh.cells.reshape(-1, 4)
if len(cells) == 0:
raise ValueError("无法生成有效的三角网")
for cell in cells:
if cell[0] == 3: # 三角形VTK格式3个顶点
triangle_count += 1
vertex_indices = cell[1:]
pts = points[vertex_indices]
# 计算三角形面积
v1 = pts[1] - pts[0]
v2 = pts[2] - pts[0]
area = 0.5 * np.linalg.norm(np.cross(v1, v2))
# 计算平均高程差
avg_diff = np.mean(elev_diff[vertex_indices])
vol = area * avg_diff
if vol > 0:
cut_volume += vol
else:
fill_volume += abs(vol)
return {
"cut_volume": cut_volume,
"fill_volume": fill_volume,
"net_volume": cut_volume - fill_volume,
"triangle_count": triangle_count
}
def calculate_volumes_by_grid(self, point_cloud: pv.PolyData, design_elev: float,
grid_size: float = 1.0) -> Dict[str, Any]:
"""
GRID算法计算土方量
Args:
point_cloud: 点云数据
design_elev: 设计高程
grid_size: 网格尺寸
Returns:
Dict: 包含体积计算结果和额外信息
"""
points = point_cloud.points
if len(points) == 0:
raise ValueError("点云数据为空")
# 计算边界
x_min, y_min = np.min(points[:, :2], axis=0)
x_max, y_max = np.max(points[:, :2], axis=0)
# 创建网格
x_edges = np.arange(x_min, x_max + grid_size, grid_size)
y_edges = np.arange(y_min, y_max + grid_size, grid_size)
grid_count = (len(x_edges) - 1) * (len(y_edges) - 1)
cut_volume = 0.0
fill_volume = 0.0
# 对每个网格计算土方量
for i in range(len(x_edges) - 1):
for j in range(len(y_edges) - 1):
# 获取当前网格内的点
mask = (points[:, 0] >= x_edges[i]) & (points[:, 0] < x_edges[i+1]) & \
(points[:, 1] >= y_edges[j]) & (points[:, 1] < y_edges[j+1])
grid_points = points[mask]
if len(grid_points) > 0:
# 计算网格内点的平均高程
avg_elevation = np.mean(grid_points[:, 2])
# 计算高程差
elev_diff = avg_elevation - design_elev
# 计算体积
cell_area = grid_size * grid_size
vol = cell_area * elev_diff
if vol > 0:
cut_volume += vol
else:
fill_volume += abs(vol)
return {
"cut_volume": cut_volume,
"fill_volume": fill_volume,
"net_volume": cut_volume - fill_volume,
"grid_count": grid_count
}
def calculate_volumes_by_prism(self, point_cloud: pv.PolyData, design_elev: float,
influence_radius: float = 0.5) -> Dict[str, Any]:
"""
PRISM算法计算土方量
Args:
point_cloud: 点云数据
design_elev: 设计高程
influence_radius: 影响半径
Returns:
Dict: 包含体积计算结果和额外信息
"""
points = point_cloud.points
if len(points) == 0:
raise ValueError("点云数据为空")
cut_volume = 0.0
fill_volume = 0.0
# 每个点的影响面积
influence_area = np.pi * influence_radius ** 2
prism_count = len(points)
for point in points:
# 计算高程差
elev_diff = point[2] - design_elev
# 计算体积
vol = influence_area * elev_diff
if vol > 0:
cut_volume += vol
else:
fill_volume += abs(vol)
return {
"cut_volume": cut_volume,
"fill_volume": fill_volume,
"net_volume": cut_volume - fill_volume,
"prism_count": prism_count
}
def calculate_statistics(self, point_cloud: pv.PolyData, mesh: pv.PolyData = None) -> Dict[str, float]:
"""
计算统计数据
Args:
point_cloud: 点云数据
mesh: 三角网格仅TIN算法需要
Returns:
Dict: 统计结果
"""
elevations = point_cloud.points[:, 2]
stats = {
"area": 0.0,
"max_elevation": np.max(elevations) if len(elevations) > 0 else 0.0,
"min_elevation": np.min(elevations) if len(elevations) > 0 else 0.0,
"avg_elevation": np.mean(elevations) if len(elevations) > 0 else 0.0,
"points_count": len(point_cloud.points)
}
# 计算面积
if mesh is not None:
stats["area"] = mesh.area
else:
# 对于非TIN算法使用多边形面积近似
if len(point_cloud.points) > 0:
# 使用点云的凸包面积
try:
hull = point_cloud.delaunay_2d()
stats["area"] = hull.area
except:
# 如果无法计算凸包,使用边界框面积
x_min, x_max = np.min(point_cloud.points[:, 0]), np.max(point_cloud.points[:, 0])
y_min, y_max = np.min(point_cloud.points[:, 1]), np.max(point_cloud.points[:, 1])
stats["area"] = (x_max - x_min) * (y_max - y_min)
return stats
def calculate_earthwork(self,
polygon_coords: List[List[float]],
design_elev: float,
algorithm: str = EarthworkAlgorithm.TIN.value,
algorithm_params: Optional[Dict[str, Any]] = None,
crs: str = "EPSG:4326",
volume_accuracy: Optional[float] = None,
resolution: Optional[float] = None) -> EarthworkResultPointCloud:
"""
主计算方法执行完整的土方量计算流程
Args:
polygon_coords: 多边形坐标列表
design_elev: 设计高程
algorithm: 计算算法可选值'grid', 'tin', 'prism'
algorithm_params: 算法特定参数
crs: 坐标系
volume_accuracy: 计算精度0-1之间
resolution: 计算分辨率
Returns:
EarthworkResultPointCloud: 计算结果
"""
try:
# 1. 验证输入
is_valid, message = self.validate_inputs(polygon_coords, design_elev)
if not is_valid:
raise ValueError(message)
design_elev = float(design_elev)
# 2. 验证算法参数
if algorithm not in [a.value for a in EarthworkAlgorithm]:
raise ValueError(f"不支持的算法: {algorithm}。支持的算法: {[a.value for a in EarthworkAlgorithm]}")
# 3. 设置默认参数
if algorithm_params is None:
algorithm_params = {}
# 4. 裁剪点云
point_cloud = self.clip_point_cloud(polygon_coords, crs)
# 5. 根据算法选择计算方法
algorithm_name = ""
mesh = None
if algorithm == EarthworkAlgorithm.TIN.value:
algorithm_name = "TIN三角网法"
mesh = self.create_tin_mesh(point_cloud)
volumes = self.calculate_volumes_by_tin(mesh, design_elev)
algorithm_params = {
"grid_size": algorithm_params.get("grid_size", 1.0)
}
elif algorithm == EarthworkAlgorithm.GRID.value:
algorithm_name = "GRID格网法"
grid_size = algorithm_params.get("grid_size", 1.0)
algorithm_name = f"GRID格网法(网格尺寸={grid_size}m)"
volumes = self.calculate_volumes_by_grid(point_cloud, design_elev, grid_size)
algorithm_params = {
"grid_size": grid_size
}
elif algorithm == EarthworkAlgorithm.PRISM.value:
algorithm_name = "PRISM棱柱体法"
influence_radius = algorithm_params.get("influence_radius", 0.5)
algorithm_name = f"PRISM棱柱体法(影响半径={influence_radius}m)"
volumes = self.calculate_volumes_by_prism(point_cloud, design_elev, influence_radius)
algorithm_params = {
"influence_radius": influence_radius
}
# 6. 计算统计数据
stats = self.calculate_statistics(point_cloud, mesh)
# 7. 计算边界框
bounding_box = self.calculate_bounding_box(point_cloud.points)
# 8. 计算或使用传入的精度和分辨率
if volume_accuracy is None:
# 根据算法和点云密度自动估算精度
volume_accuracy = self.estimate_accuracy(algorithm, point_cloud)
if resolution is None:
# 根据点云密度自动估算分辨率
resolution = self.estimate_resolution(point_cloud)
# 9. 创建EarthworkResultPointCloud对象
result = EarthworkResultPointCloud(
cut_volume=volumes["cut_volume"],
fill_volume=volumes["fill_volume"],
net_volume=volumes["net_volume"],
area=stats["area"],
avg_elevation=stats["avg_elevation"],
min_elevation=stats["min_elevation"],
max_elevation=stats["max_elevation"],
points_count=stats["points_count"],
triangle_count=volumes.get("triangle_count", 0),
grid_count=volumes.get("grid_count", 0),
prism_count=volumes.get("prism_count", 0),
bounding_box=bounding_box,
volume_accuracy=volume_accuracy,
algorithm=algorithm_name,
resolution=resolution,
algorithm_params=algorithm_params
)
return result
except Exception as e:
print(f"计算错误: {str(e)}")
print(traceback.format_exc())
raise
def estimate_accuracy(self, algorithm: str, point_cloud: pv.PolyData) -> float:
"""根据算法和点云密度估计计算精度"""
point_density = len(point_cloud.points) / max(point_cloud.area, 0.1)
# 基础精度
base_accuracy = {
EarthworkAlgorithm.TIN.value: 0.95,
EarthworkAlgorithm.GRID.value: 0.90,
EarthworkAlgorithm.PRISM.value: 0.85
}.get(algorithm, 0.90)
# 根据点云密度调整精度
if point_density > 10: # 高密度点云
accuracy_boost = min(0.05, point_density * 0.002)
elif point_density < 1: # 低密度点云
accuracy_boost = -0.05
else:
accuracy_boost = 0
estimated_accuracy = base_accuracy + accuracy_boost
# 确保精度在合理范围内
return max(0.7, min(0.99, estimated_accuracy))
def estimate_resolution(self, point_cloud: pv.PolyData) -> float:
"""根据点云密度估计分辨率"""
if len(point_cloud.points) < 2:
return 1.0
area = point_cloud.area
if area > 0:
point_count = len(point_cloud.points)
avg_spacing = np.sqrt(area / point_count)
return float(round(avg_spacing, 2))
return 1.0
def get_algorithm_info(self) -> Dict[str, Any]:
"""
获取支持的算法信息
Returns:
Dict: 算法信息
"""
return {
"supported_algorithms": [
{
"id": EarthworkAlgorithm.TIN.value,
"name": "TIN三角网法",
"description": "通过构建不规则三角网计算土方量,精度高,适合复杂地形",
"parameters": [
{
"name": "grid_size",
"type": "float",
"default": 1.0,
"description": "网格尺寸(m),用于点云预处理",
"required": False
}
]
},
{
"id": EarthworkAlgorithm.GRID.value,
"name": "GRID格网法",
"description": "将区域划分为规则网格计算土方量,计算速度快,适合大规模区域",
"parameters": [
{
"name": "grid_size",
"type": "float",
"default": 1.0,
"description": "网格尺寸(m)",
"required": True,
"min": 0.1,
"max": 10.0
}
]
},
{
"id": EarthworkAlgorithm.PRISM.value,
"name": "PRISM棱柱体法",
"description": "将每个点视为一个棱柱体计算土方量,计算简单快速",
"parameters": [
{
"name": "influence_radius",
"type": "float",
"default": 0.5,
"description": "点的影响半径(m)",
"required": True,
"min": 0.1,
"max": 5.0
}
]
}
],
"default_algorithm": EarthworkAlgorithm.TIN.value
}
# 使用示例
if __name__ == "__main__":
# 创建计算器实例
calculator = EarthworkCalculatorPointCloud("./data/sample_point_cloud.laz")
# 获取支持的算法信息
algorithm_info = calculator.get_algorithm_info()
print("支持的算法:")
for algo in algorithm_info["supported_algorithms"]:
print(f" {algo['id']}: {algo['name']} - {algo['description']}")
# 定义多边形区域
polygon = [
[116.3974, 39.9093],
[116.4084, 39.9093],
[116.4084, 39.9193],
[116.3974, 39.9193]
]
# 设计高程
design_elevation = 100.0
# 测试不同算法
algorithms = [
(EarthworkAlgorithm.TIN.value, {}, "TIN算法"),
(EarthworkAlgorithm.GRID.value, {"grid_size": 2.0}, "GRID算法(2米网格)"),
(EarthworkAlgorithm.PRISM.value, {"influence_radius": 1.0}, "PRISM算法(1米影响半径)")
]
for algo_id, params, description in algorithms:
print(f"\n使用{description}计算:")
try:
result = calculator.calculate_earthwork(
polygon_coords=polygon,
design_elev=design_elevation,
algorithm=algo_id,
algorithm_params=params
)
print(f" 挖方量: {result.cut_volume:.3f}")
print(f" 填方量: {result.fill_volume:.3f}")
print(f" 净方量: {result.net_volume:.3f}")
print(f" 计算面积: {result.area:.3f}")
print(f" 计算精度: {result.volume_accuracy:.3%}")
print(f" 分辨率: {result.resolution:.2f} m")
except Exception as e:
print(f" 计算失败: {str(e)}")

469
b3dm/glb_with_draco.py Normal file
View File

@ -0,0 +1,469 @@
import json
import struct
import numpy as np
import DracoPy
class DracoGLBParser:
"""使用 DracoPy 解析包含 Draco 压缩的 GLB 文件"""
def __init__(self, glb_file_path):
self.glb_file_path = glb_file_path
self.json_data = None
self.binary_data = None
self.decoded_meshes = [] # 缓存解码后的网格数据
def parse_glb_structure(self):
"""解析 GLB 文件结构"""
with open(self.glb_file_path, 'rb') as f:
# 读取 GLB 头部
magic = f.read(4)
version = struct.unpack('<I', f.read(4))[0]
total_length = struct.unpack('<I', f.read(4))[0]
print("=" * 60)
print(f"GLB 文件分析:")
print(f" 文件类型: {magic.decode('utf-8')}")
print(f" 版本: {version}")
print(f" 总大小: {total_length:,} bytes")
# 读取 JSON chunk
json_length = struct.unpack('<I', f.read(4))[0]
json_type = f.read(4)
if json_type != b'JSON':
raise ValueError(f"期望 JSON chunk但得到: {json_type}")
self.json_data = json.loads(f.read(json_length).decode('utf-8'))
print(f" JSON 大小: {json_length:,} bytes")
# 读取 Binary chunk
if f.tell() < total_length:
bin_length = struct.unpack('<I', f.read(4))[0]
bin_type = f.read(4)
if bin_type != b'BIN\x00':
raise ValueError(f"期望 BIN chunk但得到: {bin_type}")
self.binary_data = f.read(bin_length)
print(f" 二进制数据大小: {bin_length:,} bytes")
print("=" * 60)
return self
def analyze_structure(self):
"""分析 GLB 结构"""
if not self.json_data:
self.parse_glb_structure()
print("\nGLB 结构分析:")
print("-" * 40)
# 基本信息
asset = self.json_data.get('asset', {})
print(f"生成器: {asset.get('generator', '未知')}")
print(f"glTF 版本: {asset.get('version', '未知')}")
# Draco 扩展
extensions_used = self.json_data.get('extensionsUsed', [])
extensions_required = self.json_data.get('extensionsRequired', [])
if 'KHR_draco_mesh_compression' in extensions_used:
print("使用 Draco 压缩")
if 'KHR_draco_mesh_compression' in extensions_required:
print("Draco 压缩是必需的")
# 网格信息
meshes = self.json_data.get('meshes', [])
print(f"\n网格数量: {len(meshes)}")
for i, mesh in enumerate(meshes):
print(f" 网格 {i}: {mesh.get('name', '未命名')}")
primitives = mesh.get('primitives', [])
print(f" 图元数量: {len(primitives)}")
for j, primitive in enumerate(primitives):
print(f" 图元 {j}:")
if 'extensions' in primitive:
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
if draco_info:
print(f" 使用 Draco 压缩")
print(f" 属性: {draco_info.get('attributes', {})}")
# 缓冲区信息
buffers = self.json_data.get('buffers', [])
buffer_views = self.json_data.get('bufferViews', [])
accessors = self.json_data.get('accessors', [])
print(f"\n缓冲区: {len(buffers)}")
print(f"BufferViews: {len(buffer_views)}")
print(f"访问器: {len(accessors)}")
return self
def decode_draco_meshes(self):
"""解码所有 Draco 压缩的网格"""
if not self.json_data:
self.parse_glb_structure()
meshes = []
buffer_views = self.json_data.get('bufferViews', [])
print("\n" + "=" * 60)
print("开始解码 Draco 压缩数据...")
print("=" * 60)
for mesh_idx, mesh in enumerate(self.json_data.get('meshes', [])):
mesh_name = mesh.get('name', f'mesh_{mesh_idx}')
for primitive_idx, primitive in enumerate(mesh.get('primitives', [])):
if 'extensions' in primitive:
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
if draco_info:
print(f"\n解码: {mesh_name} - 图元 {primitive_idx}")
# 解码 Draco 数据
mesh_data = self._decode_primitive(draco_info, buffer_views)
if mesh_data:
meshes.append({
'mesh_idx': mesh_idx,
'primitive_idx': primitive_idx,
'name': mesh_name,
**mesh_data
})
self.decoded_meshes = meshes # 缓存解码结果
print("\n" + "=" * 60)
print(f"解码完成!共解码 {len(meshes)} 个网格")
print("=" * 60)
return meshes
def get_vertices(self, mesh_idx=0, primitive_idx=0):
"""
获取指定网格的顶点集合
参数:
mesh_idx: 网格索引默认0
primitive_idx: 图元索引默认0
返回:
np.array: 顶点数组形状为 (n, 3) None
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
# 查找指定网格
for mesh in self.decoded_meshes:
if mesh['mesh_idx'] == mesh_idx and mesh['primitive_idx'] == primitive_idx:
return mesh.get('vertices')
print(f"未找到网格 {mesh_idx} 的图元 {primitive_idx}")
return None
def get_all_vertices(self):
"""
获取所有网格的所有顶点合并成一个数组
返回:
np.array: 所有顶点的合并数组形状为 (n, 3) None
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
if not self.decoded_meshes:
print("没有解码的网格数据")
return None
# 收集所有顶点
all_vertices = []
for mesh in self.decoded_meshes:
if mesh.get('vertices') is not None:
all_vertices.append(mesh['vertices'])
if not all_vertices:
return None
# 合并所有顶点
return np.vstack(all_vertices)
def get_vertices_by_mesh_name(self, mesh_name):
"""
根据网格名称获取顶点集合
参数:
mesh_name: 网格名称
返回:
list: 包含所有匹配网格的顶点数组列表
"""
# 如果还没有解码数据,先解码
if not self.decoded_meshes:
self.decode_draco_meshes()
vertices_list = []
for mesh in self.decoded_meshes:
if mesh['name'] == mesh_name and mesh.get('vertices') is not None:
vertices_list.append(mesh['vertices'])
return vertices_list
def get_vertex_count(self):
"""
获取总顶点数
返回:
int: 所有网格的总顶点数
"""
vertices = self.get_all_vertices()
return len(vertices) if vertices is not None else 0
def _decode_primitive(self, draco_info, buffer_views):
"""解码单个图元的 Draco 数据"""
try:
# 获取 bufferView 信息
buffer_view_idx = draco_info['bufferView']
attributes = draco_info['attributes']
buffer_view = buffer_views[buffer_view_idx]
byte_offset = buffer_view.get('byteOffset', 0)
byte_length = buffer_view['byteLength']
print(f" BufferView: {buffer_view_idx}")
print(f" 属性映射: {attributes}")
print(f" 数据位置: offset={byte_offset}, length={byte_length}")
# 提取 Draco 压缩数据
draco_data = self.binary_data[byte_offset:byte_offset + byte_length]
print(f" Draco 数据大小: {len(draco_data):,} bytes")
# 使用 DracoPy 解码
print(" 正在使用 DracoPy 解码...")
draco_decoder = DracoPy.decode(draco_data)
# 解析解码结果
mesh_data = self._parse_draco_result(draco_decoder, attributes)
return mesh_data
except Exception as e:
print(f" 解码失败: {e}")
import traceback
traceback.print_exc()
return None
def _parse_draco_result(self, draco_decoder, attributes):
"""解析 DracoPy 解码结果"""
result = {
'vertices': None,
'faces': None,
'texcoords': None,
'batch_ids': None,
'normals': None,
'colors': None
}
# 获取顶点
if hasattr(draco_decoder, 'points'):
result['vertices'] = np.array(draco_decoder.points, dtype=np.float32)
print(f" 顶点数量: {len(result['vertices'])}")
# 获取面/三角形
if hasattr(draco_decoder, 'faces'):
faces_data = draco_decoder.faces
# 确保是三角形每面3个顶点
if len(faces_data) > 0:
if isinstance(faces_data[0], list) or isinstance(faces_data[0], tuple):
# 如果是列表的列表
result['faces'] = np.array(faces_data, dtype=np.uint32)
else:
# 如果是扁平化的数组
result['faces'] = np.array(faces_data, dtype=np.uint32).reshape(-1, 3)
print(f" 面数量: {len(result['faces']) if result['faces'] is not None else 0}")
# 获取属性数据
if hasattr(draco_decoder, 'attributes'):
attrs = draco_decoder.attributes
# 根据属性映射查找数据
for gltf_attr_name, draco_attr_id in attributes.items():
if draco_attr_id in attrs:
attr_data = attrs[draco_attr_id]
if gltf_attr_name == 'POSITION':
result['vertices'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == 'TEXCOORD_0':
result['texcoords'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == '_BATCHID':
result['batch_ids'] = np.array(attr_data, dtype=np.uint32)
elif gltf_attr_name == 'NORMAL':
result['normals'] = np.array(attr_data, dtype=np.float32)
elif gltf_attr_name == 'COLOR_0':
result['colors'] = np.array(attr_data, dtype=np.float32)
print(f" 已提取属性: {gltf_attr_name} (ID: {draco_attr_id})")
# 如果没有通过attributes获取到顶点尝试其他方式
if result['vertices'] is None and hasattr(draco_decoder, 'get_points'):
try:
result['vertices'] = np.array(draco_decoder.get_points(), dtype=np.float32)
except:
pass
return result
def save_decoded_meshes(self, meshes, output_format='obj'):
"""保存解码后的网格"""
import os
base_name = os.path.splitext(os.path.basename(self.glb_file_path))[0]
output_dir = f"{base_name}_decoded"
if not os.path.exists(output_dir):
os.makedirs(output_dir)
for mesh in meshes:
filename = f"{output_dir}/{mesh['name']}_p{mesh['primitive_idx']}.{output_format}"
if output_format == 'obj':
self._save_as_obj(mesh, filename)
elif output_format == 'ply':
self._save_as_ply(mesh, filename)
elif output_format == 'npz':
self._save_as_npz(mesh, filename)
else:
print(f"不支持的格式: {output_format}")
continue
print(f"已保存: {filename}")
def _save_as_obj(self, mesh, filename):
"""保存为 OBJ 格式"""
with open(filename, 'w') as f:
# 写入顶点
if mesh['vertices'] is not None:
for v in mesh['vertices']:
f.write(f"v {v[0]} {v[1]} {v[2]}\n")
# 写入纹理坐标
if mesh['texcoords'] is not None:
for uv in mesh['texcoords']:
f.write(f"vt {uv[0]} {uv[1]}\n")
# 写入法线
if mesh['normals'] is not None:
for n in mesh['normals']:
f.write(f"vn {n[0]} {n[1]} {n[2]}\n")
# 写入面
if mesh['faces'] is not None:
for face in mesh['faces']:
# OBJ 索引从1开始
face_indices = [str(idx + 1) for idx in face]
f.write(f"f {' '.join(face_indices)}\n")
def _save_as_ply(self, mesh, filename):
"""保存为 PLY 格式"""
from plyfile import PlyData, PlyElement
import numpy as np
vertices = mesh['vertices']
faces = mesh['faces']
if vertices is None:
return
# 创建顶点数据
vertex_data = np.zeros(len(vertices), dtype=[
('x', 'f4'), ('y', 'f4'), ('z', 'f4')
])
vertex_data['x'] = vertices[:, 0]
vertex_data['y'] = vertices[:, 1]
vertex_data['z'] = vertices[:, 2]
# 创建面数据
if faces is not None:
face_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (3,))])
face_data['vertex_indices'] = faces
# 写入文件
vertex_element = PlyElement.describe(vertex_data, 'vertex')
if faces is not None:
face_element = PlyElement.describe(face_data, 'face')
PlyData([vertex_element, face_element], text=False).write(filename)
else:
PlyData([vertex_element], text=False).write(filename)
def _save_as_npz(self, mesh, filename):
"""保存为 NPZ 格式"""
np.savez(
filename,
vertices=mesh['vertices'],
faces=mesh['faces'],
texcoords=mesh['texcoords'],
batch_ids=mesh['batch_ids'],
normals=mesh['normals'],
colors=mesh['colors']
)
# 使用示例
def main():
# 初始化解析器
parser = DracoGLBParser(r"D:\devForBdzlWork\ai_project_v1\b3dm\test\temp_glb\temp_6e895637.glb")
# 解析 GLB 结构
parser.parse_glb_structure()
# 分析结构
parser.analyze_structure()
# 解码 Draco 网格
meshes = parser.decode_draco_meshes()
# 使用新增的顶点获取方法
print("\n" + "=" * 60)
print("顶点获取方法演示:")
print("=" * 60)
# 1. 获取第一个网格的第一个图元的顶点
vertices = parser.get_vertices(mesh_idx=0, primitive_idx=0)
if vertices is not None:
print(f"1. 获取第一个网格顶点:")
print(f" 形状: {vertices.shape}")
print(f" 数据类型: {vertices.dtype}")
print(f" 前5个顶点: \n{vertices[:5]}")
# 2. 获取所有顶点(合并)
all_vertices = parser.get_all_vertices()
if all_vertices is not None:
print(f"\n2. 获取所有网格顶点(合并):")
print(f" 总顶点数: {len(all_vertices)}")
print(f" 形状: {all_vertices.shape}")
# 3. 获取总顶点数
total_vertices = parser.get_vertex_count()
print(f"\n3. 总顶点数: {total_vertices}")
# 4. 根据网格名称获取顶点
if meshes:
mesh_name = meshes[0]['name']
vertices_list = parser.get_vertices_by_mesh_name(mesh_name)
print(f"\n4. 根据名称 '{mesh_name}' 获取的顶点:")
for i, verts in enumerate(vertices_list):
print(f" 图元 {i}: {verts.shape if verts is not None else 'None'}")
# 保存解码后的网格
parser.save_decoded_meshes(meshes, output_format='obj')
if __name__ == "__main__":
main()

380
b3dm/slope_aspect_img.py Normal file
View File

@ -0,0 +1,380 @@
# slope_aspect_img13
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import sobel
import matplotlib as mpl
import rasterio
from osgeo import gdal
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d
import os
# 自定义3D箭头类
class Arrow3D(FancyArrowPatch):
def __init__(self, xs, ys, zs, *args, **kwargs):
super().__init__((0,0), (0,0), *args, **kwargs)
self._verts3d = xs, ys, zs
def do_3d_projection(self, renderer=None):
xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
return min(zs)
def read_dem_rasterio(filepath):
"""使用rasterio读取DEM数据"""
try:
with rasterio.open(filepath) as src:
dem_data = src.read(1)
transform = src.transform
bounds = src.bounds
crs = src.crs
print(f"DEM信息:")
print(f" 尺寸: {dem_data.shape}")
print(f" 范围: {bounds}")
print(f" 坐标系: {crs}")
print(f" 高程范围: {np.nanmin(dem_data):.2f} - {np.nanmax(dem_data):.2f}")
if np.isnan(dem_data).any():
print(" 检测到NaN值")
dem_data = np.nan_to_num(dem_data, nan=np.nanmean(dem_data))
rows, cols = dem_data.shape
x = np.linspace(bounds.left, bounds.right, cols)
y = np.linspace(bounds.bottom, bounds.top, rows)
X, Y = np.meshgrid(x, y)
return X, Y, dem_data
except Exception as e:
print(f"使用rasterio读取失败: {e}")
return None
def read_slope_aspect_by_dem(dem_path, overall_3d_output_path=None) :
# ---------------------- 核心:解决中文乱码配置 ----------------------
setup_chinese_font()
# 尝试读取DEM数据
print("正在读取DEM文件...")
X, Y, Z = read_dem_rasterio(dem_path)
if X is None:
raise FileNotFoundError(f"无法读取DEM文件: {dem_path}")
# 检查数据有效性
if Z.size == 0 or np.all(Z == 0):
raise ValueError("DEM数据无效或全部为0")
# 重采样
if Z.shape[0] > 200 or Z.shape[1] > 200:
print(f"原始DEM尺寸较大 ({Z.shape}),进行重采样...")
from scipy.ndimage import zoom
scale_factor = min(200/Z.shape[0], 200/Z.shape[1])
Z = zoom(Z, scale_factor, order=1)
X = zoom(X, scale_factor, order=1)
Y = zoom(Y, scale_factor, order=1)
print(f"重采样后尺寸: {Z.shape}")
# ---------------------- 2. 计算坡度和坡向 ----------------------
print("计算坡度和坡向...")
dx_pixel = np.abs(X[0,1] - X[0,0]) * 111000
dy_pixel = np.abs(Y[1,0] - Y[0,0]) * 111000
dx = sobel(Z, axis=1) / (2 * dx_pixel)
dy = sobel(Z, axis=0) / (2 * dy_pixel)
slope_rad = np.arctan(np.sqrt(dx**2 + dy**2))
slope_deg = slope_rad * (180 / np.pi)
aspect_rad = np.arctan2(-dx, dy)
aspect_deg = aspect_rad * (180 / np.pi)
aspect_deg[aspect_deg < 0] += 360
print(f"坡度范围: {np.min(slope_deg):.2f}° - {np.max(slope_deg):.2f}°")
print(f"坡向范围: {np.min(aspect_deg):.2f}° - {np.max(aspect_deg):.2f}°")
# ---------------------- 3. 计算整体坡向 ----------------------
print("\n计算整体坡向...")
def calculate_overall_aspect(aspect_deg, slope_deg, method='weighted_mean'):
"""计算整体坡向"""
aspect_rad = np.deg2rad(aspect_deg)
if method == 'weighted_mean':
# 坡度加权平均法
u = np.sin(aspect_rad)
v = np.cos(aspect_rad)
weights = slope_deg.flatten()
weighted_u = np.nansum(u.flatten() * weights) / np.nansum(weights)
weighted_v = np.nansum(v.flatten() * weights) / np.nansum(weights)
weighted_aspect_rad = np.arctan2(weighted_u, weighted_v)
weighted_aspect_deg = np.rad2deg(weighted_aspect_rad)
if weighted_aspect_deg < 0:
weighted_aspect_deg += 360
weighted_strength = np.sqrt(weighted_u**2 + weighted_v**2)
return weighted_aspect_deg, weighted_strength, "坡度加权平均法"
elif method == 'vector_mean':
# 向量平均法
u = np.sin(aspect_rad)
v = np.cos(aspect_rad)
mean_u = np.nanmean(u)
mean_v = np.nanmean(v)
mean_aspect_rad = np.arctan2(mean_u, mean_v)
mean_aspect_deg = np.rad2deg(mean_aspect_rad)
if mean_aspect_deg < 0:
mean_aspect_deg += 360
vector_strength = np.sqrt(mean_u**2 + mean_v**2)
return mean_aspect_deg, vector_strength, "向量平均法"
# 使用坡度加权平均法计算整体坡向
overall_aspect, overall_strength, method_name = calculate_overall_aspect(aspect_deg, slope_deg, 'weighted_mean')
# 将整体坡向转换为方向描述
if overall_aspect < 22.5 or overall_aspect >= 337.5:
overall_direction = ""
elif 22.5 <= overall_aspect < 67.5:
overall_direction = "东北"
elif 67.5 <= overall_aspect < 112.5:
overall_direction = ""
elif 112.5 <= overall_aspect < 157.5:
overall_direction = "东南"
elif 157.5 <= overall_aspect < 202.5:
overall_direction = ""
elif 202.5 <= overall_aspect < 247.5:
overall_direction = "西南"
elif 247.5 <= overall_aspect < 292.5:
overall_direction = "西"
else:
overall_direction = "西北"
print(f"整体坡向 ({method_name}):")
print(f" 角度: {overall_aspect:.1f}°")
print(f" 方向: {overall_direction}")
# print(f" 一致性: {overall_strength:.3f}")
# ---------------------- 5. 3D可视化俯视图包含整体坡向和关键点坡向 ----------------------
print("\n生成3D俯视图可视化...")
fig3d, ax3d = plt.subplots(figsize=(16, 12), subplot_kw={"projection": "3d"})
# 绘制地形曲面 - 俯视图需要更清晰的地形表现
norm = mpl.colors.Normalize(vmin=np.percentile(slope_deg, 5),
vmax=np.percentile(slope_deg, 95))
plot_skip = max(2, Z.shape[0] // 60) # 增加采样密度,使俯视图更清晰
X_plot = X[::plot_skip, ::plot_skip]
Y_plot = Y[::plot_skip, ::plot_skip]
Z_plot = Z[::plot_skip, ::plot_skip]
slope_plot = slope_deg[::plot_skip, ::plot_skip]
surf = ax3d.plot_surface(
X_plot, Y_plot, Z_plot,
cmap="viridis_r",
alpha=0.85, # 增加透明度,使箭头更明显
linewidth=0.1, # 很细的网格线
facecolors=plt.cm.viridis_r(norm(slope_plot)),
zorder=1
)
# ---------------------- 绘制整体坡向箭头(中心位置,红色粗箭头) ----------------------
center_x = (X.min() + X.max()) / 2
center_y = (Y.min() + Y.max()) / 2
# 整体坡向箭头长度
arrow_length_overall = 0.15 * min(X.max()-X.min(), Y.max()-Y.min())
# 计算整体坡向箭头方向
scale_factor = 1.5
overall_aspect_rad = np.deg2rad(overall_aspect)
dx_overall = np.sin(overall_aspect_rad) * arrow_length_overall * scale_factor
dy_overall = np.cos(overall_aspect_rad) * arrow_length_overall * scale_factor
# 方案2如果希望箭头在地形上方一定高度
terrain_max_z = Z.max() # 地形最高点
float_height = 0.2 * terrain_max_z # 在地形最高点上方20%的高度
# 绘制整体坡向箭头(红色,粗)
arrow_overall = Arrow3D(
[center_x, center_x + dx_overall],
[center_y, center_y + dy_overall],
[terrain_max_z + float_height, terrain_max_z + float_height], # 在地形上方
mutation_scale=25, # 稍大一些的箭头
lw=5, # 更粗的线
arrowstyle='-|>',
color='red',
alpha=0.98, # 更高的透明度
zorder=20 # 更高的绘制顺序
)
ax3d.add_artist(arrow_overall)
# 在整体坡向箭头起点添加标记
ax3d.scatter(center_x, center_y, terrain_max_z + float_height,
color='red', s=120, edgecolor='white', linewidth=2, zorder=21)
# ---------------------- 整体坡向面板放置在左上角 ----------------------
# 计算左上角位置
panel_x = X.min() + 0.02 * (X.max() - X.min()) # 左边留2%的边距
panel_y = Y.max() - 0.02 * (Y.max() - Y.min()) # 上边留2%的边距
panel_z = Z.max() + 0.5 * (Z.max() - Z.min()) # 进一步提高Z坐标确保在视野内
# 整体坡向面板内容
overall_info = (
f"整体坡向分析\n"
f"角度: {overall_aspect:.1f}°\n"
f"方向: {overall_direction}"
# f"一致性: {overall_strength:.3f}"
)
# 绘制整体坡向面板(左上角)
ax3d.text(panel_x, panel_y, panel_z,
overall_info,
fontsize=12, fontweight='bold',
bbox=dict(facecolor='white', alpha=0.5, boxstyle="round,pad=0.5",
edgecolor='red', linewidth=2.5),
ha='left', va='top', zorder=30)
# ---------------------- 设置3D图参数 ----------------------
ax3d.set_xlabel("经度 (X)", fontsize=12)
ax3d.set_ylabel("纬度 (Y)", fontsize=12)
ax3d.set_zlabel("高程 (m)", fontsize=12)
# 获取文件名用于标题
filename = os.path.basename(dem_path)
ax3d.set_title(f"DEM三维俯视图 - 坡向分析\n文件: {filename}",
fontsize=14, fontweight='bold', pad=20)
# 添加颜色条
cbar = plt.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap='viridis_r'),
ax=ax3d, shrink=0.6, aspect=25, pad=0.15
)
cbar.set_label("坡度 (°)", fontsize=12)
# ---------------------- 设置为俯视图(高仰角) ----------------------
view_elev = 85 # 接近90度的仰角俯视效果
view_azim = overall_aspect + 180 # 从坡向的相反方向观看,可以看到坡面
# 确保方位角在0-360度范围内
view_azim = view_azim % 360
print(f"\n设置3D俯视图视角:")
print(f" 整体坡向: {overall_aspect:.1f}° ({overall_direction})")
print(f" 视角方位角: {view_azim:.1f}°")
print(f" 视角仰角: {view_elev:.1f}°")
# 应用俯视图视角设置
ax3d.view_init(elev=view_elev, azim=view_azim)
# 调整相机距离,使视角更广
ax3d.dist = 9.0 # 增加相机距离
# 设置坐标轴范围,确保所有元素都显示
z_min, z_max = Z.min(), Z.max()
z_padding = 0.6 * (z_max - z_min) # 适当增加Z轴范围
ax3d.set_zlim(z_min - 0.1*z_padding, z_max + z_padding)
# 设置XY轴范围
x_margin = 0.1 * (X.max() - X.min())
y_margin = 0.1 * (Y.max() - Y.min())
ax3d.set_xlim(X.min() - x_margin, X.max() + x_margin)
ax3d.set_ylim(Y.min() - y_margin, Y.max() + y_margin)
plt.tight_layout()
# 保存图片
if not overall_3d_output_path :
overall_3d_output_path = dem_path.replace('.tif', '_slopeAspect_3D_overlook.png')
plt.savefig(overall_3d_output_path, dpi=250, bbox_inches='tight', facecolor='white')
# plt.show()
os.remove(dem_path)
print(f"\n3D俯视图已保存: {overall_3d_output_path}")
print("分析完成!")
print("="*60)
return overall_3d_output_path
# ---------------------- 跨平台中文字体配置 ----------------------
def setup_chinese_font():
"""设置中文字体支持兼容Windows、Linux、macOS"""
import platform
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
system = platform.system()
# 尝试添加系统中文字体路径
font_paths = []
if system == 'Windows':
# Windows 字体路径
font_paths = [
'C:/Windows/Fonts/simhei.ttf', # 黑体
'C:/Windows/Fonts/simkai.ttf', # 楷体
'C:/Windows/Fonts/simsun.ttc', # 宋体
'C:/Windows/Fonts/microsoftyahei.ttf', # 微软雅黑
]
elif system == 'Darwin': # macOS
# macOS 字体路径
font_paths = [
'/System/Library/Fonts/PingFang.ttc', # 苹方
'/System/Library/Fonts/STHeiti Light.ttc', # 华文黑体
'/System/Library/Fonts/STHeiti Medium.ttc',
'/Library/Fonts/Arial Unicode.ttf', # Arial Unicode
]
elif system == 'Linux':
# Linux 字体路径
font_paths = [
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑
'/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑
'/usr/share/fonts/truetype/arphic/uming.ttc', # AR PL UMing
'/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # Noto
]
# 尝试找到并添加第一个可用的中文字体
font_added = False
for font_path in font_paths:
try:
if os.path.exists(font_path):
font_prop = mpl.font_manager.FontProperties(fname=font_path)
font_name = font_prop.get_name()
mpl.font_manager.fontManager.addfont(font_path)
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
font_added = True
print(f"已添加中文字体: {font_name} ({font_path})")
break
except Exception as e:
print(f"添加字体 {font_path} 失败: {e}")
continue
# 如果以上方法都失败,使用通用备选方案
if not font_added:
if system == 'Windows':
fallback_fonts = ['SimHei', 'Microsoft YaHei', 'KaiTi', 'FangSong']
elif system == 'Darwin':
fallback_fonts = ['PingFang SC', 'Hiragino Sans GB', 'Apple LiGothic Medium']
else: # Linux and others
fallback_fonts = ['DejaVu Sans', 'WenQuanYi Micro Hei',
'Noto Sans CJK SC', 'Heiti TC', 'AR PL UMing CN']
current_fonts = plt.rcParams.get('font.sans-serif', [])
plt.rcParams['font.sans-serif'] = fallback_fonts + current_fonts
print(f"使用备选字体方案: {fallback_fonts[:2]}...")
# 设置字体族
plt.rcParams['font.family'] = 'sans-serif'
if __name__ == "__main__":
dem_path = r'D:/devForBdzlWork/ai_project_v1/b3dm/o_dem_f1cb6f69_slopeAspect.tif'
read_slope_aspect_by_dem(dem_path)

1062
b3dm/slope_aspect_tif.py Normal file

File diff suppressed because it is too large Load Diff

419
b3dm/terrain_api.py Normal file
View File

@ -0,0 +1,419 @@
from sanic import Blueprint, Request, json
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional, Dict, Any
import numpy as np
import logging
import time
import uuid
import urllib.parse
import threading
import os
from b3dm.terrain_calculator import TerrainCalculator
terrain_bp = Blueprint("terrain", url_prefix="")
MINIO_SUB_PATH = "slopeAspectPng"
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 请求模型
class NormalVector(BaseModel):
"""法向量模型"""
nx: float = Field(..., description="法向量X分量")
ny: float = Field(..., description="法向量Y分量")
nz: float = Field(..., description="法向量Z分量")
@field_validator('nx', 'ny', 'nz')
def check_finite(cls, v):
if not np.isfinite(v):
raise ValueError(f"值必须是有限数字,得到: {v}")
return v
def to_list(self):
return [self.nx, self.ny, self.nz]
class BatchRequest(BaseModel):
"""批量请求模型"""
vectors: List[List[float]] = Field(..., description="法向量列表")
@field_validator('vectors')
def validate_vectors(cls, v):
if len(v) > 1000:
raise ValueError("批量处理最多支持1000个向量")
for vec in v:
if len(vec) != 3:
raise ValueError("每个向量必须是长度为3的列表")
if not all(isinstance(x, (int, float)) for x in vec):
raise ValueError("向量元素必须是数字")
return v
class PointItem(BaseModel):
"""单个点模型"""
x: float = Field(..., description="x坐标")
y: float = Field(..., description="y坐标")
z: float = Field(..., description="z坐标")
class PointRequest(BaseModel):
points: List[PointItem] = Field(..., description="点列表")
url: str = Field(..., description="URL地址")
@field_validator('points')
def validate_points_count(cls, v):
if len(v) > 10:
raise ValueError("批量处理最多支持10个点")
return v
class PreloadRequest(BaseModel):
url: str = Field(..., description="URL地址")
class AnalysisConfig(BaseModel):
"""分析配置"""
classify: bool = Field(default=True, description="是否进行分类")
include_percent: bool = Field(default=True, description="是否包含坡度百分比")
include_direction: bool = Field(default=True, description="是否包含方向描述")
# 中间件:请求计时
@terrain_bp.middleware("request")
async def add_start_time(request: Request):
request.ctx.start_time = time.time()
@terrain_bp.middleware("response")
async def add_response_time(request: Request, response):
if hasattr(request.ctx, "start_time"):
process_time = (time.time() - request.ctx.start_time) * 1000
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
@terrain_bp.post("/api/v1/calculate/slope")
async def calculate_slope(request: Request):
"""计算坡度"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = NormalVector(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 计算坡度
result = TerrainCalculator.calculate_slope(vector.to_list())
# 检查是否有错误
if "error" in result and result["error"]:
return json(result, status=400)
return json({
"success": True,
"data": result,
"request": {
"input_vector": vector.to_list(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡度计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/aspect")
async def calculate_aspect1(request: Request):
"""计算坡向"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = NormalVector(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 计算坡向
result = TerrainCalculator.calculate_aspect(vector.to_list())
# 检查是否有错误
if "error" in result and result["error"]:
return json(result, status=400)
return json({
"success": True,
"data": result,
"request": {
"input_vector": vector.to_list(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡向计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/preload3dTiles")
async def preload_3dtiles(request: Request):
"""预加载3dtiles地图数据"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = PreloadRequest(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 创建并启动线程
script_dir = os.path.dirname(os.path.abspath(__file__))
thread1 = threading.Thread(target=TerrainCalculator.preload_3dtiles, args=(vector.url,))
# 启动线程
thread1.start()
url_prefix = extract_and_rebuild_url(vector.url)
return json({
"success": True,
"data": f"{script_dir}/data_3dtiles",
"request": {
"input_vector": vector.model_dump(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡向计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/slopeAspect")
async def calculate_slopeAspect(request: Request):
"""生成坡向坡度俯视图"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = PointRequest(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 生成坡向坡度俯视图
region_coords = [(point.x, point.y, point.z) for point in vector.points]
overall_3d_png_name = f"o_dem_{uuid.uuid4().hex[:8]}_slopeAspect.png"
# 创建并启动线程
thread1 = threading.Thread(target=TerrainCalculator.generate_slopeAspect_3d_overlook, args=(region_coords, vector.url, overall_3d_png_name, MINIO_SUB_PATH))
# 启动线程
thread1.start()
url_prefix = extract_and_rebuild_url(vector.url)
return json({
"success": True,
"data": f"{url_prefix}/{MINIO_SUB_PATH}/{overall_3d_png_name}",
"request": {
"input_vector": vector.model_dump(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡向计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/slopeAspectTif")
async def calculate_slopeAspect_tif(request: Request):
"""生成坡向坡度tif文件"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = PointRequest(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 生成坡向坡度俯视图
region_coords = [(point.x, point.y, point.z) for point in vector.points]
slope_aspect_tif_name = f"o_dem_{uuid.uuid4().hex[:8]}_slopeAspect.tif"
# 创建并启动线程
thread1 = threading.Thread(target=TerrainCalculator.generate_slopeAspect_tif, args=(region_coords, vector.url, slope_aspect_tif_name, MINIO_SUB_PATH))
# 启动线程
thread1.start()
url_prefix = extract_and_rebuild_url(vector.url)
return json({
"success": True,
"data": f"{url_prefix}/{MINIO_SUB_PATH}/{slope_aspect_tif_name}",
"request": {
"input_vector": vector.model_dump(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"坡向计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/both")
async def calculate_both(request: Request):
"""同时计算坡度和坡向"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
vector = NormalVector(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 计算坡度和坡向
result = TerrainCalculator.calculate_slope_aspect(vector.to_list())
# 检查是否有错误
if "error" in result and result["error"]:
return json(result, status=400)
return json({
"success": True,
"data": result,
"request": {
"input_vector": vector.to_list(),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"综合计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.post("/api/v1/calculate/batch")
async def batch_calculate(request: Request):
"""批量计算"""
try:
data = request.json
if not data:
return json({"error": "请求体不能为空"}, status=400)
# 验证输入
try:
batch_request = BatchRequest(**data)
except Exception as e:
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
# 批量计算
start_time = time.time()
result = TerrainCalculator.batch_calculate(batch_request.vectors)
process_time = (time.time() - start_time) * 1000
if "error" in result and result["error"]:
return json(result, status=400)
return json({
"success": True,
"data": result,
"performance": {
"process_time_ms": process_time,
"vectors_per_second": len(batch_request.vectors) / (process_time / 1000) if process_time > 0 else 0
},
"request": {
"vector_count": len(batch_request.vectors),
"timestamp": time.time()
}
})
except Exception as e:
logger.error(f"批量计算API错误: {e}")
return json({
"error": f"服务器内部错误: {str(e)}",
"success": False
}, status=500)
@terrain_bp.get("/api/v1/example")
async def get_examples(request: Request):
"""获取示例数据"""
examples = {
"flat": {
"nx": 0.0,
"ny": 0.0,
"nz": 1.0,
"expected_slope": 0.0,
"description": "完全水平面"
},
"north_slope_30": {
"nx": 0.0,
"ny": -0.5,
"nz": 0.8660254,
"expected_slope": 30.0,
"expected_aspect": 0.0,
"description": "朝北30度斜坡"
},
"east_slope_45": {
"nx": 0.7071068,
"ny": 0.0,
"nz": 0.7071068,
"expected_slope": 45.0,
"expected_aspect": 90.0,
"description": "朝东45度斜坡"
},
"vertical": {
"nx": 1.0,
"ny": 0.0,
"nz": 0.0,
"expected_slope": 90.0,
"description": "垂直面"
}
}
return json({
"examples": examples,
"count": len(examples)
})
def extract_and_rebuild_url(url):
"""提取URL的三部分并重建"""
# 解析URL
parsed = urllib.parse.urlparse(url)
# 1. 提取协议部分 (http/https)
scheme = parsed.scheme or "http" # 如果没有协议默认用http
# 2. 提取IP端口/主机部分
netloc = parsed.netloc
# 3. 提取第一个路径分段
path = parsed.path.strip("/") # 去掉首尾的斜杠
path_parts = path.split("/")
if path_parts and path_parts[0]:
first_segment = path_parts[0]
else:
first_segment = ""
# 重建URL
if first_segment:
rebuilt_url = f"{scheme}://{netloc}/{first_segment}"
else:
rebuilt_url = f"{scheme}://{netloc}"
return rebuilt_url

392
b3dm/terrain_calculator.py Normal file
View File

@ -0,0 +1,392 @@
import numpy as np
from typing import List, Tuple, Dict, Any
import logging
import os
import uuid
import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem
import b3dm.slope_aspect_img as slope_aspect_img
import b3dm.slope_aspect_tif as slope_aspect_tif
from b3dm.tileset_data_source import TilesetDataSource
logger = logging.getLogger(__name__)
_data_source = None
class TerrainCalculator:
"""地形坡度和坡向计算器"""
def preload_3dtiles(url: str) :
# 下载3dtiles地图数据
_data_source = TilesetDataSource(url)
_data_source.dowload_map_data(url)
if not _data_source.tileset_path :
logger.info(f"下载地图数据失败: {url}")
return "下载地图数据失败", None
def generate_slopeAspect_3d_overlook(region_coords, url, overall_3d_png_name, minio_sub_path) :
# 下载3dtiles地图数据
_data_source = TilesetDataSource(url)
_data_source.dowload_map_data(url)
if not _data_source.tileset_path :
logger.info(f"下载地图数据失败: {url},{region_coords}")
return "下载地图数据失败", None
tileset_path = _data_source.tileset_path
script_dir = os.path.dirname(os.path.abspath(__file__))
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords)
if not os.path.exists(dem_path) :
logger.info(f"生成坡度坡向俯视图失败: {url},{region_coords}")
return "生成坡度坡向俯视图失败", None
overall_3d_png_path = os.path.join(script_dir, overall_3d_png_name)
slope_aspect_img.read_slope_aspect_by_dem(dem_path, overall_3d_png_path)
logger.info(f"生成成功: {url},{region_coords},{overall_3d_png_path}")
entry_bucket, _ = _data_source.parse_minio_url(url);
success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{overall_3d_png_name}", overall_3d_png_path)
if success :
return "生成成功", minio_path
else :
return "生成失败", None
def generate_slopeAspect_tif(region_coords, url, slope_aspect_tif_name, minio_sub_path) :
# 下载3dtiles地图数据
_data_source = TilesetDataSource(url)
_data_source.dowload_map_data(url)
if not _data_source.tileset_path :
logger.info(f"下载地图数据失败: {url},{region_coords}")
return "下载地图数据失败", None
tileset_path = _data_source.tileset_path
script_dir = os.path.dirname(os.path.abspath(__file__))
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords)
if not os.path.exists(dem_path) :
logger.info(f"生成坡度坡向tif失败: {url},{region_coords}")
return "生成坡度坡向tif失败", None
slope_aspect_tif_path = os.path.join(script_dir, slope_aspect_tif_name)
slope_aspect_tif.create_slope_aspect(dem_path, 'combined', slope_aspect_tif_path)
logger.info(f"生成成功: {url},{region_coords},{slope_aspect_tif_path}")
entry_bucket, _ = _data_source.parse_minio_url(url);
success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{slope_aspect_tif_name}", slope_aspect_tif_path)
if success :
return "生成成功", minio_path
else :
return "生成失败", None
@staticmethod
def validate_vector(vector: List[float]) -> bool:
"""验证输入向量是否有效"""
if len(vector) != 3:
return False
if not all(isinstance(v, (int, float)) for v in vector):
return False
norm = np.linalg.norm(vector)
return norm > 1e-10 # 避免零向量
@staticmethod
def normalize_vector(vector: List[float]) -> np.ndarray:
"""向量归一化"""
arr = np.array(vector, dtype=np.float64)
norm = np.linalg.norm(arr)
return arr / norm if norm > 0 else arr
@staticmethod
def calculate_slope(normal_vector: List[float]) -> Dict[str, Any]:
"""
计算坡度
Args:
normal_vector: 法向量 [nx, ny, nz]
Returns:
dict: 包含坡度()和相关信息
"""
try:
# 验证输入
if not TerrainCalculator.validate_vector(normal_vector):
return {
"error": "无效的法向量,必须是长度为3的数值列表且不能为零向量",
"slope_deg": None
}
# 归一化
n = TerrainCalculator.normalize_vector(normal_vector)
# 计算坡度使用arccos法
nz_abs = abs(n[2])
# 处理数值误差
if nz_abs > 1.0:
nz_abs = 1.0
elif nz_abs < 0.0:
nz_abs = 0.0
# 计算坡度(弧度)
if abs(nz_abs - 1.0) < 1e-10: # 完全水平
slope_rad = 0.0
elif abs(nz_abs) < 1e-10: # 完全垂直
slope_rad = np.pi / 2
else:
slope_rad = np.arccos(nz_abs)
# 转换为度
slope_deg = np.degrees(slope_rad)
# 计算坡度百分比
slope_percent = np.tan(slope_rad) * 100 if slope_rad < np.pi/2 else float('inf')
return {
"slope_deg": float(slope_deg),
"slope_rad": float(slope_rad),
"slope_percent": float(slope_percent),
"normalized_vector": n.tolist(),
"classification": TerrainCalculator.classify_slope(slope_deg)
}
except Exception as e:
logger.error(f"坡度计算错误: {e}")
return {
"error": f"计算失败: {str(e)}",
"slope_deg": None
}
@staticmethod
def calculate_aspect(normal_vector: List[float]) -> Dict[str, Any]:
"""
计算坡向
Args:
normal_vector: 法向量 [nx, ny, nz]
Returns:
dict: 包含坡向()和相关信息
"""
try:
# 验证输入
if not TerrainCalculator.validate_vector(normal_vector):
return {
"error": "无效的法向量,必须是长度为3的数值列表且不能为零向量",
"aspect_deg": None
}
# 归一化
n = TerrainCalculator.normalize_vector(normal_vector)
# 检查是否为水平面
nx, ny, nz = n
horizontal_magnitude = np.sqrt(nx*nx + ny*ny)
if horizontal_magnitude < 1e-10: # 水平面,坡向无定义
return {
"aspect_deg": None,
"aspect_rad": None,
"is_flat": True,
"message": "水平面,坡向无定义",
"normalized_vector": n.tolist()
}
# 计算原始坡向(四象限反正切)
# 注意arctan2(nx, ny) 不是 arctan2(ny, nx)
raw_angle_rad = np.arctan2(nx, ny)
# 转换为坡向(下坡方向 = 法向量方向 + 180°
aspect_rad = raw_angle_rad + np.pi
# 转换为度
aspect_deg = np.degrees(aspect_rad)
# 归一化到 [0, 360) 范围
aspect_deg = aspect_deg % 360.0
# 转换为八方向
direction = TerrainCalculator.aspect_to_direction(aspect_deg)
return {
"aspect_deg": float(aspect_deg),
"aspect_rad": float(aspect_rad % (2*np.pi)),
"direction": direction,
"is_flat": False,
"normalized_vector": n.tolist(),
"raw_angle_deg": float(np.degrees(raw_angle_rad))
}
except Exception as e:
logger.error(f"坡向计算错误: {e}")
return {
"error": f"计算失败: {str(e)}",
"aspect_deg": None
}
@staticmethod
def calculate_slope_aspect(normal_vector: List[float]) -> Dict[str, Any]:
"""
同时计算坡度和坡向
Args:
normal_vector: 法向量 [nx, ny, nz]
Returns:
dict: 包含坡度和坡向的综合结果
"""
try:
slope_result = TerrainCalculator.calculate_slope(normal_vector)
aspect_result = TerrainCalculator.calculate_aspect(normal_vector)
result = {
"slope": slope_result,
"aspect": aspect_result,
"input_vector": normal_vector,
"calculation_time": None # 可在调用处添加时间戳
}
# 如果有错误,合并错误信息
errors = []
if "error" in slope_result and slope_result["error"]:
errors.append(f"坡度: {slope_result['error']}")
if "error" in aspect_result and aspect_result["error"]:
errors.append(f"坡向: {aspect_result['error']}")
if errors:
result["errors"] = errors
return result
except Exception as e:
logger.error(f"综合计算错误: {e}")
return {
"error": f"综合计算失败: {str(e)}"
}
@staticmethod
def classify_slope(slope_deg: float) -> Dict[str, Any]:
"""坡度分类"""
if slope_deg < 2:
return {"category": "平坦", "level": 0, "description": "基本平坦"}
elif slope_deg < 5:
return {"category": "缓坡", "level": 1, "description": "适合农业"}
elif slope_deg < 15:
return {"category": "斜坡", "level": 2, "description": "适合建设"}
elif slope_deg < 30:
return {"category": "陡坡", "level": 3, "description": "需要工程措施"}
elif slope_deg < 45:
return {"category": "急陡坡", "level": 4, "description": "高风险区域"}
else:
return {"category": "峭壁", "level": 5, "description": "危险区域"}
@staticmethod
def aspect_to_direction(aspect_deg: float) -> Dict[str, Any]:
"""将坡向转换为八方向"""
directions = ["", "东北", "", "东南", "", "西南", "西", "西北"]
# 计算方向索引 (45°一个区间)
index = int((aspect_deg + 22.5) % 360 / 45)
return {
"chinese": directions[index],
"english": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"][index],
"degree_range": {
"min": (index * 45 - 22.5) % 360,
"max": (index * 45 + 22.5) % 360
}
}
@staticmethod
def batch_calculate(vectors: List[List[float]]) -> Dict[str, Any]:
"""批量计算多个法向量"""
try:
results = []
errors = []
for i, vec in enumerate(vectors):
try:
result = TerrainCalculator.calculate_slope_aspect(vec)
result["index"] = i
results.append(result)
except Exception as e:
errors.append({
"index": i,
"vector": vec,
"error": str(e)
})
return {
"total": len(vectors),
"successful": len(results),
"failed": len(errors),
"results": results,
"errors": errors,
"statistics": TerrainCalculator.calculate_statistics(results)
}
except Exception as e:
logger.error(f"批量计算错误: {e}")
return {"error": f"批量计算失败: {str(e)}"}
@staticmethod
def calculate_statistics(results: List[Dict]) -> Dict[str, Any]:
"""计算统计信息"""
if not results:
return {}
slope_values = []
aspect_values = []
for r in results:
if "slope" in r and "slope_deg" in r["slope"] and r["slope"]["slope_deg"] is not None:
slope_values.append(r["slope"]["slope_deg"])
if "aspect" in r and "aspect_deg" in r["aspect"] and r["aspect"]["aspect_deg"] is not None:
aspect_values.append(r["aspect"]["aspect_deg"])
if slope_values:
slope_arr = np.array(slope_values)
aspect_arr = np.array(aspect_values) if aspect_values else np.array([])
stats = {
"slope": {
"count": len(slope_values),
"mean": float(np.mean(slope_arr)),
"std": float(np.std(slope_arr)),
"min": float(np.min(slope_arr)),
"max": float(np.max(slope_arr)),
"median": float(np.median(slope_arr))
}
}
if aspect_values:
# 坡向统计需要循环统计
stats["aspect"] = {
"count": len(aspect_values),
"mean_vector": TerrainCalculator.circular_mean(aspect_arr),
"concentration": TerrainCalculator.circular_concentration(aspect_arr)
}
return stats
return {}
@staticmethod
def circular_mean(angles_deg: np.ndarray) -> float:
"""计算循环数据的平均值(角度)"""
angles_rad = np.radians(angles_deg)
x = np.mean(np.cos(angles_rad))
y = np.mean(np.sin(angles_rad))
mean_rad = np.arctan2(y, x)
return np.degrees(mean_rad) % 360
@staticmethod
def circular_concentration(angles_deg: np.ndarray) -> float:
"""计算角度数据的集中度 (0-1)"""
angles_rad = np.radians(angles_deg)
x = np.mean(np.cos(angles_rad))
y = np.mean(np.sin(angles_rad))
return np.sqrt(x*x + y*y)

105
b3dm/tileset_data_source.py Normal file
View File

@ -0,0 +1,105 @@
# tileset_data_source_py
import numpy as np
from typing import List, Dict, Optional, Tuple
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging
import os
from pathlib import Path
from b3dm.data_3dtiles_manager import MinIO3DTilesManager
import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem
logger = logging.getLogger(__name__)
ENDPOINT_URL = "222.212.85.86:9000"
ACCESS_KEY = "WuRenJi"
SECRET_KEY = "WRJ@2024"
class TilesetDataSource:
"""使用py3dtiles库的数据源"""
def __init__(self, url: str, cache_size: int = 1000):
self.url = url
self.tileset_path = None
self.tileset_dir = None
self._crs = "EPSG:4979"
def parse_minio_url(self, url):
manager = MinIO3DTilesManager(
endpoint_url=ENDPOINT_URL,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
return manager.parse_minio_url(url)
def upload_file(self, bucket_name, object_name, file_path):
manager = MinIO3DTilesManager(
endpoint_url=ENDPOINT_URL,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
flag, path = manager.upload_file(bucket_name, object_name, file_path)
if flag :
os.remove(file_path)
return flag, path
def dowload_map_data(self, url: str) :
# 下载3dtiles地图数据
manager = MinIO3DTilesManager(
endpoint_url=ENDPOINT_URL,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
success, tileset_path = manager.download_full_tileset(
tileset_url=url,
save_dir=f"data_3dtiles",
region_filter=None
)
if success :
self.tileset_path = os.path.abspath(tileset_path)
self.tileset_dir = os.path.dirname(tileset_path)
async def get_points_in_polygon(self, polygon_coords, z_range=None):
"""获取多边形内的点数据"""
points = data_3dtiles_to_dem.parse_tileset(self.tileset_path, polygon_coords)
return np.array(points)
async def get_data_bounds(self) -> Dict[str, List[float]]:
"""获取数据边界"""
bounds = {
"min": [float('inf'), float('inf'), float('inf')],
"max": [-float('inf'), -float('inf'), -float('inf')]
}
if self._tileset and hasattr(self._tileset.root_tile, 'bounding_volume'):
bv = self._tileset.root_tile.bounding_volume
if hasattr(bv, 'get_corners'):
corners = bv.get_corners()
if corners is not None:
bounds["min"] = corners.min(axis=0).tolist()
bounds["max"] = corners.max(axis=0).tolist()
return bounds
def get_crs(self) -> str:
return self._crs
async def main() :
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SCRIPT_PAR_DIR = os.path.dirname(SCRIPT_DIR)
tileset_path = os.path.join(SCRIPT_PAR_DIR, "data/3dtiles/tileset.json")
data_source_3d_tiles = TilesetDataSource(tileset_path)
# tileSet = TileSet()
# path = Path(tileset_path)
# tileset_data = tileSet.from_file(path)
print("====================================================")
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main())

395
b3dm/tileset_to_ply.py Normal file
View File

@ -0,0 +1,395 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
3D Tiles Tileset to PLY Converter
将整个3D Tiles tileset转换为单个PLY文件
"""
import json
import struct
import os
import sys
from pathlib import Path
import numpy as np
try:
import DracoPy
except ImportError:
print("警告: DracoPy库未安装无法处理Draco压缩的数据")
print("请运行: pip install DracoPy")
DracoPy = None
class TilesetToPLYConverter:
def __init__(self):
self.all_vertices = []
self.vertex_count = 0
def multiply_matrix_vector(self, matrix, vector):
"""4x4矩阵与4D向量相乘"""
# matrix是16个元素的列表按列主序排列
# 转换为4x4矩阵行主序
m = [
[matrix[0], matrix[4], matrix[8], matrix[12]],
[matrix[1], matrix[5], matrix[9], matrix[13]],
[matrix[2], matrix[6], matrix[10], matrix[14]],
[matrix[3], matrix[7], matrix[11], matrix[15]]
]
# 向量扩展为齐次坐标 [x, y, z, 1]
v = [vector[0], vector[1], vector[2], 1.0]
# 矩阵乘法
result = [
m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2] + m[0][3]*v[3],
m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2] + m[1][3]*v[3],
m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2] + m[2][3]*v[3]
]
return result
def multiply_matrices(self, m1, m2):
"""两个4x4矩阵相乘"""
# 将16元素列表转换为4x4矩阵
def list_to_matrix(lst):
return [
[lst[0], lst[4], lst[8], lst[12]],
[lst[1], lst[5], lst[9], lst[13]],
[lst[2], lst[6], lst[10], lst[14]],
[lst[3], lst[7], lst[11], lst[15]]
]
def matrix_to_list(mat):
return [
mat[0][0], mat[1][0], mat[2][0], mat[3][0],
mat[0][1], mat[1][1], mat[2][1], mat[3][1],
mat[0][2], mat[1][2], mat[2][2], mat[3][2],
mat[0][3], mat[1][3], mat[2][3], mat[3][3]
]
a = list_to_matrix(m1)
b = list_to_matrix(m2)
result = [[0 for _ in range(4)] for _ in range(4)]
for i in range(4):
for j in range(4):
for k in range(4):
result[i][j] += a[i][k] * b[k][j]
return matrix_to_list(result)
def apply_transform_to_vertices(self, vertices, transform_matrix):
"""对顶点应用变换矩阵"""
if not transform_matrix:
return vertices
transformed_vertices = []
for vertex in vertices:
transformed = self.multiply_matrix_vector(transform_matrix, vertex)
transformed_vertices.append(transformed)
return transformed_vertices
def parse_tileset_json(self, tileset_path, parent_transform=None):
"""解析tileset.json文件收集B3DM文件和变换矩阵"""
try:
with open(tileset_path, 'r', encoding='utf-8') as f:
tileset_data = json.load(f)
b3dm_files = []
def process_node(node, base_path, accumulated_transform):
# 获取当前节点的变换矩阵
current_transform = node.get('transform')
# 计算累积变换矩阵
if current_transform and accumulated_transform:
# 矩阵相乘accumulated_transform * current_transform
final_transform = self.multiply_matrices(accumulated_transform, current_transform)
elif current_transform:
final_transform = current_transform
else:
final_transform = accumulated_transform
if 'content' in node and 'uri' in node['content']:
uri = node['content']['uri']
if uri.endswith('.b3dm'):
full_path = os.path.join(base_path, uri)
if os.path.exists(full_path):
b3dm_files.append((full_path, final_transform))
elif uri.endswith('.json'):
# 递归处理子tileset
sub_tileset_path = os.path.join(base_path, uri)
if os.path.exists(sub_tileset_path):
sub_files = self.parse_tileset_json(sub_tileset_path, final_transform)
b3dm_files.extend(sub_files)
if 'children' in node:
for child in node['children']:
process_node(child, base_path, final_transform)
base_path = os.path.dirname(tileset_path)
if 'root' in tileset_data:
process_node(tileset_data['root'], base_path, parent_transform)
return b3dm_files
except Exception as e:
print(f"解析tileset.json时出错: {e}")
return []
def parse_b3dm_file(self, file_path):
"""解析B3DM文件"""
try:
with open(file_path, 'rb') as f:
# 读取B3DM头部
magic = f.read(4)
if magic != b'b3dm':
print(f"警告: {file_path} 不是有效的B3DM文件")
return None
version = struct.unpack('<I', f.read(4))[0]
byte_length = struct.unpack('<I', f.read(4))[0]
feature_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
feature_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
batch_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
batch_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
# 跳过feature table和batch table
f.seek(28 + feature_table_json_byte_length + feature_table_binary_byte_length +
batch_table_json_byte_length + batch_table_binary_byte_length)
# 读取glTF数据
gltf_data = f.read()
return self.parse_gltf_data(gltf_data)
except Exception as e:
print(f"解析B3DM文件 {file_path} 失败: {e}")
return None
def parse_gltf_data(self, gltf_data):
"""解析glTF数据"""
try:
# 检查是否为GLB格式
if gltf_data[:4] == b'glTF':
return self.parse_glb_data(gltf_data)
else:
# 尝试作为JSON解析
gltf_json = json.loads(gltf_data.decode('utf-8'))
return self.extract_vertices_from_gltf(gltf_json, None)
except Exception as e:
print(f"解析glTF数据失败: {e}")
return None
def parse_glb_data(self, glb_data):
"""解析GLB格式的glTF数据"""
try:
# GLB头部: magic(4) + version(4) + length(4)
magic = glb_data[:4]
if magic != b'glTF':
return None
version = struct.unpack('<I', glb_data[4:8])[0]
total_length = struct.unpack('<I', glb_data[8:12])[0]
offset = 12
json_data = None
binary_data = None
# 读取chunks
while offset < len(glb_data):
if offset + 8 > len(glb_data):
break
chunk_length = struct.unpack('<I', glb_data[offset:offset+4])[0]
chunk_type = glb_data[offset+4:offset+8]
chunk_data = glb_data[offset+8:offset+8+chunk_length]
if chunk_type == b'JSON':
json_data = json.loads(chunk_data.decode('utf-8'))
elif chunk_type == b'BIN\x00':
binary_data = chunk_data
offset += 8 + chunk_length
# 对齐到4字节边界
offset = (offset + 3) & ~3
if json_data:
return self.extract_vertices_from_gltf(json_data, binary_data)
except Exception as e:
print(f"解析GLB数据失败: {e}")
return None
def extract_vertices_from_gltf(self, gltf_json, binary_data):
"""从glTF JSON中提取顶点数据"""
vertices = []
try:
# 检查是否使用了Draco压缩
if 'extensionsUsed' in gltf_json and 'KHR_draco_mesh_compression' in gltf_json['extensionsUsed']:
if DracoPy is None:
print("警告: 检测到Draco压缩但DracoPy未安装")
return vertices
return self.extract_draco_vertices(gltf_json, binary_data)
# 处理标准glTF格式
if 'meshes' not in gltf_json:
return vertices
for mesh in gltf_json['meshes']:
for primitive in mesh['primitives']:
if 'attributes' in primitive and 'POSITION' in primitive['attributes']:
position_accessor_index = primitive['attributes']['POSITION']
if 'accessors' in gltf_json and position_accessor_index < len(gltf_json['accessors']):
accessor = gltf_json['accessors'][position_accessor_index]
buffer_view_index = accessor['bufferView']
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
buffer_view = gltf_json['bufferViews'][buffer_view_index]
buffer_index = buffer_view['buffer']
byte_offset = buffer_view.get('byteOffset', 0) + accessor.get('byteOffset', 0)
if binary_data and buffer_index == 0:
# 从二进制数据中读取顶点
component_type = accessor['componentType']
count = accessor['count']
if component_type == 5126: # FLOAT
vertex_data = struct.unpack(f'<{count*3}f',
binary_data[byte_offset:byte_offset+count*12])
for i in range(0, len(vertex_data), 3):
vertices.append([vertex_data[i], vertex_data[i+1], vertex_data[i+2]])
except Exception as e:
print(f"提取顶点数据失败: {e}")
return vertices
def extract_draco_vertices(self, gltf_json, binary_data):
"""提取Draco压缩的顶点数据"""
vertices = []
try:
if 'meshes' not in gltf_json:
return vertices
for mesh in gltf_json['meshes']:
for primitive in mesh['primitives']:
if 'extensions' in primitive and 'KHR_draco_mesh_compression' in primitive['extensions']:
draco_ext = primitive['extensions']['KHR_draco_mesh_compression']
buffer_view_index = draco_ext['bufferView']
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
buffer_view = gltf_json['bufferViews'][buffer_view_index]
byte_offset = buffer_view.get('byteOffset', 0)
byte_length = buffer_view['byteLength']
if binary_data:
draco_data = binary_data[byte_offset:byte_offset+byte_length]
# 使用DracoPy解码
mesh_data = DracoPy.decode(draco_data)
if hasattr(mesh_data, 'points'):
points = mesh_data.points
for point in points:
vertices.append([float(point[0]), float(point[1]), float(point[2])])
except Exception as e:
print(f"解码Draco数据失败: {e}")
return vertices
def save_ply_file(self, output_path):
"""保存PLY文件"""
try:
with open(output_path, 'w') as f:
# 写入PLY头部
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {len(self.all_vertices)}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("end_header\n")
# 写入顶点数据
for vertex in self.all_vertices:
f.write(f"{vertex[0]} {vertex[1]} {vertex[2]}\n")
print(f"PLY文件已保存: {output_path}")
print(f"总顶点数: {len(self.all_vertices)}")
except Exception as e:
print(f"保存PLY文件失败: {e}")
def convert_tileset_to_ply(self, tileset_path, output_path):
"""将整个tileset转换为PLY文件"""
print(f"开始处理tileset: {tileset_path}")
# 解析主tileset.json
tileset_data = self.parse_tileset_json(tileset_path)
if not tileset_data:
print("无法解析tileset.json文件")
return False
# 获取基础路径
base_path = os.path.dirname(tileset_path)
# 提取所有b3dm文件和变换矩阵
b3dm_files = self.parse_tileset_json(tileset_path)
print(f"找到 {len(b3dm_files)} 个B3DM文件")
if not b3dm_files:
print("未找到任何B3DM文件")
return False
# 处理每个b3dm文件
processed_count = 0
for i, (b3dm_file, transform_matrix) in enumerate(b3dm_files):
print(f"处理文件 {i+1}/{len(b3dm_files)}: {os.path.basename(b3dm_file)}")
vertices = self.parse_b3dm_file(b3dm_file)
if vertices:
# 应用变换矩阵
if transform_matrix:
vertices = self.apply_transform_to_vertices(vertices, transform_matrix)
print(f" 应用了变换矩阵")
self.all_vertices.extend(vertices)
processed_count += 1
print(f" 提取到 {len(vertices)} 个顶点")
print(f"\n成功处理 {processed_count}/{len(b3dm_files)} 个文件")
print(f"总计提取 {len(self.all_vertices)} 个顶点")
if self.all_vertices:
self.save_ply_file(output_path)
return True
else:
print("未提取到任何顶点数据")
return False
def main():
if len(sys.argv) < 2:
print("用法: python tileset_to_ply.py <tileset.json路径> [输出PLY文件路径]")
print("示例: python tileset_to_ply.py tileset.json output.ply")
return
tileset_path = sys.argv[1]
output_path = sys.argv[2] if len(sys.argv) > 2 else "merged_tileset.ply"
if not os.path.exists(tileset_path):
print(f"错误: 文件不存在 {tileset_path}")
return
converter = TilesetToPLYConverter()
success = converter.convert_tileset_to_ply(tileset_path, output_path)
if success:
print("\n转换完成!")
else:
print("\n转换失败!")
if __name__ == "__main__":
main()

684
b3dm/volume_calculator.py Normal file
View File

@ -0,0 +1,684 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
三维模型体积计算器
基于指定地理范围内的三维模型数据使用三角构网方法计算体积
"""
import json
import struct
import os
import sys
from pathlib import Path
import numpy as np
from scipy.spatial import Delaunay
import math
from tqdm import tqdm
try:
import DracoPy
except ImportError:
print("警告: DracoPy库未安装无法处理Draco压缩的数据")
print("请运行: pip install DracoPy")
DracoPy = None
class VolumeCalculator:
def __init__(self, location_file):
self.location_bounds = self.load_location_bounds(location_file)
self.all_vertices = []
self.filtered_vertices = []
def load_location_bounds(self, location_file):
"""加载地理范围边界"""
try:
with open(location_file, 'r', encoding='utf-8') as f:
coords = json.load(f)
# 提取经纬度范围
lons = [coord[0] for coord in coords]
lats = [coord[1] for coord in coords]
elevs = [coord[2] for coord in coords]
bounds = {
'min_lon': min(lons),
'max_lon': max(lons),
'min_lat': min(lats),
'max_lat': max(lats),
'min_elev': min(elevs),
'max_elev': max(elevs),
'coords': coords
}
print(f"地理范围边界:")
print(f" 经度: {bounds['min_lon']:.8f} ~ {bounds['max_lon']:.8f}")
print(f" 纬度: {bounds['min_lat']:.8f} ~ {bounds['max_lat']:.8f}")
print(f" 高程: {bounds['min_elev']:.2f} ~ {bounds['max_elev']:.2f}")
return bounds
except Exception as e:
print(f"加载地理范围文件失败: {e}")
return None
def wgs84_to_cartesian(self, lon, lat, elev):
"""WGS84坐标转换为笛卡尔坐标高精度算法"""
# WGS84椭球参数EPSG:4326
a = 6378137.0 # 长半轴 (米)
f = 1/298.257223563 # 扁率
e2 = 2*f - f*f # 第一偏心率平方
# 角度转弧度
lon_rad = math.radians(lon)
lat_rad = math.radians(lat)
# 计算卯酉圈曲率半径
sin_lat = math.sin(lat_rad)
cos_lat = math.cos(lat_rad)
sin_lon = math.sin(lon_rad)
cos_lon = math.cos(lon_rad)
N = a / math.sqrt(1 - e2 * sin_lat * sin_lat)
# 高精度笛卡尔坐标计算
x = (N + elev) * cos_lat * cos_lon
y = (N + elev) * cos_lat * sin_lon
z = (N * (1 - e2) + elev) * sin_lat
return [x, y, z]
def cartesian_to_wgs84(self, x, y, z):
"""笛卡尔坐标转换为WGS84坐标高精度迭代法"""
# WGS84椭球参数
a = 6378137.0 # 长半轴
f = 1/298.257223563 # 扁率
e2 = 2*f - f*f # 第一偏心率平方
ep2 = e2 / (1 - e2) # 第二偏心率平方
# 计算经度(精确值)
lon = math.atan2(y, x)
# 计算纬度和高程使用Bowring迭代法
p = math.sqrt(x*x + y*y)
if p == 0:
# 极点情况
lat = math.pi/2 if z > 0 else -math.pi/2
elev = abs(z) - a * math.sqrt(1 - e2)
return [math.degrees(lon), math.degrees(lat), elev]
# 初始估计
theta = math.atan2(z, p * (1 - f))
lat_prev = math.atan2(z + ep2 * a * (1 - f) * math.sin(theta)**3,
p - e2 * a * math.cos(theta)**3)
# 迭代求解纬度
max_iterations = 10
tolerance = 1e-12
for i in range(max_iterations):
N = a / math.sqrt(1 - e2 * math.sin(lat_prev)**2)
elev = p / math.cos(lat_prev) - N
# 更新纬度估计
lat_new = math.atan2(z + e2 * N * math.sin(lat_prev), p)
# 检查收敛性
if abs(lat_new - lat_prev) < tolerance:
break
lat_prev = lat_new
# 最终计算高程
N = a / math.sqrt(1 - e2 * math.sin(lat_prev)**2)
elev = p / math.cos(lat_prev) - N
return [math.degrees(lon), math.degrees(lat_prev), elev]
def is_point_in_bounds(self, vertex):
"""检查点是否在指定的地理范围内"""
if not self.location_bounds:
return True
# 将笛卡尔坐标转换为WGS84
try:
lon, lat, elev = self.cartesian_to_wgs84(vertex[0], vertex[1], vertex[2])
# 检查是否在边界范围内
return (self.location_bounds['min_lon'] <= lon <= self.location_bounds['max_lon'] and
self.location_bounds['min_lat'] <= lat <= self.location_bounds['max_lat'])
except:
return False
def multiply_matrix_vector(self, matrix, vector):
"""4x4矩阵与4D向量相乘"""
m = [
[matrix[0], matrix[4], matrix[8], matrix[12]],
[matrix[1], matrix[5], matrix[9], matrix[13]],
[matrix[2], matrix[6], matrix[10], matrix[14]],
[matrix[3], matrix[7], matrix[11], matrix[15]]
]
v = [vector[0], vector[1], vector[2], 1.0]
result = [
m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2] + m[0][3]*v[3],
m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2] + m[1][3]*v[3],
m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2] + m[2][3]*v[3]
]
return result
def multiply_matrices(self, m1, m2):
"""两个4x4矩阵相乘"""
def list_to_matrix(lst):
return [
[lst[0], lst[4], lst[8], lst[12]],
[lst[1], lst[5], lst[9], lst[13]],
[lst[2], lst[6], lst[10], lst[14]],
[lst[3], lst[7], lst[11], lst[15]]
]
def matrix_to_list(mat):
return [
mat[0][0], mat[1][0], mat[2][0], mat[3][0],
mat[0][1], mat[1][1], mat[2][1], mat[3][1],
mat[0][2], mat[1][2], mat[2][2], mat[3][2],
mat[0][3], mat[1][3], mat[2][3], mat[3][3]
]
a = list_to_matrix(m1)
b = list_to_matrix(m2)
result = [[0 for _ in range(4)] for _ in range(4)]
for i in range(4):
for j in range(4):
for k in range(4):
result[i][j] += a[i][k] * b[k][j]
return matrix_to_list(result)
def apply_transform_to_vertices(self, vertices, transform_matrix):
"""对顶点应用变换矩阵"""
if not transform_matrix:
return vertices
transformed_vertices = []
for vertex in vertices:
transformed = self.multiply_matrix_vector(transform_matrix, vertex)
transformed_vertices.append(transformed)
return transformed_vertices
def parse_tileset_json(self, tileset_path, parent_transform=None):
"""解析tileset.json文件收集B3DM文件和变换矩阵"""
try:
with open(tileset_path, 'r', encoding='utf-8') as f:
tileset_data = json.load(f)
b3dm_files = []
def process_node(node, base_path, accumulated_transform):
current_transform = node.get('transform')
if current_transform and accumulated_transform:
final_transform = self.multiply_matrices(accumulated_transform, current_transform)
elif current_transform:
final_transform = current_transform
else:
final_transform = accumulated_transform
if 'content' in node and 'uri' in node['content']:
uri = node['content']['uri']
if uri.endswith('.b3dm'):
full_path = os.path.join(base_path, uri)
if os.path.exists(full_path):
b3dm_files.append((full_path, final_transform))
elif uri.endswith('.json'):
sub_tileset_path = os.path.join(base_path, uri)
if os.path.exists(sub_tileset_path):
sub_files = self.parse_tileset_json(sub_tileset_path, final_transform)
b3dm_files.extend(sub_files)
if 'children' in node:
for child in node['children']:
process_node(child, base_path, final_transform)
base_path = os.path.dirname(tileset_path)
if 'root' in tileset_data:
process_node(tileset_data['root'], base_path, parent_transform)
return b3dm_files
except Exception as e:
print(f"解析tileset.json时出错: {e}")
return []
def parse_b3dm_file(self, file_path):
"""解析B3DM文件"""
try:
with open(file_path, 'rb') as f:
magic = f.read(4)
if magic != b'b3dm':
return None
version = struct.unpack('<I', f.read(4))[0]
byte_length = struct.unpack('<I', f.read(4))[0]
feature_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
feature_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
batch_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
batch_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
f.seek(28 + feature_table_json_byte_length + feature_table_binary_byte_length +
batch_table_json_byte_length + batch_table_binary_byte_length)
gltf_data = f.read()
return self.parse_gltf_data(gltf_data)
except Exception as e:
print(f"解析B3DM文件 {file_path} 失败: {e}")
return None
def parse_gltf_data(self, gltf_data):
"""解析glTF数据"""
try:
if gltf_data[:4] == b'glTF':
return self.parse_glb_data(gltf_data)
else:
gltf_json = json.loads(gltf_data.decode('utf-8'))
return self.extract_vertices_from_gltf(gltf_json, None)
except Exception as e:
print(f"解析glTF数据失败: {e}")
return None
def parse_glb_data(self, glb_data):
"""解析GLB格式的glTF数据"""
try:
magic = glb_data[:4]
if magic != b'glTF':
return None
version = struct.unpack('<I', glb_data[4:8])[0]
total_length = struct.unpack('<I', glb_data[8:12])[0]
offset = 12
json_data = None
binary_data = None
while offset < len(glb_data):
if offset + 8 > len(glb_data):
break
chunk_length = struct.unpack('<I', glb_data[offset:offset+4])[0]
chunk_type = glb_data[offset+4:offset+8]
chunk_data = glb_data[offset+8:offset+8+chunk_length]
if chunk_type == b'JSON':
json_data = json.loads(chunk_data.decode('utf-8'))
elif chunk_type == b'BIN\x00':
binary_data = chunk_data
offset += 8 + chunk_length
offset = (offset + 3) & ~3
if json_data:
return self.extract_vertices_from_gltf(json_data, binary_data)
except Exception as e:
print(f"解析GLB数据失败: {e}")
return None
def extract_vertices_from_gltf(self, gltf_json, binary_data):
"""从glTF JSON中提取顶点数据"""
vertices = []
try:
if 'extensionsUsed' in gltf_json and 'KHR_draco_mesh_compression' in gltf_json['extensionsUsed']:
if DracoPy is None:
print("警告: 检测到Draco压缩但DracoPy未安装")
return vertices
return self.extract_draco_vertices(gltf_json, binary_data)
if 'meshes' not in gltf_json:
return vertices
for mesh in gltf_json['meshes']:
for primitive in mesh['primitives']:
if 'attributes' in primitive and 'POSITION' in primitive['attributes']:
position_accessor_index = primitive['attributes']['POSITION']
if 'accessors' in gltf_json and position_accessor_index < len(gltf_json['accessors']):
accessor = gltf_json['accessors'][position_accessor_index]
buffer_view_index = accessor['bufferView']
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
buffer_view = gltf_json['bufferViews'][buffer_view_index]
buffer_index = buffer_view['buffer']
byte_offset = buffer_view.get('byteOffset', 0) + accessor.get('byteOffset', 0)
if binary_data and buffer_index == 0:
component_type = accessor['componentType']
count = accessor['count']
if component_type == 5126: # FLOAT
vertex_data = struct.unpack(f'<{count*3}f',
binary_data[byte_offset:byte_offset+count*12])
for i in range(0, len(vertex_data), 3):
vertices.append([vertex_data[i], vertex_data[i+1], vertex_data[i+2]])
except Exception as e:
print(f"提取顶点数据失败: {e}")
return vertices
def extract_draco_vertices(self, gltf_json, binary_data):
"""提取Draco压缩的顶点数据"""
vertices = []
try:
if 'meshes' not in gltf_json:
return vertices
for mesh in gltf_json['meshes']:
for primitive in mesh['primitives']:
if 'extensions' in primitive and 'KHR_draco_mesh_compression' in primitive['extensions']:
draco_ext = primitive['extensions']['KHR_draco_mesh_compression']
buffer_view_index = draco_ext['bufferView']
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
buffer_view = gltf_json['bufferViews'][buffer_view_index]
byte_offset = buffer_view.get('byteOffset', 0)
byte_length = buffer_view['byteLength']
if binary_data:
draco_data = binary_data[byte_offset:byte_offset+byte_length]
mesh_data = DracoPy.decode(draco_data)
if hasattr(mesh_data, 'points'):
points = mesh_data.points
for point in points:
vertices.append([float(point[0]), float(point[1]), float(point[2])])
except Exception as e:
print(f"解码Draco数据失败: {e}")
return vertices
def calculate_triangle_angles(self, p1, p2, p3):
"""计算三角形的三个内角(度)"""
# 计算三边长度
a = np.linalg.norm(p2 - p3) # 边a对应角A(p1)
b = np.linalg.norm(p1 - p3) # 边b对应角B(p2)
c = np.linalg.norm(p1 - p2) # 边c对应角C(p3)
# 避免除零错误
if a == 0 or b == 0 or c == 0:
return [0, 0, 0]
# 使用余弦定理计算角度
try:
# 角A = arccos((b²+c²-a²)/(2bc))
cos_A = (b*b + c*c - a*a) / (2*b*c)
cos_B = (a*a + c*c - b*b) / (2*a*c)
cos_C = (a*a + b*b - c*c) / (2*a*b)
# 限制余弦值范围,避免数值误差
cos_A = np.clip(cos_A, -1.0, 1.0)
cos_B = np.clip(cos_B, -1.0, 1.0)
cos_C = np.clip(cos_C, -1.0, 1.0)
angle_A = np.arccos(cos_A) * 180 / np.pi
angle_B = np.arccos(cos_B) * 180 / np.pi
angle_C = np.arccos(cos_C) * 180 / np.pi
return [angle_A, angle_B, angle_C]
except:
return [0, 0, 0]
def is_valid_triangle(self, p1, p2, p3, min_angle=10.0, max_aspect_ratio=10.0):
"""验证三角形质量,基于角度约束和长宽比"""
angles = self.calculate_triangle_angles(p1, p2, p3)
# 检查最小角度约束参考C#代码中的10度限制
if min(angles) < min_angle:
return False
# 计算三边长度
a = np.linalg.norm(p2 - p3)
b = np.linalg.norm(p1 - p3)
c = np.linalg.norm(p1 - p2)
# 检查长宽比约束
max_side = max(a, b, c)
min_side = min(a, b, c)
if min_side > 0 and max_side / min_side > max_aspect_ratio:
return False
return True
def calculate_circumcenter_and_radius(self, p1, p2, p3):
"""计算三角形外接圆圆心和半径(高精度算法)"""
try:
x1, y1 = p1[0], p1[1]
x2, y2 = p2[0], p2[1]
x3, y3 = p3[0], p3[1]
# 使用C#代码中的高精度外接圆计算公式
d = 2 * (x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2))
if abs(d) < 1e-10: # 三点共线
return None, float('inf')
ux = ((x1*x1 + y1*y1) * (y2 - y3) + (x2*x2 + y2*y2) * (y3 - y1) + (x3*x3 + y3*y3) * (y1 - y2)) / d
uy = ((x1*x1 + y1*y1) * (x3 - x2) + (x2*x2 + y2*y2) * (x1 - y3) + (x3*x3 + y3*y3) * (x2 - x1)) / d
# 计算半径
radius = np.sqrt((ux - x1)**2 + (uy - y1)**2)
return np.array([ux, uy]), radius
except:
return None, float('inf')
def calculate_volume_delaunay(self, vertices, base_elevation=None, min_angle=10.0, use_quality_filter=True):
"""使用Delaunay三角剖分计算体积"""
if len(vertices) < 4:
print("顶点数量不足,无法进行三角剖分")
return 0.0
try:
points = np.array(vertices)
if base_elevation is None:
base_elevation = np.min(points[:, 2])
print(f"Delaunay方法使用基准面高程: {base_elevation:.2f}")
print(f"质量控制参数: 最小角度={min_angle}°, 质量过滤={'开启' if use_quality_filter else '关闭'}")
adjusted_points = points.copy()
adjusted_points[:, 2] = np.maximum(points[:, 2] - base_elevation, 0)
print("正在进行Delaunay三角剖分...")
tri = Delaunay(adjusted_points)
valid_simplices = []
total_simplices = len(tri.simplices)
if use_quality_filter:
print("正在进行三角形质量检查...")
for simplex in tqdm(tri.simplices, desc="质量检查", unit=""):
p0 = adjusted_points[simplex[0]]
p1 = adjusted_points[simplex[1]]
p2 = adjusted_points[simplex[2]]
p3 = adjusted_points[simplex[3]]
faces = [(p0, p1, p2), (p0, p1, p3), (p0, p2, p3), (p1, p2, p3)]
valid_faces = 0
for face in faces:
if self.is_valid_triangle(face[0], face[1], face[2], min_angle):
valid_faces += 1
if valid_faces >= 3:
valid_simplices.append(simplex)
print(f"质量过滤: {len(valid_simplices)}/{total_simplices} 个四面体通过质量检查")
else:
valid_simplices = tri.simplices
total_volume = 0.0
valid_volume_count = 0
print(f"计算 {len(valid_simplices)} 个四面体的体积...")
for simplex in tqdm(valid_simplices, desc="计算四面体体积", unit=""):
p0 = adjusted_points[simplex[0]]
p1 = adjusted_points[simplex[1]]
p2 = adjusted_points[simplex[2]]
p3 = adjusted_points[simplex[3]]
v1 = p1 - p0
v2 = p2 - p0
v3 = p3 - p0
det = np.linalg.det(np.array([v1, v2, v3]))
volume = abs(det) / 6.0
if volume > 1e-12:
total_volume += volume
valid_volume_count += 1
print(f"有效体积计算: {valid_volume_count}/{len(valid_simplices)} 个四面体")
return total_volume
except Exception as e:
print(f"Delaunay三角剖分计算体积失败: {e}")
return 0.0
def load_and_filter_vertices(self, tileset_path):
"""加载并过滤指定范围内的顶点"""
print(f"开始处理tileset: {tileset_path}")
# 解析tileset获取所有B3DM文件
b3dm_files = self.parse_tileset_json(tileset_path)
print(f"找到 {len(b3dm_files)} 个B3DM文件")
if not b3dm_files:
print("未找到任何B3DM文件")
return False
# 处理每个B3DM文件
processed_count = 0
total_vertices = 0
filtered_count = 0
for i, (b3dm_file, transform_matrix) in enumerate(tqdm(b3dm_files, desc="处理B3DM文件", unit="文件")):
# print(f"处理文件 {i+1}/{len(b3dm_files)}: {os.path.basename(b3dm_file)}")
vertices = self.parse_b3dm_file(b3dm_file)
if vertices:
# 应用变换矩阵
if transform_matrix:
vertices = self.apply_transform_to_vertices(vertices, transform_matrix)
# 过滤范围内的顶点
for vertex in vertices:
total_vertices += 1
if self.is_point_in_bounds(vertex):
self.filtered_vertices.append(vertex)
filtered_count += 1
self.all_vertices.extend(vertices)
processed_count += 1
# tqdm.write(f" {os.path.basename(b3dm_file)}: 提取到 {len(vertices)} 个顶点")
print(f"\n成功处理 {processed_count}/{len(b3dm_files)} 个文件")
print(f"总计提取 {total_vertices} 个顶点")
print(f"范围内顶点 {filtered_count}")
return len(self.filtered_vertices) > 0
def calculate_volume(self, tileset_path, base_elevation=None, min_angle=10.0, use_quality_filter=True):
"""计算指定范围内三维模型的体积
Args:
tileset_path: tileset.json文件路径
base_elevation: 基准面高程
min_angle: 最小角度约束
use_quality_filter: 是否启用质量过滤
"""
if not self.load_and_filter_vertices(tileset_path):
print("未找到范围内的顶点数据")
return 0.0
print(f"\n开始计算体积使用Delaunay三角剖分方法")
print(f"参与计算的顶点数: {len(self.filtered_vertices)}")
points = np.array(self.filtered_vertices)
if base_elevation is None:
base_elevation = np.min(points[:, 2])
print(f"统一基准面高程: {base_elevation:.2f}")
volume = self.calculate_volume_delaunay(self.filtered_vertices, base_elevation, min_angle, use_quality_filter)
print(f"\n计算结果:")
print(f"体积: {volume:.6f} 立方米")
print(f"体积: {volume/1000000:.6f} 立方千米")
return volume
def main():
if len(sys.argv) < 3:
print("用法: python volume_calculator.py <lboundary.json路径> <tileset.json路径> [基准面高程] [最小角度] [质量过滤]")
print("基准面高程: 基准面高程(米),默认使用最低点")
print("最小角度: 三角形最小角度约束(度)默认10.0")
print("质量过滤: 是否启用质量过滤(true/false)默认true")
print("示例: python volume_calculator.py boundary.json tileset.json 100.0 15.0 true")
return
location_file = sys.argv[1]
tileset_path = sys.argv[2]
# 解析可选参数
base_elevation = None
min_angle = 10.0
use_quality_filter = True
if len(sys.argv) > 3:
try:
base_elevation = float(sys.argv[3])
except ValueError:
print("错误:基准面高程必须是数字")
return
if len(sys.argv) > 4:
try:
min_angle = float(sys.argv[4])
except ValueError:
print("错误:最小角度必须是数字")
return
if len(sys.argv) > 5:
use_quality_filter = sys.argv[5].lower() in ['true', '1', 'yes', 'on']
if not os.path.exists(location_file):
print(f"错误: 位置文件不存在 {location_file}")
return
if not os.path.exists(tileset_path):
print(f"错误: tileset文件不存在 {tileset_path}")
return
calculator = VolumeCalculator(location_file)
volume = calculator.calculate_volume(tileset_path, base_elevation, min_angle, use_quality_filter)
if volume > 0:
print("\n体积计算完成!")
else:
print("\n体积计算失败!")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.3 MiB

View File

@ -1,8 +0,0 @@
import os.path
from middleware.minio_util import downBigFile
miniourl=r"media/22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_202509121519_001_22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_20250912152112_0001_V.mp4"
file_path=downBigFile(miniourl)
print(f"os.path.abspath(file_path) {os.path.abspath(file_path)}")

View File

@ -94,7 +94,7 @@ def func_100000(results, cls_id_list, type_name_list, func_id_10001, list_track_
trickier_detail = { trickier_detail = {
# "track_id": results.track_ids[i], # "track_id": results.track_ids[i],
"confidence": results.confs[i], "confidence": results.confs[i],
"cls_id": i, "cls_id": ind,
"type_name": type_name_list[ind], "type_name": type_name_list[ind],
"box": boxes[i] "box": boxes[i]
} }

View File

@ -0,0 +1,76 @@
import asyncio
import grpc
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
async def async_check_server_status(channel):
try:
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
# 注意这里不需要await因为Check方法不是异步的
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
async def async_check_grpc_request(max_retries=3, delay=5):
retries = 0
channel = None
while retries < max_retries:
try:
channel = grpc.insecure_channel('localhost:50051')
if not await async_check_server_status(channel):
raise Exception("Server is not healthy")
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
request = check_grpc_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
sn="8UUXN6S00A0CK7",
content_body=check_grpc_pb2.ContentBody(
org_code="HMZHB",
func_id=[101204],
source_url="xxxxxxxxxx",
push_url="",
confidence=0.4,
para_list=[
check_grpc_pb2.ParaList(
func_id=101204,
para_invade_enable=True
)
],
invade=check_grpc_pb2.Invade(
invade_file="meta_data/高压线-0826.geojson",
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
)
)
)
# ProcessTask可能也不是异步的所以不需要await
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
if channel:
channel.close() # 同步关闭不需要await
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if channel:
channel.close() # 确保在重试前关闭连接
channel = None
await asyncio.sleep(delay) # 异步等待
print("All retry attempts failed")
return False
if __name__ == '__main__':
asyncio.run(async_check_grpc_request())

View File

@ -0,0 +1,58 @@
syntax = "proto3";
package task;
service TaskService {
rpc ProcessTask (TaskRequest) returns (TaskResponse);
}
//
service HealthCheck {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
message TaskRequest {
string task_id = 1;
string sn = 2;
ContentBody content_body = 3;
}
message ContentBody {
string org_code = 1;
repeated int32 func_id = 2;
string source_url = 3;
string push_url = 4;
float confidence = 5;
repeated ParaList para_list = 6;
Invade invade = 7;
}
message ParaList {
int32 func_id = 1;
bool para_invade_enable = 2;
}
message Invade {
string invade_file = 1;
string camera_para_url = 2;
}
message TaskResponse {
string task_id = 1;
bool success = 2;
string message = 3;
}

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: check_grpc.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
6,
31,
1,
'',
'check_grpc.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63heck_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'check_grpc_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_HEALTHCHECKREQUEST']._serialized_start=26
_globals['_HEALTHCHECKREQUEST']._serialized_end=63
_globals['_HEALTHCHECKRESPONSE']._serialized_start=66
_globals['_HEALTHCHECKRESPONSE']._serialized_end=225
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=146
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=225
_globals['_TASKREQUEST']._serialized_start=227
_globals['_TASKREQUEST']._serialized_end=310
_globals['_CONTENTBODY']._serialized_start=313
_globals['_CONTENTBODY']._serialized_end=484
_globals['_PARALIST']._serialized_start=486
_globals['_PARALIST']._serialized_end=541
_globals['_INVADE']._serialized_start=543
_globals['_INVADE']._serialized_end=597
_globals['_TASKRESPONSE']._serialized_start=599
_globals['_TASKRESPONSE']._serialized_end=664
_globals['_TASKSERVICE']._serialized_start=666
_globals['_TASKSERVICE']._serialized_end=733
_globals['_HEALTHCHECK']._serialized_start=735
_globals['_HEALTHCHECK']._serialized_end=810
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,172 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import grpc_util.grpc_proto_demo.check_grpc.check_grpc_pb2 as check__grpc__pb2
GRPC_GENERATED_VERSION = '1.76.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ ' but the generated code in check_grpc_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class TaskServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ProcessTask = channel.unary_unary(
'/task.TaskService/ProcessTask',
request_serializer=check__grpc__pb2.TaskRequest.SerializeToString,
response_deserializer=check__grpc__pb2.TaskResponse.FromString,
_registered_method=True)
class TaskServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def ProcessTask(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TaskServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'ProcessTask': grpc.unary_unary_rpc_method_handler(
servicer.ProcessTask,
request_deserializer=check__grpc__pb2.TaskRequest.FromString,
response_serializer=check__grpc__pb2.TaskResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'task.TaskService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('task.TaskService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class TaskService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ProcessTask(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/task.TaskService/ProcessTask',
check__grpc__pb2.TaskRequest.SerializeToString,
check__grpc__pb2.TaskResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
class HealthCheckStub(object):
"""添加健康检查服务
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Check = channel.unary_unary(
'/task.HealthCheck/Check',
request_serializer=check__grpc__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=check__grpc__pb2.HealthCheckResponse.FromString,
_registered_method=True)
class HealthCheckServicer(object):
"""添加健康检查服务
"""
def Check(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_HealthCheckServicer_to_server(servicer, server):
rpc_method_handlers = {
'Check': grpc.unary_unary_rpc_method_handler(
servicer.Check,
request_deserializer=check__grpc__pb2.HealthCheckRequest.FromString,
response_serializer=check__grpc__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'task.HealthCheck', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('task.HealthCheck', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class HealthCheck(object):
"""添加健康检查服务
"""
@staticmethod
def Check(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/task.HealthCheck/Check',
check__grpc__pb2.HealthCheckRequest.SerializeToString,
check__grpc__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@ -0,0 +1,83 @@
import grpc
import time
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
def check_server_status(channel):
try:
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
def check_grpc_request(max_retries=3, delay=5):
channel = None
retries = 0
while retries < max_retries:
try:
# 创建通道
channel = grpc.insecure_channel('localhost:50051')
# 检查服务器状态
if not check_server_status(channel):
raise Exception("Server is not healthy")
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
# 创建请求消息
request = check_grpc_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
sn="8UUXN6S00A0CK7",
content_body=check_grpc_pb2.ContentBody(
org_code="HMZHB",
func_id=[101204],
source_url="xxxxxxxxxx",
push_url="",
confidence=0.4,
para_list=[
check_grpc_pb2.ParaList(
func_id=101204,
para_invade_enable=True
)
],
invade=check_grpc_pb2.Invade(
invade_file="meta_data/高压线-0826.geojson",
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
)
)
)
# 调用远程方法
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
finally:
if channel:
channel.close()
print("All retry attempts failed")
return False
if __name__ == '__main__':
check_grpc_request()

View File

@ -0,0 +1,55 @@
from concurrent import futures
import grpc
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2, check_grpc_pb2_grpc
class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer):
def ProcessTask(self, request, context):
print(f"Received task_id: {request.task_id}")
return check_grpc_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message="Task processed successfully"
)
class HealthCheckServicer(check_grpc_pb2_grpc.HealthCheckServicer):
def Check(self, request, context):
# 简单实现总是返回SERVING状态
# 实际应用中可以根据服务状态返回不同值
return check_grpc_pb2.HealthCheckResponse(
status=check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
# 添加服务实现
check_grpc_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
check_grpc_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
server.add_insecure_port('[::]:50051')
server.start()
print("Server started, listening on port 50051...")
# try:
# while True:
# time.sleep(86400) # 保持运行
# except KeyboardInterrupt:
# server.stop(0)
# 信号处理
import signal
def shutdown_handler(signum, frame):
print(f"Received signal {signum}, shutting down...")
server.stop(0)
signal.signal(signal.SIGINT, shutdown_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
server.wait_for_termination()
if __name__ == '__main__':
serve()

View File

@ -0,0 +1,8 @@
1、参考check_grpc.proto 手写 proto 相关结构
2、使用编译命令生成grpc通讯相关命令如下
python -m grpc_tools.protoc --proto_path=proto_dir \
--python_out=gen_dir \
--grpc_python_out=gen_dir \
proto_dir/check_grpc.proto
只会生成requeset、response
3、拷贝当前文件下的check_grpc_client.py 、check_grpc_server.py重写逻辑代码

View File

View File

@ -0,0 +1,89 @@
import asyncio
import grpc
# from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
# from grpc_util.grpc_sam3.grpc_sam3_img_pb2_grpc import TaskServiceStub, HealthCheckStub
# from grpc_util.grpc_sam3.grpc_sam3_img_pb2 import TaskRequest, TaskResponse, HealthCheckRequest, HealthCheckResponse
# from . import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
async def async_check_server_status(channel):
try:
health_stub = grpc_sam3_img_pb2_grpc.HealthCheckStub(channel)
# 注意这里不需要await因为Check方法不是异步的
response = health_stub.Check(grpc_sam3_img_pb2.HealthCheckRequest(service="TaskService"))
return response.status == grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
async def grpc_sam3_pic_predict(
task_id:str,
sn:str,
img_url:str,
prompt:str,
confidence:float,
mqtt_ip:str,
mqtt_port:int,
mqtt_topic:str,
max_retries=3, delay=5):
retries = 0
channel = None
while retries < max_retries:
try:
channel = grpc.insecure_channel('0.0.0.0:50051')
if not await async_check_server_status(channel):
raise Exception("Server is not healthy")
stub = grpc_sam3_img_pb2_grpc.TaskServiceStub(channel)
request = grpc_sam3_img_pb2.TaskRequest(
task_id=task_id,
sn=sn,
content_body=grpc_sam3_img_pb2.ContentBody(
img_url=img_url,
prompt=prompt,
confidence=0.5,
mqtt_ip=mqtt_ip,
mqtt_port=mqtt_port,
mqtt_topic=mqtt_topic
)
)
# ProcessTask可能也不是异步的所以不需要await
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
if channel:
channel.close() # 同步关闭不需要await
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if channel:
channel.close() # 确保在重试前关闭连接
channel = None
await asyncio.sleep(delay) # 异步等待
print("All retry attempts failed")
return False
if __name__ == '__main__':
asyncio.run(grpc_sam3_pic_predict(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354111",
sn="8UUXN6S00A0CK7",
img_url="demo/03.png",
prompt="cat",
confidence=0.5,
mqtt_ip="47.108.62.6",
mqtt_port=12503,
mqtt_topic="thing/product/ai/events",
))

View File

@ -0,0 +1,48 @@
syntax = "proto3";
package grpc_sam3_img;
service TaskService {
rpc ProcessTask (TaskRequest) returns (TaskResponse);
}
//
service HealthCheck {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
message TaskRequest {
string task_id = 1;
string sn = 2;
ContentBody content_body = 3;
}
message ContentBody {
string img_url = 1;
string prompt = 2;
float confidence =3;
string mqtt_ip = 4;
int32 mqtt_port = 5;
string mqtt_topic = 6;
}
message TaskResponse {
string task_id = 1;
bool success = 2;
string message = 3;
}

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: grpc_sam3_img.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
6,
31,
1,
'',
'grpc_sam3_img.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13grpc_sam3_img.proto\x12\rgrpc_sam3_img\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\xa8\x01\n\x13HealthCheckResponse\x12@\n\x06status\x18\x01 \x01(\x0e\x32\x30.grpc_sam3_img.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"\\\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\x30\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x1a.grpc_sam3_img.ContentBody\"z\n\x0b\x43ontentBody\x12\x0f\n\x07img_url\x18\x01 \x01(\t\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x0f\n\x07mqtt_ip\x18\x04 \x01(\t\x12\x11\n\tmqtt_port\x18\x05 \x01(\x05\x12\x12\n\nmqtt_topic\x18\x06 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2U\n\x0bTaskService\x12\x46\n\x0bProcessTask\x12\x1a.grpc_sam3_img.TaskRequest\x1a\x1b.grpc_sam3_img.TaskResponse2]\n\x0bHealthCheck\x12N\n\x05\x43heck\x12!.grpc_sam3_img.HealthCheckRequest\x1a\".grpc_sam3_img.HealthCheckResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_sam3_img_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_HEALTHCHECKREQUEST']._serialized_start=38
_globals['_HEALTHCHECKREQUEST']._serialized_end=75
_globals['_HEALTHCHECKRESPONSE']._serialized_start=78
_globals['_HEALTHCHECKRESPONSE']._serialized_end=246
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=167
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=246
_globals['_TASKREQUEST']._serialized_start=248
_globals['_TASKREQUEST']._serialized_end=340
_globals['_CONTENTBODY']._serialized_start=342
_globals['_CONTENTBODY']._serialized_end=464
_globals['_TASKRESPONSE']._serialized_start=466
_globals['_TASKRESPONSE']._serialized_end=531
_globals['_TASKSERVICE']._serialized_start=533
_globals['_TASKSERVICE']._serialized_end=618
_globals['_HEALTHCHECK']._serialized_start=620
_globals['_HEALTHCHECK']._serialized_end=713
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,172 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import grpc_util.grpc_sam3.grpc_sam3_img_pb2 as grpc__sam3__img__pb2
GRPC_GENERATED_VERSION = '1.76.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ ' but the generated code in grpc_sam3_img_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class TaskServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ProcessTask = channel.unary_unary(
'/grpc_sam3_img.TaskService/ProcessTask',
request_serializer=grpc__sam3__img__pb2.TaskRequest.SerializeToString,
response_deserializer=grpc__sam3__img__pb2.TaskResponse.FromString,
_registered_method=True)
class TaskServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def ProcessTask(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TaskServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'ProcessTask': grpc.unary_unary_rpc_method_handler(
servicer.ProcessTask,
request_deserializer=grpc__sam3__img__pb2.TaskRequest.FromString,
response_serializer=grpc__sam3__img__pb2.TaskResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'grpc_sam3_img.TaskService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('grpc_sam3_img.TaskService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class TaskService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ProcessTask(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/grpc_sam3_img.TaskService/ProcessTask',
grpc__sam3__img__pb2.TaskRequest.SerializeToString,
grpc__sam3__img__pb2.TaskResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
class HealthCheckStub(object):
"""添加健康检查服务
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Check = channel.unary_unary(
'/grpc_sam3_img.HealthCheck/Check',
request_serializer=grpc__sam3__img__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=grpc__sam3__img__pb2.HealthCheckResponse.FromString,
_registered_method=True)
class HealthCheckServicer(object):
"""添加健康检查服务
"""
def Check(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_HealthCheckServicer_to_server(servicer, server):
rpc_method_handlers = {
'Check': grpc.unary_unary_rpc_method_handler(
servicer.Check,
request_deserializer=grpc__sam3__img__pb2.HealthCheckRequest.FromString,
response_serializer=grpc__sam3__img__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'grpc_sam3_img.HealthCheck', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('grpc_sam3_img.HealthCheck', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class HealthCheck(object):
"""添加健康检查服务
"""
@staticmethod
def Check(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/grpc_sam3_img.HealthCheck/Check',
grpc__sam3__img__pb2.HealthCheckRequest.SerializeToString,
grpc__sam3__img__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@ -0,0 +1,72 @@
import grpc
import time
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
def check_server_status(channel):
try:
health_stub = grpc_sam3_img_pb2_grpc.HealthCheckStub(channel)
response = health_stub.Check(grpc_sam3_img_pb2.HealthCheckRequest(service="TaskService"))
return response.status == grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
def check_grpc_request(max_retries=3, delay=5):
channel = None
retries = 0
while retries < max_retries:
try:
# 创建通道
channel = grpc.insecure_channel('192.168.110.187:9999')
# 检查服务器状态
if not check_server_status(channel):
raise Exception("Server is not healthy")
stub = grpc_sam3_img_pb2_grpc.TaskServiceStub(channel)
# 创建请求消息
request = grpc_sam3_img_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354111",
sn="8UUXN6S00A0CK7",
content_body=grpc_sam3_img_pb2.ContentBody(
img_url="demo/03.png",
prompt="cat",
confidence=0.5,
mqtt_ip="47.108.62.6",
mqtt_port=12503,
mqtt_topic="thing/product/ai/events"
)
)
# 调用远程方法
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
finally:
if channel:
channel.close()
print("All retry attempts failed")
return False
if __name__ == '__main__':
check_grpc_request()

View File

@ -0,0 +1,369 @@
import json
from concurrent import futures
import grpc
import threading
import queue
import time
import logging
from typing import Dict, Optional
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
from middleware.MQTTService import MQTTService
from middleware.minio_util import downFile, upload_file
import sys
from middleware.util import get_current_date_and_milliseconds
print(sys.executable)
import os
import matplotlib.pyplot as plt
import numpy as np
import sam3.sam3
from PIL import Image
from sam3.sam3 import build_sam3_image_model
from sam3.sam3 import build_sam3_image_model_0228
from sam3.sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.sam3.model.sam3_image_processor import Sam3Processor
from sam3.sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results,plot_results_savepic
sam3_root = os.path.join(os.path.dirname(sam3.sam3.__file__), "..")
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TaskQueue:
"""任务队列管理类"""
def __init__(self, max_size: int = 1000):
self.queue = queue.Queue(maxsize=max_size)
self.task_status: Dict[str, dict] = {} # 存储任务状态
self.lock = threading.Lock()
self.stop_event = threading.Event()
def add_task(self, task_id: str, request_data: dict) -> bool:
"""添加任务到队列"""
try:
task_item = {
'task_id': task_id,
'data': request_data,
'timestamp': time.time(),
'status': 'pending'
}
with self.lock:
self.task_status[task_id] = task_item
# 非阻塞方式放入队列
self.queue.put(task_item, block=False)
logger.info(f"任务 {task_id} 已添加到队列,当前队列大小: {self.queue.qsize()}")
return True
except queue.Full:
logger.warning(f"队列已满,任务 {task_id} 被拒绝")
return False
except Exception as e:
logger.error(f"添加任务失败: {e}")
return False
def get_task(self, timeout: float = 1.0) -> Optional[dict]:
"""从队列获取任务"""
try:
return self.queue.get(timeout=timeout)
except queue.Empty:
return None
def update_task_status(self, task_id: str, status: str, result: dict = None):
"""更新任务状态"""
with self.lock:
if task_id in self.task_status:
self.task_status[task_id]['status'] = status
if result:
self.task_status[task_id]['result'] = result
self.task_status[task_id]['completed_time'] = time.time()
def get_task_status(self, task_id: str) -> Optional[dict]:
"""获取任务状态"""
with self.lock:
return self.task_status.get(task_id)
def cleanup_old_tasks(self, max_age_seconds: int = 3600):
"""清理旧任务"""
with self.lock:
current_time = time.time()
to_delete = []
for task_id, task in self.task_status.items():
if 'completed_time' in task:
age = current_time - task['completed_time']
if age > max_age_seconds:
to_delete.append(task_id)
for task_id in to_delete:
del self.task_status[task_id]
logger.info(f"清理旧任务: {task_id}")
class TaskWorker(threading.Thread):
"""工作线程,从队列中取任务并处理"""
def __init__(self, worker_id: int, task_queue: TaskQueue, stop_event: threading.Event):
super().__init__(daemon=True)
self.worker_id = worker_id
self.task_queue = task_queue
self.stop_event = stop_event
self.processed_count = 0
def run(self):
logger.info(f"工作线程 {self.worker_id} 启动")
while not self.stop_event.is_set():
try:
# 从队列获取任务
task = self.task_queue.get_task(timeout=0.5)
if not task:
continue
task_id = task['task_id']
request_data = task['data']
logger.info(f"工作线程 {self.worker_id} 开始处理任务: {task_id}")
# 更新任务状态为处理中
self.task_queue.update_task_status(task_id, 'processing')
# 这里是你的实际处理逻辑
result = self.process_task(task_id, request_data)
# 更新任务状态为完成
self.task_queue.update_task_status(
task_id,
'completed' if result.get('success') else 'failed',
result
)
self.processed_count += 1
logger.info(f"工作线程 {self.worker_id} 完成任务: {task_id}, 处理总数: {self.processed_count}")
except Exception as e:
logger.error(f"工作线程 {self.worker_id} 处理任务失败: {e}")
if task:
self.task_queue.update_task_status(
task['task_id'],
'failed',
{'error': str(e)}
)
def process_task(self, task_id: str, request_data: dict) -> dict:
"""模拟耗时任务处理"""
# 这里替换为你的实际处理逻辑
# time.sleep(10) # 模拟10秒处理时间
task_id=request_data["task_id"]
sn=request_data["sn"]
img_url=request_data["img_url"]
prompt=request_data["prompt"]
confidence=request_data["confidence"]
mqtt_ip=request_data["mqtt_ip"]
mqtt_port=request_data["mqtt_port"]
mqtt_topic=request_data["mqtt_topic"]
local_image_path=downFile(img_url)
bpe_path = f"/home/beidou/test0623/sam3/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
config_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/config.json" # 替换为本地路径
checkpoint_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/sam3.pt" # 替换为本地路径
# model = build_sam3_image_model(bpe_path=bpe_path)
# 2. 构建模型(从本地加载)
model = build_sam3_image_model_0228(
bpe_path=bpe_path,
checkpoint_path=checkpoint_path,
config_path=config_path, # 可选
load_from_HF=False,
device="cuda",
eval_mode=True,
)
formatted_date, milliseconds_timestamp = get_current_date_and_milliseconds()
img_name=os.path.basename(local_image_path)
dir_name=os.path.dirname(local_image_path)
predict_save_path=os.path.join(dir_name,str(milliseconds_timestamp)+img_name)
# image = Image.open(image_path)
image = Image.open(local_image_path).convert("RGB")
width, height = image.size
processor = Sam3Processor(model, confidence_threshold=0.5)
inference_state = processor.set_image(image)
processor.reset_all_prompts(inference_state)
inference_state = processor.set_text_prompt(state=inference_state, prompt="road")
img0 = Image.open(local_image_path)
plot_results_savepic(img0, inference_state, save_path=predict_save_path)
object_name, _=upload_file(predict_save_path,None)
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
message = {
'success': True,
"task_id":task_id,
'object_name': object_name
}
mqtt.publish_sync(mqtt_topic, json.dumps(message, ensure_ascii=False))
# 删除本地文件
if os.path.exists(local_image_path):
os.remove(local_image_path)
if os.path.exists(predict_save_path):
os.remove(predict_save_path)
# 模拟处理结果
return {
'success': True,
'message': f'任务 {task_id} 处理完成',
'data': {'result': 'some_result'}
}
class TaskServiceServicer(grpc_sam3_img_pb2_grpc.TaskServiceServicer):
def __init__(self, task_queue: TaskQueue, max_workers: int = 1):
self.task_queue = task_queue
self.max_workers = max_workers
self.stop_event = threading.Event()
self.workers = []
# 启动工作线程
self.start_workers()
def start_workers(self):
"""启动工作线程池"""
for i in range(self.max_workers):
worker = TaskWorker(i, self.task_queue, self.stop_event)
worker.start()
self.workers.append(worker)
logger.info(f"启动了 {self.max_workers} 个工作线程")
def ProcessTask(self, request, context):
"""处理任务请求 - 将任务放入队列后立即返回"""
try:
# 检查队列是否已满
if self.task_queue.queue.full():
logger.warning(f"队列已满,拒绝任务: {request.task_id}")
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message="服务器忙,请稍后重试"
)
# 准备任务数据
task_data = {
'task_id': request.task_id,
'sn': request.sn,
'img_url': request.content_body.img_url,
'prompt': request.content_body.prompt,
'confidence': request.content_body.confidence,
'mqtt_ip': request.content_body.mqtt_ip,
'mqtt_port': request.content_body.mqtt_port,
'mqtt_topic': request.content_body.mqtt_topic
}
# 将任务添加到队列
if self.task_queue.add_task(request.task_id, task_data):
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message=f"任务已接收,正在排队处理。当前队列位置: {self.task_queue.queue.qsize()}"
)
else:
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message="任务提交失败"
)
except Exception as e:
logger.error(f"处理任务请求失败: {e}")
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message=f"服务器内部错误: {str(e)}"
)
def stop(self):
"""停止工作线程"""
self.stop_event.set()
for worker in self.workers:
worker.join(timeout=2)
logger.info("所有工作线程已停止")
class HealthCheckServicer(grpc_sam3_img_pb2_grpc.HealthCheckServicer):
def __init__(self, task_queue: TaskQueue):
self.task_queue = task_queue
def Check(self, request, context):
"""健康检查,包含队列状态"""
queue_size = self.task_queue.queue.qsize()
if queue_size > 50: # 队列过长
return grpc_sam3_img_pb2.HealthCheckResponse(
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.NOT_SERVING
)
else:
return grpc_sam3_img_pb2.HealthCheckResponse(
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
# 创建任务队列
task_queue = TaskQueue(max_size=20)
# 创建服务实例
task_service = TaskServiceServicer(task_queue, max_workers=1) # 10个工作线程
health_service = HealthCheckServicer(task_queue)
# 创建gRPC服务器
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) # 处理RPC请求的线程
# 注册服务
grpc_sam3_img_pb2_grpc.add_TaskServiceServicer_to_server(task_service, server)
grpc_sam3_img_pb2_grpc.add_HealthCheckServicer_to_server(health_service, server)
# 启动服务器
server.add_insecure_port('[::]:50051')
server.start()
logger.info("服务器已启动,监听端口: 50051")
logger.info(f"工作线程数: 1, 队列最大容量: 20")
# 定时清理旧任务
def cleanup_loop():
while True:
time.sleep(300) # 每5分钟清理一次
task_queue.cleanup_old_tasks()
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
cleanup_thread.start()
# 优雅关闭处理
def shutdown():
logger.info("收到关闭信号,正在停止服务器...")
task_service.stop()
server.stop(5) # 5秒宽限期
logger.info("服务器已停止")
import signal
signal.signal(signal.SIGINT, lambda s, f: shutdown())
signal.signal(signal.SIGTERM, lambda s, f: shutdown())
# 保持服务器运行
try:
server.wait_for_termination()
except KeyboardInterrupt:
shutdown()
if __name__ == '__main__':
serve()

284
md/grpc.md Normal file
View File

@ -0,0 +1,284 @@
# 前言
sanic 和 服务之间基于grpc 解绑一个服务一个grpc
可以参考接口grpc 最好留有健康校验
# grpc demo
为了增强 gRPC 通讯的可靠性,我们可以添加以下功能:
1. 服务器健康检查服务
2. 客户端连接前检查服务器状态
3. 通讯过程中的错误处理和重试机制
## 1. 修改 protobuf 定义 (task.proto)
首先添加健康检查服务定义:
```protobuf
syntax = "proto3";
package task;
service TaskService {
rpc ProcessTask (TaskRequest) returns (TaskResponse);
}
// 添加健康检查服务
service HealthCheck {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
message TaskRequest {
string task_id = 1;
string sn = 2;
ContentBody content_body = 3;
}
message ContentBody {
string org_code = 1;
repeated int32 func_id = 2;
string source_url = 3;
string push_url = 4;
float confidence = 5;
repeated ParaList para_list = 6;
Invade invade = 7;
}
message ParaList {
int32 func_id = 1;
bool para_invade_enable = 2;
}
message Invade {
string invade_file = 1;
string camera_para_url = 2;
}
message TaskResponse {
string task_id = 1;
bool success = 2;
string message = 3;
}
```
## 2. 重新生成 Python 代码
```bash
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. task.proto
protoc -I=./proto --go_out=./gen proto/subfolder/task.proto
```
## 3. 增强服务端实现 (server.py)
```python
from concurrent import futures
import grpc
import time
import task_pb2
import task_pb2_grpc
class TaskServiceServicer(task_pb2_grpc.TaskServiceServicer):
def ProcessTask(self, request, context):
print(f"Received task_id: {request.task_id}")
# ... 原有处理逻辑
return task_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message="Task processed successfully"
)
class HealthCheckServicer(task_pb2_grpc.HealthCheckServicer):
def Check(self, request, context):
# 简单实现总是返回SERVING状态
# 实际应用中可以根据服务状态返回不同值
return task_pb2.HealthCheckResponse(
status=task_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
# 添加服务实现
task_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
task_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
server.add_insecure_port('[::]:50051')
server.start()
print("Server started, listening on port 50051...")
try:
while True:
time.sleep(86400) # 保持运行
except KeyboardInterrupt:
server.stop(0)
if __name__ == '__main__':
serve()
```
## 4. 增强客户端实现 (client.py)
```python
import grpc
import time
import task_pb2
import task_pb2_grpc
def check_server_status(channel):
try:
health_stub = task_pb2_grpc.HealthCheckStub(channel)
response = health_stub.Check(task_pb2.HealthCheckRequest(service="TaskService"))
return response.status == task_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
def run_with_retry(max_retries=3, delay=5):
channel = None
retries = 0
while retries < max_retries:
try:
# 创建通道
channel = grpc.insecure_channel('localhost:50051')
# 检查服务器状态
if not check_server_status(channel):
raise Exception("Server is not healthy")
stub = task_pb2_grpc.TaskServiceStub(channel)
# 创建请求消息
request = task_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
sn="8UUXN6S00A0CK7",
content_body=task_pb2.ContentBody(
org_code="HMZHB",
func_id=[101204],
source_url="rtmp://222.212.85.86:1935/live/1581F8HGX253S00A05L8",
push_url="",
confidence=0.4,
para_list=[
task_pb2.ParaList(
func_id=101204,
para_invade_enable=True
)
],
invade=task_pb2.Invade(
invade_file="meta_data/高压线-0826.geojson",
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
)
)
)
# 调用远程方法
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
finally:
if channel:
channel.close()
print("All retry attempts failed")
return False
if __name__ == '__main__':
run_with_retry()
```
## 5. 增强功能说明
1. 健康检查服务
- 添加了标准的 gRPC 健康检查服务
- 客户端可以在执行主要操作前检查服务状态
2. 错误处理和重试机制
- 客户端现在会捕获 `grpc.RpcError` 和其他异常
- 实现了最大重试次数和重试间隔
- 每次重试前都会检查服务器状态
3. 资源管理
- 确保在所有情况下都正确关闭通道
- 使用上下文管理器处理通道生命周期
4. 状态反馈
- 提供更详细的错误信息
- 记录重试尝试
## 6. 运行说明
1. 安装依赖:
```bash
pip install grpcio grpcio-tools
```
2. 生成 protobuf 代码:
```bash
# 确保在项目根目录下执行
# 编译 proto 文件,输出到 a/b/ 目录
python -m grpc_tools.protoc \
-I. \ # 指定 proto 文件的根目录(当前目录 "."
--python_out=. \ # 生成 _pb2.py 到当前目录(或指定子目录)
--grpc_python_out=. \ # 生成 _pb2_grpc.py 到当前目录(或指定子目录)
a/b/task.proto # proto 文件路径(相对于 -I 指定的根目录)
```
3. 启动服务器:
```bash
python server.py
```
4. 运行客户端:
```bash
python client.py
```
## 7. 测试场景
1. 服务器未运行
- 客户端会检测到连接失败并重试
- 最终显示所有重试失败
2. 服务器运行但健康检查失败
- 可以修改 `HealthCheckServicer` 返回 `NOT_SERVING` 状态进行测试
- 客户端会拒绝执行主要操作
3. 网络中断
- 客户端会捕获异常并尝试重试
这个增强版本提供了更健壮的 gRPC 通讯机制,适合生产环境使用。

350
md/接口.md Normal file
View File

@ -0,0 +1,350 @@
# 算法与后台解耦规则
# 1、方法
postgres 的ai_model_list 表id字段声明为6位长度数字
1、第1位表示算法类别1xxxxx 表明为目标识别、2xxxxx标明为语义分割、3xxxxx表示变化监测
2、最后两位表示二次计算100001 表示做目标识别、100002表示做目标识别且做人员计数
# 接口名:视频流识别
接收前端的视频流、模型、识别类型算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/stream/back_detect
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "1234567890", #任务id
"sn":"", #无人机sn
"content_body": {
"source_url": "rtmp://192.168.0.142:1935/live/123456", #无人机视频流url
"confidence":0.4, #置信度
"model_func_id":[100001,100002] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "1234567890",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "1234567890",
"minio": {
"minio_path": "ai_result/20250702/1751425303860-output-1751425303800959985.jpg",
"file_type": "pic"
},
"box_detail": {
"result_100001": {
"func_id_10001": 100001,
"type_name": "行人",
"cls_count": 1,
"box_count": [
[
{
"track_id": 22099,
"confidence": 0.34013107419013977,
"cls_id": 0,
"type_name": "行人",
"box": [
15.935794830322266,
694.75390625,
33.22901916503906,
713.1658935546875
]
}
]
]
}
},
"uav_location": {
"data": {
"attitude_head": 60,
"gimbal_pitch": 60,
"gimbal_roll": 60,
"gimbal_yaw": 60,
"height": 10,
"latitude": 10,
"longitude": 10,
"speed_x": 10,
"speed_y": 10,
"speed_z": 10
},
"timestamp": 1751425301213249700
}
}
```
# 接口名:图片识别
接收前端的图片算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "0001111",
"content_body": {
"s3_id":1, #根据id适配minio相关存储参数
"s3_url":[
"test/frame_0000.jpg","test/frame_0001.jpg","test/frame_0002.jpg" # minio文件地址
],
"confidence":0.4, #算法置信度
"model_func_id":[10001,10002] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "0001111", #任务id
"minio": {
"minio_path": "ai_result/20250627/1751006943659-frame_0001.jpg", # minio 存储路径
"file_type": "pic"
},
"box_detail": {
"model_id": 10001,
"box_count": [
{
"type": 3, # 类型
"type_name": "车辆", #类型名称
"count": 71 #数量
},
{
"type": 0,
"type_name": "车辆",
"count": 7
}
]
}
}
```
# 接口名:地类分割
接收前端的图片算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
"content_body": {
"s3_id":1,#根据id适配minio相关存储参数
"s3_url":[
"test/patch_0011.png", # minio文件地址
"test/patch_0012.png"
],
"model_func_id":[20000,20001] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
"minio": [
{
"minio_path_before": "ai_result/20250710/1752128232469-patch_0011.png", # 需要分割的图片
"minio_path_after": "ai_result/20250710/1752128234222-patch_0011.png", #分割之后的图片
"minio_path_boundary": "ai_result/20250710/1752128234264-patch_0011.pngfinal_vis.png", # 分割的边界图片
"minio_path_json": "ai_result/20250710/1752128234326-patch_0011.pnginstance_results.json", #分割生成的json文件
"file_type": "pic"
},
{
"minio_path_before": "ai_result/20250710/1752128240382-patch_0012.png",
"minio_path_after": "ai_result/20250710/1752128241553-patch_0012.png",
"minio_path_boundary": "ai_result/20250710/1752128241587-patch_0012.pngfinal_vis.png",
"minio_path_json": "ai_result/20250710/1752128241631-patch_0012.pnginstance_results.json",
"file_type": "pic"
}
]
}
```
# 接口名:地类变化监测
接收前端的图片对一期、二期的图像做变化监测并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
"content_body": {
"s3_id":1,
"s3_url":{
"early":"/test/1-00205.png", # 一期图像minio文件地址
"later":"/test/2-00205.png" # 二期图像minio文件地址
},
"model_func_id":[30000,30001]
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
"minio": {
"minio_path_1": "ai_result/20250627/1751007686483-1-00205.png", # 一期影像minio地址
"minio_path_2": "ai_result/20250627/1751007686541-2-00205.png", # 二期影像minio地址
"minio_path_result": "ai_result/20250627/1751007686.458642-result-2-00205.png", #识别结果minio地址
"file_type": "pic"
}
}
```

View File

@ -51,6 +51,7 @@ class MQTTService:
self._message_task = None self._message_task = None
self._connection_lock = asyncio.Lock() self._connection_lock = asyncio.Lock()
self.os_type = sys.platform.lower() self.os_type = sys.platform.lower()
self._loop = None # 保存事件循环
async def connect(self): async def connect(self):
async with self._connection_lock: async with self._connection_lock:
@ -122,6 +123,25 @@ class MQTTService:
await self.reconnect() await self.reconnect()
await self.client.publish(topic, payload, qos=qos, retain=retain) await self.client.publish(topic, payload, qos=qos, retain=retain)
def publish_sync(self, topic, payload, qos=0, retain=False):
"""同步发布消息(适用于同步代码调用)"""
if self._loop is None or self._loop.is_closed():
# 如果没有事件循环,创建一个新的
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
# 在事件循环中运行异步publish
if not self._loop.is_running():
return self._loop.run_until_complete(
self.publish(topic, payload, qos=qos, retain=retain)
)
else:
# 如果事件循环已经在运行,创建一个任务
asyncio.ensure_future(
self.publish(topic, payload, qos=qos, retain=retain)
)
return None
async def subscribe(self, topic, callback=None, qos=0): async def subscribe(self, topic, callback=None, qos=0):
if not self.is_connected: if not self.is_connected:
await self.connect() await self.connect()

View File

@ -1,7 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
import json
from typing import Optional
import asyncio import asyncio
from typing import Union, Optional
from typing import Any
@dataclass @dataclass
@ -25,9 +29,23 @@ class OSDInfo:
wind_speed: float wind_speed: float
@dataclass
class OSDInfo_v1:
attitude_head: float
latitude: float
longitude: float
height: float
speed_x: float
speed_y: float
speed_z: float
gimbal_pitch: float
gimbal_roll: float
gimbal_yaw: float
@dataclass @dataclass
class OSDMessage: class OSDMessage:
data: OSDInfo data: Any # 可以是两种类型之一
method: str method: str
seq: int seq: int
timestamp: int timestamp: int
@ -36,19 +54,47 @@ class OSDMessage:
def parse_osd_message(json_str: Optional[str]) -> Optional[OSDMessage]: def parse_osd_message(json_str: Optional[str]) -> Optional[OSDMessage]:
if not json_str: if not json_str:
return None return None
data = json_str
try: try:
data=json_str
osd_info = OSDInfo(**data["data"]) osd_info = OSDInfo(**data["data"])
return OSDMessage( data_seq=data["seq"]
data=osd_info, except (TypeError, KeyError) as e:
method=data["method"], # 如果OSDInfo格式失败尝试使用OSDInfo_v1格式
seq=data["seq"], try:
timestamp=data["timestamp"] osd_info = OSDInfo_v1(**data["data"])
) data_seq = 0
except Exception as e: except Exception as e2:
print(f"Error parsing OSD message: {e}") print(f"Error parsing OSD message with both formats: {e2}")
return None return None
return OSDMessage(
data=osd_info,
method=data["method"],
seq=data_seq,
timestamp=data["timestamp"]
)
# try:
# data=json_str
# # osd_info = OSDInfo(**data["data"])
# # if osd_info is None:
# # osd_info = OSDInfo_v1(**data["data"]) #适配东西湖区一代飞机的格式
# try:
# osd_info = OSDInfo(**data["data"])
# except (TypeError, KeyError) as e:
# # 如果OSDInfo格式失败尝试使用OSDInfo_v1格式
# try:
# osd_info = OSDInfo_v1(**data["data"])
# except Exception as e2:
# print(f"Error parsing OSD message with both formats: {e2}")
# return None
# return OSDMessage(
# data=osd_info,
# method=data["method"],
# seq=data["seq"],
# timestamp=data["timestamp"]
# )
# except Exception as e:
# print(f"Error parsing OSD message: {e}")
# return None
async def main(): async def main():

View File

@ -46,6 +46,7 @@ class ModelData:
so_path: str so_path: str
repeat_dis: float repeat_dis: float
repeat_time: float repeat_time: float
high_count_warn: float
func_description: Optional[str] func_description: Optional[str]
filter_indices: List[int] filter_indices: List[int]
class_indices: List[int] class_indices: List[int]
@ -250,6 +251,7 @@ class ModelConfigDAO:
aml.py_func, aml.py_func,
aml.repeat_dis, aml.repeat_dis,
aml.repeat_time, aml.repeat_time,
aml.high_count_warn,
am.scope, am.scope,
am.yolo_version, am.yolo_version,
am.PATH, am.PATH,
@ -572,6 +574,7 @@ WHERE
filter_indices=filter_indices, filter_indices=filter_indices,
repeat_dis=repeat_dis, repeat_dis=repeat_dis,
repeat_time=row.get('repeat_time'), repeat_time=row.get('repeat_time'),
high_count_warn=row.get('high_count_warn'),
class_indices=row['cls_index'], class_indices=row['cls_index'],
conf=conf, conf=conf,
classes=classes, classes=classes,

View File

@ -0,0 +1,645 @@
import json
import time
from sanic import Blueprint
from sanic.response import json as json_response
from sanic.exceptions import Unauthorized, SanicException
from dataclasses import dataclass, asdict
from typing import List, Dict, Any
import logging
import asyncio
import traceback
from datetime import datetime
from sympy import false
try :
from middleware.TaskManager import task_manager
from middleware.query_model import ModelConfigDAO
from middleware.query_postgress import batch_query_model_func_id
from middleware.read_yolo_config import read_local_func_config
from yolo.cv_multi_model_back_video import start_rtmp_processing
from cv_video import startAIVideo, stopAIVideo
from cv_back_video import startBackAIVideo
except Exception as e:
import sys
from pathlib import Path
# 获取项目根目录(假设脚本在根目录的子目录中)
ROOT_DIR = Path(__file__).parent.parent # 根据实际层级调整
sys.path.append(str(ROOT_DIR))
from middleware.TaskManager import task_manager
from middleware.query_model import ModelConfigDAO
from middleware.query_postgress import batch_query_model_func_id
from middleware.read_yolo_config import read_local_func_config
from yolo.cv_multi_model_back_video import start_rtmp_processing
from cv_video import startAIVideo, stopAIVideo
from cv_back_video import startBackAIVideo
# 配置类
class Config:
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
MAX_ACTIVE_TASKS = 10
DEFAULT_CONFIDENCE = 0.5
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
#正式数据库
DB_CONFIG = {
"dbname": "smart_dev",
"user": "postgres",
"password": "StrongPassword@123",
"host": "222.212.85.86",
"port": "5061"
}
# 服务状态标志
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
multi_back_detect_bp = Blueprint("multi_back_detect", url_prefix="")
@dataclass
class BaseContentBody:
pass # 公共字段可以放在这里
@dataclass
class ContentBodyFormat_VideoMultiBackDetect(BaseContentBody):
# mqtt_pub_id: int
org_code: str
# mqtt_ip: str
# mqtt_port: int
# mqtt_topic: str
minio_file_path: str
push_url: str # 临时测试用
confidence: float
para_list: list
# invade_file: str
invade: list
@dataclass
class ContentBodyFormat_Sam3Pic(BaseContentBody):
img_url: str
prompt: str
confidence: int
mqtt_ip: str
mqtt_port: int
mqtt_topic: str
@dataclass
class ContentBodyFormat_MultiBackDetect(BaseContentBody):
# mqtt_pub_id: int
# mqtt_sub_id: int
# mqtt_pub_ip: str
# mqtt_pub_port: int
# mqtt_pub_topic: str
# mqtt_sub_ip: str
# mqtt_sub_port: int
# mqtt_sub_topic: str
org_code: str
func_id: list
source_url: str
push_url: str # 临时测试用
confidence: float
para_list: list
# invade_file: str
invade: list
@dataclass
class ContentBodyFormat_BackDetect(BaseContentBody):
mqtt_ip: str
mqtt_port: int
mqtt_topic: str
source_url: str
push_url: str # 临时测试用
confidence: float
func_id: List[int]
para: {}
@dataclass
class ContentBodyFormat_BackDetectPic(BaseContentBody):
s3_id: int
s3_url: list[str]
org_code: str
# mqtt_pub_id: int
confidence: float
func_id: list[int]
para: {}
@dataclass
class EarlyLaterUrls:
early: str
later: str
@dataclass
class ContentBodyFormat_Detection(BaseContentBody):
s3_id: int
s3_url: EarlyLaterUrls
func_id: list[int]
@dataclass
class ContentBodyFormat_Segementation(BaseContentBody):
s3_id: int
s3_url: list[str]
func_id: list[int]
@dataclass
class RequestJson:
task_id: str
sn: str
content_body: BaseContentBody
def validate(self) -> None:
"""验证请求参数"""
if not self.task_id:
raise ValueError("task_id is required")
if isinstance(self.content_body, ContentBodyFormat_VideoMultiBackDetect):
if not self.content_body.minio_file_path:
raise ValueError("minio_file_path is required for ContentBodyFormat_VideoMultiBackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_MultiBackDetect):
if not self.content_body.para_list:
raise ValueError("para_list is required for ContentBodyFormat_MultiBackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_BackDetect):
if not self.content_body.source_url:
raise ValueError("source_url is required for ContentBodyFormat_BackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_BackDetectPic):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_BackDetectPic")
if not self.content_body.s3_url:
raise ValueError("s3_url is required for ContentBodyFormat_BackDetectPic")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetectPic")
elif isinstance(self.content_body, ContentBodyFormat_Detection):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_Detection")
if not self.content_body.s3_url.early or not self.content_body.s3_url.later:
raise ValueError("Both early and later URLs are required for ContentBodyFormat_Detection")
elif isinstance(self.content_body, ContentBodyFormat_Segementation):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_Segementation")
if not self.content_body.s3_url:
raise ValueError("s3_url is required for ContentBodyFormat_Segementation")
elif isinstance(self.content_body, ContentBodyFormat_Sam3Pic):
if not self.content_body.prompt:
raise ValueError("prompt is required for ContentBodyFormat_Sam3Pic")
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RequestJson':
try:
task_id = data['task_id']
sn = data['sn']
content_body_data = data['content_body']
if 'minio_file_path' in content_body_data:
content_body = ContentBodyFormat_VideoMultiBackDetect(
org_code=content_body_data['org_code'],
minio_file_path=content_body_data['minio_file_path'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
para_list=content_body_data.get('para_list', []),
invade=content_body_data.get('invade', [])
)
# 根据 content_body_data 的类型创建相应的实例
elif 'para_list' in content_body_data:
content_body = ContentBodyFormat_MultiBackDetect(
org_code=content_body_data['org_code'],
func_id=content_body_data['func_id'],
source_url=content_body_data['source_url'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
para_list=content_body_data.get('para_list', []),
invade=content_body_data.get('invade', [])
)
# 根据 content_body_data 的类型创建相应的实例
elif 'source_url' in content_body_data:
content_body = ContentBodyFormat_BackDetect(
mqtt_ip=content_body_data['mqtt_ip'],
mqtt_port=content_body_data['mqtt_port'],
mqtt_topic=content_body_data['mqtt_topic'],
source_url=content_body_data['source_url'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
func_id=content_body_data.get('func_id', []),
para=content_body_data.get('para', {})
)
elif 's3_id' in content_body_data and 's3_url' in content_body_data:
if isinstance(content_body_data['s3_url'], dict) and 'early' in content_body_data[
's3_url'] and 'later' in content_body_data['s3_url']:
content_body = ContentBodyFormat_Detection(
s3_id=content_body_data['s3_id'],
s3_url=EarlyLaterUrls(
early=content_body_data['s3_url']['early'],
later=content_body_data['s3_url']['later']
),
func_id=content_body_data.get('func_id', [])
)
elif isinstance(content_body_data['s3_url'], list) and 'confidence' not in content_body_data:
content_body = ContentBodyFormat_Segementation(
s3_id=content_body_data['s3_id'],
s3_url=content_body_data['s3_url'],
func_id=content_body_data.get('func_id', [])
)
elif isinstance(content_body_data['s3_url'], list):
content_body = ContentBodyFormat_BackDetectPic(
s3_id=content_body_data['s3_id'],
s3_url=content_body_data['s3_url'],
# mqtt_pub_id=content_body_data['mqtt_pub_id'],
org_code=content_body_data['org_code'],
confidence=content_body_data.get('confidence', 0.5),
func_id=content_body_data.get('func_id', []),
para=content_body_data.get('para', {})
)
else:
raise ValueError("Invalid s3_url format for ContentBodyFormat_Detection")
elif 'prompt' in content_body_data:
content_body = ContentBodyFormat_Sam3Pic(
img_url=content_body_data['img_url'],
prompt=content_body_data['prompt'],
confidence=content_body_data.get('confidence', 0.5),
mqtt_ip=content_body_data['mqtt_ip'],
mqtt_port=content_body_data['mqtt_port'],
mqtt_topic=content_body_data['mqtt_topic']
)
else:
raise ValueError("Invalid content_body format")
instance = cls(
task_id=task_id,
sn=sn,
content_body=content_body
)
instance.validate()
return instance
except KeyError as e:
raise ValueError(f"Missing required field: {str(e)}")
async def safe_stop_ai_video():
"""安全地停止AI视频处理带有错误处理和恢复机制"""
try:
await asyncio.to_thread(stopAIVideo)
return True
except Exception as e:
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
# 标记服务状态为不健康
service_status["is_healthy"] = False
service_status["last_error"] = str(e)
service_status["error_time"] = datetime.now().isoformat()
# 强制结束所有任务
task_manager.mark_all_tasks_as_stopped()
# 尝试通过其他方式杀死可能存在的进程
try:
import os
import signal
import psutil
current_process = psutil.Process(os.getpid())
# 查找并终止ffmpeg子进程
for child in current_process.children(recursive=True):
try:
child_name = child.name().lower()
if 'ffmpeg' in child_name:
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
child.send_signal(signal.SIGTERM)
except Exception as child_e:
logger.error(f"终止子进程出错: {str(child_e)}")
except Exception as kill_e:
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
# 等待一段时间让系统恢复
await asyncio.sleep(Config.RESTART_DELAY)
# 重置服务状态
service_status["is_healthy"] = True
return False
def verify_token(request) -> None:
"""验证请求token"""
token = request.headers.get('X-API-Token')
if not token or token != Config.VALID_TOKEN:
logger.warning("Invalid token attempt")
raise Unauthorized("Invalid token")
@multi_back_detect_bp.post("/ai/stream/multi_back_detect")
async def start_multi_back_detection(request):
try:
verify_token(request)
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
service_status["is_healthy"] = True
# 停止所有现有任务(可选,根据需求调整)
# for task_id in list(task_manager.tasks.keys()):
# await task_manager.remove_task(task_id)
# 解析并验证请求数据
request_json = RequestJson.from_dict(request.json)
print(f"/ai/stream/multi_back_detect 请求:{request.json}")
time.sleep(3)
if request_json.task_id in task_manager.tasks:
logger.warning(f"任务 {request_json.task_id} 已存在,跳过创建")
return json_response({
"status": "error",
"message": f"任务 {request_json.task_id} 已存在,跳过创建"
}, status=500)
if isinstance(request_json.content_body, ContentBodyFormat_MultiBackDetect):
try:
# 创建停止事件
stop_event = asyncio.Event()
# 包装处理函数以支持停止事件
async def wrapped_processing():
try:
await run_back_Multi_Detect_async(request, request_json, stop_event)
except asyncio.CancelledError:
logger.info(f"任务 {request_json.task_id} 被取消")
except Exception as e:
logger.error(f"任务 {request_json.task_id} 异常终止: {e}")
# 创建并启动任务
task_handle = asyncio.create_task(wrapped_processing())
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
else:
return json_response({
"status": "failed",
"message": "content_body structure is wrong"
})
return json_response({
"status": "success",
"task_id": request_json.task_id,
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio.Event):
global DB_CONFIG
model_configs = []
task_handle = None # 初始化task_handle
try:
invade_enable = false # para_list 中可以包含多个侵限使能其中一个使能即将置为True
py_func = []
# 创建DAO实例
dao = ModelConfigDAO(DB_CONFIG)
# insert_request_log(self, task_id, sn, org_code, requset_json, request)
for para in request_json.content_body.para_list:
func_id = para["func_id"]
category = para.get("py_func", []) # 提供默认值
py_func = category # 湖北现场临时用
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
if para_invade_enable:
invade_enable = True
query_results = batch_query_model_func_id([func_id], **DB_CONFIG)
row_func_id = 0 # 伪代码,后续记得修改
if len(query_results) < 1:
continue
for row in query_results:
row_func_id = row["model_func_id"]
func_id = para["func_id"]
config = dao.get_config(func_id, category)
repeat_dis = -1 # 基于两帧之间的距离去重
if config:
# 打印结构化结果使用自定义编码器处理datetime
# 访问特定字段
print("\n模型路径:", config.model_path)
print("过滤类别:", config.filter_indices)
print("第一个类别:", asdict(config.classes[0]))
print("创建时间:", config.created_at)
print("更新时间:", config.updated_at)
print("去重的距离:", config.repeat_dis)
repeat_dis = config.repeat_dis
repeat_time = config.repeat_time
high_count_warn = config.high_count_warn
print(f"config.high_count_warn {config.high_count_warn}")
model_configs.append(
{
'path': config.model_path,
# 'engine_path': config.engine_path,
# 'so_path': config.so_path,
# # 测试代码
'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine",
'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll",
# 工地安全帽
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
'cls_map': config.cls_zn_to_eh_dict,
'allowed_classes': config.allowed_classes,
"cls_index": config.class_indices,
"class_names": config.cls_names,
"chinese_label": config.cls_en_dict,
"list_func_id": row_func_id,
"func_id": func_id,
"para_invade_enable": para_invade_enable,
"config_conf": config.conf
}
)
else:
print(f"未找到ID为 {func_id} 的模型配置")
# category = para.get("category", []) # 提供默认值
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
if para_invade_enable:
invade_enable = True
# 前置处理 不同的版本的模型输入的数据的格式不同需要做前置处理入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式
video_url = request_json.content_body.source_url
sn = request_json.sn
task_id = request_json.task_id
# mqtt_pub_id = request_json.content_body.mqtt_pub_id
# mqtt_sub_id = request_json.content_body.mqtt_sub_id
org_code = request_json.content_body.org_code
push_url = request_json.content_body.push_url
# invade_file = request_json.content_body.invade_file
invade = request_json.content_body.invade
invade_file = invade["invade_file"]
camera_para_url = invade["camera_para_url"]
if high_count_warn is None:
high_count_warn=0
if "invade_switch" in invade:
invade_switch = invade["invade_switch"]
else:
invade_switch = 0 # 或其他默认值
# dao.get_mqtt_config_by_orgcode(org_code,)
str_request = str(request) + "&" + str(request.socket) # 待测试看看公网能不能捕获到请求端ip
dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request)
mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
mqtt_sub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "sub")
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
mqtt_pub_port = mqtt_pub_config.mqtt_port
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
print(f"mqtt_pub_topic {mqtt_pub_topic}")
mqtt_pub_username = mqtt_pub_config.mqtt_username
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
mqtt_pub_description = mqtt_pub_config.mqtt_description
mqtt_pub_org_code = mqtt_pub_config.org_code
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
mqtt_sub_ip = mqtt_sub_config.mqtt_ip
mqtt_sub_port = mqtt_sub_config.mqtt_port
mqtt_sub_topic = mqtt_sub_config.mqtt_topic
mqtt_sub_topic = mqtt_sub_topic.format(sn=sn)
print(f"mqtt_sub_topic {mqtt_sub_topic}")
mqtt_sub_username = mqtt_sub_config.mqtt_username
mqtt_sub_pass = mqtt_sub_config.mqtt_pass
mqtt_sub_description = mqtt_sub_config.mqtt_description
mqtt_sub_org_code = mqtt_sub_config.org_code
mqtt_sub_mqtt_type = mqtt_sub_config.mqtt_type
local_func_config = read_local_func_config()
sn = request_json.sn
device = dao.get_device(sn, org_code)
print(f"device表 {sn} {org_code}")
device_sn = device.sn
device_orgcode = device.orgcode
device_dname = device.dname
device_lat = device.lat
device_lng = device.lng
device_height = device.height # 机场高度,后续用作现场的高程计算
# # 启动处理流程
async def process_flow():
try:
await start_rtmp_processing(
video_url,
request_json.task_id,
model_configs,
mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
push_url,
invade_enable,invade_switch, invade_file, camera_para_url,
device_height, repeat_dis, repeat_time,high_count_warn
)
except Exception as e:
logger.error(f"处理流程异常: {e}")
raise
# 运行处理流程,直到被取消
await asyncio.shield(process_flow())
# # # # 针对湖北现场临时处理--------------------------------------------------------
# model_path = model_configs[0]["path"]
# detect_classes = py_func
# print(f"detect_classesdetect_classes {detect_classes}")
# confidence = model_configs[0]["config_conf"]
# # 创建处理函数以支持停止事件
# async def process_video():
# nonlocal task_handle # 使用nonlocal访问外部变量
# try:
# # 针对湖北现场临时处理
# source_url = video_url
# model_path = model_configs[0]["path"]
# detect_classes = py_func # 使用py_func作为检测类别
# print(f"detect_classesdetect_classes {detect_classes}")
#
# confidence = model_configs[0]["config_conf"]
#
# # 启动YOLO检测
# await asyncio.to_thread(
# startAIVideo,
# source_url,
# push_url,
# model_path,
# detect_classes,
# confidence
# )
# except asyncio.CancelledError:
# logger.info(f"任务 {task_id} 被取消")
# raise
# except Exception as e:
# logger.error(f"任务 {task_id} 异常终止: {e}")
# raise
#
# # 创建并启动任务
# task_handle = asyncio.create_task(process_video()) # 存储task_handle
# 记录任务信息到task_manager
task_info = {
"source_url": video_url,
"push_url": push_url,
"status": "running",
"task_handle": task_handle, # 存储实际的任务句柄
"model_configs": model_configs,
"device_height": device_height,
"repeat_dis": repeat_dis,
"repeat_time": repeat_time
}
# 使用task_manager管理任务
await task_manager.add_task(
task_id,
task_info,
task_handle, # 传递实际的任务句柄
[] # 暂时没有子任务
)
# 等待任务完成或被取消
try:
await task_handle
except asyncio.CancelledError:
pass # 任务被取消是正常的
await task_manager.remove_task(task_id)
# # # 针对湖北现场处理结束--------------------------------------------------------
except asyncio.CancelledError:
logger.info(f"任务 {request_json.task_id} 收到停止信号,正在清理...")
# 清理资源逻辑...
raise
except Exception as e:
logger.error(f"任务 {request_json.task_id} 处理失败: {e}")
raise

View File

@ -0,0 +1,58 @@
import time
import logging
from sanic_cors import CORS
from sanic import Sanic, Request, json
from multi_back_detect_api import multi_back_detect_bp
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建Sanic应用
app = Sanic("multiBackDetectAPI")
# 显式注册蓝图
app.blueprint(multi_back_detect_bp)
CORS(app, automatic_options=True)
# 中间件:请求计时
@app.middleware("request")
async def add_start_time(request: Request):
request.ctx.start_time = time.time()
@app.middleware("response")
async def add_response_time(request: Request, response):
if hasattr(request.ctx, "start_time"):
process_time = (time.time() - request.ctx.start_time) * 1000
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
@app.get("/api/v1/health")
async def health_check(request: Request):
"""健康检查"""
return json({
"status": "healthy",
"timestamp": time.time(),
"service": "terrain-analysis-api",
"version": "1.0.0"
})
# 错误处理
@app.exception(Exception)
async def handle_exception(request: Request, exception):
"""全局异常处理"""
logger.error(f"未处理的异常: {exception}")
return json({
"error": "服务器内部错误",
"message": str(exception) if app.debug else "请稍后重试",
"success": False
}, status=500)
if __name__ == "__main__":
# 启动服务器
app.run(
host="0.0.0.0",
port=12320,
debug=False, # 生产环境设为False
access_log=True,
auto_reload=True
)

View File

@ -165,13 +165,15 @@ def pic_detect_func(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
try: try:
frame_copy = frame.copy() frame_copy = frame.copy()
results = counter(frame) results = counter(frame)
func_id=model_func_id_list[0]
annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names, annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names,
model_func_id_list, func_id,
local_func_cache, para, cls, chinese_label, local_func_cache, para, cls, chinese_label,
model_func_id_list[0]) model_func_id_list[0])
except Exception as e: except Exception as e:
print(f"处理帧错误: {e}") print(f"处理帧错误1: {e}")
error_count += 1 error_count += 1
if error_count >= 5: if error_count >= 5:
print(f"连续处理错误达到5次 ,正在停止处理...") print(f"连续处理错误达到5次 ,正在停止处理...")

BIN
pt/build-wall.pt Normal file

Binary file not shown.

2
readme
View File

@ -1,4 +1,4 @@
conda create -n aienv2 python=3.10 conda create -n test_sahi python=3.10
激活环境 激活环境
conda create -n yolo_trt_10 python=3.10 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ --override-channels conda create -n yolo_trt_10 python=3.10 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ --override-channels

153
sam3/.gitignore vendored Normal file
View File

@ -0,0 +1,153 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
*-Copy*.ipynb
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# PyCharm
.idea/
# VS Code
.vscode/
*.code-workspace
# Model weights and checkpoints
*.pth
*.pt
*.bin
*.ckpt
*.safetensors
weights/
checkpoints/
sam3_logs/
# Data files
*.h5
*.hdf5
*.pkl
*.pickle
*.npy
*.npz
# Logs
logs/
runs/
tensorboard/
# OS specific
.DS_Store
Thumbs.db
# BPE vocabulary files
*.bpe
*.vocab

80
sam3/CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,80 @@
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

30
sam3/CONTRIBUTING.md Normal file
View File

@ -0,0 +1,30 @@
# Contributing to sam3
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Make sure your code lints.
5. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to sam3, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

61
sam3/LICENSE Normal file
View File

@ -0,0 +1,61 @@
SAM License
Last Updated: November 19, 2025
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
“Documentation” means the specifications, manuals and documentation accompanying
SAM Materials distributed by Meta.
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entitys behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
“Trade Controls” means any of the following: Sanctions and applicable export and import controls.
By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
1. License Rights and Redistribution.
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Metas intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
b. Redistribution and Use.
i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
5. Intellectual Property.
a. Subject to Metas ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.

6
sam3/MANIFEST.in Normal file
View File

@ -0,0 +1,6 @@
include LICENSE
include README.md
recursive-include examples *.py
recursive-include examples *.ipynb
recursive-include examples *.md
recursive-include tests *.py

395
sam3/README.md Normal file
View File

@ -0,0 +1,395 @@
# SAM 3: Segment Anything with Concepts
Meta Superintelligence Labs
[Nicolas Carion](https://www.nicolascarion.com/)\*,
[Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en)\*,
[Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en)\*,
[Shoubhik Debnath](https://scholar.google.com/citations?user=fb6FOfsAAAAJ&hl=en)\*,
[Ronghang Hu](https://ronghanghu.com/)\*,
[Didac Suris](https://www.didacsuris.com/)\*,
[Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en)\*,
[Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en)\*,
[Haitham Khedr](https://hkhedr.com/)\*, Andrew Huang,
[Jie Lei](https://jayleicn.github.io/),
[Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en),
[Baishan Guo](https://scholar.google.com/citations?user=BC5wDu8AAAAJ&hl=en),
Arpit Kalla, [Markus Marks](https://damaggu.github.io/),
[Joseph Greer](https://scholar.google.com/citations?user=guL96CkAAAAJ&hl=en),
Meng Wang, [Peize Sun](https://peizesun.github.io/),
[Roman Rädle](https://scholar.google.com/citations?user=Tpt57v0AAAAJ&hl=en),
[Triantafyllos Afouras](https://www.robots.ox.ac.uk/~afourast/),
[Effrosyni Mavroudi](https://scholar.google.com/citations?user=vYRzGGEAAAAJ&hl=en),
[Katherine Xu](https://k8xu.github.io/)°,
[Tsung-Han Wu](https://patrickthwu.com/)°,
[Yu Zhou](https://yu-bryan-zhou.github.io/)°,
[Liliane Momeni](https://scholar.google.com/citations?user=Lb-KgVYAAAAJ&hl=en)°,
[Rishi Hazra](https://rishihazra.github.io/)°,
[Shuangrui Ding](https://mark12ding.github.io/)°,
[Sagar Vaze](https://sgvaze.github.io/)°,
[Francois Porcher](https://scholar.google.com/citations?user=LgHZ8hUAAAAJ&hl=en)°,
[Feng Li](https://fengli-ust.github.io/)°,
[Siyuan Li](https://siyuanliii.github.io/)°,
[Aishwarya Kamath](https://ashkamath.github.io/)°,
[Ho Kei Cheng](https://hkchengrex.com/)°,
[Piotr Dollar](https://pdollar.github.io/)†,
[Nikhila Ravi](https://nikhilaravi.com/)†,
[Kate Saenko](https://ai.bu.edu/ksaenko.html)†,
[Pengchuan Zhang](https://pzzhang.github.io/pzzhang/)†,
[Christoph Feichtenhofer](https://feichtenhofer.github.io/)†
\* core contributor, ° intern, † project lead, order is random within groups
[[`Paper`](https://ai.meta.com/research/publications/sam-3-segment-anything-with-concepts/)]
[[`Project`](https://ai.meta.com/sam3)]
[[`Demo`](https://segment-anything.com/)]
[[`Blog`](https://ai.meta.com/blog/segment-anything-model-3/)]
[[`BibTeX`](#citing-sam-3)]
![SAM 3 architecture](assets/model_diagram.png?raw=true) SAM 3 is a unified foundation model for promptable segmentation in images and videos. It can detect, segment, and track objects using text or visual prompts such as points, boxes, and masks. Compared to its predecessor [SAM 2](https://github.com/facebookresearch/sam2), SAM 3 introduces the ability to exhaustively segment all instances of an open-vocabulary concept specified by a short text phrase or exemplars. Unlike prior work, SAM 3 can handle a vastly larger set of open-vocabulary prompts. It achieves 75-80% of human performance on our new [SA-CO benchmark](https://github.com/facebookresearch/sam3?tab=readme-ov-file#sa-co-dataset) which contains 270K unique concepts, over 50 times more than existing benchmarks.
This breakthrough is driven by an innovative data engine that has automatically annotated over 4 million unique concepts, creating the largest high-quality open-vocabulary segmentation dataset to date. In addition, SAM 3 introduces a new model architecture featuring a presence token that improves discrimination between closely related text prompts (e.g., “a player in white” vs. “a player in red”), as well as a decoupled detectortracker design that minimizes task interference and scales efficiently with data.
<p align="center">
<img src="assets/dog.gif" width=380 />
<img src="assets/player.gif" width=380 />
</p>
## Installation
### Prerequisites
- Python 3.12 or higher
- PyTorch 2.7 or higher
- CUDA-compatible GPU with CUDA 12.6 or higher
1. **Create a new Conda environment:**
```bash
conda create -n sam3 python=3.12
conda deactivate
conda activate sam3
```
2. **Install PyTorch with CUDA support:**
```bash
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
```
3. **Clone the repository and install the package:**
```bash
git clone https://github.com/facebookresearch/sam3.git
cd sam3
pip install -e .
```
4. **Install additional dependencies for example notebooks or development:**
```bash
# For running example notebooks
pip install -e ".[notebooks]"
# For development
pip install -e ".[train,dev]"
```
## Getting Started
⚠️ Before using SAM 3, please request access to the checkpoints on the SAM 3
Hugging Face [repo](https://huggingface.co/facebook/sam3). Once accepted, you
need to be authenticated to download the checkpoints. You can do this by running
the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
(e.g. `hf auth login` after generating an access token.)
### Basic Usage
```python
import torch
#################################### For Image ####################################
from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# Load the model
model = build_sam3_image_model()
processor = Sam3Processor(model)
# Load an image
image = Image.open("<YOUR_IMAGE_PATH.jpg>")
inference_state = processor.set_image(image)
# Prompt the model with text
output = processor.set_text_prompt(state=inference_state, prompt="<YOUR_TEXT_PROMPT>")
# Get the masks, bounding boxes, and scores
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
#################################### For Video ####################################
from sam3.model_builder import build_sam3_video_predictor
video_predictor = build_sam3_video_predictor()
video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
# Start a session
response = video_predictor.handle_request(
request=dict(
type="start_session",
resource_path=video_path,
)
)
response = video_predictor.handle_request(
request=dict(
type="add_prompt",
session_id=response["session_id"],
frame_index=0, # Arbitrary frame index
text="<YOUR_TEXT_PROMPT>",
)
)
output = response["outputs"]
```
## Examples
The `examples` directory contains notebooks demonstrating how to use SAM3 with
various types of prompts:
- [`sam3_image_predictor_example.ipynb`](examples/sam3_image_predictor_example.ipynb)
: Demonstrates how to prompt SAM 3 with text and visual box prompts on images.
- [`sam3_video_predictor_example.ipynb`](examples/sam3_video_predictor_example.ipynb)
: Demonstrates how to prompt SAM 3 with text prompts on videos, and doing
further interactive refinements with points.
- [`sam3_image_batched_inference.ipynb`](examples/sam3_image_batched_inference.ipynb)
: Demonstrates how to run batched inference with SAM 3 on images.
- [`sam3_agent.ipynb`](examples/sam3_agent.ipynb): Demonsterates the use of SAM
3 Agent to segment complex text prompt on images.
- [`saco_gold_silver_vis_example.ipynb`](examples/saco_gold_silver_vis_example.ipynb)
: Shows a few examples from SA-Co image evaluation set.
- [`saco_veval_vis_example.ipynb`](examples/saco_veval_vis_example.ipynb) :
Shows a few examples from SA-Co video evaluation set.
There are additional notebooks in the examples directory that demonstrate how to
use SAM 3 for interactive instance segmentation in images and videos (SAM 1/2
tasks), or as a tool for an MLLM, and how to run evaluations on the SA-Co
dataset.
To run the Jupyter notebook examples:
```bash
# Make sure you have the notebooks dependencies installed
pip install -e ".[notebooks]"
# Start Jupyter notebook
jupyter notebook examples/sam3_image_predictor_example.ipynb
```
## Model
SAM 3 consists of a detector and a tracker that share a vision encoder. It has 848M parameters. The
detector is a DETR-based model conditioned on text, geometry, and image
exemplars. The tracker inherits the SAM 2 transformer encoder-decoder
architecture, supporting video segmentation and interactive refinement.
## Image Results
<div align="center">
<table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
<thead>
<tr>
<th rowspan="3" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
<th colspan="3" style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">Instance Segmentation</th>
<th colspan="5" style="text-align: center; padding: 12px 20px">Box Detection</th>
</tr>
<tr>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
<th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">SA-Co/Gold</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">COCO</th>
<th style="text-align: center; padding: 12px 20px">SA-Co/Gold</th>
</tr>
<tr>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
<th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">cgF1</th>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
<th style="text-align: center; padding: 12px 20px">AP</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP<sub>o</sub>
</th>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
</tr>
</thead>
<tbody>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">72.8</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">74.0</td>
</tr>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">OWLv2*</td>
<td style="text-align: center; padding: 10px 20px; color: #999">29.3</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">43.4</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">24.6</td>
<td style="text-align: center; padding: 10px 20px; color: #999">30.2</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">45.5</td>
<td style="text-align: center; padding: 10px 20px">46.1</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">23.9</td>
<td style="text-align: center; padding: 10px 20px">24.5</td>
</tr>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">DINO-X</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">38.5</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">21.3</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">52.4</td>
<td style="text-align: center; padding: 10px 20px">56.0</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">22.5</td>
</tr>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Gemini 2.5</td>
<td style="text-align: center; padding: 10px 20px">13.4</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">13.0</td>
<td style="text-align: center; padding: 10px 20px">16.1</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">14.4</td>
</tr>
<tr style="border-top: 2px solid #b19c9cff">
<td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
<td style="text-align: center; padding: 10px 20px">37.2</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">48.5</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">54.1</td>
<td style="text-align: center; padding: 10px 20px">40.6</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">53.6</td>
<td style="text-align: center; padding: 10px 20px">56.4</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">55.7</td>
<td style="text-align: center; padding: 10px 20px">55.7</td>
</tr>
</tbody>
</table>
<p style="text-align: center; margin-top: 10px; font-size: 0.9em; color: #ddd;">* Partially trained on LVIS, AP<sub>o</sub> refers to COCO-O accuracy</p>
</div>
## Video Results
<div align="center">
<table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
<thead>
<tr>
<th rowspan="2" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SA-V test</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">YT-Temporal-1B test</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SmartGlasses test</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVVIS test</th>
<th style="text-align: center; padding: 12px 20px">BURST test</th>
</tr>
<tr>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">mAP</th>
<th style="text-align: center; padding: 12px 20px">HOTA</th>
</tr>
</thead>
<tbody>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
<td style="text-align: center; padding: 10px 20px">53.1</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">70.5</td>
<td style="text-align: center; padding: 10px 20px">71.2</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">78.4</td>
<td style="text-align: center; padding: 10px 20px">58.5</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">72.3</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">-</td>
</tr>
<tr style="border-top: 2px solid #b19c9cff">
<td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
<td style="text-align: center; padding: 10px 20px">30.3</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">58.0</td>
<td style="text-align: center; padding: 10px 20px">50.8</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">69.9</td>
<td style="text-align: center; padding: 10px 20px">36.4</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">63.6</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">36.3</td>
<td style="text-align: center; padding: 10px 20px">44.5</td>
</tr>
</tbody>
</table>
</div>
## SA-Co Dataset
We release 2 image benchmarks, [SA-Co/Gold](scripts/eval/gold/README.md) and
[SA-Co/Silver](scripts/eval/silver/README.md), and a video benchmark
[SA-Co/VEval](scripts/eval/veval/README.md). The datasets contain images (or videos) with annotated noun phrases. Each image/video and noun phrase pair is annotated with instance masks and unique IDs of each object matching the phrase. Phrases that have no matching objects (negative prompts) have no masks, shown in red font in the figure. See the linked READMEs for more details on how to download and run evaluations on the datasets.
* HuggingFace host: [SA-Co/Gold](https://huggingface.co/datasets/facebook/SACo-Gold), [SA-Co/Silver](https://huggingface.co/datasets/facebook/SACo-Silver) and [SA-Co/VEval](https://huggingface.co/datasets/facebook/SACo-VEval)
* Roboflow host: [SA-Co/Gold](https://universe.roboflow.com/sa-co-gold), [SA-Co/Silver](https://universe.roboflow.com/sa-co-silver) and [SA-Co/VEval](https://universe.roboflow.com/sa-co-veval)
![SA-Co dataset](assets/sa_co_dataset.jpg?raw=true)
## Development
To set up the development environment:
```bash
pip install -e ".[dev,train]"
```
To format the code:
```bash
ufmt format .
```
## Contributing
See [contributing](CONTRIBUTING.md) and the
[code of conduct](CODE_OF_CONDUCT.md).
## License
This project is licensed under the SAM License - see the [LICENSE](LICENSE) file
for details.
## Acknowledgements
We would like to thank the following people for their contributions to the SAM 3 project: Alex He, Alexander Kirillov,
Alyssa Newcomb, Ana Paula Kirschner Mofarrej, Andrea Madotto, Andrew Westbury, Ashley Gabriel, Azita Shokpour,
Ben Samples, Bernie Huang, Carleigh Wood, Ching-Feng Yeh, Christian Puhrsch, Claudette Ward, Daniel Bolya,
Daniel Li, Facundo Figueroa, Fazila Vhora, George Orlin, Hanzi Mao, Helen Klein, Hu Xu, Ida Cheng, Jake Kinney,
Jiale Zhi, Jo Sampaio, Joel Schlosser, Justin Johnson, Kai Brown, Karen Bergan, Karla Martucci, Kenny Lehmann,
Maddie Mintz, Mallika Malhotra, Matt Ward, Michelle Chan, Michelle Restrepo, Miranda Hartley, Muhammad Maaz,
Nisha Deo, Peter Park, Phillip Thomas, Raghu Nayani, Rene Martinez Doehner, Robbie Adkins, Ross Girshik, Sasha
Mitts, Shashank Jain, Spencer Whitehead, Ty Toledano, Valentin Gabeur, Vincent Cho, Vivian Lee, William Ngan,
Xuehai He, Yael Yungster, Ziqi Pang, Ziyi Dou, Zoe Quake.
## Citing SAM 3
If you use SAM 3 or the SA-Co dataset in your research, please use the following BibTeX entry.
```bibtex
@misc{carion2025sam3segmentconcepts,
title={SAM 3: Segment Anything with Concepts},
author={Nicolas Carion and Laura Gustafson and Yuan-Ting Hu and Shoubhik Debnath and Ronghang Hu and Didac Suris and Chaitanya Ryali and Kalyan Vasudev Alwala and Haitham Khedr and Andrew Huang and Jie Lei and Tengyu Ma and Baishan Guo and Arpit Kalla and Markus Marks and Joseph Greer and Meng Wang and Peize Sun and Roman Rädle and Triantafyllos Afouras and Effrosyni Mavroudi and Katherine Xu and Tsung-Han Wu and Yu Zhou and Liliane Momeni and Rishi Hazra and Shuangrui Ding and Sagar Vaze and Francois Porcher and Feng Li and Siyuan Li and Aishwarya Kamath and Ho Kei Cheng and Piotr Dollár and Nikhila Ravi and Kate Saenko and Pengchuan Zhang and Christoph Feichtenhofer},
year={2025},
eprint={2511.16719},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2511.16719},
}
```

190
sam3/README_TRAIN.md Normal file
View File

@ -0,0 +1,190 @@
# Training
This repository supports finetuning SAM3 models on custom datasets in multi-node setup or local execution. The training script is located at `sam3/train.py` and uses Hydra configuration management to handle complex training setups.
## Installation
```bash
cd sam3
pip install -e ".[train]"
```
### Training Script Usage
The main training script is located at `sam3/train.py`. It uses Hydra configuration management to handle complex training setups.
#### Basic Usage
```bash
# Example: Train on Roboflow dataset
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml
# Example: Train on ODinW13 dataset
python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
```
Follow [`Roboflow 100-VL`](https://github.com/roboflow/rf100-vl/) to download the roboflow 100-vl datasets. Follow [`GLIP`](https://github.com/microsoft/GLIP) to download the ODinW datasets. The data folder should be organized as follows, and put your roboflow_vl_100_root and odinw_data_root in the job configs.
```
roboflow_vl_100_root:
13-lkc01
train
valid
test
2024-frc
actions
...
odinw_data_root:
AerialMaritimeDrone
large
train
valid
test
Aquarium
...
```
#### Command Line Arguments
The training script supports several command line arguments:
```bash
python sam3/train/train.py \
-c CONFIG_NAME \
[--use-cluster 0|1] \
[--partition PARTITION_NAME] \
[--account ACCOUNT_NAME] \
[--qos QOS_NAME] \
[--num-gpus NUM_GPUS] \
[--num-nodes NUM_NODES]
```
**Arguments:**
- `-c, --config`: **Required.** Path to the configuration file (e.g., `sam3/train/configs/roboflow_v100_full_ft_100_images.yaml`)
- `--use-cluster`: Whether to launch on a cluster (0: local, 1: cluster). Default: uses config setting
- `--partition`: SLURM partition name for cluster execution
- `--account`: SLURM account name for cluster execution
- `--qos`: SLURM QOS (Quality of Service) setting
- `--num-gpus`: Number of GPUs per node. Default: uses config setting
- `--num-nodes`: Number of nodes for distributed training. Default: uses config setting
#### Local Training Examples
```bash
# Single GPU training
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 1
# Multi-GPU training on a single node
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 4
# Force local execution even if config specifies GPUs
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0
```
#### Cluster Training Examples
```bash
# Basic cluster training with default settings from config
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 1
# Cluster training with specific SLURM settings
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
--use-cluster 1 \
--partition gpu_partition \
--account my_account \
--qos high_priority \
--num-gpus 8 \
--num-nodes 2
```
### Configuration Files
Training configurations are stored in `sam3/train/configs/`. The configuration files use Hydra's YAML format and support:
- **Dataset Configuration**: Data paths, transforms, and loading parameters
- **Model Configuration**: Architecture settings, checkpoint paths, and model parameters
- **Training Configuration**: Batch sizes, learning rates, optimization settings
- **Launcher Configuration**: Distributed training and cluster settings
- **Logging Configuration**: TensorBoard, experiment tracking, and output directories
#### Key Configuration Sections
```yaml
# Paths to datasets and checkpoints
paths:
bpe_path: /path/to/bpe/file
dataset_root: /path/to/dataset
experiment_log_dir: /path/to/logs
# Launcher settings for local/cluster execution
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
# Cluster execution settings
submitit:
use_cluster: True
timeout_hour: 72
cpus_per_task: 10
partition: null
account: null
```
### Monitoring Training
The training script automatically sets up logging and saves outputs to the experiment directory:
```bash
# Logs are saved to the experiment_log_dir specified in config
experiment_log_dir/
├── config.yaml # Original configuration
├── config_resolved.yaml # Resolved configuration with all variables expanded
├── checkpoints/ # Model checkpoints (if skip_checkpointing=False)
├── tensorboard/ # TensorBoard logs
├── logs/ # Text logs
└── submitit_logs/ # Cluster job logs (if using cluster)
```
You can monitor training progress using TensorBoard:
```bash
tensorboard --logdir /path/to/experiment_log_dir/tensorboard
```
### Job Arrays for Dataset Sweeps
The Roboflow and ODinW configuration supports job arrays for training multiple models on different datasets:
This feature is specifically enabled via,
```yaml
submitit:
job_array:
num_tasks: 100
task_index: 0
```
The configuration includes a complete list of 100 Roboflow supercategories, and the `submitit.job_array.task_index` automatically selects which dataset to use based on the array job index.
```bash
# Submit job array to train on different Roboflow datasets
# The job array index selects which dataset from all_roboflow_supercategories
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
--use-cluster 1
```
### Reproduce ODinW13 10-shot results
Running the following job will give the results on the ODinW13 seed 300, see `odinw_train.train_file: fewshot_train_shot10_seed300` in the config file.
```bash
# Example: Train on ODinW13 dataset
python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
```
Change `odinw_train.train_file` to `fewshot_train_shot10_seed30` and `fewshot_train_shot10_seed3` to get the results for the other two seeds. Final results are aggregated from the three seeds. Notice that a small number of jobs may diverge during training, in which case we just use the last checkpoint's result before it diverges.
### Eval Script Usage
With a similar setup as the training config, the training script `sam3/train.py` can also be used for evaluation, too, when setting `trainer.mode = val` in the job config. Run the following job will give the results on the zero-shot results on RF100-VL and ODinW13 datasets.
```bash
# Example: Evaluate on Roboflow dataset
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_eval.yaml
# Example: Evaluate on ODinW13 dataset
python sam3/train/train.py -c configs/odinw13/odinw_text_only.yaml
```

BIN
sam3/assets/dog.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 707 KiB

BIN
sam3/assets/player.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 991 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 141 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Some files were not shown because too many files have changed in this diff Show More