@ -1,21 +1,18 @@
|
|||||||
# ---> Vim
|
# Build and Release Folders
|
||||||
# Swap
|
bin-debug/
|
||||||
[._]*.s[a-v][a-z]
|
bin-release/
|
||||||
!*.svg # comment out if you don't need vector files
|
[Oo]bj/
|
||||||
[._]*.sw[a-p]
|
[Bb]in/
|
||||||
[._]s[a-rt-v][a-z]
|
|
||||||
[._]ss[a-gi-z]
|
|
||||||
[._]sw[a-p]
|
|
||||||
|
|
||||||
# Session
|
# Other files and folders
|
||||||
Session.vim
|
.settings/
|
||||||
Sessionx.vim
|
|
||||||
|
|
||||||
# Temporary
|
# Executables
|
||||||
.netrwhist
|
*.swf
|
||||||
*~
|
*.air
|
||||||
# Auto-generated tag files
|
*.ipa
|
||||||
tags
|
*.apk
|
||||||
# Persistent undo
|
|
||||||
[._]*.un~
|
|
||||||
|
|
||||||
|
# Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties`
|
||||||
|
# should NOT be excluded as they contain compiler settings and other important
|
||||||
|
# information for Eclipse / Flash Builder.
|
||||||
|
@ -0,0 +1,172 @@
|
|||||||
|
from flask import Flask, request, send_file, jsonify, send_from_directory
|
||||||
|
from flask_cors import CORS
|
||||||
|
from PIL import Image, ImageFilter, ImageEnhance, ImageOps
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import imageio
|
||||||
|
import imageio_ffmpeg as ffmpeg
|
||||||
|
import numpy as np
|
||||||
|
from moviepy.editor import VideoFileClip
|
||||||
|
from skimage.transform import resize
|
||||||
|
from skimage import img_as_ubyte
|
||||||
|
from tqdm import tqdm
|
||||||
|
sys.path.append(os.path.abspath('./firstordermodel'))
|
||||||
|
sys.path.append(".")
|
||||||
|
from modules.generator import OcclusionAwareGenerator
|
||||||
|
from modules.keypoint_detector import KPDetector
|
||||||
|
from animate import normalize_kp
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
CORS(app)
|
||||||
|
|
||||||
|
@app.route('/process-image', methods=['POST'])
|
||||||
|
def process_image():
|
||||||
|
file = request.files['file']
|
||||||
|
operation = request.form['operation']
|
||||||
|
parameter = request.form['parameter']
|
||||||
|
|
||||||
|
image = Image.open(file.stream)
|
||||||
|
|
||||||
|
if operation == 'rotate':
|
||||||
|
image = image.rotate(float(parameter))
|
||||||
|
elif operation == 'flip':
|
||||||
|
image = ImageOps.flip(image)
|
||||||
|
elif operation == 'scale':
|
||||||
|
scale_factor = float(parameter)
|
||||||
|
w, h = image.size
|
||||||
|
new_width = int(w*scale_factor)
|
||||||
|
new_height = int(h*scale_factor)
|
||||||
|
image = image.resize((new_width,new_height), Image.ANTIALIAS)
|
||||||
|
# blank=(new_width-new_height)*scale_factor
|
||||||
|
# image=image.crop((0,-blank,new_width,new_width-blank))
|
||||||
|
elif operation == 'filter':
|
||||||
|
if parameter == 'BLUR':
|
||||||
|
image = image.filter(ImageFilter.BLUR)
|
||||||
|
elif parameter == 'EMBOSS':
|
||||||
|
image = image.filter(ImageFilter.EMBOSS)
|
||||||
|
elif parameter == 'CONTOUR':
|
||||||
|
image = image.filter(ImageFilter.CONTOUR)
|
||||||
|
elif parameter == 'SHARPEN':
|
||||||
|
image = image.filter(ImageFilter.SHARPEN)
|
||||||
|
elif operation == 'color_adjust':
|
||||||
|
r, g, b = map(float, parameter.split(','))
|
||||||
|
r = r if r else 1.0
|
||||||
|
g = g if g else 1.0
|
||||||
|
b = b if b else 1.0
|
||||||
|
r_channel, g_channel, b_channel = image.split()
|
||||||
|
r_channel = r_channel.point(lambda i: i * r)
|
||||||
|
g_channel = g_channel.point(lambda i: i * g)
|
||||||
|
b_channel = b_channel.point(lambda i: i * b)
|
||||||
|
image = Image.merge('RGB', (r_channel, g_channel, b_channel))
|
||||||
|
elif operation == 'contrast':
|
||||||
|
enhancer = ImageEnhance.Contrast(image)
|
||||||
|
image = enhancer.enhance(2)
|
||||||
|
elif operation == 'smooth':
|
||||||
|
image = image.filter(ImageFilter.SMOOTH)
|
||||||
|
|
||||||
|
img_io = io.BytesIO()
|
||||||
|
image.save(img_io, 'JPEG')
|
||||||
|
img_io.seek(0)
|
||||||
|
|
||||||
|
return send_file(img_io, mimetype='image/jpeg')
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoints(config_path, checkpoint_path, cpu=True):
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
if not cpu:
|
||||||
|
generator.cuda()
|
||||||
|
|
||||||
|
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
if not cpu:
|
||||||
|
kp_detector.cuda()
|
||||||
|
|
||||||
|
if cpu:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
|
||||||
|
generator.load_state_dict(checkpoint['generator'])
|
||||||
|
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
||||||
|
|
||||||
|
if not cpu:
|
||||||
|
generator = DataParallelWithCallback(generator)
|
||||||
|
kp_detector = DataParallelWithCallback(kp_detector)
|
||||||
|
|
||||||
|
generator.eval()
|
||||||
|
kp_detector.eval()
|
||||||
|
|
||||||
|
return generator, kp_detector
|
||||||
|
|
||||||
|
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True):
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = []
|
||||||
|
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
||||||
|
if not cpu:
|
||||||
|
source = source.cuda()
|
||||||
|
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
||||||
|
kp_source = kp_detector(source)
|
||||||
|
kp_driving_initial = kp_detector(driving[:, :, 0])
|
||||||
|
|
||||||
|
for frame_idx in tqdm(range(driving.shape[2])):
|
||||||
|
driving_frame = driving[:, :, frame_idx]
|
||||||
|
if not cpu:
|
||||||
|
driving_frame = driving_frame.cuda()
|
||||||
|
kp_driving = kp_detector(driving_frame)
|
||||||
|
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
||||||
|
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
||||||
|
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
|
||||||
|
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
|
||||||
|
|
||||||
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
@app.route('/motion-drive', methods=['POST'])
|
||||||
|
def motion_drive():
|
||||||
|
image_file = request.files['image']
|
||||||
|
video_file = request.files['video']
|
||||||
|
|
||||||
|
source_image = imageio.imread(image_file)
|
||||||
|
|
||||||
|
# 保存视频文件到临时路径
|
||||||
|
video_path = f"./data/{uuid.uuid1().hex}.mp4"
|
||||||
|
video_file.save(video_path)
|
||||||
|
|
||||||
|
reader = imageio.get_reader(video_path, 'ffmpeg')
|
||||||
|
fps = reader.get_meta_data()['fps']
|
||||||
|
driving_video = []
|
||||||
|
try:
|
||||||
|
for im in reader:
|
||||||
|
driving_video.append(im)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
source_image = resize(source_image, (256, 256))[..., :3]
|
||||||
|
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
||||||
|
config_path = "./firstordermodel/config/vox-adv-256.yaml"
|
||||||
|
checkpoint_path = "./firstordermodel/checkpoints/vox-adv-cpk.pth.tar"
|
||||||
|
generator, kp_detector = load_checkpoints(config_path=config_path, checkpoint_path=checkpoint_path, cpu=True)
|
||||||
|
|
||||||
|
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True)
|
||||||
|
|
||||||
|
result_filename = f"result_{uuid.uuid1().hex}.mp4"
|
||||||
|
result_path = os.path.join('data', result_filename)
|
||||||
|
imageio.mimsave(result_path, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
||||||
|
|
||||||
|
return jsonify({"video_url": f"/data/{result_filename}"})
|
||||||
|
|
||||||
|
@app.route('/data/<path:filename>', methods=['GET'])
|
||||||
|
def download_file(filename):
|
||||||
|
return send_from_directory('data', filename)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(debug=True)
|
@ -0,0 +1,3 @@
|
|||||||
|
/venv
|
||||||
|
.git
|
||||||
|
__pycache__
|
@ -0,0 +1,3 @@
|
|||||||
|
/.vscode
|
||||||
|
__pycache__
|
||||||
|
/venv
|
@ -0,0 +1,13 @@
|
|||||||
|
FROM nvcr.io/nvidia/pytorch:21.02-py3
|
||||||
|
|
||||||
|
RUN DEBIAN_FRONTEND=noninteractive apt-get -qq update \
|
||||||
|
&& DEBIAN_FRONTEND=noninteractive apt-get -qqy install python3-pip ffmpeg git less nano libsm6 libxext6 libxrender-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN pip3 install --upgrade pip
|
||||||
|
RUN pip3 install \
|
||||||
|
git+https://github.com/1adrianb/face-alignment \
|
||||||
|
-r requirements.txt
|
@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from frames_dataset import PairedDataset
|
||||||
|
from logger import Logger, Visualizer
|
||||||
|
import imageio
|
||||||
|
from scipy.spatial import ConvexHull
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sync_batchnorm import DataParallelWithCallback
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
|
||||||
|
use_relative_movement=False, use_relative_jacobian=False):
|
||||||
|
if adapt_movement_scale:
|
||||||
|
source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
|
||||||
|
driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
|
||||||
|
adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
|
||||||
|
else:
|
||||||
|
adapt_movement_scale = 1
|
||||||
|
|
||||||
|
kp_new = {k: v for k, v in kp_driving.items()}
|
||||||
|
|
||||||
|
if use_relative_movement:
|
||||||
|
kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
|
||||||
|
kp_value_diff *= adapt_movement_scale
|
||||||
|
kp_new['value'] = kp_value_diff + kp_source['value']
|
||||||
|
|
||||||
|
if use_relative_jacobian:
|
||||||
|
jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
|
||||||
|
kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
|
||||||
|
|
||||||
|
return kp_new
|
||||||
|
|
||||||
|
|
||||||
|
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
|
||||||
|
log_dir = os.path.join(log_dir, 'animation')
|
||||||
|
png_dir = os.path.join(log_dir, 'png')
|
||||||
|
animate_params = config['animate_params']
|
||||||
|
|
||||||
|
dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
|
||||||
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
||||||
|
|
||||||
|
if checkpoint is not None:
|
||||||
|
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
|
||||||
|
else:
|
||||||
|
raise AttributeError("Checkpoint should be specified for mode='animate'.")
|
||||||
|
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
|
||||||
|
if not os.path.exists(png_dir):
|
||||||
|
os.makedirs(png_dir)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
generator = DataParallelWithCallback(generator)
|
||||||
|
kp_detector = DataParallelWithCallback(kp_detector)
|
||||||
|
|
||||||
|
generator.eval()
|
||||||
|
kp_detector.eval()
|
||||||
|
|
||||||
|
for it, x in tqdm(enumerate(dataloader)):
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = []
|
||||||
|
visualizations = []
|
||||||
|
|
||||||
|
driving_video = x['driving_video']
|
||||||
|
source_frame = x['source_video'][:, :, 0, :, :]
|
||||||
|
|
||||||
|
kp_source = kp_detector(source_frame)
|
||||||
|
kp_driving_initial = kp_detector(driving_video[:, :, 0])
|
||||||
|
|
||||||
|
for frame_idx in range(driving_video.shape[2]):
|
||||||
|
driving_frame = driving_video[:, :, frame_idx]
|
||||||
|
kp_driving = kp_detector(driving_frame)
|
||||||
|
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
||||||
|
kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
|
||||||
|
out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)
|
||||||
|
|
||||||
|
out['kp_driving'] = kp_driving
|
||||||
|
out['kp_source'] = kp_source
|
||||||
|
out['kp_norm'] = kp_norm
|
||||||
|
|
||||||
|
del out['sparse_deformed']
|
||||||
|
|
||||||
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
||||||
|
|
||||||
|
visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
|
||||||
|
driving=driving_frame, out=out)
|
||||||
|
visualization = visualization
|
||||||
|
visualizations.append(visualization)
|
||||||
|
|
||||||
|
predictions = np.concatenate(predictions, axis=1)
|
||||||
|
result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
|
||||||
|
imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))
|
||||||
|
|
||||||
|
image_name = result_name + animate_params['format']
|
||||||
|
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
|
@ -0,0 +1,345 @@
|
|||||||
|
"""
|
||||||
|
Code from https://github.com/hassony2/torch_videovision
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
|
||||||
|
from skimage.transform import resize, rotate
|
||||||
|
from numpy import pad
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from skimage import img_as_ubyte, img_as_float
|
||||||
|
|
||||||
|
|
||||||
|
def crop_clip(clip, min_h, min_w, h, w):
|
||||||
|
if isinstance(clip[0], np.ndarray):
|
||||||
|
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
|
||||||
|
|
||||||
|
elif isinstance(clip[0], PIL.Image.Image):
|
||||||
|
cropped = [
|
||||||
|
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
||||||
|
'but got list of {0}'.format(type(clip[0])))
|
||||||
|
return cropped
|
||||||
|
|
||||||
|
|
||||||
|
def pad_clip(clip, h, w):
|
||||||
|
im_h, im_w = clip[0].shape[:2]
|
||||||
|
pad_h = (0, 0) if h < im_h else ((h - im_h) // 2, (h - im_h + 1) // 2)
|
||||||
|
pad_w = (0, 0) if w < im_w else ((w - im_w) // 2, (w - im_w + 1) // 2)
|
||||||
|
|
||||||
|
return pad(clip, ((0, 0), pad_h, pad_w, (0, 0)), mode='edge')
|
||||||
|
|
||||||
|
|
||||||
|
def resize_clip(clip, size, interpolation='bilinear'):
|
||||||
|
if isinstance(clip[0], np.ndarray):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
im_h, im_w, im_c = clip[0].shape
|
||||||
|
# Min spatial dim already matches minimal size
|
||||||
|
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
||||||
|
and im_h == size):
|
||||||
|
return clip
|
||||||
|
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
||||||
|
size = (new_w, new_h)
|
||||||
|
else:
|
||||||
|
size = size[1], size[0]
|
||||||
|
|
||||||
|
scaled = [
|
||||||
|
resize(img, size, order=1 if interpolation == 'bilinear' else 0, preserve_range=True,
|
||||||
|
mode='constant', anti_aliasing=True) for img in clip
|
||||||
|
]
|
||||||
|
elif isinstance(clip[0], PIL.Image.Image):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
im_w, im_h = clip[0].size
|
||||||
|
# Min spatial dim already matches minimal size
|
||||||
|
if (im_w <= im_h and im_w == size) or (im_h <= im_w
|
||||||
|
and im_h == size):
|
||||||
|
return clip
|
||||||
|
new_h, new_w = get_resize_sizes(im_h, im_w, size)
|
||||||
|
size = (new_w, new_h)
|
||||||
|
else:
|
||||||
|
size = size[1], size[0]
|
||||||
|
if interpolation == 'bilinear':
|
||||||
|
pil_inter = PIL.Image.NEAREST
|
||||||
|
else:
|
||||||
|
pil_inter = PIL.Image.BILINEAR
|
||||||
|
scaled = [img.resize(size, pil_inter) for img in clip]
|
||||||
|
else:
|
||||||
|
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
||||||
|
'but got list of {0}'.format(type(clip[0])))
|
||||||
|
return scaled
|
||||||
|
|
||||||
|
|
||||||
|
def get_resize_sizes(im_h, im_w, size):
|
||||||
|
if im_w < im_h:
|
||||||
|
ow = size
|
||||||
|
oh = int(size * im_h / im_w)
|
||||||
|
else:
|
||||||
|
oh = size
|
||||||
|
ow = int(size * im_w / im_h)
|
||||||
|
return oh, ow
|
||||||
|
|
||||||
|
|
||||||
|
class RandomFlip(object):
|
||||||
|
def __init__(self, time_flip=False, horizontal_flip=False):
|
||||||
|
self.time_flip = time_flip
|
||||||
|
self.horizontal_flip = horizontal_flip
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
if random.random() < 0.5 and self.time_flip:
|
||||||
|
return clip[::-1]
|
||||||
|
if random.random() < 0.5 and self.horizontal_flip:
|
||||||
|
return [np.fliplr(img) for img in clip]
|
||||||
|
|
||||||
|
return clip
|
||||||
|
|
||||||
|
|
||||||
|
class RandomResize(object):
|
||||||
|
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
|
||||||
|
The larger the original image is, the more times it takes to
|
||||||
|
interpolate
|
||||||
|
Args:
|
||||||
|
interpolation (str): Can be one of 'nearest', 'bilinear'
|
||||||
|
defaults to nearest
|
||||||
|
size (tuple): (widht, height)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'):
|
||||||
|
self.ratio = ratio
|
||||||
|
self.interpolation = interpolation
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
scaling_factor = random.uniform(self.ratio[0], self.ratio[1])
|
||||||
|
|
||||||
|
if isinstance(clip[0], np.ndarray):
|
||||||
|
im_h, im_w, im_c = clip[0].shape
|
||||||
|
elif isinstance(clip[0], PIL.Image.Image):
|
||||||
|
im_w, im_h = clip[0].size
|
||||||
|
|
||||||
|
new_w = int(im_w * scaling_factor)
|
||||||
|
new_h = int(im_h * scaling_factor)
|
||||||
|
new_size = (new_w, new_h)
|
||||||
|
resized = resize_clip(
|
||||||
|
clip, new_size, interpolation=self.interpolation)
|
||||||
|
|
||||||
|
return resized
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCrop(object):
|
||||||
|
"""Extract random crop at the same location for a list of videos
|
||||||
|
Args:
|
||||||
|
size (sequence or int): Desired output size for the
|
||||||
|
crop in format (h, w)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size):
|
||||||
|
if isinstance(size, numbers.Number):
|
||||||
|
size = (size, size)
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
||||||
|
in format (h, w, c) in numpy.ndarray
|
||||||
|
Returns:
|
||||||
|
PIL.Image or numpy.ndarray: Cropped list of videos
|
||||||
|
"""
|
||||||
|
h, w = self.size
|
||||||
|
if isinstance(clip[0], np.ndarray):
|
||||||
|
im_h, im_w, im_c = clip[0].shape
|
||||||
|
elif isinstance(clip[0], PIL.Image.Image):
|
||||||
|
im_w, im_h = clip[0].size
|
||||||
|
else:
|
||||||
|
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
||||||
|
'but got list of {0}'.format(type(clip[0])))
|
||||||
|
|
||||||
|
clip = pad_clip(clip, h, w)
|
||||||
|
im_h, im_w = clip.shape[1:3]
|
||||||
|
x1 = 0 if h == im_h else random.randint(0, im_w - w)
|
||||||
|
y1 = 0 if w == im_w else random.randint(0, im_h - h)
|
||||||
|
cropped = crop_clip(clip, y1, x1, h, w)
|
||||||
|
|
||||||
|
return cropped
|
||||||
|
|
||||||
|
|
||||||
|
class RandomRotation(object):
|
||||||
|
"""Rotate entire clip randomly by a random angle within
|
||||||
|
given bounds
|
||||||
|
Args:
|
||||||
|
degrees (sequence or int): Range of degrees to select from
|
||||||
|
If degrees is a number instead of sequence like (min, max),
|
||||||
|
the range of degrees, will be (-degrees, +degrees).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, degrees):
|
||||||
|
if isinstance(degrees, numbers.Number):
|
||||||
|
if degrees < 0:
|
||||||
|
raise ValueError('If degrees is a single number,'
|
||||||
|
'must be positive')
|
||||||
|
degrees = (-degrees, degrees)
|
||||||
|
else:
|
||||||
|
if len(degrees) != 2:
|
||||||
|
raise ValueError('If degrees is a sequence,'
|
||||||
|
'it must be of len 2.')
|
||||||
|
|
||||||
|
self.degrees = degrees
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img (PIL.Image or numpy.ndarray): List of videos to be cropped
|
||||||
|
in format (h, w, c) in numpy.ndarray
|
||||||
|
Returns:
|
||||||
|
PIL.Image or numpy.ndarray: Cropped list of videos
|
||||||
|
"""
|
||||||
|
angle = random.uniform(self.degrees[0], self.degrees[1])
|
||||||
|
if isinstance(clip[0], np.ndarray):
|
||||||
|
rotated = [rotate(image=img, angle=angle, preserve_range=True) for img in clip]
|
||||||
|
elif isinstance(clip[0], PIL.Image.Image):
|
||||||
|
rotated = [img.rotate(angle) for img in clip]
|
||||||
|
else:
|
||||||
|
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
||||||
|
'but got list of {0}'.format(type(clip[0])))
|
||||||
|
|
||||||
|
return rotated
|
||||||
|
|
||||||
|
|
||||||
|
class ColorJitter(object):
|
||||||
|
"""Randomly change the brightness, contrast and saturation and hue of the clip
|
||||||
|
Args:
|
||||||
|
brightness (float): How much to jitter brightness. brightness_factor
|
||||||
|
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
||||||
|
contrast (float): How much to jitter contrast. contrast_factor
|
||||||
|
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
||||||
|
saturation (float): How much to jitter saturation. saturation_factor
|
||||||
|
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
||||||
|
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
||||||
|
[-hue, hue]. Should be >=0 and <= 0.5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
||||||
|
self.brightness = brightness
|
||||||
|
self.contrast = contrast
|
||||||
|
self.saturation = saturation
|
||||||
|
self.hue = hue
|
||||||
|
|
||||||
|
def get_params(self, brightness, contrast, saturation, hue):
|
||||||
|
if brightness > 0:
|
||||||
|
brightness_factor = random.uniform(
|
||||||
|
max(0, 1 - brightness), 1 + brightness)
|
||||||
|
else:
|
||||||
|
brightness_factor = None
|
||||||
|
|
||||||
|
if contrast > 0:
|
||||||
|
contrast_factor = random.uniform(
|
||||||
|
max(0, 1 - contrast), 1 + contrast)
|
||||||
|
else:
|
||||||
|
contrast_factor = None
|
||||||
|
|
||||||
|
if saturation > 0:
|
||||||
|
saturation_factor = random.uniform(
|
||||||
|
max(0, 1 - saturation), 1 + saturation)
|
||||||
|
else:
|
||||||
|
saturation_factor = None
|
||||||
|
|
||||||
|
if hue > 0:
|
||||||
|
hue_factor = random.uniform(-hue, hue)
|
||||||
|
else:
|
||||||
|
hue_factor = None
|
||||||
|
return brightness_factor, contrast_factor, saturation_factor, hue_factor
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
clip (list): list of PIL.Image
|
||||||
|
Returns:
|
||||||
|
list PIL.Image : list of transformed PIL.Image
|
||||||
|
"""
|
||||||
|
if isinstance(clip[0], np.ndarray):
|
||||||
|
brightness, contrast, saturation, hue = self.get_params(
|
||||||
|
self.brightness, self.contrast, self.saturation, self.hue)
|
||||||
|
|
||||||
|
# Create img transform function sequence
|
||||||
|
img_transforms = []
|
||||||
|
if brightness is not None:
|
||||||
|
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
||||||
|
if saturation is not None:
|
||||||
|
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
||||||
|
if hue is not None:
|
||||||
|
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
||||||
|
if contrast is not None:
|
||||||
|
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
||||||
|
random.shuffle(img_transforms)
|
||||||
|
img_transforms = [img_as_ubyte, torchvision.transforms.ToPILImage()] + img_transforms + [np.array,
|
||||||
|
img_as_float]
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
jittered_clip = []
|
||||||
|
for img in clip:
|
||||||
|
jittered_img = img
|
||||||
|
for func in img_transforms:
|
||||||
|
jittered_img = func(jittered_img)
|
||||||
|
jittered_clip.append(jittered_img.astype('float32'))
|
||||||
|
elif isinstance(clip[0], PIL.Image.Image):
|
||||||
|
brightness, contrast, saturation, hue = self.get_params(
|
||||||
|
self.brightness, self.contrast, self.saturation, self.hue)
|
||||||
|
|
||||||
|
# Create img transform function sequence
|
||||||
|
img_transforms = []
|
||||||
|
if brightness is not None:
|
||||||
|
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness))
|
||||||
|
if saturation is not None:
|
||||||
|
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation))
|
||||||
|
if hue is not None:
|
||||||
|
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue))
|
||||||
|
if contrast is not None:
|
||||||
|
img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast))
|
||||||
|
random.shuffle(img_transforms)
|
||||||
|
|
||||||
|
# Apply to all videos
|
||||||
|
jittered_clip = []
|
||||||
|
for img in clip:
|
||||||
|
for func in img_transforms:
|
||||||
|
jittered_img = func(img)
|
||||||
|
jittered_clip.append(jittered_img)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise TypeError('Expected numpy.ndarray or PIL.Image' +
|
||||||
|
'but got list of {0}'.format(type(clip[0])))
|
||||||
|
return jittered_clip
|
||||||
|
|
||||||
|
|
||||||
|
class AllAugmentationTransform:
|
||||||
|
def __init__(self, resize_param=None, rotation_param=None, flip_param=None, crop_param=None, jitter_param=None):
|
||||||
|
self.transforms = []
|
||||||
|
|
||||||
|
if flip_param is not None:
|
||||||
|
self.transforms.append(RandomFlip(**flip_param))
|
||||||
|
|
||||||
|
if rotation_param is not None:
|
||||||
|
self.transforms.append(RandomRotation(**rotation_param))
|
||||||
|
|
||||||
|
if resize_param is not None:
|
||||||
|
self.transforms.append(RandomResize(**resize_param))
|
||||||
|
|
||||||
|
if crop_param is not None:
|
||||||
|
self.transforms.append(RandomCrop(**crop_param))
|
||||||
|
|
||||||
|
if jitter_param is not None:
|
||||||
|
self.transforms.append(ColorJitter(**jitter_param))
|
||||||
|
|
||||||
|
def __call__(self, clip):
|
||||||
|
for t in self.transforms:
|
||||||
|
clip = t(clip)
|
||||||
|
return clip
|
@ -0,0 +1,82 @@
|
|||||||
|
dataset_params:
|
||||||
|
root_dir: data/bair
|
||||||
|
frame_shape: [256, 256, 3]
|
||||||
|
id_sampling: False
|
||||||
|
augmentation_params:
|
||||||
|
flip_param:
|
||||||
|
horizontal_flip: True
|
||||||
|
time_flip: True
|
||||||
|
jitter_param:
|
||||||
|
brightness: 0.1
|
||||||
|
contrast: 0.1
|
||||||
|
saturation: 0.1
|
||||||
|
hue: 0.1
|
||||||
|
|
||||||
|
|
||||||
|
model_params:
|
||||||
|
common_params:
|
||||||
|
num_kp: 10
|
||||||
|
num_channels: 3
|
||||||
|
estimate_jacobian: True
|
||||||
|
kp_detector_params:
|
||||||
|
temperature: 0.1
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 1024
|
||||||
|
scale_factor: 0.25
|
||||||
|
num_blocks: 5
|
||||||
|
generator_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 512
|
||||||
|
num_down_blocks: 2
|
||||||
|
num_bottleneck_blocks: 6
|
||||||
|
estimate_occlusion_map: True
|
||||||
|
dense_motion_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 1024
|
||||||
|
num_blocks: 5
|
||||||
|
scale_factor: 0.25
|
||||||
|
discriminator_params:
|
||||||
|
scales: [1]
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 512
|
||||||
|
num_blocks: 4
|
||||||
|
sn: True
|
||||||
|
|
||||||
|
train_params:
|
||||||
|
num_epochs: 20
|
||||||
|
num_repeats: 1
|
||||||
|
epoch_milestones: [12, 18]
|
||||||
|
lr_generator: 2.0e-4
|
||||||
|
lr_discriminator: 2.0e-4
|
||||||
|
lr_kp_detector: 2.0e-4
|
||||||
|
batch_size: 36
|
||||||
|
scales: [1, 0.5, 0.25, 0.125]
|
||||||
|
checkpoint_freq: 10
|
||||||
|
transform_params:
|
||||||
|
sigma_affine: 0.05
|
||||||
|
sigma_tps: 0.005
|
||||||
|
points_tps: 5
|
||||||
|
loss_weights:
|
||||||
|
generator_gan: 1
|
||||||
|
discriminator_gan: 1
|
||||||
|
feature_matching: [10, 10, 10, 10]
|
||||||
|
perceptual: [10, 10, 10, 10, 10]
|
||||||
|
equivariance_value: 10
|
||||||
|
equivariance_jacobian: 10
|
||||||
|
|
||||||
|
reconstruction_params:
|
||||||
|
num_videos: 1000
|
||||||
|
format: '.mp4'
|
||||||
|
|
||||||
|
animate_params:
|
||||||
|
num_pairs: 50
|
||||||
|
format: '.mp4'
|
||||||
|
normalization_params:
|
||||||
|
adapt_movement_scale: False
|
||||||
|
use_relative_movement: True
|
||||||
|
use_relative_jacobian: True
|
||||||
|
|
||||||
|
visualizer_params:
|
||||||
|
kp_size: 5
|
||||||
|
draw_border: True
|
||||||
|
colormap: 'gist_rainbow'
|
@ -0,0 +1,77 @@
|
|||||||
|
dataset_params:
|
||||||
|
root_dir: data/fashion-png
|
||||||
|
frame_shape: [256, 256, 3]
|
||||||
|
id_sampling: False
|
||||||
|
augmentation_params:
|
||||||
|
flip_param:
|
||||||
|
horizontal_flip: True
|
||||||
|
time_flip: True
|
||||||
|
jitter_param:
|
||||||
|
hue: 0.1
|
||||||
|
|
||||||
|
model_params:
|
||||||
|
common_params:
|
||||||
|
num_kp: 10
|
||||||
|
num_channels: 3
|
||||||
|
estimate_jacobian: True
|
||||||
|
kp_detector_params:
|
||||||
|
temperature: 0.1
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 1024
|
||||||
|
scale_factor: 0.25
|
||||||
|
num_blocks: 5
|
||||||
|
generator_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 512
|
||||||
|
num_down_blocks: 2
|
||||||
|
num_bottleneck_blocks: 6
|
||||||
|
estimate_occlusion_map: True
|
||||||
|
dense_motion_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 1024
|
||||||
|
num_blocks: 5
|
||||||
|
scale_factor: 0.25
|
||||||
|
discriminator_params:
|
||||||
|
scales: [1]
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 512
|
||||||
|
num_blocks: 4
|
||||||
|
|
||||||
|
train_params:
|
||||||
|
num_epochs: 100
|
||||||
|
num_repeats: 50
|
||||||
|
epoch_milestones: [60, 90]
|
||||||
|
lr_generator: 2.0e-4
|
||||||
|
lr_discriminator: 2.0e-4
|
||||||
|
lr_kp_detector: 2.0e-4
|
||||||
|
batch_size: 27
|
||||||
|
scales: [1, 0.5, 0.25, 0.125]
|
||||||
|
checkpoint_freq: 50
|
||||||
|
transform_params:
|
||||||
|
sigma_affine: 0.05
|
||||||
|
sigma_tps: 0.005
|
||||||
|
points_tps: 5
|
||||||
|
loss_weights:
|
||||||
|
generator_gan: 1
|
||||||
|
discriminator_gan: 1
|
||||||
|
feature_matching: [10, 10, 10, 10]
|
||||||
|
perceptual: [10, 10, 10, 10, 10]
|
||||||
|
equivariance_value: 10
|
||||||
|
equivariance_jacobian: 10
|
||||||
|
|
||||||
|
reconstruction_params:
|
||||||
|
num_videos: 1000
|
||||||
|
format: '.mp4'
|
||||||
|
|
||||||
|
animate_params:
|
||||||
|
num_pairs: 50
|
||||||
|
format: '.mp4'
|
||||||
|
normalization_params:
|
||||||
|
adapt_movement_scale: False
|
||||||
|
use_relative_movement: True
|
||||||
|
use_relative_jacobian: True
|
||||||
|
|
||||||
|
visualizer_params:
|
||||||
|
kp_size: 5
|
||||||
|
draw_border: True
|
||||||
|
colormap: 'gist_rainbow'
|
@ -0,0 +1,84 @@
|
|||||||
|
dataset_params:
|
||||||
|
root_dir: data/moving-gif
|
||||||
|
frame_shape: [256, 256, 3]
|
||||||
|
id_sampling: False
|
||||||
|
augmentation_params:
|
||||||
|
flip_param:
|
||||||
|
horizontal_flip: True
|
||||||
|
time_flip: True
|
||||||
|
crop_param:
|
||||||
|
size: [256, 256]
|
||||||
|
resize_param:
|
||||||
|
ratio: [0.9, 1.1]
|
||||||
|
jitter_param:
|
||||||
|
hue: 0.5
|
||||||
|
|
||||||
|
model_params:
|
||||||
|
common_params:
|
||||||
|
num_kp: 10
|
||||||
|
num_channels: 3
|
||||||
|
estimate_jacobian: True
|
||||||
|
kp_detector_params:
|
||||||
|
temperature: 0.1
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 1024
|
||||||
|
scale_factor: 0.25
|
||||||
|
num_blocks: 5
|
||||||
|
single_jacobian_map: True
|
||||||
|
generator_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 512
|
||||||
|
num_down_blocks: 2
|
||||||
|
num_bottleneck_blocks: 6
|
||||||
|
estimate_occlusion_map: True
|
||||||
|
dense_motion_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 1024
|
||||||
|
num_blocks: 5
|
||||||
|
scale_factor: 0.25
|
||||||
|
discriminator_params:
|
||||||
|
scales: [1]
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 512
|
||||||
|
num_blocks: 4
|
||||||
|
sn: True
|
||||||
|
|
||||||
|
train_params:
|
||||||
|
num_epochs: 100
|
||||||
|
num_repeats: 25
|
||||||
|
epoch_milestones: [60, 90]
|
||||||
|
lr_generator: 2.0e-4
|
||||||
|
lr_discriminator: 2.0e-4
|
||||||
|
lr_kp_detector: 2.0e-4
|
||||||
|
|
||||||
|
batch_size: 36
|
||||||
|
scales: [1, 0.5, 0.25, 0.125]
|
||||||
|
checkpoint_freq: 100
|
||||||
|
transform_params:
|
||||||
|
sigma_affine: 0.05
|
||||||
|
sigma_tps: 0.005
|
||||||
|
points_tps: 5
|
||||||
|
loss_weights:
|
||||||
|
generator_gan: 1
|
||||||
|
discriminator_gan: 1
|
||||||
|
feature_matching: [10, 10, 10, 10]
|
||||||
|
perceptual: [10, 10, 10, 10, 10]
|
||||||
|
equivariance_value: 10
|
||||||
|
equivariance_jacobian: 10
|
||||||
|
|
||||||
|
reconstruction_params:
|
||||||
|
num_videos: 1000
|
||||||
|
format: '.mp4'
|
||||||
|
|
||||||
|
animate_params:
|
||||||
|
num_pairs: 50
|
||||||
|
format: '.mp4'
|
||||||
|
normalization_params:
|
||||||
|
adapt_movement_scale: False
|
||||||
|
use_relative_movement: True
|
||||||
|
use_relative_jacobian: True
|
||||||
|
|
||||||
|
visualizer_params:
|
||||||
|
kp_size: 5
|
||||||
|
draw_border: True
|
||||||
|
colormap: 'gist_rainbow'
|
@ -0,0 +1,76 @@
|
|||||||
|
dataset_params:
|
||||||
|
root_dir: data/nemo-png
|
||||||
|
frame_shape: [256, 256, 3]
|
||||||
|
id_sampling: False
|
||||||
|
augmentation_params:
|
||||||
|
flip_param:
|
||||||
|
horizontal_flip: True
|
||||||
|
time_flip: True
|
||||||
|
|
||||||
|
model_params:
|
||||||
|
common_params:
|
||||||
|
num_kp: 10
|
||||||
|
num_channels: 3
|
||||||
|
estimate_jacobian: True
|
||||||
|
kp_detector_params:
|
||||||
|
temperature: 0.1
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 1024
|
||||||
|
scale_factor: 0.25
|
||||||
|
num_blocks: 5
|
||||||
|
generator_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 512
|
||||||
|
num_down_blocks: 2
|
||||||
|
num_bottleneck_blocks: 6
|
||||||
|
estimate_occlusion_map: True
|
||||||
|
dense_motion_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 1024
|
||||||
|
num_blocks: 5
|
||||||
|
scale_factor: 0.25
|
||||||
|
discriminator_params:
|
||||||
|
scales: [1]
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 512
|
||||||
|
num_blocks: 4
|
||||||
|
sn: True
|
||||||
|
|
||||||
|
train_params:
|
||||||
|
num_epochs: 100
|
||||||
|
num_repeats: 8
|
||||||
|
epoch_milestones: [60, 90]
|
||||||
|
lr_generator: 2.0e-4
|
||||||
|
lr_discriminator: 2.0e-4
|
||||||
|
lr_kp_detector: 2.0e-4
|
||||||
|
batch_size: 36
|
||||||
|
scales: [1, 0.5, 0.25, 0.125]
|
||||||
|
checkpoint_freq: 50
|
||||||
|
transform_params:
|
||||||
|
sigma_affine: 0.05
|
||||||
|
sigma_tps: 0.005
|
||||||
|
points_tps: 5
|
||||||
|
loss_weights:
|
||||||
|
generator_gan: 1
|
||||||
|
discriminator_gan: 1
|
||||||
|
feature_matching: [10, 10, 10, 10]
|
||||||
|
perceptual: [10, 10, 10, 10, 10]
|
||||||
|
equivariance_value: 10
|
||||||
|
equivariance_jacobian: 10
|
||||||
|
|
||||||
|
reconstruction_params:
|
||||||
|
num_videos: 1000
|
||||||
|
format: '.mp4'
|
||||||
|
|
||||||
|
animate_params:
|
||||||
|
num_pairs: 50
|
||||||
|
format: '.mp4'
|
||||||
|
normalization_params:
|
||||||
|
adapt_movement_scale: False
|
||||||
|
use_relative_movement: True
|
||||||
|
use_relative_jacobian: True
|
||||||
|
|
||||||
|
visualizer_params:
|
||||||
|
kp_size: 5
|
||||||
|
draw_border: True
|
||||||
|
colormap: 'gist_rainbow'
|
@ -0,0 +1,157 @@
|
|||||||
|
# Dataset parameters
|
||||||
|
# Each dataset should contain 2 folders train and test
|
||||||
|
# Each video can be represented as:
|
||||||
|
# - an image of concatenated frames
|
||||||
|
# - '.mp4' or '.gif'
|
||||||
|
# - folder with all frames from a specific video
|
||||||
|
# In case of Taichi. Same (youtube) video can be splitted in many parts (chunks). Each part has a following
|
||||||
|
# format (id)#other#info.mp4. For example '12335#adsbf.mp4' has an id 12335. In case of TaiChi id stands for youtube
|
||||||
|
# video id.
|
||||||
|
dataset_params:
|
||||||
|
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
|
||||||
|
root_dir: data/taichi-png
|
||||||
|
# Image shape, needed for staked .png format.
|
||||||
|
frame_shape: [256, 256, 3]
|
||||||
|
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
|
||||||
|
# In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
|
||||||
|
# If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
|
||||||
|
id_sampling: True
|
||||||
|
# List with pairs for animation, None for random pairs
|
||||||
|
pairs_list: data/taichi256.csv
|
||||||
|
# Augmentation parameters see augmentation.py for all posible augmentations
|
||||||
|
augmentation_params:
|
||||||
|
flip_param:
|
||||||
|
horizontal_flip: True
|
||||||
|
time_flip: True
|
||||||
|
jitter_param:
|
||||||
|
brightness: 0.1
|
||||||
|
contrast: 0.1
|
||||||
|
saturation: 0.1
|
||||||
|
hue: 0.1
|
||||||
|
|
||||||
|
# Defines model architecture
|
||||||
|
model_params:
|
||||||
|
common_params:
|
||||||
|
# Number of keypoint
|
||||||
|
num_kp: 10
|
||||||
|
# Number of channels per image
|
||||||
|
num_channels: 3
|
||||||
|
# Using first or zero order model
|
||||||
|
estimate_jacobian: True
|
||||||
|
kp_detector_params:
|
||||||
|
# Softmax temperature for keypoint heatmaps
|
||||||
|
temperature: 0.1
|
||||||
|
# Number of features mutliplier
|
||||||
|
block_expansion: 32
|
||||||
|
# Maximum allowed number of features
|
||||||
|
max_features: 1024
|
||||||
|
# Number of block in Unet. Can be increased or decreased depending or resolution.
|
||||||
|
num_blocks: 5
|
||||||
|
# Keypioint is predicted on smaller images for better performance,
|
||||||
|
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
||||||
|
scale_factor: 0.25
|
||||||
|
generator_params:
|
||||||
|
# Number of features mutliplier
|
||||||
|
block_expansion: 64
|
||||||
|
# Maximum allowed number of features
|
||||||
|
max_features: 512
|
||||||
|
# Number of downsampling blocks in Jonson architecture.
|
||||||
|
# Can be increased or decreased depending or resolution.
|
||||||
|
num_down_blocks: 2
|
||||||
|
# Number of ResBlocks in Jonson architecture.
|
||||||
|
num_bottleneck_blocks: 6
|
||||||
|
# Use occlusion map or not
|
||||||
|
estimate_occlusion_map: True
|
||||||
|
|
||||||
|
dense_motion_params:
|
||||||
|
# Number of features mutliplier
|
||||||
|
block_expansion: 64
|
||||||
|
# Maximum allowed number of features
|
||||||
|
max_features: 1024
|
||||||
|
# Number of block in Unet. Can be increased or decreased depending or resolution.
|
||||||
|
num_blocks: 5
|
||||||
|
# Dense motion is predicted on smaller images for better performance,
|
||||||
|
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
||||||
|
scale_factor: 0.25
|
||||||
|
discriminator_params:
|
||||||
|
# Discriminator can be multiscale, if you want 2 discriminator on original
|
||||||
|
# resolution and half of the original, specify scales: [1, 0.5]
|
||||||
|
scales: [1]
|
||||||
|
# Number of features mutliplier
|
||||||
|
block_expansion: 32
|
||||||
|
# Maximum allowed number of features
|
||||||
|
max_features: 512
|
||||||
|
# Number of blocks. Can be increased or decreased depending or resolution.
|
||||||
|
num_blocks: 4
|
||||||
|
|
||||||
|
# Parameters of training
|
||||||
|
train_params:
|
||||||
|
# Number of training epochs
|
||||||
|
num_epochs: 100
|
||||||
|
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
|
||||||
|
# Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
|
||||||
|
num_repeats: 150
|
||||||
|
# Drop learning rate by 10 times after this epochs
|
||||||
|
epoch_milestones: [60, 90]
|
||||||
|
# Initial learing rate for all modules
|
||||||
|
lr_generator: 2.0e-4
|
||||||
|
lr_discriminator: 2.0e-4
|
||||||
|
lr_kp_detector: 2.0e-4
|
||||||
|
batch_size: 30
|
||||||
|
# Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
|
||||||
|
# than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
|
||||||
|
scales: [1, 0.5, 0.25, 0.125]
|
||||||
|
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
|
||||||
|
checkpoint_freq: 50
|
||||||
|
# Parameters of transform for equivariance loss
|
||||||
|
transform_params:
|
||||||
|
# Sigma for affine part
|
||||||
|
sigma_affine: 0.05
|
||||||
|
# Sigma for deformation part
|
||||||
|
sigma_tps: 0.005
|
||||||
|
# Number of point in the deformation grid
|
||||||
|
points_tps: 5
|
||||||
|
loss_weights:
|
||||||
|
# Weight for LSGAN loss in generator, 0 for no adversarial loss.
|
||||||
|
generator_gan: 0
|
||||||
|
# Weight for LSGAN loss in discriminator
|
||||||
|
discriminator_gan: 1
|
||||||
|
# Weights for feature matching loss, the number should be the same as number of blocks in discriminator.
|
||||||
|
feature_matching: [10, 10, 10, 10]
|
||||||
|
# Weights for perceptual loss.
|
||||||
|
perceptual: [10, 10, 10, 10, 10]
|
||||||
|
# Weights for value equivariance.
|
||||||
|
equivariance_value: 10
|
||||||
|
# Weights for jacobian equivariance.
|
||||||
|
equivariance_jacobian: 10
|
||||||
|
|
||||||
|
# Parameters of reconstruction
|
||||||
|
reconstruction_params:
|
||||||
|
# Maximum number of videos for reconstruction
|
||||||
|
num_videos: 1000
|
||||||
|
# Format for visualization, note that results will be also stored in staked .png.
|
||||||
|
format: '.mp4'
|
||||||
|
|
||||||
|
# Parameters of animation
|
||||||
|
animate_params:
|
||||||
|
# Maximum number of pairs for animation, the pairs will be either taken from pairs_list or random.
|
||||||
|
num_pairs: 50
|
||||||
|
# Format for visualization, note that results will be also stored in staked .png.
|
||||||
|
format: '.mp4'
|
||||||
|
# Normalization of diriving keypoints
|
||||||
|
normalization_params:
|
||||||
|
# Increase or decrease relative movement scale depending on the size of the object
|
||||||
|
adapt_movement_scale: False
|
||||||
|
# Apply only relative displacement of the keypoint
|
||||||
|
use_relative_movement: True
|
||||||
|
# Apply only relative change in jacobian
|
||||||
|
use_relative_jacobian: True
|
||||||
|
|
||||||
|
# Visualization parameters
|
||||||
|
visualizer_params:
|
||||||
|
# Draw keypoints of this size, increase or decrease depending on resolution
|
||||||
|
kp_size: 5
|
||||||
|
# Draw white border around images
|
||||||
|
draw_border: True
|
||||||
|
# Color map for keypoints
|
||||||
|
colormap: 'gist_rainbow'
|
@ -0,0 +1,150 @@
|
|||||||
|
# Dataset parameters
|
||||||
|
dataset_params:
|
||||||
|
# Path to data, data can be stored in several formats: .mp4 or .gif videos, stacked .png images or folders with frames.
|
||||||
|
root_dir: data/taichi-png
|
||||||
|
# Image shape, needed for staked .png format.
|
||||||
|
frame_shape: [256, 256, 3]
|
||||||
|
# In case of TaiChi single video can be splitted in many chunks, or the maybe several videos for single person.
|
||||||
|
# In this case epoch can be a pass over different videos (if id_sampling=True) or over different chunks (if id_sampling=False)
|
||||||
|
# If the name of the video '12335#adsbf.mp4' the id is assumed to be 12335
|
||||||
|
id_sampling: True
|
||||||
|
# List with pairs for animation, None for random pairs
|
||||||
|
pairs_list: data/taichi256.csv
|
||||||
|
# Augmentation parameters see augmentation.py for all posible augmentations
|
||||||
|
augmentation_params:
|
||||||
|
flip_param:
|
||||||
|
horizontal_flip: True
|
||||||
|
time_flip: True
|
||||||
|
jitter_param:
|
||||||
|
brightness: 0.1
|
||||||
|
contrast: 0.1
|
||||||
|
saturation: 0.1
|
||||||
|
hue: 0.1
|
||||||
|
|
||||||
|
# Defines model architecture
|
||||||
|
model_params:
|
||||||
|
common_params:
|
||||||
|
# Number of keypoint
|
||||||
|
num_kp: 10
|
||||||
|
# Number of channels per image
|
||||||
|
num_channels: 3
|
||||||
|
# Using first or zero order model
|
||||||
|
estimate_jacobian: True
|
||||||
|
kp_detector_params:
|
||||||
|
# Softmax temperature for keypoint heatmaps
|
||||||
|
temperature: 0.1
|
||||||
|
# Number of features mutliplier
|
||||||
|
block_expansion: 32
|
||||||
|
# Maximum allowed number of features
|
||||||
|
max_features: 1024
|
||||||
|
# Number of block in Unet. Can be increased or decreased depending or resolution.
|
||||||
|
num_blocks: 5
|
||||||
|
# Keypioint is predicted on smaller images for better performance,
|
||||||
|
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
||||||
|
scale_factor: 0.25
|
||||||
|
generator_params:
|
||||||
|
# Number of features mutliplier
|
||||||
|
block_expansion: 64
|
||||||
|
# Maximum allowed number of features
|
||||||
|
max_features: 512
|
||||||
|
# Number of downsampling blocks in Jonson architecture.
|
||||||
|
# Can be increased or decreased depending or resolution.
|
||||||
|
num_down_blocks: 2
|
||||||
|
# Number of ResBlocks in Jonson architecture.
|
||||||
|
num_bottleneck_blocks: 6
|
||||||
|
# Use occlusion map or not
|
||||||
|
estimate_occlusion_map: True
|
||||||
|
|
||||||
|
dense_motion_params:
|
||||||
|
# Number of features mutliplier
|
||||||
|
block_expansion: 64
|
||||||
|
# Maximum allowed number of features
|
||||||
|
max_features: 1024
|
||||||
|
# Number of block in Unet. Can be increased or decreased depending or resolution.
|
||||||
|
num_blocks: 5
|
||||||
|
# Dense motion is predicted on smaller images for better performance,
|
||||||
|
# scale_factor=0.25 means that 256x256 image will be resized to 64x64
|
||||||
|
scale_factor: 0.25
|
||||||
|
discriminator_params:
|
||||||
|
# Discriminator can be multiscale, if you want 2 discriminator on original
|
||||||
|
# resolution and half of the original, specify scales: [1, 0.5]
|
||||||
|
scales: [1]
|
||||||
|
# Number of features mutliplier
|
||||||
|
block_expansion: 32
|
||||||
|
# Maximum allowed number of features
|
||||||
|
max_features: 512
|
||||||
|
# Number of blocks. Can be increased or decreased depending or resolution.
|
||||||
|
num_blocks: 4
|
||||||
|
use_kp: True
|
||||||
|
|
||||||
|
# Parameters of training
|
||||||
|
train_params:
|
||||||
|
# Number of training epochs
|
||||||
|
num_epochs: 150
|
||||||
|
# For better i/o performance when number of videos is small number of epochs can be multiplied by this number.
|
||||||
|
# Thus effectivlly with num_repeats=100 each epoch is 100 times larger.
|
||||||
|
num_repeats: 150
|
||||||
|
# Drop learning rate by 10 times after this epochs
|
||||||
|
epoch_milestones: []
|
||||||
|
# Initial learing rate for all modules
|
||||||
|
lr_generator: 2.0e-4
|
||||||
|
lr_discriminator: 2.0e-4
|
||||||
|
lr_kp_detector: 0
|
||||||
|
batch_size: 27
|
||||||
|
# Scales for perceptual pyramide loss. If scales = [1, 0.5, 0.25, 0.125] and image resolution is 256x256,
|
||||||
|
# than the loss will be computer on resolutions 256x256, 128x128, 64x64, 32x32.
|
||||||
|
scales: [1, 0.5, 0.25, 0.125]
|
||||||
|
# Save checkpoint this frequently. If checkpoint_freq=50, checkpoint will be saved every 50 epochs.
|
||||||
|
checkpoint_freq: 50
|
||||||
|
# Parameters of transform for equivariance loss
|
||||||
|
transform_params:
|
||||||
|
# Sigma for affine part
|
||||||
|
sigma_affine: 0.05
|
||||||
|
# Sigma for deformation part
|
||||||
|
sigma_tps: 0.005
|
||||||
|
# Number of point in the deformation grid
|
||||||
|
points_tps: 5
|
||||||
|
loss_weights:
|
||||||
|
# Weight for LSGAN loss in generator
|
||||||
|
generator_gan: 1
|
||||||
|
# Weight for LSGAN loss in discriminator
|
||||||
|
discriminator_gan: 1
|
||||||
|
# Weights for feature matching loss, the number should be the same as number of blocks in discriminator.
|
||||||
|
feature_matching: [10, 10, 10, 10]
|
||||||
|
# Weights for perceptual loss.
|
||||||
|
perceptual: [10, 10, 10, 10, 10]
|
||||||
|
# Weights for value equivariance.
|
||||||
|
equivariance_value: 10
|
||||||
|
# Weights for jacobian equivariance.
|
||||||
|
equivariance_jacobian: 10
|
||||||
|
|
||||||
|
# Parameters of reconstruction
|
||||||
|
reconstruction_params:
|
||||||
|
# Maximum number of videos for reconstruction
|
||||||
|
num_videos: 1000
|
||||||
|
# Format for visualization, note that results will be also stored in staked .png.
|
||||||
|
format: '.mp4'
|
||||||
|
|
||||||
|
# Parameters of animation
|
||||||
|
animate_params:
|
||||||
|
# Maximum number of pairs for animation, the pairs will be either taken from pairs_list or random.
|
||||||
|
num_pairs: 50
|
||||||
|
# Format for visualization, note that results will be also stored in staked .png.
|
||||||
|
format: '.mp4'
|
||||||
|
# Normalization of diriving keypoints
|
||||||
|
normalization_params:
|
||||||
|
# Increase or decrease relative movement scale depending on the size of the object
|
||||||
|
adapt_movement_scale: False
|
||||||
|
# Apply only relative displacement of the keypoint
|
||||||
|
use_relative_movement: True
|
||||||
|
# Apply only relative change in jacobian
|
||||||
|
use_relative_jacobian: True
|
||||||
|
|
||||||
|
# Visualization parameters
|
||||||
|
visualizer_params:
|
||||||
|
# Draw keypoints of this size, increase or decrease depending on resolution
|
||||||
|
kp_size: 5
|
||||||
|
# Draw white border around images
|
||||||
|
draw_border: True
|
||||||
|
# Color map for keypoints
|
||||||
|
colormap: 'gist_rainbow'
|
@ -0,0 +1,83 @@
|
|||||||
|
dataset_params:
|
||||||
|
root_dir: data/vox-png
|
||||||
|
frame_shape: [256, 256, 3]
|
||||||
|
id_sampling: True
|
||||||
|
pairs_list: data/vox256.csv
|
||||||
|
augmentation_params:
|
||||||
|
flip_param:
|
||||||
|
horizontal_flip: True
|
||||||
|
time_flip: True
|
||||||
|
jitter_param:
|
||||||
|
brightness: 0.1
|
||||||
|
contrast: 0.1
|
||||||
|
saturation: 0.1
|
||||||
|
hue: 0.1
|
||||||
|
|
||||||
|
|
||||||
|
model_params:
|
||||||
|
common_params:
|
||||||
|
num_kp: 10
|
||||||
|
num_channels: 3
|
||||||
|
estimate_jacobian: True
|
||||||
|
kp_detector_params:
|
||||||
|
temperature: 0.1
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 1024
|
||||||
|
scale_factor: 0.25
|
||||||
|
num_blocks: 5
|
||||||
|
generator_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 512
|
||||||
|
num_down_blocks: 2
|
||||||
|
num_bottleneck_blocks: 6
|
||||||
|
estimate_occlusion_map: True
|
||||||
|
dense_motion_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 1024
|
||||||
|
num_blocks: 5
|
||||||
|
scale_factor: 0.25
|
||||||
|
discriminator_params:
|
||||||
|
scales: [1]
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 512
|
||||||
|
num_blocks: 4
|
||||||
|
sn: True
|
||||||
|
|
||||||
|
train_params:
|
||||||
|
num_epochs: 100
|
||||||
|
num_repeats: 75
|
||||||
|
epoch_milestones: [60, 90]
|
||||||
|
lr_generator: 2.0e-4
|
||||||
|
lr_discriminator: 2.0e-4
|
||||||
|
lr_kp_detector: 2.0e-4
|
||||||
|
batch_size: 40
|
||||||
|
scales: [1, 0.5, 0.25, 0.125]
|
||||||
|
checkpoint_freq: 50
|
||||||
|
transform_params:
|
||||||
|
sigma_affine: 0.05
|
||||||
|
sigma_tps: 0.005
|
||||||
|
points_tps: 5
|
||||||
|
loss_weights:
|
||||||
|
generator_gan: 0
|
||||||
|
discriminator_gan: 1
|
||||||
|
feature_matching: [10, 10, 10, 10]
|
||||||
|
perceptual: [10, 10, 10, 10, 10]
|
||||||
|
equivariance_value: 10
|
||||||
|
equivariance_jacobian: 10
|
||||||
|
|
||||||
|
reconstruction_params:
|
||||||
|
num_videos: 1000
|
||||||
|
format: '.mp4'
|
||||||
|
|
||||||
|
animate_params:
|
||||||
|
num_pairs: 50
|
||||||
|
format: '.mp4'
|
||||||
|
normalization_params:
|
||||||
|
adapt_movement_scale: False
|
||||||
|
use_relative_movement: True
|
||||||
|
use_relative_jacobian: True
|
||||||
|
|
||||||
|
visualizer_params:
|
||||||
|
kp_size: 5
|
||||||
|
draw_border: True
|
||||||
|
colormap: 'gist_rainbow'
|
@ -0,0 +1,84 @@
|
|||||||
|
dataset_params:
|
||||||
|
root_dir: data/vox-png
|
||||||
|
frame_shape: [256, 256, 3]
|
||||||
|
id_sampling: True
|
||||||
|
pairs_list: data/vox256.csv
|
||||||
|
augmentation_params:
|
||||||
|
flip_param:
|
||||||
|
horizontal_flip: True
|
||||||
|
time_flip: True
|
||||||
|
jitter_param:
|
||||||
|
brightness: 0.1
|
||||||
|
contrast: 0.1
|
||||||
|
saturation: 0.1
|
||||||
|
hue: 0.1
|
||||||
|
|
||||||
|
|
||||||
|
model_params:
|
||||||
|
common_params:
|
||||||
|
num_kp: 10
|
||||||
|
num_channels: 3
|
||||||
|
estimate_jacobian: True
|
||||||
|
kp_detector_params:
|
||||||
|
temperature: 0.1
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 1024
|
||||||
|
scale_factor: 0.25
|
||||||
|
num_blocks: 5
|
||||||
|
generator_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 512
|
||||||
|
num_down_blocks: 2
|
||||||
|
num_bottleneck_blocks: 6
|
||||||
|
estimate_occlusion_map: True
|
||||||
|
dense_motion_params:
|
||||||
|
block_expansion: 64
|
||||||
|
max_features: 1024
|
||||||
|
num_blocks: 5
|
||||||
|
scale_factor: 0.25
|
||||||
|
discriminator_params:
|
||||||
|
scales: [1]
|
||||||
|
block_expansion: 32
|
||||||
|
max_features: 512
|
||||||
|
num_blocks: 4
|
||||||
|
use_kp: True
|
||||||
|
|
||||||
|
|
||||||
|
train_params:
|
||||||
|
num_epochs: 150
|
||||||
|
num_repeats: 75
|
||||||
|
epoch_milestones: []
|
||||||
|
lr_generator: 2.0e-4
|
||||||
|
lr_discriminator: 2.0e-4
|
||||||
|
lr_kp_detector: 2.0e-4
|
||||||
|
batch_size: 36
|
||||||
|
scales: [1, 0.5, 0.25, 0.125]
|
||||||
|
checkpoint_freq: 50
|
||||||
|
transform_params:
|
||||||
|
sigma_affine: 0.05
|
||||||
|
sigma_tps: 0.005
|
||||||
|
points_tps: 5
|
||||||
|
loss_weights:
|
||||||
|
generator_gan: 1
|
||||||
|
discriminator_gan: 1
|
||||||
|
feature_matching: [10, 10, 10, 10]
|
||||||
|
perceptual: [10, 10, 10, 10, 10]
|
||||||
|
equivariance_value: 10
|
||||||
|
equivariance_jacobian: 10
|
||||||
|
|
||||||
|
reconstruction_params:
|
||||||
|
num_videos: 1000
|
||||||
|
format: '.mp4'
|
||||||
|
|
||||||
|
animate_params:
|
||||||
|
num_pairs: 50
|
||||||
|
format: '.mp4'
|
||||||
|
normalization_params:
|
||||||
|
adapt_movement_scale: False
|
||||||
|
use_relative_movement: True
|
||||||
|
use_relative_jacobian: True
|
||||||
|
|
||||||
|
visualizer_params:
|
||||||
|
kp_size: 5
|
||||||
|
draw_border: True
|
||||||
|
colormap: 'gist_rainbow'
|
@ -0,0 +1,158 @@
|
|||||||
|
import face_alignment
|
||||||
|
import skimage.io
|
||||||
|
import numpy
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from skimage import img_as_ubyte
|
||||||
|
from skimage.transform import resize
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
def extract_bbox(frame, fa):
|
||||||
|
if max(frame.shape[0], frame.shape[1]) > 640:
|
||||||
|
scale_factor = max(frame.shape[0], frame.shape[1]) / 640.0
|
||||||
|
frame = resize(frame, (int(frame.shape[0] / scale_factor), int(frame.shape[1] / scale_factor)))
|
||||||
|
frame = img_as_ubyte(frame)
|
||||||
|
else:
|
||||||
|
scale_factor = 1
|
||||||
|
frame = frame[..., :3]
|
||||||
|
bboxes = fa.face_detector.detect_from_image(frame[..., ::-1])
|
||||||
|
if len(bboxes) == 0:
|
||||||
|
return []
|
||||||
|
return np.array(bboxes)[:, :-1] * scale_factor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def bb_intersection_over_union(boxA, boxB):
|
||||||
|
xA = max(boxA[0], boxB[0])
|
||||||
|
yA = max(boxA[1], boxB[1])
|
||||||
|
xB = min(boxA[2], boxB[2])
|
||||||
|
yB = min(boxA[3], boxB[3])
|
||||||
|
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
||||||
|
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
||||||
|
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
||||||
|
iou = interArea / float(boxAArea + boxBArea - interArea)
|
||||||
|
return iou
|
||||||
|
|
||||||
|
|
||||||
|
def join(tube_bbox, bbox):
|
||||||
|
xA = min(tube_bbox[0], bbox[0])
|
||||||
|
yA = min(tube_bbox[1], bbox[1])
|
||||||
|
xB = max(tube_bbox[2], bbox[2])
|
||||||
|
yB = max(tube_bbox[3], bbox[3])
|
||||||
|
return (xA, yA, xB, yB)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bbox(start, end, fps, tube_bbox, frame_shape, inp, image_shape, increase_area=0.1):
|
||||||
|
left, top, right, bot = tube_bbox
|
||||||
|
width = right - left
|
||||||
|
height = bot - top
|
||||||
|
|
||||||
|
#Computing aspect preserving bbox
|
||||||
|
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
|
||||||
|
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
|
||||||
|
|
||||||
|
left = int(left - width_increase * width)
|
||||||
|
top = int(top - height_increase * height)
|
||||||
|
right = int(right + width_increase * width)
|
||||||
|
bot = int(bot + height_increase * height)
|
||||||
|
|
||||||
|
top, bot, left, right = max(0, top), min(bot, frame_shape[0]), max(0, left), min(right, frame_shape[1])
|
||||||
|
h, w = bot - top, right - left
|
||||||
|
|
||||||
|
start = start / fps
|
||||||
|
end = end / fps
|
||||||
|
time = end - start
|
||||||
|
|
||||||
|
scale = f'{image_shape[0]}:{image_shape[1]}'
|
||||||
|
|
||||||
|
return f'ffmpeg -i {inp} -ss {start} -t {time} -filter:v "crop={w}:{h}:{left}:{top}, scale={scale}" crop.mp4'
|
||||||
|
|
||||||
|
|
||||||
|
def compute_bbox_trajectories(trajectories, fps, frame_shape, args):
|
||||||
|
commands = []
|
||||||
|
for i, (bbox, tube_bbox, start, end) in enumerate(trajectories):
|
||||||
|
if (end - start) > args.min_frames:
|
||||||
|
command = compute_bbox(start, end, fps, tube_bbox, frame_shape, inp=args.inp, image_shape=args.image_shape, increase_area=args.increase)
|
||||||
|
commands.append(command)
|
||||||
|
return commands
|
||||||
|
|
||||||
|
|
||||||
|
def process_video(args):
|
||||||
|
device = 'cpu' if args.cpu else 'cuda'
|
||||||
|
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device=device)
|
||||||
|
video = imageio.get_reader(args.inp)
|
||||||
|
|
||||||
|
trajectories = []
|
||||||
|
previous_frame = None
|
||||||
|
fps = video.get_meta_data()['fps']
|
||||||
|
commands = []
|
||||||
|
try:
|
||||||
|
for i, frame in tqdm(enumerate(video)):
|
||||||
|
frame_shape = frame.shape
|
||||||
|
bboxes = extract_bbox(frame, fa)
|
||||||
|
## For each trajectory check the criterion
|
||||||
|
not_valid_trajectories = []
|
||||||
|
valid_trajectories = []
|
||||||
|
|
||||||
|
for trajectory in trajectories:
|
||||||
|
tube_bbox = trajectory[0]
|
||||||
|
intersection = 0
|
||||||
|
for bbox in bboxes:
|
||||||
|
intersection = max(intersection, bb_intersection_over_union(tube_bbox, bbox))
|
||||||
|
if intersection > args.iou_with_initial:
|
||||||
|
valid_trajectories.append(trajectory)
|
||||||
|
else:
|
||||||
|
not_valid_trajectories.append(trajectory)
|
||||||
|
|
||||||
|
commands += compute_bbox_trajectories(not_valid_trajectories, fps, frame_shape, args)
|
||||||
|
trajectories = valid_trajectories
|
||||||
|
|
||||||
|
## Assign bbox to trajectories, create new trajectories
|
||||||
|
for bbox in bboxes:
|
||||||
|
intersection = 0
|
||||||
|
current_trajectory = None
|
||||||
|
for trajectory in trajectories:
|
||||||
|
tube_bbox = trajectory[0]
|
||||||
|
current_intersection = bb_intersection_over_union(tube_bbox, bbox)
|
||||||
|
if intersection < current_intersection and current_intersection > args.iou_with_initial:
|
||||||
|
intersection = bb_intersection_over_union(tube_bbox, bbox)
|
||||||
|
current_trajectory = trajectory
|
||||||
|
|
||||||
|
## Create new trajectory
|
||||||
|
if current_trajectory is None:
|
||||||
|
trajectories.append([bbox, bbox, i, i])
|
||||||
|
else:
|
||||||
|
current_trajectory[3] = i
|
||||||
|
current_trajectory[1] = join(current_trajectory[1], bbox)
|
||||||
|
|
||||||
|
|
||||||
|
except IndexError as e:
|
||||||
|
raise (e)
|
||||||
|
|
||||||
|
commands += compute_bbox_trajectories(trajectories, fps, frame_shape, args)
|
||||||
|
return commands
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
|
||||||
|
help="Image shape")
|
||||||
|
parser.add_argument("--increase", default=0.1, type=float, help='Increase bbox by this amount')
|
||||||
|
parser.add_argument("--iou_with_initial", type=float, default=0.25, help="The minimal allowed iou with inital bbox")
|
||||||
|
parser.add_argument("--inp", required=True, help='Input image or video')
|
||||||
|
parser.add_argument("--min_frames", type=int, default=150, help='Minimum number of frames')
|
||||||
|
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
commands = process_video(args)
|
||||||
|
for command in commands:
|
||||||
|
print (command)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
|||||||
|
# TaiChi dataset
|
||||||
|
|
||||||
|
The scripst for loading the TaiChi dataset.
|
||||||
|
|
||||||
|
We provide only the id of the corresponding video and the bounding box. Following script will download videos from youtube and crop them according to the provided bounding boxes.
|
||||||
|
|
||||||
|
1) Load youtube-dl:
|
||||||
|
```
|
||||||
|
wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl
|
||||||
|
chmod a+rx youtube-dl
|
||||||
|
```
|
||||||
|
|
||||||
|
2) Run script to download videos, there are 2 formats that can be used for storing videos one is .mp4 and another is folder with .png images. While .png images occupy significantly more space, the format is loss-less and have better i/o performance when training.
|
||||||
|
|
||||||
|
```
|
||||||
|
python load_videos.py --metadata taichi-metadata.csv --format .mp4 --out_folder taichi --workers 8
|
||||||
|
```
|
||||||
|
select number of workers based on number of cpu avaliable. Note .png format take aproximatly 80GB.
|
@ -0,0 +1,113 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import imageio
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from itertools import cycle
|
||||||
|
import warnings
|
||||||
|
import glob
|
||||||
|
import time
|
||||||
|
from tqdm import tqdm
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from skimage import img_as_ubyte
|
||||||
|
from skimage.transform import resize
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
DEVNULL = open(os.devnull, 'wb')
|
||||||
|
|
||||||
|
|
||||||
|
def save(path, frames, format):
|
||||||
|
if format == '.mp4':
|
||||||
|
imageio.mimsave(path, frames)
|
||||||
|
elif format == '.png':
|
||||||
|
if os.path.exists(path):
|
||||||
|
print ("Warning: skiping video %s" % os.path.basename(path))
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
os.makedirs(path)
|
||||||
|
for j, frame in enumerate(frames):
|
||||||
|
imageio.imsave(os.path.join(path, str(j).zfill(7) + '.png'), frames[j])
|
||||||
|
else:
|
||||||
|
print ("Unknown format %s" % format)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
def download(video_id, args):
|
||||||
|
video_path = os.path.join(args.video_folder, video_id + ".mp4")
|
||||||
|
subprocess.call([args.youtube, '-f', "''best/mp4''", '--write-auto-sub', '--write-sub',
|
||||||
|
'--sub-lang', 'en', '--skip-unavailable-fragments',
|
||||||
|
"https://www.youtube.com/watch?v=" + video_id, "--output",
|
||||||
|
video_path], stdout=DEVNULL, stderr=DEVNULL)
|
||||||
|
return video_path
|
||||||
|
|
||||||
|
|
||||||
|
def run(data):
|
||||||
|
video_id, args = data
|
||||||
|
if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')):
|
||||||
|
download(video_id.split('#')[0], args)
|
||||||
|
|
||||||
|
if not os.path.exists(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4')):
|
||||||
|
print ('Can not load video %s, broken link' % video_id.split('#')[0])
|
||||||
|
return
|
||||||
|
reader = imageio.get_reader(os.path.join(args.video_folder, video_id.split('#')[0] + '.mp4'))
|
||||||
|
fps = reader.get_meta_data()['fps']
|
||||||
|
|
||||||
|
df = pd.read_csv(args.metadata)
|
||||||
|
df = df[df['video_id'] == video_id]
|
||||||
|
|
||||||
|
all_chunks_dict = [{'start': df['start'].iloc[j], 'end': df['end'].iloc[j],
|
||||||
|
'bbox': list(map(int, df['bbox'].iloc[j].split('-'))), 'frames':[]} for j in range(df.shape[0])]
|
||||||
|
ref_fps = df['fps'].iloc[0]
|
||||||
|
ref_height = df['height'].iloc[0]
|
||||||
|
ref_width = df['width'].iloc[0]
|
||||||
|
partition = df['partition'].iloc[0]
|
||||||
|
try:
|
||||||
|
for i, frame in enumerate(reader):
|
||||||
|
for entry in all_chunks_dict:
|
||||||
|
if (i * ref_fps >= entry['start'] * fps) and (i * ref_fps < entry['end'] * fps):
|
||||||
|
left, top, right, bot = entry['bbox']
|
||||||
|
left = int(left / (ref_width / frame.shape[1]))
|
||||||
|
top = int(top / (ref_height / frame.shape[0]))
|
||||||
|
right = int(right / (ref_width / frame.shape[1]))
|
||||||
|
bot = int(bot / (ref_height / frame.shape[0]))
|
||||||
|
crop = frame[top:bot, left:right]
|
||||||
|
if args.image_shape is not None:
|
||||||
|
crop = img_as_ubyte(resize(crop, args.image_shape, anti_aliasing=True))
|
||||||
|
entry['frames'].append(crop)
|
||||||
|
except imageio.core.format.CannotReadFrameError:
|
||||||
|
None
|
||||||
|
|
||||||
|
for entry in all_chunks_dict:
|
||||||
|
first_part = '#'.join(video_id.split('#')[::-1])
|
||||||
|
path = first_part + '#' + str(entry['start']).zfill(6) + '#' + str(entry['end']).zfill(6) + '.mp4'
|
||||||
|
save(os.path.join(args.out_folder, partition, path), entry['frames'], args.format)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--video_folder", default='youtube-taichi', help='Path to youtube videos')
|
||||||
|
parser.add_argument("--metadata", default='taichi-metadata-new.csv', help='Path to metadata')
|
||||||
|
parser.add_argument("--out_folder", default='taichi-png', help='Path to output')
|
||||||
|
parser.add_argument("--format", default='.png', help='Storing format')
|
||||||
|
parser.add_argument("--workers", default=1, type=int, help='Number of workers')
|
||||||
|
parser.add_argument("--youtube", default='./youtube-dl', help='Path to youtube-dl')
|
||||||
|
|
||||||
|
parser.add_argument("--image_shape", default=(256, 256), type=lambda x: tuple(map(int, x.split(','))),
|
||||||
|
help="Image shape, None for no resize")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if not os.path.exists(args.video_folder):
|
||||||
|
os.makedirs(args.video_folder)
|
||||||
|
if not os.path.exists(args.out_folder):
|
||||||
|
os.makedirs(args.out_folder)
|
||||||
|
for partition in ['test', 'train']:
|
||||||
|
if not os.path.exists(os.path.join(args.out_folder, partition)):
|
||||||
|
os.makedirs(os.path.join(args.out_folder, partition))
|
||||||
|
|
||||||
|
df = pd.read_csv(args.metadata)
|
||||||
|
video_ids = set(df['video_id'])
|
||||||
|
pool = Pool(processes=args.workers)
|
||||||
|
args_list = cycle([args])
|
||||||
|
for chunks_data in tqdm(pool.imap_unordered(run, zip(video_ids, args_list))):
|
||||||
|
None
|
|
@ -0,0 +1,169 @@
|
|||||||
|
#!/user/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import yaml
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
from skimage.transform import resize
|
||||||
|
from skimage import img_as_ubyte
|
||||||
|
import torch
|
||||||
|
from sync_batchnorm import DataParallelWithCallback
|
||||||
|
|
||||||
|
from modules.generator import OcclusionAwareGenerator
|
||||||
|
from modules.keypoint_detector import KPDetector
|
||||||
|
from animate import normalize_kp
|
||||||
|
from scipy.spatial import ConvexHull
|
||||||
|
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoints(config_path, checkpoint_path, cpu=True):
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
|
||||||
|
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
if not cpu:
|
||||||
|
generator.cuda()
|
||||||
|
|
||||||
|
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
if not cpu:
|
||||||
|
kp_detector.cuda()
|
||||||
|
|
||||||
|
if cpu:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
|
||||||
|
generator.load_state_dict(checkpoint['generator'])
|
||||||
|
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
||||||
|
|
||||||
|
if not cpu:
|
||||||
|
generator = DataParallelWithCallback(generator)
|
||||||
|
kp_detector = DataParallelWithCallback(kp_detector)
|
||||||
|
|
||||||
|
generator.eval()
|
||||||
|
kp_detector.eval()
|
||||||
|
|
||||||
|
return generator, kp_detector
|
||||||
|
|
||||||
|
|
||||||
|
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True,
|
||||||
|
cpu=True):
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = []
|
||||||
|
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
||||||
|
if not cpu:
|
||||||
|
source = source.cuda()
|
||||||
|
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
||||||
|
kp_source = kp_detector(source)
|
||||||
|
kp_driving_initial = kp_detector(driving[:, :, 0])
|
||||||
|
|
||||||
|
for frame_idx in tqdm(range(driving.shape[2])):
|
||||||
|
driving_frame = driving[:, :, frame_idx]
|
||||||
|
if not cpu:
|
||||||
|
driving_frame = driving_frame.cuda()
|
||||||
|
kp_driving = kp_detector(driving_frame)
|
||||||
|
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
||||||
|
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
||||||
|
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
|
||||||
|
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
|
||||||
|
|
||||||
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def find_best_frame(source, driving, cpu=False):
|
||||||
|
import face_alignment
|
||||||
|
|
||||||
|
def normalize_kp(kp):
|
||||||
|
kp = kp - kp.mean(axis=0, keepdims=True)
|
||||||
|
area = ConvexHull(kp[:, :2]).volume
|
||||||
|
area = np.sqrt(area)
|
||||||
|
kp[:, :2] = kp[:, :2] / area
|
||||||
|
return kp
|
||||||
|
|
||||||
|
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
||||||
|
device='cpu' if cpu else 'cuda')
|
||||||
|
kp_source = fa.get_landmarks(255 * source)[0]
|
||||||
|
kp_source = normalize_kp(kp_source)
|
||||||
|
norm = float('inf')
|
||||||
|
frame_num = 0
|
||||||
|
for i, image in tqdm(enumerate(driving)):
|
||||||
|
kp_driving = fa.get_landmarks(255 * image)[0]
|
||||||
|
kp_driving = normalize_kp(kp_driving)
|
||||||
|
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
||||||
|
if new_norm < norm:
|
||||||
|
norm = new_norm
|
||||||
|
frame_num = i
|
||||||
|
return frame_num
|
||||||
|
|
||||||
|
|
||||||
|
def h_interface(input_image: np.ndarray):
|
||||||
|
parser = ArgumentParser()
|
||||||
|
opt = parser.parse_args()
|
||||||
|
opt.config = "./config/vox-adv-256.yaml"
|
||||||
|
opt.checkpoint = "./checkpoints/vox-adv-cpk.pth.tar"
|
||||||
|
opt.source_image = input_image
|
||||||
|
opt.driving_video = "./data/chuck.mp4"
|
||||||
|
opt.result_video = "./data/result.mp4".format(uuid.uuid1().hex)
|
||||||
|
opt.relative = True
|
||||||
|
opt.adapt_scale = True
|
||||||
|
opt.cpu = True
|
||||||
|
opt.find_best_frame = False
|
||||||
|
opt.best_frame = False
|
||||||
|
|
||||||
|
source_image = opt.source_image
|
||||||
|
reader = imageio.get_reader(opt.driving_video)
|
||||||
|
fps = reader.get_meta_data()['fps']
|
||||||
|
driving_video = []
|
||||||
|
try:
|
||||||
|
for im in reader:
|
||||||
|
driving_video.append(im)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
source_image = resize(source_image, (256, 256))[..., :3]
|
||||||
|
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
||||||
|
generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
|
||||||
|
|
||||||
|
if opt.find_best_frame or opt.best_frame is not None:
|
||||||
|
i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
|
||||||
|
print("Best frame: " + str(i))
|
||||||
|
driving_forward = driving_video[i:]
|
||||||
|
driving_backward = driving_video[:(i + 1)][::-1]
|
||||||
|
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector,
|
||||||
|
relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
||||||
|
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector,
|
||||||
|
relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
||||||
|
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
||||||
|
else:
|
||||||
|
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative,
|
||||||
|
adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
||||||
|
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
||||||
|
return opt.result_video
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo = gr.Interface(
|
||||||
|
fn=h_interface,
|
||||||
|
inputs=gr.Image(type="numpy", label="Input Image"),
|
||||||
|
outputs=gr.Video(label="Output Video")
|
||||||
|
)
|
||||||
|
|
||||||
|
demo.launch()
|
@ -0,0 +1,168 @@
|
|||||||
|
import sys
|
||||||
|
import yaml
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
import numpy as np
|
||||||
|
from skimage.transform import resize
|
||||||
|
from skimage import img_as_ubyte
|
||||||
|
import torch
|
||||||
|
from sync_batchnorm import DataParallelWithCallback
|
||||||
|
|
||||||
|
from modules.generator import OcclusionAwareGenerator
|
||||||
|
from modules.keypoint_detector import KPDetector
|
||||||
|
from animate import normalize_kp
|
||||||
|
|
||||||
|
import ffmpeg
|
||||||
|
from os.path import splitext
|
||||||
|
from shutil import copyfileobj
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
|
||||||
|
|
||||||
|
def load_checkpoints(config_path, checkpoint_path, cpu=True):
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = yaml.full_load(f)
|
||||||
|
|
||||||
|
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
if not cpu:
|
||||||
|
generator.cuda()
|
||||||
|
|
||||||
|
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
if not cpu:
|
||||||
|
kp_detector.cuda()
|
||||||
|
|
||||||
|
if cpu:
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
|
||||||
|
generator.load_state_dict(checkpoint['generator'])
|
||||||
|
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
||||||
|
|
||||||
|
if not cpu:
|
||||||
|
generator = DataParallelWithCallback(generator)
|
||||||
|
kp_detector = DataParallelWithCallback(kp_detector)
|
||||||
|
|
||||||
|
generator.eval()
|
||||||
|
kp_detector.eval()
|
||||||
|
|
||||||
|
return generator, kp_detector
|
||||||
|
|
||||||
|
|
||||||
|
def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=True):
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = []
|
||||||
|
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
|
||||||
|
if not cpu:
|
||||||
|
source = source.cuda()
|
||||||
|
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
|
||||||
|
kp_source = kp_detector(source)
|
||||||
|
kp_driving_initial = kp_detector(driving[:, :, 0])
|
||||||
|
|
||||||
|
for frame_idx in tqdm(range(driving.shape[2])):
|
||||||
|
driving_frame = driving[:, :, frame_idx]
|
||||||
|
if not cpu:
|
||||||
|
driving_frame = driving_frame.cuda()
|
||||||
|
kp_driving = kp_detector(driving_frame)
|
||||||
|
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
|
||||||
|
kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
|
||||||
|
use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
|
||||||
|
out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
|
||||||
|
|
||||||
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def find_best_frame(source, driving, cpu=False):
|
||||||
|
import face_alignment # type: ignore (local file)
|
||||||
|
from scipy.spatial import ConvexHull
|
||||||
|
|
||||||
|
def normalize_kp(kp):
|
||||||
|
kp = kp - kp.mean(axis=0, keepdims=True)
|
||||||
|
area = ConvexHull(kp[:, :2]).volume
|
||||||
|
area = np.sqrt(area)
|
||||||
|
kp[:, :2] = kp[:, :2] / area
|
||||||
|
return kp
|
||||||
|
|
||||||
|
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
|
||||||
|
device='cpu' if cpu else 'cuda')
|
||||||
|
kp_source = fa.get_landmarks(255 * source)[0]
|
||||||
|
kp_source = normalize_kp(kp_source)
|
||||||
|
norm = float('inf')
|
||||||
|
frame_num = 0
|
||||||
|
for i, image in tqdm(enumerate(driving)):
|
||||||
|
kp_driving = fa.get_landmarks(255 * image)[0]
|
||||||
|
kp_driving = normalize_kp(kp_driving)
|
||||||
|
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()
|
||||||
|
if new_norm < norm:
|
||||||
|
norm = new_norm
|
||||||
|
frame_num = i
|
||||||
|
return frame_num
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--config", required=True, help="path to config")
|
||||||
|
parser.add_argument("--checkpoint", default='vox-cpk.pth.tar', help="path to checkpoint to restore")
|
||||||
|
|
||||||
|
parser.add_argument("--source_image", default='sup-mat/source.png', help="path to source image")
|
||||||
|
parser.add_argument("--driving_video", default='driving.mp4', help="path to driving video")
|
||||||
|
parser.add_argument("--result_video", default='result.mp4', help="path to output")
|
||||||
|
|
||||||
|
parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates")
|
||||||
|
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints")
|
||||||
|
|
||||||
|
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true",
|
||||||
|
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)")
|
||||||
|
|
||||||
|
parser.add_argument("--best_frame", dest="best_frame", type=int, default=None, help="Set frame to start from.")
|
||||||
|
|
||||||
|
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
|
||||||
|
|
||||||
|
parser.add_argument("--audio", dest="audio", action="store_true", help="copy audio to output from the driving video" )
|
||||||
|
|
||||||
|
parser.set_defaults(relative=False)
|
||||||
|
parser.set_defaults(adapt_scale=False)
|
||||||
|
parser.set_defaults(audio_on=False)
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
source_image = imageio.imread(opt.source_image)
|
||||||
|
reader = imageio.get_reader(opt.driving_video)
|
||||||
|
fps = reader.get_meta_data()['fps']
|
||||||
|
driving_video = []
|
||||||
|
try:
|
||||||
|
for im in reader:
|
||||||
|
driving_video.append(im)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
reader.close()
|
||||||
|
|
||||||
|
source_image = resize(source_image, (256, 256))[..., :3]
|
||||||
|
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]
|
||||||
|
generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)
|
||||||
|
|
||||||
|
if opt.find_best_frame or opt.best_frame is not None:
|
||||||
|
i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)
|
||||||
|
print("Best frame: " + str(i))
|
||||||
|
driving_forward = driving_video[i:]
|
||||||
|
driving_backward = driving_video[:(i+1)][::-1]
|
||||||
|
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
||||||
|
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
||||||
|
predictions = predictions_backward[::-1] + predictions_forward[1:]
|
||||||
|
else:
|
||||||
|
predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
|
||||||
|
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps)
|
||||||
|
|
||||||
|
if opt.audio:
|
||||||
|
try:
|
||||||
|
with NamedTemporaryFile(suffix=splitext(opt.result_video)[1]) as output:
|
||||||
|
ffmpeg.output(ffmpeg.input(opt.result_video).video, ffmpeg.input(opt.driving_video).audio, output.name, c='copy').run()
|
||||||
|
with open(opt.result_video, 'wb') as result:
|
||||||
|
copyfileobj(output, result)
|
||||||
|
except ffmpeg.Error:
|
||||||
|
print("Failed to copy audio: the driving video may have no audio track or the audio format is invalid.")
|
After Width: | Height: | Size: 52 KiB |
After Width: | Height: | Size: 242 KiB |
After Width: | Height: | Size: 1007 KiB |
|
@ -0,0 +1,197 @@
|
|||||||
|
import os
|
||||||
|
from skimage import io, img_as_float32
|
||||||
|
from skimage.color import gray2rgb
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from imageio import mimread
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import pandas as pd
|
||||||
|
from augmentation import AllAugmentationTransform
|
||||||
|
import glob
|
||||||
|
|
||||||
|
|
||||||
|
def read_video(name, frame_shape):
|
||||||
|
"""
|
||||||
|
Read video which can be:
|
||||||
|
- an image of concatenated frames
|
||||||
|
- '.mp4' and'.gif'
|
||||||
|
- folder with videos
|
||||||
|
"""
|
||||||
|
|
||||||
|
if os.path.isdir(name):
|
||||||
|
frames = sorted(os.listdir(name))
|
||||||
|
num_frames = len(frames)
|
||||||
|
video_array = np.array(
|
||||||
|
[img_as_float32(io.imread(os.path.join(name, frames[idx]))) for idx in range(num_frames)])
|
||||||
|
elif name.lower().endswith('.png') or name.lower().endswith('.jpg'):
|
||||||
|
image = io.imread(name)
|
||||||
|
|
||||||
|
if len(image.shape) == 2 or image.shape[2] == 1:
|
||||||
|
image = gray2rgb(image)
|
||||||
|
|
||||||
|
if image.shape[2] == 4:
|
||||||
|
image = image[..., :3]
|
||||||
|
|
||||||
|
image = img_as_float32(image)
|
||||||
|
|
||||||
|
video_array = np.moveaxis(image, 1, 0)
|
||||||
|
|
||||||
|
video_array = video_array.reshape((-1,) + frame_shape)
|
||||||
|
video_array = np.moveaxis(video_array, 1, 2)
|
||||||
|
elif name.lower().endswith('.gif') or name.lower().endswith('.mp4') or name.lower().endswith('.mov'):
|
||||||
|
video = np.array(mimread(name))
|
||||||
|
if len(video.shape) == 3:
|
||||||
|
video = np.array([gray2rgb(frame) for frame in video])
|
||||||
|
if video.shape[-1] == 4:
|
||||||
|
video = video[..., :3]
|
||||||
|
video_array = img_as_float32(video)
|
||||||
|
else:
|
||||||
|
raise Exception("Unknown file extensions %s" % name)
|
||||||
|
|
||||||
|
return video_array
|
||||||
|
|
||||||
|
|
||||||
|
class FramesDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset of videos, each video can be represented as:
|
||||||
|
- an image of concatenated frames
|
||||||
|
- '.mp4' or '.gif'
|
||||||
|
- folder with all frames
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, root_dir, frame_shape=(256, 256, 3), id_sampling=False, is_train=True,
|
||||||
|
random_seed=0, pairs_list=None, augmentation_params=None):
|
||||||
|
self.root_dir = root_dir
|
||||||
|
self.videos = os.listdir(root_dir)
|
||||||
|
self.frame_shape = tuple(frame_shape)
|
||||||
|
self.pairs_list = pairs_list
|
||||||
|
self.id_sampling = id_sampling
|
||||||
|
if os.path.exists(os.path.join(root_dir, 'train')):
|
||||||
|
assert os.path.exists(os.path.join(root_dir, 'test'))
|
||||||
|
print("Use predefined train-test split.")
|
||||||
|
if id_sampling:
|
||||||
|
train_videos = {os.path.basename(video).split('#')[0] for video in
|
||||||
|
os.listdir(os.path.join(root_dir, 'train'))}
|
||||||
|
train_videos = list(train_videos)
|
||||||
|
else:
|
||||||
|
train_videos = os.listdir(os.path.join(root_dir, 'train'))
|
||||||
|
test_videos = os.listdir(os.path.join(root_dir, 'test'))
|
||||||
|
self.root_dir = os.path.join(self.root_dir, 'train' if is_train else 'test')
|
||||||
|
else:
|
||||||
|
print("Use random train-test split.")
|
||||||
|
train_videos, test_videos = train_test_split(self.videos, random_state=random_seed, test_size=0.2)
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
self.videos = train_videos
|
||||||
|
else:
|
||||||
|
self.videos = test_videos
|
||||||
|
|
||||||
|
self.is_train = is_train
|
||||||
|
|
||||||
|
if self.is_train:
|
||||||
|
self.transform = AllAugmentationTransform(**augmentation_params)
|
||||||
|
else:
|
||||||
|
self.transform = None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.videos)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if self.is_train and self.id_sampling:
|
||||||
|
name = self.videos[idx]
|
||||||
|
path = np.random.choice(glob.glob(os.path.join(self.root_dir, name + '*.mp4')))
|
||||||
|
else:
|
||||||
|
name = self.videos[idx]
|
||||||
|
path = os.path.join(self.root_dir, name)
|
||||||
|
|
||||||
|
video_name = os.path.basename(path)
|
||||||
|
|
||||||
|
if self.is_train and os.path.isdir(path):
|
||||||
|
frames = os.listdir(path)
|
||||||
|
num_frames = len(frames)
|
||||||
|
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2))
|
||||||
|
video_array = [img_as_float32(io.imread(os.path.join(path, frames[idx]))) for idx in frame_idx]
|
||||||
|
else:
|
||||||
|
video_array = read_video(path, frame_shape=self.frame_shape)
|
||||||
|
num_frames = len(video_array)
|
||||||
|
frame_idx = np.sort(np.random.choice(num_frames, replace=True, size=2)) if self.is_train else range(
|
||||||
|
num_frames)
|
||||||
|
video_array = video_array[frame_idx]
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
video_array = self.transform(video_array)
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
if self.is_train:
|
||||||
|
source = np.array(video_array[0], dtype='float32')
|
||||||
|
driving = np.array(video_array[1], dtype='float32')
|
||||||
|
|
||||||
|
out['driving'] = driving.transpose((2, 0, 1))
|
||||||
|
out['source'] = source.transpose((2, 0, 1))
|
||||||
|
else:
|
||||||
|
video = np.array(video_array, dtype='float32')
|
||||||
|
out['video'] = video.transpose((3, 0, 1, 2))
|
||||||
|
|
||||||
|
out['name'] = video_name
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetRepeater(Dataset):
|
||||||
|
"""
|
||||||
|
Pass several times over the same dataset for better i/o performance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, num_repeats=100):
|
||||||
|
self.dataset = dataset
|
||||||
|
self.num_repeats = num_repeats
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_repeats * self.dataset.__len__()
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.dataset[idx % self.dataset.__len__()]
|
||||||
|
|
||||||
|
|
||||||
|
class PairedDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset of pairs for animation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial_dataset, number_of_pairs, seed=0):
|
||||||
|
self.initial_dataset = initial_dataset
|
||||||
|
pairs_list = self.initial_dataset.pairs_list
|
||||||
|
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
if pairs_list is None:
|
||||||
|
max_idx = min(number_of_pairs, len(initial_dataset))
|
||||||
|
nx, ny = max_idx, max_idx
|
||||||
|
xy = np.mgrid[:nx, :ny].reshape(2, -1).T
|
||||||
|
number_of_pairs = min(xy.shape[0], number_of_pairs)
|
||||||
|
self.pairs = xy.take(np.random.choice(xy.shape[0], number_of_pairs, replace=False), axis=0)
|
||||||
|
else:
|
||||||
|
videos = self.initial_dataset.videos
|
||||||
|
name_to_index = {name: index for index, name in enumerate(videos)}
|
||||||
|
pairs = pd.read_csv(pairs_list)
|
||||||
|
pairs = pairs[np.logical_and(pairs['source'].isin(videos), pairs['driving'].isin(videos))]
|
||||||
|
|
||||||
|
number_of_pairs = min(pairs.shape[0], number_of_pairs)
|
||||||
|
self.pairs = []
|
||||||
|
self.start_frames = []
|
||||||
|
for ind in range(number_of_pairs):
|
||||||
|
self.pairs.append(
|
||||||
|
(name_to_index[pairs['driving'].iloc[ind]], name_to_index[pairs['source'].iloc[ind]]))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.pairs)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
pair = self.pairs[idx]
|
||||||
|
first = self.initial_dataset[pair[0]]
|
||||||
|
second = self.initial_dataset[pair[1]]
|
||||||
|
first = {'driving_' + key: value for key, value in first.items()}
|
||||||
|
second = {'source_' + key: value for key, value in second.items()}
|
||||||
|
|
||||||
|
return {**first, **second}
|
@ -0,0 +1,211 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import imageio
|
||||||
|
|
||||||
|
import os
|
||||||
|
from skimage.draw import disk
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import collections
|
||||||
|
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
def __init__(self, log_dir, checkpoint_freq=100, visualizer_params=None, zfill_num=8, log_file_name='log.txt'):
|
||||||
|
|
||||||
|
self.loss_list = []
|
||||||
|
self.cpk_dir = log_dir
|
||||||
|
self.visualizations_dir = os.path.join(log_dir, 'train-vis')
|
||||||
|
if not os.path.exists(self.visualizations_dir):
|
||||||
|
os.makedirs(self.visualizations_dir)
|
||||||
|
self.log_file = open(os.path.join(log_dir, log_file_name), 'a')
|
||||||
|
self.zfill_num = zfill_num
|
||||||
|
self.visualizer = Visualizer(**visualizer_params)
|
||||||
|
self.checkpoint_freq = checkpoint_freq
|
||||||
|
self.epoch = 0
|
||||||
|
self.best_loss = float('inf')
|
||||||
|
self.names = None
|
||||||
|
|
||||||
|
def log_scores(self, loss_names):
|
||||||
|
loss_mean = np.array(self.loss_list).mean(axis=0)
|
||||||
|
|
||||||
|
loss_string = "; ".join(["%s - %.5f" % (name, value) for name, value in zip(loss_names, loss_mean)])
|
||||||
|
loss_string = str(self.epoch).zfill(self.zfill_num) + ") " + loss_string
|
||||||
|
|
||||||
|
print(loss_string, file=self.log_file)
|
||||||
|
self.loss_list = []
|
||||||
|
self.log_file.flush()
|
||||||
|
|
||||||
|
def visualize_rec(self, inp, out):
|
||||||
|
image = self.visualizer.visualize(inp['driving'], inp['source'], out)
|
||||||
|
imageio.imsave(os.path.join(self.visualizations_dir, "%s-rec.png" % str(self.epoch).zfill(self.zfill_num)), image)
|
||||||
|
|
||||||
|
def save_cpk(self, emergent=False):
|
||||||
|
cpk = {k: v.state_dict() for k, v in self.models.items()}
|
||||||
|
cpk['epoch'] = self.epoch
|
||||||
|
cpk_path = os.path.join(self.cpk_dir, '%s-checkpoint.pth.tar' % str(self.epoch).zfill(self.zfill_num))
|
||||||
|
if not (os.path.exists(cpk_path) and emergent):
|
||||||
|
torch.save(cpk, cpk_path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_cpk(checkpoint_path, generator=None, discriminator=None, kp_detector=None,
|
||||||
|
optimizer_generator=None, optimizer_discriminator=None, optimizer_kp_detector=None):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
map_location = None
|
||||||
|
else:
|
||||||
|
map_location = 'cpu'
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location)
|
||||||
|
if generator is not None:
|
||||||
|
generator.load_state_dict(checkpoint['generator'])
|
||||||
|
if kp_detector is not None:
|
||||||
|
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
||||||
|
if discriminator is not None:
|
||||||
|
try:
|
||||||
|
discriminator.load_state_dict(checkpoint['discriminator'])
|
||||||
|
except:
|
||||||
|
print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
|
||||||
|
if optimizer_generator is not None:
|
||||||
|
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
|
||||||
|
if optimizer_discriminator is not None:
|
||||||
|
try:
|
||||||
|
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
|
||||||
|
except RuntimeError as e:
|
||||||
|
print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
|
||||||
|
if optimizer_kp_detector is not None:
|
||||||
|
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
|
||||||
|
|
||||||
|
return checkpoint['epoch']
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if 'models' in self.__dict__:
|
||||||
|
self.save_cpk()
|
||||||
|
self.log_file.close()
|
||||||
|
|
||||||
|
def log_iter(self, losses):
|
||||||
|
losses = collections.OrderedDict(losses.items())
|
||||||
|
if self.names is None:
|
||||||
|
self.names = list(losses.keys())
|
||||||
|
self.loss_list.append(list(losses.values()))
|
||||||
|
|
||||||
|
def log_epoch(self, epoch, models, inp, out):
|
||||||
|
self.epoch = epoch
|
||||||
|
self.models = models
|
||||||
|
if (self.epoch + 1) % self.checkpoint_freq == 0:
|
||||||
|
self.save_cpk()
|
||||||
|
self.log_scores(self.names)
|
||||||
|
self.visualize_rec(inp, out)
|
||||||
|
|
||||||
|
|
||||||
|
class Visualizer:
|
||||||
|
def __init__(self, kp_size=5, draw_border=False, colormap='gist_rainbow'):
|
||||||
|
self.kp_size = kp_size
|
||||||
|
self.draw_border = draw_border
|
||||||
|
self.colormap = plt.get_cmap(colormap)
|
||||||
|
|
||||||
|
def draw_image_with_kp(self, image, kp_array):
|
||||||
|
image = np.copy(image)
|
||||||
|
spatial_size = np.array(image.shape[:2][::-1])[np.newaxis]
|
||||||
|
kp_array = spatial_size * (kp_array + 1) / 2
|
||||||
|
num_kp = kp_array.shape[0]
|
||||||
|
for kp_ind, kp in enumerate(kp_array):
|
||||||
|
rr, cc = disk(kp[1], kp[0], self.kp_size, shape=image.shape[:2])
|
||||||
|
image[rr, cc] = np.array(self.colormap(kp_ind / num_kp))[:3]
|
||||||
|
return image
|
||||||
|
|
||||||
|
def create_image_column_with_kp(self, images, kp):
|
||||||
|
image_array = np.array([self.draw_image_with_kp(v, k) for v, k in zip(images, kp)])
|
||||||
|
return self.create_image_column(image_array)
|
||||||
|
|
||||||
|
def create_image_column(self, images):
|
||||||
|
if self.draw_border:
|
||||||
|
images = np.copy(images)
|
||||||
|
images[:, :, [0, -1]] = (1, 1, 1)
|
||||||
|
return np.concatenate(list(images), axis=0)
|
||||||
|
|
||||||
|
def create_image_grid(self, *args):
|
||||||
|
out = []
|
||||||
|
for arg in args:
|
||||||
|
if type(arg) == tuple:
|
||||||
|
out.append(self.create_image_column_with_kp(arg[0], arg[1]))
|
||||||
|
else:
|
||||||
|
out.append(self.create_image_column(arg))
|
||||||
|
return np.concatenate(out, axis=1)
|
||||||
|
|
||||||
|
def visualize(self, driving, source, out):
|
||||||
|
images = []
|
||||||
|
|
||||||
|
# Source image with keypoints
|
||||||
|
source = source.data.cpu()
|
||||||
|
kp_source = out['kp_source']['value'].data.cpu().numpy()
|
||||||
|
source = np.transpose(source, [0, 2, 3, 1])
|
||||||
|
images.append((source, kp_source))
|
||||||
|
|
||||||
|
# Equivariance visualization
|
||||||
|
if 'transformed_frame' in out:
|
||||||
|
transformed = out['transformed_frame'].data.cpu().numpy()
|
||||||
|
transformed = np.transpose(transformed, [0, 2, 3, 1])
|
||||||
|
transformed_kp = out['transformed_kp']['value'].data.cpu().numpy()
|
||||||
|
images.append((transformed, transformed_kp))
|
||||||
|
|
||||||
|
# Driving image with keypoints
|
||||||
|
kp_driving = out['kp_driving']['value'].data.cpu().numpy()
|
||||||
|
driving = driving.data.cpu().numpy()
|
||||||
|
driving = np.transpose(driving, [0, 2, 3, 1])
|
||||||
|
images.append((driving, kp_driving))
|
||||||
|
|
||||||
|
# Deformed image
|
||||||
|
if 'deformed' in out:
|
||||||
|
deformed = out['deformed'].data.cpu().numpy()
|
||||||
|
deformed = np.transpose(deformed, [0, 2, 3, 1])
|
||||||
|
images.append(deformed)
|
||||||
|
|
||||||
|
# Result with and without keypoints
|
||||||
|
prediction = out['prediction'].data.cpu().numpy()
|
||||||
|
prediction = np.transpose(prediction, [0, 2, 3, 1])
|
||||||
|
if 'kp_norm' in out:
|
||||||
|
kp_norm = out['kp_norm']['value'].data.cpu().numpy()
|
||||||
|
images.append((prediction, kp_norm))
|
||||||
|
images.append(prediction)
|
||||||
|
|
||||||
|
|
||||||
|
## Occlusion map
|
||||||
|
if 'occlusion_map' in out:
|
||||||
|
occlusion_map = out['occlusion_map'].data.cpu().repeat(1, 3, 1, 1)
|
||||||
|
occlusion_map = F.interpolate(occlusion_map, size=source.shape[1:3]).numpy()
|
||||||
|
occlusion_map = np.transpose(occlusion_map, [0, 2, 3, 1])
|
||||||
|
images.append(occlusion_map)
|
||||||
|
|
||||||
|
# Deformed images according to each individual transform
|
||||||
|
if 'sparse_deformed' in out:
|
||||||
|
full_mask = []
|
||||||
|
for i in range(out['sparse_deformed'].shape[1]):
|
||||||
|
image = out['sparse_deformed'][:, i].data.cpu()
|
||||||
|
image = F.interpolate(image, size=source.shape[1:3])
|
||||||
|
mask = out['mask'][:, i:(i+1)].data.cpu().repeat(1, 3, 1, 1)
|
||||||
|
mask = F.interpolate(mask, size=source.shape[1:3])
|
||||||
|
image = np.transpose(image.numpy(), (0, 2, 3, 1))
|
||||||
|
mask = np.transpose(mask.numpy(), (0, 2, 3, 1))
|
||||||
|
|
||||||
|
if i != 0:
|
||||||
|
color = np.array(self.colormap((i - 1) / (out['sparse_deformed'].shape[1] - 1)))[:3]
|
||||||
|
else:
|
||||||
|
color = np.array((0, 0, 0))
|
||||||
|
|
||||||
|
color = color.reshape((1, 1, 1, 3))
|
||||||
|
|
||||||
|
images.append(image)
|
||||||
|
if i != 0:
|
||||||
|
images.append(mask * color)
|
||||||
|
else:
|
||||||
|
images.append(mask)
|
||||||
|
|
||||||
|
full_mask.append(mask * color)
|
||||||
|
|
||||||
|
images.append(sum(full_mask))
|
||||||
|
|
||||||
|
image = self.create_image_grid(*images)
|
||||||
|
image = (255 * image).astype(np.uint8)
|
||||||
|
return image
|
@ -0,0 +1,113 @@
|
|||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
|
||||||
|
|
||||||
|
|
||||||
|
class DenseMotionNetwork(nn.Module):
|
||||||
|
"""
|
||||||
|
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
|
||||||
|
scale_factor=1, kp_variance=0.01):
|
||||||
|
super(DenseMotionNetwork, self).__init__()
|
||||||
|
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
|
||||||
|
max_features=max_features, num_blocks=num_blocks)
|
||||||
|
|
||||||
|
self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
|
||||||
|
|
||||||
|
if estimate_occlusion_map:
|
||||||
|
self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
|
||||||
|
else:
|
||||||
|
self.occlusion = None
|
||||||
|
|
||||||
|
self.num_kp = num_kp
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
self.kp_variance = kp_variance
|
||||||
|
|
||||||
|
if self.scale_factor != 1:
|
||||||
|
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
||||||
|
|
||||||
|
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
|
||||||
|
"""
|
||||||
|
Eq 6. in the paper H_k(z)
|
||||||
|
"""
|
||||||
|
spatial_size = source_image.shape[2:]
|
||||||
|
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
||||||
|
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
||||||
|
heatmap = gaussian_driving - gaussian_source
|
||||||
|
|
||||||
|
#adding background feature
|
||||||
|
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
|
||||||
|
heatmap = torch.cat([zeros, heatmap], dim=1)
|
||||||
|
heatmap = heatmap.unsqueeze(2)
|
||||||
|
return heatmap
|
||||||
|
|
||||||
|
def create_sparse_motions(self, source_image, kp_driving, kp_source):
|
||||||
|
"""
|
||||||
|
Eq 4. in the paper T_{s<-d}(z)
|
||||||
|
"""
|
||||||
|
bs, _, h, w = source_image.shape
|
||||||
|
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
|
||||||
|
identity_grid = identity_grid.view(1, 1, h, w, 2)
|
||||||
|
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
|
||||||
|
if 'jacobian' in kp_driving:
|
||||||
|
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
|
||||||
|
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
|
||||||
|
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
|
||||||
|
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
|
||||||
|
coordinate_grid = coordinate_grid.squeeze(-1)
|
||||||
|
|
||||||
|
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
|
||||||
|
|
||||||
|
#adding background feature
|
||||||
|
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
|
||||||
|
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
|
||||||
|
return sparse_motions
|
||||||
|
|
||||||
|
def create_deformed_source_image(self, source_image, sparse_motions):
|
||||||
|
"""
|
||||||
|
Eq 7. in the paper \hat{T}_{s<-d}(z)
|
||||||
|
"""
|
||||||
|
bs, _, h, w = source_image.shape
|
||||||
|
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
|
||||||
|
source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
|
||||||
|
sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
|
||||||
|
sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
|
||||||
|
sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
|
||||||
|
return sparse_deformed
|
||||||
|
|
||||||
|
def forward(self, source_image, kp_driving, kp_source):
|
||||||
|
if self.scale_factor != 1:
|
||||||
|
source_image = self.down(source_image)
|
||||||
|
|
||||||
|
bs, _, h, w = source_image.shape
|
||||||
|
|
||||||
|
out_dict = dict()
|
||||||
|
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
|
||||||
|
sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
|
||||||
|
deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
|
||||||
|
out_dict['sparse_deformed'] = deformed_source
|
||||||
|
|
||||||
|
input = torch.cat([heatmap_representation, deformed_source], dim=2)
|
||||||
|
input = input.view(bs, -1, h, w)
|
||||||
|
|
||||||
|
prediction = self.hourglass(input)
|
||||||
|
|
||||||
|
mask = self.mask(prediction)
|
||||||
|
mask = F.softmax(mask, dim=1)
|
||||||
|
out_dict['mask'] = mask
|
||||||
|
mask = mask.unsqueeze(2)
|
||||||
|
sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
|
||||||
|
deformation = (sparse_motion * mask).sum(dim=1)
|
||||||
|
deformation = deformation.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
out_dict['deformation'] = deformation
|
||||||
|
|
||||||
|
# Sec. 3.2 in the paper
|
||||||
|
if self.occlusion:
|
||||||
|
occlusion_map = torch.sigmoid(self.occlusion(prediction))
|
||||||
|
out_dict['occlusion_map'] = occlusion_map
|
||||||
|
|
||||||
|
return out_dict
|
@ -0,0 +1,95 @@
|
|||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from modules.util import kp2gaussian
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlock2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Simple block for processing video (encoder).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
||||||
|
super(DownBlock2d, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
||||||
|
|
||||||
|
if sn:
|
||||||
|
self.conv = nn.utils.spectral_norm(self.conv)
|
||||||
|
|
||||||
|
if norm:
|
||||||
|
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
self.pool = pool
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x
|
||||||
|
out = self.conv(out)
|
||||||
|
if self.norm:
|
||||||
|
out = self.norm(out)
|
||||||
|
out = F.leaky_relu(out, 0.2)
|
||||||
|
if self.pool:
|
||||||
|
out = F.avg_pool2d(out, (2, 2))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
"""
|
||||||
|
Discriminator similar to Pix2Pix
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
||||||
|
sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
|
||||||
|
super(Discriminator, self).__init__()
|
||||||
|
|
||||||
|
down_blocks = []
|
||||||
|
for i in range(num_blocks):
|
||||||
|
down_blocks.append(
|
||||||
|
DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
||||||
|
min(max_features, block_expansion * (2 ** (i + 1))),
|
||||||
|
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList(down_blocks)
|
||||||
|
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
||||||
|
if sn:
|
||||||
|
self.conv = nn.utils.spectral_norm(self.conv)
|
||||||
|
self.use_kp = use_kp
|
||||||
|
self.kp_variance = kp_variance
|
||||||
|
|
||||||
|
def forward(self, x, kp=None):
|
||||||
|
feature_maps = []
|
||||||
|
out = x
|
||||||
|
if self.use_kp:
|
||||||
|
heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
|
||||||
|
out = torch.cat([out, heatmap], dim=1)
|
||||||
|
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
feature_maps.append(down_block(out))
|
||||||
|
out = feature_maps[-1]
|
||||||
|
prediction_map = self.conv(out)
|
||||||
|
|
||||||
|
return feature_maps, prediction_map
|
||||||
|
|
||||||
|
|
||||||
|
class MultiScaleDiscriminator(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-scale (scale) discriminator
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, scales=(), **kwargs):
|
||||||
|
super(MultiScaleDiscriminator, self).__init__()
|
||||||
|
self.scales = scales
|
||||||
|
discs = {}
|
||||||
|
for scale in scales:
|
||||||
|
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
||||||
|
self.discs = nn.ModuleDict(discs)
|
||||||
|
|
||||||
|
def forward(self, x, kp=None):
|
||||||
|
out_dict = {}
|
||||||
|
for scale, disc in self.discs.items():
|
||||||
|
scale = str(scale).replace('-', '.')
|
||||||
|
key = 'prediction_' + scale
|
||||||
|
feature_maps, prediction_map = disc(x[key], kp)
|
||||||
|
out_dict['feature_maps_' + scale] = feature_maps
|
||||||
|
out_dict['prediction_map_' + scale] = prediction_map
|
||||||
|
return out_dict
|
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
||||||
|
from modules.dense_motion import DenseMotionNetwork
|
||||||
|
|
||||||
|
|
||||||
|
class OcclusionAwareGenerator(nn.Module):
|
||||||
|
"""
|
||||||
|
Generator that given source image and and keypoints try to transform image according to movement trajectories
|
||||||
|
induced by keypoints. Generator follows Johnson architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
|
||||||
|
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
|
||||||
|
super(OcclusionAwareGenerator, self).__init__()
|
||||||
|
|
||||||
|
if dense_motion_params is not None:
|
||||||
|
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
|
||||||
|
estimate_occlusion_map=estimate_occlusion_map,
|
||||||
|
**dense_motion_params)
|
||||||
|
else:
|
||||||
|
self.dense_motion_network = None
|
||||||
|
|
||||||
|
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
||||||
|
|
||||||
|
down_blocks = []
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
in_features = min(max_features, block_expansion * (2 ** i))
|
||||||
|
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
||||||
|
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
||||||
|
self.down_blocks = nn.ModuleList(down_blocks)
|
||||||
|
|
||||||
|
up_blocks = []
|
||||||
|
for i in range(num_down_blocks):
|
||||||
|
in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
|
||||||
|
out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
|
||||||
|
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
||||||
|
self.up_blocks = nn.ModuleList(up_blocks)
|
||||||
|
|
||||||
|
self.bottleneck = torch.nn.Sequential()
|
||||||
|
in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
|
||||||
|
for i in range(num_bottleneck_blocks):
|
||||||
|
self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
|
||||||
|
|
||||||
|
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
|
||||||
|
self.estimate_occlusion_map = estimate_occlusion_map
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def deform_input(self, inp, deformation):
|
||||||
|
_, h_old, w_old, _ = deformation.shape
|
||||||
|
_, _, h, w = inp.shape
|
||||||
|
if h_old != h or w_old != w:
|
||||||
|
deformation = deformation.permute(0, 3, 1, 2)
|
||||||
|
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
|
||||||
|
deformation = deformation.permute(0, 2, 3, 1)
|
||||||
|
return F.grid_sample(inp, deformation)
|
||||||
|
|
||||||
|
def forward(self, source_image, kp_driving, kp_source):
|
||||||
|
# Encoding (downsampling) part
|
||||||
|
out = self.first(source_image)
|
||||||
|
for i in range(len(self.down_blocks)):
|
||||||
|
out = self.down_blocks[i](out)
|
||||||
|
|
||||||
|
# Transforming feature representation according to deformation and occlusion
|
||||||
|
output_dict = {}
|
||||||
|
if self.dense_motion_network is not None:
|
||||||
|
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
|
||||||
|
kp_source=kp_source)
|
||||||
|
output_dict['mask'] = dense_motion['mask']
|
||||||
|
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
|
||||||
|
|
||||||
|
if 'occlusion_map' in dense_motion:
|
||||||
|
occlusion_map = dense_motion['occlusion_map']
|
||||||
|
output_dict['occlusion_map'] = occlusion_map
|
||||||
|
else:
|
||||||
|
occlusion_map = None
|
||||||
|
deformation = dense_motion['deformation']
|
||||||
|
out = self.deform_input(out, deformation)
|
||||||
|
|
||||||
|
if occlusion_map is not None:
|
||||||
|
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
|
||||||
|
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
|
||||||
|
out = out * occlusion_map
|
||||||
|
|
||||||
|
output_dict["deformed"] = self.deform_input(source_image, deformation)
|
||||||
|
|
||||||
|
# Decoding part
|
||||||
|
out = self.bottleneck(out)
|
||||||
|
for i in range(len(self.up_blocks)):
|
||||||
|
out = self.up_blocks[i](out)
|
||||||
|
out = self.final(out)
|
||||||
|
out = F.sigmoid(out)
|
||||||
|
|
||||||
|
output_dict["prediction"] = out
|
||||||
|
|
||||||
|
return output_dict
|
@ -0,0 +1,75 @@
|
|||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d
|
||||||
|
|
||||||
|
|
||||||
|
class KPDetector(nn.Module):
|
||||||
|
"""
|
||||||
|
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block_expansion, num_kp, num_channels, max_features,
|
||||||
|
num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
|
||||||
|
single_jacobian_map=False, pad=0):
|
||||||
|
super(KPDetector, self).__init__()
|
||||||
|
|
||||||
|
self.predictor = Hourglass(block_expansion, in_features=num_channels,
|
||||||
|
max_features=max_features, num_blocks=num_blocks)
|
||||||
|
|
||||||
|
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
|
||||||
|
padding=pad)
|
||||||
|
|
||||||
|
if estimate_jacobian:
|
||||||
|
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
|
||||||
|
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
|
||||||
|
out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
|
||||||
|
self.jacobian.weight.data.zero_()
|
||||||
|
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
|
||||||
|
else:
|
||||||
|
self.jacobian = None
|
||||||
|
|
||||||
|
self.temperature = temperature
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
if self.scale_factor != 1:
|
||||||
|
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
||||||
|
|
||||||
|
def gaussian2kp(self, heatmap):
|
||||||
|
"""
|
||||||
|
Extract the mean and from a heatmap
|
||||||
|
"""
|
||||||
|
shape = heatmap.shape
|
||||||
|
heatmap = heatmap.unsqueeze(-1)
|
||||||
|
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
|
||||||
|
value = (heatmap * grid).sum(dim=(2, 3))
|
||||||
|
kp = {'value': value}
|
||||||
|
|
||||||
|
return kp
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.scale_factor != 1:
|
||||||
|
x = self.down(x)
|
||||||
|
|
||||||
|
feature_map = self.predictor(x)
|
||||||
|
prediction = self.kp(feature_map)
|
||||||
|
|
||||||
|
final_shape = prediction.shape
|
||||||
|
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
|
||||||
|
heatmap = F.softmax(heatmap / self.temperature, dim=2)
|
||||||
|
heatmap = heatmap.view(*final_shape)
|
||||||
|
|
||||||
|
out = self.gaussian2kp(heatmap)
|
||||||
|
|
||||||
|
if self.jacobian is not None:
|
||||||
|
jacobian_map = self.jacobian(feature_map)
|
||||||
|
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
|
||||||
|
final_shape[3])
|
||||||
|
heatmap = heatmap.unsqueeze(2)
|
||||||
|
|
||||||
|
jacobian = heatmap * jacobian_map
|
||||||
|
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
|
||||||
|
jacobian = jacobian.sum(dim=-1)
|
||||||
|
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
|
||||||
|
out['jacobian'] = jacobian
|
||||||
|
|
||||||
|
return out
|
@ -0,0 +1,259 @@
|
|||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
|
||||||
|
from torchvision import models
|
||||||
|
import numpy as np
|
||||||
|
from torch.autograd import grad
|
||||||
|
|
||||||
|
|
||||||
|
class Vgg19(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Vgg19 network for perceptual loss. See Sec 3.3.
|
||||||
|
"""
|
||||||
|
def __init__(self, requires_grad=False):
|
||||||
|
super(Vgg19, self).__init__()
|
||||||
|
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||||
|
self.slice1 = torch.nn.Sequential()
|
||||||
|
self.slice2 = torch.nn.Sequential()
|
||||||
|
self.slice3 = torch.nn.Sequential()
|
||||||
|
self.slice4 = torch.nn.Sequential()
|
||||||
|
self.slice5 = torch.nn.Sequential()
|
||||||
|
for x in range(2):
|
||||||
|
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||||
|
for x in range(2, 7):
|
||||||
|
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||||
|
for x in range(7, 12):
|
||||||
|
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||||
|
for x in range(12, 21):
|
||||||
|
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||||
|
for x in range(21, 30):
|
||||||
|
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||||
|
|
||||||
|
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
||||||
|
requires_grad=False)
|
||||||
|
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
if not requires_grad:
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, X):
|
||||||
|
X = (X - self.mean) / self.std
|
||||||
|
h_relu1 = self.slice1(X)
|
||||||
|
h_relu2 = self.slice2(h_relu1)
|
||||||
|
h_relu3 = self.slice3(h_relu2)
|
||||||
|
h_relu4 = self.slice4(h_relu3)
|
||||||
|
h_relu5 = self.slice5(h_relu4)
|
||||||
|
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePyramide(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
|
||||||
|
"""
|
||||||
|
def __init__(self, scales, num_channels):
|
||||||
|
super(ImagePyramide, self).__init__()
|
||||||
|
downs = {}
|
||||||
|
for scale in scales:
|
||||||
|
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
||||||
|
self.downs = nn.ModuleDict(downs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_dict = {}
|
||||||
|
for scale, down_module in self.downs.items():
|
||||||
|
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
|
class Transform:
|
||||||
|
"""
|
||||||
|
Random tps transformation for equivariance constraints. See Sec 3.3
|
||||||
|
"""
|
||||||
|
def __init__(self, bs, **kwargs):
|
||||||
|
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
|
||||||
|
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
|
||||||
|
self.bs = bs
|
||||||
|
|
||||||
|
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
|
||||||
|
self.tps = True
|
||||||
|
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
|
||||||
|
self.control_points = self.control_points.unsqueeze(0)
|
||||||
|
self.control_params = torch.normal(mean=0,
|
||||||
|
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
|
||||||
|
else:
|
||||||
|
self.tps = False
|
||||||
|
|
||||||
|
def transform_frame(self, frame):
|
||||||
|
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)
|
||||||
|
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
||||||
|
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
|
||||||
|
return F.grid_sample(frame, grid, padding_mode="reflection")
|
||||||
|
|
||||||
|
def warp_coordinates(self, coordinates):
|
||||||
|
theta = self.theta.type(coordinates.type())
|
||||||
|
theta = theta.unsqueeze(1)
|
||||||
|
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
|
||||||
|
transformed = transformed.squeeze(-1)
|
||||||
|
|
||||||
|
if self.tps:
|
||||||
|
control_points = self.control_points.type(coordinates.type())
|
||||||
|
control_params = self.control_params.type(coordinates.type())
|
||||||
|
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
|
||||||
|
distances = torch.abs(distances).sum(-1)
|
||||||
|
|
||||||
|
result = distances ** 2
|
||||||
|
result = result * torch.log(distances + 1e-6)
|
||||||
|
result = result * control_params
|
||||||
|
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
|
||||||
|
transformed = transformed + result
|
||||||
|
|
||||||
|
return transformed
|
||||||
|
|
||||||
|
def jacobian(self, coordinates):
|
||||||
|
new_coordinates = self.warp_coordinates(coordinates)
|
||||||
|
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
|
||||||
|
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
|
||||||
|
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
|
||||||
|
return jacobian
|
||||||
|
|
||||||
|
|
||||||
|
def detach_kp(kp):
|
||||||
|
return {key: value.detach() for key, value in kp.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorFullModel(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Merge all generator related updates into single model for better multi-gpu usage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kp_extractor, generator, discriminator, train_params):
|
||||||
|
super(GeneratorFullModel, self).__init__()
|
||||||
|
self.kp_extractor = kp_extractor
|
||||||
|
self.generator = generator
|
||||||
|
self.discriminator = discriminator
|
||||||
|
self.train_params = train_params
|
||||||
|
self.scales = train_params['scales']
|
||||||
|
self.disc_scales = self.discriminator.scales
|
||||||
|
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.pyramid = self.pyramid.cuda()
|
||||||
|
|
||||||
|
self.loss_weights = train_params['loss_weights']
|
||||||
|
|
||||||
|
if sum(self.loss_weights['perceptual']) != 0:
|
||||||
|
self.vgg = Vgg19()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.vgg = self.vgg.cuda()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
kp_source = self.kp_extractor(x['source'])
|
||||||
|
kp_driving = self.kp_extractor(x['driving'])
|
||||||
|
|
||||||
|
generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
|
||||||
|
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
||||||
|
|
||||||
|
loss_values = {}
|
||||||
|
|
||||||
|
pyramide_real = self.pyramid(x['driving'])
|
||||||
|
pyramide_generated = self.pyramid(generated['prediction'])
|
||||||
|
|
||||||
|
if sum(self.loss_weights['perceptual']) != 0:
|
||||||
|
value_total = 0
|
||||||
|
for scale in self.scales:
|
||||||
|
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
|
||||||
|
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
|
||||||
|
|
||||||
|
for i, weight in enumerate(self.loss_weights['perceptual']):
|
||||||
|
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
||||||
|
value_total += self.loss_weights['perceptual'][i] * value
|
||||||
|
loss_values['perceptual'] = value_total
|
||||||
|
|
||||||
|
if self.loss_weights['generator_gan'] != 0:
|
||||||
|
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
|
||||||
|
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
|
||||||
|
value_total = 0
|
||||||
|
for scale in self.disc_scales:
|
||||||
|
key = 'prediction_map_%s' % scale
|
||||||
|
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
||||||
|
value_total += self.loss_weights['generator_gan'] * value
|
||||||
|
loss_values['gen_gan'] = value_total
|
||||||
|
|
||||||
|
if sum(self.loss_weights['feature_matching']) != 0:
|
||||||
|
value_total = 0
|
||||||
|
for scale in self.disc_scales:
|
||||||
|
key = 'feature_maps_%s' % scale
|
||||||
|
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
|
||||||
|
if self.loss_weights['feature_matching'][i] == 0:
|
||||||
|
continue
|
||||||
|
value = torch.abs(a - b).mean()
|
||||||
|
value_total += self.loss_weights['feature_matching'][i] * value
|
||||||
|
loss_values['feature_matching'] = value_total
|
||||||
|
|
||||||
|
if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
|
||||||
|
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
|
||||||
|
transformed_frame = transform.transform_frame(x['driving'])
|
||||||
|
transformed_kp = self.kp_extractor(transformed_frame)
|
||||||
|
|
||||||
|
generated['transformed_frame'] = transformed_frame
|
||||||
|
generated['transformed_kp'] = transformed_kp
|
||||||
|
|
||||||
|
## Value loss part
|
||||||
|
if self.loss_weights['equivariance_value'] != 0:
|
||||||
|
value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
|
||||||
|
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
|
||||||
|
|
||||||
|
## jacobian loss part
|
||||||
|
if self.loss_weights['equivariance_jacobian'] != 0:
|
||||||
|
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
|
||||||
|
transformed_kp['jacobian'])
|
||||||
|
|
||||||
|
normed_driving = torch.inverse(kp_driving['jacobian'])
|
||||||
|
normed_transformed = jacobian_transformed
|
||||||
|
value = torch.matmul(normed_driving, normed_transformed)
|
||||||
|
|
||||||
|
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
|
||||||
|
|
||||||
|
value = torch.abs(eye - value).mean()
|
||||||
|
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
|
||||||
|
|
||||||
|
return loss_values, generated
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorFullModel(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Merge all discriminator related updates into single model for better multi-gpu usage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, kp_extractor, generator, discriminator, train_params):
|
||||||
|
super(DiscriminatorFullModel, self).__init__()
|
||||||
|
self.kp_extractor = kp_extractor
|
||||||
|
self.generator = generator
|
||||||
|
self.discriminator = discriminator
|
||||||
|
self.train_params = train_params
|
||||||
|
self.scales = self.discriminator.scales
|
||||||
|
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.pyramid = self.pyramid.cuda()
|
||||||
|
|
||||||
|
self.loss_weights = train_params['loss_weights']
|
||||||
|
|
||||||
|
def forward(self, x, generated):
|
||||||
|
pyramide_real = self.pyramid(x['driving'])
|
||||||
|
pyramide_generated = self.pyramid(generated['prediction'].detach())
|
||||||
|
|
||||||
|
kp_driving = generated['kp_driving']
|
||||||
|
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
|
||||||
|
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
|
||||||
|
|
||||||
|
loss_values = {}
|
||||||
|
value_total = 0
|
||||||
|
for scale in self.scales:
|
||||||
|
key = 'prediction_map_%s' % scale
|
||||||
|
value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
|
||||||
|
value_total += self.loss_weights['discriminator_gan'] * value.mean()
|
||||||
|
loss_values['disc_gan'] = value_total
|
||||||
|
|
||||||
|
return loss_values
|
@ -0,0 +1,245 @@
|
|||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
|
||||||
|
|
||||||
|
|
||||||
|
def kp2gaussian(kp, spatial_size, kp_variance):
|
||||||
|
"""
|
||||||
|
Transform a keypoint into gaussian like representation
|
||||||
|
"""
|
||||||
|
mean = kp['value']
|
||||||
|
|
||||||
|
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
|
||||||
|
number_of_leading_dimensions = len(mean.shape) - 1
|
||||||
|
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
||||||
|
coordinate_grid = coordinate_grid.view(*shape)
|
||||||
|
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
|
||||||
|
coordinate_grid = coordinate_grid.repeat(*repeats)
|
||||||
|
|
||||||
|
# Preprocess kp shape
|
||||||
|
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
|
||||||
|
mean = mean.view(*shape)
|
||||||
|
|
||||||
|
mean_sub = (coordinate_grid - mean)
|
||||||
|
|
||||||
|
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def make_coordinate_grid(spatial_size, type):
|
||||||
|
"""
|
||||||
|
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
||||||
|
"""
|
||||||
|
h, w = spatial_size
|
||||||
|
x = torch.arange(w).type(type)
|
||||||
|
y = torch.arange(h).type(type)
|
||||||
|
|
||||||
|
x = (2 * (x / (w - 1)) - 1)
|
||||||
|
y = (2 * (y / (h - 1)) - 1)
|
||||||
|
|
||||||
|
yy = y.view(-1, 1).repeat(1, w)
|
||||||
|
xx = x.view(1, -1).repeat(h, 1)
|
||||||
|
|
||||||
|
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
||||||
|
|
||||||
|
return meshed
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Res block, preserve spatial resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, kernel_size, padding):
|
||||||
|
super(ResBlock2d, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
||||||
|
padding=padding)
|
||||||
|
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
||||||
|
padding=padding)
|
||||||
|
self.norm1 = BatchNorm2d(in_features, affine=True)
|
||||||
|
self.norm2 = BatchNorm2d(in_features, affine=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.norm1(x)
|
||||||
|
out = F.relu(out)
|
||||||
|
out = self.conv1(out)
|
||||||
|
out = self.norm2(out)
|
||||||
|
out = F.relu(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
out += x
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UpBlock2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Upsampling block for use in decoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
||||||
|
super(UpBlock2d, self).__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
||||||
|
padding=padding, groups=groups)
|
||||||
|
self.norm = BatchNorm2d(out_features, affine=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = F.interpolate(x, scale_factor=2)
|
||||||
|
out = self.conv(out)
|
||||||
|
out = self.norm(out)
|
||||||
|
out = F.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DownBlock2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Downsampling block for use in encoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
||||||
|
super(DownBlock2d, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
||||||
|
padding=padding, groups=groups)
|
||||||
|
self.norm = BatchNorm2d(out_features, affine=True)
|
||||||
|
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv(x)
|
||||||
|
out = self.norm(out)
|
||||||
|
out = F.relu(out)
|
||||||
|
out = self.pool(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SameBlock2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Simple block, preserve spatial resolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
|
||||||
|
super(SameBlock2d, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
|
||||||
|
kernel_size=kernel_size, padding=padding, groups=groups)
|
||||||
|
self.norm = BatchNorm2d(out_features, affine=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.conv(x)
|
||||||
|
out = self.norm(out)
|
||||||
|
out = F.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Hourglass Encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
|
||||||
|
down_blocks = []
|
||||||
|
for i in range(num_blocks):
|
||||||
|
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
||||||
|
min(max_features, block_expansion * (2 ** (i + 1))),
|
||||||
|
kernel_size=3, padding=1))
|
||||||
|
self.down_blocks = nn.ModuleList(down_blocks)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
outs = [x]
|
||||||
|
for down_block in self.down_blocks:
|
||||||
|
outs.append(down_block(outs[-1]))
|
||||||
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Hourglass Decoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
|
||||||
|
up_blocks = []
|
||||||
|
|
||||||
|
for i in range(num_blocks)[::-1]:
|
||||||
|
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
||||||
|
out_filters = min(max_features, block_expansion * (2 ** i))
|
||||||
|
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
||||||
|
|
||||||
|
self.up_blocks = nn.ModuleList(up_blocks)
|
||||||
|
self.out_filters = block_expansion + in_features
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = x.pop()
|
||||||
|
for up_block in self.up_blocks:
|
||||||
|
out = up_block(out)
|
||||||
|
skip = x.pop()
|
||||||
|
out = torch.cat([out, skip], dim=1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Hourglass(nn.Module):
|
||||||
|
"""
|
||||||
|
Hourglass architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
||||||
|
super(Hourglass, self).__init__()
|
||||||
|
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
||||||
|
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
||||||
|
self.out_filters = self.decoder.out_filters
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.decoder(self.encoder(x))
|
||||||
|
|
||||||
|
|
||||||
|
class AntiAliasInterpolation2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Band-limited downsampling, for better preservation of the input signal.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, scale):
|
||||||
|
super(AntiAliasInterpolation2d, self).__init__()
|
||||||
|
sigma = (1 / scale - 1) / 2
|
||||||
|
kernel_size = 2 * round(sigma * 4) + 1
|
||||||
|
self.ka = kernel_size // 2
|
||||||
|
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
||||||
|
|
||||||
|
kernel_size = [kernel_size, kernel_size]
|
||||||
|
sigma = [sigma, sigma]
|
||||||
|
# The gaussian kernel is the product of the
|
||||||
|
# gaussian function of each dimension.
|
||||||
|
kernel = 1
|
||||||
|
meshgrids = torch.meshgrid(
|
||||||
|
[
|
||||||
|
torch.arange(size, dtype=torch.float32)
|
||||||
|
for size in kernel_size
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||||
|
mean = (size - 1) / 2
|
||||||
|
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
||||||
|
|
||||||
|
# Make sure sum of values in gaussian kernel equals 1.
|
||||||
|
kernel = kernel / torch.sum(kernel)
|
||||||
|
# Reshape to depthwise convolutional weight
|
||||||
|
kernel = kernel.view(1, 1, *kernel.size())
|
||||||
|
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||||
|
|
||||||
|
self.register_buffer('weight', kernel)
|
||||||
|
self.groups = channels
|
||||||
|
self.scale = scale
|
||||||
|
inv_scale = 1 / scale
|
||||||
|
self.int_inv_scale = int(inv_scale)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
if self.scale == 1.0:
|
||||||
|
return input
|
||||||
|
|
||||||
|
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
||||||
|
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
||||||
|
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
||||||
|
|
||||||
|
return out
|
@ -0,0 +1,67 @@
|
|||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from logger import Logger, Visualizer
|
||||||
|
import numpy as np
|
||||||
|
import imageio
|
||||||
|
from sync_batchnorm import DataParallelWithCallback
|
||||||
|
|
||||||
|
|
||||||
|
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset):
|
||||||
|
png_dir = os.path.join(log_dir, 'reconstruction/png')
|
||||||
|
log_dir = os.path.join(log_dir, 'reconstruction')
|
||||||
|
|
||||||
|
if checkpoint is not None:
|
||||||
|
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
|
||||||
|
else:
|
||||||
|
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.")
|
||||||
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
|
||||||
|
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
|
||||||
|
if not os.path.exists(png_dir):
|
||||||
|
os.makedirs(png_dir)
|
||||||
|
|
||||||
|
loss_list = []
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
generator = DataParallelWithCallback(generator)
|
||||||
|
kp_detector = DataParallelWithCallback(kp_detector)
|
||||||
|
|
||||||
|
generator.eval()
|
||||||
|
kp_detector.eval()
|
||||||
|
|
||||||
|
for it, x in tqdm(enumerate(dataloader)):
|
||||||
|
if config['reconstruction_params']['num_videos'] is not None:
|
||||||
|
if it > config['reconstruction_params']['num_videos']:
|
||||||
|
break
|
||||||
|
with torch.no_grad():
|
||||||
|
predictions = []
|
||||||
|
visualizations = []
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
x['video'] = x['video'].cuda()
|
||||||
|
kp_source = kp_detector(x['video'][:, :, 0])
|
||||||
|
for frame_idx in range(x['video'].shape[2]):
|
||||||
|
source = x['video'][:, :, 0]
|
||||||
|
driving = x['video'][:, :, frame_idx]
|
||||||
|
kp_driving = kp_detector(driving)
|
||||||
|
out = generator(source, kp_source=kp_source, kp_driving=kp_driving)
|
||||||
|
out['kp_source'] = kp_source
|
||||||
|
out['kp_driving'] = kp_driving
|
||||||
|
del out['sparse_deformed']
|
||||||
|
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
|
||||||
|
|
||||||
|
visualization = Visualizer(**config['visualizer_params']).visualize(source=source,
|
||||||
|
driving=driving, out=out)
|
||||||
|
visualizations.append(visualization)
|
||||||
|
|
||||||
|
loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy())
|
||||||
|
|
||||||
|
predictions = np.concatenate(predictions, axis=1)
|
||||||
|
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8))
|
||||||
|
|
||||||
|
image_name = x['name'][0] + config['reconstruction_params']['format']
|
||||||
|
imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
|
||||||
|
|
||||||
|
print("Reconstruction loss: %s" % np.mean(loss_list))
|
@ -0,0 +1,15 @@
|
|||||||
|
ffmpeg-python==0.2.0
|
||||||
|
imageio==2.22.0
|
||||||
|
imageio-ffmpeg==0.4.7
|
||||||
|
matplotlib==3.6.0
|
||||||
|
numpy==1.23.3
|
||||||
|
pandas==1.5.0
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
pytz==2022.2.1
|
||||||
|
PyYAML==6.0
|
||||||
|
scikit-image==0.19.3
|
||||||
|
scikit-learn==1.1.2
|
||||||
|
scipy==1.9.1
|
||||||
|
torch==1.12.1
|
||||||
|
torchvision==0.13.1
|
||||||
|
tqdm==4.64.1
|
@ -0,0 +1,87 @@
|
|||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
|
import os, sys
|
||||||
|
import yaml
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from time import gmtime, strftime
|
||||||
|
from shutil import copy
|
||||||
|
|
||||||
|
from frames_dataset import FramesDataset
|
||||||
|
|
||||||
|
from modules.generator import OcclusionAwareGenerator
|
||||||
|
from modules.discriminator import MultiScaleDiscriminator
|
||||||
|
from modules.keypoint_detector import KPDetector
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from train import train
|
||||||
|
from reconstruction import reconstruction
|
||||||
|
from animate import animate
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--config", required=True, help="path to config")
|
||||||
|
parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "animate"])
|
||||||
|
parser.add_argument("--log_dir", default='log', help="path to log into")
|
||||||
|
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
|
||||||
|
parser.add_argument("--device_ids", default="0", type=lambda x: list(map(int, x.split(','))),
|
||||||
|
help="Names of the devices comma separated.")
|
||||||
|
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Print model architecture")
|
||||||
|
parser.set_defaults(verbose=False)
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
with open(opt.config) as f:
|
||||||
|
config = yaml.load(f)
|
||||||
|
|
||||||
|
if opt.checkpoint is not None:
|
||||||
|
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
|
||||||
|
else:
|
||||||
|
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
|
||||||
|
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
|
||||||
|
|
||||||
|
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
generator.to(opt.device_ids[0])
|
||||||
|
if opt.verbose:
|
||||||
|
print(generator)
|
||||||
|
|
||||||
|
discriminator = MultiScaleDiscriminator(**config['model_params']['discriminator_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
discriminator.to(opt.device_ids[0])
|
||||||
|
if opt.verbose:
|
||||||
|
print(discriminator)
|
||||||
|
|
||||||
|
kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
|
||||||
|
**config['model_params']['common_params'])
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
kp_detector.to(opt.device_ids[0])
|
||||||
|
|
||||||
|
if opt.verbose:
|
||||||
|
print(kp_detector)
|
||||||
|
|
||||||
|
dataset = FramesDataset(is_train=(opt.mode == 'train'), **config['dataset_params'])
|
||||||
|
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
|
||||||
|
copy(opt.config, log_dir)
|
||||||
|
|
||||||
|
if opt.mode == 'train':
|
||||||
|
print("Training...")
|
||||||
|
train(config, generator, discriminator, kp_detector, opt.checkpoint, log_dir, dataset, opt.device_ids)
|
||||||
|
elif opt.mode == 'reconstruction':
|
||||||
|
print("Reconstruction...")
|
||||||
|
reconstruction(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
|
||||||
|
elif opt.mode == 'animate':
|
||||||
|
print("Animate...")
|
||||||
|
animate(config, generator, kp_detector, opt.checkpoint, log_dir, dataset)
|
After Width: | Height: | Size: 5.3 MiB |
After Width: | Height: | Size: 14 MiB |
After Width: | Height: | Size: 11 MiB |
After Width: | Height: | Size: 3.0 MiB |
After Width: | Height: | Size: 5.3 MiB |
After Width: | Height: | Size: 38 MiB |
@ -0,0 +1,12 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# File : __init__.py
|
||||||
|
# Author : Jiayuan Mao
|
||||||
|
# Email : maojiayuan@gmail.com
|
||||||
|
# Date : 27/01/2018
|
||||||
|
#
|
||||||
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||||
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||||
|
# Distributed under MIT License.
|
||||||
|
|
||||||
|
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
||||||
|
from .replicate import DataParallelWithCallback, patch_replication_callback
|
@ -0,0 +1,315 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# File : batchnorm.py
|
||||||
|
# Author : Jiayuan Mao
|
||||||
|
# Email : maojiayuan@gmail.com
|
||||||
|
# Date : 27/01/2018
|
||||||
|
#
|
||||||
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||||
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||||
|
# Distributed under MIT License.
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||||
|
|
||||||
|
from .comm import SyncMaster
|
||||||
|
|
||||||
|
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
||||||
|
|
||||||
|
|
||||||
|
def _sum_ft(tensor):
|
||||||
|
"""sum over the first and last dimention"""
|
||||||
|
return tensor.sum(dim=0).sum(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def _unsqueeze_ft(tensor):
|
||||||
|
"""add new dementions at the front and the tail"""
|
||||||
|
return tensor.unsqueeze(0).unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
||||||
|
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
||||||
|
|
||||||
|
|
||||||
|
class _SynchronizedBatchNorm(_BatchNorm):
|
||||||
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
||||||
|
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
||||||
|
|
||||||
|
self._sync_master = SyncMaster(self._data_parallel_master)
|
||||||
|
|
||||||
|
self._is_parallel = False
|
||||||
|
self._parallel_id = None
|
||||||
|
self._slave_pipe = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
||||||
|
if not (self._is_parallel and self.training):
|
||||||
|
return F.batch_norm(
|
||||||
|
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||||
|
self.training, self.momentum, self.eps)
|
||||||
|
|
||||||
|
# Resize the input to (B, C, -1).
|
||||||
|
input_shape = input.size()
|
||||||
|
input = input.view(input.size(0), self.num_features, -1)
|
||||||
|
|
||||||
|
# Compute the sum and square-sum.
|
||||||
|
sum_size = input.size(0) * input.size(2)
|
||||||
|
input_sum = _sum_ft(input)
|
||||||
|
input_ssum = _sum_ft(input ** 2)
|
||||||
|
|
||||||
|
# Reduce-and-broadcast the statistics.
|
||||||
|
if self._parallel_id == 0:
|
||||||
|
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||||
|
else:
|
||||||
|
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||||
|
|
||||||
|
# Compute the output.
|
||||||
|
if self.affine:
|
||||||
|
# MJY:: Fuse the multiplication for speed.
|
||||||
|
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
||||||
|
else:
|
||||||
|
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
||||||
|
|
||||||
|
# Reshape it.
|
||||||
|
return output.view(input_shape)
|
||||||
|
|
||||||
|
def __data_parallel_replicate__(self, ctx, copy_id):
|
||||||
|
self._is_parallel = True
|
||||||
|
self._parallel_id = copy_id
|
||||||
|
|
||||||
|
# parallel_id == 0 means master device.
|
||||||
|
if self._parallel_id == 0:
|
||||||
|
ctx.sync_master = self._sync_master
|
||||||
|
else:
|
||||||
|
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
||||||
|
|
||||||
|
def _data_parallel_master(self, intermediates):
|
||||||
|
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
||||||
|
|
||||||
|
# Always using same "device order" makes the ReduceAdd operation faster.
|
||||||
|
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
||||||
|
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
||||||
|
|
||||||
|
to_reduce = [i[1][:2] for i in intermediates]
|
||||||
|
to_reduce = [j for i in to_reduce for j in i] # flatten
|
||||||
|
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
||||||
|
|
||||||
|
sum_size = sum([i[1].sum_size for i in intermediates])
|
||||||
|
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
||||||
|
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
||||||
|
|
||||||
|
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for i, rec in enumerate(intermediates):
|
||||||
|
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def _compute_mean_std(self, sum_, ssum, size):
|
||||||
|
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
||||||
|
also maintains the moving average on the master device."""
|
||||||
|
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
||||||
|
mean = sum_ / size
|
||||||
|
sumvar = ssum - sum_ * mean
|
||||||
|
unbias_var = sumvar / (size - 1)
|
||||||
|
bias_var = sumvar / size
|
||||||
|
|
||||||
|
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
||||||
|
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
||||||
|
|
||||||
|
return mean, bias_var.clamp(self.eps) ** -0.5
|
||||||
|
|
||||||
|
|
||||||
|
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
||||||
|
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
||||||
|
mini-batch.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||||
|
|
||||||
|
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
||||||
|
standard-deviation are reduced across all devices during training.
|
||||||
|
|
||||||
|
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||||
|
training, PyTorch's implementation normalize the tensor on each device using
|
||||||
|
the statistics only on that device, which accelerated the computation and
|
||||||
|
is also easy to implement, but the statistics might be inaccurate.
|
||||||
|
Instead, in this synchronized version, the statistics will be computed
|
||||||
|
over all training samples distributed on multiple devices.
|
||||||
|
|
||||||
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||||
|
as the built-in PyTorch implementation.
|
||||||
|
|
||||||
|
The mean and standard-deviation are calculated per-dimension over
|
||||||
|
the mini-batches and gamma and beta are learnable parameter vectors
|
||||||
|
of size C (where C is the input size).
|
||||||
|
|
||||||
|
During training, this layer keeps a running estimate of its computed mean
|
||||||
|
and variance. The running sum is kept with a default momentum of 0.1.
|
||||||
|
|
||||||
|
During evaluation, this running mean/variance is used for normalization.
|
||||||
|
|
||||||
|
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||||
|
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_features: num_features from an expected input of size
|
||||||
|
`batch_size x num_features [x width]`
|
||||||
|
eps: a value added to the denominator for numerical stability.
|
||||||
|
Default: 1e-5
|
||||||
|
momentum: the value used for the running_mean and running_var
|
||||||
|
computation. Default: 0.1
|
||||||
|
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||||
|
affine parameters. Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
||||||
|
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # With Learnable Parameters
|
||||||
|
>>> m = SynchronizedBatchNorm1d(100)
|
||||||
|
>>> # Without Learnable Parameters
|
||||||
|
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
||||||
|
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
||||||
|
>>> output = m(input)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _check_input_dim(self, input):
|
||||||
|
if input.dim() != 2 and input.dim() != 3:
|
||||||
|
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||||
|
.format(input.dim()))
|
||||||
|
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
||||||
|
|
||||||
|
|
||||||
|
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
||||||
|
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
||||||
|
of 3d inputs
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||||
|
|
||||||
|
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
||||||
|
standard-deviation are reduced across all devices during training.
|
||||||
|
|
||||||
|
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||||
|
training, PyTorch's implementation normalize the tensor on each device using
|
||||||
|
the statistics only on that device, which accelerated the computation and
|
||||||
|
is also easy to implement, but the statistics might be inaccurate.
|
||||||
|
Instead, in this synchronized version, the statistics will be computed
|
||||||
|
over all training samples distributed on multiple devices.
|
||||||
|
|
||||||
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||||
|
as the built-in PyTorch implementation.
|
||||||
|
|
||||||
|
The mean and standard-deviation are calculated per-dimension over
|
||||||
|
the mini-batches and gamma and beta are learnable parameter vectors
|
||||||
|
of size C (where C is the input size).
|
||||||
|
|
||||||
|
During training, this layer keeps a running estimate of its computed mean
|
||||||
|
and variance. The running sum is kept with a default momentum of 0.1.
|
||||||
|
|
||||||
|
During evaluation, this running mean/variance is used for normalization.
|
||||||
|
|
||||||
|
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||||
|
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_features: num_features from an expected input of
|
||||||
|
size batch_size x num_features x height x width
|
||||||
|
eps: a value added to the denominator for numerical stability.
|
||||||
|
Default: 1e-5
|
||||||
|
momentum: the value used for the running_mean and running_var
|
||||||
|
computation. Default: 0.1
|
||||||
|
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||||
|
affine parameters. Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C, H, W)`
|
||||||
|
- Output: :math:`(N, C, H, W)` (same shape as input)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # With Learnable Parameters
|
||||||
|
>>> m = SynchronizedBatchNorm2d(100)
|
||||||
|
>>> # Without Learnable Parameters
|
||||||
|
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
||||||
|
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
||||||
|
>>> output = m(input)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _check_input_dim(self, input):
|
||||||
|
if input.dim() != 4:
|
||||||
|
raise ValueError('expected 4D input (got {}D input)'
|
||||||
|
.format(input.dim()))
|
||||||
|
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
||||||
|
|
||||||
|
|
||||||
|
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||||
|
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
||||||
|
of 4d inputs
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||||
|
|
||||||
|
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
||||||
|
standard-deviation are reduced across all devices during training.
|
||||||
|
|
||||||
|
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||||
|
training, PyTorch's implementation normalize the tensor on each device using
|
||||||
|
the statistics only on that device, which accelerated the computation and
|
||||||
|
is also easy to implement, but the statistics might be inaccurate.
|
||||||
|
Instead, in this synchronized version, the statistics will be computed
|
||||||
|
over all training samples distributed on multiple devices.
|
||||||
|
|
||||||
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||||
|
as the built-in PyTorch implementation.
|
||||||
|
|
||||||
|
The mean and standard-deviation are calculated per-dimension over
|
||||||
|
the mini-batches and gamma and beta are learnable parameter vectors
|
||||||
|
of size C (where C is the input size).
|
||||||
|
|
||||||
|
During training, this layer keeps a running estimate of its computed mean
|
||||||
|
and variance. The running sum is kept with a default momentum of 0.1.
|
||||||
|
|
||||||
|
During evaluation, this running mean/variance is used for normalization.
|
||||||
|
|
||||||
|
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||||
|
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
||||||
|
or Spatio-temporal BatchNorm
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_features: num_features from an expected input of
|
||||||
|
size batch_size x num_features x depth x height x width
|
||||||
|
eps: a value added to the denominator for numerical stability.
|
||||||
|
Default: 1e-5
|
||||||
|
momentum: the value used for the running_mean and running_var
|
||||||
|
computation. Default: 0.1
|
||||||
|
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||||
|
affine parameters. Default: ``True``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, C, D, H, W)`
|
||||||
|
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # With Learnable Parameters
|
||||||
|
>>> m = SynchronizedBatchNorm3d(100)
|
||||||
|
>>> # Without Learnable Parameters
|
||||||
|
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
||||||
|
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
||||||
|
>>> output = m(input)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _check_input_dim(self, input):
|
||||||
|
if input.dim() != 5:
|
||||||
|
raise ValueError('expected 5D input (got {}D input)'
|
||||||
|
.format(input.dim()))
|
||||||
|
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
@ -0,0 +1,137 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# File : comm.py
|
||||||
|
# Author : Jiayuan Mao
|
||||||
|
# Email : maojiayuan@gmail.com
|
||||||
|
# Date : 27/01/2018
|
||||||
|
#
|
||||||
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||||
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||||
|
# Distributed under MIT License.
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import collections
|
||||||
|
import threading
|
||||||
|
|
||||||
|
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
||||||
|
|
||||||
|
|
||||||
|
class FutureResult(object):
|
||||||
|
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._result = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._cond = threading.Condition(self._lock)
|
||||||
|
|
||||||
|
def put(self, result):
|
||||||
|
with self._lock:
|
||||||
|
assert self._result is None, 'Previous result has\'t been fetched.'
|
||||||
|
self._result = result
|
||||||
|
self._cond.notify()
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
with self._lock:
|
||||||
|
if self._result is None:
|
||||||
|
self._cond.wait()
|
||||||
|
|
||||||
|
res = self._result
|
||||||
|
self._result = None
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
||||||
|
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
||||||
|
|
||||||
|
|
||||||
|
class SlavePipe(_SlavePipeBase):
|
||||||
|
"""Pipe for master-slave communication."""
|
||||||
|
|
||||||
|
def run_slave(self, msg):
|
||||||
|
self.queue.put((self.identifier, msg))
|
||||||
|
ret = self.result.get()
|
||||||
|
self.queue.put(True)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class SyncMaster(object):
|
||||||
|
"""An abstract `SyncMaster` object.
|
||||||
|
|
||||||
|
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
||||||
|
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
||||||
|
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
||||||
|
and passed to a registered callback.
|
||||||
|
- After receiving the messages, the master device should gather the information and determine to message passed
|
||||||
|
back to each slave devices.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, master_callback):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
master_callback: a callback to be invoked after having collected messages from slave devices.
|
||||||
|
"""
|
||||||
|
self._master_callback = master_callback
|
||||||
|
self._queue = queue.Queue()
|
||||||
|
self._registry = collections.OrderedDict()
|
||||||
|
self._activated = False
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
return {'master_callback': self._master_callback}
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__init__(state['master_callback'])
|
||||||
|
|
||||||
|
def register_slave(self, identifier):
|
||||||
|
"""
|
||||||
|
Register an slave device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: an identifier, usually is the device id.
|
||||||
|
|
||||||
|
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._activated:
|
||||||
|
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
||||||
|
self._activated = False
|
||||||
|
self._registry.clear()
|
||||||
|
future = FutureResult()
|
||||||
|
self._registry[identifier] = _MasterRegistry(future)
|
||||||
|
return SlavePipe(identifier, self._queue, future)
|
||||||
|
|
||||||
|
def run_master(self, master_msg):
|
||||||
|
"""
|
||||||
|
Main entry for the master device in each forward pass.
|
||||||
|
The messages were first collected from each devices (including the master device), and then
|
||||||
|
an callback will be invoked to compute the message to be sent back to each devices
|
||||||
|
(including the master device).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
master_msg: the message that the master want to send to itself. This will be placed as the first
|
||||||
|
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
||||||
|
|
||||||
|
Returns: the message to be sent back to the master device.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self._activated = True
|
||||||
|
|
||||||
|
intermediates = [(0, master_msg)]
|
||||||
|
for i in range(self.nr_slaves):
|
||||||
|
intermediates.append(self._queue.get())
|
||||||
|
|
||||||
|
results = self._master_callback(intermediates)
|
||||||
|
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
||||||
|
|
||||||
|
for i, res in results:
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
self._registry[i].result.put(res)
|
||||||
|
|
||||||
|
for i in range(self.nr_slaves):
|
||||||
|
assert self._queue.get() is True
|
||||||
|
|
||||||
|
return results[0][1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nr_slaves(self):
|
||||||
|
return len(self._registry)
|
@ -0,0 +1,94 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# File : replicate.py
|
||||||
|
# Author : Jiayuan Mao
|
||||||
|
# Email : maojiayuan@gmail.com
|
||||||
|
# Date : 27/01/2018
|
||||||
|
#
|
||||||
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||||
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||||
|
# Distributed under MIT License.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from torch.nn.parallel.data_parallel import DataParallel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'CallbackContext',
|
||||||
|
'execute_replication_callbacks',
|
||||||
|
'DataParallelWithCallback',
|
||||||
|
'patch_replication_callback'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackContext(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def execute_replication_callbacks(modules):
|
||||||
|
"""
|
||||||
|
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
||||||
|
|
||||||
|
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||||
|
|
||||||
|
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
||||||
|
(shared among multiple copies of this module on different devices).
|
||||||
|
Through this context, different copies can share some information.
|
||||||
|
|
||||||
|
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
||||||
|
of any slave copies.
|
||||||
|
"""
|
||||||
|
master_copy = modules[0]
|
||||||
|
nr_modules = len(list(master_copy.modules()))
|
||||||
|
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
||||||
|
|
||||||
|
for i, module in enumerate(modules):
|
||||||
|
for j, m in enumerate(module.modules()):
|
||||||
|
if hasattr(m, '__data_parallel_replicate__'):
|
||||||
|
m.__data_parallel_replicate__(ctxs[j], i)
|
||||||
|
|
||||||
|
|
||||||
|
class DataParallelWithCallback(DataParallel):
|
||||||
|
"""
|
||||||
|
Data Parallel with a replication callback.
|
||||||
|
|
||||||
|
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
||||||
|
original `replicate` function.
|
||||||
|
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||||
|
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||||
|
# sync_bn.__data_parallel_replicate__ will be invoked.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def replicate(self, module, device_ids):
|
||||||
|
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
||||||
|
execute_replication_callbacks(modules)
|
||||||
|
return modules
|
||||||
|
|
||||||
|
|
||||||
|
def patch_replication_callback(data_parallel):
|
||||||
|
"""
|
||||||
|
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
||||||
|
Useful when you have customized `DataParallel` implementation.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||||
|
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
||||||
|
> patch_replication_callback(sync_bn)
|
||||||
|
# this is equivalent to
|
||||||
|
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||||
|
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert isinstance(data_parallel, DataParallel)
|
||||||
|
|
||||||
|
old_replicate = data_parallel.replicate
|
||||||
|
|
||||||
|
@functools.wraps(old_replicate)
|
||||||
|
def new_replicate(module, device_ids):
|
||||||
|
modules = old_replicate(module, device_ids)
|
||||||
|
execute_replication_callbacks(modules)
|
||||||
|
return modules
|
||||||
|
|
||||||
|
data_parallel.replicate = new_replicate
|
@ -0,0 +1,29 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# File : unittest.py
|
||||||
|
# Author : Jiayuan Mao
|
||||||
|
# Email : maojiayuan@gmail.com
|
||||||
|
# Date : 27/01/2018
|
||||||
|
#
|
||||||
|
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||||
|
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||||
|
# Distributed under MIT License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
|
||||||
|
def as_numpy(v):
|
||||||
|
if isinstance(v, Variable):
|
||||||
|
v = v.data
|
||||||
|
return v.cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
class TorchTestCase(unittest.TestCase):
|
||||||
|
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
||||||
|
npa, npb = as_numpy(a), as_numpy(b)
|
||||||
|
self.assertTrue(
|
||||||
|
np.allclose(npa, npb, atol=atol),
|
||||||
|
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
||||||
|
)
|
@ -0,0 +1,87 @@
|
|||||||
|
from tqdm import trange
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from logger import Logger
|
||||||
|
from modules.model import GeneratorFullModel, DiscriminatorFullModel
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import MultiStepLR
|
||||||
|
|
||||||
|
from sync_batchnorm import DataParallelWithCallback
|
||||||
|
|
||||||
|
from frames_dataset import DatasetRepeater
|
||||||
|
|
||||||
|
|
||||||
|
def train(config, generator, discriminator, kp_detector, checkpoint, log_dir, dataset, device_ids):
|
||||||
|
train_params = config['train_params']
|
||||||
|
|
||||||
|
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=train_params['lr_generator'], betas=(0.5, 0.999))
|
||||||
|
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=train_params['lr_discriminator'], betas=(0.5, 0.999))
|
||||||
|
optimizer_kp_detector = torch.optim.Adam(kp_detector.parameters(), lr=train_params['lr_kp_detector'], betas=(0.5, 0.999))
|
||||||
|
|
||||||
|
if checkpoint is not None:
|
||||||
|
start_epoch = Logger.load_cpk(checkpoint, generator, discriminator, kp_detector,
|
||||||
|
optimizer_generator, optimizer_discriminator,
|
||||||
|
None if train_params['lr_kp_detector'] == 0 else optimizer_kp_detector)
|
||||||
|
else:
|
||||||
|
start_epoch = 0
|
||||||
|
|
||||||
|
scheduler_generator = MultiStepLR(optimizer_generator, train_params['epoch_milestones'], gamma=0.1,
|
||||||
|
last_epoch=start_epoch - 1)
|
||||||
|
scheduler_discriminator = MultiStepLR(optimizer_discriminator, train_params['epoch_milestones'], gamma=0.1,
|
||||||
|
last_epoch=start_epoch - 1)
|
||||||
|
scheduler_kp_detector = MultiStepLR(optimizer_kp_detector, train_params['epoch_milestones'], gamma=0.1,
|
||||||
|
last_epoch=-1 + start_epoch * (train_params['lr_kp_detector'] != 0))
|
||||||
|
|
||||||
|
if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
|
||||||
|
dataset = DatasetRepeater(dataset, train_params['num_repeats'])
|
||||||
|
dataloader = DataLoader(dataset, batch_size=train_params['batch_size'], shuffle=True, num_workers=6, drop_last=True)
|
||||||
|
|
||||||
|
generator_full = GeneratorFullModel(kp_detector, generator, discriminator, train_params)
|
||||||
|
discriminator_full = DiscriminatorFullModel(kp_detector, generator, discriminator, train_params)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
generator_full = DataParallelWithCallback(generator_full, device_ids=device_ids)
|
||||||
|
discriminator_full = DataParallelWithCallback(discriminator_full, device_ids=device_ids)
|
||||||
|
|
||||||
|
with Logger(log_dir=log_dir, visualizer_params=config['visualizer_params'], checkpoint_freq=train_params['checkpoint_freq']) as logger:
|
||||||
|
for epoch in trange(start_epoch, train_params['num_epochs']):
|
||||||
|
for x in dataloader:
|
||||||
|
losses_generator, generated = generator_full(x)
|
||||||
|
|
||||||
|
loss_values = [val.mean() for val in losses_generator.values()]
|
||||||
|
loss = sum(loss_values)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer_generator.step()
|
||||||
|
optimizer_generator.zero_grad()
|
||||||
|
optimizer_kp_detector.step()
|
||||||
|
optimizer_kp_detector.zero_grad()
|
||||||
|
|
||||||
|
if train_params['loss_weights']['generator_gan'] != 0:
|
||||||
|
optimizer_discriminator.zero_grad()
|
||||||
|
losses_discriminator = discriminator_full(x, generated)
|
||||||
|
loss_values = [val.mean() for val in losses_discriminator.values()]
|
||||||
|
loss = sum(loss_values)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer_discriminator.step()
|
||||||
|
optimizer_discriminator.zero_grad()
|
||||||
|
else:
|
||||||
|
losses_discriminator = {}
|
||||||
|
|
||||||
|
losses_generator.update(losses_discriminator)
|
||||||
|
losses = {key: value.mean().detach().data.cpu().numpy() for key, value in losses_generator.items()}
|
||||||
|
logger.log_iter(losses=losses)
|
||||||
|
|
||||||
|
scheduler_generator.step()
|
||||||
|
scheduler_discriminator.step()
|
||||||
|
scheduler_kp_detector.step()
|
||||||
|
|
||||||
|
logger.log_epoch(epoch, {'generator': generator,
|
||||||
|
'discriminator': discriminator,
|
||||||
|
'kp_detector': kp_detector,
|
||||||
|
'optimizer_generator': optimizer_generator,
|
||||||
|
'optimizer_discriminator': optimizer_discriminator,
|
||||||
|
'optimizer_kp_detector': optimizer_kp_detector}, inp=x, out=generated)
|
@ -0,0 +1,2 @@
|
|||||||
|
body{-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale;font-family:-apple-system,BlinkMacSystemFont,Segoe UI,Roboto,Oxygen,Ubuntu,Cantarell,Fira Sans,Droid Sans,Helvetica Neue,sans-serif;margin:0}code{font-family:source-code-pro,Menlo,Monaco,Consolas,Courier New,monospace}
|
||||||
|
/*# sourceMappingURL=main.e6c13ad2.css.map*/
|
@ -0,0 +1 @@
|
|||||||
|
{"version":3,"file":"static/css/main.e6c13ad2.css","mappings":"AAAA,KAKE,kCAAmC,CACnC,iCAAkC,CAJlC,mIAEY,CAHZ,QAMF,CAEA,KACE,uEAEF","sources":["index.css"],"sourcesContent":["body {\n margin: 0;\n font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',\n 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',\n sans-serif;\n -webkit-font-smoothing: antialiased;\n -moz-osx-font-smoothing: grayscale;\n}\n\ncode {\n font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New',\n monospace;\n}\n"],"names":[],"sourceRoot":""}
|
@ -0,0 +1,2 @@
|
|||||||
|
"use strict";(self.webpackChunkimage_processing_app=self.webpackChunkimage_processing_app||[]).push([[453],{6453:(e,t,n)=>{n.r(t),n.d(t,{getCLS:()=>y,getFCP:()=>g,getFID:()=>C,getLCP:()=>P,getTTFB:()=>D});var i,r,a,o,u=function(e,t){return{name:e,value:void 0===t?-1:t,delta:0,entries:[],id:"v2-".concat(Date.now(),"-").concat(Math.floor(8999999999999*Math.random())+1e12)}},c=function(e,t){try{if(PerformanceObserver.supportedEntryTypes.includes(e)){if("first-input"===e&&!("PerformanceEventTiming"in self))return;var n=new PerformanceObserver((function(e){return e.getEntries().map(t)}));return n.observe({type:e,buffered:!0}),n}}catch(e){}},s=function(e,t){var n=function n(i){"pagehide"!==i.type&&"hidden"!==document.visibilityState||(e(i),t&&(removeEventListener("visibilitychange",n,!0),removeEventListener("pagehide",n,!0)))};addEventListener("visibilitychange",n,!0),addEventListener("pagehide",n,!0)},f=function(e){addEventListener("pageshow",(function(t){t.persisted&&e(t)}),!0)},m=function(e,t,n){var i;return function(r){t.value>=0&&(r||n)&&(t.delta=t.value-(i||0),(t.delta||void 0===i)&&(i=t.value,e(t)))}},p=-1,v=function(){return"hidden"===document.visibilityState?0:1/0},d=function(){s((function(e){var t=e.timeStamp;p=t}),!0)},l=function(){return p<0&&(p=v(),d(),f((function(){setTimeout((function(){p=v(),d()}),0)}))),{get firstHiddenTime(){return p}}},g=function(e,t){var n,i=l(),r=u("FCP"),a=function(e){"first-contentful-paint"===e.name&&(s&&s.disconnect(),e.startTime<i.firstHiddenTime&&(r.value=e.startTime,r.entries.push(e),n(!0)))},o=window.performance&&performance.getEntriesByName&&performance.getEntriesByName("first-contentful-paint")[0],s=o?null:c("paint",a);(o||s)&&(n=m(e,r,t),o&&a(o),f((function(i){r=u("FCP"),n=m(e,r,t),requestAnimationFrame((function(){requestAnimationFrame((function(){r.value=performance.now()-i.timeStamp,n(!0)}))}))})))},h=!1,T=-1,y=function(e,t){h||(g((function(e){T=e.value})),h=!0);var n,i=function(t){T>-1&&e(t)},r=u("CLS",0),a=0,o=[],p=function(e){if(!e.hadRecentInput){var t=o[0],i=o[o.length-1];a&&e.startTime-i.startTime<1e3&&e.startTime-t.startTime<5e3?(a+=e.value,o.push(e)):(a=e.value,o=[e]),a>r.value&&(r.value=a,r.entries=o,n())}},v=c("layout-shift",p);v&&(n=m(i,r,t),s((function(){v.takeRecords().map(p),n(!0)})),f((function(){a=0,T=-1,r=u("CLS",0),n=m(i,r,t)})))},E={passive:!0,capture:!0},w=new Date,L=function(e,t){i||(i=t,r=e,a=new Date,F(removeEventListener),S())},S=function(){if(r>=0&&r<a-w){var e={entryType:"first-input",name:i.type,target:i.target,cancelable:i.cancelable,startTime:i.timeStamp,processingStart:i.timeStamp+r};o.forEach((function(t){t(e)})),o=[]}},b=function(e){if(e.cancelable){var t=(e.timeStamp>1e12?new Date:performance.now())-e.timeStamp;"pointerdown"==e.type?function(e,t){var n=function(){L(e,t),r()},i=function(){r()},r=function(){removeEventListener("pointerup",n,E),removeEventListener("pointercancel",i,E)};addEventListener("pointerup",n,E),addEventListener("pointercancel",i,E)}(t,e):L(t,e)}},F=function(e){["mousedown","keydown","touchstart","pointerdown"].forEach((function(t){return e(t,b,E)}))},C=function(e,t){var n,a=l(),p=u("FID"),v=function(e){e.startTime<a.firstHiddenTime&&(p.value=e.processingStart-e.startTime,p.entries.push(e),n(!0))},d=c("first-input",v);n=m(e,p,t),d&&s((function(){d.takeRecords().map(v),d.disconnect()}),!0),d&&f((function(){var a;p=u("FID"),n=m(e,p,t),o=[],r=-1,i=null,F(addEventListener),a=v,o.push(a),S()}))},k={},P=function(e,t){var n,i=l(),r=u("LCP"),a=function(e){var t=e.startTime;t<i.firstHiddenTime&&(r.value=t,r.entries.push(e),n())},o=c("largest-contentful-paint",a);if(o){n=m(e,r,t);var p=function(){k[r.id]||(o.takeRecords().map(a),o.disconnect(),k[r.id]=!0,n(!0))};["keydown","click"].forEach((function(e){addEventListener(e,p,{once:!0,capture:!0})})),s(p,!0),f((function(i){r=u("LCP"),n=m(e,r,t),requestAnimationFrame((function(){requestAnimationFrame((function(){r.value=performance.now()-i.timeStamp,k[r.id]=!0,n(!0)}))}))}))}},D=function(e){var t,n=u("TTFB");t=function(){try{var t=performance.getEntriesByType("navigation")[0]||function(){var e=performance.timing,t={entryType:"navigation",startTime:0};for(var n in e)"navigationStart"!==n&&"toJSON"!==n&&(t[n]=Math.max(e[n]-e.navigationStart,0));return t}();if(n.value=n.delta=t.responseStart,n.value<0||n.value>performance.now())return;n.entries=[t],e(n)}catch(e){}},"complete"===document.readyState?setTimeout(t,0):addEventListener("load",(function(){return setTimeout(t,0)}))}}}]);
|
||||||
|
//# sourceMappingURL=453.8beb5808.chunk.js.map
|
@ -0,0 +1,91 @@
|
|||||||
|
/**
|
||||||
|
* @license React
|
||||||
|
* react-dom.production.min.js
|
||||||
|
*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @license React
|
||||||
|
* react-is.production.min.js
|
||||||
|
*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @license React
|
||||||
|
* react-jsx-runtime.production.min.js
|
||||||
|
*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @license React
|
||||||
|
* react.production.min.js
|
||||||
|
*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @license React
|
||||||
|
* scheduler.production.min.js
|
||||||
|
*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @remix-run/router v1.17.0
|
||||||
|
*
|
||||||
|
* Copyright (c) Remix Software Inc.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE.md file in the root directory of this source tree.
|
||||||
|
*
|
||||||
|
* @license MIT
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* React Router DOM v6.24.0
|
||||||
|
*
|
||||||
|
* Copyright (c) Remix Software Inc.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE.md file in the root directory of this source tree.
|
||||||
|
*
|
||||||
|
* @license MIT
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* React Router v6.24.0
|
||||||
|
*
|
||||||
|
* Copyright (c) Remix Software Inc.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE.md file in the root directory of this source tree.
|
||||||
|
*
|
||||||
|
* @license MIT
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** @license React v16.13.1
|
||||||
|
* react-is.production.min.js
|
||||||
|
*
|
||||||
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
*
|
||||||
|
* This source code is licensed under the MIT license found in the
|
||||||
|
* LICENSE file in the root directory of this source tree.
|
||||||
|
*/
|