396 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
import math
import copy
from functools import partial
from typing import Optional, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, trunc_normal_
# from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
# import mamba_ssm.selective_scan_fn (in which causal_conv1d is needed)
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
except:
pass
#
# # an alternative for mamba_ssm
# try:
# from selective_scan import selective_scan_fn as selective_scan_fn_v1
# from selective_scan import selective_scan_ref as selective_scan_ref_v1
# except:
# pass
"""
参考isdnet的主干将其后三个阶段特征都不进行下采样
以2 3 306 306为例
"""
from mmseg.registry import MODELS
class PatchEmbed2D(nn.Module):
r""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
super().__init__()
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=1)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = self.proj(x).permute(0, 2, 3, 1)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerging2D(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
# x 16 77 77 96
def forward(self, x):
#x = F.pad(x, pad=(0, 0, 1, 1, 1, 1), mode="constant", value=0) # 最后俩个填充第第一维度,中间两个填充第二维度,开头两个填充第一维度
B, H, W, C = x.shape
SHAPE_FIX = [-1, -1]
if (W % 2 != 0) or (H % 2 != 0):
#print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
SHAPE_FIX[0] = H // 2
SHAPE_FIX[1] = W // 2
# 将空间信息融合进通道中
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
if SHAPE_FIX[0] > 0:
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C
x = self.norm(x) # 4c
x = self.reduction(x) # 2c
return x
from .mamba_block.mambair import mambairBlock
# from .mamba_block.vmamba import Vmamba_Block
from .mamba_block.vdb import VDB_Block
from .mamba_block.lfssblock import LFSSBlock
# from .mamba_block.my_efficent_block import VSSBlock
# from .mamba_block.my_efficent_block_td import VSSBlock
from .mamba_block.moganet import VSSBlock
class VSSLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
d_state=16,
vssm_mode='',
**kwargs,
):
super().__init__()
self.dim = dim
self.use_checkpoint = use_checkpoint
self.vssm_mode=vssm_mode
self.blocks = nn.ModuleList([
mambairBlock(
hidden_dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
attn_drop_rate=attn_drop,
d_state=d_state,
vssm_mode= self.vssm_mode,
)
for i in range(depth)])
if True: # is this really applied? Yes, but been overriden later in VSSM!
def _init_weights(module: nn.Module):
for name, p in module.named_parameters():
if name in ["out_proj.weight"]:
p = p.clone().detach_() # fake init, just to keep the seed ....
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
self.apply(_init_weights)
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
@MODELS.register_module()
class VSSM(nn.Module):#原始ssm
def __init__(self, patch_size=4, in_chans=3, num_classes=4, depths=[2, 2, 9, 2],
dims=[96, 192, 384, 768], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, patch_norm=True,vssm_mode='vssm',
use_checkpoint=False, final_upsample="expand_first",pretrain_model="", **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
if isinstance(dims, int):
dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
self.embed_dim = dims[0]
self.num_features = dims[-1]
self.num_features_up = int(dims[0] * 2)
self.dims = dims
self.final_upsample = final_upsample
self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
norm_layer=norm_layer if patch_norm else None)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
self.vssm_mode=vssm_mode
# build encoder and bottleneck layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers): # 下采样层
layer = VSSLayer(
# dim=dims[i_layer], #int(embed_dim * 2 ** i_layer)
dim=int(dims[0] * 2 ** i_layer),
depth=depths[i_layer],
d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
vssm_mode=self.vssm_mode,
)
self.layers.append(layer)
# build decoder layers
self.norm = norm_layer(self.num_features)
if pretrain_model:
print('use pretrain model {}'.format(pretrain_model))
self.load_from(pretrain_model)
self.apply(self._init_weights)
else:
self.apply(self._init_weights)
def load_from(self, config):
# pretrained_path = config.MODEL.PRETRAIN_CKPT
pretrained_path = config
if pretrained_path is not None:
print("pretrained_path:{}".format(pretrained_path))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_dict = torch.load(pretrained_path, map_location=device)
if "model" not in pretrained_dict:
print("---start load pretrained modle by splitting---")
pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()}
for k in list(pretrained_dict.keys()):
if "output" in k:
print("delete key:{}".format(k))
del pretrained_dict[k]
msg = self.load_state_dict(pretrained_dict, strict=False)
# print(msg)
return
pretrained_dict = pretrained_dict['model']
print("---start load pretrained modle of swin encoder---")
model_dict = self.state_dict()
full_dict = copy.deepcopy(pretrained_dict)
for k, v in pretrained_dict.items():
if "layers." in k:
current_layer_num = 3 - int(k[7:8])
current_k = "layers_up." + str(current_layer_num) + k[8:]
full_dict.update({current_k: v})
for k in list(full_dict.keys()):
if k in model_dict:
if full_dict[k].shape != model_dict[k].shape:
print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape))
del full_dict[k]
msg = self.load_state_dict(full_dict, strict=False)
# print(msg)
else:
print("none pretrain")
def init_weight(self, pretrain_model):
state_dict = torch.load(pretrain_model)["state_dict"]#权重字典键
self_state_dict = self.state_dict()
for k, v in state_dict.items():
if k == 'features.0.conv.weight' and self.in_channels != 3:
v = torch.cat([v, v], dim=1)
self_state_dict.update({k: v})
self.load_state_dict(self_state_dict, strict=False)
def _init_weights(self, m: nn.Module):
"""
out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
no fc.weight found in the any of the model parameters
no nn.Embedding found in the any of the model parameters
so the thing is, VSSBlock initialization is useless
Conv2D is not intialized !!!
"""
# print(m, getattr(getattr(m, "weight", nn.Identity()), "INIT", None), isinstance(m, nn.Linear), "======================")
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
# Encoder and Bottleneck
def forward_features(self, x):
x = self.patch_embed(x) # 下采样4倍通道数增加为96 24 56 56 96
x_downsample = []
for layer in self.layers:
x_downsample.append(x)
x = layer(x)
x = self.norm(x) # B H W C
x_middle=x_downsample[0:3]
x_middle.append(x)
return x, x_downsample,x_middle
# def forward_backbone(self, x):
# x = self.patch_embed(x)
# for layer in self.layers:
# x = layer(x)
# return x
# Dencoder and Skip connection
def forward_up_features(self, x, x_downsample):
for inx, layer_up in enumerate(self.layers_up):
if inx == 0:
x = layer_up(x)
else:
x = torch.cat([x, x_downsample[3 - inx]], -1)
x = self.concat_back_dim[inx](x)
x = layer_up(x)
x = self.norm_up(x) # B H W C
return x
def up_x4(self, x):
if self.final_upsample == "expand_first":
B, H, W, C = x.shape
x = self.up(x)
x = x.view(B, 4 * H, 4 * W, -1)
x = x.permute(0, 3, 1, 2) # B,C,H,W
x = self.output(x)
return x
def forward(self, x): # 24 3 224 224
x, x_downsample,x_middle= self.forward_features(x)
# x = self.forward_up_features(x,x_downsample)
# x = self.up_x4(x)
x_final=[]
for feat_every in x_middle:
feat=feat_every.permute(0,3,1,2)
x_final.append(feat)
return x_final
if __name__ == "__main__":
# check_vssm_equals_vmambadp()
# from torchstat import stat
#
# from thop import profile
weight="/media/cm/D450B76750B74ECC/Fusion_model/checkpoint/vmamba_tiny_e292.pth"
# model=MambaUnet_backbone(weight).to('cuda')
# # model = VSSM(use_checkpoint=weight).to('cuda')
# model.load_from(weight)
# int = torch.randn(2,1,224,224).cuda()
# out = model(int)
#
# # flops, params = profile(model, inputs=(int,))
# # print('flops:{}'.format(flops))
# # print('params:{}'.format(params))
# print(out.shape)
model = VSSM(pretrain_model=weight).to('cuda')
int = torch.randn(2, 3, 306, 306).cuda()
out = model(int)
# flops, params = profile(model, inputs=(int,))
# print('flops:{}'.format(flops))
# print('params:{}'.format(params))
# print(out.shape)
print(out)