129 lines
5.1 KiB
Python
129 lines
5.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# from lib.pvtv2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b3, pvt_v2_b4, pvt_v2_b5
|
|
from mmseg.models.backbones.emcadblock.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
|
|
from mmseg.models.backbones.emcadblock.decoders import EMCAD
|
|
|
|
|
|
from mmseg.registry import MODELS
|
|
@MODELS.register_module()
|
|
class EMCADNet(nn.Module):
|
|
def __init__(self, num_classes=1, kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True, add=True, lgag_ks=3,
|
|
activation='relu', encoder='resnet50', pretrain=False, pretrained_dir='./pretrained_pth/pvt/'):
|
|
super(EMCADNet, self).__init__()
|
|
|
|
# conv block to convert single channel to 3 channels
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(1, 3, kernel_size=1),
|
|
nn.BatchNorm2d(3),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
# # backbone network initialization with pretrained weight
|
|
# if encoder == 'pvt_v2_b0':
|
|
# self.backbone = pvt_v2_b0()
|
|
# path = pretrained_dir + '/pvt_v2_b0.pth'
|
|
# channels = [256, 160, 64, 32]
|
|
# elif encoder == 'pvt_v2_b1':
|
|
# self.backbone = pvt_v2_b1()
|
|
# path = pretrained_dir + '/pvt_v2_b1.pth'
|
|
# channels = [512, 320, 128, 64]
|
|
# elif encoder == 'pvt_v2_b2':
|
|
# self.backbone = pvt_v2_b2()
|
|
# path = pretrained_dir + '/pvt_v2_b2.pth'
|
|
# channels = [512, 320, 128, 64]
|
|
# elif encoder == 'pvt_v2_b3':
|
|
# self.backbone = pvt_v2_b3()
|
|
# path = pretrained_dir + '/pvt_v2_b3.pth'
|
|
# channels = [512, 320, 128, 64]
|
|
# elif encoder == 'pvt_v2_b4':
|
|
# self.backbone = pvt_v2_b4()
|
|
# path = pretrained_dir + '/pvt_v2_b4.pth'
|
|
# channels = [512, 320, 128, 64]
|
|
# elif encoder == 'pvt_v2_b5':
|
|
# self.backbone = pvt_v2_b5()
|
|
# path = pretrained_dir + '/pvt_v2_b5.pth'
|
|
# channels = [512, 320, 128, 64]
|
|
if encoder == 'resnet18':
|
|
self.backbone = resnet18(pretrained=pretrain)
|
|
channels = [512, 256, 128, 64]
|
|
elif encoder == 'resnet34':
|
|
self.backbone = resnet34(pretrained=pretrain)
|
|
channels = [512, 256, 128, 64]
|
|
elif encoder == 'resnet50':
|
|
self.backbone = resnet50(pretrained=pretrain)
|
|
channels = [2048, 1024, 512, 256]
|
|
elif encoder == 'resnet101':
|
|
self.backbone = resnet101(pretrained=pretrain)
|
|
channels = [2048, 1024, 512, 256]
|
|
elif encoder == 'resnet152':
|
|
self.backbone = resnet152(pretrained=pretrain)
|
|
channels = [2048, 1024, 512, 256]
|
|
else:
|
|
print('Encoder not implemented! Continuing with default encoder pvt_v2_b2.')
|
|
# self.backbone = pvt_v2_b2()
|
|
path = pretrained_dir + '/pvt_v2_b2.pth'
|
|
channels = [512, 320, 128, 64]
|
|
|
|
if pretrain == True and 'pvt_v2' in encoder:
|
|
save_model = torch.load(path)
|
|
model_dict = self.backbone.state_dict()
|
|
state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
|
|
model_dict.update(state_dict)
|
|
self.backbone.load_state_dict(model_dict)
|
|
|
|
print('Model %s created, param count: %d' %
|
|
(encoder + ' backbone: ', sum([m.numel() for m in self.backbone.parameters()])))
|
|
|
|
# decoder initialization
|
|
self.decoder = EMCAD(channels=channels, kernel_sizes=kernel_sizes, expansion_factor=expansion_factor,
|
|
dw_parallel=dw_parallel, add=add, lgag_ks=lgag_ks, activation=activation)
|
|
|
|
print('Model %s created, param count: %d' %
|
|
('EMCAD decoder: ', sum([m.numel() for m in self.decoder.parameters()])))
|
|
|
|
self.out_head4 = nn.Conv2d(channels[0], num_classes, 1)
|
|
self.out_head3 = nn.Conv2d(channels[1], num_classes, 1)
|
|
self.out_head2 = nn.Conv2d(channels[2], num_classes, 1)
|
|
self.out_head1 = nn.Conv2d(channels[3], num_classes, 1)
|
|
|
|
def forward(self, x, mode='test'):
|
|
|
|
# if grayscale input, convert to 3 channels
|
|
if x.size()[1] == 1:
|
|
x = self.conv(x)
|
|
|
|
# encoder
|
|
x1, x2, x3, x4 = self.backbone(x)
|
|
# print(x1.shape, x2.shape, x3.shape, x4.shape)
|
|
|
|
# decoder
|
|
dec_outs = self.decoder(x4, [x3, x2, x1])
|
|
|
|
# prediction heads
|
|
p4 = self.out_head4(dec_outs[0])
|
|
p3 = self.out_head3(dec_outs[1])
|
|
p2 = self.out_head2(dec_outs[2])
|
|
p1 = self.out_head1(dec_outs[3])
|
|
|
|
p4 = F.interpolate(p4, scale_factor=32, mode='bilinear')
|
|
p3 = F.interpolate(p3, scale_factor=16, mode='bilinear')
|
|
p2 = F.interpolate(p2, scale_factor=8, mode='bilinear')
|
|
p1 = F.interpolate(p1, scale_factor=4, mode='bilinear')
|
|
|
|
if mode == 'test':
|
|
return [p4, p3, p2, p1]
|
|
|
|
return [p4, p3, p2, p1]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = EMCADNet().cuda()
|
|
input_tensor = torch.randn(1, 3, 352, 352).cuda()
|
|
|
|
P = model(input_tensor)
|
|
print(P[0].size(), P[1].size(), P[2].size(), P[3].size())
|
|
|