# import time # # from sanic.response import json as json_response # from sanic.exceptions import Unauthorized, NotFound, SanicException # # import logging # import asyncio # # from sanic_cors import CORS # # from sanic import Sanic, Request # # from download_train import download_train # # # 配置日志 # logging.basicConfig( # level=logging.INFO, # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' # ) # logger = logging.getLogger(__name__) # # DB_CONFIG = { # "dbname": "smart_dev_123", # "user": "postgres", # "password": "root", # "host": "8.137.54.85", # "port": "5060" # } # # # # 配置类 # class Config: # VALID_TOKEN = "5e8899fe-dc74-4280-8169-2f4d185f3afa" # MAX_ACTIVE_TASKS = 10 # DEFAULT_CONFIDENCE = 0.5 # RESTART_DELAY = 2 # 服务尝试自动恢复前的延迟(秒) # # app = Sanic("YoloStreamService2") # CORS(app) # # # def verify_token(request) -> None: # """验证请求token""" # token = request.headers.get('X-API-Token') # if not token or token != Config.VALID_TOKEN: # logger.warning("Invalid token attempt") # raise Unauthorized("Invalid token") # # # # 未针对具体方法做实现,待完成 # async def run_train(request_json): # try: # print("12121") # except Exception as e: # logger.error(f"启动AI视频处理失败: {e}") # raise SanicException(f"Failed to start AI video processing: {str(e)}", status_code=500) # # # # 接收前端实时流,进行任务训练 # @app.post("/ai/project/train") # async def start_train(request): # try: # verify_token(request) # # # 解析并验证请求数据 # request_json=request.json # task_id=request_json["task_id"] # train_task_id=request_json["train_task_id"] # time_ns=time.time_ns() # pt_name=f"{time_ns}-{task_id}.pt" # # print(f"task_id {task_id}") # # # result = await download_train(task_id,train_task_id,pt_name) # 等待异步函数执行完成 # asyncio.create_task(download_train(task_id, train_task_id, pt_name)) # # return json_response({ # "status": "success", # "task_id": task_id, # "pt_name": pt_name, # "time_ns": time_ns, # "process_id":123, # "message": "task started successfully" # }) # # except ValueError as e: # logger.error(f"Validation error: {str(e)}") # return json_response({"status": "error", "message": str(e)}, status=400) # except Exception as e: # logger.error(f"Unexpected error: {str(e)}", exc_info=True) # return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) # # # # 接收前端实时流,进行任务推理 # @app.post("/ai/project/inference") # async def start_inference(request): # try: # verify_token(request) # # 解析并验证请求数据 # request_json = request.json # task_id = request_json["task_id"] # time_ns = time.time_ns() # pt_name = f"{time_ns}-{task_id}.pt" # # print(f"task_id {task_id}") # # return json_response({ # "status": "success", # "task_id": task_id, # "message": "Detection started successfully" # }) # except ValueError as e: # logger.error(f"Validation error: {str(e)}") # return json_response({"status": "error", "message": str(e)}, status=400) # except Exception as e: # logger.error(f"Unexpected error: {str(e)}", exc_info=True) # return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) # # # # # 接收前端实时流,进行任务查询 # @app.post("/ai/project/query_train_task") # async def query_train_task(request): # try: # verify_token(request) # # # 解析并验证请求数据 # request_json = request.json # task_id = request_json["task_id"] # time_ns = time.time_ns() # pt_name = f"{time_ns}-{task_id}.pt" # # print(f"task_id {task_id}") # # return json_response({ # "task_id": task_id, # "process_id":123, # "status": "Running" # }) # # except ValueError as e: # logger.error(f"Validation error: {str(e)}") # return json_response({"status": "error", "message": str(e)}, status=400) # except Exception as e: # logger.error(f"Unexpected error: {str(e)}", exc_info=True) # return json_response({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) # # # # if __name__ == "__main__": # # 保证服务启动前没有残留任务 # # 安装psutil库,用于进程管理 # try: # import psutil # except ImportError: # import subprocess # import sys # # print("正在安装psutil库...") # subprocess.check_call([sys.executable, "-m", "pip", "install", "psutil"]) # # # app.run(host="0.0.0.0", port=12316, debug=False, access_log=True) # app.run(host="0.0.0.0", port=12325) from sanic import Sanic, response from sanic.exceptions import SanicException from sanic.request import Request import os import asyncio import signal import sys import json import time import subprocess from typing import Dict, Optional, List, Any from pathlib import Path from download_train import download_train from predict.predict_yolo11seg import predict_images from query_process_status import get_process_status # 配置 class Config: MAX_ACTIVE_TASKS = 10 # 限制最大并发任务数 TASK_EXPIRY_SECONDS = 86400 # 任务过期时间(24小时) LOG_FILE = "yolo_service.log" PID_FILE = "service.pid" HOST = "0.0.0.0" PORT = 12325 DEBUG = False WORKERS = 1 # 单worker模式运行 # 任务状态 class TaskStatus: PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" CANCELLED = "cancelled" # 任务信息 class TaskInfo: def __init__(self, task_id: str, bz_training_task_id: int, pt_name: str): self.task_id = task_id self.bz_training_task_id = bz_training_task_id self.pt_name = pt_name self.status = TaskStatus.PENDING self.pid = None # 训练进程ID self.start_time = time.time() self.end_time = None self.error_message = None self.progress = 0.0 # 训练进度(百分比) self.result = None # 训练结果 def to_dict(self) -> Dict: """转换为字典用于JSON序列化""" return { "task_id": self.task_id, "bz_training_task_id": self.bz_training_task_id, "pt_name": self.pt_name, "status": self.status, "pid": self.pid, "start_time": self.start_time, "end_time": self.end_time, "error_message": self.error_message, "progress": self.progress, "result": self.result } # 应用初始化 app = Sanic("YoloStreamService2") app.config.update({ "DEBUG": Config.DEBUG, "WORKERS": Config.WORKERS }) # 全局状态 active_tasks: Dict[str, TaskInfo] = {} task_counter = 0 # 保存PID文件 def save_pid_file(): """保存服务PID到文件""" try: with open(Config.PID_FILE, 'w') as f: f.write(str(os.getpid())) print(f"Service PID {os.getpid()} saved to {Config.PID_FILE}") except Exception as e: print(f"Failed to save PID file: {e}") # 删除PID文件 def remove_pid_file(): """删除PID文件""" try: if os.path.exists(Config.PID_FILE): os.remove(Config.PID_FILE) print(f"PID file {Config.PID_FILE} removed") except Exception as e: print(f"Failed to remove PID file: {e}") # 生成唯一任务ID def generate_task_id() -> str: """生成唯一任务ID""" global task_counter task_counter += 1 return f"yolo_task_{int(time.time() * 1000)}_{task_counter}" # 检查任务是否存在 def task_exists(task_id: str) -> bool: """检查任务是否存在""" return task_id in active_tasks # 获取任务信息 def get_task_info(task_id: str) -> Optional[TaskInfo]: """获取任务信息""" return active_tasks.get(task_id) # 更新任务状态 def update_task_status(task_id: str, status: str, **kwargs): """更新任务状态""" if task_id not in active_tasks: return task = active_tasks[task_id] task.status = status # 更新其他属性 for key, value in kwargs.items(): if hasattr(task, key): setattr(task, key, value) print(f"Task {task_id} status updated to {status}") # 清理过期任务 async def cleanup_expired_tasks(): """清理过期任务""" current_time = time.time() expired_tasks = [] for task_id, task in active_tasks.items(): if (current_time - task.start_time) > Config.TASK_EXPIRY_SECONDS: expired_tasks.append(task_id) for task_id in expired_tasks: print(f"Removing expired task: {task_id}") del active_tasks[task_id] # 定期清理任务 async def scheduled_cleanup(): """定期清理过期任务""" while True: await cleanup_expired_tasks() await asyncio.sleep(3600) # 每小时清理一次 # 检查训练进程状态 async def check_training_processes(): """检查训练进程状态""" while True: current_time = time.time() for task_id, task in list(active_tasks.items()): if task.status == TaskStatus.RUNNING and task.pid: try: # 检查进程是否存在 if sys.platform == 'win32': # Windows平台检查进程的方法 import ctypes kernel32 = ctypes.WinDLL('kernel32') handle = kernel32.OpenProcess(1, 0, task.pid) if handle == 0: # 进程不存在 raise OSError(f"Process {task.pid} not found") kernel32.CloseHandle(handle) else: # Unix-like平台 os.kill(task.pid, 0) except (OSError, ctypes.WinError): # 进程不存在,更新任务状态 task.status = TaskStatus.COMPLETED task.end_time = current_time print(f"Training process {task.pid} for task {task_id} has completed") await asyncio.sleep(30) # 每30秒检查一次 # 训练接口 @app.route("/ai/project/train", methods=["POST"]) async def train_project(request: Request): """ 训练项目接口 请求体格式: { "bz_training_task_id": 123, "pt_name": "custom_model.pt" } """ try: # 解析请求数据 data = request.json if not data: return response.json( {"code": 400, "message": "Invalid request body", "data": None}, status=400 ) time_ns = time.time_ns() bz_training_task_id = data.get("train_task_id") task_id = data.get("task_id") pt_name = f"{time_ns}-{task_id}.pt" if not bz_training_task_id: return response.json( {"code": 400, "message": "train_task_id is required", "data": None}, status=400 ) # 检查并发任务数 running_tasks = sum(1 for task in active_tasks.values() if task.status == TaskStatus.RUNNING) if running_tasks >= Config.MAX_ACTIVE_TASKS: return response.json( {"code": 429, "message": f"Too many active tasks. Max allowed: {Config.MAX_ACTIVE_TASKS}", "data": None}, status=429 ) # 生成任务ID task_id_new = generate_task_id() # 创建任务信息 task_info = TaskInfo(task_id_new, bz_training_task_id, pt_name) active_tasks[task_id] = task_info print(f"Created new training task: {task_id} for bz_training_task_id: {bz_training_task_id}") try: # 启动下载和训练准备(不等待完成) asyncio.create_task(download_train(task_id, bz_training_task_id, pt_name)) return response.json({ "status": "success", "code": 200, "task_id": task_id, "pt_name": pt_name, "time_ns": time_ns, "message": "task started successfully" }) except Exception as e: # 训练准备失败 del active_tasks[task_id] print(f"Training preparation failed for task {task_id}: {e}", exc_info=True) return response.json( {"code": 500, "message": f"Training preparation failed: {str(e)}", "data": None}, status=500 ) except Exception as e: print(f"Error in train_project endpoint: {e}", exc_info=True) return response.json( {"code": 500, "message": f"Internal server error: {str(e)}", "data": None}, status=500 ) # 训练接口 @app.route("/ai/project/query_train_task", methods=["POST"]) async def query_train_task(request: Request): """ 训练项目接口 请求体格式: """ try: # 解析请求数据 data = request.json if not data: return response.json( {"code": 400, "message": "Invalid request body", "data": None}, status=400 ) time_ns = time.time_ns() task_id = data.get("task_id") process_id = data.get("process_id") try: # 启动下载和训练准备(不等待完成) process_info=get_process_status(process_id) return response.json({ "status": "success", "code": 200, "task_id": task_id, "time_ns": time_ns, "process_info":process_info }) except Exception as e: # 训练准备失败 del active_tasks[task_id] print(f"Training preparation failed for task {task_id}: {e}", exc_info=True) return response.json( {"code": 500, "message": f"Training preparation failed: {str(e)}", "data": None}, status=500 ) except Exception as e: print(f"Error in train_project endpoint: {e}", exc_info=True) return response.json( {"code": 500, "message": f"Internal server error: {str(e)}", "data": None}, status=500 ) # 接收前端实时流,进行任务推理 @app.post("/ai/project/inference") async def start_inference(request): try: # 解析并验证请求数据 request_json = request.json task_id = request_json["task_id"] pt_name = request_json["pt_name"] zip_url = request_json["zip_url"] time_ns = time.time_ns() # pt_name = f"{time_ns}-{task_id}.pt" # model_path=r"pt_save\best.pt" print(f"task_id {task_id}") inference_zip_url,message=predict_images(pt_name, zip_url, output_dir="predictions", conf_threshold=0.25, save_json=False) if inference_zip_url: return response.json({ "status": "success", "task_id": task_id, "inference_zip_url":inference_zip_url, "message": "predict successfully" }) else: return response.json({ "status": "fail", "task_id": task_id, "inference_zip_url":inference_zip_url, "message": message }) except ValueError as e: print(f"Validation error: {str(e)}") return response.json({"status": "error", "message": str(e)}, status=400) except Exception as e: print(f"Unexpected error: {str(e)}") return response.json({"status": "error", "message": f"Internal server error: {str(e)}"}, status=500) # 获取任务状态接口 @app.route("/ai/project/task/status/", methods=["GET"]) async def get_task_status(request: Request, task_id: str): """获取任务状态""" try: if not task_exists(task_id): return response.json( {"code": 404, "message": f"Task {task_id} not found", "data": None}, status=404 ) task_info = get_task_info(task_id) return response.json({ "code": 200, "message": "success", "data": task_info.to_dict() }) except Exception as e: print(f"Error in get_task_status endpoint: {e}", exc_info=True) return response.json( {"code": 500, "message": f"Internal server error: {str(e)}", "data": None}, status=500 ) # 获取所有任务状态接口 @app.route("/ai/project/tasks", methods=["GET"]) async def get_all_tasks(request: Request): """获取所有任务状态""" try: tasks_data = [task.to_dict() for task in active_tasks.values()] return response.json({ "code": 200, "message": "success", "data": { "total_tasks": len(tasks_data), "tasks": tasks_data } }) except Exception as e: print(f"Error in get_all_tasks endpoint: {e}", exc_info=True) return response.json( {"code": 500, "message": f"Internal server error: {str(e)}", "data": None}, status=500 ) # 取消任务接口 @app.route("/ai/project/task/cancel/", methods=["POST"]) async def cancel_task(request: Request, task_id: str): """取消任务""" try: if not task_exists(task_id): return response.json( {"code": 404, "message": f"Task {task_id} not found", "data": None}, status=404 ) task_info = get_task_info(task_id) if task_info.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: return response.json( {"code": 400, "message": f"Task {task_id} is already {task_info.status}", "data": None}, status=400 ) # 尝试终止训练进程 if task_info.pid: try: if sys.platform == 'win32': # Windows平台终止进程的方法 subprocess.run(['taskkill', '/F', '/T', '/PID', str(task_info.pid)], check=True, capture_output=True, text=True) else: # Unix-like平台 os.kill(task_info.pid, signal.SIGTERM) print(f"Sent termination signal to training process {task_info.pid} for task {task_id}") except Exception as e: print(f"Failed to terminate training process {task_info.pid}: {e}") # 更新任务状态 update_task_status( task_id, TaskStatus.CANCELLED, end_time=time.time(), error_message="Task cancelled by user" ) return response.json({ "code": 200, "message": f"Task {task_id} cancelled successfully", "data": task_info.to_dict() }) except Exception as e: print(f"Error in cancel_task endpoint: {e}", exc_info=True) return response.json( {"code": 500, "message": f"Internal server error: {str(e)}", "data": None}, status=500 ) # 健康检查接口 @app.route("/health", methods=["GET"]) async def health_check(request: Request): """健康检查""" try: running_tasks = sum(1 for task in active_tasks.values() if task.status == TaskStatus.RUNNING) return response.json({ "code": 200, "message": "Service is running", "data": { "status": "healthy", "active_tasks": len(active_tasks), "running_tasks": running_tasks, "max_active_tasks": Config.MAX_ACTIVE_TASKS, "timestamp": time.time() } }) except Exception as e: print(f"Health check failed: {e}", exc_info=True) return response.json( {"code": 500, "message": "Service is unhealthy", "data": None}, status=500 ) # 应用启动事件 @app.listener('before_server_start') async def before_server_start(app, loop): """服务器启动前事件""" try: # 保存PID文件 save_pid_file() # 启动定期清理任务 loop.create_task(scheduled_cleanup()) # 启动进程监控任务 loop.create_task(check_training_processes()) print("Yolo Stream Service started successfully") except Exception as e: print(f"Error during server startup: {e}", exc_info=True) raise # 应用停止事件 @app.listener('after_server_stop') async def after_server_stop(app, loop): """服务器停止后事件""" try: # 删除PID文件 remove_pid_file() # 清理所有训练进程 for task_id, task in active_tasks.items(): if task.status == TaskStatus.RUNNING and task.pid: try: if sys.platform == 'win32': subprocess.run(['taskkill', '/F', '/T', '/PID', str(task.pid)], capture_output=True, text=True) else: os.kill(task.pid, signal.SIGTERM) print(f"Terminated training process {task.pid} for task {task_id}") except Exception as e: print(f"Failed to terminate process {task.pid}: {e}") print("Yolo Stream Service stopped gracefully") except Exception as e: print(f"Error during server shutdown: {e}", exc_info=True) # 信号处理 def handle_signal(signal_num, frame): """信号处理函数""" print(f"Received signal {signal_num}. Shutting down...") # 优雅关闭应用 asyncio.create_task(app.stop()) # 注册信号处理 if sys.platform != 'win32': # Windows不支持某些信号 signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) # 主函数 if __name__ == "__main__": """应用入口点""" try: # 运行应用 app.run( host=Config.HOST, port=Config.PORT, workers=Config.WORKERS, debug=Config.DEBUG, access_log=True ) except Exception as e: print(f"Failed to start Yolo Stream Service: {e}", exc_info=True) remove_pid_file() sys.exit(1)