522 lines
26 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from ..builder import HEADS
from .cascade_decode_head import BaseCascadeDecodeHead
from .shallow_head import ShallowNet
from ..losses import accuracy
from mmseg.models.utils.wrappers import resize
from mmseg.models.backbones.mambaunet import VSSM
class SegmentationHead(nn.Module):
def __init__(self, conv_cfg, norm_cfg, act_cfg, in_channels, mid_channels, n_classes, *args, **kwargs):
super(SegmentationHead, self).__init__()
self.conv_bn_relu = ConvModule(in_channels, mid_channels, 3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv_out = nn.Conv2d(mid_channels, n_classes, kernel_size=1, bias=True)
def forward(self, x):
x = self.conv_bn_relu(x)
x = self.conv_out(x)
return x
class Lap_Pyramid_Conv(nn.Module):
def __init__(self, num_high=3, gauss_chl=3):
super(Lap_Pyramid_Conv, self).__init__()
self.num_high = num_high
self.gauss_chl = gauss_chl
self.kernel = self.gauss_kernel()
def gauss_kernel(self, device=torch.device('cuda')):
kernel = torch.tensor([[1., 4., 6., 4., 1],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]])
kernel /= 256.
kernel = kernel.repeat(self.gauss_chl, 1, 1, 1)
kernel = kernel.to(device)
return kernel
def downsample(self, x):
return x[:, :, ::2, ::2]
def upsample(self, x):
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
cc = cc.permute(0, 1, 3, 2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
x_up = cc.permute(0, 1, 3, 2)
return self.conv_gauss(x_up, 4 * self.kernel)
def conv_gauss(self, img, kernel):
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
return out
def pyramid_decom(self, img):
current = img
pyr = []
for _ in range(self.num_high):
filtered = self.conv_gauss(current, self.kernel)
down = self.downsample(filtered)
up = self.upsample(down)
if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
diff = current - up
pyr.append(diff)
current = down
return pyr
class SRDecoder(nn.Module):
# super resolution decoder
def __init__(self, conv_cfg, norm_cfg, act_cfg, channels=128, up_lists=[2, 2, 2]):
super(SRDecoder, self).__init__()
self.conv1 = ConvModule(channels, channels // 2, 3, stride=1, padding=1, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, act_cfg=act_cfg)
self.up1 = nn.Upsample(scale_factor=up_lists[0])
self.conv2 = ConvModule(channels // 2, channels // 2, 3, stride=1, padding=1, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, act_cfg=act_cfg)
self.up2 = nn.Upsample(scale_factor=up_lists[1])
self.conv3 = ConvModule(channels // 2, channels, 3, stride=1, padding=1, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, act_cfg=act_cfg)
self.up3 = nn.Upsample(scale_factor=up_lists[2])
self.conv_sr = SegmentationHead(conv_cfg, norm_cfg, act_cfg, channels, channels // 2, 3, kernel_size=1)
def forward(self, x, fa=False):#输入x是2 128 39 39
x = self.up1(x)
x = self.conv1(x)
x = self.up2(x)
x = self.conv2(x)
x = self.up3(x)
feats = self.conv3(x)
outs = self.conv_sr(feats)#2 3 624 624
if fa:
return feats, outs
else:
return outs
#将通道数写死
class Reducer(nn.Module):
# Reduce channel (typically to 128)
def __init__(self, in_channels=512, reduce=128, bn_relu=True):
super(Reducer, self).__init__()
self.bn_relu = bn_relu
self.conv1 = nn.Conv2d(in_channels, reduce, 1, bias=False)
if self.bn_relu:
self.bn1 = nn.BatchNorm2d(reduce)
def forward(self, x):
x = self.conv1(x)
if self.bn_relu:
x = self.bn1(x)
x = F.relu(x)
return x
from .fusion_module.my_res_pag import RelationAwareFusion_coordatt_res_pag
from .fusion_module.coord_pag_raf import RelationAwareFusion_coordatt_pag_raf
from .fusion_module.raf import RelationAwareFusion
from .my_fusion_module.add import Vssm_cnn_Add1,Vssm_cnn_Add2
from .my_fusion_module.cat import Vssm_cnn_Cat1,Vssm_cnn_Cat2
from .my_fusion_module.sp_con_pag import RelationAwareFusion_coordatt_pag_raf as RCPR
from .CNN_module.RFD import ShallowNet_RFD
from .CNN_module.convnext import convnext_supertiny
from .CNN_module.efficientnetv2 import efficientnetv2_s,efficientnetv2_m
from .CNN_module.segnext import segnext_t
from .CNN_module.resnet import resnet50,resnet18
from .CNN_module.model_v2 import MobileNetV2
from .CNN_module.model_v3 import mobilenet_v3_large
from .CNN_module.convnext import convnext_tiny
from .CNN_module.pkinet import PKINet
from .fusion_module.FAM import FeatureAggregationModule as FAM
from .fusion_module.FFM import FFM
from .fusion_module.GFM import Gated_Fusion as GFM
from .fusion_module.CLM import CrossFusionModule as CFM
from .fusion_module.CMX import FeatureFusionModule as cmxffm
#-------------------------------------------------------------#
#-------------------------------------------------------------#
@HEADS.register_module()
class ISDHead(BaseCascadeDecodeHead):
def __init__(self, down_ratio, prev_channels, reduce=False,fusion_mode='raf',consist=False,
model_cls='mamba',dims=[48, 96, 192, 384],depths=[1, 1, 2, 1],
shallow_model_inchan=6,lap=True,
**kwargs):
super(ISDHead, self).__init__(**kwargs)
self.down_ratio = down_ratio
self.sr_decoder = SRDecoder(self.conv_cfg, self.norm_cfg, self.act_cfg,
channels=self.channels, up_lists=[4, 2, 2])
# shallow branch
self.model_cls=model_cls
self.dims=dims
self.depth=depths
self.shallow_model_inchan=shallow_model_inchan
self.lap=lap
if self.model_cls =="mamba":
self.raf_channel = dims[0]
self.fuse8 = RelationAwareFusion(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=4)
self.fuse16 = RelationAwareFusion(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=8)
self.stdc_net = VSSM(
patch_size=4,
in_chans=self.shallow_model_inchan,
dims=self.dims,#dims=[96, 192, 384, 768]
depths=self.depth,
pretrain_model="/media/cm/2c5e0c44-80c0-4ab7-b8af-c5a0997b2a7f/zjb/UHR_Model/checkpoint/vmamba_tiny_e292.pth",)
elif self.model_cls =="ShallowNet":
self.raf_channel = 128
# self.fuse8 = RelationAwareFusion(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
# self.act_cfg, ext=2)
# self.fuse16 = RelationAwareFusion(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
# self.act_cfg, ext=4)
self.fuse8 =RCPR(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
self.act_cfg, ext=2)
self.fuse16 = RCPR(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
self.act_cfg, ext=4)
#fam
# self.fuse8=FAM(256,128)
# self.fuse16=FAM(512,128)
#FFM
# self.fuse8=FFM(256,128)
# self.fuse16=FFM(512,128)
#
#gfm
# self.fuse8=GFM(256,256)
# self.fuse16=GFM(512,256)
#CFM
# self.fuse8=CFM(256,128)
# self.fuse16=CFM(512,128)
#cmx
# self.fuse8=cmxffm(dim=128,reduction=1,num_heads=8,smooth_channels=256)
# self.fuse16=cmxffm(dim=128,reduction=1,num_heads=8,smooth_channels=512)
#
# self.fuse8 =Vssm_cnn_Add2(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
# self.fuse16 =Vssm_cnn_Add1(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
# self.fuse8 = Vssm_cnn_Cat2(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
# self.fuse16 = Vssm_cnn_Cat1(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
# self.stdc_net = ShallowNet(in_channels=self.shallow_model_inchan,
# pretrain_model="")#64.41
self.stdc_net = ShallowNet_RFD(in_channels=self.shallow_model_inchan,
# pretrain_model="/media/cm/2c5e0c44-80c0-4ab7-b8af-c5a0997b2a7f/zjb/UHR_Model/checkpoint/STDCNet813M_73.91.tar")
pretrain_model="")
# self.stdc_net =efficientnetv2_m()#0.625 1.375
# self.stdc_net =efficientnetv2_s()#0.5 1.25
# self.stdc_net = segnext_t()#0.5 1.25
# self.stdc_net=resnet50()# 4 8
# self.stdc_net=resnet18()# 1 2
# self.stdc_net=MobileNetV2()#0.25 0.5
# self.stdc_net=mobilenet_v3_large()#0.1875 0.625
# self.stdc_net=convnext_tiny()#1.5 3
# self.stdc_net=PKINet('SST')#0.5 1
# self.stdc_net=PKINet('SS')#1 2
self.lap_prymaid_conv = Lap_Pyramid_Conv(num_high=2)
self.conv_seg_aux_16 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
self.channels // 2, self.num_classes, kernel_size=1)
self.conv_seg_aux_8 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
self.channels // 2, self.num_classes, kernel_size=1)
self.conv_seg = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
self.channels // 2, self.num_classes, kernel_size=1)
self.reduce = Reducer() if reduce else None
self.fusion_mode=fusion_mode
self.channel_reduce1=Reducer(768,128)
self.consist=consist
def forward(self, inputs, prev_output, input_16,train_flag=True):
"""Forward function."""
#input:2 3 1224 1224
if self.lap:
prymaid_results = self.lap_prymaid_conv.pyramid_decom(inputs)#2 3 1224 1224;2 3 612 612
high_residual_1 = prymaid_results[0]#1224 1224 3 2
high_residual_2 = F.interpolate(prymaid_results[1], prymaid_results[0].size()[2:], mode='bilinear',
align_corners=False)#2 3 1224 1224
high_residual_input = torch.cat([high_residual_1, high_residual_2], dim=1)
elif not self.lap:
# high_residual_1 = inputs
# high_residual_2 = inputs
high_residual_input =inputs
if self.model_cls =="mamba":
feature_all=self.stdc_net(high_residual_input)
shallow_feat8, shallow_feat16 =feature_all[2],feature_all[3]
else:
shallow_feat8, shallow_feat16 = self.stdc_net(high_residual_input)#2 256 153 153;2 512 77 77
deep_feat = prev_output[0]#2 128 39 39 deeplabv3分支的分类前结果
deep_feat16=input_16
# deep_feat = prev_output # 2 128 39 39 deeplabv3分支的分类前结果
# ----------------------------------------------------#
# ----------------------------------------------------#
if self.reduce is not None:
deep_feat = self.reduce(deep_feat)
# stage 1
_, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
# stage 2
_, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
output = self.cls_seg(fused_feat_8) # 2 7 153 153
# torch.save(deep_feat,
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-M_feat-tensor.pt')
# torch.save(shallow_feat16,
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-C-feat-tensor.pt')
# torch.save(fused_feat_8,
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-fusion-feat-tensor.pt')
#
# _, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
# # stage 2
# _, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
# output = self.cls_seg(fused_feat_8)#2 7 153 153
if train_flag:#结构蒸馏损失
output_aux16 = self.conv_seg_aux_16(aux_feat8)#aux——feat8是上采样后的损失 2 7 153 153 cbr
output_aux8 = self.conv_seg_aux_8(aux_feat16)#aux-feat16是aware fusion 上采样后的损失 2 7 77 77 cbr
feats, output_sr = self.sr_decoder(deep_feat , True)#2 128 624 624;2 3 624 624
#重构损失
losses_re = self.image_recon_loss(high_residual_input, output_sr, re_weight=0.1)#将拉普拉斯金字塔的结果和deeplab编码再解码的结果做重构损失
if self.consist:#如果要做特征之间的一致性约束
loss_consist1=self.gauss_spatial_consistency_loss(self.reduce256_1(shallow_feat8),self.reduce384_1(deep_feat16))
loss_consist2=self.gauss_spatial_consistency_loss(self.reduce512_1(shallow_feat16), self.reduce128_1(deep_feat))
losses_fa = self.feature_affinity_loss(deep_feat, feats)#把上采样后的特征与deeplab提取的特征做特征关系损失
return output, output_aux16, output_aux8, losses_re, losses_fa,loss_consist1,loss_consist2
else:
losses_fa = self.feature_affinity_loss(deep_feat, feats) # 把上采样后的特征与deeplab提取的特征做特征关系损失
return output, output_aux16, output_aux8, losses_re, losses_fa
else:
return output
# @HEADS.register_module()
# class ISDHead(BaseCascadeDecodeHead):
# def __init__(self, down_ratio, prev_channels, reduce=False,consist=True, **kwargs):
# super(ISDHead, self).__init__(**kwargs)
# self.down_ratio = down_ratio
# self.fuse8 = RelationAwareFusion(self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=2)
# self.fuse16 = RelationAwareFusion(self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=4)
# self.sr_decoder = SRDecoder(self.conv_cfg, self.norm_cfg, self.act_cfg,
# channels=self.channels, up_lists=[4, 2, 2])
# # shallow branch
# self.stdc_net = ShallowNet(in_channels=6, pretrain_model="/media/cm/D450B76750B74ECC/Fusion_model/checkpoint/STDCNet813M_73.91.tar")
# self.lap_prymaid_conv = Lap_Pyramid_Conv(num_high=2)
# self.conv_seg_aux_16 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
# self.channels // 2, self.num_classes, kernel_size=1)
# self.conv_seg_aux_8 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
# self.channels // 2, self.num_classes, kernel_size=1)
# self.conv_seg = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
# self.channels // 2, self.num_classes, kernel_size=1)
#
# self.reduce = Reducer() if reduce else None
#
# self.reduce384_1=Reducer(384,1)
# self.reduce128_1=Reducer(128,1)
# self.reduce256_1=Reducer(256,1)
# self.reduce512_1=Reducer(512,1)
#
# self.consist=consist
#
# def forward(self, inputs, prev_output,input_16,train_flag=True):
# """Forward function."""
# #input:2 3 1224 1224
# prymaid_results = self.lap_prymaid_conv.pyramid_decom(inputs)#2 3 1224 1224;2 3 612 612
# high_residual_1 = prymaid_results[0]#1224 1224 3 2
# high_residual_2 = F.interpolate(prymaid_results[1], prymaid_results[0].size()[2:], mode='bilinear',
# align_corners=False)#2 3 1224 1224
# high_residual_input = torch.cat([high_residual_1, high_residual_2], dim=1)
# shallow_feat8, shallow_feat16 = self.stdc_net(high_residual_input)#2 256 153 153;2 512 77 77
# deep_feat = prev_output[0]#2 128 39 39 deeplabv3分支的分类前结果
# deep_feat16=input_16
# if self.reduce is not None:
# deep_feat = self.reduce(deep_feat)
# # stage 1
# _, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
# # stage 2
# _, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
# output = self.cls_seg(fused_feat_8)#2 7 153 153
# if train_flag:#结构蒸馏损失
# output_aux16 = self.conv_seg_aux_16(aux_feat8)#aux——feat8是上采样后的损失 2 7 153 153 cbr
# output_aux8 = self.conv_seg_aux_8(aux_feat16)#aux-feat16是aware fusion 上采样后的损失 2 7 77 77 cbr
# feats, output_sr = self.sr_decoder(deep_feat, True)#2 128 624 624;2 3 624 624
# #重构损失
# losses_re = self.image_recon_loss(high_residual_1 + high_residual_2, output_sr, re_weight=0.1)#将拉普拉斯金字塔的结果和deeplab编码再解码的结果做重构损失
# if self.consist:
# loss_consist1=self.gauss_spatial_consistency_loss(self.reduce256_1(shallow_feat8),self.reduce384_1(deep_feat16))
# loss_consist2=self.gauss_spatial_consistency_loss(self.reduce512_1(shallow_feat16), self.reduce128_1(deep_feat))
# losses_fa = self.feature_affinity_loss(deep_feat, feats)#把上采样后的特征与deeplab提取的特征做特征关系损失
# return output, output_aux16, output_aux8, losses_re, losses_fa,loss_consist1,loss_consist2
# else:
# losses_fa = self.feature_affinity_loss(deep_feat, feats) # 把上采样后的特征与deeplab提取的特征做特征关系损失
# return output, output_aux16, output_aux8, losses_re, losses_fa
#
#
# else:
# return output
def image_recon_loss(self, img, pred, re_weight=0.5):
loss = dict()
if pred.size()[2:] != img.size()[2:]:
pred = F.interpolate(pred, img.size()[2:], mode='bilinear', align_corners=False)
recon_loss = F.mse_loss(pred, img) * re_weight#均方跟损失
loss['recon_losses'] = recon_loss
return loss
def gaussion_similarity(self,tensor1,tensor2,sigma=1.0):
distance=torch.sum((tensor1-tensor2)**2,dim=[1,2,3],keepdim=True)
similarity=torch.exp(-distance/(2*sigma**2))
return similarity
def gauss_spatial_consistency_loss(self, img, pred, re_weight=0.5):
loss = dict()
if pred.size()[2:] != img.size()[2:]:
pred = F.interpolate(pred, img.size()[2:], mode='bilinear', align_corners=False)
#similarity=self.gaussion_similarity(img,pred)
recon_loss = F.mse_loss(pred, img) * re_weight#均方跟损失
loss['gauss_consistency_losses'] = recon_loss
return loss
#这个函数还不是很懂,回归损失函数
def feature_affinity_loss(self, seg_feats, sr_feats, fa_weight=1., eps=1e-6):#上采样后的特征为sr_feats2 128 624 624deeplab特征为seg_feats2 128 39 39
if seg_feats.size()[2:] != sr_feats.size()[2:]:
sr_feats = F.interpolate(sr_feats, seg_feats.size()[2:], mode='bilinear', align_corners=False)#把上采样的特征又采样回去
loss = dict()
# flatten:
seg_feats_flatten = torch.flatten(seg_feats, start_dim=2)#2 128 1521
sr_feats_flatten = torch.flatten(sr_feats, start_dim=2)#2 128 1521
# L2 norm
seg_norm = torch.norm(seg_feats_flatten, p=2, dim=2, keepdim=True)#2 128 1 通道唯独做norm
sr_norm = torch.norm(sr_feats_flatten, p=2, dim=2, keepdim=True)#2 128 1
# similiarity
seg_feats_flatten = seg_feats_flatten / (seg_norm + eps)#2 128 1521
sr_feats_flatten = sr_feats_flatten / (sr_norm + eps)#2 128 1521
seg_sim = torch.matmul(seg_feats_flatten.permute(0, 2, 1), seg_feats_flatten)#通道唯独做乘法,虽然是通道维度但是×的是每个像素,所以就是求像素之间的余弦相似度,看起来有点像自注意力了
sr_sim = torch.matmul(sr_feats_flatten.permute(0, 2, 1), sr_feats_flatten)#
# L1 loss
loss['fa_loss'] = F.l1_loss(seg_sim, sr_sim.detach()) * fa_weight#detach分离梯度了
return loss
def _stack_batch_gt(self, batch_data_samples) :
gt_semantic_segs = [
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
]
return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logits,
batch_data_samples) -> dict:
"""Compute segmentation loss.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_label = self._stack_batch_gt(batch_data_samples)#那个标签列表会在这个函数中生成为bs 1 w h的标签
loss = dict()
seg_logits = resize(
input=seg_logits,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)#将head产生的结果进行上采样
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)#压缩一维
if not isinstance(self.loss_decode, nn.ModuleList):#调用损失函数
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)#计算损失
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(
seg_logits, seg_label, ignore_index=self.ignore_index)
return loss
def forward_train(self, inputs, prev_output,input_16, gt_semantic_seg):
if self.consist:
seg_logits, seg_logits_aux16, seg_logits_aux8, losses_recon, losses_fa,loss_consist1,loss_consist2= self.forward(inputs, prev_output,input_16,)
losses = self.loss_by_feat(seg_logits, gt_semantic_seg)
losses_aux16 = self.loss_by_feat(seg_logits_aux16, gt_semantic_seg)
losses_aux8 = self.loss_by_feat(seg_logits_aux8, gt_semantic_seg)
return losses, losses_aux16, losses_aux8, losses_recon, losses_fa,loss_consist1,loss_consist2
else:
seg_logits, seg_logits_aux16, seg_logits_aux8, losses_recon, losses_fa = self.forward(
inputs, prev_output, input_16)
losses = self.loss_by_feat(seg_logits, gt_semantic_seg)
losses_aux16 = self.loss_by_feat(seg_logits_aux16, gt_semantic_seg)
losses_aux8 = self.loss_by_feat(seg_logits_aux8, gt_semantic_seg)
return losses, losses_aux16, losses_aux8, losses_recon, losses_fa
def forward_test(self, inputs, prev_output,input_16, ):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs, prev_output,input_16, False)