diff --git a/modules/dense_motion.py b/modules/dense_motion.py new file mode 100644 index 0000000..06f7039 --- /dev/null +++ b/modules/dense_motion.py @@ -0,0 +1,113 @@ +from torch import nn +import torch.nn.functional as F +import torch +from modules.util import Hourglass, AntiAliasInterpolation2d, make_coordinate_grid, kp2gaussian + + +class DenseMotionNetwork(nn.Module): + """ + Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving + """ + + def __init__(self, block_expansion, num_blocks, max_features, num_kp, num_channels, estimate_occlusion_map=False, + scale_factor=1, kp_variance=0.01): + super(DenseMotionNetwork, self).__init__() + self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp + 1) * (num_channels + 1), + max_features=max_features, num_blocks=num_blocks) + + self.mask = nn.Conv2d(self.hourglass.out_filters, num_kp + 1, kernel_size=(7, 7), padding=(3, 3)) + + if estimate_occlusion_map: + self.occlusion = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) + else: + self.occlusion = None + + self.num_kp = num_kp + self.scale_factor = scale_factor + self.kp_variance = kp_variance + + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + + def create_heatmap_representations(self, source_image, kp_driving, kp_source): + """ + Eq 6. in the paper H_k(z) + """ + spatial_size = source_image.shape[2:] + gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=self.kp_variance) + gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=self.kp_variance) + heatmap = gaussian_driving - gaussian_source + + #adding background feature + zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1]).type(heatmap.type()) + heatmap = torch.cat([zeros, heatmap], dim=1) + heatmap = heatmap.unsqueeze(2) + return heatmap + + def create_sparse_motions(self, source_image, kp_driving, kp_source): + """ + Eq 4. in the paper T_{s<-d}(z) + """ + bs, _, h, w = source_image.shape + identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) + identity_grid = identity_grid.view(1, 1, h, w, 2) + coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 2) + if 'jacobian' in kp_driving: + jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian'])) + jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) + jacobian = jacobian.repeat(1, 1, h, w, 1, 1) + coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) + coordinate_grid = coordinate_grid.squeeze(-1) + + driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 2) + + #adding background feature + identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) + sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) + return sparse_motions + + def create_deformed_source_image(self, source_image, sparse_motions): + """ + Eq 7. in the paper \hat{T}_{s<-d}(z) + """ + bs, _, h, w = source_image.shape + source_repeat = source_image.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp + 1, 1, 1, 1, 1) + source_repeat = source_repeat.view(bs * (self.num_kp + 1), -1, h, w) + sparse_motions = sparse_motions.view((bs * (self.num_kp + 1), h, w, -1)) + sparse_deformed = F.grid_sample(source_repeat, sparse_motions) + sparse_deformed = sparse_deformed.view((bs, self.num_kp + 1, -1, h, w)) + return sparse_deformed + + def forward(self, source_image, kp_driving, kp_source): + if self.scale_factor != 1: + source_image = self.down(source_image) + + bs, _, h, w = source_image.shape + + out_dict = dict() + heatmap_representation = self.create_heatmap_representations(source_image, kp_driving, kp_source) + sparse_motion = self.create_sparse_motions(source_image, kp_driving, kp_source) + deformed_source = self.create_deformed_source_image(source_image, sparse_motion) + out_dict['sparse_deformed'] = deformed_source + + input = torch.cat([heatmap_representation, deformed_source], dim=2) + input = input.view(bs, -1, h, w) + + prediction = self.hourglass(input) + + mask = self.mask(prediction) + mask = F.softmax(mask, dim=1) + out_dict['mask'] = mask + mask = mask.unsqueeze(2) + sparse_motion = sparse_motion.permute(0, 1, 4, 2, 3) + deformation = (sparse_motion * mask).sum(dim=1) + deformation = deformation.permute(0, 2, 3, 1) + + out_dict['deformation'] = deformation + + # Sec. 3.2 in the paper + if self.occlusion: + occlusion_map = torch.sigmoid(self.occlusion(prediction)) + out_dict['occlusion_map'] = occlusion_map + + return out_dict diff --git a/modules/discriminator.py b/modules/discriminator.py new file mode 100644 index 0000000..8356493 --- /dev/null +++ b/modules/discriminator.py @@ -0,0 +1,95 @@ +from torch import nn +import torch.nn.functional as F +from modules.util import kp2gaussian +import torch + + +class DownBlock2d(nn.Module): + """ + Simple block for processing video (encoder). + """ + + def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size) + + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + + if norm: + self.norm = nn.InstanceNorm2d(out_features, affine=True) + else: + self.norm = None + self.pool = pool + + def forward(self, x): + out = x + out = self.conv(out) + if self.norm: + out = self.norm(out) + out = F.leaky_relu(out, 0.2) + if self.pool: + out = F.avg_pool2d(out, (2, 2)) + return out + + +class Discriminator(nn.Module): + """ + Discriminator similar to Pix2Pix + """ + + def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512, + sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs): + super(Discriminator, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append( + DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn)) + + self.down_blocks = nn.ModuleList(down_blocks) + self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1) + if sn: + self.conv = nn.utils.spectral_norm(self.conv) + self.use_kp = use_kp + self.kp_variance = kp_variance + + def forward(self, x, kp=None): + feature_maps = [] + out = x + if self.use_kp: + heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance) + out = torch.cat([out, heatmap], dim=1) + + for down_block in self.down_blocks: + feature_maps.append(down_block(out)) + out = feature_maps[-1] + prediction_map = self.conv(out) + + return feature_maps, prediction_map + + +class MultiScaleDiscriminator(nn.Module): + """ + Multi-scale (scale) discriminator + """ + + def __init__(self, scales=(), **kwargs): + super(MultiScaleDiscriminator, self).__init__() + self.scales = scales + discs = {} + for scale in scales: + discs[str(scale).replace('.', '-')] = Discriminator(**kwargs) + self.discs = nn.ModuleDict(discs) + + def forward(self, x, kp=None): + out_dict = {} + for scale, disc in self.discs.items(): + scale = str(scale).replace('-', '.') + key = 'prediction_' + scale + feature_maps, prediction_map = disc(x[key], kp) + out_dict['feature_maps_' + scale] = feature_maps + out_dict['prediction_map_' + scale] = prediction_map + return out_dict diff --git a/modules/generator.py b/modules/generator.py new file mode 100644 index 0000000..ec66570 --- /dev/null +++ b/modules/generator.py @@ -0,0 +1,97 @@ +import torch +from torch import nn +import torch.nn.functional as F +from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d +from modules.dense_motion import DenseMotionNetwork + + +class OcclusionAwareGenerator(nn.Module): + """ + Generator that given source image and and keypoints try to transform image according to movement trajectories + induced by keypoints. Generator follows Johnson architecture. + """ + + def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks, + num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False): + super(OcclusionAwareGenerator, self).__init__() + + if dense_motion_params is not None: + self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels, + estimate_occlusion_map=estimate_occlusion_map, + **dense_motion_params) + else: + self.dense_motion_network = None + + self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.ModuleList(down_blocks) + + up_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) + out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) + up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.up_blocks = nn.ModuleList(up_blocks) + + self.bottleneck = torch.nn.Sequential() + in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) + for i in range(num_bottleneck_blocks): + self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) + + self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) + self.estimate_occlusion_map = estimate_occlusion_map + self.num_channels = num_channels + + def deform_input(self, inp, deformation): + _, h_old, w_old, _ = deformation.shape + _, _, h, w = inp.shape + if h_old != h or w_old != w: + deformation = deformation.permute(0, 3, 1, 2) + deformation = F.interpolate(deformation, size=(h, w), mode='bilinear') + deformation = deformation.permute(0, 2, 3, 1) + return F.grid_sample(inp, deformation) + + def forward(self, source_image, kp_driving, kp_source): + # Encoding (downsampling) part + out = self.first(source_image) + for i in range(len(self.down_blocks)): + out = self.down_blocks[i](out) + + # Transforming feature representation according to deformation and occlusion + output_dict = {} + if self.dense_motion_network is not None: + dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving, + kp_source=kp_source) + output_dict['mask'] = dense_motion['mask'] + output_dict['sparse_deformed'] = dense_motion['sparse_deformed'] + + if 'occlusion_map' in dense_motion: + occlusion_map = dense_motion['occlusion_map'] + output_dict['occlusion_map'] = occlusion_map + else: + occlusion_map = None + deformation = dense_motion['deformation'] + out = self.deform_input(out, deformation) + + if occlusion_map is not None: + if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]: + occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear') + out = out * occlusion_map + + output_dict["deformed"] = self.deform_input(source_image, deformation) + + # Decoding part + out = self.bottleneck(out) + for i in range(len(self.up_blocks)): + out = self.up_blocks[i](out) + out = self.final(out) + out = F.sigmoid(out) + + output_dict["prediction"] = out + + return output_dict diff --git a/modules/keypoint_detector.py b/modules/keypoint_detector.py new file mode 100644 index 0000000..33f9f1d --- /dev/null +++ b/modules/keypoint_detector.py @@ -0,0 +1,75 @@ +from torch import nn +import torch +import torch.nn.functional as F +from modules.util import Hourglass, make_coordinate_grid, AntiAliasInterpolation2d + + +class KPDetector(nn.Module): + """ + Detecting a keypoints. Return keypoint position and jacobian near each keypoint. + """ + + def __init__(self, block_expansion, num_kp, num_channels, max_features, + num_blocks, temperature, estimate_jacobian=False, scale_factor=1, + single_jacobian_map=False, pad=0): + super(KPDetector, self).__init__() + + self.predictor = Hourglass(block_expansion, in_features=num_channels, + max_features=max_features, num_blocks=num_blocks) + + self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=(7, 7), + padding=pad) + + if estimate_jacobian: + self.num_jacobian_maps = 1 if single_jacobian_map else num_kp + self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, + out_channels=4 * self.num_jacobian_maps, kernel_size=(7, 7), padding=pad) + self.jacobian.weight.data.zero_() + self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float)) + else: + self.jacobian = None + + self.temperature = temperature + self.scale_factor = scale_factor + if self.scale_factor != 1: + self.down = AntiAliasInterpolation2d(num_channels, self.scale_factor) + + def gaussian2kp(self, heatmap): + """ + Extract the mean and from a heatmap + """ + shape = heatmap.shape + heatmap = heatmap.unsqueeze(-1) + grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0) + value = (heatmap * grid).sum(dim=(2, 3)) + kp = {'value': value} + + return kp + + def forward(self, x): + if self.scale_factor != 1: + x = self.down(x) + + feature_map = self.predictor(x) + prediction = self.kp(feature_map) + + final_shape = prediction.shape + heatmap = prediction.view(final_shape[0], final_shape[1], -1) + heatmap = F.softmax(heatmap / self.temperature, dim=2) + heatmap = heatmap.view(*final_shape) + + out = self.gaussian2kp(heatmap) + + if self.jacobian is not None: + jacobian_map = self.jacobian(feature_map) + jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 4, final_shape[2], + final_shape[3]) + heatmap = heatmap.unsqueeze(2) + + jacobian = heatmap * jacobian_map + jacobian = jacobian.view(final_shape[0], final_shape[1], 4, -1) + jacobian = jacobian.sum(dim=-1) + jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 2, 2) + out['jacobian'] = jacobian + + return out diff --git a/modules/model.py b/modules/model.py new file mode 100644 index 0000000..7ee07c0 --- /dev/null +++ b/modules/model.py @@ -0,0 +1,259 @@ +from torch import nn +import torch +import torch.nn.functional as F +from modules.util import AntiAliasInterpolation2d, make_coordinate_grid +from torchvision import models +import numpy as np +from torch.autograd import grad + + +class Vgg19(torch.nn.Module): + """ + Vgg19 network for perceptual loss. See Sec 3.3. + """ + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + + self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))), + requires_grad=False) + self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))), + requires_grad=False) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + X = (X - self.mean) / self.std + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + +class ImagePyramide(torch.nn.Module): + """ + Create image pyramide for computing pyramide perceptual loss. See Sec 3.3 + """ + def __init__(self, scales, num_channels): + super(ImagePyramide, self).__init__() + downs = {} + for scale in scales: + downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale) + self.downs = nn.ModuleDict(downs) + + def forward(self, x): + out_dict = {} + for scale, down_module in self.downs.items(): + out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x) + return out_dict + + +class Transform: + """ + Random tps transformation for equivariance constraints. See Sec 3.3 + """ + def __init__(self, bs, **kwargs): + noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3])) + self.theta = noise + torch.eye(2, 3).view(1, 2, 3) + self.bs = bs + + if ('sigma_tps' in kwargs) and ('points_tps' in kwargs): + self.tps = True + self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type()) + self.control_points = self.control_points.unsqueeze(0) + self.control_params = torch.normal(mean=0, + std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2])) + else: + self.tps = False + + def transform_frame(self, frame): + grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0) + grid = grid.view(1, frame.shape[2] * frame.shape[3], 2) + grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2) + return F.grid_sample(frame, grid, padding_mode="reflection") + + def warp_coordinates(self, coordinates): + theta = self.theta.type(coordinates.type()) + theta = theta.unsqueeze(1) + transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:] + transformed = transformed.squeeze(-1) + + if self.tps: + control_points = self.control_points.type(coordinates.type()) + control_params = self.control_params.type(coordinates.type()) + distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2) + distances = torch.abs(distances).sum(-1) + + result = distances ** 2 + result = result * torch.log(distances + 1e-6) + result = result * control_params + result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1) + transformed = transformed + result + + return transformed + + def jacobian(self, coordinates): + new_coordinates = self.warp_coordinates(coordinates) + grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True) + grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True) + jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2) + return jacobian + + +def detach_kp(kp): + return {key: value.detach() for key, value in kp.items()} + + +class GeneratorFullModel(torch.nn.Module): + """ + Merge all generator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, generator, discriminator, train_params): + super(GeneratorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = train_params['scales'] + self.disc_scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + if sum(self.loss_weights['perceptual']) != 0: + self.vgg = Vgg19() + if torch.cuda.is_available(): + self.vgg = self.vgg.cuda() + + def forward(self, x): + kp_source = self.kp_extractor(x['source']) + kp_driving = self.kp_extractor(x['driving']) + + generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving) + generated.update({'kp_source': kp_source, 'kp_driving': kp_driving}) + + loss_values = {} + + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction']) + + if sum(self.loss_weights['perceptual']) != 0: + value_total = 0 + for scale in self.scales: + x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)]) + y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)]) + + for i, weight in enumerate(self.loss_weights['perceptual']): + value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean() + value_total += self.loss_weights['perceptual'][i] * value + loss_values['perceptual'] = value_total + + if self.loss_weights['generator_gan'] != 0: + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + value_total = 0 + for scale in self.disc_scales: + key = 'prediction_map_%s' % scale + value = ((1 - discriminator_maps_generated[key]) ** 2).mean() + value_total += self.loss_weights['generator_gan'] * value + loss_values['gen_gan'] = value_total + + if sum(self.loss_weights['feature_matching']) != 0: + value_total = 0 + for scale in self.disc_scales: + key = 'feature_maps_%s' % scale + for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])): + if self.loss_weights['feature_matching'][i] == 0: + continue + value = torch.abs(a - b).mean() + value_total += self.loss_weights['feature_matching'][i] * value + loss_values['feature_matching'] = value_total + + if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0: + transform = Transform(x['driving'].shape[0], **self.train_params['transform_params']) + transformed_frame = transform.transform_frame(x['driving']) + transformed_kp = self.kp_extractor(transformed_frame) + + generated['transformed_frame'] = transformed_frame + generated['transformed_kp'] = transformed_kp + + ## Value loss part + if self.loss_weights['equivariance_value'] != 0: + value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean() + loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value + + ## jacobian loss part + if self.loss_weights['equivariance_jacobian'] != 0: + jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']), + transformed_kp['jacobian']) + + normed_driving = torch.inverse(kp_driving['jacobian']) + normed_transformed = jacobian_transformed + value = torch.matmul(normed_driving, normed_transformed) + + eye = torch.eye(2).view(1, 1, 2, 2).type(value.type()) + + value = torch.abs(eye - value).mean() + loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value + + return loss_values, generated + + +class DiscriminatorFullModel(torch.nn.Module): + """ + Merge all discriminator related updates into single model for better multi-gpu usage + """ + + def __init__(self, kp_extractor, generator, discriminator, train_params): + super(DiscriminatorFullModel, self).__init__() + self.kp_extractor = kp_extractor + self.generator = generator + self.discriminator = discriminator + self.train_params = train_params + self.scales = self.discriminator.scales + self.pyramid = ImagePyramide(self.scales, generator.num_channels) + if torch.cuda.is_available(): + self.pyramid = self.pyramid.cuda() + + self.loss_weights = train_params['loss_weights'] + + def forward(self, x, generated): + pyramide_real = self.pyramid(x['driving']) + pyramide_generated = self.pyramid(generated['prediction'].detach()) + + kp_driving = generated['kp_driving'] + discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving)) + discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving)) + + loss_values = {} + value_total = 0 + for scale in self.scales: + key = 'prediction_map_%s' % scale + value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2 + value_total += self.loss_weights['discriminator_gan'] * value.mean() + loss_values['disc_gan'] = value_total + + return loss_values diff --git a/modules/util.py b/modules/util.py new file mode 100644 index 0000000..831fea4 --- /dev/null +++ b/modules/util.py @@ -0,0 +1,245 @@ +from torch import nn + +import torch.nn.functional as F +import torch + +from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d + + +def kp2gaussian(kp, spatial_size, kp_variance): + """ + Transform a keypoint into gaussian like representation + """ + mean = kp['value'] + + coordinate_grid = make_coordinate_grid(spatial_size, mean.type()) + number_of_leading_dimensions = len(mean.shape) - 1 + shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape + coordinate_grid = coordinate_grid.view(*shape) + repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1) + coordinate_grid = coordinate_grid.repeat(*repeats) + + # Preprocess kp shape + shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 2) + mean = mean.view(*shape) + + mean_sub = (coordinate_grid - mean) + + out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance) + + return out + + +def make_coordinate_grid(spatial_size, type): + """ + Create a meshgrid [-1,1] x [-1,1] of given spatial_size. + """ + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + + return meshed + + +class ResBlock2d(nn.Module): + """ + Res block, preserve spatial resolution. + """ + + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm2d(in_features, affine=True) + self.norm2 = BatchNorm2d(in_features, affine=True) + + def forward(self, x): + out = self.norm1(x) + out = F.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = F.relu(out) + out = self.conv2(out) + out += x + return out + + +class UpBlock2d(nn.Module): + """ + Upsampling block for use in decoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + + +class DownBlock2d(nn.Module): + """ + Downsampling block for use in encoder. + """ + + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + out = self.pool(out) + return out + + +class SameBlock2d(nn.Module): + """ + Simple block, preserve spatial resolution. + """ + + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features, affine=True) + + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = F.relu(out) + return out + + +class Encoder(nn.Module): + """ + Hourglass Encoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Encoder, self).__init__() + + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + + +class Decoder(nn.Module): + """ + Hourglass Decoder + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Decoder, self).__init__() + + up_blocks = [] + + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) + + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + return out + + +class Hourglass(nn.Module): + """ + Hourglass architecture. + """ + + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + + def forward(self, x): + return self.decoder(self.encoder(x)) + + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + inv_scale = 1 / scale + self.int_inv_scale = int(inv_scale) + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] + + return out diff --git a/resources/img/test1.jpg b/resources/img/test1.jpg new file mode 100644 index 0000000..756a1f5 Binary files /dev/null and b/resources/img/test1.jpg differ diff --git a/resources/img/white board.jpg b/resources/img/white board.jpg new file mode 100644 index 0000000..6b88f15 Binary files /dev/null and b/resources/img/white board.jpg differ diff --git a/resources/output/change.png b/resources/output/change.png new file mode 100644 index 0000000..112309e Binary files /dev/null and b/resources/output/change.png differ diff --git a/resources/output/input.png b/resources/output/input.png new file mode 100644 index 0000000..951519c Binary files /dev/null and b/resources/output/input.png differ