You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
159 lines
5.1 KiB
159 lines
5.1 KiB
import os, sys
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
# from .types_ import *
|
|
|
|
|
|
class VanillaVAE(nn.Module):
|
|
def __init__(self,args,
|
|
in_channels: int,
|
|
latent_dim: int,
|
|
hidden_dims = None,
|
|
**kwargs) -> None:
|
|
super(VanillaVAE, self).__init__()
|
|
|
|
self.latent_dim = latent_dim
|
|
|
|
modules = []
|
|
if hidden_dims is None:
|
|
hidden_dims = [32, 64, 128, 256, 512]
|
|
|
|
if latent_dim is None:
|
|
latent_dim = 512
|
|
|
|
# Build Encoder
|
|
for h_dim in hidden_dims:
|
|
modules.append(
|
|
nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels=h_dim,
|
|
kernel_size= 3, stride= 2, padding = 1),
|
|
nn.BatchNorm2d(h_dim),
|
|
nn.LeakyReLU())
|
|
)
|
|
in_channels = h_dim
|
|
|
|
self.encoder = nn.Sequential(*modules)
|
|
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
|
|
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
|
|
|
|
|
|
# Build Decoder
|
|
modules = []
|
|
|
|
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
|
|
|
|
hidden_dims.reverse()
|
|
|
|
for i in range(len(hidden_dims) - 1):
|
|
modules.append(
|
|
nn.Sequential(
|
|
nn.ConvTranspose2d(hidden_dims[i],
|
|
hidden_dims[i + 1],
|
|
kernel_size=3,
|
|
stride = 2,
|
|
padding=1,
|
|
output_padding=1),
|
|
nn.BatchNorm2d(hidden_dims[i + 1]),
|
|
nn.LeakyReLU())
|
|
)
|
|
|
|
|
|
|
|
self.decoder = nn.Sequential(*modules)
|
|
|
|
self.final_layer = nn.Sequential(
|
|
nn.ConvTranspose2d(hidden_dims[-1],
|
|
hidden_dims[-1],
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
output_padding=1),
|
|
nn.BatchNorm2d(hidden_dims[-1]),
|
|
nn.LeakyReLU(),
|
|
nn.Conv2d(hidden_dims[-1], out_channels= 3,
|
|
kernel_size= 3, padding= 1),
|
|
nn.Tanh())
|
|
|
|
def encode(self, input):
|
|
"""
|
|
Encodes the input by passing through the encoder network
|
|
and returns the latent codes.
|
|
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
|
|
:return: (Tensor) List of latent codes
|
|
"""
|
|
result = self.encoder(input)
|
|
result = torch.flatten(result, start_dim=1)
|
|
|
|
# Split the result into mu and var components
|
|
# of the latent Gaussian distribution
|
|
mu = self.fc_mu(result)
|
|
# log_var = self.fc_var(result)
|
|
|
|
return mu
|
|
|
|
def decode(self, z):
|
|
"""
|
|
Maps the given latent codes
|
|
onto the image space.
|
|
:param z: (Tensor) [B x D]
|
|
:return: (Tensor) [B x C x H x W]
|
|
"""
|
|
result = self.decoder_input(z)
|
|
result = result.view(-1, 512, 2, 2)
|
|
result = self.decoder(result)
|
|
result = self.final_layer(result)
|
|
return result
|
|
|
|
# def reparameterize(self, mu, logvar):
|
|
# """
|
|
# Reparameterization trick to sample from N(mu, var) from
|
|
# N(0,1).
|
|
# :param mu: (Tensor) Mean of the latent Gaussian [B x D]
|
|
# :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
|
|
# :return: (Tensor) [B x D]
|
|
# """
|
|
# std = torch.exp(0.5 * logvar)
|
|
# eps = torch.randn_like(std)
|
|
# return eps * std + mu
|
|
|
|
def forward(self, input, **kwargs):
|
|
mu = self.encode(input)
|
|
# z = self.reparameterize(mu, log_var)
|
|
return self.decode(mu)
|
|
|
|
def loss_function(self,
|
|
*args,
|
|
**kwargs) -> dict:
|
|
"""
|
|
Computes the VAE loss function.
|
|
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
|
|
:param args:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
recons = args[0]
|
|
input = args[1]
|
|
# mu = args[2]
|
|
# log_var = args[3]
|
|
|
|
# kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
|
|
recons_loss =F.mse_loss(recons, input)
|
|
|
|
|
|
# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
|
|
|
|
loss = recons_loss
|
|
return loss
|
|
# {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':recons_loss.detach()}
|
|
|
|
|
|
def generate(self, x, **kwargs):
|
|
"""
|
|
Given an input image x, returns the reconstructed image
|
|
:param x: (Tensor) [B x C x H x W]
|
|
:return: (Tensor) [B x C x H x W]
|
|
"""
|
|
|
|
return self.forward(x)[0] |