import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import resnet34, ResNet34_Weights class ConvBlock(nn.Module): """卷积块""" def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv(x) class UNetPlusPlus(nn.Module): """UNet++模型,适用于大尺度遥感图像语义分割""" def __init__(self, num_classes, in_channels=3, deep_supervision=False, pretrained=True): super(UNetPlusPlus, self).__init__() self.deep_supervision = deep_supervision # 使用预训练的ResNet34作为编码器 if pretrained: backbone = resnet34(weights=ResNet34_Weights.DEFAULT) else: backbone = resnet34(weights=None) # 编码器 self.encoder0 = nn.Sequential( nn.Conv2d(in_channels, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True) ) self.encoder1 = nn.Sequential(backbone.layer1) self.encoder2 = nn.Sequential(backbone.layer2) self.encoder3 = nn.Sequential(backbone.layer3) self.encoder4 = nn.Sequential(backbone.layer4) # 解码器 self.decoder0_1 = ConvBlock(64 + 64, 64) self.decoder1_1 = ConvBlock(64 + 128, 64) self.decoder2_1 = ConvBlock(128 + 256, 128) self.decoder3_1 = ConvBlock(256 + 512, 256) self.decoder0_2 = ConvBlock(64 + 64, 64) self.decoder1_2 = ConvBlock(64 + 64 + 128, 64) self.decoder2_2 = ConvBlock(128 + 128 + 256, 128) self.decoder0_3 = ConvBlock(64 + 64 + 64, 64) self.decoder1_3 = ConvBlock(64 + 64 + 64 + 128, 64) self.decoder0_4 = ConvBlock(64 + 64 + 64 + 64, 64) # 输出层 if self.deep_supervision: self.final1 = nn.Conv2d(64, num_classes, 1) self.final2 = nn.Conv2d(64, num_classes, 1) self.final3 = nn.Conv2d(64, num_classes, 1) self.final4 = nn.Conv2d(64, num_classes, 1) else: self.final = nn.Conv2d(64, num_classes, 1) # 初始化权重 self._init_weight() def forward(self, x): # 编码器 x0_0 = self.encoder0(x) x1_0 = self.encoder1(x0_0) x2_0 = self.encoder2(x1_0) x3_0 = self.encoder3(x2_0) x4_0 = self.encoder4(x3_0) # 解码器 x0_1 = self.decoder0_1(torch.cat([x0_0, F.interpolate(x1_0, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x1_1 = self.decoder1_1(torch.cat([x1_0, F.interpolate(x2_0, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x2_1 = self.decoder2_1(torch.cat([x2_0, F.interpolate(x3_0, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x3_1 = self.decoder3_1(torch.cat([x3_0, F.interpolate(x4_0, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x0_2 = self.decoder0_2(torch.cat([x0_0, x0_1, F.interpolate(x1_1, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x1_2 = self.decoder1_2(torch.cat([x1_0, x1_1, F.interpolate(x2_1, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x2_2 = self.decoder2_2(torch.cat([x2_0, x2_1, F.interpolate(x3_1, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x0_3 = self.decoder0_3(torch.cat([x0_0, x0_1, x0_2, F.interpolate(x1_2, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x1_3 = self.decoder1_3(torch.cat([x1_0, x1_1, x1_2, F.interpolate(x2_2, scale_factor=2, mode='bilinear', align_corners=True)], 1)) x0_4 = self.decoder0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, F.interpolate(x1_3, scale_factor=2, mode='bilinear', align_corners=True)], 1)) # 输出 if self.deep_supervision: output1 = self.final1(x0_1) output2 = self.final2(x0_2) output3 = self.final3(x0_3) output4 = self.final4(x0_4) return [output1, output2, output3, output4] else: output = self.final(x0_4) return output def _init_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def get_backbone_params(self): modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4] for i in range(len(modules)): for m in modules[i].named_modules(): if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): for p in m[1].parameters(): if p.requires_grad: yield p def get_decoder_params(self): modules = [self.decoder0_1, self.decoder1_1, self.decoder2_1, self.decoder3_1, self.decoder0_2, self.decoder1_2, self.decoder2_2, self.decoder0_3, self.decoder1_3, self.decoder0_4] if self.deep_supervision: modules.extend([self.final1, self.final2, self.final3, self.final4]) else: modules.append(self.final) for i in range(len(modules)): for m in modules[i].named_modules(): if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): for p in m[1].parameters(): if p.requires_grad: yield p def freeze_backbone(self): modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4] for i in range(len(modules)): for param in modules[i].parameters(): param.requires_grad = False def unfreeze_backbone(self): modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4] for i in range(len(modules)): for param in modules[i].parameters(): param.requires_grad = True