396 lines
14 KiB
Python
396 lines
14 KiB
Python
|
|
import time
|
|||
|
|
import math
|
|||
|
|
import copy
|
|||
|
|
from functools import partial
|
|||
|
|
from typing import Optional, Callable
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
import torch.utils.checkpoint as checkpoint
|
|||
|
|
from einops import rearrange, repeat
|
|||
|
|
from timm.models.layers import DropPath, trunc_normal_
|
|||
|
|
# from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
|
|||
|
|
|
|||
|
|
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
|
|||
|
|
|
|||
|
|
# import mamba_ssm.selective_scan_fn (in which causal_conv1d is needed)
|
|||
|
|
try:
|
|||
|
|
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
|
|||
|
|
except:
|
|||
|
|
pass
|
|||
|
|
#
|
|||
|
|
# # an alternative for mamba_ssm
|
|||
|
|
# try:
|
|||
|
|
# from selective_scan import selective_scan_fn as selective_scan_fn_v1
|
|||
|
|
# from selective_scan import selective_scan_ref as selective_scan_ref_v1
|
|||
|
|
# except:
|
|||
|
|
# pass
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
参考isdnet的主干,将其后三个阶段特征都不进行下采样
|
|||
|
|
以2 3 306 306为例
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
from mmseg.registry import MODELS
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PatchEmbed2D(nn.Module):
|
|||
|
|
r""" Image to Patch Embedding
|
|||
|
|
Args:
|
|||
|
|
patch_size (int): Patch token size. Default: 4.
|
|||
|
|
in_chans (int): Number of input image channels. Default: 3.
|
|||
|
|
embed_dim (int): Number of linear projection output channels. Default: 96.
|
|||
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
|
|||
|
|
super().__init__()
|
|||
|
|
if isinstance(patch_size, int):
|
|||
|
|
patch_size = (patch_size, patch_size)
|
|||
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=1)
|
|||
|
|
if norm_layer is not None:
|
|||
|
|
self.norm = norm_layer(embed_dim)
|
|||
|
|
else:
|
|||
|
|
self.norm = None
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
x = self.proj(x).permute(0, 2, 3, 1)
|
|||
|
|
if self.norm is not None:
|
|||
|
|
x = self.norm(x)
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PatchMerging2D(nn.Module):
|
|||
|
|
r""" Patch Merging Layer.
|
|||
|
|
Args:
|
|||
|
|
input_resolution (tuple[int]): Resolution of input feature.
|
|||
|
|
dim (int): Number of input channels.
|
|||
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
|||
|
|
super().__init__()
|
|||
|
|
self.dim = dim
|
|||
|
|
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
|||
|
|
self.norm = norm_layer(4 * dim)
|
|||
|
|
|
|||
|
|
# x 16 77 77 96
|
|||
|
|
def forward(self, x):
|
|||
|
|
#x = F.pad(x, pad=(0, 0, 1, 1, 1, 1), mode="constant", value=0) # 最后俩个填充第第一维度,中间两个填充第二维度,开头两个填充第一维度
|
|||
|
|
B, H, W, C = x.shape
|
|||
|
|
|
|||
|
|
SHAPE_FIX = [-1, -1]
|
|||
|
|
if (W % 2 != 0) or (H % 2 != 0):
|
|||
|
|
#print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
|
|||
|
|
SHAPE_FIX[0] = H // 2
|
|||
|
|
SHAPE_FIX[1] = W // 2
|
|||
|
|
# 将空间信息融合进通道中
|
|||
|
|
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
|||
|
|
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
|||
|
|
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
|||
|
|
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
|||
|
|
|
|||
|
|
if SHAPE_FIX[0] > 0:
|
|||
|
|
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
|||
|
|
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
|||
|
|
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
|||
|
|
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
|||
|
|
|
|||
|
|
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
|||
|
|
x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C
|
|||
|
|
|
|||
|
|
x = self.norm(x) # 4c
|
|||
|
|
x = self.reduction(x) # 2c
|
|||
|
|
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
from .mamba_block.mambair import mambairBlock
|
|||
|
|
# from .mamba_block.vmamba import Vmamba_Block
|
|||
|
|
from .mamba_block.vdb import VDB_Block
|
|||
|
|
from .mamba_block.lfssblock import LFSSBlock
|
|||
|
|
# from .mamba_block.my_efficent_block import VSSBlock
|
|||
|
|
# from .mamba_block.my_efficent_block_td import VSSBlock
|
|||
|
|
from .mamba_block.moganet import VSSBlock
|
|||
|
|
class VSSLayer(nn.Module):
|
|||
|
|
""" A basic Swin Transformer layer for one stage.
|
|||
|
|
Args:
|
|||
|
|
dim (int): Number of input channels.
|
|||
|
|
depth (int): Number of blocks.
|
|||
|
|
drop (float, optional): Dropout rate. Default: 0.0
|
|||
|
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
|||
|
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
|||
|
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
|||
|
|
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
|||
|
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
dim,
|
|||
|
|
depth,
|
|||
|
|
attn_drop=0.,
|
|||
|
|
drop_path=0.,
|
|||
|
|
norm_layer=nn.LayerNorm,
|
|||
|
|
downsample=None,
|
|||
|
|
use_checkpoint=False,
|
|||
|
|
d_state=16,
|
|||
|
|
vssm_mode='',
|
|||
|
|
**kwargs,
|
|||
|
|
):
|
|||
|
|
super().__init__()
|
|||
|
|
self.dim = dim
|
|||
|
|
self.use_checkpoint = use_checkpoint
|
|||
|
|
self.vssm_mode=vssm_mode
|
|||
|
|
|
|||
|
|
self.blocks = nn.ModuleList([
|
|||
|
|
mambairBlock(
|
|||
|
|
hidden_dim=dim,
|
|||
|
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
|||
|
|
norm_layer=norm_layer,
|
|||
|
|
attn_drop_rate=attn_drop,
|
|||
|
|
d_state=d_state,
|
|||
|
|
vssm_mode= self.vssm_mode,
|
|||
|
|
)
|
|||
|
|
for i in range(depth)])
|
|||
|
|
|
|||
|
|
if True: # is this really applied? Yes, but been overriden later in VSSM!
|
|||
|
|
def _init_weights(module: nn.Module):
|
|||
|
|
for name, p in module.named_parameters():
|
|||
|
|
if name in ["out_proj.weight"]:
|
|||
|
|
p = p.clone().detach_() # fake init, just to keep the seed ....
|
|||
|
|
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
|||
|
|
|
|||
|
|
self.apply(_init_weights)
|
|||
|
|
|
|||
|
|
if downsample is not None:
|
|||
|
|
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
|||
|
|
else:
|
|||
|
|
self.downsample = None
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
for blk in self.blocks:
|
|||
|
|
if self.use_checkpoint:
|
|||
|
|
x = checkpoint.checkpoint(blk, x)
|
|||
|
|
else:
|
|||
|
|
x = blk(x)
|
|||
|
|
|
|||
|
|
if self.downsample is not None:
|
|||
|
|
x = self.downsample(x)
|
|||
|
|
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
@MODELS.register_module()
|
|||
|
|
class VSSM(nn.Module):#原始ssm
|
|||
|
|
def __init__(self, patch_size=4, in_chans=3, num_classes=4, depths=[2, 2, 9, 2],
|
|||
|
|
dims=[96, 192, 384, 768], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
|||
|
|
norm_layer=nn.LayerNorm, patch_norm=True,vssm_mode='vssm',
|
|||
|
|
use_checkpoint=False, final_upsample="expand_first",pretrain_model="", **kwargs):
|
|||
|
|
super().__init__()
|
|||
|
|
self.num_classes = num_classes
|
|||
|
|
self.num_layers = len(depths)
|
|||
|
|
if isinstance(dims, int):
|
|||
|
|
dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
|
|||
|
|
self.embed_dim = dims[0]
|
|||
|
|
self.num_features = dims[-1]
|
|||
|
|
self.num_features_up = int(dims[0] * 2)
|
|||
|
|
self.dims = dims
|
|||
|
|
self.final_upsample = final_upsample
|
|||
|
|
|
|||
|
|
self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
|
|||
|
|
norm_layer=norm_layer if patch_norm else None)
|
|||
|
|
|
|||
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
|||
|
|
|
|||
|
|
self.vssm_mode=vssm_mode
|
|||
|
|
# build encoder and bottleneck layers
|
|||
|
|
self.layers = nn.ModuleList()
|
|||
|
|
for i_layer in range(self.num_layers): # 下采样层
|
|||
|
|
layer = VSSLayer(
|
|||
|
|
# dim=dims[i_layer], #int(embed_dim * 2 ** i_layer)
|
|||
|
|
dim=int(dims[0] * 2 ** i_layer),
|
|||
|
|
depth=depths[i_layer],
|
|||
|
|
d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
|
|||
|
|
drop=drop_rate,
|
|||
|
|
attn_drop=attn_drop_rate,
|
|||
|
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
|||
|
|
norm_layer=norm_layer,
|
|||
|
|
downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
|
|||
|
|
use_checkpoint=use_checkpoint,
|
|||
|
|
vssm_mode=self.vssm_mode,
|
|||
|
|
)
|
|||
|
|
self.layers.append(layer)
|
|||
|
|
|
|||
|
|
# build decoder layers
|
|||
|
|
|
|||
|
|
self.norm = norm_layer(self.num_features)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if pretrain_model:
|
|||
|
|
print('use pretrain model {}'.format(pretrain_model))
|
|||
|
|
self.load_from(pretrain_model)
|
|||
|
|
self.apply(self._init_weights)
|
|||
|
|
else:
|
|||
|
|
self.apply(self._init_weights)
|
|||
|
|
|
|||
|
|
def load_from(self, config):
|
|||
|
|
# pretrained_path = config.MODEL.PRETRAIN_CKPT
|
|||
|
|
pretrained_path = config
|
|||
|
|
if pretrained_path is not None:
|
|||
|
|
print("pretrained_path:{}".format(pretrained_path))
|
|||
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|||
|
|
pretrained_dict = torch.load(pretrained_path, map_location=device)
|
|||
|
|
if "model" not in pretrained_dict:
|
|||
|
|
print("---start load pretrained modle by splitting---")
|
|||
|
|
pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()}
|
|||
|
|
for k in list(pretrained_dict.keys()):
|
|||
|
|
if "output" in k:
|
|||
|
|
print("delete key:{}".format(k))
|
|||
|
|
del pretrained_dict[k]
|
|||
|
|
msg = self.load_state_dict(pretrained_dict, strict=False)
|
|||
|
|
# print(msg)
|
|||
|
|
return
|
|||
|
|
pretrained_dict = pretrained_dict['model']
|
|||
|
|
print("---start load pretrained modle of swin encoder---")
|
|||
|
|
|
|||
|
|
model_dict = self.state_dict()
|
|||
|
|
full_dict = copy.deepcopy(pretrained_dict)
|
|||
|
|
for k, v in pretrained_dict.items():
|
|||
|
|
if "layers." in k:
|
|||
|
|
current_layer_num = 3 - int(k[7:8])
|
|||
|
|
current_k = "layers_up." + str(current_layer_num) + k[8:]
|
|||
|
|
full_dict.update({current_k: v})
|
|||
|
|
for k in list(full_dict.keys()):
|
|||
|
|
if k in model_dict:
|
|||
|
|
if full_dict[k].shape != model_dict[k].shape:
|
|||
|
|
print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape))
|
|||
|
|
del full_dict[k]
|
|||
|
|
|
|||
|
|
msg = self.load_state_dict(full_dict, strict=False)
|
|||
|
|
# print(msg)
|
|||
|
|
else:
|
|||
|
|
print("none pretrain")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def init_weight(self, pretrain_model):
|
|||
|
|
|
|||
|
|
state_dict = torch.load(pretrain_model)["state_dict"]#权重字典键
|
|||
|
|
self_state_dict = self.state_dict()
|
|||
|
|
for k, v in state_dict.items():
|
|||
|
|
if k == 'features.0.conv.weight' and self.in_channels != 3:
|
|||
|
|
v = torch.cat([v, v], dim=1)
|
|||
|
|
self_state_dict.update({k: v})
|
|||
|
|
self.load_state_dict(self_state_dict, strict=False)
|
|||
|
|
def _init_weights(self, m: nn.Module):
|
|||
|
|
"""
|
|||
|
|
out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
|
|||
|
|
no fc.weight found in the any of the model parameters
|
|||
|
|
no nn.Embedding found in the any of the model parameters
|
|||
|
|
so the thing is, VSSBlock initialization is useless
|
|||
|
|
|
|||
|
|
Conv2D is not intialized !!!
|
|||
|
|
"""
|
|||
|
|
# print(m, getattr(getattr(m, "weight", nn.Identity()), "INIT", None), isinstance(m, nn.Linear), "======================")
|
|||
|
|
if isinstance(m, nn.Linear):
|
|||
|
|
trunc_normal_(m.weight, std=.02)
|
|||
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|||
|
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
elif isinstance(m, nn.LayerNorm):
|
|||
|
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
nn.init.constant_(m.weight, 1.0)
|
|||
|
|
|
|||
|
|
# Encoder and Bottleneck
|
|||
|
|
def forward_features(self, x):
|
|||
|
|
x = self.patch_embed(x) # 下采样4倍,通道数增加为96 ;24 56 56 96
|
|||
|
|
|
|||
|
|
x_downsample = []
|
|||
|
|
for layer in self.layers:
|
|||
|
|
x_downsample.append(x)
|
|||
|
|
x = layer(x)
|
|||
|
|
x = self.norm(x) # B H W C
|
|||
|
|
|
|||
|
|
x_middle=x_downsample[0:3]
|
|||
|
|
x_middle.append(x)
|
|||
|
|
return x, x_downsample,x_middle
|
|||
|
|
|
|||
|
|
# def forward_backbone(self, x):
|
|||
|
|
# x = self.patch_embed(x)
|
|||
|
|
|
|||
|
|
# for layer in self.layers:
|
|||
|
|
# x = layer(x)
|
|||
|
|
# return x
|
|||
|
|
|
|||
|
|
# Dencoder and Skip connection
|
|||
|
|
def forward_up_features(self, x, x_downsample):
|
|||
|
|
for inx, layer_up in enumerate(self.layers_up):
|
|||
|
|
if inx == 0:
|
|||
|
|
x = layer_up(x)
|
|||
|
|
else:
|
|||
|
|
x = torch.cat([x, x_downsample[3 - inx]], -1)
|
|||
|
|
x = self.concat_back_dim[inx](x)
|
|||
|
|
x = layer_up(x)
|
|||
|
|
|
|||
|
|
x = self.norm_up(x) # B H W C
|
|||
|
|
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
def up_x4(self, x):
|
|||
|
|
if self.final_upsample == "expand_first":
|
|||
|
|
B, H, W, C = x.shape
|
|||
|
|
x = self.up(x)
|
|||
|
|
x = x.view(B, 4 * H, 4 * W, -1)
|
|||
|
|
x = x.permute(0, 3, 1, 2) # B,C,H,W
|
|||
|
|
x = self.output(x)
|
|||
|
|
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
def forward(self, x): # 24 3 224 224
|
|||
|
|
x, x_downsample,x_middle= self.forward_features(x)
|
|||
|
|
# x = self.forward_up_features(x,x_downsample)
|
|||
|
|
# x = self.up_x4(x)
|
|||
|
|
x_final=[]
|
|||
|
|
for feat_every in x_middle:
|
|||
|
|
feat=feat_every.permute(0,3,1,2)
|
|||
|
|
x_final.append(feat)
|
|||
|
|
return x_final
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# check_vssm_equals_vmambadp()
|
|||
|
|
# from torchstat import stat
|
|||
|
|
#
|
|||
|
|
# from thop import profile
|
|||
|
|
|
|||
|
|
weight="/media/cm/D450B76750B74ECC/Fusion_model/checkpoint/vmamba_tiny_e292.pth"
|
|||
|
|
# model=MambaUnet_backbone(weight).to('cuda')
|
|||
|
|
# # model = VSSM(use_checkpoint=weight).to('cuda')
|
|||
|
|
# model.load_from(weight)
|
|||
|
|
# int = torch.randn(2,1,224,224).cuda()
|
|||
|
|
# out = model(int)
|
|||
|
|
#
|
|||
|
|
# # flops, params = profile(model, inputs=(int,))
|
|||
|
|
# # print('flops:{}'.format(flops))
|
|||
|
|
# # print('params:{}'.format(params))
|
|||
|
|
# print(out.shape)
|
|||
|
|
|
|||
|
|
|
|||
|
|
model = VSSM(pretrain_model=weight).to('cuda')
|
|||
|
|
int = torch.randn(2, 3, 306, 306).cuda()
|
|||
|
|
out = model(int)
|
|||
|
|
|
|||
|
|
# flops, params = profile(model, inputs=(int,))
|
|||
|
|
# print('flops:{}'.format(flops))
|
|||
|
|
# print('params:{}'.format(params))
|
|||
|
|
|
|||
|
|
# print(out.shape)
|
|||
|
|
print(out)
|