2025-07-10 09:41:26 +08:00

150 lines
6.4 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet34, ResNet34_Weights
class ConvBlock(nn.Module):
"""卷积块"""
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class UNetPlusPlus(nn.Module):
"""UNet++模型,适用于大尺度遥感图像语义分割"""
def __init__(self, num_classes, in_channels=3, deep_supervision=False, pretrained=True):
super(UNetPlusPlus, self).__init__()
self.deep_supervision = deep_supervision
# 使用预训练的ResNet34作为编码器
if pretrained:
backbone = resnet34(weights=ResNet34_Weights.DEFAULT)
else:
backbone = resnet34(weights=None)
# 编码器
self.encoder0 = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.encoder1 = nn.Sequential(backbone.layer1)
self.encoder2 = nn.Sequential(backbone.layer2)
self.encoder3 = nn.Sequential(backbone.layer3)
self.encoder4 = nn.Sequential(backbone.layer4)
# 解码器
self.decoder0_1 = ConvBlock(64 + 64, 64)
self.decoder1_1 = ConvBlock(64 + 128, 64)
self.decoder2_1 = ConvBlock(128 + 256, 128)
self.decoder3_1 = ConvBlock(256 + 512, 256)
self.decoder0_2 = ConvBlock(64 + 64, 64)
self.decoder1_2 = ConvBlock(64 + 64 + 128, 64)
self.decoder2_2 = ConvBlock(128 + 128 + 256, 128)
self.decoder0_3 = ConvBlock(64 + 64 + 64, 64)
self.decoder1_3 = ConvBlock(64 + 64 + 64 + 128, 64)
self.decoder0_4 = ConvBlock(64 + 64 + 64 + 64, 64)
# 输出层
if self.deep_supervision:
self.final1 = nn.Conv2d(64, num_classes, 1)
self.final2 = nn.Conv2d(64, num_classes, 1)
self.final3 = nn.Conv2d(64, num_classes, 1)
self.final4 = nn.Conv2d(64, num_classes, 1)
else:
self.final = nn.Conv2d(64, num_classes, 1)
# 初始化权重
self._init_weight()
def forward(self, x):
# 编码器
x0_0 = self.encoder0(x)
x1_0 = self.encoder1(x0_0)
x2_0 = self.encoder2(x1_0)
x3_0 = self.encoder3(x2_0)
x4_0 = self.encoder4(x3_0)
# 解码器
x0_1 = self.decoder0_1(torch.cat([x0_0, F.interpolate(x1_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x1_1 = self.decoder1_1(torch.cat([x1_0, F.interpolate(x2_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x2_1 = self.decoder2_1(torch.cat([x2_0, F.interpolate(x3_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x3_1 = self.decoder3_1(torch.cat([x3_0, F.interpolate(x4_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x0_2 = self.decoder0_2(torch.cat([x0_0, x0_1, F.interpolate(x1_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x1_2 = self.decoder1_2(torch.cat([x1_0, x1_1, F.interpolate(x2_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x2_2 = self.decoder2_2(torch.cat([x2_0, x2_1, F.interpolate(x3_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x0_3 = self.decoder0_3(torch.cat([x0_0, x0_1, x0_2, F.interpolate(x1_2, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x1_3 = self.decoder1_3(torch.cat([x1_0, x1_1, x1_2, F.interpolate(x2_2, scale_factor=2, mode='bilinear', align_corners=True)], 1))
x0_4 = self.decoder0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, F.interpolate(x1_3, scale_factor=2, mode='bilinear', align_corners=True)], 1))
# 输出
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
output = self.final(x0_4)
return output
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def get_backbone_params(self):
modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4]
for i in range(len(modules)):
for m in modules[i].named_modules():
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
def get_decoder_params(self):
modules = [self.decoder0_1, self.decoder1_1, self.decoder2_1, self.decoder3_1,
self.decoder0_2, self.decoder1_2, self.decoder2_2,
self.decoder0_3, self.decoder1_3,
self.decoder0_4]
if self.deep_supervision:
modules.extend([self.final1, self.final2, self.final3, self.final4])
else:
modules.append(self.final)
for i in range(len(modules)):
for m in modules[i].named_modules():
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
def freeze_backbone(self):
modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4]
for i in range(len(modules)):
for param in modules[i].parameters():
param.requires_grad = False
def unfreeze_backbone(self):
modules = [self.encoder0, self.encoder1, self.encoder2, self.encoder3, self.encoder4]
for i in range(len(modules)):
for param in modules[i].parameters():
param.requires_grad = True