423 lines
15 KiB
Python

import torch
import torch.nn as nn
from functools import partial
import math
# from timm.models.layers import trunc_normal_tf_
from timm.models.helpers import named_apply
def gcd(a, b):
while b:
a, b = b, a % b
return a
# Other types of layers can go here (e.g., nn.Linear, etc.)
def _init_weights(module, name, scheme=''):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d):
if scheme == 'normal':
nn.init.normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
# elif scheme == 'trunc_normal':
# trunc_normal_tf_(module.weight, std=.02)
# if module.bias is not None:
# nn.init.zeros_(module.bias)
elif scheme == 'xavier_normal':
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'kaiming_normal':
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
# efficientnet like
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
fan_out //= module.groups
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer
act = act.lower()
if act == 'relu':
layer = nn.ReLU(inplace)
elif act == 'relu6':
layer = nn.ReLU6(inplace)
elif act == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act == 'gelu':
layer = nn.GELU()
elif act == 'hswish':
layer = nn.Hardswish(inplace)
else:
raise NotImplementedError('activation layer [%s] is not found' % act)
return layer
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
# Multi-scale depth-wise convolution (MSDC)
class MSDC(nn.Module):
def __init__(self, in_channels, kernel_sizes, stride, activation='relu6', dw_parallel=True):
super(MSDC, self).__init__()
self.in_channels = in_channels
self.kernel_sizes = kernel_sizes
self.activation = activation
self.dw_parallel = dw_parallel
self.dwconvs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(self.in_channels, self.in_channels, kernel_size, stride, kernel_size // 2,
groups=self.in_channels, bias=False),
nn.BatchNorm2d(self.in_channels),
act_layer(self.activation, inplace=True)
)
for kernel_size in self.kernel_sizes
])
self.init_weights('normal')
def init_weights(self, scheme=''):
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
# Apply the convolution layers in a loop
outputs = []
for dwconv in self.dwconvs:
dw_out = dwconv(x)
outputs.append(dw_out)
if self.dw_parallel == False:
x = x + dw_out
# You can return outputs based on what you intend to do with them
return outputs
class MSCB(nn.Module):
"""
Multi-scale convolution block (MSCB)
"""
def __init__(self, in_channels, out_channels, stride, kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True,
add=True, activation='relu6'):
super(MSCB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.kernel_sizes = kernel_sizes
self.expansion_factor = expansion_factor
self.dw_parallel = dw_parallel
self.add = add
self.activation = activation
self.n_scales = len(self.kernel_sizes)
# check stride value
assert self.stride in [1, 2]
# Skip connection if stride is 1
self.use_skip_connection = True if self.stride == 1 else False
# expansion factor
self.ex_channels = int(self.in_channels * self.expansion_factor)
self.pconv1 = nn.Sequential(
# pointwise convolution
nn.Conv2d(self.in_channels, self.ex_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.ex_channels),
act_layer(self.activation, inplace=True)
)
self.msdc = MSDC(self.ex_channels, self.kernel_sizes, self.stride, self.activation,
dw_parallel=self.dw_parallel)
if self.add == True:
self.combined_channels = self.ex_channels * 1
else:
self.combined_channels = self.ex_channels * self.n_scales
self.pconv2 = nn.Sequential(
# pointwise convolution
nn.Conv2d(self.combined_channels, self.out_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.out_channels),
)
if self.use_skip_connection and (self.in_channels != self.out_channels):
self.conv1x1 = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0, bias=False)
self.init_weights('normal')
def init_weights(self, scheme=''):
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
pout1 = self.pconv1(x)
msdc_outs = self.msdc(pout1)
if self.add == True:
dout = 0
for dwout in msdc_outs:
dout = dout + dwout
else:
dout = torch.cat(msdc_outs, dim=1)
dout = channel_shuffle(dout, gcd(self.combined_channels, self.out_channels))
out = self.pconv2(dout)
if self.use_skip_connection:
if self.in_channels != self.out_channels:
x = self.conv1x1(x)
return x + out
else:
return out
# Multi-scale convolution block (MSCB)
def MSCBLayer(in_channels, out_channels, n=1, stride=1, kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True,
add=True, activation='relu6'):
"""
create a series of multi-scale convolution blocks.
"""
convs = []
mscb = MSCB(in_channels, out_channels, stride, kernel_sizes=kernel_sizes, expansion_factor=expansion_factor,
dw_parallel=dw_parallel, add=add, activation=activation)
convs.append(mscb)
if n > 1:
for i in range(1, n):
mscb = MSCB(out_channels, out_channels, 1, kernel_sizes=kernel_sizes, expansion_factor=expansion_factor,
dw_parallel=dw_parallel, add=add, activation=activation)
convs.append(mscb)
conv = nn.Sequential(*convs)
return conv
# Efficient up-convolution block (EUCB)
class EUCB(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, activation='relu'):
super(EUCB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_dwc = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(self.in_channels, self.in_channels, kernel_size=kernel_size, stride=stride,
padding=kernel_size // 2, groups=self.in_channels, bias=False),
nn.BatchNorm2d(self.in_channels),
act_layer(activation, inplace=True)
)
self.pwc = nn.Sequential(
nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, stride=1, padding=0, bias=True)
)
self.init_weights('normal')
def init_weights(self, scheme=''):
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
x = self.up_dwc(x)
x = channel_shuffle(x, self.in_channels)
x = self.pwc(x)
return x
# Large-kernel grouped attention gate (LGAG)
class LGAG(nn.Module):
def __init__(self, F_g, F_l, F_int, kernel_size=3, groups=1, activation='relu'):
super(LGAG, self).__init__()
if kernel_size == 1:
groups = 1
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups,
bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups,
bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.activation = act_layer(activation, inplace=True)
self.init_weights('normal')
def init_weights(self, scheme=''):
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.activation(g1 + x1)
psi = self.psi(psi)
return x * psi
# Channel attention block (CAB)
class CAB(nn.Module):
def __init__(self, in_channels, out_channels=None, ratio=16, activation='relu'):
super(CAB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if self.in_channels < ratio:
ratio = self.in_channels
self.reduced_channels = self.in_channels // ratio
if self.out_channels == None:
self.out_channels = in_channels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.activation = act_layer(activation, inplace=True)
self.fc1 = nn.Conv2d(self.in_channels, self.reduced_channels, 1, bias=False)
self.fc2 = nn.Conv2d(self.reduced_channels, self.out_channels, 1, bias=False)
self.sigmoid = nn.Sigmoid()
self.init_weights('normal')
def init_weights(self, scheme=''):
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
avg_pool_out = self.avg_pool(x)
avg_out = self.fc2(self.activation(self.fc1(avg_pool_out)))
max_pool_out = self.max_pool(x)
max_out = self.fc2(self.activation(self.fc1(max_pool_out)))
out = avg_out + max_out
return self.sigmoid(out)
# Spatial attention block (SAB)
class SAB(nn.Module):
def __init__(self, kernel_size=7):
super(SAB, self).__init__()
assert kernel_size in (3, 7, 11), 'kernel must be 3 or 7 or 11'
padding = kernel_size // 2
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
self.init_weights('normal')
def init_weights(self, scheme=''):
named_apply(partial(_init_weights, scheme=scheme), self)
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv(x)
return self.sigmoid(x)
# Efficient multi-scale convolutional attention decoding (EMCAD)
class EMCAD(nn.Module):
def __init__(self, channels=[512, 320, 128, 64], kernel_sizes=[1, 3, 5], expansion_factor=6, dw_parallel=True,
add=True, lgag_ks=3, activation='relu6'):
super(EMCAD, self).__init__()
eucb_ks = 3 # kernel size for eucb
self.mscb4 = MSCBLayer(channels[0], channels[0], n=1, stride=1, kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add,
activation=activation)
self.eucb3 = EUCB(in_channels=channels[0], out_channels=channels[1], kernel_size=eucb_ks, stride=eucb_ks // 2)
self.lgag3 = LGAG(F_g=channels[1], F_l=channels[1], F_int=channels[1] // 2, kernel_size=lgag_ks,
groups=channels[1] // 2)
self.mscb3 = MSCBLayer(channels[1], channels[1], n=1, stride=1, kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add,
activation=activation)
self.eucb2 = EUCB(in_channels=channels[1], out_channels=channels[2], kernel_size=eucb_ks, stride=eucb_ks // 2)
self.lgag2 = LGAG(F_g=channels[2], F_l=channels[2], F_int=channels[2] // 2, kernel_size=lgag_ks,
groups=channels[2] // 2)
self.mscb2 = MSCBLayer(channels[2], channels[2], n=1, stride=1, kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add,
activation=activation)
self.eucb1 = EUCB(in_channels=channels[2], out_channels=channels[3], kernel_size=eucb_ks, stride=eucb_ks // 2)
self.lgag1 = LGAG(F_g=channels[3], F_l=channels[3], F_int=int(channels[3] / 2), kernel_size=lgag_ks,
groups=int(channels[3] / 2))
self.mscb1 = MSCBLayer(channels[3], channels[3], n=1, stride=1, kernel_sizes=kernel_sizes,
expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add,
activation=activation)
self.cab4 = CAB(channels[0])
self.cab3 = CAB(channels[1])
self.cab2 = CAB(channels[2])
self.cab1 = CAB(channels[3])
self.sab = SAB()
def forward(self, x, skips):
# MSCAM4
d4 = self.cab4(x) * x
d4 = self.sab(d4) * d4
d4 = self.mscb4(d4)
# EUCB3
d3 = self.eucb3(d4)
# LGAG3
x3 = self.lgag3(g=d3, x=skips[0])
# Additive aggregation 3
d3 = d3 + x3
# MSCAM3
d3 = self.cab3(d3) * d3
d3 = self.sab(d3) * d3
d3 = self.mscb3(d3)
# EUCB2
d2 = self.eucb2(d3)
# LGAG2
x2 = self.lgag2(g=d2, x=skips[1])
# Additive aggregation 2
d2 = d2 + x2
# MSCAM2
d2 = self.cab2(d2) * d2
d2 = self.sab(d2) * d2
d2 = self.mscb2(d2)
# EUCB1
d1 = self.eucb1(d2)
# LGAG1
x1 = self.lgag1(g=d1, x=skips[2])
# Additive aggregation 1
d1 = d1 + x1
# MSCAM1
d1 = self.cab1(d1) * d1
d1 = self.sab(d1) * d1
d1 = self.mscb1(d1)
return [d4, d3, d2, d1]