340 lines
12 KiB
Python
340 lines
12 KiB
Python
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
# original size to 4x downsampling layer
|
||
class SRFD(nn.Module):
|
||
def __init__(self, in_channels=3, out_channels=96):
|
||
super().__init__()
|
||
out_c14 = int(out_channels / 4) # out_channels / 4
|
||
out_c12 = int(out_channels / 2) # out_channels / 2
|
||
|
||
# 7x7 convolution with stride 1 for feature reinforcement, Channels from 3 to 1/4C.
|
||
self.conv_init = nn.Conv2d(in_channels, out_c14, kernel_size=7, stride=1, padding=3)
|
||
|
||
# original size to 2x downsampling layer
|
||
self.conv_1 = nn.Conv2d(out_c14, out_c12, kernel_size=3, stride=1, padding=1, groups=out_c14)
|
||
self.conv_x1 = nn.Conv2d(out_c12, out_c12, kernel_size=3, stride=2, padding=1, groups=out_c12)
|
||
self.batch_norm_x1 = nn.BatchNorm2d(out_c12)
|
||
self.cut_c = Cut(out_c14, out_c12)
|
||
self.fusion1 = nn.Conv2d(out_channels, out_c12, kernel_size=1, stride=1)
|
||
|
||
# 2x to 4x downsampling layer
|
||
self.conv_2 = nn.Conv2d(out_c12, out_channels, kernel_size=3, stride=1, padding=1, groups=out_c12)
|
||
self.conv_x2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, groups=out_channels)
|
||
self.batch_norm_x2 = nn.BatchNorm2d(out_channels)
|
||
self.max_m = nn.MaxPool2d(kernel_size=2, stride=2)
|
||
self.batch_norm_m = nn.BatchNorm2d(out_channels)
|
||
self.cut_r = Cut(out_c12, out_channels)
|
||
self.fusion2 = nn.Conv2d(out_channels * 3, out_channels, kernel_size=1, stride=1)
|
||
|
||
def forward(self, x):
|
||
# 7x7 convolution with stride 1 for feature reinforcement, Channels from 3 to 1/4C.
|
||
x = self.conv_init(x) # x = [B, C/4, H, W]
|
||
|
||
# original size to 2x downsampling layer
|
||
c = x # c = [B, C/4, H, W]
|
||
# CutD
|
||
c = self.cut_c(c) # c = [B, C, H/2, W/2] --> [B, C/2, H/2, W/2]
|
||
# ConvD
|
||
x = self.conv_1(x) # x = [B, C/4, H, W] --> [B, C/2, H/2, W/2]
|
||
x = self.conv_x1(x) # x = [B, C/2, H/2, W/2]
|
||
x = self.batch_norm_x1(x)
|
||
# Concat + conv
|
||
x = torch.cat([x, c], dim=1) # x = [B, C, H/2, W/2]
|
||
x = self.fusion1(x) # x = [B, C, H/2, W/2] --> [B, C/2, H/2, W/2]
|
||
|
||
# 2x to 4x downsampling layer
|
||
r = x # r = [B, C/2, H/2, W/2]
|
||
x = self.conv_2(x) # x = [B, C/2, H/2, W/2] --> [B, C, H/2, W/2]
|
||
m = x # m = [B, C, H/2, W/2]
|
||
# ConvD
|
||
x = self.conv_x2(x) # x = [B, C, H/4, W/4]
|
||
x = self.batch_norm_x2(x)
|
||
# MaxD
|
||
m = self.max_m(m) # m = [B, C, H/4, W/4]
|
||
m = self.batch_norm_m(m)
|
||
# CutD
|
||
r = self.cut_r(r) # r = [B, C, H/4, W/4]
|
||
# Concat + conv
|
||
x = torch.cat([x, r, m], dim=1) # x = [B, C*3, H/4, W/4]
|
||
x = self.fusion2(x) # x = [B, C*3, H/4, W/4] --> [B, C, H/4, W/4]
|
||
return x # x = [B, C, H/4, W/4]
|
||
|
||
|
||
# CutD
|
||
class Cut(nn.Module):
|
||
def __init__(self, in_channels, out_channels):
|
||
super().__init__()
|
||
self.conv_fusion = nn.Conv2d(in_channels * 4, out_channels, kernel_size=1, stride=1)
|
||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||
|
||
def forward(self, x):
|
||
x0 = x[:, :, 0::2, 0::2] # x = [B, C, H/2, W/2]
|
||
x1 = x[:, :, 1::2, 0::2]
|
||
x2 = x[:, :, 0::2, 1::2]
|
||
x3 = x[:, :, 1::2, 1::2]
|
||
x = torch.cat([x0, x1, x2, x3], dim=1) # x = [B, 4*C, H/2, W/2]
|
||
x = self.conv_fusion(x) # x = [B, out_channels, H/2, W/2]
|
||
x = self.batch_norm(x)
|
||
return x
|
||
|
||
|
||
# Deep feature downsampling
|
||
class DRFD(nn.Module):
|
||
def __init__(self, in_channels, out_channels):
|
||
super().__init__()
|
||
self.cut_c = Cut(in_channels=in_channels, out_channels=out_channels)
|
||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=in_channels)
|
||
self.conv_x = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, groups=out_channels)
|
||
self.act_x = nn.GELU()
|
||
self.batch_norm_x = nn.BatchNorm2d(out_channels)
|
||
self.batch_norm_m = nn.BatchNorm2d(out_channels)
|
||
self.max_m = nn.MaxPool2d(kernel_size=2, stride=2)
|
||
self.fusion = nn.Conv2d(3 * out_channels, out_channels, kernel_size=1, stride=1)
|
||
|
||
def forward(self, x): # input: x = [B, C, H, W]
|
||
c = x # c = [B, C, H, W]
|
||
x = self.conv(x) # x = [B, C, H, W] --> [B, 2C, H, W]
|
||
m = x # m = [B, 2C, H, W]
|
||
|
||
# CutD
|
||
c = self.cut_c(c) # c = [B, C, H, W] --> [B, 2C, H/2, W/2]
|
||
|
||
# ConvD
|
||
x = self.conv_x(x) # x = [B, 2C, H, W] --> [B, 2C, H/2, W/2]
|
||
x = self.act_x(x)
|
||
x = self.batch_norm_x(x)
|
||
|
||
# MaxD
|
||
m = self.max_m(m) # m = [B, 2C, H/2, W/2]
|
||
m = self.batch_norm_m(m)
|
||
|
||
# Concat + conv
|
||
x = torch.cat([c, x, m], dim=1) # x = [B, 6C, H/2, W/2]
|
||
x = self.fusion(x) # x = [B, 6C, H/2, W/2] --> [B, 2C, H/2, W/2]
|
||
B, C, h, w = x.shape
|
||
x = x.view(B, C, -1) # x = [B, 2C, H/2*W/2]
|
||
x = x.permute(0, 2, 1) # x = [B, H/2*W/2, 2C]
|
||
|
||
return x, h, w
|
||
|
||
|
||
from torch.nn import init
|
||
#-------------------------------------------------------------------------#
|
||
import time
|
||
import math
|
||
import copy
|
||
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
|
||
#-------------------------------------------------------------------------#
|
||
#-------------------------------------------------------------------------#
|
||
class ConvX(nn.Module):
|
||
def __init__(self, in_planes, out_planes, kernel=3, stride=1, sync=False):
|
||
super(ConvX, self).__init__()
|
||
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel//2, bias=False)
|
||
if sync:
|
||
self.bn = nn.SyncBatchNorm(out_planes)
|
||
else:
|
||
self.bn = nn.BatchNorm2d(out_planes)
|
||
self.relu = nn.ReLU(inplace=True)
|
||
|
||
def forward(self, x):
|
||
out = self.relu(self.bn(self.conv(x)))
|
||
return out
|
||
|
||
|
||
class CatBottleneck(nn.Module):
|
||
def __init__(self, in_planes, out_planes, block_num=3, stride=1):
|
||
super(CatBottleneck, self).__init__()
|
||
assert block_num > 1, print("block number should be larger than 1.")
|
||
self.conv_list = nn.ModuleList()
|
||
self.stride = stride
|
||
if stride == 2:
|
||
self.avd_layer = nn.Sequential(
|
||
nn.Conv2d(out_planes//2, out_planes//2, kernel_size=3, stride=2, padding=1, groups=out_planes//2, bias=False),
|
||
nn.BatchNorm2d(out_planes//2),
|
||
)
|
||
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||
stride = 1
|
||
|
||
for idx in range(block_num):
|
||
if idx == 0:
|
||
self.conv_list.append(ConvX(in_planes, out_planes//2, kernel=1))
|
||
elif idx == 1 and block_num == 2:
|
||
self.conv_list.append(ConvX(out_planes//2, out_planes//2, stride=stride))
|
||
elif idx == 1 and block_num > 2:
|
||
self.conv_list.append(ConvX(out_planes//2, out_planes//4, stride=stride))
|
||
elif idx < block_num - 1:
|
||
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx +1)),kernel=2*idx+1,stride=1))
|
||
else:
|
||
self.conv_list.append(ConvX(out_planes//int(math.pow(2, idx)), out_planes//int(math.pow(2, idx)),kernel=2*idx+1,stride=1))
|
||
|
||
def forward(self, x):
|
||
out_list = []
|
||
out1 = self.conv_list[0](x)
|
||
|
||
for idx, conv in enumerate(self.conv_list[1:]):
|
||
if idx == 0:
|
||
if self.stride == 2:
|
||
out = conv(self.avd_layer(out1))
|
||
else:
|
||
out = conv(out1)
|
||
else:
|
||
out = conv(out)
|
||
out_list.append(out)
|
||
|
||
if self.stride == 2:
|
||
out1 = self.skip(out1)
|
||
out_list.insert(0, out1)
|
||
|
||
out = torch.cat(out_list, dim=1)
|
||
return out
|
||
|
||
from mmseg.registry import MODELS
|
||
@MODELS.register_module()
|
||
class RFD(nn.Module):
|
||
def __init__(self, base=64, in_channels=3, layers=[2,2], block_num=4, fuse_type="cat", dropout=0.20, pretrain_model=''):
|
||
super(RFD, self).__init__()
|
||
if fuse_type == "cat":
|
||
block = CatBottleneck
|
||
# elif type == "add":
|
||
# block = AddBottleneck
|
||
self.in_channels = in_channels
|
||
self.features = self._make_layers(base, layers, block_num, block)
|
||
# self.x2 = nn.Sequential(self.features[:1])
|
||
self.x4 = nn.Sequential(self.features[:1])
|
||
self.x8 = nn.Sequential(self.features[1:3])
|
||
self.x16 = nn.Sequential(self.features[3:5])
|
||
if pretrain_model:
|
||
print('use pretrain model {}'.format(pretrain_model))
|
||
self.init_weight(pretrain_model)
|
||
else:
|
||
self.init_params()
|
||
# self.x4test=nn.Sequential(nn.Conv2d(3,32,3,2,1,),
|
||
# nn.BatchNorm2d(32),
|
||
# nn.ReLU(inplace=True),
|
||
# nn.Conv2d(32,64,3,2,1,),
|
||
# nn.BatchNorm2d(64),
|
||
# nn.ReLU(inplace=True))
|
||
def init_weight(self, pretrain_model):
|
||
|
||
state_dict = torch.load(pretrain_model)["state_dict"]#权重字典键
|
||
self_state_dict = self.state_dict()
|
||
for k, v in state_dict.items():
|
||
if k == 'features.0.conv.weight' and self.in_channels != 3:
|
||
#v = torch.cat([v, v,v], dim=1)#lap+org
|
||
v = torch.cat([v, v], dim=1)#lap
|
||
self_state_dict.update({k: v})
|
||
self.load_state_dict(self_state_dict, strict=False)
|
||
|
||
def init_params(self):
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
init.kaiming_normal_(m.weight, mode='fan_out')
|
||
if m.bias is not None:
|
||
init.constant_(m.bias, 0)
|
||
elif isinstance(m, nn.BatchNorm2d):
|
||
init.constant_(m.weight, 1)
|
||
init.constant_(m.bias, 0)
|
||
elif isinstance(m, nn.Linear):
|
||
init.normal_(m.weight, std=0.001)
|
||
if m.bias is not None:
|
||
init.constant_(m.bias, 0)
|
||
|
||
def _make_layers(self, base, layers, block_num, block):
|
||
features = []
|
||
# features += [ConvX(self.in_channels, base//2, 3, 2)]
|
||
# features += [ConvX(base//2, base, 3, 2)]
|
||
features += [SRFD(in_channels=self.in_channels, out_channels=64)]
|
||
|
||
for i, layer in enumerate(layers):
|
||
for j in range(layer):
|
||
if i == 0 and j == 0:
|
||
features.append(block(base, base*4, block_num, 2))
|
||
elif j == 0:
|
||
features.append(block(base*int(math.pow(2,i+1)), base*int(math.pow(2,i+2)), block_num, 2))
|
||
else:
|
||
features.append(block(base*int(math.pow(2,i+2)), base*int(math.pow(2,i+2)), block_num, 1))
|
||
|
||
return nn.Sequential(*features)
|
||
|
||
def forward(self, x, cas3=False):
|
||
|
||
feat4 = self.x4(x)#8 64 306 306
|
||
feat8 = self.x8(feat4)
|
||
feat16 = self.x16(feat8)
|
||
|
||
if cas3:
|
||
return feat4, feat8, feat16
|
||
else:
|
||
return feat8, feat16
|
||
|
||
|
||
if __name__=="__main__":
|
||
import torch
|
||
from thop import profile
|
||
|
||
|
||
# 清空GPU缓存,避免残留数据干扰
|
||
torch.cuda.empty_cache()
|
||
|
||
# 加载模型到GPU(假设模型已定义)
|
||
model= ShallowNet_RFD().to("cuda")
|
||
|
||
# 定义输入数据(模拟真实输入)
|
||
input_tensor = torch.randn(1, 3, 5000, 5000).cuda() # 假设输入为batch=1, 3通道,
|
||
|
||
# 预热:避免首次推理的初始化开销
|
||
with torch.no_grad():
|
||
_ = model(input_tensor)
|
||
|
||
# 重置显存统计,开始正式测量
|
||
torch.cuda.reset_peak_memory_stats()
|
||
|
||
# 执行推理(禁用梯度以节省显存)
|
||
with torch.no_grad():
|
||
output = model(input_tensor)
|
||
|
||
# 获取推理过程的峰值显存(单位:字节)
|
||
peak_memory = torch.cuda.max_memory_allocated()
|
||
print(f"推理峰值显存: {peak_memory / 1024 ** 2:.2f} MB")
|
||
|
||
# model= ShallowNet_RFD().to("cuda:5")
|
||
# # spp=ASPP(3,128).to("cuda")
|
||
# inputs=torch.randn(1,3,2448,2448).to("cuda:5")
|
||
# # results=spp(img)
|
||
# # result2=split_9patch(img)
|
||
# res = model(inputs)
|
||
#
|
||
# flops, params = profile(model, (inputs,))
|
||
# print('flops: ', flops, 'params: ', params)
|
||
# print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))
|
||
# for i in res:
|
||
# print(i.size())
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|