You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

27 lines
936 B

import torch.nn as nn
from torchvision import models
# pre-trained alex net model
alexnet_model = models.alexnet(pretrained=True)
# nn.Module: Base class for all neural network modules.
# Custom class should also subclass this class
class AlexNetPlusLatent(nn.Module):
def __init__(self, bits):
super(AlexNetPlusLatent, self).__init__()
self.bits = bits
self.features = nn.Sequential(*list(alexnet_model.features.children()))
self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1])
self.Linear1 = nn.Linear(4096, self.bits)
self.sigmoid = nn.Sigmoid()
self.Linear2 = nn.Linear(self.bits, 10)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), 256 * 6 * 6)
x = self.remain(x)
x = self.Linear1(x)
features = self.sigmoid(x)
result = self.Linear2(features)
return features, result