ai_project_v1/grpc_util/grpc_sam3/sam3_grpc_server.py

370 lines
13 KiB
Python
Raw Permalink Normal View History

2026-03-05 14:51:08 +08:00
import json
from concurrent import futures
import grpc
import threading
import queue
import time
import logging
from typing import Dict, Optional
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
from middleware.MQTTService import MQTTService
from middleware.minio_util import downFile, upload_file
import sys
from middleware.util import get_current_date_and_milliseconds
print(sys.executable)
import os
import matplotlib.pyplot as plt
import numpy as np
import sam3.sam3
from PIL import Image
from sam3.sam3 import build_sam3_image_model
from sam3.sam3 import build_sam3_image_model_0228
from sam3.sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.sam3.model.sam3_image_processor import Sam3Processor
from sam3.sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results,plot_results_savepic
sam3_root = os.path.join(os.path.dirname(sam3.sam3.__file__), "..")
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TaskQueue:
"""任务队列管理类"""
def __init__(self, max_size: int = 1000):
self.queue = queue.Queue(maxsize=max_size)
self.task_status: Dict[str, dict] = {} # 存储任务状态
self.lock = threading.Lock()
self.stop_event = threading.Event()
def add_task(self, task_id: str, request_data: dict) -> bool:
"""添加任务到队列"""
try:
task_item = {
'task_id': task_id,
'data': request_data,
'timestamp': time.time(),
'status': 'pending'
}
with self.lock:
self.task_status[task_id] = task_item
# 非阻塞方式放入队列
self.queue.put(task_item, block=False)
logger.info(f"任务 {task_id} 已添加到队列,当前队列大小: {self.queue.qsize()}")
return True
except queue.Full:
logger.warning(f"队列已满,任务 {task_id} 被拒绝")
return False
except Exception as e:
logger.error(f"添加任务失败: {e}")
return False
def get_task(self, timeout: float = 1.0) -> Optional[dict]:
"""从队列获取任务"""
try:
return self.queue.get(timeout=timeout)
except queue.Empty:
return None
def update_task_status(self, task_id: str, status: str, result: dict = None):
"""更新任务状态"""
with self.lock:
if task_id in self.task_status:
self.task_status[task_id]['status'] = status
if result:
self.task_status[task_id]['result'] = result
self.task_status[task_id]['completed_time'] = time.time()
def get_task_status(self, task_id: str) -> Optional[dict]:
"""获取任务状态"""
with self.lock:
return self.task_status.get(task_id)
def cleanup_old_tasks(self, max_age_seconds: int = 3600):
"""清理旧任务"""
with self.lock:
current_time = time.time()
to_delete = []
for task_id, task in self.task_status.items():
if 'completed_time' in task:
age = current_time - task['completed_time']
if age > max_age_seconds:
to_delete.append(task_id)
for task_id in to_delete:
del self.task_status[task_id]
logger.info(f"清理旧任务: {task_id}")
class TaskWorker(threading.Thread):
"""工作线程,从队列中取任务并处理"""
def __init__(self, worker_id: int, task_queue: TaskQueue, stop_event: threading.Event):
super().__init__(daemon=True)
self.worker_id = worker_id
self.task_queue = task_queue
self.stop_event = stop_event
self.processed_count = 0
def run(self):
logger.info(f"工作线程 {self.worker_id} 启动")
while not self.stop_event.is_set():
try:
# 从队列获取任务
task = self.task_queue.get_task(timeout=0.5)
if not task:
continue
task_id = task['task_id']
request_data = task['data']
logger.info(f"工作线程 {self.worker_id} 开始处理任务: {task_id}")
# 更新任务状态为处理中
self.task_queue.update_task_status(task_id, 'processing')
# 这里是你的实际处理逻辑
result = self.process_task(task_id, request_data)
# 更新任务状态为完成
self.task_queue.update_task_status(
task_id,
'completed' if result.get('success') else 'failed',
result
)
self.processed_count += 1
logger.info(f"工作线程 {self.worker_id} 完成任务: {task_id}, 处理总数: {self.processed_count}")
except Exception as e:
logger.error(f"工作线程 {self.worker_id} 处理任务失败: {e}")
if task:
self.task_queue.update_task_status(
task['task_id'],
'failed',
{'error': str(e)}
)
def process_task(self, task_id: str, request_data: dict) -> dict:
"""模拟耗时任务处理"""
# 这里替换为你的实际处理逻辑
# time.sleep(10) # 模拟10秒处理时间
task_id=request_data["task_id"]
sn=request_data["sn"]
img_url=request_data["img_url"]
prompt=request_data["prompt"]
confidence=request_data["confidence"]
mqtt_ip=request_data["mqtt_ip"]
mqtt_port=request_data["mqtt_port"]
mqtt_topic=request_data["mqtt_topic"]
local_image_path=downFile(img_url)
bpe_path = f"/home/beidou/test0623/sam3/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
config_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/config.json" # 替换为本地路径
checkpoint_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/sam3.pt" # 替换为本地路径
# model = build_sam3_image_model(bpe_path=bpe_path)
# 2. 构建模型(从本地加载)
model = build_sam3_image_model_0228(
bpe_path=bpe_path,
checkpoint_path=checkpoint_path,
config_path=config_path, # 可选
load_from_HF=False,
device="cuda",
eval_mode=True,
)
formatted_date, milliseconds_timestamp = get_current_date_and_milliseconds()
img_name=os.path.basename(local_image_path)
dir_name=os.path.dirname(local_image_path)
predict_save_path=os.path.join(dir_name,str(milliseconds_timestamp)+img_name)
# image = Image.open(image_path)
image = Image.open(local_image_path).convert("RGB")
width, height = image.size
processor = Sam3Processor(model, confidence_threshold=0.5)
inference_state = processor.set_image(image)
processor.reset_all_prompts(inference_state)
inference_state = processor.set_text_prompt(state=inference_state, prompt="road")
img0 = Image.open(local_image_path)
plot_results_savepic(img0, inference_state, save_path=predict_save_path)
object_name, _=upload_file(predict_save_path,None)
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
message = {
'success': True,
"task_id":task_id,
'object_name': object_name
}
mqtt.publish_sync(mqtt_topic, json.dumps(message, ensure_ascii=False))
# 删除本地文件
if os.path.exists(local_image_path):
os.remove(local_image_path)
if os.path.exists(predict_save_path):
os.remove(predict_save_path)
# 模拟处理结果
return {
'success': True,
'message': f'任务 {task_id} 处理完成',
'data': {'result': 'some_result'}
}
class TaskServiceServicer(grpc_sam3_img_pb2_grpc.TaskServiceServicer):
def __init__(self, task_queue: TaskQueue, max_workers: int = 1):
self.task_queue = task_queue
self.max_workers = max_workers
self.stop_event = threading.Event()
self.workers = []
# 启动工作线程
self.start_workers()
def start_workers(self):
"""启动工作线程池"""
for i in range(self.max_workers):
worker = TaskWorker(i, self.task_queue, self.stop_event)
worker.start()
self.workers.append(worker)
logger.info(f"启动了 {self.max_workers} 个工作线程")
def ProcessTask(self, request, context):
"""处理任务请求 - 将任务放入队列后立即返回"""
try:
# 检查队列是否已满
if self.task_queue.queue.full():
logger.warning(f"队列已满,拒绝任务: {request.task_id}")
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message="服务器忙,请稍后重试"
)
# 准备任务数据
task_data = {
'task_id': request.task_id,
'sn': request.sn,
'img_url': request.content_body.img_url,
'prompt': request.content_body.prompt,
'confidence': request.content_body.confidence,
'mqtt_ip': request.content_body.mqtt_ip,
'mqtt_port': request.content_body.mqtt_port,
'mqtt_topic': request.content_body.mqtt_topic
}
# 将任务添加到队列
if self.task_queue.add_task(request.task_id, task_data):
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message=f"任务已接收,正在排队处理。当前队列位置: {self.task_queue.queue.qsize()}"
)
else:
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message="任务提交失败"
)
except Exception as e:
logger.error(f"处理任务请求失败: {e}")
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message=f"服务器内部错误: {str(e)}"
)
def stop(self):
"""停止工作线程"""
self.stop_event.set()
for worker in self.workers:
worker.join(timeout=2)
logger.info("所有工作线程已停止")
class HealthCheckServicer(grpc_sam3_img_pb2_grpc.HealthCheckServicer):
def __init__(self, task_queue: TaskQueue):
self.task_queue = task_queue
def Check(self, request, context):
"""健康检查,包含队列状态"""
queue_size = self.task_queue.queue.qsize()
if queue_size > 50: # 队列过长
return grpc_sam3_img_pb2.HealthCheckResponse(
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.NOT_SERVING
)
else:
return grpc_sam3_img_pb2.HealthCheckResponse(
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
# 创建任务队列
task_queue = TaskQueue(max_size=20)
# 创建服务实例
task_service = TaskServiceServicer(task_queue, max_workers=1) # 10个工作线程
health_service = HealthCheckServicer(task_queue)
# 创建gRPC服务器
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) # 处理RPC请求的线程
# 注册服务
grpc_sam3_img_pb2_grpc.add_TaskServiceServicer_to_server(task_service, server)
grpc_sam3_img_pb2_grpc.add_HealthCheckServicer_to_server(health_service, server)
# 启动服务器
server.add_insecure_port('[::]:50051')
server.start()
logger.info("服务器已启动,监听端口: 50051")
logger.info(f"工作线程数: 1, 队列最大容量: 20")
# 定时清理旧任务
def cleanup_loop():
while True:
time.sleep(300) # 每5分钟清理一次
task_queue.cleanup_old_tasks()
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
cleanup_thread.start()
# 优雅关闭处理
def shutdown():
logger.info("收到关闭信号,正在停止服务器...")
task_service.stop()
server.stop(5) # 5秒宽限期
logger.info("服务器已停止")
import signal
signal.signal(signal.SIGINT, lambda s, f: shutdown())
signal.signal(signal.SIGTERM, lambda s, f: shutdown())
# 保持服务器运行
try:
server.wait_for_termination()
except KeyboardInterrupt:
shutdown()
if __name__ == '__main__':
serve()