parent
85da7e3f67
commit
d0804c693d
@ -0,0 +1,113 @@
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian
|
||||
|
||||
|
||||
class DenseMotionNetwork(nn.Module):
|
||||
"""
|
||||
Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
|
||||
"""
|
||||
|
||||
def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False,
|
||||
scale_factor=1, kp_variance=0.01):
|
||||
super(DenseMotionNetwork, self).__init__()
|
||||
self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1),
|
||||
max_features=max_features, num_blocks=num_blocks)
|
||||
|
||||
self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3))
|
||||
|
||||
if estimate_occlusion_map:
|
||||
self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
|
||||
else:
|
||||
self.occlusion = None
|
||||
|
||||
self.num_kp = num_kp
|
||||
self.scale_factor = scale_factor
|
||||
self.kp_variance = kp_variance
|
||||
|
||||
if self.scale_factor != 1:
|
||||
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
||||
|
||||
def create_heatmap_representations(self, source_image, kp_driving, kp_source):
|
||||
"""
|
||||
Eq 6. in the paper H_k(z)
|
||||
"""
|
||||
spatial_size = source_image.shape[2:]
|
||||
gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
||||
gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance)
|
||||
heatmap = gaussian_driving - gaussian_source
|
||||
|
||||
#adding background feature
|
||||
zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type())
|
||||
heatmap = torch.cat([zeros, heatmap], dim=1)
|
||||
heatmap = heatmap.unsqueeze(2)
|
||||
return heatmap
|
||||
|
||||
def create_sparse_motions(self, source_image, kp_driving, kp_source):
|
||||
"""
|
||||
Eq 4. in the paper T_{s<-d}(z)
|
||||
"""
|
||||
bs, _, h, w = source_image.shape
|
||||
identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type())
|
||||
identity_grid = identity_grid.view(1, 1, h, w, 2)
|
||||
coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2)
|
||||
if 'jacobian' in kp_driving:
|
||||
jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
|
||||
jacobian = jacobian.unsqueeze(-3).unsqueeze(-3)
|
||||
jacobian = jacobian.repeat(1, 1, h, w, 1, 1)
|
||||
coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
|
||||
coordinate_grid = coordinate_grid.squeeze(-1)
|
||||
|
||||
driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2)
|
||||
|
||||
#adding background feature
|
||||
identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1)
|
||||
sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)
|
||||
return sparse_motions
|
||||
|
||||
def create_deformed_source_image(self, source_image, sparse_motions):
|
||||
"""
|
||||
Eq 7. in the paper \hat{T}_{s<-d}(z)
|
||||
"""
|
||||
bs, _, h, w = source_image.shape
|
||||
source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1)
|
||||
source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w)
|
||||
sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1))
|
||||
sparse_deformed = F.grid_sample(source_repeat, sparse_motions)
|
||||
sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w))
|
||||
return sparse_deformed
|
||||
|
||||
def forward(self, source_image, kp_driving, kp_source):
|
||||
if self.scale_factor != 1:
|
||||
source_image = self.down(source_image)
|
||||
|
||||
bs, _, h, w = source_image.shape
|
||||
|
||||
out_dict = dict()
|
||||
heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source)
|
||||
sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source)
|
||||
deformed_source = self.create_deformed_source_image(source_image, sparse_motion)
|
||||
out_dict['sparse_deformed'] = deformed_source
|
||||
|
||||
input = torch.cat([heatmap_representation, deformed_source], dim=2)
|
||||
input = input.view(bs, -1, h, w)
|
||||
|
||||
prediction = self.hourglass(input)
|
||||
|
||||
mask = self.mask(prediction)
|
||||
mask = F.softmax(mask, dim=1)
|
||||
out_dict['mask'] = mask
|
||||
mask = mask.unsqueeze(2)
|
||||
sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3)
|
||||
deformation = (sparse_motion * mask).sum(dim=1)
|
||||
deformation = deformation.permute(0, 2, 3, 1)
|
||||
|
||||
out_dict['deformation'] = deformation
|
||||
|
||||
# Sec. 3.2 in the paper
|
||||
if self.occlusion:
|
||||
occlusion_map = torch.sigmoid(self.occlusion(prediction))
|
||||
out_dict['occlusion_map'] = occlusion_map
|
||||
|
||||
return out_dict
|
||||
@ -0,0 +1,95 @@
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from modules.util import kp2gaussian
|
||||
import torch
|
||||
|
||||
|
||||
class DownBlock2d(nn.Module):
|
||||
"""
|
||||
Simple block for processing video (encoder).
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
||||
super(DownBlock2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
||||
|
||||
if sn:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
if norm:
|
||||
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
||||
else:
|
||||
self.norm = None
|
||||
self.pool = pool
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
out = self.conv(out)
|
||||
if self.norm:
|
||||
out = self.norm(out)
|
||||
out = F.leaky_relu(out, 0.2)
|
||||
if self.pool:
|
||||
out = F.avg_pool2d(out, (2, 2))
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
"""
|
||||
Discriminator similar to Pix2Pix
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
||||
sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
down_blocks = []
|
||||
for i in range(num_blocks):
|
||||
down_blocks.append(
|
||||
DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
||||
min(max_features, block_expansion * (2 ** (i + 1))),
|
||||
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
||||
if sn:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
self.use_kp = use_kp
|
||||
self.kp_variance = kp_variance
|
||||
|
||||
def forward(self, x, kp=None):
|
||||
feature_maps = []
|
||||
out = x
|
||||
if self.use_kp:
|
||||
heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
|
||||
out = torch.cat([out, heatmap], dim=1)
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
feature_maps.append(down_block(out))
|
||||
out = feature_maps[-1]
|
||||
prediction_map = self.conv(out)
|
||||
|
||||
return feature_maps, prediction_map
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
"""
|
||||
Multi-scale (scale) discriminator
|
||||
"""
|
||||
|
||||
def __init__(self, scales=(), **kwargs):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.scales = scales
|
||||
discs = {}
|
||||
for scale in scales:
|
||||
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
||||
self.discs = nn.ModuleDict(discs)
|
||||
|
||||
def forward(self, x, kp=None):
|
||||
out_dict = {}
|
||||
for scale, disc in self.discs.items():
|
||||
scale = str(scale).replace('-', '.')
|
||||
key = 'prediction_' + scale
|
||||
feature_maps, prediction_map = disc(x[key], kp)
|
||||
out_dict['feature_maps_' + scale] = feature_maps
|
||||
out_dict['prediction_map_' + scale] = prediction_map
|
||||
return out_dict
|
||||
@ -0,0 +1,97 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
||||
from modules.dense_motion import DenseMotionNetwork
|
||||
|
||||
|
||||
class OcclusionAwareGenerator(nn.Module):
|
||||
"""
|
||||
Generator that given source image and and keypoints try to transform image according to movement trajectories
|
||||
induced by keypoints. Generator follows Johnson architecture.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
|
||||
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
|
||||
super(OcclusionAwareGenerator, self).__init__()
|
||||
|
||||
if dense_motion_params is not None:
|
||||
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
|
||||
estimate_occlusion_map=estimate_occlusion_map,
|
||||
**dense_motion_params)
|
||||
else:
|
||||
self.dense_motion_network = None
|
||||
|
||||
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
||||
|
||||
down_blocks = []
|
||||
for i in range(num_down_blocks):
|
||||
in_features = min(max_features, block_expansion * (2 ** i))
|
||||
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
||||
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
|
||||
up_blocks = []
|
||||
for i in range(num_down_blocks):
|
||||
in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
|
||||
out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
|
||||
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
|
||||
self.bottleneck = torch.nn.Sequential()
|
||||
in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
|
||||
for i in range(num_bottleneck_blocks):
|
||||
self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
|
||||
|
||||
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
|
||||
self.estimate_occlusion_map = estimate_occlusion_map
|
||||
self.num_channels = num_channels
|
||||
|
||||
def deform_input(self, inp, deformation):
|
||||
_, h_old, w_old, _ = deformation.shape
|
||||
_, _, h, w = inp.shape
|
||||
if h_old != h or w_old != w:
|
||||
deformation = deformation.permute(0, 3, 1, 2)
|
||||
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
|
||||
deformation = deformation.permute(0, 2, 3, 1)
|
||||
return F.grid_sample(inp, deformation)
|
||||
|
||||
def forward(self, source_image, kp_driving, kp_source):
|
||||
# Encoding (downsampling) part
|
||||
out = self.first(source_image)
|
||||
for i in range(len(self.down_blocks)):
|
||||
out = self.down_blocks[i](out)
|
||||
|
||||
# Transforming feature representation according to deformation and occlusion
|
||||
output_dict = {}
|
||||
if self.dense_motion_network is not None:
|
||||
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
|
||||
kp_source=kp_source)
|
||||
output_dict['mask'] = dense_motion['mask']
|
||||
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
|
||||
|
||||
if 'occlusion_map' in dense_motion:
|
||||
occlusion_map = dense_motion['occlusion_map']
|
||||
output_dict['occlusion_map'] = occlusion_map
|
||||
else:
|
||||
occlusion_map = None
|
||||
deformation = dense_motion['deformation']
|
||||
out = self.deform_input(out, deformation)
|
||||
|
||||
if occlusion_map is not None:
|
||||
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
|
||||
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
|
||||
out = out * occlusion_map
|
||||
|
||||
output_dict["deformed"] = self.deform_input(source_image, deformation)
|
||||
|
||||
# Decoding part
|
||||
out = self.bottleneck(out)
|
||||
for i in range(len(self.up_blocks)):
|
||||
out = self.up_blocks[i](out)
|
||||
out = self.final(out)
|
||||
out = F.sigmoid(out)
|
||||
|
||||
output_dict["prediction"] = out
|
||||
|
||||
return output_dict
|
||||
@ -0,0 +1,75 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d
|
||||
|
||||
|
||||
class KPDetector(nn.Module):
|
||||
"""
|
||||
Detecting a keypoints. Return keypoint position and jacobian near each keypoint.
|
||||
"""
|
||||
|
||||
def __init__(self, block_expansion, num_kp, num_channels, max_features,
|
||||
num_blocks, temperature, estimate_jacobian=False, scale_factor=1,
|
||||
single_jacobian_map=False, pad=0):
|
||||
super(KPDetector, self).__init__()
|
||||
|
||||
self.predictor = Hourglass(block_expansion, in_features=num_channels,
|
||||
max_features=max_features, num_blocks=num_blocks)
|
||||
|
||||
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7),
|
||||
padding=pad)
|
||||
|
||||
if estimate_jacobian:
|
||||
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
|
||||
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
|
||||
out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad)
|
||||
self.jacobian.weight.data.zero_()
|
||||
self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
|
||||
else:
|
||||
self.jacobian = None
|
||||
|
||||
self.temperature = temperature
|
||||
self.scale_factor = scale_factor
|
||||
if self.scale_factor != 1:
|
||||
self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor)
|
||||
|
||||
def gaussian2kp(self, heatmap):
|
||||
"""
|
||||
Extract the mean and from a heatmap
|
||||
"""
|
||||
shape = heatmap.shape
|
||||
heatmap = heatmap.unsqueeze(-1)
|
||||
grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
|
||||
value = (heatmap * grid).sum(dim=(2, 3))
|
||||
kp = {'value': value}
|
||||
|
||||
return kp
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale_factor != 1:
|
||||
x = self.down(x)
|
||||
|
||||
feature_map = self.predictor(x)
|
||||
prediction = self.kp(feature_map)
|
||||
|
||||
final_shape = prediction.shape
|
||||
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
|
||||
heatmap = F.softmax(heatmap / self.temperature, dim=2)
|
||||
heatmap = heatmap.view(*final_shape)
|
||||
|
||||
out = self.gaussian2kp(heatmap)
|
||||
|
||||
if self.jacobian is not None:
|
||||
jacobian_map = self.jacobian(feature_map)
|
||||
jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2],
|
||||
final_shape[3])
|
||||
heatmap = heatmap.unsqueeze(2)
|
||||
|
||||
jacobian = heatmap * jacobian_map
|
||||
jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1)
|
||||
jacobian = jacobian.sum(dim=-1)
|
||||
jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2)
|
||||
out['jacobian'] = jacobian
|
||||
|
||||
return out
|
||||
@ -0,0 +1,259 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from modules.util import AntiAliasInterpolation2d, make_coordinate_grid
|
||||
from torchvision import models
|
||||
import numpy as np
|
||||
from torch.autograd import grad
|
||||
|
||||
|
||||
class Vgg19(torch.nn.Module):
|
||||
"""
|
||||
Vgg19 network for perceptual loss. See Sec 3.3.
|
||||
"""
|
||||
def __init__(self, requires_grad=False):
|
||||
super(Vgg19, self).__init__()
|
||||
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(2, 7):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(7, 12):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 21):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(21, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
|
||||
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
X = (X - self.mean) / self.std
|
||||
h_relu1 = self.slice1(X)
|
||||
h_relu2 = self.slice2(h_relu1)
|
||||
h_relu3 = self.slice3(h_relu2)
|
||||
h_relu4 = self.slice4(h_relu3)
|
||||
h_relu5 = self.slice5(h_relu4)
|
||||
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||
return out
|
||||
|
||||
|
||||
class ImagePyramide(torch.nn.Module):
|
||||
"""
|
||||
Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
|
||||
"""
|
||||
def __init__(self, scales, num_channels):
|
||||
super(ImagePyramide, self).__init__()
|
||||
downs = {}
|
||||
for scale in scales:
|
||||
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
||||
self.downs = nn.ModuleDict(downs)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = {}
|
||||
for scale, down_module in self.downs.items():
|
||||
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
||||
return out_dict
|
||||
|
||||
|
||||
class Transform:
|
||||
"""
|
||||
Random tps transformation for equivariance constraints. See Sec 3.3
|
||||
"""
|
||||
def __init__(self, bs, **kwargs):
|
||||
noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
|
||||
self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
|
||||
self.bs = bs
|
||||
|
||||
if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
|
||||
self.tps = True
|
||||
self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
|
||||
self.control_points = self.control_points.unsqueeze(0)
|
||||
self.control_params = torch.normal(mean=0,
|
||||
std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
|
||||
else:
|
||||
self.tps = False
|
||||
|
||||
def transform_frame(self, frame):
|
||||
grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)
|
||||
grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
|
||||
grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
|
||||
return F.grid_sample(frame, grid, padding_mode="reflection")
|
||||
|
||||
def warp_coordinates(self, coordinates):
|
||||
theta = self.theta.type(coordinates.type())
|
||||
theta = theta.unsqueeze(1)
|
||||
transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
|
||||
transformed = transformed.squeeze(-1)
|
||||
|
||||
if self.tps:
|
||||
control_points = self.control_points.type(coordinates.type())
|
||||
control_params = self.control_params.type(coordinates.type())
|
||||
distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
|
||||
distances = torch.abs(distances).sum(-1)
|
||||
|
||||
result = distances ** 2
|
||||
result = result * torch.log(distances + 1e-6)
|
||||
result = result * control_params
|
||||
result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
|
||||
transformed = transformed + result
|
||||
|
||||
return transformed
|
||||
|
||||
def jacobian(self, coordinates):
|
||||
new_coordinates = self.warp_coordinates(coordinates)
|
||||
grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
|
||||
grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
|
||||
jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
|
||||
return jacobian
|
||||
|
||||
|
||||
def detach_kp(kp):
|
||||
return {key: value.detach() for key, value in kp.items()}
|
||||
|
||||
|
||||
class GeneratorFullModel(torch.nn.Module):
|
||||
"""
|
||||
Merge all generator related updates into single model for better multi-gpu usage
|
||||
"""
|
||||
|
||||
def __init__(self, kp_extractor, generator, discriminator, train_params):
|
||||
super(GeneratorFullModel, self).__init__()
|
||||
self.kp_extractor = kp_extractor
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
self.train_params = train_params
|
||||
self.scales = train_params['scales']
|
||||
self.disc_scales = self.discriminator.scales
|
||||
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
|
||||
if torch.cuda.is_available():
|
||||
self.pyramid = self.pyramid.cuda()
|
||||
|
||||
self.loss_weights = train_params['loss_weights']
|
||||
|
||||
if sum(self.loss_weights['perceptual']) != 0:
|
||||
self.vgg = Vgg19()
|
||||
if torch.cuda.is_available():
|
||||
self.vgg = self.vgg.cuda()
|
||||
|
||||
def forward(self, x):
|
||||
kp_source = self.kp_extractor(x['source'])
|
||||
kp_driving = self.kp_extractor(x['driving'])
|
||||
|
||||
generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
|
||||
generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})
|
||||
|
||||
loss_values = {}
|
||||
|
||||
pyramide_real = self.pyramid(x['driving'])
|
||||
pyramide_generated = self.pyramid(generated['prediction'])
|
||||
|
||||
if sum(self.loss_weights['perceptual']) != 0:
|
||||
value_total = 0
|
||||
for scale in self.scales:
|
||||
x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
|
||||
y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
|
||||
|
||||
for i, weight in enumerate(self.loss_weights['perceptual']):
|
||||
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
||||
value_total += self.loss_weights['perceptual'][i] * value
|
||||
loss_values['perceptual'] = value_total
|
||||
|
||||
if self.loss_weights['generator_gan'] != 0:
|
||||
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
|
||||
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
|
||||
value_total = 0
|
||||
for scale in self.disc_scales:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
||||
value_total += self.loss_weights['generator_gan'] * value
|
||||
loss_values['gen_gan'] = value_total
|
||||
|
||||
if sum(self.loss_weights['feature_matching']) != 0:
|
||||
value_total = 0
|
||||
for scale in self.disc_scales:
|
||||
key = 'feature_maps_%s' % scale
|
||||
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
|
||||
if self.loss_weights['feature_matching'][i] == 0:
|
||||
continue
|
||||
value = torch.abs(a - b).mean()
|
||||
value_total += self.loss_weights['feature_matching'][i] * value
|
||||
loss_values['feature_matching'] = value_total
|
||||
|
||||
if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
|
||||
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
|
||||
transformed_frame = transform.transform_frame(x['driving'])
|
||||
transformed_kp = self.kp_extractor(transformed_frame)
|
||||
|
||||
generated['transformed_frame'] = transformed_frame
|
||||
generated['transformed_kp'] = transformed_kp
|
||||
|
||||
## Value loss part
|
||||
if self.loss_weights['equivariance_value'] != 0:
|
||||
value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
|
||||
loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value
|
||||
|
||||
## jacobian loss part
|
||||
if self.loss_weights['equivariance_jacobian'] != 0:
|
||||
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
|
||||
transformed_kp['jacobian'])
|
||||
|
||||
normed_driving = torch.inverse(kp_driving['jacobian'])
|
||||
normed_transformed = jacobian_transformed
|
||||
value = torch.matmul(normed_driving, normed_transformed)
|
||||
|
||||
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
|
||||
|
||||
value = torch.abs(eye - value).mean()
|
||||
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value
|
||||
|
||||
return loss_values, generated
|
||||
|
||||
|
||||
class DiscriminatorFullModel(torch.nn.Module):
|
||||
"""
|
||||
Merge all discriminator related updates into single model for better multi-gpu usage
|
||||
"""
|
||||
|
||||
def __init__(self, kp_extractor, generator, discriminator, train_params):
|
||||
super(DiscriminatorFullModel, self).__init__()
|
||||
self.kp_extractor = kp_extractor
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
self.train_params = train_params
|
||||
self.scales = self.discriminator.scales
|
||||
self.pyramid = ImagePyramide(self.scales, generator.num_channels)
|
||||
if torch.cuda.is_available():
|
||||
self.pyramid = self.pyramid.cuda()
|
||||
|
||||
self.loss_weights = train_params['loss_weights']
|
||||
|
||||
def forward(self, x, generated):
|
||||
pyramide_real = self.pyramid(x['driving'])
|
||||
pyramide_generated = self.pyramid(generated['prediction'].detach())
|
||||
|
||||
kp_driving = generated['kp_driving']
|
||||
discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
|
||||
discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
|
||||
|
||||
loss_values = {}
|
||||
value_total = 0
|
||||
for scale in self.scales:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
|
||||
value_total += self.loss_weights['discriminator_gan'] * value.mean()
|
||||
loss_values['disc_gan'] = value_total
|
||||
|
||||
return loss_values
|
||||
@ -0,0 +1,245 @@
|
||||
from torch import nn
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
|
||||
|
||||
|
||||
def kp2gaussian(kp, spatial_size, kp_variance):
|
||||
"""
|
||||
Transform a keypoint into gaussian like representation
|
||||
"""
|
||||
mean = kp['value']
|
||||
|
||||
coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
|
||||
number_of_leading_dimensions = len(mean.shape) - 1
|
||||
shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
|
||||
coordinate_grid = coordinate_grid.view(*shape)
|
||||
repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1)
|
||||
coordinate_grid = coordinate_grid.repeat(*repeats)
|
||||
|
||||
# Preprocess kp shape
|
||||
shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2)
|
||||
mean = mean.view(*shape)
|
||||
|
||||
mean_sub = (coordinate_grid - mean)
|
||||
|
||||
out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def make_coordinate_grid(spatial_size, type):
|
||||
"""
|
||||
Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
|
||||
"""
|
||||
h, w = spatial_size
|
||||
x = torch.arange(w).type(type)
|
||||
y = torch.arange(h).type(type)
|
||||
|
||||
x = (2 * (x / (w - 1)) - 1)
|
||||
y = (2 * (y / (h - 1)) - 1)
|
||||
|
||||
yy = y.view(-1, 1).repeat(1, w)
|
||||
xx = x.view(1, -1).repeat(h, 1)
|
||||
|
||||
meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
|
||||
|
||||
return meshed
|
||||
|
||||
|
||||
class ResBlock2d(nn.Module):
|
||||
"""
|
||||
Res block, preserve spatial resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, kernel_size, padding):
|
||||
super(ResBlock2d, self).__init__()
|
||||
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
||||
padding=padding)
|
||||
self.norm1 = BatchNorm2d(in_features, affine=True)
|
||||
self.norm2 = BatchNorm2d(in_features, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.norm1(x)
|
||||
out = F.relu(out)
|
||||
out = self.conv1(out)
|
||||
out = self.norm2(out)
|
||||
out = F.relu(out)
|
||||
out = self.conv2(out)
|
||||
out += x
|
||||
return out
|
||||
|
||||
|
||||
class UpBlock2d(nn.Module):
|
||||
"""
|
||||
Upsampling block for use in decoder.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
||||
super(UpBlock2d, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
||||
padding=padding, groups=groups)
|
||||
self.norm = BatchNorm2d(out_features, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.interpolate(x, scale_factor=2)
|
||||
out = self.conv(out)
|
||||
out = self.norm(out)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class DownBlock2d(nn.Module):
|
||||
"""
|
||||
Downsampling block for use in encoder.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
|
||||
super(DownBlock2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
|
||||
padding=padding, groups=groups)
|
||||
self.norm = BatchNorm2d(out_features, affine=True)
|
||||
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
out = self.norm(out)
|
||||
out = F.relu(out)
|
||||
out = self.pool(out)
|
||||
return out
|
||||
|
||||
|
||||
class SameBlock2d(nn.Module):
|
||||
"""
|
||||
Simple block, preserve spatial resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
|
||||
super(SameBlock2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
|
||||
kernel_size=kernel_size, padding=padding, groups=groups)
|
||||
self.norm = BatchNorm2d(out_features, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
out = self.norm(out)
|
||||
out = F.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
Hourglass Encoder
|
||||
"""
|
||||
|
||||
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
down_blocks = []
|
||||
for i in range(num_blocks):
|
||||
down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
||||
min(max_features, block_expansion * (2 ** (i + 1))),
|
||||
kernel_size=3, padding=1))
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
|
||||
def forward(self, x):
|
||||
outs = [x]
|
||||
for down_block in self.down_blocks:
|
||||
outs.append(down_block(outs[-1]))
|
||||
return outs
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
Hourglass Decoder
|
||||
"""
|
||||
|
||||
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
up_blocks = []
|
||||
|
||||
for i in range(num_blocks)[::-1]:
|
||||
in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
|
||||
out_filters = min(max_features, block_expansion * (2 ** i))
|
||||
up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
|
||||
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
self.out_filters = block_expansion + in_features
|
||||
|
||||
def forward(self, x):
|
||||
out = x.pop()
|
||||
for up_block in self.up_blocks:
|
||||
out = up_block(out)
|
||||
skip = x.pop()
|
||||
out = torch.cat([out, skip], dim=1)
|
||||
return out
|
||||
|
||||
|
||||
class Hourglass(nn.Module):
|
||||
"""
|
||||
Hourglass architecture.
|
||||
"""
|
||||
|
||||
def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
|
||||
super(Hourglass, self).__init__()
|
||||
self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
|
||||
self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
|
||||
self.out_filters = self.decoder.out_filters
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.encoder(x))
|
||||
|
||||
|
||||
class AntiAliasInterpolation2d(nn.Module):
|
||||
"""
|
||||
Band-limited downsampling, for better preservation of the input signal.
|
||||
"""
|
||||
def __init__(self, channels, scale):
|
||||
super(AntiAliasInterpolation2d, self).__init__()
|
||||
sigma = (1 / scale - 1) / 2
|
||||
kernel_size = 2 * round(sigma * 4) + 1
|
||||
self.ka = kernel_size // 2
|
||||
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
||||
|
||||
kernel_size = [kernel_size, kernel_size]
|
||||
sigma = [sigma, sigma]
|
||||
# The gaussian kernel is the product of the
|
||||
# gaussian function of each dimension.
|
||||
kernel = 1
|
||||
meshgrids = torch.meshgrid(
|
||||
[
|
||||
torch.arange(size, dtype=torch.float32)
|
||||
for size in kernel_size
|
||||
]
|
||||
)
|
||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||
mean = (size - 1) / 2
|
||||
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
kernel = kernel / torch.sum(kernel)
|
||||
# Reshape to depthwise convolutional weight
|
||||
kernel = kernel.view(1, 1, *kernel.size())
|
||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||
|
||||
self.register_buffer('weight', kernel)
|
||||
self.groups = channels
|
||||
self.scale = scale
|
||||
inv_scale = 1 / scale
|
||||
self.int_inv_scale = int(inv_scale)
|
||||
|
||||
def forward(self, input):
|
||||
if self.scale == 1.0:
|
||||
return input
|
||||
|
||||
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
||||
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
||||
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
||||
|
||||
return out
|
||||
|
After Width: | Height: | Size: 70 KiB |
|
After Width: | Height: | Size: 1.9 KiB |
|
After Width: | Height: | Size: 128 KiB |
|
After Width: | Height: | Size: 398 KiB |
Loading…
Reference in new issue