import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import resnet50, ResNet50_Weights class ASPP(nn.Module): """空洞空间金字塔池化模块""" def __init__(self, in_channels, out_channels, rates): super(ASPP, self).__init__() self.aspp = nn.ModuleList() # 1x1 卷积 self.aspp.append(nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) )) # 空洞卷积 for rate in rates: self.aspp.append(nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) )) # 全局平均池化 self.global_avg_pool = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) # 输出层 self.output = nn.Sequential( nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Dropout(0.5) ) def forward(self, x): size = x.size() outputs = [] for module in self.aspp: outputs.append(module(x)) # 全局平均池化 global_feat = self.global_avg_pool(x) global_feat = F.interpolate(global_feat, size=(size[2], size[3]), mode='bilinear', align_corners=True) outputs.append(global_feat) # 拼接所有特征 x = torch.cat(outputs, dim=1) return self.output(x) class DeepLabV3Plus(nn.Module): """DeepLabV3+模型,适用于大尺度遥感图像语义分割""" def __init__(self, num_classes, backbone='resnet50', output_stride=16, pretrained=True): super(DeepLabV3Plus, self).__init__() # 设置空洞卷积率 if output_stride == 16: rates = [6, 12, 18] elif output_stride == 8: rates = [12, 24, 36] else: raise ValueError("output_stride 必须是 8 或 16!") # 主干网络 if backbone == 'resnet50': if pretrained: self.backbone = resnet50(weights=ResNet50_Weights.DEFAULT) else: self.backbone = resnet50(weights=None) self.low_level_features = self.backbone.layer1 self.high_level_features = nn.Sequential( self.backbone.layer2, self.backbone.layer3, self.backbone.layer4 ) low_level_channels = 256 high_level_channels = 2048 else: raise ValueError(f"不支持的主干网络: {backbone}") # ASPP模块 self.aspp = ASPP(high_level_channels, 256, rates) # 低层特征处理 self.low_level_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1, bias=False), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) # 解码器 self.decoder = nn.Sequential( nn.Conv2d(304, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), nn.Conv2d(256, num_classes, 1) ) # 初始化权重 self._init_weight() def forward(self, x): size = x.size() # 提取特征 x = self.backbone.conv1(x) x = self.backbone.bn1(x) x = self.backbone.relu(x) x = self.backbone.maxpool(x) # 低层特征 low_level_feat = self.low_level_features(x) # 高层特征 x = self.high_level_features(low_level_feat) x = self.aspp(x) # 上采样高层特征 x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) # 处理低层特征 low_level_feat = self.low_level_conv(low_level_feat) # 拼接特征 x = torch.cat((x, low_level_feat), dim=1) # 解码 x = self.decoder(x) # 上采样到原始大小 x = F.interpolate(x, size=size[2:], mode='bilinear', align_corners=True) return x def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def get_backbone_params(self): modules = [self.backbone.conv1, self.backbone.bn1, self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4] for i in range(len(modules)): for m in modules[i].named_modules(): if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): for p in m[1].parameters(): if p.requires_grad: yield p def get_decoder_params(self): modules = [self.aspp, self.low_level_conv, self.decoder] for i in range(len(modules)): for m in modules[i].named_modules(): if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): for p in m[1].parameters(): if p.requires_grad: yield p def freeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad = False def unfreeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad = True