ai_project_v1/multi_back_detect/multi_back_detect_api.py

645 lines
26 KiB
Python
Raw Permalink Normal View History

import json
import time
from sanic import Blueprint
from sanic.response import json as json_response
from sanic.exceptions import Unauthorized, SanicException
from dataclasses import dataclass, asdict
from typing import List, Dict, Any
import logging
import asyncio
import traceback
from datetime import datetime
from sympy import false
try :
from middleware.TaskManager import task_manager
from middleware.query_model import ModelConfigDAO
from middleware.query_postgress import batch_query_model_func_id
from middleware.read_yolo_config import read_local_func_config
from yolo.cv_multi_model_back_video import start_rtmp_processing
from cv_video import startAIVideo, stopAIVideo
from cv_back_video import startBackAIVideo
except Exception as e:
import sys
from pathlib import Path
# 获取项目根目录(假设脚本在根目录的子目录中)
ROOT_DIR = Path(__file__).parent.parent # 根据实际层级调整
sys.path.append(str(ROOT_DIR))
from middleware.TaskManager import task_manager
from middleware.query_model import ModelConfigDAO
from middleware.query_postgress import batch_query_model_func_id
from middleware.read_yolo_config import read_local_func_config
from yolo.cv_multi_model_back_video import start_rtmp_processing
from cv_video import startAIVideo, stopAIVideo
from cv_back_video import startBackAIVideo
# 配置类
class Config:
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
MAX_ACTIVE_TASKS = 10
DEFAULT_CONFIDENCE = 0.5
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
#正式数据库
DB_CONFIG = {
"dbname": "smart_dev",
"user": "postgres",
"password": "StrongPassword@123",
"host": "222.212.85.86",
"port": "5061"
}
# 服务状态标志
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
multi_back_detect_bp = Blueprint("multi_back_detect", url_prefix="")
@dataclass
class BaseContentBody:
pass # 公共字段可以放在这里
@dataclass
class ContentBodyFormat_VideoMultiBackDetect(BaseContentBody):
# mqtt_pub_id: int
org_code: str
# mqtt_ip: str
# mqtt_port: int
# mqtt_topic: str
minio_file_path: str
push_url: str # 临时测试用
confidence: float
para_list: list
# invade_file: str
invade: list
@dataclass
class ContentBodyFormat_Sam3Pic(BaseContentBody):
img_url: str
prompt: str
confidence: int
mqtt_ip: str
mqtt_port: int
mqtt_topic: str
@dataclass
class ContentBodyFormat_MultiBackDetect(BaseContentBody):
# mqtt_pub_id: int
# mqtt_sub_id: int
# mqtt_pub_ip: str
# mqtt_pub_port: int
# mqtt_pub_topic: str
# mqtt_sub_ip: str
# mqtt_sub_port: int
# mqtt_sub_topic: str
org_code: str
func_id: list
source_url: str
push_url: str # 临时测试用
confidence: float
para_list: list
# invade_file: str
invade: list
@dataclass
class ContentBodyFormat_BackDetect(BaseContentBody):
mqtt_ip: str
mqtt_port: int
mqtt_topic: str
source_url: str
push_url: str # 临时测试用
confidence: float
func_id: List[int]
para: {}
@dataclass
class ContentBodyFormat_BackDetectPic(BaseContentBody):
s3_id: int
s3_url: list[str]
org_code: str
# mqtt_pub_id: int
confidence: float
func_id: list[int]
para: {}
@dataclass
class EarlyLaterUrls:
early: str
later: str
@dataclass
class ContentBodyFormat_Detection(BaseContentBody):
s3_id: int
s3_url: EarlyLaterUrls
func_id: list[int]
@dataclass
class ContentBodyFormat_Segementation(BaseContentBody):
s3_id: int
s3_url: list[str]
func_id: list[int]
@dataclass
class RequestJson:
task_id: str
sn: str
content_body: BaseContentBody
def validate(self) -> None:
"""验证请求参数"""
if not self.task_id:
raise ValueError("task_id is required")
if isinstance(self.content_body, ContentBodyFormat_VideoMultiBackDetect):
if not self.content_body.minio_file_path:
raise ValueError("minio_file_path is required for ContentBodyFormat_VideoMultiBackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_MultiBackDetect):
if not self.content_body.para_list:
raise ValueError("para_list is required for ContentBodyFormat_MultiBackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_BackDetect):
if not self.content_body.source_url:
raise ValueError("source_url is required for ContentBodyFormat_BackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_BackDetectPic):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_BackDetectPic")
if not self.content_body.s3_url:
raise ValueError("s3_url is required for ContentBodyFormat_BackDetectPic")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetectPic")
elif isinstance(self.content_body, ContentBodyFormat_Detection):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_Detection")
if not self.content_body.s3_url.early or not self.content_body.s3_url.later:
raise ValueError("Both early and later URLs are required for ContentBodyFormat_Detection")
elif isinstance(self.content_body, ContentBodyFormat_Segementation):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_Segementation")
if not self.content_body.s3_url:
raise ValueError("s3_url is required for ContentBodyFormat_Segementation")
elif isinstance(self.content_body, ContentBodyFormat_Sam3Pic):
if not self.content_body.prompt:
raise ValueError("prompt is required for ContentBodyFormat_Sam3Pic")
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RequestJson':
try:
task_id = data['task_id']
sn = data['sn']
content_body_data = data['content_body']
if 'minio_file_path' in content_body_data:
content_body = ContentBodyFormat_VideoMultiBackDetect(
org_code=content_body_data['org_code'],
minio_file_path=content_body_data['minio_file_path'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
para_list=content_body_data.get('para_list', []),
invade=content_body_data.get('invade', [])
)
# 根据 content_body_data 的类型创建相应的实例
elif 'para_list' in content_body_data:
content_body = ContentBodyFormat_MultiBackDetect(
org_code=content_body_data['org_code'],
func_id=content_body_data['func_id'],
source_url=content_body_data['source_url'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
para_list=content_body_data.get('para_list', []),
invade=content_body_data.get('invade', [])
)
# 根据 content_body_data 的类型创建相应的实例
elif 'source_url' in content_body_data:
content_body = ContentBodyFormat_BackDetect(
mqtt_ip=content_body_data['mqtt_ip'],
mqtt_port=content_body_data['mqtt_port'],
mqtt_topic=content_body_data['mqtt_topic'],
source_url=content_body_data['source_url'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
func_id=content_body_data.get('func_id', []),
para=content_body_data.get('para', {})
)
elif 's3_id' in content_body_data and 's3_url' in content_body_data:
if isinstance(content_body_data['s3_url'], dict) and 'early' in content_body_data[
's3_url'] and 'later' in content_body_data['s3_url']:
content_body = ContentBodyFormat_Detection(
s3_id=content_body_data['s3_id'],
s3_url=EarlyLaterUrls(
early=content_body_data['s3_url']['early'],
later=content_body_data['s3_url']['later']
),
func_id=content_body_data.get('func_id', [])
)
elif isinstance(content_body_data['s3_url'], list) and 'confidence' not in content_body_data:
content_body = ContentBodyFormat_Segementation(
s3_id=content_body_data['s3_id'],
s3_url=content_body_data['s3_url'],
func_id=content_body_data.get('func_id', [])
)
elif isinstance(content_body_data['s3_url'], list):
content_body = ContentBodyFormat_BackDetectPic(
s3_id=content_body_data['s3_id'],
s3_url=content_body_data['s3_url'],
# mqtt_pub_id=content_body_data['mqtt_pub_id'],
org_code=content_body_data['org_code'],
confidence=content_body_data.get('confidence', 0.5),
func_id=content_body_data.get('func_id', []),
para=content_body_data.get('para', {})
)
else:
raise ValueError("Invalid s3_url format for ContentBodyFormat_Detection")
elif 'prompt' in content_body_data:
content_body = ContentBodyFormat_Sam3Pic(
img_url=content_body_data['img_url'],
prompt=content_body_data['prompt'],
confidence=content_body_data.get('confidence', 0.5),
mqtt_ip=content_body_data['mqtt_ip'],
mqtt_port=content_body_data['mqtt_port'],
mqtt_topic=content_body_data['mqtt_topic']
)
else:
raise ValueError("Invalid content_body format")
instance = cls(
task_id=task_id,
sn=sn,
content_body=content_body
)
instance.validate()
return instance
except KeyError as e:
raise ValueError(f"Missing required field: {str(e)}")
async def safe_stop_ai_video():
"""安全地停止AI视频处理带有错误处理和恢复机制"""
try:
await asyncio.to_thread(stopAIVideo)
return True
except Exception as e:
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
# 标记服务状态为不健康
service_status["is_healthy"] = False
service_status["last_error"] = str(e)
service_status["error_time"] = datetime.now().isoformat()
# 强制结束所有任务
task_manager.mark_all_tasks_as_stopped()
# 尝试通过其他方式杀死可能存在的进程
try:
import os
import signal
import psutil
current_process = psutil.Process(os.getpid())
# 查找并终止ffmpeg子进程
for child in current_process.children(recursive=True):
try:
child_name = child.name().lower()
if 'ffmpeg' in child_name:
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
child.send_signal(signal.SIGTERM)
except Exception as child_e:
logger.error(f"终止子进程出错: {str(child_e)}")
except Exception as kill_e:
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
# 等待一段时间让系统恢复
await asyncio.sleep(Config.RESTART_DELAY)
# 重置服务状态
service_status["is_healthy"] = True
return False
def verify_token(request) -> None:
"""验证请求token"""
token = request.headers.get('X-API-Token')
if not token or token != Config.VALID_TOKEN:
logger.warning("Invalid token attempt")
raise Unauthorized("Invalid token")
@multi_back_detect_bp.post("/ai/stream/multi_back_detect")
async def start_multi_back_detection(request):
try:
verify_token(request)
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
service_status["is_healthy"] = True
# 停止所有现有任务(可选,根据需求调整)
# for task_id in list(task_manager.tasks.keys()):
# await task_manager.remove_task(task_id)
# 解析并验证请求数据
request_json = RequestJson.from_dict(request.json)
print(f"/ai/stream/multi_back_detect 请求:{request.json}")
time.sleep(3)
if request_json.task_id in task_manager.tasks:
logger.warning(f"任务 {request_json.task_id} 已存在,跳过创建")
return json_response({
"status": "error",
"message": f"任务 {request_json.task_id} 已存在,跳过创建"
}, status=500)
if isinstance(request_json.content_body, ContentBodyFormat_MultiBackDetect):
try:
# 创建停止事件
stop_event = asyncio.Event()
# 包装处理函数以支持停止事件
async def wrapped_processing():
try:
await run_back_Multi_Detect_async(request, request_json, stop_event)
except asyncio.CancelledError:
logger.info(f"任务 {request_json.task_id} 被取消")
except Exception as e:
logger.error(f"任务 {request_json.task_id} 异常终止: {e}")
# 创建并启动任务
task_handle = asyncio.create_task(wrapped_processing())
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
else:
return json_response({
"status": "failed",
"message": "content_body structure is wrong"
})
return json_response({
"status": "success",
"task_id": request_json.task_id,
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio.Event):
global DB_CONFIG
model_configs = []
task_handle = None # 初始化task_handle
try:
invade_enable = false # para_list 中可以包含多个侵限使能其中一个使能即将置为True
py_func = []
# 创建DAO实例
dao = ModelConfigDAO(DB_CONFIG)
# insert_request_log(self, task_id, sn, org_code, requset_json, request)
for para in request_json.content_body.para_list:
func_id = para["func_id"]
category = para.get("py_func", []) # 提供默认值
py_func = category # 湖北现场临时用
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
if para_invade_enable:
invade_enable = True
query_results = batch_query_model_func_id([func_id], **DB_CONFIG)
row_func_id = 0 # 伪代码,后续记得修改
if len(query_results) < 1:
continue
for row in query_results:
row_func_id = row["model_func_id"]
func_id = para["func_id"]
config = dao.get_config(func_id, category)
repeat_dis = -1 # 基于两帧之间的距离去重
if config:
# 打印结构化结果使用自定义编码器处理datetime
# 访问特定字段
print("\n模型路径:", config.model_path)
print("过滤类别:", config.filter_indices)
print("第一个类别:", asdict(config.classes[0]))
print("创建时间:", config.created_at)
print("更新时间:", config.updated_at)
print("去重的距离:", config.repeat_dis)
repeat_dis = config.repeat_dis
repeat_time = config.repeat_time
high_count_warn = config.high_count_warn
print(f"config.high_count_warn {config.high_count_warn}")
model_configs.append(
{
'path': config.model_path,
# 'engine_path': config.engine_path,
# 'so_path': config.so_path,
# # 测试代码
'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\renche.engine",
'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\renche\myplugins.dll",
# 工地安全帽
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
'cls_map': config.cls_zn_to_eh_dict,
'allowed_classes': config.allowed_classes,
"cls_index": config.class_indices,
"class_names": config.cls_names,
"chinese_label": config.cls_en_dict,
"list_func_id": row_func_id,
"func_id": func_id,
"para_invade_enable": para_invade_enable,
"config_conf": config.conf
}
)
else:
print(f"未找到ID为 {func_id} 的模型配置")
# category = para.get("category", []) # 提供默认值
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
if para_invade_enable:
invade_enable = True
# 前置处理 不同的版本的模型输入的数据的格式不同需要做前置处理入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式
video_url = request_json.content_body.source_url
sn = request_json.sn
task_id = request_json.task_id
# mqtt_pub_id = request_json.content_body.mqtt_pub_id
# mqtt_sub_id = request_json.content_body.mqtt_sub_id
org_code = request_json.content_body.org_code
push_url = request_json.content_body.push_url
# invade_file = request_json.content_body.invade_file
invade = request_json.content_body.invade
invade_file = invade["invade_file"]
camera_para_url = invade["camera_para_url"]
if high_count_warn is None:
high_count_warn=0
if "invade_switch" in invade:
invade_switch = invade["invade_switch"]
else:
invade_switch = 0 # 或其他默认值
# dao.get_mqtt_config_by_orgcode(org_code,)
str_request = str(request) + "&" + str(request.socket) # 待测试看看公网能不能捕获到请求端ip
dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request)
mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
mqtt_sub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "sub")
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
mqtt_pub_port = mqtt_pub_config.mqtt_port
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
print(f"mqtt_pub_topic {mqtt_pub_topic}")
mqtt_pub_username = mqtt_pub_config.mqtt_username
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
mqtt_pub_description = mqtt_pub_config.mqtt_description
mqtt_pub_org_code = mqtt_pub_config.org_code
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
mqtt_sub_ip = mqtt_sub_config.mqtt_ip
mqtt_sub_port = mqtt_sub_config.mqtt_port
mqtt_sub_topic = mqtt_sub_config.mqtt_topic
mqtt_sub_topic = mqtt_sub_topic.format(sn=sn)
print(f"mqtt_sub_topic {mqtt_sub_topic}")
mqtt_sub_username = mqtt_sub_config.mqtt_username
mqtt_sub_pass = mqtt_sub_config.mqtt_pass
mqtt_sub_description = mqtt_sub_config.mqtt_description
mqtt_sub_org_code = mqtt_sub_config.org_code
mqtt_sub_mqtt_type = mqtt_sub_config.mqtt_type
local_func_config = read_local_func_config()
sn = request_json.sn
device = dao.get_device(sn, org_code)
print(f"device表 {sn} {org_code}")
device_sn = device.sn
device_orgcode = device.orgcode
device_dname = device.dname
device_lat = device.lat
device_lng = device.lng
device_height = device.height # 机场高度,后续用作现场的高程计算
# # 启动处理流程
async def process_flow():
try:
await start_rtmp_processing(
video_url,
request_json.task_id,
model_configs,
mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
push_url,
invade_enable,invade_switch, invade_file, camera_para_url,
device_height, repeat_dis, repeat_time,high_count_warn
)
except Exception as e:
logger.error(f"处理流程异常: {e}")
raise
# 运行处理流程,直到被取消
await asyncio.shield(process_flow())
# # # # 针对湖北现场临时处理--------------------------------------------------------
# model_path = model_configs[0]["path"]
# detect_classes = py_func
# print(f"detect_classesdetect_classes {detect_classes}")
# confidence = model_configs[0]["config_conf"]
# # 创建处理函数以支持停止事件
# async def process_video():
# nonlocal task_handle # 使用nonlocal访问外部变量
# try:
# # 针对湖北现场临时处理
# source_url = video_url
# model_path = model_configs[0]["path"]
# detect_classes = py_func # 使用py_func作为检测类别
# print(f"detect_classesdetect_classes {detect_classes}")
#
# confidence = model_configs[0]["config_conf"]
#
# # 启动YOLO检测
# await asyncio.to_thread(
# startAIVideo,
# source_url,
# push_url,
# model_path,
# detect_classes,
# confidence
# )
# except asyncio.CancelledError:
# logger.info(f"任务 {task_id} 被取消")
# raise
# except Exception as e:
# logger.error(f"任务 {task_id} 异常终止: {e}")
# raise
#
# # 创建并启动任务
# task_handle = asyncio.create_task(process_video()) # 存储task_handle
# 记录任务信息到task_manager
task_info = {
"source_url": video_url,
"push_url": push_url,
"status": "running",
"task_handle": task_handle, # 存储实际的任务句柄
"model_configs": model_configs,
"device_height": device_height,
"repeat_dis": repeat_dis,
"repeat_time": repeat_time
}
# 使用task_manager管理任务
await task_manager.add_task(
task_id,
task_info,
task_handle, # 传递实际的任务句柄
[] # 暂时没有子任务
)
# 等待任务完成或被取消
try:
await task_handle
except asyncio.CancelledError:
pass # 任务被取消是正常的
await task_manager.remove_task(task_id)
# # # 针对湖北现场处理结束--------------------------------------------------------
except asyncio.CancelledError:
logger.info(f"任务 {request_json.task_id} 收到停止信号,正在清理...")
# 清理资源逻辑...
raise
except Exception as e:
logger.error(f"任务 {request_json.task_id} 处理失败: {e}")
raise