You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.7 KiB
68 lines
2.7 KiB
import os
|
|
from tqdm import tqdm
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from logger import Logger, Visualizer
|
|
import numpy as np
|
|
import imageio
|
|
from sync_batchnorm import DataParallelWithCallback
|
|
|
|
|
|
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
|
|
png_dir = os.path.join(log_dir, 'reconstruction/png')
|
|
log_dir = os.path.join(log_dir, 'reconstruction')
|
|
|
|
if checkpoint is not None:
|
|
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
|
|
else:
|
|
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
|
|
|
if not os.path.exists(log_dir):
|
|
os.makedirs(log_dir)
|
|
|
|
if not os.path.exists(png_dir):
|
|
os.makedirs(png_dir)
|
|
|
|
loss_list = []
|
|
if torch.cuda.is_available():
|
|
generator = DataParallelWithCallback(generator)
|
|
kp_detector = DataParallelWithCallback(kp_detector)
|
|
|
|
generator.eval()
|
|
kp_detector.eval()
|
|
|
|
for it, x in tqdm(enumerate(dataloader)):
|
|
if config['reconstruction_params']['num_videos'] is not None:
|
|
if it > config['reconstruction_params']['num_videos']:
|
|
break
|
|
with torch.no_grad():
|
|
predictions = []
|
|
visualizations = []
|
|
if torch.cuda.is_available():
|
|
x['video'] = x['video'].cuda()
|
|
kp_source = kp_detector(x['video'][:, :, 0])
|
|
for frame_idx in range(x['video'].shape[2]):
|
|
source = x['video'][:, :, 0]
|
|
driving = x['video'][:, :, frame_idx]
|
|
kp_driving = kp_detector(driving)
|
|
out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
|
|
out['kp_source'] = kp_source
|
|
out['kp_driving'] = kp_driving
|
|
del out['sparse_deformed']
|
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
|
|
|
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
|
|
driving=driving, out=out)
|
|
visualizations.append(visualization)
|
|
|
|
loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())
|
|
|
|
predictions = np.concatenate(predictions, axis=1)
|
|
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
|
|
|
|
image_name = x['name'][0] + config['reconstruction_params']['format']
|
|
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
|
|
|
|
print("Reconstruction loss: %s" % np.mean(loss_list))
|