1403 lines
55 KiB
Python
Raw Normal View History

import os
import time
import math
import copy
from functools import partial
from typing import Optional, Callable, Any
from collections import OrderedDict
from mmseg.registry import MODELS
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, trunc_normal_
# from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
# import selective scan ==============================
try:
import selective_scan_cuda_oflex
except Exception as e:
...
# print(f"WARNING: can not import selective_scan_cuda_oflex.", flush=True)
# print(e, flush=True)
try:
import selective_scan_cuda_core
except Exception as e:
...
# print(f"WARNING: can not import selective_scan_cuda_core.", flush=True)
# print(e, flush=True)
try:
import selective_scan_cuda
except Exception as e:
...
# print(f"WARNING: can not import selective_scan_cuda.", flush=True)
# print(e, flush=True)
# fvcore flops =======================================
def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
assert not with_complex
# https://github.com/state-spaces/mamba/issues/110
flops = 9 * B * L * D * N
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
return flops
# this is only for selective_scan_ref...
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
import numpy as np
# fvcore.nn.jit_handles
def get_flops_einsum(input_shapes, equation):
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
for line in optim.split("\n"):
if "optimized flop" in line.lower():
# divided by 2 because we count MAC (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(":")[-1]) / 2))
return flop
assert not with_complex
flops = 0 # below code flops = 0
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
if with_Group:
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
else:
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
in_for_flops = B * D * N
if with_Group:
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
else:
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
flops += L * in_for_flops
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
return flops
def print_jit_input_names(inputs):
print("input params: ", end=" ", flush=True)
try:
for i in range(10):
print(inputs[i].debugName(), end=" ", flush=True)
except Exception as e:
pass
print("", flush=True)
# cross selective scan ===============================
class SelectiveScanMamba(torch.autograd.Function):
# comment all checks if inside cross_selective_scan
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1,
oflex=True):
# assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile
# assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}"
ctx.delta_softplus = delta_softplus
# all in float
# if u.stride(-1) != 1:
# u = u.contiguous()
# if delta.stride(-1) != 1:
# delta = delta.contiguous()
# if D is not None and D.stride(-1) != 1:
# D = D.contiguous()
# if B.stride(-1) != 1:
# B = B.contiguous()
# if C.stride(-1) != 1:
# C = C.contiguous()
# if B.dim() == 3:
# B = B.unsqueeze(dim=1)
# ctx.squeeze_B = True
# if C.dim() == 3:
# C = C.unsqueeze(dim=1)
# ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, dout, *args):
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
False
)
# dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
# dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
# class SelectiveScanCore(torch.autograd.Function):
# # comment all checks if inside cross_selective_scan
# @staticmethod
# @torch.cuda.amp.custom_fwd
# def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True):
# ctx.delta_softplus = delta_softplus
# out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
# ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
# return out
#
# @staticmethod
# @torch.cuda.amp.custom_bwd
# def backward(ctx, dout, *args):
# u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
# if dout.stride(-1) != 1:
# dout = dout.contiguous()
# du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
# u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
# )
# return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
#
#
# class SelectiveScanOflex(torch.autograd.Function):
# # comment all checks if inside cross_selective_scan
# @staticmethod
# @torch.cuda.amp.custom_fwd
# def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True):
# ctx.delta_softplus = delta_softplus
# out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
# ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
# return out
#
# @staticmethod
# @torch.cuda.amp.custom_bwd
# def backward(ctx, dout, *args):
# u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
# if dout.stride(-1) != 1:
# dout = dout.contiguous()
# du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
# u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
# )
# return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
#
#
# class SelectiveScanFake(torch.autograd.Function):
# # comment all checks if inside cross_selective_scan
# @staticmethod
# @torch.cuda.amp.custom_fwd
# def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True):
# ctx.delta_softplus = delta_softplus
# ctx.backnrows = backnrows
# x = delta
# out = u
# ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
# return out
#
# @staticmethod
# @torch.cuda.amp.custom_bwd
# def backward(ctx, dout, *args):
# u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
# if dout.stride(-1) != 1:
# dout = dout.contiguous()
# du, ddelta, dA, dB, dC, dD, ddelta_bias = u * 0, delta * 0, A * 0, B * 0, C * 0, C * 0, (D * 0 if D else None), (delta_bias * 0 if delta_bias else None)
# return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
# =============
def antidiagonal_gather(tensor):
# 取出矩阵所有反斜向的元素并拼接
B, C, H, W = tensor.size()
shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1]
index = (torch.arange(W, device=tensor.device) - shift) % W # 利用广播创建索引矩阵[H, W]
# 扩展索引以适应B和C维度
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
# 使用gather进行索引选择
return tensor.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W)
def diagonal_gather(tensor):
# 取出矩阵所有反斜向的元素并拼接
B, C, H, W = tensor.size()
shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1]
index = (shift + torch.arange(W, device=tensor.device)) % W # 利用广播创建索引矩阵[H, W]
# 扩展索引以适应B和C维度
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
# 使用gather进行索引选择
return tensor.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W)
def diagonal_scatter(tensor_flat, original_shape):
# 把斜向元素拼接起来的一维向量还原为最初的矩阵形式
B, C, H, W = original_shape
shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1]
index = (shift + torch.arange(W, device=tensor_flat.device)) % W # 利用广播创建索引矩阵[H, W]
# 扩展索引以适应B和C维度
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
# 创建一个空的张量来存储反向散布的结果
result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
# 将平铺的张量重新变形为[B, C, H, W]考虑到需要使用transpose将H和W调换
tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
# 使用scatter_根据expanded_index将元素放回原位
result_tensor.scatter_(3, expanded_index, tensor_reshaped)
return result_tensor
def antidiagonal_scatter(tensor_flat, original_shape):
# 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式
B, C, H, W = original_shape
shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1]
index = (torch.arange(W, device=tensor_flat.device) - shift) % W # 利用广播创建索引矩阵[H, W]
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
# 初始化一个与原始张量形状相同、元素全为0的张量
result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
# 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度
tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
# 使用scatter_将元素根据索引放回原位
result_tensor.scatter_(3, expanded_index, tensor_reshaped)
return result_tensor
class CrossScan(torch.autograd.Function):
# ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C, H, W)
# xs = x.new_empty((B, 4, C, H * W))
xs = x.new_empty((B, 8, C, H * W))
# 添加横向和竖向的扫描
xs[:, 0] = x.flatten(2, 3)
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
# 提供斜向和反斜向的扫描
xs[:, 4] = diagonal_gather(x)
xs[:, 5] = antidiagonal_gather(x)
xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
B, C, H, W = ctx.shape
L = H * W
# 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
# 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
y_rb = y_rb.view(B, -1, H, W)
# 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
# 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
y_da = diagonal_scatter(y_da[:, 0], (B, C, H, W)) + antidiagonal_scatter(y_da[:, 1], (B, C, H, W))
y_res = y_rb + y_da
# return y.view(B, -1, H, W)
return y_res
class CrossMerge(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor):
B, K, D, H, W = ys.shape
ctx.shape = (H, W)
ys = ys.view(B, K, D, -1)
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
# 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
y_rb = y_rb.view(B, -1, H, W)
# 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1)
# 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
y_da = diagonal_scatter(y_da[:, 0], (B, D, H, W)) + antidiagonal_scatter(y_da[:, 1], (B, D, H, W))
y_res = y_rb + y_da
return y_res.view(B, D, -1)
# return y
@staticmethod
def backward(ctx, x: torch.Tensor):
# B, D, L = x.shape
# out: (b, k, d, l)
H, W = ctx.shape
B, C, L = x.shape
# xs = x.new_empty((B, 4, C, L))
xs = x.new_empty((B, 8, C, L))
# 横向和竖向扫描
xs[:, 0] = x
xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
# xs = xs.view(B, 4, C, H, W)
# 提供斜向和反斜向的扫描
xs[:, 4] = diagonal_gather(x.view(B, C, H, W))
xs[:, 5] = antidiagonal_gather(x.view(B, C, H, W))
xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
# return xs
return xs.view(B, 8, C, H, W)
"""
ablation exp
"""
# these are for ablations =============
# class CrossScan_Ab_2direction(torch.autograd.Function):
# @staticmethod
# def forward(ctx, x: torch.Tensor):
# B, C, H, W = x.shape
# ctx.shape = (B, C, H, W)
# xs = x.new_empty((B, 4, C, H * W))
# xs[:, 0] = x.flatten(2, 3)
# xs[:, 1] = x.flatten(2, 3)
# xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
# return xs
#
# @staticmethod
# def backward(ctx, ys: torch.Tensor):
# # out: (b, k, d, l)
# B, C, H, W = ctx.shape
# L = H * W
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
# return y.view(B, -1, H, W)
#
#
# class CrossMerge_Ab_2direction(torch.autograd.Function):
# @staticmethod
# def forward(ctx, ys: torch.Tensor):
# B, K, D, H, W = ys.shape
# ctx.shape = (H, W)
# ys = ys.view(B, K, D, -1)
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
# y = ys.sum(dim=1)
# return y
#
# @staticmethod
# def backward(ctx, x: torch.Tensor):
# # B, D, L = x.shape
# # out: (b, k, d, l)
# H, W = ctx.shape
# B, C, L = x.shape
# xs = x.new_empty((B, 4, C, L))
# xs[:, 0] = x
# xs[:, 1] = x
# xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
# xs = xs.view(B, 4, C, H, W)
# return xs
#
#
# class CrossScan_Ab_1direction(torch.autograd.Function):
# @staticmethod
# def forward(ctx, x: torch.Tensor):
# B, C, H, W = x.shape
# ctx.shape = (B, C, H, W)
# xs = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1).contiguous()
# return xs
#
# @staticmethod
# def backward(ctx, ys: torch.Tensor):
# # out: (b, k, d, l)
# B, C, H, W = ctx.shape
# y = ys.sum(dim=1).view(B, C, H, W)
# return y
#
#
# class CrossMerge_Ab_1direction(torch.autograd.Function):
# @staticmethod
# def forward(ctx, ys: torch.Tensor):
# B, K, D, H, W = ys.shape
# ctx.shape = (H, W)
# y = ys.sum(dim=1).view(B, D, H * W)
# return y
#
# @staticmethod
# def backward(ctx, x: torch.Tensor):
# # B, D, L = x.shape
# # out: (b, k, d, l)
# H, W = ctx.shape
# B, C, L = x.shape
# xs = x.view(B, 1, C, L).repeat(1, 4, 1, 1).contiguous().view(B, 4, C, H, W)
# return xs
#
# =============
# ZSJ 这里是mamba的具体内容要增加扫描方向就在这里改
class SelectiveScanFn(torch.autograd.Function):
@staticmethod
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3:
B = rearrange(B, "b dstate l -> b 1 dstate l")
ctx.squeeze_B = True
if C.dim() == 3:
C = rearrange(C, "b dstate l -> b 1 dstate l")
ctx.squeeze_C = True
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
ctx.delta_softplus = delta_softplus
ctx.has_z = z is not None
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
if not ctx.has_z:
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
return out if not return_last_state else (out, last_state)
else:
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
out_z = rest[0]
return out_z if not return_last_state else (out_z, last_state)
@staticmethod
def backward(ctx, dout, *args):
if not ctx.has_z:
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
z = None
out = None
else:
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
if dout.stride(-1) != 1:
dout = dout.contiguous()
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
# backward of selective_scan_cuda with the backward of chunk).
# Here we just pass in None and dz will be allocated in the C++ code.
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
False # option to recompute out_z, not used here
)
dz = rest[0] if ctx.has_z else None
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
return (du, ddelta, dA, dB, dC,
dD if D is not None else None,
dz,
ddelta_bias if delta_bias is not None else None,
None,
None)
def cross_selective_scan(
x: torch.Tensor = None,
x_proj_weight: torch.Tensor = None,
x_proj_bias: torch.Tensor = None,
dt_projs_weight: torch.Tensor = None,
dt_projs_bias: torch.Tensor = None,
A_logs: torch.Tensor = None,
Ds: torch.Tensor = None,
delta_softplus=True,
out_norm: torch.nn.Module = None,
out_norm_shape="v0",
# ==============================
to_dtype=True, # True: final out to dtype
force_fp32=False, # True: input fp32
# ==============================
nrows=-1, # for SelectiveScanNRow; 0: auto; -1: disable;
backnrows=-1, # for SelectiveScanNRow; 0: auto; -1: disable;
ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore
# ==============================
# SelectiveScan=None,
CrossScan=CrossScan,
CrossMerge=CrossMerge,
):
# out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);...
B, D, H, W = x.shape
D, N = A_logs.shape
K, D, R = dt_projs_weight.shape
L = H * W
if nrows == 0:
if D % 4 == 0:
nrows = 4
elif D % 3 == 0:
nrows = 3
elif D % 2 == 0:
nrows = 2
else:
nrows = 1
if backnrows == 0:
if D % 4 == 0:
backnrows = 4
elif D % 3 == 0:
backnrows = 3
elif D % 2 == 0:
backnrows = 2
else:
backnrows = 1
# sacn jiushizhege bata huancheng houmian nage banben
# def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
# return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex)
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
xs = CrossScan.apply(x)
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)
if x_proj_bias is not None:
x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight)
xs = xs.view(B, -1, L)
dts = dts.contiguous().view(B, -1, L)
As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state)
Bs = Bs.contiguous()
Cs = Cs.contiguous()
Ds = Ds.to(torch.float) # (K * c)
delta_bias = dt_projs_bias.view(-1).to(torch.float)
if force_fp32:
xs = xs.to(torch.float)
dts = dts.to(torch.float)
Bs = Bs.to(torch.float)
Cs = Cs.to(torch.float)
# ZSJ 这里把矩阵拆分成不同方向的序列,并进行扫描
ys: torch.Tensor = selective_scan_fn(
xs, dts,
As, Bs, Cs, Ds, z=None,
delta_bias=delta_bias,
delta_softplus=True,
return_last_state=False,
).view(B, K, -1, H, W)
# ZSJ 这里把处理之后的序列融合起来,并还原回原来的矩阵形式
y: torch.Tensor = CrossMerge.apply(ys)
if out_norm_shape in ["v1"]: # (B, C, H, W)
y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) # (B, H, W, C)
else: # (B, L, C)
y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
y = out_norm(y).view(B, H, W, -1)
return (y.to(x.dtype) if to_dtype else y)
def selective_scan_flop_jit(inputs, outputs):
print_jit_input_names(inputs)
B, D, L = inputs[0].type().sizes()
N = inputs[2].type().sizes()[1]
flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
return flops
# =====================================================
class PatchMerging2D(nn.Module):
def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False)
self.norm = norm_layer(4 * dim)
@staticmethod
def _patch_merging_pad(x: torch.Tensor):
H, W, _ = x.shape[-3:]
if (W % 2 != 0) or (H % 2 != 0):
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
return x
def forward(self, x):
x = self._patch_merging_pad(x)
x = self.norm(x)
x = self.reduction(x)
return x
class OSSM(nn.Module):
def __init__(
self,
# basic dims ===========
d_model=96,
d_state=16,
ssm_ratio=2.0,
dt_rank="auto",
act_layer=nn.SiLU,
# dwconv ===============
d_conv=3, # < 2 means no conv
conv_bias=True,
# ======================
dropout=0.0,
bias=False,
# dt init ==============
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
initialize="v0",
# ======================
forward_type="v2",
# ======================
**kwargs,
):
factory_kwargs = {"device": None, "dtype": None}
super().__init__()
d_inner = int(ssm_ratio * d_model)
dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
self.d_conv = d_conv
# tags for forward_type ==============================
def checkpostfix(tag, value):
ret = value[-len(tag):] == tag
if ret:
value = value[:-len(tag)]
return ret, value
self.disable_force32, forward_type = checkpostfix("no32", forward_type)
self.disable_z, forward_type = checkpostfix("noz", forward_type)
self.disable_z_act, forward_type = checkpostfix("nozact", forward_type)
# softmax | sigmoid | dwconv | norm ===========================
if forward_type[-len("none"):] == "none":
forward_type = forward_type[:-len("none")]
self.out_norm = nn.Identity()
elif forward_type[-len("dwconv3"):] == "dwconv3":
forward_type = forward_type[:-len("dwconv3")]
self.out_norm = nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False)
self.out_norm_shape = "v1"
elif forward_type[-len("softmax"):] == "softmax":
forward_type = forward_type[:-len("softmax")]
self.out_norm = nn.Softmax(dim=1)
elif forward_type[-len("sigmoid"):] == "sigmoid":
forward_type = forward_type[:-len("sigmoid")]
self.out_norm = nn.Sigmoid()
else:
self.out_norm = nn.LayerNorm(d_inner)
# # forward_type debug =======================================
# FORWARD_TYPES = dict(
# v0=self.forward_corev0,
# # v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanCore),
# v2=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanCore),
# v3=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex),
# v31d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial(
# cross_selective_scan, CrossScan=CrossScan_Ab_1direction, CrossMerge=CrossMerge_Ab_1direction,
# )),
# v32d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial(
# cross_selective_scan, CrossScan=CrossScan_Ab_2direction, CrossMerge=CrossMerge_Ab_2direction,
# )),
# # ===============================
# fake=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanFake),
# v1=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanOflex),
# v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanMamba),
# )
# if forward_type.startswith("debug"):
# from .ss2d_ablations import SS2D_ForwardCoreSpeedAblations, SS2D_ForwardCoreModeAblations, cross_selective_scanv2
# FORWARD_TYPES.update(dict(
# debugforward_core_mambassm_seq=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_seq, self),
# debugforward_core_mambassm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm, self),
# debugforward_core_mambassm_fp16=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fp16, self),
# debugforward_core_mambassm_fusecs=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecs, self),
# debugforward_core_mambassm_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecscm, self),
# debugforward_core_sscore_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_sscore_fusecscm, self),
# debugforward_core_sscore_fusecscm_fwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fwdnrow, self),
# debugforward_core_sscore_fusecscm_bwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_bwdnrow, self),
# debugforward_core_sscore_fusecscm_fbnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fbnrow, self),
# debugforward_core_ssoflex_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm, self),
# debugforward_core_ssoflex_fusecscm_i16o32=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm_i16o32, self),
# debugscan_sharessm=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scanv2),
# ))
# self.forward_core = FORWARD_TYPES.get(forward_type, None)
# ZSJ k_group 指的是扫描的方向
# k_group = 4 if forward_type not in ["debugscan_sharessm"] else 1
k_group = 8 if forward_type not in ["debugscan_sharessm"] else 1
# in proj =======================================
d_proj = d_inner if self.disable_z else (d_inner * 2)
self.in_proj = nn.Linear(d_model, d_proj, bias=bias, **factory_kwargs)
self.act: nn.Module = act_layer()
# conv =======================================
if d_conv > 1:
self.conv2d = nn.Conv2d(
in_channels=d_inner,
out_channels=d_inner,
groups=d_inner,
bias=conv_bias,
kernel_size=d_conv,
padding=(d_conv - 1) // 2,
**factory_kwargs,
)
# x proj ============================
self.x_proj = [
nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs)
for _ in range(k_group)
]
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
del self.x_proj
# out proj =======================================
self.out_proj = nn.Linear(d_inner, d_model, bias=bias, **factory_kwargs)
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
if initialize in ["v0"]:
# dt proj ============================
self.dt_projs = [
self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs)
for _ in range(k_group)
]
self.dt_projs_weight = nn.Parameter(
torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank)
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner)
del self.dt_projs
# A, D =======================================
self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D)
elif initialize in ["v1"]:
# simple init dt_projs, A_logs, Ds
self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
self.A_logs = nn.Parameter(
torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
elif initialize in ["v2"]:
# simple init dt_projs, A_logs, Ds
self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
self.A_logs = nn.Parameter(
torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
@staticmethod
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
**factory_kwargs):
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = dt_rank ** -0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
# dt_proj.bias._no_reinit = True
return dt_proj
@staticmethod
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
if copies > 0:
A_log = repeat(A_log, "d n -> r d n", r=copies)
if merge:
A_log = A_log.flatten(0, 1)
A_log = nn.Parameter(A_log)
A_log._no_weight_decay = True
return A_log
@staticmethod
def D_init(d_inner, copies=-1, device=None, merge=True):
# D "skip" parameter
D = torch.ones(d_inner, device=device)
if copies > 0:
D = repeat(D, "n1 -> r n1", r=copies)
if merge:
D = D.flatten(0, 1)
D = nn.Parameter(D) # Keep in fp32
D._no_weight_decay = True
return D
# only used to run previous version
# def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False):
# def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):
# return SelectiveScanCore.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, False)
#
# if not channel_first:
# x = x.permute(0, 3, 1, 2).contiguous()
# B, D, H, W = x.shape
# D, N = self.A_logs.shape
# K, D, R = self.dt_projs_weight.shape
# L = H * W
#
# # ZSJ 这里进行data expand操作也就是把相同的数据在不同方向展开成一维并拼接起来,但是这个函数只用在旧版本
# # 把横向和竖向拼接在K维度
# x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
# # torch.flip把横向和竖向两个方向都进行反向操作
# xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
#
# x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
# # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
# dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
# dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
#
# xs = xs.float().view(B, -1, L) # (b, k * d, l)
# dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
# Bs = Bs.float() # (b, k, d_state, l)
# Cs = Cs.float() # (b, k, d_state, l)
#
# As = -torch.exp(self.A_logs.float()) # (k * d, d_state)
# Ds = self.Ds.float() # (k * d)
# dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
#
# # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
# # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
#
# out_y = selective_scan(
# xs, dts,
# As, Bs, Cs, Ds,
# delta_bias=dt_projs_bias,
# delta_softplus=True,
# ).view(B, K, -1, L)
# # assert out_y.dtype == torch.float
#
# inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
# wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
# invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
# y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
# y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
# y = self.out_norm(y).view(B, H, W, -1)
#
# return (y.to(x.dtype) if to_dtype else y)
def forward_corev2(self, x: torch.Tensor, channel_first=False, cross_selective_scan=cross_selective_scan,
force_fp32=None):
if not channel_first:
x = x.permute(0, 3, 1, 2).contiguous()
# ZSJ V2版本使用的mamba要改扫描方向在这里改
x = cross_selective_scan(
x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,
self.A_logs, self.Ds, delta_softplus=True,
out_norm=getattr(self, "out_norm", None),
out_norm_shape=getattr(self, "out_norm_shape", "v0"),
force_fp32=force_fp32,
)
return x
def forward(self, x: torch.Tensor, **kwargs):
with_dconv = (self.d_conv > 1)
x = self.in_proj(x)
if not self.disable_z:
x, z = x.chunk(2, dim=-1) # (b, h, w, d)
if not self.disable_z_act:
z = self.act(z)
if with_dconv:
x = x.permute(0, 3, 1, 2).contiguous()
x = self.conv2d(x) # (b, d, h, w)
x = self.act(x)
y = self.forward_corev2(x, channel_first=with_dconv)
if not self.disable_z:
y = y * z
out = self.dropout(self.out_proj(y))
return out
class Permute(nn.Module):
def __init__(self, *args):
super().__init__()
self.args = args
def forward(self, x: torch.Tensor):
return x.permute(*self.args)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
channels_first=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
self.fc1 = Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class OSSBlock(nn.Module):
def __init__(
self,
hidden_dim: int = 0,
drop_path: float = 0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
# =============================
ssm_d_state: int = 16,
ssm_ratio=2.0,
ssm_dt_rank: Any = "auto",
ssm_act_layer=nn.SiLU,
ssm_conv: int = 3,
ssm_conv_bias=True,
ssm_drop_rate: float = 0,
ssm_init="v0",
forward_type="v2",
# =============================
mlp_ratio=4.0,
mlp_act_layer=nn.GELU,
mlp_drop_rate: float = 0.0,
# =============================
use_checkpoint: bool = False,
post_norm: bool = False,
**kwargs,
):
super().__init__()
self.ssm_branch = ssm_ratio > 0
self.mlp_branch = mlp_ratio > 0
self.use_checkpoint = use_checkpoint
self.post_norm = post_norm
try:
from ss2d_ablations import SS2DDev
_OSSM = SS2DDev if forward_type.startswith("dev") else OSSM
except:
_OSSM = OSSM
if self.ssm_branch:
self.norm = norm_layer(hidden_dim)
self.op = _OSSM(
d_model=hidden_dim,
d_state=ssm_d_state,
ssm_ratio=ssm_ratio,
dt_rank=ssm_dt_rank,
act_layer=ssm_act_layer,
# ==========================
d_conv=ssm_conv,
conv_bias=ssm_conv_bias,
# ==========================
dropout=ssm_drop_rate,
# bias=False,
# ==========================
# dt_min=0.001,
# dt_max=0.1,
# dt_init="random",
# dt_scale="random",
# dt_init_floor=1e-4,
initialize=ssm_init,
# ==========================
forward_type=forward_type,
)
self.drop_path = DropPath(drop_path)
if self.mlp_branch:
self.norm2 = norm_layer(hidden_dim)
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer,
drop=mlp_drop_rate, channels_first=False)
def _forward(self, input: torch.Tensor):
if self.ssm_branch:
if self.post_norm:
x = input + self.drop_path(self.norm(self.op(input)))
else:
x = input + self.drop_path(self.op(self.norm(input)))
if self.mlp_branch:
if self.post_norm:
x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
else:
x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
return x
def forward(self, input: torch.Tensor):
if self.use_checkpoint:
return checkpoint.checkpoint(self._forward, input)
else:
return self._forward(input)
class Decoder_Block(nn.Module):
"""Basic block in decoder."""
def __init__(self, in_channel, out_channel):
super().__init__()
assert out_channel == in_channel // 2, 'the out_channel is not in_channel//2 in decoder block'
self.up = nn.Upsample(scale_factor=2, mode='nearest')
self.fuse = nn.Sequential(nn.Conv2d(in_channels=in_channel + out_channel, out_channels=out_channel,
kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
)
def forward(self, de, en):
de = self.up(de)
output = torch.cat([de, en], dim=1)
output = self.fuse(output)
return output
@MODELS.register_module()
class RSM_SS(nn.Module):
def __init__(
self,
patch_size=4,
in_chans=3,
num_classes=1000,
depths=[2, 2, 9, 2],
dims=[96, 192, 384, 768],
# =========================
ssm_d_state=16,
ssm_ratio=2.0,
ssm_dt_rank="auto",
ssm_act_layer="silu",
ssm_conv=3,
ssm_conv_bias=True,
ssm_drop_rate=0.0,
ssm_init="v0",
forward_type="v2",
# =========================
mlp_ratio=4.0,
mlp_act_layer="gelu",
mlp_drop_rate=0.0,
# =========================
drop_path_rate=0.1,
patch_norm=True,
norm_layer="LN",
use_checkpoint=False,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
if isinstance(dims, int):
dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
self.num_features = dims[-1]
self.dims = dims
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
_NORMLAYERS = dict(
ln=nn.LayerNorm,
bn=nn.BatchNorm2d,
)
_ACTLAYERS = dict(
silu=nn.SiLU,
gelu=nn.GELU,
relu=nn.ReLU,
sigmoid=nn.Sigmoid,
)
if isinstance(norm_layer, str) and norm_layer.lower() in ["ln"]:
norm_layer: nn.Module = _NORMLAYERS[norm_layer.lower()]
if isinstance(ssm_act_layer, str) and ssm_act_layer.lower() in ["silu", "gelu", "relu"]:
ssm_act_layer: nn.Module = _ACTLAYERS[ssm_act_layer.lower()]
if isinstance(mlp_act_layer, str) and mlp_act_layer.lower() in ["silu", "gelu", "relu"]:
mlp_act_layer: nn.Module = _ACTLAYERS[mlp_act_layer.lower()]
_make_patch_embed = self._make_patch_embed_v2
self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer)
_make_downsample = self._make_downsample_v3
# self.encoder_layers = [nn.ModuleList()] * self.num_layers
self.encoder_layers = []
self.decoder_layers = []
for i_layer in range(self.num_layers):
# downsample = _make_downsample(
# self.dims[i_layer],
# self.dims[i_layer + 1],
# norm_layer=norm_layer,
# ) if (i_layer < self.num_layers - 1) else nn.Identity()
downsample = _make_downsample(
self.dims[i_layer - 1],
self.dims[i_layer],
norm_layer=norm_layer,
) if (i_layer != 0) else nn.Identity() # ZSJ 修改为i_layer != 0也就是第一层不下采样和论文的图保持一致也方便我取出每个尺度处理好的特征
self.encoder_layers.append(self._make_layer(
dim=self.dims[i_layer],
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
use_checkpoint=use_checkpoint,
norm_layer=norm_layer,
downsample=downsample,
# =================
ssm_d_state=ssm_d_state,
ssm_ratio=ssm_ratio,
ssm_dt_rank=ssm_dt_rank,
ssm_act_layer=ssm_act_layer,
ssm_conv=ssm_conv,
ssm_conv_bias=ssm_conv_bias,
ssm_drop_rate=ssm_drop_rate,
ssm_init=ssm_init,
forward_type=forward_type,
# =================
mlp_ratio=mlp_ratio,
mlp_act_layer=mlp_act_layer,
mlp_drop_rate=mlp_drop_rate,
))
if i_layer != 0:
self.decoder_layers.append(
Decoder_Block(in_channel=self.dims[i_layer], out_channel=self.dims[i_layer - 1]))
self.encoder_block1, self.encoder_block2, self.encoder_block3, self.encoder_block4 = self.encoder_layers
self.deocder_block1, self.deocder_block2, self.deocder_block3 = self.decoder_layers
self.upsample_x4 = nn.Sequential(
nn.Conv2d(self.dims[0], self.dims[0] // 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(self.dims[0] // 2),
nn.ReLU(inplace=True),
nn.UpsamplingBilinear2d(scale_factor=2),
# nn.Conv2d(self.dims[0] // 2, 8, kernel_size=3, stride=1, padding=1),
nn.Conv2d(self.dims[0] // 2, self.dims[0] // 2, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True),
nn.UpsamplingBilinear2d(scale_factor=2)
)
# self.conv_out_seg = nn.Conv2d(8, 1, kernel_size=7, stride=1, padding=3)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@staticmethod
def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm):
assert patch_size == 4
return nn.Sequential(
nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1),
(Permute(0, 2, 3, 1) if patch_norm else nn.Identity()),
(norm_layer(embed_dim // 2) if patch_norm else nn.Identity()),
(Permute(0, 3, 1, 2) if patch_norm else nn.Identity()),
nn.GELU(),
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
Permute(0, 2, 3, 1),
(norm_layer(embed_dim) if patch_norm else nn.Identity()),
)
@staticmethod
def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm):
return nn.Sequential(
Permute(0, 3, 1, 2),
nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1),
Permute(0, 2, 3, 1),
norm_layer(out_dim),
)
@staticmethod
def _make_layer(
dim=96,
drop_path=[0.1, 0.1],
use_checkpoint=False,
norm_layer=nn.LayerNorm,
downsample=nn.Identity(),
# ===========================
ssm_d_state=16,
ssm_ratio=2.0,
ssm_dt_rank="auto",
ssm_act_layer=nn.SiLU,
ssm_conv=3,
ssm_conv_bias=True,
ssm_drop_rate=0.0,
ssm_init="v0",
forward_type="v2",
# ===========================
mlp_ratio=4.0,
mlp_act_layer=nn.GELU,
mlp_drop_rate=0.0,
**kwargs,
):
depth = len(drop_path)
blocks = []
for d in range(depth):
blocks.append(OSSBlock(
hidden_dim=dim,
drop_path=drop_path[d],
norm_layer=norm_layer,
ssm_d_state=ssm_d_state,
ssm_ratio=ssm_ratio,
ssm_dt_rank=ssm_dt_rank,
ssm_act_layer=ssm_act_layer,
ssm_conv=ssm_conv,
ssm_conv_bias=ssm_conv_bias,
ssm_drop_rate=ssm_drop_rate,
ssm_init=ssm_init,
forward_type=forward_type,
mlp_ratio=mlp_ratio,
mlp_act_layer=mlp_act_layer,
mlp_drop_rate=mlp_drop_rate,
use_checkpoint=use_checkpoint,
))
return nn.Sequential(OrderedDict(
# ZSJ 把downsample放到前面来方便我取出encoder中每个尺度处理好的图像而不是刚刚下采样完的图像
downsample=downsample,
blocks=nn.Sequential(*blocks, ),
))
def forward(self, x1: torch.Tensor):
x1 = self.patch_embed(x1)
x1_1 = self.encoder_block1(x1)
x1_2 = self.encoder_block2(x1_1)
x1_3 = self.encoder_block3(x1_2)
x1_4 = self.encoder_block4(x1_3) # b,h,w,c
x1_1 = rearrange(x1_1, "b h w c -> b c h w").contiguous()
x1_2 = rearrange(x1_2, "b h w c -> b c h w").contiguous()
x1_3 = rearrange(x1_3, "b h w c -> b c h w").contiguous()
x1_4 = rearrange(x1_4, "b h w c -> b c h w").contiguous()
decode_3 = self.deocder_block3(x1_4, x1_3)
decode_2 = self.deocder_block2(decode_3, x1_2)
decode_1 = self.deocder_block1(decode_2, x1_1)
output = self.upsample_x4(decode_1)
# output = self.conv_out_seg(output)
return output
# if __name__=="__main__":
# net=RSM_SS().to("cuda:1")
# img=torch.randn(2,3,512,512).to("cuda:1")
# out=net(img)
# print(out.size())
if __name__ == '__main__':
# Comment batchnorms here and in model_utils before testing speed since the batchnorm could be integrated into conv operation
# (do not comment all, just the batchnorm following its corresponding conv layer)
device = torch.device('cuda:0')
model = RSM_SS()
model.eval()
model.to(device)
iterations = None
input = torch.randn(1, 3, 512, 512).to(device)
with torch.no_grad():
for _ in range(10):
model(input)
if iterations is None:
elapsed_time = 0
iterations = 100
while elapsed_time < 1:
torch.cuda.synchronize()
torch.cuda.synchronize()
t_start = time.time()
for _ in range(iterations):
model(input)
torch.cuda.synchronize()
torch.cuda.synchronize()
elapsed_time = time.time() - t_start
iterations *= 2
FPS = iterations / elapsed_time
iterations = int(FPS * 6)
print('=========Speed Testing=========')
torch.cuda.synchronize()
torch.cuda.synchronize()
t_start = time.time()
for _ in range(iterations):
model(input)
torch.cuda.synchronize()
torch.cuda.synchronize()
elapsed_time = time.time() - t_start
latency = elapsed_time / iterations * 1000
torch.cuda.empty_cache()
FPS = 1000 / latency
print(FPS)