228 lines
9.4 KiB
Python
228 lines
9.4 KiB
Python
|
|
from datetime import datetime
|
|||
|
|
from typing import List, Dict, Any, Optional
|
|||
|
|
import logging
|
|||
|
|
import asyncio
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 配置日志
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.INFO,
|
|||
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|||
|
|
)
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# 心跳暂未启用,直接使用100分钟
|
|||
|
|
TaskTimeout=6000
|
|||
|
|
class TaskManager:
|
|||
|
|
def __init__(self):
|
|||
|
|
self.tasks: Dict[str, Dict[str, Any]] = {} # task_id -> task_info
|
|||
|
|
self.heartbeats: Dict[str, datetime] = {} # task_id -> last_heartbeat
|
|||
|
|
self.task_handles: Dict[str, asyncio.Task] = {} # task_id -> main_task_handle
|
|||
|
|
self.sub_tasks: Dict[str, List[asyncio.Task]] = {} # task_id -> [sub_task_handles]
|
|||
|
|
self.stop_events: Dict[str, asyncio.Event] = {} # task_id -> stop_event
|
|||
|
|
self.lock = asyncio.Lock()
|
|||
|
|
self.heartbeat_interval = 30
|
|||
|
|
self.task_timeout = TaskTimeout
|
|||
|
|
self.cleanup_interval = 10
|
|||
|
|
|
|||
|
|
async def start(self):
|
|||
|
|
"""启动任务管理器"""
|
|||
|
|
self.cleanup_task = asyncio.create_task(self._cleanup_tasks())
|
|||
|
|
|
|||
|
|
async def add_task(self, task_id: str, task_info: Dict[str, Any], main_task_handle: asyncio.Task,
|
|||
|
|
sub_task_handles: List[asyncio.Task] = None):
|
|||
|
|
"""添加新任务及其所有子任务"""
|
|||
|
|
async with self.lock:
|
|||
|
|
if task_id in self.tasks:
|
|||
|
|
raise ValueError("Task already exists")
|
|||
|
|
|
|||
|
|
# 为每个任务创建stop_event
|
|||
|
|
stop_event = asyncio.Event()
|
|||
|
|
|
|||
|
|
self.tasks[task_id] = {
|
|||
|
|
"task_info": task_info,
|
|||
|
|
"status": "running",
|
|||
|
|
"main_task_handle": main_task_handle,
|
|||
|
|
"sub_task_handles": sub_task_handles if sub_task_handles else []
|
|||
|
|
}
|
|||
|
|
self.heartbeats[task_id] = datetime.now()
|
|||
|
|
self.task_handles[task_id] = main_task_handle
|
|||
|
|
self.stop_events[task_id] = stop_event # 存储stop_event
|
|||
|
|
if sub_task_handles:
|
|||
|
|
self.sub_tasks[task_id] = sub_task_handles
|
|||
|
|
logger.info(f"Task {task_id} started")
|
|||
|
|
#
|
|||
|
|
async def remove_task(self, task_id: str):
|
|||
|
|
"""移除任务及其所有子任务,包括MQTT资源"""
|
|||
|
|
async with self.lock:
|
|||
|
|
if task_id not in self.tasks:
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
# 获取stop_event并设置
|
|||
|
|
if task_id in self.stop_events:
|
|||
|
|
self.stop_events[task_id].set() # 触发停止信号
|
|||
|
|
|
|||
|
|
# 获取所有子任务
|
|||
|
|
sub_tasks = []
|
|||
|
|
if task_id in self.sub_tasks:
|
|||
|
|
sub_tasks = self.sub_tasks[task_id]
|
|||
|
|
|
|||
|
|
# 获取主任务
|
|||
|
|
if task_id in self.task_handles:
|
|||
|
|
main_task = self.task_handles[task_id]
|
|||
|
|
if not main_task.done():
|
|||
|
|
try:
|
|||
|
|
# 尝试取消主任务
|
|||
|
|
await asyncio.shield(main_task.cancel())
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"取消主任务 {task_id} 失败: {e}")
|
|||
|
|
|
|||
|
|
# 清理MQTT设备资源
|
|||
|
|
if task_id in self.tasks:
|
|||
|
|
task_info = self.tasks[task_id]
|
|||
|
|
if 'device' in task_info: # 假设任务信息中包含MQTT设备引用
|
|||
|
|
try:
|
|||
|
|
for device in task_info['device']:
|
|||
|
|
await device.stop() # 停止MQTT设备
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"停止MQTT设备 {task_id} 失败: {e}")
|
|||
|
|
|
|||
|
|
# 取消所有子任务
|
|||
|
|
for sub_task in sub_tasks:
|
|||
|
|
if not sub_task.done():
|
|||
|
|
try:
|
|||
|
|
await asyncio.shield(sub_task.cancel())
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"取消子任务 {task_id} 失败: {e}")
|
|||
|
|
|
|||
|
|
# 清理资源
|
|||
|
|
del self.tasks[task_id]
|
|||
|
|
if task_id in self.task_handles:
|
|||
|
|
del self.task_handles[task_id]
|
|||
|
|
if task_id in self.sub_tasks:
|
|||
|
|
del self.sub_tasks[task_id]
|
|||
|
|
if task_id in self.stop_events:
|
|||
|
|
del self.stop_events[task_id]
|
|||
|
|
|
|||
|
|
logger.info(f"Task {task_id} removed")
|
|||
|
|
return True
|
|||
|
|
#
|
|||
|
|
# async def remove_task(self, task_id: str):
|
|||
|
|
# """改进的任务移除逻辑"""
|
|||
|
|
# async with self.lock:
|
|||
|
|
# if task_id not in self.tasks:
|
|||
|
|
# return False
|
|||
|
|
#
|
|||
|
|
# # 设置取消标志
|
|||
|
|
# if task_id in self.stop_events:
|
|||
|
|
# self.stop_events[task_id].set() # 触发停止信号
|
|||
|
|
#
|
|||
|
|
# # 获取任务信息
|
|||
|
|
# task_info = self.tasks[task_id]
|
|||
|
|
#
|
|||
|
|
# # 取消主任务(如果是协程)
|
|||
|
|
# if 'main_task' in task_info and isinstance(task_info['main_task'], asyncio.Task):
|
|||
|
|
# try:
|
|||
|
|
# task_info['main_task'].cancel()
|
|||
|
|
# await asyncio.shield(task_info['main_task']) # 等待任务完成取消
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logger.warning(f"取消主任务 {task_id} 失败: {str(e)}")
|
|||
|
|
#
|
|||
|
|
# # 取消所有子任务
|
|||
|
|
# if 'sub_tasks' in task_info and task_id in self.sub_tasks:
|
|||
|
|
# for sub_task in self.sub_tasks[task_id]:
|
|||
|
|
# if isinstance(sub_task, asyncio.Task):
|
|||
|
|
# try:
|
|||
|
|
# sub_task.cancel()
|
|||
|
|
# await asyncio.shield(sub_task)
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logger.warning(f"取消子任务 {task_id} 失败: {str(e)}")
|
|||
|
|
#
|
|||
|
|
# # 清理MQTT资源
|
|||
|
|
# if 'device' in task_info and isinstance(task_info['device'], MQTTDevice):
|
|||
|
|
# try:
|
|||
|
|
# await task_info['device'].stop()
|
|||
|
|
# except Exception as e:
|
|||
|
|
# logger.warning(f"停止MQTT设备 {task_id} 失败: {str(e)}")
|
|||
|
|
#
|
|||
|
|
# # 清理资源
|
|||
|
|
# del self.tasks[task_id]
|
|||
|
|
# if task_id in self.sub_tasks:
|
|||
|
|
# del self.sub_tasks[task_id]
|
|||
|
|
# if task_id in self.stop_events:
|
|||
|
|
# del self.stop_events[task_id]
|
|||
|
|
#
|
|||
|
|
# logger.info(f"Task {task_id} removed successfully")
|
|||
|
|
# return True
|
|||
|
|
async def mark_all_tasks_as_stopped(self):
|
|||
|
|
"""标记所有任务为停止状态"""
|
|||
|
|
async with self.lock:
|
|||
|
|
for task_id in self.tasks:
|
|||
|
|
self.tasks[task_id]["status"] = "stopped"
|
|||
|
|
|
|||
|
|
async def update_heartbeat(self, task_id: str) -> None:
|
|||
|
|
"""更新任务心跳时间"""
|
|||
|
|
if task_id in self.heartbeats:
|
|||
|
|
self.heartbeats[task_id] = datetime.now()
|
|||
|
|
logger.debug(f"Heartbeat updated for task {task_id}")
|
|||
|
|
|
|||
|
|
async def get_task_info(self, task_id: str) -> Optional[Dict[str, any]]:
|
|||
|
|
"""获取任务信息"""
|
|||
|
|
async with self.lock:
|
|||
|
|
if task_id not in self.tasks:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
last_heartbeat = self.heartbeats.get(task_id)
|
|||
|
|
current_time = datetime.now()
|
|||
|
|
# elapsed_time = (current_time - last_heartbeat).total_seconds()
|
|||
|
|
elapsed_time = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"task_info": self.tasks[task_id],
|
|||
|
|
"status": "running" if (elapsed_time is not None and elapsed_time < self.task_timeout) else "timeout",
|
|||
|
|
"last_heartbeat": last_heartbeat.isoformat(),
|
|||
|
|
"elapsed_time": elapsed_time
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
async def check_tasks_health(self) -> Dict[str, str]:
|
|||
|
|
"""检查任务健康状态"""
|
|||
|
|
unhealthy_tasks = {}
|
|||
|
|
current_time = datetime.now()
|
|||
|
|
|
|||
|
|
async with self.lock:
|
|||
|
|
for task_id in list(self.tasks.keys()):
|
|||
|
|
last_heartbeat = self.heartbeats.get(task_id)
|
|||
|
|
elapsed_seconds = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None
|
|||
|
|
print(f"当前心跳已经进行{elapsed_seconds}s")
|
|||
|
|
if elapsed_seconds is not None and elapsed_seconds > self.task_timeout:
|
|||
|
|
unhealthy_tasks[task_id] = "timeout"
|
|||
|
|
logger.warning(f"----------------------------Task {task_id} has timed out-----------------------------")
|
|||
|
|
|
|||
|
|
return unhealthy_tasks
|
|||
|
|
|
|||
|
|
async def _cleanup_tasks(self):
|
|||
|
|
"""清理超时任务"""
|
|||
|
|
while True:
|
|||
|
|
await asyncio.sleep(self.cleanup_interval)
|
|||
|
|
unhealthy_tasks = await self.check_tasks_health()
|
|||
|
|
for task_id in unhealthy_tasks:
|
|||
|
|
await self.remove_task(task_id)
|
|||
|
|
logger.warning(f"Task {task_id} removed due to timeout")
|
|||
|
|
|
|||
|
|
async def get_all_tasks(self) -> Dict[str, Dict[str, any]]:
|
|||
|
|
"""获取所有任务信息"""
|
|||
|
|
async with self.lock:
|
|||
|
|
return {
|
|||
|
|
task_id: {
|
|||
|
|
# "task_info": task_info,
|
|||
|
|
"status": "running" if (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds() < self.task_timeout
|
|||
|
|
else "timeout",
|
|||
|
|
"last_heartbeat": self.heartbeats.get(task_id, datetime.min).isoformat(),
|
|||
|
|
"elapsed_time": (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds()
|
|||
|
|
if task_id in self.heartbeats else 0
|
|||
|
|
}
|
|||
|
|
for task_id, task_info in self.tasks.items()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
task_manager = TaskManager()
|