first commit

main
Finch02567 5 months ago
parent 3b3e7008db
commit 7ea57dcc86

33
.gitignore vendored

@ -1,21 +1,18 @@
# ---> Vim # Build and Release Folders
# Swap bin-debug/
[._]*.s[a-v][a-z] bin-release/
!*.svg # comment out if you don't need vector files [Oo]bj/
[._]*.sw[a-p] [Bb]in/
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
# Session # Other files and folders
Session.vim .settings/
Sessionx.vim
# Temporary # Executables
.netrwhist *.swf
*~ *.air
# Auto-generated tag files *.ipa
tags *.apk
# Persistent undo
[._]*.un~
# Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties`
# should NOT be excluded as they contain compiler settings and other important
# information for Eclipse / Flash Builder.

@ -1,2 +1,33 @@
# DigitImage # DigitalImages
#### 介绍
这是数字图像处理后端项目
#### 软件架构
Flask框架搭建
#### 安装教程
1. 导包 参照requirements.txt
2. 下载vox-adv-cpk.pth.tar文件 链接https://pan.baidu.com/s/1_HUbCt7TZO_k8Kp19o8oKw 提取码6beh
3. 在firstordermodel目录下新建checkpoints文件夹将下载好的.tar文件放于此处
4. 在导好包的环境中输入 python app.py 启动后端
#### 参与贡献
1. Fork 本仓库
2. 新建 Feat_xxx 分支
3. 提交代码
4. 新建 Pull Request
#### 特技
1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md
2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com)
3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目
4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目
5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help)
6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/)

172
app.py

@ -0,0 +1,172 @@
from flask import Flask, request, send_file, jsonify, send_from_directory
from flask_cors import CORS
from PIL import Image, ImageFilter, ImageEnhance, ImageOps
import io
import uuid
import os
import sys
import yaml
import torch
import imageio
import imageio_ffmpeg as ffmpeg
import numpy as np
from moviepy.editor import VideoFileClip
from skimage.transform import resize
from skimage import img_as_ubyte
from tqdm import tqdm
sys.path.append(os.path.abspath('./firstordermodel'))
sys.path.append(".")
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
app = Flask(__name__)
CORS(app)
@app.route('/process-image', methods=['POST'])
def process_image():
file = request.files['file']
operation = request.form['operation']
parameter = request.form['parameter']
image = Image.open(file.stream)
if operation == 'rotate':
image = image.rotate(float(parameter))
elif operation == 'flip':
image = ImageOps.flip(image)
elif operation == 'scale':
scale_factor = float(parameter)
w, h = image.size
new_width = int(w*scale_factor)
new_height = int(h*scale_factor)
image = image.resize((new_width,new_height), Image.ANTIALIAS)
# blank=(new_width-new_height)*scale_factor
# image=image.crop((0,-blank,new_width,new_width-blank))
elif operation == 'filter':
if parameter == 'BLUR':
image = image.filter(ImageFilter.BLUR)
elif parameter == 'EMBOSS':
image = image.filter(ImageFilter.EMBOSS)
elif parameter == 'CONTOUR':
image = image.filter(ImageFilter.CONTOUR)
elif parameter == 'SHARPEN':
image = image.filter(ImageFilter.SHARPEN)
elif operation == 'color_adjust':
r, g, b = map(float, parameter.split(','))
r = r if r else 1.0
g = g if g else 1.0
b = b if b else 1.0
r_channel, g_channel, b_channel = image.split()
r_channel = r_channel.point(lambda i: i * r)
g_channel = g_channel.point(lambda i: i * g)
b_channel = b_channel.point(lambda i: i * b)
image = Image.merge('RGB', (r_channel, g_channel, b_channel))
elif operation == 'contrast':
enhancer = ImageEnhance.Contrast(image)
image = enhancer.enhance(2)
elif operation == 'smooth':
image = image.filter(ImageFilter.SMOOTH)
img_io = io.BytesIO()
image.save(img_io, 'JPEG')
img_io.seek(0)
return send_file(img_io, mimetype='image/jpeg')
def load_checkpoints(config_path, checkpoint_path, cpu=True):
with open(config_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if not cpu:
generator.cuda()
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if not cpu:
kp_detector.cuda()
if cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(checkpoint_path)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
if not cpu:
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
return generator, kp_detector
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True):
with torch.no_grad():
predictions = []
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
if not cpu:
source = source.cuda()
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[:, :, 0])
for frame_idx in tqdm(range(driving.shape[2])):
driving_frame = driving[:, :, frame_idx]
if not cpu:
driving_frame = driving_frame.cuda()
kp_driving = kp_detector(driving_frame)
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
return predictions
@app.route('/motion-drive', methods=['POST'])
def motion_drive():
image_file = request.files['image']
video_file = request.files['video']
source_image = imageio.imread(image_file)
# 保存视频文件到临时路径
video_path = f"./data/{uuid.uuid1().hex}.mp4"
video_file.save(video_path)
reader = imageio.get_reader(video_path, 'ffmpeg')
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
config_path = "./firstordermodel/config/vox-adv-256.yaml"
checkpoint_path = "./firstordermodel/checkpoints/vox-adv-cpk.pth.tar"
generator, kp_detector = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path, cpu=True)
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True)
result_filename = f"result_{uuid.uuid1().hex}.mp4"
result_path = os.path.join('data', result_filename)
imageio.mimsave(result_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
return jsonify({"video_url": f"/data/{result_filename}"})
@app.route('/data/<path:filename>', methods=['GET'])
def download_file(filename):
return send_from_directory('data', filename)
if __name__ == "__main__":
app.run(debug=True)

@ -0,0 +1,3 @@
/venv
.git
__pycache__

@ -0,0 +1,3 @@
/.vscode
__pycache__
/venv

@ -0,0 +1,13 @@
FROM nvcr.io/nvidia/pytorch:21.02-py3
RUN DEBIAN_FRONTEND=noninteractive apt-get -qq update \
&& DEBIAN_FRONTEND=noninteractive apt-get -qqy install python3-pip ffmpeg git less nano libsm6 libxext6 libxrender-dev \
&& rm -rf /var/lib/apt/lists/*
COPY . /app/
WORKDIR /app
RUN pip3 install --upgrade pip
RUN pip3 install \
git+https://github.com/1adrianb/face-alignment \
-r requirements.txt

@ -0,0 +1,101 @@
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from frames_dataset import PairedDataset
from logger import Logger, Visualizer
import imageio
from scipy.spatial import ConvexHull
import numpy as np
from sync_batchnorm import DataParallelWithCallback
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
use_relative_movement=False, use_relative_jacobian=False):
if adapt_movement_scale:
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
else:
adapt_movement_scale = 1
kp_new = {k: v for k, v in kp_driving.items()}
if use_relative_movement:
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
kp_value_diff *= adapt_movement_scale
kp_new['value'] = kp_value_diff + kp_source['value']
if use_relative_jacobian:
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
return kp_new
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
log_dir = os.path.join(log_dir, 'animation')
png_dir = os.path.join(log_dir, 'png')
animate_params = config['animate_params']
dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
if checkpoint is not None:
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
else:
raise AttributeError("Checkpoint should be specified for mode='animate'.")
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(png_dir):
os.makedirs(png_dir)
if torch.cuda.is_available():
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
for it, x in tqdm(enumerate(dataloader)):
with torch.no_grad():
predictions = []
visualizations = []
driving_video = x['driving_video']
source_frame = x['source_video'][:, :, 0, :, :]
kp_source = kp_detector(source_frame)
kp_driving_initial = kp_detector(driving_video[:, :, 0])
for frame_idx in range(driving_video.shape[2]):
driving_frame = driving_video[:, :, frame_idx]
kp_driving = kp_detector(driving_frame)
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)
out['kp_driving'] = kp_driving
out['kp_source'] = kp_source
out['kp_norm'] = kp_norm
del out['sparse_deformed']
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
driving=driving_frame, out=out)
visualization = visualization
visualizations.append(visualization)
predictions = np.concatenate(predictions, axis=1)
result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))
image_name = result_name + animate_params['format']
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

@ -0,0 +1,345 @@
"""
Code from https://github.com/hassony2/torch_videovision
"""
import numbers
import random
import numpy as np
import PIL
from skimage.transform import resize, rotate
from numpy import pad
import torchvision
import warnings
from skimage import img_as_ubyte, img_as_float
def crop_clip(clip, min_h, min_w, h, w):
if isinstance(clip[0], np.ndarray):
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
cropped = [
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return cropped
def pad_clip(clip, h, w):
im_h, im_w = clip[0].shape[:2]
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
def resize_clip(clip, size, interpolation='bilinear'):
if isinstance(clip[0], np.ndarray):
if isinstance(size, numbers.Number):
im_h, im_w, im_c = clip[0].shape
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
scaled = [
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
mode='constant', anti_aliasing=True) for img in clip
]
elif isinstance(clip[0], PIL.Image.Image):
if isinstance(size, numbers.Number):
im_w, im_h = clip[0].size
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
if interpolation == 'bilinear':
pil_inter = PIL.Image.NEAREST
else:
pil_inter = PIL.Image.BILINEAR
scaled = [img.resize(size, pil_inter) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return scaled
def get_resize_sizes(im_h, im_w, size):
if im_w < im_h:
ow = size
oh = int(size * im_h / im_w)
else:
oh = size
ow = int(size * im_w / im_h)
return oh, ow
class RandomFlip(object):
def __init__(self, time_flip=False, horizontal_flip=False):
self.time_flip = time_flip
self.horizontal_flip = horizontal_flip
def __call__(self, clip):
if random.random() < 0.5 and self.time_flip:
return clip[::-1]
if random.random() < 0.5 and self.horizontal_flip:
return [np.fliplr(img) for img in clip]
return clip
class RandomResize(object):
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
The larger the original image is, the more times it takes to
interpolate
Args:
interpolation (str): Can be one of 'nearest', 'bilinear'
defaults to nearest
size (tuple): (widht, height)
"""
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
self.ratio = ratio
self.interpolation = interpolation
def __call__(self, clip):
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
new_w = int(im_w * scaling_factor)
new_h = int(im_h * scaling_factor)
new_size = (new_w, new_h)
resized = resize_clip(
clip, new_size, interpolation=self.interpolation)
return resized
class RandomCrop(object):
"""Extract random crop at the same location for a list of videos
Args:
size (sequence or int): Desired output size for the
crop in format (h, w)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
size = (size, size)
self.size = size
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of videos to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of videos
"""
h, w = self.size
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
clip = pad_clip(clip, h, w)
im_h, im_w = clip.shape[1:3]
x1 = 0 if h == im_h else random.randint(0, im_w - w)
y1 = 0 if w == im_w else random.randint(0, im_h - h)
cropped = crop_clip(clip, y1, x1, h, w)
return cropped
class RandomRotation(object):
"""Rotate entire clip randomly by a random angle within
given bounds
Args:
degrees (sequence or int): Range of degrees to select from
If degrees is a number instead of sequence like (min, max),
the range of degrees, will be (-degrees, +degrees).
"""
def __init__(self, degrees):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError('If degrees is a single number,'
'must be positive')
degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError('If degrees is a sequence,'
'it must be of len 2.')
self.degrees = degrees
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of videos to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of videos
"""
angle = random.uniform(self.degrees[0], self.degrees[1])
if isinstance(clip[0], np.ndarray):
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
rotated = [img.rotate(angle) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return rotated
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation and hue of the clip
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
def get_params(self, brightness, contrast, saturation, hue):
if brightness > 0:
brightness_factor = random.uniform(
max(0, 1 - brightness), 1 + brightness)
else:
brightness_factor = None
if contrast > 0:
contrast_factor = random.uniform(
max(0, 1 - contrast), 1 + contrast)
else:
contrast_factor = None
if saturation > 0:
saturation_factor = random.uniform(
max(0, 1 - saturation), 1 + saturation)
else:
saturation_factor = None
if hue > 0:
hue_factor = random.uniform(-hue, hue)
else:
hue_factor = None
return brightness_factor, contrast_factor, saturation_factor, hue_factor
def __call__(self, clip):
"""
Args:
clip (list): list of PIL.Image
Returns:
list PIL.Image : list of transformed PIL.Image
"""
if isinstance(clip[0], np.ndarray):
brightness, contrast, saturation, hue = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
# Create img transform function sequence
img_transforms = []
if brightness is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
if saturation is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
if hue is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
if contrast is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
random.shuffle(img_transforms)
img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
img_as_float]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
jittered_clip = []
for img in clip:
jittered_img = img
for func in img_transforms:
jittered_img = func(jittered_img)
jittered_clip.append(jittered_img.astype('float32'))
elif isinstance(clip[0], PIL.Image.Image):
brightness, contrast, saturation, hue = self.get_params(
self.brightness, self.contrast, self.saturation, self.hue)
# Create img transform function sequence
img_transforms = []
if brightness is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
if saturation is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
if hue is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
if contrast is not None:
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
random.shuffle(img_transforms)
# Apply to all videos
jittered_clip = []
for img in clip:
for func in img_transforms:
jittered_img = func(img)
jittered_clip.append(jittered_img)
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return jittered_clip
class AllAugmentationTransform:
def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
self.transforms = []
if flip_param is not None:
self.transforms.append(RandomFlip(**flip_param))
if rotation_param is not None:
self.transforms.append(RandomRotation(**rotation_param))
if resize_param is not None:
self.transforms.append(RandomResize(**resize_param))
if crop_param is not None:
self.transforms.append(RandomCrop(**crop_param))
if jitter_param is not None:
self.transforms.append(ColorJitter(**jitter_param))
def __call__(self, clip):
for t in self.transforms:
clip = t(clip)
return clip

@ -0,0 +1,82 @@
dataset_params:
root_dir: data/bair
frame_shape: [256, 256, 3]
id_sampling: False
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
num_epochs: 20
num_repeats: 1
epoch_milestones: [12, 18]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 36
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 10
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'

@ -0,0 +1,77 @@
dataset_params:
root_dir: data/fashion-png
frame_shape: [256, 256, 3]
id_sampling: False
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
hue: 0.1
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
train_params:
num_epochs: 100
num_repeats: 50
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 27
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 50
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'

@ -0,0 +1,84 @@
dataset_params:
root_dir: data/moving-gif
frame_shape: [256, 256, 3]
id_sampling: False
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
crop_param:
size: [256, 256]
resize_param:
ratio: [0.9, 1.1]
jitter_param:
hue: 0.5
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
single_jacobian_map: True
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
num_epochs: 100
num_repeats: 25
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 36
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 100
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'

@ -0,0 +1,76 @@
dataset_params:
root_dir: data/nemo-png
frame_shape: [256, 256, 3]
id_sampling: False
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
num_epochs: 100
num_repeats: 8
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 36
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 50
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'

@ -0,0 +1,157 @@
# Dataset parameters
# Each dataset should contain 2 folders train and test
# Each video can be represented as:
# - an image of concatenated frames
# - '.mp4' or '.gif'
# - folder with all frames from a specific video
# In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
# format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
# video id.
dataset_params:
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
root_dir: data/taichi-png
# Image shape, needed for staked .png format.
frame_shape: [256, 256, 3]
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
# In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
# If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
id_sampling: True
# List with pairs for animation, None for random pairs
pairs_list: data/taichi256.csv
# Augmentation parameters see augmentation.py for all posible augmentations
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
# Defines model architecture
model_params:
common_params:
# Number of keypoint
num_kp: 10
# Number of channels per image
num_channels: 3
# Using first or zero order model
estimate_jacobian: True
kp_detector_params:
# Softmax temperature for keypoint heatmaps
temperature: 0.1
# Number of features mutliplier
block_expansion: 32
# Maximum allowed number of features
max_features: 1024
# Number of block in Unet. Can be increased or decreased depending or resolution.
num_blocks: 5
# Keypioint is predicted on smaller images for better performance,
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
scale_factor: 0.25
generator_params:
# Number of features mutliplier
block_expansion: 64
# Maximum allowed number of features
max_features: 512
# Number of downsampling blocks in Jonson architecture.
# Can be increased or decreased depending or resolution.
num_down_blocks: 2
# Number of ResBlocks in Jonson architecture.
num_bottleneck_blocks: 6
# Use occlusion map or not
estimate_occlusion_map: True
dense_motion_params:
# Number of features mutliplier
block_expansion: 64
# Maximum allowed number of features
max_features: 1024
# Number of block in Unet. Can be increased or decreased depending or resolution.
num_blocks: 5
# Dense motion is predicted on smaller images for better performance,
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
scale_factor: 0.25
discriminator_params:
# Discriminator can be multiscale, if you want 2 discriminator on original
# resolution and half of the original, specify scales: [1, 0.5]
scales: [1]
# Number of features mutliplier
block_expansion: 32
# Maximum allowed number of features
max_features: 512
# Number of blocks. Can be increased or decreased depending or resolution.
num_blocks: 4
# Parameters of training
train_params:
# Number of training epochs
num_epochs: 100
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
# Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
num_repeats: 150
# Drop learning rate by 10 times after this epochs
epoch_milestones: [60, 90]
# Initial learing rate for all modules
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 30
# Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
# than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
scales: [1, 0.5, 0.25, 0.125]
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
checkpoint_freq: 50
# Parameters of transform for equivariance loss
transform_params:
# Sigma for affine part
sigma_affine: 0.05
# Sigma for deformation part
sigma_tps: 0.005
# Number of point in the deformation grid
points_tps: 5
loss_weights:
# Weight for LSGAN loss in generator, 0 for no adversarial loss.
generator_gan: 0
# Weight for LSGAN loss in discriminator
discriminator_gan: 1
# Weights for feature matching loss, the number should be the same as number of blocks in discriminator.
feature_matching: [10, 10, 10, 10]
# Weights for perceptual loss.
perceptual: [10, 10, 10, 10, 10]
# Weights for value equivariance.
equivariance_value: 10
# Weights for jacobian equivariance.
equivariance_jacobian: 10
# Parameters of reconstruction
reconstruction_params:
# Maximum number of videos for reconstruction
num_videos: 1000
# Format for visualization, note that results will be also stored in staked .png.
format: '.mp4'
# Parameters of animation
animate_params:
# Maximum number of pairs for animation, the pairs will be either taken from pairs_list or random.
num_pairs: 50
# Format for visualization, note that results will be also stored in staked .png.
format: '.mp4'
# Normalization of diriving keypoints
normalization_params:
# Increase or decrease relative movement scale depending on the size of the object
adapt_movement_scale: False
# Apply only relative displacement of the keypoint
use_relative_movement: True
# Apply only relative change in jacobian
use_relative_jacobian: True
# Visualization parameters
visualizer_params:
# Draw keypoints of this size, increase or decrease depending on resolution
kp_size: 5
# Draw white border around images
draw_border: True
# Color map for keypoints
colormap: 'gist_rainbow'

@ -0,0 +1,150 @@
# Dataset parameters
dataset_params:
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
root_dir: data/taichi-png
# Image shape, needed for staked .png format.
frame_shape: [256, 256, 3]
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
# In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
# If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
id_sampling: True
# List with pairs for animation, None for random pairs
pairs_list: data/taichi256.csv
# Augmentation parameters see augmentation.py for all posible augmentations
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
# Defines model architecture
model_params:
common_params:
# Number of keypoint
num_kp: 10
# Number of channels per image
num_channels: 3
# Using first or zero order model
estimate_jacobian: True
kp_detector_params:
# Softmax temperature for keypoint heatmaps
temperature: 0.1
# Number of features mutliplier
block_expansion: 32
# Maximum allowed number of features
max_features: 1024
# Number of block in Unet. Can be increased or decreased depending or resolution.
num_blocks: 5
# Keypioint is predicted on smaller images for better performance,
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
scale_factor: 0.25
generator_params:
# Number of features mutliplier
block_expansion: 64
# Maximum allowed number of features
max_features: 512
# Number of downsampling blocks in Jonson architecture.
# Can be increased or decreased depending or resolution.
num_down_blocks: 2
# Number of ResBlocks in Jonson architecture.
num_bottleneck_blocks: 6
# Use occlusion map or not
estimate_occlusion_map: True
dense_motion_params:
# Number of features mutliplier
block_expansion: 64
# Maximum allowed number of features
max_features: 1024
# Number of block in Unet. Can be increased or decreased depending or resolution.
num_blocks: 5
# Dense motion is predicted on smaller images for better performance,
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
scale_factor: 0.25
discriminator_params:
# Discriminator can be multiscale, if you want 2 discriminator on original
# resolution and half of the original, specify scales: [1, 0.5]
scales: [1]
# Number of features mutliplier
block_expansion: 32
# Maximum allowed number of features
max_features: 512
# Number of blocks. Can be increased or decreased depending or resolution.
num_blocks: 4
use_kp: True
# Parameters of training
train_params:
# Number of training epochs
num_epochs: 150
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
# Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
num_repeats: 150
# Drop learning rate by 10 times after this epochs
epoch_milestones: []
# Initial learing rate for all modules
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 0
batch_size: 27
# Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
# than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
scales: [1, 0.5, 0.25, 0.125]
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
checkpoint_freq: 50
# Parameters of transform for equivariance loss
transform_params:
# Sigma for affine part
sigma_affine: 0.05
# Sigma for deformation part
sigma_tps: 0.005
# Number of point in the deformation grid
points_tps: 5
loss_weights:
# Weight for LSGAN loss in generator
generator_gan: 1
# Weight for LSGAN loss in discriminator
discriminator_gan: 1
# Weights for feature matching loss, the number should be the same as number of blocks in discriminator.
feature_matching: [10, 10, 10, 10]
# Weights for perceptual loss.
perceptual: [10, 10, 10, 10, 10]
# Weights for value equivariance.
equivariance_value: 10
# Weights for jacobian equivariance.
equivariance_jacobian: 10
# Parameters of reconstruction
reconstruction_params:
# Maximum number of videos for reconstruction
num_videos: 1000
# Format for visualization, note that results will be also stored in staked .png.
format: '.mp4'
# Parameters of animation
animate_params:
# Maximum number of pairs for animation, the pairs will be either taken from pairs_list or random.
num_pairs: 50
# Format for visualization, note that results will be also stored in staked .png.
format: '.mp4'
# Normalization of diriving keypoints
normalization_params:
# Increase or decrease relative movement scale depending on the size of the object
adapt_movement_scale: False
# Apply only relative displacement of the keypoint
use_relative_movement: True
# Apply only relative change in jacobian
use_relative_jacobian: True
# Visualization parameters
visualizer_params:
# Draw keypoints of this size, increase or decrease depending on resolution
kp_size: 5
# Draw white border around images
draw_border: True
# Color map for keypoints
colormap: 'gist_rainbow'

@ -0,0 +1,83 @@
dataset_params:
root_dir: data/vox-png
frame_shape: [256, 256, 3]
id_sampling: True
pairs_list: data/vox256.csv
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
num_epochs: 100
num_repeats: 75
epoch_milestones: [60, 90]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 40
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 50
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 0
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'

@ -0,0 +1,84 @@
dataset_params:
root_dir: data/vox-png
frame_shape: [256, 256, 3]
id_sampling: True
pairs_list: data/vox256.csv
augmentation_params:
flip_param:
horizontal_flip: True
time_flip: True
jitter_param:
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
model_params:
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
use_kp: True
train_params:
num_epochs: 150
num_repeats: 75
epoch_milestones: []
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4
batch_size: 36
scales: [1, 0.5, 0.25, 0.125]
checkpoint_freq: 50
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10
reconstruction_params:
num_videos: 1000
format: '.mp4'
animate_params:
num_pairs: 50
format: '.mp4'
normalization_params:
adapt_movement_scale: False
use_relative_movement: True
use_relative_jacobian: True
visualizer_params:
kp_size: 5
draw_border: True
colormap: 'gist_rainbow'

@ -0,0 +1,158 @@
import face_alignment
import skimage.io
import numpy
from argparse import ArgumentParser
from skimage import img_as_ubyte
from skimage.transform import resize
from tqdm import tqdm
import os
import imageio
import numpy as np
import warnings
warnings.filterwarnings("ignore")
def extract_bbox(frame, fa):
if max(frame.shape[0], frame.shape[1]) > 640:
scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0
frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor)))
frame = img_as_ubyte(frame)
else:
scale_factor = 1
frame = frame[..., :3]
bboxes = fa.face_detector.detect_from_image(frame[..., ::-1])
if len(bboxes) == 0:
return []
return np.array(bboxes)[:, :-1] * scale_factor
def bb_intersection_over_union(boxA, boxB):
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
iou = interArea / float(boxAArea + boxBArea - interArea)
return iou
def join(tube_bbox, bbox):
xA = min(tube_bbox[0], bbox[0])
yA = min(tube_bbox[1], bbox[1])
xB = max(tube_bbox[2], bbox[2])
yB = max(tube_bbox[3], bbox[3])
return (xA, yA, xB, yB)
def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1):
left, top, right, bot = tube_bbox
width = right - left
height = bot - top
#Computing aspect preserving bbox
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
left = int(left - width_increase * width)
top = int(top - height_increase * height)
right = int(right + width_increase * width)
bot = int(bot + height_increase * height)
top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1])
h, w = bot - top, right - left
start = start / fps
end = end / fps
time = end - start
scale = f'{image_shape[0]}:{image_shape[1]}'
return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4'
def compute_bbox_trajectories(trajectories, fps, frame_shape, args):
commands = []
for i, (bbox, tube_bbox, start, end) in enumerate(trajectories):
if (end - start) > args.min_frames:
command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase)
commands.append(command)
return commands
def process_video(args):
device = 'cpu' if args.cpu else 'cuda'
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
video = imageio.get_reader(args.inp)
trajectories = []
previous_frame = None
fps = video.get_meta_data()['fps']
commands = []
try:
for i, frame in tqdm(enumerate(video)):
frame_shape = frame.shape
bboxes = extract_bbox(frame, fa)
## For each trajectory check the criterion
not_valid_trajectories = []
valid_trajectories = []
for trajectory in trajectories:
tube_bbox = trajectory[0]
intersection = 0
for bbox in bboxes:
intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox))
if intersection > args.iou_with_initial:
valid_trajectories.append(trajectory)
else:
not_valid_trajectories.append(trajectory)
commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args)
trajectories = valid_trajectories
## Assign bbox to trajectories, create new trajectories
for bbox in bboxes:
intersection = 0
current_trajectory = None
for trajectory in trajectories:
tube_bbox = trajectory[0]
current_intersection = bb_intersection_over_union(tube_bbox, bbox)
if intersection < current_intersection and current_intersection > args.iou_with_initial:
intersection = bb_intersection_over_union(tube_bbox, bbox)
current_trajectory = trajectory
## Create new trajectory
if current_trajectory is None:
trajectories.append([bbox, bbox, i, i])
else:
current_trajectory[3] = i
current_trajectory[1] = join(current_trajectory[1], bbox)
except IndexError as e:
raise (e)
commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args)
return commands
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
help="Image shape")
parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount')
parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox")
parser.add_argument("--inp", required=True, help='Input image or video')
parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames')
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
args = parser.parse_args()
commands = process_video(args)
for command in commands:
print (command)

@ -0,0 +1,51 @@
distance,source,driving,frame
0,000054.mp4,000048.mp4,0
0,000050.mp4,000063.mp4,0
0,000073.mp4,000007.mp4,0
0,000021.mp4,000010.mp4,0
0,000084.mp4,000046.mp4,0
0,000031.mp4,000102.mp4,0
0,000029.mp4,000111.mp4,0
0,000090.mp4,000112.mp4,0
0,000039.mp4,000010.mp4,0
0,000008.mp4,000069.mp4,0
0,000068.mp4,000076.mp4,0
0,000051.mp4,000052.mp4,0
0,000022.mp4,000098.mp4,0
0,000096.mp4,000032.mp4,0
0,000032.mp4,000099.mp4,0
0,000006.mp4,000053.mp4,0
0,000098.mp4,000020.mp4,0
0,000029.mp4,000066.mp4,0
0,000022.mp4,000007.mp4,0
0,000027.mp4,000065.mp4,0
0,000026.mp4,000059.mp4,0
0,000015.mp4,000112.mp4,0
0,000086.mp4,000123.mp4,0
0,000103.mp4,000052.mp4,0
0,000123.mp4,000103.mp4,0
0,000051.mp4,000005.mp4,0
0,000062.mp4,000125.mp4,0
0,000126.mp4,000111.mp4,0
0,000066.mp4,000090.mp4,0
0,000075.mp4,000106.mp4,0
0,000020.mp4,000010.mp4,0
0,000076.mp4,000028.mp4,0
0,000062.mp4,000002.mp4,0
0,000095.mp4,000127.mp4,0
0,000113.mp4,000072.mp4,0
0,000027.mp4,000104.mp4,0
0,000054.mp4,000124.mp4,0
0,000019.mp4,000089.mp4,0
0,000052.mp4,000072.mp4,0
0,000108.mp4,000033.mp4,0
0,000044.mp4,000118.mp4,0
0,000029.mp4,000086.mp4,0
0,000068.mp4,000066.mp4,0
0,000014.mp4,000036.mp4,0
0,000053.mp4,000071.mp4,0
0,000022.mp4,000094.mp4,0
0,000000.mp4,000121.mp4,0
0,000071.mp4,000079.mp4,0
0,000127.mp4,000005.mp4,0
0,000085.mp4,000023.mp4,0
1 distance source driving frame
2 0 000054.mp4 000048.mp4 0
3 0 000050.mp4 000063.mp4 0
4 0 000073.mp4 000007.mp4 0
5 0 000021.mp4 000010.mp4 0
6 0 000084.mp4 000046.mp4 0
7 0 000031.mp4 000102.mp4 0
8 0 000029.mp4 000111.mp4 0
9 0 000090.mp4 000112.mp4 0
10 0 000039.mp4 000010.mp4 0
11 0 000008.mp4 000069.mp4 0
12 0 000068.mp4 000076.mp4 0
13 0 000051.mp4 000052.mp4 0
14 0 000022.mp4 000098.mp4 0
15 0 000096.mp4 000032.mp4 0
16 0 000032.mp4 000099.mp4 0
17 0 000006.mp4 000053.mp4 0
18 0 000098.mp4 000020.mp4 0
19 0 000029.mp4 000066.mp4 0
20 0 000022.mp4 000007.mp4 0
21 0 000027.mp4 000065.mp4 0
22 0 000026.mp4 000059.mp4 0
23 0 000015.mp4 000112.mp4 0
24 0 000086.mp4 000123.mp4 0
25 0 000103.mp4 000052.mp4 0
26 0 000123.mp4 000103.mp4 0
27 0 000051.mp4 000005.mp4 0
28 0 000062.mp4 000125.mp4 0
29 0 000126.mp4 000111.mp4 0
30 0 000066.mp4 000090.mp4 0
31 0 000075.mp4 000106.mp4 0
32 0 000020.mp4 000010.mp4 0
33 0 000076.mp4 000028.mp4 0
34 0 000062.mp4 000002.mp4 0
35 0 000095.mp4 000127.mp4 0
36 0 000113.mp4 000072.mp4 0
37 0 000027.mp4 000104.mp4 0
38 0 000054.mp4 000124.mp4 0
39 0 000019.mp4 000089.mp4 0
40 0 000052.mp4 000072.mp4 0
41 0 000108.mp4 000033.mp4 0
42 0 000044.mp4 000118.mp4 0
43 0 000029.mp4 000086.mp4 0
44 0 000068.mp4 000066.mp4 0
45 0 000014.mp4 000036.mp4 0
46 0 000053.mp4 000071.mp4 0
47 0 000022.mp4 000094.mp4 0
48 0 000000.mp4 000121.mp4 0
49 0 000071.mp4 000079.mp4 0
50 0 000127.mp4 000005.mp4 0
51 0 000085.mp4 000023.mp4 0

Binary file not shown.

@ -0,0 +1,18 @@
# TaiChi dataset
The scripst for loading the TaiChi dataset.
We provide only the id of the corresponding video and the bounding box. Following script will download videos from youtube and crop them according to the provided bounding boxes.
1) Load youtube-dl:
```
wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl
chmod a+rx youtube-dl
```
2) Run script to download videos, there are 2 formats that can be used for storing videos one is .mp4 and another is folder with .png images. While .png images occupy significantly more space, the format is loss-less and have better i/o performance when training.
```
python load_videos.py --metadata taichi-metadata.csv --format .mp4 --out_folder taichi --workers 8
```
select number of workers based on number of cpu avaliable. Note .png format take aproximatly 80GB.

@ -0,0 +1,113 @@
import numpy as np
import pandas as pd
import imageio
import os
import subprocess
from multiprocessing import Pool
from itertools import cycle
import warnings
import glob
import time
from tqdm import tqdm
from argparse import ArgumentParser
from skimage import img_as_ubyte
from skimage.transform import resize
warnings.filterwarnings("ignore")
DEVNULL = open(os.devnull, 'wb')
def save(path, frames, format):
if format == '.mp4':
imageio.mimsave(path, frames)
elif format == '.png':
if os.path.exists(path):
print ("Warning: skiping video %s" % os.path.basename(path))
return
else:
os.makedirs(path)
for j, frame in enumerate(frames):
imageio.imsave(os.path.join(path, str(j).zfill(7) + '.png'), frames[j])
else:
print ("Unknown format %s" % format)
exit()
def download(video_id, args):
video_path = os.path.join(args.video_folder, video_id + ".mp4")
subprocess.call([args.youtube, '-f', "''best/mp4''", '--write-auto-sub', '--write-sub',
'--sub-lang', 'en', '--skip-unavailable-fragments',
"https://www.youtube.com/watch?v=" + video_id, "--output",
video_path], stdout=DEVNULL, stderr=DEVNULL)
return video_path
def run(data):
video_id, args = data
if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')):
download(video_id.split('#')[0], args)
if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')):
print ('Can not load video %s, broken link' % video_id.split('#')[0])
return
reader = imageio.get_reader(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4'))
fps = reader.get_meta_data()['fps']
df = pd.read_csv(args.metadata)
df = df[df['video_id'] == video_id]
all_chunks_dict = [{'start': df['start'].iloc[j], 'end': df['end'].iloc[j],
'bbox': list(map(int, df['bbox'].iloc[j].split('-'))), 'frames':[]} for j in range(df.shape[0])]
ref_fps = df['fps'].iloc[0]
ref_height = df['height'].iloc[0]
ref_width = df['width'].iloc[0]
partition = df['partition'].iloc[0]
try:
for i, frame in enumerate(reader):
for entry in all_chunks_dict:
if (i * ref_fps >= entry['start'] * fps) and (i * ref_fps < entry['end'] * fps):
left, top, right, bot = entry['bbox']
left = int(left / (ref_width / frame.shape[1]))
top = int(top / (ref_height / frame.shape[0]))
right = int(right / (ref_width / frame.shape[1]))
bot = int(bot / (ref_height / frame.shape[0]))
crop = frame[top:bot, left:right]
if args.image_shape is not None:
crop = img_as_ubyte(resize(crop, args.image_shape, anti_aliasing=True))
entry['frames'].append(crop)
except imageio.core.format.CannotReadFrameError:
None
for entry in all_chunks_dict:
first_part = '#'.join(video_id.split('#')[::-1])
path = first_part + '#' + str(entry['start']).zfill(6) + '#' + str(entry['end']).zfill(6) + '.mp4'
save(os.path.join(args.out_folder, partition, path), entry['frames'], args.format)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--video_folder", default='youtube-taichi', help='Path to youtube videos')
parser.add_argument("--metadata", default='taichi-metadata-new.csv', help='Path to metadata')
parser.add_argument("--out_folder", default='taichi-png', help='Path to output')
parser.add_argument("--format", default='.png', help='Storing format')
parser.add_argument("--workers", default=1, type=int, help='Number of workers')
parser.add_argument("--youtube", default='./youtube-dl', help='Path to youtube-dl')
parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
help="Image shape, None for no resize")
args = parser.parse_args()
if not os.path.exists(args.video_folder):
os.makedirs(args.video_folder)
if not os.path.exists(args.out_folder):
os.makedirs(args.out_folder)
for partition in ['test', 'train']:
if not os.path.exists(os.path.join(args.out_folder, partition)):
os.makedirs(os.path.join(args.out_folder, partition))
df = pd.read_csv(args.metadata)
video_ids = set(df['video_id'])
pool = Pool(processes=args.workers)
args_list = cycle([args])
for chunks_data in tqdm(pool.imap_unordered(run, zip(video_ids, args_list))):
None

File diff suppressed because it is too large Load Diff

@ -0,0 +1,51 @@
distance,source,driving,frame
3.54437869822485,ab28GAufK8o#000261#000596.mp4,aDyyTMUBoLE#000164#000351.mp4,0
2.8639053254437887,DMEaUoA8EPE#000028#000354.mp4,0Q914by5A98#010440#010764.mp4,0
2.153846153846153,L82WHgYRq6I#000021#000479.mp4,0Q914by5A98#010440#010764.mp4,0
2.8994082840236666,oNkBx4CZuEg#000000#001024.mp4,DMEaUoA8EPE#000028#000354.mp4,0
3.3905325443786998,ab28GAufK8o#000261#000596.mp4,uEqWZ9S_-Lw#000089#000581.mp4,0
3.266272189349112,0Q914by5A98#010440#010764.mp4,ab28GAufK8o#000261#000596.mp4,0
2.7514792899408294,WlDYrq8K6nk#008186#008512.mp4,OiblkvkAHWM#014331#014459.mp4,0
3.0177514792899407,oNkBx4CZuEg#001024#002048.mp4,aDyyTMUBoLE#000375#000518.mp4,0
3.4792899408284064,aDyyTMUBoLE#000164#000351.mp4,w2awOCDRtrc#001729#002009.mp4,0
2.769230769230769,oNkBx4CZuEg#000000#001024.mp4,L82WHgYRq6I#000021#000479.mp4,0
3.8047337278106514,ab28GAufK8o#000261#000596.mp4,w2awOCDRtrc#001729#002009.mp4,0
3.4260355029585763,w2awOCDRtrc#001729#002009.mp4,oNkBx4CZuEg#000000#001024.mp4,0
3.313609467455621,DMEaUoA8EPE#000028#000354.mp4,WlDYrq8K6nk#005943#006135.mp4,0
3.8402366863905333,oNkBx4CZuEg#001024#002048.mp4,ab28GAufK8o#000261#000596.mp4,0
3.3254437869822504,aDyyTMUBoLE#000164#000351.mp4,oNkBx4CZuEg#000000#001024.mp4,0
1.2485207100591724,0Q914by5A98#010440#010764.mp4,aDyyTMUBoLE#000164#000351.mp4,0
3.804733727810652,OiblkvkAHWM#006251#006533.mp4,aDyyTMUBoLE#000375#000518.mp4,0
3.662721893491124,uEqWZ9S_-Lw#000089#000581.mp4,DMEaUoA8EPE#000028#000354.mp4,0
3.230769230769233,A3ZmT97hAWU#000095#000678.mp4,ab28GAufK8o#000261#000596.mp4,0
3.3668639053254434,w81Tr0Dp1K8#015329#015485.mp4,WlDYrq8K6nk#008186#008512.mp4,0
3.313609467455621,WlDYrq8K6nk#005943#006135.mp4,DMEaUoA8EPE#000028#000354.mp4,0
2.7514792899408294,OiblkvkAHWM#014331#014459.mp4,WlDYrq8K6nk#008186#008512.mp4,0
1.964497041420118,L82WHgYRq6I#000021#000479.mp4,DMEaUoA8EPE#000028#000354.mp4,0
3.78698224852071,FBuF0xOal9M#046824#047542.mp4,lCb5w6n8kPs#011879#012014.mp4,0
3.92307692307692,ab28GAufK8o#000261#000596.mp4,L82WHgYRq6I#000021#000479.mp4,0
3.8402366863905333,ab28GAufK8o#000261#000596.mp4,oNkBx4CZuEg#001024#002048.mp4,0
3.828402366863905,ab28GAufK8o#000261#000596.mp4,OiblkvkAHWM#006251#006533.mp4,0
2.041420118343196,L82WHgYRq6I#000021#000479.mp4,aDyyTMUBoLE#000164#000351.mp4,0
3.2485207100591724,0Q914by5A98#010440#010764.mp4,w2awOCDRtrc#001729#002009.mp4,0
3.2485207100591746,oNkBx4CZuEg#000000#001024.mp4,0Q914by5A98#010440#010764.mp4,0
1.964497041420118,DMEaUoA8EPE#000028#000354.mp4,L82WHgYRq6I#000021#000479.mp4,0
3.5266272189349115,kgvcI9oe3NI#001578#001763.mp4,lCb5w6n8kPs#004451#004631.mp4,0
3.005917159763317,A3ZmT97hAWU#000095#000678.mp4,0Q914by5A98#010440#010764.mp4,0
3.230769230769233,ab28GAufK8o#000261#000596.mp4,A3ZmT97hAWU#000095#000678.mp4,0
3.5266272189349115,lCb5w6n8kPs#004451#004631.mp4,kgvcI9oe3NI#001578#001763.mp4,0
2.769230769230769,L82WHgYRq6I#000021#000479.mp4,oNkBx4CZuEg#000000#001024.mp4,0
3.165680473372782,WlDYrq8K6nk#005943#006135.mp4,w81Tr0Dp1K8#001375#001516.mp4,0
2.8994082840236666,DMEaUoA8EPE#000028#000354.mp4,oNkBx4CZuEg#000000#001024.mp4,0
2.4556213017751523,0Q914by5A98#010440#010764.mp4,mndSqTrxpts#000000#000175.mp4,0
2.201183431952659,A3ZmT97hAWU#000095#000678.mp4,VMSqvTE90hk#007168#007312.mp4,0
3.8047337278106514,w2awOCDRtrc#001729#002009.mp4,ab28GAufK8o#000261#000596.mp4,0
3.769230769230769,uEqWZ9S_-Lw#000089#000581.mp4,0Q914by5A98#010440#010764.mp4,0
3.6568047337278102,A3ZmT97hAWU#000095#000678.mp4,aDyyTMUBoLE#000164#000351.mp4,0
3.7869822485207107,uEqWZ9S_-Lw#000089#000581.mp4,L82WHgYRq6I#000021#000479.mp4,0
3.78698224852071,lCb5w6n8kPs#011879#012014.mp4,FBuF0xOal9M#046824#047542.mp4,0
3.591715976331361,nAQEOC1Z10M#020177#020600.mp4,w81Tr0Dp1K8#004036#004218.mp4,0
3.8757396449704156,uEqWZ9S_-Lw#000089#000581.mp4,aDyyTMUBoLE#000164#000351.mp4,0
2.45562130177515,aDyyTMUBoLE#000164#000351.mp4,DMEaUoA8EPE#000028#000354.mp4,0
3.5502958579881647,uEqWZ9S_-Lw#000089#000581.mp4,OiblkvkAHWM#006251#006533.mp4,0
3.7928994082840224,aDyyTMUBoLE#000375#000518.mp4,ab28GAufK8o#000261#000596.mp4,0
1 distance source driving frame
2 3.54437869822485 ab28GAufK8o#000261#000596.mp4 aDyyTMUBoLE#000164#000351.mp4 0
3 2.8639053254437887 DMEaUoA8EPE#000028#000354.mp4 0Q914by5A98#010440#010764.mp4 0
4 2.153846153846153 L82WHgYRq6I#000021#000479.mp4 0Q914by5A98#010440#010764.mp4 0
5 2.8994082840236666 oNkBx4CZuEg#000000#001024.mp4 DMEaUoA8EPE#000028#000354.mp4 0
6 3.3905325443786998 ab28GAufK8o#000261#000596.mp4 uEqWZ9S_-Lw#000089#000581.mp4 0
7 3.266272189349112 0Q914by5A98#010440#010764.mp4 ab28GAufK8o#000261#000596.mp4 0
8 2.7514792899408294 WlDYrq8K6nk#008186#008512.mp4 OiblkvkAHWM#014331#014459.mp4 0
9 3.0177514792899407 oNkBx4CZuEg#001024#002048.mp4 aDyyTMUBoLE#000375#000518.mp4 0
10 3.4792899408284064 aDyyTMUBoLE#000164#000351.mp4 w2awOCDRtrc#001729#002009.mp4 0
11 2.769230769230769 oNkBx4CZuEg#000000#001024.mp4 L82WHgYRq6I#000021#000479.mp4 0
12 3.8047337278106514 ab28GAufK8o#000261#000596.mp4 w2awOCDRtrc#001729#002009.mp4 0
13 3.4260355029585763 w2awOCDRtrc#001729#002009.mp4 oNkBx4CZuEg#000000#001024.mp4 0
14 3.313609467455621 DMEaUoA8EPE#000028#000354.mp4 WlDYrq8K6nk#005943#006135.mp4 0
15 3.8402366863905333 oNkBx4CZuEg#001024#002048.mp4 ab28GAufK8o#000261#000596.mp4 0
16 3.3254437869822504 aDyyTMUBoLE#000164#000351.mp4 oNkBx4CZuEg#000000#001024.mp4 0
17 1.2485207100591724 0Q914by5A98#010440#010764.mp4 aDyyTMUBoLE#000164#000351.mp4 0
18 3.804733727810652 OiblkvkAHWM#006251#006533.mp4 aDyyTMUBoLE#000375#000518.mp4 0
19 3.662721893491124 uEqWZ9S_-Lw#000089#000581.mp4 DMEaUoA8EPE#000028#000354.mp4 0
20 3.230769230769233 A3ZmT97hAWU#000095#000678.mp4 ab28GAufK8o#000261#000596.mp4 0
21 3.3668639053254434 w81Tr0Dp1K8#015329#015485.mp4 WlDYrq8K6nk#008186#008512.mp4 0
22 3.313609467455621 WlDYrq8K6nk#005943#006135.mp4 DMEaUoA8EPE#000028#000354.mp4 0
23 2.7514792899408294 OiblkvkAHWM#014331#014459.mp4 WlDYrq8K6nk#008186#008512.mp4 0
24 1.964497041420118 L82WHgYRq6I#000021#000479.mp4 DMEaUoA8EPE#000028#000354.mp4 0
25 3.78698224852071 FBuF0xOal9M#046824#047542.mp4 lCb5w6n8kPs#011879#012014.mp4 0
26 3.92307692307692 ab28GAufK8o#000261#000596.mp4 L82WHgYRq6I#000021#000479.mp4 0
27 3.8402366863905333 ab28GAufK8o#000261#000596.mp4 oNkBx4CZuEg#001024#002048.mp4 0
28 3.828402366863905 ab28GAufK8o#000261#000596.mp4 OiblkvkAHWM#006251#006533.mp4 0
29 2.041420118343196 L82WHgYRq6I#000021#000479.mp4 aDyyTMUBoLE#000164#000351.mp4 0
30 3.2485207100591724 0Q914by5A98#010440#010764.mp4 w2awOCDRtrc#001729#002009.mp4 0
31 3.2485207100591746 oNkBx4CZuEg#000000#001024.mp4 0Q914by5A98#010440#010764.mp4 0
32 1.964497041420118 DMEaUoA8EPE#000028#000354.mp4 L82WHgYRq6I#000021#000479.mp4 0
33 3.5266272189349115 kgvcI9oe3NI#001578#001763.mp4 lCb5w6n8kPs#004451#004631.mp4 0
34 3.005917159763317 A3ZmT97hAWU#000095#000678.mp4 0Q914by5A98#010440#010764.mp4 0
35 3.230769230769233 ab28GAufK8o#000261#000596.mp4 A3ZmT97hAWU#000095#000678.mp4 0
36 3.5266272189349115 lCb5w6n8kPs#004451#004631.mp4 kgvcI9oe3NI#001578#001763.mp4 0
37 2.769230769230769 L82WHgYRq6I#000021#000479.mp4 oNkBx4CZuEg#000000#001024.mp4 0
38 3.165680473372782 WlDYrq8K6nk#005943#006135.mp4 w81Tr0Dp1K8#001375#001516.mp4 0
39 2.8994082840236666 DMEaUoA8EPE#000028#000354.mp4 oNkBx4CZuEg#000000#001024.mp4 0
40 2.4556213017751523 0Q914by5A98#010440#010764.mp4 mndSqTrxpts#000000#000175.mp4 0
41 2.201183431952659 A3ZmT97hAWU#000095#000678.mp4 VMSqvTE90hk#007168#007312.mp4 0
42 3.8047337278106514 w2awOCDRtrc#001729#002009.mp4 ab28GAufK8o#000261#000596.mp4 0
43 3.769230769230769 uEqWZ9S_-Lw#000089#000581.mp4 0Q914by5A98#010440#010764.mp4 0
44 3.6568047337278102 A3ZmT97hAWU#000095#000678.mp4 aDyyTMUBoLE#000164#000351.mp4 0
45 3.7869822485207107 uEqWZ9S_-Lw#000089#000581.mp4 L82WHgYRq6I#000021#000479.mp4 0
46 3.78698224852071 lCb5w6n8kPs#011879#012014.mp4 FBuF0xOal9M#046824#047542.mp4 0
47 3.591715976331361 nAQEOC1Z10M#020177#020600.mp4 w81Tr0Dp1K8#004036#004218.mp4 0
48 3.8757396449704156 uEqWZ9S_-Lw#000089#000581.mp4 aDyyTMUBoLE#000164#000351.mp4 0
49 2.45562130177515 aDyyTMUBoLE#000164#000351.mp4 DMEaUoA8EPE#000028#000354.mp4 0
50 3.5502958579881647 uEqWZ9S_-Lw#000089#000581.mp4 OiblkvkAHWM#006251#006533.mp4 0
51 3.7928994082840224 aDyyTMUBoLE#000375#000518.mp4 ab28GAufK8o#000261#000596.mp4 0

@ -0,0 +1,543 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "first-order-model-demo",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>",
"<a href=\"https://kaggle.com/kernels/welcome?src=https://github.com/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb\" target=\"_parent\"><img alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cdO_RxQZLahB"
},
"source": [
"# Demo for paper \"First Order Motion Model for Image Animation\"\n",
"To try the demo, press the 2 play buttons in order and scroll to the bottom. Note that it may take several minutes to load."
]
},
{
"cell_type": "code",
"metadata": {
"id": "UCMFMJV7K-ag"
},
"source": [
"%%capture\n",
"%pip install ffmpeg-python imageio-ffmpeg\n",
"!git init .\n",
"!git remote add origin https://github.com/AliaksandrSiarohin/first-order-model\n",
"!git pull origin master\n",
"!git clone https://github.com/graphemecluster/first-order-model-demo demo"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Oxi6-riLOgnm"
},
"source": [
"import IPython.display\n",
"import PIL.Image\n",
"import cv2\n",
"import ffmpeg\n",
"import imageio\n",
"import io\n",
"import ipywidgets\n",
"import numpy\n",
"import os.path\n",
"import requests\n",
"import skimage.transform\n",
"import warnings\n",
"from base64 import b64encode\n",
"from demo import load_checkpoints, make_animation # type: ignore (local file)\n",
"from google.colab import files, output\n",
"from IPython.display import HTML, Javascript\n",
"from shutil import copyfileobj\n",
"from skimage import img_as_ubyte\n",
"from tempfile import NamedTemporaryFile\n",
"from tqdm.auto import tqdm\n",
"warnings.filterwarnings(\"ignore\")\n",
"os.makedirs(\"user\", exist_ok=True)\n",
"\n",
"display(HTML(\"\"\"\n",
"<style>\n",
".widget-box > * {\n",
"\tflex-shrink: 0;\n",
"}\n",
".widget-tab {\n",
"\tmin-width: 0;\n",
"\tflex: 1 1 auto;\n",
"}\n",
".widget-tab .p-TabBar-tabLabel {\n",
"\tfont-size: 15px;\n",
"}\n",
".widget-upload {\n",
"\tbackground-color: tan;\n",
"}\n",
".widget-button {\n",
"\tfont-size: 18px;\n",
"\twidth: 160px;\n",
"\theight: 34px;\n",
"\tline-height: 34px;\n",
"}\n",
".widget-dropdown {\n",
"\twidth: 250px;\n",
"}\n",
".widget-checkbox {\n",
"\twidth: 650px;\n",
"}\n",
".widget-checkbox + .widget-checkbox {\n",
"\tmargin-top: -6px;\n",
"}\n",
".input-widget .output_html {\n",
"\ttext-align: center;\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tline-height: 266px;\n",
"\tcolor: lightgray;\n",
"\tfont-size: 72px;\n",
"}\n",
".title {\n",
"\tfont-size: 20px;\n",
"\tfont-weight: bold;\n",
"\tmargin: 12px 0 6px 0;\n",
"}\n",
".warning {\n",
"\tdisplay: none;\n",
"\tcolor: red;\n",
"\tmargin-left: 10px;\n",
"}\n",
".warn {\n",
"\tdisplay: initial;\n",
"}\n",
".resource {\n",
"\tcursor: pointer;\n",
"\tborder: 1px solid gray;\n",
"\tmargin: 5px;\n",
"\twidth: 160px;\n",
"\theight: 160px;\n",
"\tmin-width: 160px;\n",
"\tmin-height: 160px;\n",
"\tmax-width: 160px;\n",
"\tmax-height: 160px;\n",
"\t-webkit-box-sizing: initial;\n",
"\tbox-sizing: initial;\n",
"}\n",
".resource:hover {\n",
"\tborder: 6px solid crimson;\n",
"\tmargin: 0;\n",
"}\n",
".selected {\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".input-widget {\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".input-button {\n",
"\twidth: 268px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".output-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".output-button {\n",
"\twidth: 258px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".uploaded {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".label-or {\n",
"\talign-self: center;\n",
"\tfont-size: 20px;\n",
"\tmargin: 16px;\n",
"}\n",
".loading {\n",
"\talign-items: center;\n",
"\twidth: fit-content;\n",
"}\n",
".loader {\n",
"\tmargin: 32px 0 16px 0;\n",
"\twidth: 48px;\n",
"\theight: 48px;\n",
"\tmin-width: 48px;\n",
"\tmin-height: 48px;\n",
"\tmax-width: 48px;\n",
"\tmax-height: 48px;\n",
"\tborder: 4px solid whitesmoke;\n",
"\tborder-top-color: gray;\n",
"\tborder-radius: 50%;\n",
"\tanimation: spin 1.8s linear infinite;\n",
"}\n",
".loading-label {\n",
"\tcolor: gray;\n",
"}\n",
".video {\n",
"\tmargin: 0;\n",
"}\n",
".comparison-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"\tmargin-left: 2px;\n",
"}\n",
".comparison-label {\n",
"\tcolor: gray;\n",
"\tfont-size: 14px;\n",
"\ttext-align: center;\n",
"\tposition: relative;\n",
"\tbottom: 3px;\n",
"}\n",
"@keyframes spin {\n",
"\tfrom { transform: rotate(0deg); }\n",
"\tto { transform: rotate(360deg); }\n",
"}\n",
"</style>\n",
"\"\"\"))\n",
"\n",
"def thumbnail(file):\n",
"\treturn imageio.get_reader(file, mode='I', format='FFMPEG').get_next_data()\n",
"\n",
"def create_image(i, j):\n",
"\timage_widget = ipywidgets.Image.from_file('demo/images/%d%d.png' % (i, j))\n",
"\timage_widget.add_class('resource')\n",
"\timage_widget.add_class('resource-image')\n",
"\timage_widget.add_class('resource-image%d%d' % (i, j))\n",
"\treturn image_widget\n",
"\n",
"def create_video(i):\n",
"\tvideo_widget = ipywidgets.Image(\n",
"\t\tvalue=cv2.imencode('.png', cv2.cvtColor(thumbnail('demo/videos/%d.mp4' % i), cv2.COLOR_RGB2BGR))[1].tostring(),\n",
"\t\tformat='png'\n",
"\t)\n",
"\tvideo_widget.add_class('resource')\n",
"\tvideo_widget.add_class('resource-video')\n",
"\tvideo_widget.add_class('resource-video%d' % i)\n",
"\treturn video_widget\n",
"\n",
"def create_title(title):\n",
"\ttitle_widget = ipywidgets.Label(title)\n",
"\ttitle_widget.add_class('title')\n",
"\treturn title_widget\n",
"\n",
"def download_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tfiles.download('output.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def convert_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tffmpeg.input('output.mp4').output('scaled.mp4', vf='scale=1080x1080:flags=lanczos,pad=1920:1080:420:0').overwrite_output().run()\n",
"\tfiles.download('scaled.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def back_to_main(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tmain.layout.display = ''\n",
"\n",
"label_or = ipywidgets.Label('or')\n",
"label_or.add_class('label-or')\n",
"\n",
"image_titles = ['Peoples', 'Cartoons', 'Dolls', 'Game of Thrones', 'Statues']\n",
"image_lengths = [8, 4, 8, 9, 4]\n",
"\n",
"image_tab = ipywidgets.Tab()\n",
"image_tab.children = [ipywidgets.HBox([create_image(i, j) for j in range(length)]) for i, length in enumerate(image_lengths)]\n",
"for i, title in enumerate(image_titles):\n",
"\timage_tab.set_title(i, title)\n",
"\n",
"input_image_widget = ipywidgets.Output()\n",
"input_image_widget.add_class('input-widget')\n",
"upload_input_image_button = ipywidgets.FileUpload(accept='image/*', button_style='primary')\n",
"upload_input_image_button.add_class('input-button')\n",
"image_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_image_widget, upload_input_image_button]),\n",
"\tlabel_or,\n",
"\timage_tab\n",
"])\n",
"\n",
"video_tab = ipywidgets.Tab()\n",
"video_tab.children = [ipywidgets.HBox([create_video(i) for i in range(5)])]\n",
"video_tab.set_title(0, 'All Videos')\n",
"\n",
"input_video_widget = ipywidgets.Output()\n",
"input_video_widget.add_class('input-widget')\n",
"upload_input_video_button = ipywidgets.FileUpload(accept='video/*', button_style='primary')\n",
"upload_input_video_button.add_class('input-button')\n",
"video_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_video_widget, upload_input_video_button]),\n",
"\tlabel_or,\n",
"\tvideo_tab\n",
"])\n",
"\n",
"model = ipywidgets.Dropdown(\n",
"\tdescription=\"Model:\",\n",
"\toptions=[\n",
"\t\t'vox',\n",
"\t\t'vox-adv',\n",
"\t\t'taichi',\n",
"\t\t'taichi-adv',\n",
"\t\t'nemo',\n",
"\t\t'mgif',\n",
"\t\t'fashion',\n",
"\t\t'bair'\n",
"\t]\n",
")\n",
"warning = ipywidgets.HTML('<b>Warning:</b> Upload your own images and videos (see README)')\n",
"warning.add_class('warning')\n",
"model_part = ipywidgets.HBox([model, warning])\n",
"\n",
"relative = ipywidgets.Checkbox(description=\"Relative keypoint displacement (Inherit object proporions from the video)\", value=True)\n",
"adapt_movement_scale = ipywidgets.Checkbox(description=\"Adapt movement scale (Dont touch unless you know want you are doing)\", value=True)\n",
"generate_button = ipywidgets.Button(description=\"Generate\", button_style='primary')\n",
"main = ipywidgets.VBox([\n",
"\tcreate_title('Choose Image'),\n",
"\timage_part,\n",
"\tcreate_title('Choose Video'),\n",
"\tvideo_part,\n",
"\tcreate_title('Settings'),\n",
"\tmodel_part,\n",
"\trelative,\n",
"\tadapt_movement_scale,\n",
"\tgenerate_button\n",
"])\n",
"\n",
"loader = ipywidgets.Label()\n",
"loader.add_class(\"loader\")\n",
"loading_label = ipywidgets.Label(\"This may take several minutes to process…\")\n",
"loading_label.add_class(\"loading-label\")\n",
"progress_bar = ipywidgets.Output()\n",
"loading = ipywidgets.VBox([loader, loading_label, progress_bar])\n",
"loading.add_class('loading')\n",
"\n",
"output_widget = ipywidgets.Output()\n",
"output_widget.add_class('output-widget')\n",
"download = ipywidgets.Button(description='Download', button_style='primary')\n",
"download.add_class('output-button')\n",
"download.on_click(download_output)\n",
"convert = ipywidgets.Button(description='Convert to 1920×1080', button_style='primary')\n",
"convert.add_class('output-button')\n",
"convert.on_click(convert_output)\n",
"back = ipywidgets.Button(description='Back', button_style='primary')\n",
"back.add_class('output-button')\n",
"back.on_click(back_to_main)\n",
"\n",
"comparison_widget = ipywidgets.Output()\n",
"comparison_widget.add_class('comparison-widget')\n",
"comparison_label = ipywidgets.Label('Comparison')\n",
"comparison_label.add_class('comparison-label')\n",
"complete = ipywidgets.HBox([\n",
"\tipywidgets.VBox([output_widget, download, convert, back]),\n",
"\tipywidgets.VBox([comparison_widget, comparison_label])\n",
"])\n",
"\n",
"display(ipywidgets.VBox([main, loading, complete]))\n",
"display(Javascript(\"\"\"\n",
"var images, videos;\n",
"function deselectImages() {\n",
"\timages.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function deselectVideos() {\n",
"\tvideos.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function invokePython(func) {\n",
"\tgoogle.colab.kernel.invokeFunction(\"notebook.\" + func, [].slice.call(arguments, 1), {});\n",
"}\n",
"setTimeout(function() {\n",
"\t(images = [].slice.call(document.getElementsByClassName(\"resource-image\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectImages();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_image\", item.className.match(/resource-image(\\d\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\timages[0].classList.add(\"selected\");\n",
"\t(videos = [].slice.call(document.getElementsByClassName(\"resource-video\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectVideos();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_video\", item.className.match(/resource-video(\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\tvideos[0].classList.add(\"selected\");\n",
"}, 1000);\n",
"\"\"\"))\n",
"\n",
"selected_image = None\n",
"def select_image(filename):\n",
"\tglobal selected_image\n",
"\tselected_image = resize(PIL.Image.open('demo/images/%s.png' % filename).convert(\"RGB\"))\n",
"\tinput_image_widget.clear_output(wait=True)\n",
"\twith input_image_widget:\n",
"\t\tdisplay(HTML('Image'))\n",
"\tinput_image_widget.remove_class('uploaded')\n",
"output.register_callback(\"notebook.select_image\", select_image)\n",
"\n",
"selected_video = None\n",
"def select_video(filename):\n",
"\tglobal selected_video\n",
"\tselected_video = 'demo/videos/%s.mp4' % filename\n",
"\tinput_video_widget.clear_output(wait=True)\n",
"\twith input_video_widget:\n",
"\t\tdisplay(HTML('Video'))\n",
"\tinput_video_widget.remove_class('uploaded')\n",
"output.register_callback(\"notebook.select_video\", select_video)\n",
"\n",
"def resize(image, size=(256, 256)):\n",
"\tw, h = image.size\n",
"\td = min(w, h)\n",
"\tr = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\n",
"\treturn image.resize(size, resample=PIL.Image.LANCZOS, box=r)\n",
"\n",
"def upload_image(change):\n",
"\tglobal selected_image\n",
"\tfor name, file_info in upload_input_image_button.value.items():\n",
"\t\tcontent = file_info['content']\n",
"\tif content is not None:\n",
"\t\tselected_image = resize(PIL.Image.open(io.BytesIO(content)).convert(\"RGB\"))\n",
"\t\tinput_image_widget.clear_output(wait=True)\n",
"\t\twith input_image_widget:\n",
"\t\t\tdisplay(selected_image)\n",
"\t\tinput_image_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectImages()'))\n",
"upload_input_image_button.observe(upload_image, names='value')\n",
"\n",
"def upload_video(change):\n",
"\tglobal selected_video\n",
"\tfor name, file_info in upload_input_video_button.value.items():\n",
"\t\tcontent = file_info['content']\n",
"\tif content is not None:\n",
"\t\tselected_video = 'user/' + name\n",
"\t\twith open(selected_video, 'wb') as video:\n",
"\t\t\tvideo.write(content)\n",
"\t\tpreview = resize(PIL.Image.fromarray(thumbnail(selected_video)).convert(\"RGB\"))\n",
"\t\tinput_video_widget.clear_output(wait=True)\n",
"\t\twith input_video_widget:\n",
"\t\t\tdisplay(preview)\n",
"\t\tinput_video_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectVideos()'))\n",
"upload_input_video_button.observe(upload_video, names='value')\n",
"\n",
"def change_model(change):\n",
"\tif model.value.startswith('vox'):\n",
"\t\twarning.remove_class('warn')\n",
"\telse:\n",
"\t\twarning.add_class('warn')\n",
"model.observe(change_model, names='value')\n",
"\n",
"def generate(button):\n",
"\tmain.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tfilename = model.value + ('' if model.value == 'fashion' else '-cpk') + '.pth.tar'\n",
"\tif not os.path.isfile(filename):\n",
"\t\tresponse = requests.get('https://github.com/graphemecluster/first-order-model-demo/releases/download/checkpoints/' + filename, stream=True)\n",
"\t\twith progress_bar:\n",
"\t\t\twith tqdm.wrapattr(response.raw, 'read', total=int(response.headers.get('Content-Length', 0)), unit='B', unit_scale=True, unit_divisor=1024) as raw:\n",
"\t\t\t\twith open(filename, 'wb') as file:\n",
"\t\t\t\t\tcopyfileobj(raw, file)\n",
"\t\tprogress_bar.clear_output()\n",
"\treader = imageio.get_reader(selected_video, mode='I', format='FFMPEG')\n",
"\tfps = reader.get_meta_data()['fps']\n",
"\tdriving_video = []\n",
"\tfor frame in reader:\n",
"\t\tdriving_video.append(frame)\n",
"\tgenerator, kp_detector = load_checkpoints(config_path='config/%s-256.yaml' % model.value, checkpoint_path=filename)\n",
"\twith progress_bar:\n",
"\t\tpredictions = make_animation(\n",
"\t\t\tskimage.transform.resize(numpy.asarray(selected_image), (256, 256)),\n",
"\t\t\t[skimage.transform.resize(frame, (256, 256)) for frame in driving_video],\n",
"\t\t\tgenerator,\n",
"\t\t\tkp_detector,\n",
"\t\t\trelative=relative.value,\n",
"\t\t\tadapt_movement_scale=adapt_movement_scale.value\n",
"\t\t)\n",
"\tprogress_bar.clear_output()\n",
"\timageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\ttry:\n",
"\t\twith NamedTemporaryFile(suffix='.mp4') as output:\n",
"\t\t\tffmpeg.output(ffmpeg.input('output.mp4').video, ffmpeg.input(selected_video).audio, output.name, c='copy').run()\n",
"\t\t\twith open('output.mp4', 'wb') as result:\n",
"\t\t\t\tcopyfileobj(output, result)\n",
"\texcept ffmpeg.Error:\n",
"\t\tpass\n",
"\toutput_widget.clear_output(True)\n",
"\twith output_widget:\n",
"\t\tvideo_widget = ipywidgets.Video.from_file('output.mp4', autoplay=False, loop=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-left')\n",
"\t\tdisplay(video_widget)\n",
"\tcomparison_widget.clear_output(True)\n",
"\twith comparison_widget:\n",
"\t\tvideo_widget = ipywidgets.Video.from_file(selected_video, autoplay=False, loop=False, controls=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-right')\n",
"\t\tdisplay(video_widget)\n",
"\tdisplay(Javascript(\"\"\"\n",
"\tsetTimeout(function() {\n",
"\t\t(function(left, right) {\n",
"\t\t\tleft.addEventListener(\"play\", function() {\n",
"\t\t\t\tright.play();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"pause\", function() {\n",
"\t\t\t\tright.pause();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"seeking\", function() {\n",
"\t\t\t\tright.currentTime = left.currentTime;\n",
"\t\t\t});\n",
"\t\t\tright.muted = true;\n",
"\t\t})(document.getElementsByClassName(\"video-left\")[0], document.getElementsByClassName(\"video-right\")[0]);\n",
"\t}, 1000);\n",
"\t\"\"\"))\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"generate_button.on_click(generate)\n",
"\n",
"loading.layout.display = 'none'\n",
"complete.layout.display = 'none'\n",
"select_image('00')\n",
"select_video('0')"
],
"execution_count": null,
"outputs": []
}
]
}

@ -0,0 +1,169 @@
#!/user/bin/env python
# coding=utf-8
import uuid
from typing import Optional
import gradio as gr
import matplotlib
matplotlib.use('Agg')
import os
import sys
import yaml
from argparse import ArgumentParser
from tqdm import tqdm
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
from scipy.spatial import ConvexHull
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
def load_checkpoints(config_path, checkpoint_path, cpu=True):
with open(config_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if not cpu:
generator.cuda()
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if not cpu:
kp_detector.cuda()
if cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(checkpoint_path)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
if not cpu:
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
return generator, kp_detector
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True,
cpu=True):
with torch.no_grad():
predictions = []
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
if not cpu:
source = source.cuda()
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[:, :, 0])
for frame_idx in tqdm(range(driving.shape[2])):
driving_frame = driving[:, :, frame_idx]
if not cpu:
driving_frame = driving_frame.cuda()
kp_driving = kp_detector(driving_frame)
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
return predictions
def find_best_frame(source, driving, cpu=False):
import face_alignment
def normalize_kp(kp):
kp = kp - kp.mean(axis=0, keepdims=True)
area = ConvexHull(kp[:, :2]).volume
area = np.sqrt(area)
kp[:, :2] = kp[:, :2] / area
return kp
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
device='cpu' if cpu else 'cuda')
kp_source = fa.get_landmarks(255 * source)[0]
kp_source = normalize_kp(kp_source)
norm = float('inf')
frame_num = 0
for i, image in tqdm(enumerate(driving)):
kp_driving = fa.get_landmarks(255 * image)[0]
kp_driving = normalize_kp(kp_driving)
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
if new_norm < norm:
norm = new_norm
frame_num = i
return frame_num
def h_interface(input_image: np.ndarray):
parser = ArgumentParser()
opt = parser.parse_args()
opt.config = "./config/vox-adv-256.yaml"
opt.checkpoint = "./checkpoints/vox-adv-cpk.pth.tar"
opt.source_image = input_image
opt.driving_video = "./data/chuck.mp4"
opt.result_video = "./data/result.mp4".format(uuid.uuid1().hex)
opt.relative = True
opt.adapt_scale = True
opt.cpu = True
opt.find_best_frame = False
opt.best_frame = False
source_image = opt.source_image
reader = imageio.get_reader(opt.driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
if opt.find_best_frame or opt.best_frame is not None:
i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
print("Best frame: " + str(i))
driving_forward = driving_video[i:]
driving_backward = driving_video[:(i + 1)][::-1]
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector,
relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector,
relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
predictions = predictions_backward[::-1] + predictions_forward[1:]
else:
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative,
adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
return opt.result_video
if __name__ == "__main__":
demo = gr.Interface(
fn=h_interface,
inputs=gr.Image(type="numpy", label="Input Image"),
outputs=gr.Video(label="Output Video")
)
demo.launch()

@ -0,0 +1,168 @@
import sys
import yaml
from argparse import ArgumentParser
from tqdm.auto import tqdm
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback
from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
import ffmpeg
from os.path import splitext
from shutil import copyfileobj
from tempfile import NamedTemporaryFile
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
def load_checkpoints(config_path, checkpoint_path, cpu=True):
with open(config_path) as f:
config = yaml.full_load(f)
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if not cpu:
generator.cuda()
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if not cpu:
kp_detector.cuda()
if cpu:
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(checkpoint_path)
generator.load_state_dict(checkpoint['generator'])
kp_detector.load_state_dict(checkpoint['kp_detector'])
if not cpu:
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
return generator, kp_detector
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True):
with torch.no_grad():
predictions = []
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
if not cpu:
source = source.cuda()
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[:, :, 0])
for frame_idx in tqdm(range(driving.shape[2])):
driving_frame = driving[:, :, frame_idx]
if not cpu:
driving_frame = driving_frame.cuda()
kp_driving = kp_detector(driving_frame)
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
return predictions
def find_best_frame(source, driving, cpu=False):
import face_alignment # type: ignore (local file)
from scipy.spatial import ConvexHull
def normalize_kp(kp):
kp = kp - kp.mean(axis=0, keepdims=True)
area = ConvexHull(kp[:, :2]).volume
area = np.sqrt(area)
kp[:, :2] = kp[:, :2] / area
return kp
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
device='cpu' if cpu else 'cuda')
kp_source = fa.get_landmarks(255 * source)[0]
kp_source = normalize_kp(kp_source)
norm = float('inf')
frame_num = 0
for i, image in tqdm(enumerate(driving)):
kp_driving = fa.get_landmarks(255 * image)[0]
kp_driving = normalize_kp(kp_driving)
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
if new_norm < norm:
norm = new_norm
frame_num = i
return frame_num
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--config", required=True, help="path to config")
parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")
parser.add_argument("--source_image", default='sup-mat/source.png', help="path to source image")
parser.add_argument("--driving_video", default='driving.mp4', help="path to driving video")
parser.add_argument("--result_video", default='result.mp4', help="path to output")
parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
parser.add_argument("--best_frame", dest="best_frame", type=int, default=None, help="Set frame to start from.")
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
parser.add_argument("--audio", dest="audio", action="store_true", help="copy audio to output from the driving video" )
parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
parser.set_defaults(audio_on=False)
opt = parser.parse_args()
source_image = imageio.imread(opt.source_image)
reader = imageio.get_reader(opt.driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
source_image = resize(source_image, (256, 256))[..., :3]
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
if opt.find_best_frame or opt.best_frame is not None:
i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
print("Best frame: " + str(i))
driving_forward = driving_video[i:]
driving_backward = driving_video[:(i+1)][::-1]
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
predictions = predictions_backward[::-1] + predictions_forward[1:]
else:
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
if opt.audio:
try:
with NamedTemporaryFile(suffix=splitext(opt.result_video)[1]) as output:
ffmpeg.output(ffmpeg.input(opt.result_video).video, ffmpeg.input(opt.driving_video).audio, output.name, c='copy').run()
with open(opt.result_video, 'wb') as result:
copyfileobj(output, result)
except ffmpeg.Error:
print("Failed to copy audio: the driving video may have no audio track or the audio format is invalid.")

@ -0,0 +1,804 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a><a href=\"https://kaggle.com/kernels/welcome?src=https://github.com/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb\" target=\"_parent\"><img alt=\"Kaggle\" title=\"Open in Kaggle\" src=\"https://kaggle.com/static/images/open-in-kaggle.svg\"></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cdO_RxQZLahB"
},
"source": [
"# Demo for paper \"First Order Motion Model for Image Animation\"\n",
"To try the demo, press the 2 play buttons in order and scroll to the bottom. Note that it may take several minutes to load."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "UCMFMJV7K-ag"
},
"outputs": [],
"source": [
"%%capture\n",
"%pip install ffmpeg-python imageio-ffmpeg\n",
"!git init .\n",
"!git remote add origin https://github.com/AliaksandrSiarohin/first-order-model\n",
"!git pull origin master\n",
"!git clone https://github.com/graphemecluster/first-order-model-demo demo"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {
"id": "Oxi6-riLOgnm"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<style>\n",
".widget-box > * {\n",
"\tflex-shrink: 0;\n",
"}\n",
".widget-tab {\n",
"\tmin-width: 0;\n",
"\tflex: 1 1 auto;\n",
"}\n",
".widget-tab .p-TabBar-tabLabel {\n",
"\tfont-size: 15px;\n",
"}\n",
".widget-upload {\n",
"\tbackground-color: tan;\n",
"}\n",
".widget-button {\n",
"\tfont-size: 18px;\n",
"\twidth: 160px;\n",
"\theight: 34px;\n",
"\tline-height: 34px;\n",
"}\n",
".widget-dropdown {\n",
"\twidth: 250px;\n",
"}\n",
".widget-checkbox {\n",
"\twidth: 650px;\n",
"}\n",
".widget-checkbox + .widget-checkbox {\n",
"\tmargin-top: -6px;\n",
"}\n",
".input-widget .output_html {\n",
"\ttext-align: center;\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tline-height: 266px;\n",
"\tcolor: lightgray;\n",
"\tfont-size: 72px;\n",
"}\n",
".title {\n",
"\tfont-size: 20px;\n",
"\tfont-weight: bold;\n",
"\tmargin: 12px 0 6px 0;\n",
"}\n",
".warning {\n",
"\tdisplay: none;\n",
"\tcolor: red;\n",
"\tmargin-left: 10px;\n",
"}\n",
".warn {\n",
"\tdisplay: initial;\n",
"}\n",
".resource {\n",
"\tcursor: pointer;\n",
"\tborder: 1px solid gray;\n",
"\tmargin: 5px;\n",
"\twidth: 160px;\n",
"\theight: 160px;\n",
"\tmin-width: 160px;\n",
"\tmin-height: 160px;\n",
"\tmax-width: 160px;\n",
"\tmax-height: 160px;\n",
"\t-webkit-box-sizing: initial;\n",
"\tbox-sizing: initial;\n",
"}\n",
".resource:hover {\n",
"\tborder: 6px solid crimson;\n",
"\tmargin: 0;\n",
"}\n",
".selected {\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".input-widget {\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".input-button {\n",
"\twidth: 268px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".output-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".output-button {\n",
"\twidth: 258px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".uploaded {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".label-or {\n",
"\talign-self: center;\n",
"\tfont-size: 20px;\n",
"\tmargin: 16px;\n",
"}\n",
".loading {\n",
"\talign-items: center;\n",
"\twidth: fit-content;\n",
"}\n",
".loader {\n",
"\tmargin: 32px 0 16px 0;\n",
"\twidth: 48px;\n",
"\theight: 48px;\n",
"\tmin-width: 48px;\n",
"\tmin-height: 48px;\n",
"\tmax-width: 48px;\n",
"\tmax-height: 48px;\n",
"\tborder: 4px solid whitesmoke;\n",
"\tborder-top-color: gray;\n",
"\tborder-radius: 50%;\n",
"\tanimation: spin 1.8s linear infinite;\n",
"}\n",
".loading-label {\n",
"\tcolor: gray;\n",
"}\n",
".video {\n",
"\tmargin: 0;\n",
"}\n",
".comparison-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"\tmargin-left: 2px;\n",
"}\n",
".comparison-label {\n",
"\tcolor: gray;\n",
"\tfont-size: 14px;\n",
"\ttext-align: center;\n",
"\tposition: relative;\n",
"\tbottom: 3px;\n",
"}\n",
"@keyframes spin {\n",
"\tfrom { transform: rotate(0deg); }\n",
"\tto { transform: rotate(360deg); }\n",
"}\n",
"</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f53c7ccd3ec34f7ea8491237d5bf03ff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(VBox(children=(Label(value='Choose Image', _dom_classes=('title',)), HBox(children=(VBox(childr…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"\n",
"var images, videos;\n",
"function deselectImages() {\n",
"\timages.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function deselectVideos() {\n",
"\tvideos.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function invokePython(func) {\n",
"\tgoogle.colab.kernel.invokeFunction(\"notebook.\" + func, [].slice.call(arguments, 1), {});\n",
"}\n",
"setTimeout(function() {\n",
"\t(images = [].slice.call(document.getElementsByClassName(\"resource-image\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectImages();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_image\", item.className.match(/resource-image(\\d\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\timages[0].classList.add(\"selected\");\n",
"\t(videos = [].slice.call(document.getElementsByClassName(\"resource-video\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectVideos();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_video\", item.className.match(/resource-video(\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\tvideos[0].classList.add(\"selected\");\n",
"}, 1000);\n"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import IPython.display\n",
"import PIL.Image\n",
"import cv2\n",
"import ffmpeg\n",
"import imageio\n",
"import io\n",
"import ipywidgets\n",
"import numpy\n",
"import os.path\n",
"import requests\n",
"import skimage.transform\n",
"import warnings\n",
"from base64 import b64encode\n",
"from demo import load_checkpoints, make_animation # type: ignore (local file)\n",
"from IPython.display import HTML, Javascript\n",
"from shutil import copyfileobj\n",
"from skimage import img_as_ubyte\n",
"from tempfile import NamedTemporaryFile\n",
"import os\n",
"import ipywidgets as ipyw\n",
"from IPython.display import display, FileLink\n",
"warnings.filterwarnings(\"ignore\")\n",
"os.makedirs(\"user\", exist_ok=True)\n",
"\n",
"display(HTML(\"\"\"\n",
"<style>\n",
".widget-box > * {\n",
"\tflex-shrink: 0;\n",
"}\n",
".widget-tab {\n",
"\tmin-width: 0;\n",
"\tflex: 1 1 auto;\n",
"}\n",
".widget-tab .p-TabBar-tabLabel {\n",
"\tfont-size: 15px;\n",
"}\n",
".widget-upload {\n",
"\tbackground-color: tan;\n",
"}\n",
".widget-button {\n",
"\tfont-size: 18px;\n",
"\twidth: 160px;\n",
"\theight: 34px;\n",
"\tline-height: 34px;\n",
"}\n",
".widget-dropdown {\n",
"\twidth: 250px;\n",
"}\n",
".widget-checkbox {\n",
"\twidth: 650px;\n",
"}\n",
".widget-checkbox + .widget-checkbox {\n",
"\tmargin-top: -6px;\n",
"}\n",
".input-widget .output_html {\n",
"\ttext-align: center;\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tline-height: 266px;\n",
"\tcolor: lightgray;\n",
"\tfont-size: 72px;\n",
"}\n",
".title {\n",
"\tfont-size: 20px;\n",
"\tfont-weight: bold;\n",
"\tmargin: 12px 0 6px 0;\n",
"}\n",
".warning {\n",
"\tdisplay: none;\n",
"\tcolor: red;\n",
"\tmargin-left: 10px;\n",
"}\n",
".warn {\n",
"\tdisplay: initial;\n",
"}\n",
".resource {\n",
"\tcursor: pointer;\n",
"\tborder: 1px solid gray;\n",
"\tmargin: 5px;\n",
"\twidth: 160px;\n",
"\theight: 160px;\n",
"\tmin-width: 160px;\n",
"\tmin-height: 160px;\n",
"\tmax-width: 160px;\n",
"\tmax-height: 160px;\n",
"\t-webkit-box-sizing: initial;\n",
"\tbox-sizing: initial;\n",
"}\n",
".resource:hover {\n",
"\tborder: 6px solid crimson;\n",
"\tmargin: 0;\n",
"}\n",
".selected {\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".input-widget {\n",
"\twidth: 266px;\n",
"\theight: 266px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".input-button {\n",
"\twidth: 268px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".output-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"}\n",
".output-button {\n",
"\twidth: 258px;\n",
"\tfont-size: 15px;\n",
"\tmargin: 2px 0 0;\n",
"}\n",
".uploaded {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 6px solid seagreen;\n",
"\tmargin: 0;\n",
"}\n",
".label-or {\n",
"\talign-self: center;\n",
"\tfont-size: 20px;\n",
"\tmargin: 16px;\n",
"}\n",
".loading {\n",
"\talign-items: center;\n",
"\twidth: fit-content;\n",
"}\n",
".loader {\n",
"\tmargin: 32px 0 16px 0;\n",
"\twidth: 48px;\n",
"\theight: 48px;\n",
"\tmin-width: 48px;\n",
"\tmin-height: 48px;\n",
"\tmax-width: 48px;\n",
"\tmax-height: 48px;\n",
"\tborder: 4px solid whitesmoke;\n",
"\tborder-top-color: gray;\n",
"\tborder-radius: 50%;\n",
"\tanimation: spin 1.8s linear infinite;\n",
"}\n",
".loading-label {\n",
"\tcolor: gray;\n",
"}\n",
".video {\n",
"\tmargin: 0;\n",
"}\n",
".comparison-widget {\n",
"\twidth: 256px;\n",
"\theight: 256px;\n",
"\tborder: 1px solid gray;\n",
"\tmargin-left: 2px;\n",
"}\n",
".comparison-label {\n",
"\tcolor: gray;\n",
"\tfont-size: 14px;\n",
"\ttext-align: center;\n",
"\tposition: relative;\n",
"\tbottom: 3px;\n",
"}\n",
"@keyframes spin {\n",
"\tfrom { transform: rotate(0deg); }\n",
"\tto { transform: rotate(360deg); }\n",
"}\n",
"</style>\n",
"\"\"\"))\n",
"\n",
"\n",
"def uploaded_file(change):\n",
" save_dir = 'uploads'\n",
" if not os.path.exists(save_dir): os.mkdir(save_dir)\n",
" \n",
" uploads = change['new']\n",
" for upload in uploads:\n",
" filename = upload['name']\n",
" content = upload['content']\n",
" with open(os.path.join(save_dir,filename), 'wb') as f:\n",
" f.write(content)\n",
" with out:\n",
" print(change)\n",
" \n",
"def create_uploader():\n",
" uploader = ipyw.FileUpload(multiple=True)\n",
" display(uploader)\n",
" uploader.description = '📂 Upload'\n",
" uploader.observe(uploaded_file, names='value')\n",
"\n",
"def download_file(filename='./face.mp4') -> HTML:\n",
" fl=FileLink(filename)\n",
" fl.html_link_str =\"<a href='%s' target='' class='downloadLink'>%s</a>\"\n",
" \n",
" display(fl)\n",
" display(HTML(f\"\"\"\n",
"<script>\n",
" var links = document.getElementsByClassName('downloadLink');\n",
" Array.from(links).map(e => e.setAttribute('download', ''))\n",
" var links = document.getElementsByClassName('downloadLink');\n",
" ['data-commandlinker-args','data-commandlinker-command'].map(e => {{ Array.from(links).map(i => {{ i.removeAttribute(e) }} ) }})\n",
" links[0].click()\n",
"</script>\n",
"\"\"\"))\n",
" \n",
"def thumbnail(file):\n",
"\treturn imageio.get_reader(file, mode='I', format='FFMPEG').get_next_data()\n",
"\n",
"def create_image(i, j):\n",
"\timage_widget = ipywidgets.Image.from_file('demo/images/%d%d.png' % (i, j))\n",
"\timage_widget.add_class('resource')\n",
"\timage_widget.add_class('resource-image')\n",
"\timage_widget.add_class('resource-image%d%d' % (i, j))\n",
"\treturn image_widget\n",
"\n",
"def create_video(i):\n",
"\tvideo_widget = ipywidgets.Image(\n",
"\t\tvalue=cv2.imencode('.png', cv2.cvtColor(thumbnail('demo/videos/%d.mp4' % i), cv2.COLOR_RGB2BGR))[1].tostring(),\n",
"\t\tformat='png'\n",
"\t)\n",
"\tvideo_widget.add_class('resource')\n",
"\tvideo_widget.add_class('resource-video')\n",
"\tvideo_widget.add_class('resource-video%d' % i)\n",
"\treturn video_widget\n",
"\n",
"def create_title(title):\n",
"\ttitle_widget = ipywidgets.Label(title)\n",
"\ttitle_widget.add_class('title')\n",
"\treturn title_widget\n",
"\n",
"def download_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tdownload_file('./output.mp4')\n",
"\t# files.download('output.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def convert_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tffmpeg.input('output.mp4').output('scaled.mp4', vf='scale=1080x1080:flags=lanczos,pad=1920:1080:420:0').overwrite_output().run()\n",
"\tfiles.download('scaled.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def back_to_main(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tmain.layout.display = ''\n",
"\n",
"label_or = ipywidgets.Label('or')\n",
"label_or.add_class('label-or')\n",
"\n",
"image_titles = ['Peoples', 'Cartoons', 'Dolls', 'Game of Thrones', 'Statues']\n",
"image_lengths = [8, 4, 8, 9, 4]\n",
"\n",
"image_tab = ipywidgets.Tab()\n",
"image_tab.children = [ipywidgets.HBox([create_image(i, j) for j in range(length)]) for i, length in enumerate(image_lengths)]\n",
"for i, title in enumerate(image_titles):\n",
"\timage_tab.set_title(i, title)\n",
"\n",
"input_image_widget = ipywidgets.Output()\n",
"input_image_widget.add_class('input-widget')\n",
"upload_input_image_button = ipywidgets.FileUpload(accept='image/*', button_style='primary')\n",
"upload_input_image_button.add_class('input-button')\n",
"image_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_image_widget, upload_input_image_button]),\n",
"\tlabel_or,\n",
"\timage_tab\n",
"])\n",
"\n",
"video_tab = ipywidgets.Tab()\n",
"video_tab.children = [ipywidgets.HBox([create_video(i) for i in range(5)])]\n",
"video_tab.set_title(0, 'All Videos')\n",
"\n",
"input_video_widget = ipywidgets.Output()\n",
"input_video_widget.add_class('input-widget')\n",
"upload_input_video_button = ipywidgets.FileUpload(accept='video/*', button_style='primary')\n",
"upload_input_video_button.add_class('input-button')\n",
"video_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_video_widget, upload_input_video_button]),\n",
"\tlabel_or,\n",
"\tvideo_tab\n",
"])\n",
"\n",
"model = ipywidgets.Dropdown(\n",
"\tdescription=\"Model:\",\n",
"\toptions=[\n",
"\t\t'vox',\n",
"\t\t'vox-adv',\n",
"\t\t'taichi',\n",
"\t\t'taichi-adv',\n",
"\t\t'nemo',\n",
"\t\t'mgif',\n",
"\t\t'fashion',\n",
"\t\t'bair'\n",
"\t]\n",
")\n",
"warning = ipywidgets.HTML('<b>Warning:</b> Upload your own images and videos (see README)')\n",
"warning.add_class('warning')\n",
"model_part = ipywidgets.HBox([model, warning])\n",
"\n",
"relative = ipywidgets.Checkbox(description=\"Relative keypoint displacement (Inherit object proporions from the video)\", value=True)\n",
"adapt_movement_scale = ipywidgets.Checkbox(description=\"Adapt movement scale (Dont touch unless you know want you are doing)\", value=True)\n",
"generate_button = ipywidgets.Button(description=\"Generate\", button_style='primary')\n",
"main = ipywidgets.VBox([\n",
"\tcreate_title('Choose Image'),\n",
"\timage_part,\n",
"\tcreate_title('Choose Video'),\n",
"\tvideo_part,\n",
"\tcreate_title('Settings'),\n",
"\tmodel_part,\n",
"\trelative,\n",
"\tadapt_movement_scale,\n",
"\tgenerate_button\n",
"])\n",
"\n",
"loader = ipywidgets.Label()\n",
"loader.add_class(\"loader\")\n",
"loading_label = ipywidgets.Label(\"This may take several minutes to process…\")\n",
"loading_label.add_class(\"loading-label\")\n",
"progress_bar = ipywidgets.Output()\n",
"loading = ipywidgets.VBox([loader, loading_label, progress_bar])\n",
"loading.add_class('loading')\n",
"\n",
"output_widget = ipywidgets.Output()\n",
"output_widget.add_class('output-widget')\n",
"download = ipywidgets.Button(description='Download', button_style='primary')\n",
"download.add_class('output-button')\n",
"download.on_click(download_output)\n",
"convert = ipywidgets.Button(description='Convert to 1920×1080', button_style='primary')\n",
"convert.add_class('output-button')\n",
"convert.on_click(convert_output)\n",
"back = ipywidgets.Button(description='Back', button_style='primary')\n",
"back.add_class('output-button')\n",
"back.on_click(back_to_main)\n",
"\n",
"comparison_widget = ipywidgets.Output()\n",
"comparison_widget.add_class('comparison-widget')\n",
"comparison_label = ipywidgets.Label('Comparison')\n",
"comparison_label.add_class('comparison-label')\n",
"complete = ipywidgets.HBox([\n",
"\tipywidgets.VBox([output_widget, download, convert, back]),\n",
"\tipywidgets.VBox([comparison_widget, comparison_label])\n",
"])\n",
"\n",
"display(ipywidgets.VBox([main, loading, complete]))\n",
"display(Javascript(\"\"\"\n",
"var images, videos;\n",
"function deselectImages() {\n",
"\timages.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function deselectVideos() {\n",
"\tvideos.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function invokePython(func) {\n",
"\tgoogle.colab.kernel.invokeFunction(\"notebook.\" + func, [].slice.call(arguments, 1), {});\n",
"}\n",
"setTimeout(function() {\n",
"\t(images = [].slice.call(document.getElementsByClassName(\"resource-image\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectImages();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_image\", item.className.match(/resource-image(\\d\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\timages[0].classList.add(\"selected\");\n",
"\t(videos = [].slice.call(document.getElementsByClassName(\"resource-video\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectVideos();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_video\", item.className.match(/resource-video(\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\tvideos[0].classList.add(\"selected\");\n",
"}, 1000);\n",
"\"\"\"))\n",
"\n",
"selected_image = None\n",
"def select_image(filename):\n",
"\tglobal selected_image\n",
"\tselected_image = resize(PIL.Image.open('demo/images/%s.png' % filename).convert(\"RGB\"))\n",
"\tinput_image_widget.clear_output(wait=True)\n",
"\twith input_image_widget:\n",
"\t\tdisplay(HTML('Image'))\n",
"\tinput_image_widget.remove_class('uploaded')\n",
"# output.register_callback(\"notebook.select_image\", select_image)\n",
"\n",
"selected_video = None\n",
"def select_video(filename):\n",
"\tglobal selected_video\n",
"\tselected_video = 'demo/videos/%s.mp4' % filename\n",
"\tinput_video_widget.clear_output(wait=True)\n",
"\twith input_video_widget:\n",
"\t\tdisplay(HTML('Video'))\n",
"\tinput_video_widget.remove_class('uploaded')\n",
"# output.register_callback(\"notebook.select_video\", select_video)\n",
"\n",
"def resize(image, size=(256, 256)):\n",
"\tw, h = image.size\n",
"\td = min(w, h)\n",
"\tr = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\n",
"\treturn image.resize(size, resample=PIL.Image.LANCZOS, box=r)\n",
"\n",
"def upload_image(change):\n",
"\tglobal selected_image\n",
"\tcontent = upload_input_image_button.value[0]['content']\n",
"\tname = upload_input_image_button.value[0]['name']\n",
" \n",
"\t# for name, file_info in upload_input_image_button.value.items():\n",
"\t\t# content = file_info['content']\n",
"\tif content is not None:\n",
"\t\tselected_image = resize(PIL.Image.open(io.BytesIO(content)).convert(\"RGB\"))\n",
"\t\tinput_image_widget.clear_output(wait=True)\n",
"\t\twith input_image_widget:\n",
"\t\t\tdisplay(selected_image)\n",
"\t\tinput_image_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectImages()'))\n",
"upload_input_image_button.observe(upload_image, names='value')\n",
"\n",
"def upload_video(change):\n",
"\tglobal selected_video\n",
"\t# for name, file_info in upload_input_video_button.value.items():\n",
"\t\t# content = file_info['content']\n",
"\tcontent = upload_input_video_button.value[0]['content']\n",
"\tname = upload_input_video_button.value[0]['name']\n",
"\tif content is not None:\n",
"\t\tselected_video = 'user/' + name\n",
"\t\tpreview = resize(PIL.Image.fromarray(thumbnail(content)).convert(\"RGB\"))\n",
"\t\tinput_video_widget.clear_output(wait=True)\n",
"\t\twith input_video_widget:\n",
"\t\t\tdisplay(preview)\n",
"\t\tinput_video_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectVideos()'))\n",
"\t\twith open(selected_video, 'wb') as video:\n",
"\t\t\tvideo.write(content)\n",
"upload_input_video_button.observe(upload_video, names='value')\n",
"\n",
"def change_model(change):\n",
"\tif model.value.startswith('vox'):\n",
"\t\twarning.remove_class('warn')\n",
"\telse:\n",
"\t\twarning.add_class('warn')\n",
"model.observe(change_model, names='value')\n",
"\n",
"def generate(button):\n",
"\tmain.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tfilename = model.value + ('' if model.value == 'fashion' else '-cpk') + '.pth.tar'\n",
"\tif not os.path.isfile(filename):\n",
"\t\tdownload = requests.get(requests.get('https://cloud-api.yandex.net/v1/disk/public/resources/download?public_key=https://yadi.sk/d/lEw8uRm140L_eQ&path=/' + filename).json().get('href'))\n",
"\t\twith open(filename, 'wb') as checkpoint:\n",
"\t\t\tcheckpoint.write(download.content)\n",
"\treader = imageio.get_reader(selected_video, mode='I', format='FFMPEG')\n",
"\tfps = reader.get_meta_data()['fps']\n",
"\tdriving_video = []\n",
"\tfor frame in reader:\n",
"\t\tdriving_video.append(frame)\n",
"\tgenerator, kp_detector = load_checkpoints(config_path='config/%s-256.yaml' % model.value, checkpoint_path=filename)\n",
"\twith progress_bar:\n",
"\t\tpredictions = make_animation(\n",
"\t\t\tskimage.transform.resize(numpy.asarray(selected_image), (256, 256)),\n",
"\t\t\t[skimage.transform.resize(frame, (256, 256)) for frame in driving_video],\n",
"\t\t\tgenerator,\n",
"\t\t\tkp_detector,\n",
"\t\t\trelative=relative.value,\n",
"\t\t\tadapt_movement_scale=adapt_movement_scale.value\n",
"\t\t)\n",
"\tprogress_bar.clear_output()\n",
"\timageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\tif selected_video.startswith('user/') or selected_video == 'demo/videos/0.mp4':\n",
"\t\twith NamedTemporaryFile(suffix='.mp4') as output:\n",
"\t\t\tffmpeg.output(ffmpeg.input('output.mp4').video, ffmpeg.input(selected_video).audio, output.name, c='copy').overwrite_output().run()\n",
"\t\t\twith open('output.mp4', 'wb') as result:\n",
"\t\t\t\tcopyfileobj(output, result)\n",
"\twith output_widget:\n",
"\t\tvideo_widget = ipywidgets.Video.from_file('output.mp4', autoplay=False, loop=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-left')\n",
"\t\tdisplay(video_widget)\n",
"\twith comparison_widget:\n",
"\t\tvideo_widget = ipywidgets.Video.from_file(selected_video, autoplay=False, loop=False, controls=False)\n",
"\t\tvideo_widget.add_class('video')\n",
"\t\tvideo_widget.add_class('video-right')\n",
"\t\tdisplay(video_widget)\n",
"\tdisplay(Javascript(\"\"\"\n",
"\tsetTimeout(function() {\n",
"\t\t(function(left, right) {\n",
"\t\t\tleft.addEventListener(\"play\", function() {\n",
"\t\t\t\tright.play();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"pause\", function() {\n",
"\t\t\t\tright.pause();\n",
"\t\t\t});\n",
"\t\t\tleft.addEventListener(\"seeking\", function() {\n",
"\t\t\t\tright.currentTime = left.currentTime;\n",
"\t\t\t});\n",
"\t\t\tright.muted = true;\n",
"\t\t})(document.getElementsByClassName(\"video-left\")[0], document.getElementsByClassName(\"video-right\")[0]);\n",
"\t}, 1000);\n",
"\t\"\"\"))\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"generate_button.on_click(generate)\n",
"\n",
"loading.layout.display = 'none'\n",
"complete.layout.display = 'none'\n",
"select_image('00')\n",
"select_video('0')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "first-order-model-demo",
"provenance": []
},
"kernelspec": {
"display_name": "ldm",
"language": "python",
"name": "ldm"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 52 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 242 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1007 KiB

@ -0,0 +1,5 @@
Input Image,Output Video,flag,username,timestamp
flagged\Input Image\04312ccaaa34d31868cf\source_image.png,"{""video"": ""flagged\\Output Video\\4df32a504b02bc01fdbb\\result.mp4"", ""subtitles"": null}",,,2024-06-30 11:24:20.090888
flagged\Input Image\69f77a91286d9746662e\屏幕截图 2024-03-10 184951.png,"{""video"": ""flagged\\Output Video\\c8bb6b7c069b5e9c6003\\result.mp4"", ""subtitles"": null}",,,2024-06-30 11:31:27.486262
,,,,2024-06-30 11:42:15.283795
flagged\Input Image\d7caf8b9c6c17d272f25\h09456_tm.txt.67257e.jpg,,,,2024-06-30 11:49:00.228115
1 Input Image Output Video flag username timestamp
2 flagged\Input Image\04312ccaaa34d31868cf\source_image.png {"video": "flagged\\Output Video\\4df32a504b02bc01fdbb\\result.mp4", "subtitles": null} 2024-06-30 11:24:20.090888
3 flagged\Input Image\69f77a91286d9746662e\屏幕截图 2024-03-10 184951.png {"video": "flagged\\Output Video\\c8bb6b7c069b5e9c6003\\result.mp4", "subtitles": null} 2024-06-30 11:31:27.486262
4 2024-06-30 11:42:15.283795
5 flagged\Input Image\d7caf8b9c6c17d272f25\h09456_tm.txt.67257e.jpg 2024-06-30 11:49:00.228115

@ -0,0 +1,197 @@
import os
from skimage import io, img_as_float32
from skimage.color import gray2rgb
from sklearn.model_selection import train_test_split
from imageio import mimread
import numpy as np
from torch.utils.data import Dataset
import pandas as pd
from augmentation import AllAugmentationTransform
import glob
def read_video(name, frame_shape):
"""
Read video which can be:
- an image of concatenated frames
- '.mp4' and'.gif'
- folder with videos
"""
if os.path.isdir(name):
frames = sorted(os.listdir(name))
num_frames = len(frames)
video_array = np.array(
[img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
image = io.imread(name)
if len(image.shape) == 2 or image.shape[2] == 1:
image = gray2rgb(image)
if image.shape[2] == 4:
image = image[..., :3]
image = img_as_float32(image)
video_array = np.moveaxis(image, 1, 0)
video_array = video_array.reshape((-1,) + frame_shape)
video_array = np.moveaxis(video_array, 1, 2)
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
video = np.array(mimread(name))
if len(video.shape) == 3:
video = np.array([gray2rgb(frame) for frame in video])
if video.shape[-1] == 4:
video = video[..., :3]
video_array = img_as_float32(video)
else:
raise Exception("Unknown file extensions %s" % name)
return video_array
class FramesDataset(Dataset):
"""
Dataset of videos, each video can be represented as:
- an image of concatenated frames
- '.mp4' or '.gif'
- folder with all frames
"""
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
random_seed=0, pairs_list=None, augmentation_params=None):
self.root_dir = root_dir
self.videos = os.listdir(root_dir)
self.frame_shape = tuple(frame_shape)
self.pairs_list = pairs_list
self.id_sampling = id_sampling
if os.path.exists(os.path.join(root_dir, 'train')):
assert os.path.exists(os.path.join(root_dir, 'test'))
print("Use predefined train-test split.")
if id_sampling:
train_videos = {os.path.basename(video).split('#')[0] for video in
os.listdir(os.path.join(root_dir, 'train'))}
train_videos = list(train_videos)
else:
train_videos = os.listdir(os.path.join(root_dir, 'train'))
test_videos = os.listdir(os.path.join(root_dir, 'test'))
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
else:
print("Use random train-test split.")
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
if is_train:
self.videos = train_videos
else:
self.videos = test_videos
self.is_train = is_train
if self.is_train:
self.transform = AllAugmentationTransform(**augmentation_params)
else:
self.transform = None
def __len__(self):
return len(self.videos)
def __getitem__(self, idx):
if self.is_train and self.id_sampling:
name = self.videos[idx]
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
else:
name = self.videos[idx]
path = os.path.join(self.root_dir, name)
video_name = os.path.basename(path)
if self.is_train and os.path.isdir(path):
frames = os.listdir(path)
num_frames = len(frames)
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
else:
video_array = read_video(path, frame_shape=self.frame_shape)
num_frames = len(video_array)
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
num_frames)
video_array = video_array[frame_idx]
if self.transform is not None:
video_array = self.transform(video_array)
out = {}
if self.is_train:
source = np.array(video_array[0], dtype='float32')
driving = np.array(video_array[1], dtype='float32')
out['driving'] = driving.transpose((2, 0, 1))
out['source'] = source.transpose((2, 0, 1))
else:
video = np.array(video_array, dtype='float32')
out['video'] = video.transpose((3, 0, 1, 2))
out['name'] = video_name
return out
class DatasetRepeater(Dataset):
"""
Pass several times over the same dataset for better i/o performance
"""
def __init__(self, dataset, num_repeats=100):
self.dataset = dataset
self.num_repeats = num_repeats
def __len__(self):
return self.num_repeats * self.dataset.__len__()
def __getitem__(self, idx):
return self.dataset[idx % self.dataset.__len__()]
class PairedDataset(Dataset):
"""
Dataset of pairs for animation.
"""
def __init__(self, initial_dataset, number_of_pairs, seed=0):
self.initial_dataset = initial_dataset
pairs_list = self.initial_dataset.pairs_list
np.random.seed(seed)
if pairs_list is None:
max_idx = min(number_of_pairs, len(initial_dataset))
nx, ny = max_idx, max_idx
xy = np.mgrid[:nx, :ny].reshape(2, -1).T
number_of_pairs = min(xy.shape[0], number_of_pairs)
self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
else:
videos = self.initial_dataset.videos
name_to_index = {name: index for index, name in enumerate(videos)}
pairs = pd.read_csv(pairs_list)
pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]
number_of_pairs = min(pairs.shape[0], number_of_pairs)
self.pairs = []
self.start_frames = []
for ind in range(number_of_pairs):
self.pairs.append(
(name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
pair = self.pairs[idx]
first = self.initial_dataset[pair[0]]
second = self.initial_dataset[pair[1]]
first = {'driving_' + key: value for key, value in first.items()}
second = {'source_' + key: value for key, value in second.items()}
return {**first, **second}

@ -0,0 +1,211 @@
import numpy as np
import torch
import torch.nn.functional as F
import imageio
import os
from skimage.draw import disk
import matplotlib.pyplot as plt
import collections
class Logger:
def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
self.loss_list = []
self.cpk_dir = log_dir
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
if not os.path.exists(self.visualizations_dir):
os.makedirs(self.visualizations_dir)
self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
self.zfill_num = zfill_num
self.visualizer = Visualizer(**visualizer_params)
self.checkpoint_freq = checkpoint_freq
self.epoch = 0
self.best_loss = float('inf')
self.names = None
def log_scores(self, loss_names):
loss_mean = np.array(self.loss_list).mean(axis=0)
loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
print(loss_string, file=self.log_file)
self.loss_list = []
self.log_file.flush()
def visualize_rec(self, inp, out):
image = self.visualizer.visualize(inp['driving'], inp['source'], out)
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
def save_cpk(self, emergent=False):
cpk = {k: v.state_dict() for k, v in self.models.items()}
cpk['epoch'] = self.epoch
cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
if not (os.path.exists(cpk_path) and emergent):
torch.save(cpk, cpk_path)
@staticmethod
def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None,
optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None):
if torch.cuda.is_available():
map_location = None
else:
map_location = 'cpu'
checkpoint = torch.load(checkpoint_path, map_location)
if generator is not None:
generator.load_state_dict(checkpoint['generator'])
if kp_detector is not None:
kp_detector.load_state_dict(checkpoint['kp_detector'])
if discriminator is not None:
try:
discriminator.load_state_dict(checkpoint['discriminator'])
except:
print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
if optimizer_generator is not None:
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
if optimizer_discriminator is not None:
try:
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
except RuntimeError as e:
print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
if optimizer_kp_detector is not None:
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
return checkpoint['epoch']
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if 'models' in self.__dict__:
self.save_cpk()
self.log_file.close()
def log_iter(self, losses):
losses = collections.OrderedDict(losses.items())
if self.names is None:
self.names = list(losses.keys())
self.loss_list.append(list(losses.values()))
def log_epoch(self, epoch, models, inp, out):
self.epoch = epoch
self.models = models
if (self.epoch + 1) % self.checkpoint_freq == 0:
self.save_cpk()
self.log_scores(self.names)
self.visualize_rec(inp, out)
class Visualizer:
def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
self.kp_size = kp_size
self.draw_border = draw_border
self.colormap = plt.get_cmap(colormap)
def draw_image_with_kp(self, image, kp_array):
image = np.copy(image)
spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
kp_array = spatial_size * (kp_array + 1) / 2
num_kp = kp_array.shape[0]
for kp_ind, kp in enumerate(kp_array):
rr, cc = disk(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
return image
def create_image_column_with_kp(self, images, kp):
image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
return self.create_image_column(image_array)
def create_image_column(self, images):
if self.draw_border:
images = np.copy(images)
images[:, :, [0, -1]] = (1, 1, 1)
return np.concatenate(list(images), axis=0)
def create_image_grid(self, *args):
out = []
for arg in args:
if type(arg) == tuple:
out.append(self.create_image_column_with_kp(arg[0], arg[1]))
else:
out.append(self.create_image_column(arg))
return np.concatenate(out, axis=1)
def visualize(self, driving, source, out):
images = []
# Source image with keypoints
source = source.data.cpu()
kp_source = out['kp_source']['value'].data.cpu().numpy()
source = np.transpose(source, [0, 2, 3, 1])
images.append((source, kp_source))
# Equivariance visualization
if 'transformed_frame' in out:
transformed = out['transformed_frame'].data.cpu().numpy()
transformed = np.transpose(transformed, [0, 2, 3, 1])
transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()
images.append((transformed, transformed_kp))
# Driving image with keypoints
kp_driving = out['kp_driving']['value'].data.cpu().numpy()
driving = driving.data.cpu().numpy()
driving = np.transpose(driving, [0, 2, 3, 1])
images.append((driving, kp_driving))
# Deformed image
if 'deformed' in out:
deformed = out['deformed'].data.cpu().numpy()
deformed = np.transpose(deformed, [0, 2, 3, 1])
images.append(deformed)
# Result with and without keypoints
prediction = out['prediction'].data.cpu().numpy()
prediction = np.transpose(prediction, [0, 2, 3, 1])
if 'kp_norm' in out:
kp_norm = out['kp_norm']['value'].data.cpu().numpy()
images.append((prediction, kp_norm))
images.append(prediction)
## Occlusion map
if 'occlusion_map' in out:
occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
images.append(occlusion_map)
# Deformed images according to each individual transform
if 'sparse_deformed' in out:
full_mask = []
for i in range(out['sparse_deformed'].shape[1]):
image = out['sparse_deformed'][:, i].data.cpu()
image = F.interpolate(image, size=source.shape[1:3])
mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
mask = F.interpolate(mask, size=source.shape[1:3])
image = np.transpose(image.numpy(), (0, 2, 3, 1))
mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
if i != 0:
color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]
else:
color = np.array((0, 0, 0))
color = color.reshape((1, 1, 1, 3))
images.append(image)
if i != 0:
images.append(mask * color)
else:
images.append(mask)
full_mask.append(mask * color)
images.append(sum(full_mask))
image = self.create_image_grid(*images)
image = (255 * image).astype(np.uint8)
return image

@ -0,0 +1,113 @@
from torch import nn
import torch.nn.functional as F
import torch
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
class DenseMotionNetwork(nn.Module):
"""
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
"""
def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
scale_factor=1, kp_variance=0.01):
super(DenseMotionNetwork, self).__init__()
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
max_features=max_features, num_blocks=num_blocks)
self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
if estimate_occlusion_map:
self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
else:
self.occlusion = None
self.num_kp = num_kp
self.scale_factor = scale_factor
self.kp_variance = kp_variance
if self.scale_factor != 1:
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
"""
Eq 6. in the paper H_k(z)
"""
spatial_size = source_image.shape[2:]
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
heatmap = gaussian_driving - gaussian_source
#adding background feature
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
heatmap = torch.cat([zeros, heatmap], dim=1)
heatmap = heatmap.unsqueeze(2)
return heatmap
def create_sparse_motions(self, source_image, kp_driving, kp_source):
"""
Eq 4. in the paper T_{s<-d}(z)
"""
bs, _, h, w = source_image.shape
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
identity_grid = identity_grid.view(1, 1, h, w, 2)
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
if 'jacobian' in kp_driving:
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
coordinate_grid = coordinate_grid.squeeze(-1)
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
#adding background feature
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
return sparse_motions
def create_deformed_source_image(self, source_image, sparse_motions):
"""
Eq 7. in the paper \hat{T}_{s<-d}(z)
"""
bs, _, h, w = source_image.shape
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
return sparse_deformed
def forward(self, source_image, kp_driving, kp_source):
if self.scale_factor != 1:
source_image = self.down(source_image)
bs, _, h, w = source_image.shape
out_dict = dict()
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
out_dict['sparse_deformed'] = deformed_source
input = torch.cat([heatmap_representation, deformed_source], dim=2)
input = input.view(bs, -1, h, w)
prediction = self.hourglass(input)
mask = self.mask(prediction)
mask = F.softmax(mask, dim=1)
out_dict['mask'] = mask
mask = mask.unsqueeze(2)
sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
deformation = (sparse_motion * mask).sum(dim=1)
deformation = deformation.permute(0, 2, 3, 1)
out_dict['deformation'] = deformation
# Sec. 3.2 in the paper
if self.occlusion:
occlusion_map = torch.sigmoid(self.occlusion(prediction))
out_dict['occlusion_map'] = occlusion_map
return out_dict

@ -0,0 +1,95 @@
from torch import nn
import torch.nn.functional as F
from modules.util import kp2gaussian
import torch
class DownBlock2d(nn.Module):
"""
Simple block for processing video (encoder).
"""
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
if norm:
self.norm = nn.InstanceNorm2d(out_features, affine=True)
else:
self.norm = None
self.pool = pool
def forward(self, x):
out = x
out = self.conv(out)
if self.norm:
out = self.norm(out)
out = F.leaky_relu(out, 0.2)
if self.pool:
out = F.avg_pool2d(out, (2, 2))
return out
class Discriminator(nn.Module):
"""
Discriminator similar to Pix2Pix
"""
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
super(Discriminator, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(
DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
self.down_blocks = nn.ModuleList(down_blocks)
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
if sn:
self.conv = nn.utils.spectral_norm(self.conv)
self.use_kp = use_kp
self.kp_variance = kp_variance
def forward(self, x, kp=None):
feature_maps = []
out = x
if self.use_kp:
heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
out = torch.cat([out, heatmap], dim=1)
for down_block in self.down_blocks:
feature_maps.append(down_block(out))
out = feature_maps[-1]
prediction_map = self.conv(out)
return feature_maps, prediction_map
class MultiScaleDiscriminator(nn.Module):
"""
Multi-scale (scale) discriminator
"""
def __init__(self, scales=(), **kwargs):
super(MultiScaleDiscriminator, self).__init__()
self.scales = scales
discs = {}
for scale in scales:
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
self.discs = nn.ModuleDict(discs)
def forward(self, x, kp=None):
out_dict = {}
for scale, disc in self.discs.items():
scale = str(scale).replace('-', '.')
key = 'prediction_' + scale
feature_maps, prediction_map = disc(x[key], kp)
out_dict['feature_maps_' + scale] = feature_maps
out_dict['prediction_map_' + scale] = prediction_map
return out_dict

@ -0,0 +1,97 @@
import torch
from torch import nn
import torch.nn.functional as F
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
from modules.dense_motion import DenseMotionNetwork
class OcclusionAwareGenerator(nn.Module):
"""
Generator that given source image and and keypoints try to transform image according to movement trajectories
induced by keypoints. Generator follows Johnson architecture.
"""
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
super(OcclusionAwareGenerator, self).__init__()
if dense_motion_params is not None:
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
estimate_occlusion_map=estimate_occlusion_map,
**dense_motion_params)
else:
self.dense_motion_network = None
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
down_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** i))
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.down_blocks = nn.ModuleList(down_blocks)
up_blocks = []
for i in range(num_down_blocks):
in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
self.up_blocks = nn.ModuleList(up_blocks)
self.bottleneck = torch.nn.Sequential()
in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
for i in range(num_bottleneck_blocks):
self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
self.estimate_occlusion_map = estimate_occlusion_map
self.num_channels = num_channels
def deform_input(self, inp, deformation):
_, h_old, w_old, _ = deformation.shape
_, _, h, w = inp.shape
if h_old != h or w_old != w:
deformation = deformation.permute(0, 3, 1, 2)
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
deformation = deformation.permute(0, 2, 3, 1)
return F.grid_sample(inp, deformation)
def forward(self, source_image, kp_driving, kp_source):
# Encoding (downsampling) part
out = self.first(source_image)
for i in range(len(self.down_blocks)):
out = self.down_blocks[i](out)
# Transforming feature representation according to deformation and occlusion
output_dict = {}
if self.dense_motion_network is not None:
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
kp_source=kp_source)
output_dict['mask'] = dense_motion['mask']
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map']
output_dict['occlusion_map'] = occlusion_map
else:
occlusion_map = None
deformation = dense_motion['deformation']
out = self.deform_input(out, deformation)
if occlusion_map is not None:
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
out = out * occlusion_map
output_dict["deformed"] = self.deform_input(source_image, deformation)
# Decoding part
out = self.bottleneck(out)
for i in range(len(self.up_blocks)):
out = self.up_blocks[i](out)
out = self.final(out)
out = F.sigmoid(out)
output_dict["prediction"] = out
return output_dict

@ -0,0 +1,75 @@
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d
class KPDetector(nn.Module):
"""
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
"""
def __init__(self, block_expansion, num_kp, num_channels, max_features,
num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
single_jacobian_map=False, pad=0):
super(KPDetector, self).__init__()
self.predictor = Hourglass(block_expansion, in_features=num_channels,
max_features=max_features, num_blocks=num_blocks)
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
padding=pad)
if estimate_jacobian:
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
self.jacobian.weight.data.zero_()
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
else:
self.jacobian = None
self.temperature = temperature
self.scale_factor = scale_factor
if self.scale_factor != 1:
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
def gaussian2kp(self, heatmap):
"""
Extract the mean and from a heatmap
"""
shape = heatmap.shape
heatmap = heatmap.unsqueeze(-1)
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
value = (heatmap * grid).sum(dim=(2, 3))
kp = {'value': value}
return kp
def forward(self, x):
if self.scale_factor != 1:
x = self.down(x)
feature_map = self.predictor(x)
prediction = self.kp(feature_map)
final_shape = prediction.shape
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
heatmap = F.softmax(heatmap / self.temperature, dim=2)
heatmap = heatmap.view(*final_shape)
out = self.gaussian2kp(heatmap)
if self.jacobian is not None:
jacobian_map = self.jacobian(feature_map)
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
final_shape[3])
heatmap = heatmap.unsqueeze(2)
jacobian = heatmap * jacobian_map
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
jacobian = jacobian.sum(dim=-1)
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
out['jacobian'] = jacobian
return out

@ -0,0 +1,259 @@
from torch import nn
import torch
import torch.nn.functional as F
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
from torchvision import models
import numpy as np
from torch.autograd import grad
class Vgg19(torch.nn.Module):
"""
Vgg19 network for perceptual loss. See Sec 3.3.
"""
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
requires_grad=False)
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
requires_grad=False)
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
X = (X - self.mean) / self.std
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class ImagePyramide(torch.nn.Module):
"""
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
"""
def __init__(self, scales, num_channels):
super(ImagePyramide, self).__init__()
downs = {}
for scale in scales:
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
self.downs = nn.ModuleDict(downs)
def forward(self, x):
out_dict = {}
for scale, down_module in self.downs.items():
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
return out_dict
class Transform:
"""
Random tps transformation for equivariance constraints. See Sec 3.3
"""
def __init__(self, bs, **kwargs):
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
self.bs = bs
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
self.tps = True
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
self.control_points = self.control_points.unsqueeze(0)
self.control_params = torch.normal(mean=0,
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
else:
self.tps = False
def transform_frame(self, frame):
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
return F.grid_sample(frame, grid, padding_mode="reflection")
def warp_coordinates(self, coordinates):
theta = self.theta.type(coordinates.type())
theta = theta.unsqueeze(1)
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
transformed = transformed.squeeze(-1)
if self.tps:
control_points = self.control_points.type(coordinates.type())
control_params = self.control_params.type(coordinates.type())
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
distances = torch.abs(distances).sum(-1)
result = distances ** 2
result = result * torch.log(distances + 1e-6)
result = result * control_params
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
transformed = transformed + result
return transformed
def jacobian(self, coordinates):
new_coordinates = self.warp_coordinates(coordinates)
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
return jacobian
def detach_kp(kp):
return {key: value.detach() for key, value in kp.items()}
class GeneratorFullModel(torch.nn.Module):
"""
Merge all generator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, generator, discriminator, train_params):
super(GeneratorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = train_params['scales']
self.disc_scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
if sum(self.loss_weights['perceptual']) != 0:
self.vgg = Vgg19()
if torch.cuda.is_available():
self.vgg = self.vgg.cuda()
def forward(self, x):
kp_source = self.kp_extractor(x['source'])
kp_driving = self.kp_extractor(x['driving'])
generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
loss_values = {}
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'])
if sum(self.loss_weights['perceptual']) != 0:
value_total = 0
for scale in self.scales:
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
for i, weight in enumerate(self.loss_weights['perceptual']):
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
value_total += self.loss_weights['perceptual'][i] * value
loss_values['perceptual'] = value_total
if self.loss_weights['generator_gan'] != 0:
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
value_total = 0
for scale in self.disc_scales:
key = 'prediction_map_%s' % scale
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
value_total += self.loss_weights['generator_gan'] * value
loss_values['gen_gan'] = value_total
if sum(self.loss_weights['feature_matching']) != 0:
value_total = 0
for scale in self.disc_scales:
key = 'feature_maps_%s' % scale
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
if self.loss_weights['feature_matching'][i] == 0:
continue
value = torch.abs(a - b).mean()
value_total += self.loss_weights['feature_matching'][i] * value
loss_values['feature_matching'] = value_total
if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
transformed_frame = transform.transform_frame(x['driving'])
transformed_kp = self.kp_extractor(transformed_frame)
generated['transformed_frame'] = transformed_frame
generated['transformed_kp'] = transformed_kp
## Value loss part
if self.loss_weights['equivariance_value'] != 0:
value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
## jacobian loss part
if self.loss_weights['equivariance_jacobian'] != 0:
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
transformed_kp['jacobian'])
normed_driving = torch.inverse(kp_driving['jacobian'])
normed_transformed = jacobian_transformed
value = torch.matmul(normed_driving, normed_transformed)
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
value = torch.abs(eye - value).mean()
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
return loss_values, generated
class DiscriminatorFullModel(torch.nn.Module):
"""
Merge all discriminator related updates into single model for better multi-gpu usage
"""
def __init__(self, kp_extractor, generator, discriminator, train_params):
super(DiscriminatorFullModel, self).__init__()
self.kp_extractor = kp_extractor
self.generator = generator
self.discriminator = discriminator
self.train_params = train_params
self.scales = self.discriminator.scales
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
if torch.cuda.is_available():
self.pyramid = self.pyramid.cuda()
self.loss_weights = train_params['loss_weights']
def forward(self, x, generated):
pyramide_real = self.pyramid(x['driving'])
pyramide_generated = self.pyramid(generated['prediction'].detach())
kp_driving = generated['kp_driving']
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
loss_values = {}
value_total = 0
for scale in self.scales:
key = 'prediction_map_%s' % scale
value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
value_total += self.loss_weights['discriminator_gan'] * value.mean()
loss_values['disc_gan'] = value_total
return loss_values

@ -0,0 +1,245 @@
from torch import nn
import torch.nn.functional as F
import torch
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
def kp2gaussian(kp, spatial_size, kp_variance):
"""
Transform a keypoint into gaussian like representation
"""
mean = kp['value']
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
number_of_leading_dimensions = len(mean.shape) - 1
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
coordinate_grid = coordinate_grid.view(*shape)
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
coordinate_grid = coordinate_grid.repeat(*repeats)
# Preprocess kp shape
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
mean = mean.view(*shape)
mean_sub = (coordinate_grid - mean)
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
return out
def make_coordinate_grid(spatial_size, type):
"""
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
"""
h, w = spatial_size
x = torch.arange(w).type(type)
y = torch.arange(h).type(type)
x = (2 * (x / (w - 1)) - 1)
y = (2 * (y / (h - 1)) - 1)
yy = y.view(-1, 1).repeat(1, w)
xx = x.view(1, -1).repeat(h, 1)
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
return meshed
class ResBlock2d(nn.Module):
"""
Res block, preserve spatial resolution.
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
padding=padding)
self.norm1 = BatchNorm2d(in_features, affine=True)
self.norm2 = BatchNorm2d(in_features, affine=True)
def forward(self, x):
out = self.norm1(x)
out = F.relu(out)
out = self.conv1(out)
out = self.norm2(out)
out = F.relu(out)
out = self.conv2(out)
out += x
return out
class UpBlock2d(nn.Module):
"""
Upsampling block for use in decoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(UpBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
def forward(self, x):
out = F.interpolate(x, scale_factor=2)
out = self.conv(out)
out = self.norm(out)
out = F.relu(out)
return out
class DownBlock2d(nn.Module):
"""
Downsampling block for use in encoder.
"""
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
out = self.pool(out)
return out
class SameBlock2d(nn.Module):
"""
Simple block, preserve spatial resolution.
"""
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
super(SameBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
kernel_size=kernel_size, padding=padding, groups=groups)
self.norm = BatchNorm2d(out_features, affine=True)
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = F.relu(out)
return out
class Encoder(nn.Module):
"""
Hourglass Encoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Encoder, self).__init__()
down_blocks = []
for i in range(num_blocks):
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
min(max_features, block_expansion * (2 ** (i + 1))),
kernel_size=3, padding=1))
self.down_blocks = nn.ModuleList(down_blocks)
def forward(self, x):
outs = [x]
for down_block in self.down_blocks:
outs.append(down_block(outs[-1]))
return outs
class Decoder(nn.Module):
"""
Hourglass Decoder
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Decoder, self).__init__()
up_blocks = []
for i in range(num_blocks)[::-1]:
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
out_filters = min(max_features, block_expansion * (2 ** i))
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
self.up_blocks = nn.ModuleList(up_blocks)
self.out_filters = block_expansion + in_features
def forward(self, x):
out = x.pop()
for up_block in self.up_blocks:
out = up_block(out)
skip = x.pop()
out = torch.cat([out, skip], dim=1)
return out
class Hourglass(nn.Module):
"""
Hourglass architecture.
"""
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
super(Hourglass, self).__init__()
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
self.out_filters = self.decoder.out_filters
def forward(self, x):
return self.decoder(self.encoder(x))
class AntiAliasInterpolation2d(nn.Module):
"""
Band-limited downsampling, for better preservation of the input signal.
"""
def __init__(self, channels, scale):
super(AntiAliasInterpolation2d, self).__init__()
sigma = (1 / scale - 1) / 2
kernel_size = 2 * round(sigma * 4) + 1
self.ka = kernel_size // 2
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
kernel_size = [kernel_size, kernel_size]
sigma = [sigma, sigma]
# The gaussian kernel is the product of the
# gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid(
[
torch.arange(size, dtype=torch.float32)
for size in kernel_size
]
)
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
# Reshape to depthwise convolutional weight
kernel = kernel.view(1, 1, *kernel.size())
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
self.register_buffer('weight', kernel)
self.groups = channels
self.scale = scale
inv_scale = 1 / scale
self.int_inv_scale = int(inv_scale)
def forward(self, input):
if self.scale == 1.0:
return input
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
out = F.conv2d(out, weight=self.weight, groups=self.groups)
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
return out

File diff suppressed because it is too large Load Diff

@ -0,0 +1,67 @@
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from logger import Logger, Visualizer
import numpy as np
import imageio
from sync_batchnorm import DataParallelWithCallback
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
png_dir = os.path.join(log_dir, 'reconstruction/png')
log_dir = os.path.join(log_dir, 'reconstruction')
if checkpoint is not None:
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
else:
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(png_dir):
os.makedirs(png_dir)
loss_list = []
if torch.cuda.is_available():
generator = DataParallelWithCallback(generator)
kp_detector = DataParallelWithCallback(kp_detector)
generator.eval()
kp_detector.eval()
for it, x in tqdm(enumerate(dataloader)):
if config['reconstruction_params']['num_videos'] is not None:
if it > config['reconstruction_params']['num_videos']:
break
with torch.no_grad():
predictions = []
visualizations = []
if torch.cuda.is_available():
x['video'] = x['video'].cuda()
kp_source = kp_detector(x['video'][:, :, 0])
for frame_idx in range(x['video'].shape[2]):
source = x['video'][:, :, 0]
driving = x['video'][:, :, frame_idx]
kp_driving = kp_detector(driving)
out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
out['kp_source'] = kp_source
out['kp_driving'] = kp_driving
del out['sparse_deformed']
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
driving=driving, out=out)
visualizations.append(visualization)
loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())
predictions = np.concatenate(predictions, axis=1)
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
image_name = x['name'][0] + config['reconstruction_params']['format']
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
print("Reconstruction loss: %s" % np.mean(loss_list))

@ -0,0 +1,15 @@
ffmpeg-python==0.2.0
imageio==2.22.0
imageio-ffmpeg==0.4.7
matplotlib==3.6.0
numpy==1.23.3
pandas==1.5.0
python-dateutil==2.8.2
pytz==2022.2.1
PyYAML==6.0
scikit-image==0.19.3
scikit-learn==1.1.2
scipy==1.9.1
torch==1.12.1
torchvision==0.13.1
tqdm==4.64.1

@ -0,0 +1,87 @@
import matplotlib
matplotlib.use('Agg')
import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy
from frames_dataset import FramesDataset
from modules.generator import OcclusionAwareGenerator
from modules.discriminator import MultiScaleDiscriminator
from modules.keypoint_detector import KPDetector
import torch
from train import train
from reconstruction import reconstruction
from animate import animate
if __name__ == "__main__":
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
parser = ArgumentParser()
parser.add_argument("--config", required=True, help="path to config")
parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"])
parser.add_argument("--log_dir", default='log', help="path to log into")
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))),
help="Names of the devices comma separated.")
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
parser.set_defaults(verbose=False)
opt = parser.parse_args()
with open(opt.config) as f:
config = yaml.load(f)
if opt.checkpoint is not None:
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
else:
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
generator.to(opt.device_ids[0])
if opt.verbose:
print(generator)
discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
discriminator.to(opt.device_ids[0])
if opt.verbose:
print(discriminator)
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
kp_detector.to(opt.device_ids[0])
if opt.verbose:
print(kp_detector)
dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
copy(opt.config, log_dir)
if opt.mode == 'train':
print("Training...")
train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids)
elif opt.mode == 'reconstruction':
print("Reconstruction...")
reconstruction(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
elif opt.mode == 'animate':
print("Animate...")
animate(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 MiB

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
# File : __init__.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .replicate import DataParallelWithCallback, patch_replication_callback

@ -0,0 +1,315 @@
# -*- coding: utf-8 -*-
# File : batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import collections
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
from .comm import SyncMaster
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
def _sum_ft(tensor):
"""sum over the first and last dimention"""
return tensor.sum(dim=0).sum(dim=-1)
def _unsqueeze_ft(tensor):
"""add new dementions at the front and the tail"""
return tensor.unsqueeze(0).unsqueeze(-1)
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SynchronizedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
# Always using same "device order" makes the ReduceAdd operation faster.
# Thanks to:: Tete Xiao (http://tetexiao.com/)
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
return mean, bias_var.clamp(self.eps) ** -0.5
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
mini-batch.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm1d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
of 3d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm2d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
of 4d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm3d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
or Spatio-temporal BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x depth x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)

@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def __getstate__(self):
return {'master_callback': self._master_callback}
def __setstate__(self, state):
self.__init__(state['master_callback'])
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)

@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
# File : replicate.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'execute_replication_callbacks',
'DataParallelWithCallback',
'patch_replication_callback'
]
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
class DataParallelWithCallback(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
def replicate(self, module, device_ids):
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# File : unittest.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import unittest
import numpy as np
from torch.autograd import Variable
def as_numpy(v):
if isinstance(v, Variable):
v = v.data
return v.cpu().numpy()
class TorchTestCase(unittest.TestCase):
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
npa, npb = as_numpy(a), as_numpy(b)
self.assertTrue(
np.allclose(npa, npb, atol=atol),
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
)

@ -0,0 +1,87 @@
from tqdm import trange
import torch
from torch.utils.data import DataLoader
from logger import Logger
from modules.model import GeneratorFullModel, DiscriminatorFullModel
from torch.optim.lr_scheduler import MultiStepLR
from sync_batchnorm import DataParallelWithCallback
from frames_dataset import DatasetRepeater
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids):
train_params = config['train_params']
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))
if checkpoint is not None:
start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
optimizer_generator, optimizer_discriminator,
None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
else:
start_epoch = 0
scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
last_epoch=start_epoch - 1)
scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
last_epoch=start_epoch - 1)
scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
dataset = DatasetRepeater(dataset, train_params['num_repeats'])
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True)
generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)
if torch.cuda.is_available():
generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
for epoch in trange(start_epoch, train_params['num_epochs']):
for x in dataloader:
losses_generator, generated = generator_full(x)
loss_values = [val.mean() for val in losses_generator.values()]
loss = sum(loss_values)
loss.backward()
optimizer_generator.step()
optimizer_generator.zero_grad()
optimizer_kp_detector.step()
optimizer_kp_detector.zero_grad()
if train_params['loss_weights']['generator_gan'] != 0:
optimizer_discriminator.zero_grad()
losses_discriminator = discriminator_full(x, generated)
loss_values = [val.mean() for val in losses_discriminator.values()]
loss = sum(loss_values)
loss.backward()
optimizer_discriminator.step()
optimizer_discriminator.zero_grad()
else:
losses_discriminator = {}
losses_generator.update(losses_discriminator)
losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
logger.log_iter(losses=losses)
scheduler_generator.step()
scheduler_discriminator.step()
scheduler_kp_detector.step()
logger.log_epoch(epoch, {'generator': generator,
'discriminator': discriminator,
'kp_detector': kp_detector,
'optimizer_generator': optimizer_generator,
'optimizer_discriminator': optimizer_discriminator,
'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)

Binary file not shown.

@ -0,0 +1,2 @@
body{-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale;font-family:-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,Oxygen,Ubuntu,Cantarell,Fira Sans,Droid Sans,Helvetica Neue,sans-serif;margin:0}code{font-family:source-code-pro,Menlo,Monaco,Consolas,Courier New,monospace}
/*# sourceMappingURL=main.e6c13ad2.css.map*/

@ -0,0 +1 @@
{"version":3,"file":"static/css/main.e6c13ad2.css","mappings":"AAAA,KAKE,kCAAmC,CACnC,iCAAkC,CAJlC,mIAEY,CAHZ,QAMF,CAEA,KACE,uEAEF","sources":["index.css"],"sourcesContent":["body {\n margin: 0;\n font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',\n 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',\n sans-serif;\n -webkit-font-smoothing: antialiased;\n -moz-osx-font-smoothing: grayscale;\n}\n\ncode {\n font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New',\n monospace;\n}\n"],"names":[],"sourceRoot":""}

@ -0,0 +1,2 @@
"use strict";(self.webpackChunkimage_processing_app=self.webpackChunkimage_processing_app||[]).push([[453],{6453:(e,t,n)=>{n.r(t),n.d(t,{getCLS:()=>y,getFCP:()=>g,getFID:()=>C,getLCP:()=>P,getTTFB:()=>D});var i,r,a,o,u=function(e,t){return{name:e,value:void 0===t?-1:t,delta:0,entries:[],id:"v2-".concat(Date.now(),"-").concat(Math.floor(8999999999999*Math.random())+1e12)}},c=function(e,t){try{if(PerformanceObserver.supportedEntryTypes.includes(e)){if("first-input"===e&&!("PerformanceEventTiming"in self))return;var n=new PerformanceObserver((function(e){return e.getEntries().map(t)}));return n.observe({type:e,buffered:!0}),n}}catch(e){}},s=function(e,t){var n=function n(i){"pagehide"!==i.type&&"hidden"!==document.visibilityState||(e(i),t&&(removeEventListener("visibilitychange",n,!0),removeEventListener("pagehide",n,!0)))};addEventListener("visibilitychange",n,!0),addEventListener("pagehide",n,!0)},f=function(e){addEventListener("pageshow",(function(t){t.persisted&&e(t)}),!0)},m=function(e,t,n){var i;return function(r){t.value>=0&&(r||n)&&(t.delta=t.value-(i||0),(t.delta||void 0===i)&&(i=t.value,e(t)))}},p=-1,v=function(){return"hidden"===document.visibilityState?0:1/0},d=function(){s((function(e){var t=e.timeStamp;p=t}),!0)},l=function(){return p<0&&(p=v(),d(),f((function(){setTimeout((function(){p=v(),d()}),0)}))),{get firstHiddenTime(){return p}}},g=function(e,t){var n,i=l(),r=u("FCP"),a=function(e){"first-contentful-paint"===e.name&&(s&&s.disconnect(),e.startTime<i.firstHiddenTime&&(r.value=e.startTime,r.entries.push(e),n(!0)))},o=window.performance&&performance.getEntriesByName&&performance.getEntriesByName("first-contentful-paint")[0],s=o?null:c("paint",a);(o||s)&&(n=m(e,r,t),o&&a(o),f((function(i){r=u("FCP"),n=m(e,r,t),requestAnimationFrame((function(){requestAnimationFrame((function(){r.value=performance.now()-i.timeStamp,n(!0)}))}))})))},h=!1,T=-1,y=function(e,t){h||(g((function(e){T=e.value})),h=!0);var n,i=function(t){T>-1&&e(t)},r=u("CLS",0),a=0,o=[],p=function(e){if(!e.hadRecentInput){var t=o[0],i=o[o.length-1];a&&e.startTime-i.startTime<1e3&&e.startTime-t.startTime<5e3?(a+=e.value,o.push(e)):(a=e.value,o=[e]),a>r.value&&(r.value=a,r.entries=o,n())}},v=c("layout-shift",p);v&&(n=m(i,r,t),s((function(){v.takeRecords().map(p),n(!0)})),f((function(){a=0,T=-1,r=u("CLS",0),n=m(i,r,t)})))},E={passive:!0,capture:!0},w=new Date,L=function(e,t){i||(i=t,r=e,a=new Date,F(removeEventListener),S())},S=function(){if(r>=0&&r<a-w){var e={entryType:"first-input",name:i.type,target:i.target,cancelable:i.cancelable,startTime:i.timeStamp,processingStart:i.timeStamp+r};o.forEach((function(t){t(e)})),o=[]}},b=function(e){if(e.cancelable){var t=(e.timeStamp>1e12?new Date:performance.now())-e.timeStamp;"pointerdown"==e.type?function(e,t){var n=function(){L(e,t),r()},i=function(){r()},r=function(){removeEventListener("pointerup",n,E),removeEventListener("pointercancel",i,E)};addEventListener("pointerup",n,E),addEventListener("pointercancel",i,E)}(t,e):L(t,e)}},F=function(e){["mousedown","keydown","touchstart","pointerdown"].forEach((function(t){return e(t,b,E)}))},C=function(e,t){var n,a=l(),p=u("FID"),v=function(e){e.startTime<a.firstHiddenTime&&(p.value=e.processingStart-e.startTime,p.entries.push(e),n(!0))},d=c("first-input",v);n=m(e,p,t),d&&s((function(){d.takeRecords().map(v),d.disconnect()}),!0),d&&f((function(){var a;p=u("FID"),n=m(e,p,t),o=[],r=-1,i=null,F(addEventListener),a=v,o.push(a),S()}))},k={},P=function(e,t){var n,i=l(),r=u("LCP"),a=function(e){var t=e.startTime;t<i.firstHiddenTime&&(r.value=t,r.entries.push(e),n())},o=c("largest-contentful-paint",a);if(o){n=m(e,r,t);var p=function(){k[r.id]||(o.takeRecords().map(a),o.disconnect(),k[r.id]=!0,n(!0))};["keydown","click"].forEach((function(e){addEventListener(e,p,{once:!0,capture:!0})})),s(p,!0),f((function(i){r=u("LCP"),n=m(e,r,t),requestAnimationFrame((function(){requestAnimationFrame((function(){r.value=performance.now()-i.timeStamp,k[r.id]=!0,n(!0)}))}))}))}},D=function(e){var t,n=u("TTFB");t=function(){try{var t=performance.getEntriesByType("navigation")[0]||function(){var e=performance.timing,t={entryType:"navigation",startTime:0};for(var n in e)"navigationStart"!==n&&"toJSON"!==n&&(t[n]=Math.max(e[n]-e.navigationStart,0));return t}();if(n.value=n.delta=t.responseStart,n.value<0||n.value>performance.now())return;n.entries=[t],e(n)}catch(e){}},"complete"===document.readyState?setTimeout(t,0):addEventListener("load",(function(){return setTimeout(t,0)}))}}}]);
//# sourceMappingURL=453.8beb5808.chunk.js.map

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

@ -0,0 +1,91 @@
/**
* @license React
* react-dom.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* react-is.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* react-jsx-runtime.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* react.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @license React
* scheduler.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
/**
* @remix-run/router v1.17.0
*
* Copyright (c) Remix Software Inc.
*
* This source code is licensed under the MIT license found in the
* LICENSE.md file in the root directory of this source tree.
*
* @license MIT
*/
/**
* React Router DOM v6.24.0
*
* Copyright (c) Remix Software Inc.
*
* This source code is licensed under the MIT license found in the
* LICENSE.md file in the root directory of this source tree.
*
* @license MIT
*/
/**
* React Router v6.24.0
*
* Copyright (c) Remix Software Inc.
*
* This source code is licensed under the MIT license found in the
* LICENSE.md file in the root directory of this source tree.
*
* @license MIT
*/
/** @license React v16.13.1
* react-is.production.min.js
*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save