Compare commits

...

1 Commits

Author SHA1 Message Date
b899c4e9de 上传grpc 测试demo 2026-01-05 16:29:39 +08:00
15 changed files with 1093 additions and 6919 deletions

20
.gitignore vendored Normal file
View File

@ -0,0 +1,20 @@
# 忽略所有 .log 文件
*.log
# 忽略特定目录(如 node_modules/
node_modules/
# 忽略本地配置文件(但保留示例文件)
config.local.json
!config.example.json
# 忽略 IDE 文件
.idea/
*.iml
# 忽略编译输出目录
dist/
build/
test
*test*

3
.idea/misc.xml generated
View File

@ -1,4 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<project version="4"> <project version="4">
<component name="Black">
<option name="sdkName" value="yolo_tensorrt" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="yolo_tensorrt" project-jdk-type="Python SDK" /> <component name="ProjectRootManager" version="2" project-jdk-name="yolo_tensorrt" project-jdk-type="Python SDK" />
</project> </project>

View File

@ -1,100 +0,0 @@
import os
import shutil
import cv2
import collections
from ultralytics import YOLO
from miniohelp import downFile, upload_file, parse_minio_url # 确保你有这些工具函数
from minio import Minio
def process_images(yolo_model, image_list, class_filter, input_folder, output_folder, minio_info):
# 初始化 MinIO 客户端# 用配置字典初始化 Minio 客户端对象
# 清洗 endpoint去掉 http:// 或 https:// 前缀
endpoint = minio_info["MinIOEndpoint"].replace("http://", "").replace("https://", "")
# 初始化 MinIO 客户端
minio = Minio(
endpoint=endpoint,
access_key=minio_info["MinIOAccessKey"],
secret_key=minio_info["MinIOSecretKey"],
secure=False
)
os.makedirs(input_folder, exist_ok=True)
os.makedirs(output_folder, exist_ok=True)
model = YOLO(yolo_model)
class_ids_filter = [int(cls) for cls in class_filter.split(",")] if class_filter else None
output_image_list = []
for item in image_list:
img_id = item["id"]
img_url = item["path"]
# 解析 MinIO 地址
if img_url.startswith("http"):
bucket_name, img_path = parse_minio_url(img_url)
else:
bucket_name, img_path = "default-bucket", img_url
try:
# 下载原图到本地
local_input_path = os.path.join(input_folder, os.path.basename(img_path))
downFile(minio, img_path, bucket_name, local_input_path)
# 读取图像
image = cv2.imread(local_input_path)
if image is None:
raise ValueError(f"无法读取图像: {local_input_path}")
# YOLO 检测
results = model.predict(image,
classes=class_ids_filter,
conf=0.5,
iou = 0.111,
show_labels = False,)
result = results[0]
# 统计类别数
class_counts = collections.Counter(result.boxes.cls.cpu().numpy().astype(int)) if result.boxes is not None else {}
filtered_class_counts = {k: v for k, v in class_counts.items() if k in class_ids_filter}
# 转换所有的 numpy.int64 为 Python 的 int 类型
detected_classes = [int(cls) for cls in filtered_class_counts.keys()]
detected_numbers = [int(num) for num in filtered_class_counts.values()]
aim = bool(detected_classes)
# 保存标注图像
annotated_image = result.plot(labels=False)
filename_no_ext, ext = os.path.splitext(os.path.basename(img_path))
output_filename = f"{filename_no_ext}_ai{ext}"
local_output_path = os.path.join(output_folder, output_filename)
cv2.imwrite(local_output_path, annotated_image)
# 上传标注图像到 MinIO
minio_path = upload_file(minio, local_output_path, bucket_name, os.path.dirname(img_path))
except Exception as e:
print(f"[错误] 处理失败 - {img_path},错误: {str(e)}")
detected_classes = []
detected_numbers = []
aim = False
output_filename = ""
minio_path = ""
output_image_list.append({
"id": img_id,
"minio_path":minio_path,
"aim": aim,
"class": detected_classes,
"number": detected_numbers
})
# 清理临时目录
shutil.rmtree(input_folder, ignore_errors=True)
shutil.rmtree(output_folder, ignore_errors=True)
return {
"status": "success",
"message": "Detection completed",
"data": output_image_list
}

View File

@ -1,505 +0,0 @@
from sanic import Sanic, json, Blueprint,response
from sanic.exceptions import Unauthorized, NotFound
from sanic.response import json as json_response
from sanic_cors import CORS
from datetime import datetime
import logging
import uuid
import os
import asyncio
from minio import Minio
from ai_image import process_images # 你实现的图片处理函数
from queue import Queue
import gdal2tiles as gdal2tiles
from map_find import map_process_images
from yolo_train import auto_train
from map_cut import process_tiling
from cv_video_counter import start_video_session,switch_model_session,stop_video_session,stream_sessions
import torch
from yolo_photo import map_process_images_with_progress # 引入你的处理函数
# 日志配置
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
###################################################################################验证中间件和管理件##############################################################################################
async def token_and_resource_check(request):
# --- Token 验证 ---
token = request.headers.get('X-API-Token')
expected_token = request.app.config.get("VALID_TOKEN")
if not token or token != expected_token:
logger.warning(f"Unauthorized request with token: {token}")
raise Unauthorized("Invalid token")
# --- GPU 使用率检查 ---
try:
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
max_usage_ratio = request.app.config.get("MAX_GPU_USAGE", 0.9) # 默认90%
for i in range(num_gpus):
used = torch.cuda.memory_reserved(i)
total = torch.cuda.max_memory_reserved(i)
ratio = used / total if total else 0
logger.info(f"GPU {i} Usage: {ratio:.2%}")
if ratio > max_usage_ratio:
logger.warning(f"GPU {i} usage too high: {ratio:.2%}")
return json_response({
"status": "error",
"message": f"GPU resource busy (GPU {i} at {ratio:.2%}). Try later."
}, status=503)
except Exception as e:
logger.error(f"GPU check failed: {e}")
return None # 允许请求继续
##################################################################################################################################################################################################
#创建Sanic应用
app = Sanic("ai_Service_v2")
CORS(app) # 允许跨域请求
task_progress = {}
@app.middleware("request")
async def global_middleware(request):
result = await token_and_resource_check(request)
if result:
return result
# 配置Token和最大GPU使用率
app.config.update({
"VALID_TOKEN": "Beidou_b8609e96-bfec-4485-8c64-6d4f662ee44a",
"MAX_GPU_USAGE": 0.9
})
######################################################################地图切割相关的API########################################################################################################
#创建地图的蓝图
map_tile_blueprint = Blueprint('map', url_prefix='/map/')
app.blueprint(map_tile_blueprint)
@map_tile_blueprint.post("/tile")
async def map_tile_api(request):
try:
# 1. 检查请求体
if not request.json:
return json_response(
{"status": "error", "message": "Request body is required"},
status=400
)
# 2. 解析必要字段
tile_data = request.json
tif_url = tile_data.get("tif_url")
prj_url = tile_data.get("prj_url")
if not tif_url or not prj_url:
return json_response(
{"status": "error", "message": "Both tif_url and prj_url are required"},
status=400
)
# 3. 处理业务逻辑(直接调用协程函数,不要用 asyncio.run
zoom_levels = tile_data.get("zoom_levels", "1-18")
try:
# 假设 process_tiling 是一个协程函数async def
result = await process_tiling(tif_url, prj_url, zoom_levels)
# 如果 process_tiling 是普通函数,用 asyncio.to_thread 包装
# result = await asyncio.to_thread(process_tiling, tif_url, prj_url, zoom_levels)
return json_response({
"status": "success",
"data": result
})
except Exception as processing_error:
logger.error(f"Processing failed: {str(processing_error)}", exc_info=True)
return json_response(
{"status": "error", "message": f"Processing error: {str(processing_error)}"},
status=500
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response(
{"status": "error", "message": str(e)}, # 直接返回字符串,不要用集合
status=500
)
#语义识别
@map_tile_blueprint.post("/uav")
async def process_handler(request):
"""
接口/map/uav
输入 JSON:
{
"urls": [
"http://example.com/img1.jpg",
"http://example.com/img2.jpg"
],
"yaml_name": "112.44.103.230",
"bucket_name": "300bdf2b-a150-406e-be63-d28bd29b409f",
"bucket_directory": "2025/seg"
"model_path": "deeplabv3plus_best.pth"
}
输出 JSON:
{
"code": 200,
"msg": "success",
"data": [
"http://minio.example.com/uav-results/2025/seg/result1.png",
"http://minio.example.com/uav-results/2025/seg/result2.png"
]
}
"""
try:
body = request.json
urls = body.get("urls", [])
yaml_name = body.get("yaml_name")
bucket_name = body.get("bucket_name")
bucket_directory = body.get("bucket_directory")
model_path = os.path.join("map", "checkpoints", body.get("model_path"))
# 校验参数
if not urls or not isinstance(urls, list):
return json({"code": 400, "msg": "Missing or invalid 'urls'"})
if not all([yaml_name, bucket_name, bucket_directory]):
return json({"code": 400, "msg": "Missing required parameters"})
# 调用图像处理函数
result = map_process_images(urls, yaml_name, bucket_name, bucket_directory,model_path)
return json(result)
except Exception as e:
return json({"code": 500, "msg": f"Server error: {str(e)}"})
######################################################################yolo相关的API########################################################################################################
#创建yolo的蓝图
yolo_tile_blueprint = Blueprint('yolo', url_prefix='/yolo/')
app.blueprint(yolo_tile_blueprint)
# YOLO URL APT
# 存储任务进度和结果(内存示例,可用 Redis 或 DB 持久化)
@yolo_tile_blueprint.post("/process_images")
async def process_images(request):
"""
{
"urls": [
"http://example.com/image1.jpg",
"http://example.com/image2.jpg",
"http://example.com/image3.jpg"
],
"yaml_name": "your_minio_config",
"bucket_name": "my-bucket",
"bucket_directory": "2025/uav-results",
"model_path": "deeplabv3plus_best.pth"
}
"""
data = request.json
urls = data.get("urls")
yaml_name = data.get("yaml_name")
bucket_name = data.get("bucket_name")
bucket_directory = data.get("bucket_directory")
uav_model_path = data.get("uav_model_path")
if not urls or not yaml_name or not bucket_name or not uav_model_path:
return response.json({"code": 400, "msg": "Missing parameters"}, status=400)
task_id = str(uuid.uuid4())
task_progress[task_id] = {"status": "pending", "progress": 0, "result": None}
# 启动后台任务
asyncio.create_task(run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path))
return response.json({"code": 200, "msg": "Task started", "task_id": task_id})
@yolo_tile_blueprint.get("/task_status/<task_id>")
async def task_status(request, task_id):
progress = task_progress.get(task_id)
if not progress:
return response.json({"code": 404, "msg": "Task not found"}, status=404)
return response.json({"code": 200, "msg": "Task status", "data": progress})
async def run_image_processing(task_id, urls, yaml_name, bucket_name, bucket_directory, uav_model_path):
try:
task_progress[task_id]["status"] = "running"
task_progress[task_id]["progress"] = 10 # 开始进度
# 下载、推理、上传阶段分别更新进度
def progress_callback(stage, percent):
task_progress[task_id]["status"] = stage
task_progress[task_id]["progress"] = percent
result = await asyncio.to_thread(
map_process_images_with_progress,
urls, yaml_name, bucket_name, bucket_directory, uav_model_path, progress_callback
)
task_progress[task_id]["status"] = "completed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = result
except Exception as e:
task_progress[task_id]["status"] = "failed"
task_progress[task_id]["progress"] = 100
task_progress[task_id]["result"] = str(e)
# YOLO检测API
@yolo_tile_blueprint.post("/picture")
async def yolo_detect_api(request):
try:
detect_data = request.json
# 解析必要字段
image_list = detect_data.get("image_list")
yolo_model = detect_data.get("yolo_model", "best.pt")
class_filter = detect_data.get("class", None)
minio_info = detect_data.get("minio", None)
if not image_list:
return json_response({"status": "error", "message": "image_list is required"}, status=400)
if not minio_info:
return json_response({"status": "error", "message": "MinIO information is required"}, status=400)
# 创建临时文件夹
input_folder = f"./temp_input_{str(uuid.uuid4())}"
output_folder = f"./temp_output_{str(uuid.uuid4())}"
# 执行图像处理
result = await asyncio.to_thread(
process_images,
yolo_model=yolo_model,
image_list=image_list,
class_filter=class_filter,
input_folder=input_folder,
output_folder=output_folder,
minio_info=minio_info
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
# YOLO自动训练
@yolo_tile_blueprint.post("/train")
async def yolo_train_api(request):
"""
自动训练模型
输入 JSON:
{
"db_host": str,
"db_database": str,
"db_user": str,
"db_password": str,
"db_port": int,
"model_id": int,
"img_path": str,
"label_path": str,
"new_path": str,
"split_list": List[float],
"class_names": Optional[List[str]],
"project_name": str
}
输出 JSON:
{
"base_metrics": Dict[str, float],
"best_model_path": str,
"final_metrics": Dict[str, float]
}
"""
try:
# 修改为直接访问 request.json 而不是调用它
data = request.json
if not data:
return json_response({"status": "error", "message": "data is required"}, status=400)
# 执行图像处理
result = await asyncio.to_thread(
auto_train,
data
)
# 返回处理结果
return json_response(result)
except Exception as e:
logger.error(f"Error occurred while processing request: {str(e)}", exc_info=True)
return json_response({
"status": "error",
"message": f"Internal server error: {str(e)}"
}, status=500)
###########################################################################################视频流相关的API#######################################################################################################
#创建视频流的蓝图
stream_tile_blueprint = Blueprint('stream', url_prefix='/stream_test/')
app.blueprint(stream_tile_blueprint)
#
# 任务管理器
class StreamTaskManager:
def __init__(self):
self.active_tasks = {}
self.task_status = {}
self.task_timestamps = {}
self.task_queue = Queue(maxsize=10)
def add_task(self, task_id: str, task_info: dict) -> None:
if self.task_queue.full():
oldest_task_id = self.task_queue.get()
self.remove_task(oldest_task_id)
stop_video_session(self.active_tasks[oldest_task_id]["session_id"])
self.active_tasks[task_id] = task_info
self.task_status[task_id] = "running"
self.task_timestamps[task_id] = datetime.now()
self.task_queue.put(task_id)
logger.info(f"Task {task_id} started")
def remove_task(self, task_id: str) -> None:
if task_id in self.active_tasks:
del self.active_tasks[task_id]
del self.task_status[task_id]
del self.task_timestamps[task_id]
logger.info(f"Task {task_id} removed")
def get_task_info(self, task_id: str) -> dict:
if task_id not in self.active_tasks:
raise NotFound("Task not found")
return {
"task_info": self.active_tasks[task_id],
"status": self.task_status[task_id],
"start_time": self.task_timestamps[task_id].isoformat()
}
task_manager = StreamTaskManager()
# ---------- API Endpoints ----------
@stream_tile_blueprint.post("/start")
async def api_start(request):
"""
启动视频流会话
输入 JSON:
{
"video_path": str,
"output_url": str,
"model_path": str,
"cls": List[int],
"confidence": float,
"cls2": Optional[List[int]]
"push": bool
}
输出 JSON:
{
"session_id": str,
"task_id": str,
"message": "started"
}
"""
data = request.json
task_id = str(uuid.uuid4())
# 启动视频处理会话,并传入 task_id
session_id = start_video_session(
video_path = data.get("video_path"),
output_url = data.get("output_url"),
model_path = data.get("model_path"),
cls = data.get("cls"),
confidence = data.get("confidence", 0.5),
cls2 = data.get("cls2", []),
push = data.get("push", False),
)
# 注册到任务管理器
task_manager.add_task(task_id, {
"session_id": session_id,
"video_path": data.get("video_path"),
"output_url": data.get("output_url"),
"model_path": data.get("model_path"),
"class_filter": data.get("cls", []),
"push": data.get("push", False),
"start_time": datetime.now().isoformat()
})
return json({"session_id": session_id, "task_id": task_id, "message": "started"})
@stream_tile_blueprint.post("/stop")
async def api_stop(request):
"""
停止指定会话
输入 JSON: { "session_id": str }
输出 JSON: { "session_id": str, "message": "stopped" }
"""
session_id = request.json.get("session_id")
stop_video_session(session_id)
# 同步移除任务
for tid, info in list(task_manager.active_tasks.items()):
if info.get("session_id") == session_id:
task_manager.remove_task(tid)
break
return json({"session_id": session_id, "message": "stopped"})
@stream_tile_blueprint.post("/switch_model")
async def api_switch_model(request):
"""
切换会话模型
输入 JSON: { "session_id": str, "new_model_path": str }
输出 JSON: { "session_id": str, "new_model_path": str, "message": "model switched" }
"""
data = request.json
session_id = data.get("session_id")
new_model = data.get("new_model_path")
switch_model_session(session_id, new_model)
return json({"session_id": session_id, "new_model_path": new_model, "message": "model switched"})
@stream_tile_blueprint.get("/sessions")
async def api_list_sessions(request):
"""
列出所有当前会话
输出 JSON: { "sessions": [{"session_id": str, "status": "running"}, ...] }
"""
sessions = [
{"session_id": sid, "status": "running"}
for sid in stream_sessions.keys()
]
return json({"sessions": sessions})
# 统一的任务查询接口(含视频流)
@stream_tile_blueprint.get("/tasks")
async def api_list_tasks(request):
"""
列出所有任务含状态开始时间详情
"""
tasks = []
for tid in task_manager.active_tasks:
info = task_manager.get_task_info(tid)
tasks.append({"task_id": tid, **info})
return json({"tasks": tasks})
##################################################################################################################################################################################################
if __name__ == '__main__':
app.run(host="0.0.0.0", port=12366, debug=True,workers=1)

47
check_grpc_server.py Normal file
View File

@ -0,0 +1,47 @@
from concurrent import futures
import grpc
import time
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer):
def ProcessTask(self, request, context):
print(f"Received task_id: {request.task_id}")
return check_grpc_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message="Task processed successfully"
)
class HealthCheckServicer(check_grpc_pb2_grpc.HealthCheckServicer):
def Check(self, request, context):
# 简单实现总是返回SERVING状态
# 实际应用中可以根据服务状态返回不同值
return check_grpc_pb2.HealthCheckResponse(
status=check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
# 添加服务实现
check_grpc_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
check_grpc_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
server.add_insecure_port('[::]:50051')
server.start()
print("Server started, listening on port 50051...")
try:
while True:
time.sleep(86400) # 保持运行
except KeyboardInterrupt:
server.stop(0)
if __name__ == '__main__':
serve()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,8 +0,0 @@
import os.path
from middleware.minio_util import downBigFile
miniourl=r"media/22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_202509121519_001_22d45cc5-0ba7-4bc3-a302-ca1a28c40fd2/DJI_20250912152112_0001_V.mp4"
file_path=downBigFile(miniourl)
print(f"os.path.abspath(file_path) {os.path.abspath(file_path)}")

View File

@ -0,0 +1,58 @@
syntax = "proto3";
package task;
service TaskService {
rpc ProcessTask (TaskRequest) returns (TaskResponse);
}
//
service HealthCheck {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
message TaskRequest {
string task_id = 1;
string sn = 2;
ContentBody content_body = 3;
}
message ContentBody {
string org_code = 1;
repeated int32 func_id = 2;
string source_url = 3;
string push_url = 4;
float confidence = 5;
repeated ParaList para_list = 6;
Invade invade = 7;
}
message ParaList {
int32 func_id = 1;
bool para_invade_enable = 2;
}
message Invade {
string invade_file = 1;
string camera_para_url = 2;
}
message TaskResponse {
string task_id = 1;
bool success = 2;
string message = 3;
}

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: grpc_proto/check_grpc/check_grpc.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
6,
31,
1,
'',
'grpc_proto/check_grpc/check_grpc.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&grpc_proto/check_grpc/check_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_proto.check_grpc.check_grpc_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_HEALTHCHECKREQUEST']._serialized_start=48
_globals['_HEALTHCHECKREQUEST']._serialized_end=85
_globals['_HEALTHCHECKRESPONSE']._serialized_start=88
_globals['_HEALTHCHECKRESPONSE']._serialized_end=247
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=168
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=247
_globals['_TASKREQUEST']._serialized_start=249
_globals['_TASKREQUEST']._serialized_end=332
_globals['_CONTENTBODY']._serialized_start=335
_globals['_CONTENTBODY']._serialized_end=506
_globals['_PARALIST']._serialized_start=508
_globals['_PARALIST']._serialized_end=563
_globals['_INVADE']._serialized_start=565
_globals['_INVADE']._serialized_end=619
_globals['_TASKRESPONSE']._serialized_start=621
_globals['_TASKRESPONSE']._serialized_end=686
_globals['_TASKSERVICE']._serialized_start=688
_globals['_TASKSERVICE']._serialized_end=755
_globals['_HEALTHCHECK']._serialized_start=757
_globals['_HEALTHCHECK']._serialized_end=832
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,172 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
from grpc_proto.check_grpc import check_grpc_pb2 as grpc__proto_dot_check__grpc_dot_check__grpc__pb2
GRPC_GENERATED_VERSION = '1.76.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ ' but the generated code in grpc_proto/check_grpc/check_grpc_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class TaskServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ProcessTask = channel.unary_unary(
'/task.TaskService/ProcessTask',
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString,
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString,
_registered_method=True)
class TaskServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def ProcessTask(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TaskServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'ProcessTask': grpc.unary_unary_rpc_method_handler(
servicer.ProcessTask,
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.FromString,
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'task.TaskService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('task.TaskService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class TaskService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ProcessTask(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/task.TaskService/ProcessTask',
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString,
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
class HealthCheckStub(object):
"""添加健康检查服务
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Check = channel.unary_unary(
'/task.HealthCheck/Check',
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString,
_registered_method=True)
class HealthCheckServicer(object):
"""添加健康检查服务
"""
def Check(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_HealthCheckServicer_to_server(servicer, server):
rpc_method_handlers = {
'Check': grpc.unary_unary_rpc_method_handler(
servicer.Check,
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.FromString,
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'task.HealthCheck', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('task.HealthCheck', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class HealthCheck(object):
"""添加健康检查服务
"""
@staticmethod
def Check(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/task.HealthCheck/Check',
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString,
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@ -0,0 +1,81 @@
import grpc
import time
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
def check_server_status(channel):
try:
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
def check_grpc_request(max_retries=3, delay=5):
channel = None
retries = 0
while retries < max_retries:
try:
# 创建通道
channel = grpc.insecure_channel('localhost:50051')
# 检查服务器状态
if not check_server_status(channel):
raise Exception("Server is not healthy")
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
# 创建请求消息
request = check_grpc_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
sn="8UUXN6S00A0CK7",
content_body=check_grpc_pb2.ContentBody(
org_code="HMZHB",
func_id=[101204],
source_url="xxxxxxxxxx",
push_url="",
confidence=0.4,
para_list=[
check_grpc_pb2.ParaList(
func_id=101204,
para_invade_enable=True
)
],
invade=check_grpc_pb2.Invade(
invade_file="meta_data/高压线-0826.geojson",
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
)
)
)
# 调用远程方法
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
finally:
if channel:
channel.close()
print("All retry attempts failed")
return False
if __name__ == '__main__':
check_grpc_request()

283
md/grpc.md Normal file
View File

@ -0,0 +1,283 @@
# 前言
sanic 和 服务之间基于grpc 解绑一个服务一个grpc
可以参考接口grpc 最好留有健康校验
# grpc demo
为了增强 gRPC 通讯的可靠性,我们可以添加以下功能:
1. 服务器健康检查服务
2. 客户端连接前检查服务器状态
3. 通讯过程中的错误处理和重试机制
## 1. 修改 protobuf 定义 (task.proto)
首先添加健康检查服务定义:
```protobuf
syntax = "proto3";
package task;
service TaskService {
rpc ProcessTask (TaskRequest) returns (TaskResponse);
}
// 添加健康检查服务
service HealthCheck {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
message TaskRequest {
string task_id = 1;
string sn = 2;
ContentBody content_body = 3;
}
message ContentBody {
string org_code = 1;
repeated int32 func_id = 2;
string source_url = 3;
string push_url = 4;
float confidence = 5;
repeated ParaList para_list = 6;
Invade invade = 7;
}
message ParaList {
int32 func_id = 1;
bool para_invade_enable = 2;
}
message Invade {
string invade_file = 1;
string camera_para_url = 2;
}
message TaskResponse {
string task_id = 1;
bool success = 2;
string message = 3;
}
```
## 2. 重新生成 Python 代码
```bash
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. task.proto
```
## 3. 增强服务端实现 (server.py)
```python
from concurrent import futures
import grpc
import time
import task_pb2
import task_pb2_grpc
class TaskServiceServicer(task_pb2_grpc.TaskServiceServicer):
def ProcessTask(self, request, context):
print(f"Received task_id: {request.task_id}")
# ... 原有处理逻辑
return task_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message="Task processed successfully"
)
class HealthCheckServicer(task_pb2_grpc.HealthCheckServicer):
def Check(self, request, context):
# 简单实现总是返回SERVING状态
# 实际应用中可以根据服务状态返回不同值
return task_pb2.HealthCheckResponse(
status=task_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
# 添加服务实现
task_pb2_grpc.add_TaskServiceServicer_to_server(TaskServiceServicer(), server)
task_pb2_grpc.add_HealthCheckServicer_to_server(HealthCheckServicer(), server)
server.add_insecure_port('[::]:50051')
server.start()
print("Server started, listening on port 50051...")
try:
while True:
time.sleep(86400) # 保持运行
except KeyboardInterrupt:
server.stop(0)
if __name__ == '__main__':
serve()
```
## 4. 增强客户端实现 (client.py)
```python
import grpc
import time
import task_pb2
import task_pb2_grpc
def check_server_status(channel):
try:
health_stub = task_pb2_grpc.HealthCheckStub(channel)
response = health_stub.Check(task_pb2.HealthCheckRequest(service="TaskService"))
return response.status == task_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
def run_with_retry(max_retries=3, delay=5):
channel = None
retries = 0
while retries < max_retries:
try:
# 创建通道
channel = grpc.insecure_channel('localhost:50051')
# 检查服务器状态
if not check_server_status(channel):
raise Exception("Server is not healthy")
stub = task_pb2_grpc.TaskServiceStub(channel)
# 创建请求消息
request = task_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
sn="8UUXN6S00A0CK7",
content_body=task_pb2.ContentBody(
org_code="HMZHB",
func_id=[101204],
source_url="rtmp://222.212.85.86:1935/live/1581F8HGX253S00A05L8",
push_url="",
confidence=0.4,
para_list=[
task_pb2.ParaList(
func_id=101204,
para_invade_enable=True
)
],
invade=task_pb2.Invade(
invade_file="meta_data/高压线-0826.geojson",
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
)
)
)
# 调用远程方法
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
finally:
if channel:
channel.close()
print("All retry attempts failed")
return False
if __name__ == '__main__':
run_with_retry()
```
## 5. 增强功能说明
1. 健康检查服务
- 添加了标准的 gRPC 健康检查服务
- 客户端可以在执行主要操作前检查服务状态
2. 错误处理和重试机制
- 客户端现在会捕获 `grpc.RpcError` 和其他异常
- 实现了最大重试次数和重试间隔
- 每次重试前都会检查服务器状态
3. 资源管理
- 确保在所有情况下都正确关闭通道
- 使用上下文管理器处理通道生命周期
4. 状态反馈
- 提供更详细的错误信息
- 记录重试尝试
## 6. 运行说明
1. 安装依赖:
```bash
pip install grpcio grpcio-tools
```
2. 生成 protobuf 代码:
```bash
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. task.proto
```
3. 启动服务器:
```bash
python server.py
```
4. 运行客户端:
```bash
python client.py
```
## 7. 测试场景
1. 服务器未运行
- 客户端会检测到连接失败并重试
- 最终显示所有重试失败
2. 服务器运行但健康检查失败
- 可以修改 `HealthCheckServicer` 返回 `NOT_SERVING` 状态进行测试
- 客户端会拒绝执行主要操作
3. 网络中断
- 客户端会捕获异常并尝试重试
这个增强版本提供了更健壮的 gRPC 通讯机制,适合生产环境使用。

350
md/接口.md Normal file
View File

@ -0,0 +1,350 @@
# 算法与后台解耦规则
# 1、方法
postgres 的ai_model_list 表id字段声明为6位长度数字
1、第1位表示算法类别1xxxxx 表明为目标识别、2xxxxx标明为语义分割、3xxxxx表示变化监测
2、最后两位表示二次计算100001 表示做目标识别、100002表示做目标识别且做人员计数
# 接口名:视频流识别
接收前端的视频流、模型、识别类型算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/stream/back_detect
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "1234567890", #任务id
"sn":"", #无人机sn
"content_body": {
"source_url": "rtmp://192.168.0.142:1935/live/123456", #无人机视频流url
"confidence":0.4, #置信度
"model_func_id":[100001,100002] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "1234567890",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "1234567890",
"minio": {
"minio_path": "ai_result/20250702/1751425303860-output-1751425303800959985.jpg",
"file_type": "pic"
},
"box_detail": {
"result_100001": {
"func_id_10001": 100001,
"type_name": "行人",
"cls_count": 1,
"box_count": [
[
{
"track_id": 22099,
"confidence": 0.34013107419013977,
"cls_id": 0,
"type_name": "行人",
"box": [
15.935794830322266,
694.75390625,
33.22901916503906,
713.1658935546875
]
}
]
]
}
},
"uav_location": {
"data": {
"attitude_head": 60,
"gimbal_pitch": 60,
"gimbal_roll": 60,
"gimbal_yaw": 60,
"height": 10,
"latitude": 10,
"longitude": 10,
"speed_x": 10,
"speed_y": 10,
"speed_z": 10
},
"timestamp": 1751425301213249700
}
}
```
# 接口名:图片识别
接收前端的图片算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "0001111",
"content_body": {
"s3_id":1, #根据id适配minio相关存储参数
"s3_url":[
"test/frame_0000.jpg","test/frame_0001.jpg","test/frame_0002.jpg" # minio文件地址
],
"confidence":0.4, #算法置信度
"model_func_id":[10001,10002] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "0001111", #任务id
"minio": {
"minio_path": "ai_result/20250627/1751006943659-frame_0001.jpg", # minio 存储路径
"file_type": "pic"
},
"box_detail": {
"model_id": 10001,
"box_count": [
{
"type": 3, # 类型
"type_name": "车辆", #类型名称
"count": 71 #数量
},
{
"type": 0,
"type_name": "车辆",
"count": 7
}
]
}
}
```
# 接口名:地类分割
接收前端的图片算法做计算并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
"content_body": {
"s3_id":1,#根据id适配minio相关存储参数
"s3_url":[
"test/patch_0011.png", # minio文件地址
"test/patch_0012.png"
],
"model_func_id":[20000,20001] #方法id
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "7a5c83e0-fe0d-47bf-a8e1-9bd663508783",
"minio": [
{
"minio_path_before": "ai_result/20250710/1752128232469-patch_0011.png", # 需要分割的图片
"minio_path_after": "ai_result/20250710/1752128234222-patch_0011.png", #分割之后的图片
"minio_path_boundary": "ai_result/20250710/1752128234264-patch_0011.pngfinal_vis.png", # 分割的边界图片
"minio_path_json": "ai_result/20250710/1752128234326-patch_0011.pnginstance_results.json", #分割生成的json文件
"file_type": "pic"
},
{
"minio_path_before": "ai_result/20250710/1752128240382-patch_0012.png",
"minio_path_after": "ai_result/20250710/1752128241553-patch_0012.png",
"minio_path_boundary": "ai_result/20250710/1752128241587-patch_0012.pngfinal_vis.png",
"minio_path_json": "ai_result/20250710/1752128241631-patch_0012.pnginstance_results.json",
"file_type": "pic"
}
]
}
```
# 接口名:地类变化监测
接收前端的图片对一期、二期的图像做变化监测并且将计算结果存储到minio消息通过mqtt发送
## 1、请求
接口 /ai/pic/back_detect_pic
方法 post
headers X-API-Token:5e8899fe-dc74-4280-8169-2f4d185f3afa
body
```
{
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
"content_body": {
"s3_id":1,
"s3_url":{
"early":"/test/1-00205.png", # 一期图像minio文件地址
"later":"/test/2-00205.png" # 二期图像minio文件地址
},
"model_func_id":[30000,30001]
}
}
```
## 2、响应
算法的响应分为两个部分
1、rest响应表明收到请求
2、mqtt消息持续输出计算结果
### 1、rest
```
{
"status": "success",
"task_id": "0001111",
"message": "Detection started successfully"
}
```
### 2、mqtt
ip 112.44.103.230 端口 1883
topic thing/product/ai/events
```
{
"task_id": "9fa19ec3-d982-4897-af6c-2c78f786c760",
"minio": {
"minio_path_1": "ai_result/20250627/1751007686483-1-00205.png", # 一期影像minio地址
"minio_path_2": "ai_result/20250627/1751007686541-2-00205.png", # 二期影像minio地址
"minio_path_result": "ai_result/20250627/1751007686.458642-result-2-00205.png", #识别结果minio地址
"file_type": "pic"
}
}
```

View File

@ -18,6 +18,7 @@ from websockets.exceptions import ConnectionClosed
from CropLand_CD_module.detection import corpland_detection_func from CropLand_CD_module.detection import corpland_detection_func
from cropland_module.detection import detection_func from cropland_module.detection import detection_func
from grpc_proto.check_grpc_client import check_grpc_request
from middleware.AsyncioMqttClient import AsyncMQTTClient, ConnectionContext, active_connections from middleware.AsyncioMqttClient import AsyncMQTTClient, ConnectionContext, active_connections
from middleware.TaskManager import TaskManager, task_manager from middleware.TaskManager import TaskManager, task_manager
from middleware.minio_util import downFile from middleware.minio_util import downFile
@ -1588,6 +1589,30 @@ async def stop_task_heart(request):
"message": str(e)}, status=500) "message": str(e)}, status=500)
@app.post("/ai/func/check_grpc")
async def check_grpc(request):
try:
verify_token(request)
check_grpc_request()
return json_response({
"status": "success",
"task_id": "task_id",
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.websocket("/ws") @app.websocket("/ws")
async def websocket_endpoint(request: Request, ws): async def websocket_endpoint(request: Request, ws):
"""WebSocket端点处理前端连接和消息为每个连接创建独立的MQTT客户端""" """WebSocket端点处理前端连接和消息为每个连接创建独立的MQTT客户端"""