增加sam3,集成并通讯成功
@ -1,54 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
||||||
# NO CHECKED-IN PROTOBUF GENCODE
|
|
||||||
# source: grpc_proto/check_grpc/check_grpc.proto
|
|
||||||
# Protobuf Python Version: 6.31.1
|
|
||||||
"""Generated protocol buffer code."""
|
|
||||||
from google.protobuf import descriptor as _descriptor
|
|
||||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
||||||
from google.protobuf import runtime_version as _runtime_version
|
|
||||||
from google.protobuf import symbol_database as _symbol_database
|
|
||||||
from google.protobuf.internal import builder as _builder
|
|
||||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
||||||
_runtime_version.Domain.PUBLIC,
|
|
||||||
6,
|
|
||||||
31,
|
|
||||||
1,
|
|
||||||
'',
|
|
||||||
'grpc_proto/check_grpc/check_grpc.proto'
|
|
||||||
)
|
|
||||||
# @@protoc_insertion_point(imports)
|
|
||||||
|
|
||||||
_sym_db = _symbol_database.Default()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n&grpc_proto/check_grpc/check_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
|
|
||||||
|
|
||||||
_globals = globals()
|
|
||||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
||||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_proto.check_grpc.check_grpc_pb2', _globals)
|
|
||||||
if not _descriptor._USE_C_DESCRIPTORS:
|
|
||||||
DESCRIPTOR._loaded_options = None
|
|
||||||
_globals['_HEALTHCHECKREQUEST']._serialized_start=48
|
|
||||||
_globals['_HEALTHCHECKREQUEST']._serialized_end=85
|
|
||||||
_globals['_HEALTHCHECKRESPONSE']._serialized_start=88
|
|
||||||
_globals['_HEALTHCHECKRESPONSE']._serialized_end=247
|
|
||||||
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=168
|
|
||||||
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=247
|
|
||||||
_globals['_TASKREQUEST']._serialized_start=249
|
|
||||||
_globals['_TASKREQUEST']._serialized_end=332
|
|
||||||
_globals['_CONTENTBODY']._serialized_start=335
|
|
||||||
_globals['_CONTENTBODY']._serialized_end=506
|
|
||||||
_globals['_PARALIST']._serialized_start=508
|
|
||||||
_globals['_PARALIST']._serialized_end=563
|
|
||||||
_globals['_INVADE']._serialized_start=565
|
|
||||||
_globals['_INVADE']._serialized_end=619
|
|
||||||
_globals['_TASKRESPONSE']._serialized_start=621
|
|
||||||
_globals['_TASKRESPONSE']._serialized_end=686
|
|
||||||
_globals['_TASKSERVICE']._serialized_start=688
|
|
||||||
_globals['_TASKSERVICE']._serialized_end=755
|
|
||||||
_globals['_HEALTHCHECK']._serialized_start=757
|
|
||||||
_globals['_HEALTHCHECK']._serialized_end=832
|
|
||||||
# @@protoc_insertion_point(module_scope)
|
|
||||||
76
grpc_util/grpc_proto_demo/async_check_grpc_client.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import asyncio
|
||||||
|
import grpc
|
||||||
|
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
|
||||||
|
|
||||||
|
|
||||||
|
async def async_check_server_status(channel):
|
||||||
|
try:
|
||||||
|
health_stub = check_grpc_pb2_grpc.HealthCheckStub(channel)
|
||||||
|
# 注意:这里不需要await,因为Check方法不是异步的
|
||||||
|
response = health_stub.Check(check_grpc_pb2.HealthCheckRequest(service="TaskService"))
|
||||||
|
return response.status == check_grpc_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
print(f"Health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def async_check_grpc_request(max_retries=3, delay=5):
|
||||||
|
retries = 0
|
||||||
|
channel = None
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
channel = grpc.insecure_channel('localhost:50051')
|
||||||
|
|
||||||
|
if not await async_check_server_status(channel):
|
||||||
|
raise Exception("Server is not healthy")
|
||||||
|
|
||||||
|
stub = check_grpc_pb2_grpc.TaskServiceStub(channel)
|
||||||
|
request = check_grpc_pb2.TaskRequest(
|
||||||
|
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354",
|
||||||
|
sn="8UUXN6S00A0CK7",
|
||||||
|
content_body=check_grpc_pb2.ContentBody(
|
||||||
|
org_code="HMZHB",
|
||||||
|
func_id=[101204],
|
||||||
|
source_url="xxxxxxxxxx",
|
||||||
|
push_url="",
|
||||||
|
confidence=0.4,
|
||||||
|
para_list=[
|
||||||
|
check_grpc_pb2.ParaList(
|
||||||
|
func_id=101204,
|
||||||
|
para_invade_enable=True
|
||||||
|
)
|
||||||
|
],
|
||||||
|
invade=check_grpc_pb2.Invade(
|
||||||
|
invade_file="meta_data/高压线-0826.geojson",
|
||||||
|
camera_para_url="meta_data/camera_para/hami_camera_para .txt"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# ProcessTask可能也不是异步的,所以不需要await
|
||||||
|
response = stub.ProcessTask(request)
|
||||||
|
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
channel.close() # 同步关闭,不需要await
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
retries += 1
|
||||||
|
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
retries += 1
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
channel.close() # 确保在重试前关闭连接
|
||||||
|
channel = None
|
||||||
|
|
||||||
|
await asyncio.sleep(delay) # 异步等待
|
||||||
|
|
||||||
|
print("All retry attempts failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(async_check_grpc_request())
|
||||||
54
grpc_util/grpc_proto_demo/check_grpc/check_grpc_pb2.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# NO CHECKED-IN PROTOBUF GENCODE
|
||||||
|
# source: check_grpc.proto
|
||||||
|
# Protobuf Python Version: 6.31.1
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import runtime_version as _runtime_version
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||||
|
_runtime_version.Domain.PUBLIC,
|
||||||
|
6,
|
||||||
|
31,
|
||||||
|
1,
|
||||||
|
'',
|
||||||
|
'check_grpc.proto'
|
||||||
|
)
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x63heck_grpc.proto\x12\x04task\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\x9f\x01\n\x13HealthCheckResponse\x12\x37\n\x06status\x18\x01 \x01(\x0e\x32\'.task.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"S\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\'\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x11.task.ContentBody\"\xab\x01\n\x0b\x43ontentBody\x12\x10\n\x08org_code\x18\x01 \x01(\t\x12\x0f\n\x07\x66unc_id\x18\x02 \x03(\x05\x12\x12\n\nsource_url\x18\x03 \x01(\t\x12\x10\n\x08push_url\x18\x04 \x01(\t\x12\x12\n\nconfidence\x18\x05 \x01(\x02\x12!\n\tpara_list\x18\x06 \x03(\x0b\x32\x0e.task.ParaList\x12\x1c\n\x06invade\x18\x07 \x01(\x0b\x32\x0c.task.Invade\"7\n\x08ParaList\x12\x0f\n\x07\x66unc_id\x18\x01 \x01(\x05\x12\x1a\n\x12para_invade_enable\x18\x02 \x01(\x08\"6\n\x06Invade\x12\x13\n\x0binvade_file\x18\x01 \x01(\t\x12\x17\n\x0f\x63\x61mera_para_url\x18\x02 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2C\n\x0bTaskService\x12\x34\n\x0bProcessTask\x12\x11.task.TaskRequest\x1a\x12.task.TaskResponse2K\n\x0bHealthCheck\x12<\n\x05\x43heck\x12\x18.task.HealthCheckRequest\x1a\x19.task.HealthCheckResponseb\x06proto3')
|
||||||
|
|
||||||
|
_globals = globals()
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'check_grpc_pb2', _globals)
|
||||||
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._loaded_options = None
|
||||||
|
_globals['_HEALTHCHECKREQUEST']._serialized_start=26
|
||||||
|
_globals['_HEALTHCHECKREQUEST']._serialized_end=63
|
||||||
|
_globals['_HEALTHCHECKRESPONSE']._serialized_start=66
|
||||||
|
_globals['_HEALTHCHECKRESPONSE']._serialized_end=225
|
||||||
|
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=146
|
||||||
|
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=225
|
||||||
|
_globals['_TASKREQUEST']._serialized_start=227
|
||||||
|
_globals['_TASKREQUEST']._serialized_end=310
|
||||||
|
_globals['_CONTENTBODY']._serialized_start=313
|
||||||
|
_globals['_CONTENTBODY']._serialized_end=484
|
||||||
|
_globals['_PARALIST']._serialized_start=486
|
||||||
|
_globals['_PARALIST']._serialized_end=541
|
||||||
|
_globals['_INVADE']._serialized_start=543
|
||||||
|
_globals['_INVADE']._serialized_end=597
|
||||||
|
_globals['_TASKRESPONSE']._serialized_start=599
|
||||||
|
_globals['_TASKRESPONSE']._serialized_end=664
|
||||||
|
_globals['_TASKSERVICE']._serialized_start=666
|
||||||
|
_globals['_TASKSERVICE']._serialized_end=733
|
||||||
|
_globals['_HEALTHCHECK']._serialized_start=735
|
||||||
|
_globals['_HEALTHCHECK']._serialized_end=810
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
||||||
@ -3,7 +3,7 @@
|
|||||||
import grpc
|
import grpc
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from grpc_proto.check_grpc import check_grpc_pb2 as grpc__proto_dot_check__grpc_dot_check__grpc__pb2
|
import grpc_util.grpc_proto_demo.check_grpc.check_grpc_pb2 as check__grpc__pb2
|
||||||
|
|
||||||
GRPC_GENERATED_VERSION = '1.76.0'
|
GRPC_GENERATED_VERSION = '1.76.0'
|
||||||
GRPC_VERSION = grpc.__version__
|
GRPC_VERSION = grpc.__version__
|
||||||
@ -18,7 +18,7 @@ except ImportError:
|
|||||||
if _version_not_supported:
|
if _version_not_supported:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||||
+ ' but the generated code in grpc_proto/check_grpc/check_grpc_pb2_grpc.py depends on'
|
+ ' but the generated code in check_grpc_pb2_grpc.py depends on'
|
||||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||||
@ -36,8 +36,8 @@ class TaskServiceStub(object):
|
|||||||
"""
|
"""
|
||||||
self.ProcessTask = channel.unary_unary(
|
self.ProcessTask = channel.unary_unary(
|
||||||
'/task.TaskService/ProcessTask',
|
'/task.TaskService/ProcessTask',
|
||||||
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString,
|
request_serializer=check__grpc__pb2.TaskRequest.SerializeToString,
|
||||||
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString,
|
response_deserializer=check__grpc__pb2.TaskResponse.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
@ -55,8 +55,8 @@ def add_TaskServiceServicer_to_server(servicer, server):
|
|||||||
rpc_method_handlers = {
|
rpc_method_handlers = {
|
||||||
'ProcessTask': grpc.unary_unary_rpc_method_handler(
|
'ProcessTask': grpc.unary_unary_rpc_method_handler(
|
||||||
servicer.ProcessTask,
|
servicer.ProcessTask,
|
||||||
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.FromString,
|
request_deserializer=check__grpc__pb2.TaskRequest.FromString,
|
||||||
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.SerializeToString,
|
response_serializer=check__grpc__pb2.TaskResponse.SerializeToString,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
@ -84,8 +84,8 @@ class TaskService(object):
|
|||||||
request,
|
request,
|
||||||
target,
|
target,
|
||||||
'/task.TaskService/ProcessTask',
|
'/task.TaskService/ProcessTask',
|
||||||
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskRequest.SerializeToString,
|
check__grpc__pb2.TaskRequest.SerializeToString,
|
||||||
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.TaskResponse.FromString,
|
check__grpc__pb2.TaskResponse.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
insecure,
|
insecure,
|
||||||
@ -109,8 +109,8 @@ class HealthCheckStub(object):
|
|||||||
"""
|
"""
|
||||||
self.Check = channel.unary_unary(
|
self.Check = channel.unary_unary(
|
||||||
'/task.HealthCheck/Check',
|
'/task.HealthCheck/Check',
|
||||||
request_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString,
|
request_serializer=check__grpc__pb2.HealthCheckRequest.SerializeToString,
|
||||||
response_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString,
|
response_deserializer=check__grpc__pb2.HealthCheckResponse.FromString,
|
||||||
_registered_method=True)
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
@ -129,8 +129,8 @@ def add_HealthCheckServicer_to_server(servicer, server):
|
|||||||
rpc_method_handlers = {
|
rpc_method_handlers = {
|
||||||
'Check': grpc.unary_unary_rpc_method_handler(
|
'Check': grpc.unary_unary_rpc_method_handler(
|
||||||
servicer.Check,
|
servicer.Check,
|
||||||
request_deserializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.FromString,
|
request_deserializer=check__grpc__pb2.HealthCheckRequest.FromString,
|
||||||
response_serializer=grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.SerializeToString,
|
response_serializer=check__grpc__pb2.HealthCheckResponse.SerializeToString,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
generic_handler = grpc.method_handlers_generic_handler(
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
@ -159,8 +159,8 @@ class HealthCheck(object):
|
|||||||
request,
|
request,
|
||||||
target,
|
target,
|
||||||
'/task.HealthCheck/Check',
|
'/task.HealthCheck/Check',
|
||||||
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckRequest.SerializeToString,
|
check__grpc__pb2.HealthCheckRequest.SerializeToString,
|
||||||
grpc__proto_dot_check__grpc_dot_check__grpc__pb2.HealthCheckResponse.FromString,
|
check__grpc__pb2.HealthCheckResponse.FromString,
|
||||||
options,
|
options,
|
||||||
channel_credentials,
|
channel_credentials,
|
||||||
insecure,
|
insecure,
|
||||||
@ -1,7 +1,9 @@
|
|||||||
import grpc
|
import grpc
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
|
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def check_server_status(channel):
|
def check_server_status(channel):
|
||||||
@ -1,8 +1,7 @@
|
|||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import grpc
|
import grpc
|
||||||
import time
|
|
||||||
|
|
||||||
from grpc_proto.check_grpc import check_grpc_pb2_grpc, check_grpc_pb2
|
from grpc_util.grpc_proto_demo.check_grpc import check_grpc_pb2, check_grpc_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer):
|
class TaskServiceServicer(check_grpc_pb2_grpc.TaskServiceServicer):
|
||||||
@ -36,12 +35,21 @@ def serve():
|
|||||||
server.start()
|
server.start()
|
||||||
print("Server started, listening on port 50051...")
|
print("Server started, listening on port 50051...")
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
while True:
|
# while True:
|
||||||
time.sleep(86400) # 保持运行
|
# time.sleep(86400) # 保持运行
|
||||||
except KeyboardInterrupt:
|
# except KeyboardInterrupt:
|
||||||
|
# server.stop(0)
|
||||||
|
# 信号处理
|
||||||
|
import signal
|
||||||
|
def shutdown_handler(signum, frame):
|
||||||
|
print(f"Received signal {signum}, shutting down...")
|
||||||
server.stop(0)
|
server.stop(0)
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, shutdown_handler)
|
||||||
|
signal.signal(signal.SIGTERM, shutdown_handler)
|
||||||
|
|
||||||
|
server.wait_for_termination()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
serve()
|
serve()
|
||||||
8
grpc_util/grpc_proto_demo/readme
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
1、参考check_grpc.proto 手写 proto 相关结构
|
||||||
|
2、使用编译命令生成grpc通讯相关,命令如下
|
||||||
|
python -m grpc_tools.protoc --proto_path=proto_dir \
|
||||||
|
--python_out=gen_dir \
|
||||||
|
--grpc_python_out=gen_dir \
|
||||||
|
proto_dir/check_grpc.proto
|
||||||
|
只会生成requeset、response
|
||||||
|
3、拷贝当前文件下的check_grpc_client.py 、check_grpc_server.py,重写逻辑代码
|
||||||
0
grpc_util/grpc_sam3/__init__.py
Normal file
89
grpc_util/grpc_sam3/async_sam3_grpc_client.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import asyncio
|
||||||
|
import grpc
|
||||||
|
# from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||||
|
# from grpc_util.grpc_sam3.grpc_sam3_img_pb2_grpc import TaskServiceStub, HealthCheckStub
|
||||||
|
# from grpc_util.grpc_sam3.grpc_sam3_img_pb2 import TaskRequest, TaskResponse, HealthCheckRequest, HealthCheckResponse
|
||||||
|
# from . import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||||
|
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||||
|
|
||||||
|
|
||||||
|
async def async_check_server_status(channel):
|
||||||
|
try:
|
||||||
|
health_stub = grpc_sam3_img_pb2_grpc.HealthCheckStub(channel)
|
||||||
|
# 注意:这里不需要await,因为Check方法不是异步的
|
||||||
|
response = health_stub.Check(grpc_sam3_img_pb2.HealthCheckRequest(service="TaskService"))
|
||||||
|
return response.status == grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
print(f"Health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def grpc_sam3_pic_predict(
|
||||||
|
task_id:str,
|
||||||
|
sn:str,
|
||||||
|
img_url:str,
|
||||||
|
prompt:str,
|
||||||
|
confidence:float,
|
||||||
|
mqtt_ip:str,
|
||||||
|
mqtt_port:int,
|
||||||
|
mqtt_topic:str,
|
||||||
|
max_retries=3, delay=5):
|
||||||
|
retries = 0
|
||||||
|
channel = None
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
channel = grpc.insecure_channel('0.0.0.0:50051')
|
||||||
|
|
||||||
|
if not await async_check_server_status(channel):
|
||||||
|
raise Exception("Server is not healthy")
|
||||||
|
|
||||||
|
stub = grpc_sam3_img_pb2_grpc.TaskServiceStub(channel)
|
||||||
|
request = grpc_sam3_img_pb2.TaskRequest(
|
||||||
|
task_id=task_id,
|
||||||
|
sn=sn,
|
||||||
|
content_body=grpc_sam3_img_pb2.ContentBody(
|
||||||
|
img_url=img_url,
|
||||||
|
prompt=prompt,
|
||||||
|
confidence=0.5,
|
||||||
|
mqtt_ip=mqtt_ip,
|
||||||
|
mqtt_port=mqtt_port,
|
||||||
|
mqtt_topic=mqtt_topic
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# ProcessTask可能也不是异步的,所以不需要await
|
||||||
|
response = stub.ProcessTask(request)
|
||||||
|
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
channel.close() # 同步关闭,不需要await
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
retries += 1
|
||||||
|
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
retries += 1
|
||||||
|
|
||||||
|
if channel:
|
||||||
|
channel.close() # 确保在重试前关闭连接
|
||||||
|
channel = None
|
||||||
|
|
||||||
|
await asyncio.sleep(delay) # 异步等待
|
||||||
|
|
||||||
|
print("All retry attempts failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(grpc_sam3_pic_predict(
|
||||||
|
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354111",
|
||||||
|
sn="8UUXN6S00A0CK7",
|
||||||
|
img_url="demo/03.png",
|
||||||
|
prompt="cat",
|
||||||
|
confidence=0.5,
|
||||||
|
mqtt_ip="47.108.62.6",
|
||||||
|
mqtt_port=12503,
|
||||||
|
mqtt_topic="thing/product/ai/events",
|
||||||
|
))
|
||||||
48
grpc_util/grpc_sam3/grpc_sam3_img.proto
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package grpc_sam3_img;
|
||||||
|
|
||||||
|
service TaskService {
|
||||||
|
rpc ProcessTask (TaskRequest) returns (TaskResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加健康检查服务
|
||||||
|
service HealthCheck {
|
||||||
|
rpc Check (HealthCheckRequest) returns (HealthCheckResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
message HealthCheckRequest {
|
||||||
|
string service = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message HealthCheckResponse {
|
||||||
|
enum ServingStatus {
|
||||||
|
UNKNOWN = 0;
|
||||||
|
SERVING = 1;
|
||||||
|
NOT_SERVING = 2;
|
||||||
|
SERVICE_UNKNOWN = 3;
|
||||||
|
}
|
||||||
|
ServingStatus status = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TaskRequest {
|
||||||
|
string task_id = 1;
|
||||||
|
string sn = 2;
|
||||||
|
ContentBody content_body = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ContentBody {
|
||||||
|
string img_url = 1;
|
||||||
|
string prompt = 2;
|
||||||
|
float confidence =3;
|
||||||
|
string mqtt_ip = 4;
|
||||||
|
int32 mqtt_port = 5;
|
||||||
|
string mqtt_topic = 6;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
message TaskResponse {
|
||||||
|
string task_id = 1;
|
||||||
|
bool success = 2;
|
||||||
|
string message = 3;
|
||||||
|
}
|
||||||
50
grpc_util/grpc_sam3/grpc_sam3_img_pb2.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# NO CHECKED-IN PROTOBUF GENCODE
|
||||||
|
# source: grpc_sam3_img.proto
|
||||||
|
# Protobuf Python Version: 6.31.1
|
||||||
|
"""Generated protocol buffer code."""
|
||||||
|
from google.protobuf import descriptor as _descriptor
|
||||||
|
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||||
|
from google.protobuf import runtime_version as _runtime_version
|
||||||
|
from google.protobuf import symbol_database as _symbol_database
|
||||||
|
from google.protobuf.internal import builder as _builder
|
||||||
|
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||||
|
_runtime_version.Domain.PUBLIC,
|
||||||
|
6,
|
||||||
|
31,
|
||||||
|
1,
|
||||||
|
'',
|
||||||
|
'grpc_sam3_img.proto'
|
||||||
|
)
|
||||||
|
# @@protoc_insertion_point(imports)
|
||||||
|
|
||||||
|
_sym_db = _symbol_database.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13grpc_sam3_img.proto\x12\rgrpc_sam3_img\"%\n\x12HealthCheckRequest\x12\x0f\n\x07service\x18\x01 \x01(\t\"\xa8\x01\n\x13HealthCheckResponse\x12@\n\x06status\x18\x01 \x01(\x0e\x32\x30.grpc_sam3_img.HealthCheckResponse.ServingStatus\"O\n\rServingStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0b\n\x07SERVING\x10\x01\x12\x0f\n\x0bNOT_SERVING\x10\x02\x12\x13\n\x0fSERVICE_UNKNOWN\x10\x03\"\\\n\x0bTaskRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\n\n\x02sn\x18\x02 \x01(\t\x12\x30\n\x0c\x63ontent_body\x18\x03 \x01(\x0b\x32\x1a.grpc_sam3_img.ContentBody\"z\n\x0b\x43ontentBody\x12\x0f\n\x07img_url\x18\x01 \x01(\t\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x12\n\nconfidence\x18\x03 \x01(\x02\x12\x0f\n\x07mqtt_ip\x18\x04 \x01(\t\x12\x11\n\tmqtt_port\x18\x05 \x01(\x05\x12\x12\n\nmqtt_topic\x18\x06 \x01(\t\"A\n\x0cTaskResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x0f\n\x07message\x18\x03 \x01(\t2U\n\x0bTaskService\x12\x46\n\x0bProcessTask\x12\x1a.grpc_sam3_img.TaskRequest\x1a\x1b.grpc_sam3_img.TaskResponse2]\n\x0bHealthCheck\x12N\n\x05\x43heck\x12!.grpc_sam3_img.HealthCheckRequest\x1a\".grpc_sam3_img.HealthCheckResponseb\x06proto3')
|
||||||
|
|
||||||
|
_globals = globals()
|
||||||
|
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||||
|
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'grpc_sam3_img_pb2', _globals)
|
||||||
|
if not _descriptor._USE_C_DESCRIPTORS:
|
||||||
|
DESCRIPTOR._loaded_options = None
|
||||||
|
_globals['_HEALTHCHECKREQUEST']._serialized_start=38
|
||||||
|
_globals['_HEALTHCHECKREQUEST']._serialized_end=75
|
||||||
|
_globals['_HEALTHCHECKRESPONSE']._serialized_start=78
|
||||||
|
_globals['_HEALTHCHECKRESPONSE']._serialized_end=246
|
||||||
|
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_start=167
|
||||||
|
_globals['_HEALTHCHECKRESPONSE_SERVINGSTATUS']._serialized_end=246
|
||||||
|
_globals['_TASKREQUEST']._serialized_start=248
|
||||||
|
_globals['_TASKREQUEST']._serialized_end=340
|
||||||
|
_globals['_CONTENTBODY']._serialized_start=342
|
||||||
|
_globals['_CONTENTBODY']._serialized_end=464
|
||||||
|
_globals['_TASKRESPONSE']._serialized_start=466
|
||||||
|
_globals['_TASKRESPONSE']._serialized_end=531
|
||||||
|
_globals['_TASKSERVICE']._serialized_start=533
|
||||||
|
_globals['_TASKSERVICE']._serialized_end=618
|
||||||
|
_globals['_HEALTHCHECK']._serialized_start=620
|
||||||
|
_globals['_HEALTHCHECK']._serialized_end=713
|
||||||
|
# @@protoc_insertion_point(module_scope)
|
||||||
172
grpc_util/grpc_sam3/grpc_sam3_img_pb2_grpc.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||||
|
"""Client and server classes corresponding to protobuf-defined services."""
|
||||||
|
import grpc
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import grpc_util.grpc_sam3.grpc_sam3_img_pb2 as grpc__sam3__img__pb2
|
||||||
|
|
||||||
|
GRPC_GENERATED_VERSION = '1.76.0'
|
||||||
|
GRPC_VERSION = grpc.__version__
|
||||||
|
_version_not_supported = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from grpc._utilities import first_version_is_lower
|
||||||
|
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||||
|
except ImportError:
|
||||||
|
_version_not_supported = True
|
||||||
|
|
||||||
|
if _version_not_supported:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||||
|
+ ' but the generated code in grpc_sam3_img_pb2_grpc.py depends on'
|
||||||
|
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||||
|
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||||
|
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskServiceStub(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.ProcessTask = channel.unary_unary(
|
||||||
|
'/grpc_sam3_img.TaskService/ProcessTask',
|
||||||
|
request_serializer=grpc__sam3__img__pb2.TaskRequest.SerializeToString,
|
||||||
|
response_deserializer=grpc__sam3__img__pb2.TaskResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskServiceServicer(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
def ProcessTask(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_TaskServiceServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'ProcessTask': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.ProcessTask,
|
||||||
|
request_deserializer=grpc__sam3__img__pb2.TaskRequest.FromString,
|
||||||
|
response_serializer=grpc__sam3__img__pb2.TaskResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'grpc_sam3_img.TaskService', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
server.add_registered_method_handlers('grpc_sam3_img.TaskService', rpc_method_handlers)
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class TaskService(object):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ProcessTask(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/grpc_sam3_img.TaskService/ProcessTask',
|
||||||
|
grpc__sam3__img__pb2.TaskRequest.SerializeToString,
|
||||||
|
grpc__sam3__img__pb2.TaskResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckStub(object):
|
||||||
|
"""添加健康检查服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel: A grpc.Channel.
|
||||||
|
"""
|
||||||
|
self.Check = channel.unary_unary(
|
||||||
|
'/grpc_sam3_img.HealthCheck/Check',
|
||||||
|
request_serializer=grpc__sam3__img__pb2.HealthCheckRequest.SerializeToString,
|
||||||
|
response_deserializer=grpc__sam3__img__pb2.HealthCheckResponse.FromString,
|
||||||
|
_registered_method=True)
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckServicer(object):
|
||||||
|
"""添加健康检查服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
def Check(self, request, context):
|
||||||
|
"""Missing associated documentation comment in .proto file."""
|
||||||
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||||
|
context.set_details('Method not implemented!')
|
||||||
|
raise NotImplementedError('Method not implemented!')
|
||||||
|
|
||||||
|
|
||||||
|
def add_HealthCheckServicer_to_server(servicer, server):
|
||||||
|
rpc_method_handlers = {
|
||||||
|
'Check': grpc.unary_unary_rpc_method_handler(
|
||||||
|
servicer.Check,
|
||||||
|
request_deserializer=grpc__sam3__img__pb2.HealthCheckRequest.FromString,
|
||||||
|
response_serializer=grpc__sam3__img__pb2.HealthCheckResponse.SerializeToString,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
generic_handler = grpc.method_handlers_generic_handler(
|
||||||
|
'grpc_sam3_img.HealthCheck', rpc_method_handlers)
|
||||||
|
server.add_generic_rpc_handlers((generic_handler,))
|
||||||
|
server.add_registered_method_handlers('grpc_sam3_img.HealthCheck', rpc_method_handlers)
|
||||||
|
|
||||||
|
|
||||||
|
# This class is part of an EXPERIMENTAL API.
|
||||||
|
class HealthCheck(object):
|
||||||
|
"""添加健康检查服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def Check(request,
|
||||||
|
target,
|
||||||
|
options=(),
|
||||||
|
channel_credentials=None,
|
||||||
|
call_credentials=None,
|
||||||
|
insecure=False,
|
||||||
|
compression=None,
|
||||||
|
wait_for_ready=None,
|
||||||
|
timeout=None,
|
||||||
|
metadata=None):
|
||||||
|
return grpc.experimental.unary_unary(
|
||||||
|
request,
|
||||||
|
target,
|
||||||
|
'/grpc_sam3_img.HealthCheck/Check',
|
||||||
|
grpc__sam3__img__pb2.HealthCheckRequest.SerializeToString,
|
||||||
|
grpc__sam3__img__pb2.HealthCheckResponse.FromString,
|
||||||
|
options,
|
||||||
|
channel_credentials,
|
||||||
|
insecure,
|
||||||
|
call_credentials,
|
||||||
|
compression,
|
||||||
|
wait_for_ready,
|
||||||
|
timeout,
|
||||||
|
metadata,
|
||||||
|
_registered_method=True)
|
||||||
72
grpc_util/grpc_sam3/sam3_grpc_client.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import grpc
|
||||||
|
import time
|
||||||
|
|
||||||
|
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||||
|
|
||||||
|
|
||||||
|
def check_server_status(channel):
|
||||||
|
try:
|
||||||
|
health_stub = grpc_sam3_img_pb2_grpc.HealthCheckStub(channel)
|
||||||
|
response = health_stub.Check(grpc_sam3_img_pb2.HealthCheckRequest(service="TaskService"))
|
||||||
|
return response.status == grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
print(f"Health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_grpc_request(max_retries=3, delay=5):
|
||||||
|
channel = None
|
||||||
|
retries = 0
|
||||||
|
|
||||||
|
while retries < max_retries:
|
||||||
|
try:
|
||||||
|
# 创建通道
|
||||||
|
channel = grpc.insecure_channel('192.168.110.187:9999')
|
||||||
|
|
||||||
|
# 检查服务器状态
|
||||||
|
if not check_server_status(channel):
|
||||||
|
raise Exception("Server is not healthy")
|
||||||
|
|
||||||
|
stub = grpc_sam3_img_pb2_grpc.TaskServiceStub(channel)
|
||||||
|
|
||||||
|
# 创建请求消息
|
||||||
|
request = grpc_sam3_img_pb2.TaskRequest(
|
||||||
|
task_id="d6118954-a170-4e1c-84bd-ddbd3114b354111",
|
||||||
|
sn="8UUXN6S00A0CK7",
|
||||||
|
content_body=grpc_sam3_img_pb2.ContentBody(
|
||||||
|
img_url="demo/03.png",
|
||||||
|
prompt="cat",
|
||||||
|
confidence=0.5,
|
||||||
|
mqtt_ip="47.108.62.6",
|
||||||
|
mqtt_port=12503,
|
||||||
|
mqtt_topic="thing/product/ai/events"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用远程方法
|
||||||
|
response = stub.ProcessTask(request)
|
||||||
|
print(f"Response: task_id={response.task_id}, success={response.success}, message={response.message}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except grpc.RpcError as e:
|
||||||
|
retries += 1
|
||||||
|
print(f"RPC error occurred (attempt {retries}/{max_retries}): {e}")
|
||||||
|
if retries < max_retries:
|
||||||
|
print(f"Retrying in {delay} seconds...")
|
||||||
|
time.sleep(delay)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error occurred: {e}")
|
||||||
|
retries += 1
|
||||||
|
if retries < max_retries:
|
||||||
|
print(f"Retrying in {delay} seconds...")
|
||||||
|
time.sleep(delay)
|
||||||
|
finally:
|
||||||
|
if channel:
|
||||||
|
channel.close()
|
||||||
|
|
||||||
|
print("All retry attempts failed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
check_grpc_request()
|
||||||
369
grpc_util/grpc_sam3/sam3_grpc_server.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
import json
|
||||||
|
from concurrent import futures
|
||||||
|
import grpc
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from grpc_util.grpc_sam3 import grpc_sam3_img_pb2_grpc, grpc_sam3_img_pb2
|
||||||
|
from middleware.MQTTService import MQTTService
|
||||||
|
from middleware.minio_util import downFile, upload_file
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from middleware.util import get_current_date_and_milliseconds
|
||||||
|
|
||||||
|
print(sys.executable)
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import sam3.sam3
|
||||||
|
from PIL import Image
|
||||||
|
from sam3.sam3 import build_sam3_image_model
|
||||||
|
from sam3.sam3 import build_sam3_image_model_0228
|
||||||
|
from sam3.sam3.model.box_ops import box_xywh_to_cxcywh
|
||||||
|
from sam3.sam3.model.sam3_image_processor import Sam3Processor
|
||||||
|
from sam3.sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results,plot_results_savepic
|
||||||
|
|
||||||
|
sam3_root = os.path.join(os.path.dirname(sam3.sam3.__file__), "..")
|
||||||
|
import torch
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskQueue:
|
||||||
|
"""任务队列管理类"""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int = 1000):
|
||||||
|
self.queue = queue.Queue(maxsize=max_size)
|
||||||
|
self.task_status: Dict[str, dict] = {} # 存储任务状态
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
|
||||||
|
def add_task(self, task_id: str, request_data: dict) -> bool:
|
||||||
|
"""添加任务到队列"""
|
||||||
|
try:
|
||||||
|
task_item = {
|
||||||
|
'task_id': task_id,
|
||||||
|
'data': request_data,
|
||||||
|
'timestamp': time.time(),
|
||||||
|
'status': 'pending'
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.lock:
|
||||||
|
self.task_status[task_id] = task_item
|
||||||
|
|
||||||
|
# 非阻塞方式放入队列
|
||||||
|
self.queue.put(task_item, block=False)
|
||||||
|
logger.info(f"任务 {task_id} 已添加到队列,当前队列大小: {self.queue.qsize()}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except queue.Full:
|
||||||
|
logger.warning(f"队列已满,任务 {task_id} 被拒绝")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"添加任务失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_task(self, timeout: float = 1.0) -> Optional[dict]:
|
||||||
|
"""从队列获取任务"""
|
||||||
|
try:
|
||||||
|
return self.queue.get(timeout=timeout)
|
||||||
|
except queue.Empty:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_task_status(self, task_id: str, status: str, result: dict = None):
|
||||||
|
"""更新任务状态"""
|
||||||
|
with self.lock:
|
||||||
|
if task_id in self.task_status:
|
||||||
|
self.task_status[task_id]['status'] = status
|
||||||
|
if result:
|
||||||
|
self.task_status[task_id]['result'] = result
|
||||||
|
self.task_status[task_id]['completed_time'] = time.time()
|
||||||
|
|
||||||
|
def get_task_status(self, task_id: str) -> Optional[dict]:
|
||||||
|
"""获取任务状态"""
|
||||||
|
with self.lock:
|
||||||
|
return self.task_status.get(task_id)
|
||||||
|
|
||||||
|
def cleanup_old_tasks(self, max_age_seconds: int = 3600):
|
||||||
|
"""清理旧任务"""
|
||||||
|
with self.lock:
|
||||||
|
current_time = time.time()
|
||||||
|
to_delete = []
|
||||||
|
for task_id, task in self.task_status.items():
|
||||||
|
if 'completed_time' in task:
|
||||||
|
age = current_time - task['completed_time']
|
||||||
|
if age > max_age_seconds:
|
||||||
|
to_delete.append(task_id)
|
||||||
|
|
||||||
|
for task_id in to_delete:
|
||||||
|
del self.task_status[task_id]
|
||||||
|
logger.info(f"清理旧任务: {task_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskWorker(threading.Thread):
|
||||||
|
"""工作线程,从队列中取任务并处理"""
|
||||||
|
|
||||||
|
def __init__(self, worker_id: int, task_queue: TaskQueue, stop_event: threading.Event):
|
||||||
|
super().__init__(daemon=True)
|
||||||
|
self.worker_id = worker_id
|
||||||
|
self.task_queue = task_queue
|
||||||
|
self.stop_event = stop_event
|
||||||
|
self.processed_count = 0
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
logger.info(f"工作线程 {self.worker_id} 启动")
|
||||||
|
while not self.stop_event.is_set():
|
||||||
|
try:
|
||||||
|
# 从队列获取任务
|
||||||
|
task = self.task_queue.get_task(timeout=0.5)
|
||||||
|
if not task:
|
||||||
|
continue
|
||||||
|
|
||||||
|
task_id = task['task_id']
|
||||||
|
request_data = task['data']
|
||||||
|
|
||||||
|
logger.info(f"工作线程 {self.worker_id} 开始处理任务: {task_id}")
|
||||||
|
|
||||||
|
# 更新任务状态为处理中
|
||||||
|
self.task_queue.update_task_status(task_id, 'processing')
|
||||||
|
|
||||||
|
# 这里是你的实际处理逻辑
|
||||||
|
result = self.process_task(task_id, request_data)
|
||||||
|
|
||||||
|
# 更新任务状态为完成
|
||||||
|
self.task_queue.update_task_status(
|
||||||
|
task_id,
|
||||||
|
'completed' if result.get('success') else 'failed',
|
||||||
|
result
|
||||||
|
)
|
||||||
|
|
||||||
|
self.processed_count += 1
|
||||||
|
logger.info(f"工作线程 {self.worker_id} 完成任务: {task_id}, 处理总数: {self.processed_count}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"工作线程 {self.worker_id} 处理任务失败: {e}")
|
||||||
|
if task:
|
||||||
|
self.task_queue.update_task_status(
|
||||||
|
task['task_id'],
|
||||||
|
'failed',
|
||||||
|
{'error': str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_task(self, task_id: str, request_data: dict) -> dict:
|
||||||
|
"""模拟耗时任务处理"""
|
||||||
|
# 这里替换为你的实际处理逻辑
|
||||||
|
# time.sleep(10) # 模拟10秒处理时间
|
||||||
|
task_id=request_data["task_id"]
|
||||||
|
sn=request_data["sn"]
|
||||||
|
img_url=request_data["img_url"]
|
||||||
|
prompt=request_data["prompt"]
|
||||||
|
confidence=request_data["confidence"]
|
||||||
|
mqtt_ip=request_data["mqtt_ip"]
|
||||||
|
mqtt_port=request_data["mqtt_port"]
|
||||||
|
mqtt_topic=request_data["mqtt_topic"]
|
||||||
|
local_image_path=downFile(img_url)
|
||||||
|
|
||||||
|
bpe_path = f"/home/beidou/test0623/sam3/sam3/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
|
||||||
|
config_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/config.json" # 替换为本地路径
|
||||||
|
checkpoint_path = f"/home/beidou/test0623/sam3/sam3/sam3-weight/sam3.pt" # 替换为本地路径
|
||||||
|
# model = build_sam3_image_model(bpe_path=bpe_path)
|
||||||
|
|
||||||
|
# 2. 构建模型(从本地加载)
|
||||||
|
model = build_sam3_image_model_0228(
|
||||||
|
bpe_path=bpe_path,
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
config_path=config_path, # 可选
|
||||||
|
load_from_HF=False,
|
||||||
|
device="cuda",
|
||||||
|
eval_mode=True,
|
||||||
|
)
|
||||||
|
formatted_date, milliseconds_timestamp = get_current_date_and_milliseconds()
|
||||||
|
img_name=os.path.basename(local_image_path)
|
||||||
|
dir_name=os.path.dirname(local_image_path)
|
||||||
|
predict_save_path=os.path.join(dir_name,str(milliseconds_timestamp)+img_name)
|
||||||
|
|
||||||
|
# image = Image.open(image_path)
|
||||||
|
image = Image.open(local_image_path).convert("RGB")
|
||||||
|
width, height = image.size
|
||||||
|
processor = Sam3Processor(model, confidence_threshold=0.5)
|
||||||
|
inference_state = processor.set_image(image)
|
||||||
|
|
||||||
|
processor.reset_all_prompts(inference_state)
|
||||||
|
inference_state = processor.set_text_prompt(state=inference_state, prompt="road")
|
||||||
|
img0 = Image.open(local_image_path)
|
||||||
|
plot_results_savepic(img0, inference_state, save_path=predict_save_path)
|
||||||
|
object_name, _=upload_file(predict_save_path,None)
|
||||||
|
|
||||||
|
mqtt = MQTTService(mqtt_ip, port=mqtt_port)
|
||||||
|
message = {
|
||||||
|
'success': True,
|
||||||
|
"task_id":task_id,
|
||||||
|
'object_name': object_name
|
||||||
|
}
|
||||||
|
mqtt.publish_sync(mqtt_topic, json.dumps(message, ensure_ascii=False))
|
||||||
|
# 删除本地文件
|
||||||
|
if os.path.exists(local_image_path):
|
||||||
|
os.remove(local_image_path)
|
||||||
|
if os.path.exists(predict_save_path):
|
||||||
|
os.remove(predict_save_path)
|
||||||
|
# 模拟处理结果
|
||||||
|
return {
|
||||||
|
'success': True,
|
||||||
|
'message': f'任务 {task_id} 处理完成',
|
||||||
|
'data': {'result': 'some_result'}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TaskServiceServicer(grpc_sam3_img_pb2_grpc.TaskServiceServicer):
|
||||||
|
def __init__(self, task_queue: TaskQueue, max_workers: int = 1):
|
||||||
|
self.task_queue = task_queue
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self.stop_event = threading.Event()
|
||||||
|
self.workers = []
|
||||||
|
|
||||||
|
# 启动工作线程
|
||||||
|
self.start_workers()
|
||||||
|
|
||||||
|
def start_workers(self):
|
||||||
|
"""启动工作线程池"""
|
||||||
|
for i in range(self.max_workers):
|
||||||
|
worker = TaskWorker(i, self.task_queue, self.stop_event)
|
||||||
|
worker.start()
|
||||||
|
self.workers.append(worker)
|
||||||
|
logger.info(f"启动了 {self.max_workers} 个工作线程")
|
||||||
|
|
||||||
|
def ProcessTask(self, request, context):
|
||||||
|
"""处理任务请求 - 将任务放入队列后立即返回"""
|
||||||
|
try:
|
||||||
|
# 检查队列是否已满
|
||||||
|
if self.task_queue.queue.full():
|
||||||
|
logger.warning(f"队列已满,拒绝任务: {request.task_id}")
|
||||||
|
return grpc_sam3_img_pb2.TaskResponse(
|
||||||
|
task_id=request.task_id,
|
||||||
|
success=False,
|
||||||
|
message="服务器忙,请稍后重试"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 准备任务数据
|
||||||
|
task_data = {
|
||||||
|
'task_id': request.task_id,
|
||||||
|
'sn': request.sn,
|
||||||
|
'img_url': request.content_body.img_url,
|
||||||
|
'prompt': request.content_body.prompt,
|
||||||
|
'confidence': request.content_body.confidence,
|
||||||
|
'mqtt_ip': request.content_body.mqtt_ip,
|
||||||
|
'mqtt_port': request.content_body.mqtt_port,
|
||||||
|
'mqtt_topic': request.content_body.mqtt_topic
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将任务添加到队列
|
||||||
|
if self.task_queue.add_task(request.task_id, task_data):
|
||||||
|
return grpc_sam3_img_pb2.TaskResponse(
|
||||||
|
task_id=request.task_id,
|
||||||
|
success=True,
|
||||||
|
message=f"任务已接收,正在排队处理。当前队列位置: {self.task_queue.queue.qsize()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return grpc_sam3_img_pb2.TaskResponse(
|
||||||
|
task_id=request.task_id,
|
||||||
|
success=False,
|
||||||
|
message="任务提交失败"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理任务请求失败: {e}")
|
||||||
|
return grpc_sam3_img_pb2.TaskResponse(
|
||||||
|
task_id=request.task_id,
|
||||||
|
success=False,
|
||||||
|
message=f"服务器内部错误: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止工作线程"""
|
||||||
|
self.stop_event.set()
|
||||||
|
for worker in self.workers:
|
||||||
|
worker.join(timeout=2)
|
||||||
|
logger.info("所有工作线程已停止")
|
||||||
|
|
||||||
|
|
||||||
|
class HealthCheckServicer(grpc_sam3_img_pb2_grpc.HealthCheckServicer):
|
||||||
|
def __init__(self, task_queue: TaskQueue):
|
||||||
|
self.task_queue = task_queue
|
||||||
|
|
||||||
|
def Check(self, request, context):
|
||||||
|
"""健康检查,包含队列状态"""
|
||||||
|
queue_size = self.task_queue.queue.qsize()
|
||||||
|
|
||||||
|
if queue_size > 50: # 队列过长
|
||||||
|
return grpc_sam3_img_pb2.HealthCheckResponse(
|
||||||
|
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.NOT_SERVING
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return grpc_sam3_img_pb2.HealthCheckResponse(
|
||||||
|
status=grpc_sam3_img_pb2.HealthCheckResponse.ServingStatus.SERVING
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
# 创建任务队列
|
||||||
|
task_queue = TaskQueue(max_size=20)
|
||||||
|
|
||||||
|
# 创建服务实例
|
||||||
|
task_service = TaskServiceServicer(task_queue, max_workers=1) # 10个工作线程
|
||||||
|
health_service = HealthCheckServicer(task_queue)
|
||||||
|
|
||||||
|
# 创建gRPC服务器
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=2)) # 处理RPC请求的线程
|
||||||
|
|
||||||
|
# 注册服务
|
||||||
|
grpc_sam3_img_pb2_grpc.add_TaskServiceServicer_to_server(task_service, server)
|
||||||
|
grpc_sam3_img_pb2_grpc.add_HealthCheckServicer_to_server(health_service, server)
|
||||||
|
|
||||||
|
# 启动服务器
|
||||||
|
server.add_insecure_port('[::]:50051')
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
logger.info("服务器已启动,监听端口: 50051")
|
||||||
|
logger.info(f"工作线程数: 1, 队列最大容量: 20")
|
||||||
|
|
||||||
|
# 定时清理旧任务
|
||||||
|
def cleanup_loop():
|
||||||
|
while True:
|
||||||
|
time.sleep(300) # 每5分钟清理一次
|
||||||
|
task_queue.cleanup_old_tasks()
|
||||||
|
|
||||||
|
cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
|
||||||
|
cleanup_thread.start()
|
||||||
|
|
||||||
|
# 优雅关闭处理
|
||||||
|
def shutdown():
|
||||||
|
logger.info("收到关闭信号,正在停止服务器...")
|
||||||
|
task_service.stop()
|
||||||
|
server.stop(5) # 5秒宽限期
|
||||||
|
logger.info("服务器已停止")
|
||||||
|
|
||||||
|
import signal
|
||||||
|
signal.signal(signal.SIGINT, lambda s, f: shutdown())
|
||||||
|
signal.signal(signal.SIGTERM, lambda s, f: shutdown())
|
||||||
|
|
||||||
|
# 保持服务器运行
|
||||||
|
try:
|
||||||
|
server.wait_for_termination()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
serve()
|
||||||
@ -51,6 +51,7 @@ class MQTTService:
|
|||||||
self._message_task = None
|
self._message_task = None
|
||||||
self._connection_lock = asyncio.Lock()
|
self._connection_lock = asyncio.Lock()
|
||||||
self.os_type = sys.platform.lower()
|
self.os_type = sys.platform.lower()
|
||||||
|
self._loop = None # 保存事件循环
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
async with self._connection_lock:
|
async with self._connection_lock:
|
||||||
@ -122,6 +123,25 @@ class MQTTService:
|
|||||||
await self.reconnect()
|
await self.reconnect()
|
||||||
await self.client.publish(topic, payload, qos=qos, retain=retain)
|
await self.client.publish(topic, payload, qos=qos, retain=retain)
|
||||||
|
|
||||||
|
def publish_sync(self, topic, payload, qos=0, retain=False):
|
||||||
|
"""同步发布消息(适用于同步代码调用)"""
|
||||||
|
if self._loop is None or self._loop.is_closed():
|
||||||
|
# 如果没有事件循环,创建一个新的
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
|
||||||
|
# 在事件循环中运行异步publish
|
||||||
|
if not self._loop.is_running():
|
||||||
|
return self._loop.run_until_complete(
|
||||||
|
self.publish(topic, payload, qos=qos, retain=retain)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果事件循环已经在运行,创建一个任务
|
||||||
|
asyncio.ensure_future(
|
||||||
|
self.publish(topic, payload, qos=qos, retain=retain)
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
async def subscribe(self, topic, callback=None, qos=0):
|
async def subscribe(self, topic, callback=None, qos=0):
|
||||||
if not self.is_connected:
|
if not self.is_connected:
|
||||||
await self.connect()
|
await self.connect()
|
||||||
|
|||||||
153
sam3/.gitignore
vendored
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
*-Copy*.ipynb
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode/
|
||||||
|
*.code-workspace
|
||||||
|
|
||||||
|
# Model weights and checkpoints
|
||||||
|
*.pth
|
||||||
|
*.pt
|
||||||
|
*.bin
|
||||||
|
*.ckpt
|
||||||
|
*.safetensors
|
||||||
|
weights/
|
||||||
|
checkpoints/
|
||||||
|
sam3_logs/
|
||||||
|
|
||||||
|
# Data files
|
||||||
|
*.h5
|
||||||
|
*.hdf5
|
||||||
|
*.pkl
|
||||||
|
*.pickle
|
||||||
|
*.npy
|
||||||
|
*.npz
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/
|
||||||
|
runs/
|
||||||
|
tensorboard/
|
||||||
|
|
||||||
|
# OS specific
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# BPE vocabulary files
|
||||||
|
*.bpe
|
||||||
|
*.vocab
|
||||||
80
sam3/CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
In the interest of fostering an open and welcoming environment, we as
|
||||||
|
contributors and maintainers pledge to make participation in our project and
|
||||||
|
our community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||||
|
level of experience, education, socio-economic status, nationality, personal
|
||||||
|
appearance, race, religion, or sexual identity and orientation.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to creating a positive environment
|
||||||
|
include:
|
||||||
|
|
||||||
|
* Using welcoming and inclusive language
|
||||||
|
* Being respectful of differing viewpoints and experiences
|
||||||
|
* Gracefully accepting constructive criticism
|
||||||
|
* Focusing on what is best for the community
|
||||||
|
* Showing empathy towards other community members
|
||||||
|
|
||||||
|
Examples of unacceptable behavior by participants include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||||
|
advances
|
||||||
|
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or electronic
|
||||||
|
address, without explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Our Responsibilities
|
||||||
|
|
||||||
|
Project maintainers are responsible for clarifying the standards of acceptable
|
||||||
|
behavior and are expected to take appropriate and fair corrective action in
|
||||||
|
response to any instances of unacceptable behavior.
|
||||||
|
|
||||||
|
Project maintainers have the right and responsibility to remove, edit, or
|
||||||
|
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||||
|
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||||
|
permanently any contributor for other behaviors that they deem inappropriate,
|
||||||
|
threatening, offensive, or harmful.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all project spaces, and it also applies when
|
||||||
|
an individual is representing the project or its community in public spaces.
|
||||||
|
Examples of representing a project or community include using an official
|
||||||
|
project e-mail address, posting via an official social media account, or acting
|
||||||
|
as an appointed representative at an online or offline event. Representation of
|
||||||
|
a project may be further defined and clarified by project maintainers.
|
||||||
|
|
||||||
|
This Code of Conduct also applies outside the project spaces when there is a
|
||||||
|
reasonable belief that an individual's behavior may have a negative impact on
|
||||||
|
the project or its community.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported by contacting the project team at <opensource-conduct@meta.com>. All
|
||||||
|
complaints will be reviewed and investigated and will result in a response that
|
||||||
|
is deemed necessary and appropriate to the circumstances. The project team is
|
||||||
|
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||||
|
Further details of specific enforcement policies may be posted separately.
|
||||||
|
|
||||||
|
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||||
|
faith may face temporary or permanent repercussions as determined by other
|
||||||
|
members of the project's leadership.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||||
|
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||||
|
|
||||||
|
[homepage]: https://www.contributor-covenant.org
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see
|
||||||
|
https://www.contributor-covenant.org/faq
|
||||||
30
sam3/CONTRIBUTING.md
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# Contributing to sam3
|
||||||
|
We want to make contributing to this project as easy and transparent as
|
||||||
|
possible.
|
||||||
|
|
||||||
|
## Pull Requests
|
||||||
|
We actively welcome your pull requests.
|
||||||
|
|
||||||
|
1. Fork the repo and create your branch from `main`.
|
||||||
|
2. If you've added code that should be tested, add tests.
|
||||||
|
3. If you've changed APIs, update the documentation.
|
||||||
|
4. Make sure your code lints.
|
||||||
|
5. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||||
|
|
||||||
|
## Contributor License Agreement ("CLA")
|
||||||
|
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||||
|
to do this once to work on any of Facebook's open source projects.
|
||||||
|
|
||||||
|
Complete your CLA here: <https://code.facebook.com/cla>
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
We use GitHub issues to track public bugs. Please ensure your description is
|
||||||
|
clear and has sufficient instructions to be able to reproduce the issue.
|
||||||
|
|
||||||
|
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||||
|
disclosure of security bugs. In those cases, please go through the process
|
||||||
|
outlined on that page and do not file a public issue.
|
||||||
|
|
||||||
|
## License
|
||||||
|
By contributing to sam3, you agree that your contributions will be licensed
|
||||||
|
under the LICENSE file in the root directory of this source tree.
|
||||||
61
sam3/LICENSE
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
SAM License
|
||||||
|
Last Updated: November 19, 2025
|
||||||
|
|
||||||
|
“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein.
|
||||||
|
|
||||||
|
|
||||||
|
“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
|
||||||
|
|
||||||
|
“Documentation” means the specifications, manuals and documentation accompanying
|
||||||
|
SAM Materials distributed by Meta.
|
||||||
|
|
||||||
|
|
||||||
|
“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
|
||||||
|
|
||||||
|
|
||||||
|
“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
|
||||||
|
|
||||||
|
|
||||||
|
“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
|
||||||
|
|
||||||
|
|
||||||
|
“Trade Controls” means any of the following: Sanctions and applicable export and import controls.
|
||||||
|
|
||||||
|
By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement.
|
||||||
|
|
||||||
|
|
||||||
|
1. License Rights and Redistribution.
|
||||||
|
|
||||||
|
|
||||||
|
a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials.
|
||||||
|
|
||||||
|
b. Redistribution and Use.
|
||||||
|
i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials.
|
||||||
|
|
||||||
|
|
||||||
|
ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication.
|
||||||
|
|
||||||
|
|
||||||
|
iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
|
||||||
|
iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials.
|
||||||
|
v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
|
||||||
|
2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
|
||||||
|
|
||||||
|
|
||||||
|
3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS.
|
||||||
|
|
||||||
|
4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||||
|
|
||||||
|
5. Intellectual Property.
|
||||||
|
|
||||||
|
|
||||||
|
a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
||||||
|
|
||||||
|
b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials.
|
||||||
|
|
||||||
|
6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
|
||||||
|
|
||||||
|
7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
||||||
|
|
||||||
|
|
||||||
|
8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
||||||
6
sam3/MANIFEST.in
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
include LICENSE
|
||||||
|
include README.md
|
||||||
|
recursive-include examples *.py
|
||||||
|
recursive-include examples *.ipynb
|
||||||
|
recursive-include examples *.md
|
||||||
|
recursive-include tests *.py
|
||||||
395
sam3/README.md
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
# SAM 3: Segment Anything with Concepts
|
||||||
|
|
||||||
|
Meta Superintelligence Labs
|
||||||
|
|
||||||
|
[Nicolas Carion](https://www.nicolascarion.com/)\*,
|
||||||
|
[Laura Gustafson](https://scholar.google.com/citations?user=c8IpF9gAAAAJ&hl=en)\*,
|
||||||
|
[Yuan-Ting Hu](https://scholar.google.com/citations?user=E8DVVYQAAAAJ&hl=en)\*,
|
||||||
|
[Shoubhik Debnath](https://scholar.google.com/citations?user=fb6FOfsAAAAJ&hl=en)\*,
|
||||||
|
[Ronghang Hu](https://ronghanghu.com/)\*,
|
||||||
|
[Didac Suris](https://www.didacsuris.com/)\*,
|
||||||
|
[Chaitanya Ryali](https://scholar.google.com/citations?user=4LWx24UAAAAJ&hl=en)\*,
|
||||||
|
[Kalyan Vasudev Alwala](https://scholar.google.co.in/citations?user=m34oaWEAAAAJ&hl=en)\*,
|
||||||
|
[Haitham Khedr](https://hkhedr.com/)\*, Andrew Huang,
|
||||||
|
[Jie Lei](https://jayleicn.github.io/),
|
||||||
|
[Tengyu Ma](https://scholar.google.com/citations?user=VeTSl0wAAAAJ&hl=en),
|
||||||
|
[Baishan Guo](https://scholar.google.com/citations?user=BC5wDu8AAAAJ&hl=en),
|
||||||
|
Arpit Kalla, [Markus Marks](https://damaggu.github.io/),
|
||||||
|
[Joseph Greer](https://scholar.google.com/citations?user=guL96CkAAAAJ&hl=en),
|
||||||
|
Meng Wang, [Peize Sun](https://peizesun.github.io/),
|
||||||
|
[Roman Rädle](https://scholar.google.com/citations?user=Tpt57v0AAAAJ&hl=en),
|
||||||
|
[Triantafyllos Afouras](https://www.robots.ox.ac.uk/~afourast/),
|
||||||
|
[Effrosyni Mavroudi](https://scholar.google.com/citations?user=vYRzGGEAAAAJ&hl=en),
|
||||||
|
[Katherine Xu](https://k8xu.github.io/)°,
|
||||||
|
[Tsung-Han Wu](https://patrickthwu.com/)°,
|
||||||
|
[Yu Zhou](https://yu-bryan-zhou.github.io/)°,
|
||||||
|
[Liliane Momeni](https://scholar.google.com/citations?user=Lb-KgVYAAAAJ&hl=en)°,
|
||||||
|
[Rishi Hazra](https://rishihazra.github.io/)°,
|
||||||
|
[Shuangrui Ding](https://mark12ding.github.io/)°,
|
||||||
|
[Sagar Vaze](https://sgvaze.github.io/)°,
|
||||||
|
[Francois Porcher](https://scholar.google.com/citations?user=LgHZ8hUAAAAJ&hl=en)°,
|
||||||
|
[Feng Li](https://fengli-ust.github.io/)°,
|
||||||
|
[Siyuan Li](https://siyuanliii.github.io/)°,
|
||||||
|
[Aishwarya Kamath](https://ashkamath.github.io/)°,
|
||||||
|
[Ho Kei Cheng](https://hkchengrex.com/)°,
|
||||||
|
[Piotr Dollar](https://pdollar.github.io/)†,
|
||||||
|
[Nikhila Ravi](https://nikhilaravi.com/)†,
|
||||||
|
[Kate Saenko](https://ai.bu.edu/ksaenko.html)†,
|
||||||
|
[Pengchuan Zhang](https://pzzhang.github.io/pzzhang/)†,
|
||||||
|
[Christoph Feichtenhofer](https://feichtenhofer.github.io/)†
|
||||||
|
|
||||||
|
\* core contributor, ° intern, † project lead, order is random within groups
|
||||||
|
|
||||||
|
[[`Paper`](https://ai.meta.com/research/publications/sam-3-segment-anything-with-concepts/)]
|
||||||
|
[[`Project`](https://ai.meta.com/sam3)]
|
||||||
|
[[`Demo`](https://segment-anything.com/)]
|
||||||
|
[[`Blog`](https://ai.meta.com/blog/segment-anything-model-3/)]
|
||||||
|
[[`BibTeX`](#citing-sam-3)]
|
||||||
|
|
||||||
|
 SAM 3 is a unified foundation model for promptable segmentation in images and videos. It can detect, segment, and track objects using text or visual prompts such as points, boxes, and masks. Compared to its predecessor [SAM 2](https://github.com/facebookresearch/sam2), SAM 3 introduces the ability to exhaustively segment all instances of an open-vocabulary concept specified by a short text phrase or exemplars. Unlike prior work, SAM 3 can handle a vastly larger set of open-vocabulary prompts. It achieves 75-80% of human performance on our new [SA-CO benchmark](https://github.com/facebookresearch/sam3?tab=readme-ov-file#sa-co-dataset) which contains 270K unique concepts, over 50 times more than existing benchmarks.
|
||||||
|
|
||||||
|
This breakthrough is driven by an innovative data engine that has automatically annotated over 4 million unique concepts, creating the largest high-quality open-vocabulary segmentation dataset to date. In addition, SAM 3 introduces a new model architecture featuring a presence token that improves discrimination between closely related text prompts (e.g., “a player in white” vs. “a player in red”), as well as a decoupled detector–tracker design that minimizes task interference and scales efficiently with data.
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="assets/dog.gif" width=380 />
|
||||||
|
<img src="assets/player.gif" width=380 />
|
||||||
|
</p>
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.12 or higher
|
||||||
|
- PyTorch 2.7 or higher
|
||||||
|
- CUDA-compatible GPU with CUDA 12.6 or higher
|
||||||
|
|
||||||
|
1. **Create a new Conda environment:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda create -n sam3 python=3.12
|
||||||
|
conda deactivate
|
||||||
|
conda activate sam3
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install PyTorch with CUDA support:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Clone the repository and install the package:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/facebookresearch/sam3.git
|
||||||
|
cd sam3
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Install additional dependencies for example notebooks or development:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For running example notebooks
|
||||||
|
pip install -e ".[notebooks]"
|
||||||
|
|
||||||
|
# For development
|
||||||
|
pip install -e ".[train,dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
⚠️ Before using SAM 3, please request access to the checkpoints on the SAM 3
|
||||||
|
Hugging Face [repo](https://huggingface.co/facebook/sam3). Once accepted, you
|
||||||
|
need to be authenticated to download the checkpoints. You can do this by running
|
||||||
|
the following [steps](https://huggingface.co/docs/huggingface_hub/en/quick-start#authentication)
|
||||||
|
(e.g. `hf auth login` after generating an access token.)
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
#################################### For Image ####################################
|
||||||
|
from PIL import Image
|
||||||
|
from sam3.model_builder import build_sam3_image_model
|
||||||
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
|
# Load the model
|
||||||
|
model = build_sam3_image_model()
|
||||||
|
processor = Sam3Processor(model)
|
||||||
|
# Load an image
|
||||||
|
image = Image.open("<YOUR_IMAGE_PATH.jpg>")
|
||||||
|
inference_state = processor.set_image(image)
|
||||||
|
# Prompt the model with text
|
||||||
|
output = processor.set_text_prompt(state=inference_state, prompt="<YOUR_TEXT_PROMPT>")
|
||||||
|
|
||||||
|
# Get the masks, bounding boxes, and scores
|
||||||
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
|
|
||||||
|
#################################### For Video ####################################
|
||||||
|
|
||||||
|
from sam3.model_builder import build_sam3_video_predictor
|
||||||
|
|
||||||
|
video_predictor = build_sam3_video_predictor()
|
||||||
|
video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
|
||||||
|
# Start a session
|
||||||
|
response = video_predictor.handle_request(
|
||||||
|
request=dict(
|
||||||
|
type="start_session",
|
||||||
|
resource_path=video_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response = video_predictor.handle_request(
|
||||||
|
request=dict(
|
||||||
|
type="add_prompt",
|
||||||
|
session_id=response["session_id"],
|
||||||
|
frame_index=0, # Arbitrary frame index
|
||||||
|
text="<YOUR_TEXT_PROMPT>",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output = response["outputs"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
The `examples` directory contains notebooks demonstrating how to use SAM3 with
|
||||||
|
various types of prompts:
|
||||||
|
|
||||||
|
- [`sam3_image_predictor_example.ipynb`](examples/sam3_image_predictor_example.ipynb)
|
||||||
|
: Demonstrates how to prompt SAM 3 with text and visual box prompts on images.
|
||||||
|
- [`sam3_video_predictor_example.ipynb`](examples/sam3_video_predictor_example.ipynb)
|
||||||
|
: Demonstrates how to prompt SAM 3 with text prompts on videos, and doing
|
||||||
|
further interactive refinements with points.
|
||||||
|
- [`sam3_image_batched_inference.ipynb`](examples/sam3_image_batched_inference.ipynb)
|
||||||
|
: Demonstrates how to run batched inference with SAM 3 on images.
|
||||||
|
- [`sam3_agent.ipynb`](examples/sam3_agent.ipynb): Demonsterates the use of SAM
|
||||||
|
3 Agent to segment complex text prompt on images.
|
||||||
|
- [`saco_gold_silver_vis_example.ipynb`](examples/saco_gold_silver_vis_example.ipynb)
|
||||||
|
: Shows a few examples from SA-Co image evaluation set.
|
||||||
|
- [`saco_veval_vis_example.ipynb`](examples/saco_veval_vis_example.ipynb) :
|
||||||
|
Shows a few examples from SA-Co video evaluation set.
|
||||||
|
|
||||||
|
There are additional notebooks in the examples directory that demonstrate how to
|
||||||
|
use SAM 3 for interactive instance segmentation in images and videos (SAM 1/2
|
||||||
|
tasks), or as a tool for an MLLM, and how to run evaluations on the SA-Co
|
||||||
|
dataset.
|
||||||
|
|
||||||
|
To run the Jupyter notebook examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Make sure you have the notebooks dependencies installed
|
||||||
|
pip install -e ".[notebooks]"
|
||||||
|
|
||||||
|
# Start Jupyter notebook
|
||||||
|
jupyter notebook examples/sam3_image_predictor_example.ipynb
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model
|
||||||
|
|
||||||
|
SAM 3 consists of a detector and a tracker that share a vision encoder. It has 848M parameters. The
|
||||||
|
detector is a DETR-based model conditioned on text, geometry, and image
|
||||||
|
exemplars. The tracker inherits the SAM 2 transformer encoder-decoder
|
||||||
|
architecture, supporting video segmentation and interactive refinement.
|
||||||
|
|
||||||
|
## Image Results
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th rowspan="3" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
|
||||||
|
<th colspan="3" style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">Instance Segmentation</th>
|
||||||
|
<th colspan="5" style="text-align: center; padding: 12px 20px">Box Detection</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
|
||||||
|
<th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">SA-Co/Gold</th>
|
||||||
|
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVIS</th>
|
||||||
|
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">COCO</th>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">SA-Co/Gold</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||||
|
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
|
||||||
|
<th style="text-align: center; border-right: 2px solid #ddd; padding: 12px 20px">cgF1</th>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||||
|
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP</th>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">AP</th>
|
||||||
|
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">AP<sub>o</sub>
|
||||||
|
</th>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">72.8</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">74.0</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td style="border-right: 2px solid #ddd; padding: 10px 20px">OWLv2*</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px; color: #999">29.3</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">43.4</td>
|
||||||
|
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">24.6</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px; color: #999">30.2</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px; color: #999">45.5</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">46.1</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">23.9</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">24.5</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td style="border-right: 2px solid #ddd; padding: 10px 20px">DINO-X</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">38.5</td>
|
||||||
|
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">21.3</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">52.4</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">56.0</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">22.5</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Gemini 2.5</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">13.4</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">13.0</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">16.1</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">14.4</td>
|
||||||
|
</tr>
|
||||||
|
<tr style="border-top: 2px solid #b19c9cff">
|
||||||
|
<td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">37.2</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">48.5</td>
|
||||||
|
<td style="text-align: center; border-right: 2px solid #ddd; padding: 10px 20px">54.1</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">40.6</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">53.6</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">56.4</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">55.7</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">55.7</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
|
||||||
|
<p style="text-align: center; margin-top: 10px; font-size: 0.9em; color: #ddd;">* Partially trained on LVIS, AP<sub>o</sub> refers to COCO-O accuracy</p>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Video Results
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<table style="min-width: 80%; border: 2px solid #ddd; border-collapse: collapse">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th rowspan="2" style="border-right: 2px solid #ddd; padding: 12px 20px">Model</th>
|
||||||
|
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SA-V test</th>
|
||||||
|
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">YT-Temporal-1B test</th>
|
||||||
|
<th colspan="2" style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">SmartGlasses test</th>
|
||||||
|
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">LVVIS test</th>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">BURST test</th>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||||
|
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||||
|
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">cgF1</th>
|
||||||
|
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">pHOTA</th>
|
||||||
|
<th style="text-align: center; border-right: 1px solid #eee; padding: 12px 20px">mAP</th>
|
||||||
|
<th style="text-align: center; padding: 12px 20px">HOTA</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<td style="border-right: 2px solid #ddd; padding: 10px 20px">Human</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">53.1</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">70.5</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">71.2</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">78.4</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">58.5</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">72.3</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">-</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">-</td>
|
||||||
|
</tr>
|
||||||
|
<tr style="border-top: 2px solid #b19c9cff">
|
||||||
|
<td style="border-right: 2px solid #ddd; padding: 10px 20px">SAM 3</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">30.3</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">58.0</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">50.8</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">69.9</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">36.4</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">63.6</td>
|
||||||
|
<td style="text-align: center; border-right: 1px solid #eee; padding: 10px 20px">36.3</td>
|
||||||
|
<td style="text-align: center; padding: 10px 20px">44.5</td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## SA-Co Dataset
|
||||||
|
|
||||||
|
We release 2 image benchmarks, [SA-Co/Gold](scripts/eval/gold/README.md) and
|
||||||
|
[SA-Co/Silver](scripts/eval/silver/README.md), and a video benchmark
|
||||||
|
[SA-Co/VEval](scripts/eval/veval/README.md). The datasets contain images (or videos) with annotated noun phrases. Each image/video and noun phrase pair is annotated with instance masks and unique IDs of each object matching the phrase. Phrases that have no matching objects (negative prompts) have no masks, shown in red font in the figure. See the linked READMEs for more details on how to download and run evaluations on the datasets.
|
||||||
|
|
||||||
|
* HuggingFace host: [SA-Co/Gold](https://huggingface.co/datasets/facebook/SACo-Gold), [SA-Co/Silver](https://huggingface.co/datasets/facebook/SACo-Silver) and [SA-Co/VEval](https://huggingface.co/datasets/facebook/SACo-VEval)
|
||||||
|
* Roboflow host: [SA-Co/Gold](https://universe.roboflow.com/sa-co-gold), [SA-Co/Silver](https://universe.roboflow.com/sa-co-silver) and [SA-Co/VEval](https://universe.roboflow.com/sa-co-veval)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
To set up the development environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[dev,train]"
|
||||||
|
```
|
||||||
|
|
||||||
|
To format the code:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ufmt format .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
See [contributing](CONTRIBUTING.md) and the
|
||||||
|
[code of conduct](CODE_OF_CONDUCT.md).
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project is licensed under the SAM License - see the [LICENSE](LICENSE) file
|
||||||
|
for details.
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
We would like to thank the following people for their contributions to the SAM 3 project: Alex He, Alexander Kirillov,
|
||||||
|
Alyssa Newcomb, Ana Paula Kirschner Mofarrej, Andrea Madotto, Andrew Westbury, Ashley Gabriel, Azita Shokpour,
|
||||||
|
Ben Samples, Bernie Huang, Carleigh Wood, Ching-Feng Yeh, Christian Puhrsch, Claudette Ward, Daniel Bolya,
|
||||||
|
Daniel Li, Facundo Figueroa, Fazila Vhora, George Orlin, Hanzi Mao, Helen Klein, Hu Xu, Ida Cheng, Jake Kinney,
|
||||||
|
Jiale Zhi, Jo Sampaio, Joel Schlosser, Justin Johnson, Kai Brown, Karen Bergan, Karla Martucci, Kenny Lehmann,
|
||||||
|
Maddie Mintz, Mallika Malhotra, Matt Ward, Michelle Chan, Michelle Restrepo, Miranda Hartley, Muhammad Maaz,
|
||||||
|
Nisha Deo, Peter Park, Phillip Thomas, Raghu Nayani, Rene Martinez Doehner, Robbie Adkins, Ross Girshik, Sasha
|
||||||
|
Mitts, Shashank Jain, Spencer Whitehead, Ty Toledano, Valentin Gabeur, Vincent Cho, Vivian Lee, William Ngan,
|
||||||
|
Xuehai He, Yael Yungster, Ziqi Pang, Ziyi Dou, Zoe Quake.
|
||||||
|
|
||||||
|
## Citing SAM 3
|
||||||
|
|
||||||
|
If you use SAM 3 or the SA-Co dataset in your research, please use the following BibTeX entry.
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{carion2025sam3segmentconcepts,
|
||||||
|
title={SAM 3: Segment Anything with Concepts},
|
||||||
|
author={Nicolas Carion and Laura Gustafson and Yuan-Ting Hu and Shoubhik Debnath and Ronghang Hu and Didac Suris and Chaitanya Ryali and Kalyan Vasudev Alwala and Haitham Khedr and Andrew Huang and Jie Lei and Tengyu Ma and Baishan Guo and Arpit Kalla and Markus Marks and Joseph Greer and Meng Wang and Peize Sun and Roman Rädle and Triantafyllos Afouras and Effrosyni Mavroudi and Katherine Xu and Tsung-Han Wu and Yu Zhou and Liliane Momeni and Rishi Hazra and Shuangrui Ding and Sagar Vaze and Francois Porcher and Feng Li and Siyuan Li and Aishwarya Kamath and Ho Kei Cheng and Piotr Dollár and Nikhila Ravi and Kate Saenko and Pengchuan Zhang and Christoph Feichtenhofer},
|
||||||
|
year={2025},
|
||||||
|
eprint={2511.16719},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CV},
|
||||||
|
url={https://arxiv.org/abs/2511.16719},
|
||||||
|
}
|
||||||
|
```
|
||||||
190
sam3/README_TRAIN.md
Normal file
@ -0,0 +1,190 @@
|
|||||||
|
# Training
|
||||||
|
|
||||||
|
This repository supports finetuning SAM3 models on custom datasets in multi-node setup or local execution. The training script is located at `sam3/train.py` and uses Hydra configuration management to handle complex training setups.
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd sam3
|
||||||
|
pip install -e ".[train]"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Script Usage
|
||||||
|
|
||||||
|
The main training script is located at `sam3/train.py`. It uses Hydra configuration management to handle complex training setups.
|
||||||
|
|
||||||
|
#### Basic Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Example: Train on Roboflow dataset
|
||||||
|
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml
|
||||||
|
# Example: Train on ODinW13 dataset
|
||||||
|
python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
|
||||||
|
```
|
||||||
|
Follow [`Roboflow 100-VL`](https://github.com/roboflow/rf100-vl/) to download the roboflow 100-vl datasets. Follow [`GLIP`](https://github.com/microsoft/GLIP) to download the ODinW datasets. The data folder should be organized as follows, and put your roboflow_vl_100_root and odinw_data_root in the job configs.
|
||||||
|
```
|
||||||
|
roboflow_vl_100_root:
|
||||||
|
13-lkc01
|
||||||
|
train
|
||||||
|
valid
|
||||||
|
test
|
||||||
|
2024-frc
|
||||||
|
actions
|
||||||
|
...
|
||||||
|
odinw_data_root:
|
||||||
|
AerialMaritimeDrone
|
||||||
|
large
|
||||||
|
train
|
||||||
|
valid
|
||||||
|
test
|
||||||
|
Aquarium
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Command Line Arguments
|
||||||
|
|
||||||
|
The training script supports several command line arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python sam3/train/train.py \
|
||||||
|
-c CONFIG_NAME \
|
||||||
|
[--use-cluster 0|1] \
|
||||||
|
[--partition PARTITION_NAME] \
|
||||||
|
[--account ACCOUNT_NAME] \
|
||||||
|
[--qos QOS_NAME] \
|
||||||
|
[--num-gpus NUM_GPUS] \
|
||||||
|
[--num-nodes NUM_NODES]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Arguments:**
|
||||||
|
- `-c, --config`: **Required.** Path to the configuration file (e.g., `sam3/train/configs/roboflow_v100_full_ft_100_images.yaml`)
|
||||||
|
- `--use-cluster`: Whether to launch on a cluster (0: local, 1: cluster). Default: uses config setting
|
||||||
|
- `--partition`: SLURM partition name for cluster execution
|
||||||
|
- `--account`: SLURM account name for cluster execution
|
||||||
|
- `--qos`: SLURM QOS (Quality of Service) setting
|
||||||
|
- `--num-gpus`: Number of GPUs per node. Default: uses config setting
|
||||||
|
- `--num-nodes`: Number of nodes for distributed training. Default: uses config setting
|
||||||
|
|
||||||
|
#### Local Training Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Single GPU training
|
||||||
|
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 1
|
||||||
|
|
||||||
|
# Multi-GPU training on a single node
|
||||||
|
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0 --num-gpus 4
|
||||||
|
|
||||||
|
# Force local execution even if config specifies GPUs
|
||||||
|
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 0
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Cluster Training Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Basic cluster training with default settings from config
|
||||||
|
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml --use-cluster 1
|
||||||
|
|
||||||
|
# Cluster training with specific SLURM settings
|
||||||
|
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
|
||||||
|
--use-cluster 1 \
|
||||||
|
--partition gpu_partition \
|
||||||
|
--account my_account \
|
||||||
|
--qos high_priority \
|
||||||
|
--num-gpus 8 \
|
||||||
|
--num-nodes 2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration Files
|
||||||
|
|
||||||
|
Training configurations are stored in `sam3/train/configs/`. The configuration files use Hydra's YAML format and support:
|
||||||
|
|
||||||
|
- **Dataset Configuration**: Data paths, transforms, and loading parameters
|
||||||
|
- **Model Configuration**: Architecture settings, checkpoint paths, and model parameters
|
||||||
|
- **Training Configuration**: Batch sizes, learning rates, optimization settings
|
||||||
|
- **Launcher Configuration**: Distributed training and cluster settings
|
||||||
|
- **Logging Configuration**: TensorBoard, experiment tracking, and output directories
|
||||||
|
|
||||||
|
#### Key Configuration Sections
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Paths to datasets and checkpoints
|
||||||
|
paths:
|
||||||
|
bpe_path: /path/to/bpe/file
|
||||||
|
dataset_root: /path/to/dataset
|
||||||
|
experiment_log_dir: /path/to/logs
|
||||||
|
|
||||||
|
# Launcher settings for local/cluster execution
|
||||||
|
launcher:
|
||||||
|
num_nodes: 1
|
||||||
|
gpus_per_node: 2
|
||||||
|
experiment_log_dir: ${paths.experiment_log_dir}
|
||||||
|
|
||||||
|
# Cluster execution settings
|
||||||
|
submitit:
|
||||||
|
use_cluster: True
|
||||||
|
timeout_hour: 72
|
||||||
|
cpus_per_task: 10
|
||||||
|
partition: null
|
||||||
|
account: null
|
||||||
|
```
|
||||||
|
|
||||||
|
### Monitoring Training
|
||||||
|
|
||||||
|
The training script automatically sets up logging and saves outputs to the experiment directory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Logs are saved to the experiment_log_dir specified in config
|
||||||
|
experiment_log_dir/
|
||||||
|
├── config.yaml # Original configuration
|
||||||
|
├── config_resolved.yaml # Resolved configuration with all variables expanded
|
||||||
|
├── checkpoints/ # Model checkpoints (if skip_checkpointing=False)
|
||||||
|
├── tensorboard/ # TensorBoard logs
|
||||||
|
├── logs/ # Text logs
|
||||||
|
└── submitit_logs/ # Cluster job logs (if using cluster)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can monitor training progress using TensorBoard:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tensorboard --logdir /path/to/experiment_log_dir/tensorboard
|
||||||
|
```
|
||||||
|
|
||||||
|
### Job Arrays for Dataset Sweeps
|
||||||
|
|
||||||
|
The Roboflow and ODinW configuration supports job arrays for training multiple models on different datasets:
|
||||||
|
|
||||||
|
This feature is specifically enabled via,
|
||||||
|
```yaml
|
||||||
|
submitit:
|
||||||
|
job_array:
|
||||||
|
num_tasks: 100
|
||||||
|
task_index: 0
|
||||||
|
```
|
||||||
|
|
||||||
|
The configuration includes a complete list of 100 Roboflow supercategories, and the `submitit.job_array.task_index` automatically selects which dataset to use based on the array job index.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Submit job array to train on different Roboflow datasets
|
||||||
|
# The job array index selects which dataset from all_roboflow_supercategories
|
||||||
|
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_full_ft_100_images.yaml \
|
||||||
|
--use-cluster 1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reproduce ODinW13 10-shot results
|
||||||
|
Running the following job will give the results on the ODinW13 seed 300, see `odinw_train.train_file: fewshot_train_shot10_seed300` in the config file.
|
||||||
|
```bash
|
||||||
|
# Example: Train on ODinW13 dataset
|
||||||
|
python sam3/train/train.py -c configs/odinw13/odinw_text_only_train.yaml
|
||||||
|
```
|
||||||
|
Change `odinw_train.train_file` to `fewshot_train_shot10_seed30` and `fewshot_train_shot10_seed3` to get the results for the other two seeds. Final results are aggregated from the three seeds. Notice that a small number of jobs may diverge during training, in which case we just use the last checkpoint's result before it diverges.
|
||||||
|
|
||||||
|
|
||||||
|
### Eval Script Usage
|
||||||
|
With a similar setup as the training config, the training script `sam3/train.py` can also be used for evaluation, too, when setting `trainer.mode = val` in the job config. Run the following job will give the results on the zero-shot results on RF100-VL and ODinW13 datasets.
|
||||||
|
```bash
|
||||||
|
# Example: Evaluate on Roboflow dataset
|
||||||
|
python sam3/train/train.py -c configs/roboflow_v100/roboflow_v100_eval.yaml
|
||||||
|
# Example: Evaluate on ODinW13 dataset
|
||||||
|
python sam3/train/train.py -c configs/odinw13/odinw_text_only.yaml
|
||||||
|
```
|
||||||
BIN
sam3/assets/dog.gif
Normal file
|
After Width: | Height: | Size: 6.8 MiB |
BIN
sam3/assets/images/groceries.jpg
Normal file
|
After Width: | Height: | Size: 164 KiB |
BIN
sam3/assets/images/truck.jpg
Normal file
|
After Width: | Height: | Size: 265 KiB |
BIN
sam3/assets/model_diagram.png
Normal file
|
After Width: | Height: | Size: 707 KiB |
BIN
sam3/assets/player.gif
Normal file
|
After Width: | Height: | Size: 4.2 MiB |
BIN
sam3/assets/sa_co_dataset.jpg
Normal file
|
After Width: | Height: | Size: 991 KiB |
BIN
sam3/assets/saco_gold_annotation.png
Normal file
|
After Width: | Height: | Size: 3.8 MiB |
BIN
sam3/assets/videos/0001/0.jpg
Normal file
|
After Width: | Height: | Size: 141 KiB |
BIN
sam3/assets/videos/0001/1.jpg
Normal file
|
After Width: | Height: | Size: 138 KiB |
BIN
sam3/assets/videos/0001/10.jpg
Normal file
|
After Width: | Height: | Size: 134 KiB |
BIN
sam3/assets/videos/0001/100.jpg
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
sam3/assets/videos/0001/101.jpg
Normal file
|
After Width: | Height: | Size: 114 KiB |
BIN
sam3/assets/videos/0001/102.jpg
Normal file
|
After Width: | Height: | Size: 111 KiB |
BIN
sam3/assets/videos/0001/103.jpg
Normal file
|
After Width: | Height: | Size: 111 KiB |
BIN
sam3/assets/videos/0001/104.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
sam3/assets/videos/0001/105.jpg
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
sam3/assets/videos/0001/106.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
sam3/assets/videos/0001/107.jpg
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
sam3/assets/videos/0001/108.jpg
Normal file
|
After Width: | Height: | Size: 111 KiB |
BIN
sam3/assets/videos/0001/109.jpg
Normal file
|
After Width: | Height: | Size: 114 KiB |
BIN
sam3/assets/videos/0001/11.jpg
Normal file
|
After Width: | Height: | Size: 136 KiB |
BIN
sam3/assets/videos/0001/110.jpg
Normal file
|
After Width: | Height: | Size: 113 KiB |
BIN
sam3/assets/videos/0001/111.jpg
Normal file
|
After Width: | Height: | Size: 113 KiB |
BIN
sam3/assets/videos/0001/112.jpg
Normal file
|
After Width: | Height: | Size: 112 KiB |
BIN
sam3/assets/videos/0001/113.jpg
Normal file
|
After Width: | Height: | Size: 113 KiB |
BIN
sam3/assets/videos/0001/114.jpg
Normal file
|
After Width: | Height: | Size: 111 KiB |
BIN
sam3/assets/videos/0001/115.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
sam3/assets/videos/0001/116.jpg
Normal file
|
After Width: | Height: | Size: 110 KiB |
BIN
sam3/assets/videos/0001/117.jpg
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
sam3/assets/videos/0001/118.jpg
Normal file
|
After Width: | Height: | Size: 107 KiB |
BIN
sam3/assets/videos/0001/119.jpg
Normal file
|
After Width: | Height: | Size: 105 KiB |
BIN
sam3/assets/videos/0001/12.jpg
Normal file
|
After Width: | Height: | Size: 134 KiB |
BIN
sam3/assets/videos/0001/120.jpg
Normal file
|
After Width: | Height: | Size: 106 KiB |
BIN
sam3/assets/videos/0001/121.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/122.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/123.jpg
Normal file
|
After Width: | Height: | Size: 106 KiB |
BIN
sam3/assets/videos/0001/124.jpg
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
sam3/assets/videos/0001/125.jpg
Normal file
|
After Width: | Height: | Size: 107 KiB |
BIN
sam3/assets/videos/0001/126.jpg
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
sam3/assets/videos/0001/127.jpg
Normal file
|
After Width: | Height: | Size: 105 KiB |
BIN
sam3/assets/videos/0001/128.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/129.jpg
Normal file
|
After Width: | Height: | Size: 102 KiB |
BIN
sam3/assets/videos/0001/13.jpg
Normal file
|
After Width: | Height: | Size: 136 KiB |
BIN
sam3/assets/videos/0001/130.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/131.jpg
Normal file
|
After Width: | Height: | Size: 102 KiB |
BIN
sam3/assets/videos/0001/132.jpg
Normal file
|
After Width: | Height: | Size: 103 KiB |
BIN
sam3/assets/videos/0001/133.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/134.jpg
Normal file
|
After Width: | Height: | Size: 106 KiB |
BIN
sam3/assets/videos/0001/135.jpg
Normal file
|
After Width: | Height: | Size: 103 KiB |
BIN
sam3/assets/videos/0001/136.jpg
Normal file
|
After Width: | Height: | Size: 103 KiB |
BIN
sam3/assets/videos/0001/137.jpg
Normal file
|
After Width: | Height: | Size: 101 KiB |
BIN
sam3/assets/videos/0001/138.jpg
Normal file
|
After Width: | Height: | Size: 102 KiB |
BIN
sam3/assets/videos/0001/139.jpg
Normal file
|
After Width: | Height: | Size: 100 KiB |
BIN
sam3/assets/videos/0001/14.jpg
Normal file
|
After Width: | Height: | Size: 134 KiB |
BIN
sam3/assets/videos/0001/140.jpg
Normal file
|
After Width: | Height: | Size: 99 KiB |
BIN
sam3/assets/videos/0001/141.jpg
Normal file
|
After Width: | Height: | Size: 101 KiB |
BIN
sam3/assets/videos/0001/142.jpg
Normal file
|
After Width: | Height: | Size: 101 KiB |
BIN
sam3/assets/videos/0001/143.jpg
Normal file
|
After Width: | Height: | Size: 103 KiB |
BIN
sam3/assets/videos/0001/144.jpg
Normal file
|
After Width: | Height: | Size: 103 KiB |
BIN
sam3/assets/videos/0001/145.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/146.jpg
Normal file
|
After Width: | Height: | Size: 102 KiB |
BIN
sam3/assets/videos/0001/147.jpg
Normal file
|
After Width: | Height: | Size: 101 KiB |
BIN
sam3/assets/videos/0001/148.jpg
Normal file
|
After Width: | Height: | Size: 99 KiB |
BIN
sam3/assets/videos/0001/149.jpg
Normal file
|
After Width: | Height: | Size: 97 KiB |
BIN
sam3/assets/videos/0001/15.jpg
Normal file
|
After Width: | Height: | Size: 133 KiB |
BIN
sam3/assets/videos/0001/150.jpg
Normal file
|
After Width: | Height: | Size: 98 KiB |
BIN
sam3/assets/videos/0001/151.jpg
Normal file
|
After Width: | Height: | Size: 99 KiB |
BIN
sam3/assets/videos/0001/152.jpg
Normal file
|
After Width: | Height: | Size: 102 KiB |
BIN
sam3/assets/videos/0001/153.jpg
Normal file
|
After Width: | Height: | Size: 104 KiB |
BIN
sam3/assets/videos/0001/154.jpg
Normal file
|
After Width: | Height: | Size: 107 KiB |
BIN
sam3/assets/videos/0001/155.jpg
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
sam3/assets/videos/0001/156.jpg
Normal file
|
After Width: | Height: | Size: 108 KiB |
BIN
sam3/assets/videos/0001/157.jpg
Normal file
|
After Width: | Height: | Size: 109 KiB |
BIN
sam3/assets/videos/0001/158.jpg
Normal file
|
After Width: | Height: | Size: 106 KiB |
BIN
sam3/assets/videos/0001/159.jpg
Normal file
|
After Width: | Height: | Size: 103 KiB |
BIN
sam3/assets/videos/0001/16.jpg
Normal file
|
After Width: | Height: | Size: 131 KiB |
BIN
sam3/assets/videos/0001/160.jpg
Normal file
|
After Width: | Height: | Size: 102 KiB |