import time import math import copy from functools import partial from typing import Optional, Callable import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange, repeat from timm.models.layers import DropPath, trunc_normal_ # from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" # import mamba_ssm.selective_scan_fn (in which causal_conv1d is needed) try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref except: pass # # # an alternative for mamba_ssm # try: # from selective_scan import selective_scan_fn as selective_scan_fn_v1 # from selective_scan import selective_scan_ref as selective_scan_ref_v1 # except: # pass """ 参考isdnet的主干,将其后三个阶段特征都不进行下采样 以2 3 306 306为例 """ from mmseg.registry import MODELS class PatchEmbed2D(nn.Module): r""" Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): super().__init__() if isinstance(patch_size, int): patch_size = (patch_size, patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=1) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): x = self.proj(x).permute(0, 2, 3, 1) if self.norm is not None: x = self.norm(x) return x class PatchMerging2D(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) # x 16 77 77 96 def forward(self, x): #x = F.pad(x, pad=(0, 0, 1, 1, 1, 1), mode="constant", value=0) # 最后俩个填充第第一维度,中间两个填充第二维度,开头两个填充第一维度 B, H, W, C = x.shape SHAPE_FIX = [-1, -1] if (W % 2 != 0) or (H % 2 != 0): #print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) SHAPE_FIX[0] = H // 2 SHAPE_FIX[1] = W // 2 # 将空间信息融合进通道中 x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C if SHAPE_FIX[0] > 0: x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C x = self.norm(x) # 4c x = self.reduction(x) # 2c return x from .mamba_block.mambair import mambairBlock # from .mamba_block.vmamba import Vmamba_Block from .mamba_block.vdb import VDB_Block from .mamba_block.lfssblock import LFSSBlock # from .mamba_block.my_efficent_block import VSSBlock # from .mamba_block.my_efficent_block_td import VSSBlock from .mamba_block.moganet import VSSBlock class VSSLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, d_state=16, vssm_mode='', **kwargs, ): super().__init__() self.dim = dim self.use_checkpoint = use_checkpoint self.vssm_mode=vssm_mode self.blocks = nn.ModuleList([ mambairBlock( hidden_dim=dim, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, attn_drop_rate=attn_drop, d_state=d_state, vssm_mode= self.vssm_mode, ) for i in range(depth)]) if True: # is this really applied? Yes, but been overriden later in VSSM! def _init_weights(module: nn.Module): for name, p in module.named_parameters(): if name in ["out_proj.weight"]: p = p.clone().detach_() # fake init, just to keep the seed .... nn.init.kaiming_uniform_(p, a=math.sqrt(5)) self.apply(_init_weights) if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x @MODELS.register_module() class VSSM(nn.Module):#原始ssm def __init__(self, patch_size=4, in_chans=3, num_classes=4, depths=[2, 2, 9, 2], dims=[96, 192, 384, 768], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True,vssm_mode='vssm', use_checkpoint=False, final_upsample="expand_first",pretrain_model="", **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) if isinstance(dims, int): dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] self.embed_dim = dims[0] self.num_features = dims[-1] self.num_features_up = int(dims[0] * 2) self.dims = dims self.final_upsample = final_upsample self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, norm_layer=norm_layer if patch_norm else None) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule self.vssm_mode=vssm_mode # build encoder and bottleneck layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): # 下采样层 layer = VSSLayer( # dim=dims[i_layer], #int(embed_dim * 2 ** i_layer) dim=int(dims[0] * 2 ** i_layer), depth=depths[i_layer], d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, vssm_mode=self.vssm_mode, ) self.layers.append(layer) # build decoder layers self.norm = norm_layer(self.num_features) if pretrain_model: print('use pretrain model {}'.format(pretrain_model)) self.load_from(pretrain_model) self.apply(self._init_weights) else: self.apply(self._init_weights) def load_from(self, config): # pretrained_path = config.MODEL.PRETRAIN_CKPT pretrained_path = config if pretrained_path is not None: print("pretrained_path:{}".format(pretrained_path)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pretrained_dict = torch.load(pretrained_path, map_location=device) if "model" not in pretrained_dict: print("---start load pretrained modle by splitting---") pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()} for k in list(pretrained_dict.keys()): if "output" in k: print("delete key:{}".format(k)) del pretrained_dict[k] msg = self.load_state_dict(pretrained_dict, strict=False) # print(msg) return pretrained_dict = pretrained_dict['model'] print("---start load pretrained modle of swin encoder---") model_dict = self.state_dict() full_dict = copy.deepcopy(pretrained_dict) for k, v in pretrained_dict.items(): if "layers." in k: current_layer_num = 3 - int(k[7:8]) current_k = "layers_up." + str(current_layer_num) + k[8:] full_dict.update({current_k: v}) for k in list(full_dict.keys()): if k in model_dict: if full_dict[k].shape != model_dict[k].shape: print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape)) del full_dict[k] msg = self.load_state_dict(full_dict, strict=False) # print(msg) else: print("none pretrain") def init_weight(self, pretrain_model): state_dict = torch.load(pretrain_model)["state_dict"]#权重字典键 self_state_dict = self.state_dict() for k, v in state_dict.items(): if k == 'features.0.conv.weight' and self.in_channels != 3: v = torch.cat([v, v], dim=1) self_state_dict.update({k: v}) self.load_state_dict(self_state_dict, strict=False) def _init_weights(self, m: nn.Module): """ out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear no fc.weight found in the any of the model parameters no nn.Embedding found in the any of the model parameters so the thing is, VSSBlock initialization is useless Conv2D is not intialized !!! """ # print(m, getattr(getattr(m, "weight", nn.Identity()), "INIT", None), isinstance(m, nn.Linear), "======================") if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) # Encoder and Bottleneck def forward_features(self, x): x = self.patch_embed(x) # 下采样4倍,通道数增加为96 ;24 56 56 96 x_downsample = [] for layer in self.layers: x_downsample.append(x) x = layer(x) x = self.norm(x) # B H W C x_middle=x_downsample[0:3] x_middle.append(x) return x, x_downsample,x_middle # def forward_backbone(self, x): # x = self.patch_embed(x) # for layer in self.layers: # x = layer(x) # return x # Dencoder and Skip connection def forward_up_features(self, x, x_downsample): for inx, layer_up in enumerate(self.layers_up): if inx == 0: x = layer_up(x) else: x = torch.cat([x, x_downsample[3 - inx]], -1) x = self.concat_back_dim[inx](x) x = layer_up(x) x = self.norm_up(x) # B H W C return x def up_x4(self, x): if self.final_upsample == "expand_first": B, H, W, C = x.shape x = self.up(x) x = x.view(B, 4 * H, 4 * W, -1) x = x.permute(0, 3, 1, 2) # B,C,H,W x = self.output(x) return x def forward(self, x): # 24 3 224 224 x, x_downsample,x_middle= self.forward_features(x) # x = self.forward_up_features(x,x_downsample) # x = self.up_x4(x) x_final=[] for feat_every in x_middle: feat=feat_every.permute(0,3,1,2) x_final.append(feat) return x_final if __name__ == "__main__": # check_vssm_equals_vmambadp() # from torchstat import stat # # from thop import profile weight="/media/cm/D450B76750B74ECC/Fusion_model/checkpoint/vmamba_tiny_e292.pth" # model=MambaUnet_backbone(weight).to('cuda') # # model = VSSM(use_checkpoint=weight).to('cuda') # model.load_from(weight) # int = torch.randn(2,1,224,224).cuda() # out = model(int) # # # flops, params = profile(model, inputs=(int,)) # # print('flops:{}'.format(flops)) # # print('params:{}'.format(params)) # print(out.shape) model = VSSM(pretrain_model=weight).to('cuda') int = torch.randn(2, 3, 306, 306).cuda() out = model(int) # flops, params = profile(model, inputs=(int,)) # print('flops:{}'.format(flops)) # print('params:{}'.format(params)) # print(out.shape) print(out)