Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a7bdbc3a4 | ||
| 795a028d0e | |||
| fc534a096e | |||
| 9720a07683 | |||
| 0b283f6b8c | |||
| be99472837 | |||
| fad554632e | |||
| c17df2e460 | |||
| 63f240ac3a | |||
|
|
dd931f6231 | ||
| 146872a4dd | |||
| 89181007c2 | |||
|
|
9a09c1e1cf | ||
| c5eeb87488 | |||
| ee8733a0ce | |||
| 0ce543572b | |||
| 929c670add | |||
| 1656f81fe3 | |||
| dfb89c70a3 | |||
| a2d3e2e24b | |||
|
|
0f44df8cec | ||
|
|
eb6ce0de46 | ||
|
|
8d4db9b6df | ||
| eedca6cd50 | |||
|
|
5c865a4418 | ||
| 0e952115c8 | |||
| fbcc505a88 | |||
| b899c4e9de |
47
.gitignore
vendored
Normal file
@ -0,0 +1,47 @@
|
||||
# 忽略所有 .log 文件
|
||||
*.log
|
||||
|
||||
# 忽略特定目录(如 node_modules/)
|
||||
node_modules/
|
||||
|
||||
# 忽略本地配置文件(但保留示例文件)
|
||||
config.local.json
|
||||
!config.example.json
|
||||
|
||||
|
||||
# 忽略编译输出目录
|
||||
dist/
|
||||
build/
|
||||
|
||||
test
|
||||
*test*
|
||||
|
||||
*.pyc
|
||||
__pycache__/
|
||||
|
||||
|
||||
misc.xml
|
||||
profiles_settings.xml
|
||||
Project_Default.xml
|
||||
*test*.py
|
||||
|
||||
|
||||
# 忽略 Python 缓存文件
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.mo
|
||||
*.so
|
||||
|
||||
|
||||
uvmodule.log
|
||||
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
*.iml
|
||||
.vscode/
|
||||
|
||||
|
||||
# Temp files
|
||||
*.tmp
|
||||
*.swp
|
||||
3
.idea/.gitignore
generated
vendored
@ -1,3 +0,0 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
3
.idea/misc.xml
generated
@ -1,4 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="yolo_tensorrt" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="yolo_tensorrt" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
100
ai_image.py
@ -1,100 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import cv2
|
||||
import collections
|
||||
from ultralytics import YOLO
|
||||
from miniohelp import downFile, upload_file, parse_minio_url # 确保你有这些工具函数
|
||||
from minio import Minio
|
||||
|
||||
def process_images(yolo_model, image_list, class_filter, input_folder, output_folder, minio_info):
|
||||
# 初始化 MinIO 客户端# 用配置字典初始化 Minio 客户端对象
|
||||
# 清洗 endpoint,去掉 http:// 或 https:// 前缀
|
||||
endpoint = minio_info["MinIOEndpoint"].replace("http://", "").replace("https://", "")
|
||||
|
||||
# 初始化 MinIO 客户端
|
||||
minio = Minio(
|
||||
endpoint=endpoint,
|
||||
access_key=minio_info["MinIOAccessKey"],
|
||||
secret_key=minio_info["MinIOSecretKey"],
|
||||
secure=False
|
||||
)
|
||||
os.makedirs(input_folder, exist_ok=True)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
model = YOLO(yolo_model)
|
||||
class_ids_filter = [int(cls) for cls in class_filter.split(",")] if class_filter else None
|
||||
output_image_list = []
|
||||
|
||||
for item in image_list:
|
||||
img_id = item["id"]
|
||||
img_url = item["path"]
|
||||
|
||||
# 解析 MinIO 地址
|
||||
if img_url.startswith("http"):
|
||||
bucket_name, img_path = parse_minio_url(img_url)
|
||||
else:
|
||||
bucket_name, img_path = "default-bucket", img_url
|
||||
|
||||
try:
|
||||
# 下载原图到本地
|
||||
local_input_path = os.path.join(input_folder, os.path.basename(img_path))
|
||||
downFile(minio, img_path, bucket_name, local_input_path)
|
||||
|
||||
# 读取图像
|
||||
image = cv2.imread(local_input_path)
|
||||
if image is None:
|
||||
raise ValueError(f"无法读取图像: {local_input_path}")
|
||||
|
||||
# YOLO 检测
|
||||
results = model.predict(image,
|
||||
classes=class_ids_filter,
|
||||
conf=0.5,
|
||||
iou = 0.111,
|
||||
show_labels = False,)
|
||||
result = results[0]
|
||||
|
||||
# 统计类别数
|
||||
class_counts = collections.Counter(result.boxes.cls.cpu().numpy().astype(int)) if result.boxes is not None else {}
|
||||
filtered_class_counts = {k: v for k, v in class_counts.items() if k in class_ids_filter}
|
||||
|
||||
# 转换所有的 numpy.int64 为 Python 的 int 类型
|
||||
detected_classes = [int(cls) for cls in filtered_class_counts.keys()]
|
||||
detected_numbers = [int(num) for num in filtered_class_counts.values()]
|
||||
aim = bool(detected_classes)
|
||||
|
||||
# 保存标注图像
|
||||
annotated_image = result.plot(labels=False)
|
||||
filename_no_ext, ext = os.path.splitext(os.path.basename(img_path))
|
||||
output_filename = f"{filename_no_ext}_ai{ext}"
|
||||
local_output_path = os.path.join(output_folder, output_filename)
|
||||
cv2.imwrite(local_output_path, annotated_image)
|
||||
|
||||
# 上传标注图像到 MinIO
|
||||
minio_path = upload_file(minio, local_output_path, bucket_name, os.path.dirname(img_path))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[错误] 处理失败 - {img_path},错误: {str(e)}")
|
||||
detected_classes = []
|
||||
detected_numbers = []
|
||||
aim = False
|
||||
output_filename = ""
|
||||
minio_path = ""
|
||||
|
||||
output_image_list.append({
|
||||
"id": img_id,
|
||||
"minio_path":minio_path,
|
||||
"aim": aim,
|
||||
"class": detected_classes,
|
||||
"number": detected_numbers
|
||||
})
|
||||
|
||||
# 清理临时目录
|
||||
shutil.rmtree(input_folder, ignore_errors=True)
|
||||
shutil.rmtree(output_folder, ignore_errors=True)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Detection completed",
|
||||
"data": output_image_list
|
||||
}
|
||||
|
||||
505
ai_tottle_api.py
@ -1,505 +0,0 @@
|
||||
from sanic import Sanic, json, Blueprint,response
|
||||
from sanic.exceptions import Unauthorized, NotFound
|
||||
from sanic.response import json as json_response
|
||||
from sanic_cors import CORS
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import uuid
|
||||
import os
|
||||
import asyncio
|
||||
from minio import Minio
|
||||
from ai_image import process_images # 你实现的图片处理函数
|
||||
from queue import Queue
|
||||
import gdal2tiles as gdal2tiles
|
||||
from map_find import map_process_images
|
||||
from yolo_train import auto_train
|
||||
from map_cut import process_tiling
|
||||
from cv_video_counter import start_video_session,switch_model_session,stop_video_session,stream_sessions
|
||||
import torch
|
||||
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
|
||||
|
||||
# 日志配置
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
###################################################################################验证中间件和管理件##############################################################################################
|
||||
|
||||
async def token_and_resource_check(request):
|
||||
# --- Token 验证 ---
|
||||
token = request.headers.get('X-API-Token')
|
||||
expected_token = request.app.config.get("VALID_TOKEN")
|
||||
if not token or token != expected_token:
|
||||
logger.warning(f"Unauthorized request with token: {token}")
|
||||
raise Unauthorized("Invalid token")
|
||||
|
||||
# --- GPU 使用率检查 ---
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
num_gpus = torch.cuda.device_count()
|
||||
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
|
||||
|
||||
for i in range(num_gpus):
|
||||
used = torch.cuda.memory_reserved(i)
|
||||
total = torch.cuda.max_memory_reserved(i)
|
||||
ratio = used / total if total else 0
|
||||
|
||||
logger.info(f"GPU {i} Usage: {ratio:.2%}")
|
||||
|
||||
if ratio > max_usage_ratio:
|
||||
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
|
||||
}, status=503)
|
||||
except Exception as e:
|
||||
logger.error(f"GPU check failed: {e}")
|
||||
|
||||
return None # 允许请求继续
|
||||
|
||||
##################################################################################################################################################################################################
|
||||
#创建Sanic应用
|
||||
app = Sanic("ai_Service_v2")
|
||||
CORS(app) # 允许跨域请求
|
||||
task_progress = {}
|
||||
|
||||
@app.middleware("request")
|
||||
async def global_middleware(request):
|
||||
result = await token_and_resource_check(request)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# 配置Token和最大GPU使用率
|
||||
app.config.update({
|
||||
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
|
||||
"MAX_GPU_USAGE": 0.9
|
||||
})
|
||||
|
||||
######################################################################地图切割相关的API########################################################################################################
|
||||
#创建地图的蓝图
|
||||
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
|
||||
app.blueprint(map_tile_blueprint)
|
||||
|
||||
|
||||
@map_tile_blueprint.post("/tile")
|
||||
async def map_tile_api(request):
|
||||
try:
|
||||
# 1. 检查请求体
|
||||
if not request.json:
|
||||
return json_response(
|
||||
{"status": "error", "message": "Request body is required"},
|
||||
status=400
|
||||
)
|
||||
|
||||
# 2. 解析必要字段
|
||||
tile_data = request.json
|
||||
tif_url = tile_data.get("tif_url")
|
||||
prj_url = tile_data.get("prj_url")
|
||||
|
||||
if not tif_url or not prj_url:
|
||||
return json_response(
|
||||
{"status": "error", "message": "Both tif_url and prj_url are required"},
|
||||
status=400
|
||||
)
|
||||
|
||||
# 3. 处理业务逻辑(直接调用协程函数,不要用 asyncio.run)
|
||||
zoom_levels = tile_data.get("zoom_levels", "1-18")
|
||||
try:
|
||||
# 假设 process_tiling 是一个协程函数(async def)
|
||||
result = await process_tiling(tif_url, prj_url, zoom_levels)
|
||||
|
||||
# 如果 process_tiling 是普通函数,用 asyncio.to_thread 包装
|
||||
# result = await asyncio.to_thread(process_tiling, tif_url, prj_url, zoom_levels)
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"data": result
|
||||
})
|
||||
except Exception as processing_error:
|
||||
logger.error(f"Processing failed: {str(processing_error)}", exc_info=True)
|
||||
return json_response(
|
||||
{"status": "error", "message": f"Processing error: {str(processing_error)}"},
|
||||
status=500
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||
return json_response(
|
||||
{"status": "error", "message": str(e)}, # 直接返回字符串,不要用集合
|
||||
status=500
|
||||
)
|
||||
|
||||
|
||||
#语义识别
|
||||
@map_tile_blueprint.post("/uav")
|
||||
async def process_handler(request):
|
||||
"""
|
||||
接口:/map/uav
|
||||
输入 JSON:
|
||||
{
|
||||
"urls": [
|
||||
"http://example.com/img1.jpg",
|
||||
"http://example.com/img2.jpg"
|
||||
],
|
||||
"yaml_name": "112.44.103.230",
|
||||
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
|
||||
"bucket_directory": "2025/seg"
|
||||
"model_path": "deeplabv3plus_best.pth"
|
||||
}
|
||||
输出 JSON:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": [
|
||||
"http://minio.example.com/uav-results/2025/seg/result1.png",
|
||||
"http://minio.example.com/uav-results/2025/seg/result2.png"
|
||||
]
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
body = request.json
|
||||
urls = body.get("urls", [])
|
||||
yaml_name = body.get("yaml_name")
|
||||
bucket_name = body.get("bucket_name")
|
||||
bucket_directory = body.get("bucket_directory")
|
||||
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
|
||||
|
||||
# 校验参数
|
||||
if not urls or not isinstance(urls, list):
|
||||
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
|
||||
if not all([yaml_name, bucket_name, bucket_directory]):
|
||||
return json({"code": 400, "msg": "Missing required parameters"})
|
||||
|
||||
# 调用图像处理函数
|
||||
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
|
||||
return json(result)
|
||||
|
||||
except Exception as e:
|
||||
return json({"code": 500, "msg": f"Server error: {str(e)}"})
|
||||
|
||||
######################################################################yolo相关的API########################################################################################################
|
||||
#创建yolo的蓝图
|
||||
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
|
||||
app.blueprint(yolo_tile_blueprint)
|
||||
|
||||
# YOLO URL APT
|
||||
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
|
||||
|
||||
@yolo_tile_blueprint.post("/process_images")
|
||||
async def process_images(request):
|
||||
"""
|
||||
{
|
||||
"urls": [
|
||||
"http://example.com/image1.jpg",
|
||||
"http://example.com/image2.jpg",
|
||||
"http://example.com/image3.jpg"
|
||||
],
|
||||
"yaml_name": "your_minio_config",
|
||||
"bucket_name": "my-bucket",
|
||||
"bucket_directory": "2025/uav-results",
|
||||
"model_path": "deeplabv3plus_best.pth"
|
||||
}
|
||||
"""
|
||||
data = request.json
|
||||
urls = data.get("urls")
|
||||
yaml_name = data.get("yaml_name")
|
||||
bucket_name = data.get("bucket_name")
|
||||
bucket_directory = data.get("bucket_directory")
|
||||
uav_model_path = data.get("uav_model_path")
|
||||
|
||||
if not urls or not yaml_name or not bucket_name or not uav_model_path:
|
||||
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
|
||||
|
||||
# 启动后台任务
|
||||
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
|
||||
|
||||
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
|
||||
|
||||
@yolo_tile_blueprint.get("/task_status/<task_id>")
|
||||
async def task_status(request, task_id):
|
||||
progress = task_progress.get(task_id)
|
||||
if not progress:
|
||||
return response.json({"code": 404, "msg": "Task not found"}, status=404)
|
||||
return response.json({"code": 200, "msg": "Task status", "data": progress})
|
||||
|
||||
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
|
||||
|
||||
|
||||
try:
|
||||
task_progress[task_id]["status"] = "running"
|
||||
task_progress[task_id]["progress"] = 10 # 开始进度
|
||||
|
||||
# 下载、推理、上传阶段分别更新进度
|
||||
def progress_callback(stage, percent):
|
||||
task_progress[task_id]["status"] = stage
|
||||
task_progress[task_id]["progress"] = percent
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
map_process_images_with_progress,
|
||||
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
|
||||
)
|
||||
|
||||
task_progress[task_id]["status"] = "completed"
|
||||
task_progress[task_id]["progress"] = 100
|
||||
task_progress[task_id]["result"] = result
|
||||
|
||||
except Exception as e:
|
||||
task_progress[task_id]["status"] = "failed"
|
||||
task_progress[task_id]["progress"] = 100
|
||||
task_progress[task_id]["result"] = str(e)
|
||||
|
||||
# YOLO检测API
|
||||
@yolo_tile_blueprint.post("/picture")
|
||||
async def yolo_detect_api(request):
|
||||
try:
|
||||
|
||||
detect_data = request.json
|
||||
|
||||
# 解析必要字段
|
||||
image_list = detect_data.get("image_list")
|
||||
yolo_model = detect_data.get("yolo_model", "best.pt")
|
||||
class_filter = detect_data.get("class", None)
|
||||
minio_info = detect_data.get("minio", None)
|
||||
|
||||
if not image_list:
|
||||
return json_response({"status": "error", "message": "image_list is required"}, status=400)
|
||||
|
||||
if not minio_info:
|
||||
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
|
||||
|
||||
# 创建临时文件夹
|
||||
input_folder = f"./temp_input_{str(uuid.uuid4())}"
|
||||
output_folder = f"./temp_output_{str(uuid.uuid4())}"
|
||||
|
||||
# 执行图像处理
|
||||
result = await asyncio.to_thread(
|
||||
process_images,
|
||||
yolo_model=yolo_model,
|
||||
image_list=image_list,
|
||||
class_filter=class_filter,
|
||||
input_folder=input_folder,
|
||||
output_folder=output_folder,
|
||||
minio_info=minio_info
|
||||
)
|
||||
|
||||
# 返回处理结果
|
||||
return json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Internal server error: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
# YOLO自动训练
|
||||
@yolo_tile_blueprint.post("/train")
|
||||
async def yolo_train_api(request):
|
||||
"""
|
||||
自动训练模型
|
||||
输入 JSON:
|
||||
{
|
||||
"db_host": str,
|
||||
"db_database": str,
|
||||
"db_user": str,
|
||||
"db_password": str,
|
||||
"db_port": int,
|
||||
"model_id": int,
|
||||
"img_path": str,
|
||||
"label_path": str,
|
||||
"new_path": str,
|
||||
"split_list": List[float],
|
||||
"class_names": Optional[List[str]],
|
||||
"project_name": str
|
||||
}
|
||||
输出 JSON:
|
||||
{
|
||||
"base_metrics": Dict[str, float],
|
||||
"best_model_path": str,
|
||||
"final_metrics": Dict[str, float]
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
# 修改为直接访问 request.json 而不是调用它
|
||||
data = request.json
|
||||
|
||||
|
||||
if not data:
|
||||
return json_response({"status": "error", "message": "data is required"}, status=400)
|
||||
|
||||
# 执行图像处理
|
||||
result = await asyncio.to_thread(
|
||||
auto_train,
|
||||
data
|
||||
)
|
||||
# 返回处理结果
|
||||
return json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Internal server error: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
###########################################################################################视频流相关的API#######################################################################################################
|
||||
|
||||
#创建视频流的蓝图
|
||||
stream_tile_blueprint = Blueprint('stream', url_prefix='/stream_test/')
|
||||
app.blueprint(stream_tile_blueprint)
|
||||
|
||||
#
|
||||
# 任务管理器
|
||||
class StreamTaskManager:
|
||||
def __init__(self):
|
||||
self.active_tasks = {}
|
||||
self.task_status = {}
|
||||
self.task_timestamps = {}
|
||||
self.task_queue = Queue(maxsize=10)
|
||||
|
||||
def add_task(self, task_id: str, task_info: dict) -> None:
|
||||
if self.task_queue.full():
|
||||
oldest_task_id = self.task_queue.get()
|
||||
self.remove_task(oldest_task_id)
|
||||
stop_video_session(self.active_tasks[oldest_task_id]["session_id"])
|
||||
self.active_tasks[task_id] = task_info
|
||||
self.task_status[task_id] = "running"
|
||||
self.task_timestamps[task_id] = datetime.now()
|
||||
self.task_queue.put(task_id)
|
||||
logger.info(f"Task {task_id} started")
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
if task_id in self.active_tasks:
|
||||
del self.active_tasks[task_id]
|
||||
del self.task_status[task_id]
|
||||
del self.task_timestamps[task_id]
|
||||
logger.info(f"Task {task_id} removed")
|
||||
|
||||
def get_task_info(self, task_id: str) -> dict:
|
||||
if task_id not in self.active_tasks:
|
||||
raise NotFound("Task not found")
|
||||
return {
|
||||
"task_info": self.active_tasks[task_id],
|
||||
"status": self.task_status[task_id],
|
||||
"start_time": self.task_timestamps[task_id].isoformat()
|
||||
}
|
||||
|
||||
task_manager = StreamTaskManager()
|
||||
|
||||
# ---------- API Endpoints ----------
|
||||
|
||||
@stream_tile_blueprint.post("/start")
|
||||
async def api_start(request):
|
||||
"""
|
||||
启动视频流会话
|
||||
输入 JSON:
|
||||
{
|
||||
"video_path": str,
|
||||
"output_url": str,
|
||||
"model_path": str,
|
||||
"cls": List[int],
|
||||
"confidence": float,
|
||||
"cls2": Optional[List[int]]
|
||||
"push": bool
|
||||
}
|
||||
输出 JSON:
|
||||
{
|
||||
"session_id": str,
|
||||
"task_id": str,
|
||||
"message": "started"
|
||||
}
|
||||
"""
|
||||
data = request.json
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 启动视频处理会话,并传入 task_id
|
||||
session_id = start_video_session(
|
||||
video_path = data.get("video_path"),
|
||||
output_url = data.get("output_url"),
|
||||
model_path = data.get("model_path"),
|
||||
cls = data.get("cls"),
|
||||
confidence = data.get("confidence", 0.5),
|
||||
cls2 = data.get("cls2", []),
|
||||
push = data.get("push", False),
|
||||
)
|
||||
|
||||
# 注册到任务管理器
|
||||
task_manager.add_task(task_id, {
|
||||
"session_id": session_id,
|
||||
"video_path": data.get("video_path"),
|
||||
"output_url": data.get("output_url"),
|
||||
"model_path": data.get("model_path"),
|
||||
"class_filter": data.get("cls", []),
|
||||
"push": data.get("push", False),
|
||||
"start_time": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return json({"session_id": session_id, "task_id": task_id, "message": "started"})
|
||||
|
||||
|
||||
@stream_tile_blueprint.post("/stop")
|
||||
async def api_stop(request):
|
||||
"""
|
||||
停止指定会话
|
||||
输入 JSON: { "session_id": str }
|
||||
输出 JSON: { "session_id": str, "message": "stopped" }
|
||||
"""
|
||||
session_id = request.json.get("session_id")
|
||||
stop_video_session(session_id)
|
||||
|
||||
# 同步移除任务
|
||||
for tid, info in list(task_manager.active_tasks.items()):
|
||||
if info.get("session_id") == session_id:
|
||||
task_manager.remove_task(tid)
|
||||
break
|
||||
|
||||
return json({"session_id": session_id, "message": "stopped"})
|
||||
|
||||
|
||||
@stream_tile_blueprint.post("/switch_model")
|
||||
async def api_switch_model(request):
|
||||
"""
|
||||
切换会话模型
|
||||
输入 JSON: { "session_id": str, "new_model_path": str }
|
||||
输出 JSON: { "session_id": str, "new_model_path": str, "message": "model switched" }
|
||||
"""
|
||||
data = request.json
|
||||
session_id = data.get("session_id")
|
||||
new_model = data.get("new_model_path")
|
||||
switch_model_session(session_id, new_model)
|
||||
return json({"session_id": session_id, "new_model_path": new_model, "message": "model switched"})
|
||||
|
||||
|
||||
@stream_tile_blueprint.get("/sessions")
|
||||
async def api_list_sessions(request):
|
||||
"""
|
||||
列出所有当前会话
|
||||
输出 JSON: { "sessions": [{"session_id": str, "status": "running"}, ...] }
|
||||
"""
|
||||
sessions = [
|
||||
{"session_id": sid, "status": "running"}
|
||||
for sid in stream_sessions.keys()
|
||||
]
|
||||
return json({"sessions": sessions})
|
||||
|
||||
|
||||
# 统一的任务查询接口(含视频流)
|
||||
@stream_tile_blueprint.get("/tasks")
|
||||
async def api_list_tasks(request):
|
||||
"""
|
||||
列出所有任务(含状态、开始时间、详情)
|
||||
"""
|
||||
tasks = []
|
||||
for tid in task_manager.active_tasks:
|
||||
info = task_manager.get_task_info(tid)
|
||||
tasks.append({"task_id": tid, **info})
|
||||
return json({"tasks": tasks})
|
||||
|
||||
|
||||
##################################################################################################################################################################################################
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)
|
||||
60
b3dm/b3dm_api.py
Normal file
@ -0,0 +1,60 @@
|
||||
from sanic import Sanic, Request, json
|
||||
from sanic_cors import CORS
|
||||
import logging
|
||||
import time
|
||||
from earthwork_api import earthwork_bp
|
||||
from terrain_api import terrain_bp
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建Sanic应用
|
||||
app = Sanic("TerrainAnalysisAPI")
|
||||
# 显式注册蓝图
|
||||
app.blueprint(earthwork_bp)
|
||||
app.blueprint(terrain_bp)
|
||||
|
||||
CORS(app, automatic_options=True)
|
||||
|
||||
# 中间件:请求计时
|
||||
@app.middleware("request")
|
||||
async def add_start_time(request: Request):
|
||||
request.ctx.start_time = time.time()
|
||||
|
||||
@app.middleware("response")
|
||||
async def add_response_time(request: Request, response):
|
||||
if hasattr(request.ctx, "start_time"):
|
||||
process_time = (time.time() - request.ctx.start_time) * 1000
|
||||
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def health_check(request: Request):
|
||||
"""健康检查"""
|
||||
return json({
|
||||
"status": "healthy",
|
||||
"timestamp": time.time(),
|
||||
"service": "terrain-analysis-api",
|
||||
"version": "1.0.0"
|
||||
})
|
||||
|
||||
# 错误处理
|
||||
@app.exception(Exception)
|
||||
async def handle_exception(request: Request, exception):
|
||||
"""全局异常处理"""
|
||||
logger.error(f"未处理的异常: {exception}")
|
||||
return json({
|
||||
"error": "服务器内部错误",
|
||||
"message": str(exception) if app.debug else "请稍后重试",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动服务器
|
||||
app.run(
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
debug=True, # 生产环境设为False
|
||||
access_log=True,
|
||||
auto_reload=True
|
||||
)
|
||||
542
b3dm/data_3dtiles_manager.py
Normal file
@ -0,0 +1,542 @@
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
from urllib.parse import urlparse
|
||||
import hashlib
|
||||
import time
|
||||
import re
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
|
||||
class MinIO3DTilesManager:
|
||||
def __init__(self, endpoint_url, access_key, secret_key, secure=False,
|
||||
mapping_file="minio_path_mapping.pkl"):
|
||||
"""
|
||||
初始化MinIO客户端
|
||||
|
||||
Args:
|
||||
endpoint_url: MinIO服务地址 (如: 222.212.85.86:9001)
|
||||
access_key: 访问密钥
|
||||
secret_key: 秘密密钥
|
||||
secure: 是否使用HTTPS
|
||||
mapping_file: 路径映射文件名
|
||||
"""
|
||||
if endpoint_url.startswith('http://'):
|
||||
endpoint_url = endpoint_url.replace('http://', '')
|
||||
elif endpoint_url.startswith('https://'):
|
||||
endpoint_url = endpoint_url.replace('https://', '')
|
||||
secure = True
|
||||
|
||||
self.endpoint_url = endpoint_url
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
self.minio_client = Minio(
|
||||
endpoint_url,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
secure=secure
|
||||
)
|
||||
|
||||
# 获取脚本所在目录
|
||||
self.script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 映射文件路径
|
||||
self.mapping_file = os.path.join(self.script_dir, mapping_file)
|
||||
|
||||
# 加载现有的路径映射
|
||||
self.path_mapping = self.load_path_mapping()
|
||||
|
||||
def load_path_mapping(self):
|
||||
"""加载路径映射数据"""
|
||||
if os.path.exists(self.mapping_file):
|
||||
try:
|
||||
with open(self.mapping_file, 'rb') as f:
|
||||
mapping = pickle.load(f)
|
||||
return mapping
|
||||
except Exception as e:
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def save_path_mapping(self):
|
||||
"""保存路径映射数据"""
|
||||
try:
|
||||
with open(self.mapping_file, 'wb') as f:
|
||||
pickle.dump(self.path_mapping, f)
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def get_cache_key(self, tileset_url, save_dir=None):
|
||||
"""生成缓存键"""
|
||||
# 基于URL和保存目录生成缓存键
|
||||
cache_data = f"{tileset_url}|{save_dir}"
|
||||
return hashlib.md5(cache_data.encode()).hexdigest()
|
||||
|
||||
def get_cached_tileset_info(self, tileset_url, save_dir=None):
|
||||
"""获取缓存的tileset信息"""
|
||||
cache_key = self.get_cache_key(tileset_url, save_dir)
|
||||
|
||||
# 检查缓存映射中是否有这个tileset
|
||||
for file_id, info in self.path_mapping.items():
|
||||
if info.get('cache_key') == cache_key and info.get('is_tileset_root'):
|
||||
# 检查入口文件是否存在
|
||||
local_path = info.get('local_path')
|
||||
if local_path and os.path.exists(local_path):
|
||||
return local_path
|
||||
return None
|
||||
|
||||
def update_tileset_cache(self, tileset_url, save_dir, local_path):
|
||||
"""更新tileset缓存信息"""
|
||||
cache_key = self.get_cache_key(tileset_url, save_dir)
|
||||
|
||||
# 将tileset根文件标记为缓存
|
||||
entry_bucket, entry_path = self.parse_minio_url(tileset_url)
|
||||
file_id = f"{entry_bucket}/{entry_path}"
|
||||
|
||||
if file_id in self.path_mapping:
|
||||
self.path_mapping[file_id]['cache_key'] = cache_key
|
||||
self.path_mapping[file_id]['is_tileset_root'] = True
|
||||
self.path_mapping[file_id]['tileset_url'] = tileset_url
|
||||
self.path_mapping[file_id]['save_dir'] = save_dir
|
||||
self.path_mapping[file_id]['cache_time'] = datetime.now().isoformat()
|
||||
|
||||
def download_full_tileset(self, tileset_url, save_dir=None, region_filter=None, use_cache=True):
|
||||
"""
|
||||
下载完整的3D Tiles数据集,支持缓存功能
|
||||
|
||||
Args:
|
||||
tileset_url: MinIO上的tileset.json URL
|
||||
save_dir: 本地保存目录
|
||||
region_filter: 区域过滤器
|
||||
use_cache: 是否使用缓存
|
||||
|
||||
Returns:
|
||||
tuple: (success, result)
|
||||
- success: True/False
|
||||
- result: 如果success=True且use_cache=True,返回本地路径;否则返回True/False
|
||||
"""
|
||||
if save_dir is None:
|
||||
save_dir = os.path.join(self.script_dir, "data_3dtiles")
|
||||
|
||||
# 清理保存目录名称
|
||||
save_dir = self.clean_file_path(save_dir)
|
||||
|
||||
# 检查缓存:只需检查入口文件是否存在
|
||||
if use_cache:
|
||||
cached_path = self.get_cached_tileset_info(tileset_url, save_dir)
|
||||
if cached_path:
|
||||
# 入口文件存在,默认缓存完备
|
||||
return True, cached_path
|
||||
|
||||
# 解析URL
|
||||
entry_bucket, entry_path = self.parse_minio_url(tileset_url)
|
||||
if not entry_bucket or not entry_path:
|
||||
return False, "无法解析URL"
|
||||
|
||||
entry_dir = os.path.dirname(entry_path)
|
||||
|
||||
# 创建保存目录
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
visited = set()
|
||||
|
||||
# 下载入口文件
|
||||
entry_local_path = self.get_local_path(entry_bucket, entry_path, save_dir)
|
||||
|
||||
success, result = self.download_file(entry_bucket, entry_path, entry_local_path)
|
||||
if not success:
|
||||
return False, f"入口文件下载失败: {result}"
|
||||
|
||||
entry_id = f"{entry_bucket}/{entry_path}"
|
||||
visited.add(entry_id)
|
||||
|
||||
# 加载tileset数据
|
||||
tileset_data = self.load_json_from_minio(entry_bucket, entry_path)
|
||||
if not tileset_data or "root" not in tileset_data:
|
||||
return False, "无效的tileset.json文件"
|
||||
|
||||
# 遍历下载所有文件
|
||||
self.traverse_and_download_tileset(
|
||||
tileset_data["root"],
|
||||
entry_bucket,
|
||||
entry_dir,
|
||||
entry_bucket,
|
||||
entry_dir,
|
||||
save_dir,
|
||||
region_filter,
|
||||
None,
|
||||
visited
|
||||
)
|
||||
|
||||
# 更新缓存信息
|
||||
self.update_tileset_cache(tileset_url, save_dir, entry_local_path)
|
||||
|
||||
# 保存路径映射
|
||||
self.save_path_mapping()
|
||||
|
||||
if use_cache:
|
||||
return True, entry_local_path
|
||||
else:
|
||||
return True, True
|
||||
|
||||
def get_tileset_local_path(self, tileset_url, save_dir=None):
|
||||
"""
|
||||
获取已缓存的tileset本地路径
|
||||
|
||||
Args:
|
||||
tileset_url: tileset的URL
|
||||
save_dir: 保存目录
|
||||
|
||||
Returns:
|
||||
str: 本地路径,如果未缓存则返回None
|
||||
"""
|
||||
if save_dir is None:
|
||||
save_dir = os.path.join(self.script_dir, "data_3dtiles")
|
||||
|
||||
return self.get_cached_tileset_info(tileset_url, save_dir)
|
||||
|
||||
def clear_tileset_cache(self, tileset_url=None, save_dir=None):
|
||||
"""
|
||||
清除tileset缓存
|
||||
|
||||
Args:
|
||||
tileset_url: 指定要清除的tileset URL,如果为None则清除所有
|
||||
save_dir: 保存目录
|
||||
|
||||
Returns:
|
||||
bool: 成功/失败
|
||||
"""
|
||||
try:
|
||||
if tileset_url:
|
||||
# 清除指定tileset的缓存
|
||||
cache_key = self.get_cache_key(tileset_url, save_dir)
|
||||
|
||||
# 找出所有相关的缓存条目
|
||||
to_remove = []
|
||||
for file_id, info in self.path_mapping.items():
|
||||
if info.get('cache_key') == cache_key:
|
||||
to_remove.append(file_id)
|
||||
|
||||
# 删除这些条目
|
||||
for file_id in to_remove:
|
||||
del self.path_mapping[file_id]
|
||||
|
||||
print(f"已清除tileset缓存: {tileset_url}")
|
||||
else:
|
||||
# 清除所有缓存
|
||||
self.path_mapping = {}
|
||||
if os.path.exists(self.mapping_file):
|
||||
os.remove(self.mapping_file)
|
||||
print("已清除所有缓存")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
# 以下是原有的辅助方法
|
||||
def clean_filename(self, filename):
|
||||
"""清理文件名中的特殊字符"""
|
||||
if not filename:
|
||||
return ""
|
||||
cleaned = re.sub(r'[<>:"/\\|?*\x00-\x1F]', '_', filename)
|
||||
cleaned = re.sub(r'_+', '_', cleaned)
|
||||
cleaned = cleaned.strip(' _')
|
||||
return cleaned
|
||||
|
||||
def parse_minio_url(self, url):
|
||||
"""解析MinIO URL"""
|
||||
if url.startswith('http://') or url.startswith('https://'):
|
||||
parsed = urlparse(url)
|
||||
path = parsed.path.lstrip('/')
|
||||
parts = path.split('/', 1)
|
||||
if len(parts) == 2:
|
||||
bucket, key = parts
|
||||
else:
|
||||
bucket = parts[0]
|
||||
key = ""
|
||||
return bucket, key
|
||||
else:
|
||||
parts = url.split('/', 1)
|
||||
if len(parts) == 2:
|
||||
bucket, key = parts
|
||||
else:
|
||||
bucket = parts[0]
|
||||
key = ""
|
||||
return bucket, key
|
||||
|
||||
def download_file(self, bucket_name, object_name, file_path):
|
||||
"""从MinIO下载文件"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
# 清理文件名
|
||||
clean_file_path = self.clean_file_path(file_path)
|
||||
|
||||
# 检查是否已下载
|
||||
file_id = f"{bucket_name}/{object_name}"
|
||||
if file_id in self.path_mapping:
|
||||
mapped_path = self.path_mapping[file_id]['local_path']
|
||||
if os.path.exists(mapped_path):
|
||||
return True, mapped_path
|
||||
|
||||
# 下载文件
|
||||
self.minio_client.fget_object(
|
||||
bucket_name,
|
||||
object_name,
|
||||
clean_file_path
|
||||
)
|
||||
|
||||
# 更新路径映射
|
||||
self.path_mapping[file_id] = {
|
||||
'local_path': clean_file_path,
|
||||
'bucket': bucket_name,
|
||||
'object': object_name,
|
||||
'download_time': datetime.now().isoformat(),
|
||||
'size': os.path.getsize(clean_file_path)
|
||||
}
|
||||
|
||||
return True, clean_file_path
|
||||
|
||||
except S3Error as e:
|
||||
return False, str(e)
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
def clean_file_path(self, file_path):
|
||||
"""清理文件路径中的所有特殊字符"""
|
||||
dir_name = os.path.dirname(file_path)
|
||||
file_name = os.path.basename(file_path)
|
||||
|
||||
if dir_name:
|
||||
dir_parts = dir_name.split(os.sep)
|
||||
cleaned_parts = []
|
||||
for part in dir_parts:
|
||||
cleaned_part = self.clean_filename(part)
|
||||
if cleaned_part:
|
||||
cleaned_parts.append(cleaned_part)
|
||||
cleaned_dir = os.sep.join(cleaned_parts)
|
||||
else:
|
||||
cleaned_dir = ""
|
||||
|
||||
cleaned_file = self.clean_filename(file_name)
|
||||
|
||||
if cleaned_dir:
|
||||
cleaned_path = os.path.join(cleaned_dir, cleaned_file)
|
||||
else:
|
||||
cleaned_path = cleaned_file
|
||||
|
||||
return cleaned_path
|
||||
|
||||
def load_json_from_minio(self, bucket_name, object_name):
|
||||
"""从MinIO加载JSON文件"""
|
||||
try:
|
||||
self.minio_client.stat_object(bucket_name, object_name)
|
||||
|
||||
response = self.minio_client.get_object(bucket_name, object_name)
|
||||
content = response.read().decode('utf-8')
|
||||
response.close()
|
||||
response.release_conn()
|
||||
|
||||
return json.loads(content)
|
||||
|
||||
except S3Error as e:
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def get_local_path(self, bucket_name, object_name, save_dir):
|
||||
"""生成保持目录结构的本地路径"""
|
||||
clean_bucket = self.clean_filename(bucket_name)
|
||||
|
||||
path_parts = object_name.split('/')
|
||||
cleaned_parts = []
|
||||
for part in path_parts:
|
||||
cleaned_part = self.clean_filename(part)
|
||||
if cleaned_part:
|
||||
cleaned_parts.append(cleaned_part)
|
||||
|
||||
if cleaned_parts:
|
||||
cleaned_relative = '/'.join(cleaned_parts)
|
||||
local_path = os.path.join(save_dir, clean_bucket, cleaned_relative)
|
||||
else:
|
||||
local_path = os.path.join(save_dir, clean_bucket)
|
||||
|
||||
|
||||
return os.path.normpath(local_path)
|
||||
|
||||
def traverse_and_download_tileset(self, tile_obj, current_bucket, current_dir,
|
||||
base_bucket, base_dir, save_dir,
|
||||
region_filter=None, parent_transform=None,
|
||||
visited=None):
|
||||
"""递归遍历并下载3D Tiles文件"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
current_transform = parent_transform
|
||||
if "transform" in tile_obj:
|
||||
tile_mat = tile_obj["transform"]
|
||||
if current_transform is None:
|
||||
current_transform = tile_mat
|
||||
else:
|
||||
mat1 = np.array(current_transform).reshape(4, 4)
|
||||
mat2 = np.array(tile_mat).reshape(4, 4)
|
||||
combined_mat = np.dot(mat1, mat2).flatten().tolist()
|
||||
current_transform = combined_mat
|
||||
|
||||
skip_current_tile = False
|
||||
if region_filter and "boundingVolume" in tile_obj:
|
||||
if not region_filter.check_tile_bounding_volume(tile_obj["boundingVolume"]):
|
||||
skip_current_tile = True
|
||||
|
||||
if not skip_current_tile and "content" in tile_obj and "uri" in tile_obj["content"]:
|
||||
tile_uri = tile_obj["content"]["uri"]
|
||||
|
||||
file_bucket = current_bucket
|
||||
file_path = ""
|
||||
|
||||
if tile_uri.startswith('http://') or tile_uri.startswith('https://'):
|
||||
parsed_bucket, parsed_path = self.parse_minio_url(tile_uri)
|
||||
if parsed_bucket:
|
||||
file_bucket = parsed_bucket
|
||||
file_path = parsed_path
|
||||
else:
|
||||
if current_dir:
|
||||
file_path = os.path.join(current_dir, tile_uri).replace('\\', '/')
|
||||
else:
|
||||
file_path = tile_uri
|
||||
|
||||
file_path = file_path.lstrip('/')
|
||||
|
||||
file_id = f"{file_bucket}/{file_path}"
|
||||
|
||||
if file_id not in visited:
|
||||
print(f"下载文件:{file_id}")
|
||||
visited.add(file_id)
|
||||
|
||||
local_path = self.get_local_path(file_bucket, file_path, save_dir)
|
||||
|
||||
self.download_file(file_bucket, file_path, local_path)
|
||||
|
||||
if file_path.lower().endswith('.json'):
|
||||
sub_tileset = self.load_json_from_minio(file_bucket, file_path)
|
||||
if sub_tileset and "root" in sub_tileset:
|
||||
sub_dir = os.path.dirname(file_path) if file_path else ""
|
||||
self.traverse_and_download_tileset(
|
||||
sub_tileset["root"],
|
||||
file_bucket,
|
||||
sub_dir,
|
||||
base_bucket,
|
||||
base_dir,
|
||||
save_dir,
|
||||
region_filter,
|
||||
current_transform,
|
||||
visited
|
||||
)
|
||||
|
||||
if "children" in tile_obj:
|
||||
for child_tile in tile_obj["children"]:
|
||||
self.traverse_and_download_tileset(
|
||||
child_tile,
|
||||
current_bucket,
|
||||
current_dir,
|
||||
base_bucket,
|
||||
base_dir,
|
||||
save_dir,
|
||||
region_filter,
|
||||
current_transform,
|
||||
visited
|
||||
)
|
||||
|
||||
def upload_file(self, bucket_name, object_name, file_path):
|
||||
"""上传文件到MinIO"""
|
||||
try:
|
||||
if not os.path.exists(file_path):
|
||||
return False, f"文件不存在: {file_path}"
|
||||
|
||||
file_size = os.path.getsize(file_path)
|
||||
self.minio_client.fput_object(bucket_name, object_name, file_path)
|
||||
|
||||
return True, f"{bucket_name}/{object_name}"
|
||||
|
||||
except S3Error as e:
|
||||
return False, f"MinIO上传错误: {e}"
|
||||
except Exception as e:
|
||||
return False, f"上传失败: {str(e)}"
|
||||
|
||||
def upload_directory(self, bucket_name, local_dir, remote_prefix=""):
|
||||
"""上传目录到MinIO"""
|
||||
if not os.path.exists(local_dir):
|
||||
return [], [f"目录不存在: {local_dir}"]
|
||||
|
||||
uploaded_files = []
|
||||
failed_files = []
|
||||
|
||||
for root, dirs, files in os.walk(local_dir):
|
||||
for file in files:
|
||||
local_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(local_path, local_dir)
|
||||
if remote_prefix:
|
||||
remote_path = os.path.join(remote_prefix, rel_path).replace('\\', '/')
|
||||
else:
|
||||
remote_path = rel_path.replace('\\', '/')
|
||||
|
||||
success, message = self.upload_file(bucket_name, remote_path, local_path)
|
||||
if success:
|
||||
uploaded_files.append(remote_path)
|
||||
else:
|
||||
failed_files.append((remote_path, message))
|
||||
|
||||
return uploaded_files, failed_files
|
||||
|
||||
def check_and_create_bucket(self, bucket_name):
|
||||
"""检查并创建bucket"""
|
||||
try:
|
||||
if not self.minio_client.bucket_exists(bucket_name):
|
||||
self.minio_client.make_bucket(bucket_name)
|
||||
return True, f"创建bucket: {bucket_name}"
|
||||
return True, f"bucket已存在: {bucket_name}"
|
||||
except S3Error as e:
|
||||
return False, f"创建bucket失败: {e}"
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 配置参数
|
||||
ENDPOINT_URL = "222.212.85.86:9000"
|
||||
ACCESS_KEY = "WuRenJi"
|
||||
SECRET_KEY = "WRJ@2024"
|
||||
|
||||
# 初始化管理器
|
||||
manager = MinIO3DTilesManager(
|
||||
endpoint_url=ENDPOINT_URL,
|
||||
access_key=ACCESS_KEY,
|
||||
secret_key=SECRET_KEY,
|
||||
secure=False
|
||||
)
|
||||
|
||||
# 使用缓存下载tileset
|
||||
tileset_url = "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/石棉0908/terra_b3dms/tileset.json"
|
||||
|
||||
# 第一次下载(会下载到本地)
|
||||
success, result = manager.download_full_tileset(tileset_url, use_cache=True)
|
||||
if success:
|
||||
print(f"下载成功,本地路径: {result}")
|
||||
|
||||
# 第二次下载相同URL(直接从缓存返回)
|
||||
success, result = manager.download_full_tileset(tileset_url, use_cache=True)
|
||||
if success:
|
||||
print(f"从缓存获取,本地路径: {result}")
|
||||
|
||||
# 强制重新下载(忽略缓存)
|
||||
success, result = manager.download_full_tileset(tileset_url, use_cache=False)
|
||||
if success:
|
||||
print("强制重新下载成功")
|
||||
|
||||
# 获取缓存的本地路径
|
||||
local_path = manager.get_tileset_local_path(tileset_url)
|
||||
if local_path:
|
||||
print(f"缓存的本地路径: {local_path}")
|
||||
1367
b3dm/data_3dtiles_to_dem.py
Normal file
530
b3dm/earthwork_api.py
Normal file
@ -0,0 +1,530 @@
|
||||
# pip install fastapi uvicorn pdal pyvista numpy
|
||||
from sanic import Blueprint, Request, json
|
||||
from sanic.response import text
|
||||
from typing import Dict, Any
|
||||
import logging
|
||||
from b3dm.earthwork_calculator_point_cloud import EarthworkCalculatorPointCloud
|
||||
# 导入计算模块
|
||||
from b3dm.earthwork_calculator_3d_tiles import EarthworkCalculator3dTiles, AlgorithmType, EarthworkResult3dTiles
|
||||
from b3dm.tileset_data_source import TilesetDataSource
|
||||
|
||||
earthwork_bp = Blueprint("earthwork", url_prefix="")
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化函数
|
||||
def init_app(url, type = "3dtiles"):
|
||||
"""初始化应用"""
|
||||
data_source = None
|
||||
calculator_3d_tiles = None
|
||||
calculator_point_cloud = None
|
||||
|
||||
try:
|
||||
# 初始化数据源
|
||||
data_source = TilesetDataSource(url)
|
||||
data_source.dowload_map_data(url)
|
||||
|
||||
if type == "3dtiles" :
|
||||
# 初始化计算器-3dTiles
|
||||
calculator_3d_tiles = EarthworkCalculator3dTiles(data_source)
|
||||
elif type == "pointcloud" :
|
||||
# 初始化计算器-点云
|
||||
calculator_point_cloud = EarthworkCalculatorPointCloud(data_source.tileset_path)
|
||||
else :
|
||||
logger.info(f"不支持的3d地图数据格式:{type}")
|
||||
raise
|
||||
|
||||
logger.info("土方量计算器初始化完成")
|
||||
return {
|
||||
"data_source":data_source,
|
||||
"calculator_3d_tiles":calculator_3d_tiles,
|
||||
"calculator_point_cloud":calculator_point_cloud
|
||||
}
|
||||
except ImportError as e:
|
||||
logger.error(f"依赖库缺失: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"初始化失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 土方量计算接口-3dTiles
|
||||
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles")
|
||||
async def calc_earthwork(request: Request):
|
||||
"""
|
||||
土方量计算接口
|
||||
|
||||
请求参数示例:
|
||||
{
|
||||
"polygonCoords": [
|
||||
[
|
||||
115.70440468338526,
|
||||
30.77363140345639
|
||||
],
|
||||
[
|
||||
115.70443054007985,
|
||||
30.773510462589584
|
||||
],
|
||||
[
|
||||
115.70459702429197,
|
||||
30.77360789911405
|
||||
]
|
||||
],
|
||||
"designElevation": 100,
|
||||
"algorithm": "tin",
|
||||
"resolution": 1,
|
||||
"crs": "EPSG:4326",
|
||||
"interpolationMethod": "linear",
|
||||
"url": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return _error_response("请求参数不能为空", 400)
|
||||
|
||||
# 2. 提取参数
|
||||
polygon_coords = data.get("polygonCoords")
|
||||
design_elevation = data.get("designElevation")
|
||||
url = data.get("url")
|
||||
|
||||
if not polygon_coords:
|
||||
return _error_response("多边形坐标不能为空", 400)
|
||||
if design_elevation is None:
|
||||
return _error_response("设计高程不能为空", 400)
|
||||
if url is None:
|
||||
return _error_response("地图不能为空", 400)
|
||||
|
||||
# 3. 可选参数
|
||||
algorithm = data.get("algorithm", "tin")
|
||||
resolution = data.get("resolution", 1.0)
|
||||
crs = data.get("crs", "EPSG:4326")
|
||||
interpolation_method = data.get("interpolationMethod", "linear")
|
||||
|
||||
# 4. 参数验证
|
||||
if len(polygon_coords) < 3:
|
||||
return _error_response("多边形至少需要3个点", 400)
|
||||
|
||||
# 检查多边形是否闭合,如不闭合则自动闭合
|
||||
if polygon_coords[0] != polygon_coords[-1]:
|
||||
polygon_coords.append(polygon_coords[0])
|
||||
|
||||
# 算法验证
|
||||
if algorithm not in ["grid", "tin", "prism"]:
|
||||
return _error_response("算法必须是 grid, tin 或 prism", 400)
|
||||
|
||||
# 分辨率验证
|
||||
if resolution <= 0 or resolution > 100:
|
||||
return _error_response("分辨率必须在0-100米之间", 400)
|
||||
|
||||
# 5. 确保计算器已初始化
|
||||
app_info = init_app(url)
|
||||
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
|
||||
|
||||
# 6. 执行计算
|
||||
algorithm_type = AlgorithmType(algorithm)
|
||||
|
||||
result = await calculator_3d_tiles.calculate(
|
||||
polygon_coords=polygon_coords,
|
||||
design_elevation=design_elevation,
|
||||
algorithm=algorithm_type,
|
||||
resolution=resolution,
|
||||
target_crs=crs,
|
||||
interpolation_method=interpolation_method
|
||||
)
|
||||
|
||||
# 7. 返回成功响应
|
||||
res_dict = result.to_dict()
|
||||
res_dict["calculation_details"] = None
|
||||
res_dict["elevation_statistics"] = None
|
||||
res_dict["volume_distribution"] = None
|
||||
return _success_response(res_dict)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"参数验证失败: {str(e)}")
|
||||
return _error_response(f"参数错误: {str(e)}", 400)
|
||||
except Exception as e:
|
||||
logger.error(f"计算失败: {str(e)}")
|
||||
return _error_response(f"服务器内部错误: {str(e)}", 500)
|
||||
|
||||
# 两期对比接口-3dTiles
|
||||
@earthwork_bp.post("/api/v1/calc/twoPhaseComparison")
|
||||
async def two_phase_comparison(request: Request):
|
||||
"""
|
||||
两期对比接口
|
||||
|
||||
请求参数示例:
|
||||
{
|
||||
"polygonCoords": [
|
||||
[
|
||||
115.70440468338526,
|
||||
30.77363140345639
|
||||
],
|
||||
[
|
||||
115.70443054007985,
|
||||
30.773510462589584
|
||||
],
|
||||
[
|
||||
115.70459702429197,
|
||||
30.77360789911405
|
||||
]
|
||||
],
|
||||
"designElevation": 100,
|
||||
"algorithm": "grid",
|
||||
"resolution": 1,
|
||||
"crs": "EPSG:4326",
|
||||
"interpolationMethod": "linear",
|
||||
"urlA": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json",
|
||||
"urlB": "http://222.212.85.86:9000/300bdf2b-a150-406e-be63-d28bd29b409f/model/hbgldk/yzk/20260113/3D/terra_b3dms/tileset.json"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return _error_response("请求参数不能为空", 400)
|
||||
|
||||
# 2. 提取参数
|
||||
polygon_coords = data.get("polygonCoords")
|
||||
design_elevation = data.get("designElevation", 1000)
|
||||
urlA = data.get("urlA")
|
||||
urlB = data.get("urlB")
|
||||
|
||||
if not polygon_coords:
|
||||
return _error_response("多边形坐标不能为空", 400)
|
||||
if design_elevation is None:
|
||||
return _error_response("设计高程不能为空", 400)
|
||||
if urlA is None or urlB is None :
|
||||
return _error_response("对比地图不能为空", 400)
|
||||
|
||||
# 3. 可选参数
|
||||
algorithm = data.get("algorithm", "tin")
|
||||
resolution = data.get("resolution", 1.0)
|
||||
crs = data.get("crs", "EPSG:4326")
|
||||
interpolation_method = data.get("interpolationMethod", "linear")
|
||||
|
||||
# 4. 参数验证
|
||||
if len(polygon_coords) < 3:
|
||||
return _error_response("多边形至少需要3个点", 400)
|
||||
|
||||
# 检查多边形是否闭合,如不闭合则自动闭合
|
||||
if polygon_coords[0] != polygon_coords[-1]:
|
||||
polygon_coords.append(polygon_coords[0])
|
||||
|
||||
# 算法验证
|
||||
if algorithm not in ["grid", "tin", "prism"]:
|
||||
return _error_response("算法必须是 grid, tin 或 prism", 400)
|
||||
|
||||
# 分辨率验证
|
||||
if resolution <= 0 or resolution > 100:
|
||||
return _error_response("分辨率必须在0-100米之间", 400)
|
||||
|
||||
# 5. 确保计算器已初始化
|
||||
app_info_a = init_app(urlA)
|
||||
if not app_info_a.get('data_source').tileset_path :
|
||||
return _error_response(f"下载地图失败:{urlA}", 400)
|
||||
calculator_3d_tiles_a = app_info_a.get("calculator_3d_tiles")
|
||||
app_info_b = init_app(urlB)
|
||||
if not app_info_b.get('data_source').tileset_path :
|
||||
return _error_response(f"下载地图失败:{urlB}", 400)
|
||||
calculator_3d_tiles_b = app_info_b.get("calculator_3d_tiles")
|
||||
|
||||
# 6. 执行计算
|
||||
algorithm_type = AlgorithmType.GRID
|
||||
result_a = await calculator_3d_tiles_a.calculate(
|
||||
polygon_coords=polygon_coords,
|
||||
design_elevation=design_elevation,
|
||||
algorithm=algorithm_type,
|
||||
resolution=resolution,
|
||||
target_crs=crs,
|
||||
interpolation_method=interpolation_method
|
||||
)
|
||||
result_b = await calculator_3d_tiles_b.calculate(
|
||||
polygon_coords=polygon_coords,
|
||||
design_elevation=design_elevation,
|
||||
algorithm=algorithm_type,
|
||||
resolution=resolution,
|
||||
target_crs=crs,
|
||||
interpolation_method=interpolation_method
|
||||
)
|
||||
|
||||
# 获取网格数据
|
||||
grids_a = result_a.calculation_details
|
||||
grids_b = result_b.calculation_details
|
||||
|
||||
# 比较网格数据
|
||||
comparison_result = calculator_3d_tiles_a.compare_grid_cells(grids_a, grids_b)
|
||||
|
||||
# 转换为字典
|
||||
result_dict = comparison_result.to_dict()
|
||||
|
||||
# 7. 返回成功响应
|
||||
return _success_response(result_dict)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"参数验证失败: {str(e)}")
|
||||
return _error_response(f"参数错误: {str(e)}", 400)
|
||||
except Exception as e:
|
||||
logger.error(f"计算失败: {str(e)}")
|
||||
return _error_response(f"服务器内部错误: {str(e)}", 500)
|
||||
|
||||
|
||||
# 验证接口
|
||||
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles/validate")
|
||||
async def validate_earthwork(request: Request):
|
||||
"""验证计算参数接口"""
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return _error_response("请求参数不能为空", 400)
|
||||
|
||||
# 2. 提取参数
|
||||
polygon_coords = data.get("polygonCoords")
|
||||
|
||||
if not polygon_coords:
|
||||
return _error_response("多边形坐标不能为空", 400)
|
||||
|
||||
url = data.get("url")
|
||||
if url is None:
|
||||
return _error_response("地图不能为空", 400)
|
||||
|
||||
# 3. 参数验证
|
||||
if len(polygon_coords) < 3:
|
||||
return _error_response("多边形至少需要3个点", 400)
|
||||
|
||||
# 检查多边形是否闭合,如不闭合则自动闭合
|
||||
if polygon_coords[0] != polygon_coords[-1]:
|
||||
polygon_coords.append(polygon_coords[0])
|
||||
|
||||
# 4. 确保计算器已初始化
|
||||
app_info = init_app(url)
|
||||
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
|
||||
|
||||
# 5. 执行验证
|
||||
validation_result = await calculator_3d_tiles.validate(polygon_coords)
|
||||
|
||||
# 6. 返回结果
|
||||
return _success_response(validation_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"验证失败: {str(e)}")
|
||||
return _error_response(f"验证失败: {str(e)}", 400)
|
||||
|
||||
# 获取算法列表接口
|
||||
@earthwork_bp.get("/api/v1/calc/earthwork3dTiles/algorithms")
|
||||
async def get_algorithms(request: Request):
|
||||
"""获取支持的算法列表接口"""
|
||||
try:
|
||||
algorithms = [
|
||||
{
|
||||
"id": "grid",
|
||||
"name": "格网法",
|
||||
"description": "将计算区域划分为规则格网,通过插值计算每个格网的高程变化,适合平坦或规则地形",
|
||||
"accuracy": "中等",
|
||||
"performance": "快速",
|
||||
"parameters": {
|
||||
"resolution": {
|
||||
"name": "格网分辨率",
|
||||
"description": "格网大小(米),影响计算精度和性能",
|
||||
"default": 1.0,
|
||||
"range": [0.1, 10.0]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "tin",
|
||||
"name": "三角网法",
|
||||
"description": "基于不规则三角网(TIN)构建地形表面,计算每个三角形的体积变化,适合复杂地形",
|
||||
"accuracy": "高",
|
||||
"performance": "中等",
|
||||
"parameters": {
|
||||
"resolution": {
|
||||
"name": "不适用",
|
||||
"description": "三角网法不使用固定的分辨率参数",
|
||||
"default": None
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "prism",
|
||||
"name": "三棱柱法",
|
||||
"description": "结合三角网和垂直棱柱的高精度算法,计算每个三棱柱的体积,精度最高",
|
||||
"accuracy": "最高",
|
||||
"performance": "较慢",
|
||||
"parameters": {
|
||||
"resolution": {
|
||||
"name": "棱柱宽度",
|
||||
"description": "棱柱宽度(米),影响计算精度",
|
||||
"default": 1.0,
|
||||
"range": [0.1, 5.0]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
return _success_response(algorithms)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取算法列表失败: {str(e)}")
|
||||
return _error_response(f"获取算法列表失败: {str(e)}", 500)
|
||||
|
||||
# 批量计算接口
|
||||
@earthwork_bp.post("/api/v1/calc/earthwork3dTiles/batch")
|
||||
async def batch_calc_earthwork(request: Request):
|
||||
"""批量土方量计算接口"""
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return _error_response("请求参数不能为空", 400)
|
||||
|
||||
calculations = data.get("calculations", [])
|
||||
|
||||
if not calculations:
|
||||
return _error_response("计算任务列表不能为空", 400)
|
||||
|
||||
if len(calculations) > 100:
|
||||
return _error_response("批量计算数量超过限制(最多100个)", 400)
|
||||
|
||||
|
||||
|
||||
# 3. 执行批量计算
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
for i, calc_data in enumerate(calculations):
|
||||
try:
|
||||
# 提取参数
|
||||
polygon_coords = calc_data.get("polygonCoords")
|
||||
design_elevation = calc_data.get("designElevation")
|
||||
url = calc_data.get("url")
|
||||
|
||||
if not polygon_coords or design_elevation is None or url is None:
|
||||
errors.append({
|
||||
"index": i,
|
||||
"error": "缺少必要参数"
|
||||
})
|
||||
continue
|
||||
|
||||
# 参数验证
|
||||
if len(polygon_coords) < 3:
|
||||
errors.append({
|
||||
"index": i,
|
||||
"error": "多边形至少需要3个点"
|
||||
})
|
||||
continue
|
||||
|
||||
# 检查多边形是否闭合
|
||||
if polygon_coords[0] != polygon_coords[-1]:
|
||||
polygon_coords.append(polygon_coords[0])
|
||||
|
||||
# 可选参数
|
||||
algorithm = calc_data.get("algorithm", "tin")
|
||||
resolution = calc_data.get("resolution", 1.0)
|
||||
crs = calc_data.get("crs", "EPSG:4326")
|
||||
interpolation_method = calc_data.get("interpolationMethod", "linear")
|
||||
|
||||
# 2. 确保计算器已初始化
|
||||
app_info = init_app(url)
|
||||
calculator_3d_tiles = app_info.get("calculator_3d_tiles")
|
||||
|
||||
# 执行计算
|
||||
algorithm_type = AlgorithmType(algorithm)
|
||||
|
||||
result = await calculator_3d_tiles.calculate(
|
||||
polygon_coords=polygon_coords,
|
||||
design_elevation=design_elevation,
|
||||
algorithm=algorithm_type,
|
||||
resolution=resolution,
|
||||
target_crs=crs,
|
||||
interpolation_method=interpolation_method
|
||||
)
|
||||
|
||||
results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
errors.append({
|
||||
"index": i,
|
||||
"error": str(e),
|
||||
"polygon": polygon_coords if 'polygon_coords' in locals() else None
|
||||
})
|
||||
continue
|
||||
|
||||
# 4. 返回结果
|
||||
batch_result = {
|
||||
"results": results,
|
||||
"errors": errors,
|
||||
"summary": {
|
||||
"total": len(calculations),
|
||||
"success": len(results),
|
||||
"failed": len(errors),
|
||||
"successRate": f"{(len(results)/len(calculations)*100):.1f}%" if calculations else "0%"
|
||||
}
|
||||
}
|
||||
|
||||
message = f"批量计算完成,成功 {len(results)} 个,失败 {len(errors)} 个"
|
||||
return _success_response(batch_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量计算失败: {str(e)}")
|
||||
return _error_response(f"批量计算失败: {str(e)}", 500)
|
||||
|
||||
# 核心接口:土方量计算-点云
|
||||
@earthwork_bp.post("/api/v1/calc/earthworkPointCloud")
|
||||
async def calc_earthwork_point_cloud(request: Request):
|
||||
try:
|
||||
# 1. 接收前端传参
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({
|
||||
"code": 400,
|
||||
"msg": "请求参数不能为空",
|
||||
"data": None
|
||||
}, status=400)
|
||||
|
||||
polygon_coords = data.get("polygonCoords") # 计算区域多边形坐标
|
||||
design_elev = data.get("designElevation") # 设计高程
|
||||
crs = data.get("crs", "EPSG:4326") # 坐标系,默认WGS84
|
||||
url = data.get("url")
|
||||
if url is None:
|
||||
return _error_response("地图不能为空", 400)
|
||||
|
||||
# 2. 确保计算器已初始化
|
||||
app_info = init_app(url)
|
||||
calculator_point_cloud = app_info.get("calculator_point_cloud")
|
||||
|
||||
result = calculator_point_cloud.calculate_earthwork(polygon_coords=polygon_coords, design_elev=design_elev, crs=crs)
|
||||
|
||||
# 3. 处理结果
|
||||
if not result["success"]:
|
||||
return _error_response(result["error"], 400)
|
||||
|
||||
# 4. 格式化结果
|
||||
formatted_result = calculator_point_cloud.format_result(result)
|
||||
|
||||
# 5. 返回成功响应
|
||||
return _success_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"服务器错误: {str(e)}")
|
||||
return _error_response(f"服务器内部错误: {str(e)}", 500)
|
||||
|
||||
def _success_response(data: Dict[str, Any]) -> json:
|
||||
"""成功响应"""
|
||||
|
||||
return json({
|
||||
"code": 200,
|
||||
"msg": "计算成功",
|
||||
"data": data
|
||||
})
|
||||
|
||||
def _error_response(message: str, status_code: int = 400) -> json:
|
||||
"""错误响应"""
|
||||
return json({
|
||||
"code": status_code,
|
||||
"msg": message,
|
||||
"data": None
|
||||
}, status=status_code)
|
||||
1647
b3dm/earthwork_calculator_3d_tiles.py
Normal file
691
b3dm/earthwork_calculator_point_cloud.py
Normal file
@ -0,0 +1,691 @@
|
||||
# earthwork_calculator.py
|
||||
import pdal
|
||||
import pyvista as pv
|
||||
import numpy as np
|
||||
import json
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
import traceback
|
||||
|
||||
|
||||
class EarthworkAlgorithm(Enum):
|
||||
"""土方量计算算法枚举"""
|
||||
GRID = "grid"
|
||||
TIN = "tin"
|
||||
PRISM = "prism"
|
||||
|
||||
|
||||
@dataclass
|
||||
class EarthworkResultPointCloud:
|
||||
"""土方量计算结果"""
|
||||
cut_volume: float # 挖方量 (m³)
|
||||
fill_volume: float # 填方量 (m³)
|
||||
net_volume: float # 净方量 (m³)
|
||||
area: float # 计算区域面积 (m²)
|
||||
avg_elevation: float # 平均高程
|
||||
min_elevation: float # 最低高程
|
||||
max_elevation: float # 最高高程
|
||||
points_count: int # 使用的点数
|
||||
triangle_count: int = 0 # 三角形数量
|
||||
grid_count: int = 0 # 网格数量(仅GRID算法使用)
|
||||
prism_count: int = 0 # 棱柱体数量(仅PRISM算法使用)
|
||||
bounding_box: Dict[str, List[float]] = field(default_factory=dict) # 边界框
|
||||
volume_accuracy: float = 0.95 # 计算精度
|
||||
algorithm: str = "TIN三角网法" # 使用的算法
|
||||
resolution: float = 1.0 # 计算分辨率
|
||||
algorithm_params: Dict[str, Any] = field(default_factory=dict) # 算法参数
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"volume": {
|
||||
"cut": round(self.cut_volume, 3),
|
||||
"fill": round(self.fill_volume, 3),
|
||||
"net": round(self.net_volume, 3),
|
||||
"unit": "m³"
|
||||
},
|
||||
"area": {
|
||||
"value": round(self.area, 3),
|
||||
"unit": "m²"
|
||||
},
|
||||
"elevation": {
|
||||
"average": round(self.avg_elevation, 3),
|
||||
"min": round(self.min_elevation, 3),
|
||||
"max": round(self.max_elevation, 3),
|
||||
"unit": "m"
|
||||
},
|
||||
"statistics": {
|
||||
"points_count": self.points_count,
|
||||
"accuracy": round(self.volume_accuracy, 3),
|
||||
"algorithm": self.algorithm
|
||||
},
|
||||
"bounding_box": self.bounding_box,
|
||||
"calculation_params": {
|
||||
"resolution": self.resolution,
|
||||
"accuracy": self.volume_accuracy,
|
||||
**self.algorithm_params
|
||||
}
|
||||
}
|
||||
|
||||
# 根据算法类型添加特定的统计信息
|
||||
if self.algorithm.startswith("GRID"):
|
||||
result["statistics"]["grid_count"] = self.grid_count
|
||||
elif self.algorithm.startswith("TIN"):
|
||||
result["statistics"]["triangle_count"] = self.triangle_count
|
||||
elif self.algorithm.startswith("PRISM"):
|
||||
result["statistics"]["prism_count"] = self.prism_count
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class EarthworkCalculatorPointCloud:
|
||||
"""土方量计算核心类(支持多种算法)"""
|
||||
|
||||
def __init__(self, point_cloud_path: str = "./data/your_point_cloud.laz"):
|
||||
"""
|
||||
初始化土方量计算器
|
||||
|
||||
Args:
|
||||
point_cloud_path: 点云数据文件路径
|
||||
"""
|
||||
self.point_cloud_path = point_cloud_path
|
||||
|
||||
def validate_inputs(self, polygon_coords: List[List[float]], design_elev: float) -> Tuple[bool, str]:
|
||||
"""验证输入参数"""
|
||||
if not polygon_coords or len(polygon_coords) < 3:
|
||||
return False, "多边形坐标至少需要3个点"
|
||||
|
||||
try:
|
||||
design_elev = float(design_elev)
|
||||
except (TypeError, ValueError):
|
||||
return False, "设计高程必须是有效数字"
|
||||
|
||||
return True, ""
|
||||
|
||||
def create_polygon_string(self, polygon_coords: List[List[float]]) -> str:
|
||||
"""创建PDAL多边形字符串"""
|
||||
coords_list = []
|
||||
for coord in polygon_coords:
|
||||
if len(coord) >= 2:
|
||||
coords_list.append(f"{coord[0]} {coord[1]}")
|
||||
|
||||
# 确保多边形闭合
|
||||
if coords_list and coords_list[0] != coords_list[-1]:
|
||||
coords_list.append(coords_list[0])
|
||||
|
||||
return "POLYGON((" + ", ".join(coords_list) + "))"
|
||||
|
||||
def calculate_bounding_box(self, points: np.ndarray) -> Dict[str, List[float]]:
|
||||
"""
|
||||
计算边界框
|
||||
|
||||
Args:
|
||||
points: 点云坐标数组
|
||||
|
||||
Returns:
|
||||
Dict: 边界框信息
|
||||
"""
|
||||
if len(points) == 0:
|
||||
return {"min": [0, 0, 0], "max": [0, 0, 0]}
|
||||
|
||||
min_vals = np.min(points, axis=0)
|
||||
max_vals = np.max(points, axis=0)
|
||||
|
||||
return {
|
||||
"min": [float(min_vals[0]), float(min_vals[1]), float(min_vals[2])],
|
||||
"max": [float(max_vals[0]), float(max_vals[1]), float(max_vals[2])]
|
||||
}
|
||||
|
||||
def clip_point_cloud(self, polygon_coords: List[List[float]], crs: str = "EPSG:4326") -> pv.PolyData:
|
||||
"""
|
||||
裁剪点云数据
|
||||
|
||||
Args:
|
||||
polygon_coords: 多边形坐标列表
|
||||
crs: 坐标系
|
||||
|
||||
Returns:
|
||||
pyvista.PolyData: 裁剪后的点云数据
|
||||
"""
|
||||
polygon_str = self.create_polygon_string(polygon_coords)
|
||||
|
||||
# PDAL管道配置
|
||||
pipeline_config = {
|
||||
"pipeline": [
|
||||
{
|
||||
"type": "readers.las",
|
||||
"filename": self.point_cloud_path,
|
||||
"spatialreference": crs
|
||||
},
|
||||
{
|
||||
"type": "filters.crop",
|
||||
"polygon": polygon_str
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 执行PDAL管道
|
||||
pipeline = pdal.Pipeline(json.dumps(pipeline_config))
|
||||
|
||||
try:
|
||||
pipeline.execute()
|
||||
|
||||
if len(pipeline.arrays) == 0:
|
||||
raise ValueError("多边形区域内没有找到点云数据")
|
||||
|
||||
# 获取裁剪后的点云数据
|
||||
points = pipeline.arrays[0]
|
||||
x = points["X"]
|
||||
y = points["Y"]
|
||||
z = points["Z"]
|
||||
|
||||
return pv.PolyData(np.column_stack((x, y, z)))
|
||||
|
||||
except RuntimeError as e:
|
||||
print(f"PDAL执行失败: {str(e)}")
|
||||
# 如果没有PDAL数据,生成模拟数据用于测试
|
||||
return self.generate_mock_point_cloud(polygon_coords)
|
||||
|
||||
def generate_mock_point_cloud(self, polygon_coords: List[List[float]]) -> pv.PolyData:
|
||||
"""
|
||||
生成模拟点云数据(仅用于测试)
|
||||
|
||||
Args:
|
||||
polygon_coords: 多边形坐标列表
|
||||
|
||||
Returns:
|
||||
pyvista.PolyData: 模拟点云数据
|
||||
"""
|
||||
print("使用模拟数据进行测试...")
|
||||
|
||||
n_points = 1000
|
||||
|
||||
# 获取多边形边界
|
||||
x_coords = [c[0] for c in polygon_coords]
|
||||
y_coords = [c[1] for c in polygon_coords]
|
||||
|
||||
x_min, x_max = min(x_coords), max(x_coords)
|
||||
y_min, y_max = min(y_coords), max(y_coords)
|
||||
|
||||
# 生成随机点
|
||||
x = np.random.uniform(x_min, x_max, n_points)
|
||||
y = np.random.uniform(y_min, y_max, n_points)
|
||||
z = np.random.uniform(100, 120, n_points) # 模拟高程在100-120米之间
|
||||
|
||||
return pv.PolyData(np.column_stack((x, y, z)))
|
||||
|
||||
def create_tin_mesh(self, point_cloud: pv.PolyData) -> pv.PolyData:
|
||||
"""
|
||||
创建三角网(TIN算法使用)
|
||||
|
||||
Args:
|
||||
point_cloud: 点云数据
|
||||
|
||||
Returns:
|
||||
pyvista.PolyData: 三角网格
|
||||
"""
|
||||
if len(point_cloud.points) < 3:
|
||||
raise ValueError("点云数据不足,无法构网")
|
||||
|
||||
try:
|
||||
return point_cloud.delaunay_2d()
|
||||
except Exception as e:
|
||||
raise ValueError(f"三角网构网失败: {str(e)}")
|
||||
|
||||
def calculate_volumes_by_tin(self, mesh: pv.PolyData, design_elev: float) -> Dict[str, Any]:
|
||||
"""
|
||||
TIN算法计算土方量
|
||||
|
||||
Args:
|
||||
mesh: 三角网格
|
||||
design_elev: 设计高程
|
||||
|
||||
Returns:
|
||||
Dict: 包含体积计算结果和额外信息
|
||||
"""
|
||||
points = mesh.points
|
||||
elev_diff = points[:, 2] - design_elev
|
||||
|
||||
cut_volume = 0.0
|
||||
fill_volume = 0.0
|
||||
triangle_count = 0
|
||||
|
||||
# 遍历所有三角形面片计算体积
|
||||
cells = mesh.cells.reshape(-1, 4)
|
||||
|
||||
if len(cells) == 0:
|
||||
raise ValueError("无法生成有效的三角网")
|
||||
|
||||
for cell in cells:
|
||||
if cell[0] == 3: # 三角形(VTK格式:3个顶点)
|
||||
triangle_count += 1
|
||||
vertex_indices = cell[1:]
|
||||
pts = points[vertex_indices]
|
||||
|
||||
# 计算三角形面积
|
||||
v1 = pts[1] - pts[0]
|
||||
v2 = pts[2] - pts[0]
|
||||
area = 0.5 * np.linalg.norm(np.cross(v1, v2))
|
||||
|
||||
# 计算平均高程差
|
||||
avg_diff = np.mean(elev_diff[vertex_indices])
|
||||
vol = area * avg_diff
|
||||
|
||||
if vol > 0:
|
||||
cut_volume += vol
|
||||
else:
|
||||
fill_volume += abs(vol)
|
||||
|
||||
return {
|
||||
"cut_volume": cut_volume,
|
||||
"fill_volume": fill_volume,
|
||||
"net_volume": cut_volume - fill_volume,
|
||||
"triangle_count": triangle_count
|
||||
}
|
||||
|
||||
def calculate_volumes_by_grid(self, point_cloud: pv.PolyData, design_elev: float,
|
||||
grid_size: float = 1.0) -> Dict[str, Any]:
|
||||
"""
|
||||
GRID算法计算土方量
|
||||
|
||||
Args:
|
||||
point_cloud: 点云数据
|
||||
design_elev: 设计高程
|
||||
grid_size: 网格尺寸
|
||||
|
||||
Returns:
|
||||
Dict: 包含体积计算结果和额外信息
|
||||
"""
|
||||
points = point_cloud.points
|
||||
|
||||
if len(points) == 0:
|
||||
raise ValueError("点云数据为空")
|
||||
|
||||
# 计算边界
|
||||
x_min, y_min = np.min(points[:, :2], axis=0)
|
||||
x_max, y_max = np.max(points[:, :2], axis=0)
|
||||
|
||||
# 创建网格
|
||||
x_edges = np.arange(x_min, x_max + grid_size, grid_size)
|
||||
y_edges = np.arange(y_min, y_max + grid_size, grid_size)
|
||||
|
||||
grid_count = (len(x_edges) - 1) * (len(y_edges) - 1)
|
||||
|
||||
cut_volume = 0.0
|
||||
fill_volume = 0.0
|
||||
|
||||
# 对每个网格计算土方量
|
||||
for i in range(len(x_edges) - 1):
|
||||
for j in range(len(y_edges) - 1):
|
||||
# 获取当前网格内的点
|
||||
mask = (points[:, 0] >= x_edges[i]) & (points[:, 0] < x_edges[i+1]) & \
|
||||
(points[:, 1] >= y_edges[j]) & (points[:, 1] < y_edges[j+1])
|
||||
|
||||
grid_points = points[mask]
|
||||
|
||||
if len(grid_points) > 0:
|
||||
# 计算网格内点的平均高程
|
||||
avg_elevation = np.mean(grid_points[:, 2])
|
||||
|
||||
# 计算高程差
|
||||
elev_diff = avg_elevation - design_elev
|
||||
|
||||
# 计算体积
|
||||
cell_area = grid_size * grid_size
|
||||
vol = cell_area * elev_diff
|
||||
|
||||
if vol > 0:
|
||||
cut_volume += vol
|
||||
else:
|
||||
fill_volume += abs(vol)
|
||||
|
||||
return {
|
||||
"cut_volume": cut_volume,
|
||||
"fill_volume": fill_volume,
|
||||
"net_volume": cut_volume - fill_volume,
|
||||
"grid_count": grid_count
|
||||
}
|
||||
|
||||
def calculate_volumes_by_prism(self, point_cloud: pv.PolyData, design_elev: float,
|
||||
influence_radius: float = 0.5) -> Dict[str, Any]:
|
||||
"""
|
||||
PRISM算法计算土方量
|
||||
|
||||
Args:
|
||||
point_cloud: 点云数据
|
||||
design_elev: 设计高程
|
||||
influence_radius: 影响半径
|
||||
|
||||
Returns:
|
||||
Dict: 包含体积计算结果和额外信息
|
||||
"""
|
||||
points = point_cloud.points
|
||||
|
||||
if len(points) == 0:
|
||||
raise ValueError("点云数据为空")
|
||||
|
||||
cut_volume = 0.0
|
||||
fill_volume = 0.0
|
||||
|
||||
# 每个点的影响面积
|
||||
influence_area = np.pi * influence_radius ** 2
|
||||
prism_count = len(points)
|
||||
|
||||
for point in points:
|
||||
# 计算高程差
|
||||
elev_diff = point[2] - design_elev
|
||||
|
||||
# 计算体积
|
||||
vol = influence_area * elev_diff
|
||||
|
||||
if vol > 0:
|
||||
cut_volume += vol
|
||||
else:
|
||||
fill_volume += abs(vol)
|
||||
|
||||
return {
|
||||
"cut_volume": cut_volume,
|
||||
"fill_volume": fill_volume,
|
||||
"net_volume": cut_volume - fill_volume,
|
||||
"prism_count": prism_count
|
||||
}
|
||||
|
||||
def calculate_statistics(self, point_cloud: pv.PolyData, mesh: pv.PolyData = None) -> Dict[str, float]:
|
||||
"""
|
||||
计算统计数据
|
||||
|
||||
Args:
|
||||
point_cloud: 点云数据
|
||||
mesh: 三角网格(仅TIN算法需要)
|
||||
|
||||
Returns:
|
||||
Dict: 统计结果
|
||||
"""
|
||||
elevations = point_cloud.points[:, 2]
|
||||
|
||||
stats = {
|
||||
"area": 0.0,
|
||||
"max_elevation": np.max(elevations) if len(elevations) > 0 else 0.0,
|
||||
"min_elevation": np.min(elevations) if len(elevations) > 0 else 0.0,
|
||||
"avg_elevation": np.mean(elevations) if len(elevations) > 0 else 0.0,
|
||||
"points_count": len(point_cloud.points)
|
||||
}
|
||||
|
||||
# 计算面积
|
||||
if mesh is not None:
|
||||
stats["area"] = mesh.area
|
||||
else:
|
||||
# 对于非TIN算法,使用多边形面积近似
|
||||
if len(point_cloud.points) > 0:
|
||||
# 使用点云的凸包面积
|
||||
try:
|
||||
hull = point_cloud.delaunay_2d()
|
||||
stats["area"] = hull.area
|
||||
except:
|
||||
# 如果无法计算凸包,使用边界框面积
|
||||
x_min, x_max = np.min(point_cloud.points[:, 0]), np.max(point_cloud.points[:, 0])
|
||||
y_min, y_max = np.min(point_cloud.points[:, 1]), np.max(point_cloud.points[:, 1])
|
||||
stats["area"] = (x_max - x_min) * (y_max - y_min)
|
||||
|
||||
return stats
|
||||
|
||||
def calculate_earthwork(self,
|
||||
polygon_coords: List[List[float]],
|
||||
design_elev: float,
|
||||
algorithm: str = EarthworkAlgorithm.TIN.value,
|
||||
algorithm_params: Optional[Dict[str, Any]] = None,
|
||||
crs: str = "EPSG:4326",
|
||||
volume_accuracy: Optional[float] = None,
|
||||
resolution: Optional[float] = None) -> EarthworkResultPointCloud:
|
||||
"""
|
||||
主计算方法:执行完整的土方量计算流程
|
||||
|
||||
Args:
|
||||
polygon_coords: 多边形坐标列表
|
||||
design_elev: 设计高程
|
||||
algorithm: 计算算法,可选值:'grid', 'tin', 'prism'
|
||||
algorithm_params: 算法特定参数
|
||||
crs: 坐标系
|
||||
volume_accuracy: 计算精度(0-1之间)
|
||||
resolution: 计算分辨率
|
||||
|
||||
Returns:
|
||||
EarthworkResultPointCloud: 计算结果
|
||||
"""
|
||||
try:
|
||||
# 1. 验证输入
|
||||
is_valid, message = self.validate_inputs(polygon_coords, design_elev)
|
||||
if not is_valid:
|
||||
raise ValueError(message)
|
||||
|
||||
design_elev = float(design_elev)
|
||||
|
||||
# 2. 验证算法参数
|
||||
if algorithm not in [a.value for a in EarthworkAlgorithm]:
|
||||
raise ValueError(f"不支持的算法: {algorithm}。支持的算法: {[a.value for a in EarthworkAlgorithm]}")
|
||||
|
||||
# 3. 设置默认参数
|
||||
if algorithm_params is None:
|
||||
algorithm_params = {}
|
||||
|
||||
# 4. 裁剪点云
|
||||
point_cloud = self.clip_point_cloud(polygon_coords, crs)
|
||||
|
||||
# 5. 根据算法选择计算方法
|
||||
algorithm_name = ""
|
||||
mesh = None
|
||||
|
||||
if algorithm == EarthworkAlgorithm.TIN.value:
|
||||
algorithm_name = "TIN三角网法"
|
||||
mesh = self.create_tin_mesh(point_cloud)
|
||||
volumes = self.calculate_volumes_by_tin(mesh, design_elev)
|
||||
algorithm_params = {
|
||||
"grid_size": algorithm_params.get("grid_size", 1.0)
|
||||
}
|
||||
|
||||
elif algorithm == EarthworkAlgorithm.GRID.value:
|
||||
algorithm_name = "GRID格网法"
|
||||
grid_size = algorithm_params.get("grid_size", 1.0)
|
||||
algorithm_name = f"GRID格网法(网格尺寸={grid_size}m)"
|
||||
volumes = self.calculate_volumes_by_grid(point_cloud, design_elev, grid_size)
|
||||
algorithm_params = {
|
||||
"grid_size": grid_size
|
||||
}
|
||||
|
||||
elif algorithm == EarthworkAlgorithm.PRISM.value:
|
||||
algorithm_name = "PRISM棱柱体法"
|
||||
influence_radius = algorithm_params.get("influence_radius", 0.5)
|
||||
algorithm_name = f"PRISM棱柱体法(影响半径={influence_radius}m)"
|
||||
volumes = self.calculate_volumes_by_prism(point_cloud, design_elev, influence_radius)
|
||||
algorithm_params = {
|
||||
"influence_radius": influence_radius
|
||||
}
|
||||
|
||||
# 6. 计算统计数据
|
||||
stats = self.calculate_statistics(point_cloud, mesh)
|
||||
|
||||
# 7. 计算边界框
|
||||
bounding_box = self.calculate_bounding_box(point_cloud.points)
|
||||
|
||||
# 8. 计算或使用传入的精度和分辨率
|
||||
if volume_accuracy is None:
|
||||
# 根据算法和点云密度自动估算精度
|
||||
volume_accuracy = self.estimate_accuracy(algorithm, point_cloud)
|
||||
|
||||
if resolution is None:
|
||||
# 根据点云密度自动估算分辨率
|
||||
resolution = self.estimate_resolution(point_cloud)
|
||||
|
||||
# 9. 创建EarthworkResultPointCloud对象
|
||||
result = EarthworkResultPointCloud(
|
||||
cut_volume=volumes["cut_volume"],
|
||||
fill_volume=volumes["fill_volume"],
|
||||
net_volume=volumes["net_volume"],
|
||||
area=stats["area"],
|
||||
avg_elevation=stats["avg_elevation"],
|
||||
min_elevation=stats["min_elevation"],
|
||||
max_elevation=stats["max_elevation"],
|
||||
points_count=stats["points_count"],
|
||||
triangle_count=volumes.get("triangle_count", 0),
|
||||
grid_count=volumes.get("grid_count", 0),
|
||||
prism_count=volumes.get("prism_count", 0),
|
||||
bounding_box=bounding_box,
|
||||
volume_accuracy=volume_accuracy,
|
||||
algorithm=algorithm_name,
|
||||
resolution=resolution,
|
||||
algorithm_params=algorithm_params
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(f"计算错误: {str(e)}")
|
||||
print(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def estimate_accuracy(self, algorithm: str, point_cloud: pv.PolyData) -> float:
|
||||
"""根据算法和点云密度估计计算精度"""
|
||||
point_density = len(point_cloud.points) / max(point_cloud.area, 0.1)
|
||||
|
||||
# 基础精度
|
||||
base_accuracy = {
|
||||
EarthworkAlgorithm.TIN.value: 0.95,
|
||||
EarthworkAlgorithm.GRID.value: 0.90,
|
||||
EarthworkAlgorithm.PRISM.value: 0.85
|
||||
}.get(algorithm, 0.90)
|
||||
|
||||
# 根据点云密度调整精度
|
||||
if point_density > 10: # 高密度点云
|
||||
accuracy_boost = min(0.05, point_density * 0.002)
|
||||
elif point_density < 1: # 低密度点云
|
||||
accuracy_boost = -0.05
|
||||
else:
|
||||
accuracy_boost = 0
|
||||
|
||||
estimated_accuracy = base_accuracy + accuracy_boost
|
||||
|
||||
# 确保精度在合理范围内
|
||||
return max(0.7, min(0.99, estimated_accuracy))
|
||||
|
||||
def estimate_resolution(self, point_cloud: pv.PolyData) -> float:
|
||||
"""根据点云密度估计分辨率"""
|
||||
if len(point_cloud.points) < 2:
|
||||
return 1.0
|
||||
|
||||
area = point_cloud.area
|
||||
if area > 0:
|
||||
point_count = len(point_cloud.points)
|
||||
avg_spacing = np.sqrt(area / point_count)
|
||||
return float(round(avg_spacing, 2))
|
||||
|
||||
return 1.0
|
||||
|
||||
def get_algorithm_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取支持的算法信息
|
||||
|
||||
Returns:
|
||||
Dict: 算法信息
|
||||
"""
|
||||
return {
|
||||
"supported_algorithms": [
|
||||
{
|
||||
"id": EarthworkAlgorithm.TIN.value,
|
||||
"name": "TIN三角网法",
|
||||
"description": "通过构建不规则三角网计算土方量,精度高,适合复杂地形",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "grid_size",
|
||||
"type": "float",
|
||||
"default": 1.0,
|
||||
"description": "网格尺寸(m),用于点云预处理",
|
||||
"required": False
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": EarthworkAlgorithm.GRID.value,
|
||||
"name": "GRID格网法",
|
||||
"description": "将区域划分为规则网格计算土方量,计算速度快,适合大规模区域",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "grid_size",
|
||||
"type": "float",
|
||||
"default": 1.0,
|
||||
"description": "网格尺寸(m)",
|
||||
"required": True,
|
||||
"min": 0.1,
|
||||
"max": 10.0
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": EarthworkAlgorithm.PRISM.value,
|
||||
"name": "PRISM棱柱体法",
|
||||
"description": "将每个点视为一个棱柱体计算土方量,计算简单快速",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "influence_radius",
|
||||
"type": "float",
|
||||
"default": 0.5,
|
||||
"description": "点的影响半径(m)",
|
||||
"required": True,
|
||||
"min": 0.1,
|
||||
"max": 5.0
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"default_algorithm": EarthworkAlgorithm.TIN.value
|
||||
}
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 创建计算器实例
|
||||
calculator = EarthworkCalculatorPointCloud("./data/sample_point_cloud.laz")
|
||||
|
||||
# 获取支持的算法信息
|
||||
algorithm_info = calculator.get_algorithm_info()
|
||||
print("支持的算法:")
|
||||
for algo in algorithm_info["supported_algorithms"]:
|
||||
print(f" {algo['id']}: {algo['name']} - {algo['description']}")
|
||||
|
||||
# 定义多边形区域
|
||||
polygon = [
|
||||
[116.3974, 39.9093],
|
||||
[116.4084, 39.9093],
|
||||
[116.4084, 39.9193],
|
||||
[116.3974, 39.9193]
|
||||
]
|
||||
|
||||
# 设计高程
|
||||
design_elevation = 100.0
|
||||
|
||||
# 测试不同算法
|
||||
algorithms = [
|
||||
(EarthworkAlgorithm.TIN.value, {}, "TIN算法"),
|
||||
(EarthworkAlgorithm.GRID.value, {"grid_size": 2.0}, "GRID算法(2米网格)"),
|
||||
(EarthworkAlgorithm.PRISM.value, {"influence_radius": 1.0}, "PRISM算法(1米影响半径)")
|
||||
]
|
||||
|
||||
for algo_id, params, description in algorithms:
|
||||
print(f"\n使用{description}计算:")
|
||||
try:
|
||||
result = calculator.calculate_earthwork(
|
||||
polygon_coords=polygon,
|
||||
design_elev=design_elevation,
|
||||
algorithm=algo_id,
|
||||
algorithm_params=params
|
||||
)
|
||||
|
||||
print(f" 挖方量: {result.cut_volume:.3f} m³")
|
||||
print(f" 填方量: {result.fill_volume:.3f} m³")
|
||||
print(f" 净方量: {result.net_volume:.3f} m³")
|
||||
print(f" 计算面积: {result.area:.3f} m²")
|
||||
print(f" 计算精度: {result.volume_accuracy:.3%}")
|
||||
print(f" 分辨率: {result.resolution:.2f} m")
|
||||
|
||||
except Exception as e:
|
||||
print(f" 计算失败: {str(e)}")
|
||||
469
b3dm/glb_with_draco.py
Normal file
@ -0,0 +1,469 @@
|
||||
import json
|
||||
import struct
|
||||
import numpy as np
|
||||
import DracoPy
|
||||
|
||||
class DracoGLBParser:
|
||||
"""使用 DracoPy 解析包含 Draco 压缩的 GLB 文件"""
|
||||
|
||||
def __init__(self, glb_file_path):
|
||||
self.glb_file_path = glb_file_path
|
||||
self.json_data = None
|
||||
self.binary_data = None
|
||||
self.decoded_meshes = [] # 缓存解码后的网格数据
|
||||
|
||||
def parse_glb_structure(self):
|
||||
"""解析 GLB 文件结构"""
|
||||
with open(self.glb_file_path, 'rb') as f:
|
||||
# 读取 GLB 头部
|
||||
magic = f.read(4)
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
total_length = struct.unpack('<I', f.read(4))[0]
|
||||
|
||||
print("=" * 60)
|
||||
print(f"GLB 文件分析:")
|
||||
print(f" 文件类型: {magic.decode('utf-8')}")
|
||||
print(f" 版本: {version}")
|
||||
print(f" 总大小: {total_length:,} bytes")
|
||||
|
||||
# 读取 JSON chunk
|
||||
json_length = struct.unpack('<I', f.read(4))[0]
|
||||
json_type = f.read(4)
|
||||
|
||||
if json_type != b'JSON':
|
||||
raise ValueError(f"期望 JSON chunk,但得到: {json_type}")
|
||||
|
||||
self.json_data = json.loads(f.read(json_length).decode('utf-8'))
|
||||
|
||||
print(f" JSON 大小: {json_length:,} bytes")
|
||||
|
||||
# 读取 Binary chunk
|
||||
if f.tell() < total_length:
|
||||
bin_length = struct.unpack('<I', f.read(4))[0]
|
||||
bin_type = f.read(4)
|
||||
|
||||
if bin_type != b'BIN\x00':
|
||||
raise ValueError(f"期望 BIN chunk,但得到: {bin_type}")
|
||||
|
||||
self.binary_data = f.read(bin_length)
|
||||
print(f" 二进制数据大小: {bin_length:,} bytes")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return self
|
||||
|
||||
def analyze_structure(self):
|
||||
"""分析 GLB 结构"""
|
||||
if not self.json_data:
|
||||
self.parse_glb_structure()
|
||||
|
||||
print("\nGLB 结构分析:")
|
||||
print("-" * 40)
|
||||
|
||||
# 基本信息
|
||||
asset = self.json_data.get('asset', {})
|
||||
print(f"生成器: {asset.get('generator', '未知')}")
|
||||
print(f"glTF 版本: {asset.get('version', '未知')}")
|
||||
|
||||
# Draco 扩展
|
||||
extensions_used = self.json_data.get('extensionsUsed', [])
|
||||
extensions_required = self.json_data.get('extensionsRequired', [])
|
||||
|
||||
if 'KHR_draco_mesh_compression' in extensions_used:
|
||||
print("使用 Draco 压缩")
|
||||
if 'KHR_draco_mesh_compression' in extensions_required:
|
||||
print("Draco 压缩是必需的")
|
||||
|
||||
# 网格信息
|
||||
meshes = self.json_data.get('meshes', [])
|
||||
print(f"\n网格数量: {len(meshes)}")
|
||||
|
||||
for i, mesh in enumerate(meshes):
|
||||
print(f" 网格 {i}: {mesh.get('name', '未命名')}")
|
||||
primitives = mesh.get('primitives', [])
|
||||
print(f" 图元数量: {len(primitives)}")
|
||||
|
||||
for j, primitive in enumerate(primitives):
|
||||
print(f" 图元 {j}:")
|
||||
if 'extensions' in primitive:
|
||||
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
|
||||
if draco_info:
|
||||
print(f" 使用 Draco 压缩")
|
||||
print(f" 属性: {draco_info.get('attributes', {})}")
|
||||
|
||||
# 缓冲区信息
|
||||
buffers = self.json_data.get('buffers', [])
|
||||
buffer_views = self.json_data.get('bufferViews', [])
|
||||
accessors = self.json_data.get('accessors', [])
|
||||
|
||||
print(f"\n缓冲区: {len(buffers)}")
|
||||
print(f"BufferViews: {len(buffer_views)}")
|
||||
print(f"访问器: {len(accessors)}")
|
||||
|
||||
return self
|
||||
|
||||
def decode_draco_meshes(self):
|
||||
"""解码所有 Draco 压缩的网格"""
|
||||
if not self.json_data:
|
||||
self.parse_glb_structure()
|
||||
|
||||
meshes = []
|
||||
buffer_views = self.json_data.get('bufferViews', [])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("开始解码 Draco 压缩数据...")
|
||||
print("=" * 60)
|
||||
|
||||
for mesh_idx, mesh in enumerate(self.json_data.get('meshes', [])):
|
||||
mesh_name = mesh.get('name', f'mesh_{mesh_idx}')
|
||||
|
||||
for primitive_idx, primitive in enumerate(mesh.get('primitives', [])):
|
||||
if 'extensions' in primitive:
|
||||
draco_info = primitive['extensions'].get('KHR_draco_mesh_compression')
|
||||
|
||||
if draco_info:
|
||||
print(f"\n解码: {mesh_name} - 图元 {primitive_idx}")
|
||||
|
||||
# 解码 Draco 数据
|
||||
mesh_data = self._decode_primitive(draco_info, buffer_views)
|
||||
|
||||
if mesh_data:
|
||||
meshes.append({
|
||||
'mesh_idx': mesh_idx,
|
||||
'primitive_idx': primitive_idx,
|
||||
'name': mesh_name,
|
||||
**mesh_data
|
||||
})
|
||||
|
||||
self.decoded_meshes = meshes # 缓存解码结果
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"解码完成!共解码 {len(meshes)} 个网格")
|
||||
print("=" * 60)
|
||||
|
||||
return meshes
|
||||
|
||||
def get_vertices(self, mesh_idx=0, primitive_idx=0):
|
||||
"""
|
||||
获取指定网格的顶点集合
|
||||
|
||||
参数:
|
||||
mesh_idx: 网格索引,默认0
|
||||
primitive_idx: 图元索引,默认0
|
||||
|
||||
返回:
|
||||
np.array: 顶点数组,形状为 (n, 3) 或 None
|
||||
"""
|
||||
# 如果还没有解码数据,先解码
|
||||
if not self.decoded_meshes:
|
||||
self.decode_draco_meshes()
|
||||
|
||||
# 查找指定网格
|
||||
for mesh in self.decoded_meshes:
|
||||
if mesh['mesh_idx'] == mesh_idx and mesh['primitive_idx'] == primitive_idx:
|
||||
return mesh.get('vertices')
|
||||
|
||||
print(f"未找到网格 {mesh_idx} 的图元 {primitive_idx}")
|
||||
return None
|
||||
|
||||
def get_all_vertices(self):
|
||||
"""
|
||||
获取所有网格的所有顶点,合并成一个数组
|
||||
|
||||
返回:
|
||||
np.array: 所有顶点的合并数组,形状为 (n, 3) 或 None
|
||||
"""
|
||||
# 如果还没有解码数据,先解码
|
||||
if not self.decoded_meshes:
|
||||
self.decode_draco_meshes()
|
||||
|
||||
if not self.decoded_meshes:
|
||||
print("没有解码的网格数据")
|
||||
return None
|
||||
|
||||
# 收集所有顶点
|
||||
all_vertices = []
|
||||
for mesh in self.decoded_meshes:
|
||||
if mesh.get('vertices') is not None:
|
||||
all_vertices.append(mesh['vertices'])
|
||||
|
||||
if not all_vertices:
|
||||
return None
|
||||
|
||||
# 合并所有顶点
|
||||
return np.vstack(all_vertices)
|
||||
|
||||
def get_vertices_by_mesh_name(self, mesh_name):
|
||||
"""
|
||||
根据网格名称获取顶点集合
|
||||
|
||||
参数:
|
||||
mesh_name: 网格名称
|
||||
|
||||
返回:
|
||||
list: 包含所有匹配网格的顶点数组列表
|
||||
"""
|
||||
# 如果还没有解码数据,先解码
|
||||
if not self.decoded_meshes:
|
||||
self.decode_draco_meshes()
|
||||
|
||||
vertices_list = []
|
||||
for mesh in self.decoded_meshes:
|
||||
if mesh['name'] == mesh_name and mesh.get('vertices') is not None:
|
||||
vertices_list.append(mesh['vertices'])
|
||||
|
||||
return vertices_list
|
||||
|
||||
def get_vertex_count(self):
|
||||
"""
|
||||
获取总顶点数
|
||||
|
||||
返回:
|
||||
int: 所有网格的总顶点数
|
||||
"""
|
||||
vertices = self.get_all_vertices()
|
||||
return len(vertices) if vertices is not None else 0
|
||||
|
||||
def _decode_primitive(self, draco_info, buffer_views):
|
||||
"""解码单个图元的 Draco 数据"""
|
||||
try:
|
||||
# 获取 bufferView 信息
|
||||
buffer_view_idx = draco_info['bufferView']
|
||||
attributes = draco_info['attributes']
|
||||
|
||||
buffer_view = buffer_views[buffer_view_idx]
|
||||
byte_offset = buffer_view.get('byteOffset', 0)
|
||||
byte_length = buffer_view['byteLength']
|
||||
|
||||
print(f" BufferView: {buffer_view_idx}")
|
||||
print(f" 属性映射: {attributes}")
|
||||
print(f" 数据位置: offset={byte_offset}, length={byte_length}")
|
||||
|
||||
# 提取 Draco 压缩数据
|
||||
draco_data = self.binary_data[byte_offset:byte_offset + byte_length]
|
||||
print(f" Draco 数据大小: {len(draco_data):,} bytes")
|
||||
|
||||
# 使用 DracoPy 解码
|
||||
print(" 正在使用 DracoPy 解码...")
|
||||
draco_decoder = DracoPy.decode(draco_data)
|
||||
|
||||
# 解析解码结果
|
||||
mesh_data = self._parse_draco_result(draco_decoder, attributes)
|
||||
|
||||
return mesh_data
|
||||
|
||||
except Exception as e:
|
||||
print(f" 解码失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def _parse_draco_result(self, draco_decoder, attributes):
|
||||
"""解析 DracoPy 解码结果"""
|
||||
result = {
|
||||
'vertices': None,
|
||||
'faces': None,
|
||||
'texcoords': None,
|
||||
'batch_ids': None,
|
||||
'normals': None,
|
||||
'colors': None
|
||||
}
|
||||
|
||||
# 获取顶点
|
||||
if hasattr(draco_decoder, 'points'):
|
||||
result['vertices'] = np.array(draco_decoder.points, dtype=np.float32)
|
||||
print(f" 顶点数量: {len(result['vertices'])}")
|
||||
|
||||
# 获取面/三角形
|
||||
if hasattr(draco_decoder, 'faces'):
|
||||
faces_data = draco_decoder.faces
|
||||
|
||||
# 确保是三角形(每面3个顶点)
|
||||
if len(faces_data) > 0:
|
||||
if isinstance(faces_data[0], list) or isinstance(faces_data[0], tuple):
|
||||
# 如果是列表的列表
|
||||
result['faces'] = np.array(faces_data, dtype=np.uint32)
|
||||
else:
|
||||
# 如果是扁平化的数组
|
||||
result['faces'] = np.array(faces_data, dtype=np.uint32).reshape(-1, 3)
|
||||
|
||||
print(f" 面数量: {len(result['faces']) if result['faces'] is not None else 0}")
|
||||
|
||||
# 获取属性数据
|
||||
if hasattr(draco_decoder, 'attributes'):
|
||||
attrs = draco_decoder.attributes
|
||||
|
||||
# 根据属性映射查找数据
|
||||
for gltf_attr_name, draco_attr_id in attributes.items():
|
||||
if draco_attr_id in attrs:
|
||||
attr_data = attrs[draco_attr_id]
|
||||
|
||||
if gltf_attr_name == 'POSITION':
|
||||
result['vertices'] = np.array(attr_data, dtype=np.float32)
|
||||
elif gltf_attr_name == 'TEXCOORD_0':
|
||||
result['texcoords'] = np.array(attr_data, dtype=np.float32)
|
||||
elif gltf_attr_name == '_BATCHID':
|
||||
result['batch_ids'] = np.array(attr_data, dtype=np.uint32)
|
||||
elif gltf_attr_name == 'NORMAL':
|
||||
result['normals'] = np.array(attr_data, dtype=np.float32)
|
||||
elif gltf_attr_name == 'COLOR_0':
|
||||
result['colors'] = np.array(attr_data, dtype=np.float32)
|
||||
|
||||
print(f" 已提取属性: {gltf_attr_name} (ID: {draco_attr_id})")
|
||||
|
||||
# 如果没有通过attributes获取到顶点,尝试其他方式
|
||||
if result['vertices'] is None and hasattr(draco_decoder, 'get_points'):
|
||||
try:
|
||||
result['vertices'] = np.array(draco_decoder.get_points(), dtype=np.float32)
|
||||
except:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
def save_decoded_meshes(self, meshes, output_format='obj'):
|
||||
"""保存解码后的网格"""
|
||||
import os
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(self.glb_file_path))[0]
|
||||
output_dir = f"{base_name}_decoded"
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
for mesh in meshes:
|
||||
filename = f"{output_dir}/{mesh['name']}_p{mesh['primitive_idx']}.{output_format}"
|
||||
|
||||
if output_format == 'obj':
|
||||
self._save_as_obj(mesh, filename)
|
||||
elif output_format == 'ply':
|
||||
self._save_as_ply(mesh, filename)
|
||||
elif output_format == 'npz':
|
||||
self._save_as_npz(mesh, filename)
|
||||
else:
|
||||
print(f"不支持的格式: {output_format}")
|
||||
continue
|
||||
|
||||
print(f"已保存: {filename}")
|
||||
|
||||
def _save_as_obj(self, mesh, filename):
|
||||
"""保存为 OBJ 格式"""
|
||||
with open(filename, 'w') as f:
|
||||
# 写入顶点
|
||||
if mesh['vertices'] is not None:
|
||||
for v in mesh['vertices']:
|
||||
f.write(f"v {v[0]} {v[1]} {v[2]}\n")
|
||||
|
||||
# 写入纹理坐标
|
||||
if mesh['texcoords'] is not None:
|
||||
for uv in mesh['texcoords']:
|
||||
f.write(f"vt {uv[0]} {uv[1]}\n")
|
||||
|
||||
# 写入法线
|
||||
if mesh['normals'] is not None:
|
||||
for n in mesh['normals']:
|
||||
f.write(f"vn {n[0]} {n[1]} {n[2]}\n")
|
||||
|
||||
# 写入面
|
||||
if mesh['faces'] is not None:
|
||||
for face in mesh['faces']:
|
||||
# OBJ 索引从1开始
|
||||
face_indices = [str(idx + 1) for idx in face]
|
||||
f.write(f"f {' '.join(face_indices)}\n")
|
||||
|
||||
def _save_as_ply(self, mesh, filename):
|
||||
"""保存为 PLY 格式"""
|
||||
from plyfile import PlyData, PlyElement
|
||||
import numpy as np
|
||||
|
||||
vertices = mesh['vertices']
|
||||
faces = mesh['faces']
|
||||
|
||||
if vertices is None:
|
||||
return
|
||||
|
||||
# 创建顶点数据
|
||||
vertex_data = np.zeros(len(vertices), dtype=[
|
||||
('x', 'f4'), ('y', 'f4'), ('z', 'f4')
|
||||
])
|
||||
|
||||
vertex_data['x'] = vertices[:, 0]
|
||||
vertex_data['y'] = vertices[:, 1]
|
||||
vertex_data['z'] = vertices[:, 2]
|
||||
|
||||
# 创建面数据
|
||||
if faces is not None:
|
||||
face_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (3,))])
|
||||
face_data['vertex_indices'] = faces
|
||||
|
||||
# 写入文件
|
||||
vertex_element = PlyElement.describe(vertex_data, 'vertex')
|
||||
|
||||
if faces is not None:
|
||||
face_element = PlyElement.describe(face_data, 'face')
|
||||
PlyData([vertex_element, face_element], text=False).write(filename)
|
||||
else:
|
||||
PlyData([vertex_element], text=False).write(filename)
|
||||
|
||||
def _save_as_npz(self, mesh, filename):
|
||||
"""保存为 NPZ 格式"""
|
||||
np.savez(
|
||||
filename,
|
||||
vertices=mesh['vertices'],
|
||||
faces=mesh['faces'],
|
||||
texcoords=mesh['texcoords'],
|
||||
batch_ids=mesh['batch_ids'],
|
||||
normals=mesh['normals'],
|
||||
colors=mesh['colors']
|
||||
)
|
||||
|
||||
# 使用示例
|
||||
def main():
|
||||
# 初始化解析器
|
||||
parser = DracoGLBParser(r"D:\devForBdzlWork\ai_project_v1\b3dm\test\temp_glb\temp_6e895637.glb")
|
||||
|
||||
# 解析 GLB 结构
|
||||
parser.parse_glb_structure()
|
||||
|
||||
# 分析结构
|
||||
parser.analyze_structure()
|
||||
|
||||
# 解码 Draco 网格
|
||||
meshes = parser.decode_draco_meshes()
|
||||
|
||||
# 使用新增的顶点获取方法
|
||||
print("\n" + "=" * 60)
|
||||
print("顶点获取方法演示:")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取第一个网格的第一个图元的顶点
|
||||
vertices = parser.get_vertices(mesh_idx=0, primitive_idx=0)
|
||||
if vertices is not None:
|
||||
print(f"1. 获取第一个网格顶点:")
|
||||
print(f" 形状: {vertices.shape}")
|
||||
print(f" 数据类型: {vertices.dtype}")
|
||||
print(f" 前5个顶点: \n{vertices[:5]}")
|
||||
|
||||
# 2. 获取所有顶点(合并)
|
||||
all_vertices = parser.get_all_vertices()
|
||||
if all_vertices is not None:
|
||||
print(f"\n2. 获取所有网格顶点(合并):")
|
||||
print(f" 总顶点数: {len(all_vertices)}")
|
||||
print(f" 形状: {all_vertices.shape}")
|
||||
|
||||
# 3. 获取总顶点数
|
||||
total_vertices = parser.get_vertex_count()
|
||||
print(f"\n3. 总顶点数: {total_vertices}")
|
||||
|
||||
# 4. 根据网格名称获取顶点
|
||||
if meshes:
|
||||
mesh_name = meshes[0]['name']
|
||||
vertices_list = parser.get_vertices_by_mesh_name(mesh_name)
|
||||
print(f"\n4. 根据名称 '{mesh_name}' 获取的顶点:")
|
||||
for i, verts in enumerate(vertices_list):
|
||||
print(f" 图元 {i}: {verts.shape if verts is not None else 'None'}")
|
||||
|
||||
# 保存解码后的网格
|
||||
parser.save_decoded_meshes(meshes, output_format='obj')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
380
b3dm/slope_aspect_img.py
Normal file
@ -0,0 +1,380 @@
|
||||
# slope_aspect_img13
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.ndimage import sobel
|
||||
import matplotlib as mpl
|
||||
import rasterio
|
||||
from osgeo import gdal
|
||||
from matplotlib.patches import FancyArrowPatch
|
||||
from mpl_toolkits.mplot3d import proj3d
|
||||
import os
|
||||
|
||||
# 自定义3D箭头类
|
||||
class Arrow3D(FancyArrowPatch):
|
||||
def __init__(self, xs, ys, zs, *args, **kwargs):
|
||||
super().__init__((0,0), (0,0), *args, **kwargs)
|
||||
self._verts3d = xs, ys, zs
|
||||
|
||||
def do_3d_projection(self, renderer=None):
|
||||
xs3d, ys3d, zs3d = self._verts3d
|
||||
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
|
||||
self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
|
||||
return min(zs)
|
||||
|
||||
def read_dem_rasterio(filepath):
|
||||
"""使用rasterio读取DEM数据"""
|
||||
try:
|
||||
with rasterio.open(filepath) as src:
|
||||
dem_data = src.read(1)
|
||||
transform = src.transform
|
||||
bounds = src.bounds
|
||||
crs = src.crs
|
||||
|
||||
print(f"DEM信息:")
|
||||
print(f" 尺寸: {dem_data.shape}")
|
||||
print(f" 范围: {bounds}")
|
||||
print(f" 坐标系: {crs}")
|
||||
print(f" 高程范围: {np.nanmin(dem_data):.2f} - {np.nanmax(dem_data):.2f}")
|
||||
|
||||
if np.isnan(dem_data).any():
|
||||
print(" 检测到NaN值")
|
||||
dem_data = np.nan_to_num(dem_data, nan=np.nanmean(dem_data))
|
||||
|
||||
rows, cols = dem_data.shape
|
||||
x = np.linspace(bounds.left, bounds.right, cols)
|
||||
y = np.linspace(bounds.bottom, bounds.top, rows)
|
||||
X, Y = np.meshgrid(x, y)
|
||||
|
||||
return X, Y, dem_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用rasterio读取失败: {e}")
|
||||
return None
|
||||
|
||||
def read_slope_aspect_by_dem(dem_path, overall_3d_output_path=None) :
|
||||
# ---------------------- 核心:解决中文乱码配置 ----------------------
|
||||
setup_chinese_font()
|
||||
|
||||
# 尝试读取DEM数据
|
||||
print("正在读取DEM文件...")
|
||||
X, Y, Z = read_dem_rasterio(dem_path)
|
||||
|
||||
if X is None:
|
||||
raise FileNotFoundError(f"无法读取DEM文件: {dem_path}")
|
||||
|
||||
# 检查数据有效性
|
||||
if Z.size == 0 or np.all(Z == 0):
|
||||
raise ValueError("DEM数据无效或全部为0")
|
||||
|
||||
# 重采样
|
||||
if Z.shape[0] > 200 or Z.shape[1] > 200:
|
||||
print(f"原始DEM尺寸较大 ({Z.shape}),进行重采样...")
|
||||
from scipy.ndimage import zoom
|
||||
scale_factor = min(200/Z.shape[0], 200/Z.shape[1])
|
||||
Z = zoom(Z, scale_factor, order=1)
|
||||
X = zoom(X, scale_factor, order=1)
|
||||
Y = zoom(Y, scale_factor, order=1)
|
||||
print(f"重采样后尺寸: {Z.shape}")
|
||||
|
||||
# ---------------------- 2. 计算坡度和坡向 ----------------------
|
||||
print("计算坡度和坡向...")
|
||||
|
||||
dx_pixel = np.abs(X[0,1] - X[0,0]) * 111000
|
||||
dy_pixel = np.abs(Y[1,0] - Y[0,0]) * 111000
|
||||
|
||||
dx = sobel(Z, axis=1) / (2 * dx_pixel)
|
||||
dy = sobel(Z, axis=0) / (2 * dy_pixel)
|
||||
|
||||
slope_rad = np.arctan(np.sqrt(dx**2 + dy**2))
|
||||
slope_deg = slope_rad * (180 / np.pi)
|
||||
|
||||
aspect_rad = np.arctan2(-dx, dy)
|
||||
aspect_deg = aspect_rad * (180 / np.pi)
|
||||
aspect_deg[aspect_deg < 0] += 360
|
||||
|
||||
print(f"坡度范围: {np.min(slope_deg):.2f}° - {np.max(slope_deg):.2f}°")
|
||||
print(f"坡向范围: {np.min(aspect_deg):.2f}° - {np.max(aspect_deg):.2f}°")
|
||||
|
||||
# ---------------------- 3. 计算整体坡向 ----------------------
|
||||
print("\n计算整体坡向...")
|
||||
|
||||
def calculate_overall_aspect(aspect_deg, slope_deg, method='weighted_mean'):
|
||||
"""计算整体坡向"""
|
||||
aspect_rad = np.deg2rad(aspect_deg)
|
||||
|
||||
if method == 'weighted_mean':
|
||||
# 坡度加权平均法
|
||||
u = np.sin(aspect_rad)
|
||||
v = np.cos(aspect_rad)
|
||||
|
||||
weights = slope_deg.flatten()
|
||||
weighted_u = np.nansum(u.flatten() * weights) / np.nansum(weights)
|
||||
weighted_v = np.nansum(v.flatten() * weights) / np.nansum(weights)
|
||||
|
||||
weighted_aspect_rad = np.arctan2(weighted_u, weighted_v)
|
||||
weighted_aspect_deg = np.rad2deg(weighted_aspect_rad)
|
||||
|
||||
if weighted_aspect_deg < 0:
|
||||
weighted_aspect_deg += 360
|
||||
|
||||
weighted_strength = np.sqrt(weighted_u**2 + weighted_v**2)
|
||||
return weighted_aspect_deg, weighted_strength, "坡度加权平均法"
|
||||
|
||||
elif method == 'vector_mean':
|
||||
# 向量平均法
|
||||
u = np.sin(aspect_rad)
|
||||
v = np.cos(aspect_rad)
|
||||
|
||||
mean_u = np.nanmean(u)
|
||||
mean_v = np.nanmean(v)
|
||||
|
||||
mean_aspect_rad = np.arctan2(mean_u, mean_v)
|
||||
mean_aspect_deg = np.rad2deg(mean_aspect_rad)
|
||||
|
||||
if mean_aspect_deg < 0:
|
||||
mean_aspect_deg += 360
|
||||
|
||||
vector_strength = np.sqrt(mean_u**2 + mean_v**2)
|
||||
return mean_aspect_deg, vector_strength, "向量平均法"
|
||||
|
||||
# 使用坡度加权平均法计算整体坡向
|
||||
overall_aspect, overall_strength, method_name = calculate_overall_aspect(aspect_deg, slope_deg, 'weighted_mean')
|
||||
|
||||
# 将整体坡向转换为方向描述
|
||||
if overall_aspect < 22.5 or overall_aspect >= 337.5:
|
||||
overall_direction = "北"
|
||||
elif 22.5 <= overall_aspect < 67.5:
|
||||
overall_direction = "东北"
|
||||
elif 67.5 <= overall_aspect < 112.5:
|
||||
overall_direction = "东"
|
||||
elif 112.5 <= overall_aspect < 157.5:
|
||||
overall_direction = "东南"
|
||||
elif 157.5 <= overall_aspect < 202.5:
|
||||
overall_direction = "南"
|
||||
elif 202.5 <= overall_aspect < 247.5:
|
||||
overall_direction = "西南"
|
||||
elif 247.5 <= overall_aspect < 292.5:
|
||||
overall_direction = "西"
|
||||
else:
|
||||
overall_direction = "西北"
|
||||
|
||||
print(f"整体坡向 ({method_name}):")
|
||||
print(f" 角度: {overall_aspect:.1f}°")
|
||||
print(f" 方向: {overall_direction}")
|
||||
# print(f" 一致性: {overall_strength:.3f}")
|
||||
|
||||
# ---------------------- 5. 3D可视化(俯视图,包含整体坡向和关键点坡向) ----------------------
|
||||
print("\n生成3D俯视图可视化...")
|
||||
fig3d, ax3d = plt.subplots(figsize=(16, 12), subplot_kw={"projection": "3d"})
|
||||
|
||||
# 绘制地形曲面 - 俯视图需要更清晰的地形表现
|
||||
norm = mpl.colors.Normalize(vmin=np.percentile(slope_deg, 5),
|
||||
vmax=np.percentile(slope_deg, 95))
|
||||
|
||||
plot_skip = max(2, Z.shape[0] // 60) # 增加采样密度,使俯视图更清晰
|
||||
X_plot = X[::plot_skip, ::plot_skip]
|
||||
Y_plot = Y[::plot_skip, ::plot_skip]
|
||||
Z_plot = Z[::plot_skip, ::plot_skip]
|
||||
slope_plot = slope_deg[::plot_skip, ::plot_skip]
|
||||
|
||||
surf = ax3d.plot_surface(
|
||||
X_plot, Y_plot, Z_plot,
|
||||
cmap="viridis_r",
|
||||
alpha=0.85, # 增加透明度,使箭头更明显
|
||||
linewidth=0.1, # 很细的网格线
|
||||
facecolors=plt.cm.viridis_r(norm(slope_plot)),
|
||||
zorder=1
|
||||
)
|
||||
|
||||
# ---------------------- 绘制整体坡向箭头(中心位置,红色粗箭头) ----------------------
|
||||
center_x = (X.min() + X.max()) / 2
|
||||
center_y = (Y.min() + Y.max()) / 2
|
||||
|
||||
# 整体坡向箭头长度
|
||||
arrow_length_overall = 0.15 * min(X.max()-X.min(), Y.max()-Y.min())
|
||||
|
||||
# 计算整体坡向箭头方向
|
||||
scale_factor = 1.5
|
||||
overall_aspect_rad = np.deg2rad(overall_aspect)
|
||||
dx_overall = np.sin(overall_aspect_rad) * arrow_length_overall * scale_factor
|
||||
dy_overall = np.cos(overall_aspect_rad) * arrow_length_overall * scale_factor
|
||||
|
||||
# 方案2:如果希望箭头在地形上方一定高度
|
||||
terrain_max_z = Z.max() # 地形最高点
|
||||
float_height = 0.2 * terrain_max_z # 在地形最高点上方20%的高度
|
||||
|
||||
# 绘制整体坡向箭头(红色,粗)
|
||||
arrow_overall = Arrow3D(
|
||||
[center_x, center_x + dx_overall],
|
||||
[center_y, center_y + dy_overall],
|
||||
[terrain_max_z + float_height, terrain_max_z + float_height], # 在地形上方
|
||||
mutation_scale=25, # 稍大一些的箭头
|
||||
lw=5, # 更粗的线
|
||||
arrowstyle='-|>',
|
||||
color='red',
|
||||
alpha=0.98, # 更高的透明度
|
||||
zorder=20 # 更高的绘制顺序
|
||||
)
|
||||
ax3d.add_artist(arrow_overall)
|
||||
|
||||
# 在整体坡向箭头起点添加标记
|
||||
ax3d.scatter(center_x, center_y, terrain_max_z + float_height,
|
||||
color='red', s=120, edgecolor='white', linewidth=2, zorder=21)
|
||||
|
||||
# ---------------------- 整体坡向面板放置在左上角 ----------------------
|
||||
# 计算左上角位置
|
||||
panel_x = X.min() + 0.02 * (X.max() - X.min()) # 左边留2%的边距
|
||||
panel_y = Y.max() - 0.02 * (Y.max() - Y.min()) # 上边留2%的边距
|
||||
panel_z = Z.max() + 0.5 * (Z.max() - Z.min()) # 进一步提高Z坐标,确保在视野内
|
||||
|
||||
# 整体坡向面板内容
|
||||
overall_info = (
|
||||
f"整体坡向分析\n"
|
||||
f"角度: {overall_aspect:.1f}°\n"
|
||||
f"方向: {overall_direction}"
|
||||
# f"一致性: {overall_strength:.3f}"
|
||||
)
|
||||
|
||||
# 绘制整体坡向面板(左上角)
|
||||
ax3d.text(panel_x, panel_y, panel_z,
|
||||
overall_info,
|
||||
fontsize=12, fontweight='bold',
|
||||
bbox=dict(facecolor='white', alpha=0.5, boxstyle="round,pad=0.5",
|
||||
edgecolor='red', linewidth=2.5),
|
||||
ha='left', va='top', zorder=30)
|
||||
|
||||
# ---------------------- 设置3D图参数 ----------------------
|
||||
ax3d.set_xlabel("经度 (X)", fontsize=12)
|
||||
ax3d.set_ylabel("纬度 (Y)", fontsize=12)
|
||||
ax3d.set_zlabel("高程 (m)", fontsize=12)
|
||||
|
||||
# 获取文件名用于标题
|
||||
filename = os.path.basename(dem_path)
|
||||
ax3d.set_title(f"DEM三维俯视图 - 坡向分析\n文件: {filename}",
|
||||
fontsize=14, fontweight='bold', pad=20)
|
||||
|
||||
# 添加颜色条
|
||||
cbar = plt.colorbar(
|
||||
mpl.cm.ScalarMappable(norm=norm, cmap='viridis_r'),
|
||||
ax=ax3d, shrink=0.6, aspect=25, pad=0.15
|
||||
)
|
||||
cbar.set_label("坡度 (°)", fontsize=12)
|
||||
|
||||
# ---------------------- 设置为俯视图(高仰角) ----------------------
|
||||
view_elev = 85 # 接近90度的仰角,俯视效果
|
||||
view_azim = overall_aspect + 180 # 从坡向的相反方向观看,可以看到坡面
|
||||
|
||||
# 确保方位角在0-360度范围内
|
||||
view_azim = view_azim % 360
|
||||
|
||||
print(f"\n设置3D俯视图视角:")
|
||||
print(f" 整体坡向: {overall_aspect:.1f}° ({overall_direction})")
|
||||
print(f" 视角方位角: {view_azim:.1f}°")
|
||||
print(f" 视角仰角: {view_elev:.1f}°")
|
||||
|
||||
# 应用俯视图视角设置
|
||||
ax3d.view_init(elev=view_elev, azim=view_azim)
|
||||
|
||||
# 调整相机距离,使视角更广
|
||||
ax3d.dist = 9.0 # 增加相机距离
|
||||
|
||||
# 设置坐标轴范围,确保所有元素都显示
|
||||
z_min, z_max = Z.min(), Z.max()
|
||||
z_padding = 0.6 * (z_max - z_min) # 适当增加Z轴范围
|
||||
ax3d.set_zlim(z_min - 0.1*z_padding, z_max + z_padding)
|
||||
|
||||
# 设置XY轴范围
|
||||
x_margin = 0.1 * (X.max() - X.min())
|
||||
y_margin = 0.1 * (Y.max() - Y.min())
|
||||
ax3d.set_xlim(X.min() - x_margin, X.max() + x_margin)
|
||||
ax3d.set_ylim(Y.min() - y_margin, Y.max() + y_margin)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
if not overall_3d_output_path :
|
||||
overall_3d_output_path = dem_path.replace('.tif', '_slopeAspect_3D_overlook.png')
|
||||
plt.savefig(overall_3d_output_path, dpi=250, bbox_inches='tight', facecolor='white')
|
||||
# plt.show()
|
||||
|
||||
os.remove(dem_path)
|
||||
print(f"\n3D俯视图已保存: {overall_3d_output_path}")
|
||||
print("分析完成!")
|
||||
print("="*60)
|
||||
|
||||
return overall_3d_output_path
|
||||
|
||||
# ---------------------- 跨平台中文字体配置 ----------------------
|
||||
def setup_chinese_font():
|
||||
"""设置中文字体支持,兼容Windows、Linux、macOS"""
|
||||
import platform
|
||||
|
||||
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
|
||||
|
||||
system = platform.system()
|
||||
|
||||
# 尝试添加系统中文字体路径
|
||||
font_paths = []
|
||||
|
||||
if system == 'Windows':
|
||||
# Windows 字体路径
|
||||
font_paths = [
|
||||
'C:/Windows/Fonts/simhei.ttf', # 黑体
|
||||
'C:/Windows/Fonts/simkai.ttf', # 楷体
|
||||
'C:/Windows/Fonts/simsun.ttc', # 宋体
|
||||
'C:/Windows/Fonts/microsoftyahei.ttf', # 微软雅黑
|
||||
]
|
||||
elif system == 'Darwin': # macOS
|
||||
# macOS 字体路径
|
||||
font_paths = [
|
||||
'/System/Library/Fonts/PingFang.ttc', # 苹方
|
||||
'/System/Library/Fonts/STHeiti Light.ttc', # 华文黑体
|
||||
'/System/Library/Fonts/STHeiti Medium.ttc',
|
||||
'/Library/Fonts/Arial Unicode.ttf', # Arial Unicode
|
||||
]
|
||||
elif system == 'Linux':
|
||||
# Linux 字体路径
|
||||
font_paths = [
|
||||
'/usr/share/fonts/truetype/wqy/wqy-microhei.ttc', # 文泉驿微米黑
|
||||
'/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc', # 文泉驿正黑
|
||||
'/usr/share/fonts/truetype/arphic/uming.ttc', # AR PL UMing
|
||||
'/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc', # Noto
|
||||
]
|
||||
|
||||
# 尝试找到并添加第一个可用的中文字体
|
||||
font_added = False
|
||||
for font_path in font_paths:
|
||||
try:
|
||||
if os.path.exists(font_path):
|
||||
font_prop = mpl.font_manager.FontProperties(fname=font_path)
|
||||
font_name = font_prop.get_name()
|
||||
mpl.font_manager.fontManager.addfont(font_path)
|
||||
plt.rcParams['font.sans-serif'] = [font_name] + plt.rcParams['font.sans-serif']
|
||||
font_added = True
|
||||
print(f"已添加中文字体: {font_name} ({font_path})")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"添加字体 {font_path} 失败: {e}")
|
||||
continue
|
||||
|
||||
# 如果以上方法都失败,使用通用备选方案
|
||||
if not font_added:
|
||||
if system == 'Windows':
|
||||
fallback_fonts = ['SimHei', 'Microsoft YaHei', 'KaiTi', 'FangSong']
|
||||
elif system == 'Darwin':
|
||||
fallback_fonts = ['PingFang SC', 'Hiragino Sans GB', 'Apple LiGothic Medium']
|
||||
else: # Linux and others
|
||||
fallback_fonts = ['DejaVu Sans', 'WenQuanYi Micro Hei',
|
||||
'Noto Sans CJK SC', 'Heiti TC', 'AR PL UMing CN']
|
||||
|
||||
current_fonts = plt.rcParams.get('font.sans-serif', [])
|
||||
plt.rcParams['font.sans-serif'] = fallback_fonts + current_fonts
|
||||
print(f"使用备选字体方案: {fallback_fonts[:2]}...")
|
||||
|
||||
# 设置字体族
|
||||
plt.rcParams['font.family'] = 'sans-serif'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dem_path = r'D:/devForBdzlWork/ai_project_v1/b3dm/o_dem_f1cb6f69_slopeAspect.tif'
|
||||
read_slope_aspect_by_dem(dem_path)
|
||||
1062
b3dm/slope_aspect_tif.py
Normal file
419
b3dm/terrain_api.py
Normal file
@ -0,0 +1,419 @@
|
||||
from sanic import Blueprint, Request, json
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import List, Optional, Dict, Any
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import threading
|
||||
import os
|
||||
from b3dm.terrain_calculator import TerrainCalculator
|
||||
|
||||
terrain_bp = Blueprint("terrain", url_prefix="")
|
||||
MINIO_SUB_PATH = "slopeAspectPng"
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 请求模型
|
||||
class NormalVector(BaseModel):
|
||||
"""法向量模型"""
|
||||
nx: float = Field(..., description="法向量X分量")
|
||||
ny: float = Field(..., description="法向量Y分量")
|
||||
nz: float = Field(..., description="法向量Z分量")
|
||||
|
||||
@field_validator('nx', 'ny', 'nz')
|
||||
def check_finite(cls, v):
|
||||
if not np.isfinite(v):
|
||||
raise ValueError(f"值必须是有限数字,得到: {v}")
|
||||
return v
|
||||
|
||||
def to_list(self):
|
||||
return [self.nx, self.ny, self.nz]
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
"""批量请求模型"""
|
||||
vectors: List[List[float]] = Field(..., description="法向量列表")
|
||||
|
||||
@field_validator('vectors')
|
||||
def validate_vectors(cls, v):
|
||||
if len(v) > 1000:
|
||||
raise ValueError("批量处理最多支持1000个向量")
|
||||
for vec in v:
|
||||
if len(vec) != 3:
|
||||
raise ValueError("每个向量必须是长度为3的列表")
|
||||
if not all(isinstance(x, (int, float)) for x in vec):
|
||||
raise ValueError("向量元素必须是数字")
|
||||
return v
|
||||
|
||||
class PointItem(BaseModel):
|
||||
"""单个点模型"""
|
||||
x: float = Field(..., description="x坐标")
|
||||
y: float = Field(..., description="y坐标")
|
||||
z: float = Field(..., description="z坐标")
|
||||
|
||||
class PointRequest(BaseModel):
|
||||
points: List[PointItem] = Field(..., description="点列表")
|
||||
url: str = Field(..., description="URL地址")
|
||||
|
||||
@field_validator('points')
|
||||
def validate_points_count(cls, v):
|
||||
if len(v) > 10:
|
||||
raise ValueError("批量处理最多支持10个点")
|
||||
return v
|
||||
|
||||
class PreloadRequest(BaseModel):
|
||||
url: str = Field(..., description="URL地址")
|
||||
|
||||
class AnalysisConfig(BaseModel):
|
||||
"""分析配置"""
|
||||
classify: bool = Field(default=True, description="是否进行分类")
|
||||
include_percent: bool = Field(default=True, description="是否包含坡度百分比")
|
||||
include_direction: bool = Field(default=True, description="是否包含方向描述")
|
||||
|
||||
# 中间件:请求计时
|
||||
@terrain_bp.middleware("request")
|
||||
async def add_start_time(request: Request):
|
||||
request.ctx.start_time = time.time()
|
||||
|
||||
@terrain_bp.middleware("response")
|
||||
async def add_response_time(request: Request, response):
|
||||
if hasattr(request.ctx, "start_time"):
|
||||
process_time = (time.time() - request.ctx.start_time) * 1000
|
||||
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/slope")
|
||||
async def calculate_slope(request: Request):
|
||||
"""计算坡度"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = NormalVector(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 计算坡度
|
||||
result = TerrainCalculator.calculate_slope(vector.to_list())
|
||||
|
||||
# 检查是否有错误
|
||||
if "error" in result and result["error"]:
|
||||
return json(result, status=400)
|
||||
|
||||
return json({
|
||||
"success": True,
|
||||
"data": result,
|
||||
"request": {
|
||||
"input_vector": vector.to_list(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡度计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/aspect")
|
||||
async def calculate_aspect1(request: Request):
|
||||
"""计算坡向"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = NormalVector(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 计算坡向
|
||||
result = TerrainCalculator.calculate_aspect(vector.to_list())
|
||||
|
||||
# 检查是否有错误
|
||||
if "error" in result and result["error"]:
|
||||
return json(result, status=400)
|
||||
|
||||
return json({
|
||||
"success": True,
|
||||
"data": result,
|
||||
"request": {
|
||||
"input_vector": vector.to_list(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/preload3dTiles")
|
||||
async def preload_3dtiles(request: Request):
|
||||
"""预加载3dtiles地图数据"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = PreloadRequest(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 创建并启动线程
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
thread1 = threading.Thread(target=TerrainCalculator.preload_3dtiles, args=(vector.url,))
|
||||
# 启动线程
|
||||
thread1.start()
|
||||
url_prefix = extract_and_rebuild_url(vector.url)
|
||||
return json({
|
||||
"success": True,
|
||||
"data": f"{script_dir}/data_3dtiles",
|
||||
"request": {
|
||||
"input_vector": vector.model_dump(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/slopeAspect")
|
||||
async def calculate_slopeAspect(request: Request):
|
||||
"""生成坡向坡度俯视图"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = PointRequest(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 生成坡向坡度俯视图
|
||||
region_coords = [(point.x, point.y, point.z) for point in vector.points]
|
||||
overall_3d_png_name = f"o_dem_{uuid.uuid4().hex[:8]}_slopeAspect.png"
|
||||
# 创建并启动线程
|
||||
thread1 = threading.Thread(target=TerrainCalculator.generate_slopeAspect_3d_overlook, args=(region_coords, vector.url, overall_3d_png_name, MINIO_SUB_PATH))
|
||||
# 启动线程
|
||||
thread1.start()
|
||||
url_prefix = extract_and_rebuild_url(vector.url)
|
||||
return json({
|
||||
"success": True,
|
||||
"data": f"{url_prefix}/{MINIO_SUB_PATH}/{overall_3d_png_name}",
|
||||
"request": {
|
||||
"input_vector": vector.model_dump(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/slopeAspectTif")
|
||||
async def calculate_slopeAspect_tif(request: Request):
|
||||
"""生成坡向坡度tif文件"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = PointRequest(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 生成坡向坡度俯视图
|
||||
region_coords = [(point.x, point.y, point.z) for point in vector.points]
|
||||
slope_aspect_tif_name = f"o_dem_{uuid.uuid4().hex[:8]}_slopeAspect.tif"
|
||||
# 创建并启动线程
|
||||
thread1 = threading.Thread(target=TerrainCalculator.generate_slopeAspect_tif, args=(region_coords, vector.url, slope_aspect_tif_name, MINIO_SUB_PATH))
|
||||
# 启动线程
|
||||
thread1.start()
|
||||
url_prefix = extract_and_rebuild_url(vector.url)
|
||||
return json({
|
||||
"success": True,
|
||||
"data": f"{url_prefix}/{MINIO_SUB_PATH}/{slope_aspect_tif_name}",
|
||||
"request": {
|
||||
"input_vector": vector.model_dump(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/both")
|
||||
async def calculate_both(request: Request):
|
||||
"""同时计算坡度和坡向"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
vector = NormalVector(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 计算坡度和坡向
|
||||
result = TerrainCalculator.calculate_slope_aspect(vector.to_list())
|
||||
|
||||
# 检查是否有错误
|
||||
if "error" in result and result["error"]:
|
||||
return json(result, status=400)
|
||||
|
||||
return json({
|
||||
"success": True,
|
||||
"data": result,
|
||||
"request": {
|
||||
"input_vector": vector.to_list(),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"综合计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.post("/api/v1/calculate/batch")
|
||||
async def batch_calculate(request: Request):
|
||||
"""批量计算"""
|
||||
try:
|
||||
data = request.json
|
||||
if not data:
|
||||
return json({"error": "请求体不能为空"}, status=400)
|
||||
|
||||
# 验证输入
|
||||
try:
|
||||
batch_request = BatchRequest(**data)
|
||||
except Exception as e:
|
||||
return json({"error": f"输入验证失败: {str(e)}"}, status=400)
|
||||
|
||||
# 批量计算
|
||||
start_time = time.time()
|
||||
result = TerrainCalculator.batch_calculate(batch_request.vectors)
|
||||
process_time = (time.time() - start_time) * 1000
|
||||
|
||||
if "error" in result and result["error"]:
|
||||
return json(result, status=400)
|
||||
|
||||
return json({
|
||||
"success": True,
|
||||
"data": result,
|
||||
"performance": {
|
||||
"process_time_ms": process_time,
|
||||
"vectors_per_second": len(batch_request.vectors) / (process_time / 1000) if process_time > 0 else 0
|
||||
},
|
||||
"request": {
|
||||
"vector_count": len(batch_request.vectors),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量计算API错误: {e}")
|
||||
return json({
|
||||
"error": f"服务器内部错误: {str(e)}",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
@terrain_bp.get("/api/v1/example")
|
||||
async def get_examples(request: Request):
|
||||
"""获取示例数据"""
|
||||
examples = {
|
||||
"flat": {
|
||||
"nx": 0.0,
|
||||
"ny": 0.0,
|
||||
"nz": 1.0,
|
||||
"expected_slope": 0.0,
|
||||
"description": "完全水平面"
|
||||
},
|
||||
"north_slope_30": {
|
||||
"nx": 0.0,
|
||||
"ny": -0.5,
|
||||
"nz": 0.8660254,
|
||||
"expected_slope": 30.0,
|
||||
"expected_aspect": 0.0,
|
||||
"description": "朝北30度斜坡"
|
||||
},
|
||||
"east_slope_45": {
|
||||
"nx": 0.7071068,
|
||||
"ny": 0.0,
|
||||
"nz": 0.7071068,
|
||||
"expected_slope": 45.0,
|
||||
"expected_aspect": 90.0,
|
||||
"description": "朝东45度斜坡"
|
||||
},
|
||||
"vertical": {
|
||||
"nx": 1.0,
|
||||
"ny": 0.0,
|
||||
"nz": 0.0,
|
||||
"expected_slope": 90.0,
|
||||
"description": "垂直面"
|
||||
}
|
||||
}
|
||||
|
||||
return json({
|
||||
"examples": examples,
|
||||
"count": len(examples)
|
||||
})
|
||||
|
||||
def extract_and_rebuild_url(url):
|
||||
"""提取URL的三部分并重建"""
|
||||
# 解析URL
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
|
||||
# 1. 提取协议部分 (http/https)
|
||||
scheme = parsed.scheme or "http" # 如果没有协议,默认用http
|
||||
|
||||
# 2. 提取IP端口/主机部分
|
||||
netloc = parsed.netloc
|
||||
|
||||
# 3. 提取第一个路径分段
|
||||
path = parsed.path.strip("/") # 去掉首尾的斜杠
|
||||
path_parts = path.split("/")
|
||||
|
||||
if path_parts and path_parts[0]:
|
||||
first_segment = path_parts[0]
|
||||
else:
|
||||
first_segment = ""
|
||||
|
||||
# 重建URL
|
||||
if first_segment:
|
||||
rebuilt_url = f"{scheme}://{netloc}/{first_segment}"
|
||||
else:
|
||||
rebuilt_url = f"{scheme}://{netloc}"
|
||||
|
||||
return rebuilt_url
|
||||
392
b3dm/terrain_calculator.py
Normal file
@ -0,0 +1,392 @@
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem
|
||||
import b3dm.slope_aspect_img as slope_aspect_img
|
||||
import b3dm.slope_aspect_tif as slope_aspect_tif
|
||||
from b3dm.tileset_data_source import TilesetDataSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_data_source = None
|
||||
|
||||
|
||||
class TerrainCalculator:
|
||||
"""地形坡度和坡向计算器"""
|
||||
|
||||
def preload_3dtiles(url: str) :
|
||||
# 下载3dtiles地图数据
|
||||
_data_source = TilesetDataSource(url)
|
||||
_data_source.dowload_map_data(url)
|
||||
|
||||
if not _data_source.tileset_path :
|
||||
logger.info(f"下载地图数据失败: {url}")
|
||||
return "下载地图数据失败", None
|
||||
|
||||
def generate_slopeAspect_3d_overlook(region_coords, url, overall_3d_png_name, minio_sub_path) :
|
||||
# 下载3dtiles地图数据
|
||||
_data_source = TilesetDataSource(url)
|
||||
_data_source.dowload_map_data(url)
|
||||
|
||||
if not _data_source.tileset_path :
|
||||
logger.info(f"下载地图数据失败: {url},{region_coords}")
|
||||
return "下载地图数据失败", None
|
||||
|
||||
tileset_path = _data_source.tileset_path
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
|
||||
data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords)
|
||||
|
||||
if not os.path.exists(dem_path) :
|
||||
logger.info(f"生成坡度坡向俯视图失败: {url},{region_coords}")
|
||||
return "生成坡度坡向俯视图失败", None
|
||||
|
||||
overall_3d_png_path = os.path.join(script_dir, overall_3d_png_name)
|
||||
slope_aspect_img.read_slope_aspect_by_dem(dem_path, overall_3d_png_path)
|
||||
logger.info(f"生成成功: {url},{region_coords},{overall_3d_png_path}")
|
||||
|
||||
entry_bucket, _ = _data_source.parse_minio_url(url);
|
||||
success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{overall_3d_png_name}", overall_3d_png_path)
|
||||
if success :
|
||||
return "生成成功", minio_path
|
||||
else :
|
||||
return "生成失败", None
|
||||
|
||||
def generate_slopeAspect_tif(region_coords, url, slope_aspect_tif_name, minio_sub_path) :
|
||||
# 下载3dtiles地图数据
|
||||
_data_source = TilesetDataSource(url)
|
||||
_data_source.dowload_map_data(url)
|
||||
|
||||
if not _data_source.tileset_path :
|
||||
logger.info(f"下载地图数据失败: {url},{region_coords}")
|
||||
return "下载地图数据失败", None
|
||||
|
||||
tileset_path = _data_source.tileset_path
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dem_path = os.path.join(script_dir, f"o_dem_{uuid.uuid4().hex[:8]}.tif")
|
||||
data_3dtiles_to_dem.generate_dem(tileset_path, dem_path, region_coords)
|
||||
|
||||
if not os.path.exists(dem_path) :
|
||||
logger.info(f"生成坡度坡向tif失败: {url},{region_coords}")
|
||||
return "生成坡度坡向tif失败", None
|
||||
|
||||
slope_aspect_tif_path = os.path.join(script_dir, slope_aspect_tif_name)
|
||||
slope_aspect_tif.create_slope_aspect(dem_path, 'combined', slope_aspect_tif_path)
|
||||
logger.info(f"生成成功: {url},{region_coords},{slope_aspect_tif_path}")
|
||||
|
||||
entry_bucket, _ = _data_source.parse_minio_url(url);
|
||||
success, minio_path = _data_source.upload_file(entry_bucket, f"{minio_sub_path}/{slope_aspect_tif_name}", slope_aspect_tif_path)
|
||||
if success :
|
||||
return "生成成功", minio_path
|
||||
else :
|
||||
return "生成失败", None
|
||||
|
||||
@staticmethod
|
||||
def validate_vector(vector: List[float]) -> bool:
|
||||
"""验证输入向量是否有效"""
|
||||
if len(vector) != 3:
|
||||
return False
|
||||
if not all(isinstance(v, (int, float)) for v in vector):
|
||||
return False
|
||||
norm = np.linalg.norm(vector)
|
||||
return norm > 1e-10 # 避免零向量
|
||||
|
||||
@staticmethod
|
||||
def normalize_vector(vector: List[float]) -> np.ndarray:
|
||||
"""向量归一化"""
|
||||
arr = np.array(vector, dtype=np.float64)
|
||||
norm = np.linalg.norm(arr)
|
||||
return arr / norm if norm > 0 else arr
|
||||
|
||||
@staticmethod
|
||||
def calculate_slope(normal_vector: List[float]) -> Dict[str, Any]:
|
||||
"""
|
||||
计算坡度
|
||||
|
||||
Args:
|
||||
normal_vector: 法向量 [nx, ny, nz]
|
||||
|
||||
Returns:
|
||||
dict: 包含坡度(度)和相关信息
|
||||
"""
|
||||
try:
|
||||
# 验证输入
|
||||
if not TerrainCalculator.validate_vector(normal_vector):
|
||||
return {
|
||||
"error": "无效的法向量,必须是长度为3的数值列表且不能为零向量",
|
||||
"slope_deg": None
|
||||
}
|
||||
|
||||
# 归一化
|
||||
n = TerrainCalculator.normalize_vector(normal_vector)
|
||||
|
||||
# 计算坡度(使用arccos法)
|
||||
nz_abs = abs(n[2])
|
||||
|
||||
# 处理数值误差
|
||||
if nz_abs > 1.0:
|
||||
nz_abs = 1.0
|
||||
elif nz_abs < 0.0:
|
||||
nz_abs = 0.0
|
||||
|
||||
# 计算坡度(弧度)
|
||||
if abs(nz_abs - 1.0) < 1e-10: # 完全水平
|
||||
slope_rad = 0.0
|
||||
elif abs(nz_abs) < 1e-10: # 完全垂直
|
||||
slope_rad = np.pi / 2
|
||||
else:
|
||||
slope_rad = np.arccos(nz_abs)
|
||||
|
||||
# 转换为度
|
||||
slope_deg = np.degrees(slope_rad)
|
||||
|
||||
# 计算坡度百分比
|
||||
slope_percent = np.tan(slope_rad) * 100 if slope_rad < np.pi/2 else float('inf')
|
||||
|
||||
return {
|
||||
"slope_deg": float(slope_deg),
|
||||
"slope_rad": float(slope_rad),
|
||||
"slope_percent": float(slope_percent),
|
||||
"normalized_vector": n.tolist(),
|
||||
"classification": TerrainCalculator.classify_slope(slope_deg)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡度计算错误: {e}")
|
||||
return {
|
||||
"error": f"计算失败: {str(e)}",
|
||||
"slope_deg": None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def calculate_aspect(normal_vector: List[float]) -> Dict[str, Any]:
|
||||
"""
|
||||
计算坡向
|
||||
|
||||
Args:
|
||||
normal_vector: 法向量 [nx, ny, nz]
|
||||
|
||||
Returns:
|
||||
dict: 包含坡向(度)和相关信息
|
||||
"""
|
||||
try:
|
||||
# 验证输入
|
||||
if not TerrainCalculator.validate_vector(normal_vector):
|
||||
return {
|
||||
"error": "无效的法向量,必须是长度为3的数值列表且不能为零向量",
|
||||
"aspect_deg": None
|
||||
}
|
||||
|
||||
# 归一化
|
||||
n = TerrainCalculator.normalize_vector(normal_vector)
|
||||
|
||||
# 检查是否为水平面
|
||||
nx, ny, nz = n
|
||||
horizontal_magnitude = np.sqrt(nx*nx + ny*ny)
|
||||
|
||||
if horizontal_magnitude < 1e-10: # 水平面,坡向无定义
|
||||
return {
|
||||
"aspect_deg": None,
|
||||
"aspect_rad": None,
|
||||
"is_flat": True,
|
||||
"message": "水平面,坡向无定义",
|
||||
"normalized_vector": n.tolist()
|
||||
}
|
||||
|
||||
# 计算原始坡向(四象限反正切)
|
||||
# 注意:arctan2(nx, ny) 不是 arctan2(ny, nx)
|
||||
raw_angle_rad = np.arctan2(nx, ny)
|
||||
|
||||
# 转换为坡向(下坡方向 = 法向量方向 + 180°)
|
||||
aspect_rad = raw_angle_rad + np.pi
|
||||
|
||||
# 转换为度
|
||||
aspect_deg = np.degrees(aspect_rad)
|
||||
|
||||
# 归一化到 [0, 360) 范围
|
||||
aspect_deg = aspect_deg % 360.0
|
||||
|
||||
# 转换为八方向
|
||||
direction = TerrainCalculator.aspect_to_direction(aspect_deg)
|
||||
|
||||
return {
|
||||
"aspect_deg": float(aspect_deg),
|
||||
"aspect_rad": float(aspect_rad % (2*np.pi)),
|
||||
"direction": direction,
|
||||
"is_flat": False,
|
||||
"normalized_vector": n.tolist(),
|
||||
"raw_angle_deg": float(np.degrees(raw_angle_rad))
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"坡向计算错误: {e}")
|
||||
return {
|
||||
"error": f"计算失败: {str(e)}",
|
||||
"aspect_deg": None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def calculate_slope_aspect(normal_vector: List[float]) -> Dict[str, Any]:
|
||||
"""
|
||||
同时计算坡度和坡向
|
||||
|
||||
Args:
|
||||
normal_vector: 法向量 [nx, ny, nz]
|
||||
|
||||
Returns:
|
||||
dict: 包含坡度和坡向的综合结果
|
||||
"""
|
||||
try:
|
||||
slope_result = TerrainCalculator.calculate_slope(normal_vector)
|
||||
aspect_result = TerrainCalculator.calculate_aspect(normal_vector)
|
||||
|
||||
result = {
|
||||
"slope": slope_result,
|
||||
"aspect": aspect_result,
|
||||
"input_vector": normal_vector,
|
||||
"calculation_time": None # 可在调用处添加时间戳
|
||||
}
|
||||
|
||||
# 如果有错误,合并错误信息
|
||||
errors = []
|
||||
if "error" in slope_result and slope_result["error"]:
|
||||
errors.append(f"坡度: {slope_result['error']}")
|
||||
if "error" in aspect_result and aspect_result["error"]:
|
||||
errors.append(f"坡向: {aspect_result['error']}")
|
||||
|
||||
if errors:
|
||||
result["errors"] = errors
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"综合计算错误: {e}")
|
||||
return {
|
||||
"error": f"综合计算失败: {str(e)}"
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def classify_slope(slope_deg: float) -> Dict[str, Any]:
|
||||
"""坡度分类"""
|
||||
if slope_deg < 2:
|
||||
return {"category": "平坦", "level": 0, "description": "基本平坦"}
|
||||
elif slope_deg < 5:
|
||||
return {"category": "缓坡", "level": 1, "description": "适合农业"}
|
||||
elif slope_deg < 15:
|
||||
return {"category": "斜坡", "level": 2, "description": "适合建设"}
|
||||
elif slope_deg < 30:
|
||||
return {"category": "陡坡", "level": 3, "description": "需要工程措施"}
|
||||
elif slope_deg < 45:
|
||||
return {"category": "急陡坡", "level": 4, "description": "高风险区域"}
|
||||
else:
|
||||
return {"category": "峭壁", "level": 5, "description": "危险区域"}
|
||||
|
||||
@staticmethod
|
||||
def aspect_to_direction(aspect_deg: float) -> Dict[str, Any]:
|
||||
"""将坡向转换为八方向"""
|
||||
directions = ["北", "东北", "东", "东南", "南", "西南", "西", "西北"]
|
||||
|
||||
# 计算方向索引 (45°一个区间)
|
||||
index = int((aspect_deg + 22.5) % 360 / 45)
|
||||
|
||||
return {
|
||||
"chinese": directions[index],
|
||||
"english": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"][index],
|
||||
"degree_range": {
|
||||
"min": (index * 45 - 22.5) % 360,
|
||||
"max": (index * 45 + 22.5) % 360
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def batch_calculate(vectors: List[List[float]]) -> Dict[str, Any]:
|
||||
"""批量计算多个法向量"""
|
||||
try:
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
for i, vec in enumerate(vectors):
|
||||
try:
|
||||
result = TerrainCalculator.calculate_slope_aspect(vec)
|
||||
result["index"] = i
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
errors.append({
|
||||
"index": i,
|
||||
"vector": vec,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
return {
|
||||
"total": len(vectors),
|
||||
"successful": len(results),
|
||||
"failed": len(errors),
|
||||
"results": results,
|
||||
"errors": errors,
|
||||
"statistics": TerrainCalculator.calculate_statistics(results)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量计算错误: {e}")
|
||||
return {"error": f"批量计算失败: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
def calculate_statistics(results: List[Dict]) -> Dict[str, Any]:
|
||||
"""计算统计信息"""
|
||||
if not results:
|
||||
return {}
|
||||
|
||||
slope_values = []
|
||||
aspect_values = []
|
||||
|
||||
for r in results:
|
||||
if "slope" in r and "slope_deg" in r["slope"] and r["slope"]["slope_deg"] is not None:
|
||||
slope_values.append(r["slope"]["slope_deg"])
|
||||
if "aspect" in r and "aspect_deg" in r["aspect"] and r["aspect"]["aspect_deg"] is not None:
|
||||
aspect_values.append(r["aspect"]["aspect_deg"])
|
||||
|
||||
if slope_values:
|
||||
slope_arr = np.array(slope_values)
|
||||
aspect_arr = np.array(aspect_values) if aspect_values else np.array([])
|
||||
|
||||
stats = {
|
||||
"slope": {
|
||||
"count": len(slope_values),
|
||||
"mean": float(np.mean(slope_arr)),
|
||||
"std": float(np.std(slope_arr)),
|
||||
"min": float(np.min(slope_arr)),
|
||||
"max": float(np.max(slope_arr)),
|
||||
"median": float(np.median(slope_arr))
|
||||
}
|
||||
}
|
||||
|
||||
if aspect_values:
|
||||
# 坡向统计需要循环统计
|
||||
stats["aspect"] = {
|
||||
"count": len(aspect_values),
|
||||
"mean_vector": TerrainCalculator.circular_mean(aspect_arr),
|
||||
"concentration": TerrainCalculator.circular_concentration(aspect_arr)
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def circular_mean(angles_deg: np.ndarray) -> float:
|
||||
"""计算循环数据的平均值(角度)"""
|
||||
angles_rad = np.radians(angles_deg)
|
||||
x = np.mean(np.cos(angles_rad))
|
||||
y = np.mean(np.sin(angles_rad))
|
||||
mean_rad = np.arctan2(y, x)
|
||||
return np.degrees(mean_rad) % 360
|
||||
|
||||
@staticmethod
|
||||
def circular_concentration(angles_deg: np.ndarray) -> float:
|
||||
"""计算角度数据的集中度 (0-1)"""
|
||||
angles_rad = np.radians(angles_deg)
|
||||
x = np.mean(np.cos(angles_rad))
|
||||
y = np.mean(np.sin(angles_rad))
|
||||
return np.sqrt(x*x + y*y)
|
||||
105
b3dm/tileset_data_source.py
Normal file
@ -0,0 +1,105 @@
|
||||
# tileset_data_source_py
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from b3dm.data_3dtiles_manager import MinIO3DTilesManager
|
||||
import b3dm.data_3dtiles_to_dem as data_3dtiles_to_dem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENDPOINT_URL = "222.212.85.86:9000"
|
||||
ACCESS_KEY = "WuRenJi"
|
||||
SECRET_KEY = "WRJ@2024"
|
||||
|
||||
class TilesetDataSource:
|
||||
"""使用py3dtiles库的数据源"""
|
||||
|
||||
def __init__(self, url: str, cache_size: int = 1000):
|
||||
self.url = url
|
||||
self.tileset_path = None
|
||||
self.tileset_dir = None
|
||||
self._crs = "EPSG:4979"
|
||||
|
||||
def parse_minio_url(self, url):
|
||||
manager = MinIO3DTilesManager(
|
||||
endpoint_url=ENDPOINT_URL,
|
||||
access_key=ACCESS_KEY,
|
||||
secret_key=SECRET_KEY,
|
||||
secure=False
|
||||
)
|
||||
return manager.parse_minio_url(url)
|
||||
|
||||
def upload_file(self, bucket_name, object_name, file_path):
|
||||
manager = MinIO3DTilesManager(
|
||||
endpoint_url=ENDPOINT_URL,
|
||||
access_key=ACCESS_KEY,
|
||||
secret_key=SECRET_KEY,
|
||||
secure=False
|
||||
)
|
||||
flag, path = manager.upload_file(bucket_name, object_name, file_path)
|
||||
if flag :
|
||||
os.remove(file_path)
|
||||
return flag, path
|
||||
|
||||
|
||||
def dowload_map_data(self, url: str) :
|
||||
# 下载3dtiles地图数据
|
||||
manager = MinIO3DTilesManager(
|
||||
endpoint_url=ENDPOINT_URL,
|
||||
access_key=ACCESS_KEY,
|
||||
secret_key=SECRET_KEY,
|
||||
secure=False
|
||||
)
|
||||
success, tileset_path = manager.download_full_tileset(
|
||||
tileset_url=url,
|
||||
save_dir=f"data_3dtiles",
|
||||
region_filter=None
|
||||
)
|
||||
if success :
|
||||
self.tileset_path = os.path.abspath(tileset_path)
|
||||
self.tileset_dir = os.path.dirname(tileset_path)
|
||||
|
||||
async def get_points_in_polygon(self, polygon_coords, z_range=None):
|
||||
"""获取多边形内的点数据"""
|
||||
points = data_3dtiles_to_dem.parse_tileset(self.tileset_path, polygon_coords)
|
||||
return np.array(points)
|
||||
|
||||
async def get_data_bounds(self) -> Dict[str, List[float]]:
|
||||
"""获取数据边界"""
|
||||
|
||||
bounds = {
|
||||
"min": [float('inf'), float('inf'), float('inf')],
|
||||
"max": [-float('inf'), -float('inf'), -float('inf')]
|
||||
}
|
||||
|
||||
if self._tileset and hasattr(self._tileset.root_tile, 'bounding_volume'):
|
||||
bv = self._tileset.root_tile.bounding_volume
|
||||
if hasattr(bv, 'get_corners'):
|
||||
corners = bv.get_corners()
|
||||
if corners is not None:
|
||||
bounds["min"] = corners.min(axis=0).tolist()
|
||||
bounds["max"] = corners.max(axis=0).tolist()
|
||||
|
||||
return bounds
|
||||
|
||||
def get_crs(self) -> str:
|
||||
return self._crs
|
||||
|
||||
async def main() :
|
||||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SCRIPT_PAR_DIR = os.path.dirname(SCRIPT_DIR)
|
||||
tileset_path = os.path.join(SCRIPT_PAR_DIR, "data/3dtiles/tileset.json")
|
||||
data_source_3d_tiles = TilesetDataSource(tileset_path)
|
||||
# tileSet = TileSet()
|
||||
# path = Path(tileset_path)
|
||||
# tileset_data = tileSet.from_file(path)
|
||||
print("====================================================")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
395
b3dm/tileset_to_ply.py
Normal file
@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
3D Tiles Tileset to PLY Converter
|
||||
将整个3D Tiles tileset转换为单个PLY文件
|
||||
"""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import DracoPy
|
||||
except ImportError:
|
||||
print("警告: DracoPy库未安装,无法处理Draco压缩的数据")
|
||||
print("请运行: pip install DracoPy")
|
||||
DracoPy = None
|
||||
|
||||
class TilesetToPLYConverter:
|
||||
def __init__(self):
|
||||
self.all_vertices = []
|
||||
self.vertex_count = 0
|
||||
|
||||
def multiply_matrix_vector(self, matrix, vector):
|
||||
"""4x4矩阵与4D向量相乘"""
|
||||
# matrix是16个元素的列表,按列主序排列
|
||||
# 转换为4x4矩阵(行主序)
|
||||
m = [
|
||||
[matrix[0], matrix[4], matrix[8], matrix[12]],
|
||||
[matrix[1], matrix[5], matrix[9], matrix[13]],
|
||||
[matrix[2], matrix[6], matrix[10], matrix[14]],
|
||||
[matrix[3], matrix[7], matrix[11], matrix[15]]
|
||||
]
|
||||
|
||||
# 向量扩展为齐次坐标 [x, y, z, 1]
|
||||
v = [vector[0], vector[1], vector[2], 1.0]
|
||||
|
||||
# 矩阵乘法
|
||||
result = [
|
||||
m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2] + m[0][3]*v[3],
|
||||
m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2] + m[1][3]*v[3],
|
||||
m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2] + m[2][3]*v[3]
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def multiply_matrices(self, m1, m2):
|
||||
"""两个4x4矩阵相乘"""
|
||||
# 将16元素列表转换为4x4矩阵
|
||||
def list_to_matrix(lst):
|
||||
return [
|
||||
[lst[0], lst[4], lst[8], lst[12]],
|
||||
[lst[1], lst[5], lst[9], lst[13]],
|
||||
[lst[2], lst[6], lst[10], lst[14]],
|
||||
[lst[3], lst[7], lst[11], lst[15]]
|
||||
]
|
||||
|
||||
def matrix_to_list(mat):
|
||||
return [
|
||||
mat[0][0], mat[1][0], mat[2][0], mat[3][0],
|
||||
mat[0][1], mat[1][1], mat[2][1], mat[3][1],
|
||||
mat[0][2], mat[1][2], mat[2][2], mat[3][2],
|
||||
mat[0][3], mat[1][3], mat[2][3], mat[3][3]
|
||||
]
|
||||
|
||||
a = list_to_matrix(m1)
|
||||
b = list_to_matrix(m2)
|
||||
|
||||
result = [[0 for _ in range(4)] for _ in range(4)]
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
for k in range(4):
|
||||
result[i][j] += a[i][k] * b[k][j]
|
||||
|
||||
return matrix_to_list(result)
|
||||
|
||||
def apply_transform_to_vertices(self, vertices, transform_matrix):
|
||||
"""对顶点应用变换矩阵"""
|
||||
if not transform_matrix:
|
||||
return vertices
|
||||
|
||||
transformed_vertices = []
|
||||
for vertex in vertices:
|
||||
transformed = self.multiply_matrix_vector(transform_matrix, vertex)
|
||||
transformed_vertices.append(transformed)
|
||||
|
||||
return transformed_vertices
|
||||
|
||||
def parse_tileset_json(self, tileset_path, parent_transform=None):
|
||||
"""解析tileset.json文件,收集B3DM文件和变换矩阵"""
|
||||
try:
|
||||
with open(tileset_path, 'r', encoding='utf-8') as f:
|
||||
tileset_data = json.load(f)
|
||||
|
||||
b3dm_files = []
|
||||
|
||||
def process_node(node, base_path, accumulated_transform):
|
||||
# 获取当前节点的变换矩阵
|
||||
current_transform = node.get('transform')
|
||||
|
||||
# 计算累积变换矩阵
|
||||
if current_transform and accumulated_transform:
|
||||
# 矩阵相乘:accumulated_transform * current_transform
|
||||
final_transform = self.multiply_matrices(accumulated_transform, current_transform)
|
||||
elif current_transform:
|
||||
final_transform = current_transform
|
||||
else:
|
||||
final_transform = accumulated_transform
|
||||
|
||||
if 'content' in node and 'uri' in node['content']:
|
||||
uri = node['content']['uri']
|
||||
if uri.endswith('.b3dm'):
|
||||
full_path = os.path.join(base_path, uri)
|
||||
if os.path.exists(full_path):
|
||||
b3dm_files.append((full_path, final_transform))
|
||||
elif uri.endswith('.json'):
|
||||
# 递归处理子tileset
|
||||
sub_tileset_path = os.path.join(base_path, uri)
|
||||
if os.path.exists(sub_tileset_path):
|
||||
sub_files = self.parse_tileset_json(sub_tileset_path, final_transform)
|
||||
b3dm_files.extend(sub_files)
|
||||
|
||||
if 'children' in node:
|
||||
for child in node['children']:
|
||||
process_node(child, base_path, final_transform)
|
||||
|
||||
base_path = os.path.dirname(tileset_path)
|
||||
if 'root' in tileset_data:
|
||||
process_node(tileset_data['root'], base_path, parent_transform)
|
||||
|
||||
return b3dm_files
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析tileset.json时出错: {e}")
|
||||
return []
|
||||
|
||||
def parse_b3dm_file(self, file_path):
|
||||
"""解析B3DM文件"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
# 读取B3DM头部
|
||||
magic = f.read(4)
|
||||
if magic != b'b3dm':
|
||||
print(f"警告: {file_path} 不是有效的B3DM文件")
|
||||
return None
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
feature_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
feature_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
batch_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
batch_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
|
||||
# 跳过feature table和batch table
|
||||
f.seek(28 + feature_table_json_byte_length + feature_table_binary_byte_length +
|
||||
batch_table_json_byte_length + batch_table_binary_byte_length)
|
||||
|
||||
# 读取glTF数据
|
||||
gltf_data = f.read()
|
||||
return self.parse_gltf_data(gltf_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析B3DM文件 {file_path} 失败: {e}")
|
||||
return None
|
||||
|
||||
def parse_gltf_data(self, gltf_data):
|
||||
"""解析glTF数据"""
|
||||
try:
|
||||
# 检查是否为GLB格式
|
||||
if gltf_data[:4] == b'glTF':
|
||||
return self.parse_glb_data(gltf_data)
|
||||
else:
|
||||
# 尝试作为JSON解析
|
||||
gltf_json = json.loads(gltf_data.decode('utf-8'))
|
||||
return self.extract_vertices_from_gltf(gltf_json, None)
|
||||
except Exception as e:
|
||||
print(f"解析glTF数据失败: {e}")
|
||||
return None
|
||||
|
||||
def parse_glb_data(self, glb_data):
|
||||
"""解析GLB格式的glTF数据"""
|
||||
try:
|
||||
# GLB头部: magic(4) + version(4) + length(4)
|
||||
magic = glb_data[:4]
|
||||
if magic != b'glTF':
|
||||
return None
|
||||
|
||||
version = struct.unpack('<I', glb_data[4:8])[0]
|
||||
total_length = struct.unpack('<I', glb_data[8:12])[0]
|
||||
|
||||
offset = 12
|
||||
json_data = None
|
||||
binary_data = None
|
||||
|
||||
# 读取chunks
|
||||
while offset < len(glb_data):
|
||||
if offset + 8 > len(glb_data):
|
||||
break
|
||||
|
||||
chunk_length = struct.unpack('<I', glb_data[offset:offset+4])[0]
|
||||
chunk_type = glb_data[offset+4:offset+8]
|
||||
chunk_data = glb_data[offset+8:offset+8+chunk_length]
|
||||
|
||||
if chunk_type == b'JSON':
|
||||
json_data = json.loads(chunk_data.decode('utf-8'))
|
||||
elif chunk_type == b'BIN\x00':
|
||||
binary_data = chunk_data
|
||||
|
||||
offset += 8 + chunk_length
|
||||
# 对齐到4字节边界
|
||||
offset = (offset + 3) & ~3
|
||||
|
||||
if json_data:
|
||||
return self.extract_vertices_from_gltf(json_data, binary_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析GLB数据失败: {e}")
|
||||
return None
|
||||
|
||||
def extract_vertices_from_gltf(self, gltf_json, binary_data):
|
||||
"""从glTF JSON中提取顶点数据"""
|
||||
vertices = []
|
||||
|
||||
try:
|
||||
# 检查是否使用了Draco压缩
|
||||
if 'extensionsUsed' in gltf_json and 'KHR_draco_mesh_compression' in gltf_json['extensionsUsed']:
|
||||
if DracoPy is None:
|
||||
print("警告: 检测到Draco压缩但DracoPy未安装")
|
||||
return vertices
|
||||
return self.extract_draco_vertices(gltf_json, binary_data)
|
||||
|
||||
# 处理标准glTF格式
|
||||
if 'meshes' not in gltf_json:
|
||||
return vertices
|
||||
|
||||
for mesh in gltf_json['meshes']:
|
||||
for primitive in mesh['primitives']:
|
||||
if 'attributes' in primitive and 'POSITION' in primitive['attributes']:
|
||||
position_accessor_index = primitive['attributes']['POSITION']
|
||||
|
||||
if 'accessors' in gltf_json and position_accessor_index < len(gltf_json['accessors']):
|
||||
accessor = gltf_json['accessors'][position_accessor_index]
|
||||
buffer_view_index = accessor['bufferView']
|
||||
|
||||
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
|
||||
buffer_view = gltf_json['bufferViews'][buffer_view_index]
|
||||
buffer_index = buffer_view['buffer']
|
||||
byte_offset = buffer_view.get('byteOffset', 0) + accessor.get('byteOffset', 0)
|
||||
|
||||
if binary_data and buffer_index == 0:
|
||||
# 从二进制数据中读取顶点
|
||||
component_type = accessor['componentType']
|
||||
count = accessor['count']
|
||||
|
||||
if component_type == 5126: # FLOAT
|
||||
vertex_data = struct.unpack(f'<{count*3}f',
|
||||
binary_data[byte_offset:byte_offset+count*12])
|
||||
for i in range(0, len(vertex_data), 3):
|
||||
vertices.append([vertex_data[i], vertex_data[i+1], vertex_data[i+2]])
|
||||
|
||||
except Exception as e:
|
||||
print(f"提取顶点数据失败: {e}")
|
||||
|
||||
return vertices
|
||||
|
||||
def extract_draco_vertices(self, gltf_json, binary_data):
|
||||
"""提取Draco压缩的顶点数据"""
|
||||
vertices = []
|
||||
|
||||
try:
|
||||
if 'meshes' not in gltf_json:
|
||||
return vertices
|
||||
|
||||
for mesh in gltf_json['meshes']:
|
||||
for primitive in mesh['primitives']:
|
||||
if 'extensions' in primitive and 'KHR_draco_mesh_compression' in primitive['extensions']:
|
||||
draco_ext = primitive['extensions']['KHR_draco_mesh_compression']
|
||||
buffer_view_index = draco_ext['bufferView']
|
||||
|
||||
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
|
||||
buffer_view = gltf_json['bufferViews'][buffer_view_index]
|
||||
byte_offset = buffer_view.get('byteOffset', 0)
|
||||
byte_length = buffer_view['byteLength']
|
||||
|
||||
if binary_data:
|
||||
draco_data = binary_data[byte_offset:byte_offset+byte_length]
|
||||
|
||||
# 使用DracoPy解码
|
||||
mesh_data = DracoPy.decode(draco_data)
|
||||
if hasattr(mesh_data, 'points'):
|
||||
points = mesh_data.points
|
||||
for point in points:
|
||||
vertices.append([float(point[0]), float(point[1]), float(point[2])])
|
||||
|
||||
except Exception as e:
|
||||
print(f"解码Draco数据失败: {e}")
|
||||
|
||||
return vertices
|
||||
|
||||
def save_ply_file(self, output_path):
|
||||
"""保存PLY文件"""
|
||||
try:
|
||||
with open(output_path, 'w') as f:
|
||||
# 写入PLY头部
|
||||
f.write("ply\n")
|
||||
f.write("format ascii 1.0\n")
|
||||
f.write(f"element vertex {len(self.all_vertices)}\n")
|
||||
f.write("property float x\n")
|
||||
f.write("property float y\n")
|
||||
f.write("property float z\n")
|
||||
f.write("end_header\n")
|
||||
|
||||
# 写入顶点数据
|
||||
for vertex in self.all_vertices:
|
||||
f.write(f"{vertex[0]} {vertex[1]} {vertex[2]}\n")
|
||||
|
||||
print(f"PLY文件已保存: {output_path}")
|
||||
print(f"总顶点数: {len(self.all_vertices)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存PLY文件失败: {e}")
|
||||
|
||||
def convert_tileset_to_ply(self, tileset_path, output_path):
|
||||
"""将整个tileset转换为PLY文件"""
|
||||
print(f"开始处理tileset: {tileset_path}")
|
||||
|
||||
# 解析主tileset.json
|
||||
tileset_data = self.parse_tileset_json(tileset_path)
|
||||
if not tileset_data:
|
||||
print("无法解析tileset.json文件")
|
||||
return False
|
||||
|
||||
# 获取基础路径
|
||||
base_path = os.path.dirname(tileset_path)
|
||||
|
||||
# 提取所有b3dm文件和变换矩阵
|
||||
b3dm_files = self.parse_tileset_json(tileset_path)
|
||||
print(f"找到 {len(b3dm_files)} 个B3DM文件")
|
||||
|
||||
if not b3dm_files:
|
||||
print("未找到任何B3DM文件")
|
||||
return False
|
||||
|
||||
# 处理每个b3dm文件
|
||||
processed_count = 0
|
||||
for i, (b3dm_file, transform_matrix) in enumerate(b3dm_files):
|
||||
print(f"处理文件 {i+1}/{len(b3dm_files)}: {os.path.basename(b3dm_file)}")
|
||||
|
||||
vertices = self.parse_b3dm_file(b3dm_file)
|
||||
if vertices:
|
||||
# 应用变换矩阵
|
||||
if transform_matrix:
|
||||
vertices = self.apply_transform_to_vertices(vertices, transform_matrix)
|
||||
print(f" 应用了变换矩阵")
|
||||
|
||||
self.all_vertices.extend(vertices)
|
||||
processed_count += 1
|
||||
print(f" 提取到 {len(vertices)} 个顶点")
|
||||
|
||||
print(f"\n成功处理 {processed_count}/{len(b3dm_files)} 个文件")
|
||||
print(f"总计提取 {len(self.all_vertices)} 个顶点")
|
||||
|
||||
if self.all_vertices:
|
||||
self.save_ply_file(output_path)
|
||||
return True
|
||||
else:
|
||||
print("未提取到任何顶点数据")
|
||||
return False
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("用法: python tileset_to_ply.py <tileset.json路径> [输出PLY文件路径]")
|
||||
print("示例: python tileset_to_ply.py tileset.json output.ply")
|
||||
return
|
||||
|
||||
tileset_path = sys.argv[1]
|
||||
output_path = sys.argv[2] if len(sys.argv) > 2 else "merged_tileset.ply"
|
||||
|
||||
if not os.path.exists(tileset_path):
|
||||
print(f"错误: 文件不存在 {tileset_path}")
|
||||
return
|
||||
|
||||
converter = TilesetToPLYConverter()
|
||||
success = converter.convert_tileset_to_ply(tileset_path, output_path)
|
||||
|
||||
if success:
|
||||
print("\n转换完成!")
|
||||
else:
|
||||
print("\n转换失败!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
684
b3dm/volume_calculator.py
Normal file
@ -0,0 +1,684 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
三维模型体积计算器
|
||||
基于指定地理范围内的三维模型数据,使用三角构网方法计算体积
|
||||
"""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from scipy.spatial import Delaunay
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
import DracoPy
|
||||
except ImportError:
|
||||
print("警告: DracoPy库未安装,无法处理Draco压缩的数据")
|
||||
print("请运行: pip install DracoPy")
|
||||
DracoPy = None
|
||||
|
||||
class VolumeCalculator:
|
||||
def __init__(self, location_file):
|
||||
self.location_bounds = self.load_location_bounds(location_file)
|
||||
self.all_vertices = []
|
||||
self.filtered_vertices = []
|
||||
|
||||
def load_location_bounds(self, location_file):
|
||||
"""加载地理范围边界"""
|
||||
try:
|
||||
with open(location_file, 'r', encoding='utf-8') as f:
|
||||
coords = json.load(f)
|
||||
|
||||
# 提取经纬度范围
|
||||
lons = [coord[0] for coord in coords]
|
||||
lats = [coord[1] for coord in coords]
|
||||
elevs = [coord[2] for coord in coords]
|
||||
|
||||
bounds = {
|
||||
'min_lon': min(lons),
|
||||
'max_lon': max(lons),
|
||||
'min_lat': min(lats),
|
||||
'max_lat': max(lats),
|
||||
'min_elev': min(elevs),
|
||||
'max_elev': max(elevs),
|
||||
'coords': coords
|
||||
}
|
||||
|
||||
print(f"地理范围边界:")
|
||||
print(f" 经度: {bounds['min_lon']:.8f} ~ {bounds['max_lon']:.8f}")
|
||||
print(f" 纬度: {bounds['min_lat']:.8f} ~ {bounds['max_lat']:.8f}")
|
||||
print(f" 高程: {bounds['min_elev']:.2f} ~ {bounds['max_elev']:.2f}")
|
||||
|
||||
return bounds
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载地理范围文件失败: {e}")
|
||||
return None
|
||||
|
||||
def wgs84_to_cartesian(self, lon, lat, elev):
|
||||
"""WGS84坐标转换为笛卡尔坐标(高精度算法)"""
|
||||
# WGS84椭球参数(EPSG:4326)
|
||||
a = 6378137.0 # 长半轴 (米)
|
||||
f = 1/298.257223563 # 扁率
|
||||
e2 = 2*f - f*f # 第一偏心率平方
|
||||
|
||||
# 角度转弧度
|
||||
lon_rad = math.radians(lon)
|
||||
lat_rad = math.radians(lat)
|
||||
|
||||
# 计算卯酉圈曲率半径
|
||||
sin_lat = math.sin(lat_rad)
|
||||
cos_lat = math.cos(lat_rad)
|
||||
sin_lon = math.sin(lon_rad)
|
||||
cos_lon = math.cos(lon_rad)
|
||||
|
||||
N = a / math.sqrt(1 - e2 * sin_lat * sin_lat)
|
||||
|
||||
# 高精度笛卡尔坐标计算
|
||||
x = (N + elev) * cos_lat * cos_lon
|
||||
y = (N + elev) * cos_lat * sin_lon
|
||||
z = (N * (1 - e2) + elev) * sin_lat
|
||||
|
||||
return [x, y, z]
|
||||
|
||||
def cartesian_to_wgs84(self, x, y, z):
|
||||
"""笛卡尔坐标转换为WGS84坐标(高精度迭代法)"""
|
||||
# WGS84椭球参数
|
||||
a = 6378137.0 # 长半轴
|
||||
f = 1/298.257223563 # 扁率
|
||||
e2 = 2*f - f*f # 第一偏心率平方
|
||||
ep2 = e2 / (1 - e2) # 第二偏心率平方
|
||||
|
||||
# 计算经度(精确值)
|
||||
lon = math.atan2(y, x)
|
||||
|
||||
# 计算纬度和高程(使用Bowring迭代法)
|
||||
p = math.sqrt(x*x + y*y)
|
||||
|
||||
if p == 0:
|
||||
# 极点情况
|
||||
lat = math.pi/2 if z > 0 else -math.pi/2
|
||||
elev = abs(z) - a * math.sqrt(1 - e2)
|
||||
return [math.degrees(lon), math.degrees(lat), elev]
|
||||
|
||||
# 初始估计
|
||||
theta = math.atan2(z, p * (1 - f))
|
||||
lat_prev = math.atan2(z + ep2 * a * (1 - f) * math.sin(theta)**3,
|
||||
p - e2 * a * math.cos(theta)**3)
|
||||
|
||||
# 迭代求解纬度
|
||||
max_iterations = 10
|
||||
tolerance = 1e-12
|
||||
|
||||
for i in range(max_iterations):
|
||||
N = a / math.sqrt(1 - e2 * math.sin(lat_prev)**2)
|
||||
elev = p / math.cos(lat_prev) - N
|
||||
|
||||
# 更新纬度估计
|
||||
lat_new = math.atan2(z + e2 * N * math.sin(lat_prev), p)
|
||||
|
||||
# 检查收敛性
|
||||
if abs(lat_new - lat_prev) < tolerance:
|
||||
break
|
||||
|
||||
lat_prev = lat_new
|
||||
|
||||
# 最终计算高程
|
||||
N = a / math.sqrt(1 - e2 * math.sin(lat_prev)**2)
|
||||
elev = p / math.cos(lat_prev) - N
|
||||
|
||||
return [math.degrees(lon), math.degrees(lat_prev), elev]
|
||||
|
||||
def is_point_in_bounds(self, vertex):
|
||||
"""检查点是否在指定的地理范围内"""
|
||||
if not self.location_bounds:
|
||||
return True
|
||||
|
||||
# 将笛卡尔坐标转换为WGS84
|
||||
try:
|
||||
lon, lat, elev = self.cartesian_to_wgs84(vertex[0], vertex[1], vertex[2])
|
||||
|
||||
# 检查是否在边界范围内
|
||||
return (self.location_bounds['min_lon'] <= lon <= self.location_bounds['max_lon'] and
|
||||
self.location_bounds['min_lat'] <= lat <= self.location_bounds['max_lat'])
|
||||
except:
|
||||
return False
|
||||
|
||||
def multiply_matrix_vector(self, matrix, vector):
|
||||
"""4x4矩阵与4D向量相乘"""
|
||||
m = [
|
||||
[matrix[0], matrix[4], matrix[8], matrix[12]],
|
||||
[matrix[1], matrix[5], matrix[9], matrix[13]],
|
||||
[matrix[2], matrix[6], matrix[10], matrix[14]],
|
||||
[matrix[3], matrix[7], matrix[11], matrix[15]]
|
||||
]
|
||||
|
||||
v = [vector[0], vector[1], vector[2], 1.0]
|
||||
|
||||
result = [
|
||||
m[0][0]*v[0] + m[0][1]*v[1] + m[0][2]*v[2] + m[0][3]*v[3],
|
||||
m[1][0]*v[0] + m[1][1]*v[1] + m[1][2]*v[2] + m[1][3]*v[3],
|
||||
m[2][0]*v[0] + m[2][1]*v[1] + m[2][2]*v[2] + m[2][3]*v[3]
|
||||
]
|
||||
|
||||
return result
|
||||
|
||||
def multiply_matrices(self, m1, m2):
|
||||
"""两个4x4矩阵相乘"""
|
||||
def list_to_matrix(lst):
|
||||
return [
|
||||
[lst[0], lst[4], lst[8], lst[12]],
|
||||
[lst[1], lst[5], lst[9], lst[13]],
|
||||
[lst[2], lst[6], lst[10], lst[14]],
|
||||
[lst[3], lst[7], lst[11], lst[15]]
|
||||
]
|
||||
|
||||
def matrix_to_list(mat):
|
||||
return [
|
||||
mat[0][0], mat[1][0], mat[2][0], mat[3][0],
|
||||
mat[0][1], mat[1][1], mat[2][1], mat[3][1],
|
||||
mat[0][2], mat[1][2], mat[2][2], mat[3][2],
|
||||
mat[0][3], mat[1][3], mat[2][3], mat[3][3]
|
||||
]
|
||||
|
||||
a = list_to_matrix(m1)
|
||||
b = list_to_matrix(m2)
|
||||
|
||||
result = [[0 for _ in range(4)] for _ in range(4)]
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
for k in range(4):
|
||||
result[i][j] += a[i][k] * b[k][j]
|
||||
|
||||
return matrix_to_list(result)
|
||||
|
||||
def apply_transform_to_vertices(self, vertices, transform_matrix):
|
||||
"""对顶点应用变换矩阵"""
|
||||
if not transform_matrix:
|
||||
return vertices
|
||||
|
||||
transformed_vertices = []
|
||||
for vertex in vertices:
|
||||
transformed = self.multiply_matrix_vector(transform_matrix, vertex)
|
||||
transformed_vertices.append(transformed)
|
||||
|
||||
return transformed_vertices
|
||||
|
||||
def parse_tileset_json(self, tileset_path, parent_transform=None):
|
||||
"""解析tileset.json文件,收集B3DM文件和变换矩阵"""
|
||||
try:
|
||||
with open(tileset_path, 'r', encoding='utf-8') as f:
|
||||
tileset_data = json.load(f)
|
||||
|
||||
b3dm_files = []
|
||||
|
||||
def process_node(node, base_path, accumulated_transform):
|
||||
current_transform = node.get('transform')
|
||||
|
||||
if current_transform and accumulated_transform:
|
||||
final_transform = self.multiply_matrices(accumulated_transform, current_transform)
|
||||
elif current_transform:
|
||||
final_transform = current_transform
|
||||
else:
|
||||
final_transform = accumulated_transform
|
||||
|
||||
if 'content' in node and 'uri' in node['content']:
|
||||
uri = node['content']['uri']
|
||||
if uri.endswith('.b3dm'):
|
||||
full_path = os.path.join(base_path, uri)
|
||||
if os.path.exists(full_path):
|
||||
b3dm_files.append((full_path, final_transform))
|
||||
elif uri.endswith('.json'):
|
||||
sub_tileset_path = os.path.join(base_path, uri)
|
||||
if os.path.exists(sub_tileset_path):
|
||||
sub_files = self.parse_tileset_json(sub_tileset_path, final_transform)
|
||||
b3dm_files.extend(sub_files)
|
||||
|
||||
if 'children' in node:
|
||||
for child in node['children']:
|
||||
process_node(child, base_path, final_transform)
|
||||
|
||||
base_path = os.path.dirname(tileset_path)
|
||||
if 'root' in tileset_data:
|
||||
process_node(tileset_data['root'], base_path, parent_transform)
|
||||
|
||||
return b3dm_files
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析tileset.json时出错: {e}")
|
||||
return []
|
||||
|
||||
def parse_b3dm_file(self, file_path):
|
||||
"""解析B3DM文件"""
|
||||
try:
|
||||
with open(file_path, 'rb') as f:
|
||||
magic = f.read(4)
|
||||
if magic != b'b3dm':
|
||||
return None
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
feature_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
feature_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
batch_table_json_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
batch_table_binary_byte_length = struct.unpack('<I', f.read(4))[0]
|
||||
|
||||
f.seek(28 + feature_table_json_byte_length + feature_table_binary_byte_length +
|
||||
batch_table_json_byte_length + batch_table_binary_byte_length)
|
||||
|
||||
gltf_data = f.read()
|
||||
return self.parse_gltf_data(gltf_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析B3DM文件 {file_path} 失败: {e}")
|
||||
return None
|
||||
|
||||
def parse_gltf_data(self, gltf_data):
|
||||
"""解析glTF数据"""
|
||||
try:
|
||||
if gltf_data[:4] == b'glTF':
|
||||
return self.parse_glb_data(gltf_data)
|
||||
else:
|
||||
gltf_json = json.loads(gltf_data.decode('utf-8'))
|
||||
return self.extract_vertices_from_gltf(gltf_json, None)
|
||||
except Exception as e:
|
||||
print(f"解析glTF数据失败: {e}")
|
||||
return None
|
||||
|
||||
def parse_glb_data(self, glb_data):
|
||||
"""解析GLB格式的glTF数据"""
|
||||
try:
|
||||
magic = glb_data[:4]
|
||||
if magic != b'glTF':
|
||||
return None
|
||||
|
||||
version = struct.unpack('<I', glb_data[4:8])[0]
|
||||
total_length = struct.unpack('<I', glb_data[8:12])[0]
|
||||
|
||||
offset = 12
|
||||
json_data = None
|
||||
binary_data = None
|
||||
|
||||
while offset < len(glb_data):
|
||||
if offset + 8 > len(glb_data):
|
||||
break
|
||||
|
||||
chunk_length = struct.unpack('<I', glb_data[offset:offset+4])[0]
|
||||
chunk_type = glb_data[offset+4:offset+8]
|
||||
chunk_data = glb_data[offset+8:offset+8+chunk_length]
|
||||
|
||||
if chunk_type == b'JSON':
|
||||
json_data = json.loads(chunk_data.decode('utf-8'))
|
||||
elif chunk_type == b'BIN\x00':
|
||||
binary_data = chunk_data
|
||||
|
||||
offset += 8 + chunk_length
|
||||
offset = (offset + 3) & ~3
|
||||
|
||||
if json_data:
|
||||
return self.extract_vertices_from_gltf(json_data, binary_data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析GLB数据失败: {e}")
|
||||
return None
|
||||
|
||||
def extract_vertices_from_gltf(self, gltf_json, binary_data):
|
||||
"""从glTF JSON中提取顶点数据"""
|
||||
vertices = []
|
||||
|
||||
try:
|
||||
if 'extensionsUsed' in gltf_json and 'KHR_draco_mesh_compression' in gltf_json['extensionsUsed']:
|
||||
if DracoPy is None:
|
||||
print("警告: 检测到Draco压缩但DracoPy未安装")
|
||||
return vertices
|
||||
return self.extract_draco_vertices(gltf_json, binary_data)
|
||||
|
||||
if 'meshes' not in gltf_json:
|
||||
return vertices
|
||||
|
||||
for mesh in gltf_json['meshes']:
|
||||
for primitive in mesh['primitives']:
|
||||
if 'attributes' in primitive and 'POSITION' in primitive['attributes']:
|
||||
position_accessor_index = primitive['attributes']['POSITION']
|
||||
|
||||
if 'accessors' in gltf_json and position_accessor_index < len(gltf_json['accessors']):
|
||||
accessor = gltf_json['accessors'][position_accessor_index]
|
||||
buffer_view_index = accessor['bufferView']
|
||||
|
||||
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
|
||||
buffer_view = gltf_json['bufferViews'][buffer_view_index]
|
||||
buffer_index = buffer_view['buffer']
|
||||
byte_offset = buffer_view.get('byteOffset', 0) + accessor.get('byteOffset', 0)
|
||||
|
||||
if binary_data and buffer_index == 0:
|
||||
component_type = accessor['componentType']
|
||||
count = accessor['count']
|
||||
|
||||
if component_type == 5126: # FLOAT
|
||||
vertex_data = struct.unpack(f'<{count*3}f',
|
||||
binary_data[byte_offset:byte_offset+count*12])
|
||||
for i in range(0, len(vertex_data), 3):
|
||||
vertices.append([vertex_data[i], vertex_data[i+1], vertex_data[i+2]])
|
||||
|
||||
except Exception as e:
|
||||
print(f"提取顶点数据失败: {e}")
|
||||
|
||||
return vertices
|
||||
|
||||
def extract_draco_vertices(self, gltf_json, binary_data):
|
||||
"""提取Draco压缩的顶点数据"""
|
||||
vertices = []
|
||||
|
||||
try:
|
||||
if 'meshes' not in gltf_json:
|
||||
return vertices
|
||||
|
||||
for mesh in gltf_json['meshes']:
|
||||
for primitive in mesh['primitives']:
|
||||
if 'extensions' in primitive and 'KHR_draco_mesh_compression' in primitive['extensions']:
|
||||
draco_ext = primitive['extensions']['KHR_draco_mesh_compression']
|
||||
buffer_view_index = draco_ext['bufferView']
|
||||
|
||||
if 'bufferViews' in gltf_json and buffer_view_index < len(gltf_json['bufferViews']):
|
||||
buffer_view = gltf_json['bufferViews'][buffer_view_index]
|
||||
byte_offset = buffer_view.get('byteOffset', 0)
|
||||
byte_length = buffer_view['byteLength']
|
||||
|
||||
if binary_data:
|
||||
draco_data = binary_data[byte_offset:byte_offset+byte_length]
|
||||
|
||||
mesh_data = DracoPy.decode(draco_data)
|
||||
if hasattr(mesh_data, 'points'):
|
||||
points = mesh_data.points
|
||||
for point in points:
|
||||
vertices.append([float(point[0]), float(point[1]), float(point[2])])
|
||||
|
||||
except Exception as e:
|
||||
print(f"解码Draco数据失败: {e}")
|
||||
|
||||
return vertices
|
||||
|
||||
def calculate_triangle_angles(self, p1, p2, p3):
|
||||
"""计算三角形的三个内角(度)"""
|
||||
# 计算三边长度
|
||||
a = np.linalg.norm(p2 - p3) # 边a对应角A(p1)
|
||||
b = np.linalg.norm(p1 - p3) # 边b对应角B(p2)
|
||||
c = np.linalg.norm(p1 - p2) # 边c对应角C(p3)
|
||||
|
||||
# 避免除零错误
|
||||
if a == 0 or b == 0 or c == 0:
|
||||
return [0, 0, 0]
|
||||
|
||||
# 使用余弦定理计算角度
|
||||
try:
|
||||
# 角A = arccos((b²+c²-a²)/(2bc))
|
||||
cos_A = (b*b + c*c - a*a) / (2*b*c)
|
||||
cos_B = (a*a + c*c - b*b) / (2*a*c)
|
||||
cos_C = (a*a + b*b - c*c) / (2*a*b)
|
||||
|
||||
# 限制余弦值范围,避免数值误差
|
||||
cos_A = np.clip(cos_A, -1.0, 1.0)
|
||||
cos_B = np.clip(cos_B, -1.0, 1.0)
|
||||
cos_C = np.clip(cos_C, -1.0, 1.0)
|
||||
|
||||
angle_A = np.arccos(cos_A) * 180 / np.pi
|
||||
angle_B = np.arccos(cos_B) * 180 / np.pi
|
||||
angle_C = np.arccos(cos_C) * 180 / np.pi
|
||||
|
||||
return [angle_A, angle_B, angle_C]
|
||||
except:
|
||||
return [0, 0, 0]
|
||||
|
||||
def is_valid_triangle(self, p1, p2, p3, min_angle=10.0, max_aspect_ratio=10.0):
|
||||
"""验证三角形质量,基于角度约束和长宽比"""
|
||||
angles = self.calculate_triangle_angles(p1, p2, p3)
|
||||
|
||||
# 检查最小角度约束(参考C#代码中的10度限制)
|
||||
if min(angles) < min_angle:
|
||||
return False
|
||||
|
||||
# 计算三边长度
|
||||
a = np.linalg.norm(p2 - p3)
|
||||
b = np.linalg.norm(p1 - p3)
|
||||
c = np.linalg.norm(p1 - p2)
|
||||
|
||||
# 检查长宽比约束
|
||||
max_side = max(a, b, c)
|
||||
min_side = min(a, b, c)
|
||||
if min_side > 0 and max_side / min_side > max_aspect_ratio:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def calculate_circumcenter_and_radius(self, p1, p2, p3):
|
||||
"""计算三角形外接圆圆心和半径(高精度算法)"""
|
||||
try:
|
||||
x1, y1 = p1[0], p1[1]
|
||||
x2, y2 = p2[0], p2[1]
|
||||
x3, y3 = p3[0], p3[1]
|
||||
|
||||
# 使用C#代码中的高精度外接圆计算公式
|
||||
d = 2 * (x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2))
|
||||
|
||||
if abs(d) < 1e-10: # 三点共线
|
||||
return None, float('inf')
|
||||
|
||||
ux = ((x1*x1 + y1*y1) * (y2 - y3) + (x2*x2 + y2*y2) * (y3 - y1) + (x3*x3 + y3*y3) * (y1 - y2)) / d
|
||||
uy = ((x1*x1 + y1*y1) * (x3 - x2) + (x2*x2 + y2*y2) * (x1 - y3) + (x3*x3 + y3*y3) * (x2 - x1)) / d
|
||||
|
||||
# 计算半径
|
||||
radius = np.sqrt((ux - x1)**2 + (uy - y1)**2)
|
||||
|
||||
return np.array([ux, uy]), radius
|
||||
except:
|
||||
return None, float('inf')
|
||||
|
||||
def calculate_volume_delaunay(self, vertices, base_elevation=None, min_angle=10.0, use_quality_filter=True):
|
||||
"""使用Delaunay三角剖分计算体积"""
|
||||
if len(vertices) < 4:
|
||||
print("顶点数量不足,无法进行三角剖分")
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
points = np.array(vertices)
|
||||
|
||||
if base_elevation is None:
|
||||
base_elevation = np.min(points[:, 2])
|
||||
|
||||
print(f"Delaunay方法使用基准面高程: {base_elevation:.2f} 米")
|
||||
print(f"质量控制参数: 最小角度={min_angle}°, 质量过滤={'开启' if use_quality_filter else '关闭'}")
|
||||
|
||||
adjusted_points = points.copy()
|
||||
adjusted_points[:, 2] = np.maximum(points[:, 2] - base_elevation, 0)
|
||||
|
||||
print("正在进行Delaunay三角剖分...")
|
||||
tri = Delaunay(adjusted_points)
|
||||
|
||||
valid_simplices = []
|
||||
total_simplices = len(tri.simplices)
|
||||
|
||||
if use_quality_filter:
|
||||
print("正在进行三角形质量检查...")
|
||||
for simplex in tqdm(tri.simplices, desc="质量检查", unit="个"):
|
||||
p0 = adjusted_points[simplex[0]]
|
||||
p1 = adjusted_points[simplex[1]]
|
||||
p2 = adjusted_points[simplex[2]]
|
||||
p3 = adjusted_points[simplex[3]]
|
||||
|
||||
faces = [(p0, p1, p2), (p0, p1, p3), (p0, p2, p3), (p1, p2, p3)]
|
||||
valid_faces = 0
|
||||
|
||||
for face in faces:
|
||||
if self.is_valid_triangle(face[0], face[1], face[2], min_angle):
|
||||
valid_faces += 1
|
||||
|
||||
if valid_faces >= 3:
|
||||
valid_simplices.append(simplex)
|
||||
|
||||
print(f"质量过滤: {len(valid_simplices)}/{total_simplices} 个四面体通过质量检查")
|
||||
else:
|
||||
valid_simplices = tri.simplices
|
||||
|
||||
total_volume = 0.0
|
||||
valid_volume_count = 0
|
||||
|
||||
print(f"计算 {len(valid_simplices)} 个四面体的体积...")
|
||||
|
||||
for simplex in tqdm(valid_simplices, desc="计算四面体体积", unit="个"):
|
||||
p0 = adjusted_points[simplex[0]]
|
||||
p1 = adjusted_points[simplex[1]]
|
||||
p2 = adjusted_points[simplex[2]]
|
||||
p3 = adjusted_points[simplex[3]]
|
||||
|
||||
v1 = p1 - p0
|
||||
v2 = p2 - p0
|
||||
v3 = p3 - p0
|
||||
|
||||
det = np.linalg.det(np.array([v1, v2, v3]))
|
||||
volume = abs(det) / 6.0
|
||||
|
||||
if volume > 1e-12:
|
||||
total_volume += volume
|
||||
valid_volume_count += 1
|
||||
|
||||
print(f"有效体积计算: {valid_volume_count}/{len(valid_simplices)} 个四面体")
|
||||
|
||||
return total_volume
|
||||
|
||||
except Exception as e:
|
||||
print(f"Delaunay三角剖分计算体积失败: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def load_and_filter_vertices(self, tileset_path):
|
||||
"""加载并过滤指定范围内的顶点"""
|
||||
print(f"开始处理tileset: {tileset_path}")
|
||||
|
||||
# 解析tileset获取所有B3DM文件
|
||||
b3dm_files = self.parse_tileset_json(tileset_path)
|
||||
print(f"找到 {len(b3dm_files)} 个B3DM文件")
|
||||
|
||||
if not b3dm_files:
|
||||
print("未找到任何B3DM文件")
|
||||
return False
|
||||
|
||||
# 处理每个B3DM文件
|
||||
processed_count = 0
|
||||
total_vertices = 0
|
||||
filtered_count = 0
|
||||
|
||||
for i, (b3dm_file, transform_matrix) in enumerate(tqdm(b3dm_files, desc="处理B3DM文件", unit="文件")):
|
||||
# print(f"处理文件 {i+1}/{len(b3dm_files)}: {os.path.basename(b3dm_file)}")
|
||||
|
||||
vertices = self.parse_b3dm_file(b3dm_file)
|
||||
if vertices:
|
||||
# 应用变换矩阵
|
||||
if transform_matrix:
|
||||
vertices = self.apply_transform_to_vertices(vertices, transform_matrix)
|
||||
|
||||
# 过滤范围内的顶点
|
||||
for vertex in vertices:
|
||||
total_vertices += 1
|
||||
if self.is_point_in_bounds(vertex):
|
||||
self.filtered_vertices.append(vertex)
|
||||
filtered_count += 1
|
||||
|
||||
self.all_vertices.extend(vertices)
|
||||
processed_count += 1
|
||||
# tqdm.write(f" {os.path.basename(b3dm_file)}: 提取到 {len(vertices)} 个顶点")
|
||||
|
||||
print(f"\n成功处理 {processed_count}/{len(b3dm_files)} 个文件")
|
||||
print(f"总计提取 {total_vertices} 个顶点")
|
||||
print(f"范围内顶点 {filtered_count} 个")
|
||||
|
||||
return len(self.filtered_vertices) > 0
|
||||
|
||||
|
||||
|
||||
def calculate_volume(self, tileset_path, base_elevation=None, min_angle=10.0, use_quality_filter=True):
|
||||
"""计算指定范围内三维模型的体积
|
||||
|
||||
Args:
|
||||
tileset_path: tileset.json文件路径
|
||||
base_elevation: 基准面高程
|
||||
min_angle: 最小角度约束(度)
|
||||
use_quality_filter: 是否启用质量过滤
|
||||
"""
|
||||
if not self.load_and_filter_vertices(tileset_path):
|
||||
print("未找到范围内的顶点数据")
|
||||
return 0.0
|
||||
|
||||
print(f"\n开始计算体积,使用Delaunay三角剖分方法")
|
||||
print(f"参与计算的顶点数: {len(self.filtered_vertices)}")
|
||||
|
||||
points = np.array(self.filtered_vertices)
|
||||
if base_elevation is None:
|
||||
base_elevation = np.min(points[:, 2])
|
||||
|
||||
print(f"统一基准面高程: {base_elevation:.2f} 米")
|
||||
|
||||
volume = self.calculate_volume_delaunay(self.filtered_vertices, base_elevation, min_angle, use_quality_filter)
|
||||
|
||||
print(f"\n计算结果:")
|
||||
print(f"体积: {volume:.6f} 立方米")
|
||||
print(f"体积: {volume/1000000:.6f} 立方千米")
|
||||
|
||||
return volume
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print("用法: python volume_calculator.py <lboundary.json路径> <tileset.json路径> [基准面高程] [最小角度] [质量过滤]")
|
||||
print("基准面高程: 基准面高程(米),默认使用最低点")
|
||||
print("最小角度: 三角形最小角度约束(度),默认10.0")
|
||||
print("质量过滤: 是否启用质量过滤(true/false),默认true")
|
||||
print("示例: python volume_calculator.py boundary.json tileset.json 100.0 15.0 true")
|
||||
return
|
||||
|
||||
location_file = sys.argv[1]
|
||||
tileset_path = sys.argv[2]
|
||||
|
||||
# 解析可选参数
|
||||
base_elevation = None
|
||||
min_angle = 10.0
|
||||
use_quality_filter = True
|
||||
|
||||
if len(sys.argv) > 3:
|
||||
try:
|
||||
base_elevation = float(sys.argv[3])
|
||||
except ValueError:
|
||||
print("错误:基准面高程必须是数字")
|
||||
return
|
||||
|
||||
if len(sys.argv) > 4:
|
||||
try:
|
||||
min_angle = float(sys.argv[4])
|
||||
except ValueError:
|
||||
print("错误:最小角度必须是数字")
|
||||
return
|
||||
|
||||
if len(sys.argv) > 5:
|
||||
use_quality_filter = sys.argv[5].lower() in ['true', '1', 'yes', 'on']
|
||||
|
||||
if not os.path.exists(location_file):
|
||||
print(f"错误: 位置文件不存在 {location_file}")
|
||||
return
|
||||
|
||||
if not os.path.exists(tileset_path):
|
||||
print(f"错误: tileset文件不存在 {tileset_path}")
|
||||
return
|
||||
|
||||
calculator = VolumeCalculator(location_file)
|
||||
volume = calculator.calculate_volume(tileset_path, base_elevation, min_angle, use_quality_filter)
|
||||
|
||||
if volume > 0:
|
||||
print("\n体积计算完成!")
|
||||
else:
|
||||
print("\n体积计算失败!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Before Width: | Height: | Size: 4.2 MiB |
|
Before Width: | Height: | Size: 5.3 MiB |
@ -1,8 +0,0 @@
|
||||
import os.path
|
||||
|
||||
from middleware.minio_util import downBigFile
|
||||
|
||||
|
||||
miniourl=r"media/22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_202509121519_001_22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_20250912152112_0001_V.mp4"
|
||||
file_path=downBigFile(miniourl)
|
||||
print(f"os.path.abspath(file_path) {os.path.abspath(file_path)}")
|
||||
@ -94,7 +94,7 @@ def func_100000(results, cls_id_list, type_name_list, func_id_10001, list_track_
|
||||
trickier_detail = {
|
||||
# "track_id": results.track_ids[i],
|
||||
"confidence": results.confs[i],
|
||||
"cls_id": i,
|
||||
"cls_id": ind,
|
||||
"type_name": type_name_list[ind],
|
||||
"box": boxes[i]
|
||||
}
|
||||
|
||||
76
grpc_util/grpc_proto_demo/async_check_grpc_client.py
Normal file
@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
import grpc
|
||||
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
|
||||
|
||||
|
||||
async def async_check_server_status(channel):
|
||||
try:
|
||||
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
|
||||
# 注意:这里不需要await,因为Check方法不是异步的
|
||||
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
|
||||
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
except grpc.RpcError as e:
|
||||
print(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def async_check_grpc_request(max_retries=3, delay=5):
|
||||
retries = 0
|
||||
channel = None
|
||||
while retries < max_retries:
|
||||
try:
|
||||
channel = grpc.insecure_channel('localhost:50051')
|
||||
|
||||
if not await async_check_server_status(channel):
|
||||
raise Exception("Server is not healthy")
|
||||
|
||||
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
|
||||
request = check_grpc_pb2.TaskRequest(
|
||||
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
|
||||
sn="8UUXN6S00A0CK7",
|
||||
content_body=check_grpc_pb2.ContentBody(
|
||||
org_code="HMZHB",
|
||||
func_id=[101204],
|
||||
source_url="xxxxxxxxxx",
|
||||
push_url="",
|
||||
confidence=0.4,
|
||||
para_list=[
|
||||
check_grpc_pb2.ParaList(
|
||||
func_id=101204,
|
||||
para_invade_enable=True
|
||||
)
|
||||
],
|
||||
invade=check_grpc_pb2.Invade(
|
||||
invade_file="meta_data/高压线-0826.geojson",
|
||||
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
|
||||
)
|
||||
)
|
||||
)
|
||||
# ProcessTask可能也不是异步的,所以不需要await
|
||||
response = stub.ProcessTask(request)
|
||||
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||
|
||||
if channel:
|
||||
channel.close() # 同步关闭,不需要await
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
retries += 1
|
||||
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
retries += 1
|
||||
|
||||
if channel:
|
||||
channel.close() # 确保在重试前关闭连接
|
||||
channel = None
|
||||
|
||||
await asyncio.sleep(delay) # 异步等待
|
||||
|
||||
print("All retry attempts failed")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(async_check_grpc_request())
|
||||
58
grpc_util/grpc_proto_demo/check_grpc/check_grpc.proto
Normal file
@ -0,0 +1,58 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package task;
|
||||
|
||||
service TaskService {
|
||||
rpc ProcessTask (TaskRequest) returns (TaskResponse);
|
||||
}
|
||||
|
||||
// 添加健康检查服务
|
||||
service HealthCheck {
|
||||
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
|
||||
}
|
||||
|
||||
message HealthCheckRequest {
|
||||
string service = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
enum ServingStatus {
|
||||
UNKNOWN = 0;
|
||||
SERVING = 1;
|
||||
NOT_SERVING = 2;
|
||||
SERVICE_UNKNOWN = 3;
|
||||
}
|
||||
ServingStatus status = 1;
|
||||
}
|
||||
|
||||
message TaskRequest {
|
||||
string task_id = 1;
|
||||
string sn = 2;
|
||||
ContentBody content_body = 3;
|
||||
}
|
||||
|
||||
message ContentBody {
|
||||
string org_code = 1;
|
||||
repeated int32 func_id = 2;
|
||||
string source_url = 3;
|
||||
string push_url = 4;
|
||||
float confidence = 5;
|
||||
repeated ParaList para_list = 6;
|
||||
Invade invade = 7;
|
||||
}
|
||||
|
||||
message ParaList {
|
||||
int32 func_id = 1;
|
||||
bool para_invade_enable = 2;
|
||||
}
|
||||
|
||||
message Invade {
|
||||
string invade_file = 1;
|
||||
string camera_para_url = 2;
|
||||
}
|
||||
|
||||
message TaskResponse {
|
||||
string task_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
54
grpc_util/grpc_proto_demo/check_grpc/check_grpc_pb2.py
Normal file
@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: check_grpc.proto
|
||||
# Protobuf Python Version: 6.31.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
6,
|
||||
31,
|
||||
1,
|
||||
'',
|
||||
'check_grpc.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63heck_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'check_grpc_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=26
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=63
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=66
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=225
|
||||
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=146
|
||||
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=225
|
||||
_globals['_TASKREQUEST']._serialized_start=227
|
||||
_globals['_TASKREQUEST']._serialized_end=310
|
||||
_globals['_CONTENTBODY']._serialized_start=313
|
||||
_globals['_CONTENTBODY']._serialized_end=484
|
||||
_globals['_PARALIST']._serialized_start=486
|
||||
_globals['_PARALIST']._serialized_end=541
|
||||
_globals['_INVADE']._serialized_start=543
|
||||
_globals['_INVADE']._serialized_end=597
|
||||
_globals['_TASKRESPONSE']._serialized_start=599
|
||||
_globals['_TASKRESPONSE']._serialized_end=664
|
||||
_globals['_TASKSERVICE']._serialized_start=666
|
||||
_globals['_TASKSERVICE']._serialized_end=733
|
||||
_globals['_HEALTHCHECK']._serialized_start=735
|
||||
_globals['_HEALTHCHECK']._serialized_end=810
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
172
grpc_util/grpc_proto_demo/check_grpc/check_grpc_pb2_grpc.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import grpc_util.grpc_proto_demo.check_grpc.check_grpc_pb2 as check__grpc__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.76.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ ' but the generated code in check_grpc_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class TaskServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.ProcessTask = channel.unary_unary(
|
||||
'/task.TaskService/ProcessTask',
|
||||
request_serializer=check__grpc__pb2.TaskRequest.SerializeToString,
|
||||
response_deserializer=check__grpc__pb2.TaskResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class TaskServiceServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def ProcessTask(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_TaskServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'ProcessTask': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ProcessTask,
|
||||
request_deserializer=check__grpc__pb2.TaskRequest.FromString,
|
||||
response_serializer=check__grpc__pb2.TaskResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'task.TaskService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('task.TaskService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class TaskService(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def ProcessTask(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/task.TaskService/ProcessTask',
|
||||
check__grpc__pb2.TaskRequest.SerializeToString,
|
||||
check__grpc__pb2.TaskResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class HealthCheckStub(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Check = channel.unary_unary(
|
||||
'/task.HealthCheck/Check',
|
||||
request_serializer=check__grpc__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=check__grpc__pb2.HealthCheckResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class HealthCheckServicer(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
def Check(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_HealthCheckServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Check': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Check,
|
||||
request_deserializer=check__grpc__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=check__grpc__pb2.HealthCheckResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'task.HealthCheck', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('task.HealthCheck', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class HealthCheck(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def Check(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/task.HealthCheck/Check',
|
||||
check__grpc__pb2.HealthCheckRequest.SerializeToString,
|
||||
check__grpc__pb2.HealthCheckResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
83
grpc_util/grpc_proto_demo/check_grpc_client.py
Normal file
@ -0,0 +1,83 @@
|
||||
import grpc
|
||||
import time
|
||||
|
||||
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
|
||||
|
||||
|
||||
|
||||
|
||||
def check_server_status(channel):
|
||||
try:
|
||||
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
|
||||
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
|
||||
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
except grpc.RpcError as e:
|
||||
print(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_grpc_request(max_retries=3, delay=5):
|
||||
channel = None
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
# 创建通道
|
||||
channel = grpc.insecure_channel('localhost:50051')
|
||||
|
||||
# 检查服务器状态
|
||||
if not check_server_status(channel):
|
||||
raise Exception("Server is not healthy")
|
||||
|
||||
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
|
||||
|
||||
# 创建请求消息
|
||||
request = check_grpc_pb2.TaskRequest(
|
||||
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
|
||||
sn="8UUXN6S00A0CK7",
|
||||
content_body=check_grpc_pb2.ContentBody(
|
||||
org_code="HMZHB",
|
||||
func_id=[101204],
|
||||
source_url="xxxxxxxxxx",
|
||||
push_url="",
|
||||
confidence=0.4,
|
||||
para_list=[
|
||||
check_grpc_pb2.ParaList(
|
||||
func_id=101204,
|
||||
para_invade_enable=True
|
||||
)
|
||||
],
|
||||
invade=check_grpc_pb2.Invade(
|
||||
invade_file="meta_data/高压线-0826.geojson",
|
||||
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 调用远程方法
|
||||
response = stub.ProcessTask(request)
|
||||
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
retries += 1
|
||||
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
finally:
|
||||
if channel:
|
||||
channel.close()
|
||||
|
||||
print("All retry attempts failed")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_grpc_request()
|
||||
55
grpc_util/grpc_proto_demo/check_grpc_server.py
Normal file
@ -0,0 +1,55 @@
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
|
||||
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2, check_grpc_pb2_grpc
|
||||
|
||||
|
||||
class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer):
|
||||
def ProcessTask(self, request, context):
|
||||
print(f"Received task_id: {request.task_id}")
|
||||
|
||||
return check_grpc_pb2.TaskResponse(
|
||||
task_id=request.task_id,
|
||||
success=True,
|
||||
message="Task processed successfully"
|
||||
)
|
||||
|
||||
|
||||
class HealthCheckServicer(check_grpc_pb2_grpc.HealthCheckServicer):
|
||||
def Check(self, request, context):
|
||||
# 简单实现:总是返回SERVING状态
|
||||
# 实际应用中可以根据服务状态返回不同值
|
||||
return check_grpc_pb2.HealthCheckResponse(
|
||||
status=check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
|
||||
# 添加服务实现
|
||||
check_grpc_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
|
||||
check_grpc_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
|
||||
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
print("Server started, listening on port 50051...")
|
||||
|
||||
# try:
|
||||
# while True:
|
||||
# time.sleep(86400) # 保持运行
|
||||
# except KeyboardInterrupt:
|
||||
# server.stop(0)
|
||||
# 信号处理
|
||||
import signal
|
||||
def shutdown_handler(signum, frame):
|
||||
print(f"Received signal {signum}, shutting down...")
|
||||
server.stop(0)
|
||||
|
||||
signal.signal(signal.SIGINT, shutdown_handler)
|
||||
signal.signal(signal.SIGTERM, shutdown_handler)
|
||||
|
||||
server.wait_for_termination()
|
||||
|
||||
if __name__ == '__main__':
|
||||
serve()
|
||||
8
grpc_util/grpc_proto_demo/readme
Normal file
@ -0,0 +1,8 @@
|
||||
1、参考check_grpc.proto 手写 proto 相关结构
|
||||
2、使用编译命令生成grpc通讯相关,命令如下
|
||||
python -m grpc_tools.protoc --proto_path=proto_dir \
|
||||
--python_out=gen_dir \
|
||||
--grpc_python_out=gen_dir \
|
||||
proto_dir/check_grpc.proto
|
||||
只会生成requeset、response
|
||||
3、拷贝当前文件下的check_grpc_client.py 、check_grpc_server.py,重写逻辑代码
|
||||
0
grpc_util/grpc_sam3/__init__.py
Normal file
89
grpc_util/grpc_sam3/async_sam3_grpc_client.py
Normal file
@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
import grpc
|
||||
# from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||
# from grpc_util.grpc_sam3.grpc_sam3_img_pb2_grpc import TaskServiceStub, HealthCheckStub
|
||||
# from grpc_util.grpc_sam3.grpc_sam3_img_pb2 import TaskRequest, TaskResponse, HealthCheckRequest, HealthCheckResponse
|
||||
# from . import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||
|
||||
|
||||
async def async_check_server_status(channel):
|
||||
try:
|
||||
health_stub = grpc_sam3_img_pb2_grpc.HealthCheckStub(channel)
|
||||
# 注意:这里不需要await,因为Check方法不是异步的
|
||||
response = health_stub.Check(grpc_sam3_img_pb2.HealthCheckRequest(service="TaskService"))
|
||||
return response.status == grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
except grpc.RpcError as e:
|
||||
print(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def grpc_sam3_pic_predict(
|
||||
task_id:str,
|
||||
sn:str,
|
||||
img_url:str,
|
||||
prompt:str,
|
||||
confidence:float,
|
||||
mqtt_ip:str,
|
||||
mqtt_port:int,
|
||||
mqtt_topic:str,
|
||||
max_retries=3, delay=5):
|
||||
retries = 0
|
||||
channel = None
|
||||
while retries < max_retries:
|
||||
try:
|
||||
channel = grpc.insecure_channel('0.0.0.0:50051')
|
||||
|
||||
if not await async_check_server_status(channel):
|
||||
raise Exception("Server is not healthy")
|
||||
|
||||
stub = grpc_sam3_img_pb2_grpc.TaskServiceStub(channel)
|
||||
request = grpc_sam3_img_pb2.TaskRequest(
|
||||
task_id=task_id,
|
||||
sn=sn,
|
||||
content_body=grpc_sam3_img_pb2.ContentBody(
|
||||
img_url=img_url,
|
||||
prompt=prompt,
|
||||
confidence=0.5,
|
||||
mqtt_ip=mqtt_ip,
|
||||
mqtt_port=mqtt_port,
|
||||
mqtt_topic=mqtt_topic
|
||||
)
|
||||
)
|
||||
# ProcessTask可能也不是异步的,所以不需要await
|
||||
response = stub.ProcessTask(request)
|
||||
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||
|
||||
if channel:
|
||||
channel.close() # 同步关闭,不需要await
|
||||
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
retries += 1
|
||||
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
retries += 1
|
||||
|
||||
if channel:
|
||||
channel.close() # 确保在重试前关闭连接
|
||||
channel = None
|
||||
|
||||
await asyncio.sleep(delay) # 异步等待
|
||||
|
||||
print("All retry attempts failed")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(grpc_sam3_pic_predict(
|
||||
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354111",
|
||||
sn="8UUXN6S00A0CK7",
|
||||
img_url="demo/03.png",
|
||||
prompt="cat",
|
||||
confidence=0.5,
|
||||
mqtt_ip="47.108.62.6",
|
||||
mqtt_port=12503,
|
||||
mqtt_topic="thing/product/ai/events",
|
||||
))
|
||||
48
grpc_util/grpc_sam3/grpc_sam3_img.proto
Normal file
@ -0,0 +1,48 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package grpc_sam3_img;
|
||||
|
||||
service TaskService {
|
||||
rpc ProcessTask (TaskRequest) returns (TaskResponse);
|
||||
}
|
||||
|
||||
// 添加健康检查服务
|
||||
service HealthCheck {
|
||||
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
|
||||
}
|
||||
|
||||
message HealthCheckRequest {
|
||||
string service = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
enum ServingStatus {
|
||||
UNKNOWN = 0;
|
||||
SERVING = 1;
|
||||
NOT_SERVING = 2;
|
||||
SERVICE_UNKNOWN = 3;
|
||||
}
|
||||
ServingStatus status = 1;
|
||||
}
|
||||
|
||||
message TaskRequest {
|
||||
string task_id = 1;
|
||||
string sn = 2;
|
||||
ContentBody content_body = 3;
|
||||
}
|
||||
|
||||
message ContentBody {
|
||||
string img_url = 1;
|
||||
string prompt = 2;
|
||||
float confidence =3;
|
||||
string mqtt_ip = 4;
|
||||
int32 mqtt_port = 5;
|
||||
string mqtt_topic = 6;
|
||||
|
||||
}
|
||||
|
||||
message TaskResponse {
|
||||
string task_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
50
grpc_util/grpc_sam3/grpc_sam3_img_pb2.py
Normal file
@ -0,0 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: grpc_sam3_img.proto
|
||||
# Protobuf Python Version: 6.31.1
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
6,
|
||||
31,
|
||||
1,
|
||||
'',
|
||||
'grpc_sam3_img.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13grpc_sam3_img.proto\x12\rgrpc_sam3_img\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\xa8\x01\n\x13HealthCheckResponse\x12@\n\x06status\x18\x01 \x01(\x0e\x32\x30.grpc_sam3_img.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"\\\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\x30\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x1a.grpc_sam3_img.ContentBody\"z\n\x0b\x43ontentBody\x12\x0f\n\x07img_url\x18\x01 \x01(\t\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x0f\n\x07mqtt_ip\x18\x04 \x01(\t\x12\x11\n\tmqtt_port\x18\x05 \x01(\x05\x12\x12\n\nmqtt_topic\x18\x06 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2U\n\x0bTaskService\x12\x46\n\x0bProcessTask\x12\x1a.grpc_sam3_img.TaskRequest\x1a\x1b.grpc_sam3_img.TaskResponse2]\n\x0bHealthCheck\x12N\n\x05\x43heck\x12!.grpc_sam3_img.HealthCheckRequest\x1a\".grpc_sam3_img.HealthCheckResponseb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_sam3_img_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=38
|
||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=75
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=78
|
||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=246
|
||||
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=167
|
||||
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=246
|
||||
_globals['_TASKREQUEST']._serialized_start=248
|
||||
_globals['_TASKREQUEST']._serialized_end=340
|
||||
_globals['_CONTENTBODY']._serialized_start=342
|
||||
_globals['_CONTENTBODY']._serialized_end=464
|
||||
_globals['_TASKRESPONSE']._serialized_start=466
|
||||
_globals['_TASKRESPONSE']._serialized_end=531
|
||||
_globals['_TASKSERVICE']._serialized_start=533
|
||||
_globals['_TASKSERVICE']._serialized_end=618
|
||||
_globals['_HEALTHCHECK']._serialized_start=620
|
||||
_globals['_HEALTHCHECK']._serialized_end=713
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
172
grpc_util/grpc_sam3/grpc_sam3_img_pb2_grpc.py
Normal file
@ -0,0 +1,172 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import grpc_util.grpc_sam3.grpc_sam3_img_pb2 as grpc__sam3__img__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.76.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ ' but the generated code in grpc_sam3_img_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class TaskServiceStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.ProcessTask = channel.unary_unary(
|
||||
'/grpc_sam3_img.TaskService/ProcessTask',
|
||||
request_serializer=grpc__sam3__img__pb2.TaskRequest.SerializeToString,
|
||||
response_deserializer=grpc__sam3__img__pb2.TaskResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class TaskServiceServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def ProcessTask(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_TaskServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'ProcessTask': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.ProcessTask,
|
||||
request_deserializer=grpc__sam3__img__pb2.TaskRequest.FromString,
|
||||
response_serializer=grpc__sam3__img__pb2.TaskResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'grpc_sam3_img.TaskService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('grpc_sam3_img.TaskService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class TaskService(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def ProcessTask(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/grpc_sam3_img.TaskService/ProcessTask',
|
||||
grpc__sam3__img__pb2.TaskRequest.SerializeToString,
|
||||
grpc__sam3__img__pb2.TaskResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class HealthCheckStub(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Check = channel.unary_unary(
|
||||
'/grpc_sam3_img.HealthCheck/Check',
|
||||
request_serializer=grpc__sam3__img__pb2.HealthCheckRequest.SerializeToString,
|
||||
response_deserializer=grpc__sam3__img__pb2.HealthCheckResponse.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class HealthCheckServicer(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
def Check(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_HealthCheckServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Check': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.Check,
|
||||
request_deserializer=grpc__sam3__img__pb2.HealthCheckRequest.FromString,
|
||||
response_serializer=grpc__sam3__img__pb2.HealthCheckResponse.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'grpc_sam3_img.HealthCheck', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('grpc_sam3_img.HealthCheck', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class HealthCheck(object):
|
||||
"""添加健康检查服务
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def Check(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/grpc_sam3_img.HealthCheck/Check',
|
||||
grpc__sam3__img__pb2.HealthCheckRequest.SerializeToString,
|
||||
grpc__sam3__img__pb2.HealthCheckResponse.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
72
grpc_util/grpc_sam3/sam3_grpc_client.py
Normal file
@ -0,0 +1,72 @@
|
||||
import grpc
|
||||
import time
|
||||
|
||||
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||
|
||||
|
||||
def check_server_status(channel):
|
||||
try:
|
||||
health_stub = grpc_sam3_img_pb2_grpc.HealthCheckStub(channel)
|
||||
response = health_stub.Check(grpc_sam3_img_pb2.HealthCheckRequest(service="TaskService"))
|
||||
return response.status == grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
except grpc.RpcError as e:
|
||||
print(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_grpc_request(max_retries=3, delay=5):
|
||||
channel = None
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
# 创建通道
|
||||
channel = grpc.insecure_channel('192.168.110.187:9999')
|
||||
|
||||
# 检查服务器状态
|
||||
if not check_server_status(channel):
|
||||
raise Exception("Server is not healthy")
|
||||
|
||||
stub = grpc_sam3_img_pb2_grpc.TaskServiceStub(channel)
|
||||
|
||||
# 创建请求消息
|
||||
request = grpc_sam3_img_pb2.TaskRequest(
|
||||
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354111",
|
||||
sn="8UUXN6S00A0CK7",
|
||||
content_body=grpc_sam3_img_pb2.ContentBody(
|
||||
img_url="demo/03.png",
|
||||
prompt="cat",
|
||||
confidence=0.5,
|
||||
mqtt_ip="47.108.62.6",
|
||||
mqtt_port=12503,
|
||||
mqtt_topic="thing/product/ai/events"
|
||||
)
|
||||
)
|
||||
|
||||
# 调用远程方法
|
||||
response = stub.ProcessTask(request)
|
||||
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
retries += 1
|
||||
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
finally:
|
||||
if channel:
|
||||
channel.close()
|
||||
|
||||
print("All retry attempts failed")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_grpc_request()
|
||||
369
grpc_util/grpc_sam3/sam3_grpc_server.py
Normal file
@ -0,0 +1,369 @@
|
||||
import json
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||
from middleware.MQTTService import MQTTService
|
||||
from middleware.minio_util import downFile, upload_file
|
||||
|
||||
import sys
|
||||
|
||||
from middleware.util import get_current_date_and_milliseconds
|
||||
|
||||
print(sys.executable)
|
||||
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
import sam3.sam3
|
||||
from PIL import Image
|
||||
from sam3.sam3 import build_sam3_image_model
|
||||
from sam3.sam3 import build_sam3_image_model_0228
|
||||
from sam3.sam3.model.box_ops import box_xywh_to_cxcywh
|
||||
from sam3.sam3.model.sam3_image_processor import Sam3Processor
|
||||
from sam3.sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results,plot_results_savepic
|
||||
|
||||
sam3_root = os.path.join(os.path.dirname(sam3.sam3.__file__), "..")
|
||||
import torch
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
||||
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskQueue:
|
||||
"""任务队列管理类"""
|
||||
|
||||
def __init__(self, max_size: int = 1000):
|
||||
self.queue = queue.Queue(maxsize=max_size)
|
||||
self.task_status: Dict[str, dict] = {} # 存储任务状态
|
||||
self.lock = threading.Lock()
|
||||
self.stop_event = threading.Event()
|
||||
|
||||
def add_task(self, task_id: str, request_data: dict) -> bool:
|
||||
"""添加任务到队列"""
|
||||
try:
|
||||
task_item = {
|
||||
'task_id': task_id,
|
||||
'data': request_data,
|
||||
'timestamp': time.time(),
|
||||
'status': 'pending'
|
||||
}
|
||||
|
||||
with self.lock:
|
||||
self.task_status[task_id] = task_item
|
||||
|
||||
# 非阻塞方式放入队列
|
||||
self.queue.put(task_item, block=False)
|
||||
logger.info(f"任务 {task_id} 已添加到队列,当前队列大小: {self.queue.qsize()}")
|
||||
return True
|
||||
|
||||
except queue.Full:
|
||||
logger.warning(f"队列已满,任务 {task_id} 被拒绝")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"添加任务失败: {e}")
|
||||
return False
|
||||
|
||||
def get_task(self, timeout: float = 1.0) -> Optional[dict]:
|
||||
"""从队列获取任务"""
|
||||
try:
|
||||
return self.queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def update_task_status(self, task_id: str, status: str, result: dict = None):
|
||||
"""更新任务状态"""
|
||||
with self.lock:
|
||||
if task_id in self.task_status:
|
||||
self.task_status[task_id]['status'] = status
|
||||
if result:
|
||||
self.task_status[task_id]['result'] = result
|
||||
self.task_status[task_id]['completed_time'] = time.time()
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[dict]:
|
||||
"""获取任务状态"""
|
||||
with self.lock:
|
||||
return self.task_status.get(task_id)
|
||||
|
||||
def cleanup_old_tasks(self, max_age_seconds: int = 3600):
|
||||
"""清理旧任务"""
|
||||
with self.lock:
|
||||
current_time = time.time()
|
||||
to_delete = []
|
||||
for task_id, task in self.task_status.items():
|
||||
if 'completed_time' in task:
|
||||
age = current_time - task['completed_time']
|
||||
if age > max_age_seconds:
|
||||
to_delete.append(task_id)
|
||||
|
||||
for task_id in to_delete:
|
||||
del self.task_status[task_id]
|
||||
logger.info(f"清理旧任务: {task_id}")
|
||||
|
||||
|
||||
class TaskWorker(threading.Thread):
|
||||
"""工作线程,从队列中取任务并处理"""
|
||||
|
||||
def __init__(self, worker_id: int, task_queue: TaskQueue, stop_event: threading.Event):
|
||||
super().__init__(daemon=True)
|
||||
self.worker_id = worker_id
|
||||
self.task_queue = task_queue
|
||||
self.stop_event = stop_event
|
||||
self.processed_count = 0
|
||||
|
||||
def run(self):
|
||||
logger.info(f"工作线程 {self.worker_id} 启动")
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
# 从队列获取任务
|
||||
task = self.task_queue.get_task(timeout=0.5)
|
||||
if not task:
|
||||
continue
|
||||
|
||||
task_id = task['task_id']
|
||||
request_data = task['data']
|
||||
|
||||
logger.info(f"工作线程 {self.worker_id} 开始处理任务: {task_id}")
|
||||
|
||||
# 更新任务状态为处理中
|
||||
self.task_queue.update_task_status(task_id, 'processing')
|
||||
|
||||
# 这里是你的实际处理逻辑
|
||||
result = self.process_task(task_id, request_data)
|
||||
|
||||
# 更新任务状态为完成
|
||||
self.task_queue.update_task_status(
|
||||
task_id,
|
||||
'completed' if result.get('success') else 'failed',
|
||||
result
|
||||
)
|
||||
|
||||
self.processed_count += 1
|
||||
logger.info(f"工作线程 {self.worker_id} 完成任务: {task_id}, 处理总数: {self.processed_count}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作线程 {self.worker_id} 处理任务失败: {e}")
|
||||
if task:
|
||||
self.task_queue.update_task_status(
|
||||
task['task_id'],
|
||||
'failed',
|
||||
{'error': str(e)}
|
||||
)
|
||||
|
||||
def process_task(self, task_id: str, request_data: dict) -> dict:
|
||||
"""模拟耗时任务处理"""
|
||||
# 这里替换为你的实际处理逻辑
|
||||
# time.sleep(10) # 模拟10秒处理时间
|
||||
task_id=request_data["task_id"]
|
||||
sn=request_data["sn"]
|
||||
img_url=request_data["img_url"]
|
||||
prompt=request_data["prompt"]
|
||||
confidence=request_data["confidence"]
|
||||
mqtt_ip=request_data["mqtt_ip"]
|
||||
mqtt_port=request_data["mqtt_port"]
|
||||
mqtt_topic=request_data["mqtt_topic"]
|
||||
local_image_path=downFile(img_url)
|
||||
|
||||
bpe_path = f"/home/beidou/test0623/sam3/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
|
||||
config_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/config.json" # 替换为本地路径
|
||||
checkpoint_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/sam3.pt" # 替换为本地路径
|
||||
# model = build_sam3_image_model(bpe_path=bpe_path)
|
||||
|
||||
# 2. 构建模型(从本地加载)
|
||||
model = build_sam3_image_model_0228(
|
||||
bpe_path=bpe_path,
|
||||
checkpoint_path=checkpoint_path,
|
||||
config_path=config_path, # 可选
|
||||
load_from_HF=False,
|
||||
device="cuda",
|
||||
eval_mode=True,
|
||||
)
|
||||
formatted_date, milliseconds_timestamp = get_current_date_and_milliseconds()
|
||||
img_name=os.path.basename(local_image_path)
|
||||
dir_name=os.path.dirname(local_image_path)
|
||||
predict_save_path=os.path.join(dir_name,str(milliseconds_timestamp)+img_name)
|
||||
|
||||
# image = Image.open(image_path)
|
||||
image = Image.open(local_image_path).convert("RGB")
|
||||
width, height = image.size
|
||||
processor = Sam3Processor(model, confidence_threshold=0.5)
|
||||
inference_state = processor.set_image(image)
|
||||
|
||||
processor.reset_all_prompts(inference_state)
|
||||
inference_state = processor.set_text_prompt(state=inference_state, prompt="road")
|
||||
img0 = Image.open(local_image_path)
|
||||
plot_results_savepic(img0, inference_state, save_path=predict_save_path)
|
||||
object_name, _=upload_file(predict_save_path,None)
|
||||
|
||||
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
|
||||
message = {
|
||||
'success': True,
|
||||
"task_id":task_id,
|
||||
'object_name': object_name
|
||||
}
|
||||
mqtt.publish_sync(mqtt_topic, json.dumps(message, ensure_ascii=False))
|
||||
# 删除本地文件
|
||||
if os.path.exists(local_image_path):
|
||||
os.remove(local_image_path)
|
||||
if os.path.exists(predict_save_path):
|
||||
os.remove(predict_save_path)
|
||||
# 模拟处理结果
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'任务 {task_id} 处理完成',
|
||||
'data': {'result': 'some_result'}
|
||||
}
|
||||
|
||||
|
||||
class TaskServiceServicer(grpc_sam3_img_pb2_grpc.TaskServiceServicer):
|
||||
def __init__(self, task_queue: TaskQueue, max_workers: int = 1):
|
||||
self.task_queue = task_queue
|
||||
self.max_workers = max_workers
|
||||
self.stop_event = threading.Event()
|
||||
self.workers = []
|
||||
|
||||
# 启动工作线程
|
||||
self.start_workers()
|
||||
|
||||
def start_workers(self):
|
||||
"""启动工作线程池"""
|
||||
for i in range(self.max_workers):
|
||||
worker = TaskWorker(i, self.task_queue, self.stop_event)
|
||||
worker.start()
|
||||
self.workers.append(worker)
|
||||
logger.info(f"启动了 {self.max_workers} 个工作线程")
|
||||
|
||||
def ProcessTask(self, request, context):
|
||||
"""处理任务请求 - 将任务放入队列后立即返回"""
|
||||
try:
|
||||
# 检查队列是否已满
|
||||
if self.task_queue.queue.full():
|
||||
logger.warning(f"队列已满,拒绝任务: {request.task_id}")
|
||||
return grpc_sam3_img_pb2.TaskResponse(
|
||||
task_id=request.task_id,
|
||||
success=False,
|
||||
message="服务器忙,请稍后重试"
|
||||
)
|
||||
|
||||
# 准备任务数据
|
||||
task_data = {
|
||||
'task_id': request.task_id,
|
||||
'sn': request.sn,
|
||||
'img_url': request.content_body.img_url,
|
||||
'prompt': request.content_body.prompt,
|
||||
'confidence': request.content_body.confidence,
|
||||
'mqtt_ip': request.content_body.mqtt_ip,
|
||||
'mqtt_port': request.content_body.mqtt_port,
|
||||
'mqtt_topic': request.content_body.mqtt_topic
|
||||
}
|
||||
|
||||
# 将任务添加到队列
|
||||
if self.task_queue.add_task(request.task_id, task_data):
|
||||
return grpc_sam3_img_pb2.TaskResponse(
|
||||
task_id=request.task_id,
|
||||
success=True,
|
||||
message=f"任务已接收,正在排队处理。当前队列位置: {self.task_queue.queue.qsize()}"
|
||||
)
|
||||
else:
|
||||
return grpc_sam3_img_pb2.TaskResponse(
|
||||
task_id=request.task_id,
|
||||
success=False,
|
||||
message="任务提交失败"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理任务请求失败: {e}")
|
||||
return grpc_sam3_img_pb2.TaskResponse(
|
||||
task_id=request.task_id,
|
||||
success=False,
|
||||
message=f"服务器内部错误: {str(e)}"
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""停止工作线程"""
|
||||
self.stop_event.set()
|
||||
for worker in self.workers:
|
||||
worker.join(timeout=2)
|
||||
logger.info("所有工作线程已停止")
|
||||
|
||||
|
||||
class HealthCheckServicer(grpc_sam3_img_pb2_grpc.HealthCheckServicer):
|
||||
def __init__(self, task_queue: TaskQueue):
|
||||
self.task_queue = task_queue
|
||||
|
||||
def Check(self, request, context):
|
||||
"""健康检查,包含队列状态"""
|
||||
queue_size = self.task_queue.queue.qsize()
|
||||
|
||||
if queue_size > 50: # 队列过长
|
||||
return grpc_sam3_img_pb2.HealthCheckResponse(
|
||||
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.NOT_SERVING
|
||||
)
|
||||
else:
|
||||
return grpc_sam3_img_pb2.HealthCheckResponse(
|
||||
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
# 创建任务队列
|
||||
task_queue = TaskQueue(max_size=20)
|
||||
|
||||
# 创建服务实例
|
||||
task_service = TaskServiceServicer(task_queue, max_workers=1) # 10个工作线程
|
||||
health_service = HealthCheckServicer(task_queue)
|
||||
|
||||
# 创建gRPC服务器
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) # 处理RPC请求的线程
|
||||
|
||||
# 注册服务
|
||||
grpc_sam3_img_pb2_grpc.add_TaskServiceServicer_to_server(task_service, server)
|
||||
grpc_sam3_img_pb2_grpc.add_HealthCheckServicer_to_server(health_service, server)
|
||||
|
||||
# 启动服务器
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
|
||||
logger.info("服务器已启动,监听端口: 50051")
|
||||
logger.info(f"工作线程数: 1, 队列最大容量: 20")
|
||||
|
||||
# 定时清理旧任务
|
||||
def cleanup_loop():
|
||||
while True:
|
||||
time.sleep(300) # 每5分钟清理一次
|
||||
task_queue.cleanup_old_tasks()
|
||||
|
||||
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
|
||||
cleanup_thread.start()
|
||||
|
||||
# 优雅关闭处理
|
||||
def shutdown():
|
||||
logger.info("收到关闭信号,正在停止服务器...")
|
||||
task_service.stop()
|
||||
server.stop(5) # 5秒宽限期
|
||||
logger.info("服务器已停止")
|
||||
|
||||
import signal
|
||||
signal.signal(signal.SIGINT, lambda s, f: shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda s, f: shutdown())
|
||||
|
||||
# 保持服务器运行
|
||||
try:
|
||||
server.wait_for_termination()
|
||||
except KeyboardInterrupt:
|
||||
shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
serve()
|
||||
284
md/grpc.md
Normal file
@ -0,0 +1,284 @@
|
||||
# 前言
|
||||
|
||||
sanic 和 服务之间基于grpc 解绑,一个服务一个grpc
|
||||
|
||||
|
||||
|
||||
可以参考接口,grpc 最好留有健康校验
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# grpc demo
|
||||
|
||||
为了增强 gRPC 通讯的可靠性,我们可以添加以下功能:
|
||||
|
||||
1. 服务器健康检查服务
|
||||
2. 客户端连接前检查服务器状态
|
||||
3. 通讯过程中的错误处理和重试机制
|
||||
|
||||
## 1. 修改 protobuf 定义 (task.proto)
|
||||
|
||||
首先添加健康检查服务定义:
|
||||
|
||||
```protobuf
|
||||
syntax = "proto3";
|
||||
|
||||
package task;
|
||||
|
||||
service TaskService {
|
||||
rpc ProcessTask (TaskRequest) returns (TaskResponse);
|
||||
}
|
||||
|
||||
// 添加健康检查服务
|
||||
service HealthCheck {
|
||||
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
|
||||
}
|
||||
|
||||
message HealthCheckRequest {
|
||||
string service = 1;
|
||||
}
|
||||
|
||||
message HealthCheckResponse {
|
||||
enum ServingStatus {
|
||||
UNKNOWN = 0;
|
||||
SERVING = 1;
|
||||
NOT_SERVING = 2;
|
||||
SERVICE_UNKNOWN = 3;
|
||||
}
|
||||
ServingStatus status = 1;
|
||||
}
|
||||
|
||||
message TaskRequest {
|
||||
string task_id = 1;
|
||||
string sn = 2;
|
||||
ContentBody content_body = 3;
|
||||
}
|
||||
|
||||
message ContentBody {
|
||||
string org_code = 1;
|
||||
repeated int32 func_id = 2;
|
||||
string source_url = 3;
|
||||
string push_url = 4;
|
||||
float confidence = 5;
|
||||
repeated ParaList para_list = 6;
|
||||
Invade invade = 7;
|
||||
}
|
||||
|
||||
message ParaList {
|
||||
int32 func_id = 1;
|
||||
bool para_invade_enable = 2;
|
||||
}
|
||||
|
||||
message Invade {
|
||||
string invade_file = 1;
|
||||
string camera_para_url = 2;
|
||||
}
|
||||
|
||||
message TaskResponse {
|
||||
string task_id = 1;
|
||||
bool success = 2;
|
||||
string message = 3;
|
||||
}
|
||||
```
|
||||
|
||||
## 2. 重新生成 Python 代码
|
||||
|
||||
```bash
|
||||
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. task.proto
|
||||
protoc -I=./proto --go_out=./gen proto/subfolder/task.proto
|
||||
```
|
||||
|
||||
## 3. 增强服务端实现 (server.py)
|
||||
|
||||
```python
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import time
|
||||
import task_pb2
|
||||
import task_pb2_grpc
|
||||
|
||||
class TaskServiceServicer(task_pb2_grpc.TaskServiceServicer):
|
||||
def ProcessTask(self, request, context):
|
||||
print(f"Received task_id: {request.task_id}")
|
||||
# ... 原有处理逻辑
|
||||
return task_pb2.TaskResponse(
|
||||
task_id=request.task_id,
|
||||
success=True,
|
||||
message="Task processed successfully"
|
||||
)
|
||||
|
||||
class HealthCheckServicer(task_pb2_grpc.HealthCheckServicer):
|
||||
def Check(self, request, context):
|
||||
# 简单实现:总是返回SERVING状态
|
||||
# 实际应用中可以根据服务状态返回不同值
|
||||
return task_pb2.HealthCheckResponse(
|
||||
status=task_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
)
|
||||
|
||||
def serve():
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
|
||||
# 添加服务实现
|
||||
task_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
|
||||
task_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
|
||||
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
print("Server started, listening on port 50051...")
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(86400) # 保持运行
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
serve()
|
||||
```
|
||||
|
||||
## 4. 增强客户端实现 (client.py)
|
||||
|
||||
```python
|
||||
import grpc
|
||||
import time
|
||||
import task_pb2
|
||||
import task_pb2_grpc
|
||||
|
||||
def check_server_status(channel):
|
||||
try:
|
||||
health_stub = task_pb2_grpc.HealthCheckStub(channel)
|
||||
response = health_stub.Check(task_pb2.HealthCheckRequest(service="TaskService"))
|
||||
return response.status == task_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||
except grpc.RpcError as e:
|
||||
print(f"Health check failed: {e}")
|
||||
return False
|
||||
|
||||
def run_with_retry(max_retries=3, delay=5):
|
||||
channel = None
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
# 创建通道
|
||||
channel = grpc.insecure_channel('localhost:50051')
|
||||
|
||||
# 检查服务器状态
|
||||
if not check_server_status(channel):
|
||||
raise Exception("Server is not healthy")
|
||||
|
||||
stub = task_pb2_grpc.TaskServiceStub(channel)
|
||||
|
||||
# 创建请求消息
|
||||
request = task_pb2.TaskRequest(
|
||||
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
|
||||
sn="8UUXN6S00A0CK7",
|
||||
content_body=task_pb2.ContentBody(
|
||||
org_code="HMZHB",
|
||||
func_id=[101204],
|
||||
source_url="rtmp://222.212.85.86:1935/live/1581F8HGX253S00A05L8",
|
||||
push_url="",
|
||||
confidence=0.4,
|
||||
para_list=[
|
||||
task_pb2.ParaList(
|
||||
func_id=101204,
|
||||
para_invade_enable=True
|
||||
)
|
||||
],
|
||||
invade=task_pb2.Invade(
|
||||
invade_file="meta_data/高压线-0826.geojson",
|
||||
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 调用远程方法
|
||||
response = stub.ProcessTask(request)
|
||||
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||
return True
|
||||
|
||||
except grpc.RpcError as e:
|
||||
retries += 1
|
||||
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
except Exception as e:
|
||||
print(f"Error occurred: {e}")
|
||||
retries += 1
|
||||
if retries < max_retries:
|
||||
print(f"Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
finally:
|
||||
if channel:
|
||||
channel.close()
|
||||
|
||||
print("All retry attempts failed")
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_with_retry()
|
||||
```
|
||||
|
||||
## 5. 增强功能说明
|
||||
|
||||
1. 健康检查服务
|
||||
- 添加了标准的 gRPC 健康检查服务
|
||||
- 客户端可以在执行主要操作前检查服务状态
|
||||
2. 错误处理和重试机制
|
||||
- 客户端现在会捕获 `grpc.RpcError` 和其他异常
|
||||
- 实现了最大重试次数和重试间隔
|
||||
- 每次重试前都会检查服务器状态
|
||||
3. 资源管理
|
||||
- 确保在所有情况下都正确关闭通道
|
||||
- 使用上下文管理器处理通道生命周期
|
||||
4. 状态反馈
|
||||
- 提供更详细的错误信息
|
||||
- 记录重试尝试
|
||||
|
||||
## 6. 运行说明
|
||||
|
||||
1. 安装依赖:
|
||||
|
||||
```bash
|
||||
pip install grpcio grpcio-tools
|
||||
```
|
||||
|
||||
2. 生成 protobuf 代码:
|
||||
|
||||
```bash
|
||||
|
||||
# 确保在项目根目录下执行
|
||||
# 编译 proto 文件,输出到 a/b/ 目录
|
||||
python -m grpc_tools.protoc \
|
||||
-I. \ # 指定 proto 文件的根目录(当前目录 ".")
|
||||
--python_out=. \ # 生成 _pb2.py 到当前目录(或指定子目录)
|
||||
--grpc_python_out=. \ # 生成 _pb2_grpc.py 到当前目录(或指定子目录)
|
||||
a/b/task.proto # proto 文件路径(相对于 -I 指定的根目录)
|
||||
```
|
||||
|
||||
3. 启动服务器:
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
```
|
||||
|
||||
4. 运行客户端:
|
||||
|
||||
```bash
|
||||
python client.py
|
||||
```
|
||||
|
||||
## 7. 测试场景
|
||||
|
||||
1. 服务器未运行
|
||||
- 客户端会检测到连接失败并重试
|
||||
- 最终显示所有重试失败
|
||||
2. 服务器运行但健康检查失败
|
||||
- 可以修改 `HealthCheckServicer` 返回 `NOT_SERVING` 状态进行测试
|
||||
- 客户端会拒绝执行主要操作
|
||||
3. 网络中断
|
||||
- 客户端会捕获异常并尝试重试
|
||||
|
||||
这个增强版本提供了更健壮的 gRPC 通讯机制,适合生产环境使用。
|
||||
350
md/接口.md
Normal file
@ -0,0 +1,350 @@
|
||||
# 算法与后台解耦规则
|
||||
|
||||
# 1、方法
|
||||
|
||||
postgres 的ai_model_list 表,id字段声明为6位长度数字
|
||||
|
||||
1、第1位表示算法类别,1xxxxx 表明为目标识别、2xxxxx标明为语义分割、3xxxxx表示变化监测
|
||||
|
||||
2、最后两位表示二次计算,100001 表示做目标识别、100002表示做目标识别,且做人员计数
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 接口名:视频流识别
|
||||
|
||||
接收前端的视频流、模型、识别类型,算法做计算,并且将计算结果存储到minio,消息通过mqtt发送
|
||||
|
||||
## 1、请求
|
||||
|
||||
接口 /ai/stream/back_detect
|
||||
|
||||
方法 post
|
||||
|
||||
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
|
||||
|
||||
body
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "1234567890", #任务id
|
||||
"sn":"", #无人机sn
|
||||
"content_body": {
|
||||
"source_url": "rtmp://192.168.0.142:1935/live/123456", #无人机视频流url
|
||||
"confidence":0.4, #置信度
|
||||
"model_func_id":[100001,100002] #方法id
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2、响应
|
||||
|
||||
算法的响应分为两个部分
|
||||
|
||||
1、rest响应,表明收到请求
|
||||
|
||||
2、mqtt消息,持续输出计算结果
|
||||
|
||||
### 1、rest
|
||||
|
||||
```
|
||||
{
|
||||
"status": "success",
|
||||
"task_id": "1234567890",
|
||||
"message": "Detection started successfully"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 2、mqtt
|
||||
|
||||
ip 112.44.103.230 端口 1883
|
||||
|
||||
topic thing/product/ai/events
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "1234567890",
|
||||
"minio": {
|
||||
"minio_path": "ai_result/20250702/1751425303860-output-1751425303800959985.jpg",
|
||||
"file_type": "pic"
|
||||
},
|
||||
"box_detail": {
|
||||
"result_100001": {
|
||||
"func_id_10001": 100001,
|
||||
"type_name": "行人",
|
||||
"cls_count": 1,
|
||||
"box_count": [
|
||||
[
|
||||
{
|
||||
"track_id": 22099,
|
||||
"confidence": 0.34013107419013977,
|
||||
"cls_id": 0,
|
||||
"type_name": "行人",
|
||||
"box": [
|
||||
15.935794830322266,
|
||||
694.75390625,
|
||||
33.22901916503906,
|
||||
713.1658935546875
|
||||
]
|
||||
}
|
||||
]
|
||||
]
|
||||
}
|
||||
},
|
||||
"uav_location": {
|
||||
"data": {
|
||||
"attitude_head": 60,
|
||||
"gimbal_pitch": 60,
|
||||
"gimbal_roll": 60,
|
||||
"gimbal_yaw": 60,
|
||||
"height": 10,
|
||||
"latitude": 10,
|
||||
"longitude": 10,
|
||||
"speed_x": 10,
|
||||
"speed_y": 10,
|
||||
"speed_z": 10
|
||||
},
|
||||
"timestamp": 1751425301213249700
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
# 接口名:图片识别
|
||||
|
||||
接收前端的图片,算法做计算,并且将计算结果存储到minio,消息通过mqtt发送
|
||||
|
||||
## 1、请求
|
||||
|
||||
接口 /ai/pic/back_detect_pic
|
||||
|
||||
方法 post
|
||||
|
||||
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
|
||||
|
||||
body
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "0001111",
|
||||
"content_body": {
|
||||
"s3_id":1, #根据id适配,minio相关存储参数
|
||||
"s3_url":[
|
||||
"test/frame_0000.jpg","test/frame_0001.jpg","test/frame_0002.jpg" # minio文件地址
|
||||
],
|
||||
"confidence":0.4, #算法置信度
|
||||
"model_func_id":[10001,10002] #方法id
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 2、响应
|
||||
|
||||
算法的响应分为两个部分
|
||||
|
||||
1、rest响应,表明收到请求
|
||||
|
||||
2、mqtt消息,持续输出计算结果
|
||||
|
||||
### 1、rest
|
||||
|
||||
```
|
||||
{
|
||||
"status": "success",
|
||||
"task_id": "0001111",
|
||||
"message": "Detection started successfully"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 2、mqtt
|
||||
|
||||
ip 112.44.103.230 端口 1883
|
||||
|
||||
topic thing/product/ai/events
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "0001111", #任务id
|
||||
"minio": {
|
||||
"minio_path": "ai_result/20250627/1751006943659-frame_0001.jpg", # minio 存储路径
|
||||
"file_type": "pic"
|
||||
},
|
||||
"box_detail": {
|
||||
"model_id": 10001,
|
||||
"box_count": [
|
||||
{
|
||||
"type": 3, # 类型
|
||||
"type_name": "车辆", #类型名称
|
||||
"count": 71 #数量
|
||||
},
|
||||
{
|
||||
"type": 0,
|
||||
"type_name": "车辆",
|
||||
"count": 7
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
# 接口名:地类分割
|
||||
|
||||
|
||||
|
||||
接收前端的图片,算法做计算,并且将计算结果存储到minio,消息通过mqtt发送
|
||||
|
||||
## 1、请求
|
||||
|
||||
接口 /ai/pic/back_detect_pic
|
||||
|
||||
方法 post
|
||||
|
||||
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
|
||||
|
||||
body
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
|
||||
"content_body": {
|
||||
"s3_id":1,#根据id适配,minio相关存储参数
|
||||
"s3_url":[
|
||||
"test/patch_0011.png", # minio文件地址
|
||||
"test/patch_0012.png"
|
||||
],
|
||||
"model_func_id":[20000,20001] #方法id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
```
|
||||
|
||||
## 2、响应
|
||||
|
||||
算法的响应分为两个部分
|
||||
|
||||
1、rest响应,表明收到请求
|
||||
|
||||
2、mqtt消息,持续输出计算结果
|
||||
|
||||
### 1、rest
|
||||
|
||||
```
|
||||
{
|
||||
"status": "success",
|
||||
"task_id": "0001111",
|
||||
"message": "Detection started successfully"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 2、mqtt
|
||||
|
||||
ip 112.44.103.230 端口 1883
|
||||
|
||||
topic thing/product/ai/events
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
|
||||
"minio": [
|
||||
{
|
||||
"minio_path_before": "ai_result/20250710/1752128232469-patch_0011.png", # 需要分割的图片
|
||||
"minio_path_after": "ai_result/20250710/1752128234222-patch_0011.png", #分割之后的图片
|
||||
"minio_path_boundary": "ai_result/20250710/1752128234264-patch_0011.pngfinal_vis.png", # 分割的边界图片
|
||||
"minio_path_json": "ai_result/20250710/1752128234326-patch_0011.pnginstance_results.json", #分割生成的json文件
|
||||
"file_type": "pic"
|
||||
},
|
||||
{
|
||||
"minio_path_before": "ai_result/20250710/1752128240382-patch_0012.png",
|
||||
"minio_path_after": "ai_result/20250710/1752128241553-patch_0012.png",
|
||||
"minio_path_boundary": "ai_result/20250710/1752128241587-patch_0012.pngfinal_vis.png",
|
||||
"minio_path_json": "ai_result/20250710/1752128241631-patch_0012.pnginstance_results.json",
|
||||
"file_type": "pic"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# 接口名:地类变化监测
|
||||
|
||||
|
||||
|
||||
接收前端的图片,对一期、二期的图像做变化监测,并且将计算结果存储到minio,消息通过mqtt发送
|
||||
|
||||
## 1、请求
|
||||
|
||||
接口 /ai/pic/back_detect_pic
|
||||
|
||||
方法 post
|
||||
|
||||
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
|
||||
|
||||
body
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
|
||||
"content_body": {
|
||||
"s3_id":1,
|
||||
"s3_url":{
|
||||
"early":"/test/1-00205.png", # 一期图像minio文件地址
|
||||
"later":"/test/2-00205.png" # 二期图像minio文件地址
|
||||
},
|
||||
"model_func_id":[30000,30001]
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## 2、响应
|
||||
|
||||
算法的响应分为两个部分
|
||||
|
||||
1、rest响应,表明收到请求
|
||||
|
||||
2、mqtt消息,持续输出计算结果
|
||||
|
||||
### 1、rest
|
||||
|
||||
```
|
||||
{
|
||||
"status": "success",
|
||||
"task_id": "0001111",
|
||||
"message": "Detection started successfully"
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 2、mqtt
|
||||
|
||||
ip 112.44.103.230 端口 1883
|
||||
|
||||
topic thing/product/ai/events
|
||||
|
||||
```
|
||||
{
|
||||
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
|
||||
"minio": {
|
||||
"minio_path_1": "ai_result/20250627/1751007686483-1-00205.png", # 一期影像,minio地址
|
||||
"minio_path_2": "ai_result/20250627/1751007686541-2-00205.png", # 二期影像,minio地址
|
||||
"minio_path_result": "ai_result/20250627/1751007686.458642-result-2-00205.png", #识别结果,minio地址
|
||||
"file_type": "pic"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@ -51,6 +51,7 @@ class MQTTService:
|
||||
self._message_task = None
|
||||
self._connection_lock = asyncio.Lock()
|
||||
self.os_type = sys.platform.lower()
|
||||
self._loop = None # 保存事件循环
|
||||
|
||||
async def connect(self):
|
||||
async with self._connection_lock:
|
||||
@ -122,6 +123,25 @@ class MQTTService:
|
||||
await self.reconnect()
|
||||
await self.client.publish(topic, payload, qos=qos, retain=retain)
|
||||
|
||||
def publish_sync(self, topic, payload, qos=0, retain=False):
|
||||
"""同步发布消息(适用于同步代码调用)"""
|
||||
if self._loop is None or self._loop.is_closed():
|
||||
# 如果没有事件循环,创建一个新的
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
# 在事件循环中运行异步publish
|
||||
if not self._loop.is_running():
|
||||
return self._loop.run_until_complete(
|
||||
self.publish(topic, payload, qos=qos, retain=retain)
|
||||
)
|
||||
else:
|
||||
# 如果事件循环已经在运行,创建一个任务
|
||||
asyncio.ensure_future(
|
||||
self.publish(topic, payload, qos=qos, retain=retain)
|
||||
)
|
||||
return None
|
||||
|
||||
async def subscribe(self, topic, callback=None, qos=0):
|
||||
if not self.is_connected:
|
||||
await self.connect()
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
from typing import Union, Optional
|
||||
from typing import Any
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -25,9 +29,23 @@ class OSDInfo:
|
||||
wind_speed: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class OSDInfo_v1:
|
||||
attitude_head: float
|
||||
latitude: float
|
||||
longitude: float
|
||||
height: float
|
||||
speed_x: float
|
||||
speed_y: float
|
||||
speed_z: float
|
||||
gimbal_pitch: float
|
||||
gimbal_roll: float
|
||||
gimbal_yaw: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class OSDMessage:
|
||||
data: OSDInfo
|
||||
data: Any # 可以是两种类型之一
|
||||
method: str
|
||||
seq: int
|
||||
timestamp: int
|
||||
@ -36,19 +54,47 @@ class OSDMessage:
|
||||
def parse_osd_message(json_str: Optional[str]) -> Optional[OSDMessage]:
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
data = json_str
|
||||
try:
|
||||
data=json_str
|
||||
osd_info = OSDInfo(**data["data"])
|
||||
data_seq=data["seq"]
|
||||
except (TypeError, KeyError) as e:
|
||||
# 如果OSDInfo格式失败,尝试使用OSDInfo_v1格式
|
||||
try:
|
||||
osd_info = OSDInfo_v1(**data["data"])
|
||||
data_seq = 0
|
||||
except Exception as e2:
|
||||
print(f"Error parsing OSD message with both formats: {e2}")
|
||||
return None
|
||||
return OSDMessage(
|
||||
data=osd_info,
|
||||
method=data["method"],
|
||||
seq=data["seq"],
|
||||
seq=data_seq,
|
||||
timestamp=data["timestamp"]
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error parsing OSD message: {e}")
|
||||
return None
|
||||
# try:
|
||||
# data=json_str
|
||||
# # osd_info = OSDInfo(**data["data"])
|
||||
# # if osd_info is None:
|
||||
# # osd_info = OSDInfo_v1(**data["data"]) #适配东西湖区一代飞机的格式
|
||||
# try:
|
||||
# osd_info = OSDInfo(**data["data"])
|
||||
# except (TypeError, KeyError) as e:
|
||||
# # 如果OSDInfo格式失败,尝试使用OSDInfo_v1格式
|
||||
# try:
|
||||
# osd_info = OSDInfo_v1(**data["data"])
|
||||
# except Exception as e2:
|
||||
# print(f"Error parsing OSD message with both formats: {e2}")
|
||||
# return None
|
||||
# return OSDMessage(
|
||||
# data=osd_info,
|
||||
# method=data["method"],
|
||||
# seq=data["seq"],
|
||||
# timestamp=data["timestamp"]
|
||||
# )
|
||||
# except Exception as e:
|
||||
# print(f"Error parsing OSD message: {e}")
|
||||
# return None
|
||||
|
||||
|
||||
async def main():
|
||||
|
||||
@ -46,6 +46,7 @@ class ModelData:
|
||||
so_path: str
|
||||
repeat_dis: float
|
||||
repeat_time: float
|
||||
high_count_warn: float
|
||||
func_description: Optional[str]
|
||||
filter_indices: List[int]
|
||||
class_indices: List[int]
|
||||
@ -250,6 +251,7 @@ class ModelConfigDAO:
|
||||
aml.py_func,
|
||||
aml.repeat_dis,
|
||||
aml.repeat_time,
|
||||
aml.high_count_warn,
|
||||
am.scope,
|
||||
am.yolo_version,
|
||||
am.PATH,
|
||||
@ -572,6 +574,7 @@ WHERE
|
||||
filter_indices=filter_indices,
|
||||
repeat_dis=repeat_dis,
|
||||
repeat_time=row.get('repeat_time'),
|
||||
high_count_warn=row.get('high_count_warn'),
|
||||
class_indices=row['cls_index'],
|
||||
conf=conf,
|
||||
classes=classes,
|
||||
|
||||
645
multi_back_detect/multi_back_detect_api.py
Normal file
@ -0,0 +1,645 @@
|
||||
import json
|
||||
import time
|
||||
from sanic import Blueprint
|
||||
from sanic.response import json as json_response
|
||||
from sanic.exceptions import Unauthorized, SanicException
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Dict, Any
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from sympy import false
|
||||
|
||||
|
||||
try :
|
||||
from middleware.TaskManager import task_manager
|
||||
from middleware.query_model import ModelConfigDAO
|
||||
from middleware.query_postgress import batch_query_model_func_id
|
||||
from middleware.read_yolo_config import read_local_func_config
|
||||
from yolo.cv_multi_model_back_video import start_rtmp_processing
|
||||
from cv_video import startAIVideo, stopAIVideo
|
||||
from cv_back_video import startBackAIVideo
|
||||
except Exception as e:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# 获取项目根目录(假设脚本在根目录的子目录中)
|
||||
ROOT_DIR = Path(__file__).parent.parent # 根据实际层级调整
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
from middleware.TaskManager import task_manager
|
||||
from middleware.query_model import ModelConfigDAO
|
||||
from middleware.query_postgress import batch_query_model_func_id
|
||||
from middleware.read_yolo_config import read_local_func_config
|
||||
from yolo.cv_multi_model_back_video import start_rtmp_processing
|
||||
from cv_video import startAIVideo, stopAIVideo
|
||||
from cv_back_video import startBackAIVideo
|
||||
|
||||
|
||||
# 配置类
|
||||
class Config:
|
||||
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
||||
MAX_ACTIVE_TASKS = 10
|
||||
DEFAULT_CONFIDENCE = 0.5
|
||||
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
||||
|
||||
#正式数据库
|
||||
DB_CONFIG = {
|
||||
"dbname": "smart_dev",
|
||||
"user": "postgres",
|
||||
"password": "StrongPassword@123",
|
||||
"host": "222.212.85.86",
|
||||
"port": "5061"
|
||||
}
|
||||
|
||||
# 服务状态标志
|
||||
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
multi_back_detect_bp = Blueprint("multi_back_detect", url_prefix="")
|
||||
|
||||
@dataclass
|
||||
class BaseContentBody:
|
||||
pass # 公共字段可以放在这里
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_VideoMultiBackDetect(BaseContentBody):
|
||||
# mqtt_pub_id: int
|
||||
org_code: str
|
||||
# mqtt_ip: str
|
||||
# mqtt_port: int
|
||||
# mqtt_topic: str
|
||||
minio_file_path: str
|
||||
push_url: str # 临时测试用
|
||||
confidence: float
|
||||
para_list: list
|
||||
# invade_file: str
|
||||
invade: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_Sam3Pic(BaseContentBody):
|
||||
img_url: str
|
||||
prompt: str
|
||||
confidence: int
|
||||
mqtt_ip: str
|
||||
mqtt_port: int
|
||||
mqtt_topic: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_MultiBackDetect(BaseContentBody):
|
||||
# mqtt_pub_id: int
|
||||
# mqtt_sub_id: int
|
||||
# mqtt_pub_ip: str
|
||||
# mqtt_pub_port: int
|
||||
# mqtt_pub_topic: str
|
||||
# mqtt_sub_ip: str
|
||||
# mqtt_sub_port: int
|
||||
# mqtt_sub_topic: str
|
||||
org_code: str
|
||||
func_id: list
|
||||
source_url: str
|
||||
push_url: str # 临时测试用
|
||||
confidence: float
|
||||
para_list: list
|
||||
# invade_file: str
|
||||
invade: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_BackDetect(BaseContentBody):
|
||||
mqtt_ip: str
|
||||
mqtt_port: int
|
||||
mqtt_topic: str
|
||||
source_url: str
|
||||
push_url: str # 临时测试用
|
||||
confidence: float
|
||||
func_id: List[int]
|
||||
para: {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_BackDetectPic(BaseContentBody):
|
||||
s3_id: int
|
||||
s3_url: list[str]
|
||||
org_code: str
|
||||
# mqtt_pub_id: int
|
||||
confidence: float
|
||||
func_id: list[int]
|
||||
para: {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EarlyLaterUrls:
|
||||
early: str
|
||||
later: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_Detection(BaseContentBody):
|
||||
s3_id: int
|
||||
s3_url: EarlyLaterUrls
|
||||
func_id: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_Segementation(BaseContentBody):
|
||||
s3_id: int
|
||||
s3_url: list[str]
|
||||
func_id: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestJson:
|
||||
task_id: str
|
||||
sn: str
|
||||
content_body: BaseContentBody
|
||||
|
||||
def validate(self) -> None:
|
||||
"""验证请求参数"""
|
||||
if not self.task_id:
|
||||
raise ValueError("task_id is required")
|
||||
|
||||
if isinstance(self.content_body, ContentBodyFormat_VideoMultiBackDetect):
|
||||
if not self.content_body.minio_file_path:
|
||||
raise ValueError("minio_file_path is required for ContentBodyFormat_VideoMultiBackDetect")
|
||||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||||
elif isinstance(self.content_body, ContentBodyFormat_MultiBackDetect):
|
||||
if not self.content_body.para_list:
|
||||
raise ValueError("para_list is required for ContentBodyFormat_MultiBackDetect")
|
||||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||||
|
||||
elif isinstance(self.content_body, ContentBodyFormat_BackDetect):
|
||||
if not self.content_body.source_url:
|
||||
raise ValueError("source_url is required for ContentBodyFormat_BackDetect")
|
||||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||||
|
||||
elif isinstance(self.content_body, ContentBodyFormat_BackDetectPic):
|
||||
if not self.content_body.s3_id:
|
||||
raise ValueError("s3_id is required for ContentBodyFormat_BackDetectPic")
|
||||
if not self.content_body.s3_url:
|
||||
raise ValueError("s3_url is required for ContentBodyFormat_BackDetectPic")
|
||||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetectPic")
|
||||
|
||||
elif isinstance(self.content_body, ContentBodyFormat_Detection):
|
||||
if not self.content_body.s3_id:
|
||||
raise ValueError("s3_id is required for ContentBodyFormat_Detection")
|
||||
if not self.content_body.s3_url.early or not self.content_body.s3_url.later:
|
||||
raise ValueError("Both early and later URLs are required for ContentBodyFormat_Detection")
|
||||
|
||||
elif isinstance(self.content_body, ContentBodyFormat_Segementation):
|
||||
if not self.content_body.s3_id:
|
||||
raise ValueError("s3_id is required for ContentBodyFormat_Segementation")
|
||||
if not self.content_body.s3_url:
|
||||
raise ValueError("s3_url is required for ContentBodyFormat_Segementation")
|
||||
elif isinstance(self.content_body, ContentBodyFormat_Sam3Pic):
|
||||
if not self.content_body.prompt:
|
||||
raise ValueError("prompt is required for ContentBodyFormat_Sam3Pic")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'RequestJson':
|
||||
try:
|
||||
task_id = data['task_id']
|
||||
sn = data['sn']
|
||||
content_body_data = data['content_body']
|
||||
|
||||
if 'minio_file_path' in content_body_data:
|
||||
content_body = ContentBodyFormat_VideoMultiBackDetect(
|
||||
org_code=content_body_data['org_code'],
|
||||
minio_file_path=content_body_data['minio_file_path'],
|
||||
push_url=content_body_data['push_url'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
para_list=content_body_data.get('para_list', []),
|
||||
invade=content_body_data.get('invade', [])
|
||||
)
|
||||
|
||||
# 根据 content_body_data 的类型创建相应的实例
|
||||
elif 'para_list' in content_body_data:
|
||||
content_body = ContentBodyFormat_MultiBackDetect(
|
||||
org_code=content_body_data['org_code'],
|
||||
func_id=content_body_data['func_id'],
|
||||
source_url=content_body_data['source_url'],
|
||||
push_url=content_body_data['push_url'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
para_list=content_body_data.get('para_list', []),
|
||||
invade=content_body_data.get('invade', [])
|
||||
)
|
||||
# 根据 content_body_data 的类型创建相应的实例
|
||||
elif 'source_url' in content_body_data:
|
||||
content_body = ContentBodyFormat_BackDetect(
|
||||
mqtt_ip=content_body_data['mqtt_ip'],
|
||||
mqtt_port=content_body_data['mqtt_port'],
|
||||
mqtt_topic=content_body_data['mqtt_topic'],
|
||||
source_url=content_body_data['source_url'],
|
||||
push_url=content_body_data['push_url'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
func_id=content_body_data.get('func_id', []),
|
||||
para=content_body_data.get('para', {})
|
||||
)
|
||||
elif 's3_id' in content_body_data and 's3_url' in content_body_data:
|
||||
if isinstance(content_body_data['s3_url'], dict) and 'early' in content_body_data[
|
||||
's3_url'] and 'later' in content_body_data['s3_url']:
|
||||
content_body = ContentBodyFormat_Detection(
|
||||
s3_id=content_body_data['s3_id'],
|
||||
s3_url=EarlyLaterUrls(
|
||||
early=content_body_data['s3_url']['early'],
|
||||
later=content_body_data['s3_url']['later']
|
||||
),
|
||||
func_id=content_body_data.get('func_id', [])
|
||||
)
|
||||
elif isinstance(content_body_data['s3_url'], list) and 'confidence' not in content_body_data:
|
||||
content_body = ContentBodyFormat_Segementation(
|
||||
s3_id=content_body_data['s3_id'],
|
||||
s3_url=content_body_data['s3_url'],
|
||||
func_id=content_body_data.get('func_id', [])
|
||||
)
|
||||
elif isinstance(content_body_data['s3_url'], list):
|
||||
content_body = ContentBodyFormat_BackDetectPic(
|
||||
s3_id=content_body_data['s3_id'],
|
||||
s3_url=content_body_data['s3_url'],
|
||||
# mqtt_pub_id=content_body_data['mqtt_pub_id'],
|
||||
org_code=content_body_data['org_code'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
func_id=content_body_data.get('func_id', []),
|
||||
para=content_body_data.get('para', {})
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid s3_url format for ContentBodyFormat_Detection")
|
||||
elif 'prompt' in content_body_data:
|
||||
content_body = ContentBodyFormat_Sam3Pic(
|
||||
img_url=content_body_data['img_url'],
|
||||
prompt=content_body_data['prompt'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
mqtt_ip=content_body_data['mqtt_ip'],
|
||||
mqtt_port=content_body_data['mqtt_port'],
|
||||
mqtt_topic=content_body_data['mqtt_topic']
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid content_body format")
|
||||
|
||||
instance = cls(
|
||||
task_id=task_id,
|
||||
sn=sn,
|
||||
content_body=content_body
|
||||
)
|
||||
instance.validate()
|
||||
return instance
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing required field: {str(e)}")
|
||||
|
||||
async def safe_stop_ai_video():
|
||||
"""安全地停止AI视频处理,带有错误处理和恢复机制"""
|
||||
try:
|
||||
await asyncio.to_thread(stopAIVideo)
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 标记服务状态为不健康
|
||||
service_status["is_healthy"] = False
|
||||
service_status["last_error"] = str(e)
|
||||
service_status["error_time"] = datetime.now().isoformat()
|
||||
|
||||
# 强制结束所有任务
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
|
||||
# 尝试通过其他方式杀死可能存在的进程
|
||||
try:
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
# 查找并终止ffmpeg子进程
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
child_name = child.name().lower()
|
||||
if 'ffmpeg' in child_name:
|
||||
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
|
||||
child.send_signal(signal.SIGTERM)
|
||||
except Exception as child_e:
|
||||
logger.error(f"终止子进程出错: {str(child_e)}")
|
||||
except Exception as kill_e:
|
||||
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
|
||||
|
||||
# 等待一段时间让系统恢复
|
||||
await asyncio.sleep(Config.RESTART_DELAY)
|
||||
|
||||
# 重置服务状态
|
||||
service_status["is_healthy"] = True
|
||||
return False
|
||||
|
||||
def verify_token(request) -> None:
|
||||
"""验证请求token"""
|
||||
token = request.headers.get('X-API-Token')
|
||||
if not token or token != Config.VALID_TOKEN:
|
||||
logger.warning("Invalid token attempt")
|
||||
raise Unauthorized("Invalid token")
|
||||
|
||||
@multi_back_detect_bp.post("/ai/stream/multi_back_detect")
|
||||
async def start_multi_back_detection(request):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查服务健康状态
|
||||
if not service_status["is_healthy"]:
|
||||
logger.warning(
|
||||
f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}")
|
||||
service_status["is_healthy"] = True
|
||||
|
||||
# 停止所有现有任务(可选,根据需求调整)
|
||||
# for task_id in list(task_manager.tasks.keys()):
|
||||
# await task_manager.remove_task(task_id)
|
||||
|
||||
# 解析并验证请求数据
|
||||
request_json = RequestJson.from_dict(request.json)
|
||||
print(f"/ai/stream/multi_back_detect 请求:{request.json}")
|
||||
time.sleep(3)
|
||||
if request_json.task_id in task_manager.tasks:
|
||||
logger.warning(f"任务 {request_json.task_id} 已存在,跳过创建")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"任务 {request_json.task_id} 已存在,跳过创建"
|
||||
}, status=500)
|
||||
|
||||
if isinstance(request_json.content_body, ContentBodyFormat_MultiBackDetect):
|
||||
try:
|
||||
# 创建停止事件
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
# 包装处理函数以支持停止事件
|
||||
async def wrapped_processing():
|
||||
try:
|
||||
await run_back_Multi_Detect_async(request, request_json, stop_event)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"任务 {request_json.task_id} 被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {request_json.task_id} 异常终止: {e}")
|
||||
|
||||
# 创建并启动任务
|
||||
task_handle = asyncio.create_task(wrapped_processing())
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动AI视频处理失败: {e}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to start AI video processing: {str(e)}"
|
||||
}, status=500)
|
||||
else:
|
||||
return json_response({
|
||||
"status": "failed",
|
||||
"message": "content_body structure is wrong"
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": request_json.task_id,
|
||||
"message": "Detection started successfully"
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
return json_response({"status": "error", "message": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio.Event):
|
||||
global DB_CONFIG
|
||||
model_configs = []
|
||||
task_handle = None # 初始化task_handle
|
||||
|
||||
try:
|
||||
|
||||
invade_enable = false # para_list 中可以包含多个侵限使能,其中一个使能,即将置为True
|
||||
py_func = []
|
||||
# 创建DAO实例
|
||||
dao = ModelConfigDAO(DB_CONFIG)
|
||||
# insert_request_log(self, task_id, sn, org_code, requset_json, request)
|
||||
for para in request_json.content_body.para_list:
|
||||
func_id = para["func_id"]
|
||||
category = para.get("py_func", []) # 提供默认值
|
||||
py_func = category # 湖北现场临时用
|
||||
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
|
||||
if para_invade_enable:
|
||||
invade_enable = True
|
||||
query_results = batch_query_model_func_id([func_id], **DB_CONFIG)
|
||||
row_func_id = 0 # 伪代码,后续记得修改
|
||||
if len(query_results) < 1:
|
||||
continue
|
||||
for row in query_results:
|
||||
row_func_id = row["model_func_id"]
|
||||
func_id = para["func_id"]
|
||||
config = dao.get_config(func_id, category)
|
||||
repeat_dis = -1 # 基于两帧之间的距离去重
|
||||
if config:
|
||||
# 打印结构化结果(使用自定义编码器处理datetime)
|
||||
# 访问特定字段
|
||||
print("\n模型路径:", config.model_path)
|
||||
print("过滤类别:", config.filter_indices)
|
||||
print("第一个类别:", asdict(config.classes[0]))
|
||||
print("创建时间:", config.created_at)
|
||||
print("更新时间:", config.updated_at)
|
||||
print("去重的距离:", config.repeat_dis)
|
||||
repeat_dis = config.repeat_dis
|
||||
repeat_time = config.repeat_time
|
||||
high_count_warn = config.high_count_warn
|
||||
print(f"config.high_count_warn {config.high_count_warn}")
|
||||
model_configs.append(
|
||||
{
|
||||
|
||||
'path': config.model_path,
|
||||
# 'engine_path': config.engine_path,
|
||||
# 'so_path': config.so_path,
|
||||
# # 测试代码
|
||||
'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine",
|
||||
'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll",
|
||||
# 工地安全帽
|
||||
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
|
||||
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
|
||||
|
||||
'cls_map': config.cls_zn_to_eh_dict,
|
||||
'allowed_classes': config.allowed_classes,
|
||||
"cls_index": config.class_indices,
|
||||
"class_names": config.cls_names,
|
||||
"chinese_label": config.cls_en_dict,
|
||||
"list_func_id": row_func_id,
|
||||
"func_id": func_id,
|
||||
"para_invade_enable": para_invade_enable,
|
||||
"config_conf": config.conf
|
||||
}
|
||||
)
|
||||
else:
|
||||
print(f"未找到ID为 {func_id} 的模型配置")
|
||||
|
||||
# category = para.get("category", []) # 提供默认值
|
||||
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
|
||||
if para_invade_enable:
|
||||
invade_enable = True
|
||||
|
||||
# 前置处理 不同的版本的模型,输入的数据的格式不同,需要做前置处理,入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式
|
||||
|
||||
video_url = request_json.content_body.source_url
|
||||
sn = request_json.sn
|
||||
task_id = request_json.task_id
|
||||
# mqtt_pub_id = request_json.content_body.mqtt_pub_id
|
||||
# mqtt_sub_id = request_json.content_body.mqtt_sub_id
|
||||
org_code = request_json.content_body.org_code
|
||||
push_url = request_json.content_body.push_url
|
||||
# invade_file = request_json.content_body.invade_file
|
||||
invade = request_json.content_body.invade
|
||||
invade_file = invade["invade_file"]
|
||||
camera_para_url = invade["camera_para_url"]
|
||||
|
||||
if high_count_warn is None:
|
||||
high_count_warn=0
|
||||
|
||||
if "invade_switch" in invade:
|
||||
invade_switch = invade["invade_switch"]
|
||||
else:
|
||||
invade_switch = 0 # 或其他默认值
|
||||
# dao.get_mqtt_config_by_orgcode(org_code,)
|
||||
str_request = str(request) + "&" + str(request.socket) # 待测试,看看公网能不能捕获到请求端ip
|
||||
dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request)
|
||||
mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
|
||||
mqtt_sub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "sub")
|
||||
|
||||
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
|
||||
mqtt_pub_port = mqtt_pub_config.mqtt_port
|
||||
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
|
||||
print(f"mqtt_pub_topic {mqtt_pub_topic}")
|
||||
mqtt_pub_username = mqtt_pub_config.mqtt_username
|
||||
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
|
||||
mqtt_pub_description = mqtt_pub_config.mqtt_description
|
||||
mqtt_pub_org_code = mqtt_pub_config.org_code
|
||||
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
|
||||
|
||||
mqtt_sub_ip = mqtt_sub_config.mqtt_ip
|
||||
mqtt_sub_port = mqtt_sub_config.mqtt_port
|
||||
mqtt_sub_topic = mqtt_sub_config.mqtt_topic
|
||||
mqtt_sub_topic = mqtt_sub_topic.format(sn=sn)
|
||||
|
||||
print(f"mqtt_sub_topic {mqtt_sub_topic}")
|
||||
mqtt_sub_username = mqtt_sub_config.mqtt_username
|
||||
mqtt_sub_pass = mqtt_sub_config.mqtt_pass
|
||||
mqtt_sub_description = mqtt_sub_config.mqtt_description
|
||||
mqtt_sub_org_code = mqtt_sub_config.org_code
|
||||
mqtt_sub_mqtt_type = mqtt_sub_config.mqtt_type
|
||||
|
||||
local_func_config = read_local_func_config()
|
||||
|
||||
sn = request_json.sn
|
||||
device = dao.get_device(sn, org_code)
|
||||
print(f"device表 {sn} {org_code}")
|
||||
device_sn = device.sn
|
||||
device_orgcode = device.orgcode
|
||||
device_dname = device.dname
|
||||
device_lat = device.lat
|
||||
device_lng = device.lng
|
||||
device_height = device.height # 机场高度,后续用作现场的高程计算
|
||||
|
||||
# # 启动处理流程
|
||||
async def process_flow():
|
||||
try:
|
||||
await start_rtmp_processing(
|
||||
video_url,
|
||||
request_json.task_id,
|
||||
model_configs,
|
||||
mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
||||
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
|
||||
push_url,
|
||||
invade_enable,invade_switch, invade_file, camera_para_url,
|
||||
device_height, repeat_dis, repeat_time,high_count_warn
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理流程异常: {e}")
|
||||
raise
|
||||
|
||||
# 运行处理流程,直到被取消
|
||||
await asyncio.shield(process_flow())
|
||||
|
||||
# # # # 针对湖北现场临时处理--------------------------------------------------------
|
||||
# model_path = model_configs[0]["path"]
|
||||
# detect_classes = py_func
|
||||
# print(f"detect_classesdetect_classes {detect_classes}")
|
||||
# confidence = model_configs[0]["config_conf"]
|
||||
# # 创建处理函数以支持停止事件
|
||||
# async def process_video():
|
||||
# nonlocal task_handle # 使用nonlocal访问外部变量
|
||||
# try:
|
||||
# # 针对湖北现场临时处理
|
||||
# source_url = video_url
|
||||
# model_path = model_configs[0]["path"]
|
||||
# detect_classes = py_func # 使用py_func作为检测类别
|
||||
# print(f"detect_classesdetect_classes {detect_classes}")
|
||||
#
|
||||
# confidence = model_configs[0]["config_conf"]
|
||||
#
|
||||
# # 启动YOLO检测
|
||||
# await asyncio.to_thread(
|
||||
# startAIVideo,
|
||||
# source_url,
|
||||
# push_url,
|
||||
# model_path,
|
||||
# detect_classes,
|
||||
# confidence
|
||||
# )
|
||||
# except asyncio.CancelledError:
|
||||
# logger.info(f"任务 {task_id} 被取消")
|
||||
# raise
|
||||
# except Exception as e:
|
||||
# logger.error(f"任务 {task_id} 异常终止: {e}")
|
||||
# raise
|
||||
#
|
||||
# # 创建并启动任务
|
||||
# task_handle = asyncio.create_task(process_video()) # 存储task_handle
|
||||
|
||||
# 记录任务信息到task_manager
|
||||
task_info = {
|
||||
"source_url": video_url,
|
||||
"push_url": push_url,
|
||||
"status": "running",
|
||||
"task_handle": task_handle, # 存储实际的任务句柄
|
||||
"model_configs": model_configs,
|
||||
"device_height": device_height,
|
||||
"repeat_dis": repeat_dis,
|
||||
"repeat_time": repeat_time
|
||||
}
|
||||
|
||||
# 使用task_manager管理任务
|
||||
await task_manager.add_task(
|
||||
task_id,
|
||||
task_info,
|
||||
task_handle, # 传递实际的任务句柄
|
||||
[] # 暂时没有子任务
|
||||
)
|
||||
|
||||
# 等待任务完成或被取消
|
||||
try:
|
||||
await task_handle
|
||||
except asyncio.CancelledError:
|
||||
pass # 任务被取消是正常的
|
||||
await task_manager.remove_task(task_id)
|
||||
# # # 针对湖北现场处理结束--------------------------------------------------------
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"任务 {request_json.task_id} 收到停止信号,正在清理...")
|
||||
# 清理资源逻辑...
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {request_json.task_id} 处理失败: {e}")
|
||||
raise
|
||||
58
multi_back_detect/router_multi_back_detect_api.py
Normal file
@ -0,0 +1,58 @@
|
||||
import time
|
||||
import logging
|
||||
from sanic_cors import CORS
|
||||
from sanic import Sanic, Request, json
|
||||
from multi_back_detect_api import multi_back_detect_bp
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建Sanic应用
|
||||
app = Sanic("multiBackDetectAPI")
|
||||
# 显式注册蓝图
|
||||
app.blueprint(multi_back_detect_bp)
|
||||
|
||||
CORS(app, automatic_options=True)
|
||||
|
||||
# 中间件:请求计时
|
||||
@app.middleware("request")
|
||||
async def add_start_time(request: Request):
|
||||
request.ctx.start_time = time.time()
|
||||
|
||||
@app.middleware("response")
|
||||
async def add_response_time(request: Request, response):
|
||||
if hasattr(request.ctx, "start_time"):
|
||||
process_time = (time.time() - request.ctx.start_time) * 1000
|
||||
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def health_check(request: Request):
|
||||
"""健康检查"""
|
||||
return json({
|
||||
"status": "healthy",
|
||||
"timestamp": time.time(),
|
||||
"service": "terrain-analysis-api",
|
||||
"version": "1.0.0"
|
||||
})
|
||||
|
||||
# 错误处理
|
||||
@app.exception(Exception)
|
||||
async def handle_exception(request: Request, exception):
|
||||
"""全局异常处理"""
|
||||
logger.error(f"未处理的异常: {exception}")
|
||||
return json({
|
||||
"error": "服务器内部错误",
|
||||
"message": str(exception) if app.debug else "请稍后重试",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动服务器
|
||||
app.run(
|
||||
host="0.0.0.0",
|
||||
port=12320,
|
||||
debug=False, # 生产环境设为False
|
||||
access_log=True,
|
||||
auto_reload=True
|
||||
)
|
||||
@ -165,13 +165,15 @@ def pic_detect_func(mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
||||
try:
|
||||
frame_copy = frame.copy()
|
||||
results = counter(frame)
|
||||
|
||||
func_id=model_func_id_list[0]
|
||||
annotated_frame, box_result = cal_tricker_results(frame_copy, counter, class_names,
|
||||
model_func_id_list,
|
||||
func_id,
|
||||
local_func_cache, para, cls, chinese_label,
|
||||
model_func_id_list[0])
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理帧错误: {e}")
|
||||
print(f"处理帧错误1: {e}")
|
||||
error_count += 1
|
||||
if error_count >= 5:
|
||||
print(f"连续处理错误达到5次 ,正在停止处理...")
|
||||
|
||||
BIN
pt/build-wall.pt
Normal file
2
readme
@ -1,4 +1,4 @@
|
||||
conda create -n aienv2 python=3.10
|
||||
conda create -n test_sahi python=3.10
|
||||
激活环境
|
||||
|
||||
conda create -n yolo_trt_10 python=3.10 -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ -c https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ --override-channels
|
||||
|
||||
153
sam3/.gitignore
vendored
Normal file
@ -0,0 +1,153 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
*-Copy*.ipynb
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
*.code-workspace
|
||||
|
||||
# Model weights and checkpoints
|
||||
*.pth
|
||||
*.pt
|
||||
*.bin
|
||||
*.ckpt
|
||||
*.safetensors
|
||||
weights/
|
||||
checkpoints/
|
||||
sam3_logs/
|
||||
|
||||
# Data files
|
||||
*.h5
|
||||
*.hdf5
|
||||
*.pkl
|
||||
*.pickle
|
||||
*.npy
|
||||
*.npz
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
runs/
|
||||
tensorboard/
|
||||
|
||||
# OS specific
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# BPE vocabulary files
|
||||
*.bpe
|
||||
*.vocab
|
||||
80
sam3/CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,80 @@
|
||||
# Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to make participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all project spaces, and it also applies when
|
||||
an individual is representing the project or its community in public spaces.
|
||||
Examples of representing a project or community include using an official
|
||||
project e-mail address, posting via an official social media account, or acting
|
||||
as an appointed representative at an online or offline event. Representation of
|
||||
a project may be further defined and clarified by project maintainers.
|
||||
|
||||
This Code of Conduct also applies outside the project spaces when there is a
|
||||
reasonable belief that an individual's behavior may have a negative impact on
|
||||
the project or its community.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
||||
30
sam3/CONTRIBUTING.md
Normal file
@ -0,0 +1,30 @@
|
||||
# Contributing to sam3
|
||||
We want to make contributing to this project as easy and transparent as
|
||||
possible.
|
||||
|
||||
## Pull Requests
|
||||
We actively welcome your pull requests.
|
||||
|
||||
1. Fork the repo and create your branch from `main`.
|
||||
2. If you've added code that should be tested, add tests.
|
||||
3. If you've changed APIs, update the documentation.
|
||||
4. Make sure your code lints.
|
||||
5. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||
|
||||
## Contributor License Agreement ("CLA")
|
||||
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||
to do this once to work on any of Facebook's open source projects.
|
||||
|
||||
Complete your CLA here: <https://code.facebook.com/cla>
|
||||
|
||||
## Issues
|
||||
We use GitHub issues to track public bugs. Please ensure your description is
|
||||
clear and has sufficient instructions to be able to reproduce the issue.
|
||||
|
||||
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||
disclosure of security bugs. In those cases, please go through the process
|
||||
outlined on that page and do not file a public issue.
|
||||
|
||||
## License
|
||||
By contributing to sam3, you agree that your contributions will be licensed
|
||||
under the LICENSE file in the root directory of this source tree.
|
||||
61
sam3/LICENSE
Normal file
@ -0,0 +1,61 @@
|
||||
SAM License
|
||||
Last Updated: November 19, 2025
|
||||
|
||||
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
|
||||
|
||||
|
||||
“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
||||
|
||||
“Documentation” means the specifications, manuals and documentation accompanying
|
||||
SAM Materials distributed by Meta.
|
||||
|
||||
|
||||
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
||||
|
||||
|
||||
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
||||
|
||||
|
||||
“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
|
||||
|
||||
|
||||
“Trade Controls” means any of the following: Sanctions and applicable export and import controls.
|
||||
|
||||
By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
|
||||
|
||||
|
||||
1. License Rights and Redistribution.
|
||||
|
||||
|
||||
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
|
||||
|
||||
b. Redistribution and Use.
|
||||
i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
|
||||
|
||||
|
||||
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
|
||||
|
||||
|
||||
iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
|
||||
iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
|
||||
v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
|
||||
2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
||||
|
||||
|
||||
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
|
||||
|
||||
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||
|
||||
5. Intellectual Property.
|
||||
|
||||
|
||||
a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
||||
|
||||
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
|
||||
|
||||
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
|
||||
|
||||
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
||||
|
||||
|
||||
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
||||
6
sam3/MANIFEST.in
Normal file
@ -0,0 +1,6 @@
|
||||
include LICENSE
|
||||
include README.md
|
||||
recursive-include examples *.py
|
||||
recursive-include examples *.ipynb
|
||||
recursive-include examples *.md
|
||||
recursive-include tests *.py
|
||||
395
sam3/README.md
Normal file
@ -0,0 +1,395 @@
|
||||
# SAM 3: Segment Anything with Concepts
|
||||
|
||||
Meta Superintelligence Labs
|
||||
|
||||
[Nicolas Carion](https://www.nicolascarion.com/)\*,
|
||||
[Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en)\*,
|
||||
[Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en)\*,
|
||||
[Shoubhik Debnath](https://scholar.google.com/citations?user=fb6FOfsAAAAJ&hl=en)\*,
|
||||
[Ronghang Hu](https://ronghanghu.com/)\*,
|
||||
[Didac Suris](https://www.didacsuris.com/)\*,
|
||||
[Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en)\*,
|
||||
[Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en)\*,
|
||||
[Haitham Khedr](https://hkhedr.com/)\*, Andrew Huang,
|
||||
[Jie Lei](https://jayleicn.github.io/),
|
||||
[Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en),
|
||||
[Baishan Guo](https://scholar.google.com/citations?user=BC5wDu8AAAAJ&hl=en),
|
||||
Arpit Kalla, [Markus Marks](https://damaggu.github.io/),
|
||||
[Joseph Greer](https://scholar.google.com/citations?user=guL96CkAAAAJ&hl=en),
|
||||
Meng Wang, [Peize Sun](https://peizesun.github.io/),
|
||||
[Roman Rädle](https://scholar.google.com/citations?user=Tpt57v0AAAAJ&hl=en),
|
||||
[Triantafyllos Afouras](https://www.robots.ox.ac.uk/~afourast/),
|
||||
[Effrosyni Mavroudi](https://scholar.google.com/citations?user=vYRzGGEAAAAJ&hl=en),
|
||||
[Katherine Xu](https://k8xu.github.io/)°,
|
||||
[Tsung-Han Wu](https://patrickthwu.com/)°,
|
||||
[Yu Zhou](https://yu-bryan-zhou.github.io/)°,
|
||||
[Liliane Momeni](https://scholar.google.com/citations?user=Lb-KgVYAAAAJ&hl=en)°,
|
||||
[Rishi Hazra](https://rishihazra.github.io/)°,
|
||||
[Shuangrui Ding](https://mark12ding.github.io/)°,
|
||||
[Sagar Vaze](https://sgvaze.github.io/)°,
|
||||
[Francois Porcher](https://scholar.google.com/citations?user=LgHZ8hUAAAAJ&hl=en)°,
|
||||
[Feng Li](https://fengli-ust.github.io/)°,
|
||||
[Siyuan Li](https://siyuanliii.github.io/)°,
|
||||
[Aishwarya Kamath](https://ashkamath.github.io/)°,
|
||||
[Ho Kei Cheng](https://hkchengrex.com/)°,
|
||||
[Piotr Dollar](https://pdollar.github.io/)†,
|
||||
[Nikhila Ravi](https://nikhilaravi.com/)†,
|
||||
[Kate Saenko](https://ai.bu.edu/ksaenko.html)†,
|
||||
[Pengchuan Zhang](https://pzzhang.github.io/pzzhang/)†,
|
||||
[Christoph Feichtenhofer](https://feichtenhofer.github.io/)†
|
||||
|
||||
\* core contributor, ° intern, † project lead, order is random within groups
|
||||
|
||||
[[`Paper`](https://ai.meta.com/research/publications/sam-3-segment-anything-with-concepts/)]
|
||||
[[`Project`](https://ai.meta.com/sam3)]
|
||||
[[`Demo`](https://segment-anything.com/)]
|
||||
[[`Blog`](https://ai.meta.com/blog/segment-anything-model-3/)]
|
||||
[[`BibTeX`](#citing-sam-3)]
|
||||
|
||||
 SAM 3 is a unified foundation model for promptable segmentation in images and videos. It can detect, segment, and track objects using text or visual prompts such as points, boxes, and masks. Compared to its predecessor [SAM 2](https://github.com/facebookresearch/sam2), SAM 3 introduces the ability to exhaustively segment all instances of an open-vocabulary concept specified by a short text phrase or exemplars. Unlike prior work, SAM 3 can handle a vastly larger set of open-vocabulary prompts. It achieves 75-80% of human performance on our new [SA-CO benchmark](https://github.com/facebookresearch/sam3?tab=readme-ov-file#sa-co-dataset) which contains 270K unique concepts, over 50 times more than existing benchmarks.
|
||||
|
||||
This breakthrough is driven by an innovative data engine that has automatically annotated over 4 million unique concepts, creating the largest high-quality open-vocabulary segmentation dataset to date. In addition, SAM 3 introduces a new model architecture featuring a presence token that improves discrimination between closely related text prompts (e.g., “a player in white” vs. “a player in red”), as well as a decoupled detector–tracker design that minimizes task interference and scales efficiently with data.
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/dog.gif" width=380 />
|
||||
<img src="assets/player.gif" width=380 />
|
||||
</p>
|
||||
|
||||
## Installation
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12 or higher
|
||||
- PyTorch 2.7 or higher
|
||||
- CUDA-compatible GPU with CUDA 12.6 or higher
|
||||
|
||||
1. **Create a new Conda environment:**
|
||||
|
||||
```bash
|
||||
conda create -n sam3 python=3.12
|
||||
conda deactivate
|
||||
conda activate sam3
|
||||
```
|
||||
|
||||
2. **Install PyTorch with CUDA support:**
|
||||
|
||||
```bash
|
||||
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
|
||||
```
|
||||
|
||||
3. **Clone the repository and install the package:**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/facebookresearch/sam3.git
|
||||
cd sam3
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
4. **Install additional dependencies for example notebooks or development:**
|
||||
|
||||
```bash
|
||||
# For running example notebooks
|
||||
pip install -e ".[notebooks]"
|
||||
|
||||
# For development
|
||||
pip install -e ".[train,dev]"
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
⚠️ Before using SAM 3, please request access to the checkpoints on the SAM 3
|
||||
Hugging Face [repo](https://huggingface.co/facebook/sam3). Once accepted, you
|
||||
need to be authenticated to download the checkpoints. You can do this by running
|
||||
the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
|
||||
(e.g. `hf auth login` after generating an access token.)
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
import torch
|
||||
#################################### For Image ####################################
|
||||
from PIL import Image
|
||||
from sam3.model_builder import build_sam3_image_model
|
||||
from sam3.model.sam3_image_processor import Sam3Processor
|
||||
# Load the model
|
||||
model = build_sam3_image_model()
|
||||
processor = Sam3Processor(model)
|
||||
# Load an image
|
||||
image = Image.open("<YOUR_IMAGE_PATH.jpg>")
|
||||
inference_state = processor.set_image(image)
|
||||
# Prompt the model with text
|
||||
output = processor.set_text_prompt(state=inference_state, prompt="<YOUR_TEXT_PROMPT>")
|
||||
|
||||
# Get the masks, bounding boxes, and scores
|
||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||
|
||||
#################################### For Video ####################################
|
||||
|
||||
from sam3.model_builder import build_sam3_video_predictor
|
||||
|
||||
video_predictor = build_sam3_video_predictor()
|
||||
video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
|
||||
# Start a session
|
||||
response = video_predictor.handle_request(
|
||||
request=dict(
|
||||
type="start_session",
|
||||
resource_path=video_path,
|
||||
)
|
||||
)
|
||||
response = video_predictor.handle_request(
|
||||
request=dict(
|
||||
type="add_prompt",
|
||||
session_id=response["session_id"],
|
||||
frame_index=0, # Arbitrary frame index
|
||||
text="<YOUR_TEXT_PROMPT>",
|
||||
)
|
||||
)
|
||||
output = response["outputs"]
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
The `examples` directory contains notebooks demonstrating how to use SAM3 with
|
||||
various types of prompts:
|
||||
|
||||
- [`sam3_image_predictor_example.ipynb`](examples/sam3_image_predictor_example.ipynb)
|
||||
: Demonstrates how to prompt SAM 3 with text and visual box prompts on images.
|
||||
- [`sam3_video_predictor_example.ipynb`](examples/sam3_video_predictor_example.ipynb)
|
||||
: Demonstrates how to prompt SAM 3 with text prompts on videos, and doing
|
||||
further interactive refinements with points.
|
||||
- [`sam3_image_batched_inference.ipynb`](examples/sam3_image_batched_inference.ipynb)
|
||||
: Demonstrates how to run batched inference with SAM 3 on images.
|
||||
- [`sam3_agent.ipynb`](examples/sam3_agent.ipynb): Demonsterates the use of SAM
|
||||
3 Agent to segment complex text prompt on images.
|
||||
- [`saco_gold_silver_vis_example.ipynb`](examples/saco_gold_silver_vis_example.ipynb)
|
||||
: Shows a few examples from SA-Co image evaluation set.
|
||||
- [`saco_veval_vis_example.ipynb`](examples/saco_veval_vis_example.ipynb) :
|
||||
Shows a few examples from SA-Co video evaluation set.
|
||||
|
||||
There are additional notebooks in the examples directory that demonstrate how to
|
||||
use SAM 3 for interactive instance segmentation in images and videos (SAM 1/2
|
||||
tasks), or as a tool for an MLLM, and how to run evaluations on the SA-Co
|
||||
dataset.
|
||||
|
||||
To run the Jupyter notebook examples:
|
||||
|
||||
```bash
|
||||
# Make sure you have the notebooks dependencies installed
|
||||
pip install -e ".[notebooks]"
|
||||
|
||||
# Start Jupyter notebook
|
||||
jupyter notebook examples/sam3_image_predictor_example.ipynb
|
||||
```
|
||||
|
||||
## Model
|
||||
|
||||
SAM 3 consists of a detector and a tracker that share a vision encoder. It has 848M parameters. The
|
||||
detector is a DETR-based model conditioned on text, geometry, and image
|
||||
exemplars. The tracker inherits the SAM 2 transformer encoder-decoder
|
||||
architecture, supporting video segmentation and interactive refinement.
|
||||
|
||||
## Image Results
|
||||
|
||||
<div align="center">
|
||||
<table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
|
||||
<thead>
|
||||
<tr>
|
||||
<th rowspan="3" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
|
||||
<th colspan="3" style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">Instance Segmentation</th>
|
||||
<th colspan="5" style="text-align: center; padding: 12px 20px">Box Detection</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
|
||||
<th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">SA-Co/Gold</th>
|
||||
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
|
||||
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">COCO</th>
|
||||
<th style="text-align: center; padding: 12px 20px">SA-Co/Gold</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
|
||||
<th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">cgF1</th>
|
||||
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
|
||||
<th style="text-align: center; padding: 12px 20px">AP</th>
|
||||
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP<sub>o</sub>
|
||||
</th>
|
||||
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
|
||||
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">72.8</td>
|
||||
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; padding: 10px 20px">74.0</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border-right: 2px solid #ddd; padding: 10px 20px">OWLv2*</td>
|
||||
<td style="text-align: center; padding: 10px 20px; color: #999">29.3</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">43.4</td>
|
||||
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">24.6</td>
|
||||
<td style="text-align: center; padding: 10px 20px; color: #999">30.2</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">45.5</td>
|
||||
<td style="text-align: center; padding: 10px 20px">46.1</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">23.9</td>
|
||||
<td style="text-align: center; padding: 10px 20px">24.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border-right: 2px solid #ddd; padding: 10px 20px">DINO-X</td>
|
||||
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">38.5</td>
|
||||
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">21.3</td>
|
||||
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">52.4</td>
|
||||
<td style="text-align: center; padding: 10px 20px">56.0</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; padding: 10px 20px">22.5</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Gemini 2.5</td>
|
||||
<td style="text-align: center; padding: 10px 20px">13.4</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">13.0</td>
|
||||
<td style="text-align: center; padding: 10px 20px">16.1</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; padding: 10px 20px">14.4</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid #b19c9cff">
|
||||
<td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
|
||||
<td style="text-align: center; padding: 10px 20px">37.2</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">48.5</td>
|
||||
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">54.1</td>
|
||||
<td style="text-align: center; padding: 10px 20px">40.6</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">53.6</td>
|
||||
<td style="text-align: center; padding: 10px 20px">56.4</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">55.7</td>
|
||||
<td style="text-align: center; padding: 10px 20px">55.7</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
<p style="text-align: center; margin-top: 10px; font-size: 0.9em; color: #ddd;">* Partially trained on LVIS, AP<sub>o</sub> refers to COCO-O accuracy</p>
|
||||
|
||||
</div>
|
||||
|
||||
## Video Results
|
||||
|
||||
<div align="center">
|
||||
<table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
|
||||
<thead>
|
||||
<tr>
|
||||
<th rowspan="2" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
|
||||
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SA-V test</th>
|
||||
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">YT-Temporal-1B test</th>
|
||||
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SmartGlasses test</th>
|
||||
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVVIS test</th>
|
||||
<th style="text-align: center; padding: 12px 20px">BURST test</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
|
||||
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
|
||||
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
|
||||
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">mAP</th>
|
||||
<th style="text-align: center; padding: 12px 20px">HOTA</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
|
||||
<td style="text-align: center; padding: 10px 20px">53.1</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">70.5</td>
|
||||
<td style="text-align: center; padding: 10px 20px">71.2</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">78.4</td>
|
||||
<td style="text-align: center; padding: 10px 20px">58.5</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">72.3</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid #b19c9cff">
|
||||
<td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
|
||||
<td style="text-align: center; padding: 10px 20px">30.3</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">58.0</td>
|
||||
<td style="text-align: center; padding: 10px 20px">50.8</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">69.9</td>
|
||||
<td style="text-align: center; padding: 10px 20px">36.4</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">63.6</td>
|
||||
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">36.3</td>
|
||||
<td style="text-align: center; padding: 10px 20px">44.5</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
## SA-Co Dataset
|
||||
|
||||
We release 2 image benchmarks, [SA-Co/Gold](scripts/eval/gold/README.md) and
|
||||
[SA-Co/Silver](scripts/eval/silver/README.md), and a video benchmark
|
||||
[SA-Co/VEval](scripts/eval/veval/README.md). The datasets contain images (or videos) with annotated noun phrases. Each image/video and noun phrase pair is annotated with instance masks and unique IDs of each object matching the phrase. Phrases that have no matching objects (negative prompts) have no masks, shown in red font in the figure. See the linked READMEs for more details on how to download and run evaluations on the datasets.
|
||||
|
||||
* HuggingFace host: [SA-Co/Gold](https://huggingface.co/datasets/facebook/SACo-Gold), [SA-Co/Silver](https://huggingface.co/datasets/facebook/SACo-Silver) and [SA-Co/VEval](https://huggingface.co/datasets/facebook/SACo-VEval)
|
||||
* Roboflow host: [SA-Co/Gold](https://universe.roboflow.com/sa-co-gold), [SA-Co/Silver](https://universe.roboflow.com/sa-co-silver) and [SA-Co/VEval](https://universe.roboflow.com/sa-co-veval)
|
||||
|
||||

|
||||
|
||||
## Development
|
||||
|
||||
To set up the development environment:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev,train]"
|
||||
```
|
||||
|
||||
To format the code:
|
||||
|
||||
```bash
|
||||
ufmt format .
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
See [contributing](CONTRIBUTING.md) and the
|
||||
[code of conduct](CODE_OF_CONDUCT.md).
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the SAM License - see the [LICENSE](LICENSE) file
|
||||
for details.
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We would like to thank the following people for their contributions to the SAM 3 project: Alex He, Alexander Kirillov,
|
||||
Alyssa Newcomb, Ana Paula Kirschner Mofarrej, Andrea Madotto, Andrew Westbury, Ashley Gabriel, Azita Shokpour,
|
||||
Ben Samples, Bernie Huang, Carleigh Wood, Ching-Feng Yeh, Christian Puhrsch, Claudette Ward, Daniel Bolya,
|
||||
Daniel Li, Facundo Figueroa, Fazila Vhora, George Orlin, Hanzi Mao, Helen Klein, Hu Xu, Ida Cheng, Jake Kinney,
|
||||
Jiale Zhi, Jo Sampaio, Joel Schlosser, Justin Johnson, Kai Brown, Karen Bergan, Karla Martucci, Kenny Lehmann,
|
||||
Maddie Mintz, Mallika Malhotra, Matt Ward, Michelle Chan, Michelle Restrepo, Miranda Hartley, Muhammad Maaz,
|
||||
Nisha Deo, Peter Park, Phillip Thomas, Raghu Nayani, Rene Martinez Doehner, Robbie Adkins, Ross Girshik, Sasha
|
||||
Mitts, Shashank Jain, Spencer Whitehead, Ty Toledano, Valentin Gabeur, Vincent Cho, Vivian Lee, William Ngan,
|
||||
Xuehai He, Yael Yungster, Ziqi Pang, Ziyi Dou, Zoe Quake.
|
||||
|
||||
## Citing SAM 3
|
||||
|
||||
If you use SAM 3 or the SA-Co dataset in your research, please use the following BibTeX entry.
|
||||
|
||||
```bibtex
|
||||
@misc{carion2025sam3segmentconcepts,
|
||||
title={SAM 3: Segment Anything with Concepts},
|
||||
author={Nicolas Carion and Laura Gustafson and Yuan-Ting Hu and Shoubhik Debnath and Ronghang Hu and Didac Suris and Chaitanya Ryali and Kalyan Vasudev Alwala and Haitham Khedr and Andrew Huang and Jie Lei and Tengyu Ma and Baishan Guo and Arpit Kalla and Markus Marks and Joseph Greer and Meng Wang and Peize Sun and Roman Rädle and Triantafyllos Afouras and Effrosyni Mavroudi and Katherine Xu and Tsung-Han Wu and Yu Zhou and Liliane Momeni and Rishi Hazra and Shuangrui Ding and Sagar Vaze and Francois Porcher and Feng Li and Siyuan Li and Aishwarya Kamath and Ho Kei Cheng and Piotr Dollár and Nikhila Ravi and Kate Saenko and Pengchuan Zhang and Christoph Feichtenhofer},
|
||||
year={2025},
|
||||
eprint={2511.16719},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV},
|
||||
url={https://arxiv.org/abs/2511.16719},
|
||||
}
|
||||
```
|
||||
190
sam3/README_TRAIN.md
Normal file
@ -0,0 +1,190 @@
|
||||
# Training
|
||||
|
||||
This repository supports finetuning SAM3 models on custom datasets in multi-node setup or local execution. The training script is located at `sam3/train.py` and uses Hydra configuration management to handle complex training setups.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
cd sam3
|
||||
pip install -e ".[train]"
|
||||
```
|
||||
|
||||
### Training Script Usage
|
||||
|
||||
The main training script is located at `sam3/train.py`. It uses Hydra configuration management to handle complex training setups.
|
||||
|
||||
#### Basic Usage
|
||||
|
||||
```bash
|
||||
# Example: Train on Roboflow dataset
|
||||
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml
|
||||
# Example: Train on ODinW13 dataset
|
||||
python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
|
||||
```
|
||||
Follow [`Roboflow 100-VL`](https://github.com/roboflow/rf100-vl/) to download the roboflow 100-vl datasets. Follow [`GLIP`](https://github.com/microsoft/GLIP) to download the ODinW datasets. The data folder should be organized as follows, and put your roboflow_vl_100_root and odinw_data_root in the job configs.
|
||||
```
|
||||
roboflow_vl_100_root:
|
||||
13-lkc01
|
||||
train
|
||||
valid
|
||||
test
|
||||
2024-frc
|
||||
actions
|
||||
...
|
||||
odinw_data_root:
|
||||
AerialMaritimeDrone
|
||||
large
|
||||
train
|
||||
valid
|
||||
test
|
||||
Aquarium
|
||||
...
|
||||
```
|
||||
|
||||
#### Command Line Arguments
|
||||
|
||||
The training script supports several command line arguments:
|
||||
|
||||
```bash
|
||||
python sam3/train/train.py \
|
||||
-c CONFIG_NAME \
|
||||
[--use-cluster 0|1] \
|
||||
[--partition PARTITION_NAME] \
|
||||
[--account ACCOUNT_NAME] \
|
||||
[--qos QOS_NAME] \
|
||||
[--num-gpus NUM_GPUS] \
|
||||
[--num-nodes NUM_NODES]
|
||||
```
|
||||
|
||||
**Arguments:**
|
||||
- `-c, --config`: **Required.** Path to the configuration file (e.g., `sam3/train/configs/roboflow_v100_full_ft_100_images.yaml`)
|
||||
- `--use-cluster`: Whether to launch on a cluster (0: local, 1: cluster). Default: uses config setting
|
||||
- `--partition`: SLURM partition name for cluster execution
|
||||
- `--account`: SLURM account name for cluster execution
|
||||
- `--qos`: SLURM QOS (Quality of Service) setting
|
||||
- `--num-gpus`: Number of GPUs per node. Default: uses config setting
|
||||
- `--num-nodes`: Number of nodes for distributed training. Default: uses config setting
|
||||
|
||||
#### Local Training Examples
|
||||
|
||||
```bash
|
||||
# Single GPU training
|
||||
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 1
|
||||
|
||||
# Multi-GPU training on a single node
|
||||
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 4
|
||||
|
||||
# Force local execution even if config specifies GPUs
|
||||
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0
|
||||
```
|
||||
|
||||
#### Cluster Training Examples
|
||||
|
||||
```bash
|
||||
# Basic cluster training with default settings from config
|
||||
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 1
|
||||
|
||||
# Cluster training with specific SLURM settings
|
||||
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
|
||||
--use-cluster 1 \
|
||||
--partition gpu_partition \
|
||||
--account my_account \
|
||||
--qos high_priority \
|
||||
--num-gpus 8 \
|
||||
--num-nodes 2
|
||||
```
|
||||
|
||||
### Configuration Files
|
||||
|
||||
Training configurations are stored in `sam3/train/configs/`. The configuration files use Hydra's YAML format and support:
|
||||
|
||||
- **Dataset Configuration**: Data paths, transforms, and loading parameters
|
||||
- **Model Configuration**: Architecture settings, checkpoint paths, and model parameters
|
||||
- **Training Configuration**: Batch sizes, learning rates, optimization settings
|
||||
- **Launcher Configuration**: Distributed training and cluster settings
|
||||
- **Logging Configuration**: TensorBoard, experiment tracking, and output directories
|
||||
|
||||
#### Key Configuration Sections
|
||||
|
||||
```yaml
|
||||
# Paths to datasets and checkpoints
|
||||
paths:
|
||||
bpe_path: /path/to/bpe/file
|
||||
dataset_root: /path/to/dataset
|
||||
experiment_log_dir: /path/to/logs
|
||||
|
||||
# Launcher settings for local/cluster execution
|
||||
launcher:
|
||||
num_nodes: 1
|
||||
gpus_per_node: 2
|
||||
experiment_log_dir: ${paths.experiment_log_dir}
|
||||
|
||||
# Cluster execution settings
|
||||
submitit:
|
||||
use_cluster: True
|
||||
timeout_hour: 72
|
||||
cpus_per_task: 10
|
||||
partition: null
|
||||
account: null
|
||||
```
|
||||
|
||||
### Monitoring Training
|
||||
|
||||
The training script automatically sets up logging and saves outputs to the experiment directory:
|
||||
|
||||
```bash
|
||||
# Logs are saved to the experiment_log_dir specified in config
|
||||
experiment_log_dir/
|
||||
├── config.yaml # Original configuration
|
||||
├── config_resolved.yaml # Resolved configuration with all variables expanded
|
||||
├── checkpoints/ # Model checkpoints (if skip_checkpointing=False)
|
||||
├── tensorboard/ # TensorBoard logs
|
||||
├── logs/ # Text logs
|
||||
└── submitit_logs/ # Cluster job logs (if using cluster)
|
||||
```
|
||||
|
||||
You can monitor training progress using TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir /path/to/experiment_log_dir/tensorboard
|
||||
```
|
||||
|
||||
### Job Arrays for Dataset Sweeps
|
||||
|
||||
The Roboflow and ODinW configuration supports job arrays for training multiple models on different datasets:
|
||||
|
||||
This feature is specifically enabled via,
|
||||
```yaml
|
||||
submitit:
|
||||
job_array:
|
||||
num_tasks: 100
|
||||
task_index: 0
|
||||
```
|
||||
|
||||
The configuration includes a complete list of 100 Roboflow supercategories, and the `submitit.job_array.task_index` automatically selects which dataset to use based on the array job index.
|
||||
|
||||
```bash
|
||||
# Submit job array to train on different Roboflow datasets
|
||||
# The job array index selects which dataset from all_roboflow_supercategories
|
||||
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
|
||||
--use-cluster 1
|
||||
```
|
||||
|
||||
### Reproduce ODinW13 10-shot results
|
||||
Running the following job will give the results on the ODinW13 seed 300, see `odinw_train.train_file: fewshot_train_shot10_seed300` in the config file.
|
||||
```bash
|
||||
# Example: Train on ODinW13 dataset
|
||||
python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
|
||||
```
|
||||
Change `odinw_train.train_file` to `fewshot_train_shot10_seed30` and `fewshot_train_shot10_seed3` to get the results for the other two seeds. Final results are aggregated from the three seeds. Notice that a small number of jobs may diverge during training, in which case we just use the last checkpoint's result before it diverges.
|
||||
|
||||
|
||||
### Eval Script Usage
|
||||
With a similar setup as the training config, the training script `sam3/train.py` can also be used for evaluation, too, when setting `trainer.mode = val` in the job config. Run the following job will give the results on the zero-shot results on RF100-VL and ODinW13 datasets.
|
||||
```bash
|
||||
# Example: Evaluate on Roboflow dataset
|
||||
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_eval.yaml
|
||||
# Example: Evaluate on ODinW13 dataset
|
||||
python sam3/train/train.py -c configs/odinw13/odinw_text_only.yaml
|
||||
```
|
||||
BIN
sam3/assets/dog.gif
Normal file
|
After Width: | Height: | Size: 6.8 MiB |
BIN
sam3/assets/images/groceries.jpg
Normal file
|
After Width: | Height: | Size: 164 KiB |
BIN
sam3/assets/images/truck.jpg
Normal file
|
After Width: | Height: | Size: 265 KiB |
BIN
sam3/assets/model_diagram.png
Normal file
|
After Width: | Height: | Size: 707 KiB |
BIN
sam3/assets/player.gif
Normal file
|
After Width: | Height: | Size: 4.2 MiB |
BIN
sam3/assets/sa_co_dataset.jpg
Normal file
|
After Width: | Height: | Size: 991 KiB |
BIN
sam3/assets/saco_gold_annotation.png
Normal file
|
After Width: | Height: | Size: 3.8 MiB |
BIN
sam3/assets/videos/0001/0.jpg
Normal file
|
After Width: | Height: | Size: 141 KiB |
BIN
sam3/assets/videos/0001/1.jpg
Normal file
|
After Width: | Height: | Size: 138 KiB |
BIN
sam3/assets/videos/0001/10.jpg
Normal file
|
After Width: | Height: | Size: 134 KiB |
BIN
sam3/assets/videos/0001/100.jpg
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
sam3/assets/videos/0001/101.jpg
Normal file
|
After Width: | Height: | Size: 114 KiB |
BIN
sam3/assets/videos/0001/102.jpg
Normal file
|
After Width: | Height: | Size: 111 KiB |
BIN
sam3/assets/videos/0001/103.jpg
Normal file
|
After Width: | Height: | Size: 111 KiB |
BIN
sam3/assets/videos/0001/104.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
sam3/assets/videos/0001/105.jpg
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
sam3/assets/videos/0001/106.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
sam3/assets/videos/0001/107.jpg
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
sam3/assets/videos/0001/108.jpg
Normal file
|
After Width: | Height: | Size: 111 KiB |
BIN
sam3/assets/videos/0001/109.jpg
Normal file
|
After Width: | Height: | Size: 114 KiB |
BIN
sam3/assets/videos/0001/11.jpg
Normal file
|
After Width: | Height: | Size: 136 KiB |
BIN
sam3/assets/videos/0001/110.jpg
Normal file
|
After Width: | Height: | Size: 113 KiB |
BIN
sam3/assets/videos/0001/111.jpg
Normal file
|
After Width: | Height: | Size: 113 KiB |
BIN
sam3/assets/videos/0001/112.jpg
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
sam3/assets/videos/0001/113.jpg
Normal file
|
After Width: | Height: | Size: 113 KiB |
BIN
sam3/assets/videos/0001/114.jpg
Normal file
|
After Width: | Height: | Size: 111 KiB |
BIN
sam3/assets/videos/0001/115.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
sam3/assets/videos/0001/116.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
sam3/assets/videos/0001/117.jpg
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
sam3/assets/videos/0001/118.jpg
Normal file
|
After Width: | Height: | Size: 107 KiB |
BIN
sam3/assets/videos/0001/119.jpg
Normal file
|
After Width: | Height: | Size: 105 KiB |
BIN
sam3/assets/videos/0001/12.jpg
Normal file
|
After Width: | Height: | Size: 134 KiB |
BIN
sam3/assets/videos/0001/120.jpg
Normal file
|
After Width: | Height: | Size: 106 KiB |
BIN
sam3/assets/videos/0001/121.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/122.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/123.jpg
Normal file
|
After Width: | Height: | Size: 106 KiB |
BIN
sam3/assets/videos/0001/124.jpg
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
sam3/assets/videos/0001/125.jpg
Normal file
|
After Width: | Height: | Size: 107 KiB |
BIN
sam3/assets/videos/0001/126.jpg
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
sam3/assets/videos/0001/127.jpg
Normal file
|
After Width: | Height: | Size: 105 KiB |
BIN
sam3/assets/videos/0001/128.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/129.jpg
Normal file
|
After Width: | Height: | Size: 102 KiB |