You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
3.3 KiB
85 lines
3.3 KiB
import os
|
|
import random
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.parallel
|
|
import torch.backends.cudnn as cudnn
|
|
import torch.optim as optim
|
|
import torch.utils.data
|
|
import torchvision.datasets as dset
|
|
import torchvision.transforms as transforms
|
|
import torchvision.utils as vutils
|
|
import numpy as np
|
|
|
|
# class Discriminator(nn.Module):
|
|
# def __init__(self, ngpu, nc = 3, ndf = 64):
|
|
# super(Discriminator, self).__init__()
|
|
# self.ngpu = ngpu
|
|
# self.main = nn.Sequential(
|
|
# # input is (nc) x 64 x 64
|
|
# nn.Conv2d(nc, ndf, 4, 4, 1, bias=False),
|
|
# nn.LeakyReLU(0.2, inplace=True),
|
|
# # state size. (ndf) x 32 x 32
|
|
# nn.Conv2d(ndf, ndf * 2, 4, 4, 1, bias=False),
|
|
# nn.BatchNorm2d(ndf * 2),
|
|
# nn.LeakyReLU(0.2, inplace=True),
|
|
# # state size. (ndf*2) x 16 x 16
|
|
# nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
|
# nn.BatchNorm2d(ndf * 4),
|
|
# nn.LeakyReLU(0.2, inplace=True),
|
|
# # state size. (ndf*4) x 8 x 8
|
|
# nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
|
# nn.BatchNorm2d(ndf * 8),
|
|
# nn.LeakyReLU(0.2, inplace=True),
|
|
# # state size. (ndf*8) x 4 x 4
|
|
# nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
|
# nn.Sigmoid()
|
|
# )
|
|
|
|
# def forward(self, input):
|
|
# return self.main(input)
|
|
|
|
|
|
|
|
class Discriminator(torch.nn.Module):
|
|
def __init__(self, channels):
|
|
super().__init__()
|
|
# Filters [256, 512, 1024]
|
|
# Input_dim = channels (Cx64x64)
|
|
# Output_dim = 1
|
|
self.main_module = nn.Sequential(
|
|
# Omitting batch normalization in critic because our new penalized training objective (WGAN with gradient penalty) is no longer valid
|
|
# in this setting, since we penalize the norm of the critic's gradient with respect to each input independently and not the enitre batch.
|
|
# There is not good & fast implementation of layer normalization --> using per instance normalization nn.InstanceNorm2d()
|
|
# Image (Cx32x32)
|
|
nn.Conv2d(in_channels=channels, out_channels=256, kernel_size=4, stride=2, padding=1),
|
|
nn.InstanceNorm2d(256, affine=True),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
|
|
# State (256x16x16)
|
|
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
|
|
nn.InstanceNorm2d(512, affine=True),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
|
|
# State (512x8x8)
|
|
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
|
|
nn.InstanceNorm2d(1024, affine=True),
|
|
nn.LeakyReLU(0.2, inplace=True))
|
|
# output of main module --> State (1024x4x4)
|
|
|
|
self.output = nn.Sequential(
|
|
# The output of D is no longer a probability, we do not apply sigmoid at the output of D.
|
|
nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=1, padding=0))
|
|
|
|
|
|
def forward(self, x):
|
|
x = self.main_module(x)
|
|
return self.output(x)
|
|
|
|
def feature_extraction(self, x):
|
|
# Use discriminator for feature extraction then flatten to vector of 16384
|
|
x = self.main_module(x)
|
|
return x.view(-1, 1024*4*4)
|
|
|
|
|