import time import math import copy from functools import partial from typing import Optional, Callable import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange, repeat from timm.models.layers import DropPath, trunc_normal_ # from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" # import mamba_ssm.selective_scan_fn (in which causal_conv1d is needed) try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref except: pass # # # an alternative for mamba_ssm # try: # from selective_scan import selective_scan_fn as selective_scan_fn_v1 # from selective_scan import selective_scan_ref as selective_scan_ref_v1 # except: # pass """ 参考isdnet的主干,将其后三个阶段特征都不进行下采样 以2 3 306 306为例 """ from mmseg.registry import MODELS def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ import numpy as np # fvcore.nn.jit_handles def get_flops_einsum(input_shapes, equation): np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): # divided by 2 because we count MAC (multiply-add counted as one flop) flop = float(np.floor(float(line.split(":")[-1]) / 2)) return flop assert not with_complex flops = 0 # below code flops = 0 if False: ... """ dtype_in = u.dtype u = u.float() delta = delta.float() if delta_bias is not None: delta = delta + delta_bias[..., None].float() if delta_softplus: delta = F.softplus(delta) batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] is_variable_B = B.dim() >= 3 is_variable_C = C.dim() >= 3 if A.is_complex(): if is_variable_B: B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2)) if is_variable_C: C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2)) else: B = B.float() C = C.float() x = A.new_zeros((batch, dim, dstate)) ys = [] """ flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") if with_Group: flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") else: flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") if False: ... """ deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) if not is_variable_B: deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) else: if B.dim() == 3: deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) else: B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) last_state = None """ in_for_flops = B * D * N if with_Group: in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") else: in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") flops += L * in_for_flops if False: ... """ for i in range(u.shape[2]): x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: if C.dim() == 3: y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: last_state = x if y.is_complex(): y = y.real * 2 ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) """ if with_D: flops += B * D * L if with_Z: flops += B * D * L if False: ... """ out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) """ return flops def selective_scan_flop_jit(inputs, outputs): # xs, dts, As, Bs, Cs, Ds (skip), z (skip), dt_projs_bias (skip) assert inputs[0].debugName().startswith("xs") # (B, D, L) assert inputs[2].debugName().startswith("As") # (D, N) assert inputs[3].debugName().startswith("Bs") # (D, N) with_Group = len(inputs[3].type().sizes()) == 4 with_D = inputs[5].debugName().startswith("Ds") if not with_D: with_z = inputs[5].debugName().startswith("z") else: with_z = inputs[6].debugName().startswith("z") B, D, L = inputs[0].type().sizes() N = inputs[2].type().sizes()[1] flops = flops_selective_scan_ref(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group) return flops class PatchEmbed2D(nn.Module): r""" Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): super().__init__() if isinstance(patch_size, int): patch_size = (patch_size, patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=1) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): x = self.proj(x).permute(0, 2, 3, 1) if self.norm is not None: x = self.norm(x) return x class PatchMerging2D(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) # x 16 77 77 96 def forward(self, x): #x = F.pad(x, pad=(0, 0, 1, 1, 1, 1), mode="constant", value=0) # 最后俩个填充第第一维度,中间两个填充第二维度,开头两个填充第一维度 B, H, W, C = x.shape SHAPE_FIX = [-1, -1] if (W % 2 != 0) or (H % 2 != 0): #print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) SHAPE_FIX[0] = H // 2 SHAPE_FIX[1] = W // 2 # 将空间信息融合进通道中 x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C if SHAPE_FIX[0] > 0: x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C x = self.norm(x) # 4c x = self.reduction(x) # 2c return x class PatchMerging2D_expand(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) # x 16 77 77 96 def forward(self, x): x = F.pad(x, pad=(0, 0, 1, 1, 1, 1), mode="constant", value=0) # 最后俩个填充第第一维度,中间两个填充第二维度,开头两个填充第一维度 B, H, W, C = x.shape SHAPE_FIX = [-1, -1] if (W % 2 != 0) or (H % 2 != 0): #print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) SHAPE_FIX[0] = H // 2 SHAPE_FIX[1] = W // 2 # 将空间信息融合进通道中 x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C if SHAPE_FIX[0] > 0: x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C x = self.norm(x) # 4c x = self.reduction(x) # 2c return x class PatchMerging2D_nodown(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(2 * dim, 2 * dim, bias=False) self.norm = norm_layer(2 * dim) # x 16 77 77 96 def forward(self, x): # x = F.pad(x, pad=(0, 0, 1, 1, 1, 1), mode="constant", value=0) # 最后俩个填充第第一维度,中间两个填充第二维度,开头两个填充第一维度 B, H, W, C = x.shape SHAPE_FIX = [-1, -1] if (W % 2 != 0) or (H % 2 != 0): #print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) SHAPE_FIX[0] = H // 2 SHAPE_FIX[1] = W // 2 # 将空间信息融合进通道中 x0 = x x1 = x x = torch.cat([x0, x1], -1) # B H/2 W/2 4*C x = x.view(B, H, W, 2 * C) # B H/2*W/2 4*C x = self.norm(x) # 4c 我把他看作了bn x = self.reduction(x) # 2c 把他看作了1*1卷及 return x class PatchExpand(nn.Module): def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.expand = nn.Linear( dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() self.norm = norm_layer(dim // dim_scale) def forward(self, x): x = self.expand(x) B, H, W, C = x.shape x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) x = self.norm(x) return x class FinalPatchExpand_X4(nn.Module): def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.dim_scale = dim_scale self.expand = nn.Linear(dim, 16 * dim, bias=False) self.output_dim = dim self.norm = norm_layer(self.output_dim) def forward(self, x): x = self.expand(x) B, H, W, C = x.shape x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C // (self.dim_scale ** 2)) x = self.norm(x) return x class SS2D(nn.Module): def __init__( self, d_model, d_state=16, # d_state="auto", # 20240109 d_conv=3, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, dropout=0., conv_bias=True, bias=False, device=None, dtype=None, **kwargs, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_state = d_state # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) self.conv2d = nn.Conv2d( in_channels=self.d_inner, out_channels=self.d_inner, groups=self.d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) self.act = nn.SiLU() self.x_proj = ( nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), ) self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) del self.x_proj self.dt_projs = ( self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), ) self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) del self.dt_projs self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) self.forward_core = self.forward_corev0 # self.forward_core = self.forward_corev0_seq # self.forward_core = self.forward_corev1 self.out_norm = nn.LayerNorm(self.d_inner) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else None @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank ** -0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 1: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 1: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D def forward_corev0(self, x: torch.Tensor): self.selective_scan = selective_scan_fn B, C, H, W = x.shape L = H * W K = 4 x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) xs = xs.float().view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) Ds = self.Ds.float().view(-1) # (k * d) As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) out_y = self.selective_scan( xs, dts, As, Bs, Cs, Ds, z=None, delta_bias=dt_projs_bias, delta_softplus=True, return_last_state=False, ).view(B, K, -1, L) assert out_y.dtype == torch.float inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) y = self.out_norm(y).to(x.dtype) return y def forward_corev0_seq(self, x: torch.Tensor): self.selective_scan = selective_scan_fn B, C, H, W = x.shape L = H * W K = 4 x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) xs = xs.float().view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) Ds = self.Ds.float().view(-1) # (k * d) As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) out_y = [] for i in range(4): yi = self.selective_scan( xs[:, i], dts[:, i], As[i], Bs[:, i], Cs[:, i], Ds[i], delta_bias=dt_projs_bias[i], delta_softplus=True, ).view(B, -1, L) out_y.append(yi) out_y = torch.stack(out_y, dim=1) assert out_y.dtype == torch.float inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) y = self.out_norm(y).to(x.dtype) return y def forward_corev1(self, x: torch.Tensor): self.selective_scan = selective_scan_fn_v1 B, C, H, W = x.shape L = H * W K = 4 x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) xs = xs.view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().view(B, -1, L) # (b, k * d, l) Bs = Bs.view(B, K, -1, L) # (b, k, d_state, l) Cs = Cs.view(B, K, -1, L) # (b, k, d_state, l) As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) Ds = self.Ds.view(-1) # (k * d) dt_projs_bias = self.dt_projs_bias.view(-1) # (k * d) # print(self.Ds.dtype, self.A_logs.dtype, self.dt_projs_bias.dtype, flush=True) # fp16, fp16, fp16 out_y = self.selective_scan( xs, dts, As, Bs, Cs, Ds, delta_bias=dt_projs_bias, delta_softplus=True, ).view(B, K, -1, L) assert out_y.dtype == torch.float16 inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) y = out_y[:, 0].float() + inv_y[:, 0].float() + wh_y.float() + invwh_y.float() y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) y = self.out_norm(y).to(x.dtype) return y def forward(self, x: torch.Tensor, **kwargs): B, H, W, C = x.shape # b 24 wh 56*56 c96 xz = self.in_proj(x) x, z = xz.chunk(2, dim=-1) # (b, h, w, d) x = x.permute(0, 3, 1, 2).contiguous() x = self.act(self.conv2d(x)) # (b, d, h, w) y = self.forward_core(x) y = y * F.silu(z) out = self.out_proj(y) if self.dropout is not None: out = self.dropout(out) return out class ChannelAttention(nn.Module): """Channel attention used in RCAN. Args: num_feat (int): Channel number of intermediate features. squeeze_factor (int): Channel squeeze factor. Default: 16. """ def __init__(self, num_feat, squeeze_factor=16): super(ChannelAttention, self).__init__() self.attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) def forward(self, x): y = self.attention(x) return x * y class CAB(nn.Module): def __init__(self, num_feat, is_light_sr= False, compress_ratio=3,squeeze_factor=30): super(CAB, self).__init__() self.cab = nn.Sequential( nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), nn.GELU(), nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), ChannelAttention(num_feat, squeeze_factor) ) def forward(self, x): return self.cab(x) class VSSBlock(nn.Module): def __init__( self, hidden_dim: int = 0, drop_path: float = 0, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), attn_drop_rate: float = 0, d_state: int = 16, vssm_mode='vssm_cab', **kwargs, ): super().__init__() self.vssm_mode=vssm_mode self.ln_1 = norm_layer(hidden_dim) self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) self.drop_path = DropPath(drop_path) #gate self.skip_scale= nn.Parameter(torch.ones(hidden_dim)) self.conv_blk = CAB(hidden_dim) self.ln_2 = nn.LayerNorm(hidden_dim) self.skip_scale2 = nn.Parameter(torch.ones(hidden_dim)) #input 8 77 77 96 def forward(self, input: torch.Tensor): if self.vssm_mode=='vssm': x = input*self.skip_scale + self.drop_path(self.self_attention(self.ln_1(input))) elif self.vssm_mode=='vssm_cab': #mambair x = input + self.drop_path(self.self_attention(self.ln_1(input)))#加上门控模块的mamba,就这个不加门控的结果反而好 x = x * self.skip_scale2 + self.conv_blk(self.ln_2(x).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous() return x class VSSLayer(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, d_state=16, vssm_mode='', **kwargs, ): super().__init__() self.dim = dim self.use_checkpoint = use_checkpoint self.vssm_mode=vssm_mode self.blocks = nn.ModuleList([ VSSBlock( hidden_dim=dim, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, attn_drop_rate=attn_drop, d_state=d_state, vssm_mode= self.vssm_mode, ) for i in range(depth)]) if True: # is this really applied? Yes, but been overriden later in VSSM! def _init_weights(module: nn.Module): for name, p in module.named_parameters(): if name in ["out_proj.weight"]: p = p.clone().detach_() # fake init, just to keep the seed .... nn.init.kaiming_uniform_(p, a=math.sqrt(5)) self.apply(_init_weights) if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x class VSSLayer_up(nn.Module): """ A basic Swin Transformer layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm Upsample (nn.Module | None, optional): Upsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False, d_state=16, **kwargs, ): super().__init__() self.dim = dim self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ VSSBlock( hidden_dim=dim, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, attn_drop_rate=attn_drop, d_state=d_state, ) for i in range(depth)]) if True: # is this really applied? Yes, but been overriden later in VSSM! def _init_weights(module: nn.Module): for name, p in module.named_parameters(): if name in ["out_proj.weight"]: p = p.clone().detach_() # fake init, just to keep the seed .... nn.init.kaiming_uniform_(p, a=math.sqrt(5)) self.apply(_init_weights) if upsample is not None: self.upsample = PatchExpand(dim, dim_scale=2, norm_layer=nn.LayerNorm) else: self.upsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.upsample is not None: x = self.upsample(x) return x # @MODELS.register_module() class VSSM(nn.Module):#原始ssm def __init__(self, patch_size=4, in_chans=3, num_classes=4, depths=[2, 2, 9, 2], dims=[64,128, 256, 512], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True,vssm_mode='vssm', use_checkpoint=False ,pretrain_model="", **kwargs): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) if isinstance(dims, int): dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] self.embed_dim = dims[0] self.num_features = dims[-1] self.num_features_up = int(dims[0] * 2) self.dims = dims self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, norm_layer=norm_layer if patch_norm else None) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule self.vssm_mode=vssm_mode # build encoder and bottleneck layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): # 下采样层 layer = VSSLayer( # dim=dims[i_layer], #int(embed_dim * 2 ** i_layer) dim=int(dims[0] * 2 ** i_layer), depth=depths[i_layer], d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None, use_checkpoint=use_checkpoint, vssm_mode=self.vssm_mode, ) self.layers.append(layer) self.norm = norm_layer(self.num_features) if pretrain_model: print('use pretrain model {}'.format(pretrain_model)) self.load_from(pretrain_model) self.apply(self._init_weights) else: self.apply(self._init_weights) def load_from(self, config): # pretrained_path = config.MODEL.PRETRAIN_CKPT pretrained_path = config if pretrained_path is not None: print("pretrained_path:{}".format(pretrained_path)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pretrained_dict = torch.load(pretrained_path, map_location=device) if "model" not in pretrained_dict: print("---start load pretrained modle by splitting---") pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()} for k in list(pretrained_dict.keys()): if "output" in k: print("delete key:{}".format(k)) del pretrained_dict[k] msg = self.load_state_dict(pretrained_dict, strict=False) # print(msg) return pretrained_dict = pretrained_dict['model'] print("---start load pretrained modle of swin encoder---") model_dict = self.state_dict() full_dict = copy.deepcopy(pretrained_dict) for k, v in pretrained_dict.items(): if "layers." in k: current_layer_num = 3 - int(k[7:8]) current_k = "layers_up." + str(current_layer_num) + k[8:] full_dict.update({current_k: v}) for k in list(full_dict.keys()): if k in model_dict: if full_dict[k].shape != model_dict[k].shape: print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape)) del full_dict[k] msg = self.load_state_dict(full_dict, strict=False) # print(msg) else: print("none pretrain") def init_weight(self, pretrain_model): state_dict = torch.load(pretrain_model)["state_dict"]#权重字典键 self_state_dict = self.state_dict() for k, v in state_dict.items(): if k == 'features.0.conv.weight' and self.in_channels != 3: v = torch.cat([v, v], dim=1) self_state_dict.update({k: v}) self.load_state_dict(self_state_dict, strict=False) def _init_weights(self, m: nn.Module): """ out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear no fc.weight found in the any of the model parameters no nn.Embedding found in the any of the model parameters so the thing is, VSSBlock initialization is useless Conv2D is not intialized !!! """ # print(m, getattr(getattr(m, "weight", nn.Identity()), "INIT", None), isinstance(m, nn.Linear), "======================") if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) # Encoder and Bottleneck def forward_features(self, x): x = self.patch_embed(x) # 下采样4倍,通道数增加为96 ;24 56 56 96 x_downsample = [] for layer in self.layers: x_downsample.append(x) x = layer(x) x = self.norm(x) # B H W C x_middle=x_downsample[0:3] x_middle.append(x) return x, x_downsample,x_middle def forward(self, x): # 24 3 224 224 x, x_downsample,x_middle= self.forward_features(x) # x = self.forward_up_features(x,x_downsample) # x = self.up_x4(x) x_final=[] for feat_every in x_middle: feat=feat_every.permute(0,3,1,2) x_final.append(feat) return x_final if __name__ == "__main__": # check_vssm_equals_vmambadp() # from torchstat import stat # # from thop import profile weight="/media/cm/D450B76750B74ECC/Fusion_model/checkpoint/vmamba_tiny_e292.pth" # model=MambaUnet_backbone(weight).to('cuda') # # model = VSSM(use_checkpoint=weight).to('cuda') # model.load_from(weight) # int = torch.randn(2,1,224,224).cuda() # out = model(int) # # # flops, params = profile(model, inputs=(int,)) # # print('flops:{}'.format(flops)) # # print('params:{}'.format(params)) # print(out.shape) model = VSSM(pretrain_model=weight).to('cuda:5') int = torch.randn(1, 3, 1224, 1224).to('cuda:5') out = model(int) # flops, params = profile(model, inputs=(int,)) # print('flops:{}'.format(flops)) # print('params:{}'.format(params)) # print(out.shape) print(out)