import os import time import math import copy from functools import partial from typing import Optional, Callable, Any from collections import OrderedDict from mmseg.registry import MODELS import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange, repeat from timm.models.layers import DropPath, trunc_normal_ # from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" # import selective scan ============================== try: import selective_scan_cuda_oflex except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda_oflex.", flush=True) # print(e, flush=True) try: import selective_scan_cuda_core except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda_core.", flush=True) # print(e, flush=True) try: import selective_scan_cuda except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda.", flush=True) # print(e, flush=True) # fvcore flops ======================================= def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ assert not with_complex # https://github.com/state-spaces/mamba/issues/110 flops = 9 * B * L * D * N if with_D: flops += B * D * L if with_Z: flops += B * D * L return flops # this is only for selective_scan_ref... def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ import numpy as np # fvcore.nn.jit_handles def get_flops_einsum(input_shapes, equation): np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): # divided by 2 because we count MAC (multiply-add counted as one flop) flop = float(np.floor(float(line.split(":")[-1]) / 2)) return flop assert not with_complex flops = 0 # below code flops = 0 flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") if with_Group: flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") else: flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") in_for_flops = B * D * N if with_Group: in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") else: in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") flops += L * in_for_flops if with_D: flops += B * D * L if with_Z: flops += B * D * L return flops def print_jit_input_names(inputs): print("input params: ", end=" ", flush=True) try: for i in range(10): print(inputs[i].debugName(), end=" ", flush=True) except Exception as e: pass print("", flush=True) # cross selective scan =============================== class SelectiveScanMamba(torch.autograd.Function): # comment all checks if inside cross_selective_scan @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): # assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile # assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}" ctx.delta_softplus = delta_softplus # all in float # if u.stride(-1) != 1: # u = u.contiguous() # if delta.stride(-1) != 1: # delta = delta.contiguous() # if D is not None and D.stride(-1) != 1: # D = D.contiguous() # if B.stride(-1) != 1: # B = B.contiguous() # if C.stride(-1) != 1: # C = C.contiguous() # if B.dim() == 3: # B = B.unsqueeze(dim=1) # ctx.squeeze_B = True # if C.dim() == 3: # C = C.unsqueeze(dim=1) # ctx.squeeze_C = True out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, False ) # dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB # dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) # class SelectiveScanCore(torch.autograd.Function): # # comment all checks if inside cross_selective_scan # @staticmethod # @torch.cuda.amp.custom_fwd # def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): # ctx.delta_softplus = delta_softplus # out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) # ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) # return out # # @staticmethod # @torch.cuda.amp.custom_bwd # def backward(ctx, dout, *args): # u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors # if dout.stride(-1) != 1: # dout = dout.contiguous() # du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( # u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 # ) # return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) # # # class SelectiveScanOflex(torch.autograd.Function): # # comment all checks if inside cross_selective_scan # @staticmethod # @torch.cuda.amp.custom_fwd # def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): # ctx.delta_softplus = delta_softplus # out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) # ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) # return out # # @staticmethod # @torch.cuda.amp.custom_bwd # def backward(ctx, dout, *args): # u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors # if dout.stride(-1) != 1: # dout = dout.contiguous() # du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( # u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 # ) # return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) # # # class SelectiveScanFake(torch.autograd.Function): # # comment all checks if inside cross_selective_scan # @staticmethod # @torch.cuda.amp.custom_fwd # def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): # ctx.delta_softplus = delta_softplus # ctx.backnrows = backnrows # x = delta # out = u # ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) # return out # # @staticmethod # @torch.cuda.amp.custom_bwd # def backward(ctx, dout, *args): # u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors # if dout.stride(-1) != 1: # dout = dout.contiguous() # du, ddelta, dA, dB, dC, dD, ddelta_bias = u * 0, delta * 0, A * 0, B * 0, C * 0, C * 0, (D * 0 if D else None), (delta_bias * 0 if delta_bias else None) # return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) # ============= def antidiagonal_gather(tensor): # 取出矩阵所有反斜向的元素并拼接 B, C, H, W = tensor.size() shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1] index = (torch.arange(W, device=tensor.device) - shift) % W # 利用广播创建索引矩阵[H, W] # 扩展索引以适应B和C维度 expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) # 使用gather进行索引选择 return tensor.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W) def diagonal_gather(tensor): # 取出矩阵所有反斜向的元素并拼接 B, C, H, W = tensor.size() shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1] index = (shift + torch.arange(W, device=tensor.device)) % W # 利用广播创建索引矩阵[H, W] # 扩展索引以适应B和C维度 expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) # 使用gather进行索引选择 return tensor.gather(3, expanded_index).transpose(-1, -2).reshape(B, C, H * W) def diagonal_scatter(tensor_flat, original_shape): # 把斜向元素拼接起来的一维向量还原为最初的矩阵形式 B, C, H, W = original_shape shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1] index = (shift + torch.arange(W, device=tensor_flat.device)) % W # 利用广播创建索引矩阵[H, W] # 扩展索引以适应B和C维度 expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) # 创建一个空的张量来存储反向散布的结果 result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype) # 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换 tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2) # 使用scatter_根据expanded_index将元素放回原位 result_tensor.scatter_(3, expanded_index, tensor_reshaped) return result_tensor def antidiagonal_scatter(tensor_flat, original_shape): # 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式 B, C, H, W = original_shape shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1] index = (torch.arange(W, device=tensor_flat.device) - shift) % W # 利用广播创建索引矩阵[H, W] expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) # 初始化一个与原始张量形状相同、元素全为0的张量 result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype) # 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度 tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2) # 使用scatter_将元素根据索引放回原位 result_tensor.scatter_(3, expanded_index, tensor_reshaped) return result_tensor class CrossScan(torch.autograd.Function): # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改 @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) # xs = x.new_empty((B, 4, C, H * W)) xs = x.new_empty((B, 8, C, H * W)) # 添加横向和竖向的扫描 xs[:, 0] = x.flatten(2, 3) xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # 提供斜向和反斜向的扫描 xs[:, 4] = diagonal_gather(x) xs[:, 5] = antidiagonal_gather(x) xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1]) return xs @staticmethod def backward(ctx, ys: torch.Tensor): # out: (b, k, d, l) B, C, H, W = ctx.shape L = H * W # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加 # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式 # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) y_rb = y_rb.view(B, -1, H, W) # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加 y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L) # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加 y_da = diagonal_scatter(y_da[:, 0], (B, C, H, W)) + antidiagonal_scatter(y_da[:, 1], (B, C, H, W)) y_res = y_rb + y_da # return y.view(B, -1, H, W) return y_res class CrossMerge(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式 y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) y_rb = y_rb.view(B, -1, H, W) # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加 y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1) # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加 y_da = diagonal_scatter(y_da[:, 0], (B, D, H, W)) + antidiagonal_scatter(y_da[:, 1], (B, D, H, W)) y_res = y_rb + y_da return y_res.view(B, D, -1) # return y @staticmethod def backward(ctx, x: torch.Tensor): # B, D, L = x.shape # out: (b, k, d, l) H, W = ctx.shape B, C, L = x.shape # xs = x.new_empty((B, 4, C, L)) xs = x.new_empty((B, 8, C, L)) # 横向和竖向扫描 xs[:, 0] = x xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # xs = xs.view(B, 4, C, H, W) # 提供斜向和反斜向的扫描 xs[:, 4] = diagonal_gather(x.view(B, C, H, W)) xs[:, 5] = antidiagonal_gather(x.view(B, C, H, W)) xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1]) # return xs return xs.view(B, 8, C, H, W) """ ablation exp """ # these are for ablations ============= # class CrossScan_Ab_2direction(torch.autograd.Function): # @staticmethod # def forward(ctx, x: torch.Tensor): # B, C, H, W = x.shape # ctx.shape = (B, C, H, W) # xs = x.new_empty((B, 4, C, H * W)) # xs[:, 0] = x.flatten(2, 3) # xs[:, 1] = x.flatten(2, 3) # xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # return xs # # @staticmethod # def backward(ctx, ys: torch.Tensor): # # out: (b, k, d, l) # B, C, H, W = ctx.shape # L = H * W # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) # return y.view(B, -1, H, W) # # # class CrossMerge_Ab_2direction(torch.autograd.Function): # @staticmethod # def forward(ctx, ys: torch.Tensor): # B, K, D, H, W = ys.shape # ctx.shape = (H, W) # ys = ys.view(B, K, D, -1) # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) # y = ys.sum(dim=1) # return y # # @staticmethod # def backward(ctx, x: torch.Tensor): # # B, D, L = x.shape # # out: (b, k, d, l) # H, W = ctx.shape # B, C, L = x.shape # xs = x.new_empty((B, 4, C, L)) # xs[:, 0] = x # xs[:, 1] = x # xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # xs = xs.view(B, 4, C, H, W) # return xs # # # class CrossScan_Ab_1direction(torch.autograd.Function): # @staticmethod # def forward(ctx, x: torch.Tensor): # B, C, H, W = x.shape # ctx.shape = (B, C, H, W) # xs = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1).contiguous() # return xs # # @staticmethod # def backward(ctx, ys: torch.Tensor): # # out: (b, k, d, l) # B, C, H, W = ctx.shape # y = ys.sum(dim=1).view(B, C, H, W) # return y # # # class CrossMerge_Ab_1direction(torch.autograd.Function): # @staticmethod # def forward(ctx, ys: torch.Tensor): # B, K, D, H, W = ys.shape # ctx.shape = (H, W) # y = ys.sum(dim=1).view(B, D, H * W) # return y # # @staticmethod # def backward(ctx, x: torch.Tensor): # # B, D, L = x.shape # # out: (b, k, d, l) # H, W = ctx.shape # B, C, L = x.shape # xs = x.view(B, 1, C, L).repeat(1, 4, 1, 1).contiguous().view(B, 4, C, H, W) # return xs # # ============= # ZSJ 这里是mamba的具体内容,要增加扫描方向就在这里改 class SelectiveScanFn(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): if u.stride(-1) != 1: u = u.contiguous() if delta.stride(-1) != 1: delta = delta.contiguous() if D is not None: D = D.contiguous() if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: C = C.contiguous() if z is not None and z.stride(-1) != 1: z = z.contiguous() if B.dim() == 3: B = rearrange(B, "b dstate l -> b 1 dstate l") ctx.squeeze_B = True if C.dim() == 3: C = rearrange(C, "b dstate l -> b 1 dstate l") ctx.squeeze_C = True out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus) ctx.delta_softplus = delta_softplus ctx.has_z = z is not None last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if not ctx.has_z: ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out if not return_last_state else (out, last_state) else: ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out) out_z = rest[0] return out_z if not return_last_state else (out_z, last_state) @staticmethod def backward(ctx, dout, *args): if not ctx.has_z: u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors z = None out = None else: u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the # backward of selective_scan_cuda with the backward of chunk). # Here we just pass in None and dz will be allocated in the C++ code. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus, False # option to recompute out_z, not used here ) dz = rest[0] if ctx.has_z else None dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC return (du, ddelta, dA, dB, dC, dD if D is not None else None, dz, ddelta_bias if delta_bias is not None else None, None, None) def cross_selective_scan( x: torch.Tensor = None, x_proj_weight: torch.Tensor = None, x_proj_bias: torch.Tensor = None, dt_projs_weight: torch.Tensor = None, dt_projs_bias: torch.Tensor = None, A_logs: torch.Tensor = None, Ds: torch.Tensor = None, delta_softplus=True, out_norm: torch.nn.Module = None, out_norm_shape="v0", # ============================== to_dtype=True, # True: final out to dtype force_fp32=False, # True: input fp32 # ============================== nrows=-1, # for SelectiveScanNRow; 0: auto; -1: disable; backnrows=-1, # for SelectiveScanNRow; 0: auto; -1: disable; ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore # ============================== # SelectiveScan=None, CrossScan=CrossScan, CrossMerge=CrossMerge, ): # out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);... B, D, H, W = x.shape D, N = A_logs.shape K, D, R = dt_projs_weight.shape L = H * W if nrows == 0: if D % 4 == 0: nrows = 4 elif D % 3 == 0: nrows = 3 elif D % 2 == 0: nrows = 2 else: nrows = 1 if backnrows == 0: if D % 4 == 0: backnrows = 4 elif D % 3 == 0: backnrows = 3 elif D % 2 == 0: backnrows = 2 else: backnrows = 1 # sacn jiushizhege bata huancheng houmian nage banben # def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True): # return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex) def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). Note that the gradient of the last state is not considered in the backward pass. """ return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) xs = CrossScan.apply(x) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) if x_proj_bias is not None: x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) xs = xs.view(B, -1, L) dts = dts.contiguous().view(B, -1, L) As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) Bs = Bs.contiguous() Cs = Cs.contiguous() Ds = Ds.to(torch.float) # (K * c) delta_bias = dt_projs_bias.view(-1).to(torch.float) if force_fp32: xs = xs.to(torch.float) dts = dts.to(torch.float) Bs = Bs.to(torch.float) Cs = Cs.to(torch.float) # ZSJ 这里把矩阵拆分成不同方向的序列,并进行扫描 ys: torch.Tensor = selective_scan_fn( xs, dts, As, Bs, Cs, Ds, z=None, delta_bias=delta_bias, delta_softplus=True, return_last_state=False, ).view(B, K, -1, H, W) # ZSJ 这里把处理之后的序列融合起来,并还原回原来的矩阵形式 y: torch.Tensor = CrossMerge.apply(ys) if out_norm_shape in ["v1"]: # (B, C, H, W) y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) # (B, H, W, C) else: # (B, L, C) y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) y = out_norm(y).view(B, H, W, -1) return (y.to(x.dtype) if to_dtype else y) def selective_scan_flop_jit(inputs, outputs): print_jit_input_names(inputs) B, D, L = inputs[0].type().sizes() N = inputs[2].type().sizes()[1] flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False) return flops # ===================================================== class PatchMerging2D(nn.Module): def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) self.norm = norm_layer(4 * dim) @staticmethod def _patch_merging_pad(x: torch.Tensor): H, W, _ = x.shape[-3:] if (W % 2 != 0) or (H % 2 != 0): x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C return x def forward(self, x): x = self._patch_merging_pad(x) x = self.norm(x) x = self.reduction(x) return x class OSSM(nn.Module): def __init__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", act_layer=nn.SiLU, # dwconv =============== d_conv=3, # < 2 means no conv conv_bias=True, # ====================== dropout=0.0, bias=False, # dt init ============== dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, initialize="v0", # ====================== forward_type="v2", # ====================== **kwargs, ): factory_kwargs = {"device": None, "dtype": None} super().__init__() d_inner = int(ssm_ratio * d_model) dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank self.d_conv = d_conv # tags for forward_type ============================== def checkpostfix(tag, value): ret = value[-len(tag):] == tag if ret: value = value[:-len(tag)] return ret, value self.disable_force32, forward_type = checkpostfix("no32", forward_type) self.disable_z, forward_type = checkpostfix("noz", forward_type) self.disable_z_act, forward_type = checkpostfix("nozact", forward_type) # softmax | sigmoid | dwconv | norm =========================== if forward_type[-len("none"):] == "none": forward_type = forward_type[:-len("none")] self.out_norm = nn.Identity() elif forward_type[-len("dwconv3"):] == "dwconv3": forward_type = forward_type[:-len("dwconv3")] self.out_norm = nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False) self.out_norm_shape = "v1" elif forward_type[-len("softmax"):] == "softmax": forward_type = forward_type[:-len("softmax")] self.out_norm = nn.Softmax(dim=1) elif forward_type[-len("sigmoid"):] == "sigmoid": forward_type = forward_type[:-len("sigmoid")] self.out_norm = nn.Sigmoid() else: self.out_norm = nn.LayerNorm(d_inner) # # forward_type debug ======================================= # FORWARD_TYPES = dict( # v0=self.forward_corev0, # # v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanCore), # v2=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanCore), # v3=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex), # v31d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial( # cross_selective_scan, CrossScan=CrossScan_Ab_1direction, CrossMerge=CrossMerge_Ab_1direction, # )), # v32d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial( # cross_selective_scan, CrossScan=CrossScan_Ab_2direction, CrossMerge=CrossMerge_Ab_2direction, # )), # # =============================== # fake=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanFake), # v1=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanOflex), # v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanMamba), # ) # if forward_type.startswith("debug"): # from .ss2d_ablations import SS2D_ForwardCoreSpeedAblations, SS2D_ForwardCoreModeAblations, cross_selective_scanv2 # FORWARD_TYPES.update(dict( # debugforward_core_mambassm_seq=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_seq, self), # debugforward_core_mambassm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm, self), # debugforward_core_mambassm_fp16=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fp16, self), # debugforward_core_mambassm_fusecs=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecs, self), # debugforward_core_mambassm_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecscm, self), # debugforward_core_sscore_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_sscore_fusecscm, self), # debugforward_core_sscore_fusecscm_fwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fwdnrow, self), # debugforward_core_sscore_fusecscm_bwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_bwdnrow, self), # debugforward_core_sscore_fusecscm_fbnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fbnrow, self), # debugforward_core_ssoflex_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm, self), # debugforward_core_ssoflex_fusecscm_i16o32=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm_i16o32, self), # debugscan_sharessm=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scanv2), # )) # self.forward_core = FORWARD_TYPES.get(forward_type, None) # ZSJ k_group 指的是扫描的方向 # k_group = 4 if forward_type not in ["debugscan_sharessm"] else 1 k_group = 8 if forward_type not in ["debugscan_sharessm"] else 1 # in proj ======================================= d_proj = d_inner if self.disable_z else (d_inner * 2) self.in_proj = nn.Linear(d_model, d_proj, bias=bias, **factory_kwargs) self.act: nn.Module = act_layer() # conv ======================================= if d_conv > 1: self.conv2d = nn.Conv2d( in_channels=d_inner, out_channels=d_inner, groups=d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) # x proj ============================ self.x_proj = [ nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs) for _ in range(k_group) ] self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) del self.x_proj # out proj ======================================= self.out_proj = nn.Linear(d_inner, d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() if initialize in ["v0"]: # dt proj ============================ self.dt_projs = [ self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) for _ in range(k_group) ] self.dt_projs_weight = nn.Parameter( torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) del self.dt_projs # A, D ======================================= self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D) elif initialize in ["v1"]: # simple init dt_projs, A_logs, Ds self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) self.A_logs = nn.Parameter( torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) elif initialize in ["v2"]: # simple init dt_projs, A_logs, Ds self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) self.A_logs = nn.Parameter( torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank ** -0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit # dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 0: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=-1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 0: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D # only used to run previous version # def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False): # def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): # return SelectiveScanCore.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, False) # # if not channel_first: # x = x.permute(0, 3, 1, 2).contiguous() # B, D, H, W = x.shape # D, N = self.A_logs.shape # K, D, R = self.dt_projs_weight.shape # L = H * W # # # ZSJ 这里进行data expand操作,也就是把相同的数据在不同方向展开成一维,并拼接起来,但是这个函数只用在旧版本 # # 把横向和竖向拼接在K维度 # x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) # # torch.flip把横向和竖向两个方向都进行反向操作 # xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) # # x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) # # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) # dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) # dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) # # xs = xs.float().view(B, -1, L) # (b, k * d, l) # dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) # Bs = Bs.float() # (b, k, d_state, l) # Cs = Cs.float() # (b, k, d_state, l) # # As = -torch.exp(self.A_logs.float()) # (k * d, d_state) # Ds = self.Ds.float() # (k * d) # dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) # # # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 # # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 # # out_y = selective_scan( # xs, dts, # As, Bs, Cs, Ds, # delta_bias=dt_projs_bias, # delta_softplus=True, # ).view(B, K, -1, L) # # assert out_y.dtype == torch.float # # inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) # wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) # invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) # y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y # y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) # y = self.out_norm(y).view(B, H, W, -1) # # return (y.to(x.dtype) if to_dtype else y) def forward_corev2(self, x: torch.Tensor, channel_first=False, cross_selective_scan=cross_selective_scan, force_fp32=None): if not channel_first: x = x.permute(0, 3, 1, 2).contiguous() # ZSJ V2版本使用的mamba,要改扫描方向在这里改 x = cross_selective_scan( x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias, self.A_logs, self.Ds, delta_softplus=True, out_norm=getattr(self, "out_norm", None), out_norm_shape=getattr(self, "out_norm_shape", "v0"), force_fp32=force_fp32, ) return x def forward(self, x: torch.Tensor, **kwargs): with_dconv = (self.d_conv > 1) x = self.in_proj(x) if not self.disable_z: x, z = x.chunk(2, dim=-1) # (b, h, w, d) if not self.disable_z_act: z = self.act(z) if with_dconv: x = x.permute(0, 3, 1, 2).contiguous() x = self.conv2d(x) # (b, d, h, w) x = self.act(x) y = self.forward_corev2(x, channel_first=with_dconv) if not self.disable_z: y = y * z out = self.dropout(self.out_proj(y)) return out class Permute(nn.Module): def __init__(self, *args): super().__init__() self.args = args def forward(self, x: torch.Tensor): return x.permute(*self.args) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., channels_first=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear self.fc1 = Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class OSSBlock(nn.Module): def __init__( self, hidden_dim: int = 0, drop_path: float = 0, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), # ============================= ssm_d_state: int = 16, ssm_ratio=2.0, ssm_dt_rank: Any = "auto", ssm_act_layer=nn.SiLU, ssm_conv: int = 3, ssm_conv_bias=True, ssm_drop_rate: float = 0, ssm_init="v0", forward_type="v2", # ============================= mlp_ratio=4.0, mlp_act_layer=nn.GELU, mlp_drop_rate: float = 0.0, # ============================= use_checkpoint: bool = False, post_norm: bool = False, **kwargs, ): super().__init__() self.ssm_branch = ssm_ratio > 0 self.mlp_branch = mlp_ratio > 0 self.use_checkpoint = use_checkpoint self.post_norm = post_norm try: from ss2d_ablations import SS2DDev _OSSM = SS2DDev if forward_type.startswith("dev") else OSSM except: _OSSM = OSSM if self.ssm_branch: self.norm = norm_layer(hidden_dim) self.op = _OSSM( d_model=hidden_dim, d_state=ssm_d_state, ssm_ratio=ssm_ratio, dt_rank=ssm_dt_rank, act_layer=ssm_act_layer, # ========================== d_conv=ssm_conv, conv_bias=ssm_conv_bias, # ========================== dropout=ssm_drop_rate, # bias=False, # ========================== # dt_min=0.001, # dt_max=0.1, # dt_init="random", # dt_scale="random", # dt_init_floor=1e-4, initialize=ssm_init, # ========================== forward_type=forward_type, ) self.drop_path = DropPath(drop_path) if self.mlp_branch: self.norm2 = norm_layer(hidden_dim) mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False) def _forward(self, input: torch.Tensor): if self.ssm_branch: if self.post_norm: x = input + self.drop_path(self.norm(self.op(input))) else: x = input + self.drop_path(self.op(self.norm(input))) if self.mlp_branch: if self.post_norm: x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN else: x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN return x def forward(self, input: torch.Tensor): if self.use_checkpoint: return checkpoint.checkpoint(self._forward, input) else: return self._forward(input) class Decoder_Block(nn.Module): """Basic block in decoder.""" def __init__(self, in_channel, out_channel): super().__init__() assert out_channel == in_channel // 2, 'the out_channel is not in_channel//2 in decoder block' self.up = nn.Upsample(scale_factor=2, mode='nearest') self.fuse = nn.Sequential(nn.Conv2d(in_channels=in_channel + out_channel, out_channels=out_channel, kernel_size=1, padding=0, bias=False), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), ) def forward(self, de, en): de = self.up(de) output = torch.cat([de, en], dim=1) output = self.fuse(output) return output @MODELS.register_module() class RSM_SS(nn.Module): def __init__( self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], dims=[96, 192, 384, 768], # ========================= ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, ssm_init="v0", forward_type="v2", # ========================= mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, # ========================= drop_path_rate=0.1, patch_norm=True, norm_layer="LN", use_checkpoint=False, **kwargs, ): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) if isinstance(dims, int): dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] self.num_features = dims[-1] self.dims = dims dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule _NORMLAYERS = dict( ln=nn.LayerNorm, bn=nn.BatchNorm2d, ) _ACTLAYERS = dict( silu=nn.SiLU, gelu=nn.GELU, relu=nn.ReLU, sigmoid=nn.Sigmoid, ) if isinstance(norm_layer, str) and norm_layer.lower() in ["ln"]: norm_layer: nn.Module = _NORMLAYERS[norm_layer.lower()] if isinstance(ssm_act_layer, str) and ssm_act_layer.lower() in ["silu", "gelu", "relu"]: ssm_act_layer: nn.Module = _ACTLAYERS[ssm_act_layer.lower()] if isinstance(mlp_act_layer, str) and mlp_act_layer.lower() in ["silu", "gelu", "relu"]: mlp_act_layer: nn.Module = _ACTLAYERS[mlp_act_layer.lower()] _make_patch_embed = self._make_patch_embed_v2 self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer) _make_downsample = self._make_downsample_v3 # self.encoder_layers = [nn.ModuleList()] * self.num_layers self.encoder_layers = [] self.decoder_layers = [] for i_layer in range(self.num_layers): # downsample = _make_downsample( # self.dims[i_layer], # self.dims[i_layer + 1], # norm_layer=norm_layer, # ) if (i_layer < self.num_layers - 1) else nn.Identity() downsample = _make_downsample( self.dims[i_layer - 1], self.dims[i_layer], norm_layer=norm_layer, ) if (i_layer != 0) else nn.Identity() # ZSJ 修改为i_layer != 0,也就是第一层不下采样,和论文的图保持一致,也方便我取出每个尺度处理好的特征 self.encoder_layers.append(self._make_layer( dim=self.dims[i_layer], drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], use_checkpoint=use_checkpoint, norm_layer=norm_layer, downsample=downsample, # ================= ssm_d_state=ssm_d_state, ssm_ratio=ssm_ratio, ssm_dt_rank=ssm_dt_rank, ssm_act_layer=ssm_act_layer, ssm_conv=ssm_conv, ssm_conv_bias=ssm_conv_bias, ssm_drop_rate=ssm_drop_rate, ssm_init=ssm_init, forward_type=forward_type, # ================= mlp_ratio=mlp_ratio, mlp_act_layer=mlp_act_layer, mlp_drop_rate=mlp_drop_rate, )) if i_layer != 0: self.decoder_layers.append( Decoder_Block(in_channel=self.dims[i_layer], out_channel=self.dims[i_layer - 1])) self.encoder_block1, self.encoder_block2, self.encoder_block3, self.encoder_block4 = self.encoder_layers self.deocder_block1, self.deocder_block2, self.deocder_block3 = self.decoder_layers self.upsample_x4 = nn.Sequential( nn.Conv2d(self.dims[0], self.dims[0] // 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(self.dims[0] // 2), nn.ReLU(inplace=True), nn.UpsamplingBilinear2d(scale_factor=2), # nn.Conv2d(self.dims[0] // 2, 8, kernel_size=3, stride=1, padding=1), nn.Conv2d(self.dims[0] // 2, self.dims[0] // 2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(48), nn.ReLU(inplace=True), nn.UpsamplingBilinear2d(scale_factor=2) ) # self.conv_out_seg = nn.Conv2d(8, 1, kernel_size=7, stride=1, padding=3) self.apply(self._init_weights) def _init_weights(self, m: nn.Module): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @staticmethod def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm): assert patch_size == 4 return nn.Sequential( nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1), (Permute(0, 2, 3, 1) if patch_norm else nn.Identity()), (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()), (Permute(0, 3, 1, 2) if patch_norm else nn.Identity()), nn.GELU(), nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), Permute(0, 2, 3, 1), (norm_layer(embed_dim) if patch_norm else nn.Identity()), ) @staticmethod def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm): return nn.Sequential( Permute(0, 3, 1, 2), nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1), Permute(0, 2, 3, 1), norm_layer(out_dim), ) @staticmethod def _make_layer( dim=96, drop_path=[0.1, 0.1], use_checkpoint=False, norm_layer=nn.LayerNorm, downsample=nn.Identity(), # =========================== ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer=nn.SiLU, ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, ssm_init="v0", forward_type="v2", # =========================== mlp_ratio=4.0, mlp_act_layer=nn.GELU, mlp_drop_rate=0.0, **kwargs, ): depth = len(drop_path) blocks = [] for d in range(depth): blocks.append(OSSBlock( hidden_dim=dim, drop_path=drop_path[d], norm_layer=norm_layer, ssm_d_state=ssm_d_state, ssm_ratio=ssm_ratio, ssm_dt_rank=ssm_dt_rank, ssm_act_layer=ssm_act_layer, ssm_conv=ssm_conv, ssm_conv_bias=ssm_conv_bias, ssm_drop_rate=ssm_drop_rate, ssm_init=ssm_init, forward_type=forward_type, mlp_ratio=mlp_ratio, mlp_act_layer=mlp_act_layer, mlp_drop_rate=mlp_drop_rate, use_checkpoint=use_checkpoint, )) return nn.Sequential(OrderedDict( # ZSJ 把downsample放到前面来,方便我取出encoder中每个尺度处理好的图像,而不是刚刚下采样完的图像 downsample=downsample, blocks=nn.Sequential(*blocks, ), )) def forward(self, x1: torch.Tensor): x1 = self.patch_embed(x1) x1_1 = self.encoder_block1(x1) x1_2 = self.encoder_block2(x1_1) x1_3 = self.encoder_block3(x1_2) x1_4 = self.encoder_block4(x1_3) # b,h,w,c x1_1 = rearrange(x1_1, "b h w c -> b c h w").contiguous() x1_2 = rearrange(x1_2, "b h w c -> b c h w").contiguous() x1_3 = rearrange(x1_3, "b h w c -> b c h w").contiguous() x1_4 = rearrange(x1_4, "b h w c -> b c h w").contiguous() decode_3 = self.deocder_block3(x1_4, x1_3) decode_2 = self.deocder_block2(decode_3, x1_2) decode_1 = self.deocder_block1(decode_2, x1_1) output = self.upsample_x4(decode_1) # output = self.conv_out_seg(output) return output # if __name__=="__main__": # net=RSM_SS().to("cuda:1") # img=torch.randn(2,3,512,512).to("cuda:1") # out=net(img) # print(out.size()) if __name__ == '__main__': # Comment batchnorms here and in model_utils before testing speed since the batchnorm could be integrated into conv operation # (do not comment all, just the batchnorm following its corresponding conv layer) device = torch.device('cuda:0') model = RSM_SS() model.eval() model.to(device) iterations = None input = torch.randn(1, 3, 512, 512).to(device) with torch.no_grad(): for _ in range(10): model(input) if iterations is None: elapsed_time = 0 iterations = 100 while elapsed_time < 1: torch.cuda.synchronize() torch.cuda.synchronize() t_start = time.time() for _ in range(iterations): model(input) torch.cuda.synchronize() torch.cuda.synchronize() elapsed_time = time.time() - t_start iterations *= 2 FPS = iterations / elapsed_time iterations = int(FPS * 6) print('=========Speed Testing=========') torch.cuda.synchronize() torch.cuda.synchronize() t_start = time.time() for _ in range(iterations): model(input) torch.cuda.synchronize() torch.cuda.synchronize() elapsed_time = time.time() - t_start latency = elapsed_time / iterations * 1000 torch.cuda.empty_cache() FPS = 1000 / latency print(FPS)