645 lines
26 KiB
Python
645 lines
26 KiB
Python
import json
|
||
import time
|
||
from sanic import Blueprint
|
||
from sanic.response import json as json_response
|
||
from sanic.exceptions import Unauthorized, SanicException
|
||
from dataclasses import dataclass, asdict
|
||
from typing import List, Dict, Any
|
||
import logging
|
||
import asyncio
|
||
import traceback
|
||
from datetime import datetime
|
||
from sympy import false
|
||
|
||
|
||
try :
|
||
from middleware.TaskManager import task_manager
|
||
from middleware.query_model import ModelConfigDAO
|
||
from middleware.query_postgress import batch_query_model_func_id
|
||
from middleware.read_yolo_config import read_local_func_config
|
||
from yolo.cv_multi_model_back_video import start_rtmp_processing
|
||
from cv_video import startAIVideo, stopAIVideo
|
||
from cv_back_video import startBackAIVideo
|
||
except Exception as e:
|
||
import sys
|
||
from pathlib import Path
|
||
# 获取项目根目录(假设脚本在根目录的子目录中)
|
||
ROOT_DIR = Path(__file__).parent.parent # 根据实际层级调整
|
||
sys.path.append(str(ROOT_DIR))
|
||
from middleware.TaskManager import task_manager
|
||
from middleware.query_model import ModelConfigDAO
|
||
from middleware.query_postgress import batch_query_model_func_id
|
||
from middleware.read_yolo_config import read_local_func_config
|
||
from yolo.cv_multi_model_back_video import start_rtmp_processing
|
||
from cv_video import startAIVideo, stopAIVideo
|
||
from cv_back_video import startBackAIVideo
|
||
|
||
|
||
# 配置类
|
||
class Config:
|
||
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
||
MAX_ACTIVE_TASKS = 10
|
||
DEFAULT_CONFIDENCE = 0.5
|
||
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
||
|
||
#正式数据库
|
||
DB_CONFIG = {
|
||
"dbname": "smart_dev",
|
||
"user": "postgres",
|
||
"password": "StrongPassword@123",
|
||
"host": "222.212.85.86",
|
||
"port": "5061"
|
||
}
|
||
|
||
# 服务状态标志
|
||
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
multi_back_detect_bp = Blueprint("multi_back_detect", url_prefix="")
|
||
|
||
@dataclass
|
||
class BaseContentBody:
|
||
pass # 公共字段可以放在这里
|
||
|
||
|
||
@dataclass
|
||
class ContentBodyFormat_VideoMultiBackDetect(BaseContentBody):
|
||
# mqtt_pub_id: int
|
||
org_code: str
|
||
# mqtt_ip: str
|
||
# mqtt_port: int
|
||
# mqtt_topic: str
|
||
minio_file_path: str
|
||
push_url: str # 临时测试用
|
||
confidence: float
|
||
para_list: list
|
||
# invade_file: str
|
||
invade: list
|
||
|
||
|
||
@dataclass
|
||
class ContentBodyFormat_Sam3Pic(BaseContentBody):
|
||
img_url: str
|
||
prompt: str
|
||
confidence: int
|
||
mqtt_ip: str
|
||
mqtt_port: int
|
||
mqtt_topic: str
|
||
|
||
|
||
@dataclass
|
||
class ContentBodyFormat_MultiBackDetect(BaseContentBody):
|
||
# mqtt_pub_id: int
|
||
# mqtt_sub_id: int
|
||
# mqtt_pub_ip: str
|
||
# mqtt_pub_port: int
|
||
# mqtt_pub_topic: str
|
||
# mqtt_sub_ip: str
|
||
# mqtt_sub_port: int
|
||
# mqtt_sub_topic: str
|
||
org_code: str
|
||
func_id: list
|
||
source_url: str
|
||
push_url: str # 临时测试用
|
||
confidence: float
|
||
para_list: list
|
||
# invade_file: str
|
||
invade: list
|
||
|
||
|
||
@dataclass
|
||
class ContentBodyFormat_BackDetect(BaseContentBody):
|
||
mqtt_ip: str
|
||
mqtt_port: int
|
||
mqtt_topic: str
|
||
source_url: str
|
||
push_url: str # 临时测试用
|
||
confidence: float
|
||
func_id: List[int]
|
||
para: {}
|
||
|
||
|
||
@dataclass
|
||
class ContentBodyFormat_BackDetectPic(BaseContentBody):
|
||
s3_id: int
|
||
s3_url: list[str]
|
||
org_code: str
|
||
# mqtt_pub_id: int
|
||
confidence: float
|
||
func_id: list[int]
|
||
para: {}
|
||
|
||
|
||
@dataclass
|
||
class EarlyLaterUrls:
|
||
early: str
|
||
later: str
|
||
|
||
|
||
@dataclass
|
||
class ContentBodyFormat_Detection(BaseContentBody):
|
||
s3_id: int
|
||
s3_url: EarlyLaterUrls
|
||
func_id: list[int]
|
||
|
||
|
||
@dataclass
|
||
class ContentBodyFormat_Segementation(BaseContentBody):
|
||
s3_id: int
|
||
s3_url: list[str]
|
||
func_id: list[int]
|
||
|
||
|
||
@dataclass
|
||
class RequestJson:
|
||
task_id: str
|
||
sn: str
|
||
content_body: BaseContentBody
|
||
|
||
def validate(self) -> None:
|
||
"""验证请求参数"""
|
||
if not self.task_id:
|
||
raise ValueError("task_id is required")
|
||
|
||
if isinstance(self.content_body, ContentBodyFormat_VideoMultiBackDetect):
|
||
if not self.content_body.minio_file_path:
|
||
raise ValueError("minio_file_path is required for ContentBodyFormat_VideoMultiBackDetect")
|
||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||
elif isinstance(self.content_body, ContentBodyFormat_MultiBackDetect):
|
||
if not self.content_body.para_list:
|
||
raise ValueError("para_list is required for ContentBodyFormat_MultiBackDetect")
|
||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||
|
||
elif isinstance(self.content_body, ContentBodyFormat_BackDetect):
|
||
if not self.content_body.source_url:
|
||
raise ValueError("source_url is required for ContentBodyFormat_BackDetect")
|
||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
|
||
|
||
elif isinstance(self.content_body, ContentBodyFormat_BackDetectPic):
|
||
if not self.content_body.s3_id:
|
||
raise ValueError("s3_id is required for ContentBodyFormat_BackDetectPic")
|
||
if not self.content_body.s3_url:
|
||
raise ValueError("s3_url is required for ContentBodyFormat_BackDetectPic")
|
||
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
|
||
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetectPic")
|
||
|
||
elif isinstance(self.content_body, ContentBodyFormat_Detection):
|
||
if not self.content_body.s3_id:
|
||
raise ValueError("s3_id is required for ContentBodyFormat_Detection")
|
||
if not self.content_body.s3_url.early or not self.content_body.s3_url.later:
|
||
raise ValueError("Both early and later URLs are required for ContentBodyFormat_Detection")
|
||
|
||
elif isinstance(self.content_body, ContentBodyFormat_Segementation):
|
||
if not self.content_body.s3_id:
|
||
raise ValueError("s3_id is required for ContentBodyFormat_Segementation")
|
||
if not self.content_body.s3_url:
|
||
raise ValueError("s3_url is required for ContentBodyFormat_Segementation")
|
||
elif isinstance(self.content_body, ContentBodyFormat_Sam3Pic):
|
||
if not self.content_body.prompt:
|
||
raise ValueError("prompt is required for ContentBodyFormat_Sam3Pic")
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> 'RequestJson':
|
||
try:
|
||
task_id = data['task_id']
|
||
sn = data['sn']
|
||
content_body_data = data['content_body']
|
||
|
||
if 'minio_file_path' in content_body_data:
|
||
content_body = ContentBodyFormat_VideoMultiBackDetect(
|
||
org_code=content_body_data['org_code'],
|
||
minio_file_path=content_body_data['minio_file_path'],
|
||
push_url=content_body_data['push_url'],
|
||
confidence=content_body_data.get('confidence', 0.5),
|
||
para_list=content_body_data.get('para_list', []),
|
||
invade=content_body_data.get('invade', [])
|
||
)
|
||
|
||
# 根据 content_body_data 的类型创建相应的实例
|
||
elif 'para_list' in content_body_data:
|
||
content_body = ContentBodyFormat_MultiBackDetect(
|
||
org_code=content_body_data['org_code'],
|
||
func_id=content_body_data['func_id'],
|
||
source_url=content_body_data['source_url'],
|
||
push_url=content_body_data['push_url'],
|
||
confidence=content_body_data.get('confidence', 0.5),
|
||
para_list=content_body_data.get('para_list', []),
|
||
invade=content_body_data.get('invade', [])
|
||
)
|
||
# 根据 content_body_data 的类型创建相应的实例
|
||
elif 'source_url' in content_body_data:
|
||
content_body = ContentBodyFormat_BackDetect(
|
||
mqtt_ip=content_body_data['mqtt_ip'],
|
||
mqtt_port=content_body_data['mqtt_port'],
|
||
mqtt_topic=content_body_data['mqtt_topic'],
|
||
source_url=content_body_data['source_url'],
|
||
push_url=content_body_data['push_url'],
|
||
confidence=content_body_data.get('confidence', 0.5),
|
||
func_id=content_body_data.get('func_id', []),
|
||
para=content_body_data.get('para', {})
|
||
)
|
||
elif 's3_id' in content_body_data and 's3_url' in content_body_data:
|
||
if isinstance(content_body_data['s3_url'], dict) and 'early' in content_body_data[
|
||
's3_url'] and 'later' in content_body_data['s3_url']:
|
||
content_body = ContentBodyFormat_Detection(
|
||
s3_id=content_body_data['s3_id'],
|
||
s3_url=EarlyLaterUrls(
|
||
early=content_body_data['s3_url']['early'],
|
||
later=content_body_data['s3_url']['later']
|
||
),
|
||
func_id=content_body_data.get('func_id', [])
|
||
)
|
||
elif isinstance(content_body_data['s3_url'], list) and 'confidence' not in content_body_data:
|
||
content_body = ContentBodyFormat_Segementation(
|
||
s3_id=content_body_data['s3_id'],
|
||
s3_url=content_body_data['s3_url'],
|
||
func_id=content_body_data.get('func_id', [])
|
||
)
|
||
elif isinstance(content_body_data['s3_url'], list):
|
||
content_body = ContentBodyFormat_BackDetectPic(
|
||
s3_id=content_body_data['s3_id'],
|
||
s3_url=content_body_data['s3_url'],
|
||
# mqtt_pub_id=content_body_data['mqtt_pub_id'],
|
||
org_code=content_body_data['org_code'],
|
||
confidence=content_body_data.get('confidence', 0.5),
|
||
func_id=content_body_data.get('func_id', []),
|
||
para=content_body_data.get('para', {})
|
||
)
|
||
|
||
else:
|
||
raise ValueError("Invalid s3_url format for ContentBodyFormat_Detection")
|
||
elif 'prompt' in content_body_data:
|
||
content_body = ContentBodyFormat_Sam3Pic(
|
||
img_url=content_body_data['img_url'],
|
||
prompt=content_body_data['prompt'],
|
||
confidence=content_body_data.get('confidence', 0.5),
|
||
mqtt_ip=content_body_data['mqtt_ip'],
|
||
mqtt_port=content_body_data['mqtt_port'],
|
||
mqtt_topic=content_body_data['mqtt_topic']
|
||
)
|
||
else:
|
||
raise ValueError("Invalid content_body format")
|
||
|
||
instance = cls(
|
||
task_id=task_id,
|
||
sn=sn,
|
||
content_body=content_body
|
||
)
|
||
instance.validate()
|
||
return instance
|
||
except KeyError as e:
|
||
raise ValueError(f"Missing required field: {str(e)}")
|
||
|
||
async def safe_stop_ai_video():
|
||
"""安全地停止AI视频处理,带有错误处理和恢复机制"""
|
||
try:
|
||
await asyncio.to_thread(stopAIVideo)
|
||
return True
|
||
except Exception as e:
|
||
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
|
||
logger.error(error_msg)
|
||
|
||
# 标记服务状态为不健康
|
||
service_status["is_healthy"] = False
|
||
service_status["last_error"] = str(e)
|
||
service_status["error_time"] = datetime.now().isoformat()
|
||
|
||
# 强制结束所有任务
|
||
task_manager.mark_all_tasks_as_stopped()
|
||
|
||
# 尝试通过其他方式杀死可能存在的进程
|
||
try:
|
||
import os
|
||
import signal
|
||
import psutil
|
||
|
||
current_process = psutil.Process(os.getpid())
|
||
# 查找并终止ffmpeg子进程
|
||
for child in current_process.children(recursive=True):
|
||
try:
|
||
child_name = child.name().lower()
|
||
if 'ffmpeg' in child_name:
|
||
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
|
||
child.send_signal(signal.SIGTERM)
|
||
except Exception as child_e:
|
||
logger.error(f"终止子进程出错: {str(child_e)}")
|
||
except Exception as kill_e:
|
||
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
|
||
|
||
# 等待一段时间让系统恢复
|
||
await asyncio.sleep(Config.RESTART_DELAY)
|
||
|
||
# 重置服务状态
|
||
service_status["is_healthy"] = True
|
||
return False
|
||
|
||
def verify_token(request) -> None:
|
||
"""验证请求token"""
|
||
token = request.headers.get('X-API-Token')
|
||
if not token or token != Config.VALID_TOKEN:
|
||
logger.warning("Invalid token attempt")
|
||
raise Unauthorized("Invalid token")
|
||
|
||
@multi_back_detect_bp.post("/ai/stream/multi_back_detect")
|
||
async def start_multi_back_detection(request):
|
||
try:
|
||
verify_token(request)
|
||
|
||
# 检查服务健康状态
|
||
if not service_status["is_healthy"]:
|
||
logger.warning(
|
||
f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}")
|
||
service_status["is_healthy"] = True
|
||
|
||
# 停止所有现有任务(可选,根据需求调整)
|
||
# for task_id in list(task_manager.tasks.keys()):
|
||
# await task_manager.remove_task(task_id)
|
||
|
||
# 解析并验证请求数据
|
||
request_json = RequestJson.from_dict(request.json)
|
||
print(f"/ai/stream/multi_back_detect 请求:{request.json}")
|
||
time.sleep(3)
|
||
if request_json.task_id in task_manager.tasks:
|
||
logger.warning(f"任务 {request_json.task_id} 已存在,跳过创建")
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"任务 {request_json.task_id} 已存在,跳过创建"
|
||
}, status=500)
|
||
|
||
if isinstance(request_json.content_body, ContentBodyFormat_MultiBackDetect):
|
||
try:
|
||
# 创建停止事件
|
||
stop_event = asyncio.Event()
|
||
|
||
# 包装处理函数以支持停止事件
|
||
async def wrapped_processing():
|
||
try:
|
||
await run_back_Multi_Detect_async(request, request_json, stop_event)
|
||
except asyncio.CancelledError:
|
||
logger.info(f"任务 {request_json.task_id} 被取消")
|
||
except Exception as e:
|
||
logger.error(f"任务 {request_json.task_id} 异常终止: {e}")
|
||
|
||
# 创建并启动任务
|
||
task_handle = asyncio.create_task(wrapped_processing())
|
||
|
||
|
||
|
||
except Exception as e:
|
||
logger.error(f"启动AI视频处理失败: {e}")
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"Failed to start AI video processing: {str(e)}"
|
||
}, status=500)
|
||
else:
|
||
return json_response({
|
||
"status": "failed",
|
||
"message": "content_body structure is wrong"
|
||
})
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"task_id": request_json.task_id,
|
||
"message": "Detection started successfully"
|
||
})
|
||
|
||
except ValueError as e:
|
||
logger.error(f"Validation error: {str(e)}")
|
||
return json_response({"status": "error", "message": str(e)}, status=400)
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio.Event):
|
||
global DB_CONFIG
|
||
model_configs = []
|
||
task_handle = None # 初始化task_handle
|
||
|
||
try:
|
||
|
||
invade_enable = false # para_list 中可以包含多个侵限使能,其中一个使能,即将置为True
|
||
py_func = []
|
||
# 创建DAO实例
|
||
dao = ModelConfigDAO(DB_CONFIG)
|
||
# insert_request_log(self, task_id, sn, org_code, requset_json, request)
|
||
for para in request_json.content_body.para_list:
|
||
func_id = para["func_id"]
|
||
category = para.get("py_func", []) # 提供默认值
|
||
py_func = category # 湖北现场临时用
|
||
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
|
||
if para_invade_enable:
|
||
invade_enable = True
|
||
query_results = batch_query_model_func_id([func_id], **DB_CONFIG)
|
||
row_func_id = 0 # 伪代码,后续记得修改
|
||
if len(query_results) < 1:
|
||
continue
|
||
for row in query_results:
|
||
row_func_id = row["model_func_id"]
|
||
func_id = para["func_id"]
|
||
config = dao.get_config(func_id, category)
|
||
repeat_dis = -1 # 基于两帧之间的距离去重
|
||
if config:
|
||
# 打印结构化结果(使用自定义编码器处理datetime)
|
||
# 访问特定字段
|
||
print("\n模型路径:", config.model_path)
|
||
print("过滤类别:", config.filter_indices)
|
||
print("第一个类别:", asdict(config.classes[0]))
|
||
print("创建时间:", config.created_at)
|
||
print("更新时间:", config.updated_at)
|
||
print("去重的距离:", config.repeat_dis)
|
||
repeat_dis = config.repeat_dis
|
||
repeat_time = config.repeat_time
|
||
high_count_warn = config.high_count_warn
|
||
print(f"config.high_count_warn {config.high_count_warn}")
|
||
model_configs.append(
|
||
{
|
||
|
||
'path': config.model_path,
|
||
# 'engine_path': config.engine_path,
|
||
# 'so_path': config.so_path,
|
||
# # 测试代码
|
||
'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine",
|
||
'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll",
|
||
# 工地安全帽
|
||
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
|
||
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
|
||
|
||
'cls_map': config.cls_zn_to_eh_dict,
|
||
'allowed_classes': config.allowed_classes,
|
||
"cls_index": config.class_indices,
|
||
"class_names": config.cls_names,
|
||
"chinese_label": config.cls_en_dict,
|
||
"list_func_id": row_func_id,
|
||
"func_id": func_id,
|
||
"para_invade_enable": para_invade_enable,
|
||
"config_conf": config.conf
|
||
}
|
||
)
|
||
else:
|
||
print(f"未找到ID为 {func_id} 的模型配置")
|
||
|
||
# category = para.get("category", []) # 提供默认值
|
||
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
|
||
if para_invade_enable:
|
||
invade_enable = True
|
||
|
||
# 前置处理 不同的版本的模型,输入的数据的格式不同,需要做前置处理,入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式
|
||
|
||
video_url = request_json.content_body.source_url
|
||
sn = request_json.sn
|
||
task_id = request_json.task_id
|
||
# mqtt_pub_id = request_json.content_body.mqtt_pub_id
|
||
# mqtt_sub_id = request_json.content_body.mqtt_sub_id
|
||
org_code = request_json.content_body.org_code
|
||
push_url = request_json.content_body.push_url
|
||
# invade_file = request_json.content_body.invade_file
|
||
invade = request_json.content_body.invade
|
||
invade_file = invade["invade_file"]
|
||
camera_para_url = invade["camera_para_url"]
|
||
|
||
if high_count_warn is None:
|
||
high_count_warn=0
|
||
|
||
if "invade_switch" in invade:
|
||
invade_switch = invade["invade_switch"]
|
||
else:
|
||
invade_switch = 0 # 或其他默认值
|
||
# dao.get_mqtt_config_by_orgcode(org_code,)
|
||
str_request = str(request) + "&" + str(request.socket) # 待测试,看看公网能不能捕获到请求端ip
|
||
dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request)
|
||
mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
|
||
mqtt_sub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "sub")
|
||
|
||
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
|
||
mqtt_pub_port = mqtt_pub_config.mqtt_port
|
||
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
|
||
print(f"mqtt_pub_topic {mqtt_pub_topic}")
|
||
mqtt_pub_username = mqtt_pub_config.mqtt_username
|
||
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
|
||
mqtt_pub_description = mqtt_pub_config.mqtt_description
|
||
mqtt_pub_org_code = mqtt_pub_config.org_code
|
||
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
|
||
|
||
mqtt_sub_ip = mqtt_sub_config.mqtt_ip
|
||
mqtt_sub_port = mqtt_sub_config.mqtt_port
|
||
mqtt_sub_topic = mqtt_sub_config.mqtt_topic
|
||
mqtt_sub_topic = mqtt_sub_topic.format(sn=sn)
|
||
|
||
print(f"mqtt_sub_topic {mqtt_sub_topic}")
|
||
mqtt_sub_username = mqtt_sub_config.mqtt_username
|
||
mqtt_sub_pass = mqtt_sub_config.mqtt_pass
|
||
mqtt_sub_description = mqtt_sub_config.mqtt_description
|
||
mqtt_sub_org_code = mqtt_sub_config.org_code
|
||
mqtt_sub_mqtt_type = mqtt_sub_config.mqtt_type
|
||
|
||
local_func_config = read_local_func_config()
|
||
|
||
sn = request_json.sn
|
||
device = dao.get_device(sn, org_code)
|
||
print(f"device表 {sn} {org_code}")
|
||
device_sn = device.sn
|
||
device_orgcode = device.orgcode
|
||
device_dname = device.dname
|
||
device_lat = device.lat
|
||
device_lng = device.lng
|
||
device_height = device.height # 机场高度,后续用作现场的高程计算
|
||
|
||
# # 启动处理流程
|
||
async def process_flow():
|
||
try:
|
||
await start_rtmp_processing(
|
||
video_url,
|
||
request_json.task_id,
|
||
model_configs,
|
||
mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
|
||
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
|
||
push_url,
|
||
invade_enable,invade_switch, invade_file, camera_para_url,
|
||
device_height, repeat_dis, repeat_time,high_count_warn
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"处理流程异常: {e}")
|
||
raise
|
||
|
||
# 运行处理流程,直到被取消
|
||
await asyncio.shield(process_flow())
|
||
|
||
# # # # 针对湖北现场临时处理--------------------------------------------------------
|
||
# model_path = model_configs[0]["path"]
|
||
# detect_classes = py_func
|
||
# print(f"detect_classesdetect_classes {detect_classes}")
|
||
# confidence = model_configs[0]["config_conf"]
|
||
# # 创建处理函数以支持停止事件
|
||
# async def process_video():
|
||
# nonlocal task_handle # 使用nonlocal访问外部变量
|
||
# try:
|
||
# # 针对湖北现场临时处理
|
||
# source_url = video_url
|
||
# model_path = model_configs[0]["path"]
|
||
# detect_classes = py_func # 使用py_func作为检测类别
|
||
# print(f"detect_classesdetect_classes {detect_classes}")
|
||
#
|
||
# confidence = model_configs[0]["config_conf"]
|
||
#
|
||
# # 启动YOLO检测
|
||
# await asyncio.to_thread(
|
||
# startAIVideo,
|
||
# source_url,
|
||
# push_url,
|
||
# model_path,
|
||
# detect_classes,
|
||
# confidence
|
||
# )
|
||
# except asyncio.CancelledError:
|
||
# logger.info(f"任务 {task_id} 被取消")
|
||
# raise
|
||
# except Exception as e:
|
||
# logger.error(f"任务 {task_id} 异常终止: {e}")
|
||
# raise
|
||
#
|
||
# # 创建并启动任务
|
||
# task_handle = asyncio.create_task(process_video()) # 存储task_handle
|
||
|
||
# 记录任务信息到task_manager
|
||
task_info = {
|
||
"source_url": video_url,
|
||
"push_url": push_url,
|
||
"status": "running",
|
||
"task_handle": task_handle, # 存储实际的任务句柄
|
||
"model_configs": model_configs,
|
||
"device_height": device_height,
|
||
"repeat_dis": repeat_dis,
|
||
"repeat_time": repeat_time
|
||
}
|
||
|
||
# 使用task_manager管理任务
|
||
await task_manager.add_task(
|
||
task_id,
|
||
task_info,
|
||
task_handle, # 传递实际的任务句柄
|
||
[] # 暂时没有子任务
|
||
)
|
||
|
||
# 等待任务完成或被取消
|
||
try:
|
||
await task_handle
|
||
except asyncio.CancelledError:
|
||
pass # 任务被取消是正常的
|
||
await task_manager.remove_task(task_id)
|
||
# # # 针对湖北现场处理结束--------------------------------------------------------
|
||
except asyncio.CancelledError:
|
||
logger.info(f"任务 {request_json.task_id} 收到停止信号,正在清理...")
|
||
# 清理资源逻辑...
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"任务 {request_json.task_id} 处理失败: {e}")
|
||
raise |