You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
505 lines
21 KiB
505 lines
21 KiB
### Copyright (C) 2020 Roy Or-El. All rights reserved.
|
|
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import re
|
|
import functools
|
|
from collections import OrderedDict
|
|
from .base_model import BaseModel
|
|
import util.util as util
|
|
from . import networks
|
|
from pdb import set_trace as st
|
|
|
|
class LATS(BaseModel): #Lifetime Age Transformation Synthesis
|
|
def name(self):
|
|
return 'LATS'
|
|
|
|
def initialize(self, opt):
|
|
BaseModel.initialize(self, opt)
|
|
|
|
# if opt.resize_or_crop != 'none': # when training at full res this causes OOM
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
# determine mode of operation [train, test, deploy, traverse (latent interpolation)]
|
|
self.isTrain = opt.isTrain
|
|
self.traverse = (not self.isTrain) and opt.traverse
|
|
|
|
# mode to generate Fig. 15 in the paper
|
|
self.compare_to_trained_outputs = (not self.isTrain) and opt.compare_to_trained_outputs
|
|
if self.compare_to_trained_outputs:
|
|
self.compare_to_trained_class = opt.compare_to_trained_class
|
|
self.trained_class_jump = opt.trained_class_jump
|
|
|
|
self.deploy = (not self.isTrain) and opt.deploy
|
|
if not self.isTrain and opt.random_seed != -1:
|
|
torch.manual_seed(opt.random_seed)
|
|
torch.cuda.manual_seed_all(opt.random_seed)
|
|
np.random.seed(opt.random_seed)
|
|
|
|
# network architecture parameters
|
|
self.nb = opt.batchSize
|
|
self.size = opt.fineSize
|
|
self.ngf = opt.ngf
|
|
self.ngf_global = self.ngf
|
|
|
|
self.numClasses = opt.numClasses
|
|
self.use_moving_avg = not opt.no_moving_avg
|
|
|
|
self.no_cond_noise = opt.no_cond_noise
|
|
style_dim = opt.gen_dim_per_style * self.numClasses
|
|
self.duplicate = opt.gen_dim_per_style
|
|
|
|
self.cond_length = style_dim
|
|
|
|
# self.active_classes_mapping = opt.active_classes_mapping
|
|
|
|
if not self.isTrain:
|
|
self.debug_mode = opt.debug_mode
|
|
else:
|
|
self.debug_mode = False
|
|
|
|
##### define networks
|
|
# Generators
|
|
self.netG = self.parallelize(networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.n_downsample,
|
|
id_enc_norm=opt.id_enc_norm, gpu_ids=self.gpu_ids, padding_type='reflect', style_dim=style_dim,
|
|
init_type='kaiming', conv_weight_norm=opt.conv_weight_norm,
|
|
decoder_norm=opt.decoder_norm, activation=opt.activation,
|
|
adaptive_blocks=opt.n_adaptive_blocks, normalize_mlp=opt.normalize_mlp,
|
|
modulated_conv=opt.use_modulated_conv))
|
|
if self.isTrain and self.use_moving_avg:
|
|
self.g_running = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.n_downsample,
|
|
id_enc_norm=opt.id_enc_norm, gpu_ids=self.gpu_ids, padding_type='reflect', style_dim=style_dim,
|
|
init_type='kaiming', conv_weight_norm=opt.conv_weight_norm,
|
|
decoder_norm=opt.decoder_norm, activation=opt.activation,
|
|
adaptive_blocks=opt.n_adaptive_blocks, normalize_mlp=opt.normalize_mlp,
|
|
modulated_conv=opt.use_modulated_conv)
|
|
self.g_running.train(False)
|
|
self.requires_grad(self.g_running, flag=False)
|
|
self.accumulate(self.g_running, self.netG, decay=0)
|
|
|
|
# Discriminator network
|
|
if self.isTrain:
|
|
self.netD = self.parallelize(networks.define_D(opt.output_nc, opt.ndf, n_layers=opt.n_layers_D,
|
|
numClasses=self.numClasses, gpu_ids=self.gpu_ids,
|
|
init_type='kaiming'))
|
|
|
|
if self.opt.verbose:
|
|
print('---------- Networks initialized -------------')
|
|
|
|
# load networks
|
|
if (not self.isTrain) or opt.continue_train or opt.load_pretrain:
|
|
pretrained_path = '' if (not self.isTrain) or (self.isTrain and opt.continue_train) else opt.load_pretrain
|
|
if self.isTrain:
|
|
self.load_network(self.netG, 'G_tex', opt.which_epoch, pretrained_path)
|
|
self.load_network(self.netD, 'D_tex', opt.which_epoch, pretrained_path)
|
|
if self.use_moving_avg:
|
|
self.load_network(self.g_running, 'g_running', opt.which_epoch, pretrained_path)
|
|
elif self.use_moving_avg:
|
|
self.load_network(self.netG, 'g_running', opt.which_epoch, pretrained_path)
|
|
else:
|
|
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
|
|
|
|
|
# set loss functions and optimizers
|
|
if self.isTrain:
|
|
# define loss functions
|
|
self.criterionGAN = self.parallelize(networks.SelectiveClassesNonSatGANLoss())
|
|
self.R1_reg = networks.R1_reg()
|
|
self.age_reconst_criterion = self.parallelize(networks.FeatureConsistency())
|
|
self.identity_reconst_criterion = self.parallelize(networks.FeatureConsistency())
|
|
self.criterionCycle = self.parallelize(networks.FeatureConsistency()) #torch.nn.L1Loss()
|
|
self.criterionRec = self.parallelize(networks.FeatureConsistency()) #torch.nn.L1Loss()
|
|
|
|
# initialize optimizers
|
|
self.old_lr = opt.lr
|
|
|
|
# set optimizer G
|
|
paramsG = []
|
|
params_dict_G = dict(self.netG.named_parameters())
|
|
# set the MLP learning rate to 0.01 or the global learning rate
|
|
for key, value in params_dict_G.items():
|
|
decay_cond = ('decoder.mlp' in key)
|
|
if opt.decay_adain_affine_layers:
|
|
decay_cond = decay_cond or ('class_std' in key) or ('class_mean' in key)
|
|
if decay_cond:
|
|
paramsG += [{'params':[value],'lr':opt.lr * 0.01,'mult':0.01}]
|
|
else:
|
|
paramsG += [{'params':[value],'lr':opt.lr}]
|
|
|
|
self.optimizer_G = torch.optim.Adam(paramsG, lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
|
|
|
# set optimizer D
|
|
paramsD = list(self.netD.parameters())
|
|
self.optimizer_D = torch.optim.Adam(paramsD, lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
|
|
|
|
|
def parallelize(self, model):
|
|
# parallelize a network
|
|
if self.isTrain and len(self.gpu_ids) > 0:
|
|
return networks._CustomDataParallel(model)
|
|
else:
|
|
return model
|
|
|
|
|
|
def requires_grad(self, model, flag=True):
|
|
# freeze network weights
|
|
for p in model.parameters():
|
|
p.requires_grad = flag
|
|
|
|
|
|
def accumulate(self, model1, model2, decay=0.999):
|
|
# implements exponential moving average
|
|
params1 = dict(model1.named_parameters())
|
|
params2 = dict(model2.named_parameters())
|
|
model1_parallel = isinstance(model1, nn.DataParallel)
|
|
model2_parallel = isinstance(model2, nn.DataParallel)
|
|
|
|
for k in params1.keys():
|
|
if model2_parallel and not model1_parallel:
|
|
k2 = 'module.' + k
|
|
elif model1_parallel and not model2_parallel:
|
|
k2 = re.sub('module.', '', k)
|
|
else:
|
|
k2 = k
|
|
params1[k].data.mul_(decay).add_(1 - decay, params2[k2].data)
|
|
|
|
|
|
def set_inputs(self, data, mode='train'):
|
|
# set input data to feed to the network
|
|
if mode == 'train':
|
|
real_A = data['A']
|
|
real_B = data['B']
|
|
|
|
self.class_A = data['A_class']
|
|
self.class_B = data['B_class']
|
|
|
|
self.reals = torch.cat((real_A, real_B), 0)
|
|
|
|
if len(self.gpu_ids) > 0:
|
|
self.reals = self.reals.cuda()
|
|
|
|
else:
|
|
inputs = data['Imgs']
|
|
if inputs.dim() > 4:
|
|
inputs = inputs.squeeze(0)
|
|
|
|
self.class_A = data['Classes']
|
|
if self.class_A.dim() > 1:
|
|
self.class_A = self.class_A.squeeze(0)
|
|
|
|
if torch.is_tensor(data['Valid']):
|
|
self.valid = data['Valid'].bool()
|
|
else:
|
|
self.valid = torch.ones(1, dtype=torch.bool)
|
|
|
|
if self.valid.dim() > 1:
|
|
self.valid = self.valid.squeeze(0)
|
|
|
|
if isinstance(data['Paths'][0], tuple):
|
|
self.image_paths = [path[0] for path in data['Paths']]
|
|
else:
|
|
self.image_paths = data['Paths']
|
|
|
|
self.isEmpty = False if any(self.valid) else True
|
|
if not self.isEmpty:
|
|
available_idx = torch.arange(len(self.class_A))
|
|
select_idx = torch.masked_select(available_idx, self.valid).long()
|
|
inputs = torch.index_select(inputs, 0, select_idx)
|
|
|
|
self.class_A = torch.index_select(self.class_A, 0, select_idx)
|
|
self.image_paths = [val for i, val in enumerate(self.image_paths) if self.valid[i] == 1]
|
|
|
|
self.reals = inputs
|
|
|
|
if len(self.gpu_ids) > 0:
|
|
self.reals = self.reals.cuda()
|
|
|
|
|
|
def get_conditions(self, mode='train'):
|
|
# set conditional inputs to the network
|
|
if mode == 'train':
|
|
nb = self.reals.shape[0] // 2
|
|
elif self.traverse or self.deploy:
|
|
if self.traverse and self.compare_to_trained_outputs:
|
|
nb = 2
|
|
else:
|
|
nb = self.numClasses
|
|
else:
|
|
nb = self.numValid
|
|
|
|
#tex condition mapping
|
|
condG_A_gen = self.Tensor(nb, self.cond_length)
|
|
condG_B_gen = self.Tensor(nb, self.cond_length)
|
|
condG_A_orig = self.Tensor(nb, self.cond_length)
|
|
condG_B_orig = self.Tensor(nb, self.cond_length)
|
|
|
|
if self.no_cond_noise:
|
|
noise_sigma = 0
|
|
else:
|
|
noise_sigma = 0.2
|
|
|
|
for i in range(nb):
|
|
condG_A_gen[i, :] = (noise_sigma * torch.randn(1, self.cond_length)).cuda()
|
|
condG_A_gen[i, self.class_B[i]*self.duplicate:(self.class_B[i] + 1)*self.duplicate] += 1
|
|
if not (self.traverse or self.deploy):
|
|
condG_B_gen[i, :] = (noise_sigma * torch.randn(1, self.cond_length)).cuda()
|
|
condG_B_gen[i, self.class_A[i]*self.duplicate:(self.class_A[i] + 1)*self.duplicate] += 1
|
|
|
|
condG_A_orig[i, :] = (noise_sigma * torch.randn(1, self.cond_length)).cuda()
|
|
condG_A_orig[i, self.class_A[i]*self.duplicate:(self.class_A[i] + 1)*self.duplicate] += 1
|
|
|
|
condG_B_orig[i, :] = (noise_sigma * torch.randn(1, self.cond_length)).cuda()
|
|
condG_B_orig[i, self.class_B[i]*self.duplicate:(self.class_B[i] + 1)*self.duplicate] += 1
|
|
|
|
if mode == 'train':
|
|
self.gen_conditions = torch.cat((condG_A_gen, condG_B_gen), 0) #torch.cat((self.class_B, self.class_A), 0)
|
|
# if the results are not good this might be the issue!!!! uncomment and update code respectively
|
|
self.cyc_conditions = torch.cat((condG_B_gen, condG_A_gen), 0)
|
|
self.orig_conditions = torch.cat((condG_A_orig, condG_B_orig),0)
|
|
else:
|
|
self.gen_conditions = condG_A_gen #self.class_B
|
|
if not (self.traverse or self.deploy):
|
|
# if the results are not good this might be the issue!!!! uncomment and update code respectively
|
|
self.cyc_conditions = condG_B_gen #self.class_A
|
|
self.orig_conditions = condG_A_orig
|
|
|
|
|
|
def update_G(self, infer=False):
|
|
# Generator optimization setp
|
|
self.optimizer_G.zero_grad()
|
|
self.get_conditions()
|
|
|
|
############### multi GPU ###############
|
|
rec_images, gen_images, cyc_images, orig_id_features, \
|
|
orig_age_features, fake_id_features, fake_age_features = \
|
|
self.netG(self.reals, self.gen_conditions, self.cyc_conditions, self.orig_conditions)
|
|
|
|
#discriminator pass
|
|
disc_out = self.netD(gen_images)
|
|
|
|
#self-reconstruction loss
|
|
if self.opt.lambda_rec > 0:
|
|
loss_G_Rec = self.criterionRec(rec_images, self.reals) * self.opt.lambda_rec
|
|
else:
|
|
loss_G_Rec = torch.zeros(1).cuda()
|
|
|
|
#cycle loss
|
|
if self.opt.lambda_cyc > 0:
|
|
loss_G_Cycle = self.criterionCycle(cyc_images, self.reals) * self.opt.lambda_cyc
|
|
else:
|
|
loss_G_Cycle = torch.zeros(1).cuda()
|
|
|
|
# identity feature loss
|
|
loss_G_identity_reconst = self.identity_reconst_criterion(fake_id_features, orig_id_features) * self.opt.lambda_id
|
|
# age feature loss
|
|
loss_G_age_reconst = self.age_reconst_criterion(fake_age_features, self.gen_conditions) * self.opt.lambda_age
|
|
# orig age feature loss
|
|
loss_G_age_reconst += self.age_reconst_criterion(orig_age_features, self.orig_conditions) * self.opt.lambda_age
|
|
|
|
# adversarial loss
|
|
target_classes = torch.cat((self.class_B,self.class_A),0)
|
|
loss_G_GAN = self.criterionGAN(disc_out, target_classes, True, is_gen=True)
|
|
|
|
# overall loss
|
|
loss_G = (loss_G_GAN + loss_G_Rec + loss_G_Cycle + \
|
|
loss_G_identity_reconst + loss_G_age_reconst).mean()
|
|
|
|
loss_G.backward()
|
|
self.optimizer_G.step()
|
|
|
|
# update exponential moving average
|
|
if self.use_moving_avg:
|
|
self.accumulate(self.g_running, self.netG)
|
|
|
|
# generate images for visdom
|
|
if infer:
|
|
if self.use_moving_avg:
|
|
with torch.no_grad():
|
|
orig_id_features_out, _ = self.g_running.encode(self.reals)
|
|
#within domain decode
|
|
if self.opt.lambda_rec > 0:
|
|
rec_images_out = self.g_running.decode(orig_id_features_out, self.orig_conditions)
|
|
|
|
#cross domain decode
|
|
gen_images_out = self.g_running.decode(orig_id_features_out, self.gen_conditions)
|
|
#encode generated
|
|
fake_id_features_out, _ = self.g_running.encode(gen_images)
|
|
#decode generated
|
|
if self.opt.lambda_cyc > 0:
|
|
cyc_images_out = self.g_running.decode(fake_id_features_out, self.cyc_conditions)
|
|
else:
|
|
gen_images_out = gen_images
|
|
if self.opt.lambda_rec > 0:
|
|
rec_images_out = rec_images
|
|
if self.opt.lambda_cyc > 0:
|
|
cyc_images_out = cyc_images
|
|
|
|
loss_dict = {'loss_G_Adv': loss_G_GAN.mean(), 'loss_G_Cycle': loss_G_Cycle.mean(),
|
|
'loss_G_Rec': loss_G_Rec.mean(), 'loss_G_identity_reconst': loss_G_identity_reconst.mean(),
|
|
'loss_G_age_reconst': loss_G_age_reconst.mean()}
|
|
|
|
return [loss_dict,
|
|
None if not infer else self.reals,
|
|
None if not infer else gen_images_out,
|
|
None if not infer else rec_images_out,
|
|
None if not infer else cyc_images_out]
|
|
|
|
|
|
def update_D(self):
|
|
# Discriminator optimization setp
|
|
self.optimizer_D.zero_grad()
|
|
self.get_conditions()
|
|
|
|
############### multi GPU ###############
|
|
_, gen_images, _, _, _, _, _ = self.netG(self.reals, self.gen_conditions, None, None, disc_pass=True)
|
|
|
|
#fake discriminator pass
|
|
fake_disc_in = gen_images.detach()
|
|
fake_disc_out = self.netD(fake_disc_in)
|
|
|
|
#real discriminator pass
|
|
real_disc_in = self.reals
|
|
|
|
# necessary for R1 regularization
|
|
real_disc_in.requires_grad_()
|
|
|
|
real_disc_out = self.netD(real_disc_in)
|
|
|
|
#Fake GAN loss
|
|
fake_target_classes = torch.cat((self.class_B,self.class_A),0)
|
|
loss_D_fake = self.criterionGAN(fake_disc_out, fake_target_classes, False, is_gen=False)
|
|
|
|
#Real GAN loss
|
|
real_target_classes = torch.cat((self.class_A,self.class_B),0)
|
|
loss_D_real = self.criterionGAN(real_disc_out, real_target_classes, True, is_gen=False)
|
|
|
|
# R1 regularization
|
|
loss_D_reg = self.R1_reg(real_disc_out, real_disc_in)
|
|
|
|
loss_D = (loss_D_fake + loss_D_real + loss_D_reg).mean()
|
|
loss_D.backward()
|
|
self.optimizer_D.step()
|
|
|
|
return {'loss_D_real': loss_D_real.mean(), 'loss_D_fake': loss_D_fake.mean(), 'loss_D_reg': loss_D_reg.mean()}
|
|
|
|
|
|
def inference(self, data):
|
|
self.set_inputs(data, mode='test')
|
|
if self.isEmpty:
|
|
return
|
|
|
|
self.numValid = self.valid.sum().item()
|
|
sz = self.reals.size()
|
|
self.fake_B = self.Tensor(self.numClasses, sz[0], sz[1], sz[2], sz[3])
|
|
self.cyc_A = self.Tensor(self.numClasses, sz[0], sz[1], sz[2], sz[3])
|
|
|
|
with torch.no_grad():
|
|
if self.traverse or self.deploy:
|
|
if self.traverse and self.compare_to_trained_outputs:
|
|
start = self.compare_to_trained_class - self.trained_class_jump
|
|
end = start + (self.trained_class_jump * 2) * 2 #arange is between [start, end), end is always omitted
|
|
self.class_B = torch.arange(start, end, step=self.trained_class_jump*2, dtype=self.class_A.dtype)
|
|
else:
|
|
self.class_B = torch.arange(self.numClasses, dtype=self.class_A.dtype)
|
|
|
|
self.get_conditions(mode='test')
|
|
|
|
self.fake_B = self.netG.infer(self.reals, self.gen_conditions, traverse=self.traverse, deploy=self.deploy, interp_step=self.opt.interp_step)
|
|
else:
|
|
for i in range(self.numClasses):
|
|
self.class_B = self.Tensor(self.numValid).long().fill_(i)
|
|
self.get_conditions(mode='test')
|
|
|
|
if self.isTrain:
|
|
self.fake_B[i, :, :, :, :] = self.g_running.infer(self.reals, self.gen_conditions)
|
|
else:
|
|
self.fake_B[i, :, :, :, :] = self.netG.infer(self.reals, self.gen_conditions)
|
|
|
|
cyc_input = self.fake_B[i, :, :, :, :]
|
|
|
|
if self.isTrain:
|
|
self.cyc_A[i, :, :, :, :] = self.g_running.infer(cyc_input, self.cyc_conditions)
|
|
else:
|
|
self.cyc_A[i, :, :, :, :] = self.netG.infer(cyc_input, self.cyc_conditions)
|
|
|
|
visuals = self.get_visuals()
|
|
|
|
return visuals
|
|
|
|
|
|
def save(self, which_epoch):
|
|
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
|
|
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
|
|
if self.use_moving_avg:
|
|
self.save_network(self.g_running, 'g_running', which_epoch, self.gpu_ids)
|
|
|
|
|
|
def update_learning_rate(self):
|
|
lr = self.old_lr * self.opt.decay_gamma
|
|
for param_group in self.optimizer_D.param_groups:
|
|
param_group['lr'] = lr
|
|
for param_group in self.optimizer_G.param_groups:
|
|
mult = param_group.get('mult', 1.0)
|
|
param_group['lr'] = lr * mult
|
|
if self.opt.verbose:
|
|
print('update learning rate: %f -> %f' % (self.old_lr, lr))
|
|
self.old_lr = lr
|
|
|
|
|
|
def get_visuals(self):
|
|
return_dicts = [OrderedDict() for i in range(self.numValid)]
|
|
|
|
real_A = util.tensor2im(self.reals.data)
|
|
fake_B_tex = util.tensor2im(self.fake_B.data)
|
|
|
|
if self.debug_mode:
|
|
rec_A_tex = util.tensor2im(self.cyc_A.data[:,:,:,:,:])
|
|
|
|
if self.numValid == 1:
|
|
real_A = np.expand_dims(real_A, axis=0)
|
|
|
|
for i in range(self.numValid):
|
|
# get the original image and the results for the current samples
|
|
curr_real_A = real_A[i, :, :, :]
|
|
real_A_img = curr_real_A[:, :, :3]
|
|
|
|
# start with age progression/regression images
|
|
if self.traverse or self.deploy:
|
|
curr_fake_B_tex = fake_B_tex
|
|
orig_dict = OrderedDict([('orig_img', real_A_img)])
|
|
else:
|
|
curr_fake_B_tex = fake_B_tex[:, i, :, :, :]
|
|
orig_dict = OrderedDict([('orig_img_cls_' + str(self.class_A[i].item()), real_A_img)])
|
|
|
|
return_dicts[i].update(orig_dict)
|
|
|
|
# set output classes numebr
|
|
if self.traverse:
|
|
out_classes = curr_fake_B_tex.shape[0]
|
|
else:
|
|
out_classes = self.numClasses
|
|
|
|
for j in range(out_classes):
|
|
fake_res_tex = curr_fake_B_tex[j, :, :, :3]
|
|
fake_dict_tex = OrderedDict([('tex_trans_to_class_' + str(j), fake_res_tex)])
|
|
return_dicts[i].update(fake_dict_tex)
|
|
|
|
if not (self.traverse or self.deploy):
|
|
if self.debug_mode:
|
|
# continue with tex reconstructions
|
|
curr_rec_A_tex = rec_A_tex[:, i, :, :, :]
|
|
orig_dict = OrderedDict([('orig_img2', real_A_img)])
|
|
return_dicts[i].update(orig_dict)
|
|
for j in range(self.numClasses):
|
|
rec_res_tex = curr_rec_A_tex[j, :, :, :3]
|
|
rec_dict_tex = OrderedDict([('tex_rec_from_class_' + str(j), rec_res_tex)])
|
|
return_dicts[i].update(rec_dict_tex)
|
|
|
|
return return_dicts
|
|
|
|
|
|
class InferenceModel(LATS):
|
|
def forward(self, data):
|
|
return self.inference(data)
|