You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
31 lines
970 B
31 lines
970 B
import torch
|
|
import torch.nn as nn
|
|
|
|
class Autoencoder(nn.Module):
|
|
def __init__(self):
|
|
super(Autoencoder, self).__init__()
|
|
self.encoder = nn.Sequential(
|
|
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2, stride=2),
|
|
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(2, stride=2)
|
|
)
|
|
|
|
self.decoder = nn.Sequential(
|
|
nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.UpsamplingNearest2d(scale_factor=2),
|
|
nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
|
|
nn.ReLU(),
|
|
nn.UpsamplingNearest2d(scale_factor=2),
|
|
nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.encoder(x)
|
|
x = self.decoder(x)
|
|
return x
|