You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
98 lines
4.5 KiB
98 lines
4.5 KiB
5 months ago
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
from modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d
|
||
|
from modules.dense_motion import DenseMotionNetwork
|
||
|
|
||
|
|
||
|
class OcclusionAwareGenerator(nn.Module):
|
||
|
"""
|
||
|
Generator that given source image and and keypoints try to transform image according to movement trajectories
|
||
|
induced by keypoints. Generator follows Johnson architecture.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,
|
||
|
num_bottleneck_blocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
|
||
|
super(OcclusionAwareGenerator, self).__init__()
|
||
|
|
||
|
if dense_motion_params is not None:
|
||
|
self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, num_channels=num_channels,
|
||
|
estimate_occlusion_map=estimate_occlusion_map,
|
||
|
**dense_motion_params)
|
||
|
else:
|
||
|
self.dense_motion_network = None
|
||
|
|
||
|
self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
|
||
|
|
||
|
down_blocks = []
|
||
|
for i in range(num_down_blocks):
|
||
|
in_features = min(max_features, block_expansion * (2 ** i))
|
||
|
out_features = min(max_features, block_expansion * (2 ** (i + 1)))
|
||
|
down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
||
|
self.down_blocks = nn.ModuleList(down_blocks)
|
||
|
|
||
|
up_blocks = []
|
||
|
for i in range(num_down_blocks):
|
||
|
in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
|
||
|
out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
|
||
|
up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
|
||
|
self.up_blocks = nn.ModuleList(up_blocks)
|
||
|
|
||
|
self.bottleneck = torch.nn.Sequential()
|
||
|
in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
|
||
|
for i in range(num_bottleneck_blocks):
|
||
|
self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
|
||
|
|
||
|
self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
|
||
|
self.estimate_occlusion_map = estimate_occlusion_map
|
||
|
self.num_channels = num_channels
|
||
|
|
||
|
def deform_input(self, inp, deformation):
|
||
|
_, h_old, w_old, _ = deformation.shape
|
||
|
_, _, h, w = inp.shape
|
||
|
if h_old != h or w_old != w:
|
||
|
deformation = deformation.permute(0, 3, 1, 2)
|
||
|
deformation = F.interpolate(deformation, size=(h, w), mode='bilinear')
|
||
|
deformation = deformation.permute(0, 2, 3, 1)
|
||
|
return F.grid_sample(inp, deformation)
|
||
|
|
||
|
def forward(self, source_image, kp_driving, kp_source):
|
||
|
# Encoding (downsampling) part
|
||
|
out = self.first(source_image)
|
||
|
for i in range(len(self.down_blocks)):
|
||
|
out = self.down_blocks[i](out)
|
||
|
|
||
|
# Transforming feature representation according to deformation and occlusion
|
||
|
output_dict = {}
|
||
|
if self.dense_motion_network is not None:
|
||
|
dense_motion = self.dense_motion_network(source_image=source_image, kp_driving=kp_driving,
|
||
|
kp_source=kp_source)
|
||
|
output_dict['mask'] = dense_motion['mask']
|
||
|
output_dict['sparse_deformed'] = dense_motion['sparse_deformed']
|
||
|
|
||
|
if 'occlusion_map' in dense_motion:
|
||
|
occlusion_map = dense_motion['occlusion_map']
|
||
|
output_dict['occlusion_map'] = occlusion_map
|
||
|
else:
|
||
|
occlusion_map = None
|
||
|
deformation = dense_motion['deformation']
|
||
|
out = self.deform_input(out, deformation)
|
||
|
|
||
|
if occlusion_map is not None:
|
||
|
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
|
||
|
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
|
||
|
out = out * occlusion_map
|
||
|
|
||
|
output_dict["deformed"] = self.deform_input(source_image, deformation)
|
||
|
|
||
|
# Decoding part
|
||
|
out = self.bottleneck(out)
|
||
|
for i in range(len(self.up_blocks)):
|
||
|
out = self.up_blocks[i](out)
|
||
|
out = self.final(out)
|
||
|
out = F.sigmoid(out)
|
||
|
|
||
|
output_dict["prediction"] = out
|
||
|
|
||
|
return output_dict
|