423 lines
15 KiB
Python
423 lines
15 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
from functools import partial
|
||
|
|
|
||
|
|
import math
|
||
|
|
# from timm.models.layers import trunc_normal_tf_
|
||
|
|
from timm.models.helpers import named_apply
|
||
|
|
|
||
|
|
|
||
|
|
def gcd(a, b):
|
||
|
|
while b:
|
||
|
|
a, b = b, a % b
|
||
|
|
return a
|
||
|
|
|
||
|
|
|
||
|
|
# Other types of layers can go here (e.g., nn.Linear, etc.)
|
||
|
|
def _init_weights(module, name, scheme=''):
|
||
|
|
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d):
|
||
|
|
if scheme == 'normal':
|
||
|
|
nn.init.normal_(module.weight, std=.02)
|
||
|
|
if module.bias is not None:
|
||
|
|
nn.init.zeros_(module.bias)
|
||
|
|
# elif scheme == 'trunc_normal':
|
||
|
|
# trunc_normal_tf_(module.weight, std=.02)
|
||
|
|
# if module.bias is not None:
|
||
|
|
# nn.init.zeros_(module.bias)
|
||
|
|
elif scheme == 'xavier_normal':
|
||
|
|
nn.init.xavier_normal_(module.weight)
|
||
|
|
if module.bias is not None:
|
||
|
|
nn.init.zeros_(module.bias)
|
||
|
|
elif scheme == 'kaiming_normal':
|
||
|
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
||
|
|
if module.bias is not None:
|
||
|
|
nn.init.zeros_(module.bias)
|
||
|
|
else:
|
||
|
|
# efficientnet like
|
||
|
|
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
|
||
|
|
fan_out //= module.groups
|
||
|
|
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
|
||
|
|
if module.bias is not None:
|
||
|
|
nn.init.zeros_(module.bias)
|
||
|
|
elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d):
|
||
|
|
nn.init.constant_(module.weight, 1)
|
||
|
|
nn.init.constant_(module.bias, 0)
|
||
|
|
elif isinstance(module, nn.LayerNorm):
|
||
|
|
nn.init.constant_(module.weight, 1)
|
||
|
|
nn.init.constant_(module.bias, 0)
|
||
|
|
|
||
|
|
|
||
|
|
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
|
||
|
|
# activation layer
|
||
|
|
act = act.lower()
|
||
|
|
if act == 'relu':
|
||
|
|
layer = nn.ReLU(inplace)
|
||
|
|
elif act == 'relu6':
|
||
|
|
layer = nn.ReLU6(inplace)
|
||
|
|
elif act == 'leakyrelu':
|
||
|
|
layer = nn.LeakyReLU(neg_slope, inplace)
|
||
|
|
elif act == 'prelu':
|
||
|
|
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||
|
|
elif act == 'gelu':
|
||
|
|
layer = nn.GELU()
|
||
|
|
elif act == 'hswish':
|
||
|
|
layer = nn.Hardswish(inplace)
|
||
|
|
else:
|
||
|
|
raise NotImplementedError('activation layer [%s] is not found' % act)
|
||
|
|
return layer
|
||
|
|
|
||
|
|
|
||
|
|
def channel_shuffle(x, groups):
|
||
|
|
batchsize, num_channels, height, width = x.data.size()
|
||
|
|
channels_per_group = num_channels // groups
|
||
|
|
# reshape
|
||
|
|
x = x.view(batchsize, groups,
|
||
|
|
channels_per_group, height, width)
|
||
|
|
x = torch.transpose(x, 1, 2).contiguous()
|
||
|
|
# flatten
|
||
|
|
x = x.view(batchsize, -1, height, width)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
# Multi-scale depth-wise convolution (MSDC)
|
||
|
|
class MSDC(nn.Module):
|
||
|
|
def __init__(self, in_channels, kernel_sizes, stride, activation='relu6', dw_parallel=True):
|
||
|
|
super(MSDC, self).__init__()
|
||
|
|
|
||
|
|
self.in_channels = in_channels
|
||
|
|
self.kernel_sizes = kernel_sizes
|
||
|
|
self.activation = activation
|
||
|
|
self.dw_parallel = dw_parallel
|
||
|
|
|
||
|
|
self.dwconvs = nn.ModuleList([
|
||
|
|
nn.Sequential(
|
||
|
|
nn.Conv2d(self.in_channels, self.in_channels, kernel_size, stride, kernel_size // 2,
|
||
|
|
groups=self.in_channels, bias=False),
|
||
|
|
nn.BatchNorm2d(self.in_channels),
|
||
|
|
act_layer(self.activation, inplace=True)
|
||
|
|
)
|
||
|
|
for kernel_size in self.kernel_sizes
|
||
|
|
])
|
||
|
|
|
||
|
|
self.init_weights('normal')
|
||
|
|
|
||
|
|
def init_weights(self, scheme=''):
|
||
|
|
named_apply(partial(_init_weights, scheme=scheme), self)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
# Apply the convolution layers in a loop
|
||
|
|
outputs = []
|
||
|
|
for dwconv in self.dwconvs:
|
||
|
|
dw_out = dwconv(x)
|
||
|
|
outputs.append(dw_out)
|
||
|
|
if self.dw_parallel == False:
|
||
|
|
x = x + dw_out
|
||
|
|
# You can return outputs based on what you intend to do with them
|
||
|
|
return outputs
|
||
|
|
|
||
|
|
|
||
|
|
class MSCB(nn.Module):
|
||
|
|
"""
|
||
|
|
Multi-scale convolution block (MSCB)
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, in_channels, out_channels, stride, kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True,
|
||
|
|
add=True, activation='relu6'):
|
||
|
|
super(MSCB, self).__init__()
|
||
|
|
|
||
|
|
self.in_channels = in_channels
|
||
|
|
self.out_channels = out_channels
|
||
|
|
self.stride = stride
|
||
|
|
self.kernel_sizes = kernel_sizes
|
||
|
|
self.expansion_factor = expansion_factor
|
||
|
|
self.dw_parallel = dw_parallel
|
||
|
|
self.add = add
|
||
|
|
self.activation = activation
|
||
|
|
self.n_scales = len(self.kernel_sizes)
|
||
|
|
# check stride value
|
||
|
|
assert self.stride in [1, 2]
|
||
|
|
# Skip connection if stride is 1
|
||
|
|
self.use_skip_connection = True if self.stride == 1 else False
|
||
|
|
|
||
|
|
# expansion factor
|
||
|
|
self.ex_channels = int(self.in_channels * self.expansion_factor)
|
||
|
|
self.pconv1 = nn.Sequential(
|
||
|
|
# pointwise convolution
|
||
|
|
nn.Conv2d(self.in_channels, self.ex_channels, 1, 1, 0, bias=False),
|
||
|
|
nn.BatchNorm2d(self.ex_channels),
|
||
|
|
act_layer(self.activation, inplace=True)
|
||
|
|
)
|
||
|
|
self.msdc = MSDC(self.ex_channels, self.kernel_sizes, self.stride, self.activation,
|
||
|
|
dw_parallel=self.dw_parallel)
|
||
|
|
if self.add == True:
|
||
|
|
self.combined_channels = self.ex_channels * 1
|
||
|
|
else:
|
||
|
|
self.combined_channels = self.ex_channels * self.n_scales
|
||
|
|
self.pconv2 = nn.Sequential(
|
||
|
|
# pointwise convolution
|
||
|
|
nn.Conv2d(self.combined_channels, self.out_channels, 1, 1, 0, bias=False),
|
||
|
|
nn.BatchNorm2d(self.out_channels),
|
||
|
|
)
|
||
|
|
if self.use_skip_connection and (self.in_channels != self.out_channels):
|
||
|
|
self.conv1x1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0, bias=False)
|
||
|
|
self.init_weights('normal')
|
||
|
|
|
||
|
|
def init_weights(self, scheme=''):
|
||
|
|
named_apply(partial(_init_weights, scheme=scheme), self)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
pout1 = self.pconv1(x)
|
||
|
|
msdc_outs = self.msdc(pout1)
|
||
|
|
if self.add == True:
|
||
|
|
dout = 0
|
||
|
|
for dwout in msdc_outs:
|
||
|
|
dout = dout + dwout
|
||
|
|
else:
|
||
|
|
dout = torch.cat(msdc_outs, dim=1)
|
||
|
|
dout = channel_shuffle(dout, gcd(self.combined_channels, self.out_channels))
|
||
|
|
out = self.pconv2(dout)
|
||
|
|
if self.use_skip_connection:
|
||
|
|
if self.in_channels != self.out_channels:
|
||
|
|
x = self.conv1x1(x)
|
||
|
|
return x + out
|
||
|
|
else:
|
||
|
|
return out
|
||
|
|
|
||
|
|
|
||
|
|
# Multi-scale convolution block (MSCB)
|
||
|
|
def MSCBLayer(in_channels, out_channels, n=1, stride=1, kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True,
|
||
|
|
add=True, activation='relu6'):
|
||
|
|
"""
|
||
|
|
create a series of multi-scale convolution blocks.
|
||
|
|
"""
|
||
|
|
convs = []
|
||
|
|
mscb = MSCB(in_channels, out_channels, stride, kernel_sizes=kernel_sizes, expansion_factor=expansion_factor,
|
||
|
|
dw_parallel=dw_parallel, add=add, activation=activation)
|
||
|
|
convs.append(mscb)
|
||
|
|
if n > 1:
|
||
|
|
for i in range(1, n):
|
||
|
|
mscb = MSCB(out_channels, out_channels, 1, kernel_sizes=kernel_sizes, expansion_factor=expansion_factor,
|
||
|
|
dw_parallel=dw_parallel, add=add, activation=activation)
|
||
|
|
convs.append(mscb)
|
||
|
|
conv = nn.Sequential(*convs)
|
||
|
|
return conv
|
||
|
|
|
||
|
|
|
||
|
|
# Efficient up-convolution block (EUCB)
|
||
|
|
class EUCB(nn.Module):
|
||
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, activation='relu'):
|
||
|
|
super(EUCB, self).__init__()
|
||
|
|
|
||
|
|
self.in_channels = in_channels
|
||
|
|
self.out_channels = out_channels
|
||
|
|
self.up_dwc = nn.Sequential(
|
||
|
|
nn.Upsample(scale_factor=2),
|
||
|
|
nn.Conv2d(self.in_channels, self.in_channels, kernel_size=kernel_size, stride=stride,
|
||
|
|
padding=kernel_size // 2, groups=self.in_channels, bias=False),
|
||
|
|
nn.BatchNorm2d(self.in_channels),
|
||
|
|
act_layer(activation, inplace=True)
|
||
|
|
)
|
||
|
|
self.pwc = nn.Sequential(
|
||
|
|
nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||
|
|
)
|
||
|
|
self.init_weights('normal')
|
||
|
|
|
||
|
|
def init_weights(self, scheme=''):
|
||
|
|
named_apply(partial(_init_weights, scheme=scheme), self)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
x = self.up_dwc(x)
|
||
|
|
x = channel_shuffle(x, self.in_channels)
|
||
|
|
x = self.pwc(x)
|
||
|
|
return x
|
||
|
|
|
||
|
|
|
||
|
|
# Large-kernel grouped attention gate (LGAG)
|
||
|
|
class LGAG(nn.Module):
|
||
|
|
def __init__(self, F_g, F_l, F_int, kernel_size=3, groups=1, activation='relu'):
|
||
|
|
super(LGAG, self).__init__()
|
||
|
|
|
||
|
|
if kernel_size == 1:
|
||
|
|
groups = 1
|
||
|
|
self.W_g = nn.Sequential(
|
||
|
|
nn.Conv2d(F_g, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups,
|
||
|
|
bias=True),
|
||
|
|
nn.BatchNorm2d(F_int)
|
||
|
|
)
|
||
|
|
self.W_x = nn.Sequential(
|
||
|
|
nn.Conv2d(F_l, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups,
|
||
|
|
bias=True),
|
||
|
|
nn.BatchNorm2d(F_int)
|
||
|
|
)
|
||
|
|
self.psi = nn.Sequential(
|
||
|
|
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
||
|
|
nn.BatchNorm2d(1),
|
||
|
|
nn.Sigmoid()
|
||
|
|
)
|
||
|
|
self.activation = act_layer(activation, inplace=True)
|
||
|
|
|
||
|
|
self.init_weights('normal')
|
||
|
|
|
||
|
|
def init_weights(self, scheme=''):
|
||
|
|
named_apply(partial(_init_weights, scheme=scheme), self)
|
||
|
|
|
||
|
|
def forward(self, g, x):
|
||
|
|
g1 = self.W_g(g)
|
||
|
|
x1 = self.W_x(x)
|
||
|
|
psi = self.activation(g1 + x1)
|
||
|
|
psi = self.psi(psi)
|
||
|
|
|
||
|
|
return x * psi
|
||
|
|
|
||
|
|
|
||
|
|
# Channel attention block (CAB)
|
||
|
|
class CAB(nn.Module):
|
||
|
|
def __init__(self, in_channels, out_channels=None, ratio=16, activation='relu'):
|
||
|
|
super(CAB, self).__init__()
|
||
|
|
|
||
|
|
self.in_channels = in_channels
|
||
|
|
self.out_channels = out_channels
|
||
|
|
if self.in_channels < ratio:
|
||
|
|
ratio = self.in_channels
|
||
|
|
self.reduced_channels = self.in_channels // ratio
|
||
|
|
if self.out_channels == None:
|
||
|
|
self.out_channels = in_channels
|
||
|
|
|
||
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||
|
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||
|
|
self.activation = act_layer(activation, inplace=True)
|
||
|
|
self.fc1 = nn.Conv2d(self.in_channels, self.reduced_channels, 1, bias=False)
|
||
|
|
self.fc2 = nn.Conv2d(self.reduced_channels, self.out_channels, 1, bias=False)
|
||
|
|
|
||
|
|
self.sigmoid = nn.Sigmoid()
|
||
|
|
|
||
|
|
self.init_weights('normal')
|
||
|
|
|
||
|
|
def init_weights(self, scheme=''):
|
||
|
|
named_apply(partial(_init_weights, scheme=scheme), self)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
avg_pool_out = self.avg_pool(x)
|
||
|
|
avg_out = self.fc2(self.activation(self.fc1(avg_pool_out)))
|
||
|
|
|
||
|
|
max_pool_out = self.max_pool(x)
|
||
|
|
max_out = self.fc2(self.activation(self.fc1(max_pool_out)))
|
||
|
|
|
||
|
|
out = avg_out + max_out
|
||
|
|
return self.sigmoid(out)
|
||
|
|
|
||
|
|
# Spatial attention block (SAB)
|
||
|
|
|
||
|
|
|
||
|
|
class SAB(nn.Module):
|
||
|
|
def __init__(self, kernel_size=7):
|
||
|
|
super(SAB, self).__init__()
|
||
|
|
|
||
|
|
assert kernel_size in (3, 7, 11), 'kernel must be 3 or 7 or 11'
|
||
|
|
padding = kernel_size // 2
|
||
|
|
|
||
|
|
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||
|
|
|
||
|
|
self.sigmoid = nn.Sigmoid()
|
||
|
|
|
||
|
|
self.init_weights('normal')
|
||
|
|
|
||
|
|
def init_weights(self, scheme=''):
|
||
|
|
named_apply(partial(_init_weights, scheme=scheme), self)
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||
|
|
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||
|
|
x = torch.cat([avg_out, max_out], dim=1)
|
||
|
|
x = self.conv(x)
|
||
|
|
return self.sigmoid(x)
|
||
|
|
|
||
|
|
|
||
|
|
# Efficient multi-scale convolutional attention decoding (EMCAD)
|
||
|
|
class EMCAD(nn.Module):
|
||
|
|
def __init__(self, channels=[512, 320, 128, 64], kernel_sizes=[1, 3, 5], expansion_factor=6, dw_parallel=True,
|
||
|
|
add=True, lgag_ks=3, activation='relu6'):
|
||
|
|
super(EMCAD, self).__init__()
|
||
|
|
eucb_ks = 3 # kernel size for eucb
|
||
|
|
self.mscb4 = MSCBLayer(channels[0], channels[0], n=1, stride=1, kernel_sizes=kernel_sizes,
|
||
|
|
expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add,
|
||
|
|
activation=activation)
|
||
|
|
|
||
|
|
self.eucb3 = EUCB(in_channels=channels[0], out_channels=channels[1], kernel_size=eucb_ks, stride=eucb_ks // 2)
|
||
|
|
self.lgag3 = LGAG(F_g=channels[1], F_l=channels[1], F_int=channels[1] // 2, kernel_size=lgag_ks,
|
||
|
|
groups=channels[1] // 2)
|
||
|
|
self.mscb3 = MSCBLayer(channels[1], channels[1], n=1, stride=1, kernel_sizes=kernel_sizes,
|
||
|
|
expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add,
|
||
|
|
activation=activation)
|
||
|
|
|
||
|
|
self.eucb2 = EUCB(in_channels=channels[1], out_channels=channels[2], kernel_size=eucb_ks, stride=eucb_ks // 2)
|
||
|
|
self.lgag2 = LGAG(F_g=channels[2], F_l=channels[2], F_int=channels[2] // 2, kernel_size=lgag_ks,
|
||
|
|
groups=channels[2] // 2)
|
||
|
|
self.mscb2 = MSCBLayer(channels[2], channels[2], n=1, stride=1, kernel_sizes=kernel_sizes,
|
||
|
|
expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add,
|
||
|
|
activation=activation)
|
||
|
|
|
||
|
|
self.eucb1 = EUCB(in_channels=channels[2], out_channels=channels[3], kernel_size=eucb_ks, stride=eucb_ks // 2)
|
||
|
|
self.lgag1 = LGAG(F_g=channels[3], F_l=channels[3], F_int=int(channels[3] / 2), kernel_size=lgag_ks,
|
||
|
|
groups=int(channels[3] / 2))
|
||
|
|
self.mscb1 = MSCBLayer(channels[3], channels[3], n=1, stride=1, kernel_sizes=kernel_sizes,
|
||
|
|
expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add,
|
||
|
|
activation=activation)
|
||
|
|
|
||
|
|
self.cab4 = CAB(channels[0])
|
||
|
|
self.cab3 = CAB(channels[1])
|
||
|
|
self.cab2 = CAB(channels[2])
|
||
|
|
self.cab1 = CAB(channels[3])
|
||
|
|
|
||
|
|
self.sab = SAB()
|
||
|
|
|
||
|
|
def forward(self, x, skips):
|
||
|
|
# MSCAM4
|
||
|
|
d4 = self.cab4(x) * x
|
||
|
|
d4 = self.sab(d4) * d4
|
||
|
|
d4 = self.mscb4(d4)
|
||
|
|
|
||
|
|
# EUCB3
|
||
|
|
d3 = self.eucb3(d4)
|
||
|
|
|
||
|
|
# LGAG3
|
||
|
|
x3 = self.lgag3(g=d3, x=skips[0])
|
||
|
|
|
||
|
|
# Additive aggregation 3
|
||
|
|
d3 = d3 + x3
|
||
|
|
|
||
|
|
# MSCAM3
|
||
|
|
d3 = self.cab3(d3) * d3
|
||
|
|
d3 = self.sab(d3) * d3
|
||
|
|
d3 = self.mscb3(d3)
|
||
|
|
|
||
|
|
# EUCB2
|
||
|
|
d2 = self.eucb2(d3)
|
||
|
|
|
||
|
|
# LGAG2
|
||
|
|
x2 = self.lgag2(g=d2, x=skips[1])
|
||
|
|
|
||
|
|
# Additive aggregation 2
|
||
|
|
d2 = d2 + x2
|
||
|
|
|
||
|
|
# MSCAM2
|
||
|
|
d2 = self.cab2(d2) * d2
|
||
|
|
d2 = self.sab(d2) * d2
|
||
|
|
d2 = self.mscb2(d2)
|
||
|
|
|
||
|
|
# EUCB1
|
||
|
|
d1 = self.eucb1(d2)
|
||
|
|
|
||
|
|
# LGAG1
|
||
|
|
x1 = self.lgag1(g=d1, x=skips[2])
|
||
|
|
|
||
|
|
# Additive aggregation 1
|
||
|
|
d1 = d1 + x1
|
||
|
|
|
||
|
|
# MSCAM1
|
||
|
|
d1 = self.cab1(d1) * d1
|
||
|
|
d1 = self.sab(d1) * d1
|
||
|
|
d1 = self.mscb1(d1)
|
||
|
|
|
||
|
|
return [d4, d3, d2, d1]
|