ai_project_v1/middleware/TaskManager.py

228 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
from typing import List, Dict, Any, Optional
import logging
import asyncio
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 心跳暂未启用直接使用100分钟
TaskTimeout=6000
class TaskManager:
def __init__(self):
self.tasks: Dict[str, Dict[str, Any]] = {} # task_id -> task_info
self.heartbeats: Dict[str, datetime] = {} # task_id -> last_heartbeat
self.task_handles: Dict[str, asyncio.Task] = {} # task_id -> main_task_handle
self.sub_tasks: Dict[str, List[asyncio.Task]] = {} # task_id -> [sub_task_handles]
self.stop_events: Dict[str, asyncio.Event] = {} # task_id -> stop_event
self.lock = asyncio.Lock()
self.heartbeat_interval = 30
self.task_timeout = TaskTimeout
self.cleanup_interval = 10
async def start(self):
"""启动任务管理器"""
self.cleanup_task = asyncio.create_task(self._cleanup_tasks())
async def add_task(self, task_id: str, task_info: Dict[str, Any], main_task_handle: asyncio.Task,
sub_task_handles: List[asyncio.Task] = None):
"""添加新任务及其所有子任务"""
async with self.lock:
if task_id in self.tasks:
raise ValueError("Task already exists")
# 为每个任务创建stop_event
stop_event = asyncio.Event()
self.tasks[task_id] = {
"task_info": task_info,
"status": "running",
"main_task_handle": main_task_handle,
"sub_task_handles": sub_task_handles if sub_task_handles else []
}
self.heartbeats[task_id] = datetime.now()
self.task_handles[task_id] = main_task_handle
self.stop_events[task_id] = stop_event # 存储stop_event
if sub_task_handles:
self.sub_tasks[task_id] = sub_task_handles
logger.info(f"Task {task_id} started")
#
async def remove_task(self, task_id: str):
"""移除任务及其所有子任务包括MQTT资源"""
async with self.lock:
if task_id not in self.tasks:
return False
# 获取stop_event并设置
if task_id in self.stop_events:
self.stop_events[task_id].set() # 触发停止信号
# 获取所有子任务
sub_tasks = []
if task_id in self.sub_tasks:
sub_tasks = self.sub_tasks[task_id]
# 获取主任务
if task_id in self.task_handles:
main_task = self.task_handles[task_id]
if not main_task.done():
try:
# 尝试取消主任务
await asyncio.shield(main_task.cancel())
except Exception as e:
logger.warning(f"取消主任务 {task_id} 失败: {e}")
# 清理MQTT设备资源
if task_id in self.tasks:
task_info = self.tasks[task_id]
if 'device' in task_info: # 假设任务信息中包含MQTT设备引用
try:
for device in task_info['device']:
await device.stop() # 停止MQTT设备
except Exception as e:
logger.warning(f"停止MQTT设备 {task_id} 失败: {e}")
# 取消所有子任务
for sub_task in sub_tasks:
if not sub_task.done():
try:
await asyncio.shield(sub_task.cancel())
except Exception as e:
logger.warning(f"取消子任务 {task_id} 失败: {e}")
# 清理资源
del self.tasks[task_id]
if task_id in self.task_handles:
del self.task_handles[task_id]
if task_id in self.sub_tasks:
del self.sub_tasks[task_id]
if task_id in self.stop_events:
del self.stop_events[task_id]
logger.info(f"Task {task_id} removed")
return True
#
# async def remove_task(self, task_id: str):
# """改进的任务移除逻辑"""
# async with self.lock:
# if task_id not in self.tasks:
# return False
#
# # 设置取消标志
# if task_id in self.stop_events:
# self.stop_events[task_id].set() # 触发停止信号
#
# # 获取任务信息
# task_info = self.tasks[task_id]
#
# # 取消主任务(如果是协程)
# if 'main_task' in task_info and isinstance(task_info['main_task'], asyncio.Task):
# try:
# task_info['main_task'].cancel()
# await asyncio.shield(task_info['main_task']) # 等待任务完成取消
# except Exception as e:
# logger.warning(f"取消主任务 {task_id} 失败: {str(e)}")
#
# # 取消所有子任务
# if 'sub_tasks' in task_info and task_id in self.sub_tasks:
# for sub_task in self.sub_tasks[task_id]:
# if isinstance(sub_task, asyncio.Task):
# try:
# sub_task.cancel()
# await asyncio.shield(sub_task)
# except Exception as e:
# logger.warning(f"取消子任务 {task_id} 失败: {str(e)}")
#
# # 清理MQTT资源
# if 'device' in task_info and isinstance(task_info['device'], MQTTDevice):
# try:
# await task_info['device'].stop()
# except Exception as e:
# logger.warning(f"停止MQTT设备 {task_id} 失败: {str(e)}")
#
# # 清理资源
# del self.tasks[task_id]
# if task_id in self.sub_tasks:
# del self.sub_tasks[task_id]
# if task_id in self.stop_events:
# del self.stop_events[task_id]
#
# logger.info(f"Task {task_id} removed successfully")
# return True
async def mark_all_tasks_as_stopped(self):
"""标记所有任务为停止状态"""
async with self.lock:
for task_id in self.tasks:
self.tasks[task_id]["status"] = "stopped"
async def update_heartbeat(self, task_id: str) -> None:
"""更新任务心跳时间"""
if task_id in self.heartbeats:
self.heartbeats[task_id] = datetime.now()
logger.debug(f"Heartbeat updated for task {task_id}")
async def get_task_info(self, task_id: str) -> Optional[Dict[str, any]]:
"""获取任务信息"""
async with self.lock:
if task_id not in self.tasks:
return None
last_heartbeat = self.heartbeats.get(task_id)
current_time = datetime.now()
# elapsed_time = (current_time - last_heartbeat).total_seconds()
elapsed_time = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None
return {
"task_info": self.tasks[task_id],
"status": "running" if (elapsed_time is not None and elapsed_time < self.task_timeout) else "timeout",
"last_heartbeat": last_heartbeat.isoformat(),
"elapsed_time": elapsed_time
}
async def check_tasks_health(self) -> Dict[str, str]:
"""检查任务健康状态"""
unhealthy_tasks = {}
current_time = datetime.now()
async with self.lock:
for task_id in list(self.tasks.keys()):
last_heartbeat = self.heartbeats.get(task_id)
elapsed_seconds = (current_time - last_heartbeat).total_seconds() if last_heartbeat else None
print(f"当前心跳已经进行{elapsed_seconds}s")
if elapsed_seconds is not None and elapsed_seconds > self.task_timeout:
unhealthy_tasks[task_id] = "timeout"
logger.warning(f"----------------------------Task {task_id} has timed out-----------------------------")
return unhealthy_tasks
async def _cleanup_tasks(self):
"""清理超时任务"""
while True:
await asyncio.sleep(self.cleanup_interval)
unhealthy_tasks = await self.check_tasks_health()
for task_id in unhealthy_tasks:
await self.remove_task(task_id)
logger.warning(f"Task {task_id} removed due to timeout")
async def get_all_tasks(self) -> Dict[str, Dict[str, any]]:
"""获取所有任务信息"""
async with self.lock:
return {
task_id: {
# "task_info": task_info,
"status": "running" if (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds() < self.task_timeout
else "timeout",
"last_heartbeat": self.heartbeats.get(task_id, datetime.min).isoformat(),
"elapsed_time": (datetime.now() - self.heartbeats.get(task_id, datetime.min)).total_seconds()
if task_id in self.heartbeats else 0
}
for task_id, task_info in self.tasks.items()
}
task_manager = TaskManager()