150 lines
6.4 KiB
Python
150 lines
6.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.models import resnet34, ResNet34_Weights
|
|
|
|
class ConvBlock(nn.Module):
|
|
"""卷积块"""
|
|
def __init__(self, in_channels, out_channels):
|
|
super(ConvBlock, self).__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class UNetPlusPlus(nn.Module):
|
|
"""UNet++模型,适用于大尺度遥感图像语义分割"""
|
|
def __init__(self, num_classes, in_channels=3, deep_supervision=False, pretrained=True):
|
|
super(UNetPlusPlus, self).__init__()
|
|
self.deep_supervision = deep_supervision
|
|
|
|
# 使用预训练的ResNet34作为编码器
|
|
if pretrained:
|
|
backbone = resnet34(weights=ResNet34_Weights.DEFAULT)
|
|
else:
|
|
backbone = resnet34(weights=None)
|
|
|
|
# 编码器
|
|
self.encoder0 = nn.Sequential(
|
|
nn.Conv2d(in_channels, 64, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.encoder1 = nn.Sequential(backbone.layer1)
|
|
self.encoder2 = nn.Sequential(backbone.layer2)
|
|
self.encoder3 = nn.Sequential(backbone.layer3)
|
|
self.encoder4 = nn.Sequential(backbone.layer4)
|
|
|
|
# 解码器
|
|
self.decoder0_1 = ConvBlock(64 + 64, 64)
|
|
self.decoder1_1 = ConvBlock(64 + 128, 64)
|
|
self.decoder2_1 = ConvBlock(128 + 256, 128)
|
|
self.decoder3_1 = ConvBlock(256 + 512, 256)
|
|
|
|
self.decoder0_2 = ConvBlock(64 + 64, 64)
|
|
self.decoder1_2 = ConvBlock(64 + 64 + 128, 64)
|
|
self.decoder2_2 = ConvBlock(128 + 128 + 256, 128)
|
|
|
|
self.decoder0_3 = ConvBlock(64 + 64 + 64, 64)
|
|
self.decoder1_3 = ConvBlock(64 + 64 + 64 + 128, 64)
|
|
|
|
self.decoder0_4 = ConvBlock(64 + 64 + 64 + 64, 64)
|
|
|
|
# 输出层
|
|
if self.deep_supervision:
|
|
self.final1 = nn.Conv2d(64, num_classes, 1)
|
|
self.final2 = nn.Conv2d(64, num_classes, 1)
|
|
self.final3 = nn.Conv2d(64, num_classes, 1)
|
|
self.final4 = nn.Conv2d(64, num_classes, 1)
|
|
else:
|
|
self.final = nn.Conv2d(64, num_classes, 1)
|
|
|
|
# 初始化权重
|
|
self._init_weight()
|
|
|
|
def forward(self, x):
|
|
# 编码器
|
|
x0_0 = self.encoder0(x)
|
|
x1_0 = self.encoder1(x0_0)
|
|
x2_0 = self.encoder2(x1_0)
|
|
x3_0 = self.encoder3(x2_0)
|
|
x4_0 = self.encoder4(x3_0)
|
|
|
|
# 解码器
|
|
x0_1 = self.decoder0_1(torch.cat([x0_0, F.interpolate(x1_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
x1_1 = self.decoder1_1(torch.cat([x1_0, F.interpolate(x2_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
x2_1 = self.decoder2_1(torch.cat([x2_0, F.interpolate(x3_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
x3_1 = self.decoder3_1(torch.cat([x3_0, F.interpolate(x4_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
|
|
x0_2 = self.decoder0_2(torch.cat([x0_0, x0_1, F.interpolate(x1_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
x1_2 = self.decoder1_2(torch.cat([x1_0, x1_1, F.interpolate(x2_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
x2_2 = self.decoder2_2(torch.cat([x2_0, x2_1, F.interpolate(x3_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
|
|
x0_3 = self.decoder0_3(torch.cat([x0_0, x0_1, x0_2, F.interpolate(x1_2, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
x1_3 = self.decoder1_3(torch.cat([x1_0, x1_1, x1_2, F.interpolate(x2_2, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
|
|
x0_4 = self.decoder0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, F.interpolate(x1_3, scale_factor=2, mode='bilinear', align_corners=True)], 1))
|
|
|
|
# 输出
|
|
if self.deep_supervision:
|
|
output1 = self.final1(x0_1)
|
|
output2 = self.final2(x0_2)
|
|
output3 = self.final3(x0_3)
|
|
output4 = self.final4(x0_4)
|
|
return [output1, output2, output3, output4]
|
|
else:
|
|
output = self.final(x0_4)
|
|
return output
|
|
|
|
def _init_weight(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
def get_backbone_params(self):
|
|
modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4]
|
|
for i in range(len(modules)):
|
|
for m in modules[i].named_modules():
|
|
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
|
|
for p in m[1].parameters():
|
|
if p.requires_grad:
|
|
yield p
|
|
|
|
def get_decoder_params(self):
|
|
modules = [self.decoder0_1, self.decoder1_1, self.decoder2_1, self.decoder3_1,
|
|
self.decoder0_2, self.decoder1_2, self.decoder2_2,
|
|
self.decoder0_3, self.decoder1_3,
|
|
self.decoder0_4]
|
|
if self.deep_supervision:
|
|
modules.extend([self.final1, self.final2, self.final3, self.final4])
|
|
else:
|
|
modules.append(self.final)
|
|
|
|
for i in range(len(modules)):
|
|
for m in modules[i].named_modules():
|
|
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
|
|
for p in m[1].parameters():
|
|
if p.requires_grad:
|
|
yield p
|
|
|
|
def freeze_backbone(self):
|
|
modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4]
|
|
for i in range(len(modules)):
|
|
for param in modules[i].parameters():
|
|
param.requires_grad = False
|
|
|
|
def unfreeze_backbone(self):
|
|
modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4]
|
|
for i in range(len(modules)):
|
|
for param in modules[i].parameters():
|
|
param.requires_grad = True |