# 模型初始化文件 from .deeplabv3plus import DeepLabV3Plus from .unetpp import UNetPlusPlus def get_model(model_type, num_classes, **kwargs): """获取指定类型的模型 Args: model_type (str): 模型类型,可选 'deeplabv3plus' 或 'unetpp' num_classes (int): 类别数量 **kwargs: 其他模型参数 Returns: nn.Module: 模型实例 """ model_type = model_type.lower() if model_type == 'deeplabv3plus': return DeepLabV3Plus(num_classes=num_classes, **kwargs) elif model_type == 'unetpp': return UNetPlusPlus(num_classes=num_classes, **kwargs) else: raise ValueError(f"不支持的模型类型: {model_type},请选择 'deeplabv3plus', 'deeplabv3plus_optimized' 或 'unetpp'")