You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
4.0 KiB
102 lines
4.0 KiB
5 months ago
|
import os
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
from frames_dataset import PairedDataset
|
||
|
from logger import Logger, Visualizer
|
||
|
import imageio
|
||
|
from scipy.spatial import ConvexHull
|
||
|
import numpy as np
|
||
|
|
||
|
from sync_batchnorm import DataParallelWithCallback
|
||
|
|
||
|
|
||
|
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
|
||
|
use_relative_movement=False, use_relative_jacobian=False):
|
||
|
if adapt_movement_scale:
|
||
|
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
|
||
|
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
|
||
|
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
||
|
else:
|
||
|
adapt_movement_scale = 1
|
||
|
|
||
|
kp_new = {k: v for k, v in kp_driving.items()}
|
||
|
|
||
|
if use_relative_movement:
|
||
|
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
|
||
|
kp_value_diff *= adapt_movement_scale
|
||
|
kp_new['value'] = kp_value_diff + kp_source['value']
|
||
|
|
||
|
if use_relative_jacobian:
|
||
|
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
|
||
|
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
|
||
|
|
||
|
return kp_new
|
||
|
|
||
|
|
||
|
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
|
||
|
log_dir = os.path.join(log_dir, 'animation')
|
||
|
png_dir = os.path.join(log_dir, 'png')
|
||
|
animate_params = config['animate_params']
|
||
|
|
||
|
dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
|
||
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
||
|
|
||
|
if checkpoint is not None:
|
||
|
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
|
||
|
else:
|
||
|
raise AttributeError("Checkpoint should be specified for mode='animate'.")
|
||
|
|
||
|
if not os.path.exists(log_dir):
|
||
|
os.makedirs(log_dir)
|
||
|
|
||
|
if not os.path.exists(png_dir):
|
||
|
os.makedirs(png_dir)
|
||
|
|
||
|
if torch.cuda.is_available():
|
||
|
generator = DataParallelWithCallback(generator)
|
||
|
kp_detector = DataParallelWithCallback(kp_detector)
|
||
|
|
||
|
generator.eval()
|
||
|
kp_detector.eval()
|
||
|
|
||
|
for it, x in tqdm(enumerate(dataloader)):
|
||
|
with torch.no_grad():
|
||
|
predictions = []
|
||
|
visualizations = []
|
||
|
|
||
|
driving_video = x['driving_video']
|
||
|
source_frame = x['source_video'][:, :, 0, :, :]
|
||
|
|
||
|
kp_source = kp_detector(source_frame)
|
||
|
kp_driving_initial = kp_detector(driving_video[:, :, 0])
|
||
|
|
||
|
for frame_idx in range(driving_video.shape[2]):
|
||
|
driving_frame = driving_video[:, :, frame_idx]
|
||
|
kp_driving = kp_detector(driving_frame)
|
||
|
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
||
|
kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
|
||
|
out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)
|
||
|
|
||
|
out['kp_driving'] = kp_driving
|
||
|
out['kp_source'] = kp_source
|
||
|
out['kp_norm'] = kp_norm
|
||
|
|
||
|
del out['sparse_deformed']
|
||
|
|
||
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
||
|
|
||
|
visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
|
||
|
driving=driving_frame, out=out)
|
||
|
visualization = visualization
|
||
|
visualizations.append(visualization)
|
||
|
|
||
|
predictions = np.concatenate(predictions, axis=1)
|
||
|
result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
|
||
|
imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))
|
||
|
|
||
|
image_name = result_name + animate_params['format']
|
||
|
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
|