541 lines
20 KiB
Python
541 lines
20 KiB
Python
from sanic import Sanic, json
|
||
from sanic.response import json as json_response
|
||
from sanic.exceptions import Unauthorized, NotFound, SanicException
|
||
from dataclasses import dataclass
|
||
from typing import List, Dict, Any, Optional
|
||
import uuid
|
||
import logging
|
||
import asyncio
|
||
import traceback
|
||
from datetime import datetime
|
||
from cv_video import startAIVideo,stopAIVideo,getIfAI
|
||
from sanic_cors import CORS
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 服务状态标志
|
||
service_status = {"is_healthy": True, "last_error": None, "error_time": None}
|
||
|
||
# 配置类
|
||
class Config:
|
||
VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
||
MAX_ACTIVE_TASKS = 10
|
||
DEFAULT_CONFIDENCE = 0.5
|
||
RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
||
|
||
|
||
@dataclass
|
||
class StreamRequest:
|
||
source_url: str
|
||
push_url: str
|
||
model_path: str
|
||
detect_classes: List[str]
|
||
confidence: float = Config.DEFAULT_CONFIDENCE
|
||
|
||
def validate(self) -> None:
|
||
"""验证请求参数"""
|
||
if not self.source_url or not self.push_url:
|
||
raise ValueError("Source URL and Push URL are required")
|
||
|
||
if not self.detect_classes:
|
||
raise ValueError("At least one detection class must be specified")
|
||
if not 0 < self.confidence < 1:
|
||
raise ValueError("Confidence must be between 0 and 1")
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: Dict[str, Any]) -> 'StreamRequest':
|
||
try:
|
||
instance = cls(
|
||
source_url=data['source_url'],
|
||
push_url=data['push_url'],
|
||
model_path=data['model_path'],
|
||
detect_classes=data['detect_classes'],
|
||
confidence=data.get('confidence', Config.DEFAULT_CONFIDENCE)
|
||
)
|
||
instance.validate()
|
||
return instance
|
||
except KeyError as e:
|
||
raise ValueError(f"Missing required field: {str(e)}")
|
||
|
||
class TaskManager:
|
||
def __init__(self):
|
||
self.active_tasks: Dict[str, Dict[str, Any]] = {}
|
||
self.task_status: Dict[str, str] = {}
|
||
self.task_timestamps: Dict[str, datetime] = {}
|
||
|
||
def add_task(self, task_id: str, task_info: Dict[str, Any]) -> None:
|
||
"""添加新任务"""
|
||
if len(self.active_tasks) >= Config.MAX_ACTIVE_TASKS:
|
||
raise ValueError("Maximum number of active tasks reached")
|
||
|
||
self.active_tasks[task_id] = task_info
|
||
self.task_status[task_id] = "running"
|
||
self.task_timestamps[task_id] = datetime.now()
|
||
logger.info(f"Task {task_id} started")
|
||
|
||
def remove_task(self, task_id: str) -> None:
|
||
"""移除任务"""
|
||
if task_id in self.active_tasks:
|
||
del self.active_tasks[task_id]
|
||
del self.task_status[task_id]
|
||
del self.task_timestamps[task_id]
|
||
logger.info(f"Task {task_id} removed")
|
||
|
||
def get_task_info(self, task_id: str) -> Dict[str, Any]:
|
||
"""获取任务信息"""
|
||
if task_id not in self.active_tasks:
|
||
raise NotFound("Task not found")
|
||
|
||
return {
|
||
"task_info": self.active_tasks[task_id],
|
||
"status": self.task_status[task_id],
|
||
"start_time": self.task_timestamps[task_id].isoformat()
|
||
}
|
||
|
||
def check_tasks_health(self) -> Dict[str, str]:
|
||
"""检查任务健康状态"""
|
||
unhealthy_tasks = {}
|
||
for task_id in list(self.active_tasks.keys()):
|
||
# 检查任务是否还在运行(通过getIfAI()函数)
|
||
if not getIfAI():
|
||
unhealthy_tasks[task_id] = "stopped"
|
||
logger.warning(f"Task {task_id} appears to be stopped unexpectedly")
|
||
return unhealthy_tasks
|
||
|
||
def mark_all_tasks_as_stopped(self):
|
||
"""标记所有任务为已停止状态"""
|
||
for task_id in list(self.active_tasks.keys()):
|
||
self.task_status[task_id] = "stopped"
|
||
logger.warning("已将所有任务标记为停止状态")
|
||
|
||
app = Sanic("YoloStreamService")
|
||
CORS(app)
|
||
task_manager = TaskManager()
|
||
|
||
async def safe_stop_ai_video():
|
||
"""安全地停止AI视频处理,带有错误处理和恢复机制"""
|
||
try:
|
||
await asyncio.to_thread(stopAIVideo)
|
||
return True
|
||
except Exception as e:
|
||
error_msg = f"停止AI视频处理出错: {str(e)}\n{traceback.format_exc()}"
|
||
logger.error(error_msg)
|
||
|
||
# 标记服务状态为不健康
|
||
service_status["is_healthy"] = False
|
||
service_status["last_error"] = str(e)
|
||
service_status["error_time"] = datetime.now().isoformat()
|
||
|
||
# 强制结束所有任务
|
||
task_manager.mark_all_tasks_as_stopped()
|
||
|
||
# 尝试通过其他方式杀死可能存在的进程
|
||
try:
|
||
import os
|
||
import signal
|
||
import psutil
|
||
|
||
current_process = psutil.Process(os.getpid())
|
||
# 查找并终止ffmpeg子进程
|
||
for child in current_process.children(recursive=True):
|
||
try:
|
||
child_name = child.name().lower()
|
||
if 'ffmpeg' in child_name:
|
||
logger.info(f"强制终止子进程: {child.pid} ({child_name})")
|
||
child.send_signal(signal.SIGTERM)
|
||
except Exception as child_e:
|
||
logger.error(f"终止子进程出错: {str(child_e)}")
|
||
except Exception as kill_e:
|
||
logger.error(f"尝试清理进程时出错: {str(kill_e)}")
|
||
|
||
# 等待一段时间让系统恢复
|
||
await asyncio.sleep(Config.RESTART_DELAY)
|
||
|
||
# 重置服务状态
|
||
service_status["is_healthy"] = True
|
||
return False
|
||
|
||
def verify_token(request) -> None:
|
||
"""验证请求token"""
|
||
token = request.headers.get('X-API-Token')
|
||
if not token or token != Config.VALID_TOKEN:
|
||
logger.warning("Invalid token attempt")
|
||
raise Unauthorized("Invalid token")
|
||
|
||
@app.post("/ai/stream/detect")
|
||
async def start_detection(request):
|
||
try:
|
||
verify_token(request)
|
||
|
||
# 检查服务健康状态
|
||
if not service_status["is_healthy"]:
|
||
logger.warning(f"服务处于不健康状态,上次错误: {service_status['last_error']} 于 {service_status['error_time']}")
|
||
service_status["is_healthy"] = True # 尝试恢复
|
||
|
||
# 先停止所有现有任务
|
||
for task_id in list(task_manager.active_tasks.keys()):
|
||
logger.info(f"停止现有任务 {task_id} 以启动新任务")
|
||
try:
|
||
success = await safe_stop_ai_video()
|
||
if success:
|
||
task_manager.remove_task(task_id)
|
||
else:
|
||
logger.warning(f"无法正常停止任务 {task_id},但仍将继续")
|
||
task_manager.mark_all_tasks_as_stopped()
|
||
except Exception as e:
|
||
logger.error(f"停止任务时出错: {e}")
|
||
|
||
# 解析并验证请求数据
|
||
stream_request = StreamRequest.from_dict(request.json)
|
||
task_id = str(uuid.uuid4())
|
||
|
||
# 如果是“yanwu.pt”模型,调用外部接口获取 liveUrl
|
||
if stream_request.model_path == "yanwu.pt":
|
||
try:
|
||
import aiohttp
|
||
import tempfile
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
# 创建一个临时空白文件
|
||
with tempfile.NamedTemporaryFile(delete=False) as tmpfile:
|
||
tmpfile.write(b'')
|
||
|
||
with open(tmpfile.name, 'rb') as f:
|
||
form_data = aiohttp.FormData()
|
||
form_data.add_field("file", f, filename="empty.txt", content_type="text/plain")
|
||
form_data.add_field("model", "yanwu")
|
||
form_data.add_field("taskType", "2")
|
||
form_data.add_field("source", stream_request.source_url)
|
||
form_data.add_field("notifyUrl", f"{stream_request.source_url}+ai")
|
||
|
||
url = "https://flightcontrol.huaiying-xunjian.com/prod-api/third/api/v1/task/startTask"
|
||
async with session.post(url, data=form_data) as resp:
|
||
if resp.status != 200:
|
||
raise Exception(f"外部服务状态码异常: {resp.status}")
|
||
result = await resp.json()
|
||
if result.get("code") != 0 or not result.get("data", {}).get("liveUrl"):
|
||
raise Exception(f"接口响应错误: {result}")
|
||
live_url = result["data"]["liveUrl"]
|
||
logger.info(f"外部接口返回推流地址: {live_url}")
|
||
stream_request.push_url = live_url # 替换推流地址
|
||
except Exception as ext:
|
||
logger.error(f"调用外部直播任务接口失败: {ext}")
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"调用直播任务接口失败: {str(ext)}"
|
||
}, status=500)
|
||
else:
|
||
# 启动YOLO检测
|
||
try:
|
||
await asyncio.to_thread(
|
||
startAIVideo,
|
||
stream_request.source_url,
|
||
stream_request.push_url,
|
||
stream_request.model_path,
|
||
stream_request.detect_classes,
|
||
stream_request.confidence
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"启动AI视频处理失败: {e}")
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"Failed to start AI video processing: {str(e)}"
|
||
}, status=500)
|
||
|
||
# 记录任务信息
|
||
task_manager.add_task(task_id, {
|
||
"source_url": stream_request.source_url,
|
||
"push_url": stream_request.push_url,
|
||
"model_path": stream_request.model_path,
|
||
"detect_classes": stream_request.detect_classes,
|
||
"confidence": stream_request.confidence
|
||
})
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"task_id": task_id,
|
||
"message": "Detection started successfully"
|
||
})
|
||
|
||
except ValueError as e:
|
||
logger.error(f"Validation error: {str(e)}")
|
||
return json_response({"status": "error", "message": str(e)}, status=400)
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
|
||
@app.post("/ai/stream/<task_id>")
|
||
async def stop_detection(request, task_id: str):
|
||
try:
|
||
verify_token(request)
|
||
|
||
# 检查任务是否存在
|
||
try:
|
||
task_info = task_manager.get_task_info(task_id)
|
||
except NotFound:
|
||
return json_response({"status": "error", "message": "Task not found"}, status=404)
|
||
|
||
# 停止AI视频处理,使用安全的停止方法
|
||
success = await safe_stop_ai_video()
|
||
|
||
# 即使停止失败,也要移除任务
|
||
task_manager.remove_task(task_id)
|
||
|
||
if not success:
|
||
logger.warning("虽然停止过程出现错误,但任务已被标记为结束")
|
||
return json_response({
|
||
"status": "warning",
|
||
"message": "Task removal completed with warnings"
|
||
})
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"message": "Detection stopped successfully"
|
||
})
|
||
except NotFound as e:
|
||
return json_response({"status": "error", "message": str(e)}, status=404)
|
||
except Exception as e:
|
||
logger.error(f"Error stopping task {task_id}: {str(e)}", exc_info=True)
|
||
# 尝试标记任务为停止状态
|
||
try:
|
||
if task_id in task_manager.task_status:
|
||
task_manager.task_status[task_id] = "error_during_stop"
|
||
except:
|
||
pass
|
||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
@app.get("/ai/stream/<task_id>")
|
||
async def get_task_status(request, task_id: str):
|
||
try:
|
||
verify_token(request)
|
||
task_info = task_manager.get_task_info(task_id)
|
||
|
||
# 检查任务是否真的在运行
|
||
if not getIfAI() and task_info["status"] == "running":
|
||
task_info["status"] = "stopped_unexpectedly"
|
||
logger.warning(f"Task {task_id} 显示为运行状态,但实际已停止")
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"task_id": task_id,
|
||
**task_info
|
||
})
|
||
except NotFound as e:
|
||
return json_response({"status": "error", "message": str(e)}, status=404)
|
||
except Exception as e:
|
||
logger.error(f"Error getting task status {task_id}: {str(e)}", exc_info=True)
|
||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
@app.get("/ai/stream/tasks")
|
||
async def list_tasks(request):
|
||
"""获取所有活动任务列表"""
|
||
try:
|
||
verify_token(request)
|
||
|
||
# 检查所有任务的健康状态
|
||
unhealthy_tasks = task_manager.check_tasks_health()
|
||
for task_id, status in unhealthy_tasks.items():
|
||
if task_id in task_manager.task_status:
|
||
task_manager.task_status[task_id] = status
|
||
|
||
tasks = {
|
||
task_id: task_manager.get_task_info(task_id)
|
||
for task_id in task_manager.active_tasks.keys()
|
||
}
|
||
return json_response({
|
||
"status": "success",
|
||
"tasks": tasks
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"Error listing tasks: {str(e)}", exc_info=True)
|
||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
@app.post("/ai/stream/stopTasks")
|
||
async def stop_all_detections(request):
|
||
"""停止所有活动任务"""
|
||
try:
|
||
verify_token(request)
|
||
|
||
if not task_manager.active_tasks:
|
||
return json_response({
|
||
"status": "success",
|
||
"message": "No active tasks to stop"
|
||
})
|
||
|
||
# 停止所有任务
|
||
success = await safe_stop_ai_video()
|
||
|
||
# 无论成功与否,都移除所有任务
|
||
for task_id in list(task_manager.active_tasks.keys()):
|
||
task_manager.remove_task(task_id)
|
||
|
||
if not success:
|
||
return json_response({
|
||
"status": "warning",
|
||
"message": "Tasks stopped with warnings"
|
||
})
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"message": "All detections stopped successfully"
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"Error stopping all tasks: {str(e)}", exc_info=True)
|
||
# 尝试标记所有任务为停止状态
|
||
task_manager.mark_all_tasks_as_stopped()
|
||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
@app.get("/ai/health")
|
||
async def health_check(request):
|
||
"""服务健康检查端点"""
|
||
try:
|
||
# 不需要验证token,这个接口可以用于监控系统检查服务状态
|
||
unhealthy_tasks = task_manager.check_tasks_health()
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"service": "running" if service_status["is_healthy"] else "degraded",
|
||
"active_tasks": len(task_manager.active_tasks),
|
||
"unhealthy_tasks": unhealthy_tasks,
|
||
"last_error": service_status["last_error"],
|
||
"error_time": service_status["error_time"],
|
||
"timestamp": datetime.now().isoformat()
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"健康检查失败: {str(e)}", exc_info=True)
|
||
return json_response({
|
||
"status": "error",
|
||
"service": "degraded",
|
||
"error": str(e),
|
||
"timestamp": datetime.now().isoformat()
|
||
}, status=500)
|
||
|
||
@app.route("/ai/reset", methods=["POST"])
|
||
async def reset_service(request):
|
||
"""重置服务状态,清理所有任务和进程"""
|
||
try:
|
||
verify_token(request)
|
||
|
||
# 尝试停止AI视频处理
|
||
await safe_stop_ai_video()
|
||
|
||
# 清理所有任务
|
||
for task_id in list(task_manager.active_tasks.keys()):
|
||
task_manager.remove_task(task_id)
|
||
|
||
# 重置服务状态
|
||
service_status["is_healthy"] = True
|
||
service_status["last_error"] = None
|
||
service_status["error_time"] = None
|
||
|
||
# 尝试清理可能存在的僵尸进程
|
||
try:
|
||
import os
|
||
import signal
|
||
import psutil
|
||
|
||
current_process = psutil.Process(os.getpid())
|
||
zombie_count = 0
|
||
|
||
for child in current_process.children(recursive=True):
|
||
try:
|
||
if child.status() == psutil.STATUS_ZOMBIE:
|
||
zombie_count += 1
|
||
child.send_signal(signal.SIGKILL)
|
||
except:
|
||
pass
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"message": f"Service reset successful. Cleaned {zombie_count} zombie processes."
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"清理僵尸进程时出错: {e}")
|
||
return json_response({
|
||
"status": "warning",
|
||
"message": "Service reset with warnings"
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"重置服务时出错: {str(e)}", exc_info=True)
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"Failed to reset service: {str(e)}"
|
||
}, status=500)
|
||
|
||
@app.route("/ai/stream/restart/<task_id>", methods=["POST"])
|
||
async def restart_task(request, task_id: str):
|
||
"""重启指定任务"""
|
||
try:
|
||
verify_token(request)
|
||
|
||
# 获取任务信息
|
||
try:
|
||
task_info = task_manager.get_task_info(task_id)["task_info"]
|
||
except NotFound:
|
||
return json_response({"status": "error", "message": "Task not found"}, status=404)
|
||
|
||
# 先停止任务
|
||
success = await safe_stop_ai_video()
|
||
task_manager.remove_task(task_id)
|
||
|
||
if not success:
|
||
logger.warning("停止任务出现问题,尝试继续重启")
|
||
|
||
# 重新启动任务
|
||
new_task_id = str(uuid.uuid4())
|
||
try:
|
||
await asyncio.to_thread(
|
||
startAIVideo,
|
||
task_info["source_url"],
|
||
task_info["push_url"],
|
||
task_info["model_path"],
|
||
task_info["detect_classes"],
|
||
task_info["confidence"]
|
||
)
|
||
|
||
# 记录新任务信息
|
||
task_manager.add_task(new_task_id, task_info)
|
||
|
||
return json_response({
|
||
"status": "success",
|
||
"old_task_id": task_id,
|
||
"new_task_id": new_task_id,
|
||
"message": "Task restarted successfully"
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"重启任务失败: {e}")
|
||
return json_response({
|
||
"status": "error",
|
||
"message": f"Failed to restart task: {str(e)}"
|
||
}, status=500)
|
||
|
||
except Exception as e:
|
||
logger.error(f"重启任务时出错: {str(e)}", exc_info=True)
|
||
return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
if __name__ == "__main__":
|
||
# 保证服务启动前没有残留任务
|
||
try:
|
||
stopAIVideo()
|
||
print("服务启动前清理完成")
|
||
except:
|
||
print("服务启动前清理失败,但仍将继续")
|
||
|
||
# 安装psutil库,用于进程管理
|
||
try:
|
||
import psutil
|
||
except ImportError:
|
||
import subprocess
|
||
import sys
|
||
print("正在安装psutil库...")
|
||
subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
|
||
|
||
app.run(host="0.0.0.0", port=12315, debug=False, access_log=True)
|