import torch import torch.nn as nn import torch.nn.functional as F # from lib.pvtv2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b3, pvt_v2_b4, pvt_v2_b5 from mmseg.models.backbones.emcadblock.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 from mmseg.models.backbones.emcadblock.decoders import EMCAD from mmseg.registry import MODELS @MODELS.register_module() class EMCADNet(nn.Module): def __init__(self, num_classes=1, kernel_sizes=[1, 3, 5], expansion_factor=2, dw_parallel=True, add=True, lgag_ks=3, activation='relu', encoder='resnet50', pretrain=False, pretrained_dir='./pretrained_pth/pvt/'): super(EMCADNet, self).__init__() # conv block to convert single channel to 3 channels self.conv = nn.Sequential( nn.Conv2d(1, 3, kernel_size=1), nn.BatchNorm2d(3), nn.ReLU(inplace=True) ) # # backbone network initialization with pretrained weight # if encoder == 'pvt_v2_b0': # self.backbone = pvt_v2_b0() # path = pretrained_dir + '/pvt_v2_b0.pth' # channels = [256, 160, 64, 32] # elif encoder == 'pvt_v2_b1': # self.backbone = pvt_v2_b1() # path = pretrained_dir + '/pvt_v2_b1.pth' # channels = [512, 320, 128, 64] # elif encoder == 'pvt_v2_b2': # self.backbone = pvt_v2_b2() # path = pretrained_dir + '/pvt_v2_b2.pth' # channels = [512, 320, 128, 64] # elif encoder == 'pvt_v2_b3': # self.backbone = pvt_v2_b3() # path = pretrained_dir + '/pvt_v2_b3.pth' # channels = [512, 320, 128, 64] # elif encoder == 'pvt_v2_b4': # self.backbone = pvt_v2_b4() # path = pretrained_dir + '/pvt_v2_b4.pth' # channels = [512, 320, 128, 64] # elif encoder == 'pvt_v2_b5': # self.backbone = pvt_v2_b5() # path = pretrained_dir + '/pvt_v2_b5.pth' # channels = [512, 320, 128, 64] if encoder == 'resnet18': self.backbone = resnet18(pretrained=pretrain) channels = [512, 256, 128, 64] elif encoder == 'resnet34': self.backbone = resnet34(pretrained=pretrain) channels = [512, 256, 128, 64] elif encoder == 'resnet50': self.backbone = resnet50(pretrained=pretrain) channels = [2048, 1024, 512, 256] elif encoder == 'resnet101': self.backbone = resnet101(pretrained=pretrain) channels = [2048, 1024, 512, 256] elif encoder == 'resnet152': self.backbone = resnet152(pretrained=pretrain) channels = [2048, 1024, 512, 256] else: print('Encoder not implemented! Continuing with default encoder pvt_v2_b2.') # self.backbone = pvt_v2_b2() path = pretrained_dir + '/pvt_v2_b2.pth' channels = [512, 320, 128, 64] if pretrain == True and 'pvt_v2' in encoder: save_model = torch.load(path) model_dict = self.backbone.state_dict() state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} model_dict.update(state_dict) self.backbone.load_state_dict(model_dict) print('Model %s created, param count: %d' % (encoder + ' backbone: ', sum([m.numel() for m in self.backbone.parameters()]))) # decoder initialization self.decoder = EMCAD(channels=channels, kernel_sizes=kernel_sizes, expansion_factor=expansion_factor, dw_parallel=dw_parallel, add=add, lgag_ks=lgag_ks, activation=activation) print('Model %s created, param count: %d' % ('EMCAD decoder: ', sum([m.numel() for m in self.decoder.parameters()]))) self.out_head4 = nn.Conv2d(channels[0], num_classes, 1) self.out_head3 = nn.Conv2d(channels[1], num_classes, 1) self.out_head2 = nn.Conv2d(channels[2], num_classes, 1) self.out_head1 = nn.Conv2d(channels[3], num_classes, 1) def forward(self, x, mode='test'): # if grayscale input, convert to 3 channels if x.size()[1] == 1: x = self.conv(x) # encoder x1, x2, x3, x4 = self.backbone(x) # print(x1.shape, x2.shape, x3.shape, x4.shape) # decoder dec_outs = self.decoder(x4, [x3, x2, x1]) # prediction heads p4 = self.out_head4(dec_outs[0]) p3 = self.out_head3(dec_outs[1]) p2 = self.out_head2(dec_outs[2]) p1 = self.out_head1(dec_outs[3]) p4 = F.interpolate(p4, scale_factor=32, mode='bilinear') p3 = F.interpolate(p3, scale_factor=16, mode='bilinear') p2 = F.interpolate(p2, scale_factor=8, mode='bilinear') p1 = F.interpolate(p1, scale_factor=4, mode='bilinear') if mode == 'test': return [p4, p3, p2, p1] return [p4, p3, p2, p1] if __name__ == '__main__': model = EMCADNet().cuda() input_tensor = torch.randn(1, 3, 352, 352).cuda() P = model(input_tensor) print(P[0].size(), P[1].size(), P[2].size(), P[3].size())