import torch.nn as nn from torchvision import models # pre-trained alex net model alexnet_model = models.alexnet(pretrained=True) # nn.Module: Base class for all neural network modules. # Custom class should also subclass this class class AlexNetPlusLatent(nn.Module): def __init__(self, bits): super(AlexNetPlusLatent, self).__init__() self.bits = bits self.features = nn.Sequential(*list(alexnet_model.features.children())) self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1]) self.Linear1 = nn.Linear(4096, self.bits) self.sigmoid = nn.Sigmoid() self.Linear2 = nn.Linear(self.bits, 10) def forward(self, x): x = self.features(x) x = x.view(x.size(0), 256 * 6 * 6) x = self.remain(x) x = self.Linear1(x) features = self.sigmoid(x) result = self.Linear2(features) return features, result