You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

801 lines
29 KiB

### Copyright (C) 2020 Roy Or-El. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import functools
from torch.autograd import grad as Grad
from torch.autograd import Function
import numpy as np
from math import sqrt
from pdb import set_trace as st
###############################################################################
# Functions
###############################################################################
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
if init_type == 'gaussian':
init.normal_(m.weight.data, 0.0, 0.02)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
return init_fun
def get_norm_layer(norm_type='instance'):
if norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'pixel':
norm_layer = PixelNorm
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def define_G(input_nc, output_nc, ngf, n_downsample_global=2,
id_enc_norm='pixel', gpu_ids=[], padding_type='reflect',
style_dim=50, init_type='gaussian',
conv_weight_norm=False, decoder_norm='pixel', activation='lrelu',
adaptive_blocks=4, normalize_mlp=False, modulated_conv=False):
id_enc_norm = get_norm_layer(norm_type=id_enc_norm)
netG = Generator(input_nc, output_nc, ngf, n_downsampling=n_downsample_global,
id_enc_norm=id_enc_norm, padding_type=padding_type, style_dim=style_dim,
conv_weight_norm=conv_weight_norm, decoder_norm=decoder_norm,
actvn=activation, adaptive_blocks=adaptive_blocks,
normalize_mlp=normalize_mlp, modulated_conv=modulated_conv)
print(netG)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init(init_type))
return netG
def define_D(input_nc, ndf, n_layers=6, numClasses=2, gpu_ids=[],
init_type='gaussian'):
netD = StyleGANDiscriminator(input_nc, ndf=ndf, n_layers=n_layers,
numClasses=numClasses)
print(netD)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netD.cuda(gpu_ids[0])
netD.apply(weights_init('gaussian'))
return netD
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Data parallel wrapper
##############################################################################
class _CustomDataParallel(nn.DataParallel):
def __init__(self, model):
super(_CustomDataParallel, self).__init__(model)
def __getattr__(self, name):
try:
return super(_CustomDataParallel, self).__getattr__(name)
except AttributeError:
print(name)
return getattr(self.module, name)
##############################################################################
# Losses
##############################################################################
class FeatureConsistency(nn.Module):
def __init__(self):
super(FeatureConsistency, self).__init__()
def __call__(self,input,target):
return torch.mean(torch.abs(input - target))
class R1_reg(nn.Module):
def __init__(self, lambda_r1=10.0):
super(R1_reg, self).__init__()
self.lambda_r1 = lambda_r1
def __call__(self, d_out, d_in):
"""Compute gradient penalty: (L2_norm(dy/dx))**2."""
b = d_in.shape[0]
dydx = torch.autograd.grad(outputs=d_out.mean(),
inputs=d_in,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
dydx_sq = dydx.pow(2)
assert (dydx_sq.size() == d_in.size())
r1_reg = dydx_sq.sum() / b
return r1_reg * self.lambda_r1
class SelectiveClassesNonSatGANLoss(nn.Module):
def __init__(self):
super(SelectiveClassesNonSatGANLoss, self).__init__()
self.sofplus = nn.Softplus()
def __call__(self, input, target_classes, target_is_real, is_gen=False):
bSize = input.shape[0]
b_ind = torch.arange(bSize).long()
relevant_inputs = input[b_ind, target_classes, :, :]
if target_is_real:
loss = self.sofplus(-relevant_inputs).mean()
else:
loss = self.sofplus(relevant_inputs).mean()
return loss
##############################################################################
# Generator
##############################################################################
class EqualLR:
def __init__(self, name):
self.name = name
def compute_weight(self, module):
weight = getattr(module, self.name + '_orig')
fan_in = weight.data.size(1) * weight.data[0][0].numel()
return weight * sqrt(2 / fan_in)
@staticmethod
def apply(module, name):
fn = EqualLR(name)
weight = getattr(module, name)
del module._parameters[name]
module.register_parameter(name + '_orig', nn.Parameter(weight.data))
module.register_forward_pre_hook(fn)
return fn
def __call__(self, module, input):
weight = self.compute_weight(module)
setattr(module, self.name, weight)
def equal_lr(module, name='weight'):
EqualLR.apply(module, name)
return module
class PixelNorm(nn.Module):
def __init__(self, num_channels=None):
super().__init__()
# num_channels is only used to match function signature with other normalization layers
# it has no actual use
def forward(self, input):
return input / torch.sqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-5)
class ModulatedConv2d(nn.Module):
def __init__(self, fin, fout, kernel_size, padding_type='reflect', upsample=False, downsample=False, latent_dim=256, normalize_mlp=False):
super(ModulatedConv2d, self).__init__()
self.in_channels = fin
self.out_channels = fout
self.kernel_size = kernel_size
self.upsample = upsample
self.downsample = downsample
padding_size = kernel_size // 2
if kernel_size == 1:
self.demudulate = False
else:
self.demudulate = True
self.weight = nn.Parameter(torch.Tensor(fout, fin, kernel_size, kernel_size))
self.bias = nn.Parameter(torch.Tensor(1, fout, 1, 1))
self.conv = F.conv2d
if normalize_mlp:
self.mlp_class_std = nn.Sequential(EqualLinear(latent_dim, fin), PixelNorm())
else:
self.mlp_class_std = EqualLinear(latent_dim, fin)
self.blur = Blur(fout)
if padding_type == 'reflect':
self.padding = nn.ReflectionPad2d(padding_size)
else:
self.padding = nn.ZeroPad2d(padding_size)
if self.upsample:
self.upsampler = nn.Upsample(scale_factor=2, mode='nearest')
if self.downsample:
self.downsampler = nn.AvgPool2d(2)
self.weight.data.normal_()
self.bias.data.zero_()
def forward(self, input, latent):
fan_in = self.weight.data.size(1) * self.weight.data[0][0].numel()
weight = self.weight * sqrt(2 / fan_in)
weight = weight.view(1, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size)
s = 1 + self.mlp_class_std(latent).view(-1, 1, self.in_channels, 1, 1)
weight = s * weight
if self.demudulate:
d = torch.rsqrt((weight ** 2).sum(4).sum(3).sum(2) + 1e-5).view(-1, self.out_channels, 1, 1, 1)
weight = (d * weight).view(-1, self.in_channels, self.kernel_size, self.kernel_size)
else:
weight = weight.view(-1, self.in_channels, self.kernel_size, self.kernel_size)
if self.upsample:
input = self.upsampler(input)
if self.downsample:
input = self.blur(input)
b,_,h,w = input.shape
input = input.view(1,-1,h,w)
input = self.padding(input)
out = self.conv(input, weight, groups=b).view(b, self.out_channels, h, w) + self.bias
if self.downsample:
out = self.downsampler(out)
if self.upsample:
out = self.blur(out)
return out
class EqualConv2d(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
conv = nn.Conv2d(*args, **kwargs)
conv.weight.data.normal_()
conv.bias.data.zero_()
self.conv = equal_lr(conv)
def forward(self, input):
return self.conv(input)
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
linear = nn.Linear(in_dim, out_dim)
linear.weight.data.normal_()
linear.bias.data.zero_()
self.linear = equal_lr(linear)
def forward(self, input):
return self.linear(input)
class BlurFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
grad_input = F.conv2d(
grad_output, kernel_flip, padding=1, groups=grad_output.shape[1]
)
return grad_input
@staticmethod
def backward(ctx, gradgrad_output):
kernel, kernel_flip = ctx.saved_tensors
grad_input = F.conv2d(
gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1]
)
return grad_input, None, None
class BlurFunction(Function):
@staticmethod
def forward(ctx, input, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
output = F.conv2d(input, kernel, padding=1, groups=input.shape[1])
return output
@staticmethod
def backward(ctx, grad_output):
kernel, kernel_flip = ctx.saved_tensors
grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
return grad_input, None, None
blur = BlurFunction.apply
class Blur(nn.Module):
def __init__(self, channel):
super().__init__()
weight = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
weight = weight.view(1, 1, 3, 3)
weight = weight / weight.sum()
weight_flip = torch.flip(weight, [2, 3])
self.register_buffer('weight', weight.repeat(channel, 1, 1, 1))
self.register_buffer('weight_flip', weight_flip.repeat(channel, 1, 1, 1))
def forward(self, input):
return blur(input, self.weight, self.weight_flip)
class MLP(nn.Module):
def __init__(self, input_dim, out_dim, fc_dim, n_fc,
weight_norm=False, activation='relu', normalize_mlp=False):#, pixel_norm=False):
super(MLP, self).__init__()
if weight_norm:
linear = EqualLinear
else:
linear = nn.Linear
if activation == 'lrelu':
actvn = nn.LeakyReLU(0.2,True)
elif activation == 'blrelu':
actvn = BidirectionalLeakyReLU()
else:
actvn = nn.ReLU(True)
self.input_dim = input_dim
self.model = []
# normalize input
if normalize_mlp:
self.model += [PixelNorm()]
# set the first layer
self.model += [linear(input_dim, fc_dim),
actvn]
if normalize_mlp:
self.model += [PixelNorm()]
# set the inner layers
for i in range(n_fc - 2):
self.model += [linear(fc_dim, fc_dim),
actvn]
if normalize_mlp:
self.model += [PixelNorm()]
# set the last layer
self.model += [linear(fc_dim, out_dim)] # no output activations
# normalize output
if normalize_mlp:
self.model += [PixelNorm()]
self.model = nn.Sequential(*self.model)
def forward(self, input):
out = self.model(input)
return out
class StyledConvBlock(nn.Module):
def __init__(self, fin, fout, latent_dim=256, padding='reflect', upsample=False, downsample=False,
actvn='lrelu', use_pixel_norm=False, normalize_affine_output=False, modulated_conv=False):
super(StyledConvBlock, self).__init__()
if not modulated_conv:
if padding == 'reflect':
padding_layer = nn.ReflectionPad2d
else:
padding_layer = nn.ZeroPad2d
if modulated_conv:
conv2d = ModulatedConv2d
else:
conv2d = EqualConv2d
if modulated_conv:
self.actvn_gain = sqrt(2)
else:
self.actvn_gain = 1.0
self.use_pixel_norm = use_pixel_norm
self.upsample = upsample
self.downsample = downsample
self.modulated_conv = modulated_conv
if actvn == 'relu':
activation = nn.ReLU(True)
else:
activation = nn.LeakyReLU(0.2,True)
if self.downsample:
self.downsampler = nn.AvgPool2d(2)
if self.modulated_conv:
self.conv0 = conv2d(fin, fout, kernel_size=3, padding_type=padding, upsample=upsample,
latent_dim=latent_dim, normalize_mlp=normalize_affine_output)
else:
conv0 = conv2d(fin, fout, kernel_size=3)
if self.upsample:
seq0 = [self.upsampler, padding_layer(1), conv0, Blur(fout)]
else:
seq0 = [padding_layer(1), conv0]
self.conv0 = nn.Sequential(*seq0)
if use_pixel_norm:
self.pxl_norm0 = PixelNorm()
self.actvn0 = activation
if self.modulated_conv:
self.conv1 = conv2d(fout, fout, kernel_size=3, padding_type=padding, downsample=downsample,
latent_dim=latent_dim, normalize_mlp=normalize_affine_output)
else:
conv1 = conv2d(fout, fout, kernel_size=3)
if self.downsample:
seq1 = [Blur(fout), padding_layer(1), conv1, self.downsampler]
else:
seq1 = [padding_layer(1), conv1]
self.conv1 = nn.Sequential(*seq1)
if use_pixel_norm:
self.pxl_norm1 = PixelNorm()
self.actvn1 = activation
def forward(self, input, latent=None):
if self.modulated_conv:
out = self.conv0(input,latent)
else:
out = self.conv0(input)
out = self.actvn0(out) * self.actvn_gain
if self.use_pixel_norm:
out = self.pxl_norm0(out)
if self.modulated_conv:
out = self.conv1(out,latent)
else:
out = self.conv1(out)
out = self.actvn1(out) * self.actvn_gain
if self.use_pixel_norm:
out = self.pxl_norm1(out)
return out
class IdentityEncoder(nn.Module):
def __init__(self, input_nc, ngf=64, n_downsampling=3, n_blocks=7,
norm_layer=PixelNorm, padding_type='reflect',
conv_weight_norm=False, actvn='relu'):
assert(n_blocks >= 0)
super(IdentityEncoder, self).__init__()
if padding_type == 'reflect':
padding_layer = nn.ReflectionPad2d
else:
padding_layer = nn.ZeroPad2d
if conv_weight_norm:
conv2d = EqualConv2d
else:
conv2d = nn.Conv2d
if actvn == 'lrelu':
activation = nn.LeakyReLU(0.2, True)
else:
activation = nn.ReLU(True)
encoder = [padding_layer(3), conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
encoder += [padding_layer(1),
conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
encoder += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation,
norm_layer=norm_layer, conv_weight_norm=conv_weight_norm)]
self.encoder = nn.Sequential(*encoder)
def forward(self, input):
return self.encoder(input)
class AgeEncoder(nn.Module):
def __init__(self, input_nc, ngf=64, n_downsampling=4, style_dim=50, padding_type='reflect',
conv_weight_norm=False, actvn='lrelu'):
super(AgeEncoder, self).__init__()
if padding_type == 'reflect':
padding_layer = nn.ReflectionPad2d
else:
padding_layer = nn.ZeroPad2d
if conv_weight_norm:
conv2d = EqualConv2d
else:
conv2d = nn.Conv2d
if actvn == 'lrelu':
activation = nn.LeakyReLU(0.2, True)
else:
activation = nn.ReLU(True)
encoder = [padding_layer(3), conv2d(input_nc, ngf, kernel_size=7, padding=0), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
encoder += [padding_layer(1),
conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=0),
activation]
encoder += [conv2d(ngf * mult * 2, style_dim, kernel_size=1, stride=1, padding=0)]
self.encoder = nn.Sequential(*encoder)
def forward(self, input):
features = self.encoder(input)
latent = features.mean(dim=3).mean(dim=2)
return latent
class StyledDecoder(nn.Module):
def __init__(self, output_nc, ngf=64, style_dim=50, latent_dim=256, n_downsampling=2,
padding_type='reflect', actvn='lrelu', use_tanh=True, use_pixel_norm=False,
normalize_mlp=False, modulated_conv=False):
super(StyledDecoder, self).__init__()
if padding_type == 'reflect':
padding_layer = nn.ReflectionPad2d
else:
padding_layer = nn.ZeroPad2d
mult = 2**n_downsampling
last_upconv_out_layers = ngf * mult // 4
self.StyledConvBlock_0 = StyledConvBlock(ngf * mult, ngf * mult, latent_dim=latent_dim,
padding=padding_type, actvn=actvn,
use_pixel_norm=use_pixel_norm,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
self.StyledConvBlock_1 = StyledConvBlock(ngf * mult, ngf * mult, latent_dim=latent_dim,
padding=padding_type, actvn=actvn,
use_pixel_norm=use_pixel_norm,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
self.StyledConvBlock_2 = StyledConvBlock(ngf * mult, ngf * mult, latent_dim=latent_dim,
padding=padding_type, actvn=actvn,
use_pixel_norm=use_pixel_norm,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
self.StyledConvBlock_3 = StyledConvBlock(ngf * mult, ngf * mult, latent_dim=latent_dim,
padding=padding_type, actvn=actvn,
use_pixel_norm=use_pixel_norm,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
self.StyledConvBlock_up0 = StyledConvBlock(ngf * mult, ngf * mult // 2, latent_dim=latent_dim,
padding=padding_type, upsample=True, actvn=actvn,
use_pixel_norm=use_pixel_norm,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
self.StyledConvBlock_up1 = StyledConvBlock(ngf * mult // 2, last_upconv_out_layers, latent_dim=latent_dim,
padding=padding_type, upsample=True, actvn=actvn,
use_pixel_norm=use_pixel_norm,
normalize_affine_output=normalize_mlp,
modulated_conv=modulated_conv)
self.conv_img = nn.Sequential(EqualConv2d(last_upconv_out_layers, output_nc, 1), nn.Tanh())
self.mlp = MLP(style_dim, latent_dim, 256, 8, weight_norm=True, activation=actvn, normalize_mlp=normalize_mlp)
def forward(self, id_features, target_age=None, traverse=False, deploy=False, interp_step=0.5):
if target_age is not None:
if traverse:
alphas = torch.arange(1,0,step=-interp_step).view(-1,1).cuda()
interps = len(alphas)
orig_class_num = target_age.shape[0]
output_classes = interps * (orig_class_num - 1) + 1
temp_latent = self.mlp(target_age)
latent = temp_latent.new_zeros((output_classes, temp_latent.shape[1]))
else:
latent = self.mlp(target_age)
else:
latent = None
if traverse:
id_features = id_features.repeat(output_classes,1,1,1)
for i in range(orig_class_num-1):
latent[interps*i:interps*(i+1), :] = alphas * temp_latent[i,:] + (1 - alphas) * temp_latent[i+1,:]
latent[-1,:] = temp_latent[-1,:]
elif deploy:
output_classes = target_age.shape[0]
id_features = id_features.repeat(output_classes,1,1,1)
out = self.StyledConvBlock_0(id_features, latent)
out = self.StyledConvBlock_1(out, latent)
out = self.StyledConvBlock_2(out, latent)
out = self.StyledConvBlock_3(out, latent)
out = self.StyledConvBlock_up0(out, latent)
out = self.StyledConvBlock_up1(out, latent)
out = self.conv_img(out)
return out
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, style_dim=50, n_downsampling=2,
n_blocks=4, adaptive_blocks=4, id_enc_norm=PixelNorm,
padding_type='reflect', conv_weight_norm=False,
decoder_norm='pixel', actvn='lrelu', normalize_mlp=False,
modulated_conv=False):
super(Generator, self).__init__()
self.id_encoder = IdentityEncoder(input_nc, ngf, n_downsampling, n_blocks, id_enc_norm,
padding_type, conv_weight_norm=conv_weight_norm,
actvn='relu') # replacing relu with leaky relu here causes nans and the entire training to collapse immediately
self.age_encoder = AgeEncoder(input_nc, ngf=ngf, n_downsampling=4, style_dim=style_dim,
padding_type=padding_type, actvn=actvn,
conv_weight_norm=conv_weight_norm)
use_pixel_norm = decoder_norm == 'pixel'
self.decoder = StyledDecoder(output_nc, ngf=ngf, style_dim=style_dim,
n_downsampling=n_downsampling, actvn=actvn,
use_pixel_norm=use_pixel_norm,
normalize_mlp=normalize_mlp,
modulated_conv=modulated_conv)
def encode(self, input):
if torch.is_tensor(input):
id_features = self.id_encoder(input)
age_features = self.age_encoder(input)
return id_features, age_features
else:
return None, None
def decode(self, id_features, target_age_features, traverse=False, deploy=False, interp_step=0.5):
if torch.is_tensor(id_features):
return self.decoder(id_features, target_age_features, traverse=traverse, deploy=deploy, interp_step=interp_step)
else:
return None
#parallel forward
def forward(self, input, target_age_code, cyc_age_code, source_age_code, disc_pass=False):
orig_id_features = self.id_encoder(input)
orig_age_features = self.age_encoder(input)
if disc_pass:
rec_out = None
else:
rec_out = self.decode(orig_id_features, source_age_code)
gen_out = self.decode(orig_id_features, target_age_code)
if disc_pass:
fake_id_features = None
fake_age_features = None
cyc_out = None
else:
fake_id_features = self.id_encoder(gen_out)
fake_age_features = self.age_encoder(gen_out)
cyc_out = self.decode(fake_id_features, cyc_age_code)
return rec_out, gen_out, cyc_out, orig_id_features, orig_age_features, fake_id_features, fake_age_features
def infer(self, input, target_age_features, traverse=False, deploy=False, interp_step=0.5):
id_features = self.id_encoder(input)
out = self.decode(id_features, target_age_features, traverse=traverse, deploy=deploy, interp_step=interp_step)
return out
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True),
conv_weight_norm=False, use_pixel_norm=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation,
conv_weight_norm, use_pixel_norm)
def build_conv_block(self, dim, padding_type, norm_layer, activation, conv_weight_norm, use_pixel_norm):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
if conv_weight_norm:
conv2d = EqualConv2d
else:
conv2d = nn.Conv2d
self.use_pixel_norm = use_pixel_norm
if self.use_pixel_norm:
self.pixel_norm = PixelNorm()
conv_block += [conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
activation]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
##############################################################################
# Discriminator
##############################################################################
class StyleGANDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=6, numClasses=2, padding_type='reflect'):
super(StyleGANDiscriminator, self).__init__()
self.n_layers = n_layers
if padding_type == 'reflect':
padding_layer = nn.ReflectionPad2d
else:
padding_layer = nn.ZeroPad2d
activation = nn.LeakyReLU(0.2,True)
sequence = [padding_layer(0), EqualConv2d(input_nc, ndf, kernel_size=1), activation]
nf = ndf
for n in range(n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [StyledConvBlock(nf_prev, nf, downsample=True, actvn=activation)]
self.model = nn.Sequential(*sequence)
output_nc = numClasses
self.gan_head = nn.Sequential(padding_layer(1), EqualConv2d(nf+1, nf, kernel_size=3), activation,
EqualConv2d(nf, output_nc, kernel_size=4), activation)
def minibatch_stdev(self, input):
out_std = torch.sqrt(input.var(0, unbiased=False) + 1e-8)
mean_std = out_std.mean()
mean_std = mean_std.expand(input.size(0), 1, input.size(2), input.size(3))
out = torch.cat((input, mean_std), 1)
return out
def forward(self, input):
features = self.model(input)
gan_out = self.gan_head(self.minibatch_stdev(features))
return gan_out