import torch import torch.nn as nn class Autoencoder(nn.Module): def __init__(self): super(Autoencoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2, stride=2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2, stride=2) ) self.decoder = nn.Sequential( nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.UpsamplingNearest2d(scale_factor=2), nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1), nn.Sigmoid() ) def forward(self, x): x = self.encoder(x) x = self.decoder(x) return x