24 lines
800 B
Python

# 模型初始化文件
from .deeplabv3plus import DeepLabV3Plus
from .unetpp import UNetPlusPlus
def get_model(model_type, num_classes, **kwargs):
"""获取指定类型的模型
Args:
model_type (str): 模型类型,可选 'deeplabv3plus''unetpp'
num_classes (int): 类别数量
**kwargs: 其他模型参数
Returns:
nn.Module: 模型实例
"""
model_type = model_type.lower()
if model_type == 'deeplabv3plus':
return DeepLabV3Plus(num_classes=num_classes, **kwargs)
elif model_type == 'unetpp':
return UNetPlusPlus(num_classes=num_classes, **kwargs)
else:
raise ValueError(f"不支持的模型类型: {model_type},请选择 'deeplabv3plus', 'deeplabv3plus_optimized''unetpp'")