1403 lines
55 KiB
Python
1403 lines
55 KiB
Python
|
|
import os
|
|||
|
|
import time
|
|||
|
|
import math
|
|||
|
|
import copy
|
|||
|
|
from functools import partial
|
|||
|
|
from typing import Optional, Callable, Any
|
|||
|
|
from collections import OrderedDict
|
|||
|
|
from mmseg.registry import MODELS
|
|||
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
import torch.utils.checkpoint as checkpoint
|
|||
|
|
from einops import rearrange, repeat
|
|||
|
|
from timm.models.layers import DropPath, trunc_normal_
|
|||
|
|
|
|||
|
|
# from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
|
|||
|
|
DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"
|
|||
|
|
|
|||
|
|
# import selective scan ==============================
|
|||
|
|
try:
|
|||
|
|
import selective_scan_cuda_oflex
|
|||
|
|
except Exception as e:
|
|||
|
|
...
|
|||
|
|
# print(f"WARNING: can not import selective_scan_cuda_oflex.", flush=True)
|
|||
|
|
# print(e, flush=True)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import selective_scan_cuda_core
|
|||
|
|
except Exception as e:
|
|||
|
|
...
|
|||
|
|
# print(f"WARNING: can not import selective_scan_cuda_core.", flush=True)
|
|||
|
|
# print(e, flush=True)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import selective_scan_cuda
|
|||
|
|
except Exception as e:
|
|||
|
|
...
|
|||
|
|
# print(f"WARNING: can not import selective_scan_cuda.", flush=True)
|
|||
|
|
# print(e, flush=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# fvcore flops =======================================
|
|||
|
|
def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
|
|||
|
|
"""
|
|||
|
|
u: r(B D L)
|
|||
|
|
delta: r(B D L)
|
|||
|
|
A: r(D N)
|
|||
|
|
B: r(B N L)
|
|||
|
|
C: r(B N L)
|
|||
|
|
D: r(D)
|
|||
|
|
z: r(B D L)
|
|||
|
|
delta_bias: r(D), fp32
|
|||
|
|
|
|||
|
|
ignores:
|
|||
|
|
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
|
|||
|
|
"""
|
|||
|
|
assert not with_complex
|
|||
|
|
# https://github.com/state-spaces/mamba/issues/110
|
|||
|
|
flops = 9 * B * L * D * N
|
|||
|
|
if with_D:
|
|||
|
|
flops += B * D * L
|
|||
|
|
if with_Z:
|
|||
|
|
flops += B * D * L
|
|||
|
|
return flops
|
|||
|
|
|
|||
|
|
|
|||
|
|
# this is only for selective_scan_ref...
|
|||
|
|
def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
|
|||
|
|
"""
|
|||
|
|
u: r(B D L)
|
|||
|
|
delta: r(B D L)
|
|||
|
|
A: r(D N)
|
|||
|
|
B: r(B N L)
|
|||
|
|
C: r(B N L)
|
|||
|
|
D: r(D)
|
|||
|
|
z: r(B D L)
|
|||
|
|
delta_bias: r(D), fp32
|
|||
|
|
|
|||
|
|
ignores:
|
|||
|
|
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
|
|||
|
|
"""
|
|||
|
|
import numpy as np
|
|||
|
|
|
|||
|
|
# fvcore.nn.jit_handles
|
|||
|
|
def get_flops_einsum(input_shapes, equation):
|
|||
|
|
np_arrs = [np.zeros(s) for s in input_shapes]
|
|||
|
|
optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
|
|||
|
|
for line in optim.split("\n"):
|
|||
|
|
if "optimized flop" in line.lower():
|
|||
|
|
# divided by 2 because we count MAC (multiply-add counted as one flop)
|
|||
|
|
flop = float(np.floor(float(line.split(":")[-1]) / 2))
|
|||
|
|
return flop
|
|||
|
|
|
|||
|
|
assert not with_complex
|
|||
|
|
|
|||
|
|
flops = 0 # below code flops = 0
|
|||
|
|
|
|||
|
|
flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
|
|||
|
|
if with_Group:
|
|||
|
|
flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
|
|||
|
|
else:
|
|||
|
|
flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
|
|||
|
|
|
|||
|
|
in_for_flops = B * D * N
|
|||
|
|
if with_Group:
|
|||
|
|
in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
|
|||
|
|
else:
|
|||
|
|
in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
|
|||
|
|
flops += L * in_for_flops
|
|||
|
|
if with_D:
|
|||
|
|
flops += B * D * L
|
|||
|
|
if with_Z:
|
|||
|
|
flops += B * D * L
|
|||
|
|
return flops
|
|||
|
|
|
|||
|
|
|
|||
|
|
def print_jit_input_names(inputs):
|
|||
|
|
print("input params: ", end=" ", flush=True)
|
|||
|
|
try:
|
|||
|
|
for i in range(10):
|
|||
|
|
print(inputs[i].debugName(), end=" ", flush=True)
|
|||
|
|
except Exception as e:
|
|||
|
|
pass
|
|||
|
|
print("", flush=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# cross selective scan ===============================
|
|||
|
|
class SelectiveScanMamba(torch.autograd.Function):
|
|||
|
|
# comment all checks if inside cross_selective_scan
|
|||
|
|
@staticmethod
|
|||
|
|
@torch.cuda.amp.custom_fwd
|
|||
|
|
def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1,
|
|||
|
|
oflex=True):
|
|||
|
|
# assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile
|
|||
|
|
# assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}"
|
|||
|
|
ctx.delta_softplus = delta_softplus
|
|||
|
|
# all in float
|
|||
|
|
# if u.stride(-1) != 1:
|
|||
|
|
# u = u.contiguous()
|
|||
|
|
# if delta.stride(-1) != 1:
|
|||
|
|
# delta = delta.contiguous()
|
|||
|
|
# if D is not None and D.stride(-1) != 1:
|
|||
|
|
# D = D.contiguous()
|
|||
|
|
# if B.stride(-1) != 1:
|
|||
|
|
# B = B.contiguous()
|
|||
|
|
# if C.stride(-1) != 1:
|
|||
|
|
# C = C.contiguous()
|
|||
|
|
# if B.dim() == 3:
|
|||
|
|
# B = B.unsqueeze(dim=1)
|
|||
|
|
# ctx.squeeze_B = True
|
|||
|
|
# if C.dim() == 3:
|
|||
|
|
# C = C.unsqueeze(dim=1)
|
|||
|
|
# ctx.squeeze_C = True
|
|||
|
|
|
|||
|
|
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus)
|
|||
|
|
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
@torch.cuda.amp.custom_bwd
|
|||
|
|
def backward(ctx, dout, *args):
|
|||
|
|
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
|||
|
|
if dout.stride(-1) != 1:
|
|||
|
|
dout = dout.contiguous()
|
|||
|
|
|
|||
|
|
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
|||
|
|
u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus,
|
|||
|
|
False
|
|||
|
|
)
|
|||
|
|
# dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
|||
|
|
# dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
|||
|
|
return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# class SelectiveScanCore(torch.autograd.Function):
|
|||
|
|
# # comment all checks if inside cross_selective_scan
|
|||
|
|
# @staticmethod
|
|||
|
|
# @torch.cuda.amp.custom_fwd
|
|||
|
|
# def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True):
|
|||
|
|
# ctx.delta_softplus = delta_softplus
|
|||
|
|
# out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1)
|
|||
|
|
# ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
|||
|
|
# return out
|
|||
|
|
#
|
|||
|
|
# @staticmethod
|
|||
|
|
# @torch.cuda.amp.custom_bwd
|
|||
|
|
# def backward(ctx, dout, *args):
|
|||
|
|
# u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
|||
|
|
# if dout.stride(-1) != 1:
|
|||
|
|
# dout = dout.contiguous()
|
|||
|
|
# du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd(
|
|||
|
|
# u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
|
|||
|
|
# )
|
|||
|
|
# return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# class SelectiveScanOflex(torch.autograd.Function):
|
|||
|
|
# # comment all checks if inside cross_selective_scan
|
|||
|
|
# @staticmethod
|
|||
|
|
# @torch.cuda.amp.custom_fwd
|
|||
|
|
# def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True):
|
|||
|
|
# ctx.delta_softplus = delta_softplus
|
|||
|
|
# out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex)
|
|||
|
|
# ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
|||
|
|
# return out
|
|||
|
|
#
|
|||
|
|
# @staticmethod
|
|||
|
|
# @torch.cuda.amp.custom_bwd
|
|||
|
|
# def backward(ctx, dout, *args):
|
|||
|
|
# u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
|||
|
|
# if dout.stride(-1) != 1:
|
|||
|
|
# dout = dout.contiguous()
|
|||
|
|
# du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd(
|
|||
|
|
# u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1
|
|||
|
|
# )
|
|||
|
|
# return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# class SelectiveScanFake(torch.autograd.Function):
|
|||
|
|
# # comment all checks if inside cross_selective_scan
|
|||
|
|
# @staticmethod
|
|||
|
|
# @torch.cuda.amp.custom_fwd
|
|||
|
|
# def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True):
|
|||
|
|
# ctx.delta_softplus = delta_softplus
|
|||
|
|
# ctx.backnrows = backnrows
|
|||
|
|
# x = delta
|
|||
|
|
# out = u
|
|||
|
|
# ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
|||
|
|
# return out
|
|||
|
|
#
|
|||
|
|
# @staticmethod
|
|||
|
|
# @torch.cuda.amp.custom_bwd
|
|||
|
|
# def backward(ctx, dout, *args):
|
|||
|
|
# u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
|||
|
|
# if dout.stride(-1) != 1:
|
|||
|
|
# dout = dout.contiguous()
|
|||
|
|
# du, ddelta, dA, dB, dC, dD, ddelta_bias = u * 0, delta * 0, A * 0, B * 0, C * 0, C * 0, (D * 0 if D else None), (delta_bias * 0 if delta_bias else None)
|
|||
|
|
# return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None)
|
|||
|
|
|
|||
|
|
# =============
|
|||
|
|
def antidiagonal_gather(tensor):
|
|||
|
|
# 取出矩阵所有反斜向的元素并拼接
|
|||
|
|
B, C, H, W = tensor.size()
|
|||
|
|
shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1]
|
|||
|
|
index = (torch.arange(W, device=tensor.device) - shift) % W # 利用广播创建索引矩阵[H, W]
|
|||
|
|
# 扩展索引以适应B和C维度
|
|||
|
|
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
|
|||
|
|
# 使用gather进行索引选择
|
|||
|
|
return tensor.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def diagonal_gather(tensor):
|
|||
|
|
# 取出矩阵所有反斜向的元素并拼接
|
|||
|
|
B, C, H, W = tensor.size()
|
|||
|
|
shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1]
|
|||
|
|
index = (shift + torch.arange(W, device=tensor.device)) % W # 利用广播创建索引矩阵[H, W]
|
|||
|
|
# 扩展索引以适应B和C维度
|
|||
|
|
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
|
|||
|
|
# 使用gather进行索引选择
|
|||
|
|
return tensor.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def diagonal_scatter(tensor_flat, original_shape):
|
|||
|
|
# 把斜向元素拼接起来的一维向量还原为最初的矩阵形式
|
|||
|
|
B, C, H, W = original_shape
|
|||
|
|
shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1]
|
|||
|
|
index = (shift + torch.arange(W, device=tensor_flat.device)) % W # 利用广播创建索引矩阵[H, W]
|
|||
|
|
# 扩展索引以适应B和C维度
|
|||
|
|
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
|
|||
|
|
# 创建一个空的张量来存储反向散布的结果
|
|||
|
|
result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
|
|||
|
|
# 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换
|
|||
|
|
tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
|
|||
|
|
# 使用scatter_根据expanded_index将元素放回原位
|
|||
|
|
result_tensor.scatter_(3, expanded_index, tensor_reshaped)
|
|||
|
|
return result_tensor
|
|||
|
|
|
|||
|
|
|
|||
|
|
def antidiagonal_scatter(tensor_flat, original_shape):
|
|||
|
|
# 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式
|
|||
|
|
B, C, H, W = original_shape
|
|||
|
|
shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1]
|
|||
|
|
index = (torch.arange(W, device=tensor_flat.device) - shift) % W # 利用广播创建索引矩阵[H, W]
|
|||
|
|
expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1)
|
|||
|
|
# 初始化一个与原始张量形状相同、元素全为0的张量
|
|||
|
|
result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype)
|
|||
|
|
# 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度
|
|||
|
|
tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2)
|
|||
|
|
# 使用scatter_将元素根据索引放回原位
|
|||
|
|
result_tensor.scatter_(3, expanded_index, tensor_reshaped)
|
|||
|
|
return result_tensor
|
|||
|
|
|
|||
|
|
|
|||
|
|
class CrossScan(torch.autograd.Function):
|
|||
|
|
# ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改
|
|||
|
|
@staticmethod
|
|||
|
|
def forward(ctx, x: torch.Tensor):
|
|||
|
|
B, C, H, W = x.shape
|
|||
|
|
ctx.shape = (B, C, H, W)
|
|||
|
|
# xs = x.new_empty((B, 4, C, H * W))
|
|||
|
|
xs = x.new_empty((B, 8, C, H * W))
|
|||
|
|
# 添加横向和竖向的扫描
|
|||
|
|
xs[:, 0] = x.flatten(2, 3)
|
|||
|
|
xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
|
|||
|
|
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
|
|||
|
|
|
|||
|
|
# 提供斜向和反斜向的扫描
|
|||
|
|
xs[:, 4] = diagonal_gather(x)
|
|||
|
|
xs[:, 5] = antidiagonal_gather(x)
|
|||
|
|
xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
|
|||
|
|
|
|||
|
|
return xs
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def backward(ctx, ys: torch.Tensor):
|
|||
|
|
# out: (b, k, d, l)
|
|||
|
|
B, C, H, W = ctx.shape
|
|||
|
|
L = H * W
|
|||
|
|
# 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加
|
|||
|
|
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
|
|||
|
|
y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
|
|||
|
|
# 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
|
|||
|
|
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
|
|||
|
|
y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
|
|||
|
|
y_rb = y_rb.view(B, -1, H, W)
|
|||
|
|
|
|||
|
|
# 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
|
|||
|
|
y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L)
|
|||
|
|
# 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
|
|||
|
|
y_da = diagonal_scatter(y_da[:, 0], (B, C, H, W)) + antidiagonal_scatter(y_da[:, 1], (B, C, H, W))
|
|||
|
|
|
|||
|
|
y_res = y_rb + y_da
|
|||
|
|
# return y.view(B, -1, H, W)
|
|||
|
|
return y_res
|
|||
|
|
|
|||
|
|
|
|||
|
|
class CrossMerge(torch.autograd.Function):
|
|||
|
|
@staticmethod
|
|||
|
|
def forward(ctx, ys: torch.Tensor):
|
|||
|
|
B, K, D, H, W = ys.shape
|
|||
|
|
ctx.shape = (H, W)
|
|||
|
|
ys = ys.view(B, K, D, -1)
|
|||
|
|
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
|
|||
|
|
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
|
|||
|
|
|
|||
|
|
y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
|
|||
|
|
# 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式
|
|||
|
|
y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
|
|||
|
|
y_rb = y_rb.view(B, -1, H, W)
|
|||
|
|
|
|||
|
|
# 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加
|
|||
|
|
y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1)
|
|||
|
|
# 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加
|
|||
|
|
y_da = diagonal_scatter(y_da[:, 0], (B, D, H, W)) + antidiagonal_scatter(y_da[:, 1], (B, D, H, W))
|
|||
|
|
|
|||
|
|
y_res = y_rb + y_da
|
|||
|
|
return y_res.view(B, D, -1)
|
|||
|
|
# return y
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def backward(ctx, x: torch.Tensor):
|
|||
|
|
# B, D, L = x.shape
|
|||
|
|
# out: (b, k, d, l)
|
|||
|
|
H, W = ctx.shape
|
|||
|
|
B, C, L = x.shape
|
|||
|
|
# xs = x.new_empty((B, 4, C, L))
|
|||
|
|
xs = x.new_empty((B, 8, C, L))
|
|||
|
|
|
|||
|
|
# 横向和竖向扫描
|
|||
|
|
xs[:, 0] = x
|
|||
|
|
xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
|
|||
|
|
xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
|
|||
|
|
# xs = xs.view(B, 4, C, H, W)
|
|||
|
|
|
|||
|
|
# 提供斜向和反斜向的扫描
|
|||
|
|
xs[:, 4] = diagonal_gather(x.view(B, C, H, W))
|
|||
|
|
xs[:, 5] = antidiagonal_gather(x.view(B, C, H, W))
|
|||
|
|
xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1])
|
|||
|
|
|
|||
|
|
# return xs
|
|||
|
|
return xs.view(B, 8, C, H, W)
|
|||
|
|
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
ablation exp
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
# these are for ablations =============
|
|||
|
|
# class CrossScan_Ab_2direction(torch.autograd.Function):
|
|||
|
|
# @staticmethod
|
|||
|
|
# def forward(ctx, x: torch.Tensor):
|
|||
|
|
# B, C, H, W = x.shape
|
|||
|
|
# ctx.shape = (B, C, H, W)
|
|||
|
|
# xs = x.new_empty((B, 4, C, H * W))
|
|||
|
|
# xs[:, 0] = x.flatten(2, 3)
|
|||
|
|
# xs[:, 1] = x.flatten(2, 3)
|
|||
|
|
# xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
|
|||
|
|
# return xs
|
|||
|
|
#
|
|||
|
|
# @staticmethod
|
|||
|
|
# def backward(ctx, ys: torch.Tensor):
|
|||
|
|
# # out: (b, k, d, l)
|
|||
|
|
# B, C, H, W = ctx.shape
|
|||
|
|
# L = H * W
|
|||
|
|
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
|
|||
|
|
# y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
|
|||
|
|
# return y.view(B, -1, H, W)
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# class CrossMerge_Ab_2direction(torch.autograd.Function):
|
|||
|
|
# @staticmethod
|
|||
|
|
# def forward(ctx, ys: torch.Tensor):
|
|||
|
|
# B, K, D, H, W = ys.shape
|
|||
|
|
# ctx.shape = (H, W)
|
|||
|
|
# ys = ys.view(B, K, D, -1)
|
|||
|
|
# ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
|
|||
|
|
# y = ys.sum(dim=1)
|
|||
|
|
# return y
|
|||
|
|
#
|
|||
|
|
# @staticmethod
|
|||
|
|
# def backward(ctx, x: torch.Tensor):
|
|||
|
|
# # B, D, L = x.shape
|
|||
|
|
# # out: (b, k, d, l)
|
|||
|
|
# H, W = ctx.shape
|
|||
|
|
# B, C, L = x.shape
|
|||
|
|
# xs = x.new_empty((B, 4, C, L))
|
|||
|
|
# xs[:, 0] = x
|
|||
|
|
# xs[:, 1] = x
|
|||
|
|
# xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1])
|
|||
|
|
# xs = xs.view(B, 4, C, H, W)
|
|||
|
|
# return xs
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# class CrossScan_Ab_1direction(torch.autograd.Function):
|
|||
|
|
# @staticmethod
|
|||
|
|
# def forward(ctx, x: torch.Tensor):
|
|||
|
|
# B, C, H, W = x.shape
|
|||
|
|
# ctx.shape = (B, C, H, W)
|
|||
|
|
# xs = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1).contiguous()
|
|||
|
|
# return xs
|
|||
|
|
#
|
|||
|
|
# @staticmethod
|
|||
|
|
# def backward(ctx, ys: torch.Tensor):
|
|||
|
|
# # out: (b, k, d, l)
|
|||
|
|
# B, C, H, W = ctx.shape
|
|||
|
|
# y = ys.sum(dim=1).view(B, C, H, W)
|
|||
|
|
# return y
|
|||
|
|
#
|
|||
|
|
#
|
|||
|
|
# class CrossMerge_Ab_1direction(torch.autograd.Function):
|
|||
|
|
# @staticmethod
|
|||
|
|
# def forward(ctx, ys: torch.Tensor):
|
|||
|
|
# B, K, D, H, W = ys.shape
|
|||
|
|
# ctx.shape = (H, W)
|
|||
|
|
# y = ys.sum(dim=1).view(B, D, H * W)
|
|||
|
|
# return y
|
|||
|
|
#
|
|||
|
|
# @staticmethod
|
|||
|
|
# def backward(ctx, x: torch.Tensor):
|
|||
|
|
# # B, D, L = x.shape
|
|||
|
|
# # out: (b, k, d, l)
|
|||
|
|
# H, W = ctx.shape
|
|||
|
|
# B, C, L = x.shape
|
|||
|
|
# xs = x.view(B, 1, C, L).repeat(1, 4, 1, 1).contiguous().view(B, 4, C, H, W)
|
|||
|
|
# return xs
|
|||
|
|
#
|
|||
|
|
|
|||
|
|
# =============
|
|||
|
|
# ZSJ 这里是mamba的具体内容,要增加扫描方向就在这里改
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SelectiveScanFn(torch.autograd.Function):
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
|||
|
|
return_last_state=False):
|
|||
|
|
if u.stride(-1) != 1:
|
|||
|
|
u = u.contiguous()
|
|||
|
|
if delta.stride(-1) != 1:
|
|||
|
|
delta = delta.contiguous()
|
|||
|
|
if D is not None:
|
|||
|
|
D = D.contiguous()
|
|||
|
|
if B.stride(-1) != 1:
|
|||
|
|
B = B.contiguous()
|
|||
|
|
if C.stride(-1) != 1:
|
|||
|
|
C = C.contiguous()
|
|||
|
|
if z is not None and z.stride(-1) != 1:
|
|||
|
|
z = z.contiguous()
|
|||
|
|
if B.dim() == 3:
|
|||
|
|
B = rearrange(B, "b dstate l -> b 1 dstate l")
|
|||
|
|
ctx.squeeze_B = True
|
|||
|
|
if C.dim() == 3:
|
|||
|
|
C = rearrange(C, "b dstate l -> b 1 dstate l")
|
|||
|
|
ctx.squeeze_C = True
|
|||
|
|
out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
|
|||
|
|
ctx.delta_softplus = delta_softplus
|
|||
|
|
ctx.has_z = z is not None
|
|||
|
|
last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
|
|||
|
|
if not ctx.has_z:
|
|||
|
|
ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
|
|||
|
|
return out if not return_last_state else (out, last_state)
|
|||
|
|
else:
|
|||
|
|
ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
|
|||
|
|
out_z = rest[0]
|
|||
|
|
return out_z if not return_last_state else (out_z, last_state)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def backward(ctx, dout, *args):
|
|||
|
|
if not ctx.has_z:
|
|||
|
|
u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
|
|||
|
|
z = None
|
|||
|
|
out = None
|
|||
|
|
else:
|
|||
|
|
u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
|
|||
|
|
if dout.stride(-1) != 1:
|
|||
|
|
dout = dout.contiguous()
|
|||
|
|
# The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
|
|||
|
|
# backward of selective_scan_cuda with the backward of chunk).
|
|||
|
|
# Here we just pass in None and dz will be allocated in the C++ code.
|
|||
|
|
du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
|
|||
|
|
u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
|
|||
|
|
False # option to recompute out_z, not used here
|
|||
|
|
)
|
|||
|
|
dz = rest[0] if ctx.has_z else None
|
|||
|
|
dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
|
|||
|
|
dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
|
|||
|
|
return (du, ddelta, dA, dB, dC,
|
|||
|
|
dD if D is not None else None,
|
|||
|
|
dz,
|
|||
|
|
ddelta_bias if delta_bias is not None else None,
|
|||
|
|
None,
|
|||
|
|
None)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def cross_selective_scan(
|
|||
|
|
x: torch.Tensor = None,
|
|||
|
|
x_proj_weight: torch.Tensor = None,
|
|||
|
|
x_proj_bias: torch.Tensor = None,
|
|||
|
|
dt_projs_weight: torch.Tensor = None,
|
|||
|
|
dt_projs_bias: torch.Tensor = None,
|
|||
|
|
A_logs: torch.Tensor = None,
|
|||
|
|
Ds: torch.Tensor = None,
|
|||
|
|
delta_softplus=True,
|
|||
|
|
out_norm: torch.nn.Module = None,
|
|||
|
|
out_norm_shape="v0",
|
|||
|
|
# ==============================
|
|||
|
|
to_dtype=True, # True: final out to dtype
|
|||
|
|
force_fp32=False, # True: input fp32
|
|||
|
|
# ==============================
|
|||
|
|
nrows=-1, # for SelectiveScanNRow; 0: auto; -1: disable;
|
|||
|
|
backnrows=-1, # for SelectiveScanNRow; 0: auto; -1: disable;
|
|||
|
|
ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore
|
|||
|
|
# ==============================
|
|||
|
|
# SelectiveScan=None,
|
|||
|
|
CrossScan=CrossScan,
|
|||
|
|
CrossMerge=CrossMerge,
|
|||
|
|
):
|
|||
|
|
# out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);...
|
|||
|
|
|
|||
|
|
B, D, H, W = x.shape
|
|||
|
|
D, N = A_logs.shape
|
|||
|
|
K, D, R = dt_projs_weight.shape
|
|||
|
|
L = H * W
|
|||
|
|
|
|||
|
|
if nrows == 0:
|
|||
|
|
if D % 4 == 0:
|
|||
|
|
nrows = 4
|
|||
|
|
elif D % 3 == 0:
|
|||
|
|
nrows = 3
|
|||
|
|
elif D % 2 == 0:
|
|||
|
|
nrows = 2
|
|||
|
|
else:
|
|||
|
|
nrows = 1
|
|||
|
|
|
|||
|
|
if backnrows == 0:
|
|||
|
|
if D % 4 == 0:
|
|||
|
|
backnrows = 4
|
|||
|
|
elif D % 3 == 0:
|
|||
|
|
backnrows = 3
|
|||
|
|
elif D % 2 == 0:
|
|||
|
|
backnrows = 2
|
|||
|
|
else:
|
|||
|
|
backnrows = 1
|
|||
|
|
|
|||
|
|
# sacn jiushizhege bata huancheng houmian nage banben
|
|||
|
|
# def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True):
|
|||
|
|
# return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex)
|
|||
|
|
|
|||
|
|
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
|
|||
|
|
return_last_state=False):
|
|||
|
|
"""if return_last_state is True, returns (out, last_state)
|
|||
|
|
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
|
|||
|
|
not considered in the backward pass.
|
|||
|
|
"""
|
|||
|
|
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
|
|||
|
|
|
|||
|
|
xs = CrossScan.apply(x)
|
|||
|
|
|
|||
|
|
x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight)
|
|||
|
|
if x_proj_bias is not None:
|
|||
|
|
x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1)
|
|||
|
|
dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
|
|||
|
|
dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight)
|
|||
|
|
xs = xs.view(B, -1, L)
|
|||
|
|
dts = dts.contiguous().view(B, -1, L)
|
|||
|
|
As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state)
|
|||
|
|
Bs = Bs.contiguous()
|
|||
|
|
Cs = Cs.contiguous()
|
|||
|
|
Ds = Ds.to(torch.float) # (K * c)
|
|||
|
|
delta_bias = dt_projs_bias.view(-1).to(torch.float)
|
|||
|
|
|
|||
|
|
if force_fp32:
|
|||
|
|
xs = xs.to(torch.float)
|
|||
|
|
dts = dts.to(torch.float)
|
|||
|
|
Bs = Bs.to(torch.float)
|
|||
|
|
Cs = Cs.to(torch.float)
|
|||
|
|
# ZSJ 这里把矩阵拆分成不同方向的序列,并进行扫描
|
|||
|
|
ys: torch.Tensor = selective_scan_fn(
|
|||
|
|
xs, dts,
|
|||
|
|
As, Bs, Cs, Ds, z=None,
|
|||
|
|
delta_bias=delta_bias,
|
|||
|
|
delta_softplus=True,
|
|||
|
|
return_last_state=False,
|
|||
|
|
).view(B, K, -1, H, W)
|
|||
|
|
# ZSJ 这里把处理之后的序列融合起来,并还原回原来的矩阵形式
|
|||
|
|
y: torch.Tensor = CrossMerge.apply(ys)
|
|||
|
|
|
|||
|
|
if out_norm_shape in ["v1"]: # (B, C, H, W)
|
|||
|
|
y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) # (B, H, W, C)
|
|||
|
|
else: # (B, L, C)
|
|||
|
|
y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
|
|||
|
|
y = out_norm(y).view(B, H, W, -1)
|
|||
|
|
|
|||
|
|
return (y.to(x.dtype) if to_dtype else y)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def selective_scan_flop_jit(inputs, outputs):
|
|||
|
|
print_jit_input_names(inputs)
|
|||
|
|
B, D, L = inputs[0].type().sizes()
|
|||
|
|
N = inputs[2].type().sizes()[1]
|
|||
|
|
flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
|
|||
|
|
return flops
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =====================================================
|
|||
|
|
|
|||
|
|
class PatchMerging2D(nn.Module):
|
|||
|
|
def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm):
|
|||
|
|
super().__init__()
|
|||
|
|
self.dim = dim
|
|||
|
|
self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False)
|
|||
|
|
self.norm = norm_layer(4 * dim)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _patch_merging_pad(x: torch.Tensor):
|
|||
|
|
H, W, _ = x.shape[-3:]
|
|||
|
|
if (W % 2 != 0) or (H % 2 != 0):
|
|||
|
|
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
|||
|
|
x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C
|
|||
|
|
x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C
|
|||
|
|
x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C
|
|||
|
|
x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C
|
|||
|
|
x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
x = self._patch_merging_pad(x)
|
|||
|
|
x = self.norm(x)
|
|||
|
|
x = self.reduction(x)
|
|||
|
|
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OSSM(nn.Module):
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
# basic dims ===========
|
|||
|
|
d_model=96,
|
|||
|
|
d_state=16,
|
|||
|
|
ssm_ratio=2.0,
|
|||
|
|
dt_rank="auto",
|
|||
|
|
act_layer=nn.SiLU,
|
|||
|
|
# dwconv ===============
|
|||
|
|
d_conv=3, # < 2 means no conv
|
|||
|
|
conv_bias=True,
|
|||
|
|
# ======================
|
|||
|
|
dropout=0.0,
|
|||
|
|
bias=False,
|
|||
|
|
# dt init ==============
|
|||
|
|
dt_min=0.001,
|
|||
|
|
dt_max=0.1,
|
|||
|
|
dt_init="random",
|
|||
|
|
dt_scale=1.0,
|
|||
|
|
dt_init_floor=1e-4,
|
|||
|
|
initialize="v0",
|
|||
|
|
# ======================
|
|||
|
|
forward_type="v2",
|
|||
|
|
# ======================
|
|||
|
|
**kwargs,
|
|||
|
|
):
|
|||
|
|
factory_kwargs = {"device": None, "dtype": None}
|
|||
|
|
super().__init__()
|
|||
|
|
d_inner = int(ssm_ratio * d_model)
|
|||
|
|
dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank
|
|||
|
|
self.d_conv = d_conv
|
|||
|
|
|
|||
|
|
# tags for forward_type ==============================
|
|||
|
|
def checkpostfix(tag, value):
|
|||
|
|
ret = value[-len(tag):] == tag
|
|||
|
|
if ret:
|
|||
|
|
value = value[:-len(tag)]
|
|||
|
|
return ret, value
|
|||
|
|
|
|||
|
|
self.disable_force32, forward_type = checkpostfix("no32", forward_type)
|
|||
|
|
self.disable_z, forward_type = checkpostfix("noz", forward_type)
|
|||
|
|
self.disable_z_act, forward_type = checkpostfix("nozact", forward_type)
|
|||
|
|
|
|||
|
|
# softmax | sigmoid | dwconv | norm ===========================
|
|||
|
|
if forward_type[-len("none"):] == "none":
|
|||
|
|
forward_type = forward_type[:-len("none")]
|
|||
|
|
self.out_norm = nn.Identity()
|
|||
|
|
elif forward_type[-len("dwconv3"):] == "dwconv3":
|
|||
|
|
forward_type = forward_type[:-len("dwconv3")]
|
|||
|
|
self.out_norm = nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False)
|
|||
|
|
self.out_norm_shape = "v1"
|
|||
|
|
elif forward_type[-len("softmax"):] == "softmax":
|
|||
|
|
forward_type = forward_type[:-len("softmax")]
|
|||
|
|
self.out_norm = nn.Softmax(dim=1)
|
|||
|
|
elif forward_type[-len("sigmoid"):] == "sigmoid":
|
|||
|
|
forward_type = forward_type[:-len("sigmoid")]
|
|||
|
|
self.out_norm = nn.Sigmoid()
|
|||
|
|
else:
|
|||
|
|
self.out_norm = nn.LayerNorm(d_inner)
|
|||
|
|
|
|||
|
|
# # forward_type debug =======================================
|
|||
|
|
# FORWARD_TYPES = dict(
|
|||
|
|
# v0=self.forward_corev0,
|
|||
|
|
# # v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanCore),
|
|||
|
|
# v2=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanCore),
|
|||
|
|
# v3=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex),
|
|||
|
|
# v31d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial(
|
|||
|
|
# cross_selective_scan, CrossScan=CrossScan_Ab_1direction, CrossMerge=CrossMerge_Ab_1direction,
|
|||
|
|
# )),
|
|||
|
|
# v32d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial(
|
|||
|
|
# cross_selective_scan, CrossScan=CrossScan_Ab_2direction, CrossMerge=CrossMerge_Ab_2direction,
|
|||
|
|
# )),
|
|||
|
|
# # ===============================
|
|||
|
|
# fake=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanFake),
|
|||
|
|
# v1=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanOflex),
|
|||
|
|
# v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanMamba),
|
|||
|
|
# )
|
|||
|
|
# if forward_type.startswith("debug"):
|
|||
|
|
# from .ss2d_ablations import SS2D_ForwardCoreSpeedAblations, SS2D_ForwardCoreModeAblations, cross_selective_scanv2
|
|||
|
|
# FORWARD_TYPES.update(dict(
|
|||
|
|
# debugforward_core_mambassm_seq=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_seq, self),
|
|||
|
|
# debugforward_core_mambassm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm, self),
|
|||
|
|
# debugforward_core_mambassm_fp16=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fp16, self),
|
|||
|
|
# debugforward_core_mambassm_fusecs=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecs, self),
|
|||
|
|
# debugforward_core_mambassm_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecscm, self),
|
|||
|
|
# debugforward_core_sscore_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_sscore_fusecscm, self),
|
|||
|
|
# debugforward_core_sscore_fusecscm_fwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fwdnrow, self),
|
|||
|
|
# debugforward_core_sscore_fusecscm_bwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_bwdnrow, self),
|
|||
|
|
# debugforward_core_sscore_fusecscm_fbnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fbnrow, self),
|
|||
|
|
# debugforward_core_ssoflex_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm, self),
|
|||
|
|
# debugforward_core_ssoflex_fusecscm_i16o32=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm_i16o32, self),
|
|||
|
|
# debugscan_sharessm=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scanv2),
|
|||
|
|
# ))
|
|||
|
|
# self.forward_core = FORWARD_TYPES.get(forward_type, None)
|
|||
|
|
# ZSJ k_group 指的是扫描的方向
|
|||
|
|
# k_group = 4 if forward_type not in ["debugscan_sharessm"] else 1
|
|||
|
|
k_group = 8 if forward_type not in ["debugscan_sharessm"] else 1
|
|||
|
|
|
|||
|
|
# in proj =======================================
|
|||
|
|
d_proj = d_inner if self.disable_z else (d_inner * 2)
|
|||
|
|
self.in_proj = nn.Linear(d_model, d_proj, bias=bias, **factory_kwargs)
|
|||
|
|
self.act: nn.Module = act_layer()
|
|||
|
|
|
|||
|
|
# conv =======================================
|
|||
|
|
if d_conv > 1:
|
|||
|
|
self.conv2d = nn.Conv2d(
|
|||
|
|
in_channels=d_inner,
|
|||
|
|
out_channels=d_inner,
|
|||
|
|
groups=d_inner,
|
|||
|
|
bias=conv_bias,
|
|||
|
|
kernel_size=d_conv,
|
|||
|
|
padding=(d_conv - 1) // 2,
|
|||
|
|
**factory_kwargs,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# x proj ============================
|
|||
|
|
self.x_proj = [
|
|||
|
|
nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs)
|
|||
|
|
for _ in range(k_group)
|
|||
|
|
]
|
|||
|
|
self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner)
|
|||
|
|
del self.x_proj
|
|||
|
|
|
|||
|
|
# out proj =======================================
|
|||
|
|
self.out_proj = nn.Linear(d_inner, d_model, bias=bias, **factory_kwargs)
|
|||
|
|
self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity()
|
|||
|
|
|
|||
|
|
if initialize in ["v0"]:
|
|||
|
|
# dt proj ============================
|
|||
|
|
self.dt_projs = [
|
|||
|
|
self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs)
|
|||
|
|
for _ in range(k_group)
|
|||
|
|
]
|
|||
|
|
self.dt_projs_weight = nn.Parameter(
|
|||
|
|
torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank)
|
|||
|
|
self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner)
|
|||
|
|
del self.dt_projs
|
|||
|
|
|
|||
|
|
# A, D =======================================
|
|||
|
|
self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N)
|
|||
|
|
self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D)
|
|||
|
|
elif initialize in ["v1"]:
|
|||
|
|
# simple init dt_projs, A_logs, Ds
|
|||
|
|
self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
|
|||
|
|
self.A_logs = nn.Parameter(
|
|||
|
|
torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
|
|||
|
|
self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
|
|||
|
|
self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
|
|||
|
|
elif initialize in ["v2"]:
|
|||
|
|
# simple init dt_projs, A_logs, Ds
|
|||
|
|
self.Ds = nn.Parameter(torch.ones((k_group * d_inner)))
|
|||
|
|
self.A_logs = nn.Parameter(
|
|||
|
|
torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1
|
|||
|
|
self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank)))
|
|||
|
|
self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner)))
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
|
|||
|
|
**factory_kwargs):
|
|||
|
|
dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
|
|||
|
|
|
|||
|
|
# Initialize special dt projection to preserve variance at initialization
|
|||
|
|
dt_init_std = dt_rank ** -0.5 * dt_scale
|
|||
|
|
if dt_init == "constant":
|
|||
|
|
nn.init.constant_(dt_proj.weight, dt_init_std)
|
|||
|
|
elif dt_init == "random":
|
|||
|
|
nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
|
|||
|
|
else:
|
|||
|
|
raise NotImplementedError
|
|||
|
|
|
|||
|
|
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
|
|||
|
|
dt = torch.exp(
|
|||
|
|
torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
|
|||
|
|
+ math.log(dt_min)
|
|||
|
|
).clamp(min=dt_init_floor)
|
|||
|
|
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
|
|||
|
|
inv_dt = dt + torch.log(-torch.expm1(-dt))
|
|||
|
|
with torch.no_grad():
|
|||
|
|
dt_proj.bias.copy_(inv_dt)
|
|||
|
|
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
|
|||
|
|
# dt_proj.bias._no_reinit = True
|
|||
|
|
|
|||
|
|
return dt_proj
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True):
|
|||
|
|
# S4D real initialization
|
|||
|
|
A = repeat(
|
|||
|
|
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
|
|||
|
|
"n -> d n",
|
|||
|
|
d=d_inner,
|
|||
|
|
).contiguous()
|
|||
|
|
A_log = torch.log(A) # Keep A_log in fp32
|
|||
|
|
if copies > 0:
|
|||
|
|
A_log = repeat(A_log, "d n -> r d n", r=copies)
|
|||
|
|
if merge:
|
|||
|
|
A_log = A_log.flatten(0, 1)
|
|||
|
|
A_log = nn.Parameter(A_log)
|
|||
|
|
A_log._no_weight_decay = True
|
|||
|
|
return A_log
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def D_init(d_inner, copies=-1, device=None, merge=True):
|
|||
|
|
# D "skip" parameter
|
|||
|
|
D = torch.ones(d_inner, device=device)
|
|||
|
|
if copies > 0:
|
|||
|
|
D = repeat(D, "n1 -> r n1", r=copies)
|
|||
|
|
if merge:
|
|||
|
|
D = D.flatten(0, 1)
|
|||
|
|
D = nn.Parameter(D) # Keep in fp32
|
|||
|
|
D._no_weight_decay = True
|
|||
|
|
return D
|
|||
|
|
|
|||
|
|
# only used to run previous version
|
|||
|
|
# def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False):
|
|||
|
|
# def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1):
|
|||
|
|
# return SelectiveScanCore.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, False)
|
|||
|
|
#
|
|||
|
|
# if not channel_first:
|
|||
|
|
# x = x.permute(0, 3, 1, 2).contiguous()
|
|||
|
|
# B, D, H, W = x.shape
|
|||
|
|
# D, N = self.A_logs.shape
|
|||
|
|
# K, D, R = self.dt_projs_weight.shape
|
|||
|
|
# L = H * W
|
|||
|
|
#
|
|||
|
|
# # ZSJ 这里进行data expand操作,也就是把相同的数据在不同方向展开成一维,并拼接起来,但是这个函数只用在旧版本
|
|||
|
|
# # 把横向和竖向拼接在K维度
|
|||
|
|
# x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
|
|||
|
|
# # torch.flip把横向和竖向两个方向都进行反向操作
|
|||
|
|
# xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)
|
|||
|
|
#
|
|||
|
|
# x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight)
|
|||
|
|
# # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
|
|||
|
|
# dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2)
|
|||
|
|
# dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight)
|
|||
|
|
#
|
|||
|
|
# xs = xs.float().view(B, -1, L) # (b, k * d, l)
|
|||
|
|
# dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
|
|||
|
|
# Bs = Bs.float() # (b, k, d_state, l)
|
|||
|
|
# Cs = Cs.float() # (b, k, d_state, l)
|
|||
|
|
#
|
|||
|
|
# As = -torch.exp(self.A_logs.float()) # (k * d, d_state)
|
|||
|
|
# Ds = self.Ds.float() # (k * d)
|
|||
|
|
# dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)
|
|||
|
|
#
|
|||
|
|
# # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4
|
|||
|
|
# # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1
|
|||
|
|
#
|
|||
|
|
# out_y = selective_scan(
|
|||
|
|
# xs, dts,
|
|||
|
|
# As, Bs, Cs, Ds,
|
|||
|
|
# delta_bias=dt_projs_bias,
|
|||
|
|
# delta_softplus=True,
|
|||
|
|
# ).view(B, K, -1, L)
|
|||
|
|
# # assert out_y.dtype == torch.float
|
|||
|
|
#
|
|||
|
|
# inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
|
|||
|
|
# wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
|||
|
|
# invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
|
|||
|
|
# y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
|
|||
|
|
# y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C)
|
|||
|
|
# y = self.out_norm(y).view(B, H, W, -1)
|
|||
|
|
#
|
|||
|
|
# return (y.to(x.dtype) if to_dtype else y)
|
|||
|
|
|
|||
|
|
def forward_corev2(self, x: torch.Tensor, channel_first=False, cross_selective_scan=cross_selective_scan,
|
|||
|
|
force_fp32=None):
|
|||
|
|
|
|||
|
|
if not channel_first:
|
|||
|
|
x = x.permute(0, 3, 1, 2).contiguous()
|
|||
|
|
|
|||
|
|
# ZSJ V2版本使用的mamba,要改扫描方向在这里改
|
|||
|
|
x = cross_selective_scan(
|
|||
|
|
x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias,
|
|||
|
|
self.A_logs, self.Ds, delta_softplus=True,
|
|||
|
|
out_norm=getattr(self, "out_norm", None),
|
|||
|
|
out_norm_shape=getattr(self, "out_norm_shape", "v0"),
|
|||
|
|
force_fp32=force_fp32,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
def forward(self, x: torch.Tensor, **kwargs):
|
|||
|
|
with_dconv = (self.d_conv > 1)
|
|||
|
|
x = self.in_proj(x)
|
|||
|
|
if not self.disable_z:
|
|||
|
|
x, z = x.chunk(2, dim=-1) # (b, h, w, d)
|
|||
|
|
if not self.disable_z_act:
|
|||
|
|
z = self.act(z)
|
|||
|
|
if with_dconv:
|
|||
|
|
x = x.permute(0, 3, 1, 2).contiguous()
|
|||
|
|
x = self.conv2d(x) # (b, d, h, w)
|
|||
|
|
x = self.act(x)
|
|||
|
|
y = self.forward_corev2(x, channel_first=with_dconv)
|
|||
|
|
if not self.disable_z:
|
|||
|
|
y = y * z
|
|||
|
|
out = self.dropout(self.out_proj(y))
|
|||
|
|
return out
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Permute(nn.Module):
|
|||
|
|
def __init__(self, *args):
|
|||
|
|
super().__init__()
|
|||
|
|
self.args = args
|
|||
|
|
|
|||
|
|
def forward(self, x: torch.Tensor):
|
|||
|
|
return x.permute(*self.args)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Mlp(nn.Module):
|
|||
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
|
|||
|
|
channels_first=False):
|
|||
|
|
super().__init__()
|
|||
|
|
out_features = out_features or in_features
|
|||
|
|
hidden_features = hidden_features or in_features
|
|||
|
|
|
|||
|
|
Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear
|
|||
|
|
self.fc1 = Linear(in_features, hidden_features)
|
|||
|
|
self.act = act_layer()
|
|||
|
|
self.fc2 = Linear(hidden_features, out_features)
|
|||
|
|
self.drop = nn.Dropout(drop)
|
|||
|
|
|
|||
|
|
def forward(self, x):
|
|||
|
|
x = self.fc1(x)
|
|||
|
|
x = self.act(x)
|
|||
|
|
x = self.drop(x)
|
|||
|
|
x = self.fc2(x)
|
|||
|
|
x = self.drop(x)
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
|
|||
|
|
class OSSBlock(nn.Module):
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
hidden_dim: int = 0,
|
|||
|
|
drop_path: float = 0,
|
|||
|
|
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
|
|||
|
|
# =============================
|
|||
|
|
ssm_d_state: int = 16,
|
|||
|
|
ssm_ratio=2.0,
|
|||
|
|
ssm_dt_rank: Any = "auto",
|
|||
|
|
ssm_act_layer=nn.SiLU,
|
|||
|
|
ssm_conv: int = 3,
|
|||
|
|
ssm_conv_bias=True,
|
|||
|
|
ssm_drop_rate: float = 0,
|
|||
|
|
ssm_init="v0",
|
|||
|
|
forward_type="v2",
|
|||
|
|
# =============================
|
|||
|
|
mlp_ratio=4.0,
|
|||
|
|
mlp_act_layer=nn.GELU,
|
|||
|
|
mlp_drop_rate: float = 0.0,
|
|||
|
|
# =============================
|
|||
|
|
use_checkpoint: bool = False,
|
|||
|
|
post_norm: bool = False,
|
|||
|
|
**kwargs,
|
|||
|
|
):
|
|||
|
|
super().__init__()
|
|||
|
|
self.ssm_branch = ssm_ratio > 0
|
|||
|
|
self.mlp_branch = mlp_ratio > 0
|
|||
|
|
self.use_checkpoint = use_checkpoint
|
|||
|
|
self.post_norm = post_norm
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from ss2d_ablations import SS2DDev
|
|||
|
|
_OSSM = SS2DDev if forward_type.startswith("dev") else OSSM
|
|||
|
|
except:
|
|||
|
|
_OSSM = OSSM
|
|||
|
|
|
|||
|
|
if self.ssm_branch:
|
|||
|
|
self.norm = norm_layer(hidden_dim)
|
|||
|
|
self.op = _OSSM(
|
|||
|
|
d_model=hidden_dim,
|
|||
|
|
d_state=ssm_d_state,
|
|||
|
|
ssm_ratio=ssm_ratio,
|
|||
|
|
dt_rank=ssm_dt_rank,
|
|||
|
|
act_layer=ssm_act_layer,
|
|||
|
|
# ==========================
|
|||
|
|
d_conv=ssm_conv,
|
|||
|
|
conv_bias=ssm_conv_bias,
|
|||
|
|
# ==========================
|
|||
|
|
dropout=ssm_drop_rate,
|
|||
|
|
# bias=False,
|
|||
|
|
# ==========================
|
|||
|
|
# dt_min=0.001,
|
|||
|
|
# dt_max=0.1,
|
|||
|
|
# dt_init="random",
|
|||
|
|
# dt_scale="random",
|
|||
|
|
# dt_init_floor=1e-4,
|
|||
|
|
initialize=ssm_init,
|
|||
|
|
# ==========================
|
|||
|
|
forward_type=forward_type,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.drop_path = DropPath(drop_path)
|
|||
|
|
|
|||
|
|
if self.mlp_branch:
|
|||
|
|
self.norm2 = norm_layer(hidden_dim)
|
|||
|
|
mlp_hidden_dim = int(hidden_dim * mlp_ratio)
|
|||
|
|
self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer,
|
|||
|
|
drop=mlp_drop_rate, channels_first=False)
|
|||
|
|
|
|||
|
|
def _forward(self, input: torch.Tensor):
|
|||
|
|
if self.ssm_branch:
|
|||
|
|
if self.post_norm:
|
|||
|
|
x = input + self.drop_path(self.norm(self.op(input)))
|
|||
|
|
else:
|
|||
|
|
x = input + self.drop_path(self.op(self.norm(input)))
|
|||
|
|
if self.mlp_branch:
|
|||
|
|
if self.post_norm:
|
|||
|
|
x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN
|
|||
|
|
else:
|
|||
|
|
x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN
|
|||
|
|
return x
|
|||
|
|
|
|||
|
|
def forward(self, input: torch.Tensor):
|
|||
|
|
if self.use_checkpoint:
|
|||
|
|
return checkpoint.checkpoint(self._forward, input)
|
|||
|
|
else:
|
|||
|
|
return self._forward(input)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class Decoder_Block(nn.Module):
|
|||
|
|
"""Basic block in decoder."""
|
|||
|
|
|
|||
|
|
def __init__(self, in_channel, out_channel):
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
assert out_channel == in_channel // 2, 'the out_channel is not in_channel//2 in decoder block'
|
|||
|
|
self.up = nn.Upsample(scale_factor=2, mode='nearest')
|
|||
|
|
self.fuse = nn.Sequential(nn.Conv2d(in_channels=in_channel + out_channel, out_channels=out_channel,
|
|||
|
|
kernel_size=1, padding=0, bias=False),
|
|||
|
|
nn.BatchNorm2d(out_channel),
|
|||
|
|
nn.ReLU(inplace=True),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def forward(self, de, en):
|
|||
|
|
de = self.up(de)
|
|||
|
|
output = torch.cat([de, en], dim=1)
|
|||
|
|
output = self.fuse(output)
|
|||
|
|
|
|||
|
|
return output
|
|||
|
|
|
|||
|
|
|
|||
|
|
@MODELS.register_module()
|
|||
|
|
class RSM_SS(nn.Module):
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
patch_size=4,
|
|||
|
|
in_chans=3,
|
|||
|
|
num_classes=1000,
|
|||
|
|
depths=[2, 2, 9, 2],
|
|||
|
|
dims=[96, 192, 384, 768],
|
|||
|
|
# =========================
|
|||
|
|
ssm_d_state=16,
|
|||
|
|
ssm_ratio=2.0,
|
|||
|
|
ssm_dt_rank="auto",
|
|||
|
|
ssm_act_layer="silu",
|
|||
|
|
ssm_conv=3,
|
|||
|
|
ssm_conv_bias=True,
|
|||
|
|
ssm_drop_rate=0.0,
|
|||
|
|
ssm_init="v0",
|
|||
|
|
forward_type="v2",
|
|||
|
|
# =========================
|
|||
|
|
mlp_ratio=4.0,
|
|||
|
|
mlp_act_layer="gelu",
|
|||
|
|
mlp_drop_rate=0.0,
|
|||
|
|
# =========================
|
|||
|
|
drop_path_rate=0.1,
|
|||
|
|
patch_norm=True,
|
|||
|
|
norm_layer="LN",
|
|||
|
|
use_checkpoint=False,
|
|||
|
|
**kwargs,
|
|||
|
|
):
|
|||
|
|
super().__init__()
|
|||
|
|
self.num_classes = num_classes
|
|||
|
|
self.num_layers = len(depths)
|
|||
|
|
if isinstance(dims, int):
|
|||
|
|
dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
|
|||
|
|
self.num_features = dims[-1]
|
|||
|
|
self.dims = dims
|
|||
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
|||
|
|
|
|||
|
|
_NORMLAYERS = dict(
|
|||
|
|
ln=nn.LayerNorm,
|
|||
|
|
bn=nn.BatchNorm2d,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
_ACTLAYERS = dict(
|
|||
|
|
silu=nn.SiLU,
|
|||
|
|
gelu=nn.GELU,
|
|||
|
|
relu=nn.ReLU,
|
|||
|
|
sigmoid=nn.Sigmoid,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if isinstance(norm_layer, str) and norm_layer.lower() in ["ln"]:
|
|||
|
|
norm_layer: nn.Module = _NORMLAYERS[norm_layer.lower()]
|
|||
|
|
|
|||
|
|
if isinstance(ssm_act_layer, str) and ssm_act_layer.lower() in ["silu", "gelu", "relu"]:
|
|||
|
|
ssm_act_layer: nn.Module = _ACTLAYERS[ssm_act_layer.lower()]
|
|||
|
|
|
|||
|
|
if isinstance(mlp_act_layer, str) and mlp_act_layer.lower() in ["silu", "gelu", "relu"]:
|
|||
|
|
mlp_act_layer: nn.Module = _ACTLAYERS[mlp_act_layer.lower()]
|
|||
|
|
|
|||
|
|
_make_patch_embed = self._make_patch_embed_v2
|
|||
|
|
self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer)
|
|||
|
|
|
|||
|
|
_make_downsample = self._make_downsample_v3
|
|||
|
|
|
|||
|
|
# self.encoder_layers = [nn.ModuleList()] * self.num_layers
|
|||
|
|
self.encoder_layers = []
|
|||
|
|
self.decoder_layers = []
|
|||
|
|
|
|||
|
|
for i_layer in range(self.num_layers):
|
|||
|
|
# downsample = _make_downsample(
|
|||
|
|
# self.dims[i_layer],
|
|||
|
|
# self.dims[i_layer + 1],
|
|||
|
|
# norm_layer=norm_layer,
|
|||
|
|
# ) if (i_layer < self.num_layers - 1) else nn.Identity()
|
|||
|
|
|
|||
|
|
downsample = _make_downsample(
|
|||
|
|
self.dims[i_layer - 1],
|
|||
|
|
self.dims[i_layer],
|
|||
|
|
norm_layer=norm_layer,
|
|||
|
|
) if (i_layer != 0) else nn.Identity() # ZSJ 修改为i_layer != 0,也就是第一层不下采样,和论文的图保持一致,也方便我取出每个尺度处理好的特征
|
|||
|
|
|
|||
|
|
self.encoder_layers.append(self._make_layer(
|
|||
|
|
dim=self.dims[i_layer],
|
|||
|
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
|||
|
|
use_checkpoint=use_checkpoint,
|
|||
|
|
norm_layer=norm_layer,
|
|||
|
|
downsample=downsample,
|
|||
|
|
# =================
|
|||
|
|
ssm_d_state=ssm_d_state,
|
|||
|
|
ssm_ratio=ssm_ratio,
|
|||
|
|
ssm_dt_rank=ssm_dt_rank,
|
|||
|
|
ssm_act_layer=ssm_act_layer,
|
|||
|
|
ssm_conv=ssm_conv,
|
|||
|
|
ssm_conv_bias=ssm_conv_bias,
|
|||
|
|
ssm_drop_rate=ssm_drop_rate,
|
|||
|
|
ssm_init=ssm_init,
|
|||
|
|
forward_type=forward_type,
|
|||
|
|
# =================
|
|||
|
|
mlp_ratio=mlp_ratio,
|
|||
|
|
mlp_act_layer=mlp_act_layer,
|
|||
|
|
mlp_drop_rate=mlp_drop_rate,
|
|||
|
|
))
|
|||
|
|
if i_layer != 0:
|
|||
|
|
self.decoder_layers.append(
|
|||
|
|
Decoder_Block(in_channel=self.dims[i_layer], out_channel=self.dims[i_layer - 1]))
|
|||
|
|
|
|||
|
|
self.encoder_block1, self.encoder_block2, self.encoder_block3, self.encoder_block4 = self.encoder_layers
|
|||
|
|
self.deocder_block1, self.deocder_block2, self.deocder_block3 = self.decoder_layers
|
|||
|
|
|
|||
|
|
self.upsample_x4 = nn.Sequential(
|
|||
|
|
nn.Conv2d(self.dims[0], self.dims[0] // 2, kernel_size=3, stride=1, padding=1),
|
|||
|
|
nn.BatchNorm2d(self.dims[0] // 2),
|
|||
|
|
nn.ReLU(inplace=True),
|
|||
|
|
nn.UpsamplingBilinear2d(scale_factor=2),
|
|||
|
|
# nn.Conv2d(self.dims[0] // 2, 8, kernel_size=3, stride=1, padding=1),
|
|||
|
|
nn.Conv2d(self.dims[0] // 2, self.dims[0] // 2, kernel_size=3, stride=1, padding=1),
|
|||
|
|
nn.BatchNorm2d(48),
|
|||
|
|
nn.ReLU(inplace=True),
|
|||
|
|
nn.UpsamplingBilinear2d(scale_factor=2)
|
|||
|
|
)
|
|||
|
|
# self.conv_out_seg = nn.Conv2d(8, 1, kernel_size=7, stride=1, padding=3)
|
|||
|
|
|
|||
|
|
self.apply(self._init_weights)
|
|||
|
|
|
|||
|
|
def _init_weights(self, m: nn.Module):
|
|||
|
|
if isinstance(m, nn.Linear):
|
|||
|
|
trunc_normal_(m.weight, std=.02)
|
|||
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|||
|
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
elif isinstance(m, nn.LayerNorm):
|
|||
|
|
nn.init.constant_(m.bias, 0)
|
|||
|
|
nn.init.constant_(m.weight, 1.0)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm):
|
|||
|
|
assert patch_size == 4
|
|||
|
|
return nn.Sequential(
|
|||
|
|
nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1),
|
|||
|
|
(Permute(0, 2, 3, 1) if patch_norm else nn.Identity()),
|
|||
|
|
(norm_layer(embed_dim // 2) if patch_norm else nn.Identity()),
|
|||
|
|
(Permute(0, 3, 1, 2) if patch_norm else nn.Identity()),
|
|||
|
|
nn.GELU(),
|
|||
|
|
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
|
|||
|
|
Permute(0, 2, 3, 1),
|
|||
|
|
(norm_layer(embed_dim) if patch_norm else nn.Identity()),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm):
|
|||
|
|
return nn.Sequential(
|
|||
|
|
Permute(0, 3, 1, 2),
|
|||
|
|
nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1),
|
|||
|
|
Permute(0, 2, 3, 1),
|
|||
|
|
norm_layer(out_dim),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def _make_layer(
|
|||
|
|
dim=96,
|
|||
|
|
drop_path=[0.1, 0.1],
|
|||
|
|
use_checkpoint=False,
|
|||
|
|
norm_layer=nn.LayerNorm,
|
|||
|
|
downsample=nn.Identity(),
|
|||
|
|
# ===========================
|
|||
|
|
ssm_d_state=16,
|
|||
|
|
ssm_ratio=2.0,
|
|||
|
|
ssm_dt_rank="auto",
|
|||
|
|
ssm_act_layer=nn.SiLU,
|
|||
|
|
ssm_conv=3,
|
|||
|
|
ssm_conv_bias=True,
|
|||
|
|
ssm_drop_rate=0.0,
|
|||
|
|
ssm_init="v0",
|
|||
|
|
forward_type="v2",
|
|||
|
|
# ===========================
|
|||
|
|
mlp_ratio=4.0,
|
|||
|
|
mlp_act_layer=nn.GELU,
|
|||
|
|
mlp_drop_rate=0.0,
|
|||
|
|
**kwargs,
|
|||
|
|
):
|
|||
|
|
depth = len(drop_path)
|
|||
|
|
blocks = []
|
|||
|
|
for d in range(depth):
|
|||
|
|
blocks.append(OSSBlock(
|
|||
|
|
hidden_dim=dim,
|
|||
|
|
drop_path=drop_path[d],
|
|||
|
|
norm_layer=norm_layer,
|
|||
|
|
ssm_d_state=ssm_d_state,
|
|||
|
|
ssm_ratio=ssm_ratio,
|
|||
|
|
ssm_dt_rank=ssm_dt_rank,
|
|||
|
|
ssm_act_layer=ssm_act_layer,
|
|||
|
|
ssm_conv=ssm_conv,
|
|||
|
|
ssm_conv_bias=ssm_conv_bias,
|
|||
|
|
ssm_drop_rate=ssm_drop_rate,
|
|||
|
|
ssm_init=ssm_init,
|
|||
|
|
forward_type=forward_type,
|
|||
|
|
mlp_ratio=mlp_ratio,
|
|||
|
|
mlp_act_layer=mlp_act_layer,
|
|||
|
|
mlp_drop_rate=mlp_drop_rate,
|
|||
|
|
use_checkpoint=use_checkpoint,
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
return nn.Sequential(OrderedDict(
|
|||
|
|
# ZSJ 把downsample放到前面来,方便我取出encoder中每个尺度处理好的图像,而不是刚刚下采样完的图像
|
|||
|
|
downsample=downsample,
|
|||
|
|
blocks=nn.Sequential(*blocks, ),
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
def forward(self, x1: torch.Tensor):
|
|||
|
|
x1 = self.patch_embed(x1)
|
|||
|
|
|
|||
|
|
x1_1 = self.encoder_block1(x1)
|
|||
|
|
x1_2 = self.encoder_block2(x1_1)
|
|||
|
|
x1_3 = self.encoder_block3(x1_2)
|
|||
|
|
x1_4 = self.encoder_block4(x1_3) # b,h,w,c
|
|||
|
|
|
|||
|
|
x1_1 = rearrange(x1_1, "b h w c -> b c h w").contiguous()
|
|||
|
|
x1_2 = rearrange(x1_2, "b h w c -> b c h w").contiguous()
|
|||
|
|
x1_3 = rearrange(x1_3, "b h w c -> b c h w").contiguous()
|
|||
|
|
x1_4 = rearrange(x1_4, "b h w c -> b c h w").contiguous()
|
|||
|
|
|
|||
|
|
decode_3 = self.deocder_block3(x1_4, x1_3)
|
|||
|
|
decode_2 = self.deocder_block2(decode_3, x1_2)
|
|||
|
|
decode_1 = self.deocder_block1(decode_2, x1_1)
|
|||
|
|
|
|||
|
|
output = self.upsample_x4(decode_1)
|
|||
|
|
# output = self.conv_out_seg(output)
|
|||
|
|
|
|||
|
|
return output
|
|||
|
|
|
|||
|
|
|
|||
|
|
# if __name__=="__main__":
|
|||
|
|
# net=RSM_SS().to("cuda:1")
|
|||
|
|
# img=torch.randn(2,3,512,512).to("cuda:1")
|
|||
|
|
# out=net(img)
|
|||
|
|
# print(out.size())
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
|
|||
|
|
# Comment batchnorms here and in model_utils before testing speed since the batchnorm could be integrated into conv operation
|
|||
|
|
# (do not comment all, just the batchnorm following its corresponding conv layer)
|
|||
|
|
device = torch.device('cuda:0')
|
|||
|
|
model = RSM_SS()
|
|||
|
|
model.eval()
|
|||
|
|
model.to(device)
|
|||
|
|
iterations = None
|
|||
|
|
|
|||
|
|
input = torch.randn(1, 3, 512, 512).to(device)
|
|||
|
|
with torch.no_grad():
|
|||
|
|
for _ in range(10):
|
|||
|
|
model(input)
|
|||
|
|
|
|||
|
|
if iterations is None:
|
|||
|
|
elapsed_time = 0
|
|||
|
|
iterations = 100
|
|||
|
|
while elapsed_time < 1:
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
t_start = time.time()
|
|||
|
|
for _ in range(iterations):
|
|||
|
|
model(input)
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
elapsed_time = time.time() - t_start
|
|||
|
|
iterations *= 2
|
|||
|
|
FPS = iterations / elapsed_time
|
|||
|
|
iterations = int(FPS * 6)
|
|||
|
|
|
|||
|
|
print('=========Speed Testing=========')
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
t_start = time.time()
|
|||
|
|
for _ in range(iterations):
|
|||
|
|
model(input)
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
torch.cuda.synchronize()
|
|||
|
|
elapsed_time = time.time() - t_start
|
|||
|
|
latency = elapsed_time / iterations * 1000
|
|||
|
|
torch.cuda.empty_cache()
|
|||
|
|
FPS = 1000 / latency
|
|||
|
|
print(FPS)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|