24 lines
800 B
Python
24 lines
800 B
Python
# 模型初始化文件
|
|
|
|
from .deeplabv3plus import DeepLabV3Plus
|
|
from .unetpp import UNetPlusPlus
|
|
|
|
def get_model(model_type, num_classes, **kwargs):
|
|
"""获取指定类型的模型
|
|
|
|
Args:
|
|
model_type (str): 模型类型,可选 'deeplabv3plus' 或 'unetpp'
|
|
num_classes (int): 类别数量
|
|
**kwargs: 其他模型参数
|
|
|
|
Returns:
|
|
nn.Module: 模型实例
|
|
"""
|
|
model_type = model_type.lower()
|
|
|
|
if model_type == 'deeplabv3plus':
|
|
return DeepLabV3Plus(num_classes=num_classes, **kwargs)
|
|
elif model_type == 'unetpp':
|
|
return UNetPlusPlus(num_classes=num_classes, **kwargs)
|
|
else:
|
|
raise ValueError(f"不支持的模型类型: {model_type},请选择 'deeplabv3plus', 'deeplabv3plus_optimized' 或 'unetpp'") |