128 lines
5.0 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import build_backbone
from .modules import TransformerDecoder, Transformer
from einops import rearrange
class token_encoder(nn.Module):
def __init__(self, in_chan = 32, token_len = 4, heads = 8):
super(token_encoder, self).__init__()
self.token_len = token_len
self.conv_a = nn.Conv2d(in_chan, token_len, kernel_size=1, padding=0)
self.pos_embedding = nn.Parameter(torch.randn(1, token_len, in_chan))
self.transformer = Transformer(dim=in_chan, depth=1, heads=heads, dim_head=64, mlp_dim=64, dropout=0)
def forward(self, x):
b, c, h, w = x.shape
spatial_attention = self.conv_a(x)
spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous()
spatial_attention = torch.softmax(spatial_attention, dim=-1)
x = x.view([b, c, -1]).contiguous()
tokens = torch.einsum('bln, bcn->blc', spatial_attention, x)
tokens += self.pos_embedding
x = self.transformer(tokens)
return x
class token_decoder(nn.Module):
def __init__(self, in_chan = 32, size = 32, heads = 8):
super(token_decoder, self).__init__()
self.pos_embedding_decoder = nn.Parameter(torch.randn(1, in_chan, size, size))
self.transformer_decoder = TransformerDecoder(dim=in_chan, depth=1, heads=heads, dim_head=True, mlp_dim=in_chan*2, dropout=0,softmax=in_chan)
def forward(self, x, m):
b, c, h, w = x.shape
x = x + self.pos_embedding_decoder
x = rearrange(x, 'b c h w -> b (h w) c')
x = self.transformer_decoder(x, m)
x = rearrange(x, 'b (h w) c -> b c h w', h=h)
return x
class context_aggregator(nn.Module):
def __init__(self, in_chan=32, size=32):
super(context_aggregator, self).__init__()
self.token_encoder = token_encoder(in_chan=in_chan, token_len=4)
self.token_decoder = token_decoder(in_chan = 32, size = size, heads = 8)
def forward(self, feature):
token = self.token_encoder(feature)
out = self.token_decoder(feature, token)
return out
class Classifier(nn.Module):
def __init__(self, in_chan=32, n_class=2):
super(Classifier, self).__init__()
self.head = nn.Sequential(
nn.Conv2d(in_chan * 2, in_chan, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(in_chan),
nn.ReLU(),
nn.Conv2d(in_chan, n_class, kernel_size=3, padding=1, stride=1))
def forward(self, x):
x = self.head(x)
return x
class CDNet(nn.Module):
def __init__(self, backbone='resnet18', output_stride=16, img_size = 512, img_chan=3, chan_num = 32, n_class =2):
super(CDNet, self).__init__()
BatchNorm = nn.BatchNorm2d
self.backbone = build_backbone(backbone, output_stride, BatchNorm, img_chan)
self.CA_s16 = context_aggregator(in_chan=chan_num, size=img_size//16)
self.CA_s8 = context_aggregator(in_chan=chan_num, size=img_size//8)
self.CA_s4 = context_aggregator(in_chan=chan_num, size=img_size//4)
self.conv_s8 = nn.Conv2d(chan_num*2, chan_num, kernel_size=3, padding=1)
self.conv_s4 = nn.Conv2d(chan_num*2, chan_num, kernel_size=3, padding=1)
self.upsamplex2 = nn.Upsample(scale_factor=2, mode="bicubic", align_corners=True)
self.classifier1 = Classifier(n_class = n_class)
self.classifier2 = Classifier(n_class = n_class)
self.classifier3 = Classifier(n_class = n_class)
def forward(self, img1, img2):
# CNN backbone, feature extractor
out1_s16, out1_s8, out1_s4 = self.backbone(img1)
out2_s16, out2_s8, out2_s4 = self.backbone(img2)
# context aggregate (scale 16, scale 8, scale 4)
x1_s16= self.CA_s16(out1_s16)
x2_s16 = self.CA_s16(out2_s16)
x16 = torch.cat([x1_s16, x2_s16], dim=1)
x16 = F.interpolate(x16, size=img1.shape[2:], mode='bicubic', align_corners=True)
x16 = self.classifier1(x16)
out1_s8 = self.conv_s8(torch.cat([self.upsamplex2(x1_s16), out1_s8], dim=1))
out2_s8 = self.conv_s8(torch.cat([self.upsamplex2(x2_s16), out2_s8], dim=1))
x1_s8 = self.CA_s8(out1_s8)
x2_s8 = self.CA_s8(out2_s8)
x8 = torch.cat([x1_s8, x2_s8], dim=1)
x8 = F.interpolate(x8, size=img1.shape[2:], mode='bicubic', align_corners=True)
x8 = self.classifier2(x8)
out1_s4 = self.conv_s4(torch.cat([self.upsamplex2(x1_s8), out1_s4], dim=1))
out2_s4 = self.conv_s4(torch.cat([self.upsamplex2(x2_s8), out2_s4], dim=1))
x1 = self.CA_s4(out1_s4)
x2 = self.CA_s4(out2_s4)
x = torch.cat([x1, x2], dim=1)
x = F.interpolate(x, size=img1.shape[2:], mode='bicubic', align_corners=True)
x = self.classifier3(x)
return x, x8, x16
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()