增加sam3,集成并通讯成功

This commit is contained in:
martin 2026-03-05 14:51:08 +08:00
parent 63f240ac3a
commit c17df2e460
522 changed files with 419681 additions and 126 deletions

View File

@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: grpc_proto/check_grpc/check_grpc.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
6,
31,
1,
'',
'grpc_proto/check_grpc/check_grpc.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&grpc_proto/check_grpc/check_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_proto.check_grpc.check_grpc_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_HEALTHCHECKREQUEST']._serialized_start=48
_globals['_HEALTHCHECKREQUEST']._serialized_end=85
_globals['_HEALTHCHECKRESPONSE']._serialized_start=88
_globals['_HEALTHCHECKRESPONSE']._serialized_end=247
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=168
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=247
_globals['_TASKREQUEST']._serialized_start=249
_globals['_TASKREQUEST']._serialized_end=332
_globals['_CONTENTBODY']._serialized_start=335
_globals['_CONTENTBODY']._serialized_end=506
_globals['_PARALIST']._serialized_start=508
_globals['_PARALIST']._serialized_end=563
_globals['_INVADE']._serialized_start=565
_globals['_INVADE']._serialized_end=619
_globals['_TASKRESPONSE']._serialized_start=621
_globals['_TASKRESPONSE']._serialized_end=686
_globals['_TASKSERVICE']._serialized_start=688
_globals['_TASKSERVICE']._serialized_end=755
_globals['_HEALTHCHECK']._serialized_start=757
_globals['_HEALTHCHECK']._serialized_end=832
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,76 @@
import asyncio
import grpc
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
async def async_check_server_status(channel):
try:
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
# 注意这里不需要await因为Check方法不是异步的
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
async def async_check_grpc_request(max_retries=3, delay=5):
retries = 0
channel = None
while retries < max_retries:
try:
channel = grpc.insecure_channel('localhost:50051')
if not await async_check_server_status(channel):
raise Exception("Server is not healthy")
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
request = check_grpc_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
sn="8UUXN6S00A0CK7",
content_body=check_grpc_pb2.ContentBody(
org_code="HMZHB",
func_id=[101204],
source_url="xxxxxxxxxx",
push_url="",
confidence=0.4,
para_list=[
check_grpc_pb2.ParaList(
func_id=101204,
para_invade_enable=True
)
],
invade=check_grpc_pb2.Invade(
invade_file="meta_data/高压线-0826.geojson",
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
)
)
)
# ProcessTask可能也不是异步的所以不需要await
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
if channel:
channel.close() # 同步关闭不需要await
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if channel:
channel.close() # 确保在重试前关闭连接
channel = None
await asyncio.sleep(delay) # 异步等待
print("All retry attempts failed")
return False
if __name__ == '__main__':
asyncio.run(async_check_grpc_request())

View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: check_grpc.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
6,
31,
1,
'',
'check_grpc.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63heck_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'check_grpc_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_HEALTHCHECKREQUEST']._serialized_start=26
_globals['_HEALTHCHECKREQUEST']._serialized_end=63
_globals['_HEALTHCHECKRESPONSE']._serialized_start=66
_globals['_HEALTHCHECKRESPONSE']._serialized_end=225
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=146
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=225
_globals['_TASKREQUEST']._serialized_start=227
_globals['_TASKREQUEST']._serialized_end=310
_globals['_CONTENTBODY']._serialized_start=313
_globals['_CONTENTBODY']._serialized_end=484
_globals['_PARALIST']._serialized_start=486
_globals['_PARALIST']._serialized_end=541
_globals['_INVADE']._serialized_start=543
_globals['_INVADE']._serialized_end=597
_globals['_TASKRESPONSE']._serialized_start=599
_globals['_TASKRESPONSE']._serialized_end=664
_globals['_TASKSERVICE']._serialized_start=666
_globals['_TASKSERVICE']._serialized_end=733
_globals['_HEALTHCHECK']._serialized_start=735
_globals['_HEALTHCHECK']._serialized_end=810
# @@protoc_insertion_point(module_scope)

View File

@ -3,7 +3,7 @@
import grpc import grpc
import warnings import warnings
from grpc_proto.check_grpc import check_grpc_pb2 as grpc__proto_dot_check__grpc_dot_check__grpc__pb2 import grpc_util.grpc_proto_demo.check_grpc.check_grpc_pb2 as check__grpc__pb2
GRPC_GENERATED_VERSION = '1.76.0' GRPC_GENERATED_VERSION = '1.76.0'
GRPC_VERSION = grpc.__version__ GRPC_VERSION = grpc.__version__
@ -18,7 +18,7 @@ except ImportError:
if _version_not_supported: if _version_not_supported:
raise RuntimeError( raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},' f'The grpc package installed is at version {GRPC_VERSION},'
+ ' but the generated code in grpc_proto/check_grpc/check_grpc_pb2_grpc.py depends on' + ' but the generated code in check_grpc_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.' + f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
@ -36,8 +36,8 @@ class TaskServiceStub(object):
""" """
self.ProcessTask = channel.unary_unary( self.ProcessTask = channel.unary_unary(
'/task.TaskService/ProcessTask', '/task.TaskService/ProcessTask',
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString, request_serializer=check__grpc__pb2.TaskRequest.SerializeToString,
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString, response_deserializer=check__grpc__pb2.TaskResponse.FromString,
_registered_method=True) _registered_method=True)
@ -55,8 +55,8 @@ def add_TaskServiceServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'ProcessTask': grpc.unary_unary_rpc_method_handler( 'ProcessTask': grpc.unary_unary_rpc_method_handler(
servicer.ProcessTask, servicer.ProcessTask,
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.FromString, request_deserializer=check__grpc__pb2.TaskRequest.FromString,
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.SerializeToString, response_serializer=check__grpc__pb2.TaskResponse.SerializeToString,
), ),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
@ -84,8 +84,8 @@ class TaskService(object):
request, request,
target, target,
'/task.TaskService/ProcessTask', '/task.TaskService/ProcessTask',
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString, check__grpc__pb2.TaskRequest.SerializeToString,
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString, check__grpc__pb2.TaskResponse.FromString,
options, options,
channel_credentials, channel_credentials,
insecure, insecure,
@ -109,8 +109,8 @@ class HealthCheckStub(object):
""" """
self.Check = channel.unary_unary( self.Check = channel.unary_unary(
'/task.HealthCheck/Check', '/task.HealthCheck/Check',
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString, request_serializer=check__grpc__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString, response_deserializer=check__grpc__pb2.HealthCheckResponse.FromString,
_registered_method=True) _registered_method=True)
@ -129,8 +129,8 @@ def add_HealthCheckServicer_to_server(servicer, server):
rpc_method_handlers = { rpc_method_handlers = {
'Check': grpc.unary_unary_rpc_method_handler( 'Check': grpc.unary_unary_rpc_method_handler(
servicer.Check, servicer.Check,
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.FromString, request_deserializer=check__grpc__pb2.HealthCheckRequest.FromString,
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.SerializeToString, response_serializer=check__grpc__pb2.HealthCheckResponse.SerializeToString,
), ),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
@ -159,8 +159,8 @@ class HealthCheck(object):
request, request,
target, target,
'/task.HealthCheck/Check', '/task.HealthCheck/Check',
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString, check__grpc__pb2.HealthCheckRequest.SerializeToString,
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString, check__grpc__pb2.HealthCheckResponse.FromString,
options, options,
channel_credentials, channel_credentials,
insecure, insecure,

View File

@ -1,7 +1,9 @@
import grpc import grpc
import time import time
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2 from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
def check_server_status(channel): def check_server_status(channel):

View File

@ -1,8 +1,7 @@
from concurrent import futures from concurrent import futures
import grpc import grpc
import time
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2 from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2, check_grpc_pb2_grpc
class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer): class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer):
@ -36,12 +35,21 @@ def serve():
server.start() server.start()
print("Server started, listening on port 50051...") print("Server started, listening on port 50051...")
try: # try:
while True: # while True:
time.sleep(86400) # 保持运行 # time.sleep(86400) # 保持运行
except KeyboardInterrupt: # except KeyboardInterrupt:
# server.stop(0)
# 信号处理
import signal
def shutdown_handler(signum, frame):
print(f"Received signal {signum}, shutting down...")
server.stop(0) server.stop(0)
signal.signal(signal.SIGINT, shutdown_handler)
signal.signal(signal.SIGTERM, shutdown_handler)
server.wait_for_termination()
if __name__ == '__main__': if __name__ == '__main__':
serve() serve()

View File

@ -0,0 +1,8 @@
1、参考check_grpc.proto 手写 proto 相关结构
2、使用编译命令生成grpc通讯相关命令如下
python -m grpc_tools.protoc --proto_path=proto_dir \
--python_out=gen_dir \
--grpc_python_out=gen_dir \
proto_dir/check_grpc.proto
只会生成requeset、response
3、拷贝当前文件下的check_grpc_client.py 、check_grpc_server.py重写逻辑代码

View File

View File

@ -0,0 +1,89 @@
import asyncio
import grpc
# from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
# from grpc_util.grpc_sam3.grpc_sam3_img_pb2_grpc import TaskServiceStub, HealthCheckStub
# from grpc_util.grpc_sam3.grpc_sam3_img_pb2 import TaskRequest, TaskResponse, HealthCheckRequest, HealthCheckResponse
# from . import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
async def async_check_server_status(channel):
try:
health_stub = grpc_sam3_img_pb2_grpc.HealthCheckStub(channel)
# 注意这里不需要await因为Check方法不是异步的
response = health_stub.Check(grpc_sam3_img_pb2.HealthCheckRequest(service="TaskService"))
return response.status == grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
async def grpc_sam3_pic_predict(
task_id:str,
sn:str,
img_url:str,
prompt:str,
confidence:float,
mqtt_ip:str,
mqtt_port:int,
mqtt_topic:str,
max_retries=3, delay=5):
retries = 0
channel = None
while retries < max_retries:
try:
channel = grpc.insecure_channel('0.0.0.0:50051')
if not await async_check_server_status(channel):
raise Exception("Server is not healthy")
stub = grpc_sam3_img_pb2_grpc.TaskServiceStub(channel)
request = grpc_sam3_img_pb2.TaskRequest(
task_id=task_id,
sn=sn,
content_body=grpc_sam3_img_pb2.ContentBody(
img_url=img_url,
prompt=prompt,
confidence=0.5,
mqtt_ip=mqtt_ip,
mqtt_port=mqtt_port,
mqtt_topic=mqtt_topic
)
)
# ProcessTask可能也不是异步的所以不需要await
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
if channel:
channel.close() # 同步关闭不需要await
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if channel:
channel.close() # 确保在重试前关闭连接
channel = None
await asyncio.sleep(delay) # 异步等待
print("All retry attempts failed")
return False
if __name__ == '__main__':
asyncio.run(grpc_sam3_pic_predict(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354111",
sn="8UUXN6S00A0CK7",
img_url="demo/03.png",
prompt="cat",
confidence=0.5,
mqtt_ip="47.108.62.6",
mqtt_port=12503,
mqtt_topic="thing/product/ai/events",
))

View File

@ -0,0 +1,48 @@
syntax = "proto3";
package grpc_sam3_img;
service TaskService {
rpc ProcessTask (TaskRequest) returns (TaskResponse);
}
//
service HealthCheck {
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
}
message HealthCheckRequest {
string service = 1;
}
message HealthCheckResponse {
enum ServingStatus {
UNKNOWN = 0;
SERVING = 1;
NOT_SERVING = 2;
SERVICE_UNKNOWN = 3;
}
ServingStatus status = 1;
}
message TaskRequest {
string task_id = 1;
string sn = 2;
ContentBody content_body = 3;
}
message ContentBody {
string img_url = 1;
string prompt = 2;
float confidence =3;
string mqtt_ip = 4;
int32 mqtt_port = 5;
string mqtt_topic = 6;
}
message TaskResponse {
string task_id = 1;
bool success = 2;
string message = 3;
}

View File

@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: grpc_sam3_img.proto
# Protobuf Python Version: 6.31.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
6,
31,
1,
'',
'grpc_sam3_img.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13grpc_sam3_img.proto\x12\rgrpc_sam3_img\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\xa8\x01\n\x13HealthCheckResponse\x12@\n\x06status\x18\x01 \x01(\x0e\x32\x30.grpc_sam3_img.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"\\\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\x30\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x1a.grpc_sam3_img.ContentBody\"z\n\x0b\x43ontentBody\x12\x0f\n\x07img_url\x18\x01 \x01(\t\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x0f\n\x07mqtt_ip\x18\x04 \x01(\t\x12\x11\n\tmqtt_port\x18\x05 \x01(\x05\x12\x12\n\nmqtt_topic\x18\x06 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2U\n\x0bTaskService\x12\x46\n\x0bProcessTask\x12\x1a.grpc_sam3_img.TaskRequest\x1a\x1b.grpc_sam3_img.TaskResponse2]\n\x0bHealthCheck\x12N\n\x05\x43heck\x12!.grpc_sam3_img.HealthCheckRequest\x1a\".grpc_sam3_img.HealthCheckResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_sam3_img_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_HEALTHCHECKREQUEST']._serialized_start=38
_globals['_HEALTHCHECKREQUEST']._serialized_end=75
_globals['_HEALTHCHECKRESPONSE']._serialized_start=78
_globals['_HEALTHCHECKRESPONSE']._serialized_end=246
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=167
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=246
_globals['_TASKREQUEST']._serialized_start=248
_globals['_TASKREQUEST']._serialized_end=340
_globals['_CONTENTBODY']._serialized_start=342
_globals['_CONTENTBODY']._serialized_end=464
_globals['_TASKRESPONSE']._serialized_start=466
_globals['_TASKRESPONSE']._serialized_end=531
_globals['_TASKSERVICE']._serialized_start=533
_globals['_TASKSERVICE']._serialized_end=618
_globals['_HEALTHCHECK']._serialized_start=620
_globals['_HEALTHCHECK']._serialized_end=713
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,172 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import grpc_util.grpc_sam3.grpc_sam3_img_pb2 as grpc__sam3__img__pb2
GRPC_GENERATED_VERSION = '1.76.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
except ImportError:
_version_not_supported = True
if _version_not_supported:
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ ' but the generated code in grpc_sam3_img_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
)
class TaskServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ProcessTask = channel.unary_unary(
'/grpc_sam3_img.TaskService/ProcessTask',
request_serializer=grpc__sam3__img__pb2.TaskRequest.SerializeToString,
response_deserializer=grpc__sam3__img__pb2.TaskResponse.FromString,
_registered_method=True)
class TaskServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def ProcessTask(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_TaskServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'ProcessTask': grpc.unary_unary_rpc_method_handler(
servicer.ProcessTask,
request_deserializer=grpc__sam3__img__pb2.TaskRequest.FromString,
response_serializer=grpc__sam3__img__pb2.TaskResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'grpc_sam3_img.TaskService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('grpc_sam3_img.TaskService', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class TaskService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def ProcessTask(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/grpc_sam3_img.TaskService/ProcessTask',
grpc__sam3__img__pb2.TaskRequest.SerializeToString,
grpc__sam3__img__pb2.TaskResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
class HealthCheckStub(object):
"""添加健康检查服务
"""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Check = channel.unary_unary(
'/grpc_sam3_img.HealthCheck/Check',
request_serializer=grpc__sam3__img__pb2.HealthCheckRequest.SerializeToString,
response_deserializer=grpc__sam3__img__pb2.HealthCheckResponse.FromString,
_registered_method=True)
class HealthCheckServicer(object):
"""添加健康检查服务
"""
def Check(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_HealthCheckServicer_to_server(servicer, server):
rpc_method_handlers = {
'Check': grpc.unary_unary_rpc_method_handler(
servicer.Check,
request_deserializer=grpc__sam3__img__pb2.HealthCheckRequest.FromString,
response_serializer=grpc__sam3__img__pb2.HealthCheckResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'grpc_sam3_img.HealthCheck', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers('grpc_sam3_img.HealthCheck', rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class HealthCheck(object):
"""添加健康检查服务
"""
@staticmethod
def Check(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/grpc_sam3_img.HealthCheck/Check',
grpc__sam3__img__pb2.HealthCheckRequest.SerializeToString,
grpc__sam3__img__pb2.HealthCheckResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

View File

@ -0,0 +1,72 @@
import grpc
import time
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
def check_server_status(channel):
try:
health_stub = grpc_sam3_img_pb2_grpc.HealthCheckStub(channel)
response = health_stub.Check(grpc_sam3_img_pb2.HealthCheckRequest(service="TaskService"))
return response.status == grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
except grpc.RpcError as e:
print(f"Health check failed: {e}")
return False
def check_grpc_request(max_retries=3, delay=5):
channel = None
retries = 0
while retries < max_retries:
try:
# 创建通道
channel = grpc.insecure_channel('192.168.110.187:9999')
# 检查服务器状态
if not check_server_status(channel):
raise Exception("Server is not healthy")
stub = grpc_sam3_img_pb2_grpc.TaskServiceStub(channel)
# 创建请求消息
request = grpc_sam3_img_pb2.TaskRequest(
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354111",
sn="8UUXN6S00A0CK7",
content_body=grpc_sam3_img_pb2.ContentBody(
img_url="demo/03.png",
prompt="cat",
confidence=0.5,
mqtt_ip="47.108.62.6",
mqtt_port=12503,
mqtt_topic="thing/product/ai/events"
)
)
# 调用远程方法
response = stub.ProcessTask(request)
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
return True
except grpc.RpcError as e:
retries += 1
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
except Exception as e:
print(f"Error occurred: {e}")
retries += 1
if retries < max_retries:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
finally:
if channel:
channel.close()
print("All retry attempts failed")
return False
if __name__ == '__main__':
check_grpc_request()

View File

@ -0,0 +1,369 @@
import json
from concurrent import futures
import grpc
import threading
import queue
import time
import logging
from typing import Dict, Optional
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
from middleware.MQTTService import MQTTService
from middleware.minio_util import downFile, upload_file
import sys
from middleware.util import get_current_date_and_milliseconds
print(sys.executable)
import os
import matplotlib.pyplot as plt
import numpy as np
import sam3.sam3
from PIL import Image
from sam3.sam3 import build_sam3_image_model
from sam3.sam3 import build_sam3_image_model_0228
from sam3.sam3.model.box_ops import box_xywh_to_cxcywh
from sam3.sam3.model.sam3_image_processor import Sam3Processor
from sam3.sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results,plot_results_savepic
sam3_root = os.path.join(os.path.dirname(sam3.sam3.__file__), "..")
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TaskQueue:
"""任务队列管理类"""
def __init__(self, max_size: int = 1000):
self.queue = queue.Queue(maxsize=max_size)
self.task_status: Dict[str, dict] = {} # 存储任务状态
self.lock = threading.Lock()
self.stop_event = threading.Event()
def add_task(self, task_id: str, request_data: dict) -> bool:
"""添加任务到队列"""
try:
task_item = {
'task_id': task_id,
'data': request_data,
'timestamp': time.time(),
'status': 'pending'
}
with self.lock:
self.task_status[task_id] = task_item
# 非阻塞方式放入队列
self.queue.put(task_item, block=False)
logger.info(f"任务 {task_id} 已添加到队列,当前队列大小: {self.queue.qsize()}")
return True
except queue.Full:
logger.warning(f"队列已满,任务 {task_id} 被拒绝")
return False
except Exception as e:
logger.error(f"添加任务失败: {e}")
return False
def get_task(self, timeout: float = 1.0) -> Optional[dict]:
"""从队列获取任务"""
try:
return self.queue.get(timeout=timeout)
except queue.Empty:
return None
def update_task_status(self, task_id: str, status: str, result: dict = None):
"""更新任务状态"""
with self.lock:
if task_id in self.task_status:
self.task_status[task_id]['status'] = status
if result:
self.task_status[task_id]['result'] = result
self.task_status[task_id]['completed_time'] = time.time()
def get_task_status(self, task_id: str) -> Optional[dict]:
"""获取任务状态"""
with self.lock:
return self.task_status.get(task_id)
def cleanup_old_tasks(self, max_age_seconds: int = 3600):
"""清理旧任务"""
with self.lock:
current_time = time.time()
to_delete = []
for task_id, task in self.task_status.items():
if 'completed_time' in task:
age = current_time - task['completed_time']
if age > max_age_seconds:
to_delete.append(task_id)
for task_id in to_delete:
del self.task_status[task_id]
logger.info(f"清理旧任务: {task_id}")
class TaskWorker(threading.Thread):
"""工作线程,从队列中取任务并处理"""
def __init__(self, worker_id: int, task_queue: TaskQueue, stop_event: threading.Event):
super().__init__(daemon=True)
self.worker_id = worker_id
self.task_queue = task_queue
self.stop_event = stop_event
self.processed_count = 0
def run(self):
logger.info(f"工作线程 {self.worker_id} 启动")
while not self.stop_event.is_set():
try:
# 从队列获取任务
task = self.task_queue.get_task(timeout=0.5)
if not task:
continue
task_id = task['task_id']
request_data = task['data']
logger.info(f"工作线程 {self.worker_id} 开始处理任务: {task_id}")
# 更新任务状态为处理中
self.task_queue.update_task_status(task_id, 'processing')
# 这里是你的实际处理逻辑
result = self.process_task(task_id, request_data)
# 更新任务状态为完成
self.task_queue.update_task_status(
task_id,
'completed' if result.get('success') else 'failed',
result
)
self.processed_count += 1
logger.info(f"工作线程 {self.worker_id} 完成任务: {task_id}, 处理总数: {self.processed_count}")
except Exception as e:
logger.error(f"工作线程 {self.worker_id} 处理任务失败: {e}")
if task:
self.task_queue.update_task_status(
task['task_id'],
'failed',
{'error': str(e)}
)
def process_task(self, task_id: str, request_data: dict) -> dict:
"""模拟耗时任务处理"""
# 这里替换为你的实际处理逻辑
# time.sleep(10) # 模拟10秒处理时间
task_id=request_data["task_id"]
sn=request_data["sn"]
img_url=request_data["img_url"]
prompt=request_data["prompt"]
confidence=request_data["confidence"]
mqtt_ip=request_data["mqtt_ip"]
mqtt_port=request_data["mqtt_port"]
mqtt_topic=request_data["mqtt_topic"]
local_image_path=downFile(img_url)
bpe_path = f"/home/beidou/test0623/sam3/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
config_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/config.json" # 替换为本地路径
checkpoint_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/sam3.pt" # 替换为本地路径
# model = build_sam3_image_model(bpe_path=bpe_path)
# 2. 构建模型(从本地加载)
model = build_sam3_image_model_0228(
bpe_path=bpe_path,
checkpoint_path=checkpoint_path,
config_path=config_path, # 可选
load_from_HF=False,
device="cuda",
eval_mode=True,
)
formatted_date, milliseconds_timestamp = get_current_date_and_milliseconds()
img_name=os.path.basename(local_image_path)
dir_name=os.path.dirname(local_image_path)
predict_save_path=os.path.join(dir_name,str(milliseconds_timestamp)+img_name)
# image = Image.open(image_path)
image = Image.open(local_image_path).convert("RGB")
width, height = image.size
processor = Sam3Processor(model, confidence_threshold=0.5)
inference_state = processor.set_image(image)
processor.reset_all_prompts(inference_state)
inference_state = processor.set_text_prompt(state=inference_state, prompt="road")
img0 = Image.open(local_image_path)
plot_results_savepic(img0, inference_state, save_path=predict_save_path)
object_name, _=upload_file(predict_save_path,None)
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
message = {
'success': True,
"task_id":task_id,
'object_name': object_name
}
mqtt.publish_sync(mqtt_topic, json.dumps(message, ensure_ascii=False))
# 删除本地文件
if os.path.exists(local_image_path):
os.remove(local_image_path)
if os.path.exists(predict_save_path):
os.remove(predict_save_path)
# 模拟处理结果
return {
'success': True,
'message': f'任务 {task_id} 处理完成',
'data': {'result': 'some_result'}
}
class TaskServiceServicer(grpc_sam3_img_pb2_grpc.TaskServiceServicer):
def __init__(self, task_queue: TaskQueue, max_workers: int = 1):
self.task_queue = task_queue
self.max_workers = max_workers
self.stop_event = threading.Event()
self.workers = []
# 启动工作线程
self.start_workers()
def start_workers(self):
"""启动工作线程池"""
for i in range(self.max_workers):
worker = TaskWorker(i, self.task_queue, self.stop_event)
worker.start()
self.workers.append(worker)
logger.info(f"启动了 {self.max_workers} 个工作线程")
def ProcessTask(self, request, context):
"""处理任务请求 - 将任务放入队列后立即返回"""
try:
# 检查队列是否已满
if self.task_queue.queue.full():
logger.warning(f"队列已满,拒绝任务: {request.task_id}")
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message="服务器忙,请稍后重试"
)
# 准备任务数据
task_data = {
'task_id': request.task_id,
'sn': request.sn,
'img_url': request.content_body.img_url,
'prompt': request.content_body.prompt,
'confidence': request.content_body.confidence,
'mqtt_ip': request.content_body.mqtt_ip,
'mqtt_port': request.content_body.mqtt_port,
'mqtt_topic': request.content_body.mqtt_topic
}
# 将任务添加到队列
if self.task_queue.add_task(request.task_id, task_data):
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=True,
message=f"任务已接收,正在排队处理。当前队列位置: {self.task_queue.queue.qsize()}"
)
else:
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message="任务提交失败"
)
except Exception as e:
logger.error(f"处理任务请求失败: {e}")
return grpc_sam3_img_pb2.TaskResponse(
task_id=request.task_id,
success=False,
message=f"服务器内部错误: {str(e)}"
)
def stop(self):
"""停止工作线程"""
self.stop_event.set()
for worker in self.workers:
worker.join(timeout=2)
logger.info("所有工作线程已停止")
class HealthCheckServicer(grpc_sam3_img_pb2_grpc.HealthCheckServicer):
def __init__(self, task_queue: TaskQueue):
self.task_queue = task_queue
def Check(self, request, context):
"""健康检查,包含队列状态"""
queue_size = self.task_queue.queue.qsize()
if queue_size > 50: # 队列过长
return grpc_sam3_img_pb2.HealthCheckResponse(
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.NOT_SERVING
)
else:
return grpc_sam3_img_pb2.HealthCheckResponse(
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
)
def serve():
# 创建任务队列
task_queue = TaskQueue(max_size=20)
# 创建服务实例
task_service = TaskServiceServicer(task_queue, max_workers=1) # 10个工作线程
health_service = HealthCheckServicer(task_queue)
# 创建gRPC服务器
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) # 处理RPC请求的线程
# 注册服务
grpc_sam3_img_pb2_grpc.add_TaskServiceServicer_to_server(task_service, server)
grpc_sam3_img_pb2_grpc.add_HealthCheckServicer_to_server(health_service, server)
# 启动服务器
server.add_insecure_port('[::]:50051')
server.start()
logger.info("服务器已启动,监听端口: 50051")
logger.info(f"工作线程数: 1, 队列最大容量: 20")
# 定时清理旧任务
def cleanup_loop():
while True:
time.sleep(300) # 每5分钟清理一次
task_queue.cleanup_old_tasks()
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
cleanup_thread.start()
# 优雅关闭处理
def shutdown():
logger.info("收到关闭信号,正在停止服务器...")
task_service.stop()
server.stop(5) # 5秒宽限期
logger.info("服务器已停止")
import signal
signal.signal(signal.SIGINT, lambda s, f: shutdown())
signal.signal(signal.SIGTERM, lambda s, f: shutdown())
# 保持服务器运行
try:
server.wait_for_termination()
except KeyboardInterrupt:
shutdown()
if __name__ == '__main__':
serve()

View File

@ -51,6 +51,7 @@ class MQTTService:
self._message_task = None self._message_task = None
self._connection_lock = asyncio.Lock() self._connection_lock = asyncio.Lock()
self.os_type = sys.platform.lower() self.os_type = sys.platform.lower()
self._loop = None # 保存事件循环
async def connect(self): async def connect(self):
async with self._connection_lock: async with self._connection_lock:
@ -122,6 +123,25 @@ class MQTTService:
await self.reconnect() await self.reconnect()
await self.client.publish(topic, payload, qos=qos, retain=retain) await self.client.publish(topic, payload, qos=qos, retain=retain)
def publish_sync(self, topic, payload, qos=0, retain=False):
"""同步发布消息(适用于同步代码调用)"""
if self._loop is None or self._loop.is_closed():
# 如果没有事件循环,创建一个新的
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
# 在事件循环中运行异步publish
if not self._loop.is_running():
return self._loop.run_until_complete(
self.publish(topic, payload, qos=qos, retain=retain)
)
else:
# 如果事件循环已经在运行,创建一个任务
asyncio.ensure_future(
self.publish(topic, payload, qos=qos, retain=retain)
)
return None
async def subscribe(self, topic, callback=None, qos=0): async def subscribe(self, topic, callback=None, qos=0):
if not self.is_connected: if not self.is_connected:
await self.connect() await self.connect()

153
sam3/.gitignore vendored Normal file
View File

@ -0,0 +1,153 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
*-Copy*.ipynb
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# PyCharm
.idea/
# VS Code
.vscode/
*.code-workspace
# Model weights and checkpoints
*.pth
*.pt
*.bin
*.ckpt
*.safetensors
weights/
checkpoints/
sam3_logs/
# Data files
*.h5
*.hdf5
*.pkl
*.pickle
*.npy
*.npz
# Logs
logs/
runs/
tensorboard/
# OS specific
.DS_Store
Thumbs.db
# BPE vocabulary files
*.bpe
*.vocab

80
sam3/CODE_OF_CONDUCT.md Normal file
View File

@ -0,0 +1,80 @@
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

30
sam3/CONTRIBUTING.md Normal file
View File

@ -0,0 +1,30 @@
# Contributing to sam3
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Make sure your code lints.
5. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to sam3, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

61
sam3/LICENSE Normal file
View File

@ -0,0 +1,61 @@
SAM License
Last Updated: November 19, 2025
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
“Documentation” means the specifications, manuals and documentation accompanying
SAM Materials distributed by Meta.
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entitys behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
“Trade Controls” means any of the following: Sanctions and applicable export and import controls.
By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
1. License Rights and Redistribution.
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Metas intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
b. Redistribution and Use.
i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
5. Intellectual Property.
a. Subject to Metas ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.

6
sam3/MANIFEST.in Normal file
View File

@ -0,0 +1,6 @@
include LICENSE
include README.md
recursive-include examples *.py
recursive-include examples *.ipynb
recursive-include examples *.md
recursive-include tests *.py

395
sam3/README.md Normal file
View File

@ -0,0 +1,395 @@
# SAM 3: Segment Anything with Concepts
Meta Superintelligence Labs
[Nicolas Carion](https://www.nicolascarion.com/)\*,
[Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en)\*,
[Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en)\*,
[Shoubhik Debnath](https://scholar.google.com/citations?user=fb6FOfsAAAAJ&hl=en)\*,
[Ronghang Hu](https://ronghanghu.com/)\*,
[Didac Suris](https://www.didacsuris.com/)\*,
[Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en)\*,
[Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en)\*,
[Haitham Khedr](https://hkhedr.com/)\*, Andrew Huang,
[Jie Lei](https://jayleicn.github.io/),
[Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en),
[Baishan Guo](https://scholar.google.com/citations?user=BC5wDu8AAAAJ&hl=en),
Arpit Kalla, [Markus Marks](https://damaggu.github.io/),
[Joseph Greer](https://scholar.google.com/citations?user=guL96CkAAAAJ&hl=en),
Meng Wang, [Peize Sun](https://peizesun.github.io/),
[Roman Rädle](https://scholar.google.com/citations?user=Tpt57v0AAAAJ&hl=en),
[Triantafyllos Afouras](https://www.robots.ox.ac.uk/~afourast/),
[Effrosyni Mavroudi](https://scholar.google.com/citations?user=vYRzGGEAAAAJ&hl=en),
[Katherine Xu](https://k8xu.github.io/)°,
[Tsung-Han Wu](https://patrickthwu.com/)°,
[Yu Zhou](https://yu-bryan-zhou.github.io/)°,
[Liliane Momeni](https://scholar.google.com/citations?user=Lb-KgVYAAAAJ&hl=en)°,
[Rishi Hazra](https://rishihazra.github.io/)°,
[Shuangrui Ding](https://mark12ding.github.io/)°,
[Sagar Vaze](https://sgvaze.github.io/)°,
[Francois Porcher](https://scholar.google.com/citations?user=LgHZ8hUAAAAJ&hl=en)°,
[Feng Li](https://fengli-ust.github.io/)°,
[Siyuan Li](https://siyuanliii.github.io/)°,
[Aishwarya Kamath](https://ashkamath.github.io/)°,
[Ho Kei Cheng](https://hkchengrex.com/)°,
[Piotr Dollar](https://pdollar.github.io/)†,
[Nikhila Ravi](https://nikhilaravi.com/)†,
[Kate Saenko](https://ai.bu.edu/ksaenko.html)†,
[Pengchuan Zhang](https://pzzhang.github.io/pzzhang/)†,
[Christoph Feichtenhofer](https://feichtenhofer.github.io/)†
\* core contributor, ° intern, † project lead, order is random within groups
[[`Paper`](https://ai.meta.com/research/publications/sam-3-segment-anything-with-concepts/)]
[[`Project`](https://ai.meta.com/sam3)]
[[`Demo`](https://segment-anything.com/)]
[[`Blog`](https://ai.meta.com/blog/segment-anything-model-3/)]
[[`BibTeX`](#citing-sam-3)]
![SAM 3 architecture](assets/model_diagram.png?raw=true) SAM 3 is a unified foundation model for promptable segmentation in images and videos. It can detect, segment, and track objects using text or visual prompts such as points, boxes, and masks. Compared to its predecessor [SAM 2](https://github.com/facebookresearch/sam2), SAM 3 introduces the ability to exhaustively segment all instances of an open-vocabulary concept specified by a short text phrase or exemplars. Unlike prior work, SAM 3 can handle a vastly larger set of open-vocabulary prompts. It achieves 75-80% of human performance on our new [SA-CO benchmark](https://github.com/facebookresearch/sam3?tab=readme-ov-file#sa-co-dataset) which contains 270K unique concepts, over 50 times more than existing benchmarks.
This breakthrough is driven by an innovative data engine that has automatically annotated over 4 million unique concepts, creating the largest high-quality open-vocabulary segmentation dataset to date. In addition, SAM 3 introduces a new model architecture featuring a presence token that improves discrimination between closely related text prompts (e.g., “a player in white” vs. “a player in red”), as well as a decoupled detectortracker design that minimizes task interference and scales efficiently with data.
<p align="center">
<img src="assets/dog.gif" width=380 />
<img src="assets/player.gif" width=380 />
</p>
## Installation
### Prerequisites
- Python 3.12 or higher
- PyTorch 2.7 or higher
- CUDA-compatible GPU with CUDA 12.6 or higher
1. **Create a new Conda environment:**
```bash
conda create -n sam3 python=3.12
conda deactivate
conda activate sam3
```
2. **Install PyTorch with CUDA support:**
```bash
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
```
3. **Clone the repository and install the package:**
```bash
git clone https://github.com/facebookresearch/sam3.git
cd sam3
pip install -e .
```
4. **Install additional dependencies for example notebooks or development:**
```bash
# For running example notebooks
pip install -e ".[notebooks]"
# For development
pip install -e ".[train,dev]"
```
## Getting Started
⚠️ Before using SAM 3, please request access to the checkpoints on the SAM 3
Hugging Face [repo](https://huggingface.co/facebook/sam3). Once accepted, you
need to be authenticated to download the checkpoints. You can do this by running
the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
(e.g. `hf auth login` after generating an access token.)
### Basic Usage
```python
import torch
#################################### For Image ####################################
from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
# Load the model
model = build_sam3_image_model()
processor = Sam3Processor(model)
# Load an image
image = Image.open("<YOUR_IMAGE_PATH.jpg>")
inference_state = processor.set_image(image)
# Prompt the model with text
output = processor.set_text_prompt(state=inference_state, prompt="<YOUR_TEXT_PROMPT>")
# Get the masks, bounding boxes, and scores
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
#################################### For Video ####################################
from sam3.model_builder import build_sam3_video_predictor
video_predictor = build_sam3_video_predictor()
video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
# Start a session
response = video_predictor.handle_request(
request=dict(
type="start_session",
resource_path=video_path,
)
)
response = video_predictor.handle_request(
request=dict(
type="add_prompt",
session_id=response["session_id"],
frame_index=0, # Arbitrary frame index
text="<YOUR_TEXT_PROMPT>",
)
)
output = response["outputs"]
```
## Examples
The `examples` directory contains notebooks demonstrating how to use SAM3 with
various types of prompts:
- [`sam3_image_predictor_example.ipynb`](examples/sam3_image_predictor_example.ipynb)
: Demonstrates how to prompt SAM 3 with text and visual box prompts on images.
- [`sam3_video_predictor_example.ipynb`](examples/sam3_video_predictor_example.ipynb)
: Demonstrates how to prompt SAM 3 with text prompts on videos, and doing
further interactive refinements with points.
- [`sam3_image_batched_inference.ipynb`](examples/sam3_image_batched_inference.ipynb)
: Demonstrates how to run batched inference with SAM 3 on images.
- [`sam3_agent.ipynb`](examples/sam3_agent.ipynb): Demonsterates the use of SAM
3 Agent to segment complex text prompt on images.
- [`saco_gold_silver_vis_example.ipynb`](examples/saco_gold_silver_vis_example.ipynb)
: Shows a few examples from SA-Co image evaluation set.
- [`saco_veval_vis_example.ipynb`](examples/saco_veval_vis_example.ipynb) :
Shows a few examples from SA-Co video evaluation set.
There are additional notebooks in the examples directory that demonstrate how to
use SAM 3 for interactive instance segmentation in images and videos (SAM 1/2
tasks), or as a tool for an MLLM, and how to run evaluations on the SA-Co
dataset.
To run the Jupyter notebook examples:
```bash
# Make sure you have the notebooks dependencies installed
pip install -e ".[notebooks]"
# Start Jupyter notebook
jupyter notebook examples/sam3_image_predictor_example.ipynb
```
## Model
SAM 3 consists of a detector and a tracker that share a vision encoder. It has 848M parameters. The
detector is a DETR-based model conditioned on text, geometry, and image
exemplars. The tracker inherits the SAM 2 transformer encoder-decoder
architecture, supporting video segmentation and interactive refinement.
## Image Results
<div align="center">
<table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
<thead>
<tr>
<th rowspan="3" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
<th colspan="3" style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">Instance Segmentation</th>
<th colspan="5" style="text-align: center; padding: 12px 20px">Box Detection</th>
</tr>
<tr>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
<th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">SA-Co/Gold</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">COCO</th>
<th style="text-align: center; padding: 12px 20px">SA-Co/Gold</th>
</tr>
<tr>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
<th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">cgF1</th>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
<th style="text-align: center; padding: 12px 20px">AP</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP<sub>o</sub>
</th>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
</tr>
</thead>
<tbody>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">72.8</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">74.0</td>
</tr>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">OWLv2*</td>
<td style="text-align: center; padding: 10px 20px; color: #999">29.3</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">43.4</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">24.6</td>
<td style="text-align: center; padding: 10px 20px; color: #999">30.2</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">45.5</td>
<td style="text-align: center; padding: 10px 20px">46.1</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">23.9</td>
<td style="text-align: center; padding: 10px 20px">24.5</td>
</tr>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">DINO-X</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">38.5</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">21.3</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">52.4</td>
<td style="text-align: center; padding: 10px 20px">56.0</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">22.5</td>
</tr>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Gemini 2.5</td>
<td style="text-align: center; padding: 10px 20px">13.4</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">13.0</td>
<td style="text-align: center; padding: 10px 20px">16.1</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">-</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">14.4</td>
</tr>
<tr style="border-top: 2px solid #b19c9cff">
<td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
<td style="text-align: center; padding: 10px 20px">37.2</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">48.5</td>
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">54.1</td>
<td style="text-align: center; padding: 10px 20px">40.6</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">53.6</td>
<td style="text-align: center; padding: 10px 20px">56.4</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">55.7</td>
<td style="text-align: center; padding: 10px 20px">55.7</td>
</tr>
</tbody>
</table>
<p style="text-align: center; margin-top: 10px; font-size: 0.9em; color: #ddd;">* Partially trained on LVIS, AP<sub>o</sub> refers to COCO-O accuracy</p>
</div>
## Video Results
<div align="center">
<table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
<thead>
<tr>
<th rowspan="2" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SA-V test</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">YT-Temporal-1B test</th>
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SmartGlasses test</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVVIS test</th>
<th style="text-align: center; padding: 12px 20px">BURST test</th>
</tr>
<tr>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
<th style="text-align: center; padding: 12px 20px">cgF1</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">mAP</th>
<th style="text-align: center; padding: 12px 20px">HOTA</th>
</tr>
</thead>
<tbody>
<tr>
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
<td style="text-align: center; padding: 10px 20px">53.1</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">70.5</td>
<td style="text-align: center; padding: 10px 20px">71.2</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">78.4</td>
<td style="text-align: center; padding: 10px 20px">58.5</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">72.3</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
<td style="text-align: center; padding: 10px 20px">-</td>
</tr>
<tr style="border-top: 2px solid #b19c9cff">
<td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
<td style="text-align: center; padding: 10px 20px">30.3</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">58.0</td>
<td style="text-align: center; padding: 10px 20px">50.8</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">69.9</td>
<td style="text-align: center; padding: 10px 20px">36.4</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">63.6</td>
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">36.3</td>
<td style="text-align: center; padding: 10px 20px">44.5</td>
</tr>
</tbody>
</table>
</div>
## SA-Co Dataset
We release 2 image benchmarks, [SA-Co/Gold](scripts/eval/gold/README.md) and
[SA-Co/Silver](scripts/eval/silver/README.md), and a video benchmark
[SA-Co/VEval](scripts/eval/veval/README.md). The datasets contain images (or videos) with annotated noun phrases. Each image/video and noun phrase pair is annotated with instance masks and unique IDs of each object matching the phrase. Phrases that have no matching objects (negative prompts) have no masks, shown in red font in the figure. See the linked READMEs for more details on how to download and run evaluations on the datasets.
* HuggingFace host: [SA-Co/Gold](https://huggingface.co/datasets/facebook/SACo-Gold), [SA-Co/Silver](https://huggingface.co/datasets/facebook/SACo-Silver) and [SA-Co/VEval](https://huggingface.co/datasets/facebook/SACo-VEval)
* Roboflow host: [SA-Co/Gold](https://universe.roboflow.com/sa-co-gold), [SA-Co/Silver](https://universe.roboflow.com/sa-co-silver) and [SA-Co/VEval](https://universe.roboflow.com/sa-co-veval)
![SA-Co dataset](assets/sa_co_dataset.jpg?raw=true)
## Development
To set up the development environment:
```bash
pip install -e ".[dev,train]"
```
To format the code:
```bash
ufmt format .
```
## Contributing
See [contributing](CONTRIBUTING.md) and the
[code of conduct](CODE_OF_CONDUCT.md).
## License
This project is licensed under the SAM License - see the [LICENSE](LICENSE) file
for details.
## Acknowledgements
We would like to thank the following people for their contributions to the SAM 3 project: Alex He, Alexander Kirillov,
Alyssa Newcomb, Ana Paula Kirschner Mofarrej, Andrea Madotto, Andrew Westbury, Ashley Gabriel, Azita Shokpour,
Ben Samples, Bernie Huang, Carleigh Wood, Ching-Feng Yeh, Christian Puhrsch, Claudette Ward, Daniel Bolya,
Daniel Li, Facundo Figueroa, Fazila Vhora, George Orlin, Hanzi Mao, Helen Klein, Hu Xu, Ida Cheng, Jake Kinney,
Jiale Zhi, Jo Sampaio, Joel Schlosser, Justin Johnson, Kai Brown, Karen Bergan, Karla Martucci, Kenny Lehmann,
Maddie Mintz, Mallika Malhotra, Matt Ward, Michelle Chan, Michelle Restrepo, Miranda Hartley, Muhammad Maaz,
Nisha Deo, Peter Park, Phillip Thomas, Raghu Nayani, Rene Martinez Doehner, Robbie Adkins, Ross Girshik, Sasha
Mitts, Shashank Jain, Spencer Whitehead, Ty Toledano, Valentin Gabeur, Vincent Cho, Vivian Lee, William Ngan,
Xuehai He, Yael Yungster, Ziqi Pang, Ziyi Dou, Zoe Quake.
## Citing SAM 3
If you use SAM 3 or the SA-Co dataset in your research, please use the following BibTeX entry.
```bibtex
@misc{carion2025sam3segmentconcepts,
title={SAM 3: Segment Anything with Concepts},
author={Nicolas Carion and Laura Gustafson and Yuan-Ting Hu and Shoubhik Debnath and Ronghang Hu and Didac Suris and Chaitanya Ryali and Kalyan Vasudev Alwala and Haitham Khedr and Andrew Huang and Jie Lei and Tengyu Ma and Baishan Guo and Arpit Kalla and Markus Marks and Joseph Greer and Meng Wang and Peize Sun and Roman Rädle and Triantafyllos Afouras and Effrosyni Mavroudi and Katherine Xu and Tsung-Han Wu and Yu Zhou and Liliane Momeni and Rishi Hazra and Shuangrui Ding and Sagar Vaze and Francois Porcher and Feng Li and Siyuan Li and Aishwarya Kamath and Ho Kei Cheng and Piotr Dollár and Nikhila Ravi and Kate Saenko and Pengchuan Zhang and Christoph Feichtenhofer},
year={2025},
eprint={2511.16719},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2511.16719},
}
```

190
sam3/README_TRAIN.md Normal file
View File

@ -0,0 +1,190 @@
# Training
This repository supports finetuning SAM3 models on custom datasets in multi-node setup or local execution. The training script is located at `sam3/train.py` and uses Hydra configuration management to handle complex training setups.
## Installation
```bash
cd sam3
pip install -e ".[train]"
```
### Training Script Usage
The main training script is located at `sam3/train.py`. It uses Hydra configuration management to handle complex training setups.
#### Basic Usage
```bash
# Example: Train on Roboflow dataset
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml
# Example: Train on ODinW13 dataset
python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
```
Follow [`Roboflow 100-VL`](https://github.com/roboflow/rf100-vl/) to download the roboflow 100-vl datasets. Follow [`GLIP`](https://github.com/microsoft/GLIP) to download the ODinW datasets. The data folder should be organized as follows, and put your roboflow_vl_100_root and odinw_data_root in the job configs.
```
roboflow_vl_100_root:
13-lkc01
train
valid
test
2024-frc
actions
...
odinw_data_root:
AerialMaritimeDrone
large
train
valid
test
Aquarium
...
```
#### Command Line Arguments
The training script supports several command line arguments:
```bash
python sam3/train/train.py \
-c CONFIG_NAME \
[--use-cluster 0|1] \
[--partition PARTITION_NAME] \
[--account ACCOUNT_NAME] \
[--qos QOS_NAME] \
[--num-gpus NUM_GPUS] \
[--num-nodes NUM_NODES]
```
**Arguments:**
- `-c, --config`: **Required.** Path to the configuration file (e.g., `sam3/train/configs/roboflow_v100_full_ft_100_images.yaml`)
- `--use-cluster`: Whether to launch on a cluster (0: local, 1: cluster). Default: uses config setting
- `--partition`: SLURM partition name for cluster execution
- `--account`: SLURM account name for cluster execution
- `--qos`: SLURM QOS (Quality of Service) setting
- `--num-gpus`: Number of GPUs per node. Default: uses config setting
- `--num-nodes`: Number of nodes for distributed training. Default: uses config setting
#### Local Training Examples
```bash
# Single GPU training
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 1
# Multi-GPU training on a single node
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 4
# Force local execution even if config specifies GPUs
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0
```
#### Cluster Training Examples
```bash
# Basic cluster training with default settings from config
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 1
# Cluster training with specific SLURM settings
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
--use-cluster 1 \
--partition gpu_partition \
--account my_account \
--qos high_priority \
--num-gpus 8 \
--num-nodes 2
```
### Configuration Files
Training configurations are stored in `sam3/train/configs/`. The configuration files use Hydra's YAML format and support:
- **Dataset Configuration**: Data paths, transforms, and loading parameters
- **Model Configuration**: Architecture settings, checkpoint paths, and model parameters
- **Training Configuration**: Batch sizes, learning rates, optimization settings
- **Launcher Configuration**: Distributed training and cluster settings
- **Logging Configuration**: TensorBoard, experiment tracking, and output directories
#### Key Configuration Sections
```yaml
# Paths to datasets and checkpoints
paths:
bpe_path: /path/to/bpe/file
dataset_root: /path/to/dataset
experiment_log_dir: /path/to/logs
# Launcher settings for local/cluster execution
launcher:
num_nodes: 1
gpus_per_node: 2
experiment_log_dir: ${paths.experiment_log_dir}
# Cluster execution settings
submitit:
use_cluster: True
timeout_hour: 72
cpus_per_task: 10
partition: null
account: null
```
### Monitoring Training
The training script automatically sets up logging and saves outputs to the experiment directory:
```bash
# Logs are saved to the experiment_log_dir specified in config
experiment_log_dir/
├── config.yaml # Original configuration
├── config_resolved.yaml # Resolved configuration with all variables expanded
├── checkpoints/ # Model checkpoints (if skip_checkpointing=False)
├── tensorboard/ # TensorBoard logs
├── logs/ # Text logs
└── submitit_logs/ # Cluster job logs (if using cluster)
```
You can monitor training progress using TensorBoard:
```bash
tensorboard --logdir /path/to/experiment_log_dir/tensorboard
```
### Job Arrays for Dataset Sweeps
The Roboflow and ODinW configuration supports job arrays for training multiple models on different datasets:
This feature is specifically enabled via,
```yaml
submitit:
job_array:
num_tasks: 100
task_index: 0
```
The configuration includes a complete list of 100 Roboflow supercategories, and the `submitit.job_array.task_index` automatically selects which dataset to use based on the array job index.
```bash
# Submit job array to train on different Roboflow datasets
# The job array index selects which dataset from all_roboflow_supercategories
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
--use-cluster 1
```
### Reproduce ODinW13 10-shot results
Running the following job will give the results on the ODinW13 seed 300, see `odinw_train.train_file: fewshot_train_shot10_seed300` in the config file.
```bash
# Example: Train on ODinW13 dataset
python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
```
Change `odinw_train.train_file` to `fewshot_train_shot10_seed30` and `fewshot_train_shot10_seed3` to get the results for the other two seeds. Final results are aggregated from the three seeds. Notice that a small number of jobs may diverge during training, in which case we just use the last checkpoint's result before it diverges.
### Eval Script Usage
With a similar setup as the training config, the training script `sam3/train.py` can also be used for evaluation, too, when setting `trainer.mode = val` in the job config. Run the following job will give the results on the zero-shot results on RF100-VL and ODinW13 datasets.
```bash
# Example: Evaluate on Roboflow dataset
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_eval.yaml
# Example: Evaluate on ODinW13 dataset
python sam3/train/train.py -c configs/odinw13/odinw_text_only.yaml
```

BIN
sam3/assets/dog.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 707 KiB

BIN
sam3/assets/player.gif Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 991 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 141 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 113 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Some files were not shown because too many files have changed in this diff Show More