554 lines
22 KiB
Python
554 lines
22 KiB
Python
import logging
|
||
|
||
from aiomqtt import MqttError, Client
|
||
import asyncio
|
||
import sys
|
||
import os
|
||
from collections import deque
|
||
import json
|
||
from typing import Dict, List, Optional, Union, Callable, Awaitable, Type, Any
|
||
|
||
from middleware.entity.up_drc_camera_osd_info_push import CameraOsdInfo, parse_camera_osd_info
|
||
from middleware.entity.up_osd_info_push import parse_osd_message, OSDMessage
|
||
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# 根据操作系统自动配置合适的事件循环
|
||
def setup_event_loop():
|
||
if sys.platform.lower() == "win32" or os.name.lower() == "nt":
|
||
from asyncio import set_event_loop_policy, WindowsSelectorEventLoopPolicy
|
||
set_event_loop_policy(WindowsSelectorEventLoopPolicy())
|
||
print("已设置Windows兼容事件循环")
|
||
else:
|
||
print("使用系统默认事件循环")
|
||
|
||
|
||
# 初始化事件循环
|
||
setup_event_loop()
|
||
|
||
|
||
|
||
# ------------------------------
|
||
# MQTT服务类(基础通信层)
|
||
# ------------------------------
|
||
class MQTTService:
|
||
def __init__(self, host, port=1883, username=None, password=None):
|
||
self.host = host
|
||
self.port = port
|
||
self.username = username
|
||
self.password = password
|
||
self.client = None
|
||
self.message_cache = {}
|
||
self.subscriptions = {}
|
||
self.is_connected = False
|
||
self._message_task = None
|
||
self._connection_lock = asyncio.Lock()
|
||
self.os_type = sys.platform.lower()
|
||
|
||
async def connect(self):
|
||
async with self._connection_lock:
|
||
if self.is_connected:
|
||
return
|
||
|
||
self.client = Client(
|
||
hostname=self.host,
|
||
port=self.port,
|
||
username=self.username,
|
||
password=self.password
|
||
)
|
||
|
||
try:
|
||
await self.client.__aenter__()
|
||
self.is_connected = True
|
||
print(f"[{self.os_type}] 成功连接到MQTT服务器: {self.host}:{self.port}")
|
||
self._start_message_listener()
|
||
except MqttError as e:
|
||
print(f"[{self.os_type}] 连接MQTT服务器失败: {e}")
|
||
self.is_connected = False
|
||
self.client = None
|
||
raise
|
||
except Exception as e:
|
||
print(f"[{self.os_type}] 连接过程发生未知错误: {e}")
|
||
self.is_connected = False
|
||
self.client = None
|
||
raise
|
||
|
||
def _start_message_listener(self):
|
||
if not self._message_task or self._message_task.done():
|
||
self._message_task = asyncio.create_task(self._message_listener())
|
||
|
||
async def _message_listener(self):
|
||
try:
|
||
async for message in self.client.messages:
|
||
for topic_pattern, sub_info in self.subscriptions.items():
|
||
if message.topic.matches(topic_pattern):
|
||
await sub_info["callback"](message)
|
||
break
|
||
except MqttError as e:
|
||
print(f"[{self.os_type}] 消息监听出错: {e}")
|
||
await self.reconnect()
|
||
except Exception as e:
|
||
print(f"[{self.os_type}] 消息监听发生未知错误: {e}")
|
||
|
||
async def reconnect(self):
|
||
print(f"[{self.os_type}] 尝试重新连接MQTT服务器...")
|
||
self.is_connected = False
|
||
if self._message_task:
|
||
self._message_task.cancel()
|
||
self._message_task = None
|
||
await self.connect()
|
||
|
||
subscriptions = list(self.subscriptions.items())
|
||
self.subscriptions.clear()
|
||
for topic, sub_info in subscriptions:
|
||
await self.subscribe(topic, sub_info["callback"], sub_info["qos"])
|
||
|
||
async def publish(self, topic, payload, qos=0, retain=False):
|
||
if not self.is_connected:
|
||
await self.connect()
|
||
|
||
try:
|
||
await self.client.publish(topic, payload, qos=qos, retain=retain)
|
||
print(f"[{self.os_type}] 已发布到 {topic}: {payload[:200]}...") # 截断长消息
|
||
except MqttError as e:
|
||
print(f"[{self.os_type}] 发布消息失败: {e}")
|
||
await self.reconnect()
|
||
await self.client.publish(topic, payload, qos=qos, retain=retain)
|
||
|
||
async def subscribe(self, topic, callback=None, qos=0):
|
||
if not self.is_connected:
|
||
await self.connect()
|
||
|
||
async def internal_callback(message):
|
||
payload = message.payload.decode() if message.payload else None
|
||
self.message_cache[topic] = payload
|
||
# print(f"[{self.os_type}] 收到消息: {topic} -> {payload[:200]}...") # 截断长消息
|
||
if callback:
|
||
await callback(topic, payload)
|
||
|
||
self.subscriptions[topic] = {
|
||
"callback": internal_callback,
|
||
"qos": qos
|
||
}
|
||
|
||
await self.client.subscribe(topic, qos=qos)
|
||
print(f"[{self.os_type}] 已订阅主题: {topic}")
|
||
|
||
def get_cached_message(self, topic, default=None):
|
||
return self.message_cache.get(topic, default)
|
||
|
||
async def disconnect(self):
|
||
if self.is_connected and self.client:
|
||
try:
|
||
await self.client.__aexit__(None, None, None)
|
||
print(f"[{self.os_type}] 已断开MQTT连接")
|
||
except MqttError as e:
|
||
print(f"[{self.os_type}] 断开连接时出错: {e}")
|
||
except Exception as e:
|
||
print(f"[{self.os_type}] 断开连接发生未知错误: {e}")
|
||
|
||
self.is_connected = False
|
||
self.subscriptions.clear()
|
||
if self._message_task:
|
||
self._message_task.cancel()
|
||
self._message_task = None
|
||
self.client = None
|
||
|
||
async def empty_osd_callback(osd_info: OSDMessage):
|
||
"""空的OSD信息更新回调函数"""
|
||
pass # 这里什么都不做,或者可以添加日志记录等
|
||
|
||
# ------------------------------
|
||
# 支持按method解析的MQTT设备类
|
||
# ------------------------------
|
||
class MQTTDevice:
|
||
_connection_cache: Dict[str, MQTTService] = {}
|
||
|
||
def __init__(self,
|
||
ip: str,
|
||
port: int = 1883,
|
||
topics: Union[str, List[str]] = None,
|
||
msg_types: Union[str, Dict[str, str]] = "json", # 主要处理JSON消息
|
||
username: Optional[str] = None,
|
||
password: Optional[str] = None,
|
||
queue_size: int = 100):
|
||
self.ip = ip
|
||
self.port = port
|
||
self.username = username
|
||
self.password = password
|
||
self.queue_size = queue_size
|
||
|
||
# 标准化主题为列表
|
||
if isinstance(topics, str):
|
||
self.topics = [topics]
|
||
elif isinstance(topics, list):
|
||
self.topics = topics
|
||
else:
|
||
self.topics = []
|
||
|
||
# 验证消息类型(强制JSON,因需要解析method字段)
|
||
self.msg_types = self._validate_msg_types(msg_types)
|
||
|
||
# 消息队列:按 (topic -> method -> deque) 三级结构存储
|
||
self.message_queues: Dict[str, Dict[str, deque]] = {
|
||
topic: {} for topic in self.topics
|
||
}
|
||
|
||
# 回调存储:按 (topic -> method -> 回调列表) 三级结构存储
|
||
self.callbacks: Dict[str, Dict[str, List[Callable[[Any], Awaitable[None]]]]] = {
|
||
topic: {} for topic in self.topics
|
||
}
|
||
|
||
# 方法-结构体映射表(核心映射关系)
|
||
self.method_struct_map: Dict[str, Type] = {
|
||
"drc_camera_osd_info_push": CameraOsdInfo,
|
||
"osd_info_push": OSDMessage,
|
||
# 可在此扩展其他method与结构体的映射
|
||
# "other_method": CorrespondingStruct,
|
||
}
|
||
|
||
self.callback_tasks = set()
|
||
self.mqtt_service = self._get_or_create_connection()
|
||
self.is_connected = False
|
||
|
||
def _validate_msg_types(self, msg_types) -> Dict[str, str]:
|
||
"""强制验证为JSON类型(因需要解析method字段)"""
|
||
valid_type = "json"
|
||
result = {}
|
||
|
||
if isinstance(msg_types, str):
|
||
if msg_types.lower() != valid_type:
|
||
print(f"警告:自动将消息类型转换为{valid_type}(需解析method字段)")
|
||
for topic in self.topics:
|
||
result[topic] = valid_type
|
||
elif isinstance(msg_types, dict):
|
||
for topic in self.topics:
|
||
result[topic] = valid_type
|
||
else:
|
||
raise ValueError("msg_types必须是字符串或字典类型")
|
||
|
||
return result
|
||
|
||
def _get_or_create_connection(self) -> MQTTService:
|
||
cache_key = f"{self.ip}:{self.port}"
|
||
|
||
if cache_key in MQTTDevice._connection_cache:
|
||
existing_service = MQTTDevice._connection_cache[cache_key]
|
||
if existing_service.is_connected:
|
||
return existing_service
|
||
|
||
new_service = MQTTService(
|
||
host=self.ip,
|
||
port=self.port,
|
||
username=self.username,
|
||
password=self.password
|
||
)
|
||
MQTTDevice._connection_cache[cache_key] = new_service
|
||
return new_service
|
||
|
||
# async def start(self):
|
||
# if not self.topics:
|
||
# print(f"设备 {self.ip}:{self.port} 未指定任何主题,不会订阅消息")
|
||
# return
|
||
# try:
|
||
# # 确保只连接一次
|
||
# if not self.is_connected:
|
||
# await self.mqtt_service.connect()
|
||
# self.is_connected = True
|
||
#
|
||
# # 订阅所有主题
|
||
# for topic in self.topics:
|
||
# await self.mqtt_service.subscribe(topic, self._on_message_received)
|
||
# logger.info(f"设备 {self.ip}:{self.port} 已订阅所有主题: {self.topics}")
|
||
#
|
||
# except Exception as e:
|
||
# logger.error(f"启动MQTT设备失败: {e}")
|
||
# raise
|
||
|
||
async def start(self):
|
||
if not self.topics:
|
||
logger.warning(f"设备 {self.ip}:{self.port} 未指定任何主题,不会订阅消息")
|
||
return
|
||
try:
|
||
if not self.is_connected:
|
||
await self.mqtt_service.connect() # 阻塞直到连接成功
|
||
self.is_connected = True
|
||
|
||
# 订阅所有主题(阻塞直到订阅完成)
|
||
for topic in self.topics:
|
||
await self.mqtt_service.subscribe(topic, self._on_message_received)
|
||
|
||
logger.info(f"设备 {self.ip}:{self.port} 已订阅所有主题: {self.topics}")
|
||
return self # 明确返回设备实例,确保连接成功
|
||
except Exception as e:
|
||
logger.error(f"启动MQTT设备失败: {e}", exc_info=True)
|
||
raise
|
||
#
|
||
# async def _on_message_received(self, topic: str, payload: str):
|
||
# if topic not in self.message_queues:
|
||
# print(f"收到未订阅的主题消息: {topic},忽略处理")
|
||
# return
|
||
#
|
||
# try:
|
||
# try:
|
||
# msg_json = json.loads(payload)
|
||
# except json.JSONDecodeError:
|
||
# print(f"收到非JSON格式消息,原始内容: {payload[:100]}...")
|
||
# return
|
||
# method = msg_json.get("method")
|
||
# if not method:
|
||
# print(f"消息缺少method字段,忽略处理: {payload[:100]}...")
|
||
# return
|
||
#
|
||
# # 根据method选择对应的结构体并解析
|
||
# parsed_struct = self._parse_by_method(msg_json, method)
|
||
#
|
||
# # 初始化method对应的队列(首次收到该method时)
|
||
# if method not in self.message_queues[topic]:
|
||
# self.message_queues[topic][method] = deque(maxlen=self.queue_size)
|
||
#
|
||
# # 存入对应队列
|
||
# self.message_queues[topic][method].append(parsed_struct)
|
||
# # print(f"设备 {self.ip}:{self.port} 主题 {topic} 方法 {method} 消息已缓存,队列长度: {len(self.message_queues[topic][method])}")
|
||
#
|
||
# # 触发该method对应的回调
|
||
# await self._trigger_callbacks(topic, method, parsed_struct)
|
||
#
|
||
# except json.JSONDecodeError:
|
||
# print(f"JSON解析失败,原始消息: {payload[:100]}...")
|
||
# except Exception as e:
|
||
# print(f"解析主题 {topic} 消息失败: {e},原始消息: {payload[:100]}...")
|
||
async def _on_message_received(self, topic: str, payload: str):
|
||
if topic not in self.message_queues:
|
||
logger.warning(f"收到未订阅的主题消息: {topic},忽略处理")
|
||
return
|
||
try:
|
||
msg_json = json.loads(payload)
|
||
method = msg_json.get("method")
|
||
if not method:
|
||
# logger.warning(f"消息缺少method字段,忽略处理: {payload[:100]}...")
|
||
return
|
||
|
||
parsed_struct = self._parse_by_method(msg_json, method)
|
||
if parsed_struct is None:
|
||
# logger.warning(f"无法解析method={method}的消息: {payload[:100]}...")
|
||
return
|
||
|
||
# 初始化队列并缓存消息
|
||
if method not in self.message_queues[topic]:
|
||
self.message_queues[topic][method] = deque(maxlen=self.queue_size)
|
||
self.message_queues[topic][method].append(parsed_struct)
|
||
# logger.debug(
|
||
# f"缓存消息: topic={topic}, method={method}, 队列长度={len(self.message_queues[topic][method])}")
|
||
|
||
await self._trigger_callbacks(topic, method, parsed_struct)
|
||
except json.JSONDecodeError:
|
||
logger.error(f"JSON解析失败,原始消息: {payload[:100]}...")
|
||
except Exception as e:
|
||
logger.error(f"解析主题 {topic} 消息失败: {e}", exc_info=True)
|
||
def _parse_by_method(self, msg_json: dict, method: str) -> Any:
|
||
"""根据method字段选择对应的结构体进行解析"""
|
||
struct_cls = self.method_struct_map.get(method)
|
||
|
||
if struct_cls == CameraOsdInfo:
|
||
# 解析drc_camera_osd_info_push特定字段
|
||
try:
|
||
data = msg_json.get('data', {})
|
||
# 忽略无法解析的字段
|
||
filtered_data = {k: v for k, v in data.items() if k != 'wide_fov_info'}
|
||
msg_json['data'] = filtered_data
|
||
return parse_camera_osd_info(msg_json)
|
||
except Exception as e:
|
||
print(f"解析drc_camera_osd_info_push失败: {e}")
|
||
return None
|
||
elif struct_cls == OSDMessage:
|
||
try:
|
||
data = msg_json.get('data', {})
|
||
# 忽略无法解析的字段
|
||
filtered_data = {k: v for k, v in data.items() if k != 'wide_fov_info'}
|
||
msg_json['data'] = filtered_data
|
||
return parse_osd_message(msg_json)
|
||
except Exception as e:
|
||
print(f"解析osd_info_push失败: {e}")
|
||
return None
|
||
else:
|
||
# 通用解析(适用于未定义特定结构体的method)
|
||
return None
|
||
# def _parse_by_method(self, msg_json: dict, method: str) -> Any:
|
||
# """根据method字段选择对应的结构体进行解析"""
|
||
# struct_cls = self.method_struct_map.get(method)
|
||
#
|
||
# if struct_cls == CameraOsdInfo:
|
||
# # 解析drc_camera_osd_info_push特定字段
|
||
# return parse_camera_osd_info(msg_json)
|
||
# elif struct_cls ==OSDMessage:
|
||
# return parse_osd_message(msg_json)
|
||
# else:
|
||
# # 通用解析(适用于未定义特定结构体的method)
|
||
# return None
|
||
|
||
# async def _trigger_callbacks(self, topic: str, method: str, parsed_struct: Any):
|
||
# """触发该topic+method对应的所有回调"""
|
||
# if method in self.callbacks[topic]:
|
||
# for callback in self.callbacks[topic][method]:
|
||
# # 为每个回调创建独立任务,避免阻塞
|
||
# task = asyncio.create_task(callback(parsed_struct))
|
||
# self.callback_tasks.add(task)
|
||
# task.add_done_callback(self.callback_tasks.discard)
|
||
async def _trigger_callbacks(self, topic: str, method: str, parsed_struct: Any):
|
||
if method in self.callbacks[topic]:
|
||
for callback in self.callbacks[topic][method]:
|
||
try:
|
||
# 若回调耗时,提交到线程池
|
||
loop = asyncio.get_running_loop()
|
||
task = loop.create_task(loop.run_in_executor(None, callback, parsed_struct))
|
||
self.callback_tasks.add(task)
|
||
task.add_done_callback(self.callback_tasks.discard)
|
||
except Exception as e:
|
||
# logger.error(f"创建回调任务失败: {e}")
|
||
pass
|
||
|
||
def register_callback(self, topic: str, method: str, callback: Callable[[Any], Awaitable[None]]):
|
||
"""为特定topic+method注册回调函数"""
|
||
if topic not in self.topics:
|
||
raise ValueError(f"主题 {topic} 未在此设备中订阅")
|
||
|
||
# 初始化method的回调列表
|
||
if method not in self.callbacks[topic]:
|
||
self.callbacks[topic][method] = []
|
||
|
||
self.callbacks[topic][method].append(callback)
|
||
print(f"已为主题 {topic} 方法 {method} 注册回调函数")
|
||
|
||
def get_messages(self, topic: str, method: str, max_count: Optional[int] = None) -> List[Any]:
|
||
"""获取指定topic+method的消息队列(先进先出)"""
|
||
if topic not in self.message_queues or method not in self.message_queues[topic]:
|
||
return []
|
||
|
||
messages = []
|
||
while self.message_queues[topic][method] and (max_count is None or len(messages) < max_count):
|
||
messages.append(self.message_queues[topic][method].popleft())
|
||
return messages
|
||
|
||
def get_latest_message(self, topic: str, method: str) -> Optional[Any]:
|
||
"""获取指定topic+method的最新消息"""
|
||
if topic in self.message_queues and method in self.message_queues[topic] and self.message_queues[topic][method]:
|
||
return self.message_queues[topic][method][-1]
|
||
return None
|
||
|
||
async def publish(self, topic: str, payload: str, qos: int = 0, retain: bool = False):
|
||
if topic not in self.topics:
|
||
print(f"警告: 发布到未订阅的主题 {topic}")
|
||
|
||
await self.mqtt_service.publish(topic, payload, qos=qos, retain=retain)
|
||
|
||
async def stop(self):
|
||
"""停止MQTT客户端并清理资源"""
|
||
logger.info(f"停止MQTT设备 {self.ip}:{self.port}")
|
||
|
||
# 取消所有回调任务
|
||
for task in self.callback_tasks:
|
||
task.cancel()
|
||
self.callback_tasks.clear()
|
||
|
||
# 断开MQTT连接
|
||
if self.is_connected and self.mqtt_service:
|
||
try:
|
||
await self.mqtt_service.disconnect()
|
||
except Exception as e:
|
||
logger.warning(f"断开MQTT连接时出错: {e}")
|
||
|
||
# 从缓存中移除
|
||
cache_key = f"{self.ip}:{self.port}"
|
||
if cache_key in MQTTDevice._connection_cache:
|
||
del MQTTDevice._connection_cache[cache_key]
|
||
|
||
self.is_connected = False
|
||
logger.info(f"MQTT设备 {cache_key} 已完全停止")
|
||
|
||
# async def stop(self):
|
||
# cache_key = f"{self.ip}:{self.port}"
|
||
# if cache_key in MQTTDevice._connection_cache:
|
||
# service = MQTTDevice._connection_cache.pop(cache_key)
|
||
# await service.disconnect()
|
||
#
|
||
# # 取消所有回调任务
|
||
# for task in self.callback_tasks:
|
||
# task.cancel()
|
||
# self.callback_tasks.clear()
|
||
#
|
||
# print(f"设备 {self.ip}:{self.port} 已停止")
|
||
|
||
|
||
# ------------------------------
|
||
# 回调函数示例(处理特定method的结构体)
|
||
# ------------------------------
|
||
async def handle_camera_osd(info: CameraOsdInfo):
|
||
"""处理drc_camera_osd_info_push类型的消息"""
|
||
print("\n[回调触发] 相机OSD信息更新:")
|
||
print(f" 云台角度: 俯仰={info.gimbal_pitch}, 横滚={info.gimbal_roll}, 偏航={info.gimbal_yaw}")
|
||
print(f" 目标位置: 纬度={info.measure_target_latitude}, 经度={info.measure_target_longitude}")
|
||
print(f" 序列号: {info.seq}, 时间戳: {info.timestamp}")
|
||
|
||
|
||
async def handle_generic_method(data: OSDMessage):
|
||
"""处理未定义特定结构体的method消息"""
|
||
print(f"\n[回调触发] 通用方法 {data.method}:")
|
||
print(f" 序列号: {data.seq}, 时间戳: {data.timestamp}")
|
||
print(f" 数据摘要: {str(data.data)[:100]}...")
|
||
|
||
|
||
|
||
|
||
async def handle_osd_info(info: OSDMessage):
|
||
"""改进的回调函数,显示更多信息"""
|
||
print("\n[回调触发] OSD信息更新:")
|
||
print(f" 高度: {info.data.height}")
|
||
print(f" 云台角度: 俯仰={info.data.gimbal_pitch}, 横滚={info.data.gimbal_roll}, 偏航={info.data.gimbal_yaw}")
|
||
print(f" 位置: 纬度={info.data.latitude}, 经度={info.data.longitude}")
|
||
print(f" 序列号: {info.seq}, 时间戳: {info.timestamp}")
|
||
|
||
|
||
async def main():
|
||
device = MQTTDevice(
|
||
ip="8.137.54.85",
|
||
port=1883,
|
||
topics=["thing/product/8UUXN5U00A09UF/drc/up"], # 确保与发布主题一致
|
||
queue_size=50
|
||
)
|
||
|
||
# 注册回调
|
||
device.register_callback(
|
||
topic="thing/product/8UUXN5U00A09UF/drc/up",
|
||
method="osd_info_push",
|
||
callback=handle_osd_info
|
||
)
|
||
|
||
await device.start()
|
||
|
||
# 保持运行
|
||
try:
|
||
while True:
|
||
await asyncio.sleep(1)
|
||
# 也可以在这里检查消息队列
|
||
latest = device.get_latest_message(
|
||
topic="thing/product/8UUXN5U00A09UF/drc/up",
|
||
method="osd_info_push"
|
||
)
|
||
if latest:
|
||
print(f"最新高度: {latest.data.height}")
|
||
except KeyboardInterrupt:
|
||
print("用户中断")
|
||
finally:
|
||
await device.stop()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |