ai_project_v1/middleware/TaskManager.py

228 lines
9.4 KiB
Python
Raw Normal View History

from datetime import datetime
from typing import List, Dict, Any, Optional
import logging
import asyncio
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 心跳暂未启用直接使用100分钟
TaskTimeout=6000
class TaskManager:
def __init__(self):
self.tasks: Dict[str, Dict[str, Any]] = {} # task_id -> task_info
self.heartbeats: Dict[str, datetime] = {} # task_id -> last_heartbeat
self.task_handles: Dict[str, asyncio.Task] = {} # task_id -> main_task_handle
self.sub_tasks: Dict[str, List[asyncio.Task]] = {} # task_id -> [sub_task_handles]
self.stop_events: Dict[str, asyncio.Event] = {} # task_id -> stop_event
self.lock = asyncio.Lock()
self.heartbeat_interval = 30
self.task_timeout = TaskTimeout
self.cleanup_interval = 10
async def start(self):
"""启动任务管理器"""
self.cleanup_task = asyncio.create_task(self._cleanup_tasks())
async def add_task(self, task_id: str, task_info: Dict[str, Any], main_task_handle: asyncio.Task,
sub_task_handles: List[asyncio.Task] = None):
"""添加新任务及其所有子任务"""
async with self.lock:
if task_id in self.tasks:
raise ValueError("Task already exists")
# 为每个任务创建stop_event
stop_event = asyncio.Event()
self.tasks[task_id] = {
"task_info": task_info,
"status": "running",
"main_task_handle": main_task_handle,
"sub_task_handles": sub_task_handles if sub_task_handles else []
}
self.heartbeats[task_id] = datetime.now()
self.task_handles[task_id] = main_task_handle
self.stop_events[task_id] = stop_event # 存储stop_event
if sub_task_handles:
self.sub_tasks[task_id] = sub_task_handles
logger.info(f"Task {task_id} started")
#
async def remove_task(self, task_id: str):
"""移除任务及其所有子任务包括MQTT资源"""
async with self.lock:
if task_id not in self.tasks:
return False
# 获取stop_event并设置
if task_id in self.stop_events:
self.stop_events[task_id].set() # 触发停止信号
# 获取所有子任务
sub_tasks = []
if task_id in self.sub_tasks:
sub_tasks = self.sub_tasks[task_id]
# 获取主任务
if task_id in self.task_handles:
main_task = self.task_handles[task_id]
if not main_task.done():
try:
# 尝试取消主任务
await asyncio.shield(main_task.cancel())
except Exception as e:
logger.warning(f"取消主任务 {task_id} 失败: {e}")
# 清理MQTT设备资源
if task_id in self.tasks:
task_info = self.tasks[task_id]
if 'device' in task_info: # 假设任务信息中包含MQTT设备引用
try:
for device in task_info['device']:
await device.stop() # 停止MQTT设备
except Exception as e:
logger.warning(f"停止MQTT设备 {task_id} 失败: {e}")
# 取消所有子任务
for sub_task in sub_tasks:
if not sub_task.done():
try:
await asyncio.shield(sub_task.cancel())
except Exception as e:
logger.warning(f"取消子任务 {task_id} 失败: {e}")
# 清理资源
del self.tasks[task_id]
if task_id in self.task_handles:
del self.task_handles[task_id]
if task_id in self.sub_tasks:
del self.sub_tasks[task_id]
if task_id in self.stop_events:
del self.stop_events[task_id]
logger.info(f"Task {task_id} removed")
return True
#
# async def remove_task(self, task_id: str):
# """改进的任务移除逻辑"""
# async with self.lock:
# if task_id not in self.tasks:
# return False
#
# # 设置取消标志
# if task_id in self.stop_events:
# self.stop_events[task_id].set() # 触发停止信号
#
# # 获取任务信息
# task_info = self.tasks[task_id]
#
# # 取消主任务(如果是协程)
# if 'main_task' in task_info and isinstance(task_info['main_task'], asyncio.Task):
# try:
# task_info['main_task'].cancel()
# await asyncio.shield(task_info['main_task']) # 等待任务完成取消
# except Exception as e:
# logger.warning(f"取消主任务 {task_id} 失败: {str(e)}")
#
# # 取消所有子任务
# if 'sub_tasks' in task_info and task_id in self.sub_tasks:
# for sub_task in self.sub_tasks[task_id]:
# if isinstance(sub_task, asyncio.Task):
# try:
# sub_task.cancel()
# await asyncio.shield(sub_task)
# except Exception as e:
# logger.warning(f"取消子任务 {task_id} 失败: {str(e)}")
#
# # 清理MQTT资源
# if 'device' in task_info and isinstance(task_info['device'], MQTTDevice):
# try:
# await task_info['device'].stop()
# except Exception as e:
# logger.warning(f"停止MQTT设备 {task_id} 失败: {str(e)}")
#
# # 清理资源
# del self.tasks[task_id]
# if task_id in self.sub_tasks:
# del self.sub_tasks[task_id]
# if task_id in self.stop_events:
# del self.stop_events[task_id]
#
# logger.info(f"Task {task_id} removed successfully")
# return True
async def mark_all_tasks_as_stopped(self):
"""标记所有任务为停止状态"""
async with self.lock:
for task_id in self.tasks:
self.tasks[task_id]["status"] = "stopped"
async def update_heartbeat(self, task_id: str) -> None:
"""更新任务心跳时间"""
if task_id in self.heartbeats:
self.heartbeats[task_id] = datetime.now()
logger.debug(f"Heartbeat updated for task {task_id}")
async def get_task_info(self, task_id: str) -> Optional[Dict[str, any]]:
"""获取任务信息"""
async with self.lock:
if task_id not in self.tasks:
return None
last_heartbeat = self.heartbeats.get(task_id)
current_time = datetime.now()
# elapsed_time = (current_time - last_heartbeat).total_seconds()
elapsed_time = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None
return {
"task_info": self.tasks[task_id],
"status": "running" if (elapsed_time is not None and elapsed_time < self.task_timeout) else "timeout",
"last_heartbeat": last_heartbeat.isoformat(),
"elapsed_time": elapsed_time
}
async def check_tasks_health(self) -> Dict[str, str]:
"""检查任务健康状态"""
unhealthy_tasks = {}
current_time = datetime.now()
async with self.lock:
for task_id in list(self.tasks.keys()):
last_heartbeat = self.heartbeats.get(task_id)
elapsed_seconds = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None
print(f"当前心跳已经进行{elapsed_seconds}s")
if elapsed_seconds is not None and elapsed_seconds > self.task_timeout:
unhealthy_tasks[task_id] = "timeout"
logger.warning(f"----------------------------Task {task_id} has timed out-----------------------------")
return unhealthy_tasks
async def _cleanup_tasks(self):
"""清理超时任务"""
while True:
await asyncio.sleep(self.cleanup_interval)
unhealthy_tasks = await self.check_tasks_health()
for task_id in unhealthy_tasks:
await self.remove_task(task_id)
logger.warning(f"Task {task_id} removed due to timeout")
async def get_all_tasks(self) -> Dict[str, Dict[str, any]]:
"""获取所有任务信息"""
async with self.lock:
return {
task_id: {
# "task_info": task_info,
"status": "running" if (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds() < self.task_timeout
else "timeout",
"last_heartbeat": self.heartbeats.get(task_id, datetime.min).isoformat(),
"elapsed_time": (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds()
if task_id in self.heartbeats else 0
}
for task_id, task_info in self.tasks.items()
}
task_manager = TaskManager()