1384 lines
53 KiB
Python
1384 lines
53 KiB
Python
import time
|
||
import math
|
||
import copy
|
||
from functools import partial
|
||
from typing import Optional, Callable
|
||
|
||
import timm
|
||
import numpy as np
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import torch.utils.checkpoint as checkpoint
|
||
from einops import rearrange, repeat
|
||
from timm.models.layers import DropPath, trunc_normal_
|
||
# from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
|
||
from mmseg.registry import MODELS
|
||
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
|
||
|
||
# import mamba_ssm.selective_scan_fn (in which causal_conv1d is needed)
|
||
try:
|
||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
|
||
except:
|
||
pass
|
||
|
||
# an alternative for mamba_ssm
|
||
try:
|
||
from selective_scan import selective_scan_fn as selective_scan_fn_v1
|
||
from selective_scan import selective_scan_ref as selective_scan_ref_v1
|
||
except:
|
||
pass
|
||
|
||
|
||
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
|
||
"""
|
||
u: r(B D L)
|
||
delta: r(B D L)
|
||
A: r(D N)
|
||
B: r(B N L)
|
||
C: r(B N L)
|
||
D: r(D)
|
||
z: r(B D L)
|
||
delta_bias: r(D), fp32
|
||
|
||
ignores:
|
||
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
|
||
"""
|
||
import numpy as np
|
||
|
||
# fvcore.nn.jit_handles
|
||
def get_flops_einsum(input_shapes, equation):
|
||
np_arrs = [np.zeros(s) for s in input_shapes]
|
||
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
|
||
for line in optim.split("\n"):
|
||
if "optimized flop" in line.lower():
|
||
# divided by 2 because we count MAC (multiply-add counted as one flop)
|
||
flop = float(np.floor(float(line.split(":")[-1]) / 2))
|
||
return flop
|
||
|
||
assert not with_complex
|
||
|
||
flops = 0 # below code flops = 0
|
||
if False:
|
||
...
|
||
"""
|
||
dtype_in = u.dtype
|
||
u = u.float()
|
||
delta = delta.float()
|
||
if delta_bias is not None:
|
||
delta = delta + delta_bias[..., None].float()
|
||
if delta_softplus:
|
||
delta = F.softplus(delta)
|
||
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
|
||
is_variable_B = B.dim() >= 3
|
||
is_variable_C = C.dim() >= 3
|
||
if A.is_complex():
|
||
if is_variable_B:
|
||
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
|
||
if is_variable_C:
|
||
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
|
||
else:
|
||
B = B.float()
|
||
C = C.float()
|
||
x = A.new_zeros((batch, dim, dstate))
|
||
ys = []
|
||
"""
|
||
|
||
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
|
||
if with_Group:
|
||
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
|
||
else:
|
||
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
|
||
if False:
|
||
...
|
||
"""
|
||
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
|
||
if not is_variable_B:
|
||
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
|
||
else:
|
||
if B.dim() == 3:
|
||
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
|
||
else:
|
||
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
|
||
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
|
||
if is_variable_C and C.dim() == 4:
|
||
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
|
||
last_state = None
|
||
"""
|
||
|
||
in_for_flops = B * D * N
|
||
if with_Group:
|
||
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
|
||
else:
|
||
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
|
||
flops += L * in_for_flops
|
||
if False:
|
||
...
|
||
"""
|
||
for i in range(u.shape[2]):
|
||
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
|
||
if not is_variable_C:
|
||
y = torch.einsum('bdn,dn->bd', x, C)
|
||
else:
|
||
if C.dim() == 3:
|
||
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
|
||
else:
|
||
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
|
||
if i == u.shape[2] - 1:
|
||
last_state = x
|
||
if y.is_complex():
|
||
y = y.real * 2
|
||
ys.append(y)
|
||
y = torch.stack(ys, dim=2) # (batch dim L)
|
||
"""
|
||
|
||
if with_D:
|
||
flops += B * D * L
|
||
if with_Z:
|
||
flops += B * D * L
|
||
if False:
|
||
...
|
||
"""
|
||
out = y if D is None else y + u * rearrange(D, "d -> d 1")
|
||
if z is not None:
|
||
out = out * F.silu(z)
|
||
out = out.to(dtype=dtype_in)
|
||
"""
|
||
|
||
return flops
|
||
|
||
|
||
def selective_scan_flop_jit(inputs, outputs):
|
||
# xs, dts, As, Bs, Cs, Ds (skip), z (skip), dt_projs_bias (skip)
|
||
assert inputs[0].debugName().startswith("xs") # (B, D, L)
|
||
assert inputs[2].debugName().startswith("As") # (D, N)
|
||
assert inputs[3].debugName().startswith("Bs") # (D, N)
|
||
with_Group = len(inputs[3].type().sizes()) == 4
|
||
with_D = inputs[5].debugName().startswith("Ds")
|
||
if not with_D:
|
||
with_z = inputs[5].debugName().startswith("z")
|
||
else:
|
||
with_z = inputs[6].debugName().startswith("z")
|
||
B, D, L = inputs[0].type().sizes()
|
||
N = inputs[2].type().sizes()[1]
|
||
flops = flops_selective_scan_ref(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group)
|
||
return flops
|
||
|
||
|
||
class PatchEmbed2D(nn.Module):
|
||
r""" Image to Patch Embedding
|
||
Args:
|
||
patch_size (int): Patch token size. Default: 4.
|
||
in_chans (int): Number of input image channels. Default: 3.
|
||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||
"""
|
||
|
||
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
|
||
super().__init__()
|
||
if isinstance(patch_size, int):
|
||
patch_size = (patch_size, patch_size)
|
||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||
if norm_layer is not None:
|
||
self.norm = norm_layer(embed_dim)
|
||
else:
|
||
self.norm = None
|
||
|
||
def forward(self, x):
|
||
x = self.proj(x).permute(0, 2, 3, 1)
|
||
if self.norm is not None:
|
||
x = self.norm(x)
|
||
return x
|
||
|
||
|
||
class PatchMerging2D(nn.Module):
|
||
r""" Patch Merging Layer.
|
||
Args:
|
||
input_resolution (tuple[int]): Resolution of input feature.
|
||
dim (int): Number of input channels.
|
||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||
"""
|
||
|
||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||
self.norm = norm_layer(4 * dim)
|
||
|
||
def forward(self, x):
|
||
B, H, W, C = x.shape
|
||
|
||
SHAPE_FIX = [-1, -1]
|
||
if (W % 2 != 0) or (H % 2 != 0):
|
||
print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
|
||
SHAPE_FIX[0] = H // 2
|
||
SHAPE_FIX[1] = W // 2
|
||
|
||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||
|
||
if SHAPE_FIX[0] > 0:
|
||
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
||
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
||
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
||
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
|
||
|
||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||
x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C
|
||
|
||
x = self.norm(x)
|
||
x = self.reduction(x)
|
||
|
||
return x
|
||
|
||
|
||
# class PatchExpand(nn.Module):
|
||
# def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
|
||
# super().__init__()
|
||
# # self.dim = dim
|
||
# self.dim_scale = dim_scale
|
||
# self.conv = nn.Conv2d(dim, dim // dim_scale, 1, 1, 0, bias=False)
|
||
|
||
# # # Assuming dim_scale is 2, which means increasing spatial dimensions by a factor of 2
|
||
# # if dim_scale == 2:
|
||
# # self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
|
||
# # # Adjusting the number of channels after pixel shuffle
|
||
# # self.adjust_channels = nn.Conv2d(dim // 4, dim // dim_scale, kernel_size=1, stride=1, padding=0, bias=False)
|
||
# # else:
|
||
# # # If no scaling is needed, use an identity mapping
|
||
# # self.pixel_shuffle = nn.Identity()
|
||
# # self.adjust_channels = nn.Identity()
|
||
|
||
# # self.norm = norm_layer(dim // dim_scale)
|
||
|
||
# def forward(self, x):
|
||
# # Pixel shuffle expects the input in the format (B, C, H, W)
|
||
# x = rearrange(x, 'b h w c -> b c h w')
|
||
# if self.dim_scale == 2:
|
||
# x = F.interpolate(x, scale_factor=self.dim_scale, mode='bilinear', align_corners=True)
|
||
# x = self.conv(x)
|
||
# # else:
|
||
# # x = F.interpolate(x, scale_factor=self.dim_scale, mode='bilinear', align_corners=True)
|
||
|
||
# x = rearrange(x, 'b c h w -> b h w c')
|
||
# return x
|
||
|
||
class PatchExpand(nn.Module):
|
||
def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.dim_scale = dim_scale
|
||
# Assuming dim_scale is 2, which means increasing spatial dimensions by a factor of 2
|
||
if dim_scale == 2:
|
||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
|
||
# Adjusting the number of channels after pixel shuffle
|
||
self.adjust_channels = nn.Conv2d(dim // 4, dim // dim_scale, kernel_size=1, stride=1, padding=0, bias=False)
|
||
else:
|
||
# If no scaling is needed, use an identity mapping
|
||
self.pixel_shuffle = nn.Identity()
|
||
self.adjust_channels = nn.Identity()
|
||
self.norm = norm_layer(dim // dim_scale)
|
||
|
||
def forward(self, x):
|
||
# Pixel shuffle expects the input in the format (B, C, H, W)
|
||
x = rearrange(x, 'b h w c -> b c h w')
|
||
if self.dim_scale == 2:
|
||
x = self.pixel_shuffle(x)
|
||
x = self.adjust_channels(x)
|
||
# Convert back to the original format for normalization
|
||
x = rearrange(x, 'b c h w -> b h w c')
|
||
x = self.norm(x)
|
||
return x
|
||
|
||
# class PatchExpand(nn.Module):
|
||
# def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
|
||
# super().__init__()
|
||
# self.dim = dim
|
||
|
||
# self.expand = nn.Linear(
|
||
# dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity()
|
||
|
||
# self.norm = norm_layer(dim // dim_scale)
|
||
|
||
# def forward(self, x):
|
||
|
||
# x = self.expand(x)
|
||
|
||
# B, H, W, C = x.shape
|
||
|
||
# x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
|
||
# x= self.norm(x)
|
||
return x
|
||
|
||
|
||
# class FinalPatchExpand_X4(nn.Module):
|
||
# def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
|
||
# super().__init__()
|
||
# # self.dim = dim
|
||
# # self.dim_scale = dim_scale
|
||
# # # Assuming dim_scale is 2, which means increasing spatial dimensions by a factor of 2
|
||
# # if dim_scale == 4:
|
||
# # self.pixel_shuffle = nn.PixelShuffle(upscale_factor=4)
|
||
# # # Adjusting the number of channels after pixel shuffle
|
||
# # self.adjust_channels = nn.Conv2d(dim // 16, dim, kernel_size=1, stride=1, padding=0, bias=False)
|
||
# # else:
|
||
# # # If no scaling is needed, use an identity mapping
|
||
# # self.pixel_shuffle = nn.Identity()
|
||
# # self.adjust_channels = nn.Identity()
|
||
# # self.norm = norm_layer(dim)
|
||
|
||
# def forward(self, x):
|
||
# x = rearrange(x, 'b h w c -> b c h w')
|
||
# x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True)
|
||
# x = rearrange(x, 'b c h w -> b h w c')
|
||
# return x
|
||
|
||
|
||
class FinalPatchExpand_X4(nn.Module):
|
||
def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.dim_scale = dim_scale
|
||
self.expand = nn.Linear(dim, 16 * dim, bias=False)
|
||
self.output_dim = dim
|
||
self.norm = norm_layer(self.output_dim)
|
||
|
||
def forward(self, x):
|
||
x = self.expand(x)
|
||
B, H, W, C = x.shape
|
||
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
|
||
c=C // (self.dim_scale ** 2))
|
||
x = self.norm(x)
|
||
return x
|
||
|
||
|
||
class SS2D(nn.Module):
|
||
def __init__(
|
||
self,
|
||
d_model,
|
||
d_state=16,
|
||
# d_state="auto", # 20240109
|
||
d_conv=3,
|
||
expand=0.5,
|
||
dt_rank="auto",
|
||
dt_min=0.001,
|
||
dt_max=0.1,
|
||
dt_init="random",
|
||
dt_scale=1.0,
|
||
dt_init_floor=1e-4,
|
||
dropout=0.,
|
||
conv_bias=True,
|
||
bias=False,
|
||
device=None,
|
||
dtype=None,
|
||
**kwargs,
|
||
):
|
||
factory_kwargs = {"device": device, "dtype": dtype}
|
||
super().__init__()
|
||
self.d_model = d_model
|
||
self.d_state = d_state
|
||
# self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
|
||
self.d_conv = d_conv
|
||
self.expand = expand
|
||
self.d_inner = int(self.expand * self.d_model)
|
||
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
|
||
|
||
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
|
||
self.conv2d = nn.Conv2d(
|
||
in_channels=self.d_inner,
|
||
out_channels=self.d_inner,
|
||
groups=self.d_inner,
|
||
bias=conv_bias,
|
||
kernel_size=d_conv,
|
||
padding=(d_conv - 1) // 2,
|
||
**factory_kwargs,
|
||
)
|
||
|
||
self.channel_attention = ChannelAttentionModule(self.d_inner)
|
||
self.spatial_attention = SpatialAttentionModule()
|
||
|
||
self.act = nn.SiLU()
|
||
|
||
self.x_proj = (
|
||
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
||
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
||
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
||
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
|
||
)
|
||
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
|
||
del self.x_proj
|
||
|
||
self.dt_projs = (
|
||
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
||
**factory_kwargs),
|
||
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
||
**factory_kwargs),
|
||
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
||
**factory_kwargs),
|
||
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
|
||
**factory_kwargs),
|
||
)
|
||
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
|
||
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
|
||
del self.dt_projs
|
||
|
||
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
|
||
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
|
||
|
||
self.forward_core = self.forward_core_windows
|
||
# self.forward_core = self.forward_corev0_seq
|
||
# self.forward_core = self.forward_corev1
|
||
|
||
self.out_norm = nn.LayerNorm(self.d_inner)
|
||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
|
||
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
|
||
|
||
@staticmethod
|
||
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
|
||
**factory_kwargs):
|
||
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
|
||
|
||
# Initialize special dt projection to preserve variance at initialization
|
||
dt_init_std = dt_rank ** -0.5 * dt_scale
|
||
if dt_init == "constant":
|
||
nn.init.constant_(dt_proj.weight, dt_init_std)
|
||
elif dt_init == "random":
|
||
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
|
||
else:
|
||
raise NotImplementedError
|
||
|
||
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
||
dt = torch.exp(
|
||
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
||
+ math.log(dt_min)
|
||
).clamp(min=dt_init_floor)
|
||
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
||
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
||
with torch.no_grad():
|
||
dt_proj.bias.copy_(inv_dt)
|
||
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
||
dt_proj.bias._no_reinit = True
|
||
|
||
return dt_proj
|
||
|
||
@staticmethod
|
||
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
|
||
# S4D real initialization
|
||
A = repeat(
|
||
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
|
||
"n -> d n",
|
||
d=d_inner,
|
||
).contiguous()
|
||
A_log = torch.log(A) # Keep A_log in fp32
|
||
if copies > 1:
|
||
A_log = repeat(A_log, "d n -> r d n", r=copies)
|
||
if merge:
|
||
A_log = A_log.flatten(0, 1)
|
||
A_log = nn.Parameter(A_log)
|
||
A_log._no_weight_decay = True
|
||
return A_log
|
||
|
||
@staticmethod
|
||
def D_init(d_inner, copies=1, device=None, merge=True):
|
||
# D "skip" parameter
|
||
D = torch.ones(d_inner, device=device)
|
||
if copies > 1:
|
||
D = repeat(D, "n1 -> r n1", r=copies)
|
||
if merge:
|
||
D = D.flatten(0, 1)
|
||
D = nn.Parameter(D) # Keep in fp32
|
||
D._no_weight_decay = True
|
||
return D
|
||
|
||
def forward_core_windows(self, x: torch.Tensor, layer=1):
|
||
return self.forward_corev0(x)
|
||
if layer == 1:
|
||
return self.forward_corev0(x)
|
||
downsampled_4 = F.avg_pool2d(x, kernel_size=2, stride=2)
|
||
processed_4 = self.forward_corev0(downsampled_4)
|
||
processed_4 = processed_4.permute(0, 3, 1, 2)
|
||
restored_4 = F.interpolate(processed_4, scale_factor=2, mode='nearest')
|
||
restored_4 = restored_4.permute(0, 2, 3, 1)
|
||
if layer == 2:
|
||
output = (self.forward_corev0(x) + restored_4) / 2.0
|
||
|
||
downsampled_8 = F.avg_pool2d(x, kernel_size=4, stride=4)
|
||
processed_8 = self.forward_corev0(downsampled_8)
|
||
processed_8 = processed_8.permute(0, 3, 1, 2)
|
||
restored_8 = F.interpolate(processed_8, scale_factor=4, mode='nearest')
|
||
restored_8 = restored_8.permute(0, 2, 3, 1)
|
||
|
||
output = (self.forward_corev0(x) + restored_4 + restored_8) / 3.0
|
||
return output
|
||
# B C H W
|
||
|
||
num_splits = 2 ** layer
|
||
split_size = x.shape[2] // num_splits # Assuming H == W and is divisible by 2**layer
|
||
|
||
# Use unfold to create windows
|
||
x_unfolded = x.unfold(2, split_size, split_size).unfold(3, split_size, split_size)
|
||
x_unfolded = x_unfolded.contiguous().view(-1, x.size(1), split_size, split_size)
|
||
|
||
# Process all splits at once
|
||
processed_splits = self.forward_corev0(x_unfolded)
|
||
processed_splits = processed_splits.permute(0, 3, 1, 2)
|
||
# Reshape to get the splits back into their original positions and then permute to align dimensions
|
||
processed_splits = processed_splits.view(x.size(0), num_splits, num_splits, x.size(1), split_size, split_size)
|
||
processed_splits = processed_splits.permute(0, 3, 1, 4, 2, 5).contiguous()
|
||
processed_splits = processed_splits.view(x.size(0), x.size(1), x.size(2), x.size(3))
|
||
processed_splits = processed_splits.permute(0, 2, 3, 1)
|
||
|
||
return processed_splits
|
||
|
||
# num_splits = 2 ** layer
|
||
# split_size = x.shape[2] // num_splits # Assuming H == W and is divisible by 2**layer
|
||
# outputs = []
|
||
# for i in range(num_splits):
|
||
# row_outputs = []
|
||
# for j in range(num_splits):
|
||
# sub_x = x[:, :, i*split_size:(i+1)*split_size, j*split_size:(j+1)*split_size].contiguous()
|
||
# processed = self.forward_corev0(sub_x)
|
||
# row_outputs.append(processed)
|
||
# # Concatenate all column splits for current row
|
||
# outputs.append(torch.cat(row_outputs, dim=2))
|
||
# # Concatenate all rows
|
||
# final_output = torch.cat(outputs, dim=1)
|
||
|
||
return final_output
|
||
|
||
def forward_corev0(self, x: torch.Tensor):
|
||
self.selective_scan = selective_scan_fn
|
||
B, C, H, W = x.shape
|
||
L = H * W
|
||
|
||
K = 4
|
||
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
|
||
dim=1).view(B, 2, -1, L)
|
||
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
|
||
|
||
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
|
||
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
|
||
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
|
||
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
|
||
|
||
xs = xs.float().view(B, -1, L) # (b, k * d, l)
|
||
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
|
||
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
||
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
||
|
||
Ds = self.Ds.float().view(-1) # (k * d)
|
||
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
|
||
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
|
||
|
||
out_y = self.selective_scan(
|
||
xs, dts,
|
||
As, Bs, Cs, Ds, z=None,
|
||
delta_bias=dt_projs_bias,
|
||
delta_softplus=True,
|
||
return_last_state=False,
|
||
).view(B, K, -1, L)
|
||
assert out_y.dtype == torch.float
|
||
|
||
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
||
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
|
||
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
|
||
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
|
||
y = self.out_norm(y).to(x.dtype)
|
||
|
||
return y
|
||
|
||
def forward_corev0_seq(self, x: torch.Tensor):
|
||
self.selective_scan = selective_scan_fn
|
||
|
||
B, C, H, W = x.shape
|
||
L = H * W
|
||
K = 4
|
||
|
||
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
|
||
dim=1).view(B, 2, -1, L)
|
||
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
|
||
|
||
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
|
||
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
|
||
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
|
||
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
|
||
|
||
xs = xs.float().view(B, -1, L) # (b, k * d, l)
|
||
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
|
||
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
||
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
|
||
|
||
Ds = self.Ds.float().view(-1) # (k * d)
|
||
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
|
||
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
|
||
|
||
out_y = []
|
||
for i in range(4):
|
||
yi = self.selective_scan(
|
||
xs[:, i], dts[:, i],
|
||
As[i], Bs[:, i], Cs[:, i], Ds[i],
|
||
delta_bias=dt_projs_bias[i],
|
||
delta_softplus=True,
|
||
).view(B, -1, L)
|
||
out_y.append(yi)
|
||
out_y = torch.stack(out_y, dim=1)
|
||
assert out_y.dtype == torch.float
|
||
|
||
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
||
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
|
||
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
|
||
y = self.out_norm(y).to(x.dtype)
|
||
|
||
return y
|
||
|
||
def forward_corev1(self, x: torch.Tensor):
|
||
self.selective_scan = selective_scan_fn_v1
|
||
|
||
B, C, H, W = x.shape
|
||
L = H * W
|
||
K = 4
|
||
|
||
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
|
||
dim=1).view(B, 2, -1, L)
|
||
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
|
||
|
||
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
|
||
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
|
||
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
|
||
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
|
||
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
|
||
|
||
xs = xs.view(B, -1, L) # (b, k * d, l)
|
||
dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
|
||
Bs = Bs.view(B, K, -1, L) # (b, k, d_state, l)
|
||
Cs = Cs.view(B, K, -1, L) # (b, k, d_state, l)
|
||
|
||
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
|
||
Ds = self.Ds.view(-1) # (k * d)
|
||
dt_projs_bias = self.dt_projs_bias.view(-1) # (k * d)
|
||
|
||
# print(self.Ds.dtype, self.A_logs.dtype, self.dt_projs_bias.dtype, flush=True) # fp16, fp16, fp16
|
||
|
||
out_y = self.selective_scan(
|
||
xs, dts,
|
||
As, Bs, Cs, Ds,
|
||
delta_bias=dt_projs_bias,
|
||
delta_softplus=True,
|
||
).view(B, K, -1, L)
|
||
assert out_y.dtype == torch.float16
|
||
|
||
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
||
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
||
y = out_y[:, 0].float() + inv_y[:, 0].float() + wh_y.float() + invwh_y.float()
|
||
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
|
||
y = self.out_norm(y).to(x.dtype)
|
||
|
||
return y
|
||
|
||
def forward(self, x: torch.Tensor, layer=1, **kwargs):
|
||
B, H, W, C = x.shape
|
||
|
||
xz = self.in_proj(x)
|
||
x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
|
||
|
||
z = z.permute(0, 3, 1, 2)
|
||
|
||
z = self.channel_attention(z) * z
|
||
z = self.spatial_attention(z) * z
|
||
z = z.permute(0, 2, 3, 1).contiguous()
|
||
|
||
x = x.permute(0, 3, 1, 2).contiguous()
|
||
x = self.act(self.conv2d(x)) # (b, d, h, w)
|
||
|
||
y = self.forward_core(x, layer)
|
||
|
||
y = y * F.silu(z)
|
||
|
||
out = self.out_proj(y)
|
||
if self.dropout is not None:
|
||
out = self.dropout(out)
|
||
return out
|
||
|
||
|
||
class VSSBlock(nn.Module):
|
||
def __init__(
|
||
self,
|
||
hidden_dim: int = 0,
|
||
drop_path: float = 0,
|
||
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
||
attn_drop_rate: float = 0,
|
||
d_state: int = 16,
|
||
layer: int = 1,
|
||
**kwargs,
|
||
):
|
||
super().__init__()
|
||
factor = 2.0
|
||
d_model = int(hidden_dim // factor)
|
||
self.down = nn.Linear(hidden_dim, d_model)
|
||
self.up = nn.Linear(d_model, hidden_dim)
|
||
self.ln_1 = norm_layer(d_model)
|
||
self.self_attention = SS2D(d_model=d_model, dropout=attn_drop_rate, d_state=d_state, **kwargs)
|
||
self.drop_path = DropPath(drop_path)
|
||
self.layer = layer
|
||
|
||
def forward(self, input: torch.Tensor):
|
||
input_x = self.down(input)
|
||
input_x = input_x + self.drop_path(self.self_attention(self.ln_1(input_x), self.layer))
|
||
x = self.up(input_x) + input
|
||
return x
|
||
|
||
|
||
class VSSLayer(nn.Module):
|
||
""" A basic Swin Transformer layer for one stage.
|
||
Args:
|
||
dim (int): Number of input channels.
|
||
depth (int): Number of blocks.
|
||
drop (float, optional): Dropout rate. Default: 0.0
|
||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dim,
|
||
depth,
|
||
attn_drop=0.,
|
||
drop_path=0.,
|
||
norm_layer=nn.LayerNorm,
|
||
downsample=None,
|
||
use_checkpoint=False,
|
||
d_state=16,
|
||
**kwargs,
|
||
):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.use_checkpoint = use_checkpoint
|
||
|
||
self.blocks = nn.ModuleList([
|
||
VSSBlock(
|
||
hidden_dim=dim,
|
||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||
norm_layer=norm_layer,
|
||
attn_drop_rate=attn_drop,
|
||
d_state=d_state,
|
||
# expand=0.25,
|
||
)
|
||
for i in range(depth)])
|
||
|
||
if True: # is this really applied? Yes, but been overriden later in VSSM!
|
||
def _init_weights(module: nn.Module):
|
||
for name, p in module.named_parameters():
|
||
if name in ["out_proj.weight"]:
|
||
p = p.clone().detach_() # fake init, just to keep the seed ....
|
||
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
||
|
||
self.apply(_init_weights)
|
||
|
||
if downsample is not None:
|
||
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
||
else:
|
||
self.downsample = None
|
||
|
||
def forward(self, x):
|
||
for blk in self.blocks:
|
||
if self.use_checkpoint:
|
||
x = checkpoint.checkpoint(blk, x)
|
||
else:
|
||
x = blk(x)
|
||
|
||
if self.downsample is not None:
|
||
x = self.downsample(x)
|
||
|
||
return x
|
||
|
||
|
||
class VSSLayer_up(nn.Module):
|
||
""" A basic Swin Transformer layer for one stage.
|
||
Args:
|
||
dim (int): Number of input channels.
|
||
depth (int): Number of blocks.
|
||
drop (float, optional): Dropout rate. Default: 0.0
|
||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||
Upsample (nn.Module | None, optional): Upsample layer at the end of the layer. Default: None
|
||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dim,
|
||
depth,
|
||
attn_drop=0.,
|
||
drop_path=0.,
|
||
norm_layer=nn.LayerNorm,
|
||
upsample=None,
|
||
use_checkpoint=False,
|
||
d_state=16,
|
||
layer=1,
|
||
**kwargs,
|
||
):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.use_checkpoint = use_checkpoint
|
||
print('layer: ', layer, ' dstate: ', d_state)
|
||
|
||
self.blocks = nn.ModuleList([
|
||
VSSBlock(
|
||
hidden_dim=dim,
|
||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||
norm_layer=norm_layer,
|
||
attn_drop_rate=attn_drop,
|
||
d_state=d_state,
|
||
# expand=1.0,
|
||
layer=layer,
|
||
)
|
||
for i in range(depth)])
|
||
|
||
if True: # is this really applied? Yes, but been overriden later in VSSM!
|
||
def _init_weights(module: nn.Module):
|
||
for name, p in module.named_parameters():
|
||
if name in ["out_proj.weight"]:
|
||
p = p.clone().detach_() # fake init, just to keep the seed ....
|
||
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
||
|
||
self.apply(_init_weights)
|
||
|
||
if upsample is not None:
|
||
self.upsample = PatchExpand(dim, dim_scale=2, norm_layer=nn.LayerNorm)
|
||
else:
|
||
self.upsample = None
|
||
|
||
def forward(self, x):
|
||
for blk in self.blocks:
|
||
if self.use_checkpoint:
|
||
x = checkpoint.checkpoint(blk, x)
|
||
else:
|
||
x = blk(x)
|
||
|
||
if self.upsample is not None:
|
||
x = self.upsample(x)
|
||
|
||
return x
|
||
|
||
|
||
class ConvBNReLU(nn.Sequential):
|
||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,
|
||
bias=False):
|
||
super(ConvBNReLU, self).__init__(
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
|
||
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
|
||
norm_layer(out_channels),
|
||
nn.ReLU6()
|
||
)
|
||
|
||
|
||
class Conv(nn.Sequential):
|
||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
|
||
super(Conv, self).__init__(
|
||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
|
||
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
|
||
)
|
||
|
||
|
||
class AuxHead(nn.Module):
|
||
|
||
def __init__(self, in_channels=64, num_classes=8):
|
||
super().__init__()
|
||
self.conv = ConvBNReLU(in_channels, in_channels)
|
||
self.drop = nn.Dropout(0.1)
|
||
self.conv_out = Conv(in_channels, num_classes, kernel_size=1)
|
||
|
||
def forward(self, x, h, w):
|
||
feat = self.conv(x)
|
||
feat = self.drop(feat)
|
||
feat = self.conv_out(feat)
|
||
feat = F.interpolate(feat, size=(h, w), mode='bilinear', align_corners=False)
|
||
return feat
|
||
|
||
|
||
class BasicConv(nn.Module):
|
||
def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
|
||
super(BasicConv, self).__init__()
|
||
if bias and norm:
|
||
bias = False
|
||
|
||
padding = kernel_size // 2
|
||
layers = list()
|
||
if transpose:
|
||
padding = kernel_size // 2 - 1
|
||
layers.append(
|
||
nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
|
||
else:
|
||
layers.append(
|
||
nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
|
||
if norm:
|
||
layers.append(nn.BatchNorm2d(out_channel))
|
||
if relu:
|
||
layers.append(nn.ReLU(inplace=True))
|
||
self.main = nn.Sequential(*layers)
|
||
|
||
def forward(self, x):
|
||
return self.main(x)
|
||
|
||
|
||
class AFF(nn.Module):
|
||
def __init__(self, in_channel, out_channel):
|
||
super(AFF, self).__init__()
|
||
self.conv = nn.Sequential(
|
||
BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
|
||
BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
|
||
)
|
||
|
||
def forward(self, x1, x2, x4):
|
||
x = torch.cat([x1, x2, x4], dim=1)
|
||
|
||
return self.conv(x)
|
||
|
||
|
||
class ChannelAttentionModule(nn.Module):
|
||
def __init__(self, in_channels, reduction=4):
|
||
super(ChannelAttentionModule, self).__init__()
|
||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||
self.fc = nn.Sequential(
|
||
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
|
||
nn.ReLU(inplace=True),
|
||
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
|
||
)
|
||
self.sigmoid = nn.Sigmoid()
|
||
|
||
def forward(self, x):
|
||
avg_out = self.fc(self.avg_pool(x))
|
||
max_out = self.fc(self.max_pool(x))
|
||
out = avg_out + max_out
|
||
return self.sigmoid(out)
|
||
|
||
|
||
class SpatialAttentionModule(nn.Module):
|
||
def __init__(self, kernel_size=7):
|
||
super(SpatialAttentionModule, self).__init__()
|
||
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
|
||
self.sigmoid = nn.Sigmoid()
|
||
|
||
def forward(self, x):
|
||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||
x = torch.cat([avg_out, max_out], dim=1)
|
||
x = self.conv1(x)
|
||
return self.sigmoid(x)
|
||
|
||
|
||
class FusionConv(nn.Module):
|
||
def __init__(self, in_channels, out_channels, factor=4.0):
|
||
super(FusionConv, self).__init__()
|
||
dim = int(out_channels // factor)
|
||
self.down = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
|
||
self.conv_3x3 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
|
||
self.conv_5x5 = nn.Conv2d(dim, dim, kernel_size=5, stride=1, padding=2)
|
||
self.conv_7x7 = nn.Conv2d(dim, dim, kernel_size=7, stride=1, padding=3)
|
||
self.spatial_attention = SpatialAttentionModule()
|
||
self.channel_attention = ChannelAttentionModule(dim)
|
||
self.up = nn.Conv2d(dim, out_channels, kernel_size=1, stride=1)
|
||
self.down_2 = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
|
||
|
||
def forward(self, x1, x2, x4):
|
||
x_fused = torch.cat([x1, x2, x4], dim=1)
|
||
x_fused = self.down(x_fused)
|
||
x_fused_c = x_fused * self.channel_attention(x_fused)
|
||
x_3x3 = self.conv_3x3(x_fused)
|
||
x_5x5 = self.conv_5x5(x_fused)
|
||
x_7x7 = self.conv_7x7(x_fused)
|
||
x_fused_s = x_3x3 + x_5x5 + x_7x7
|
||
x_fused_s = x_fused_s * self.spatial_attention(x_fused_s)
|
||
|
||
x_out = self.up(x_fused_s + x_fused_c)
|
||
|
||
return x_out
|
||
|
||
|
||
class DownFusion(nn.Module):
|
||
def __init__(self, in_channels, out_channels):
|
||
super(DownFusion, self).__init__()
|
||
self.fusion_conv = FusionConv(in_channels, out_channels)
|
||
self.CAM = ChannelAttentionModule(out_channels)
|
||
|
||
def forward(self, x1, x2):
|
||
x_fused = torch.cat([x1, x2], dim=1)
|
||
x_fused = self.fusion_conv(x_fused)
|
||
|
||
x_fused = + x_fused
|
||
return x_fused
|
||
|
||
|
||
class MSAA(nn.Module):
|
||
def __init__(self, in_channels, out_channels):
|
||
super(MSAA, self).__init__()
|
||
self.fusion_conv = FusionConv(in_channels, out_channels)
|
||
|
||
def forward(self, x1, x2, x4, last=False):
|
||
# # x2 是从低到高,x4是从高到低的设计,x2传递语义信息,x4传递边缘问题特征补充
|
||
# x_1_2_fusion = self.fusion_1x2(x1, x2)
|
||
# x_1_4_fusion = self.fusion_1x4(x1, x4)
|
||
# x_fused = x_1_2_fusion + x_1_4_fusion
|
||
x_fused = self.fusion_conv(x1, x2, x4)
|
||
return x_fused
|
||
|
||
|
||
@MODELS.register_module()
|
||
class CMUnet(nn.Module):
|
||
def __init__(self, patch_size=4, in_chans=1, num_classes=4, depths=[2, 2, 9, 2],
|
||
dims=[96, 192, 384, 768], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||
norm_layer=nn.LayerNorm, patch_norm=True,
|
||
use_checkpoint=False, final_upsample="expand_first", **kwargs):
|
||
super().__init__()
|
||
self.num_classes = num_classes
|
||
self.num_layers = len(depths)
|
||
|
||
# resnet18
|
||
dims = [64, 128, 256, 512]
|
||
# dims = [64, 64, 64, 64]
|
||
self.backbone = timm.create_model("swsl_resnet18", features_only=True, output_stride=32,
|
||
out_indices=(1, 2, 3, 4),
|
||
pretrained=False) # , pretrained_cfg_overlay=dict(file='~/.cache/huggingface/hub/models–timm–resnet18.a1_in1k/pytorch_model.bin')) # build decoder layers
|
||
base_dims = 64
|
||
# state_dict = torch.load('/media/cm/1T_SSD/UHR_Comparative_Exp3/checkpoint/semi_weakly_supervised_resnet18-118f1556.pth')
|
||
# msg = self.backbone.load_state_dict(state_dict, strict=False)
|
||
# print('[INFO] ', msg)
|
||
|
||
if isinstance(dims, int):
|
||
dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
|
||
self.embed_dim = dims[0]
|
||
self.num_features = dims[-1]
|
||
self.num_features_up = int(dims[0] * 2)
|
||
self.dims = dims
|
||
self.final_upsample = final_upsample
|
||
|
||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||
|
||
self.layers_up = nn.ModuleList()
|
||
self.concat_back_dim = nn.ModuleList()
|
||
for i_layer in range(self.num_layers):
|
||
concat_linear = nn.Linear(2 * int(dims[0] * 2 ** (self.num_layers - 1 - i_layer)),
|
||
int(dims[0] * 2 ** (
|
||
self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
|
||
if i_layer == 0:
|
||
layer_up = nn.Sequential(
|
||
VSSLayer(
|
||
dim=int(dims[0] * 2 ** (self.num_layers - 1 - i_layer)),
|
||
depth=2,
|
||
d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
|
||
drop=drop_rate,
|
||
attn_drop=attn_drop_rate,
|
||
drop_path=dpr[sum(depths[:-1]):sum(depths[:])],
|
||
norm_layer=norm_layer,
|
||
downsample=None,
|
||
use_checkpoint=use_checkpoint),
|
||
PatchExpand(dim=int(self.embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2,
|
||
norm_layer=norm_layer)
|
||
)
|
||
# layer_up =PatchExpand(dim=int(self.embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
|
||
else:
|
||
layer_up = VSSLayer_up(
|
||
dim=int(dims[0] * 2 ** (self.num_layers - 1 - i_layer)),
|
||
depth=depths[(self.num_layers - 1 - i_layer)],
|
||
d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
|
||
drop=drop_rate,
|
||
attn_drop=attn_drop_rate,
|
||
drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
|
||
depths[:(self.num_layers - 1 - i_layer) + 1])],
|
||
norm_layer=norm_layer,
|
||
upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
|
||
use_checkpoint=use_checkpoint,
|
||
layer=i_layer,
|
||
)
|
||
self.layers_up.append(layer_up)
|
||
self.concat_back_dim.append(concat_linear)
|
||
|
||
self.norm = norm_layer(self.num_features)
|
||
self.norm_up = norm_layer(self.embed_dim)
|
||
|
||
if self.final_upsample == "expand_first":
|
||
print("---final upsample expand_first---")
|
||
self.up = FinalPatchExpand_X4(dim_scale=4, dim=self.embed_dim)
|
||
self.output = nn.Conv2d(in_channels=self.embed_dim, out_channels=self.num_classes, kernel_size=1,
|
||
bias=False)
|
||
|
||
if self.training:
|
||
self.conv4 = nn.Conv2d(base_dims * 2, num_classes, 1, bias=False)
|
||
self.conv3 = nn.Conv2d(base_dims, num_classes, 1, bias=False)
|
||
self.conv2 = nn.Conv2d(base_dims, num_classes, 1, bias=False)
|
||
# self.conv1 = nn.Conv2d(base_dims, num_classes, 1, bias=False)
|
||
|
||
hidden_dim = int(base_dims // 4)
|
||
self.AFFs = nn.ModuleList([
|
||
MSAA(hidden_dim * 7, base_dims),
|
||
MSAA(hidden_dim * 7, base_dims * 2),
|
||
MSAA(hidden_dim * 7, base_dims * 4),
|
||
])
|
||
|
||
self.transfer = nn.ModuleList(
|
||
[
|
||
nn.Conv2d(base_dims, hidden_dim, 1, bias=False),
|
||
nn.Conv2d(base_dims * 2, hidden_dim * 2, 1, bias=False),
|
||
nn.Conv2d(base_dims * 4, hidden_dim * 4, 1, bias=False),
|
||
]
|
||
)
|
||
|
||
self.apply(self._init_weights)
|
||
|
||
def _init_weights(self, m: nn.Module):
|
||
"""
|
||
out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
|
||
no fc.weight found in the any of the model parameters
|
||
no nn.Embedding found in the any of the model parameters
|
||
so the thing is, VSSBlock initialization is useless
|
||
|
||
Conv2D is not intialized !!!
|
||
"""
|
||
# print(m, getattr(getattr(m, "weight", nn.Identity()), "INIT", None), isinstance(m, nn.Linear), "======================")
|
||
if isinstance(m, nn.Linear):
|
||
trunc_normal_(m.weight, std=.02)
|
||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||
nn.init.constant_(m.bias, 0)
|
||
elif isinstance(m, nn.LayerNorm):
|
||
nn.init.constant_(m.bias, 0)
|
||
nn.init.constant_(m.weight, 1.0)
|
||
|
||
# Encoder and Bottleneck
|
||
def forward_features(self, x):
|
||
x = self.patch_embed(x)
|
||
|
||
x_downsample = []
|
||
for layer in self.layers:
|
||
x_downsample.append(x)
|
||
x = layer(x)
|
||
x = self.norm(x) # B H W C
|
||
return x, x_downsample
|
||
|
||
# def forward_backbone(self, x):
|
||
# x = self.patch_embed(x)
|
||
|
||
# for layer in self.layers:
|
||
# x = layer(x)
|
||
# return x
|
||
|
||
# Dencoder and Skip connection
|
||
def forward_up_features(self, x, x_downsample, h, w):
|
||
|
||
for inx, layer_up in enumerate(self.layers_up):
|
||
if inx == 0:
|
||
x = layer_up(x)
|
||
else:
|
||
x = torch.cat([x, x_downsample[3 - inx]], -1)
|
||
x = self.concat_back_dim[inx](x)
|
||
x = layer_up(x)
|
||
|
||
if self.training and inx == 1:
|
||
tmp = torch.permute(x, (0, 3, 1, 2))
|
||
# h4 = self.up4(tmp)
|
||
# h4 = self.conv4(tmp)
|
||
h4 = tmp
|
||
if self.training and inx == 2:
|
||
tmp = torch.permute(x, (0, 3, 1, 2))
|
||
# h3 = self.up3(tmp)
|
||
# h3 = self.conv3(tmp)
|
||
h3 =tmp
|
||
if self.training and inx == 3:
|
||
tmp = torch.permute(x, (0, 3, 1, 2))
|
||
# h2 = self.up2(tmp)
|
||
# h2 = self.conv2(tmp)
|
||
h2 =tmp
|
||
if self.training:
|
||
# ah = h4 + h3 + h2
|
||
# ah = self.aux_head(ah, h, w)
|
||
ah = [h2, h3, h4]
|
||
x = self.norm_up(x) # B H W C
|
||
|
||
return x, ah
|
||
else:
|
||
x = self.norm_up(x) # B H W C
|
||
return x
|
||
|
||
def up_x4(self, x, h, w):
|
||
B, H, W, C = x.shape
|
||
x = x.permute(0, 3, 1, 2)
|
||
x = self.output(x)
|
||
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)
|
||
return x
|
||
|
||
def forward_downfeatures(self, x_downsample):
|
||
x_down_last = x_downsample[-1]
|
||
x_downsample_2 = x_downsample
|
||
x_downsample = []
|
||
for idx, feat in enumerate(x_downsample_2[:-1]):
|
||
feat = torch.permute(feat, (0, 3, 1, 2))
|
||
feat = self.transfer[idx](feat)
|
||
x_downsample.append(feat)
|
||
|
||
x_down_3_2 = F.interpolate(x_downsample[1], scale_factor=2.0, mode="bilinear", align_corners=True)
|
||
x_down_4_2 = F.interpolate(x_downsample[2], scale_factor=4.0, mode="bilinear", align_corners=True)
|
||
|
||
x_down_4_3 = F.interpolate(x_downsample[2], scale_factor=2.0, mode="bilinear", align_corners=True)
|
||
x_down_2_3 = F.interpolate(x_downsample[0], scale_factor=0.5, mode="bilinear", align_corners=True)
|
||
|
||
x_down_2_4 = F.interpolate(x_downsample[0], scale_factor=0.25, mode="bilinear", align_corners=True)
|
||
x_down_3_4 = F.interpolate(x_downsample[1], scale_factor=0.5, mode="bilinear", align_corners=True)
|
||
|
||
x_down_2 = self.AFFs[0](x_downsample[0], x_down_3_2, x_down_4_2)
|
||
x_down_3 = self.AFFs[1](x_downsample[1], x_down_2_3, x_down_4_3)
|
||
x_down_4 = self.AFFs[2](x_downsample[2], x_down_3_4, x_down_2_4)
|
||
|
||
x_down_2 = torch.permute(x_down_2, (0, 2, 3, 1))
|
||
x_down_3 = torch.permute(x_down_3, (0, 2, 3, 1))
|
||
x_down_4 = torch.permute(x_down_4, (0, 2, 3, 1))
|
||
|
||
return [x_down_2, x_down_3, x_down_4, x_down_last]
|
||
|
||
def forward_resnet(self, x):
|
||
h, w = x.size()[-2:]
|
||
res1, res2, res3, res4 = self.backbone(x)
|
||
res1 = res1.permute(0, 2, 3, 1)
|
||
res2 = res2.permute(0, 2, 3, 1)
|
||
res3 = res3.permute(0, 2, 3, 1)
|
||
res4 = res4.permute(0, 2, 3, 1)
|
||
# x_downsample = [res1, res2, res3, res4]
|
||
# x = res4
|
||
x_downsample = [res1, res2, res3, res4]
|
||
x = res4#(zui hou de jie guo qiu loss)
|
||
return x,x_downsample
|
||
|
||
def forward(self, x):
|
||
b, c, h, w = x.size()
|
||
# x,x_downsample = self.forward_features(x)
|
||
x, x_downsample = self.forward_resnet(x)
|
||
x_downsample = self.forward_downfeatures(x_downsample)
|
||
|
||
if self.training:
|
||
x, ah = self.forward_up_features(x,x_downsample, h, w)
|
||
x = self.up_x4(x, h, w)
|
||
return x, ah
|
||
else:
|
||
x = self.forward_up_features(x,x_downsample,h,w)
|
||
x = self.up_x4(x, h, w)
|
||
return x
|
||
|
||
#
|
||
# def flops(self, shape=(3, 224, 224)):
|
||
# # shape = self.__input_shape__[1:]
|
||
# supported_ops = {
|
||
# "aten::silu": None, # as relu is in _IGNORED_OPS
|
||
# "aten::neg": None, # as relu is in _IGNORED_OPS
|
||
# "aten::exp": None, # as relu is in _IGNORED_OPS
|
||
# "aten::flip": None, # as permute is in _IGNORED_OPS
|
||
# "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, # latter
|
||
# }
|
||
#
|
||
# model = copy.deepcopy(self)
|
||
# model.cuda().eval()
|
||
#
|
||
# input = torch.randn((1, *shape), device=next(model.parameters()).device)
|
||
# params = parameter_count(model)[""]
|
||
# Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
|
||
#
|
||
# del model, input
|
||
# return sum(Gflops.values()) * 1e9
|
||
# return f"params {params} GFLOPs {sum(Gflops.values())}"
|
||
|
||
|
||
# # APIs with VMamba2Dp =================
|
||
# def check_vssm_equals_vmambadp():
|
||
# from bak.vmamba_bak1 import VMamba2Dp
|
||
#
|
||
# # test 1 True =================================
|
||
# torch.manual_seed(time.time());
|
||
# torch.cuda.manual_seed(time.time())
|
||
# oldvss = VMamba2Dp(depths=[2, 2, 6, 2]).half().cuda()
|
||
# newvss = VSSM(depths=[2, 2, 6, 2]).half().cuda()
|
||
# newvss.load_state_dict(oldvss.state_dict())
|
||
# input = torch.randn((12, 3, 224, 224)).half().cuda()
|
||
# torch.cuda.manual_seed(0)
|
||
# with torch.cuda.amp.autocast():
|
||
# y1 = oldvss.forward_backbone(input)
|
||
# torch.cuda.manual_seed(0)
|
||
# with torch.cuda.amp.autocast():
|
||
# y2 = newvss.forward_backbone(input)
|
||
# print((y1 - y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
|
||
#
|
||
# # test 2 True ==========================================
|
||
# torch.manual_seed(0);
|
||
# torch.cuda.manual_seed(0)
|
||
# oldvss = VMamba2Dp(depths=[2, 2, 6, 2]).cuda()
|
||
# torch.manual_seed(0);
|
||
# torch.cuda.manual_seed(0)
|
||
# newvss = VSSM(depths=[2, 2, 6, 2]).cuda()
|
||
#
|
||
# miss_align = 0
|
||
# for k, v in oldvss.state_dict().items():
|
||
# same = (oldvss.state_dict()[k] == newvss.state_dict()[k]).all()
|
||
# if not same:
|
||
# print(k, same)
|
||
# miss_align += 1
|
||
# print("init miss align", miss_align) # init miss align 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# check_vssm_equals_vmambadp()
|
||
model = CMUnet().to('cuda:4')
|
||
int = torch.randn(2, 3, 2048, 2048).to('cuda:4')
|
||
out = model(int)
|
||
# for res in out:
|
||
# print(res.size())
|
||
print(out.size())
|
||
|
||
# device = torch.device('cuda:4')
|
||
# model = CMUnet().to('cuda:4')
|
||
# model.eval()
|
||
# model.to(device)
|
||
# iterations = None
|
||
#
|
||
# input = torch.randn(1, 3, 4928, 4928).to('cuda:4')
|
||
# with torch.no_grad():
|
||
# for _ in range(10):
|
||
# model(input)
|
||
#
|
||
# if iterations is None:
|
||
# elapsed_time = 0
|
||
# iterations = 100
|
||
# while elapsed_time < 1:
|
||
# torch.cuda.synchronize()
|
||
# torch.cuda.synchronize()
|
||
# t_start = time.time()
|
||
# for _ in range(iterations):
|
||
# model(input)
|
||
# torch.cuda.synchronize()
|
||
# torch.cuda.synchronize()
|
||
# elapsed_time = time.time() - t_start
|
||
# iterations *= 2
|
||
# FPS = iterations / elapsed_time
|
||
# iterations = int(FPS * 6)
|
||
#
|
||
# print('=========Speed Testing=========')
|
||
# torch.cuda.synchronize()
|
||
# torch.cuda.synchronize()
|
||
# t_start = time.time()
|
||
# for _ in range(iterations):
|
||
# model(input)
|
||
# torch.cuda.synchronize()
|
||
# torch.cuda.synchronize()
|
||
# elapsed_time = time.time() - t_start
|
||
# latency = elapsed_time / iterations * 1000
|
||
# torch.cuda.empty_cache()
|
||
# FPS = 1000 / latency
|
||
# print(FPS) |