645 lines
26 KiB
Python
645 lines
26 KiB
Python
|
|
import json
|
|||
|
|
import time
|
|||
|
|
from sanic import Blueprint
|
|||
|
|
from sanic.response import json as json_response
|
|||
|
|
from sanic.exceptions import Unauthorized, SanicException
|
|||
|
|
from dataclasses import dataclass, asdict
|
|||
|
|
from typing import List, Dict, Any
|
|||
|
|
import logging
|
|||
|
|
import asyncio
|
|||
|
|
import traceback
|
|||
|
|
from datetime import datetime
|
|||
|
|
from sympy import false
|
|||
|
|
|
|||
|
|
|
|||
|
|
try :
|
|||
|
|
from middleware.TaskManager import task_manager
|
|||
|
|
from middleware.query_model import ModelConfigDAO
|
|||
|
|
from middleware.query_postgress import batch_query_model_func_id
|
|||
|
|
from middleware.read_yolo_config import read_local_func_config
|
|||
|
|
from yolo.cv_multi_model_back_video import start_rtmp_processing
|
|||
|
|
from cv_video import startAIVideo, stopAIVideo
|
|||
|
|
from cv_back_video import startBackAIVideo
|
|||
|
|
except Exception as e:
|
|||
|
|
import sys
|
|||
|
|
from pathlib import Path
|
|||
|
|
# 获取项目根目录(假设脚本在根目录的子目录中)
|
|||
|
|
ROOT_DIR = Path(__file__).parent.parent # 根据实际层级调整
|
|||
|
|
sys.path.append(str(ROOT_DIR))
|
|||
|
|
from middleware.TaskManager import task_manager
|
|||
|
|
from middleware.query_model import ModelConfigDAO
|
|||
|
|
from middleware.query_postgress import batch_query_model_func_id
|
|||
|
|
from middleware.read_yolo_config import read_local_func_config
|
|||
|
|
from yolo.cv_multi_model_back_video import start_rtmp_processing
|
|||
|
|
from cv_video import startAIVideo, stopAIVideo
|
|||
|
|
from cv_back_video import startBackAIVideo
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 配置类
|
|||
|
|
class Config:
|
|||
|
|
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
|||
|
|
MAX_ACTIVE_TASKS = 10
|
|||
|
|
DEFAULT_CONFIDENCE = 0.5
|
|||
|
|
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
|||
|
|
|
|||
|
|
#正式数据库
|
|||
|
|
DB_CONFIG = {
|
|||
|
|
"dbname": "smart_dev",
|
|||
|
|
"user": "postgres",
|
|||
|
|
"password": "StrongPassword@123",
|
|||
|
|
"host": "222.212.85.86",
|
|||
|
|
"port": "5061"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 服务状态标志
|
|||
|
|
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
|
|||
|
|
|
|||
|
|
# 配置日志
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.INFO,
|
|||
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|||
|
|
)
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
multi_back_detect_bp = Blueprint("multi_back_detect", url_prefix="")
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class BaseContentBody:
|
|||
|
|
pass # 公共字段可以放在这里
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ContentBodyFormat_VideoMultiBackDetect(BaseContentBody):
|
|||
|
|
# mqtt_pub_id: int
|
|||
|
|
org_code: str
|
|||
|
|
# mqtt_ip: str
|
|||
|
|
# mqtt_port: int
|
|||
|
|
# mqtt_topic: str
|
|||
|
|
minio_file_path: str
|
|||
|
|
push_url: str # 临时测试用
|
|||
|
|
confidence: float
|
|||
|
|
para_list: list
|
|||
|
|
# invade_file: str
|
|||
|
|
invade: list
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ContentBodyFormat_Sam3Pic(BaseContentBody):
|
|||
|
|
img_url: str
|
|||
|
|
prompt: str
|
|||
|
|
confidence: int
|
|||
|
|
mqtt_ip: str
|
|||
|
|
mqtt_port: int
|
|||
|
|
mqtt_topic: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ContentBodyFormat_MultiBackDetect(BaseContentBody):
|
|||
|
|
# mqtt_pub_id: int
|
|||
|
|
# mqtt_sub_id: int
|
|||
|
|
# mqtt_pub_ip: str
|
|||
|
|
# mqtt_pub_port: int
|
|||
|
|
# mqtt_pub_topic: str
|
|||
|
|
# mqtt_sub_ip: str
|
|||
|
|
# mqtt_sub_port: int
|
|||
|
|
# mqtt_sub_topic: str
|
|||
|
|
org_code: str
|
|||
|
|
func_id: list
|
|||
|
|
source_url: str
|
|||
|
|
push_url: str # 临时测试用
|
|||
|
|
confidence: float
|
|||
|
|
para_list: list
|
|||
|
|
# invade_file: str
|
|||
|
|
invade: list
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ContentBodyFormat_BackDetect(BaseContentBody):
|
|||
|
|
mqtt_ip: str
|
|||
|
|
mqtt_port: int
|
|||
|
|
mqtt_topic: str
|
|||
|
|
source_url: str
|
|||
|
|
push_url: str # 临时测试用
|
|||
|
|
confidence: float
|
|||
|
|
func_id: List[int]
|
|||
|
|
para: {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ContentBodyFormat_BackDetectPic(BaseContentBody):
|
|||
|
|
s3_id: int
|
|||
|
|
s3_url: list[str]
|
|||
|
|
org_code: str
|
|||
|
|
# mqtt_pub_id: int
|
|||
|
|
confidence: float
|
|||
|
|
func_id: list[int]
|
|||
|
|
para: {}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class EarlyLaterUrls:
|
|||
|
|
early: str
|
|||
|
|
later: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ContentBodyFormat_Detection(BaseContentBody):
|
|||
|
|
s3_id: int
|
|||
|
|
s3_url: EarlyLaterUrls
|
|||
|
|
func_id: list[int]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ContentBodyFormat_Segementation(BaseContentBody):
|
|||
|
|
s3_id: int
|
|||
|
|
s3_url: list[str]
|
|||
|
|
func_id: list[int]
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RequestJson:
|
|||
|
|
task_id: str
|
|||
|
|
sn: str
|
|||
|
|
content_body: BaseContentBody
|
|||
|
|
|
|||
|
|
def validate(self) -> None:
|
|||
|
|
"""验证请求参数"""
|
|||
|
|
if not self.task_id:
|
|||
|
|
raise ValueError("task_id is required")
|
|||
|
|
|
|||
|
|
if isinstance(self.content_body, ContentBodyFormat_VideoMultiBackDetect):
|
|||
|
|
if not self.content_body.minio_file_path:
|
|||
|
|
raise ValueError("minio_file_path is required for ContentBodyFormat_VideoMultiBackDetect")
|
|||
|
|
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
|||
|
|
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
|||
|
|
elif isinstance(self.content_body, ContentBodyFormat_MultiBackDetect):
|
|||
|
|
if not self.content_body.para_list:
|
|||
|
|
raise ValueError("para_list is required for ContentBodyFormat_MultiBackDetect")
|
|||
|
|
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
|||
|
|
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
|||
|
|
|
|||
|
|
elif isinstance(self.content_body, ContentBodyFormat_BackDetect):
|
|||
|
|
if not self.content_body.source_url:
|
|||
|
|
raise ValueError("source_url is required for ContentBodyFormat_BackDetect")
|
|||
|
|
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
|||
|
|
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
|||
|
|
|
|||
|
|
elif isinstance(self.content_body, ContentBodyFormat_BackDetectPic):
|
|||
|
|
if not self.content_body.s3_id:
|
|||
|
|
raise ValueError("s3_id is required for ContentBodyFormat_BackDetectPic")
|
|||
|
|
if not self.content_body.s3_url:
|
|||
|
|
raise ValueError("s3_url is required for ContentBodyFormat_BackDetectPic")
|
|||
|
|
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
|||
|
|
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetectPic")
|
|||
|
|
|
|||
|
|
elif isinstance(self.content_body, ContentBodyFormat_Detection):
|
|||
|
|
if not self.content_body.s3_id:
|
|||
|
|
raise ValueError("s3_id is required for ContentBodyFormat_Detection")
|
|||
|
|
if not self.content_body.s3_url.early or not self.content_body.s3_url.later:
|
|||
|
|
raise ValueError("Both early and later URLs are required for ContentBodyFormat_Detection")
|
|||
|
|
|
|||
|
|
elif isinstance(self.content_body, ContentBodyFormat_Segementation):
|
|||
|
|
if not self.content_body.s3_id:
|
|||
|
|
raise ValueError("s3_id is required for ContentBodyFormat_Segementation")
|
|||
|
|
if not self.content_body.s3_url:
|
|||
|
|
raise ValueError("s3_url is required for ContentBodyFormat_Segementation")
|
|||
|
|
elif isinstance(self.content_body, ContentBodyFormat_Sam3Pic):
|
|||
|
|
if not self.content_body.prompt:
|
|||
|
|
raise ValueError("prompt is required for ContentBodyFormat_Sam3Pic")
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def from_dict(cls, data: Dict[str, Any]) -> 'RequestJson':
|
|||
|
|
try:
|
|||
|
|
task_id = data['task_id']
|
|||
|
|
sn = data['sn']
|
|||
|
|
content_body_data = data['content_body']
|
|||
|
|
|
|||
|
|
if 'minio_file_path' in content_body_data:
|
|||
|
|
content_body = ContentBodyFormat_VideoMultiBackDetect(
|
|||
|
|
org_code=content_body_data['org_code'],
|
|||
|
|
minio_file_path=content_body_data['minio_file_path'],
|
|||
|
|
push_url=content_body_data['push_url'],
|
|||
|
|
confidence=content_body_data.get('confidence', 0.5),
|
|||
|
|
para_list=content_body_data.get('para_list', []),
|
|||
|
|
invade=content_body_data.get('invade', [])
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 根据 content_body_data 的类型创建相应的实例
|
|||
|
|
elif 'para_list' in content_body_data:
|
|||
|
|
content_body = ContentBodyFormat_MultiBackDetect(
|
|||
|
|
org_code=content_body_data['org_code'],
|
|||
|
|
func_id=content_body_data['func_id'],
|
|||
|
|
source_url=content_body_data['source_url'],
|
|||
|
|
push_url=content_body_data['push_url'],
|
|||
|
|
confidence=content_body_data.get('confidence', 0.5),
|
|||
|
|
para_list=content_body_data.get('para_list', []),
|
|||
|
|
invade=content_body_data.get('invade', [])
|
|||
|
|
)
|
|||
|
|
# 根据 content_body_data 的类型创建相应的实例
|
|||
|
|
elif 'source_url' in content_body_data:
|
|||
|
|
content_body = ContentBodyFormat_BackDetect(
|
|||
|
|
mqtt_ip=content_body_data['mqtt_ip'],
|
|||
|
|
mqtt_port=content_body_data['mqtt_port'],
|
|||
|
|
mqtt_topic=content_body_data['mqtt_topic'],
|
|||
|
|
source_url=content_body_data['source_url'],
|
|||
|
|
push_url=content_body_data['push_url'],
|
|||
|
|
confidence=content_body_data.get('confidence', 0.5),
|
|||
|
|
func_id=content_body_data.get('func_id', []),
|
|||
|
|
para=content_body_data.get('para', {})
|
|||
|
|
)
|
|||
|
|
elif 's3_id' in content_body_data and 's3_url' in content_body_data:
|
|||
|
|
if isinstance(content_body_data['s3_url'], dict) and 'early' in content_body_data[
|
|||
|
|
's3_url'] and 'later' in content_body_data['s3_url']:
|
|||
|
|
content_body = ContentBodyFormat_Detection(
|
|||
|
|
s3_id=content_body_data['s3_id'],
|
|||
|
|
s3_url=EarlyLaterUrls(
|
|||
|
|
early=content_body_data['s3_url']['early'],
|
|||
|
|
later=content_body_data['s3_url']['later']
|
|||
|
|
),
|
|||
|
|
func_id=content_body_data.get('func_id', [])
|
|||
|
|
)
|
|||
|
|
elif isinstance(content_body_data['s3_url'], list) and 'confidence' not in content_body_data:
|
|||
|
|
content_body = ContentBodyFormat_Segementation(
|
|||
|
|
s3_id=content_body_data['s3_id'],
|
|||
|
|
s3_url=content_body_data['s3_url'],
|
|||
|
|
func_id=content_body_data.get('func_id', [])
|
|||
|
|
)
|
|||
|
|
elif isinstance(content_body_data['s3_url'], list):
|
|||
|
|
content_body = ContentBodyFormat_BackDetectPic(
|
|||
|
|
s3_id=content_body_data['s3_id'],
|
|||
|
|
s3_url=content_body_data['s3_url'],
|
|||
|
|
# mqtt_pub_id=content_body_data['mqtt_pub_id'],
|
|||
|
|
org_code=content_body_data['org_code'],
|
|||
|
|
confidence=content_body_data.get('confidence', 0.5),
|
|||
|
|
func_id=content_body_data.get('func_id', []),
|
|||
|
|
para=content_body_data.get('para', {})
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
raise ValueError("Invalid s3_url format for ContentBodyFormat_Detection")
|
|||
|
|
elif 'prompt' in content_body_data:
|
|||
|
|
content_body = ContentBodyFormat_Sam3Pic(
|
|||
|
|
img_url=content_body_data['img_url'],
|
|||
|
|
prompt=content_body_data['prompt'],
|
|||
|
|
confidence=content_body_data.get('confidence', 0.5),
|
|||
|
|
mqtt_ip=content_body_data['mqtt_ip'],
|
|||
|
|
mqtt_port=content_body_data['mqtt_port'],
|
|||
|
|
mqtt_topic=content_body_data['mqtt_topic']
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
raise ValueError("Invalid content_body format")
|
|||
|
|
|
|||
|
|
instance = cls(
|
|||
|
|
task_id=task_id,
|
|||
|
|
sn=sn,
|
|||
|
|
content_body=content_body
|
|||
|
|
)
|
|||
|
|
instance.validate()
|
|||
|
|
return instance
|
|||
|
|
except KeyError as e:
|
|||
|
|
raise ValueError(f"Missing required field: {str(e)}")
|
|||
|
|
|
|||
|
|
async def safe_stop_ai_video():
|
|||
|
|
"""安全地停止AI视频处理,带有错误处理和恢复机制"""
|
|||
|
|
try:
|
|||
|
|
await asyncio.to_thread(stopAIVideo)
|
|||
|
|
return True
|
|||
|
|
except Exception as e:
|
|||
|
|
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
|
|||
|
|
logger.error(error_msg)
|
|||
|
|
|
|||
|
|
# 标记服务状态为不健康
|
|||
|
|
service_status["is_healthy"] = False
|
|||
|
|
service_status["last_error"] = str(e)
|
|||
|
|
service_status["error_time"] = datetime.now().isoformat()
|
|||
|
|
|
|||
|
|
# 强制结束所有任务
|
|||
|
|
task_manager.mark_all_tasks_as_stopped()
|
|||
|
|
|
|||
|
|
# 尝试通过其他方式杀死可能存在的进程
|
|||
|
|
try:
|
|||
|
|
import os
|
|||
|
|
import signal
|
|||
|
|
import psutil
|
|||
|
|
|
|||
|
|
current_process = psutil.Process(os.getpid())
|
|||
|
|
# 查找并终止ffmpeg子进程
|
|||
|
|
for child in current_process.children(recursive=True):
|
|||
|
|
try:
|
|||
|
|
child_name = child.name().lower()
|
|||
|
|
if 'ffmpeg' in child_name:
|
|||
|
|
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
|
|||
|
|
child.send_signal(signal.SIGTERM)
|
|||
|
|
except Exception as child_e:
|
|||
|
|
logger.error(f"终止子进程出错: {str(child_e)}")
|
|||
|
|
except Exception as kill_e:
|
|||
|
|
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
|
|||
|
|
|
|||
|
|
# 等待一段时间让系统恢复
|
|||
|
|
await asyncio.sleep(Config.RESTART_DELAY)
|
|||
|
|
|
|||
|
|
# 重置服务状态
|
|||
|
|
service_status["is_healthy"] = True
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def verify_token(request) -> None:
|
|||
|
|
"""验证请求token"""
|
|||
|
|
token = request.headers.get('X-API-Token')
|
|||
|
|
if not token or token != Config.VALID_TOKEN:
|
|||
|
|
logger.warning("Invalid token attempt")
|
|||
|
|
raise Unauthorized("Invalid token")
|
|||
|
|
|
|||
|
|
@multi_back_detect_bp.post("/ai/stream/multi_back_detect")
|
|||
|
|
async def start_multi_back_detection(request):
|
|||
|
|
try:
|
|||
|
|
verify_token(request)
|
|||
|
|
|
|||
|
|
# 检查服务健康状态
|
|||
|
|
if not service_status["is_healthy"]:
|
|||
|
|
logger.warning(
|
|||
|
|
f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}")
|
|||
|
|
service_status["is_healthy"] = True
|
|||
|
|
|
|||
|
|
# 停止所有现有任务(可选,根据需求调整)
|
|||
|
|
# for task_id in list(task_manager.tasks.keys()):
|
|||
|
|
# await task_manager.remove_task(task_id)
|
|||
|
|
|
|||
|
|
# 解析并验证请求数据
|
|||
|
|
request_json = RequestJson.from_dict(request.json)
|
|||
|
|
print(f"/ai/stream/multi_back_detect 请求:{request.json}")
|
|||
|
|
time.sleep(3)
|
|||
|
|
if request_json.task_id in task_manager.tasks:
|
|||
|
|
logger.warning(f"任务 {request_json.task_id} 已存在,跳过创建")
|
|||
|
|
return json_response({
|
|||
|
|
"status": "error",
|
|||
|
|
"message": f"任务 {request_json.task_id} 已存在,跳过创建"
|
|||
|
|
}, status=500)
|
|||
|
|
|
|||
|
|
if isinstance(request_json.content_body, ContentBodyFormat_MultiBackDetect):
|
|||
|
|
try:
|
|||
|
|
# 创建停止事件
|
|||
|
|
stop_event = asyncio.Event()
|
|||
|
|
|
|||
|
|
# 包装处理函数以支持停止事件
|
|||
|
|
async def wrapped_processing():
|
|||
|
|
try:
|
|||
|
|
await run_back_Multi_Detect_async(request, request_json, stop_event)
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
logger.info(f"任务 {request_json.task_id} 被取消")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"任务 {request_json.task_id} 异常终止: {e}")
|
|||
|
|
|
|||
|
|
# 创建并启动任务
|
|||
|
|
task_handle = asyncio.create_task(wrapped_processing())
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"启动AI视频处理失败: {e}")
|
|||
|
|
return json_response({
|
|||
|
|
"status": "error",
|
|||
|
|
"message": f"Failed to start AI video processing: {str(e)}"
|
|||
|
|
}, status=500)
|
|||
|
|
else:
|
|||
|
|
return json_response({
|
|||
|
|
"status": "failed",
|
|||
|
|
"message": "content_body structure is wrong"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return json_response({
|
|||
|
|
"status": "success",
|
|||
|
|
"task_id": request_json.task_id,
|
|||
|
|
"message": "Detection started successfully"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except ValueError as e:
|
|||
|
|
logger.error(f"Validation error: {str(e)}")
|
|||
|
|
return json_response({"status": "error", "message": str(e)}, status=400)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
|||
|
|
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
|||
|
|
|
|||
|
|
async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio.Event):
|
|||
|
|
global DB_CONFIG
|
|||
|
|
model_configs = []
|
|||
|
|
task_handle = None # 初始化task_handle
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
|
|||
|
|
invade_enable = false # para_list 中可以包含多个侵限使能,其中一个使能,即将置为True
|
|||
|
|
py_func = []
|
|||
|
|
# 创建DAO实例
|
|||
|
|
dao = ModelConfigDAO(DB_CONFIG)
|
|||
|
|
# insert_request_log(self, task_id, sn, org_code, requset_json, request)
|
|||
|
|
for para in request_json.content_body.para_list:
|
|||
|
|
func_id = para["func_id"]
|
|||
|
|
category = para.get("py_func", []) # 提供默认值
|
|||
|
|
py_func = category # 湖北现场临时用
|
|||
|
|
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
|
|||
|
|
if para_invade_enable:
|
|||
|
|
invade_enable = True
|
|||
|
|
query_results = batch_query_model_func_id([func_id], **DB_CONFIG)
|
|||
|
|
row_func_id = 0 # 伪代码,后续记得修改
|
|||
|
|
if len(query_results) < 1:
|
|||
|
|
continue
|
|||
|
|
for row in query_results:
|
|||
|
|
row_func_id = row["model_func_id"]
|
|||
|
|
func_id = para["func_id"]
|
|||
|
|
config = dao.get_config(func_id, category)
|
|||
|
|
repeat_dis = -1 # 基于两帧之间的距离去重
|
|||
|
|
if config:
|
|||
|
|
# 打印结构化结果(使用自定义编码器处理datetime)
|
|||
|
|
# 访问特定字段
|
|||
|
|
print("\n模型路径:", config.model_path)
|
|||
|
|
print("过滤类别:", config.filter_indices)
|
|||
|
|
print("第一个类别:", asdict(config.classes[0]))
|
|||
|
|
print("创建时间:", config.created_at)
|
|||
|
|
print("更新时间:", config.updated_at)
|
|||
|
|
print("去重的距离:", config.repeat_dis)
|
|||
|
|
repeat_dis = config.repeat_dis
|
|||
|
|
repeat_time = config.repeat_time
|
|||
|
|
high_count_warn = config.high_count_warn
|
|||
|
|
print(f"config.high_count_warn {config.high_count_warn}")
|
|||
|
|
model_configs.append(
|
|||
|
|
{
|
|||
|
|
|
|||
|
|
'path': config.model_path,
|
|||
|
|
# 'engine_path': config.engine_path,
|
|||
|
|
# 'so_path': config.so_path,
|
|||
|
|
# # 测试代码
|
|||
|
|
'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine",
|
|||
|
|
'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll",
|
|||
|
|
# 工地安全帽
|
|||
|
|
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
|
|||
|
|
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
|
|||
|
|
|
|||
|
|
'cls_map': config.cls_zn_to_eh_dict,
|
|||
|
|
'allowed_classes': config.allowed_classes,
|
|||
|
|
"cls_index": config.class_indices,
|
|||
|
|
"class_names": config.cls_names,
|
|||
|
|
"chinese_label": config.cls_en_dict,
|
|||
|
|
"list_func_id": row_func_id,
|
|||
|
|
"func_id": func_id,
|
|||
|
|
"para_invade_enable": para_invade_enable,
|
|||
|
|
"config_conf": config.conf
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
print(f"未找到ID为 {func_id} 的模型配置")
|
|||
|
|
|
|||
|
|
# category = para.get("category", []) # 提供默认值
|
|||
|
|
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
|
|||
|
|
if para_invade_enable:
|
|||
|
|
invade_enable = True
|
|||
|
|
|
|||
|
|
# 前置处理 不同的版本的模型,输入的数据的格式不同,需要做前置处理,入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式
|
|||
|
|
|
|||
|
|
video_url = request_json.content_body.source_url
|
|||
|
|
sn = request_json.sn
|
|||
|
|
task_id = request_json.task_id
|
|||
|
|
# mqtt_pub_id = request_json.content_body.mqtt_pub_id
|
|||
|
|
# mqtt_sub_id = request_json.content_body.mqtt_sub_id
|
|||
|
|
org_code = request_json.content_body.org_code
|
|||
|
|
push_url = request_json.content_body.push_url
|
|||
|
|
# invade_file = request_json.content_body.invade_file
|
|||
|
|
invade = request_json.content_body.invade
|
|||
|
|
invade_file = invade["invade_file"]
|
|||
|
|
camera_para_url = invade["camera_para_url"]
|
|||
|
|
|
|||
|
|
if high_count_warn is None:
|
|||
|
|
high_count_warn=0
|
|||
|
|
|
|||
|
|
if "invade_switch" in invade:
|
|||
|
|
invade_switch = invade["invade_switch"]
|
|||
|
|
else:
|
|||
|
|
invade_switch = 0 # 或其他默认值
|
|||
|
|
# dao.get_mqtt_config_by_orgcode(org_code,)
|
|||
|
|
str_request = str(request) + "&" + str(request.socket) # 待测试,看看公网能不能捕获到请求端ip
|
|||
|
|
dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request)
|
|||
|
|
mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
|
|||
|
|
mqtt_sub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "sub")
|
|||
|
|
|
|||
|
|
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
|
|||
|
|
mqtt_pub_port = mqtt_pub_config.mqtt_port
|
|||
|
|
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
|
|||
|
|
print(f"mqtt_pub_topic {mqtt_pub_topic}")
|
|||
|
|
mqtt_pub_username = mqtt_pub_config.mqtt_username
|
|||
|
|
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
|
|||
|
|
mqtt_pub_description = mqtt_pub_config.mqtt_description
|
|||
|
|
mqtt_pub_org_code = mqtt_pub_config.org_code
|
|||
|
|
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
|
|||
|
|
|
|||
|
|
mqtt_sub_ip = mqtt_sub_config.mqtt_ip
|
|||
|
|
mqtt_sub_port = mqtt_sub_config.mqtt_port
|
|||
|
|
mqtt_sub_topic = mqtt_sub_config.mqtt_topic
|
|||
|
|
mqtt_sub_topic = mqtt_sub_topic.format(sn=sn)
|
|||
|
|
|
|||
|
|
print(f"mqtt_sub_topic {mqtt_sub_topic}")
|
|||
|
|
mqtt_sub_username = mqtt_sub_config.mqtt_username
|
|||
|
|
mqtt_sub_pass = mqtt_sub_config.mqtt_pass
|
|||
|
|
mqtt_sub_description = mqtt_sub_config.mqtt_description
|
|||
|
|
mqtt_sub_org_code = mqtt_sub_config.org_code
|
|||
|
|
mqtt_sub_mqtt_type = mqtt_sub_config.mqtt_type
|
|||
|
|
|
|||
|
|
local_func_config = read_local_func_config()
|
|||
|
|
|
|||
|
|
sn = request_json.sn
|
|||
|
|
device = dao.get_device(sn, org_code)
|
|||
|
|
print(f"device表 {sn} {org_code}")
|
|||
|
|
device_sn = device.sn
|
|||
|
|
device_orgcode = device.orgcode
|
|||
|
|
device_dname = device.dname
|
|||
|
|
device_lat = device.lat
|
|||
|
|
device_lng = device.lng
|
|||
|
|
device_height = device.height # 机场高度,后续用作现场的高程计算
|
|||
|
|
|
|||
|
|
# # 启动处理流程
|
|||
|
|
async def process_flow():
|
|||
|
|
try:
|
|||
|
|
await start_rtmp_processing(
|
|||
|
|
video_url,
|
|||
|
|
request_json.task_id,
|
|||
|
|
model_configs,
|
|||
|
|
mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
|||
|
|
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
|
|||
|
|
push_url,
|
|||
|
|
invade_enable,invade_switch, invade_file, camera_para_url,
|
|||
|
|
device_height, repeat_dis, repeat_time,high_count_warn
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"处理流程异常: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
# 运行处理流程,直到被取消
|
|||
|
|
await asyncio.shield(process_flow())
|
|||
|
|
|
|||
|
|
# # # # 针对湖北现场临时处理--------------------------------------------------------
|
|||
|
|
# model_path = model_configs[0]["path"]
|
|||
|
|
# detect_classes = py_func
|
|||
|
|
# print(f"detect_classesdetect_classes {detect_classes}")
|
|||
|
|
# confidence = model_configs[0]["config_conf"]
|
|||
|
|
# # 创建处理函数以支持停止事件
|
|||
|
|
# async def process_video():
|
|||
|
|
# nonlocal task_handle # 使用nonlocal访问外部变量
|
|||
|
|
# try:
|
|||
|
|
# # 针对湖北现场临时处理
|
|||
|
|
# source_url = video_url
|
|||
|
|
# model_path = model_configs[0]["path"]
|
|||
|
|
# detect_classes = py_func # 使用py_func作为检测类别
|
|||
|
|
# print(f"detect_classesdetect_classes {detect_classes}")
|
|||
|
|
#
|
|||
|
|
# confidence = model_configs[0]["config_conf"]
|
|||
|
|
#
|
|||
|
|
# # 启动YOLO检测
|
|||
|
|
# await asyncio.to_thread(
|
|||
|
|
# startAIVideo,
|
|||
|
|
# source_url,
|
|||
|
|
# push_url,
|
|||
|
|
# model_path,
|
|||
|
|
# detect_classes,
|
|||
|
|
# confidence
|
|||
|
|
# )
|
|||
|
|
# except asyncio.CancelledError:
|
|||
|
|
# logger.info(f"任务 {task_id} 被取消")
|
|||
|
|
# raise
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logger.error(f"任务 {task_id} 异常终止: {e}")
|
|||
|
|
# raise
|
|||
|
|
#
|
|||
|
|
# # 创建并启动任务
|
|||
|
|
# task_handle = asyncio.create_task(process_video()) # 存储task_handle
|
|||
|
|
|
|||
|
|
# 记录任务信息到task_manager
|
|||
|
|
task_info = {
|
|||
|
|
"source_url": video_url,
|
|||
|
|
"push_url": push_url,
|
|||
|
|
"status": "running",
|
|||
|
|
"task_handle": task_handle, # 存储实际的任务句柄
|
|||
|
|
"model_configs": model_configs,
|
|||
|
|
"device_height": device_height,
|
|||
|
|
"repeat_dis": repeat_dis,
|
|||
|
|
"repeat_time": repeat_time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 使用task_manager管理任务
|
|||
|
|
await task_manager.add_task(
|
|||
|
|
task_id,
|
|||
|
|
task_info,
|
|||
|
|
task_handle, # 传递实际的任务句柄
|
|||
|
|
[] # 暂时没有子任务
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 等待任务完成或被取消
|
|||
|
|
try:
|
|||
|
|
await task_handle
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
pass # 任务被取消是正常的
|
|||
|
|
await task_manager.remove_task(task_id)
|
|||
|
|
# # # 针对湖北现场处理结束--------------------------------------------------------
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
logger.info(f"任务 {request_json.task_id} 收到停止信号,正在清理...")
|
|||
|
|
# 清理资源逻辑...
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"任务 {request_json.task_id} 处理失败: {e}")
|
|||
|
|
raise
|