ai_project_v1/yolo_api.py

1737 lines
70 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import time
from sanic import Sanic, Websocket
from sanic.response import json as json_response
from sanic.exceptions import Unauthorized, NotFound, SanicException
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Optional
import uuid
import logging
import asyncio
import traceback
from datetime import datetime
from sympy import false
from websockets.exceptions import ConnectionClosed
from CropLand_CD_module.detection import corpland_detection_func
from cropland_module.detection import detection_func
from middleware.AsyncioMqttClient import AsyncMQTTClient, ConnectionContext, active_connections
from middleware.TaskManager import TaskManager, task_manager
from middleware.minio_util import downFile
from middleware.query_model import ModelConfigDAO
from middleware.query_postgress import batch_query_model_func_id
from middleware.read_yolo_config import read_local_func_config, get_local_func_by_id, \
get_local_func_by_id_and_category
# from Ai_tottle.Cropland-CD import detection_func
from pic_detect import pic_detect_func, pic_detect_func_trt
from touying.getmq_sendresult import CalTouYing
from uav_module.mqtt_request import send_mqtt_uv_request
from uav_module.segementation import segementation_func
from yolo.cv_multi_model_back_video import start_video_processing, start_rtmp_processing, \
cleanup_resources
from cv_video import startAIVideo, stopAIVideo, getIfAI
from cv_back_video import startBackAIVideo
from sanic_cors import CORS
from sanic import Sanic, Request
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 服务状态标志
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
# DB_CONFIG = {
# "host": "8.137.54.85",
# "port": 5060,
# "dbname": "smart_dev_123",
# "user": "postgres",
# "password": "root"
# }
DB_CONFIG = {
"dbname": "smart_dev_123",
"user": "postgres",
"password": "root",
"host": "8.137.54.85",
"port": "5060"
}
# 配置类
class Config:
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
MAX_ACTIVE_TASKS = 10
DEFAULT_CONFIDENCE = 0.5
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
@dataclass
class BaseContentBody:
pass # 公共字段可以放在这里
@dataclass
class ContentBodyFormat_VideoMultiBackDetect(BaseContentBody):
# mqtt_pub_id: int
org_code: str
# mqtt_ip: str
# mqtt_port: int
# mqtt_topic: str
minio_file_path: str
push_url: str # 临时测试用
confidence: float
para_list: list
# invade_file: str
invade: list
@dataclass
class ContentBodyFormat_MultiBackDetect(BaseContentBody):
# mqtt_pub_id: int
# mqtt_sub_id: int
# mqtt_pub_ip: str
# mqtt_pub_port: int
# mqtt_pub_topic: str
# mqtt_sub_ip: str
# mqtt_sub_port: int
# mqtt_sub_topic: str
org_code: str
func_id: list
source_url: str
push_url: str # 临时测试用
confidence: float
para_list: list
# invade_file: str
invade: list
@dataclass
class ContentBodyFormat_BackDetect(BaseContentBody):
mqtt_ip: str
mqtt_port: int
mqtt_topic: str
source_url: str
push_url: str # 临时测试用
confidence: float
func_id: List[int]
para: {}
@dataclass
class ContentBodyFormat_BackDetectPic(BaseContentBody):
s3_id: int
s3_url: list[str]
org_code: str
# mqtt_pub_id: int
confidence: float
func_id: list[int]
para: {}
@dataclass
class EarlyLaterUrls:
early: str
later: str
@dataclass
class ContentBodyFormat_Detection(BaseContentBody):
s3_id: int
s3_url: EarlyLaterUrls
func_id: list[int]
@dataclass
class ContentBodyFormat_Segementation(BaseContentBody):
s3_id: int
s3_url: list[str]
func_id: list[int]
@dataclass
class RequestJson:
task_id: str
sn: str
content_body: BaseContentBody
def validate(self) -> None:
"""验证请求参数"""
if not self.task_id:
raise ValueError("task_id is required")
if isinstance(self.content_body, ContentBodyFormat_VideoMultiBackDetect):
if not self.content_body.minio_file_path:
raise ValueError("minio_file_path is required for ContentBodyFormat_VideoMultiBackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_MultiBackDetect):
if not self.content_body.para_list:
raise ValueError("para_list is required for ContentBodyFormat_MultiBackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_BackDetect):
if not self.content_body.source_url:
raise ValueError("source_url is required for ContentBodyFormat_BackDetect")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetect")
elif isinstance(self.content_body, ContentBodyFormat_BackDetectPic):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_BackDetectPic")
if not self.content_body.s3_url:
raise ValueError("s3_url is required for ContentBodyFormat_BackDetectPic")
if self.content_body.confidence is not None and not 0 < self.content_body.confidence < 1:
raise ValueError("Confidence must be between 0 and 1 for ContentBodyFormat_BackDetectPic")
elif isinstance(self.content_body, ContentBodyFormat_Detection):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_Detection")
if not self.content_body.s3_url.early or not self.content_body.s3_url.later:
raise ValueError("Both early and later URLs are required for ContentBodyFormat_Detection")
elif isinstance(self.content_body, ContentBodyFormat_Segementation):
if not self.content_body.s3_id:
raise ValueError("s3_id is required for ContentBodyFormat_Segementation")
if not self.content_body.s3_url:
raise ValueError("s3_url is required for ContentBodyFormat_Segementation")
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'RequestJson':
try:
task_id = data['task_id']
sn = data['sn']
content_body_data = data['content_body']
if 'minio_file_path' in content_body_data:
content_body = ContentBodyFormat_VideoMultiBackDetect(
# mqtt_pub_id=content_body_data['mqtt_pub_id'],
org_code=content_body_data['org_code'],
# mqtt_ip=content_body_data['mqtt_ip'],
# mqtt_port=content_body_data['mqtt_port'],
# mqtt_topic=content_body_data['mqtt_topic'],
minio_file_path=content_body_data['minio_file_path'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
para_list=content_body_data.get('para_list', []),
# invade_file=content_body_data.get('invade_file', "")
invade=content_body_data.get('invade', [])
)
# 根据 content_body_data 的类型创建相应的实例
elif 'para_list' in content_body_data:
content_body = ContentBodyFormat_MultiBackDetect(
org_code=content_body_data['org_code'],
func_id=content_body_data['func_id'],
# mqtt_sub_id=content_body_data['mqtt_sub_id'],
# mqtt_pub_ip=content_body_data['mqtt_pub_ip'],
# mqtt_pub_port=content_body_data['mqtt_pub_port'],
# mqtt_pub_topic=content_body_data['mqtt_pub_topic'],
# mqtt_sub_ip=content_body_data['mqtt_sub_ip'],
# mqtt_sub_port=content_body_data['mqtt_sub_port'],
# mqtt_sub_topic=content_body_data['mqtt_sub_topic'],
source_url=content_body_data['source_url'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
para_list=content_body_data.get('para_list', []),
invade=content_body_data.get('invade', [])
)
# 根据 content_body_data 的类型创建相应的实例
elif 'source_url' in content_body_data:
content_body = ContentBodyFormat_BackDetect(
mqtt_ip=content_body_data['mqtt_ip'],
mqtt_port=content_body_data['mqtt_port'],
mqtt_topic=content_body_data['mqtt_topic'],
source_url=content_body_data['source_url'],
push_url=content_body_data['push_url'],
confidence=content_body_data.get('confidence', 0.5),
func_id=content_body_data.get('func_id', []),
para=content_body_data.get('para', {})
)
elif 's3_id' in content_body_data and 's3_url' in content_body_data:
if isinstance(content_body_data['s3_url'], dict) and 'early' in content_body_data[
's3_url'] and 'later' in content_body_data['s3_url']:
content_body = ContentBodyFormat_Detection(
s3_id=content_body_data['s3_id'],
s3_url=EarlyLaterUrls(
early=content_body_data['s3_url']['early'],
later=content_body_data['s3_url']['later']
),
func_id=content_body_data.get('func_id', [])
)
elif isinstance(content_body_data['s3_url'], list) and 'confidence' not in content_body_data:
content_body = ContentBodyFormat_Segementation(
s3_id=content_body_data['s3_id'],
s3_url=content_body_data['s3_url'],
func_id=content_body_data.get('func_id', [])
)
elif isinstance(content_body_data['s3_url'], list):
content_body = ContentBodyFormat_BackDetectPic(
s3_id=content_body_data['s3_id'],
s3_url=content_body_data['s3_url'],
# mqtt_pub_id=content_body_data['mqtt_pub_id'],
org_code=content_body_data['org_code'],
confidence=content_body_data.get('confidence', 0.5),
func_id=content_body_data.get('func_id', []),
para=content_body_data.get('para', {})
)
else:
raise ValueError("Invalid s3_url format for ContentBodyFormat_Detection")
else:
raise ValueError("Invalid content_body format")
instance = cls(
task_id=task_id,
sn=sn,
content_body=content_body
)
instance.validate()
return instance
except KeyError as e:
raise ValueError(f"Missing required field: {str(e)}")
async def heartbeat_monitor(task_manager: TaskManager):
"""心跳监测协程,定期检查任务状态"""
while True:
await asyncio.sleep(task_manager.heartbeat_interval)
unhealthy_tasks = await task_manager.check_tasks_health()
for task_id in unhealthy_tasks:
logger.warning(f"Task {task_id} has timed out, removing...")
await task_manager.remove_task(task_id)
app = Sanic("YoloStreamService1")
CORS(app)
# 启动心跳监测
async def start_heartbeat_monitor():
await heartbeat_monitor(task_manager)
#
# @app.listener("before_server_start")
# async def setup_task_manager(app, loop):
# await task_manager.start()
# asyncio.create_task(start_heartbeat_monitor())
async def safe_stop_ai_video():
"""安全地停止AI视频处理带有错误处理和恢复机制"""
try:
await asyncio.to_thread(stopAIVideo)
return True
except Exception as e:
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
# 标记服务状态为不健康
service_status["is_healthy"] = False
service_status["last_error"] = str(e)
service_status["error_time"] = datetime.now().isoformat()
# 强制结束所有任务
task_manager.mark_all_tasks_as_stopped()
# 尝试通过其他方式杀死可能存在的进程
try:
import os
import signal
import psutil
current_process = psutil.Process(os.getpid())
# 查找并终止ffmpeg子进程
for child in current_process.children(recursive=True):
try:
child_name = child.name().lower()
if 'ffmpeg' in child_name:
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
child.send_signal(signal.SIGTERM)
except Exception as child_e:
logger.error(f"终止子进程出错: {str(child_e)}")
except Exception as kill_e:
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
# 等待一段时间让系统恢复
await asyncio.sleep(Config.RESTART_DELAY)
# 重置服务状态
service_status["is_healthy"] = True
return False
def verify_token(request) -> None:
"""验证请求token"""
token = request.headers.get('X-API-Token')
if not token or token != Config.VALID_TOKEN:
logger.warning("Invalid token attempt")
raise Unauthorized("Invalid token")
# 未针对具体方法做实现,待完成
async def run_picDetect_async(request_json):
"""异步运行分割检测算法"""
global DB_CONFIG
dao = ModelConfigDAO(DB_CONFIG)
org_code = request_json.content_body.org_code
sn = request_json.sn
mqtt_pub_config = dao.get_pic_mqtt_config_by_orgcode(org_code, "pic")
# mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
mqtt_pub_port = mqtt_pub_config.mqtt_port
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
print(f"mqtt_pub_topic {mqtt_pub_topic}")
mqtt_pub_username = mqtt_pub_config.mqtt_username
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
mqtt_pub_description = mqtt_pub_config.mqtt_description
mqtt_pub_org_code = mqtt_pub_config.org_code
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
try:
model_path = "pt/best.pt" # 默认使用人车
model_cls = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
chinese_label = {0: "行人", 1: "", 2: "自行车", 3: "汽车", 4: "厢型车", 5: "卡车", 6: "三轮车", 7: "三轮车",
8: "公交", 9: "摩托"}
for model_id in request_json.content_body.func_id:
config = dao.get_config(model_id, []) # 应该有个category字段做过滤
# model_path = "pt/best.pt" # 桥梁裂缝检测
# model_cls = [0, 1]
# chinese_label = {0: "行人", 1: "人"}
model_path = config.model_path # 桥梁裂缝检测
model_cls = config.filter_indices
chinese_label = config.filtered_cls_en_dict
conf = config.conf
engine_file_path = config.engine_path
PLUGIN_LIBRARY = config.so_path
confidence = config.conf
categories = config.allowed_classes
await asyncio.to_thread(
pic_detect_func,
mqtt_pub_ip,
mqtt_pub_port,
mqtt_pub_topic,
request_json.task_id,
request_json.content_body.func_id,
request_json.content_body.para,
request_json.content_body.s3_url,
# "best.pt", [0, 1, 2, 3, 4],
model_path,
model_cls,
chinese_label,
conf
)
# await asyncio.to_thread(pic_detect_func_trt, mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
# request_json.task_id, request_json.content_body.s3_url, engine_file_path,
# PLUGIN_LIBRARY, confidence, categories)
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
raise SanicException(f"Failed to start AI video processing: {str(e)}", status_code=500)
# 未针对具体方法做实现,待完成
async def run_backDetect_async(request_json):
"""异步运行分割检测算法"""
try:
model_path = "pt/best.pt" # 默认使用人车
model_cls = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
chinese_label = {0: "行人", 1: "", 2: "自行车", 3: "汽车", 4: "厢型车", 5: "卡车", 6: "三轮车", 7: "三轮车",
8: "公交", 9: "摩托"}
# 针对湖南现场特殊处理
py_func = request_json.content_body.para_list[0]["py_func"]
for v in request_json.content_body.func_id:
match v:
case 100000:
# 0: honeycomb & rough_surface
# 1: peeling & chipping
# 2: cavity & hole
# 3: rust on rebar
# 4: damage & exposedrebar
# 5: creck
model_path = "pt/best.pt" # 桥梁裂缝检测
model_cls = [0, 1]
chinese_label = {0: "行人", 1: ""}
case 100004:
# 0: honeycomb & rough_surface
# 1: peeling & chipping
# 2: cavity & hole
# 3: rust on rebar
# 4: damage & exposedrebar
# 5: creck
model_path = "pt/best.pt" # 桥梁裂缝检测
model_cls = [0, 1]
chinese_label = {0: "行人", 1: ""}
case 100002:
# 0: honeycomb & rough_surface
# 1: peeling & chipping
# 2: cavity & hole
# 3: rust on rebar
# 4: damage & exposedrebar
# 5: creck
model_path = "pt/best.pt" # 桥梁裂缝检测
# model_cls = [2, 3, 4, 5, 6, 7, 8, 9]
# chinese_label = {2: "自行车", 3: "汽车", 4: "厢型车", 5: "卡车", 6: "三轮车",
# 7: "三轮车", 8: "公交", 9: "摩托"}
model_cls = py_func
chinese_label = {0: "", 1: "行人", 2: "自行车", 3: "汽车", 4: "厢型车", 5: "卡车", 6: "三轮车",
7: "三轮车", 8: "公交", 9: "摩托"}
# 针对湖北现场特殊处理
new_chinese_label = {}
seen_values = set()
for key in model_cls:
if key in chinese_label:
value = chinese_label[key]
if value not in seen_values:
new_chinese_label[key] = value
seen_values.add(value)
chinese_label = new_chinese_label
print("去重后字典:", model_cls, " ", new_chinese_label)
case 100006:
# 0: honeycomb & rough_surface
# 1: peeling & chipping
# 2: cavity & hole
# 3: rust on rebar
# 4: damage & exposedrebar
# 5: creck
model_path = "pt/best.pt" # 桥梁裂缝检测
model_cls = [2, 3, 4, 5, 6, 7, 8, 9]
chinese_label = {2: "自行车", 3: "汽车", 4: "厢型车", 5: "卡车", 6: "三轮车",
7: "三轮车", 8: "公交", 9: "摩托"}
case 100031:
# 0: honeycomb & rough_surface
# 1: peeling & chipping
# 2: cavity & hole
# 3: rust on rebar
# 4: damage & exposedrebar
# 5: creck
model_path = "pt/build.pt" # 桥梁裂缝检测
model_cls = [0, 1, 2, 3, 4, 5]
chinese_label = {0: "蜂窝", 1: "剥落", 2: "空腔", 3: "锈蚀", 4: "裸露", 5: "裂缝"}
case 100041:
model_path = "pt/smoke.pt" # 烟雾,对应类型为 0: smoke1: hote(infrared)&fire
model_cls = [0, 1]
chinese_label = {0: "烟雾", 1: ""}
case 100051:
# 0: 挖掘机
# 1: 自卸卡车
# 2: 压路机
# 3: 移动式起重机
# 4: 塔式起重机
# 5: 轮式装载机
# 6: 混凝土搅拌车
# 7: 反铲装载机
# 8: 推土机
# 9: 平地机
model_path = "pt/GDCL.pt" # 红外人车
model_cls = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
chinese_label = {0: "挖掘机", 1: "自卸卡车", 2: "压路机", 3: "移动式起重机", 4: "塔式起重机",
5: "轮式装载机", 6: "混凝土搅拌车", 7: "反铲装载机", 8: "推土机", 9: "平地机"}
case 100061:
# 0: human
# 1: track
# 2: car
# 3: bicycle
model_path = "pt/HWRC.pt" # 红外人车
model_cls = [0, 1, 2, 3]
chinese_label = {0: "", 1: "卡车", 2: "", 3: "自行车"}
case 100071:
# 0: 垃圾
model_path = "pt/trash.pt" # 只有中文垃圾,所以是乱码
model_cls = [0]
chinese_label = {0: "垃圾"}
case 100081:
# 0: dyrb
# 1: dmjrb
# 10: ygfs_hw
# 11: ejgdl_ycdw
# 2: dyrb_ycdw
# 3: dmjrb_ycdw
# 4: ycdw
# 5: dyrb_zd
# 6: zd_hw
# 7: dmjrb_zd
# 8: ycdw_zd
# 9: ejgdl
model_path = "pt/hwgf.pt" # 红外人车
model_cls = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
chinese_label = {0: "单一热班", 1: "大面积热斑", 2: "单一热班&异常低温", 3: "大面积热班&异常低温",
4: "异常低温", 5: "单一热斑&遮挡", 6: "异常低温&热斑", 7: "二极管短路",
8: "阳光反射", 9: "二极管短路&异常低温"}
case _:
model_path = "pt/gdaq.pt" # 默认使用人车
model_cls = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
chinese_label = {0: "行人", 1: "", 2: "自行车", 3: "汽车", 4: "厢型车", 5: "卡车", 6: "三轮车",
7: "三轮车", 8: "公交", 9: "摩托"}
await asyncio.to_thread(
startBackAIVideo,
request_json.task_id,
# request_json.content_body.model_func_id[0],
request_json.content_body.source_url,
request_json.content_body.push_url, # 临时增加推流
# "best.pt",
model_path,
model_cls,
chinese_label,
request_json.content_body.func_id,
request_json.content_body.confidence,
request_json.content_body.para,
request_json.content_body.mqtt_ip,
request_json.content_body.mqtt_port,
request_json.content_body.mqtt_topic,
)
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
raise SanicException(f"Failed to start AI video processing: {str(e)}", status_code=500)
async def run_back_Multi_Detect_async(request, request_json, stop_event: asyncio.Event):
global DB_CONFIG
model_configs = []
task_handle = None # 初始化task_handle
try:
invade_enable = false # para_list 中可以包含多个侵限使能其中一个使能即将置为True
py_func = []
# 创建DAO实例
dao = ModelConfigDAO(DB_CONFIG)
# insert_request_log(self, task_id, sn, org_code, requset_json, request)
for para in request_json.content_body.para_list:
func_id = para["func_id"]
category = para.get("py_func", []) # 提供默认值
py_func = category # 湖北现场临时用
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
if para_invade_enable:
invade_enable = True
query_results = batch_query_model_func_id([func_id], **DB_CONFIG)
row_func_id = 0 #伪代码,后续记得修改
if len(query_results) < 1:
continue
for row in query_results:
row_func_id = row["model_func_id"]
func_id = para["func_id"]
config = dao.get_config(func_id, category)
repeat_dis = -1 # 基于两帧之间的距离去重
if config:
# 打印结构化结果使用自定义编码器处理datetime
# 访问特定字段
print("\n模型路径:", config.model_path)
print("过滤类别:", config.filter_indices)
print("第一个类别:", asdict(config.classes[0]))
print("创建时间:", config.created_at)
print("更新时间:", config.updated_at)
print("去重的距离:", config.repeat_dis)
repeat_dis = config.repeat_dis
repeat_time = config.repeat_time
model_configs.append(
{
'path': config.model_path,
'engine_path': config.engine_path,
'so_path': config.so_path,
# # 测试代码
# 'engine_path': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\build.engine",
# 'so_path': r"D:\project\AI-PYTHON\tensorrtx-master\yolo11\build\Release\myplugins.dll",
# 工地安全帽
# 'engine_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\gdaq_hat_0926.engine",
# 'so_path': r"D:\project\AI-PYTHON\Ai_tottle\engine\gdaq_hat_0926\myplugins.dll",
'cls_map': config.cls_zn_to_eh_dict,
'allowed_classes': config.allowed_classes,
"cls_index": config.class_indices,
"class_names": config.cls_names,
"chinese_label": config.cls_en_dict,
"list_func_id": row_func_id,
"func_id": func_id,
"para_invade_enable": para_invade_enable,
"config_conf": config.conf
}
)
else:
print(f"未找到ID为 {func_id} 的模型配置")
# category = para.get("category", []) # 提供默认值
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
if para_invade_enable:
invade_enable = True
# 前置处理 不同的版本的模型输入的数据的格式不同需要做前置处理入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式
video_url = request_json.content_body.source_url
sn = request_json.sn
task_id = request_json.task_id
# mqtt_pub_id = request_json.content_body.mqtt_pub_id
# mqtt_sub_id = request_json.content_body.mqtt_sub_id
org_code = request_json.content_body.org_code
push_url = request_json.content_body.push_url
# invade_file = request_json.content_body.invade_file
invade = request_json.content_body.invade
invade_file = invade["invade_file"]
camera_para_url = invade["camera_para_url"]
# dao.get_mqtt_config_by_orgcode(org_code,)
str_request = str(request) + "&" + str(request.socket) # 待测试看看公网能不能捕获到请求端ip
dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request)
mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
mqtt_sub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "sub")
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
mqtt_pub_port = mqtt_pub_config.mqtt_port
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
print(f"mqtt_pub_topic {mqtt_pub_topic}")
mqtt_pub_username = mqtt_pub_config.mqtt_username
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
mqtt_pub_description = mqtt_pub_config.mqtt_description
mqtt_pub_org_code = mqtt_pub_config.org_code
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
mqtt_sub_ip = mqtt_sub_config.mqtt_ip
mqtt_sub_port = mqtt_sub_config.mqtt_port
mqtt_sub_topic = mqtt_sub_config.mqtt_topic
mqtt_sub_topic = mqtt_sub_topic.format(sn=sn)
print(f"mqtt_sub_topic {mqtt_sub_topic}")
mqtt_sub_username = mqtt_sub_config.mqtt_username
mqtt_sub_pass = mqtt_sub_config.mqtt_pass
mqtt_sub_description = mqtt_sub_config.mqtt_description
mqtt_sub_org_code = mqtt_sub_config.org_code
mqtt_sub_mqtt_type = mqtt_sub_config.mqtt_type
local_func_config = read_local_func_config()
sn = request_json.sn
device = dao.get_device(sn, org_code)
print(f"device表 {sn} {org_code}")
device_sn = device.sn
device_orgcode = device.orgcode
device_dname = device.dname
device_lat = device.lat
device_lng = device.lng
device_height = device.height # 机场高度,后续用作现场的高程计算
# # 启动处理流程
async def process_flow():
try:
await start_rtmp_processing(
video_url,
request_json.task_id,
model_configs,
mqtt_pub_ip, mqtt_pub_port, mqtt_pub_topic,
mqtt_sub_ip, mqtt_sub_port, mqtt_sub_topic,
push_url,
invade_enable, invade_file, camera_para_url,
device_height, repeat_dis, repeat_time
)
except Exception as e:
logger.error(f"处理流程异常: {e}")
raise
# 运行处理流程,直到被取消
await asyncio.shield(process_flow())
# # # # 针对湖北现场临时处理--------------------------------------------------------
# model_path = model_configs[0]["path"]
# detect_classes = py_func
# print(f"detect_classesdetect_classes {detect_classes}")
# confidence = model_configs[0]["config_conf"]
# # 创建处理函数以支持停止事件
# async def process_video():
# nonlocal task_handle # 使用nonlocal访问外部变量
# try:
# # 针对湖北现场临时处理
# source_url = video_url
# model_path = model_configs[0]["path"]
# detect_classes = py_func # 使用py_func作为检测类别
# print(f"detect_classesdetect_classes {detect_classes}")
#
# confidence = model_configs[0]["config_conf"]
#
# # 启动YOLO检测
# await asyncio.to_thread(
# startAIVideo,
# source_url,
# push_url,
# model_path,
# detect_classes,
# confidence
# )
# except asyncio.CancelledError:
# logger.info(f"任务 {task_id} 被取消")
# raise
# except Exception as e:
# logger.error(f"任务 {task_id} 异常终止: {e}")
# raise
#
# # 创建并启动任务
# task_handle = asyncio.create_task(process_video()) # 存储task_handle
# 记录任务信息到task_manager
task_info = {
"source_url": video_url,
"push_url": push_url,
"status": "running",
"task_handle": task_handle, # 存储实际的任务句柄
"model_configs": model_configs,
"device_height": device_height,
"repeat_dis": repeat_dis,
"repeat_time": repeat_time
}
# 使用task_manager管理任务
await task_manager.add_task(
task_id,
task_info,
task_handle, # 传递实际的任务句柄
[] # 暂时没有子任务
)
# 等待任务完成或被取消
try:
await task_handle
except asyncio.CancelledError:
pass # 任务被取消是正常的
await task_manager.remove_task(task_id)
# # # 针对湖北现场处理结束--------------------------------------------------------
except asyncio.CancelledError:
logger.info(f"任务 {request_json.task_id} 收到停止信号,正在清理...")
# 清理资源逻辑...
raise
except Exception as e:
logger.error(f"任务 {request_json.task_id} 处理失败: {e}")
raise
# 未针对具体方法做实现,待完成
async def run_back_Video_Multi_Detect_async(request, request_json):
"""异步运行分割检测算法"""
global DB_CONFIG
model_configs = []
try:
invade_enable = false # para_list 中可以包含多个侵限使能其中一个使能即将置为True
# 创建DAO实例
dao = ModelConfigDAO(DB_CONFIG)
# model_func_id
# 当前查询逻辑待优化直接基于postgres做批量查询
sn = request_json.sn
for para in request_json.content_body.para_list:
# model_func_id = para["model_func_id"]
func_id = [para["func_id"]]
query_results = batch_query_model_func_id(func_id, **DB_CONFIG)
row_func_id = 0
# func_id = 0
if len(query_results) < 1:
continue
for row in query_results:
row_func_id = row["model_func_id"]
func_id = row["func_id"]
category = para.get("py_func", []) # 使用过滤值,前端提供
para_invade_enable = para.get("para_invade_enable", False) # 提供默认值
if para_invade_enable:
invade_enable = True
config = dao.get_config(func_id, category)
# 打印结构化结果使用自定义编码器处理datetime
repeat_dis = -1 # 基于两帧之间的距离去重
if config:
# 打印结构化结果使用自定义编码器处理datetime
# 访问特定字段
print("\n模型路径:", config.model_path)
print("过滤类别:", config.filter_indices)
print("第一个类别:", asdict(config.classes[0]))
print("创建时间:", config.created_at)
print("更新时间:", config.updated_at)
print("去重的距离:", config.repeat_dis)
repeat_dis = config.repeat_dis
repeat_time = config.repeat_time
model_configs.append(
{
'path': config.model_path,
'cls_map': config.cls_zn_to_eh_dict,
'allowed_classes': config.allowed_classes,
"cls_index": config.class_indices,
"class_names": config.cls_names,
"chinese_label": config.cls_en_dict,
"list_func_id": row_func_id,
"func_id": func_id,
"config_conf": config.conf,
"para_invade_enable": para_invade_enable,
"config_conf": config.conf
}
)
else:
print(f"未找到ID为 {func_id} 的模型配置")
org_code = request_json.content_body.org_code
mqtt_pub_config = dao.get_mqtt_config_by_orgcode(org_code, sn, "pub")
# mqtt_pub_config = dao.get_mqtt_config(mqtt_pub_id)
if not mqtt_pub_config:
print("org_code 查询失败")
return
mqtt_pub_ip = mqtt_pub_config.mqtt_ip
mqtt_pub_port = mqtt_pub_config.mqtt_port
mqtt_pub_topic = mqtt_pub_config.mqtt_topic
print(f"mqtt_pub_topic {mqtt_pub_topic}")
mqtt_pub_username = mqtt_pub_config.mqtt_username
mqtt_pub_pass = mqtt_pub_config.mqtt_pass
mqtt_pub_description = mqtt_pub_config.mqtt_description
mqtt_pub_org_code = mqtt_pub_config.org_code
mqtt_pub_mqtt_type = mqtt_pub_config.mqtt_type
# 前置处理 不同的版本的模型输入的数据的格式不同需要做前置处理入参应该是模型类型、版本、rest参数。出参应该是模型入参的格式
minio_file_path = request_json.content_body.minio_file_path
task_id = request_json.task_id
# sn = request_json.sn
device = dao.get_device(sn, org_code)
device_sn = device.sn
device_orgcode = device.orgcode
device_dname = device.dname
device_lat = device.lat
device_lng = device.lng
device_height = device.height # 机场高度,后续用作现场的高程计算
# 记录掉用日志
str_request = str(request) + "&" + str(request.socket) # 待测试看看公网能不能捕获到请求端ip
dao.insert_request_log(task_id, sn, org_code, str(request.body), str_request)
push_url = request_json.content_body.push_url
# invade_file = request_json.content_body.invade_file
invade = request_json.content_body.invade
invade_file = invade["invade_file"]
camera_para_url = invade["camera_para_url"]
await start_video_processing(minio_file_path, task_id, model_configs, mqtt_pub_ip, mqtt_pub_port,
mqtt_pub_topic, push_url,
invade_enable, invade_file, camera_para_url, device_height, repeat_dis,
repeat_time)
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
raise SanicException(f"Failed to start AI video processing: {str(e)}", status_code=500)
# 接收前端实时流进行后台计算并且将结果存储到minio
@app.post("/ai/stream/back_detect")
async def start_back_detection(request):
try:
verify_token(request)
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
# 尝试恢复服务
service_status["is_healthy"] = True
# 先停止所有现有任务
for task_id in list(task_manager.tasks.keys()):
logger.info(f"停止现有任务 {task_id} 以启动新任务")
try:
success = await safe_stop_ai_video()
if success:
task_manager.remove_task(task_id)
else:
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
task_manager.mark_all_tasks_as_stopped()
except Exception as e:
logger.error(f"停止任务时出错: {e}")
# 继续执行,尝试启动新任务
# 解析并验证请求数据
request_json = RequestJson.from_dict(request.json)
if isinstance(request_json.content_body, ContentBodyFormat_BackDetect):
# 加格式校验所有的的逻辑都是跟方法id做直接关联通过方法id进行参数校验。进而将校验问题直接返回
func_id_list = request_json.content_body.func_id
func_id_100021 = 100021 # 人员聚集的判断方法
func_id_100023 = 100023 # 车辆聚集的判断方法
# 人员聚集、车辆聚集需要在para字段里面加入参数N
if (func_id_100021 in func_id_list or func_id_100023 in func_id_list) and request_json.content_body.para[
"N"] is None:
return json_response({
"status": "error",
"message": "para[N] could not be None"
}, status=500)
try:
asyncio.create_task(run_backDetect_async(request_json))
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
else:
return json_response({
"status": "failed",
"message": "content_body structure is wrong"
})
# 记录任务信息
task_manager.add_task(request_json.task_id, {
"task_id": request_json.task_id,
"source_url": request_json.content_body.source_url,
"func_id": request_json.content_body.func_id
})
return json_response({
"status": "success",
"task_id": request_json.task_id,
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.post("/ai/stream/multi_back_detect")
async def start_multi_back_detection(request):
try:
verify_token(request)
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
service_status["is_healthy"] = True
# 停止所有现有任务(可选,根据需求调整)
# for task_id in list(task_manager.tasks.keys()):
# await task_manager.remove_task(task_id)
# 解析并验证请求数据
request_json = RequestJson.from_dict(request.json)
print(f"/ai/stream/multi_back_detect 请求:{request.json}")
time.sleep(3)
if request_json.task_id in task_manager.tasks:
logger.warning(f"任务 {request_json.task_id} 已存在,跳过创建")
return json_response({
"status": "error",
"message": f"任务 {request_json.task_id} 已存在,跳过创建"
}, status=500)
if isinstance(request_json.content_body, ContentBodyFormat_MultiBackDetect):
try:
# 创建停止事件
stop_event = asyncio.Event()
# 包装处理函数以支持停止事件
async def wrapped_processing():
try:
await run_back_Multi_Detect_async(request, request_json, stop_event)
except asyncio.CancelledError:
logger.info(f"任务 {request_json.task_id} 被取消")
except Exception as e:
logger.error(f"任务 {request_json.task_id} 异常终止: {e}")
# 创建并启动任务
task_handle = asyncio.create_task(wrapped_processing())
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
else:
return json_response({
"status": "failed",
"message": "content_body structure is wrong"
})
return json_response({
"status": "success",
"task_id": request_json.task_id,
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
# 接收前端实时流进行后台计算并且将结果存储到minio
@app.post("/ai/stream/video_multi_back_detect")
async def start_video_multi_back_detect(request):
try:
verify_token(request)
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
# 尝试恢复服务
service_status["is_healthy"] = True
# 解析并验证请求数据
request_json = RequestJson.from_dict(request.json)
print(f"/ai/stream/video_multi_back_detect 请求:{request.json}")
time.sleep(3)
if isinstance(request_json.content_body, ContentBodyFormat_VideoMultiBackDetect):
try:
# await initialize_resources()
asyncio.create_task(run_back_Video_Multi_Detect_async(request, request_json))
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
else:
return json_response({
"status": "failed",
"message": "content_body structure is wrong"
})
# # 记录任务信息
# task_manager.add_task(request_json.task_id, {
# "task_id": request_json.task_id,
# "minio_file_path": request_json.content_body.minio_file_path,
# "para_list": request_json.content_body.para_list
# })
return json_response({
"status": "success",
"task_id": request_json.task_id,
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
await cleanup_resources() # 清理资源
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
# 接收前端实时流进行后台计算并且将结果存储到minio
@app.post("/ai/pic/back_detect_pic")
async def start_back_pic_detection(request):
try:
verify_token(request)
print(f"/ai/pic/back_detect_pic 1")
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
# 尝试恢复服务
service_status["is_healthy"] = True
# 先停止所有现有任务
for task_id in list(task_manager.tasks.keys()):
logger.info(f"停止现有任务 {task_id} 以启动新任务")
try:
success = await safe_stop_ai_video()
if success:
task_manager.remove_task(task_id)
else:
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
task_manager.mark_all_tasks_as_stopped()
except Exception as e:
logger.error(f"停止任务时出错: {e}")
# 继续执行,尝试启动新任务
# 解析并验证请求数据
request_json = RequestJson.from_dict(request.json)
if isinstance(request_json.content_body, ContentBodyFormat_BackDetectPic):
try:
print(3333)
asyncio.create_task(run_picDetect_async(request_json))
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
else:
return json_response({
"status": "failed",
"message": "content_body structure is wrong"
})
# # 记录任务信息
# task_manager.add_task(request_json.task_id, {
# "s3_id": request_json.content_body.s3_id,
# "s3_url": request_json.content_body.s3_url,
# "func_id": request_json.content_body.func_id
# })
return json_response({
"status": "success",
"task_id": request_json.task_id,
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.post("/ai/stream/detect")
async def start_detection(request):
try:
verify_token(request)
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
# 尝试恢复服务
service_status["is_healthy"] = True
# 先停止所有现有任务
for task_id in list(task_manager.tasks.keys()):
logger.info(f"停止现有任务 {task_id} 以启动新任务")
try:
success = await safe_stop_ai_video()
if success:
task_manager.remove_task(task_id)
else:
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
task_manager.mark_all_tasks_as_stopped()
except Exception as e:
logger.error(f"停止任务时出错: {e}")
# 继续执行,尝试启动新任务
# 解析并验证请求数据
# stream_request = StreamRequest.from_dict(request.json)
request_json = request.json
source_url = request_json["source_url"]
push_url = request_json["push_url"]
model_path = request_json["model_path"]
detect_classes = request_json["detect_classes"]
confidence = request_json["confidence"]
task_id = str(uuid.uuid4())
# 针对之前的接口模型都放到了pt文件夹
model_path = "pt/" + model_path
# 启动YOLO检测
try:
await asyncio.to_thread(
startAIVideo,
source_url,
push_url,
model_path,
detect_classes,
confidence
)
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
# # 记录任务信息
# task_manager.add_task(task_id, {
# "source_url": source_url,
# "push_url": push_url,
# "model_path": model_path,
# "detect_classes": detect_classes,
# "confidence": confidence
# })
return json_response({
"status": "success",
"task_id": task_id,
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
async def run_detection_async(request_json):
"""异步运行分割检测算法"""
try:
await asyncio.to_thread(
# detection_func,
corpland_detection_func,
request_json.task_id,
request_json.content_body.s3_id,
request_json.content_body.s3_url.early,
request_json.content_body.s3_url.later,
request_json.content_body.func_id
)
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
raise SanicException(f"Failed to start AI video processing: {str(e)}", status_code=500)
async def run_segmentation_async(request_json):
"""异步运行分割检测算法"""
try:
await asyncio.to_thread(
send_mqtt_uv_request,
request_json.task_id,
request_json.content_body.s3_id,
request_json.content_body.s3_url,
request_json.content_body.func_id
)
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
raise SanicException(f"Failed to start AI video processing: {str(e)}", status_code=500)
# 两期影像变化监测
@app.post("/ai/pic/detection")
async def pic_detection(request):
try:
verify_token(request)
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
# 尝试恢复服务
service_status["is_healthy"] = True
# # 先停止所有现有任务
# for task_id in list(task_manager.tasks.keys()):
# logger.info(f"停止现有任务 {task_id} 以启动新任务")
# try:
# success = await safe_stop_ai_video()
# if success:
# task_manager.remove_task(task_id)
# else:
# logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
# task_manager.mark_all_tasks_as_stopped()
# except Exception as e:
# logger.error(f"停止任务时出错: {e}")
# # 继续执行,尝试启动新任务
# 解析并验证请求数据
request_json = RequestJson.from_dict(request.json)
if isinstance(request_json.content_body, ContentBodyFormat_Detection):
try:
asyncio.create_task(run_detection_async(request_json))
except Exception as e:
logger.error(f"启动AI视频处理失败: {e}")
return json_response({
"status": "error",
"message": f"Failed to start AI video processing: {str(e)}"
}, status=500)
else:
return json_response({
"status": "failed",
"message": "content_body structure is wrong"
})
# # 记录任务信息
# task_manager.add_task(request_json.task_id, {
# "s3_id": request_json.content_body.s3_id,
# # "push_url": stream_request.push_url,
# "s3_url": request_json.content_body.s3_url,
# "func_id": request_json.content_body.func_id
# })
return json_response({
"status": "success",
"task_id": request_json.task_id,
"message": "Detection started successfully"
})
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
@app.post("/ai/pic/segementation")
async def pic_segementation(request):
try:
verify_token(request)
# 检查服务健康状态
if not service_status["is_healthy"]:
logger.warning(
f"服务处于不健康状态,上次错误: {service_status['last_error']}{service_status['error_time']}")
# 尝试恢复服务
service_status["is_healthy"] = True
# 先停止所有现有任务
for task_id in list(task_manager.tasks.keys()):
logger.info(f"停止现有任务 {task_id} 以启动新任务")
try:
success = await safe_stop_ai_video()
if success:
task_manager.remove_task(task_id)
else:
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
task_manager.mark_all_tasks_as_stopped()
except Exception as e:
logger.error(f"停止任务时出错: {e}")
# 继续执行,尝试启动新任务
# 解析并验证请求数据
request_json = RequestJson.from_dict(request.json)
if isinstance(request_json.content_body, ContentBodyFormat_Segementation):
# 异步运行分割检测算法
asyncio.create_task(run_segmentation_async(request_json))
#
# # 记录任务信息
# task_manager.add_task(request_json.task_id, {
# "s3_id": request_json.content_body.s3_id,
# "s3_url": request_json.content_body.s3_url,
# "func_id": request_json.content_body.func_id
# })
return json_response({
"status": "success",
"task_id": request_json.task_id,
"message": "Detection started successfully"
})
else:
return json_response({
"status": "failed",
"message": "content_body structure is wrong"
}, status=400)
except ValueError as e:
logger.error(f"Validation error: {str(e)}")
return json_response({"status": "error", "message": str(e)}, status=400)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
# 查询任务状态
@app.get("/api/tasks/{task_id}")
async def get_task_status(request: Request, task_id: str):
try:
task_info = await task_manager.get_task_info(task_id)
if not task_info:
return json_response({
"status": "error",
"message": "Task not found"
}, status=404)
return json_response({
"status": "success",
"data": task_info
})
except Exception as e:
logger.error(f"Failed to get task status: {str(e)}")
return json_response({
"status": "error",
"message": str(e)
}, status=500)
# 查询所有任务
@app.get("/api/tasks")
async def list_tasks(request: Request):
try:
tasks_info = await task_manager.get_all_tasks()
return json_response({
"status": "success",
"data": tasks_info
})
except Exception as e:
logger.error(f"Failed to list tasks: {str(e)}")
return json_response({
"status": "error",
"message": str(e)
}, status=500)
#
# # REST API 删除任务接口
# @app.route("/api/tasks/<task_id:str>", methods=["DELETE"])
# async def delete_task(request, task_id: str):
# """删除任务并取消所有关联的协程"""
# if task_id in task_manager.tasks:
# task_info = task_manager.tasks[task_id]["task_info"]
# if 'cancel_flag' in task_info:
# task_info['cancel_flag'].set() # 触发任务取消
#
# # 取消主任务
# if 'main_task' in task_info and not task_info['main_task'].done():
# task_info['main_task'].cancel()
#
# # 清理任务信息
# del task_manager.tasks[task_id]
#
# return json_response({"status": "success", "message": f"Task {task_id} deleted successfully"})
# else:
# return json_response({"status": "error", "message": f"Task {task_id} not found"}, status=404)
@app.post("/ai/task/reset")
async def reset_task_heart(request):
try:
verify_token(request)
task_id = request.json["task_id"]
# request_json = RequestJson.from_dict(request.json)
await task_manager.update_heartbeat(task_id)
if not task_id:
return json_response({
"status": "error",
"message": f"task:{task_id} is None"
}, status=500)
return json_response({
"status": "success",
"message": "Task update_heartbeat successfully"
})
except Exception as e:
logger.error(f"删除任务失败: {str(e)}")
return json_response({
"status": "error",
"message": str(e)}, status=500)
@app.post("/ai/task/stop")
async def stop_task_heart(request):
try:
verify_token(request)
task_id = request.json["task_id"]
# # 先停止所有现有任务
for task_id in list(task_manager.tasks.keys()):
logger.info(f"停止现有任务 {task_id} 以启动新任务")
try:
success = await safe_stop_ai_video()
print("stopAIVideo111")
stopAIVideo()
print("stopAIVideo222")
if success:
await task_manager.remove_task(task_id)
stopAIVideo()
else:
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
task_manager.mark_all_tasks_as_stopped()
except Exception as e:
logger.error(f"停止任务时出错: {e}")
# 继续执行,尝试启动新任务
# request_json = RequestJson.from_dict(request.json)
await task_manager.remove_task(task_id)
if not task_id:
return json_response({
"status": "error",
"message": f"task:{task_id} is None"
}, status=500)
return json_response({
"status": "success",
"message": "Task stop successfully"
})
except Exception as e:
logger.error(f"停止任务失败: {str(e)}")
return json_response({
"status": "error",
"message": str(e)}, status=500)
@app.websocket("/ws")
async def websocket_endpoint(request: Request, ws):
"""WebSocket端点处理前端连接和消息为每个连接创建独立的MQTT客户端"""
# 生成唯一连接ID
connection_id = str(uuid.uuid4())
logger.info(f"New WebSocket connection established. Connection ID: {connection_id}")
# 创建连接上下文
context = ConnectionContext(connection_id, ws)
active_connections[connection_id] = context
try:
while True:
message = await ws.recv()
logger.info(f"Connection {connection_id} received WebSocket message: {message}")
try:
# 解析JSON消息
data = json.loads(message)
# 验证必要字段
required_fields = ["img_width", "img_height", "file_url", "broker", "port", "topic"]
for field in required_fields:
if field not in data:
raise ValueError(f"Missing required field: {field}")
# 下载文件
file_url = data["file_url"]
# 假设downFile是已实现的下载函数
local_file_path = downFile(file_url)
# # 下载文件
# camera_para_url = data["camera_para_url"]
# # 假设downFile是已实现的下载函数
# 临时手段将型号与配置文件挂钩实际应该是飞机与配置文件挂钩20250928临时手段
model2 = data["model2"]
if model2 == "M3D":
camera_para_url = "meta_data/camera_para/xyzj_camera_para.txt"
if model2 == "M4D":
camera_para_url = "meta_data/camera_para/xyzj_camera_para.txt"
elif model2 == "M4TD":
camera_para_url = "meta_data/camera_para/hami_camera_para .txt"
camera_file_path = downFile(camera_para_url)
if not local_file_path or not isinstance(local_file_path, str):
raise ValueError("Failed to download file or invalid file path")
# 初始化计算器
try:
logger.info(f"Connection {connection_id} initializing calculator with file path: {local_file_path}")
context.calculator = CalTouYing(
str(local_file_path),
str(camera_file_path),
int(data["img_width"]),
int(data["img_height"])
)
context.file_path = local_file_path
except Exception as calc_error:
logger.error(f"Connection {connection_id} calculator initialization failed: {calc_error}")
raise ValueError(f"Calculator initialization failed: {str(calc_error)}")
# 连接MQTT
broker = data["broker"]
port = int(data["port"])
topic = data["topic"]
client_id = data.get("client_id", f"sanic-client-{connection_id}")
# 检查是否需要重新连接
if not context.mqtt_client.connected or context.mqtt_client.broker != broker or context.mqtt_client.port != port:
# 如果已连接,先断开
if context.mqtt_client.connected:
logger.info(f"Connection {connection_id} reconnecting to new MQTT broker")
await context.mqtt_client.disconnect()
# 连接新的MQTT broker
success = await context.mqtt_client.connect(broker, port, client_id)
if not success:
raise ConnectionError("Failed to connect to MQTT broker")
# 订阅主题
subscribe_success = await context.mqtt_client.subscribe(topic)
if not subscribe_success:
raise ConnectionError(f"Failed to subscribe to topic {topic}")
await ws.send(json.dumps({
"status": "success",
"message": f"Connected to MQTT broker {broker}:{port} and subscribed to topic {topic}",
"connection_id": connection_id
}))
except json.JSONDecodeError:
logger.error(f"Connection {connection_id} received invalid JSON format")
await ws.send(json.dumps({
"status": "error",
"error": "Invalid JSON format"
}))
except ValueError as ve:
logger.error(f"Connection {connection_id} value error: {str(ve)}")
await ws.send(json.dumps({
"status": "error",
"error": str(ve)
}))
except Exception as e:
logger.error(f"Connection {connection_id} error processing message: {str(e)}", exc_info=True)
await ws.send(json.dumps({
"status": "error",
"error": f"Server error: {str(e)}"
}))
except ConnectionClosed:
logger.info(f"客户端断开连接 WebSocket connection closed by client. Connection ID: {connection_id}")
except Exception as e:
logger.error(f"WebSocket error for connection {connection_id}: {str(e)}", exc_info=True)
finally:
# 清理资源
if connection_id in active_connections:
# 断开MQTT连接
if context.mqtt_client.connected:
logger.info(f"Disconnecting MQTT for connection {connection_id}")
await context.mqtt_client.disconnect()
# 移除连接
del active_connections[connection_id]
logger.info(f"Connection {connection_id} cleaned up. Remaining connections: {len(active_connections)}")
if __name__ == "__main__":
# 保证服务启动前没有残留任务
try:
stopAIVideo()
print("服务启动前清理完成")
except:
print("服务启动前清理失败,但仍将继续")
# 安装psutil库用于进程管理
try:
import psutil
except ImportError:
import subprocess
import sys
print("正在安装psutil库...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
app.run(host="0.0.0.0", port=12315, debug=False, access_log=True)
# app.run(host="0.0.0.0", workers=3, port=12315)