228 lines
9.4 KiB
Python
228 lines
9.4 KiB
Python
from datetime import datetime
|
||
from typing import List, Dict, Any, Optional
|
||
import logging
|
||
import asyncio
|
||
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 心跳暂未启用,直接使用100分钟
|
||
TaskTimeout=6000
|
||
class TaskManager:
|
||
def __init__(self):
|
||
self.tasks: Dict[str, Dict[str, Any]] = {} # task_id -> task_info
|
||
self.heartbeats: Dict[str, datetime] = {} # task_id -> last_heartbeat
|
||
self.task_handles: Dict[str, asyncio.Task] = {} # task_id -> main_task_handle
|
||
self.sub_tasks: Dict[str, List[asyncio.Task]] = {} # task_id -> [sub_task_handles]
|
||
self.stop_events: Dict[str, asyncio.Event] = {} # task_id -> stop_event
|
||
self.lock = asyncio.Lock()
|
||
self.heartbeat_interval = 30
|
||
self.task_timeout = TaskTimeout
|
||
self.cleanup_interval = 10
|
||
|
||
async def start(self):
|
||
"""启动任务管理器"""
|
||
self.cleanup_task = asyncio.create_task(self._cleanup_tasks())
|
||
|
||
async def add_task(self, task_id: str, task_info: Dict[str, Any], main_task_handle: asyncio.Task,
|
||
sub_task_handles: List[asyncio.Task] = None):
|
||
"""添加新任务及其所有子任务"""
|
||
async with self.lock:
|
||
if task_id in self.tasks:
|
||
raise ValueError("Task already exists")
|
||
|
||
# 为每个任务创建stop_event
|
||
stop_event = asyncio.Event()
|
||
|
||
self.tasks[task_id] = {
|
||
"task_info": task_info,
|
||
"status": "running",
|
||
"main_task_handle": main_task_handle,
|
||
"sub_task_handles": sub_task_handles if sub_task_handles else []
|
||
}
|
||
self.heartbeats[task_id] = datetime.now()
|
||
self.task_handles[task_id] = main_task_handle
|
||
self.stop_events[task_id] = stop_event # 存储stop_event
|
||
if sub_task_handles:
|
||
self.sub_tasks[task_id] = sub_task_handles
|
||
logger.info(f"Task {task_id} started")
|
||
#
|
||
async def remove_task(self, task_id: str):
|
||
"""移除任务及其所有子任务,包括MQTT资源"""
|
||
async with self.lock:
|
||
if task_id not in self.tasks:
|
||
return False
|
||
|
||
# 获取stop_event并设置
|
||
if task_id in self.stop_events:
|
||
self.stop_events[task_id].set() # 触发停止信号
|
||
|
||
# 获取所有子任务
|
||
sub_tasks = []
|
||
if task_id in self.sub_tasks:
|
||
sub_tasks = self.sub_tasks[task_id]
|
||
|
||
# 获取主任务
|
||
if task_id in self.task_handles:
|
||
main_task = self.task_handles[task_id]
|
||
if not main_task.done():
|
||
try:
|
||
# 尝试取消主任务
|
||
await asyncio.shield(main_task.cancel())
|
||
except Exception as e:
|
||
logger.warning(f"取消主任务 {task_id} 失败: {e}")
|
||
|
||
# 清理MQTT设备资源
|
||
if task_id in self.tasks:
|
||
task_info = self.tasks[task_id]
|
||
if 'device' in task_info: # 假设任务信息中包含MQTT设备引用
|
||
try:
|
||
for device in task_info['device']:
|
||
await device.stop() # 停止MQTT设备
|
||
except Exception as e:
|
||
logger.warning(f"停止MQTT设备 {task_id} 失败: {e}")
|
||
|
||
# 取消所有子任务
|
||
for sub_task in sub_tasks:
|
||
if not sub_task.done():
|
||
try:
|
||
await asyncio.shield(sub_task.cancel())
|
||
except Exception as e:
|
||
logger.warning(f"取消子任务 {task_id} 失败: {e}")
|
||
|
||
# 清理资源
|
||
del self.tasks[task_id]
|
||
if task_id in self.task_handles:
|
||
del self.task_handles[task_id]
|
||
if task_id in self.sub_tasks:
|
||
del self.sub_tasks[task_id]
|
||
if task_id in self.stop_events:
|
||
del self.stop_events[task_id]
|
||
|
||
logger.info(f"Task {task_id} removed")
|
||
return True
|
||
#
|
||
# async def remove_task(self, task_id: str):
|
||
# """改进的任务移除逻辑"""
|
||
# async with self.lock:
|
||
# if task_id not in self.tasks:
|
||
# return False
|
||
#
|
||
# # 设置取消标志
|
||
# if task_id in self.stop_events:
|
||
# self.stop_events[task_id].set() # 触发停止信号
|
||
#
|
||
# # 获取任务信息
|
||
# task_info = self.tasks[task_id]
|
||
#
|
||
# # 取消主任务(如果是协程)
|
||
# if 'main_task' in task_info and isinstance(task_info['main_task'], asyncio.Task):
|
||
# try:
|
||
# task_info['main_task'].cancel()
|
||
# await asyncio.shield(task_info['main_task']) # 等待任务完成取消
|
||
# except Exception as e:
|
||
# logger.warning(f"取消主任务 {task_id} 失败: {str(e)}")
|
||
#
|
||
# # 取消所有子任务
|
||
# if 'sub_tasks' in task_info and task_id in self.sub_tasks:
|
||
# for sub_task in self.sub_tasks[task_id]:
|
||
# if isinstance(sub_task, asyncio.Task):
|
||
# try:
|
||
# sub_task.cancel()
|
||
# await asyncio.shield(sub_task)
|
||
# except Exception as e:
|
||
# logger.warning(f"取消子任务 {task_id} 失败: {str(e)}")
|
||
#
|
||
# # 清理MQTT资源
|
||
# if 'device' in task_info and isinstance(task_info['device'], MQTTDevice):
|
||
# try:
|
||
# await task_info['device'].stop()
|
||
# except Exception as e:
|
||
# logger.warning(f"停止MQTT设备 {task_id} 失败: {str(e)}")
|
||
#
|
||
# # 清理资源
|
||
# del self.tasks[task_id]
|
||
# if task_id in self.sub_tasks:
|
||
# del self.sub_tasks[task_id]
|
||
# if task_id in self.stop_events:
|
||
# del self.stop_events[task_id]
|
||
#
|
||
# logger.info(f"Task {task_id} removed successfully")
|
||
# return True
|
||
async def mark_all_tasks_as_stopped(self):
|
||
"""标记所有任务为停止状态"""
|
||
async with self.lock:
|
||
for task_id in self.tasks:
|
||
self.tasks[task_id]["status"] = "stopped"
|
||
|
||
async def update_heartbeat(self, task_id: str) -> None:
|
||
"""更新任务心跳时间"""
|
||
if task_id in self.heartbeats:
|
||
self.heartbeats[task_id] = datetime.now()
|
||
logger.debug(f"Heartbeat updated for task {task_id}")
|
||
|
||
async def get_task_info(self, task_id: str) -> Optional[Dict[str, any]]:
|
||
"""获取任务信息"""
|
||
async with self.lock:
|
||
if task_id not in self.tasks:
|
||
return None
|
||
|
||
last_heartbeat = self.heartbeats.get(task_id)
|
||
current_time = datetime.now()
|
||
# elapsed_time = (current_time - last_heartbeat).total_seconds()
|
||
elapsed_time = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None
|
||
|
||
return {
|
||
"task_info": self.tasks[task_id],
|
||
"status": "running" if (elapsed_time is not None and elapsed_time < self.task_timeout) else "timeout",
|
||
"last_heartbeat": last_heartbeat.isoformat(),
|
||
"elapsed_time": elapsed_time
|
||
}
|
||
|
||
async def check_tasks_health(self) -> Dict[str, str]:
|
||
"""检查任务健康状态"""
|
||
unhealthy_tasks = {}
|
||
current_time = datetime.now()
|
||
|
||
async with self.lock:
|
||
for task_id in list(self.tasks.keys()):
|
||
last_heartbeat = self.heartbeats.get(task_id)
|
||
elapsed_seconds = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None
|
||
print(f"当前心跳已经进行{elapsed_seconds}s")
|
||
if elapsed_seconds is not None and elapsed_seconds > self.task_timeout:
|
||
unhealthy_tasks[task_id] = "timeout"
|
||
logger.warning(f"----------------------------Task {task_id} has timed out-----------------------------")
|
||
|
||
return unhealthy_tasks
|
||
|
||
async def _cleanup_tasks(self):
|
||
"""清理超时任务"""
|
||
while True:
|
||
await asyncio.sleep(self.cleanup_interval)
|
||
unhealthy_tasks = await self.check_tasks_health()
|
||
for task_id in unhealthy_tasks:
|
||
await self.remove_task(task_id)
|
||
logger.warning(f"Task {task_id} removed due to timeout")
|
||
|
||
async def get_all_tasks(self) -> Dict[str, Dict[str, any]]:
|
||
"""获取所有任务信息"""
|
||
async with self.lock:
|
||
return {
|
||
task_id: {
|
||
# "task_info": task_info,
|
||
"status": "running" if (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds() < self.task_timeout
|
||
else "timeout",
|
||
"last_heartbeat": self.heartbeats.get(task_id, datetime.min).isoformat(),
|
||
"elapsed_time": (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds()
|
||
if task_id in self.heartbeats else 0
|
||
}
|
||
for task_id, task_info in self.tasks.items()
|
||
}
|
||
|
||
|
||
task_manager = TaskManager() |