You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

83 lines
2.1 KiB

import torch.nn as nn
import torch.nn.functional as F
class ResBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, stride: int=1):
super(ResBlock, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(out_channels),
)
def forward(self, x):
out1 = self.layer(x)
out2 = self.shortcut(x)
out = out1 + out2
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, in_channels: int=1):
super(ResNet, self).__init__()
# (32, 159, 159)
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=5, stride=3, padding=0),
nn.BatchNorm2d(32),
nn.MaxPool2d(2),
nn.ReLU(),
)
self.layers = nn.ModuleList()
# (64, 79, 79)
self.layers.append(ResBlock(32, 64, stride=2))
# (128, 39, 39)
self.layers.append(ResBlock(64, 128, stride=2))
# (256, 19, 19)
self.layers.append(ResBlock(128, 256, stride=2))
# # (512, 9, 9)
self.layers.append(ResBlock(256, 512, stride=2))
# (1024, 4, 4)
# self.layers.append(ResBlock(512, 1024, stride=2))
self.pool = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=0.5)
self.fc = nn.Linear(512, 2)
def forward(self, x):
out = self.conv(x)
for layer in self.layers:
out = layer(out)
out = self.dropout(out)
out = self.pool(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out