import timm import re import time import math import numpy as np from functools import partial from typing import Optional, Union, Type, List, Tuple, Callable, Dict import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange, repeat from timm.models.layers import DropPath, to_2tuple, trunc_normal_ # from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" from mmseg.registry import MODELS class PatchEmbed2D(nn.Module): r""" Image to Patch Embedding Args: patch_size (int): Patch token size. Default: 4. in_chans (int): Number of input image channels. Default: 3. embed_dim (int): Number of linear projection output channels. Default: 96. norm_layer (nn.Module, optional): Normalization layer. Default: None """ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs): super().__init__() if isinstance(patch_size, int): patch_size = (patch_size, patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): x = self.proj(x).permute(0, 2, 3, 1) if self.norm is not None: x = self.norm(x) return x class PatchMerging2D(nn.Module): r""" Patch Merging Layer. Args: input_resolution (tuple[int]): Resolution of input feature. dim (int): Number of input channels. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__(self, dim, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) def forward(self, x): B, H, W, C = x.shape SHAPE_FIX = [-1, -1] if (W % 2 != 0) or (H % 2 != 0): print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True) SHAPE_FIX[0] = H // 2 SHAPE_FIX[1] = W // 2 x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C if SHAPE_FIX[0] > 0: x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :] x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C x = self.norm(x) x = self.reduction(x) return x class SS2D(nn.Module): def __init__( self, d_model, d_state=16, d_conv=3, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, dropout=0., conv_bias=True, bias=False, device=None, dtype=None, **kwargs, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_state = d_state self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) self.conv2d = nn.Conv2d( in_channels=self.d_inner, out_channels=self.d_inner, groups=self.d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) self.act = nn.SiLU() self.x_proj = ( nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), ) self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) del self.x_proj self.dt_projs = ( self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), ) self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) del self.dt_projs self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) # self.selective_scan = selective_scan_fn self.out_norm = nn.LayerNorm(self.d_inner) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else None @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank ** -0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 1: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 1: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D def forward_core(self, x: torch.Tensor): B, C, H, W = x.shape L = H * W K = 4 x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) xs = xs.float().view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) Ds = self.Ds.float().view(-1) # (k * d) As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) out_y = self.selective_scan( xs, dts, As, Bs, Cs, Ds, z=None, delta_bias=dt_projs_bias, delta_softplus=True, return_last_state=False, ).view(B, K, -1, L) assert out_y.dtype == torch.float inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y def forward(self, x: torch.Tensor, **kwargs): B, H, W, C = x.shape xz = self.in_proj(x) x, z = xz.chunk(2, dim=-1) # (b, h, w, d) x = x.permute(0, 3, 1, 2).contiguous() x = self.act(self.conv2d(x)) # (b, d, h, w) y1, y2, y3, y4 = self.forward_core(x) assert y1.dtype == torch.float32 y = y1 + y2 + y3 + y4 y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) y = self.out_norm(y) y = y * F.silu(z) out = self.out_proj(y) if self.dropout is not None: out = self.dropout(out) return out class VSSBlock(nn.Module): def __init__( self, hidden_dim: int = 0, drop_path: float = 0, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), attn_drop_rate: float = 0, d_state: int = 16, **kwargs, ): super().__init__() self.ln_1 = norm_layer(hidden_dim) self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs) self.drop_path = DropPath(drop_path) def forward(self, input: torch.Tensor): x = input + self.drop_path(self.self_attention(self.ln_1(input))) return x class VSSLayer(nn.Module): """ A basic layer for one stage. Args: dim (int): Number of input channels. depth (int): Number of blocks. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. """ def __init__( self, dim, depth, attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, d_state=16, **kwargs, ): super().__init__() self.dim = dim self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ VSSBlock( hidden_dim=dim, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, attn_drop_rate=attn_drop, d_state=d_state, ) for i in range(depth)]) if True: # is this really applied? Yes, but been overriden later in VSSM! def _init_weights(module: nn.Module): for name, p in module.named_parameters(): if name in ["out_proj.weight"]: p = p.clone().detach_() # fake init, just to keep the seed .... nn.init.kaiming_uniform_(p, a=math.sqrt(5)) self.apply(_init_weights) if downsample is not None: self.downsample = downsample(dim=dim, norm_layer=norm_layer) else: self.downsample = None def forward(self, x): for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x) else: x = blk(x) if self.downsample is not None: x = self.downsample(x) return x class VSSMEncoder(nn.Module): def __init__(self, patch_size=4, in_chans=3, depths=[2, 2, 9, 2], dims=[96, 192, 384, 768], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=nn.LayerNorm, patch_norm=True, use_checkpoint=False, **kwargs): super().__init__() self.num_layers = len(depths) if isinstance(dims, int): dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] self.embed_dim = dims[0] self.num_features = dims[-1] self.dims = dims self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim, norm_layer=norm_layer if patch_norm else None) # WASTED absolute position embedding ====================== self.ape = False if self.ape: self.patches_resolution = self.patch_embed.patches_resolution self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim)) trunc_normal_(self.absolute_pos_embed, std=.02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule self.layers = nn.ModuleList() self.downsamples = nn.ModuleList() for i_layer in range(self.num_layers): layer = VSSLayer( dim=dims[i_layer], depth=depths[i_layer], d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=None, use_checkpoint=use_checkpoint, ) self.layers.append(layer) if i_layer < self.num_layers - 1: self.downsamples.append(PatchMerging2D(dim=dims[i_layer], norm_layer=norm_layer)) self.apply(self._init_weights) def _init_weights(self, m: nn.Module): """ out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear no fc.weight found in the any of the model parameters no nn.Embedding found in the any of the model parameters so the thing is, VSSBlock initialization is useless Conv2D is not intialized !!! """ if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {'absolute_pos_embed'} @torch.jit.ignore def no_weight_decay_keywords(self): return {'relative_position_bias_table'} def forward(self, x): x_ret = [] x_ret.append(x) x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for s, layer in enumerate(self.layers): x = layer(x) x_ret.append(x.permute(0, 3, 1, 2)) if s < len(self.downsamples): x = self.downsamples[s](x) return x_ret class ConvBNReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False): super(ConvBNReLU, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2), norm_layer(out_channels), nn.ReLU6() ) class ConvBN(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d, bias=False): super(ConvBN, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2), norm_layer(out_channels) ) class Conv(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False): super(Conv, self).__init__( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias, dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2) ) class SeparableConvBNReLU(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): super(SeparableConvBNReLU, self).__init__( nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False), norm_layer(out_channels), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.ReLU6() ) class SeparableConvBN(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): super(SeparableConvBN, self).__init__( nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False), norm_layer(out_channels), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) ) class SeparableConv(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1): super(SeparableConv, self).__init__( nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, dilation=dilation, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2, groups=in_channels, bias=False), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) ) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0, bias=True) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1, 1, 0, bias=True) self.drop = nn.Dropout(drop, inplace=True) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class GlobalLocalAttention(nn.Module): def __init__(self, dim=256, num_heads=16, qkv_bias=False, window_size=8, relative_pos_embedding=True ): super().__init__() self.num_heads = num_heads head_dim = dim // self.num_heads self.scale = head_dim ** -0.5 self.ws = window_size self.qkv = Conv(dim, 3 * dim, kernel_size=1, bias=qkv_bias) self.local1 = ConvBN(dim, dim, kernel_size=3) self.local2 = ConvBN(dim, dim, kernel_size=1) self.proj = SeparableConvBN(dim, dim, kernel_size=window_size) self.attn_x = nn.AvgPool2d(kernel_size=(window_size, 1), stride=1, padding=(window_size // 2 - 1, 0)) self.attn_y = nn.AvgPool2d(kernel_size=(1, window_size), stride=1, padding=(0, window_size // 2 - 1)) self.relative_pos_embedding = relative_pos_embedding if self.relative_pos_embedding: # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.ws) coords_w = torch.arange(self.ws) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.ws - 1 # shift to start from 0 relative_coords[:, :, 1] += self.ws - 1 relative_coords[:, :, 0] *= 2 * self.ws - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) trunc_normal_(self.relative_position_bias_table, std=.02) def pad(self, x, ps): _, _, H, W = x.size() if W % ps != 0: x = F.pad(x, (0, ps - W % ps,0,0), mode='reflect') if H % ps != 0: x = F.pad(x, (0, 0, 0, ps - H % ps), mode='reflect') return x def pad_out(self, x): x = F.pad(x, pad=(0, 1, 0, 1), mode='reflect') return x def forward(self, x): B, C, H, W = x.shape local = self.local2(x) + self.local1(x) x = self.pad(x, self.ws) B, C, Hp, Wp = x.shape qkv = self.qkv(x) q, k, v = rearrange(qkv, 'b (qkv h d) (hh ws1) (ww ws2) -> qkv (b hh ww) h (ws1 ws2) d', h=self.num_heads, d=C // self.num_heads, hh=Hp // self.ws, ww=Wp // self.ws, qkv=3, ws1=self.ws, ws2=self.ws) dots = (q @ k.transpose(-2, -1)) * self.scale if self.relative_pos_embedding: relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.ws * self.ws, self.ws * self.ws, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww dots += relative_position_bias.unsqueeze(0) attn = dots.softmax(dim=-1) attn = attn @ v attn = rearrange(attn, '(b hh ww) h (ws1 ws2) d -> b (h d) (hh ws1) (ww ws2)', h=self.num_heads, d=C // self.num_heads, hh=Hp // self.ws, ww=Wp // self.ws, ws1=self.ws, ws2=self.ws) attn = attn[:, :, :H, :W] out = self.attn_x(F.pad(attn, pad=(0, 0, 0, 1), mode='reflect')) + \ self.attn_y(F.pad(attn, pad=(0, 1, 0, 0), mode='reflect')) out = out + local out = self.pad_out(out) out = self.proj(out) # print(out.size()) out = out[:, :, :H, :W] return out class Block(nn.Module): def __init__(self, dim=256, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d, window_size=8): super().__init__() self.norm1 = norm_layer(dim) self.attn = GlobalLocalAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, window_size=window_size) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop) self.norm2 = norm_layer(dim) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class WF(nn.Module): def __init__(self, in_channels=128, decode_channels=128, eps=1e-8): super(WF, self).__init__() self.pre_conv = Conv(in_channels, decode_channels, kernel_size=1) self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) self.eps = eps self.post_conv = ConvBNReLU(decode_channels, decode_channels, kernel_size=3) def forward(self, x, res): x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) weights = nn.ReLU()(self.weights) fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps) x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x x = self.post_conv(x) return x class FeatureRefinementHead(nn.Module): def __init__(self, in_channels=64, decode_channels=64): super().__init__() self.pre_conv = Conv(in_channels, decode_channels, kernel_size=1) self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) self.eps = 1e-8 self.post_conv = ConvBNReLU(decode_channels, decode_channels, kernel_size=3) self.pa = nn.Sequential( nn.Conv2d(decode_channels, decode_channels, kernel_size=3, padding=1, groups=decode_channels), nn.Sigmoid()) self.ca = nn.Sequential(nn.AdaptiveAvgPool2d(1), Conv(decode_channels, decode_channels // 16, kernel_size=1), nn.ReLU6(), Conv(decode_channels // 16, decode_channels, kernel_size=1), nn.Sigmoid()) self.shortcut = ConvBN(decode_channels, decode_channels, kernel_size=1) self.proj = SeparableConvBN(decode_channels, decode_channels, kernel_size=3) self.act = nn.ReLU6() def forward(self, x, res): x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) weights = nn.ReLU()(self.weights) fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps) x = fuse_weights[0] * self.pre_conv(res) + fuse_weights[1] * x x = self.post_conv(x) shortcut = self.shortcut(x) pa = self.pa(x) * x ca = self.ca(x) * x x = pa + ca x = self.proj(x) + shortcut x = self.act(x) return x class AuxHead(nn.Module): def __init__(self, in_channels=64, num_classes=8): super().__init__() self.conv = ConvBNReLU(in_channels, in_channels) self.drop = nn.Dropout(0.1) self.conv_out = Conv(in_channels, num_classes, kernel_size=1) def forward(self, x, h, w): feat = self.conv(x) feat = self.drop(feat) feat = self.conv_out(feat) feat = F.interpolate(feat, size=(h, w), mode='bilinear', align_corners=False) return feat class Decoder(nn.Module): def __init__(self, encoder_channels=(64, 128, 256, 512), decode_channels=64, dropout=0.1, window_size=8, num_classes=6): super(Decoder, self).__init__() self.pre_conv = ConvBN(encoder_channels[-1], decode_channels, kernel_size=1) self.b4 = Block(dim=decode_channels, num_heads=8, window_size=window_size) self.b3 = Block(dim=decode_channels, num_heads=8, window_size=window_size) self.p3 = WF(encoder_channels[-2], decode_channels) self.b2 = Block(dim=decode_channels, num_heads=8, window_size=window_size) self.p2 = WF(encoder_channels[-3], decode_channels) self.p1 = FeatureRefinementHead(encoder_channels[-4], decode_channels) # self.segmentation_head = nn.Sequential(ConvBNReLU(decode_channels, decode_channels), # nn.Dropout2d(p=dropout, inplace=True), # Conv(decode_channels, num_classes, kernel_size=1)) self.segmentation_head = nn.Sequential(ConvBNReLU(decode_channels, decode_channels), nn.Dropout2d(p=dropout, inplace=True), ) self.init_weight() def forward(self, res1, res2, res3, res4, h, w): x = self.b4(self.pre_conv(res4)) x = self.p3(x, res3) x = self.b3(x) x = self.p2(x, res2) x = self.b2(x) x = self.p1(x, res1) x = self.segmentation_head(x) # x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False) return x def init_weight(self): for m in self.children(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, a=1) if m.bias is not None: nn.init.constant_(m.bias, 0) class BasicConv(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): super(BasicConv, self).__init__() self.out_channels = out_planes self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None self.relu = nn.ReLU() if relu else None def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu is not None: x = self.relu(x) return x class SoftPool2d(nn.Module): def __init__(self, kernel_size=2, stride=2): super(SoftPool2d, self).__init__() self.avgpool = nn.AvgPool2d(kernel_size, stride) def forward(self, x): x_exp = torch.exp(x) x_exp_pool = self.avgpool(x_exp) x = self.avgpool(x_exp * x) return x / x_exp_pool class Flatten(nn.Module): def forward(self, x): return x.view(x.size(0), -1) class ChannelAtt(nn.Module): def __init__(self, gate_channels, reduction_ratio=2, pool_types=['avg', 'max', 'soft']): super(ChannelAtt, self).__init__() self.gate_channels = gate_channels self.mlp = nn.Sequential( Flatten(), nn.Linear(gate_channels, gate_channels // reduction_ratio), nn.ReLU() # nn.Linear(gate_channels // reduction_ratio, gate_channels) ) self.pool_types = pool_types self.incr = nn.Linear(gate_channels // reduction_ratio, gate_channels) def forward(self, x): channel_att_sum = None avg_pool = F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) max_pool = F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) avgpoolmlp = self.mlp(avg_pool) maxpoolmlp = self.mlp(max_pool) pooladd = avgpoolmlp + maxpoolmlp self.pool = SoftPool2d(kernel_size=(x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) soft_pool = self.mlp(self.pool(x)) weightPool = soft_pool * pooladd # channel_att_sum = self.mlp(weightPool) channel_att_sum = self.incr(weightPool) Att = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) return Att class FusionAttention(nn.Module): def __init__(self, dim=256, ssmdims=256, num_heads=16, qkv_bias=False, window_size=8, relative_pos_embedding=True ): super().__init__() self.num_heads = num_heads head_dim = dim // self.num_heads self.scale = head_dim ** -0.5 self.ws = window_size self.qkv = Conv(dim, 3 * dim, kernel_size=1, bias=qkv_bias) self.local1 = ConvBN(ssmdims, dim, kernel_size=3) self.local2 = ConvBN(ssmdims, dim, kernel_size=1) self.proj = SeparableConvBN(dim, dim, kernel_size=window_size) self.attn_x = nn.AvgPool2d(kernel_size=(window_size, 1), stride=1, padding=(window_size // 2 - 1, 0)) self.attn_y = nn.AvgPool2d(kernel_size=(1, window_size), stride=1, padding=(0, window_size // 2 - 1)) self.relative_pos_embedding = relative_pos_embedding if self.relative_pos_embedding: # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.ws) coords_w = torch.arange(self.ws) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.ws - 1 # shift to start from 0 relative_coords[:, :, 1] += self.ws - 1 relative_coords[:, :, 0] *= 2 * self.ws - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) trunc_normal_(self.relative_position_bias_table, std=.02) def pad(self, x, ps): _, _, H, W = x.size() if W % ps != 0: x = F.pad(x, (0, ps - W % ps,0,0), mode='reflect') if H % ps != 0: x = F.pad(x, (0, 0, 0, ps - H % ps), mode='reflect') return x def pad_out(self, x): x = F.pad(x, pad=(0, 1, 0, 1), mode='reflect') return x def forward(self, x, y): ## x from res, need global; y from smm, need local B, C, H, W = x.shape local = self.local2(y) + self.local1(y) x = self.pad(x, self.ws) B, C, Hp, Wp = x.shape qkv = self.qkv(x) q, k, v = rearrange(qkv, 'b (qkv h d) (hh ws1) (ww ws2) -> qkv (b hh ww) h (ws1 ws2) d', h=self.num_heads, d=C // self.num_heads, hh=Hp // self.ws, ww=Wp // self.ws, qkv=3, ws1=self.ws, ws2=self.ws) dots = (q @ k.transpose(-2, -1)) * self.scale if self.relative_pos_embedding: relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.ws * self.ws, self.ws * self.ws, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww dots += relative_position_bias.unsqueeze(0) attn = dots.softmax(dim=-1) attn = attn @ v attn = rearrange(attn, '(b hh ww) h (ws1 ws2) d -> b (h d) (hh ws1) (ww ws2)', h=self.num_heads, d=C // self.num_heads, hh=Hp // self.ws, ww=Wp // self.ws, ws1=self.ws, ws2=self.ws) attn = attn[:, :, :H, :W] out = self.attn_x(F.pad(attn, pad=(0, 0, 0, 1), mode='reflect')) + \ self.attn_y(F.pad(attn, pad=(0, 1, 0, 0), mode='reflect')) out = out + local out = self.pad_out(out) out = self.proj(out) # print(out.size()) out = out[:, :, :H, :W] return out class FusionBlock(nn.Module): def __init__(self, dim=256, ssmdims=256, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d, window_size=8): super().__init__() self.normx = norm_layer(dim) self.normy = norm_layer(ssmdims) self.attn = FusionAttention(dim, ssmdims, num_heads=num_heads, qkv_bias=qkv_bias, window_size=window_size) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, drop=drop) self.norm2 = norm_layer(dim) def forward(self, x, y): x = x + self.drop_path(self.attn(self.normx(x), self.normy(y))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x @MODELS.register_module() class RS3Mamba(nn.Module): def __init__(self, decode_channels=64, dropout=0.1, backbone_name='swsl_resnet18', pretrained=False, window_size=8, num_classes=7 ): super().__init__() self.backbone = timm.create_model(backbone_name, features_only=True, output_stride=32, out_indices=(1, 2, 3, 4), pretrained=pretrained,) self.conv1 = self.backbone.conv1 self.bn1 = self.backbone.bn1 self.act1 = self.backbone.act1 self.maxpool = self.backbone.maxpool self.layers = nn.ModuleList() self.layers.append(self.backbone.layer1) self.layers.append(self.backbone.layer2) self.layers.append(self.backbone.layer3) self.layers.append(self.backbone.layer4) self.stem = nn.Sequential( nn.Conv2d(3, 48, kernel_size=7, stride=2, padding=3), nn.InstanceNorm2d(48, eps=1e-5, affine=True), ) self.vssm_encoder = VSSMEncoder(patch_size=2, in_chans=48) encoder_channels = self.backbone.feature_info.channels() ssm_dims = [96, 192, 384, 768] self.Fuse = nn.ModuleList() self.decoder = Decoder(encoder_channels, decode_channels, dropout, window_size, num_classes) for i in range(4): fuse = FusionBlock(encoder_channels[i], ssm_dims[i]) self.Fuse.append(fuse) def forward(self, x): h, w = x.size()[-2:] ssmx = self.stem(x) vss_outs = self.vssm_encoder(ssmx) # 48*128*128, 96*64*64, 192*32*32, 384*16*16, 768*8*8 ress = [] x = self.conv1(x) x = self.bn1(x) x = self.act1(x) x = self.maxpool(x) for i in range(len(self.layers)): x = self.layers[i](x) x = self.Fuse[i](x, vss_outs[i + 1]) res = x ress.append(res) x = self.decoder(ress[0], ress[1], ress[2], ress[3], h, w) return x if __name__=="__main__": model=RS3Mamba().to('cuda:3') img=torch.randn(1,3,512,512).to('cuda:3') out=model(img) print(out.size())