You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
170 lines
6.2 KiB
170 lines
6.2 KiB
#!/user/bin/env python
|
|
# coding=utf-8
|
|
|
|
import uuid
|
|
from typing import Optional
|
|
|
|
import gradio as gr
|
|
import matplotlib
|
|
|
|
matplotlib.use('Agg')
|
|
import os
|
|
import sys
|
|
import yaml
|
|
from argparse import ArgumentParser
|
|
from tqdm import tqdm
|
|
|
|
import imageio
|
|
import numpy as np
|
|
from skimage.transform import resize
|
|
from skimage import img_as_ubyte
|
|
import torch
|
|
from sync_batchnorm import DataParallelWithCallback
|
|
|
|
from modules.generator import OcclusionAwareGenerator
|
|
from modules.keypoint_detector import KPDetector
|
|
from animate import normalize_kp
|
|
from scipy.spatial import ConvexHull
|
|
|
|
if sys.version_info[0] < 3:
|
|
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
|
|
|
|
|
|
def load_checkpoints(config_path, checkpoint_path, cpu=True):
|
|
with open(config_path) as f:
|
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
|
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
|
**config['model_params']['common_params'])
|
|
if not cpu:
|
|
generator.cuda()
|
|
|
|
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
|
**config['model_params']['common_params'])
|
|
if not cpu:
|
|
kp_detector.cuda()
|
|
|
|
if cpu:
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
|
else:
|
|
checkpoint = torch.load(checkpoint_path)
|
|
|
|
generator.load_state_dict(checkpoint['generator'])
|
|
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
|
|
|
if not cpu:
|
|
generator = DataParallelWithCallback(generator)
|
|
kp_detector = DataParallelWithCallback(kp_detector)
|
|
|
|
generator.eval()
|
|
kp_detector.eval()
|
|
|
|
return generator, kp_detector
|
|
|
|
|
|
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True,
|
|
cpu=True):
|
|
with torch.no_grad():
|
|
predictions = []
|
|
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
|
if not cpu:
|
|
source = source.cuda()
|
|
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
|
kp_source = kp_detector(source)
|
|
kp_driving_initial = kp_detector(driving[:, :, 0])
|
|
|
|
for frame_idx in tqdm(range(driving.shape[2])):
|
|
driving_frame = driving[:, :, frame_idx]
|
|
if not cpu:
|
|
driving_frame = driving_frame.cuda()
|
|
kp_driving = kp_detector(driving_frame)
|
|
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
|
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
|
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
|
|
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
|
|
|
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
|
return predictions
|
|
|
|
|
|
def find_best_frame(source, driving, cpu=False):
|
|
import face_alignment
|
|
|
|
def normalize_kp(kp):
|
|
kp = kp - kp.mean(axis=0, keepdims=True)
|
|
area = ConvexHull(kp[:, :2]).volume
|
|
area = np.sqrt(area)
|
|
kp[:, :2] = kp[:, :2] / area
|
|
return kp
|
|
|
|
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
|
device='cpu' if cpu else 'cuda')
|
|
kp_source = fa.get_landmarks(255 * source)[0]
|
|
kp_source = normalize_kp(kp_source)
|
|
norm = float('inf')
|
|
frame_num = 0
|
|
for i, image in tqdm(enumerate(driving)):
|
|
kp_driving = fa.get_landmarks(255 * image)[0]
|
|
kp_driving = normalize_kp(kp_driving)
|
|
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
|
if new_norm < norm:
|
|
norm = new_norm
|
|
frame_num = i
|
|
return frame_num
|
|
|
|
|
|
def h_interface(input_image: np.ndarray):
|
|
parser = ArgumentParser()
|
|
opt = parser.parse_args()
|
|
opt.config = "./config/vox-adv-256.yaml"
|
|
opt.checkpoint = "./checkpoints/vox-adv-cpk.pth.tar"
|
|
opt.source_image = input_image
|
|
opt.driving_video = "./data/chuck.mp4"
|
|
opt.result_video = "./data/result.mp4".format(uuid.uuid1().hex)
|
|
opt.relative = True
|
|
opt.adapt_scale = True
|
|
opt.cpu = True
|
|
opt.find_best_frame = False
|
|
opt.best_frame = False
|
|
|
|
source_image = opt.source_image
|
|
reader = imageio.get_reader(opt.driving_video)
|
|
fps = reader.get_meta_data()['fps']
|
|
driving_video = []
|
|
try:
|
|
for im in reader:
|
|
driving_video.append(im)
|
|
except RuntimeError:
|
|
pass
|
|
reader.close()
|
|
|
|
source_image = resize(source_image, (256, 256))[..., :3]
|
|
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
|
generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
|
|
|
|
if opt.find_best_frame or opt.best_frame is not None:
|
|
i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
|
|
print("Best frame: " + str(i))
|
|
driving_forward = driving_video[i:]
|
|
driving_backward = driving_video[:(i + 1)][::-1]
|
|
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector,
|
|
relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
|
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector,
|
|
relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
|
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
|
else:
|
|
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative,
|
|
adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
|
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
|
return opt.result_video
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo = gr.Interface(
|
|
fn=h_interface,
|
|
inputs=gr.Image(type="numpy", label="Input Image"),
|
|
outputs=gr.Video(label="Output Video")
|
|
)
|
|
|
|
demo.launch()
|