522 lines
26 KiB
Python
522 lines
26 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
from mmcv.cnn import ConvModule
|
||
from ..builder import HEADS
|
||
from .cascade_decode_head import BaseCascadeDecodeHead
|
||
from .shallow_head import ShallowNet
|
||
from ..losses import accuracy
|
||
from mmseg.models.utils.wrappers import resize
|
||
from mmseg.models.backbones.mambaunet import VSSM
|
||
|
||
|
||
class SegmentationHead(nn.Module):
|
||
def __init__(self, conv_cfg, norm_cfg, act_cfg, in_channels, mid_channels, n_classes, *args, **kwargs):
|
||
super(SegmentationHead, self).__init__()
|
||
|
||
self.conv_bn_relu = ConvModule(in_channels, mid_channels, 3,
|
||
stride=1,
|
||
padding=1,
|
||
conv_cfg=conv_cfg,
|
||
norm_cfg=norm_cfg,
|
||
act_cfg=act_cfg)
|
||
|
||
self.conv_out = nn.Conv2d(mid_channels, n_classes, kernel_size=1, bias=True)
|
||
|
||
def forward(self, x):
|
||
x = self.conv_bn_relu(x)
|
||
x = self.conv_out(x)
|
||
return x
|
||
|
||
|
||
class Lap_Pyramid_Conv(nn.Module):
|
||
def __init__(self, num_high=3, gauss_chl=3):
|
||
super(Lap_Pyramid_Conv, self).__init__()
|
||
|
||
self.num_high = num_high
|
||
self.gauss_chl = gauss_chl
|
||
self.kernel = self.gauss_kernel()
|
||
|
||
def gauss_kernel(self, device=torch.device('cuda')):
|
||
kernel = torch.tensor([[1., 4., 6., 4., 1],
|
||
[4., 16., 24., 16., 4.],
|
||
[6., 24., 36., 24., 6.],
|
||
[4., 16., 24., 16., 4.],
|
||
[1., 4., 6., 4., 1.]])
|
||
kernel /= 256.
|
||
kernel = kernel.repeat(self.gauss_chl, 1, 1, 1)
|
||
kernel = kernel.to(device)
|
||
return kernel
|
||
|
||
def downsample(self, x):
|
||
return x[:, :, ::2, ::2]
|
||
|
||
def upsample(self, x):
|
||
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
|
||
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
|
||
cc = cc.permute(0, 1, 3, 2)
|
||
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3)
|
||
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
|
||
x_up = cc.permute(0, 1, 3, 2)
|
||
return self.conv_gauss(x_up, 4 * self.kernel)
|
||
|
||
def conv_gauss(self, img, kernel):
|
||
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
|
||
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
|
||
return out
|
||
|
||
def pyramid_decom(self, img):
|
||
current = img
|
||
pyr = []
|
||
for _ in range(self.num_high):
|
||
filtered = self.conv_gauss(current, self.kernel)
|
||
down = self.downsample(filtered)
|
||
up = self.upsample(down)
|
||
if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
|
||
up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
|
||
diff = current - up
|
||
pyr.append(diff)
|
||
current = down
|
||
return pyr
|
||
|
||
class SRDecoder(nn.Module):
|
||
# super resolution decoder
|
||
def __init__(self, conv_cfg, norm_cfg, act_cfg, channels=128, up_lists=[2, 2, 2]):
|
||
super(SRDecoder, self).__init__()
|
||
self.conv1 = ConvModule(channels, channels // 2, 3, stride=1, padding=1, conv_cfg=conv_cfg,
|
||
norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||
self.up1 = nn.Upsample(scale_factor=up_lists[0])
|
||
self.conv2 = ConvModule(channels // 2, channels // 2, 3, stride=1, padding=1, conv_cfg=conv_cfg,
|
||
norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||
self.up2 = nn.Upsample(scale_factor=up_lists[1])
|
||
self.conv3 = ConvModule(channels // 2, channels, 3, stride=1, padding=1, conv_cfg=conv_cfg,
|
||
norm_cfg=norm_cfg, act_cfg=act_cfg)
|
||
self.up3 = nn.Upsample(scale_factor=up_lists[2])
|
||
self.conv_sr = SegmentationHead(conv_cfg, norm_cfg, act_cfg, channels, channels // 2, 3, kernel_size=1)
|
||
|
||
def forward(self, x, fa=False):#输入x是2 128 39 39
|
||
x = self.up1(x)
|
||
x = self.conv1(x)
|
||
x = self.up2(x)
|
||
x = self.conv2(x)
|
||
x = self.up3(x)
|
||
feats = self.conv3(x)
|
||
outs = self.conv_sr(feats)#2 3 624 624
|
||
if fa:
|
||
return feats, outs
|
||
else:
|
||
return outs
|
||
|
||
|
||
|
||
#将通道数写死
|
||
|
||
class Reducer(nn.Module):
|
||
# Reduce channel (typically to 128)
|
||
def __init__(self, in_channels=512, reduce=128, bn_relu=True):
|
||
super(Reducer, self).__init__()
|
||
self.bn_relu = bn_relu
|
||
self.conv1 = nn.Conv2d(in_channels, reduce, 1, bias=False)
|
||
if self.bn_relu:
|
||
self.bn1 = nn.BatchNorm2d(reduce)
|
||
|
||
def forward(self, x):
|
||
|
||
x = self.conv1(x)
|
||
if self.bn_relu:
|
||
x = self.bn1(x)
|
||
x = F.relu(x)
|
||
|
||
return x
|
||
|
||
|
||
|
||
|
||
|
||
from .fusion_module.my_res_pag import RelationAwareFusion_coordatt_res_pag
|
||
from .fusion_module.coord_pag_raf import RelationAwareFusion_coordatt_pag_raf
|
||
from .fusion_module.raf import RelationAwareFusion
|
||
from .my_fusion_module.add import Vssm_cnn_Add1,Vssm_cnn_Add2
|
||
from .my_fusion_module.cat import Vssm_cnn_Cat1,Vssm_cnn_Cat2
|
||
from .my_fusion_module.sp_con_pag import RelationAwareFusion_coordatt_pag_raf as RCPR
|
||
from .CNN_module.RFD import ShallowNet_RFD
|
||
from .CNN_module.convnext import convnext_supertiny
|
||
from .CNN_module.efficientnetv2 import efficientnetv2_s,efficientnetv2_m
|
||
from .CNN_module.segnext import segnext_t
|
||
from .CNN_module.resnet import resnet50,resnet18
|
||
from .CNN_module.model_v2 import MobileNetV2
|
||
from .CNN_module.model_v3 import mobilenet_v3_large
|
||
from .CNN_module.convnext import convnext_tiny
|
||
from .CNN_module.pkinet import PKINet
|
||
from .fusion_module.FAM import FeatureAggregationModule as FAM
|
||
from .fusion_module.FFM import FFM
|
||
from .fusion_module.GFM import Gated_Fusion as GFM
|
||
from .fusion_module.CLM import CrossFusionModule as CFM
|
||
from .fusion_module.CMX import FeatureFusionModule as cmxffm
|
||
#-------------------------------------------------------------#
|
||
#-------------------------------------------------------------#
|
||
@HEADS.register_module()
|
||
class ISDHead(BaseCascadeDecodeHead):
|
||
def __init__(self, down_ratio, prev_channels, reduce=False,fusion_mode='raf',consist=False,
|
||
model_cls='mamba',dims=[48, 96, 192, 384],depths=[1, 1, 2, 1],
|
||
shallow_model_inchan=6,lap=True,
|
||
**kwargs):
|
||
super(ISDHead, self).__init__(**kwargs)
|
||
self.down_ratio = down_ratio
|
||
|
||
self.sr_decoder = SRDecoder(self.conv_cfg, self.norm_cfg, self.act_cfg,
|
||
channels=self.channels, up_lists=[4, 2, 2])
|
||
# shallow branch
|
||
self.model_cls=model_cls
|
||
self.dims=dims
|
||
self.depth=depths
|
||
|
||
self.shallow_model_inchan=shallow_model_inchan
|
||
self.lap=lap
|
||
if self.model_cls =="mamba":
|
||
self.raf_channel = dims[0]
|
||
self.fuse8 = RelationAwareFusion(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=4)
|
||
self.fuse16 = RelationAwareFusion(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=8)
|
||
self.stdc_net = VSSM(
|
||
patch_size=4,
|
||
in_chans=self.shallow_model_inchan,
|
||
dims=self.dims,#dims=[96, 192, 384, 768]
|
||
depths=self.depth,
|
||
pretrain_model="/media/cm/2c5e0c44-80c0-4ab7-b8af-c5a0997b2a7f/zjb/UHR_Model/checkpoint/vmamba_tiny_e292.pth",)
|
||
elif self.model_cls =="ShallowNet":
|
||
self.raf_channel = 128
|
||
# self.fuse8 = RelationAwareFusion(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
|
||
# self.act_cfg, ext=2)
|
||
# self.fuse16 = RelationAwareFusion(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
|
||
# self.act_cfg, ext=4)
|
||
|
||
self.fuse8 =RCPR(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
|
||
self.act_cfg, ext=2)
|
||
self.fuse16 = RCPR(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
|
||
self.act_cfg, ext=4)
|
||
|
||
#fam
|
||
# self.fuse8=FAM(256,128)
|
||
# self.fuse16=FAM(512,128)
|
||
|
||
#FFM
|
||
# self.fuse8=FFM(256,128)
|
||
# self.fuse16=FFM(512,128)
|
||
#
|
||
#gfm
|
||
# self.fuse8=GFM(256,256)
|
||
# self.fuse16=GFM(512,256)
|
||
|
||
#CFM
|
||
# self.fuse8=CFM(256,128)
|
||
# self.fuse16=CFM(512,128)
|
||
|
||
#cmx
|
||
# self.fuse8=cmxffm(dim=128,reduction=1,num_heads=8,smooth_channels=256)
|
||
# self.fuse16=cmxffm(dim=128,reduction=1,num_heads=8,smooth_channels=512)
|
||
|
||
#
|
||
# self.fuse8 =Vssm_cnn_Add2(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
|
||
# self.fuse16 =Vssm_cnn_Add1(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
|
||
|
||
|
||
# self.fuse8 = Vssm_cnn_Cat2(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
|
||
# self.fuse16 = Vssm_cnn_Cat1(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
|
||
|
||
|
||
# self.stdc_net = ShallowNet(in_channels=self.shallow_model_inchan,
|
||
# pretrain_model="")#64.41
|
||
self.stdc_net = ShallowNet_RFD(in_channels=self.shallow_model_inchan,
|
||
# pretrain_model="/media/cm/2c5e0c44-80c0-4ab7-b8af-c5a0997b2a7f/zjb/UHR_Model/checkpoint/STDCNet813M_73.91.tar")
|
||
pretrain_model="")
|
||
|
||
# self.stdc_net =efficientnetv2_m()#0.625 1.375
|
||
# self.stdc_net =efficientnetv2_s()#0.5 1.25
|
||
# self.stdc_net = segnext_t()#0.5 1.25
|
||
# self.stdc_net=resnet50()# 4 8
|
||
# self.stdc_net=resnet18()# 1 2
|
||
# self.stdc_net=MobileNetV2()#0.25 0.5
|
||
# self.stdc_net=mobilenet_v3_large()#0.1875 0.625
|
||
# self.stdc_net=convnext_tiny()#1.5 3
|
||
# self.stdc_net=PKINet('SST')#0.5 1
|
||
# self.stdc_net=PKINet('SS')#1 2
|
||
|
||
|
||
|
||
self.lap_prymaid_conv = Lap_Pyramid_Conv(num_high=2)
|
||
self.conv_seg_aux_16 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
|
||
self.channels // 2, self.num_classes, kernel_size=1)
|
||
self.conv_seg_aux_8 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
|
||
self.channels // 2, self.num_classes, kernel_size=1)
|
||
self.conv_seg = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
|
||
self.channels // 2, self.num_classes, kernel_size=1)
|
||
|
||
self.reduce = Reducer() if reduce else None
|
||
self.fusion_mode=fusion_mode
|
||
|
||
self.channel_reduce1=Reducer(768,128)
|
||
self.consist=consist
|
||
|
||
def forward(self, inputs, prev_output, input_16,train_flag=True):
|
||
"""Forward function."""
|
||
#input:2 3 1224 1224
|
||
if self.lap:
|
||
prymaid_results = self.lap_prymaid_conv.pyramid_decom(inputs)#2 3 1224 1224;2 3 612 612
|
||
high_residual_1 = prymaid_results[0]#1224 1224 3 2
|
||
high_residual_2 = F.interpolate(prymaid_results[1], prymaid_results[0].size()[2:], mode='bilinear',
|
||
align_corners=False)#2 3 1224 1224
|
||
high_residual_input = torch.cat([high_residual_1, high_residual_2], dim=1)
|
||
|
||
elif not self.lap:
|
||
# high_residual_1 = inputs
|
||
# high_residual_2 = inputs
|
||
high_residual_input =inputs
|
||
|
||
|
||
if self.model_cls =="mamba":
|
||
feature_all=self.stdc_net(high_residual_input)
|
||
shallow_feat8, shallow_feat16 =feature_all[2],feature_all[3]
|
||
|
||
else:
|
||
shallow_feat8, shallow_feat16 = self.stdc_net(high_residual_input)#2 256 153 153;2 512 77 77
|
||
|
||
|
||
deep_feat = prev_output[0]#2 128 39 39 deeplabv3分支的分类前结果
|
||
deep_feat16=input_16
|
||
# deep_feat = prev_output # 2 128 39 39 deeplabv3分支的分类前结果
|
||
# ----------------------------------------------------#
|
||
# ----------------------------------------------------#
|
||
|
||
|
||
if self.reduce is not None:
|
||
deep_feat = self.reduce(deep_feat)
|
||
# stage 1
|
||
|
||
_, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
|
||
# stage 2
|
||
_, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
|
||
output = self.cls_seg(fused_feat_8) # 2 7 153 153
|
||
|
||
|
||
# torch.save(deep_feat,
|
||
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-M_feat-tensor.pt')
|
||
# torch.save(shallow_feat16,
|
||
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-C-feat-tensor.pt')
|
||
# torch.save(fused_feat_8,
|
||
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-fusion-feat-tensor.pt')
|
||
|
||
#
|
||
# _, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
|
||
# # stage 2
|
||
# _, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
|
||
# output = self.cls_seg(fused_feat_8)#2 7 153 153
|
||
if train_flag:#结构蒸馏损失
|
||
output_aux16 = self.conv_seg_aux_16(aux_feat8)#aux——feat8是上采样后的损失 2 7 153 153 cbr
|
||
output_aux8 = self.conv_seg_aux_8(aux_feat16)#aux-feat16是aware fusion 上采样后的损失 2 7 77 77 cbr
|
||
feats, output_sr = self.sr_decoder(deep_feat , True)#2 128 624 624;2 3 624 624
|
||
#重构损失
|
||
|
||
losses_re = self.image_recon_loss(high_residual_input, output_sr, re_weight=0.1)#将拉普拉斯金字塔的结果和deeplab编码再解码的结果做重构损失
|
||
if self.consist:#如果要做特征之间的一致性约束
|
||
loss_consist1=self.gauss_spatial_consistency_loss(self.reduce256_1(shallow_feat8),self.reduce384_1(deep_feat16))
|
||
loss_consist2=self.gauss_spatial_consistency_loss(self.reduce512_1(shallow_feat16), self.reduce128_1(deep_feat))
|
||
losses_fa = self.feature_affinity_loss(deep_feat, feats)#把上采样后的特征与deeplab提取的特征做特征关系损失
|
||
return output, output_aux16, output_aux8, losses_re, losses_fa,loss_consist1,loss_consist2
|
||
else:
|
||
losses_fa = self.feature_affinity_loss(deep_feat, feats) # 把上采样后的特征与deeplab提取的特征做特征关系损失
|
||
return output, output_aux16, output_aux8, losses_re, losses_fa
|
||
else:
|
||
return output
|
||
|
||
# @HEADS.register_module()
|
||
# class ISDHead(BaseCascadeDecodeHead):
|
||
# def __init__(self, down_ratio, prev_channels, reduce=False,consist=True, **kwargs):
|
||
# super(ISDHead, self).__init__(**kwargs)
|
||
# self.down_ratio = down_ratio
|
||
# self.fuse8 = RelationAwareFusion(self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=2)
|
||
# self.fuse16 = RelationAwareFusion(self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=4)
|
||
# self.sr_decoder = SRDecoder(self.conv_cfg, self.norm_cfg, self.act_cfg,
|
||
# channels=self.channels, up_lists=[4, 2, 2])
|
||
# # shallow branch
|
||
# self.stdc_net = ShallowNet(in_channels=6, pretrain_model="/media/cm/D450B76750B74ECC/Fusion_model/checkpoint/STDCNet813M_73.91.tar")
|
||
# self.lap_prymaid_conv = Lap_Pyramid_Conv(num_high=2)
|
||
# self.conv_seg_aux_16 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
|
||
# self.channels // 2, self.num_classes, kernel_size=1)
|
||
# self.conv_seg_aux_8 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
|
||
# self.channels // 2, self.num_classes, kernel_size=1)
|
||
# self.conv_seg = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
|
||
# self.channels // 2, self.num_classes, kernel_size=1)
|
||
#
|
||
# self.reduce = Reducer() if reduce else None
|
||
#
|
||
# self.reduce384_1=Reducer(384,1)
|
||
# self.reduce128_1=Reducer(128,1)
|
||
# self.reduce256_1=Reducer(256,1)
|
||
# self.reduce512_1=Reducer(512,1)
|
||
#
|
||
# self.consist=consist
|
||
#
|
||
# def forward(self, inputs, prev_output,input_16,train_flag=True):
|
||
# """Forward function."""
|
||
# #input:2 3 1224 1224
|
||
# prymaid_results = self.lap_prymaid_conv.pyramid_decom(inputs)#2 3 1224 1224;2 3 612 612
|
||
# high_residual_1 = prymaid_results[0]#1224 1224 3 2
|
||
# high_residual_2 = F.interpolate(prymaid_results[1], prymaid_results[0].size()[2:], mode='bilinear',
|
||
# align_corners=False)#2 3 1224 1224
|
||
# high_residual_input = torch.cat([high_residual_1, high_residual_2], dim=1)
|
||
# shallow_feat8, shallow_feat16 = self.stdc_net(high_residual_input)#2 256 153 153;2 512 77 77
|
||
# deep_feat = prev_output[0]#2 128 39 39 deeplabv3分支的分类前结果
|
||
# deep_feat16=input_16
|
||
# if self.reduce is not None:
|
||
# deep_feat = self.reduce(deep_feat)
|
||
# # stage 1
|
||
# _, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
|
||
# # stage 2
|
||
# _, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
|
||
# output = self.cls_seg(fused_feat_8)#2 7 153 153
|
||
# if train_flag:#结构蒸馏损失
|
||
# output_aux16 = self.conv_seg_aux_16(aux_feat8)#aux——feat8是上采样后的损失 2 7 153 153 cbr
|
||
# output_aux8 = self.conv_seg_aux_8(aux_feat16)#aux-feat16是aware fusion 上采样后的损失 2 7 77 77 cbr
|
||
# feats, output_sr = self.sr_decoder(deep_feat, True)#2 128 624 624;2 3 624 624
|
||
# #重构损失
|
||
# losses_re = self.image_recon_loss(high_residual_1 + high_residual_2, output_sr, re_weight=0.1)#将拉普拉斯金字塔的结果和deeplab编码再解码的结果做重构损失
|
||
# if self.consist:
|
||
# loss_consist1=self.gauss_spatial_consistency_loss(self.reduce256_1(shallow_feat8),self.reduce384_1(deep_feat16))
|
||
# loss_consist2=self.gauss_spatial_consistency_loss(self.reduce512_1(shallow_feat16), self.reduce128_1(deep_feat))
|
||
# losses_fa = self.feature_affinity_loss(deep_feat, feats)#把上采样后的特征与deeplab提取的特征做特征关系损失
|
||
# return output, output_aux16, output_aux8, losses_re, losses_fa,loss_consist1,loss_consist2
|
||
# else:
|
||
# losses_fa = self.feature_affinity_loss(deep_feat, feats) # 把上采样后的特征与deeplab提取的特征做特征关系损失
|
||
# return output, output_aux16, output_aux8, losses_re, losses_fa
|
||
#
|
||
#
|
||
# else:
|
||
# return output
|
||
|
||
def image_recon_loss(self, img, pred, re_weight=0.5):
|
||
loss = dict()
|
||
if pred.size()[2:] != img.size()[2:]:
|
||
pred = F.interpolate(pred, img.size()[2:], mode='bilinear', align_corners=False)
|
||
recon_loss = F.mse_loss(pred, img) * re_weight#均方跟损失
|
||
loss['recon_losses'] = recon_loss
|
||
return loss
|
||
|
||
def gaussion_similarity(self,tensor1,tensor2,sigma=1.0):
|
||
distance=torch.sum((tensor1-tensor2)**2,dim=[1,2,3],keepdim=True)
|
||
similarity=torch.exp(-distance/(2*sigma**2))
|
||
return similarity
|
||
def gauss_spatial_consistency_loss(self, img, pred, re_weight=0.5):
|
||
loss = dict()
|
||
if pred.size()[2:] != img.size()[2:]:
|
||
pred = F.interpolate(pred, img.size()[2:], mode='bilinear', align_corners=False)
|
||
#similarity=self.gaussion_similarity(img,pred)
|
||
recon_loss = F.mse_loss(pred, img) * re_weight#均方跟损失
|
||
loss['gauss_consistency_losses'] = recon_loss
|
||
return loss
|
||
#这个函数还不是很懂,回归损失函数
|
||
def feature_affinity_loss(self, seg_feats, sr_feats, fa_weight=1., eps=1e-6):#上采样后的特征为sr_feats2 128 624 624;deeplab特征为seg_feats2 128 39 39
|
||
if seg_feats.size()[2:] != sr_feats.size()[2:]:
|
||
sr_feats = F.interpolate(sr_feats, seg_feats.size()[2:], mode='bilinear', align_corners=False)#把上采样的特征又采样回去
|
||
loss = dict()
|
||
# flatten:
|
||
seg_feats_flatten = torch.flatten(seg_feats, start_dim=2)#2 128 1521
|
||
sr_feats_flatten = torch.flatten(sr_feats, start_dim=2)#2 128 1521
|
||
# L2 norm
|
||
seg_norm = torch.norm(seg_feats_flatten, p=2, dim=2, keepdim=True)#2 128 1 通道唯独做norm
|
||
sr_norm = torch.norm(sr_feats_flatten, p=2, dim=2, keepdim=True)#2 128 1
|
||
# similiarity
|
||
seg_feats_flatten = seg_feats_flatten / (seg_norm + eps)#2 128 1521
|
||
sr_feats_flatten = sr_feats_flatten / (sr_norm + eps)#2 128 1521
|
||
seg_sim = torch.matmul(seg_feats_flatten.permute(0, 2, 1), seg_feats_flatten)#通道唯独做乘法,虽然是通道维度但是×的是每个像素,所以就是求像素之间的余弦相似度,看起来有点像自注意力了
|
||
sr_sim = torch.matmul(sr_feats_flatten.permute(0, 2, 1), sr_feats_flatten)#
|
||
# L1 loss
|
||
loss['fa_loss'] = F.l1_loss(seg_sim, sr_sim.detach()) * fa_weight#detach()分离梯度了
|
||
return loss
|
||
def _stack_batch_gt(self, batch_data_samples) :
|
||
gt_semantic_segs = [
|
||
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
||
]
|
||
return torch.stack(gt_semantic_segs, dim=0)
|
||
|
||
def loss_by_feat(self, seg_logits,
|
||
batch_data_samples) -> dict:
|
||
"""Compute segmentation loss.
|
||
|
||
Args:
|
||
seg_logits (Tensor): The output from decode head forward function.
|
||
batch_data_samples (List[:obj:`SegDataSample`]): The seg
|
||
data samples. It usually includes information such
|
||
as `metainfo` and `gt_sem_seg`.
|
||
|
||
Returns:
|
||
dict[str, Tensor]: a dictionary of loss components
|
||
"""
|
||
|
||
seg_label = self._stack_batch_gt(batch_data_samples)#那个标签列表会在这个函数中生成为bs 1 w h的标签
|
||
loss = dict()
|
||
seg_logits = resize(
|
||
input=seg_logits,
|
||
size=seg_label.shape[2:],
|
||
mode='bilinear',
|
||
align_corners=self.align_corners)#将head产生的结果进行上采样
|
||
if self.sampler is not None:
|
||
seg_weight = self.sampler.sample(seg_logits, seg_label)
|
||
else:
|
||
seg_weight = None
|
||
seg_label = seg_label.squeeze(1)#压缩一维
|
||
|
||
if not isinstance(self.loss_decode, nn.ModuleList):#调用损失函数
|
||
losses_decode = [self.loss_decode]
|
||
else:
|
||
losses_decode = self.loss_decode
|
||
for loss_decode in losses_decode:
|
||
if loss_decode.loss_name not in loss:
|
||
loss[loss_decode.loss_name] = loss_decode(
|
||
seg_logits,
|
||
seg_label,
|
||
weight=seg_weight,
|
||
ignore_index=self.ignore_index)#计算损失
|
||
else:
|
||
loss[loss_decode.loss_name] += loss_decode(
|
||
seg_logits,
|
||
seg_label,
|
||
weight=seg_weight,
|
||
ignore_index=self.ignore_index)
|
||
|
||
loss['acc_seg'] = accuracy(
|
||
seg_logits, seg_label, ignore_index=self.ignore_index)
|
||
return loss
|
||
|
||
def forward_train(self, inputs, prev_output,input_16, gt_semantic_seg):
|
||
if self.consist:
|
||
seg_logits, seg_logits_aux16, seg_logits_aux8, losses_recon, losses_fa,loss_consist1,loss_consist2= self.forward(inputs, prev_output,input_16,)
|
||
losses = self.loss_by_feat(seg_logits, gt_semantic_seg)
|
||
losses_aux16 = self.loss_by_feat(seg_logits_aux16, gt_semantic_seg)
|
||
losses_aux8 = self.loss_by_feat(seg_logits_aux8, gt_semantic_seg)
|
||
return losses, losses_aux16, losses_aux8, losses_recon, losses_fa,loss_consist1,loss_consist2
|
||
else:
|
||
seg_logits, seg_logits_aux16, seg_logits_aux8, losses_recon, losses_fa = self.forward(
|
||
inputs, prev_output, input_16)
|
||
losses = self.loss_by_feat(seg_logits, gt_semantic_seg)
|
||
losses_aux16 = self.loss_by_feat(seg_logits_aux16, gt_semantic_seg)
|
||
losses_aux8 = self.loss_by_feat(seg_logits_aux8, gt_semantic_seg)
|
||
return losses, losses_aux16, losses_aux8, losses_recon, losses_fa
|
||
def forward_test(self, inputs, prev_output,input_16, ):
|
||
"""Forward function for testing.
|
||
|
||
Args:
|
||
inputs (list[Tensor]): List of multi-level img features.
|
||
prev_output (Tensor): The output of previous decode head.
|
||
img_metas (list[dict]): List of image info dict where each dict
|
||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||
For details on the values of these keys see
|
||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||
test_cfg (dict): The testing config.
|
||
|
||
Returns:
|
||
Tensor: Output segmentation map.
|
||
"""
|
||
|
||
return self.forward(inputs, prev_output,input_16, False) |