741 lines
23 KiB
Python
741 lines
23 KiB
Python
|
|
# import time
|
|||
|
|
#
|
|||
|
|
# from sanic.response import json as json_response
|
|||
|
|
# from sanic.exceptions import Unauthorized, NotFound, SanicException
|
|||
|
|
#
|
|||
|
|
# import logging
|
|||
|
|
# import asyncio
|
|||
|
|
#
|
|||
|
|
# from sanic_cors import CORS
|
|||
|
|
#
|
|||
|
|
# from sanic import Sanic, Request
|
|||
|
|
#
|
|||
|
|
# from download_train import download_train
|
|||
|
|
#
|
|||
|
|
# # 配置日志
|
|||
|
|
# logging.basicConfig(
|
|||
|
|
# level=logging.INFO,
|
|||
|
|
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|||
|
|
# )
|
|||
|
|
# logger = logging.getLogger(__name__)
|
|||
|
|
#
|
|||
|
|
# DB_CONFIG = {
|
|||
|
|
# "dbname": "smart_dev_123",
|
|||
|
|
# "user": "postgres",
|
|||
|
|
# "password": "root",
|
|||
|
|
# "host": "8.137.54.85",
|
|||
|
|
# "port": "5060"
|
|||
|
|
# }
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# # 配置类
|
|||
|
|
# class Config:
|
|||
|
|
# VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa"
|
|||
|
|
# MAX_ACTIVE_TASKS = 10
|
|||
|
|
# DEFAULT_CONFIDENCE = 0.5
|
|||
|
|
# RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒)
|
|||
|
|
#
|
|||
|
|
# app = Sanic("YoloStreamService2")
|
|||
|
|
# CORS(app)
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# def verify_token(request) -> None:
|
|||
|
|
# """验证请求token"""
|
|||
|
|
# token = request.headers.get('X-API-Token')
|
|||
|
|
# if not token or token != Config.VALID_TOKEN:
|
|||
|
|
# logger.warning("Invalid token attempt")
|
|||
|
|
# raise Unauthorized("Invalid token")
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# # 未针对具体方法做实现,待完成
|
|||
|
|
# async def run_train(request_json):
|
|||
|
|
# try:
|
|||
|
|
# print("12121")
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logger.error(f"启动AI视频处理失败: {e}")
|
|||
|
|
# raise SanicException(f"Failed to start AI video processing: {str(e)}", status_code=500)
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# # 接收前端实时流,进行任务训练
|
|||
|
|
# @app.post("/ai/project/train")
|
|||
|
|
# async def start_train(request):
|
|||
|
|
# try:
|
|||
|
|
# verify_token(request)
|
|||
|
|
#
|
|||
|
|
# # 解析并验证请求数据
|
|||
|
|
# request_json=request.json
|
|||
|
|
# task_id=request_json["task_id"]
|
|||
|
|
# train_task_id=request_json["train_task_id"]
|
|||
|
|
# time_ns=time.time_ns()
|
|||
|
|
# pt_name=f"{time_ns}-{task_id}.pt"
|
|||
|
|
#
|
|||
|
|
# print(f"task_id {task_id}")
|
|||
|
|
#
|
|||
|
|
# # result = await download_train(task_id,train_task_id,pt_name) # 等待异步函数执行完成
|
|||
|
|
# asyncio.create_task(download_train(task_id, train_task_id, pt_name))
|
|||
|
|
#
|
|||
|
|
# return json_response({
|
|||
|
|
# "status": "success",
|
|||
|
|
# "task_id": task_id,
|
|||
|
|
# "pt_name": pt_name,
|
|||
|
|
# "time_ns": time_ns,
|
|||
|
|
# "process_id":123,
|
|||
|
|
# "message": "task started successfully"
|
|||
|
|
# })
|
|||
|
|
#
|
|||
|
|
# except ValueError as e:
|
|||
|
|
# logger.error(f"Validation error: {str(e)}")
|
|||
|
|
# return json_response({"status": "error", "message": str(e)}, status=400)
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
|||
|
|
# return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# # 接收前端实时流,进行任务推理
|
|||
|
|
# @app.post("/ai/project/inference")
|
|||
|
|
# async def start_inference(request):
|
|||
|
|
# try:
|
|||
|
|
# verify_token(request)
|
|||
|
|
# # 解析并验证请求数据
|
|||
|
|
# request_json = request.json
|
|||
|
|
# task_id = request_json["task_id"]
|
|||
|
|
# time_ns = time.time_ns()
|
|||
|
|
# pt_name = f"{time_ns}-{task_id}.pt"
|
|||
|
|
#
|
|||
|
|
# print(f"task_id {task_id}")
|
|||
|
|
#
|
|||
|
|
# return json_response({
|
|||
|
|
# "status": "success",
|
|||
|
|
# "task_id": task_id,
|
|||
|
|
# "message": "Detection started successfully"
|
|||
|
|
# })
|
|||
|
|
# except ValueError as e:
|
|||
|
|
# logger.error(f"Validation error: {str(e)}")
|
|||
|
|
# return json_response({"status": "error", "message": str(e)}, status=400)
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
|||
|
|
# return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# # 接收前端实时流,进行任务查询
|
|||
|
|
# @app.post("/ai/project/query_train_task")
|
|||
|
|
# async def query_train_task(request):
|
|||
|
|
# try:
|
|||
|
|
# verify_token(request)
|
|||
|
|
#
|
|||
|
|
# # 解析并验证请求数据
|
|||
|
|
# request_json = request.json
|
|||
|
|
# task_id = request_json["task_id"]
|
|||
|
|
# time_ns = time.time_ns()
|
|||
|
|
# pt_name = f"{time_ns}-{task_id}.pt"
|
|||
|
|
#
|
|||
|
|
# print(f"task_id {task_id}")
|
|||
|
|
#
|
|||
|
|
# return json_response({
|
|||
|
|
# "task_id": task_id,
|
|||
|
|
# "process_id":123,
|
|||
|
|
# "status": "Running"
|
|||
|
|
# })
|
|||
|
|
#
|
|||
|
|
# except ValueError as e:
|
|||
|
|
# logger.error(f"Validation error: {str(e)}")
|
|||
|
|
# return json_response({"status": "error", "message": str(e)}, status=400)
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logger.error(f"Unexpected error: {str(e)}", exc_info=True)
|
|||
|
|
# return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# if __name__ == "__main__":
|
|||
|
|
# # 保证服务启动前没有残留任务
|
|||
|
|
# # 安装psutil库,用于进程管理
|
|||
|
|
# try:
|
|||
|
|
# import psutil
|
|||
|
|
# except ImportError:
|
|||
|
|
# import subprocess
|
|||
|
|
# import sys
|
|||
|
|
#
|
|||
|
|
# print("正在安装psutil库...")
|
|||
|
|
# subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"])
|
|||
|
|
#
|
|||
|
|
# # app.run(host="0.0.0.0", port=12316, debug=False, access_log=True)
|
|||
|
|
# app.run(host="0.0.0.0", port=12325)
|
|||
|
|
|
|||
|
|
|
|||
|
|
from sanic import Sanic, response
|
|||
|
|
from sanic.exceptions import SanicException
|
|||
|
|
from sanic.request import Request
|
|||
|
|
import os
|
|||
|
|
import asyncio
|
|||
|
|
import signal
|
|||
|
|
import sys
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
import subprocess
|
|||
|
|
from typing import Dict, Optional, List, Any
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
from download_train import download_train
|
|||
|
|
from predict.predict_yolo11seg import predict_images
|
|||
|
|
from query_process_status import get_process_status
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 配置
|
|||
|
|
class Config:
|
|||
|
|
MAX_ACTIVE_TASKS = 10 # 限制最大并发任务数
|
|||
|
|
TASK_EXPIRY_SECONDS = 86400 # 任务过期时间(24小时)
|
|||
|
|
LOG_FILE = "yolo_service.log"
|
|||
|
|
PID_FILE = "service.pid"
|
|||
|
|
HOST = "0.0.0.0"
|
|||
|
|
PORT = 12325
|
|||
|
|
DEBUG = False
|
|||
|
|
WORKERS = 1 # 单worker模式运行
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 任务状态
|
|||
|
|
class TaskStatus:
|
|||
|
|
PENDING = "pending"
|
|||
|
|
RUNNING = "running"
|
|||
|
|
COMPLETED = "completed"
|
|||
|
|
FAILED = "failed"
|
|||
|
|
CANCELLED = "cancelled"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 任务信息
|
|||
|
|
class TaskInfo:
|
|||
|
|
def __init__(self, task_id: str, bz_training_task_id: int, pt_name: str):
|
|||
|
|
self.task_id = task_id
|
|||
|
|
self.bz_training_task_id = bz_training_task_id
|
|||
|
|
self.pt_name = pt_name
|
|||
|
|
self.status = TaskStatus.PENDING
|
|||
|
|
self.pid = None # 训练进程ID
|
|||
|
|
self.start_time = time.time()
|
|||
|
|
self.end_time = None
|
|||
|
|
self.error_message = None
|
|||
|
|
self.progress = 0.0 # 训练进度(百分比)
|
|||
|
|
self.result = None # 训练结果
|
|||
|
|
|
|||
|
|
def to_dict(self) -> Dict:
|
|||
|
|
"""转换为字典用于JSON序列化"""
|
|||
|
|
return {
|
|||
|
|
"task_id": self.task_id,
|
|||
|
|
"bz_training_task_id": self.bz_training_task_id,
|
|||
|
|
"pt_name": self.pt_name,
|
|||
|
|
"status": self.status,
|
|||
|
|
"pid": self.pid,
|
|||
|
|
"start_time": self.start_time,
|
|||
|
|
"end_time": self.end_time,
|
|||
|
|
"error_message": self.error_message,
|
|||
|
|
"progress": self.progress,
|
|||
|
|
"result": self.result
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 应用初始化
|
|||
|
|
app = Sanic("YoloStreamService2")
|
|||
|
|
app.config.update({
|
|||
|
|
"DEBUG": Config.DEBUG,
|
|||
|
|
"WORKERS": Config.WORKERS
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 全局状态
|
|||
|
|
active_tasks: Dict[str, TaskInfo] = {}
|
|||
|
|
task_counter = 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 保存PID文件
|
|||
|
|
def save_pid_file():
|
|||
|
|
"""保存服务PID到文件"""
|
|||
|
|
try:
|
|||
|
|
with open(Config.PID_FILE, 'w') as f:
|
|||
|
|
f.write(str(os.getpid()))
|
|||
|
|
print(f"Service PID {os.getpid()} saved to {Config.PID_FILE}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Failed to save PID file: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 删除PID文件
|
|||
|
|
def remove_pid_file():
|
|||
|
|
"""删除PID文件"""
|
|||
|
|
try:
|
|||
|
|
if os.path.exists(Config.PID_FILE):
|
|||
|
|
os.remove(Config.PID_FILE)
|
|||
|
|
print(f"PID file {Config.PID_FILE} removed")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Failed to remove PID file: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 生成唯一任务ID
|
|||
|
|
def generate_task_id() -> str:
|
|||
|
|
"""生成唯一任务ID"""
|
|||
|
|
global task_counter
|
|||
|
|
task_counter += 1
|
|||
|
|
return f"yolo_task_{int(time.time() * 1000)}_{task_counter}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 检查任务是否存在
|
|||
|
|
def task_exists(task_id: str) -> bool:
|
|||
|
|
"""检查任务是否存在"""
|
|||
|
|
return task_id in active_tasks
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 获取任务信息
|
|||
|
|
def get_task_info(task_id: str) -> Optional[TaskInfo]:
|
|||
|
|
"""获取任务信息"""
|
|||
|
|
return active_tasks.get(task_id)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 更新任务状态
|
|||
|
|
def update_task_status(task_id: str, status: str, **kwargs):
|
|||
|
|
"""更新任务状态"""
|
|||
|
|
if task_id not in active_tasks:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
task = active_tasks[task_id]
|
|||
|
|
task.status = status
|
|||
|
|
|
|||
|
|
# 更新其他属性
|
|||
|
|
for key, value in kwargs.items():
|
|||
|
|
if hasattr(task, key):
|
|||
|
|
setattr(task, key, value)
|
|||
|
|
|
|||
|
|
print(f"Task {task_id} status updated to {status}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 清理过期任务
|
|||
|
|
async def cleanup_expired_tasks():
|
|||
|
|
"""清理过期任务"""
|
|||
|
|
current_time = time.time()
|
|||
|
|
expired_tasks = []
|
|||
|
|
|
|||
|
|
for task_id, task in active_tasks.items():
|
|||
|
|
if (current_time - task.start_time) > Config.TASK_EXPIRY_SECONDS:
|
|||
|
|
expired_tasks.append(task_id)
|
|||
|
|
|
|||
|
|
for task_id in expired_tasks:
|
|||
|
|
print(f"Removing expired task: {task_id}")
|
|||
|
|
del active_tasks[task_id]
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 定期清理任务
|
|||
|
|
async def scheduled_cleanup():
|
|||
|
|
"""定期清理过期任务"""
|
|||
|
|
while True:
|
|||
|
|
await cleanup_expired_tasks()
|
|||
|
|
await asyncio.sleep(3600) # 每小时清理一次
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 检查训练进程状态
|
|||
|
|
async def check_training_processes():
|
|||
|
|
"""检查训练进程状态"""
|
|||
|
|
while True:
|
|||
|
|
current_time = time.time()
|
|||
|
|
|
|||
|
|
for task_id, task in list(active_tasks.items()):
|
|||
|
|
if task.status == TaskStatus.RUNNING and task.pid:
|
|||
|
|
try:
|
|||
|
|
# 检查进程是否存在
|
|||
|
|
if sys.platform == 'win32':
|
|||
|
|
# Windows平台检查进程的方法
|
|||
|
|
import ctypes
|
|||
|
|
kernel32 = ctypes.WinDLL('kernel32')
|
|||
|
|
handle = kernel32.OpenProcess(1, 0, task.pid)
|
|||
|
|
if handle == 0:
|
|||
|
|
# 进程不存在
|
|||
|
|
raise OSError(f"Process {task.pid} not found")
|
|||
|
|
kernel32.CloseHandle(handle)
|
|||
|
|
else:
|
|||
|
|
# Unix-like平台
|
|||
|
|
os.kill(task.pid, 0)
|
|||
|
|
except (OSError, ctypes.WinError):
|
|||
|
|
# 进程不存在,更新任务状态
|
|||
|
|
task.status = TaskStatus.COMPLETED
|
|||
|
|
task.end_time = current_time
|
|||
|
|
print(f"Training process {task.pid} for task {task_id} has completed")
|
|||
|
|
|
|||
|
|
await asyncio.sleep(30) # 每30秒检查一次
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 训练接口
|
|||
|
|
@app.route("/ai/project/train", methods=["POST"])
|
|||
|
|
async def train_project(request: Request):
|
|||
|
|
"""
|
|||
|
|
训练项目接口
|
|||
|
|
请求体格式:
|
|||
|
|
{
|
|||
|
|
"bz_training_task_id": 123,
|
|||
|
|
"pt_name": "custom_model.pt"
|
|||
|
|
}
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 解析请求数据
|
|||
|
|
data = request.json
|
|||
|
|
if not data:
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 400, "message": "Invalid request body", "data": None},
|
|||
|
|
status=400
|
|||
|
|
)
|
|||
|
|
time_ns = time.time_ns()
|
|||
|
|
bz_training_task_id = data.get("train_task_id")
|
|||
|
|
task_id = data.get("task_id")
|
|||
|
|
pt_name = f"{time_ns}-{task_id}.pt"
|
|||
|
|
|
|||
|
|
if not bz_training_task_id:
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 400, "message": "train_task_id is required", "data": None},
|
|||
|
|
status=400
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查并发任务数
|
|||
|
|
running_tasks = sum(1 for task in active_tasks.values()
|
|||
|
|
if task.status == TaskStatus.RUNNING)
|
|||
|
|
|
|||
|
|
if running_tasks >= Config.MAX_ACTIVE_TASKS:
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 429, "message": f"Too many active tasks. Max allowed: {Config.MAX_ACTIVE_TASKS}",
|
|||
|
|
"data": None},
|
|||
|
|
status=429
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 生成任务ID
|
|||
|
|
task_id_new = generate_task_id()
|
|||
|
|
|
|||
|
|
# 创建任务信息
|
|||
|
|
task_info = TaskInfo(task_id_new, bz_training_task_id, pt_name)
|
|||
|
|
active_tasks[task_id] = task_info
|
|||
|
|
|
|||
|
|
print(f"Created new training task: {task_id} for bz_training_task_id: {bz_training_task_id}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 启动下载和训练准备(不等待完成)
|
|||
|
|
asyncio.create_task(download_train(task_id, bz_training_task_id, pt_name))
|
|||
|
|
|
|||
|
|
return response.json({
|
|||
|
|
"status": "success",
|
|||
|
|
"code": 200,
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"pt_name": pt_name,
|
|||
|
|
"time_ns": time_ns,
|
|||
|
|
"message": "task started successfully"
|
|||
|
|
})
|
|||
|
|
except Exception as e:
|
|||
|
|
# 训练准备失败
|
|||
|
|
del active_tasks[task_id]
|
|||
|
|
print(f"Training preparation failed for task {task_id}: {e}", exc_info=True)
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 500, "message": f"Training preparation failed: {str(e)}", "data": None},
|
|||
|
|
status=500
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error in train_project endpoint: {e}", exc_info=True)
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
|||
|
|
status=500
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 训练接口
|
|||
|
|
@app.route("/ai/project/query_train_task", methods=["POST"])
|
|||
|
|
async def query_train_task(request: Request):
|
|||
|
|
"""
|
|||
|
|
训练项目接口
|
|||
|
|
请求体格式:
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 解析请求数据
|
|||
|
|
data = request.json
|
|||
|
|
if not data:
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 400, "message": "Invalid request body", "data": None},
|
|||
|
|
status=400
|
|||
|
|
)
|
|||
|
|
time_ns = time.time_ns()
|
|||
|
|
task_id = data.get("task_id")
|
|||
|
|
process_id = data.get("process_id")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 启动下载和训练准备(不等待完成)
|
|||
|
|
process_info=get_process_status(process_id)
|
|||
|
|
return response.json({
|
|||
|
|
"status": "success",
|
|||
|
|
"code": 200,
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"time_ns": time_ns,
|
|||
|
|
"process_info":process_info
|
|||
|
|
})
|
|||
|
|
except Exception as e:
|
|||
|
|
# 训练准备失败
|
|||
|
|
del active_tasks[task_id]
|
|||
|
|
print(f"Training preparation failed for task {task_id}: {e}", exc_info=True)
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 500, "message": f"Training preparation failed: {str(e)}", "data": None},
|
|||
|
|
status=500
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error in train_project endpoint: {e}", exc_info=True)
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
|||
|
|
status=500
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 接收前端实时流,进行任务推理
|
|||
|
|
@app.post("/ai/project/inference")
|
|||
|
|
async def start_inference(request):
|
|||
|
|
try:
|
|||
|
|
# 解析并验证请求数据
|
|||
|
|
request_json = request.json
|
|||
|
|
task_id = request_json["task_id"]
|
|||
|
|
pt_name = request_json["pt_name"]
|
|||
|
|
zip_url = request_json["zip_url"]
|
|||
|
|
time_ns = time.time_ns()
|
|||
|
|
# pt_name = f"{time_ns}-{task_id}.pt"
|
|||
|
|
# model_path=r"pt_save\best.pt"
|
|||
|
|
print(f"task_id {task_id}")
|
|||
|
|
inference_zip_url,message=predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.25, save_json=False)
|
|||
|
|
if inference_zip_url:
|
|||
|
|
return response.json({
|
|||
|
|
"status": "success",
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"inference_zip_url":inference_zip_url,
|
|||
|
|
"message": "predict successfully"
|
|||
|
|
})
|
|||
|
|
else:
|
|||
|
|
return response.json({
|
|||
|
|
"status": "fail",
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"inference_zip_url":inference_zip_url,
|
|||
|
|
"message": message
|
|||
|
|
})
|
|||
|
|
except ValueError as e:
|
|||
|
|
print(f"Validation error: {str(e)}")
|
|||
|
|
return response.json({"status": "error", "message": str(e)}, status=400)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Unexpected error: {str(e)}")
|
|||
|
|
return response.json({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 获取任务状态接口
|
|||
|
|
@app.route("/ai/project/task/status/<task_id>", methods=["GET"])
|
|||
|
|
async def get_task_status(request: Request, task_id: str):
|
|||
|
|
"""获取任务状态"""
|
|||
|
|
try:
|
|||
|
|
if not task_exists(task_id):
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 404, "message": f"Task {task_id} not found", "data": None},
|
|||
|
|
status=404
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
task_info = get_task_info(task_id)
|
|||
|
|
return response.json({
|
|||
|
|
"code": 200,
|
|||
|
|
"message": "success",
|
|||
|
|
"data": task_info.to_dict()
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error in get_task_status endpoint: {e}", exc_info=True)
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
|||
|
|
status=500
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 获取所有任务状态接口
|
|||
|
|
@app.route("/ai/project/tasks", methods=["GET"])
|
|||
|
|
async def get_all_tasks(request: Request):
|
|||
|
|
"""获取所有任务状态"""
|
|||
|
|
try:
|
|||
|
|
tasks_data = [task.to_dict() for task in active_tasks.values()]
|
|||
|
|
return response.json({
|
|||
|
|
"code": 200,
|
|||
|
|
"message": "success",
|
|||
|
|
"data": {
|
|||
|
|
"total_tasks": len(tasks_data),
|
|||
|
|
"tasks": tasks_data
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error in get_all_tasks endpoint: {e}", exc_info=True)
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
|||
|
|
status=500
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 取消任务接口
|
|||
|
|
@app.route("/ai/project/task/cancel/<task_id>", methods=["POST"])
|
|||
|
|
async def cancel_task(request: Request, task_id: str):
|
|||
|
|
"""取消任务"""
|
|||
|
|
try:
|
|||
|
|
if not task_exists(task_id):
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 404, "message": f"Task {task_id} not found", "data": None},
|
|||
|
|
status=404
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
task_info = get_task_info(task_id)
|
|||
|
|
|
|||
|
|
if task_info.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 400, "message": f"Task {task_id} is already {task_info.status}", "data": None},
|
|||
|
|
status=400
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 尝试终止训练进程
|
|||
|
|
if task_info.pid:
|
|||
|
|
try:
|
|||
|
|
if sys.platform == 'win32':
|
|||
|
|
# Windows平台终止进程的方法
|
|||
|
|
subprocess.run(['taskkill', '/F', '/T', '/PID', str(task_info.pid)],
|
|||
|
|
check=True, capture_output=True, text=True)
|
|||
|
|
else:
|
|||
|
|
# Unix-like平台
|
|||
|
|
os.kill(task_info.pid, signal.SIGTERM)
|
|||
|
|
|
|||
|
|
print(f"Sent termination signal to training process {task_info.pid} for task {task_id}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Failed to terminate training process {task_info.pid}: {e}")
|
|||
|
|
|
|||
|
|
# 更新任务状态
|
|||
|
|
update_task_status(
|
|||
|
|
task_id,
|
|||
|
|
TaskStatus.CANCELLED,
|
|||
|
|
end_time=time.time(),
|
|||
|
|
error_message="Task cancelled by user"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return response.json({
|
|||
|
|
"code": 200,
|
|||
|
|
"message": f"Task {task_id} cancelled successfully",
|
|||
|
|
"data": task_info.to_dict()
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error in cancel_task endpoint: {e}", exc_info=True)
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 500, "message": f"Internal server error: {str(e)}", "data": None},
|
|||
|
|
status=500
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 健康检查接口
|
|||
|
|
@app.route("/health", methods=["GET"])
|
|||
|
|
async def health_check(request: Request):
|
|||
|
|
"""健康检查"""
|
|||
|
|
try:
|
|||
|
|
running_tasks = sum(1 for task in active_tasks.values()
|
|||
|
|
if task.status == TaskStatus.RUNNING)
|
|||
|
|
|
|||
|
|
return response.json({
|
|||
|
|
"code": 200,
|
|||
|
|
"message": "Service is running",
|
|||
|
|
"data": {
|
|||
|
|
"status": "healthy",
|
|||
|
|
"active_tasks": len(active_tasks),
|
|||
|
|
"running_tasks": running_tasks,
|
|||
|
|
"max_active_tasks": Config.MAX_ACTIVE_TASKS,
|
|||
|
|
"timestamp": time.time()
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Health check failed: {e}", exc_info=True)
|
|||
|
|
return response.json(
|
|||
|
|
{"code": 500, "message": "Service is unhealthy", "data": None},
|
|||
|
|
status=500
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 应用启动事件
|
|||
|
|
@app.listener('before_server_start')
|
|||
|
|
async def before_server_start(app, loop):
|
|||
|
|
"""服务器启动前事件"""
|
|||
|
|
try:
|
|||
|
|
# 保存PID文件
|
|||
|
|
save_pid_file()
|
|||
|
|
|
|||
|
|
# 启动定期清理任务
|
|||
|
|
loop.create_task(scheduled_cleanup())
|
|||
|
|
|
|||
|
|
# 启动进程监控任务
|
|||
|
|
loop.create_task(check_training_processes())
|
|||
|
|
|
|||
|
|
print("Yolo Stream Service started successfully")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error during server startup: {e}", exc_info=True)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 应用停止事件
|
|||
|
|
@app.listener('after_server_stop')
|
|||
|
|
async def after_server_stop(app, loop):
|
|||
|
|
"""服务器停止后事件"""
|
|||
|
|
try:
|
|||
|
|
# 删除PID文件
|
|||
|
|
remove_pid_file()
|
|||
|
|
|
|||
|
|
# 清理所有训练进程
|
|||
|
|
for task_id, task in active_tasks.items():
|
|||
|
|
if task.status == TaskStatus.RUNNING and task.pid:
|
|||
|
|
try:
|
|||
|
|
if sys.platform == 'win32':
|
|||
|
|
subprocess.run(['taskkill', '/F', '/T', '/PID', str(task.pid)],
|
|||
|
|
capture_output=True, text=True)
|
|||
|
|
else:
|
|||
|
|
os.kill(task.pid, signal.SIGTERM)
|
|||
|
|
print(f"Terminated training process {task.pid} for task {task_id}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Failed to terminate process {task.pid}: {e}")
|
|||
|
|
|
|||
|
|
print("Yolo Stream Service stopped gracefully")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error during server shutdown: {e}", exc_info=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 信号处理
|
|||
|
|
def handle_signal(signal_num, frame):
|
|||
|
|
"""信号处理函数"""
|
|||
|
|
print(f"Received signal {signal_num}. Shutting down...")
|
|||
|
|
|
|||
|
|
# 优雅关闭应用
|
|||
|
|
asyncio.create_task(app.stop())
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 注册信号处理
|
|||
|
|
if sys.platform != 'win32': # Windows不支持某些信号
|
|||
|
|
signal.signal(signal.SIGINT, handle_signal)
|
|||
|
|
signal.signal(signal.SIGTERM, handle_signal)
|
|||
|
|
|
|||
|
|
# 主函数
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
"""应用入口点"""
|
|||
|
|
try:
|
|||
|
|
# 运行应用
|
|||
|
|
app.run(
|
|||
|
|
host=Config.HOST,
|
|||
|
|
port=Config.PORT,
|
|||
|
|
workers=Config.WORKERS,
|
|||
|
|
debug=Config.DEBUG,
|
|||
|
|
access_log=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Failed to start Yolo Stream Service: {e}", exc_info=True)
|
|||
|
|
remove_pid_file()
|
|||
|
|
sys.exit(1)
|