174 lines
4.3 KiB
Python

import torchvision
import torch.nn as nn
import torch
from torchvision import models
import torch.nn.functional as F
__all__ = ["ResNet18", "ResNet50", "ResNet34"]
class ResNet18(nn.Module):
output_size = 512
def __init__(self, pretrained=True):
super(ResNet18, self).__init__()
pretrained = torchvision.models.resnet18(pretrained=pretrained)
for module_name in [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
]:
self.add_module(module_name, getattr(pretrained, module_name))
def forward(self, x):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
b1 = self.layer1(x)
b2 = self.layer2(b1)
b3 = self.layer3(b2)
b4 = self.layer4(b3)
return b1, b2, b3, b4
class ResNet34(nn.Module):
output_size = 512
def __init__(self, pretrained=True):
super(ResNet34, self).__init__()
pretrained = torchvision.models.resnet34(pretrained=pretrained)
for module_name in [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
]:
self.add_module(module_name, getattr(pretrained, module_name))
def forward(self, x):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
b1 = self.layer1(x)
b2 = self.layer2(b1)
b3 = self.layer3(b2)
b4 = self.layer4(b3)
return b1, b2, b3, b4
class ResNet50(nn.Module):
output_size = 2048
def __init__(self, pretrained=False):
super(ResNet50, self).__init__()
pretrained = torchvision.models.resnet50(pretrained=pretrained)
for module_name in [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
]:
self.add_module(module_name, getattr(pretrained, module_name))
def forward(self, x):
b0 = self.relu(self.bn1(self.conv1(x)))
b = self.maxpool(b0)
b1 = self.layer1(b)
b2 = self.layer2(b1)
b3 = self.layer3(b2)
b4 = self.layer4(b3)
return b1, b2, b3, b4
class resnext50_32x4d(nn.Module):
output_size = 2048
def __init__(self, pretrained=True):
super(resnext50_32x4d, self).__init__()
pretrained = torchvision.models.resnext50_32x4d(pretrained=pretrained)
for module_name in [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
"avgpool",
]:
self.add_module(module_name, getattr(pretrained, module_name))
def forward(self, x, get_ha=False):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
b1 = self.layer1(x)
b2 = self.layer2(b1)
b3 = self.layer3(b2)
b4 = self.layer4(b3)
pool = self.avgpool(b4)
if get_ha:
return b1, b2, b3, b4, pool
return pool
class resnet152(nn.Module):
output_size = 2048
def __init__(self, pretrained=True):
super(resnet152, self).__init__()
pretrained = torchvision.models.resnet152(pretrained=pretrained)
for module_name in [
"conv1",
"bn1",
"relu",
"maxpool",
"layer1",
"layer2",
"layer3",
"layer4",
"avgpool",
]:
self.add_module(module_name, getattr(pretrained, module_name))
def forward(self, x, get_ha=False):
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
b1 = self.layer1(x)
b2 = self.layer2(b1)
b3 = self.layer3(b2)
b4 = self.layer4(b3)
pool = self.avgpool(b4)
if get_ha:
return b1, b2, b3, b4, pool
return pool
if __name__ == "__main__":
from thop import profile
x = torch.autograd.Variable(torch.randn(1, 3, 512, 512))
net = ResNet50()
print(net)
out = net(x)
print(out[0].shape,out[1].shape,out[2].shape,out[3].shape)
flops, params = profile(net, (x,))
print('flops: ', flops, 'params: ', params)