You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
936 B
27 lines
936 B
import torch.nn as nn
|
|
from torchvision import models
|
|
|
|
# pre-trained alex net model
|
|
alexnet_model = models.alexnet(pretrained=True)
|
|
|
|
# nn.Module: Base class for all neural network modules.
|
|
# Custom class should also subclass this class
|
|
class AlexNetPlusLatent(nn.Module):
|
|
def __init__(self, bits):
|
|
super(AlexNetPlusLatent, self).__init__()
|
|
self.bits = bits
|
|
self.features = nn.Sequential(*list(alexnet_model.features.children()))
|
|
self.remain = nn.Sequential(*list(alexnet_model.classifier.children())[:-1])
|
|
self.Linear1 = nn.Linear(4096, self.bits)
|
|
self.sigmoid = nn.Sigmoid()
|
|
self.Linear2 = nn.Linear(self.bits, 10)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = x.view(x.size(0), 256 * 6 * 6)
|
|
x = self.remain(x)
|
|
x = self.Linear1(x)
|
|
features = self.sigmoid(x)
|
|
result = self.Linear2(features)
|
|
return features, result
|