import json import time from sanic import Blueprint from sanic.response import json as json_response from sanic.exceptions import Unauthorized, SanicException from dataclasses import dataclass, asdict from typing import List, Dict, Any import logging import asyncio import traceback from datetime import datetime from sympy import false try : from middleware.TaskManager import task_manager from middleware.query_model import ModelConfigDAO from middleware.query_postgress import batch_query_model_func_id from middleware.read_yolo_config import read_local_func_config from yolo.cv_multi_model_back_video import start_rtmp_processing from cv_video import startAIVideo, stopAIVideo from cv_back_video import startBackAIVideo except Exception as e: import sys from pathlib import Path # 获取项目根目录(假设脚本在根目录的子目录中) ROOT_DIR = Path(__file__).parent.parent # 根据实际层级调整 sys.path.append(str(ROOT_DIR)) from middleware.TaskManager import task_manager from middleware.query_model import ModelConfigDAO from middleware.query_postgress import batch_query_model_func_id from middleware.read_yolo_config import read_local_func_config from yolo.cv_multi_model_back_video import start_rtmp_processing from cv_video import startAIVideo, stopAIVideo from cv_back_video import startBackAIVideo # 配置类 class Config: VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa" MAX_ACTIVE_TASKS = 10 DEFAULT_CONFIDENCE = 0.5 RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒) #正式数据库 DB_CONFIG = { "dbname": "smart_dev", "user": "postgres", "password": "StrongPassword@123", "host": "222.212.85.86", "port": "5061" } # 服务状态标志 service_status = {"is_healthy": True, "last_error": None, "error_time": None} # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) multi_back_detect_bp = Blueprint("multi_back_detect", url_prefix="") @dataclass class BaseContentBody: pass # 公共字段可以放在这里 @dataclass class ContentBodyFormat_VideoMultiBackDetect(BaseContentBody): # mqtt_pub_id: int org_code: str # mqtt_ip: str # mqtt_port: int # mqtt_topic: str minio_file_path: str push_url: str # 临时测试用 confidence: float para_list: list # invade_file: str invade: list @dataclass class ContentBodyFormat_Sam3Pic(BaseContentBody): img_url: str prompt: str confidence: int mqtt_ip: str mqtt_port: int mqtt_topic: str @dataclass class ContentBodyFormat_MultiBackDetect(BaseContentBody): # mqtt_pub_id: int # mqtt_sub_id: int # mqtt_pub_ip: str # mqtt_pub_port: int # mqtt_pub_topic: str # mqtt_sub_ip: str # mqtt_sub_port: int # mqtt_sub_topic: str org_code: str func_id: list source_url: str push_url: str # 临时测试用 confidence: float para_list: list # invade_file: str invade: list @dataclass class ContentBodyFormat_BackDetect(BaseContentBody): mqtt_ip: str mqtt_port: int mqtt_topic: str source_url: str push_url: str # 临时测试用 confidence: float func_id: List[int] para: {} @dataclass class ContentBodyFormat_BackDetectPic(BaseContentBody): s3_id: int s3_url: list[str] org_code: str # mqtt_pub_id: int confidence: float func_id: list[int] para: {} @dataclass class EarlyLaterUrls: early: str later: str @dataclass class ContentBodyFormat_Detection(BaseContentBody): s3_id: int s3_url: EarlyLaterUrls func_id: list[int] @dataclass class ContentBodyFormat_Segementation(BaseContentBody): s3_id: int s3_url: list[str] func_id: list[int] @dataclass class RequestJson: task_id: str sn: str content_body: BaseContentBody def validate(self) -> None: """验证请求参数""" if not self.task_id: raise ValueError("task_id is required") if isinstance(self.content_body, ContentBodyFormat_VideoMultiBackDetect): if not self.content_body.minio_file_path: raise ValueError("minio_file_path is required for ContentBodyFormat_VideoMultiBackDetect") if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1: raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect") elif isinstance(self.content_body, ContentBodyFormat_MultiBackDetect): if not self.content_body.para_list: raise ValueError("para_list is required for ContentBodyFormat_MultiBackDetect") if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1: raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect") elif isinstance(self.content_body, ContentBodyFormat_BackDetect): if not self.content_body.source_url: raise ValueError("source_url is required for ContentBodyFormat_BackDetect") if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1: raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect") elif isinstance(self.content_body, ContentBodyFormat_BackDetectPic): if not self.content_body.s3_id: raise ValueError("s3_id is required for ContentBodyFormat_BackDetectPic") if not self.content_body.s3_url: raise ValueError("s3_url is required for ContentBodyFormat_BackDetectPic") if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1: raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetectPic") elif isinstance(self.content_body, ContentBodyFormat_Detection): if not self.content_body.s3_id: raise ValueError("s3_id is required for ContentBodyFormat_Detection") if not self.content_body.s3_url.early or not self.content_body.s3_url.later: raise ValueError("Both early and later URLs are required for ContentBodyFormat_Detection") elif isinstance(self.content_body, ContentBodyFormat_Segementation): if not self.content_body.s3_id: raise ValueError("s3_id is required for ContentBodyFormat_Segementation") if not self.content_body.s3_url: raise ValueError("s3_url is required for ContentBodyFormat_Segementation") elif isinstance(self.content_body, ContentBodyFormat_Sam3Pic): if not self.content_body.prompt: raise ValueError("prompt is required for ContentBodyFormat_Sam3Pic") @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'RequestJson': try: task_id = data['task_id'] sn = data['sn'] content_body_data = data['content_body'] if 'minio_file_path' in content_body_data: content_body = ContentBodyFormat_VideoMultiBackDetect( org_code=content_body_data['org_code'], minio_file_path=content_body_data['minio_file_path'], push_url=content_body_data['push_url'], confidence=content_body_data.get('confidence', 0.5), para_list=content_body_data.get('para_list', []), invade=content_body_data.get('invade', []) ) # 根据 content_body_data 的类型创建相应的实例 elif 'para_list' in content_body_data: content_body = ContentBodyFormat_MultiBackDetect( org_code=content_body_data['org_code'], func_id=content_body_data['func_id'], source_url=content_body_data['source_url'], push_url=content_body_data['push_url'], confidence=content_body_data.get('confidence', 0.5), para_list=content_body_data.get('para_list', []), invade=content_body_data.get('invade', []) ) # 根据 content_body_data 的类型创建相应的实例 elif 'source_url' in content_body_data: content_body = ContentBodyFormat_BackDetect( mqtt_ip=content_body_data['mqtt_ip'], mqtt_port=content_body_data['mqtt_port'], mqtt_topic=content_body_data['mqtt_topic'], source_url=content_body_data['source_url'], push_url=content_body_data['push_url'], confidence=content_body_data.get('confidence', 0.5), func_id=content_body_data.get('func_id', []), para=content_body_data.get('para', {}) ) elif 's3_id' in content_body_data and 's3_url' in content_body_data: if isinstance(content_body_data['s3_url'], dict) and 'early' in content_body_data[ 's3_url'] and 'later' in content_body_data['s3_url']: content_body = ContentBodyFormat_Detection( s3_id=content_body_data['s3_id'], s3_url=EarlyLaterUrls( early=content_body_data['s3_url']['early'], later=content_body_data['s3_url']['later'] ), func_id=content_body_data.get('func_id', []) ) elif isinstance(content_body_data['s3_url'], list) and 'confidence' not in content_body_data: content_body = ContentBodyFormat_Segementation( s3_id=content_body_data['s3_id'], s3_url=content_body_data['s3_url'], func_id=content_body_data.get('func_id', []) ) elif isinstance(content_body_data['s3_url'], list): content_body = ContentBodyFormat_BackDetectPic( s3_id=content_body_data['s3_id'], s3_url=content_body_data['s3_url'], # mqtt_pub_id=content_body_data['mqtt_pub_id'], org_code=content_body_data['org_code'], confidence=content_body_data.get('confidence', 0.5), func_id=content_body_data.get('func_id', []), para=content_body_data.get('para', {}) ) else: raise ValueError("Invalid s3_url format for ContentBodyFormat_Detection") elif 'prompt' in content_body_data: content_body = ContentBodyFormat_Sam3Pic( img_url=content_body_data['img_url'], prompt=content_body_data['prompt'], confidence=content_body_data.get('confidence', 0.5), mqtt_ip=content_body_data['mqtt_ip'], mqtt_port=content_body_data['mqtt_port'], mqtt_topic=content_body_data['mqtt_topic'] ) else: raise ValueError("Invalid content_body format") instance = cls( task_id=task_id, sn=sn, content_body=content_body ) instance.validate() return instance except KeyError as e: raise ValueError(f"Missing required field: {str(e)}") async def safe_stop_ai_video(): """安全地停止AI视频处理,带有错误处理和恢复机制""" try: await asyncio.to_thread(stopAIVideo) return True except Exception as e: error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}" logger.error(error_msg) # 标记服务状态为不健康 service_status["is_healthy"] = False service_status["last_error"] = str(e) service_status["error_time"] = datetime.now().isoformat() # 强制结束所有任务 task_manager.mark_all_tasks_as_stopped() # 尝试通过其他方式杀死可能存在的进程 try: import os import signal import psutil current_process = psutil.Process(os.getpid()) # 查找并终止ffmpeg子进程 for child in current_process.children(recursive=True): try: child_name = child.name().lower() if 'ffmpeg' in child_name: logger.info(f"强制终止子进程: {child.pid} ({child_name})") child.send_signal(signal.SIGTERM) except Exception as child_e: logger.error(f"终止子进程出错: {str(child_e)}") except Exception as kill_e: logger.error(f"尝试清理进程时出错: {str(kill_e)}") # 等待一段时间让系统恢复 await asyncio.sleep(Config.RESTART_DELAY) # 重置服务状态 service_status["is_healthy"] = True return False def verify_token(request) -> None: """验证请求token""" token = request.headers.get('X-API-Token') if not token or token != Config.VALID_TOKEN: logger.warning("Invalid token attempt") raise Unauthorized("Invalid token") @multi_back_detect_bp.post("/ai/stream/multi_back_detect") async def start_multi_back_detection(request): try: verify_token(request) # 检查服务健康状态 if not service_status["is_healthy"]: logger.warning( f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}") service_status["is_healthy"] = True # 停止所有现有任务(可选,根据需求调整) # for task_id in list(task_manager.tasks.keys()): # await task_manager.remove_task(task_id) # 解析并验证请求数据 request_json = RequestJson.from_dict(request.json) print(f"/ai/stream/multi_back_detect 请求:{request.json}") time.sleep(3) if request_json.task_id in task_manager.tasks: logger.warning(f"任务 {request_json.task_id} 已存在,跳过创建") return json_response({ "status": "error", "message": f"任务 {request_json.task_id} 已存在,跳过创建" }, status=500) if isinstance(request_json.content_body, ContentBodyFormat_MultiBackDetect): try: # 创建停止事件 stop_event = asyncio.Event() # 包装处理函数以支持停止事件 async def wrapped_processing(): try: await run_back_Multi_Detect_async(request, request_json, stop_event) except asyncio.CancelledError: logger.info(f"任务 {request_json.task_id} 被取消") except Exception as e: logger.error(f"任务 {request_json.task_id} 异常终止: {e}") # 创建并启动任务 task_handle = asyncio.create_task(wrapped_processing()) except Exception as e: logger.error(f"启动AI视频处理失败: {e}") return json_response({ "status": "error", "message": f"Failed to start AI video processing: {str(e)}" }, status=500) else: return json_response({ "status": "failed", "message": "content_body structure is wrong" }) return json_response({ "status": "success", "task_id": request_json.task_id, "message": "Detection started successfully" }) except ValueError as e: logger.error(f"Validation error: {str(e)}") return json_response({"status": "error", "message": str(e)}, status=400) except Exception as e: logger.error(f"Unexpected error: {str(e)}", exc_info=True) return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio.Event): global DB_CONFIG model_configs = [] task_handle = None # 初始化task_handle try: invade_enable = false # para_list 中可以包含多个侵限使能,其中一个使能,即将置为True py_func = [] # 创建DAO实例 dao = ModelConfigDAO(DB_CONFIG) # insert_request_log(self, task_id, sn, org_code, requset_json, request) for para in request_json.content_body.para_list: func_id = para["func_id"] category = para.get("py_func", []) # 提供默认值 py_func = category # 湖北现场临时用 para_invade_enable = para.get("para_invade_enable", False) # 提供默认值 if para_invade_enable: invade_enable = True query_results = batch_query_model_func_id([func_id], **DB_CONFIG) row_func_id = 0 # 伪代码,后续记得修改 if len(query_results) < 1: continue for row in query_results: row_func_id = row["model_func_id"] func_id = para["func_id"] config = dao.get_config(func_id, category) repeat_dis = -1 # 基于两帧之间的距离去重 if config: # 打印结构化结果(使用自定义编码器处理datetime) # 访问特定字段 print("\n模型路径:", config.model_path) print("过滤类别:", config.filter_indices) print("第一个类别:", asdict(config.classes[0])) print("创建时间:", config.created_at) print("更新时间:", config.updated_at) print("去重的距离:", config.repeat_dis) repeat_dis = config.repeat_dis repeat_time = config.repeat_time high_count_warn = config.high_count_warn print(f"config.high_count_warn {config.high_count_warn}") model_configs.append( { 'path': config.model_path, # 'engine_path': config.engine_path, # 'so_path': config.so_path, # # 测试代码 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine", 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll", # 工地安全帽 # 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine", # 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll", 'cls_map': config.cls_zn_to_eh_dict, 'allowed_classes': config.allowed_classes, "cls_index": config.class_indices, "class_names": config.cls_names, "chinese_label": config.cls_en_dict, "list_func_id": row_func_id, "func_id": func_id, "para_invade_enable": para_invade_enable, "config_conf": config.conf } ) else: print(f"未找到ID为 {func_id} 的模型配置") # category = para.get("category", []) # 提供默认值 para_invade_enable = para.get("para_invade_enable", False) # 提供默认值 if para_invade_enable: invade_enable = True # 前置处理 不同的版本的模型,输入的数据的格式不同,需要做前置处理,入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式 video_url = request_json.content_body.source_url sn = request_json.sn task_id = request_json.task_id # mqtt_pub_id = request_json.content_body.mqtt_pub_id # mqtt_sub_id = request_json.content_body.mqtt_sub_id org_code = request_json.content_body.org_code push_url = request_json.content_body.push_url # invade_file = request_json.content_body.invade_file invade = request_json.content_body.invade invade_file = invade["invade_file"] camera_para_url = invade["camera_para_url"] if high_count_warn is None: high_count_warn=0 if "invade_switch" in invade: invade_switch = invade["invade_switch"] else: invade_switch = 0 # 或其他默认值 # dao.get_mqtt_config_by_orgcode(org_code,) str_request = str(request) + "&" + str(request.socket) # 待测试,看看公网能不能捕获到请求端ip dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request) mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub") mqtt_sub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "sub") mqtt_pub_ip = mqtt_pub_config.mqtt_ip mqtt_pub_port = mqtt_pub_config.mqtt_port mqtt_pub_topic = mqtt_pub_config.mqtt_topic print(f"mqtt_pub_topic {mqtt_pub_topic}") mqtt_pub_username = mqtt_pub_config.mqtt_username mqtt_pub_pass = mqtt_pub_config.mqtt_pass mqtt_pub_description = mqtt_pub_config.mqtt_description mqtt_pub_org_code = mqtt_pub_config.org_code mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type mqtt_sub_ip = mqtt_sub_config.mqtt_ip mqtt_sub_port = mqtt_sub_config.mqtt_port mqtt_sub_topic = mqtt_sub_config.mqtt_topic mqtt_sub_topic = mqtt_sub_topic.format(sn=sn) print(f"mqtt_sub_topic {mqtt_sub_topic}") mqtt_sub_username = mqtt_sub_config.mqtt_username mqtt_sub_pass = mqtt_sub_config.mqtt_pass mqtt_sub_description = mqtt_sub_config.mqtt_description mqtt_sub_org_code = mqtt_sub_config.org_code mqtt_sub_mqtt_type = mqtt_sub_config.mqtt_type local_func_config = read_local_func_config() sn = request_json.sn device = dao.get_device(sn, org_code) print(f"device表 {sn} {org_code}") device_sn = device.sn device_orgcode = device.orgcode device_dname = device.dname device_lat = device.lat device_lng = device.lng device_height = device.height # 机场高度,后续用作现场的高程计算 # # 启动处理流程 async def process_flow(): try: await start_rtmp_processing( video_url, request_json.task_id, model_configs, mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic, mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic, push_url, invade_enable,invade_switch, invade_file, camera_para_url, device_height, repeat_dis, repeat_time,high_count_warn ) except Exception as e: logger.error(f"处理流程异常: {e}") raise # 运行处理流程,直到被取消 await asyncio.shield(process_flow()) # # # # 针对湖北现场临时处理-------------------------------------------------------- # model_path = model_configs[0]["path"] # detect_classes = py_func # print(f"detect_classesdetect_classes {detect_classes}") # confidence = model_configs[0]["config_conf"] # # 创建处理函数以支持停止事件 # async def process_video(): # nonlocal task_handle # 使用nonlocal访问外部变量 # try: # # 针对湖北现场临时处理 # source_url = video_url # model_path = model_configs[0]["path"] # detect_classes = py_func # 使用py_func作为检测类别 # print(f"detect_classesdetect_classes {detect_classes}") # # confidence = model_configs[0]["config_conf"] # # # 启动YOLO检测 # await asyncio.to_thread( # startAIVideo, # source_url, # push_url, # model_path, # detect_classes, # confidence # ) # except asyncio.CancelledError: # logger.info(f"任务 {task_id} 被取消") # raise # except Exception as e: # logger.error(f"任务 {task_id} 异常终止: {e}") # raise # # # 创建并启动任务 # task_handle = asyncio.create_task(process_video()) # 存储task_handle # 记录任务信息到task_manager task_info = { "source_url": video_url, "push_url": push_url, "status": "running", "task_handle": task_handle, # 存储实际的任务句柄 "model_configs": model_configs, "device_height": device_height, "repeat_dis": repeat_dis, "repeat_time": repeat_time } # 使用task_manager管理任务 await task_manager.add_task( task_id, task_info, task_handle, # 传递实际的任务句柄 [] # 暂时没有子任务 ) # 等待任务完成或被取消 try: await task_handle except asyncio.CancelledError: pass # 任务被取消是正常的 await task_manager.remove_task(task_id) # # # 针对湖北现场处理结束-------------------------------------------------------- except asyncio.CancelledError: logger.info(f"任务 {request_json.task_id} 收到停止信号,正在清理...") # 清理资源逻辑... raise except Exception as e: logger.error(f"任务 {request_json.task_id} 处理失败: {e}") raise