import json from concurrent import futures import grpc import threading import queue import time import logging from typing import Dict, Optional from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2 from middleware.MQTTService import MQTTService from middleware.minio_util import downFile, upload_file import sys from middleware.util import get_current_date_and_milliseconds print(sys.executable) import os import matplotlib.pyplot as plt import numpy as np import sam3.sam3 from PIL import Image from sam3.sam3 import build_sam3_image_model from sam3.sam3 import build_sam3_image_model_0228 from sam3.sam3.model.box_ops import box_xywh_to_cxcywh from sam3.sam3.model.sam3_image_processor import Sam3Processor from sam3.sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results,plot_results_savepic sam3_root = os.path.join(os.path.dirname(sam3.sam3.__file__), "..") import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.autocast("cuda", dtype=torch.bfloat16).__enter__() # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TaskQueue: """任务队列管理类""" def __init__(self, max_size: int = 1000): self.queue = queue.Queue(maxsize=max_size) self.task_status: Dict[str, dict] = {} # 存储任务状态 self.lock = threading.Lock() self.stop_event = threading.Event() def add_task(self, task_id: str, request_data: dict) -> bool: """添加任务到队列""" try: task_item = { 'task_id': task_id, 'data': request_data, 'timestamp': time.time(), 'status': 'pending' } with self.lock: self.task_status[task_id] = task_item # 非阻塞方式放入队列 self.queue.put(task_item, block=False) logger.info(f"任务 {task_id} 已添加到队列,当前队列大小: {self.queue.qsize()}") return True except queue.Full: logger.warning(f"队列已满,任务 {task_id} 被拒绝") return False except Exception as e: logger.error(f"添加任务失败: {e}") return False def get_task(self, timeout: float = 1.0) -> Optional[dict]: """从队列获取任务""" try: return self.queue.get(timeout=timeout) except queue.Empty: return None def update_task_status(self, task_id: str, status: str, result: dict = None): """更新任务状态""" with self.lock: if task_id in self.task_status: self.task_status[task_id]['status'] = status if result: self.task_status[task_id]['result'] = result self.task_status[task_id]['completed_time'] = time.time() def get_task_status(self, task_id: str) -> Optional[dict]: """获取任务状态""" with self.lock: return self.task_status.get(task_id) def cleanup_old_tasks(self, max_age_seconds: int = 3600): """清理旧任务""" with self.lock: current_time = time.time() to_delete = [] for task_id, task in self.task_status.items(): if 'completed_time' in task: age = current_time - task['completed_time'] if age > max_age_seconds: to_delete.append(task_id) for task_id in to_delete: del self.task_status[task_id] logger.info(f"清理旧任务: {task_id}") class TaskWorker(threading.Thread): """工作线程,从队列中取任务并处理""" def __init__(self, worker_id: int, task_queue: TaskQueue, stop_event: threading.Event): super().__init__(daemon=True) self.worker_id = worker_id self.task_queue = task_queue self.stop_event = stop_event self.processed_count = 0 def run(self): logger.info(f"工作线程 {self.worker_id} 启动") while not self.stop_event.is_set(): try: # 从队列获取任务 task = self.task_queue.get_task(timeout=0.5) if not task: continue task_id = task['task_id'] request_data = task['data'] logger.info(f"工作线程 {self.worker_id} 开始处理任务: {task_id}") # 更新任务状态为处理中 self.task_queue.update_task_status(task_id, 'processing') # 这里是你的实际处理逻辑 result = self.process_task(task_id, request_data) # 更新任务状态为完成 self.task_queue.update_task_status( task_id, 'completed' if result.get('success') else 'failed', result ) self.processed_count += 1 logger.info(f"工作线程 {self.worker_id} 完成任务: {task_id}, 处理总数: {self.processed_count}") except Exception as e: logger.error(f"工作线程 {self.worker_id} 处理任务失败: {e}") if task: self.task_queue.update_task_status( task['task_id'], 'failed', {'error': str(e)} ) def process_task(self, task_id: str, request_data: dict) -> dict: """模拟耗时任务处理""" # 这里替换为你的实际处理逻辑 # time.sleep(10) # 模拟10秒处理时间 task_id=request_data["task_id"] sn=request_data["sn"] img_url=request_data["img_url"] prompt=request_data["prompt"] confidence=request_data["confidence"] mqtt_ip=request_data["mqtt_ip"] mqtt_port=request_data["mqtt_port"] mqtt_topic=request_data["mqtt_topic"] local_image_path=downFile(img_url) bpe_path = f"/home/beidou/test0623/sam3/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz" config_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/config.json" # 替换为本地路径 checkpoint_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/sam3.pt" # 替换为本地路径 # model = build_sam3_image_model(bpe_path=bpe_path) # 2. 构建模型(从本地加载) model = build_sam3_image_model_0228( bpe_path=bpe_path, checkpoint_path=checkpoint_path, config_path=config_path, # 可选 load_from_HF=False, device="cuda", eval_mode=True, ) formatted_date, milliseconds_timestamp = get_current_date_and_milliseconds() img_name=os.path.basename(local_image_path) dir_name=os.path.dirname(local_image_path) predict_save_path=os.path.join(dir_name,str(milliseconds_timestamp)+img_name) # image = Image.open(image_path) image = Image.open(local_image_path).convert("RGB") width, height = image.size processor = Sam3Processor(model, confidence_threshold=0.5) inference_state = processor.set_image(image) processor.reset_all_prompts(inference_state) inference_state = processor.set_text_prompt(state=inference_state, prompt="road") img0 = Image.open(local_image_path) plot_results_savepic(img0, inference_state, save_path=predict_save_path) object_name, _=upload_file(predict_save_path,None) mqtt = MQTTService(mqtt_ip, port=mqtt_port) message = { 'success': True, "task_id":task_id, 'object_name': object_name } mqtt.publish_sync(mqtt_topic, json.dumps(message, ensure_ascii=False)) # 删除本地文件 if os.path.exists(local_image_path): os.remove(local_image_path) if os.path.exists(predict_save_path): os.remove(predict_save_path) # 模拟处理结果 return { 'success': True, 'message': f'任务 {task_id} 处理完成', 'data': {'result': 'some_result'} } class TaskServiceServicer(grpc_sam3_img_pb2_grpc.TaskServiceServicer): def __init__(self, task_queue: TaskQueue, max_workers: int = 1): self.task_queue = task_queue self.max_workers = max_workers self.stop_event = threading.Event() self.workers = [] # 启动工作线程 self.start_workers() def start_workers(self): """启动工作线程池""" for i in range(self.max_workers): worker = TaskWorker(i, self.task_queue, self.stop_event) worker.start() self.workers.append(worker) logger.info(f"启动了 {self.max_workers} 个工作线程") def ProcessTask(self, request, context): """处理任务请求 - 将任务放入队列后立即返回""" try: # 检查队列是否已满 if self.task_queue.queue.full(): logger.warning(f"队列已满,拒绝任务: {request.task_id}") return grpc_sam3_img_pb2.TaskResponse( task_id=request.task_id, success=False, message="服务器忙,请稍后重试" ) # 准备任务数据 task_data = { 'task_id': request.task_id, 'sn': request.sn, 'img_url': request.content_body.img_url, 'prompt': request.content_body.prompt, 'confidence': request.content_body.confidence, 'mqtt_ip': request.content_body.mqtt_ip, 'mqtt_port': request.content_body.mqtt_port, 'mqtt_topic': request.content_body.mqtt_topic } # 将任务添加到队列 if self.task_queue.add_task(request.task_id, task_data): return grpc_sam3_img_pb2.TaskResponse( task_id=request.task_id, success=True, message=f"任务已接收,正在排队处理。当前队列位置: {self.task_queue.queue.qsize()}" ) else: return grpc_sam3_img_pb2.TaskResponse( task_id=request.task_id, success=False, message="任务提交失败" ) except Exception as e: logger.error(f"处理任务请求失败: {e}") return grpc_sam3_img_pb2.TaskResponse( task_id=request.task_id, success=False, message=f"服务器内部错误: {str(e)}" ) def stop(self): """停止工作线程""" self.stop_event.set() for worker in self.workers: worker.join(timeout=2) logger.info("所有工作线程已停止") class HealthCheckServicer(grpc_sam3_img_pb2_grpc.HealthCheckServicer): def __init__(self, task_queue: TaskQueue): self.task_queue = task_queue def Check(self, request, context): """健康检查,包含队列状态""" queue_size = self.task_queue.queue.qsize() if queue_size > 50: # 队列过长 return grpc_sam3_img_pb2.HealthCheckResponse( status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.NOT_SERVING ) else: return grpc_sam3_img_pb2.HealthCheckResponse( status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING ) def serve(): # 创建任务队列 task_queue = TaskQueue(max_size=20) # 创建服务实例 task_service = TaskServiceServicer(task_queue, max_workers=1) # 10个工作线程 health_service = HealthCheckServicer(task_queue) # 创建gRPC服务器 server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) # 处理RPC请求的线程 # 注册服务 grpc_sam3_img_pb2_grpc.add_TaskServiceServicer_to_server(task_service, server) grpc_sam3_img_pb2_grpc.add_HealthCheckServicer_to_server(health_service, server) # 启动服务器 server.add_insecure_port('[::]:50051') server.start() logger.info("服务器已启动,监听端口: 50051") logger.info(f"工作线程数: 1, 队列最大容量: 20") # 定时清理旧任务 def cleanup_loop(): while True: time.sleep(300) # 每5分钟清理一次 task_queue.cleanup_old_tasks() cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True) cleanup_thread.start() # 优雅关闭处理 def shutdown(): logger.info("收到关闭信号,正在停止服务器...") task_service.stop() server.stop(5) # 5秒宽限期 logger.info("服务器已停止") import signal signal.signal(signal.SIGINT, lambda s, f: shutdown()) signal.signal(signal.SIGTERM, lambda s, f: shutdown()) # 保持服务器运行 try: server.wait_for_termination() except KeyboardInterrupt: shutdown() if __name__ == '__main__': serve()