You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
2.1 KiB
83 lines
2.1 KiB
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
def __init__(self, in_channels: int, out_channels: int, stride: int=1):
|
|
super(ResBlock, self).__init__()
|
|
self.layer = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
nn.ReLU(),
|
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
)
|
|
|
|
self.shortcut = nn.Sequential()
|
|
if stride != 1 or in_channels != out_channels:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
|
|
nn.BatchNorm2d(out_channels),
|
|
)
|
|
|
|
def forward(self, x):
|
|
out1 = self.layer(x)
|
|
out2 = self.shortcut(x)
|
|
out = out1 + out2
|
|
out = F.relu(out)
|
|
return out
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
def __init__(self, in_channels: int=1):
|
|
super(ResNet, self).__init__()
|
|
|
|
# (32, 159, 159)
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_channels, 32, kernel_size=5, stride=3, padding=0),
|
|
nn.BatchNorm2d(32),
|
|
nn.MaxPool2d(2),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
|
|
self.layers = nn.ModuleList()
|
|
|
|
# (64, 79, 79)
|
|
self.layers.append(ResBlock(32, 64, stride=2))
|
|
|
|
# (128, 39, 39)
|
|
self.layers.append(ResBlock(64, 128, stride=2))
|
|
|
|
# (256, 19, 19)
|
|
self.layers.append(ResBlock(128, 256, stride=2))
|
|
|
|
# # (512, 9, 9)
|
|
self.layers.append(ResBlock(256, 512, stride=2))
|
|
|
|
# (1024, 4, 4)
|
|
# self.layers.append(ResBlock(512, 1024, stride=2))
|
|
|
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
self.dropout = nn.Dropout(p=0.5)
|
|
|
|
self.fc = nn.Linear(512, 2)
|
|
|
|
|
|
def forward(self, x):
|
|
out = self.conv(x)
|
|
|
|
|
|
for layer in self.layers:
|
|
out = layer(out)
|
|
|
|
out = self.dropout(out)
|
|
|
|
out = self.pool(out)
|
|
|
|
out = out.view(out.size(0), -1)
|
|
|
|
out = self.fc(out)
|
|
return out
|