ai_project_v1/middleware/MQTTService.py

554 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from aiomqtt import MqttError, Client
import asyncio
import sys
import os
from collections import deque
import json
from typing import Dict, List, Optional, Union, Callable, Awaitable, Type, Any
from middleware.entity.up_drc_camera_osd_info_push import CameraOsdInfo, parse_camera_osd_info
from middleware.entity.up_osd_info_push import parse_osd_message, OSDMessage
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# 根据操作系统自动配置合适的事件循环
def setup_event_loop():
if sys.platform.lower() == "win32" or os.name.lower() == "nt":
from asyncio import set_event_loop_policy, WindowsSelectorEventLoopPolicy
set_event_loop_policy(WindowsSelectorEventLoopPolicy())
print("已设置Windows兼容事件循环")
else:
print("使用系统默认事件循环")
# 初始化事件循环
setup_event_loop()
# ------------------------------
# MQTT服务类基础通信层
# ------------------------------
class MQTTService:
def __init__(self, host, port=1883, username=None, password=None):
self.host = host
self.port = port
self.username = username
self.password = password
self.client = None
self.message_cache = {}
self.subscriptions = {}
self.is_connected = False
self._message_task = None
self._connection_lock = asyncio.Lock()
self.os_type = sys.platform.lower()
async def connect(self):
async with self._connection_lock:
if self.is_connected:
return
self.client = Client(
hostname=self.host,
port=self.port,
username=self.username,
password=self.password
)
try:
await self.client.__aenter__()
self.is_connected = True
print(f"[{self.os_type}] 成功连接到MQTT服务器: {self.host}:{self.port}")
self._start_message_listener()
except MqttError as e:
print(f"[{self.os_type}] 连接MQTT服务器失败: {e}")
self.is_connected = False
self.client = None
raise
except Exception as e:
print(f"[{self.os_type}] 连接过程发生未知错误: {e}")
self.is_connected = False
self.client = None
raise
def _start_message_listener(self):
if not self._message_task or self._message_task.done():
self._message_task = asyncio.create_task(self._message_listener())
async def _message_listener(self):
try:
async for message in self.client.messages:
for topic_pattern, sub_info in self.subscriptions.items():
if message.topic.matches(topic_pattern):
await sub_info["callback"](message)
break
except MqttError as e:
print(f"[{self.os_type}] 消息监听出错: {e}")
await self.reconnect()
except Exception as e:
print(f"[{self.os_type}] 消息监听发生未知错误: {e}")
async def reconnect(self):
print(f"[{self.os_type}] 尝试重新连接MQTT服务器...")
self.is_connected = False
if self._message_task:
self._message_task.cancel()
self._message_task = None
await self.connect()
subscriptions = list(self.subscriptions.items())
self.subscriptions.clear()
for topic, sub_info in subscriptions:
await self.subscribe(topic, sub_info["callback"], sub_info["qos"])
async def publish(self, topic, payload, qos=0, retain=False):
if not self.is_connected:
await self.connect()
try:
await self.client.publish(topic, payload, qos=qos, retain=retain)
print(f"[{self.os_type}] 已发布到 {topic}: {payload[:200]}...") # 截断长消息
except MqttError as e:
print(f"[{self.os_type}] 发布消息失败: {e}")
await self.reconnect()
await self.client.publish(topic, payload, qos=qos, retain=retain)
async def subscribe(self, topic, callback=None, qos=0):
if not self.is_connected:
await self.connect()
async def internal_callback(message):
payload = message.payload.decode() if message.payload else None
self.message_cache[topic] = payload
# print(f"[{self.os_type}] 收到消息: {topic} -> {payload[:200]}...") # 截断长消息
if callback:
await callback(topic, payload)
self.subscriptions[topic] = {
"callback": internal_callback,
"qos": qos
}
await self.client.subscribe(topic, qos=qos)
print(f"[{self.os_type}] 已订阅主题: {topic}")
def get_cached_message(self, topic, default=None):
return self.message_cache.get(topic, default)
async def disconnect(self):
if self.is_connected and self.client:
try:
await self.client.__aexit__(None, None, None)
print(f"[{self.os_type}] 已断开MQTT连接")
except MqttError as e:
print(f"[{self.os_type}] 断开连接时出错: {e}")
except Exception as e:
print(f"[{self.os_type}] 断开连接发生未知错误: {e}")
self.is_connected = False
self.subscriptions.clear()
if self._message_task:
self._message_task.cancel()
self._message_task = None
self.client = None
async def empty_osd_callback(osd_info: OSDMessage):
"""空的OSD信息更新回调函数"""
pass # 这里什么都不做,或者可以添加日志记录等
# ------------------------------
# 支持按method解析的MQTT设备类
# ------------------------------
class MQTTDevice:
_connection_cache: Dict[str, MQTTService] = {}
def __init__(self,
ip: str,
port: int = 1883,
topics: Union[str, List[str]] = None,
msg_types: Union[str, Dict[str, str]] = "json", # 主要处理JSON消息
username: Optional[str] = None,
password: Optional[str] = None,
queue_size: int = 100):
self.ip = ip
self.port = port
self.username = username
self.password = password
self.queue_size = queue_size
# 标准化主题为列表
if isinstance(topics, str):
self.topics = [topics]
elif isinstance(topics, list):
self.topics = topics
else:
self.topics = []
# 验证消息类型强制JSON因需要解析method字段
self.msg_types = self._validate_msg_types(msg_types)
# 消息队列:按 (topic -> method -> deque) 三级结构存储
self.message_queues: Dict[str, Dict[str, deque]] = {
topic: {} for topic in self.topics
}
# 回调存储:按 (topic -> method -> 回调列表) 三级结构存储
self.callbacks: Dict[str, Dict[str, List[Callable[[Any], Awaitable[None]]]]] = {
topic: {} for topic in self.topics
}
# 方法-结构体映射表(核心映射关系)
self.method_struct_map: Dict[str, Type] = {
"drc_camera_osd_info_push": CameraOsdInfo,
"osd_info_push": OSDMessage,
# 可在此扩展其他method与结构体的映射
# "other_method": CorrespondingStruct,
}
self.callback_tasks = set()
self.mqtt_service = self._get_or_create_connection()
self.is_connected = False
def _validate_msg_types(self, msg_types) -> Dict[str, str]:
"""强制验证为JSON类型因需要解析method字段"""
valid_type = "json"
result = {}
if isinstance(msg_types, str):
if msg_types.lower() != valid_type:
print(f"警告:自动将消息类型转换为{valid_type}需解析method字段")
for topic in self.topics:
result[topic] = valid_type
elif isinstance(msg_types, dict):
for topic in self.topics:
result[topic] = valid_type
else:
raise ValueError("msg_types必须是字符串或字典类型")
return result
def _get_or_create_connection(self) -> MQTTService:
cache_key = f"{self.ip}:{self.port}"
if cache_key in MQTTDevice._connection_cache:
existing_service = MQTTDevice._connection_cache[cache_key]
if existing_service.is_connected:
return existing_service
new_service = MQTTService(
host=self.ip,
port=self.port,
username=self.username,
password=self.password
)
MQTTDevice._connection_cache[cache_key] = new_service
return new_service
# async def start(self):
# if not self.topics:
# print(f"设备 {self.ip}:{self.port} 未指定任何主题,不会订阅消息")
# return
# try:
# # 确保只连接一次
# if not self.is_connected:
# await self.mqtt_service.connect()
# self.is_connected = True
#
# # 订阅所有主题
# for topic in self.topics:
# await self.mqtt_service.subscribe(topic, self._on_message_received)
# logger.info(f"设备 {self.ip}:{self.port} 已订阅所有主题: {self.topics}")
#
# except Exception as e:
# logger.error(f"启动MQTT设备失败: {e}")
# raise
async def start(self):
if not self.topics:
logger.warning(f"设备 {self.ip}:{self.port} 未指定任何主题,不会订阅消息")
return
try:
if not self.is_connected:
await self.mqtt_service.connect() # 阻塞直到连接成功
self.is_connected = True
# 订阅所有主题(阻塞直到订阅完成)
for topic in self.topics:
await self.mqtt_service.subscribe(topic, self._on_message_received)
logger.info(f"设备 {self.ip}:{self.port} 已订阅所有主题: {self.topics}")
return self # 明确返回设备实例,确保连接成功
except Exception as e:
logger.error(f"启动MQTT设备失败: {e}", exc_info=True)
raise
#
# async def _on_message_received(self, topic: str, payload: str):
# if topic not in self.message_queues:
# print(f"收到未订阅的主题消息: {topic},忽略处理")
# return
#
# try:
# try:
# msg_json = json.loads(payload)
# except json.JSONDecodeError:
# print(f"收到非JSON格式消息原始内容: {payload[:100]}...")
# return
# method = msg_json.get("method")
# if not method:
# print(f"消息缺少method字段忽略处理: {payload[:100]}...")
# return
#
# # 根据method选择对应的结构体并解析
# parsed_struct = self._parse_by_method(msg_json, method)
#
# # 初始化method对应的队列首次收到该method时
# if method not in self.message_queues[topic]:
# self.message_queues[topic][method] = deque(maxlen=self.queue_size)
#
# # 存入对应队列
# self.message_queues[topic][method].append(parsed_struct)
# # print(f"设备 {self.ip}:{self.port} 主题 {topic} 方法 {method} 消息已缓存,队列长度: {len(self.message_queues[topic][method])}")
#
# # 触发该method对应的回调
# await self._trigger_callbacks(topic, method, parsed_struct)
#
# except json.JSONDecodeError:
# print(f"JSON解析失败原始消息: {payload[:100]}...")
# except Exception as e:
# print(f"解析主题 {topic} 消息失败: {e},原始消息: {payload[:100]}...")
async def _on_message_received(self, topic: str, payload: str):
if topic not in self.message_queues:
logger.warning(f"收到未订阅的主题消息: {topic},忽略处理")
return
try:
msg_json = json.loads(payload)
method = msg_json.get("method")
if not method:
# logger.warning(f"消息缺少method字段忽略处理: {payload[:100]}...")
return
parsed_struct = self._parse_by_method(msg_json, method)
if parsed_struct is None:
# logger.warning(f"无法解析method={method}的消息: {payload[:100]}...")
return
# 初始化队列并缓存消息
if method not in self.message_queues[topic]:
self.message_queues[topic][method] = deque(maxlen=self.queue_size)
self.message_queues[topic][method].append(parsed_struct)
# logger.debug(
# f"缓存消息: topic={topic}, method={method}, 队列长度={len(self.message_queues[topic][method])}")
await self._trigger_callbacks(topic, method, parsed_struct)
except json.JSONDecodeError:
logger.error(f"JSON解析失败原始消息: {payload[:100]}...")
except Exception as e:
logger.error(f"解析主题 {topic} 消息失败: {e}", exc_info=True)
def _parse_by_method(self, msg_json: dict, method: str) -> Any:
"""根据method字段选择对应的结构体进行解析"""
struct_cls = self.method_struct_map.get(method)
if struct_cls == CameraOsdInfo:
# 解析drc_camera_osd_info_push特定字段
try:
data = msg_json.get('data', {})
# 忽略无法解析的字段
filtered_data = {k: v for k, v in data.items() if k != 'wide_fov_info'}
msg_json['data'] = filtered_data
return parse_camera_osd_info(msg_json)
except Exception as e:
print(f"解析drc_camera_osd_info_push失败: {e}")
return None
elif struct_cls == OSDMessage:
try:
data = msg_json.get('data', {})
# 忽略无法解析的字段
filtered_data = {k: v for k, v in data.items() if k != 'wide_fov_info'}
msg_json['data'] = filtered_data
return parse_osd_message(msg_json)
except Exception as e:
print(f"解析osd_info_push失败: {e}")
return None
else:
# 通用解析适用于未定义特定结构体的method
return None
# def _parse_by_method(self, msg_json: dict, method: str) -> Any:
# """根据method字段选择对应的结构体进行解析"""
# struct_cls = self.method_struct_map.get(method)
#
# if struct_cls == CameraOsdInfo:
# # 解析drc_camera_osd_info_push特定字段
# return parse_camera_osd_info(msg_json)
# elif struct_cls ==OSDMessage:
# return parse_osd_message(msg_json)
# else:
# # 通用解析适用于未定义特定结构体的method
# return None
# async def _trigger_callbacks(self, topic: str, method: str, parsed_struct: Any):
# """触发该topic+method对应的所有回调"""
# if method in self.callbacks[topic]:
# for callback in self.callbacks[topic][method]:
# # 为每个回调创建独立任务,避免阻塞
# task = asyncio.create_task(callback(parsed_struct))
# self.callback_tasks.add(task)
# task.add_done_callback(self.callback_tasks.discard)
async def _trigger_callbacks(self, topic: str, method: str, parsed_struct: Any):
if method in self.callbacks[topic]:
for callback in self.callbacks[topic][method]:
try:
# 若回调耗时,提交到线程池
loop = asyncio.get_running_loop()
task = loop.create_task(loop.run_in_executor(None, callback, parsed_struct))
self.callback_tasks.add(task)
task.add_done_callback(self.callback_tasks.discard)
except Exception as e:
# logger.error(f"创建回调任务失败: {e}")
pass
def register_callback(self, topic: str, method: str, callback: Callable[[Any], Awaitable[None]]):
"""为特定topic+method注册回调函数"""
if topic not in self.topics:
raise ValueError(f"主题 {topic} 未在此设备中订阅")
# 初始化method的回调列表
if method not in self.callbacks[topic]:
self.callbacks[topic][method] = []
self.callbacks[topic][method].append(callback)
print(f"已为主题 {topic} 方法 {method} 注册回调函数")
def get_messages(self, topic: str, method: str, max_count: Optional[int] = None) -> List[Any]:
"""获取指定topic+method的消息队列先进先出"""
if topic not in self.message_queues or method not in self.message_queues[topic]:
return []
messages = []
while self.message_queues[topic][method] and (max_count is None or len(messages) < max_count):
messages.append(self.message_queues[topic][method].popleft())
return messages
def get_latest_message(self, topic: str, method: str) -> Optional[Any]:
"""获取指定topic+method的最新消息"""
if topic in self.message_queues and method in self.message_queues[topic] and self.message_queues[topic][method]:
return self.message_queues[topic][method][-1]
return None
async def publish(self, topic: str, payload: str, qos: int = 0, retain: bool = False):
if topic not in self.topics:
print(f"警告: 发布到未订阅的主题 {topic}")
await self.mqtt_service.publish(topic, payload, qos=qos, retain=retain)
async def stop(self):
"""停止MQTT客户端并清理资源"""
logger.info(f"停止MQTT设备 {self.ip}:{self.port}")
# 取消所有回调任务
for task in self.callback_tasks:
task.cancel()
self.callback_tasks.clear()
# 断开MQTT连接
if self.is_connected and self.mqtt_service:
try:
await self.mqtt_service.disconnect()
except Exception as e:
logger.warning(f"断开MQTT连接时出错: {e}")
# 从缓存中移除
cache_key = f"{self.ip}:{self.port}"
if cache_key in MQTTDevice._connection_cache:
del MQTTDevice._connection_cache[cache_key]
self.is_connected = False
logger.info(f"MQTT设备 {cache_key} 已完全停止")
# async def stop(self):
# cache_key = f"{self.ip}:{self.port}"
# if cache_key in MQTTDevice._connection_cache:
# service = MQTTDevice._connection_cache.pop(cache_key)
# await service.disconnect()
#
# # 取消所有回调任务
# for task in self.callback_tasks:
# task.cancel()
# self.callback_tasks.clear()
#
# print(f"设备 {self.ip}:{self.port} 已停止")
# ------------------------------
# 回调函数示例处理特定method的结构体
# ------------------------------
async def handle_camera_osd(info: CameraOsdInfo):
"""处理drc_camera_osd_info_push类型的消息"""
print("\n[回调触发] 相机OSD信息更新:")
print(f" 云台角度: 俯仰={info.gimbal_pitch}, 横滚={info.gimbal_roll}, 偏航={info.gimbal_yaw}")
print(f" 目标位置: 纬度={info.measure_target_latitude}, 经度={info.measure_target_longitude}")
print(f" 序列号: {info.seq}, 时间戳: {info.timestamp}")
async def handle_generic_method(data: OSDMessage):
"""处理未定义特定结构体的method消息"""
print(f"\n[回调触发] 通用方法 {data.method}:")
print(f" 序列号: {data.seq}, 时间戳: {data.timestamp}")
print(f" 数据摘要: {str(data.data)[:100]}...")
async def handle_osd_info(info: OSDMessage):
"""改进的回调函数,显示更多信息"""
print("\n[回调触发] OSD信息更新:")
print(f" 高度: {info.data.height}")
print(f" 云台角度: 俯仰={info.data.gimbal_pitch}, 横滚={info.data.gimbal_roll}, 偏航={info.data.gimbal_yaw}")
print(f" 位置: 纬度={info.data.latitude}, 经度={info.data.longitude}")
print(f" 序列号: {info.seq}, 时间戳: {info.timestamp}")
async def main():
device = MQTTDevice(
ip="8.137.54.85",
port=1883,
topics=["thing/product/8UUXN5U00A09UF/drc/up"], # 确保与发布主题一致
queue_size=50
)
# 注册回调
device.register_callback(
topic="thing/product/8UUXN5U00A09UF/drc/up",
method="osd_info_push",
callback=handle_osd_info
)
await device.start()
# 保持运行
try:
while True:
await asyncio.sleep(1)
# 也可以在这里检查消息队列
latest = device.get_latest_message(
topic="thing/product/8UUXN5U00A09UF/drc/up",
method="osd_info_push"
)
if latest:
print(f"最新高度: {latest.data.height}")
except KeyboardInterrupt:
print("用户中断")
finally:
await device.stop()
if __name__ == "__main__":
asyncio.run(main())