522 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from ..builder import HEADS
from .cascade_decode_head import BaseCascadeDecodeHead
from .shallow_head import ShallowNet
from ..losses import accuracy
from mmseg.models.utils.wrappers import resize
from mmseg.models.backbones.mambaunet import VSSM
class SegmentationHead(nn.Module):
def __init__(self, conv_cfg, norm_cfg, act_cfg, in_channels, mid_channels, n_classes, *args, **kwargs):
super(SegmentationHead, self).__init__()
self.conv_bn_relu = ConvModule(in_channels, mid_channels, 3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv_out = nn.Conv2d(mid_channels, n_classes, kernel_size=1, bias=True)
def forward(self, x):
x = self.conv_bn_relu(x)
x = self.conv_out(x)
return x
class Lap_Pyramid_Conv(nn.Module):
def __init__(self, num_high=3, gauss_chl=3):
super(Lap_Pyramid_Conv, self).__init__()
self.num_high = num_high
self.gauss_chl = gauss_chl
self.kernel = self.gauss_kernel()
def gauss_kernel(self, device=torch.device('cuda')):
kernel = torch.tensor([[1., 4., 6., 4., 1],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]])
kernel /= 256.
kernel = kernel.repeat(self.gauss_chl, 1, 1, 1)
kernel = kernel.to(device)
return kernel
def downsample(self, x):
return x[:, :, ::2, ::2]
def upsample(self, x):
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
cc = cc.permute(0, 1, 3, 2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
x_up = cc.permute(0, 1, 3, 2)
return self.conv_gauss(x_up, 4 * self.kernel)
def conv_gauss(self, img, kernel):
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
return out
def pyramid_decom(self, img):
current = img
pyr = []
for _ in range(self.num_high):
filtered = self.conv_gauss(current, self.kernel)
down = self.downsample(filtered)
up = self.upsample(down)
if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
diff = current - up
pyr.append(diff)
current = down
return pyr
class SRDecoder(nn.Module):
# super resolution decoder
def __init__(self, conv_cfg, norm_cfg, act_cfg, channels=128, up_lists=[2, 2, 2]):
super(SRDecoder, self).__init__()
self.conv1 = ConvModule(channels, channels // 2, 3, stride=1, padding=1, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, act_cfg=act_cfg)
self.up1 = nn.Upsample(scale_factor=up_lists[0])
self.conv2 = ConvModule(channels // 2, channels // 2, 3, stride=1, padding=1, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, act_cfg=act_cfg)
self.up2 = nn.Upsample(scale_factor=up_lists[1])
self.conv3 = ConvModule(channels // 2, channels, 3, stride=1, padding=1, conv_cfg=conv_cfg,
norm_cfg=norm_cfg, act_cfg=act_cfg)
self.up3 = nn.Upsample(scale_factor=up_lists[2])
self.conv_sr = SegmentationHead(conv_cfg, norm_cfg, act_cfg, channels, channels // 2, 3, kernel_size=1)
def forward(self, x, fa=False):#输入x是2 128 39 39
x = self.up1(x)
x = self.conv1(x)
x = self.up2(x)
x = self.conv2(x)
x = self.up3(x)
feats = self.conv3(x)
outs = self.conv_sr(feats)#2 3 624 624
if fa:
return feats, outs
else:
return outs
#将通道数写死
class Reducer(nn.Module):
# Reduce channel (typically to 128)
def __init__(self, in_channels=512, reduce=128, bn_relu=True):
super(Reducer, self).__init__()
self.bn_relu = bn_relu
self.conv1 = nn.Conv2d(in_channels, reduce, 1, bias=False)
if self.bn_relu:
self.bn1 = nn.BatchNorm2d(reduce)
def forward(self, x):
x = self.conv1(x)
if self.bn_relu:
x = self.bn1(x)
x = F.relu(x)
return x
from .fusion_module.my_res_pag import RelationAwareFusion_coordatt_res_pag
from .fusion_module.coord_pag_raf import RelationAwareFusion_coordatt_pag_raf
from .fusion_module.raf import RelationAwareFusion
from .my_fusion_module.add import Vssm_cnn_Add1,Vssm_cnn_Add2
from .my_fusion_module.cat import Vssm_cnn_Cat1,Vssm_cnn_Cat2
from .my_fusion_module.sp_con_pag import RelationAwareFusion_coordatt_pag_raf as RCPR
from .CNN_module.RFD import ShallowNet_RFD
from .CNN_module.convnext import convnext_supertiny
from .CNN_module.efficientnetv2 import efficientnetv2_s,efficientnetv2_m
from .CNN_module.segnext import segnext_t
from .CNN_module.resnet import resnet50,resnet18
from .CNN_module.model_v2 import MobileNetV2
from .CNN_module.model_v3 import mobilenet_v3_large
from .CNN_module.convnext import convnext_tiny
from .CNN_module.pkinet import PKINet
from .fusion_module.FAM import FeatureAggregationModule as FAM
from .fusion_module.FFM import FFM
from .fusion_module.GFM import Gated_Fusion as GFM
from .fusion_module.CLM import CrossFusionModule as CFM
from .fusion_module.CMX import FeatureFusionModule as cmxffm
#-------------------------------------------------------------#
#-------------------------------------------------------------#
@HEADS.register_module()
class ISDHead(BaseCascadeDecodeHead):
def __init__(self, down_ratio, prev_channels, reduce=False,fusion_mode='raf',consist=False,
model_cls='mamba',dims=[48, 96, 192, 384],depths=[1, 1, 2, 1],
shallow_model_inchan=6,lap=True,
**kwargs):
super(ISDHead, self).__init__(**kwargs)
self.down_ratio = down_ratio
self.sr_decoder = SRDecoder(self.conv_cfg, self.norm_cfg, self.act_cfg,
channels=self.channels, up_lists=[4, 2, 2])
# shallow branch
self.model_cls=model_cls
self.dims=dims
self.depth=depths
self.shallow_model_inchan=shallow_model_inchan
self.lap=lap
if self.model_cls =="mamba":
self.raf_channel = dims[0]
self.fuse8 = RelationAwareFusion(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=4)
self.fuse16 = RelationAwareFusion(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=8)
self.stdc_net = VSSM(
patch_size=4,
in_chans=self.shallow_model_inchan,
dims=self.dims,#dims=[96, 192, 384, 768]
depths=self.depth,
pretrain_model="/media/cm/2c5e0c44-80c0-4ab7-b8af-c5a0997b2a7f/zjb/UHR_Model/checkpoint/vmamba_tiny_e292.pth",)
elif self.model_cls =="ShallowNet":
self.raf_channel = 128
# self.fuse8 = RelationAwareFusion(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
# self.act_cfg, ext=2)
# self.fuse16 = RelationAwareFusion(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
# self.act_cfg, ext=4)
self.fuse8 =RCPR(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
self.act_cfg, ext=2)
self.fuse16 = RCPR(self.raf_channel, self.channels, self.conv_cfg, self.norm_cfg,
self.act_cfg, ext=4)
#fam
# self.fuse8=FAM(256,128)
# self.fuse16=FAM(512,128)
#FFM
# self.fuse8=FFM(256,128)
# self.fuse16=FFM(512,128)
#
#gfm
# self.fuse8=GFM(256,256)
# self.fuse16=GFM(512,256)
#CFM
# self.fuse8=CFM(256,128)
# self.fuse16=CFM(512,128)
#cmx
# self.fuse8=cmxffm(dim=128,reduction=1,num_heads=8,smooth_channels=256)
# self.fuse16=cmxffm(dim=128,reduction=1,num_heads=8,smooth_channels=512)
#
# self.fuse8 =Vssm_cnn_Add2(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
# self.fuse16 =Vssm_cnn_Add1(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
# self.fuse8 = Vssm_cnn_Cat2(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
# self.fuse16 = Vssm_cnn_Cat1(self.raf_channel,self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg,)
# self.stdc_net = ShallowNet(in_channels=self.shallow_model_inchan,
# pretrain_model="")#64.41
self.stdc_net = ShallowNet_RFD(in_channels=self.shallow_model_inchan,
# pretrain_model="/media/cm/2c5e0c44-80c0-4ab7-b8af-c5a0997b2a7f/zjb/UHR_Model/checkpoint/STDCNet813M_73.91.tar")
pretrain_model="")
# self.stdc_net =efficientnetv2_m()#0.625 1.375
# self.stdc_net =efficientnetv2_s()#0.5 1.25
# self.stdc_net = segnext_t()#0.5 1.25
# self.stdc_net=resnet50()# 4 8
# self.stdc_net=resnet18()# 1 2
# self.stdc_net=MobileNetV2()#0.25 0.5
# self.stdc_net=mobilenet_v3_large()#0.1875 0.625
# self.stdc_net=convnext_tiny()#1.5 3
# self.stdc_net=PKINet('SST')#0.5 1
# self.stdc_net=PKINet('SS')#1 2
self.lap_prymaid_conv = Lap_Pyramid_Conv(num_high=2)
self.conv_seg_aux_16 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
self.channels // 2, self.num_classes, kernel_size=1)
self.conv_seg_aux_8 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
self.channels // 2, self.num_classes, kernel_size=1)
self.conv_seg = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
self.channels // 2, self.num_classes, kernel_size=1)
self.reduce = Reducer() if reduce else None
self.fusion_mode=fusion_mode
self.channel_reduce1=Reducer(768,128)
self.consist=consist
def forward(self, inputs, prev_output, input_16,train_flag=True):
"""Forward function."""
#input:2 3 1224 1224
if self.lap:
prymaid_results = self.lap_prymaid_conv.pyramid_decom(inputs)#2 3 1224 1224;2 3 612 612
high_residual_1 = prymaid_results[0]#1224 1224 3 2
high_residual_2 = F.interpolate(prymaid_results[1], prymaid_results[0].size()[2:], mode='bilinear',
align_corners=False)#2 3 1224 1224
high_residual_input = torch.cat([high_residual_1, high_residual_2], dim=1)
elif not self.lap:
# high_residual_1 = inputs
# high_residual_2 = inputs
high_residual_input =inputs
if self.model_cls =="mamba":
feature_all=self.stdc_net(high_residual_input)
shallow_feat8, shallow_feat16 =feature_all[2],feature_all[3]
else:
shallow_feat8, shallow_feat16 = self.stdc_net(high_residual_input)#2 256 153 153;2 512 77 77
deep_feat = prev_output[0]#2 128 39 39 deeplabv3分支的分类前结果
deep_feat16=input_16
# deep_feat = prev_output # 2 128 39 39 deeplabv3分支的分类前结果
# ----------------------------------------------------#
# ----------------------------------------------------#
if self.reduce is not None:
deep_feat = self.reduce(deep_feat)
# stage 1
_, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
# stage 2
_, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
output = self.cls_seg(fused_feat_8) # 2 7 153 153
# torch.save(deep_feat,
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-M_feat-tensor.pt')
# torch.save(shallow_feat16,
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-C-feat-tensor.pt')
# torch.save(fused_feat_8,
# '/media/cm/D450B76750B74ECC/prediction2/feature_map/deepglobe2/cat-fusion-feat-tensor.pt')
#
# _, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
# # stage 2
# _, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
# output = self.cls_seg(fused_feat_8)#2 7 153 153
if train_flag:#结构蒸馏损失
output_aux16 = self.conv_seg_aux_16(aux_feat8)#aux——feat8是上采样后的损失 2 7 153 153 cbr
output_aux8 = self.conv_seg_aux_8(aux_feat16)#aux-feat16是aware fusion 上采样后的损失 2 7 77 77 cbr
feats, output_sr = self.sr_decoder(deep_feat , True)#2 128 624 624;2 3 624 624
#重构损失
losses_re = self.image_recon_loss(high_residual_input, output_sr, re_weight=0.1)#将拉普拉斯金字塔的结果和deeplab编码再解码的结果做重构损失
if self.consist:#如果要做特征之间的一致性约束
loss_consist1=self.gauss_spatial_consistency_loss(self.reduce256_1(shallow_feat8),self.reduce384_1(deep_feat16))
loss_consist2=self.gauss_spatial_consistency_loss(self.reduce512_1(shallow_feat16), self.reduce128_1(deep_feat))
losses_fa = self.feature_affinity_loss(deep_feat, feats)#把上采样后的特征与deeplab提取的特征做特征关系损失
return output, output_aux16, output_aux8, losses_re, losses_fa,loss_consist1,loss_consist2
else:
losses_fa = self.feature_affinity_loss(deep_feat, feats) # 把上采样后的特征与deeplab提取的特征做特征关系损失
return output, output_aux16, output_aux8, losses_re, losses_fa
else:
return output
# @HEADS.register_module()
# class ISDHead(BaseCascadeDecodeHead):
# def __init__(self, down_ratio, prev_channels, reduce=False,consist=True, **kwargs):
# super(ISDHead, self).__init__(**kwargs)
# self.down_ratio = down_ratio
# self.fuse8 = RelationAwareFusion(self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=2)
# self.fuse16 = RelationAwareFusion(self.channels, self.conv_cfg, self.norm_cfg, self.act_cfg, ext=4)
# self.sr_decoder = SRDecoder(self.conv_cfg, self.norm_cfg, self.act_cfg,
# channels=self.channels, up_lists=[4, 2, 2])
# # shallow branch
# self.stdc_net = ShallowNet(in_channels=6, pretrain_model="/media/cm/D450B76750B74ECC/Fusion_model/checkpoint/STDCNet813M_73.91.tar")
# self.lap_prymaid_conv = Lap_Pyramid_Conv(num_high=2)
# self.conv_seg_aux_16 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
# self.channels // 2, self.num_classes, kernel_size=1)
# self.conv_seg_aux_8 = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
# self.channels // 2, self.num_classes, kernel_size=1)
# self.conv_seg = SegmentationHead(self.conv_cfg, self.norm_cfg, self.act_cfg, self.channels,
# self.channels // 2, self.num_classes, kernel_size=1)
#
# self.reduce = Reducer() if reduce else None
#
# self.reduce384_1=Reducer(384,1)
# self.reduce128_1=Reducer(128,1)
# self.reduce256_1=Reducer(256,1)
# self.reduce512_1=Reducer(512,1)
#
# self.consist=consist
#
# def forward(self, inputs, prev_output,input_16,train_flag=True):
# """Forward function."""
# #input:2 3 1224 1224
# prymaid_results = self.lap_prymaid_conv.pyramid_decom(inputs)#2 3 1224 1224;2 3 612 612
# high_residual_1 = prymaid_results[0]#1224 1224 3 2
# high_residual_2 = F.interpolate(prymaid_results[1], prymaid_results[0].size()[2:], mode='bilinear',
# align_corners=False)#2 3 1224 1224
# high_residual_input = torch.cat([high_residual_1, high_residual_2], dim=1)
# shallow_feat8, shallow_feat16 = self.stdc_net(high_residual_input)#2 256 153 153;2 512 77 77
# deep_feat = prev_output[0]#2 128 39 39 deeplabv3分支的分类前结果
# deep_feat16=input_16
# if self.reduce is not None:
# deep_feat = self.reduce(deep_feat)
# # stage 1
# _, aux_feat16, fused_feat_16 = self.fuse16(shallow_feat16, deep_feat)#然后吧这个deeplabv3的特征与16特征第一次融合
# # stage 2
# _, aux_feat8, fused_feat_8 = self.fuse8(shallow_feat8, fused_feat_16)#吧融合结果与8倍第二次融合
# output = self.cls_seg(fused_feat_8)#2 7 153 153
# if train_flag:#结构蒸馏损失
# output_aux16 = self.conv_seg_aux_16(aux_feat8)#aux——feat8是上采样后的损失 2 7 153 153 cbr
# output_aux8 = self.conv_seg_aux_8(aux_feat16)#aux-feat16是aware fusion 上采样后的损失 2 7 77 77 cbr
# feats, output_sr = self.sr_decoder(deep_feat, True)#2 128 624 624;2 3 624 624
# #重构损失
# losses_re = self.image_recon_loss(high_residual_1 + high_residual_2, output_sr, re_weight=0.1)#将拉普拉斯金字塔的结果和deeplab编码再解码的结果做重构损失
# if self.consist:
# loss_consist1=self.gauss_spatial_consistency_loss(self.reduce256_1(shallow_feat8),self.reduce384_1(deep_feat16))
# loss_consist2=self.gauss_spatial_consistency_loss(self.reduce512_1(shallow_feat16), self.reduce128_1(deep_feat))
# losses_fa = self.feature_affinity_loss(deep_feat, feats)#把上采样后的特征与deeplab提取的特征做特征关系损失
# return output, output_aux16, output_aux8, losses_re, losses_fa,loss_consist1,loss_consist2
# else:
# losses_fa = self.feature_affinity_loss(deep_feat, feats) # 把上采样后的特征与deeplab提取的特征做特征关系损失
# return output, output_aux16, output_aux8, losses_re, losses_fa
#
#
# else:
# return output
def image_recon_loss(self, img, pred, re_weight=0.5):
loss = dict()
if pred.size()[2:] != img.size()[2:]:
pred = F.interpolate(pred, img.size()[2:], mode='bilinear', align_corners=False)
recon_loss = F.mse_loss(pred, img) * re_weight#均方跟损失
loss['recon_losses'] = recon_loss
return loss
def gaussion_similarity(self,tensor1,tensor2,sigma=1.0):
distance=torch.sum((tensor1-tensor2)**2,dim=[1,2,3],keepdim=True)
similarity=torch.exp(-distance/(2*sigma**2))
return similarity
def gauss_spatial_consistency_loss(self, img, pred, re_weight=0.5):
loss = dict()
if pred.size()[2:] != img.size()[2:]:
pred = F.interpolate(pred, img.size()[2:], mode='bilinear', align_corners=False)
#similarity=self.gaussion_similarity(img,pred)
recon_loss = F.mse_loss(pred, img) * re_weight#均方跟损失
loss['gauss_consistency_losses'] = recon_loss
return loss
#这个函数还不是很懂,回归损失函数
def feature_affinity_loss(self, seg_feats, sr_feats, fa_weight=1., eps=1e-6):#上采样后的特征为sr_feats2 128 624 624deeplab特征为seg_feats2 128 39 39
if seg_feats.size()[2:] != sr_feats.size()[2:]:
sr_feats = F.interpolate(sr_feats, seg_feats.size()[2:], mode='bilinear', align_corners=False)#把上采样的特征又采样回去
loss = dict()
# flatten:
seg_feats_flatten = torch.flatten(seg_feats, start_dim=2)#2 128 1521
sr_feats_flatten = torch.flatten(sr_feats, start_dim=2)#2 128 1521
# L2 norm
seg_norm = torch.norm(seg_feats_flatten, p=2, dim=2, keepdim=True)#2 128 1 通道唯独做norm
sr_norm = torch.norm(sr_feats_flatten, p=2, dim=2, keepdim=True)#2 128 1
# similiarity
seg_feats_flatten = seg_feats_flatten / (seg_norm + eps)#2 128 1521
sr_feats_flatten = sr_feats_flatten / (sr_norm + eps)#2 128 1521
seg_sim = torch.matmul(seg_feats_flatten.permute(0, 2, 1), seg_feats_flatten)#通道唯独做乘法,虽然是通道维度但是×的是每个像素,所以就是求像素之间的余弦相似度,看起来有点像自注意力了
sr_sim = torch.matmul(sr_feats_flatten.permute(0, 2, 1), sr_feats_flatten)#
# L1 loss
loss['fa_loss'] = F.l1_loss(seg_sim, sr_sim.detach()) * fa_weight#detach分离梯度了
return loss
def _stack_batch_gt(self, batch_data_samples) :
gt_semantic_segs = [
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
]
return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logits,
batch_data_samples) -> dict:
"""Compute segmentation loss.
Args:
seg_logits (Tensor): The output from decode head forward function.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_sem_seg`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_label = self._stack_batch_gt(batch_data_samples)#那个标签列表会在这个函数中生成为bs 1 w h的标签
loss = dict()
seg_logits = resize(
input=seg_logits,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)#将head产生的结果进行上采样
if self.sampler is not None:
seg_weight = self.sampler.sample(seg_logits, seg_label)
else:
seg_weight = None
seg_label = seg_label.squeeze(1)#压缩一维
if not isinstance(self.loss_decode, nn.ModuleList):#调用损失函数
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)#计算损失
else:
loss[loss_decode.loss_name] += loss_decode(
seg_logits,
seg_label,
weight=seg_weight,
ignore_index=self.ignore_index)
loss['acc_seg'] = accuracy(
seg_logits, seg_label, ignore_index=self.ignore_index)
return loss
def forward_train(self, inputs, prev_output,input_16, gt_semantic_seg):
if self.consist:
seg_logits, seg_logits_aux16, seg_logits_aux8, losses_recon, losses_fa,loss_consist1,loss_consist2= self.forward(inputs, prev_output,input_16,)
losses = self.loss_by_feat(seg_logits, gt_semantic_seg)
losses_aux16 = self.loss_by_feat(seg_logits_aux16, gt_semantic_seg)
losses_aux8 = self.loss_by_feat(seg_logits_aux8, gt_semantic_seg)
return losses, losses_aux16, losses_aux8, losses_recon, losses_fa,loss_consist1,loss_consist2
else:
seg_logits, seg_logits_aux16, seg_logits_aux8, losses_recon, losses_fa = self.forward(
inputs, prev_output, input_16)
losses = self.loss_by_feat(seg_logits, gt_semantic_seg)
losses_aux16 = self.loss_by_feat(seg_logits_aux16, gt_semantic_seg)
losses_aux8 = self.loss_by_feat(seg_logits_aux8, gt_semantic_seg)
return losses, losses_aux16, losses_aux8, losses_recon, losses_fa
def forward_test(self, inputs, prev_output,input_16, ):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
return self.forward(inputs, prev_output,input_16, False)