788 lines
24 KiB
Python
788 lines
24 KiB
Python
# import time
|
||
#
|
||
# from sanic.response import json as json_response
|
||
# from sanic.exceptions import Unauthorized, NotFound, SanicException
|
||
#
|
||
# import logging
|
||
# import asyncio
|
||
#
|
||
# from sanic_cors import CORS
|
||
#
|
||
# from sanic import Sanic, Request
|
||
#
|
||
# from download_train import download_train
|
||
#
|
||
# # 配置日志
|
||
# logging.basicConfig(
|
||
# level=logging.INFO,
|
||
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
# )
|
||
# logger = logging.getLogger(__name__)
|
||
#
|
||
# DB_CONFIG = {
|
||
# "dbname": "smart_dev_123",
|
||
# "user": "postgres",
|
||
# "password": "root",
|
||
# "host": "8.137.54.85",
|
||
# "port": "5060"
|
||
# }
|
||
#
|
||
#
|
||
# # 配置类
|
||
# class Config:
|
||
# VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
||
# MAX_ACTIVE_TASKS = 10
|
||
# DEFAULT_CONFIDENCE = 0.5
|
||
# RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
||
#
|
||
# app = Sanic("YoloStreamService2")
|
||
# CORS(app)
|
||
#
|
||
#
|
||
# def verify_token(request) -> None:
|
||
# """验证请求token"""
|
||
# token = request.headers.get('X-API-Token')
|
||
# if not token or token != Config.VALID_TOKEN:
|
||
# logger.warning("Invalid token attempt")
|
||
# raise Unauthorized("Invalid token")
|
||
#
|
||
#
|
||
# # 未针对具体方法做实现,待完成
|
||
# async def run_train(request_json):
|
||
# try:
|
||
# print("12121")
|
||
# except Exception as e:
|
||
# logger.error(f"启动AI视频处理失败: {e}")
|
||
# raise SanicException(f"Failed to start AI video processing: {str(e)}", status_code=500)
|
||
#
|
||
#
|
||
# # 接收前端实时流,进行任务训练
|
||
# @app.post("/ai/project/train")
|
||
# async def start_train(request):
|
||
# try:
|
||
# verify_token(request)
|
||
#
|
||
# # 解析并验证请求数据
|
||
# request_json=request.json
|
||
# task_id=request_json["task_id"]
|
||
# train_task_id=request_json["train_task_id"]
|
||
# time_ns=time.time_ns()
|
||
# pt_name=f"{time_ns}-{task_id}.pt"
|
||
#
|
||
# print(f"task_id {task_id}")
|
||
#
|
||
# # result = await download_train(task_id,train_task_id,pt_name) # 等待异步函数执行完成
|
||
# asyncio.create_task(download_train(task_id, train_task_id, pt_name))
|
||
#
|
||
# return json_response({
|
||
# "status": "success",
|
||
# "task_id": task_id,
|
||
# "pt_name": pt_name,
|
||
# "time_ns": time_ns,
|
||
# "process_id":123,
|
||
# "message": "task started successfully"
|
||
# })
|
||
#
|
||
# except ValueError as e:
|
||
# logger.error(f"Validation error: {str(e)}")
|
||
# return json_response({"status": "error", "message": str(e)}, status=400)
|
||
# except Exception as e:
|
||
# logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||
# return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
#
|
||
#
|
||
# # 接收前端实时流,进行任务推理
|
||
# @app.post("/ai/project/inference")
|
||
# async def start_inference(request):
|
||
# try:
|
||
# verify_token(request)
|
||
# # 解析并验证请求数据
|
||
# request_json = request.json
|
||
# task_id = request_json["task_id"]
|
||
# time_ns = time.time_ns()
|
||
# pt_name = f"{time_ns}-{task_id}.pt"
|
||
#
|
||
# print(f"task_id {task_id}")
|
||
#
|
||
# return json_response({
|
||
# "status": "success",
|
||
# "task_id": task_id,
|
||
# "message": "Detection started successfully"
|
||
# })
|
||
# except ValueError as e:
|
||
# logger.error(f"Validation error: {str(e)}")
|
||
# return json_response({"status": "error", "message": str(e)}, status=400)
|
||
# except Exception as e:
|
||
# logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||
# return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
#
|
||
#
|
||
#
|
||
# # 接收前端实时流,进行任务查询
|
||
# @app.post("/ai/project/query_train_task")
|
||
# async def query_train_task(request):
|
||
# try:
|
||
# verify_token(request)
|
||
#
|
||
# # 解析并验证请求数据
|
||
# request_json = request.json
|
||
# task_id = request_json["task_id"]
|
||
# time_ns = time.time_ns()
|
||
# pt_name = f"{time_ns}-{task_id}.pt"
|
||
#
|
||
# print(f"task_id {task_id}")
|
||
#
|
||
# return json_response({
|
||
# "task_id": task_id,
|
||
# "process_id":123,
|
||
# "status": "Running"
|
||
# })
|
||
#
|
||
# except ValueError as e:
|
||
# logger.error(f"Validation error: {str(e)}")
|
||
# return json_response({"status": "error", "message": str(e)}, status=400)
|
||
# except Exception as e:
|
||
# logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
||
# return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
#
|
||
#
|
||
#
|
||
# if __name__ == "__main__":
|
||
# # 保证服务启动前没有残留任务
|
||
# # 安装psutil库,用于进程管理
|
||
# try:
|
||
# import psutil
|
||
# except ImportError:
|
||
# import subprocess
|
||
# import sys
|
||
#
|
||
# print("正在安装psutil库...")
|
||
# subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
|
||
#
|
||
# # app.run(host="0.0.0.0", port=12316, debug=False, access_log=True)
|
||
# app.run(host="0.0.0.0", port=12325)
|
||
|
||
|
||
from sanic import Sanic, response
|
||
from sanic.exceptions import SanicException
|
||
from sanic.request import Request
|
||
import os
|
||
import asyncio
|
||
import signal
|
||
import sys
|
||
import json
|
||
import time
|
||
import subprocess
|
||
from typing import Dict, Optional, List, Any
|
||
from pathlib import Path
|
||
|
||
from download_train import download_train
|
||
from predict.predict_yolo11seg import predict_images
|
||
from predict.predict_yolo11seg import predict_images_share_dir
|
||
from query_process_status import get_process_status
|
||
|
||
|
||
# 配置
|
||
class Config:
|
||
MAX_ACTIVE_TASKS = 10 # 限制最大并发任务数
|
||
TASK_EXPIRY_SECONDS = 86400 # 任务过期时间(24小时)
|
||
LOG_FILE = "yolo_service.log"
|
||
PID_FILE = "service.pid"
|
||
HOST = "0.0.0.0"
|
||
PORT = 12325
|
||
DEBUG = False
|
||
WORKERS = 1 # 单worker模式运行
|
||
|
||
|
||
# 任务状态
|
||
class TaskStatus:
|
||
PENDING = "pending"
|
||
RUNNING = "running"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
CANCELLED = "cancelled"
|
||
|
||
|
||
# 任务信息
|
||
class TaskInfo:
|
||
def __init__(self, task_id: str, bz_training_task_id: int, pt_name: str):
|
||
self.task_id = task_id
|
||
self.bz_training_task_id = bz_training_task_id
|
||
self.pt_name = pt_name
|
||
self.status = TaskStatus.PENDING
|
||
self.pid = None # 训练进程ID
|
||
self.start_time = time.time()
|
||
self.end_time = None
|
||
self.error_message = None
|
||
self.progress = 0.0 # 训练进度(百分比)
|
||
self.result = None # 训练结果
|
||
|
||
def to_dict(self) -> Dict:
|
||
"""转换为字典用于JSON序列化"""
|
||
return {
|
||
"task_id": self.task_id,
|
||
"bz_training_task_id": self.bz_training_task_id,
|
||
"pt_name": self.pt_name,
|
||
"status": self.status,
|
||
"pid": self.pid,
|
||
"start_time": self.start_time,
|
||
"end_time": self.end_time,
|
||
"error_message": self.error_message,
|
||
"progress": self.progress,
|
||
"result": self.result
|
||
}
|
||
|
||
|
||
# 应用初始化
|
||
app = Sanic("YoloStreamService2")
|
||
app.config.update({
|
||
"DEBUG": Config.DEBUG,
|
||
"WORKERS": Config.WORKERS
|
||
})
|
||
|
||
# 全局状态
|
||
active_tasks: Dict[str, TaskInfo] = {}
|
||
task_counter = 0
|
||
|
||
|
||
# 保存PID文件
|
||
def save_pid_file():
|
||
"""保存服务PID到文件"""
|
||
try:
|
||
with open(Config.PID_FILE, 'w') as f:
|
||
f.write(str(os.getpid()))
|
||
print(f"Service PID {os.getpid()} saved to {Config.PID_FILE}")
|
||
except Exception as e:
|
||
print(f"Failed to save PID file: {e}")
|
||
|
||
|
||
# 删除PID文件
|
||
def remove_pid_file():
|
||
"""删除PID文件"""
|
||
try:
|
||
if os.path.exists(Config.PID_FILE):
|
||
os.remove(Config.PID_FILE)
|
||
print(f"PID file {Config.PID_FILE} removed")
|
||
except Exception as e:
|
||
print(f"Failed to remove PID file: {e}")
|
||
|
||
|
||
# 生成唯一任务ID
|
||
def generate_task_id() -> str:
|
||
"""生成唯一任务ID"""
|
||
global task_counter
|
||
task_counter += 1
|
||
return f"yolo_task_{int(time.time() * 1000)}_{task_counter}"
|
||
|
||
|
||
# 检查任务是否存在
|
||
def task_exists(task_id: str) -> bool:
|
||
"""检查任务是否存在"""
|
||
return task_id in active_tasks
|
||
|
||
|
||
# 获取任务信息
|
||
def get_task_info(task_id: str) -> Optional[TaskInfo]:
|
||
"""获取任务信息"""
|
||
return active_tasks.get(task_id)
|
||
|
||
|
||
# 更新任务状态
|
||
def update_task_status(task_id: str, status: str, **kwargs):
|
||
"""更新任务状态"""
|
||
if task_id not in active_tasks:
|
||
return
|
||
|
||
task = active_tasks[task_id]
|
||
task.status = status
|
||
|
||
# 更新其他属性
|
||
for key, value in kwargs.items():
|
||
if hasattr(task, key):
|
||
setattr(task, key, value)
|
||
|
||
print(f"Task {task_id} status updated to {status}")
|
||
|
||
|
||
# 清理过期任务
|
||
async def cleanup_expired_tasks():
|
||
"""清理过期任务"""
|
||
current_time = time.time()
|
||
expired_tasks = []
|
||
|
||
for task_id, task in active_tasks.items():
|
||
if (current_time - task.start_time) > Config.TASK_EXPIRY_SECONDS:
|
||
expired_tasks.append(task_id)
|
||
|
||
for task_id in expired_tasks:
|
||
print(f"Removing expired task: {task_id}")
|
||
del active_tasks[task_id]
|
||
|
||
|
||
# 定期清理任务
|
||
async def scheduled_cleanup():
|
||
"""定期清理过期任务"""
|
||
while True:
|
||
await cleanup_expired_tasks()
|
||
await asyncio.sleep(3600) # 每小时清理一次
|
||
|
||
|
||
# 检查训练进程状态
|
||
async def check_training_processes():
|
||
"""检查训练进程状态"""
|
||
while True:
|
||
current_time = time.time()
|
||
|
||
for task_id, task in list(active_tasks.items()):
|
||
if task.status == TaskStatus.RUNNING and task.pid:
|
||
try:
|
||
# 检查进程是否存在
|
||
if sys.platform == 'win32':
|
||
# Windows平台检查进程的方法
|
||
import ctypes
|
||
kernel32 = ctypes.WinDLL('kernel32')
|
||
handle = kernel32.OpenProcess(1, 0, task.pid)
|
||
if handle == 0:
|
||
# 进程不存在
|
||
raise OSError(f"Process {task.pid} not found")
|
||
kernel32.CloseHandle(handle)
|
||
else:
|
||
# Unix-like平台
|
||
os.kill(task.pid, 0)
|
||
except (OSError, ctypes.WinError):
|
||
# 进程不存在,更新任务状态
|
||
task.status = TaskStatus.COMPLETED
|
||
task.end_time = current_time
|
||
print(f"Training process {task.pid} for task {task_id} has completed")
|
||
|
||
await asyncio.sleep(30) # 每30秒检查一次
|
||
|
||
|
||
# 训练接口
|
||
@app.route("/ai/project/train", methods=["POST"])
|
||
async def train_project(request: Request):
|
||
"""
|
||
训练项目接口
|
||
请求体格式:
|
||
{
|
||
"bz_training_task_id": 123,
|
||
"pt_name": "custom_model.pt"
|
||
}
|
||
"""
|
||
try:
|
||
# 解析请求数据
|
||
data = request.json
|
||
if not data:
|
||
return response.json(
|
||
{"code": 400, "message": "Invalid request body", "data": None},
|
||
status=400
|
||
)
|
||
time_ns = time.time_ns()
|
||
bz_training_task_id = data.get("train_task_id")
|
||
task_id = data.get("task_id")
|
||
pt_name = f"{time_ns}-{task_id}.pt"
|
||
|
||
if not bz_training_task_id:
|
||
return response.json(
|
||
{"code": 400, "message": "train_task_id is required", "data": None},
|
||
status=400
|
||
)
|
||
|
||
# 检查并发任务数
|
||
running_tasks = sum(1 for task in active_tasks.values()
|
||
if task.status == TaskStatus.RUNNING)
|
||
|
||
if running_tasks >= Config.MAX_ACTIVE_TASKS:
|
||
return response.json(
|
||
{"code": 429, "message": f"Too many active tasks. Max allowed: {Config.MAX_ACTIVE_TASKS}",
|
||
"data": None},
|
||
status=429
|
||
)
|
||
|
||
# 生成任务ID
|
||
task_id_new = generate_task_id()
|
||
|
||
# 创建任务信息
|
||
task_info = TaskInfo(task_id_new, bz_training_task_id, pt_name)
|
||
active_tasks[task_id] = task_info
|
||
|
||
print(f"Created new training task: {task_id} for bz_training_task_id: {bz_training_task_id}")
|
||
|
||
try:
|
||
# 启动下载和训练准备(不等待完成)
|
||
asyncio.create_task(download_train(task_id, bz_training_task_id, pt_name))
|
||
|
||
return response.json({
|
||
"status": "success",
|
||
"code": 200,
|
||
"task_id": task_id,
|
||
"pt_name": pt_name,
|
||
"time_ns": time_ns,
|
||
"message": "task started successfully"
|
||
})
|
||
except Exception as e:
|
||
# 训练准备失败
|
||
del active_tasks[task_id]
|
||
print(f"Training preparation failed for task {task_id}: {e}", exc_info=True)
|
||
return response.json(
|
||
{"code": 500, "message": f"Training preparation failed: {str(e)}", "data": None},
|
||
status=500
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"Error in train_project endpoint: {e}", exc_info=True)
|
||
return response.json(
|
||
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
||
status=500
|
||
)
|
||
|
||
|
||
|
||
|
||
|
||
# 训练接口
|
||
@app.route("/ai/project/query_train_task", methods=["POST"])
|
||
async def query_train_task(request: Request):
|
||
"""
|
||
训练项目接口
|
||
请求体格式:
|
||
|
||
"""
|
||
try:
|
||
# 解析请求数据
|
||
data = request.json
|
||
if not data:
|
||
return response.json(
|
||
{"code": 400, "message": "Invalid request body", "data": None},
|
||
status=400
|
||
)
|
||
time_ns = time.time_ns()
|
||
task_id = data.get("task_id")
|
||
process_id = data.get("process_id")
|
||
|
||
try:
|
||
# 启动下载和训练准备(不等待完成)
|
||
process_info=get_process_status(process_id)
|
||
return response.json({
|
||
"status": "success",
|
||
"code": 200,
|
||
"task_id": task_id,
|
||
"time_ns": time_ns,
|
||
"process_info":process_info
|
||
})
|
||
except Exception as e:
|
||
# 训练准备失败
|
||
del active_tasks[task_id]
|
||
print(f"Training preparation failed for task {task_id}: {e}", exc_info=True)
|
||
return response.json(
|
||
{"code": 500, "message": f"Training preparation failed: {str(e)}", "data": None},
|
||
status=500
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"Error in train_project endpoint: {e}", exc_info=True)
|
||
return response.json(
|
||
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
||
status=500
|
||
)
|
||
|
||
|
||
|
||
# 接收前端实时流,进行任务推理-共享目录
|
||
@app.post("/ai/project/inference4ShareDir")
|
||
async def start_inference_share_dir(request):
|
||
try:
|
||
# 解析并验证请求数据
|
||
request_json = request.json
|
||
task_id = request_json["task_id"]
|
||
pt_name = request_json["pt_name"]
|
||
zip_url = request_json["zip_url"]
|
||
user_name = request_json["user_name"]
|
||
pwd = request_json["pwd"]
|
||
time_ns = time.time_ns()
|
||
# pt_name = f"{time_ns}-{task_id}.pt"
|
||
# model_path=r"pt_save\best.pt"
|
||
print(f"task_id {task_id}")
|
||
|
||
if user_name == "":
|
||
user_name = "administrator"
|
||
if pwd == "":
|
||
pwd = "abc@1234"
|
||
|
||
output_dir = f"predictions/{task_id}"
|
||
inference_zip_url,message=predict_images_share_dir(pt_name, zip_url, user_name, pwd, output_dir=output_dir, conf_threshold=0.25, save_json=False)
|
||
if inference_zip_url:
|
||
return response.json({
|
||
"status": "success",
|
||
"task_id": task_id,
|
||
"inference_zip_url":inference_zip_url,
|
||
"message": "predict request successfully"
|
||
})
|
||
else:
|
||
return response.json({
|
||
"status": "fail",
|
||
"task_id": task_id,
|
||
"inference_zip_url":inference_zip_url,
|
||
"message": message
|
||
})
|
||
except ValueError as e:
|
||
print(f"Validation error: {str(e)}")
|
||
return response.json({"status": "error", "message": str(e)}, status=400)
|
||
except Exception as e:
|
||
print(f"Unexpected error: {str(e)}")
|
||
return response.json({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
|
||
|
||
# 接收前端实时流,进行任务推理
|
||
@app.post("/ai/project/inference")
|
||
async def start_inference(request):
|
||
try:
|
||
# 解析并验证请求数据
|
||
request_json = request.json
|
||
task_id = request_json["task_id"]
|
||
pt_name = request_json["pt_name"]
|
||
zip_url = request_json["zip_url"]
|
||
time_ns = time.time_ns()
|
||
# pt_name = f"{time_ns}-{task_id}.pt"
|
||
# model_path=r"pt_save\best.pt"
|
||
print(f"task_id {task_id}")
|
||
inference_zip_url,message=predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.25, save_json=False)
|
||
if inference_zip_url:
|
||
return response.json({
|
||
"status": "success",
|
||
"task_id": task_id,
|
||
"inference_zip_url":inference_zip_url,
|
||
"message": "predict successfully"
|
||
})
|
||
else:
|
||
return response.json({
|
||
"status": "fail",
|
||
"task_id": task_id,
|
||
"inference_zip_url":inference_zip_url,
|
||
"message": message
|
||
})
|
||
except ValueError as e:
|
||
print(f"Validation error: {str(e)}")
|
||
return response.json({"status": "error", "message": str(e)}, status=400)
|
||
except Exception as e:
|
||
print(f"Unexpected error: {str(e)}")
|
||
return response.json({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
||
|
||
|
||
|
||
|
||
# 获取任务状态接口
|
||
@app.route("/ai/project/task/status/<task_id>", methods=["GET"])
|
||
async def get_task_status(request: Request, task_id: str):
|
||
"""获取任务状态"""
|
||
try:
|
||
if not task_exists(task_id):
|
||
return response.json(
|
||
{"code": 404, "message": f"Task {task_id} not found", "data": None},
|
||
status=404
|
||
)
|
||
|
||
task_info = get_task_info(task_id)
|
||
return response.json({
|
||
"code": 200,
|
||
"message": "success",
|
||
"data": task_info.to_dict()
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Error in get_task_status endpoint: {e}", exc_info=True)
|
||
return response.json(
|
||
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
||
status=500
|
||
)
|
||
|
||
|
||
# 获取所有任务状态接口
|
||
@app.route("/ai/project/tasks", methods=["GET"])
|
||
async def get_all_tasks(request: Request):
|
||
"""获取所有任务状态"""
|
||
try:
|
||
tasks_data = [task.to_dict() for task in active_tasks.values()]
|
||
return response.json({
|
||
"code": 200,
|
||
"message": "success",
|
||
"data": {
|
||
"total_tasks": len(tasks_data),
|
||
"tasks": tasks_data
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Error in get_all_tasks endpoint: {e}", exc_info=True)
|
||
return response.json(
|
||
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
||
status=500
|
||
)
|
||
|
||
|
||
# 取消任务接口
|
||
@app.route("/ai/project/task/cancel/<task_id>", methods=["POST"])
|
||
async def cancel_task(request: Request, task_id: str):
|
||
"""取消任务"""
|
||
try:
|
||
if not task_exists(task_id):
|
||
return response.json(
|
||
{"code": 404, "message": f"Task {task_id} not found", "data": None},
|
||
status=404
|
||
)
|
||
|
||
task_info = get_task_info(task_id)
|
||
|
||
if task_info.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||
return response.json(
|
||
{"code": 400, "message": f"Task {task_id} is already {task_info.status}", "data": None},
|
||
status=400
|
||
)
|
||
|
||
# 尝试终止训练进程
|
||
if task_info.pid:
|
||
try:
|
||
if sys.platform == 'win32':
|
||
# Windows平台终止进程的方法
|
||
subprocess.run(['taskkill', '/F', '/T', '/PID', str(task_info.pid)],
|
||
check=True, capture_output=True, text=True)
|
||
else:
|
||
# Unix-like平台
|
||
os.kill(task_info.pid, signal.SIGTERM)
|
||
|
||
print(f"Sent termination signal to training process {task_info.pid} for task {task_id}")
|
||
|
||
except Exception as e:
|
||
print(f"Failed to terminate training process {task_info.pid}: {e}")
|
||
|
||
# 更新任务状态
|
||
update_task_status(
|
||
task_id,
|
||
TaskStatus.CANCELLED,
|
||
end_time=time.time(),
|
||
error_message="Task cancelled by user"
|
||
)
|
||
|
||
return response.json({
|
||
"code": 200,
|
||
"message": f"Task {task_id} cancelled successfully",
|
||
"data": task_info.to_dict()
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Error in cancel_task endpoint: {e}", exc_info=True)
|
||
return response.json(
|
||
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
||
status=500
|
||
)
|
||
|
||
|
||
# 健康检查接口
|
||
@app.route("/health", methods=["GET"])
|
||
async def health_check(request: Request):
|
||
"""健康检查"""
|
||
try:
|
||
running_tasks = sum(1 for task in active_tasks.values()
|
||
if task.status == TaskStatus.RUNNING)
|
||
|
||
return response.json({
|
||
"code": 200,
|
||
"message": "Service is running",
|
||
"data": {
|
||
"status": "healthy",
|
||
"active_tasks": len(active_tasks),
|
||
"running_tasks": running_tasks,
|
||
"max_active_tasks": Config.MAX_ACTIVE_TASKS,
|
||
"timestamp": time.time()
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
print(f"Health check failed: {e}", exc_info=True)
|
||
return response.json(
|
||
{"code": 500, "message": "Service is unhealthy", "data": None},
|
||
status=500
|
||
)
|
||
|
||
|
||
# 应用启动事件
|
||
@app.listener('before_server_start')
|
||
async def before_server_start(app, loop):
|
||
"""服务器启动前事件"""
|
||
try:
|
||
# 保存PID文件
|
||
save_pid_file()
|
||
|
||
# 启动定期清理任务
|
||
loop.create_task(scheduled_cleanup())
|
||
|
||
# 启动进程监控任务
|
||
loop.create_task(check_training_processes())
|
||
|
||
print("Yolo Stream Service started successfully")
|
||
|
||
except Exception as e:
|
||
print(f"Error during server startup: {e}", exc_info=True)
|
||
raise
|
||
|
||
|
||
# 应用停止事件
|
||
@app.listener('after_server_stop')
|
||
async def after_server_stop(app, loop):
|
||
"""服务器停止后事件"""
|
||
try:
|
||
# 删除PID文件
|
||
remove_pid_file()
|
||
|
||
# 清理所有训练进程
|
||
for task_id, task in active_tasks.items():
|
||
if task.status == TaskStatus.RUNNING and task.pid:
|
||
try:
|
||
if sys.platform == 'win32':
|
||
subprocess.run(['taskkill', '/F', '/T', '/PID', str(task.pid)],
|
||
capture_output=True, text=True)
|
||
else:
|
||
os.kill(task.pid, signal.SIGTERM)
|
||
print(f"Terminated training process {task.pid} for task {task_id}")
|
||
except Exception as e:
|
||
print(f"Failed to terminate process {task.pid}: {e}")
|
||
|
||
print("Yolo Stream Service stopped gracefully")
|
||
|
||
except Exception as e:
|
||
print(f"Error during server shutdown: {e}", exc_info=True)
|
||
|
||
|
||
# 信号处理
|
||
def handle_signal(signal_num, frame):
|
||
"""信号处理函数"""
|
||
print(f"Received signal {signal_num}. Shutting down...")
|
||
|
||
# 优雅关闭应用
|
||
asyncio.create_task(app.stop())
|
||
|
||
|
||
# 注册信号处理
|
||
if sys.platform != 'win32': # Windows不支持某些信号
|
||
signal.signal(signal.SIGINT, handle_signal)
|
||
signal.signal(signal.SIGTERM, handle_signal)
|
||
|
||
# 主函数
|
||
if __name__ == "__main__":
|
||
"""应用入口点"""
|
||
try:
|
||
# 运行应用
|
||
app.run(
|
||
host=Config.HOST,
|
||
port=Config.PORT,
|
||
workers=Config.WORKERS,
|
||
debug=Config.DEBUG,
|
||
access_log=True
|
||
)
|
||
|
||
except Exception as e:
|
||
print(f"Failed to start Yolo Stream Service: {e}", exc_info=True)
|
||
remove_pid_file()
|
||
sys.exit(1)
|