370 lines
13 KiB
Python
370 lines
13 KiB
Python
import json
|
|
from concurrent import futures
|
|
import grpc
|
|
import threading
|
|
import queue
|
|
import time
|
|
import logging
|
|
from typing import Dict, Optional
|
|
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
|
from middleware.MQTTService import MQTTService
|
|
from middleware.minio_util import downFile, upload_file
|
|
|
|
import sys
|
|
|
|
from middleware.util import get_current_date_and_milliseconds
|
|
|
|
print(sys.executable)
|
|
|
|
import os
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
import sam3.sam3
|
|
from PIL import Image
|
|
from sam3.sam3 import build_sam3_image_model
|
|
from sam3.sam3 import build_sam3_image_model_0228
|
|
from sam3.sam3.model.box_ops import box_xywh_to_cxcywh
|
|
from sam3.sam3.model.sam3_image_processor import Sam3Processor
|
|
from sam3.sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results,plot_results_savepic
|
|
|
|
sam3_root = os.path.join(os.path.dirname(sam3.sam3.__file__), "..")
|
|
import torch
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
|
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TaskQueue:
|
|
"""任务队列管理类"""
|
|
|
|
def __init__(self, max_size: int = 1000):
|
|
self.queue = queue.Queue(maxsize=max_size)
|
|
self.task_status: Dict[str, dict] = {} # 存储任务状态
|
|
self.lock = threading.Lock()
|
|
self.stop_event = threading.Event()
|
|
|
|
def add_task(self, task_id: str, request_data: dict) -> bool:
|
|
"""添加任务到队列"""
|
|
try:
|
|
task_item = {
|
|
'task_id': task_id,
|
|
'data': request_data,
|
|
'timestamp': time.time(),
|
|
'status': 'pending'
|
|
}
|
|
|
|
with self.lock:
|
|
self.task_status[task_id] = task_item
|
|
|
|
# 非阻塞方式放入队列
|
|
self.queue.put(task_item, block=False)
|
|
logger.info(f"任务 {task_id} 已添加到队列,当前队列大小: {self.queue.qsize()}")
|
|
return True
|
|
|
|
except queue.Full:
|
|
logger.warning(f"队列已满,任务 {task_id} 被拒绝")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"添加任务失败: {e}")
|
|
return False
|
|
|
|
def get_task(self, timeout: float = 1.0) -> Optional[dict]:
|
|
"""从队列获取任务"""
|
|
try:
|
|
return self.queue.get(timeout=timeout)
|
|
except queue.Empty:
|
|
return None
|
|
|
|
def update_task_status(self, task_id: str, status: str, result: dict = None):
|
|
"""更新任务状态"""
|
|
with self.lock:
|
|
if task_id in self.task_status:
|
|
self.task_status[task_id]['status'] = status
|
|
if result:
|
|
self.task_status[task_id]['result'] = result
|
|
self.task_status[task_id]['completed_time'] = time.time()
|
|
|
|
def get_task_status(self, task_id: str) -> Optional[dict]:
|
|
"""获取任务状态"""
|
|
with self.lock:
|
|
return self.task_status.get(task_id)
|
|
|
|
def cleanup_old_tasks(self, max_age_seconds: int = 3600):
|
|
"""清理旧任务"""
|
|
with self.lock:
|
|
current_time = time.time()
|
|
to_delete = []
|
|
for task_id, task in self.task_status.items():
|
|
if 'completed_time' in task:
|
|
age = current_time - task['completed_time']
|
|
if age > max_age_seconds:
|
|
to_delete.append(task_id)
|
|
|
|
for task_id in to_delete:
|
|
del self.task_status[task_id]
|
|
logger.info(f"清理旧任务: {task_id}")
|
|
|
|
|
|
class TaskWorker(threading.Thread):
|
|
"""工作线程,从队列中取任务并处理"""
|
|
|
|
def __init__(self, worker_id: int, task_queue: TaskQueue, stop_event: threading.Event):
|
|
super().__init__(daemon=True)
|
|
self.worker_id = worker_id
|
|
self.task_queue = task_queue
|
|
self.stop_event = stop_event
|
|
self.processed_count = 0
|
|
|
|
def run(self):
|
|
logger.info(f"工作线程 {self.worker_id} 启动")
|
|
while not self.stop_event.is_set():
|
|
try:
|
|
# 从队列获取任务
|
|
task = self.task_queue.get_task(timeout=0.5)
|
|
if not task:
|
|
continue
|
|
|
|
task_id = task['task_id']
|
|
request_data = task['data']
|
|
|
|
logger.info(f"工作线程 {self.worker_id} 开始处理任务: {task_id}")
|
|
|
|
# 更新任务状态为处理中
|
|
self.task_queue.update_task_status(task_id, 'processing')
|
|
|
|
# 这里是你的实际处理逻辑
|
|
result = self.process_task(task_id, request_data)
|
|
|
|
# 更新任务状态为完成
|
|
self.task_queue.update_task_status(
|
|
task_id,
|
|
'completed' if result.get('success') else 'failed',
|
|
result
|
|
)
|
|
|
|
self.processed_count += 1
|
|
logger.info(f"工作线程 {self.worker_id} 完成任务: {task_id}, 处理总数: {self.processed_count}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"工作线程 {self.worker_id} 处理任务失败: {e}")
|
|
if task:
|
|
self.task_queue.update_task_status(
|
|
task['task_id'],
|
|
'failed',
|
|
{'error': str(e)}
|
|
)
|
|
|
|
def process_task(self, task_id: str, request_data: dict) -> dict:
|
|
"""模拟耗时任务处理"""
|
|
# 这里替换为你的实际处理逻辑
|
|
# time.sleep(10) # 模拟10秒处理时间
|
|
task_id=request_data["task_id"]
|
|
sn=request_data["sn"]
|
|
img_url=request_data["img_url"]
|
|
prompt=request_data["prompt"]
|
|
confidence=request_data["confidence"]
|
|
mqtt_ip=request_data["mqtt_ip"]
|
|
mqtt_port=request_data["mqtt_port"]
|
|
mqtt_topic=request_data["mqtt_topic"]
|
|
local_image_path=downFile(img_url)
|
|
|
|
bpe_path = f"/home/beidou/test0623/sam3/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
|
|
config_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/config.json" # 替换为本地路径
|
|
checkpoint_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/sam3.pt" # 替换为本地路径
|
|
# model = build_sam3_image_model(bpe_path=bpe_path)
|
|
|
|
# 2. 构建模型(从本地加载)
|
|
model = build_sam3_image_model_0228(
|
|
bpe_path=bpe_path,
|
|
checkpoint_path=checkpoint_path,
|
|
config_path=config_path, # 可选
|
|
load_from_HF=False,
|
|
device="cuda",
|
|
eval_mode=True,
|
|
)
|
|
formatted_date, milliseconds_timestamp = get_current_date_and_milliseconds()
|
|
img_name=os.path.basename(local_image_path)
|
|
dir_name=os.path.dirname(local_image_path)
|
|
predict_save_path=os.path.join(dir_name,str(milliseconds_timestamp)+img_name)
|
|
|
|
# image = Image.open(image_path)
|
|
image = Image.open(local_image_path).convert("RGB")
|
|
width, height = image.size
|
|
processor = Sam3Processor(model, confidence_threshold=0.5)
|
|
inference_state = processor.set_image(image)
|
|
|
|
processor.reset_all_prompts(inference_state)
|
|
inference_state = processor.set_text_prompt(state=inference_state, prompt="road")
|
|
img0 = Image.open(local_image_path)
|
|
plot_results_savepic(img0, inference_state, save_path=predict_save_path)
|
|
object_name, _=upload_file(predict_save_path,None)
|
|
|
|
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
|
|
message = {
|
|
'success': True,
|
|
"task_id":task_id,
|
|
'object_name': object_name
|
|
}
|
|
mqtt.publish_sync(mqtt_topic, json.dumps(message, ensure_ascii=False))
|
|
# 删除本地文件
|
|
if os.path.exists(local_image_path):
|
|
os.remove(local_image_path)
|
|
if os.path.exists(predict_save_path):
|
|
os.remove(predict_save_path)
|
|
# 模拟处理结果
|
|
return {
|
|
'success': True,
|
|
'message': f'任务 {task_id} 处理完成',
|
|
'data': {'result': 'some_result'}
|
|
}
|
|
|
|
|
|
class TaskServiceServicer(grpc_sam3_img_pb2_grpc.TaskServiceServicer):
|
|
def __init__(self, task_queue: TaskQueue, max_workers: int = 1):
|
|
self.task_queue = task_queue
|
|
self.max_workers = max_workers
|
|
self.stop_event = threading.Event()
|
|
self.workers = []
|
|
|
|
# 启动工作线程
|
|
self.start_workers()
|
|
|
|
def start_workers(self):
|
|
"""启动工作线程池"""
|
|
for i in range(self.max_workers):
|
|
worker = TaskWorker(i, self.task_queue, self.stop_event)
|
|
worker.start()
|
|
self.workers.append(worker)
|
|
logger.info(f"启动了 {self.max_workers} 个工作线程")
|
|
|
|
def ProcessTask(self, request, context):
|
|
"""处理任务请求 - 将任务放入队列后立即返回"""
|
|
try:
|
|
# 检查队列是否已满
|
|
if self.task_queue.queue.full():
|
|
logger.warning(f"队列已满,拒绝任务: {request.task_id}")
|
|
return grpc_sam3_img_pb2.TaskResponse(
|
|
task_id=request.task_id,
|
|
success=False,
|
|
message="服务器忙,请稍后重试"
|
|
)
|
|
|
|
# 准备任务数据
|
|
task_data = {
|
|
'task_id': request.task_id,
|
|
'sn': request.sn,
|
|
'img_url': request.content_body.img_url,
|
|
'prompt': request.content_body.prompt,
|
|
'confidence': request.content_body.confidence,
|
|
'mqtt_ip': request.content_body.mqtt_ip,
|
|
'mqtt_port': request.content_body.mqtt_port,
|
|
'mqtt_topic': request.content_body.mqtt_topic
|
|
}
|
|
|
|
# 将任务添加到队列
|
|
if self.task_queue.add_task(request.task_id, task_data):
|
|
return grpc_sam3_img_pb2.TaskResponse(
|
|
task_id=request.task_id,
|
|
success=True,
|
|
message=f"任务已接收,正在排队处理。当前队列位置: {self.task_queue.queue.qsize()}"
|
|
)
|
|
else:
|
|
return grpc_sam3_img_pb2.TaskResponse(
|
|
task_id=request.task_id,
|
|
success=False,
|
|
message="任务提交失败"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"处理任务请求失败: {e}")
|
|
return grpc_sam3_img_pb2.TaskResponse(
|
|
task_id=request.task_id,
|
|
success=False,
|
|
message=f"服务器内部错误: {str(e)}"
|
|
)
|
|
|
|
def stop(self):
|
|
"""停止工作线程"""
|
|
self.stop_event.set()
|
|
for worker in self.workers:
|
|
worker.join(timeout=2)
|
|
logger.info("所有工作线程已停止")
|
|
|
|
|
|
class HealthCheckServicer(grpc_sam3_img_pb2_grpc.HealthCheckServicer):
|
|
def __init__(self, task_queue: TaskQueue):
|
|
self.task_queue = task_queue
|
|
|
|
def Check(self, request, context):
|
|
"""健康检查,包含队列状态"""
|
|
queue_size = self.task_queue.queue.qsize()
|
|
|
|
if queue_size > 50: # 队列过长
|
|
return grpc_sam3_img_pb2.HealthCheckResponse(
|
|
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.NOT_SERVING
|
|
)
|
|
else:
|
|
return grpc_sam3_img_pb2.HealthCheckResponse(
|
|
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
|
|
)
|
|
|
|
|
|
def serve():
|
|
# 创建任务队列
|
|
task_queue = TaskQueue(max_size=20)
|
|
|
|
# 创建服务实例
|
|
task_service = TaskServiceServicer(task_queue, max_workers=1) # 10个工作线程
|
|
health_service = HealthCheckServicer(task_queue)
|
|
|
|
# 创建gRPC服务器
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) # 处理RPC请求的线程
|
|
|
|
# 注册服务
|
|
grpc_sam3_img_pb2_grpc.add_TaskServiceServicer_to_server(task_service, server)
|
|
grpc_sam3_img_pb2_grpc.add_HealthCheckServicer_to_server(health_service, server)
|
|
|
|
# 启动服务器
|
|
server.add_insecure_port('[::]:50051')
|
|
server.start()
|
|
|
|
logger.info("服务器已启动,监听端口: 50051")
|
|
logger.info(f"工作线程数: 1, 队列最大容量: 20")
|
|
|
|
# 定时清理旧任务
|
|
def cleanup_loop():
|
|
while True:
|
|
time.sleep(300) # 每5分钟清理一次
|
|
task_queue.cleanup_old_tasks()
|
|
|
|
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
|
|
cleanup_thread.start()
|
|
|
|
# 优雅关闭处理
|
|
def shutdown():
|
|
logger.info("收到关闭信号,正在停止服务器...")
|
|
task_service.stop()
|
|
server.stop(5) # 5秒宽限期
|
|
logger.info("服务器已停止")
|
|
|
|
import signal
|
|
signal.signal(signal.SIGINT, lambda s, f: shutdown())
|
|
signal.signal(signal.SIGTERM, lambda s, f: shutdown())
|
|
|
|
# 保持服务器运行
|
|
try:
|
|
server.wait_for_termination()
|
|
except KeyboardInterrupt:
|
|
shutdown()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
serve()
|