2025-07-10 09:41:26 +08:00

183 lines
6.1 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
class ASPP(nn.Module):
"""空洞空间金字塔池化模块"""
def __init__(self, in_channels, out_channels, rates):
super(ASPP, self).__init__()
self.aspp = nn.ModuleList()
# 1x1 卷积
self.aspp.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# 空洞卷积
for rate in rates:
self.aspp.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
# 全局平均池化
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
# 输出层
self.output = nn.Sequential(
nn.Conv2d(out_channels * (len(rates) + 2), out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.5)
)
def forward(self, x):
size = x.size()
outputs = []
for module in self.aspp:
outputs.append(module(x))
# 全局平均池化
global_feat = self.global_avg_pool(x)
global_feat = F.interpolate(global_feat, size=(size[2], size[3]), mode='bilinear', align_corners=True)
outputs.append(global_feat)
# 拼接所有特征
x = torch.cat(outputs, dim=1)
return self.output(x)
class DeepLabV3Plus(nn.Module):
"""DeepLabV3+模型,适用于大尺度遥感图像语义分割"""
def __init__(self, num_classes, backbone='resnet50', output_stride=16, pretrained=True):
super(DeepLabV3Plus, self).__init__()
# 设置空洞卷积率
if output_stride == 16:
rates = [6, 12, 18]
elif output_stride == 8:
rates = [12, 24, 36]
else:
raise ValueError("output_stride 必须是 8 或 16!")
# 主干网络
if backbone == 'resnet50':
if pretrained:
self.backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
else:
self.backbone = resnet50(weights=None)
self.low_level_features = self.backbone.layer1
self.high_level_features = nn.Sequential(
self.backbone.layer2,
self.backbone.layer3,
self.backbone.layer4
)
low_level_channels = 256
high_level_channels = 2048
else:
raise ValueError(f"不支持的主干网络: {backbone}")
# ASPP模块
self.aspp = ASPP(high_level_channels, 256, rates)
# 低层特征处理
self.low_level_conv = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True)
)
# 解码器
self.decoder = nn.Sequential(
nn.Conv2d(304, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(256, num_classes, 1)
)
# 初始化权重
self._init_weight()
def forward(self, x):
size = x.size()
# 提取特征
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
# 低层特征
low_level_feat = self.low_level_features(x)
# 高层特征
x = self.high_level_features(low_level_feat)
x = self.aspp(x)
# 上采样高层特征
x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
# 处理低层特征
low_level_feat = self.low_level_conv(low_level_feat)
# 拼接特征
x = torch.cat((x, low_level_feat), dim=1)
# 解码
x = self.decoder(x)
# 上采样到原始大小
x = F.interpolate(x, size=size[2:], mode='bilinear', align_corners=True)
return x
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def get_backbone_params(self):
modules = [self.backbone.conv1, self.backbone.bn1, self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4]
for i in range(len(modules)):
for m in modules[i].named_modules():
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
def get_decoder_params(self):
modules = [self.aspp, self.low_level_conv, self.decoder]
for i in range(len(modules)):
for m in modules[i].named_modules():
if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d):
for p in m[1].parameters():
if p.requires_grad:
yield p
def freeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = False
def unfreeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = True