You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
5.1 KiB
171 lines
5.1 KiB
"""senet in pytorch
|
|
|
|
|
|
|
|
[1] Jie Hu, Li Shen, Samuel Albanie, Gang Sun, Enhua Wu
|
|
|
|
Squeeze-and-Excitation Networks
|
|
https://arxiv.org/abs/1709.01507
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class BasicResidualSEBlock(nn.Module):
|
|
|
|
expansion = 1
|
|
|
|
def __init__(self, in_channels, out_channels, stride, r=16):
|
|
super().__init__()
|
|
|
|
self.residual = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
|
|
nn.Conv2d(out_channels, out_channels * self.expansion, 3, padding=1),
|
|
nn.BatchNorm2d(out_channels * self.expansion),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
self.shortcut = nn.Sequential()
|
|
if stride != 1 or in_channels != out_channels * self.expansion:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride),
|
|
nn.BatchNorm2d(out_channels * self.expansion)
|
|
)
|
|
|
|
self.squeeze = nn.AdaptiveAvgPool2d(1)
|
|
self.excitation = nn.Sequential(
|
|
nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
shortcut = self.shortcut(x)
|
|
residual = self.residual(x)
|
|
|
|
squeeze = self.squeeze(residual)
|
|
squeeze = squeeze.view(squeeze.size(0), -1)
|
|
excitation = self.excitation(squeeze)
|
|
excitation = excitation.view(residual.size(0), residual.size(1), 1, 1)
|
|
|
|
x = residual * excitation.expand_as(residual) + shortcut
|
|
|
|
return F.relu(x)
|
|
|
|
class BottleneckResidualSEBlock(nn.Module):
|
|
|
|
expansion = 4
|
|
|
|
def __init__(self, in_channels, out_channels, stride, r=16):
|
|
super().__init__()
|
|
|
|
self.residual = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
|
|
nn.Conv2d(out_channels, out_channels, 3, stride=stride, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(inplace=True),
|
|
|
|
nn.Conv2d(out_channels, out_channels * self.expansion, 1),
|
|
nn.BatchNorm2d(out_channels * self.expansion),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
self.squeeze = nn.AdaptiveAvgPool2d(1)
|
|
self.excitation = nn.Sequential(
|
|
nn.Linear(out_channels * self.expansion, out_channels * self.expansion // r),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(out_channels * self.expansion // r, out_channels * self.expansion),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
self.shortcut = nn.Sequential()
|
|
if stride != 1 or in_channels != out_channels * self.expansion:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels * self.expansion, 1, stride=stride),
|
|
nn.BatchNorm2d(out_channels * self.expansion)
|
|
)
|
|
|
|
def forward(self, x):
|
|
|
|
shortcut = self.shortcut(x)
|
|
|
|
residual = self.residual(x)
|
|
squeeze = self.squeeze(residual)
|
|
squeeze = squeeze.view(squeeze.size(0), -1)
|
|
excitation = self.excitation(squeeze)
|
|
excitation = excitation.view(residual.size(0), residual.size(1), 1, 1)
|
|
|
|
x = residual * excitation.expand_as(residual) + shortcut
|
|
|
|
return F.relu(x)
|
|
|
|
class SEResNet(nn.Module):
|
|
|
|
def __init__(self, block, block_num, class_num=1):
|
|
super().__init__()
|
|
|
|
self.in_channels = 64
|
|
|
|
self.pre = nn.Sequential(
|
|
nn.Conv2d(3, 64, 3, padding=1),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
self.stage1 = self._make_stage(block, block_num[0], 64, 1)
|
|
self.stage2 = self._make_stage(block, block_num[1], 128, 2)
|
|
self.stage3 = self._make_stage(block, block_num[2], 256, 2)
|
|
self.stage4 = self._make_stage(block, block_num[3], 516, 2)
|
|
|
|
self.linear = nn.Linear(self.in_channels, class_num)
|
|
|
|
def forward(self, x):
|
|
x = self.pre(x)
|
|
|
|
x = self.stage1(x)
|
|
x = self.stage2(x)
|
|
x = self.stage3(x)
|
|
x = self.stage4(x)
|
|
|
|
x = F.adaptive_avg_pool2d(x, 1)
|
|
x = x.view(x.size(0), -1)
|
|
|
|
x = self.linear(x)
|
|
|
|
return x
|
|
|
|
|
|
def _make_stage(self, block, num, out_channels, stride):
|
|
|
|
layers = []
|
|
layers.append(block(self.in_channels, out_channels, stride))
|
|
self.in_channels = out_channels * block.expansion
|
|
|
|
while num - 1:
|
|
layers.append(block(self.in_channels, out_channels, 1))
|
|
num -= 1
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def seresnet18():
|
|
return SEResNet(BasicResidualSEBlock, [2, 2, 2, 2])
|
|
|
|
def seresnet34():
|
|
return SEResNet(BasicResidualSEBlock, [3, 4, 6, 3])
|
|
|
|
def seresnet50():
|
|
return SEResNet(BottleneckResidualSEBlock, [3, 4, 6, 3])
|
|
|
|
def seresnet101():
|
|
return SEResNet(BottleneckResidualSEBlock, [3, 4, 23, 3])
|
|
|
|
def seresnet152():
|
|
return SEResNet(BottleneckResidualSEBlock, [3, 8, 36, 3]) |