视频流-识别-多任务并发 拆分成小模块
This commit is contained in:
parent
795a028d0e
commit
3a7bdbc3a4
645
multi_back_detect/multi_back_detect_api.py
Normal file
645
multi_back_detect/multi_back_detect_api.py
Normal file
@ -0,0 +1,645 @@
|
||||
import json
|
||||
import time
|
||||
from sanic import Blueprint
|
||||
from sanic.response import json as json_response
|
||||
from sanic.exceptions import Unauthorized, SanicException
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Dict, Any
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from sympy import false
|
||||
|
||||
|
||||
try :
|
||||
from middleware.TaskManager import task_manager
|
||||
from middleware.query_model import ModelConfigDAO
|
||||
from middleware.query_postgress import batch_query_model_func_id
|
||||
from middleware.read_yolo_config import read_local_func_config
|
||||
from yolo.cv_multi_model_back_video import start_rtmp_processing
|
||||
from cv_video import startAIVideo, stopAIVideo
|
||||
from cv_back_video import startBackAIVideo
|
||||
except Exception as e:
|
||||
import sys
|
||||
from pathlib import Path
|
||||
# 获取项目根目录(假设脚本在根目录的子目录中)
|
||||
ROOT_DIR = Path(__file__).parent.parent # 根据实际层级调整
|
||||
sys.path.append(str(ROOT_DIR))
|
||||
from middleware.TaskManager import task_manager
|
||||
from middleware.query_model import ModelConfigDAO
|
||||
from middleware.query_postgress import batch_query_model_func_id
|
||||
from middleware.read_yolo_config import read_local_func_config
|
||||
from yolo.cv_multi_model_back_video import start_rtmp_processing
|
||||
from cv_video import startAIVideo, stopAIVideo
|
||||
from cv_back_video import startBackAIVideo
|
||||
|
||||
|
||||
# 配置类
|
||||
class Config:
|
||||
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
||||
MAX_ACTIVE_TASKS = 10
|
||||
DEFAULT_CONFIDENCE = 0.5
|
||||
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
||||
|
||||
#正式数据库
|
||||
DB_CONFIG = {
|
||||
"dbname": "smart_dev",
|
||||
"user": "postgres",
|
||||
"password": "StrongPassword@123",
|
||||
"host": "222.212.85.86",
|
||||
"port": "5061"
|
||||
}
|
||||
|
||||
# 服务状态标志
|
||||
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
multi_back_detect_bp = Blueprint("multi_back_detect", url_prefix="")
|
||||
|
||||
@dataclass
|
||||
class BaseContentBody:
|
||||
pass # 公共字段可以放在这里
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_VideoMultiBackDetect(BaseContentBody):
|
||||
# mqtt_pub_id: int
|
||||
org_code: str
|
||||
# mqtt_ip: str
|
||||
# mqtt_port: int
|
||||
# mqtt_topic: str
|
||||
minio_file_path: str
|
||||
push_url: str # 临时测试用
|
||||
confidence: float
|
||||
para_list: list
|
||||
# invade_file: str
|
||||
invade: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_Sam3Pic(BaseContentBody):
|
||||
img_url: str
|
||||
prompt: str
|
||||
confidence: int
|
||||
mqtt_ip: str
|
||||
mqtt_port: int
|
||||
mqtt_topic: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_MultiBackDetect(BaseContentBody):
|
||||
# mqtt_pub_id: int
|
||||
# mqtt_sub_id: int
|
||||
# mqtt_pub_ip: str
|
||||
# mqtt_pub_port: int
|
||||
# mqtt_pub_topic: str
|
||||
# mqtt_sub_ip: str
|
||||
# mqtt_sub_port: int
|
||||
# mqtt_sub_topic: str
|
||||
org_code: str
|
||||
func_id: list
|
||||
source_url: str
|
||||
push_url: str # 临时测试用
|
||||
confidence: float
|
||||
para_list: list
|
||||
# invade_file: str
|
||||
invade: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_BackDetect(BaseContentBody):
|
||||
mqtt_ip: str
|
||||
mqtt_port: int
|
||||
mqtt_topic: str
|
||||
source_url: str
|
||||
push_url: str # 临时测试用
|
||||
confidence: float
|
||||
func_id: List[int]
|
||||
para: {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_BackDetectPic(BaseContentBody):
|
||||
s3_id: int
|
||||
s3_url: list[str]
|
||||
org_code: str
|
||||
# mqtt_pub_id: int
|
||||
confidence: float
|
||||
func_id: list[int]
|
||||
para: {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EarlyLaterUrls:
|
||||
early: str
|
||||
later: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_Detection(BaseContentBody):
|
||||
s3_id: int
|
||||
s3_url: EarlyLaterUrls
|
||||
func_id: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentBodyFormat_Segementation(BaseContentBody):
|
||||
s3_id: int
|
||||
s3_url: list[str]
|
||||
func_id: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestJson:
|
||||
task_id: str
|
||||
sn: str
|
||||
content_body: BaseContentBody
|
||||
|
||||
def validate(self) -> None:
|
||||
"""验证请求参数"""
|
||||
if not self.task_id:
|
||||
raise ValueError("task_id is required")
|
||||
|
||||
if isinstance(self.content_body, ContentBodyFormat_VideoMultiBackDetect):
|
||||
if not self.content_body.minio_file_path:
|
||||
raise ValueError("minio_file_path is required for ContentBodyFormat_VideoMultiBackDetect")
|
||||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||||
elif isinstance(self.content_body, ContentBodyFormat_MultiBackDetect):
|
||||
if not self.content_body.para_list:
|
||||
raise ValueError("para_list is required for ContentBodyFormat_MultiBackDetect")
|
||||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||||
|
||||
elif isinstance(self.content_body, ContentBodyFormat_BackDetect):
|
||||
if not self.content_body.source_url:
|
||||
raise ValueError("source_url is required for ContentBodyFormat_BackDetect")
|
||||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||||
|
||||
elif isinstance(self.content_body, ContentBodyFormat_BackDetectPic):
|
||||
if not self.content_body.s3_id:
|
||||
raise ValueError("s3_id is required for ContentBodyFormat_BackDetectPic")
|
||||
if not self.content_body.s3_url:
|
||||
raise ValueError("s3_url is required for ContentBodyFormat_BackDetectPic")
|
||||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetectPic")
|
||||
|
||||
elif isinstance(self.content_body, ContentBodyFormat_Detection):
|
||||
if not self.content_body.s3_id:
|
||||
raise ValueError("s3_id is required for ContentBodyFormat_Detection")
|
||||
if not self.content_body.s3_url.early or not self.content_body.s3_url.later:
|
||||
raise ValueError("Both early and later URLs are required for ContentBodyFormat_Detection")
|
||||
|
||||
elif isinstance(self.content_body, ContentBodyFormat_Segementation):
|
||||
if not self.content_body.s3_id:
|
||||
raise ValueError("s3_id is required for ContentBodyFormat_Segementation")
|
||||
if not self.content_body.s3_url:
|
||||
raise ValueError("s3_url is required for ContentBodyFormat_Segementation")
|
||||
elif isinstance(self.content_body, ContentBodyFormat_Sam3Pic):
|
||||
if not self.content_body.prompt:
|
||||
raise ValueError("prompt is required for ContentBodyFormat_Sam3Pic")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'RequestJson':
|
||||
try:
|
||||
task_id = data['task_id']
|
||||
sn = data['sn']
|
||||
content_body_data = data['content_body']
|
||||
|
||||
if 'minio_file_path' in content_body_data:
|
||||
content_body = ContentBodyFormat_VideoMultiBackDetect(
|
||||
org_code=content_body_data['org_code'],
|
||||
minio_file_path=content_body_data['minio_file_path'],
|
||||
push_url=content_body_data['push_url'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
para_list=content_body_data.get('para_list', []),
|
||||
invade=content_body_data.get('invade', [])
|
||||
)
|
||||
|
||||
# 根据 content_body_data 的类型创建相应的实例
|
||||
elif 'para_list' in content_body_data:
|
||||
content_body = ContentBodyFormat_MultiBackDetect(
|
||||
org_code=content_body_data['org_code'],
|
||||
func_id=content_body_data['func_id'],
|
||||
source_url=content_body_data['source_url'],
|
||||
push_url=content_body_data['push_url'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
para_list=content_body_data.get('para_list', []),
|
||||
invade=content_body_data.get('invade', [])
|
||||
)
|
||||
# 根据 content_body_data 的类型创建相应的实例
|
||||
elif 'source_url' in content_body_data:
|
||||
content_body = ContentBodyFormat_BackDetect(
|
||||
mqtt_ip=content_body_data['mqtt_ip'],
|
||||
mqtt_port=content_body_data['mqtt_port'],
|
||||
mqtt_topic=content_body_data['mqtt_topic'],
|
||||
source_url=content_body_data['source_url'],
|
||||
push_url=content_body_data['push_url'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
func_id=content_body_data.get('func_id', []),
|
||||
para=content_body_data.get('para', {})
|
||||
)
|
||||
elif 's3_id' in content_body_data and 's3_url' in content_body_data:
|
||||
if isinstance(content_body_data['s3_url'], dict) and 'early' in content_body_data[
|
||||
's3_url'] and 'later' in content_body_data['s3_url']:
|
||||
content_body = ContentBodyFormat_Detection(
|
||||
s3_id=content_body_data['s3_id'],
|
||||
s3_url=EarlyLaterUrls(
|
||||
early=content_body_data['s3_url']['early'],
|
||||
later=content_body_data['s3_url']['later']
|
||||
),
|
||||
func_id=content_body_data.get('func_id', [])
|
||||
)
|
||||
elif isinstance(content_body_data['s3_url'], list) and 'confidence' not in content_body_data:
|
||||
content_body = ContentBodyFormat_Segementation(
|
||||
s3_id=content_body_data['s3_id'],
|
||||
s3_url=content_body_data['s3_url'],
|
||||
func_id=content_body_data.get('func_id', [])
|
||||
)
|
||||
elif isinstance(content_body_data['s3_url'], list):
|
||||
content_body = ContentBodyFormat_BackDetectPic(
|
||||
s3_id=content_body_data['s3_id'],
|
||||
s3_url=content_body_data['s3_url'],
|
||||
# mqtt_pub_id=content_body_data['mqtt_pub_id'],
|
||||
org_code=content_body_data['org_code'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
func_id=content_body_data.get('func_id', []),
|
||||
para=content_body_data.get('para', {})
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid s3_url format for ContentBodyFormat_Detection")
|
||||
elif 'prompt' in content_body_data:
|
||||
content_body = ContentBodyFormat_Sam3Pic(
|
||||
img_url=content_body_data['img_url'],
|
||||
prompt=content_body_data['prompt'],
|
||||
confidence=content_body_data.get('confidence', 0.5),
|
||||
mqtt_ip=content_body_data['mqtt_ip'],
|
||||
mqtt_port=content_body_data['mqtt_port'],
|
||||
mqtt_topic=content_body_data['mqtt_topic']
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid content_body format")
|
||||
|
||||
instance = cls(
|
||||
task_id=task_id,
|
||||
sn=sn,
|
||||
content_body=content_body
|
||||
)
|
||||
instance.validate()
|
||||
return instance
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing required field: {str(e)}")
|
||||
|
||||
async def safe_stop_ai_video():
|
||||
"""安全地停止AI视频处理,带有错误处理和恢复机制"""
|
||||
try:
|
||||
await asyncio.to_thread(stopAIVideo)
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 标记服务状态为不健康
|
||||
service_status["is_healthy"] = False
|
||||
service_status["last_error"] = str(e)
|
||||
service_status["error_time"] = datetime.now().isoformat()
|
||||
|
||||
# 强制结束所有任务
|
||||
task_manager.mark_all_tasks_as_stopped()
|
||||
|
||||
# 尝试通过其他方式杀死可能存在的进程
|
||||
try:
|
||||
import os
|
||||
import signal
|
||||
import psutil
|
||||
|
||||
current_process = psutil.Process(os.getpid())
|
||||
# 查找并终止ffmpeg子进程
|
||||
for child in current_process.children(recursive=True):
|
||||
try:
|
||||
child_name = child.name().lower()
|
||||
if 'ffmpeg' in child_name:
|
||||
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
|
||||
child.send_signal(signal.SIGTERM)
|
||||
except Exception as child_e:
|
||||
logger.error(f"终止子进程出错: {str(child_e)}")
|
||||
except Exception as kill_e:
|
||||
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
|
||||
|
||||
# 等待一段时间让系统恢复
|
||||
await asyncio.sleep(Config.RESTART_DELAY)
|
||||
|
||||
# 重置服务状态
|
||||
service_status["is_healthy"] = True
|
||||
return False
|
||||
|
||||
def verify_token(request) -> None:
|
||||
"""验证请求token"""
|
||||
token = request.headers.get('X-API-Token')
|
||||
if not token or token != Config.VALID_TOKEN:
|
||||
logger.warning("Invalid token attempt")
|
||||
raise Unauthorized("Invalid token")
|
||||
|
||||
@multi_back_detect_bp.post("/ai/stream/multi_back_detect")
|
||||
async def start_multi_back_detection(request):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
# 检查服务健康状态
|
||||
if not service_status["is_healthy"]:
|
||||
logger.warning(
|
||||
f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}")
|
||||
service_status["is_healthy"] = True
|
||||
|
||||
# 停止所有现有任务(可选,根据需求调整)
|
||||
# for task_id in list(task_manager.tasks.keys()):
|
||||
# await task_manager.remove_task(task_id)
|
||||
|
||||
# 解析并验证请求数据
|
||||
request_json = RequestJson.from_dict(request.json)
|
||||
print(f"/ai/stream/multi_back_detect 请求:{request.json}")
|
||||
time.sleep(3)
|
||||
if request_json.task_id in task_manager.tasks:
|
||||
logger.warning(f"任务 {request_json.task_id} 已存在,跳过创建")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"任务 {request_json.task_id} 已存在,跳过创建"
|
||||
}, status=500)
|
||||
|
||||
if isinstance(request_json.content_body, ContentBodyFormat_MultiBackDetect):
|
||||
try:
|
||||
# 创建停止事件
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
# 包装处理函数以支持停止事件
|
||||
async def wrapped_processing():
|
||||
try:
|
||||
await run_back_Multi_Detect_async(request, request_json, stop_event)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"任务 {request_json.task_id} 被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {request_json.task_id} 异常终止: {e}")
|
||||
|
||||
# 创建并启动任务
|
||||
task_handle = asyncio.create_task(wrapped_processing())
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动AI视频处理失败: {e}")
|
||||
return json_response({
|
||||
"status": "error",
|
||||
"message": f"Failed to start AI video processing: {str(e)}"
|
||||
}, status=500)
|
||||
else:
|
||||
return json_response({
|
||||
"status": "failed",
|
||||
"message": "content_body structure is wrong"
|
||||
})
|
||||
|
||||
return json_response({
|
||||
"status": "success",
|
||||
"task_id": request_json.task_id,
|
||||
"message": "Detection started successfully"
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Validation error: {str(e)}")
|
||||
return json_response({"status": "error", "message": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio.Event):
|
||||
global DB_CONFIG
|
||||
model_configs = []
|
||||
task_handle = None # 初始化task_handle
|
||||
|
||||
try:
|
||||
|
||||
invade_enable = false # para_list 中可以包含多个侵限使能,其中一个使能,即将置为True
|
||||
py_func = []
|
||||
# 创建DAO实例
|
||||
dao = ModelConfigDAO(DB_CONFIG)
|
||||
# insert_request_log(self, task_id, sn, org_code, requset_json, request)
|
||||
for para in request_json.content_body.para_list:
|
||||
func_id = para["func_id"]
|
||||
category = para.get("py_func", []) # 提供默认值
|
||||
py_func = category # 湖北现场临时用
|
||||
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
|
||||
if para_invade_enable:
|
||||
invade_enable = True
|
||||
query_results = batch_query_model_func_id([func_id], **DB_CONFIG)
|
||||
row_func_id = 0 # 伪代码,后续记得修改
|
||||
if len(query_results) < 1:
|
||||
continue
|
||||
for row in query_results:
|
||||
row_func_id = row["model_func_id"]
|
||||
func_id = para["func_id"]
|
||||
config = dao.get_config(func_id, category)
|
||||
repeat_dis = -1 # 基于两帧之间的距离去重
|
||||
if config:
|
||||
# 打印结构化结果(使用自定义编码器处理datetime)
|
||||
# 访问特定字段
|
||||
print("\n模型路径:", config.model_path)
|
||||
print("过滤类别:", config.filter_indices)
|
||||
print("第一个类别:", asdict(config.classes[0]))
|
||||
print("创建时间:", config.created_at)
|
||||
print("更新时间:", config.updated_at)
|
||||
print("去重的距离:", config.repeat_dis)
|
||||
repeat_dis = config.repeat_dis
|
||||
repeat_time = config.repeat_time
|
||||
high_count_warn = config.high_count_warn
|
||||
print(f"config.high_count_warn {config.high_count_warn}")
|
||||
model_configs.append(
|
||||
{
|
||||
|
||||
'path': config.model_path,
|
||||
# 'engine_path': config.engine_path,
|
||||
# 'so_path': config.so_path,
|
||||
# # 测试代码
|
||||
'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine",
|
||||
'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll",
|
||||
# 工地安全帽
|
||||
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
|
||||
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
|
||||
|
||||
'cls_map': config.cls_zn_to_eh_dict,
|
||||
'allowed_classes': config.allowed_classes,
|
||||
"cls_index": config.class_indices,
|
||||
"class_names": config.cls_names,
|
||||
"chinese_label": config.cls_en_dict,
|
||||
"list_func_id": row_func_id,
|
||||
"func_id": func_id,
|
||||
"para_invade_enable": para_invade_enable,
|
||||
"config_conf": config.conf
|
||||
}
|
||||
)
|
||||
else:
|
||||
print(f"未找到ID为 {func_id} 的模型配置")
|
||||
|
||||
# category = para.get("category", []) # 提供默认值
|
||||
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
|
||||
if para_invade_enable:
|
||||
invade_enable = True
|
||||
|
||||
# 前置处理 不同的版本的模型,输入的数据的格式不同,需要做前置处理,入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式
|
||||
|
||||
video_url = request_json.content_body.source_url
|
||||
sn = request_json.sn
|
||||
task_id = request_json.task_id
|
||||
# mqtt_pub_id = request_json.content_body.mqtt_pub_id
|
||||
# mqtt_sub_id = request_json.content_body.mqtt_sub_id
|
||||
org_code = request_json.content_body.org_code
|
||||
push_url = request_json.content_body.push_url
|
||||
# invade_file = request_json.content_body.invade_file
|
||||
invade = request_json.content_body.invade
|
||||
invade_file = invade["invade_file"]
|
||||
camera_para_url = invade["camera_para_url"]
|
||||
|
||||
if high_count_warn is None:
|
||||
high_count_warn=0
|
||||
|
||||
if "invade_switch" in invade:
|
||||
invade_switch = invade["invade_switch"]
|
||||
else:
|
||||
invade_switch = 0 # 或其他默认值
|
||||
# dao.get_mqtt_config_by_orgcode(org_code,)
|
||||
str_request = str(request) + "&" + str(request.socket) # 待测试,看看公网能不能捕获到请求端ip
|
||||
dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request)
|
||||
mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
|
||||
mqtt_sub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "sub")
|
||||
|
||||
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
|
||||
mqtt_pub_port = mqtt_pub_config.mqtt_port
|
||||
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
|
||||
print(f"mqtt_pub_topic {mqtt_pub_topic}")
|
||||
mqtt_pub_username = mqtt_pub_config.mqtt_username
|
||||
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
|
||||
mqtt_pub_description = mqtt_pub_config.mqtt_description
|
||||
mqtt_pub_org_code = mqtt_pub_config.org_code
|
||||
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
|
||||
|
||||
mqtt_sub_ip = mqtt_sub_config.mqtt_ip
|
||||
mqtt_sub_port = mqtt_sub_config.mqtt_port
|
||||
mqtt_sub_topic = mqtt_sub_config.mqtt_topic
|
||||
mqtt_sub_topic = mqtt_sub_topic.format(sn=sn)
|
||||
|
||||
print(f"mqtt_sub_topic {mqtt_sub_topic}")
|
||||
mqtt_sub_username = mqtt_sub_config.mqtt_username
|
||||
mqtt_sub_pass = mqtt_sub_config.mqtt_pass
|
||||
mqtt_sub_description = mqtt_sub_config.mqtt_description
|
||||
mqtt_sub_org_code = mqtt_sub_config.org_code
|
||||
mqtt_sub_mqtt_type = mqtt_sub_config.mqtt_type
|
||||
|
||||
local_func_config = read_local_func_config()
|
||||
|
||||
sn = request_json.sn
|
||||
device = dao.get_device(sn, org_code)
|
||||
print(f"device表 {sn} {org_code}")
|
||||
device_sn = device.sn
|
||||
device_orgcode = device.orgcode
|
||||
device_dname = device.dname
|
||||
device_lat = device.lat
|
||||
device_lng = device.lng
|
||||
device_height = device.height # 机场高度,后续用作现场的高程计算
|
||||
|
||||
# # 启动处理流程
|
||||
async def process_flow():
|
||||
try:
|
||||
await start_rtmp_processing(
|
||||
video_url,
|
||||
request_json.task_id,
|
||||
model_configs,
|
||||
mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
||||
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
|
||||
push_url,
|
||||
invade_enable,invade_switch, invade_file, camera_para_url,
|
||||
device_height, repeat_dis, repeat_time,high_count_warn
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"处理流程异常: {e}")
|
||||
raise
|
||||
|
||||
# 运行处理流程,直到被取消
|
||||
await asyncio.shield(process_flow())
|
||||
|
||||
# # # # 针对湖北现场临时处理--------------------------------------------------------
|
||||
# model_path = model_configs[0]["path"]
|
||||
# detect_classes = py_func
|
||||
# print(f"detect_classesdetect_classes {detect_classes}")
|
||||
# confidence = model_configs[0]["config_conf"]
|
||||
# # 创建处理函数以支持停止事件
|
||||
# async def process_video():
|
||||
# nonlocal task_handle # 使用nonlocal访问外部变量
|
||||
# try:
|
||||
# # 针对湖北现场临时处理
|
||||
# source_url = video_url
|
||||
# model_path = model_configs[0]["path"]
|
||||
# detect_classes = py_func # 使用py_func作为检测类别
|
||||
# print(f"detect_classesdetect_classes {detect_classes}")
|
||||
#
|
||||
# confidence = model_configs[0]["config_conf"]
|
||||
#
|
||||
# # 启动YOLO检测
|
||||
# await asyncio.to_thread(
|
||||
# startAIVideo,
|
||||
# source_url,
|
||||
# push_url,
|
||||
# model_path,
|
||||
# detect_classes,
|
||||
# confidence
|
||||
# )
|
||||
# except asyncio.CancelledError:
|
||||
# logger.info(f"任务 {task_id} 被取消")
|
||||
# raise
|
||||
# except Exception as e:
|
||||
# logger.error(f"任务 {task_id} 异常终止: {e}")
|
||||
# raise
|
||||
#
|
||||
# # 创建并启动任务
|
||||
# task_handle = asyncio.create_task(process_video()) # 存储task_handle
|
||||
|
||||
# 记录任务信息到task_manager
|
||||
task_info = {
|
||||
"source_url": video_url,
|
||||
"push_url": push_url,
|
||||
"status": "running",
|
||||
"task_handle": task_handle, # 存储实际的任务句柄
|
||||
"model_configs": model_configs,
|
||||
"device_height": device_height,
|
||||
"repeat_dis": repeat_dis,
|
||||
"repeat_time": repeat_time
|
||||
}
|
||||
|
||||
# 使用task_manager管理任务
|
||||
await task_manager.add_task(
|
||||
task_id,
|
||||
task_info,
|
||||
task_handle, # 传递实际的任务句柄
|
||||
[] # 暂时没有子任务
|
||||
)
|
||||
|
||||
# 等待任务完成或被取消
|
||||
try:
|
||||
await task_handle
|
||||
except asyncio.CancelledError:
|
||||
pass # 任务被取消是正常的
|
||||
await task_manager.remove_task(task_id)
|
||||
# # # 针对湖北现场处理结束--------------------------------------------------------
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"任务 {request_json.task_id} 收到停止信号,正在清理...")
|
||||
# 清理资源逻辑...
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"任务 {request_json.task_id} 处理失败: {e}")
|
||||
raise
|
||||
58
multi_back_detect/router_multi_back_detect_api.py
Normal file
58
multi_back_detect/router_multi_back_detect_api.py
Normal file
@ -0,0 +1,58 @@
|
||||
import time
|
||||
import logging
|
||||
from sanic_cors import CORS
|
||||
from sanic import Sanic, Request, json
|
||||
from multi_back_detect_api import multi_back_detect_bp
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建Sanic应用
|
||||
app = Sanic("multiBackDetectAPI")
|
||||
# 显式注册蓝图
|
||||
app.blueprint(multi_back_detect_bp)
|
||||
|
||||
CORS(app, automatic_options=True)
|
||||
|
||||
# 中间件:请求计时
|
||||
@app.middleware("request")
|
||||
async def add_start_time(request: Request):
|
||||
request.ctx.start_time = time.time()
|
||||
|
||||
@app.middleware("response")
|
||||
async def add_response_time(request: Request, response):
|
||||
if hasattr(request.ctx, "start_time"):
|
||||
process_time = (time.time() - request.ctx.start_time) * 1000
|
||||
response.headers["X-Process-Time"] = f"{process_time:.2f}ms"
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def health_check(request: Request):
|
||||
"""健康检查"""
|
||||
return json({
|
||||
"status": "healthy",
|
||||
"timestamp": time.time(),
|
||||
"service": "terrain-analysis-api",
|
||||
"version": "1.0.0"
|
||||
})
|
||||
|
||||
# 错误处理
|
||||
@app.exception(Exception)
|
||||
async def handle_exception(request: Request, exception):
|
||||
"""全局异常处理"""
|
||||
logger.error(f"未处理的异常: {exception}")
|
||||
return json({
|
||||
"error": "服务器内部错误",
|
||||
"message": str(exception) if app.debug else "请稍后重试",
|
||||
"success": False
|
||||
}, status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启动服务器
|
||||
app.run(
|
||||
host="0.0.0.0",
|
||||
port=12320,
|
||||
debug=False, # 生产环境设为False
|
||||
access_log=True,
|
||||
auto_reload=True
|
||||
)
|
||||
@ -36,6 +36,7 @@ from sanic import Sanic, Request
|
||||
# 引入其他模块
|
||||
from b3dm.earthwork_api import earthwork_bp
|
||||
from b3dm.terrain_api import terrain_bp
|
||||
from multi_back_detect.multi_back_detect_api import multi_back_detect_bp
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
@ -1055,8 +1056,8 @@ async def start_back_detection(request):
|
||||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||||
|
||||
|
||||
@app.post("/ai/stream/multi_back_detect")
|
||||
async def start_multi_back_detection(request):
|
||||
@app.post("/ai/stream/multi_back_detect1")
|
||||
async def start_multi_back_detection1(request):
|
||||
try:
|
||||
verify_token(request)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user