from datetime import datetime from typing import List, Dict, Any, Optional import logging import asyncio # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # 心跳暂未启用,直接使用100分钟 TaskTimeout=6000 class TaskManager: def __init__(self): self.tasks: Dict[str, Dict[str, Any]] = {} # task_id -> task_info self.heartbeats: Dict[str, datetime] = {} # task_id -> last_heartbeat self.task_handles: Dict[str, asyncio.Task] = {} # task_id -> main_task_handle self.sub_tasks: Dict[str, List[asyncio.Task]] = {} # task_id -> [sub_task_handles] self.stop_events: Dict[str, asyncio.Event] = {} # task_id -> stop_event self.lock = asyncio.Lock() self.heartbeat_interval = 30 self.task_timeout = TaskTimeout self.cleanup_interval = 10 async def start(self): """启动任务管理器""" self.cleanup_task = asyncio.create_task(self._cleanup_tasks()) async def add_task(self, task_id: str, task_info: Dict[str, Any], main_task_handle: asyncio.Task, sub_task_handles: List[asyncio.Task] = None): """添加新任务及其所有子任务""" async with self.lock: if task_id in self.tasks: raise ValueError("Task already exists") # 为每个任务创建stop_event stop_event = asyncio.Event() self.tasks[task_id] = { "task_info": task_info, "status": "running", "main_task_handle": main_task_handle, "sub_task_handles": sub_task_handles if sub_task_handles else [] } self.heartbeats[task_id] = datetime.now() self.task_handles[task_id] = main_task_handle self.stop_events[task_id] = stop_event # 存储stop_event if sub_task_handles: self.sub_tasks[task_id] = sub_task_handles logger.info(f"Task {task_id} started") # async def remove_task(self, task_id: str): """移除任务及其所有子任务,包括MQTT资源""" async with self.lock: if task_id not in self.tasks: return False # 获取stop_event并设置 if task_id in self.stop_events: self.stop_events[task_id].set() # 触发停止信号 # 获取所有子任务 sub_tasks = [] if task_id in self.sub_tasks: sub_tasks = self.sub_tasks[task_id] # 获取主任务 if task_id in self.task_handles: main_task = self.task_handles[task_id] if not main_task.done(): try: # 尝试取消主任务 await asyncio.shield(main_task.cancel()) except Exception as e: logger.warning(f"取消主任务 {task_id} 失败: {e}") # 清理MQTT设备资源 if task_id in self.tasks: task_info = self.tasks[task_id] if 'device' in task_info: # 假设任务信息中包含MQTT设备引用 try: for device in task_info['device']: await device.stop() # 停止MQTT设备 except Exception as e: logger.warning(f"停止MQTT设备 {task_id} 失败: {e}") # 取消所有子任务 for sub_task in sub_tasks: if not sub_task.done(): try: await asyncio.shield(sub_task.cancel()) except Exception as e: logger.warning(f"取消子任务 {task_id} 失败: {e}") # 清理资源 del self.tasks[task_id] if task_id in self.task_handles: del self.task_handles[task_id] if task_id in self.sub_tasks: del self.sub_tasks[task_id] if task_id in self.stop_events: del self.stop_events[task_id] logger.info(f"Task {task_id} removed") return True # # async def remove_task(self, task_id: str): # """改进的任务移除逻辑""" # async with self.lock: # if task_id not in self.tasks: # return False # # # 设置取消标志 # if task_id in self.stop_events: # self.stop_events[task_id].set() # 触发停止信号 # # # 获取任务信息 # task_info = self.tasks[task_id] # # # 取消主任务(如果是协程) # if 'main_task' in task_info and isinstance(task_info['main_task'], asyncio.Task): # try: # task_info['main_task'].cancel() # await asyncio.shield(task_info['main_task']) # 等待任务完成取消 # except Exception as e: # logger.warning(f"取消主任务 {task_id} 失败: {str(e)}") # # # 取消所有子任务 # if 'sub_tasks' in task_info and task_id in self.sub_tasks: # for sub_task in self.sub_tasks[task_id]: # if isinstance(sub_task, asyncio.Task): # try: # sub_task.cancel() # await asyncio.shield(sub_task) # except Exception as e: # logger.warning(f"取消子任务 {task_id} 失败: {str(e)}") # # # 清理MQTT资源 # if 'device' in task_info and isinstance(task_info['device'], MQTTDevice): # try: # await task_info['device'].stop() # except Exception as e: # logger.warning(f"停止MQTT设备 {task_id} 失败: {str(e)}") # # # 清理资源 # del self.tasks[task_id] # if task_id in self.sub_tasks: # del self.sub_tasks[task_id] # if task_id in self.stop_events: # del self.stop_events[task_id] # # logger.info(f"Task {task_id} removed successfully") # return True async def mark_all_tasks_as_stopped(self): """标记所有任务为停止状态""" async with self.lock: for task_id in self.tasks: self.tasks[task_id]["status"] = "stopped" async def update_heartbeat(self, task_id: str) -> None: """更新任务心跳时间""" if task_id in self.heartbeats: self.heartbeats[task_id] = datetime.now() logger.debug(f"Heartbeat updated for task {task_id}") async def get_task_info(self, task_id: str) -> Optional[Dict[str, any]]: """获取任务信息""" async with self.lock: if task_id not in self.tasks: return None last_heartbeat = self.heartbeats.get(task_id) current_time = datetime.now() # elapsed_time = (current_time - last_heartbeat).total_seconds() elapsed_time = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None return { "task_info": self.tasks[task_id], "status": "running" if (elapsed_time is not None and elapsed_time < self.task_timeout) else "timeout", "last_heartbeat": last_heartbeat.isoformat(), "elapsed_time": elapsed_time } async def check_tasks_health(self) -> Dict[str, str]: """检查任务健康状态""" unhealthy_tasks = {} current_time = datetime.now() async with self.lock: for task_id in list(self.tasks.keys()): last_heartbeat = self.heartbeats.get(task_id) elapsed_seconds = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None print(f"当前心跳已经进行{elapsed_seconds}s") if elapsed_seconds is not None and elapsed_seconds > self.task_timeout: unhealthy_tasks[task_id] = "timeout" logger.warning(f"----------------------------Task {task_id} has timed out-----------------------------") return unhealthy_tasks async def _cleanup_tasks(self): """清理超时任务""" while True: await asyncio.sleep(self.cleanup_interval) unhealthy_tasks = await self.check_tasks_health() for task_id in unhealthy_tasks: await self.remove_task(task_id) logger.warning(f"Task {task_id} removed due to timeout") async def get_all_tasks(self) -> Dict[str, Dict[str, any]]: """获取所有任务信息""" async with self.lock: return { task_id: { # "task_info": task_info, "status": "running" if (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds() < self.task_timeout else "timeout", "last_heartbeat": self.heartbeats.get(task_id, datetime.min).isoformat(), "elapsed_time": (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds() if task_id in self.heartbeats else 0 } for task_id, task_info in self.tasks.items() } task_manager = TaskManager()