183 lines
6.1 KiB
Python
183 lines
6.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.models import resnet50, ResNet50_Weights
|
|
|
|
class ASPP(nn.Module):
|
|
"""空洞空间金字塔池化模块"""
|
|
def __init__(self, in_channels, out_channels, rates):
|
|
super(ASPP, self).__init__()
|
|
|
|
self.aspp = nn.ModuleList()
|
|
|
|
# 1x1 卷积
|
|
self.aspp.append(nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
))
|
|
|
|
# 空洞卷积
|
|
for rate in rates:
|
|
self.aspp.append(nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
))
|
|
|
|
# 全局平均池化
|
|
self.global_avg_pool = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
# 输出层
|
|
self.output = nn.Sequential(
|
|
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(0.5)
|
|
)
|
|
|
|
def forward(self, x):
|
|
size = x.size()
|
|
|
|
outputs = []
|
|
for module in self.aspp:
|
|
outputs.append(module(x))
|
|
|
|
# 全局平均池化
|
|
global_feat = self.global_avg_pool(x)
|
|
global_feat = F.interpolate(global_feat, size=(size[2], size[3]), mode='bilinear', align_corners=True)
|
|
outputs.append(global_feat)
|
|
|
|
# 拼接所有特征
|
|
x = torch.cat(outputs, dim=1)
|
|
|
|
return self.output(x)
|
|
|
|
class DeepLabV3Plus(nn.Module):
|
|
"""DeepLabV3+模型,适用于大尺度遥感图像语义分割"""
|
|
def __init__(self, num_classes, backbone='resnet50', output_stride=16, pretrained=True):
|
|
super(DeepLabV3Plus, self).__init__()
|
|
|
|
# 设置空洞卷积率
|
|
if output_stride == 16:
|
|
rates = [6, 12, 18]
|
|
elif output_stride == 8:
|
|
rates = [12, 24, 36]
|
|
else:
|
|
raise ValueError("output_stride 必须是 8 或 16!")
|
|
|
|
# 主干网络
|
|
if backbone == 'resnet50':
|
|
if pretrained:
|
|
self.backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
else:
|
|
self.backbone = resnet50(weights=None)
|
|
|
|
self.low_level_features = self.backbone.layer1
|
|
self.high_level_features = nn.Sequential(
|
|
self.backbone.layer2,
|
|
self.backbone.layer3,
|
|
self.backbone.layer4
|
|
)
|
|
low_level_channels = 256
|
|
high_level_channels = 2048
|
|
else:
|
|
raise ValueError(f"不支持的主干网络: {backbone}")
|
|
|
|
# ASPP模块
|
|
self.aspp = ASPP(high_level_channels, 256, rates)
|
|
|
|
# 低层特征处理
|
|
self.low_level_conv = nn.Sequential(
|
|
nn.Conv2d(low_level_channels, 48, 1, bias=False),
|
|
nn.BatchNorm2d(48),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
# 解码器
|
|
self.decoder = nn.Sequential(
|
|
nn.Conv2d(304, 256, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(256),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(0.5),
|
|
nn.Conv2d(256, 256, 3, padding=1, bias=False),
|
|
nn.BatchNorm2d(256),
|
|
nn.ReLU(inplace=True),
|
|
nn.Dropout(0.1),
|
|
nn.Conv2d(256, num_classes, 1)
|
|
)
|
|
|
|
# 初始化权重
|
|
self._init_weight()
|
|
|
|
def forward(self, x):
|
|
size = x.size()
|
|
|
|
# 提取特征
|
|
x = self.backbone.conv1(x)
|
|
x = self.backbone.bn1(x)
|
|
x = self.backbone.relu(x)
|
|
x = self.backbone.maxpool(x)
|
|
|
|
# 低层特征
|
|
low_level_feat = self.low_level_features(x)
|
|
|
|
# 高层特征
|
|
x = self.high_level_features(low_level_feat)
|
|
x = self.aspp(x)
|
|
|
|
# 上采样高层特征
|
|
x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
|
|
|
|
# 处理低层特征
|
|
low_level_feat = self.low_level_conv(low_level_feat)
|
|
|
|
# 拼接特征
|
|
x = torch.cat((x, low_level_feat), dim=1)
|
|
|
|
# 解码
|
|
x = self.decoder(x)
|
|
|
|
# 上采样到原始大小
|
|
x = F.interpolate(x, size=size[2:], mode='bilinear', align_corners=True)
|
|
|
|
return x
|
|
|
|
def _init_weight(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
m.weight.data.fill_(1)
|
|
m.bias.data.zero_()
|
|
|
|
def get_backbone_params(self):
|
|
modules = [self.backbone.conv1, self.backbone.bn1, self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4]
|
|
for i in range(len(modules)):
|
|
for m in modules[i].named_modules():
|
|
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
|
|
for p in m[1].parameters():
|
|
if p.requires_grad:
|
|
yield p
|
|
|
|
def get_decoder_params(self):
|
|
modules = [self.aspp, self.low_level_conv, self.decoder]
|
|
for i in range(len(modules)):
|
|
for m in modules[i].named_modules():
|
|
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
|
|
for p in m[1].parameters():
|
|
if p.requires_grad:
|
|
yield p
|
|
|
|
def freeze_backbone(self):
|
|
for param in self.backbone.parameters():
|
|
param.requires_grad = False
|
|
|
|
def unfreeze_backbone(self):
|
|
for param in self.backbone.parameters():
|
|
param.requires_grad = True |