1384 lines
53 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
import math
import copy
from functools import partial
from typing import Optional, Callable
import timm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, trunc_normal_
# from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
from mmseg.registry import MODELS
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
# import mamba_ssm.selective_scan_fn (in which causal_conv1d is needed)
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
except:
pass
# an alternative for mamba_ssm
try:
from selective_scan import selective_scan_fn as selective_scan_fn_v1
from selective_scan import selective_scan_ref as selective_scan_ref_v1
except:
pass
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
import numpy as np
# fvcore.nn.jit_handles
def get_flops_einsum(input_shapes, equation):
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
return flop
assert not with_complex
flops = 0 # below code flops = 0
if False:
...
"""
dtype_in = u.dtype
u = u.float()
delta = delta.float()
if delta_bias is not None:
delta = delta + delta_bias[..., None].float()
if delta_softplus:
delta = F.softplus(delta)
batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
is_variable_B = B.dim() >= 3
is_variable_C = C.dim() >= 3
if A.is_complex():
if is_variable_B:
B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
if is_variable_C:
C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
else:
B = B.float()
C = C.float()
x = A.new_zeros((batch, dim, dstate))
ys = []
"""
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
if with_Group:
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
else:
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
if False:
...
"""
deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
if not is_variable_B:
deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
else:
if B.dim() == 3:
deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
else:
B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
if is_variable_C and C.dim() == 4:
C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
last_state = None
"""
in_for_flops = B * D * N
if with_Group:
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
else:
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
flops += L * in_for_flops
if False:
...
"""
for i in range(u.shape[2]):
x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
if not is_variable_C:
y = torch.einsum('bdn,dn->bd', x, C)
else:
if C.dim() == 3:
y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
else:
y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
if i == u.shape[2] - 1:
last_state = x
if y.is_complex():
y = y.real * 2
ys.append(y)
y = torch.stack(ys, dim=2) # (batch dim L)
"""
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
if False:
...
"""
out = y if D is None else y + u * rearrange(D, "d -> d 1")
if z is not None:
out = out * F.silu(z)
out = out.to(dtype=dtype_in)
"""
return flops
def selective_scan_flop_jit(inputs, outputs):
# xs, dts, As, Bs, Cs, Ds (skip), z (skip), dt_projs_bias (skip)
assert inputs[0].debugName().startswith("xs") # (B, D, L)
assert inputs[2].debugName().startswith("As") # (D, N)
assert inputs[3].debugName().startswith("Bs") # (D, N)
with_Group = len(inputs[3].type().sizes()) == 4
with_D = inputs[5].debugName().startswith("Ds")
if not with_D:
with_z = inputs[5].debugName().startswith("z")
else:
with_z = inputs[6].debugName().startswith("z")
B, D, L = inputs[0].type().sizes()
N = inputs[2].type().sizes()[1]
flops = flops_selective_scan_ref(B=B, L=L, D=D, N=N, with_D=with_D, with_Z=with_z, with_Group=with_Group)
return flops
class PatchEmbed2D(nn.Module):
r""" Image to Patch Embedding
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
super().__init__()
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = self.proj(x).permute(0, 2, 3, 1)
if self.norm is not None:
x = self.norm(x)
return x
class PatchMerging2D(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
B, H, W, C = x.shape
SHAPE_FIX = [-1, -1]
if (W % 2 != 0) or (H % 2 != 0):
print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
SHAPE_FIX[0] = H // 2
SHAPE_FIX[1] = W // 2
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
if SHAPE_FIX[0] > 0:
x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, H // 2, W // 2, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
# class PatchExpand(nn.Module):
# def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
# super().__init__()
# # self.dim = dim
# self.dim_scale = dim_scale
# self.conv = nn.Conv2d(dim, dim // dim_scale, 1, 1, 0, bias=False)
# # # Assuming dim_scale is 2, which means increasing spatial dimensions by a factor of 2
# # if dim_scale == 2:
# # self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
# # # Adjusting the number of channels after pixel shuffle
# # self.adjust_channels = nn.Conv2d(dim // 4, dim // dim_scale, kernel_size=1, stride=1, padding=0, bias=False)
# # else:
# # # If no scaling is needed, use an identity mapping
# # self.pixel_shuffle = nn.Identity()
# # self.adjust_channels = nn.Identity()
# # self.norm = norm_layer(dim // dim_scale)
# def forward(self, x):
# # Pixel shuffle expects the input in the format (B, C, H, W)
# x = rearrange(x, 'b h w c -> b c h w')
# if self.dim_scale == 2:
# x = F.interpolate(x, scale_factor=self.dim_scale, mode='bilinear', align_corners=True)
# x = self.conv(x)
# # else:
# # x = F.interpolate(x, scale_factor=self.dim_scale, mode='bilinear', align_corners=True)
# x = rearrange(x, 'b c h w -> b h w c')
# return x
class PatchExpand(nn.Module):
def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.dim_scale = dim_scale
# Assuming dim_scale is 2, which means increasing spatial dimensions by a factor of 2
if dim_scale == 2:
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
# Adjusting the number of channels after pixel shuffle
self.adjust_channels = nn.Conv2d(dim // 4, dim // dim_scale, kernel_size=1, stride=1, padding=0, bias=False)
else:
# If no scaling is needed, use an identity mapping
self.pixel_shuffle = nn.Identity()
self.adjust_channels = nn.Identity()
self.norm = norm_layer(dim // dim_scale)
def forward(self, x):
# Pixel shuffle expects the input in the format (B, C, H, W)
x = rearrange(x, 'b h w c -> b c h w')
if self.dim_scale == 2:
x = self.pixel_shuffle(x)
x = self.adjust_channels(x)
# Convert back to the original format for normalization
x = rearrange(x, 'b c h w -> b h w c')
x = self.norm(x)
return x
# class PatchExpand(nn.Module):
# def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
# super().__init__()
# self.dim = dim
# self.expand = nn.Linear(
# dim, 2*dim, bias=False) if dim_scale == 2 else nn.Identity()
# self.norm = norm_layer(dim // dim_scale)
# def forward(self, x):
# x = self.expand(x)
# B, H, W, C = x.shape
# x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)
# x= self.norm(x)
return x
# class FinalPatchExpand_X4(nn.Module):
# def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
# super().__init__()
# # self.dim = dim
# # self.dim_scale = dim_scale
# # # Assuming dim_scale is 2, which means increasing spatial dimensions by a factor of 2
# # if dim_scale == 4:
# # self.pixel_shuffle = nn.PixelShuffle(upscale_factor=4)
# # # Adjusting the number of channels after pixel shuffle
# # self.adjust_channels = nn.Conv2d(dim // 16, dim, kernel_size=1, stride=1, padding=0, bias=False)
# # else:
# # # If no scaling is needed, use an identity mapping
# # self.pixel_shuffle = nn.Identity()
# # self.adjust_channels = nn.Identity()
# # self.norm = norm_layer(dim)
# def forward(self, x):
# x = rearrange(x, 'b h w c -> b c h w')
# x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True)
# x = rearrange(x, 'b c h w -> b h w c')
# return x
class FinalPatchExpand_X4(nn.Module):
def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.dim_scale = dim_scale
self.expand = nn.Linear(dim, 16 * dim, bias=False)
self.output_dim = dim
self.norm = norm_layer(self.output_dim)
def forward(self, x):
x = self.expand(x)
B, H, W, C = x.shape
x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
c=C // (self.dim_scale ** 2))
x = self.norm(x)
return x
class SS2D(nn.Module):
def __init__(
self,
d_model,
d_state=16,
# d_state="auto", # 20240109
d_conv=3,
expand=0.5,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
dropout=0.,
conv_bias=True,
bias=False,
device=None,
dtype=None,
**kwargs,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_state = d_state
# self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
self.conv2d = nn.Conv2d(
in_channels=self.d_inner,
out_channels=self.d_inner,
groups=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
self.channel_attention = ChannelAttentionModule(self.d_inner)
self.spatial_attention = SpatialAttentionModule()
self.act = nn.SiLU()
self.x_proj = (
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
)
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
del self.x_proj
self.dt_projs = (
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
**factory_kwargs),
)
self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
del self.dt_projs
self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
self.forward_core = self.forward_core_windows
# self.forward_core = self.forward_corev0_seq
# self.forward_core = self.forward_corev1
self.out_norm = nn.LayerNorm(self.d_inner)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else None
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
**factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = dt_rank ** -0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
dt_proj.bias._no_reinit = True
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
if copies > 1:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=1, device=None, merge=True):
# D "skip" parameter
D = torch.ones(d_inner, device=device)
if copies > 1:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
def forward_core_windows(self, x: torch.Tensor, layer=1):
return self.forward_corev0(x)
if layer == 1:
return self.forward_corev0(x)
downsampled_4 = F.avg_pool2d(x, kernel_size=2, stride=2)
processed_4 = self.forward_corev0(downsampled_4)
processed_4 = processed_4.permute(0, 3, 1, 2)
restored_4 = F.interpolate(processed_4, scale_factor=2, mode='nearest')
restored_4 = restored_4.permute(0, 2, 3, 1)
if layer == 2:
output = (self.forward_corev0(x) + restored_4) / 2.0
downsampled_8 = F.avg_pool2d(x, kernel_size=4, stride=4)
processed_8 = self.forward_corev0(downsampled_8)
processed_8 = processed_8.permute(0, 3, 1, 2)
restored_8 = F.interpolate(processed_8, scale_factor=4, mode='nearest')
restored_8 = restored_8.permute(0, 2, 3, 1)
output = (self.forward_corev0(x) + restored_4 + restored_8) / 3.0
return output
# B C H W
num_splits = 2 ** layer
split_size = x.shape[2] // num_splits # Assuming H == W and is divisible by 2**layer
# Use unfold to create windows
x_unfolded = x.unfold(2, split_size, split_size).unfold(3, split_size, split_size)
x_unfolded = x_unfolded.contiguous().view(-1, x.size(1), split_size, split_size)
# Process all splits at once
processed_splits = self.forward_corev0(x_unfolded)
processed_splits = processed_splits.permute(0, 3, 1, 2)
# Reshape to get the splits back into their original positions and then permute to align dimensions
processed_splits = processed_splits.view(x.size(0), num_splits, num_splits, x.size(1), split_size, split_size)
processed_splits = processed_splits.permute(0, 3, 1, 4, 2, 5).contiguous()
processed_splits = processed_splits.view(x.size(0), x.size(1), x.size(2), x.size(3))
processed_splits = processed_splits.permute(0, 2, 3, 1)
return processed_splits
# num_splits = 2 ** layer
# split_size = x.shape[2] // num_splits # Assuming H == W and is divisible by 2**layer
# outputs = []
# for i in range(num_splits):
# row_outputs = []
# for j in range(num_splits):
# sub_x = x[:, :, i*split_size:(i+1)*split_size, j*split_size:(j+1)*split_size].contiguous()
# processed = self.forward_corev0(sub_x)
# row_outputs.append(processed)
# # Concatenate all column splits for current row
# outputs.append(torch.cat(row_outputs, dim=2))
# # Concatenate all rows
# final_output = torch.cat(outputs, dim=1)
return final_output
def forward_corev0(self, x: torch.Tensor):
self.selective_scan = selective_scan_fn
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds, z=None,
delta_bias=dt_projs_bias,
delta_softplus=True,
return_last_state=False,
).view(B, K, -1, L)
assert out_y.dtype == torch.float
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
y = self.out_norm(y).to(x.dtype)
return y
def forward_corev0_seq(self, x: torch.Tensor):
self.selective_scan = selective_scan_fn
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
xs = xs.float().view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
Ds = self.Ds.float().view(-1) # (k * d)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
out_y = []
for i in range(4):
yi = self.selective_scan(
xs[:, i], dts[:, i],
As[i], Bs[:, i], Cs[:, i], Ds[i],
delta_bias=dt_projs_bias[i],
delta_softplus=True,
).view(B, -1, L)
out_y.append(yi)
out_y = torch.stack(out_y, dim=1)
assert out_y.dtype == torch.float
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
y = self.out_norm(y).to(x.dtype)
return y
def forward_corev1(self, x: torch.Tensor):
self.selective_scan = selective_scan_fn_v1
B, C, H, W = x.shape
L = H * W
K = 4
x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)],
dim=1).view(B, 2, -1, L)
xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
# x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
# dts = dts + self.dt_projs_bias.view(1, K, -1, 1)
xs = xs.view(B, -1, L) # (b, k * d, l)
dts = dts.contiguous().view(B, -1, L) # (b, k * d, l)
Bs = Bs.view(B, K, -1, L) # (b, k, d_state, l)
Cs = Cs.view(B, K, -1, L) # (b, k, d_state, l)
As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state)
Ds = self.Ds.view(-1) # (k * d)
dt_projs_bias = self.dt_projs_bias.view(-1) # (k * d)
# print(self.Ds.dtype, self.A_logs.dtype, self.dt_projs_bias.dtype, flush=True) # fp16, fp16, fp16
out_y = self.selective_scan(
xs, dts,
As, Bs, Cs, Ds,
delta_bias=dt_projs_bias,
delta_softplus=True,
).view(B, K, -1, L)
assert out_y.dtype == torch.float16
inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
y = out_y[:, 0].float() + inv_y[:, 0].float() + wh_y.float() + invwh_y.float()
y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
y = self.out_norm(y).to(x.dtype)
return y
def forward(self, x: torch.Tensor, layer=1, **kwargs):
B, H, W, C = x.shape
xz = self.in_proj(x)
x, z = xz.chunk(2, dim=-1) # (b, h, w, d)
z = z.permute(0, 3, 1, 2)
z = self.channel_attention(z) * z
z = self.spatial_attention(z) * z
z = z.permute(0, 2, 3, 1).contiguous()
x = x.permute(0, 3, 1, 2).contiguous()
x = self.act(self.conv2d(x)) # (b, d, h, w)
y = self.forward_core(x, layer)
y = y * F.silu(z)
out = self.out_proj(y)
if self.dropout is not None:
out = self.dropout(out)
return out
class VSSBlock(nn.Module):
def __init__(
self,
hidden_dim: int = 0,
drop_path: float = 0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
attn_drop_rate: float = 0,
d_state: int = 16,
layer: int = 1,
**kwargs,
):
super().__init__()
factor = 2.0
d_model = int(hidden_dim // factor)
self.down = nn.Linear(hidden_dim, d_model)
self.up = nn.Linear(d_model, hidden_dim)
self.ln_1 = norm_layer(d_model)
self.self_attention = SS2D(d_model=d_model, dropout=attn_drop_rate, d_state=d_state, **kwargs)
self.drop_path = DropPath(drop_path)
self.layer = layer
def forward(self, input: torch.Tensor):
input_x = self.down(input)
input_x = input_x + self.drop_path(self.self_attention(self.ln_1(input_x), self.layer))
x = self.up(input_x) + input
return x
class VSSLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
d_state=16,
**kwargs,
):
super().__init__()
self.dim = dim
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
VSSBlock(
hidden_dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
attn_drop_rate=attn_drop,
d_state=d_state,
# expand=0.25,
)
for i in range(depth)])
if True: # is this really applied? Yes, but been overriden later in VSSM!
def _init_weights(module: nn.Module):
for name, p in module.named_parameters():
if name in ["out_proj.weight"]:
p = p.clone().detach_() # fake init, just to keep the seed ....
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
self.apply(_init_weights)
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class VSSLayer_up(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
Upsample (nn.Module | None, optional): Upsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(
self,
dim,
depth,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
upsample=None,
use_checkpoint=False,
d_state=16,
layer=1,
**kwargs,
):
super().__init__()
self.dim = dim
self.use_checkpoint = use_checkpoint
print('layer: ', layer, ' dstate: ', d_state)
self.blocks = nn.ModuleList([
VSSBlock(
hidden_dim=dim,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
attn_drop_rate=attn_drop,
d_state=d_state,
# expand=1.0,
layer=layer,
)
for i in range(depth)])
if True: # is this really applied? Yes, but been overriden later in VSSM!
def _init_weights(module: nn.Module):
for name, p in module.named_parameters():
if name in ["out_proj.weight"]:
p = p.clone().detach_() # fake init, just to keep the seed ....
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
self.apply(_init_weights)
if upsample is not None:
self.upsample = PatchExpand(dim, dim_scale=2, norm_layer=nn.LayerNorm)
else:
self.upsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.upsample is not None:
x = self.upsample(x)
return x
class ConvBNReLU(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, norm_layer=nn.BatchNorm2d,
bias=False):
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2),
norm_layer(out_channels),
nn.ReLU6()
)
class Conv(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1, bias=False):
super(Conv, self).__init__(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=bias,
dilation=dilation, stride=stride, padding=((stride - 1) + dilation * (kernel_size - 1)) // 2)
)
class AuxHead(nn.Module):
def __init__(self, in_channels=64, num_classes=8):
super().__init__()
self.conv = ConvBNReLU(in_channels, in_channels)
self.drop = nn.Dropout(0.1)
self.conv_out = Conv(in_channels, num_classes, kernel_size=1)
def forward(self, x, h, w):
feat = self.conv(x)
feat = self.drop(feat)
feat = self.conv_out(feat)
feat = F.interpolate(feat, size=(h, w), mode='bilinear', align_corners=False)
return feat
class BasicConv(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
super(BasicConv, self).__init__()
if bias and norm:
bias = False
padding = kernel_size // 2
layers = list()
if transpose:
padding = kernel_size // 2 - 1
layers.append(
nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
else:
layers.append(
nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
if norm:
layers.append(nn.BatchNorm2d(out_channel))
if relu:
layers.append(nn.ReLU(inplace=True))
self.main = nn.Sequential(*layers)
def forward(self, x):
return self.main(x)
class AFF(nn.Module):
def __init__(self, in_channel, out_channel):
super(AFF, self).__init__()
self.conv = nn.Sequential(
BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
)
def forward(self, x1, x2, x4):
x = torch.cat([x1, x2, x4], dim=1)
return self.conv(x)
class ChannelAttentionModule(nn.Module):
def __init__(self, in_channels, reduction=4):
super(ChannelAttentionModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttentionModule(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttentionModule, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class FusionConv(nn.Module):
def __init__(self, in_channels, out_channels, factor=4.0):
super(FusionConv, self).__init__()
dim = int(out_channels // factor)
self.down = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
self.conv_3x3 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1)
self.conv_5x5 = nn.Conv2d(dim, dim, kernel_size=5, stride=1, padding=2)
self.conv_7x7 = nn.Conv2d(dim, dim, kernel_size=7, stride=1, padding=3)
self.spatial_attention = SpatialAttentionModule()
self.channel_attention = ChannelAttentionModule(dim)
self.up = nn.Conv2d(dim, out_channels, kernel_size=1, stride=1)
self.down_2 = nn.Conv2d(in_channels, dim, kernel_size=1, stride=1)
def forward(self, x1, x2, x4):
x_fused = torch.cat([x1, x2, x4], dim=1)
x_fused = self.down(x_fused)
x_fused_c = x_fused * self.channel_attention(x_fused)
x_3x3 = self.conv_3x3(x_fused)
x_5x5 = self.conv_5x5(x_fused)
x_7x7 = self.conv_7x7(x_fused)
x_fused_s = x_3x3 + x_5x5 + x_7x7
x_fused_s = x_fused_s * self.spatial_attention(x_fused_s)
x_out = self.up(x_fused_s + x_fused_c)
return x_out
class DownFusion(nn.Module):
def __init__(self, in_channels, out_channels):
super(DownFusion, self).__init__()
self.fusion_conv = FusionConv(in_channels, out_channels)
self.CAM = ChannelAttentionModule(out_channels)
def forward(self, x1, x2):
x_fused = torch.cat([x1, x2], dim=1)
x_fused = self.fusion_conv(x_fused)
x_fused = + x_fused
return x_fused
class MSAA(nn.Module):
def __init__(self, in_channels, out_channels):
super(MSAA, self).__init__()
self.fusion_conv = FusionConv(in_channels, out_channels)
def forward(self, x1, x2, x4, last=False):
# # x2 是从低到高x4是从高到低的设计x2传递语义信息x4传递边缘问题特征补充
# x_1_2_fusion = self.fusion_1x2(x1, x2)
# x_1_4_fusion = self.fusion_1x4(x1, x4)
# x_fused = x_1_2_fusion + x_1_4_fusion
x_fused = self.fusion_conv(x1, x2, x4)
return x_fused
@MODELS.register_module()
class CMUnet(nn.Module):
def __init__(self, patch_size=4, in_chans=1, num_classes=4, depths=[2, 2, 9, 2],
dims=[96, 192, 384, 768], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, patch_norm=True,
use_checkpoint=False, final_upsample="expand_first", **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
# resnet18
dims = [64, 128, 256, 512]
# dims = [64, 64, 64, 64]
self.backbone = timm.create_model("swsl_resnet18", features_only=True, output_stride=32,
out_indices=(1, 2, 3, 4),
pretrained=False) # , pretrained_cfg_overlay=dict(file='~/.cache/huggingface/hub/modelstimmresnet18.a1_in1k/pytorch_model.bin')) # build decoder layers
base_dims = 64
# state_dict = torch.load('/media/cm/1T_SSD/UHR_Comparative_Exp3/checkpoint/semi_weakly_supervised_resnet18-118f1556.pth')
# msg = self.backbone.load_state_dict(state_dict, strict=False)
# print('[INFO] ', msg)
if isinstance(dims, int):
dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
self.embed_dim = dims[0]
self.num_features = dims[-1]
self.num_features_up = int(dims[0] * 2)
self.dims = dims
self.final_upsample = final_upsample
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
self.layers_up = nn.ModuleList()
self.concat_back_dim = nn.ModuleList()
for i_layer in range(self.num_layers):
concat_linear = nn.Linear(2 * int(dims[0] * 2 ** (self.num_layers - 1 - i_layer)),
int(dims[0] * 2 ** (
self.num_layers - 1 - i_layer))) if i_layer > 0 else nn.Identity()
if i_layer == 0:
layer_up = nn.Sequential(
VSSLayer(
dim=int(dims[0] * 2 ** (self.num_layers - 1 - i_layer)),
depth=2,
d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:-1]):sum(depths[:])],
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint),
PatchExpand(dim=int(self.embed_dim * 2 ** (self.num_layers - 1 - i_layer)), dim_scale=2,
norm_layer=norm_layer)
)
# layer_up =PatchExpand(dim=int(self.embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)
else:
layer_up = VSSLayer_up(
dim=int(dims[0] * 2 ** (self.num_layers - 1 - i_layer)),
depth=depths[(self.num_layers - 1 - i_layer)],
d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:(self.num_layers - 1 - i_layer)]):sum(
depths[:(self.num_layers - 1 - i_layer) + 1])],
norm_layer=norm_layer,
upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
layer=i_layer,
)
self.layers_up.append(layer_up)
self.concat_back_dim.append(concat_linear)
self.norm = norm_layer(self.num_features)
self.norm_up = norm_layer(self.embed_dim)
if self.final_upsample == "expand_first":
print("---final upsample expand_first---")
self.up = FinalPatchExpand_X4(dim_scale=4, dim=self.embed_dim)
self.output = nn.Conv2d(in_channels=self.embed_dim, out_channels=self.num_classes, kernel_size=1,
bias=False)
if self.training:
self.conv4 = nn.Conv2d(base_dims * 2, num_classes, 1, bias=False)
self.conv3 = nn.Conv2d(base_dims, num_classes, 1, bias=False)
self.conv2 = nn.Conv2d(base_dims, num_classes, 1, bias=False)
# self.conv1 = nn.Conv2d(base_dims, num_classes, 1, bias=False)
hidden_dim = int(base_dims // 4)
self.AFFs = nn.ModuleList([
MSAA(hidden_dim * 7, base_dims),
MSAA(hidden_dim * 7, base_dims * 2),
MSAA(hidden_dim * 7, base_dims * 4),
])
self.transfer = nn.ModuleList(
[
nn.Conv2d(base_dims, hidden_dim, 1, bias=False),
nn.Conv2d(base_dims * 2, hidden_dim * 2, 1, bias=False),
nn.Conv2d(base_dims * 4, hidden_dim * 4, 1, bias=False),
]
)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module):
"""
out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
no fc.weight found in the any of the model parameters
no nn.Embedding found in the any of the model parameters
so the thing is, VSSBlock initialization is useless
Conv2D is not intialized !!!
"""
# print(m, getattr(getattr(m, "weight", nn.Identity()), "INIT", None), isinstance(m, nn.Linear), "======================")
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
# Encoder and Bottleneck
def forward_features(self, x):
x = self.patch_embed(x)
x_downsample = []
for layer in self.layers:
x_downsample.append(x)
x = layer(x)
x = self.norm(x) # B H W C
return x, x_downsample
# def forward_backbone(self, x):
# x = self.patch_embed(x)
# for layer in self.layers:
# x = layer(x)
# return x
# Dencoder and Skip connection
def forward_up_features(self, x, x_downsample, h, w):
for inx, layer_up in enumerate(self.layers_up):
if inx == 0:
x = layer_up(x)
else:
x = torch.cat([x, x_downsample[3 - inx]], -1)
x = self.concat_back_dim[inx](x)
x = layer_up(x)
if self.training and inx == 1:
tmp = torch.permute(x, (0, 3, 1, 2))
# h4 = self.up4(tmp)
# h4 = self.conv4(tmp)
h4 = tmp
if self.training and inx == 2:
tmp = torch.permute(x, (0, 3, 1, 2))
# h3 = self.up3(tmp)
# h3 = self.conv3(tmp)
h3 =tmp
if self.training and inx == 3:
tmp = torch.permute(x, (0, 3, 1, 2))
# h2 = self.up2(tmp)
# h2 = self.conv2(tmp)
h2 =tmp
if self.training:
# ah = h4 + h3 + h2
# ah = self.aux_head(ah, h, w)
ah = [h2, h3, h4]
x = self.norm_up(x) # B H W C
return x, ah
else:
x = self.norm_up(x) # B H W C
return x
def up_x4(self, x, h, w):
B, H, W, C = x.shape
x = x.permute(0, 3, 1, 2)
x = self.output(x)
x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)
return x
def forward_downfeatures(self, x_downsample):
x_down_last = x_downsample[-1]
x_downsample_2 = x_downsample
x_downsample = []
for idx, feat in enumerate(x_downsample_2[:-1]):
feat = torch.permute(feat, (0, 3, 1, 2))
feat = self.transfer[idx](feat)
x_downsample.append(feat)
x_down_3_2 = F.interpolate(x_downsample[1], scale_factor=2.0, mode="bilinear", align_corners=True)
x_down_4_2 = F.interpolate(x_downsample[2], scale_factor=4.0, mode="bilinear", align_corners=True)
x_down_4_3 = F.interpolate(x_downsample[2], scale_factor=2.0, mode="bilinear", align_corners=True)
x_down_2_3 = F.interpolate(x_downsample[0], scale_factor=0.5, mode="bilinear", align_corners=True)
x_down_2_4 = F.interpolate(x_downsample[0], scale_factor=0.25, mode="bilinear", align_corners=True)
x_down_3_4 = F.interpolate(x_downsample[1], scale_factor=0.5, mode="bilinear", align_corners=True)
x_down_2 = self.AFFs[0](x_downsample[0], x_down_3_2, x_down_4_2)
x_down_3 = self.AFFs[1](x_downsample[1], x_down_2_3, x_down_4_3)
x_down_4 = self.AFFs[2](x_downsample[2], x_down_3_4, x_down_2_4)
x_down_2 = torch.permute(x_down_2, (0, 2, 3, 1))
x_down_3 = torch.permute(x_down_3, (0, 2, 3, 1))
x_down_4 = torch.permute(x_down_4, (0, 2, 3, 1))
return [x_down_2, x_down_3, x_down_4, x_down_last]
def forward_resnet(self, x):
h, w = x.size()[-2:]
res1, res2, res3, res4 = self.backbone(x)
res1 = res1.permute(0, 2, 3, 1)
res2 = res2.permute(0, 2, 3, 1)
res3 = res3.permute(0, 2, 3, 1)
res4 = res4.permute(0, 2, 3, 1)
# x_downsample = [res1, res2, res3, res4]
# x = res4
x_downsample = [res1, res2, res3, res4]
x = res4#(zui hou de jie guo qiu loss)
return x,x_downsample
def forward(self, x):
b, c, h, w = x.size()
# x,x_downsample = self.forward_features(x)
x, x_downsample = self.forward_resnet(x)
x_downsample = self.forward_downfeatures(x_downsample)
if self.training:
x, ah = self.forward_up_features(x,x_downsample, h, w)
x = self.up_x4(x, h, w)
return x, ah
else:
x = self.forward_up_features(x,x_downsample,h,w)
x = self.up_x4(x, h, w)
return x
#
# def flops(self, shape=(3, 224, 224)):
# # shape = self.__input_shape__[1:]
# supported_ops = {
# "aten::silu": None, # as relu is in _IGNORED_OPS
# "aten::neg": None, # as relu is in _IGNORED_OPS
# "aten::exp": None, # as relu is in _IGNORED_OPS
# "aten::flip": None, # as permute is in _IGNORED_OPS
# "prim::PythonOp.SelectiveScanFn": selective_scan_flop_jit, # latter
# }
#
# model = copy.deepcopy(self)
# model.cuda().eval()
#
# input = torch.randn((1, *shape), device=next(model.parameters()).device)
# params = parameter_count(model)[""]
# Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)
#
# del model, input
# return sum(Gflops.values()) * 1e9
# return f"params {params} GFLOPs {sum(Gflops.values())}"
# # APIs with VMamba2Dp =================
# def check_vssm_equals_vmambadp():
# from bak.vmamba_bak1 import VMamba2Dp
#
# # test 1 True =================================
# torch.manual_seed(time.time());
# torch.cuda.manual_seed(time.time())
# oldvss = VMamba2Dp(depths=[2, 2, 6, 2]).half().cuda()
# newvss = VSSM(depths=[2, 2, 6, 2]).half().cuda()
# newvss.load_state_dict(oldvss.state_dict())
# input = torch.randn((12, 3, 224, 224)).half().cuda()
# torch.cuda.manual_seed(0)
# with torch.cuda.amp.autocast():
# y1 = oldvss.forward_backbone(input)
# torch.cuda.manual_seed(0)
# with torch.cuda.amp.autocast():
# y2 = newvss.forward_backbone(input)
# print((y1 - y2).abs().sum()) # tensor(0., device='cuda:0', grad_fn=<SumBackward0>)
#
# # test 2 True ==========================================
# torch.manual_seed(0);
# torch.cuda.manual_seed(0)
# oldvss = VMamba2Dp(depths=[2, 2, 6, 2]).cuda()
# torch.manual_seed(0);
# torch.cuda.manual_seed(0)
# newvss = VSSM(depths=[2, 2, 6, 2]).cuda()
#
# miss_align = 0
# for k, v in oldvss.state_dict().items():
# same = (oldvss.state_dict()[k] == newvss.state_dict()[k]).all()
# if not same:
# print(k, same)
# miss_align += 1
# print("init miss align", miss_align) # init miss align 0
if __name__ == "__main__":
# check_vssm_equals_vmambadp()
model = CMUnet().to('cuda:4')
int = torch.randn(2, 3, 2048, 2048).to('cuda:4')
out = model(int)
# for res in out:
# print(res.size())
print(out.size())
# device = torch.device('cuda:4')
# model = CMUnet().to('cuda:4')
# model.eval()
# model.to(device)
# iterations = None
#
# input = torch.randn(1, 3, 4928, 4928).to('cuda:4')
# with torch.no_grad():
# for _ in range(10):
# model(input)
#
# if iterations is None:
# elapsed_time = 0
# iterations = 100
# while elapsed_time < 1:
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# t_start = time.time()
# for _ in range(iterations):
# model(input)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# elapsed_time = time.time() - t_start
# iterations *= 2
# FPS = iterations / elapsed_time
# iterations = int(FPS * 6)
#
# print('=========Speed Testing=========')
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# t_start = time.time()
# for _ in range(iterations):
# model(input)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# elapsed_time = time.time() - t_start
# latency = elapsed_time / iterations * 1000
# torch.cuda.empty_cache()
# FPS = 1000 / latency
# print(FPS)