You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1038 lines
38 KiB
1038 lines
38 KiB
import torch
|
|
from torch import nn
|
|
from torch.nn import Parameter
|
|
import torch.nn.functional as F
|
|
import torchvision.models as models
|
|
import copy
|
|
import numpy as np
|
|
# import lpips
|
|
|
|
####################################################################################################
|
|
# spectral normalization layer to decouple the magnitude of a weight tensor
|
|
####################################################################################################
|
|
class VGG16(nn.Module):
|
|
def __init__(self):
|
|
super(VGG16, self).__init__()
|
|
features = models.vgg16(pretrained=True).features
|
|
self.relu1_1 = torch.nn.Sequential()
|
|
self.relu1_2 = torch.nn.Sequential()
|
|
|
|
self.relu2_1 = torch.nn.Sequential()
|
|
self.relu2_2 = torch.nn.Sequential()
|
|
|
|
self.relu3_1 = torch.nn.Sequential()
|
|
self.relu3_2 = torch.nn.Sequential()
|
|
self.relu3_3 = torch.nn.Sequential()
|
|
self.max3 = torch.nn.Sequential()
|
|
|
|
self.relu4_1 = torch.nn.Sequential()
|
|
self.relu4_2 = torch.nn.Sequential()
|
|
self.relu4_3 = torch.nn.Sequential()
|
|
|
|
self.relu5_1 = torch.nn.Sequential()
|
|
self.relu5_2 = torch.nn.Sequential()
|
|
self.relu5_3 = torch.nn.Sequential()
|
|
|
|
for x in range(2):
|
|
self.relu1_1.add_module(str(x), features[x])
|
|
|
|
for x in range(2, 4):
|
|
self.relu1_2.add_module(str(x), features[x])
|
|
|
|
for x in range(4, 7):
|
|
self.relu2_1.add_module(str(x), features[x])
|
|
|
|
for x in range(7, 9):
|
|
self.relu2_2.add_module(str(x), features[x])
|
|
|
|
for x in range(9, 12):
|
|
self.relu3_1.add_module(str(x), features[x])
|
|
|
|
for x in range(12, 14):
|
|
self.relu3_2.add_module(str(x), features[x])
|
|
|
|
for x in range(14, 16):
|
|
self.relu3_3.add_module(str(x), features[x])
|
|
for x in range(16, 17):
|
|
self.max3.add_module(str(x), features[x])
|
|
|
|
for x in range(17, 19):
|
|
self.relu4_1.add_module(str(x), features[x])
|
|
|
|
for x in range(19, 21):
|
|
self.relu4_2.add_module(str(x), features[x])
|
|
|
|
for x in range(21, 23):
|
|
self.relu4_3.add_module(str(x), features[x])
|
|
|
|
for x in range(23, 26):
|
|
self.relu5_1.add_module(str(x), features[x])
|
|
|
|
for x in range(26, 28):
|
|
self.relu5_2.add_module(str(x), features[x])
|
|
|
|
for x in range(28, 30):
|
|
self.relu5_3.add_module(str(x), features[x])
|
|
|
|
# don't need the gradients, just want the features
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
|
|
def forward(self, x):
|
|
relu1_1 = self.relu1_1(x)
|
|
relu1_2 = self.relu1_2(relu1_1)
|
|
|
|
relu2_1 = self.relu2_1(relu1_2)
|
|
relu2_2 = self.relu2_2(relu2_1)
|
|
|
|
relu3_1 = self.relu3_1(relu2_2)
|
|
relu3_2 = self.relu3_2(relu3_1)
|
|
relu3_3 = self.relu3_3(relu3_2)
|
|
max_3 = self.max3(relu3_3)
|
|
|
|
relu4_1 = self.relu4_1(max_3)
|
|
relu4_2 = self.relu4_2(relu4_1)
|
|
relu4_3 = self.relu4_3(relu4_2)
|
|
|
|
relu5_1 = self.relu5_1(relu4_3)
|
|
relu5_2 = self.relu5_1(relu5_1)
|
|
relu5_3 = self.relu5_1(relu5_2)
|
|
out = {
|
|
"relu1_1": relu1_1,
|
|
"relu1_2": relu1_2,
|
|
"relu2_1": relu2_1,
|
|
"relu2_2": relu2_2,
|
|
"relu3_1": relu3_1,
|
|
"relu3_2": relu3_2,
|
|
"relu3_3": relu3_3,
|
|
"max_3": max_3,
|
|
"relu4_1": relu4_1,
|
|
"relu4_2": relu4_2,
|
|
"relu4_3": relu4_3,
|
|
"relu5_1": relu5_1,
|
|
"relu5_2": relu5_2,
|
|
"relu5_3": relu5_3,
|
|
}
|
|
return out
|
|
|
|
|
|
def l2normalize(v, eps=1e-12):
|
|
return v / (v.norm() + eps)
|
|
|
|
|
|
class SpectralNorm(nn.Module):
|
|
"""
|
|
spectral normalization
|
|
code and idea originally from Takeru Miyato's work 'Spectral Normalization for GAN'
|
|
https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
|
|
"""
|
|
def __init__(self, module, name='weight', power_iterations=1):
|
|
super(SpectralNorm, self).__init__()
|
|
self.module = module
|
|
self.name = name
|
|
self.power_iterations = power_iterations
|
|
if not self._made_params():
|
|
self._make_params()
|
|
|
|
def _update_u_v(self):
|
|
u = getattr(self.module, self.name + "_u")
|
|
v = getattr(self.module, self.name + "_v")
|
|
w = getattr(self.module, self.name + "_bar")
|
|
|
|
height = w.data.shape[0]
|
|
for _ in range(self.power_iterations):
|
|
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
|
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
|
|
|
sigma = u.dot(w.view(height, -1).mv(v))
|
|
setattr(self.module, self.name, w / sigma.expand_as(w))
|
|
|
|
def _made_params(self):
|
|
try:
|
|
u = getattr(self.module, self.name + "_u")
|
|
v = getattr(self.module, self.name + "_v")
|
|
w = getattr(self.module, self.name + "_bar")
|
|
return True
|
|
except AttributeError:
|
|
return False
|
|
|
|
def _make_params(self):
|
|
w = getattr(self.module, self.name)
|
|
|
|
height = w.data.shape[0]
|
|
width = w.view(height, -1).data.shape[1]
|
|
|
|
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
|
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
|
u.data = l2normalize(u.data)
|
|
v.data = l2normalize(v.data)
|
|
w_bar = Parameter(w.data)
|
|
|
|
del self.module._parameters[self.name]
|
|
|
|
self.module.register_parameter(self.name + "_u", u)
|
|
self.module.register_parameter(self.name + "_v", v)
|
|
self.module.register_parameter(self.name + "_bar", w_bar)
|
|
|
|
def forward(self, *args):
|
|
self._update_u_v()
|
|
return self.module.forward(*args)
|
|
|
|
|
|
####################################################################################################
|
|
# adversarial loss for different gan mode
|
|
####################################################################################################
|
|
|
|
|
|
class GANLoss(nn.Module):
|
|
"""Define different GAN objectives.
|
|
The GANLoss class abstracts away the need to create the target label tensor
|
|
that has the same size as the input.
|
|
"""
|
|
|
|
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
|
""" Initialize the GANLoss class.
|
|
Parameters:
|
|
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
|
target_real_label (bool) - - label for a real image
|
|
target_fake_label (bool) - - label of a fake image
|
|
Note: Do not use sigmoid as the last layer of Discriminator.
|
|
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
|
"""
|
|
super(GANLoss, self).__init__()
|
|
self.register_buffer('real_label', torch.tensor(target_real_label))
|
|
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
|
self.gan_mode = gan_mode
|
|
if gan_mode == 'lsgan':
|
|
self.loss = nn.MSELoss()
|
|
elif gan_mode == 'vanilla':
|
|
self.loss = nn.BCEWithLogitsLoss()
|
|
elif gan_mode == 'hinge':
|
|
self.loss = nn.ReLU()
|
|
elif gan_mode == 'wgangp':
|
|
self.loss = None
|
|
elif gan_mode == 'wgandiv':
|
|
self.loss = None
|
|
else:
|
|
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
|
|
|
def __call__(self, prediction, target_is_real, is_disc=False):
|
|
"""Calculate loss given Discriminator's output and grount truth labels.
|
|
Parameters:
|
|
prediction (tensor) - - tpyically the prediction output from a discriminator
|
|
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
|
Returns:
|
|
the calculated loss.
|
|
"""
|
|
if self.gan_mode in ['lsgan', 'vanilla']:
|
|
labels = (self.real_label if target_is_real else self.fake_label).expand_as(prediction).type_as(prediction)
|
|
loss = self.loss(prediction, labels)
|
|
elif self.gan_mode in ['hinge', 'wgangp']:
|
|
if is_disc:
|
|
if target_is_real:
|
|
prediction = -prediction
|
|
if self.gan_mode == 'hinge':
|
|
loss = self.loss(1 + prediction).mean()
|
|
elif self.gan_mode == 'wgangp':
|
|
loss = prediction.mean()
|
|
else:
|
|
loss = -prediction.mean()
|
|
elif self.gan_mode in ['wgandiv']:
|
|
loss = prediction.mean()
|
|
|
|
return loss
|
|
|
|
|
|
class PD_Loss(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(PD_Loss, self).__init__()
|
|
self.criterion = torch.nn.L1Loss()
|
|
|
|
def __call__(self, x, y):
|
|
# Compute features
|
|
pd_loss = 0.0
|
|
pd_loss = self.criterion(x, y)
|
|
return pd_loss
|
|
|
|
|
|
class TV_Loss(nn.Module):
|
|
def __init__(self):
|
|
super(TV_Loss, self).__init__()
|
|
|
|
def __call__(self, image, mask, method):
|
|
hole_mask = 1 - mask
|
|
b, ch, h, w = hole_mask.shape
|
|
dilation_conv = nn.Conv2d(ch, ch, 3, padding=1, bias=False).to(hole_mask)
|
|
torch.nn.init.constant_(dilation_conv.weight, 1.0)
|
|
with torch.no_grad():
|
|
output_mask = dilation_conv(hole_mask)
|
|
updated_holes = output_mask != 0
|
|
dilated_holes = updated_holes.float()
|
|
colomns_in_Pset = dilated_holes[:, :, :, 1:] * dilated_holes[:, :, :, :-1]
|
|
rows_in_Pset = dilated_holes[:, :, 1:, :] * dilated_holes[:, :, :-1:, :]
|
|
if method == "sum":
|
|
loss = torch.sum(
|
|
torch.abs(colomns_in_Pset * (image[:, :, :, 1:] - image[:, :, :, :-1]))
|
|
) + torch.sum(
|
|
torch.abs(rows_in_Pset * (image[:, :, :1, :] - image[:, :, -1:, :]))
|
|
)
|
|
else:
|
|
loss = torch.mean(
|
|
torch.abs(colomns_in_Pset * (image[:, :, :, 1:] - image[:, :, :, :-1]))
|
|
) + torch.mean(
|
|
torch.abs(rows_in_Pset * (image[:, :, :1, :] - image[:, :, -1:, :]))
|
|
)
|
|
return loss
|
|
|
|
|
|
def cal_gradient_penalty(netD, real_data, fake_data, type='mixed', constant=1.0, lambda_gp=10.0):
|
|
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
|
Arguments:
|
|
netD (network) -- discriminator network
|
|
real_data (tensor array) -- real images
|
|
fake_data (tensor array) -- generated images from the generator
|
|
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
|
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
|
|
lambda_gp (float) -- weight for this loss
|
|
Returns the gradient penalty loss
|
|
"""
|
|
if lambda_gp > 0.0:
|
|
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
|
interpolatesv = real_data
|
|
elif type == 'fake':
|
|
interpolatesv = fake_data
|
|
elif type == 'mixed':
|
|
alpha = torch.rand(real_data.shape[0], 1)
|
|
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
|
alpha = alpha.type_as(real_data)
|
|
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
|
else:
|
|
raise NotImplementedError('{} not implemented'.format(type))
|
|
interpolatesv.requires_grad_(True)
|
|
disc_interpolates = netD(interpolatesv)
|
|
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
|
grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data),
|
|
create_graph=True, retain_graph=True, only_inputs=True)
|
|
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
|
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
|
return gradient_penalty, gradients
|
|
else:
|
|
return 0.0, None
|
|
|
|
|
|
def cal_gradient_penalty_div(netD, real_data, fake_data, type='mixed', const_power=6.0, const_kappa=2.0):
|
|
|
|
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
|
interpolatesv = real_data
|
|
elif type == 'fake':
|
|
interpolatesv = fake_data
|
|
elif type == 'mixed':
|
|
alpha = torch.rand(real_data.shape[0], 1)
|
|
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(
|
|
*real_data.shape)
|
|
alpha = alpha.type_as(real_data)
|
|
interpolatesv = (1 - alpha) * real_data + alpha * fake_data
|
|
else:
|
|
raise NotImplementedError('{} not implemented'.format(type))
|
|
interpolatesv.requires_grad_(True)
|
|
disc_interpolates = netD(interpolatesv)
|
|
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
|
grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data),
|
|
create_graph=True, retain_graph=True, only_inputs=True)
|
|
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
|
gradients_penalty_div = torch.pow((gradients + 1e-16).norm(2, dim=1), const_power).mean() * const_kappa
|
|
return gradients_penalty_div
|
|
|
|
|
|
####################################################################################################
|
|
# neural style transform loss from neural_style_tutorial of pytorch
|
|
####################################################################################################
|
|
|
|
|
|
def ContentLoss(input, target):
|
|
target = target.detach()
|
|
loss = F.l1_loss(input, target)
|
|
return loss
|
|
|
|
|
|
def GramMatrix(input):
|
|
s = input.size()
|
|
features = input.view(s[0], s[1], s[2]*s[3])
|
|
features_t = torch.transpose(features, 1, 2)
|
|
G = torch.bmm(features, features_t).div(s[1]*s[2]*s[3])
|
|
return G
|
|
|
|
|
|
def StyleLoss(input, target):
|
|
target = GramMatrix(target).detach()
|
|
input = GramMatrix(input)
|
|
loss = F.l1_loss(input, target)
|
|
return loss
|
|
|
|
|
|
def img_crop(input, size=224):
|
|
input_cropped = F.upsample(input, size=(size, size), mode='bilinear', align_corners=True)
|
|
return input_cropped
|
|
|
|
|
|
class Normalization(nn.Module):
|
|
def __init__(self, mean, std):
|
|
super(Normalization, self).__init__()
|
|
self.mean = mean.view(-1, 1, 1)
|
|
self.std = std.view(-1, 1, 1)
|
|
|
|
def forward(self, input):
|
|
return (input-self.mean) / self.std
|
|
|
|
|
|
class get_features(nn.Module):
|
|
def __init__(self, cnn):
|
|
super(get_features, self).__init__()
|
|
|
|
vgg = copy.deepcopy(cnn)
|
|
|
|
self.conv1 = nn.Sequential(vgg[0], vgg[1], vgg[2], vgg[3], vgg[4])
|
|
self.conv2 = nn.Sequential(vgg[5], vgg[6], vgg[7], vgg[8], vgg[9])
|
|
self.conv3 = nn.Sequential(vgg[10], vgg[11], vgg[12], vgg[13], vgg[14], vgg[15], vgg[16])
|
|
self.conv4 = nn.Sequential(vgg[17], vgg[18], vgg[19], vgg[20], vgg[21], vgg[22], vgg[23])
|
|
self.conv5 = nn.Sequential(vgg[24], vgg[25], vgg[26], vgg[27], vgg[28], vgg[29], vgg[30])
|
|
|
|
def forward(self, input, layers):
|
|
input = img_crop(input)
|
|
output = []
|
|
for i in range(1, layers):
|
|
layer = getattr(self, 'conv'+str(i))
|
|
input = layer(input)
|
|
output.append(input)
|
|
return output
|
|
|
|
|
|
class GroupNorm(nn.Module):
|
|
def __init__(self, num_features, num_groups=32, eps=1e-5):
|
|
super(GroupNorm, self).__init__()
|
|
self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
|
self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
|
self.num_groups = num_groups
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
N, C, H, W = x.size()
|
|
G = self.num_groups
|
|
# assert C % G == 0
|
|
|
|
x = x.view(N,G,-1)
|
|
mean = x.mean(-1, keepdim=True)
|
|
var = x.var(-1, keepdim=True)
|
|
|
|
x = (x-mean) / (var+self.eps).sqrt()
|
|
x = x.view(N, C, H, W)
|
|
return x * self.weight + self.bias
|
|
|
|
|
|
class FullAttention(nn.Module):
|
|
"""
|
|
Layer implements my version of the self-attention module
|
|
it is mostly same as self attention, but generalizes to
|
|
(k x k) convolutions instead of (1 x 1)
|
|
args:
|
|
in_channels: number of input channels
|
|
out_channels: number of output channels
|
|
activation: activation function to be applied (default: lrelu(0.2))
|
|
kernel_size: kernel size for convolution (default: (1 x 1))
|
|
transpose_conv: boolean denoting whether to use convolutions or transpose
|
|
convolutions
|
|
squeeze_factor: squeeze factor for query and keys (default: 8)
|
|
stride: stride for the convolutions (default: 1)
|
|
padding: padding for the applied convolutions (default: 1)
|
|
bias: whether to apply bias or not (default: True)
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels,
|
|
activation=nn.LeakyReLU(0.2), kernel_size=(3, 3), transpose_conv=False,
|
|
use_spectral_norm=True, use_batch_norm=True,
|
|
squeeze_factor=8, stride=1, padding=1, bias=True):
|
|
""" constructor for the layer """
|
|
|
|
from torch.nn import Conv2d, Parameter, \
|
|
Softmax, ConvTranspose2d, BatchNorm2d, InstanceNorm2d
|
|
|
|
# base constructor call
|
|
super().__init__()
|
|
|
|
# state of the layer
|
|
self.activation = activation
|
|
self.gamma = Parameter(torch.zeros(1))
|
|
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.squeezed_channels = in_channels // squeeze_factor
|
|
self.use_batch_norm = use_batch_norm
|
|
|
|
# Modules required for computations
|
|
if transpose_conv:
|
|
self.query_conv = ConvTranspose2d( # query convolution
|
|
in_channels=in_channels,
|
|
out_channels=in_channels // squeeze_factor,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
|
|
self.key_conv = ConvTranspose2d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels // squeeze_factor,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
|
|
self.value_conv = ConvTranspose2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
|
|
self.residual_conv = ConvTranspose2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
) if not use_spectral_norm else SpectralNorm(
|
|
ConvTranspose2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
)
|
|
|
|
else:
|
|
self.query_conv = Conv2d( # query convolution
|
|
in_channels=in_channels,
|
|
out_channels=in_channels // squeeze_factor,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
|
|
self.key_conv = Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=in_channels // squeeze_factor,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
|
|
self.value_conv = Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
|
|
self.residual_conv = Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
) if not use_spectral_norm else SpectralNorm(
|
|
Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias
|
|
)
|
|
)
|
|
|
|
# softmax module for applying attention
|
|
self.softmax = Softmax(dim=-1)
|
|
# self.batch_norm = BatchNorm2d(out_channels)
|
|
self.batch_norm = InstanceNorm2d(out_channels)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
forward computations of the layer
|
|
:param x: input feature maps (B x C x H x W)
|
|
:return:
|
|
out: self attention value + input feature (B x O x H x W)
|
|
attention: attention map (B x C x H x W)
|
|
"""
|
|
|
|
# extract the batch size of the input tensor
|
|
m_batchsize, _, _, _ = x.size()
|
|
|
|
# create the query projection
|
|
proj_query = self.query_conv(x).view(
|
|
m_batchsize, self.squeezed_channels, -1).permute(0, 2, 1) # B x (N) x C
|
|
|
|
# create the key projection
|
|
proj_key = self.key_conv(x).view(
|
|
m_batchsize, self.squeezed_channels, -1) # B x C x (N)
|
|
|
|
# calculate the attention maps
|
|
energy = torch.bmm(proj_query, proj_key) # energy
|
|
attention = self.softmax(energy) # attention (B x (N) x (N))
|
|
|
|
# create the value projection
|
|
proj_value = self.value_conv(x).view(
|
|
m_batchsize, self.out_channels, -1) # B X C X N
|
|
|
|
# calculate the output
|
|
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
|
|
|
# calculate the residual output
|
|
res_out = self.residual_conv(x)
|
|
|
|
out = out.view(m_batchsize, self.out_channels,
|
|
res_out.shape[-2], res_out.shape[-1])
|
|
|
|
attention = attention.view(m_batchsize, -1,
|
|
res_out.shape[-2], res_out.shape[-1])
|
|
|
|
if self.use_batch_norm:
|
|
res_out = self.batch_norm(res_out)
|
|
|
|
if self.activation is not None:
|
|
out = self.activation(out)
|
|
res_out = self.activation(res_out)
|
|
|
|
# apply the residual connections
|
|
out = (self.gamma * out) + ((1 - self.gamma) * res_out)
|
|
return out, attention
|
|
|
|
|
|
class Diversityloss(nn.Module):
|
|
def __init__(self):
|
|
super(Diversityloss, self).__init__()
|
|
self.vgg = VGG16().cuda()
|
|
self.criterion = nn.L1Loss()
|
|
self.weights = [1.0, 1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0]
|
|
|
|
def forward(self, x, y):
|
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
|
diversity_loss = 0.0
|
|
diversity_loss += self.weights[4] * self.criterion(
|
|
x_vgg["relu4_1"], y_vgg["relu4_1"]
|
|
)
|
|
return diversity_loss
|
|
|
|
|
|
class PerceptualLoss(nn.Module):
|
|
r"""
|
|
Perceptual loss, VGG-based
|
|
https://arxiv.org/abs/1603.08155
|
|
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
|
|
"""
|
|
|
|
def __init__(self, weights=[1.0 / 2, 1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0]):
|
|
super(PerceptualLoss, self).__init__()
|
|
self.add_module("vgg", VGG16().cuda())
|
|
self.criterion = torch.nn.L1Loss()
|
|
self.weights = weights
|
|
|
|
def __call__(self, x, y):
|
|
# Compute features
|
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
|
|
|
content_loss = 0.0
|
|
content_loss += self.weights[0] * self.criterion(
|
|
x_vgg["relu1_1"], y_vgg["relu1_1"]
|
|
)
|
|
content_loss += self.weights[1] * self.criterion(
|
|
x_vgg["relu2_1"], y_vgg["relu2_1"]
|
|
)
|
|
content_loss += self.weights[2] * self.criterion(
|
|
x_vgg["relu3_1"], y_vgg["relu3_1"]
|
|
)
|
|
content_loss += self.weights[3] * self.criterion(
|
|
x_vgg["relu4_1"], y_vgg["relu4_1"]
|
|
)
|
|
content_loss += self.weights[4] * self.criterion(
|
|
x_vgg["relu5_1"], y_vgg["relu5_1"]
|
|
)
|
|
return content_loss
|
|
|
|
|
|
class StyleLoss(nn.Module):
|
|
r"""
|
|
Perceptual loss, VGG-based
|
|
https://arxiv.org/abs/1603.08155
|
|
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(StyleLoss, self).__init__()
|
|
self.add_module("vgg", VGG16().cuda())
|
|
self.criterion = torch.nn.L1Loss()
|
|
|
|
def compute_gram(self, x):
|
|
b, ch, h, w = x.size()
|
|
f = x.view(b, ch, w * h)
|
|
f_T = f.transpose(1, 2)
|
|
G = f.bmm(f_T) / (h * w * ch)
|
|
|
|
return G
|
|
|
|
def __call__(self, x, y):
|
|
# Compute features
|
|
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
|
|
|
# Compute loss
|
|
style_loss = 0.0
|
|
style_loss += self.criterion(
|
|
self.compute_gram(x_vgg["relu2_2"]), self.compute_gram(y_vgg["relu2_2"])
|
|
)
|
|
style_loss += self.criterion(
|
|
self.compute_gram(x_vgg["relu3_3"]), self.compute_gram(y_vgg["relu3_3"])
|
|
)
|
|
style_loss += self.criterion(
|
|
self.compute_gram(x_vgg["relu4_3"]), self.compute_gram(y_vgg["relu4_3"])
|
|
)
|
|
return style_loss
|
|
|
|
|
|
def reduce_mean(x, axis=None, keepdim=False):
|
|
if not axis:
|
|
axis = range(len(x.shape))
|
|
for i in sorted(axis, reverse=True):
|
|
x = torch.mean(x, dim=i, keepdim=keepdim)
|
|
return x
|
|
|
|
|
|
def same_padding(images, ksizes, strides, rates):
|
|
assert len(images.size()) == 4
|
|
batch_size, channel, rows, cols = images.size()
|
|
out_rows = (rows + strides[0] - 1) // strides[0]
|
|
out_cols = (cols + strides[1] - 1) // strides[1]
|
|
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
|
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
|
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
|
|
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
|
|
# Pad the input
|
|
padding_top = int(padding_rows / 2.)
|
|
padding_left = int(padding_cols / 2.)
|
|
padding_bottom = padding_rows - padding_top
|
|
padding_right = padding_cols - padding_left
|
|
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
|
images = torch.nn.ZeroPad2d(paddings)(images)
|
|
return images
|
|
|
|
|
|
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
|
"""
|
|
Extract patches from images and put them in the C output dimension.
|
|
:param padding:
|
|
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
|
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
|
each dimension of images
|
|
:param strides: [stride_rows, stride_cols]
|
|
:param rates: [dilation_rows, dilation_cols]
|
|
:return: A Tensor
|
|
"""
|
|
assert len(images.size()) == 4
|
|
assert padding in ['same', 'valid']
|
|
batch_size, channel, height, width = images.size()
|
|
|
|
if padding == 'same':
|
|
images = same_padding(images, ksizes, strides, rates)
|
|
elif padding == 'valid':
|
|
pass
|
|
else:
|
|
raise NotImplementedError('Unsupported padding type: {}.\
|
|
Only "same" or "valid" are supported.'.format(padding))
|
|
|
|
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
|
dilation=rates,
|
|
padding=0,
|
|
stride=strides)
|
|
patches = unfold(images)
|
|
return patches # [N, C*k*k, L], L is the total number of such blocks
|
|
|
|
|
|
def flow_to_image(flow):
|
|
"""Transfer flow map to image.
|
|
Part of code forked from flownet.
|
|
"""
|
|
out = []
|
|
maxu = -999.
|
|
maxv = -999.
|
|
minu = 999.
|
|
minv = 999.
|
|
maxrad = -1
|
|
for i in range(flow.shape[0]):
|
|
u = flow[i, :, :, 0]
|
|
v = flow[i, :, :, 1]
|
|
idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
|
|
u[idxunknow] = 0
|
|
v[idxunknow] = 0
|
|
maxu = max(maxu, np.max(u))
|
|
minu = min(minu, np.min(u))
|
|
maxv = max(maxv, np.max(v))
|
|
minv = min(minv, np.min(v))
|
|
rad = np.sqrt(u ** 2 + v ** 2)
|
|
maxrad = max(maxrad, np.max(rad))
|
|
u = u / (maxrad + np.finfo(float).eps)
|
|
v = v / (maxrad + np.finfo(float).eps)
|
|
img = compute_color(u, v)
|
|
out.append(img)
|
|
return np.float32(np.uint8(out))
|
|
|
|
|
|
def make_color_wheel():
|
|
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
|
|
ncols = RY + YG + GC + CB + BM + MR
|
|
colorwheel = np.zeros([ncols, 3])
|
|
col = 0
|
|
# RY
|
|
colorwheel[0:RY, 0] = 255
|
|
colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
|
|
col += RY
|
|
# YG
|
|
colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
|
|
colorwheel[col:col + YG, 1] = 255
|
|
col += YG
|
|
# GC
|
|
colorwheel[col:col + GC, 1] = 255
|
|
colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
|
|
col += GC
|
|
# CB
|
|
colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
|
|
colorwheel[col:col + CB, 2] = 255
|
|
col += CB
|
|
# BM
|
|
colorwheel[col:col + BM, 2] = 255
|
|
colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
|
|
col += + BM
|
|
# MR
|
|
colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
|
|
colorwheel[col:col + MR, 0] = 255
|
|
return colorwheel
|
|
|
|
|
|
def compute_color(u, v):
|
|
h, w = u.shape
|
|
img = np.zeros([h, w, 3])
|
|
nanIdx = np.isnan(u) | np.isnan(v)
|
|
u[nanIdx] = 0
|
|
v[nanIdx] = 0
|
|
# colorwheel = COLORWHEEL
|
|
colorwheel = make_color_wheel()
|
|
ncols = np.size(colorwheel, 0)
|
|
rad = np.sqrt(u ** 2 + v ** 2)
|
|
a = np.arctan2(-v, -u) / np.pi
|
|
fk = (a + 1) / 2 * (ncols - 1) + 1
|
|
k0 = np.floor(fk).astype(int)
|
|
k1 = k0 + 1
|
|
k1[k1 == ncols + 1] = 1
|
|
f = fk - k0
|
|
for i in range(np.size(colorwheel, 1)):
|
|
tmp = colorwheel[:, i]
|
|
col0 = tmp[k0 - 1] / 255
|
|
col1 = tmp[k1 - 1] / 255
|
|
col = (1 - f) * col0 + f * col1
|
|
idx = rad <= 1
|
|
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
|
notidx = np.logical_not(idx)
|
|
col[notidx] *= 0.75
|
|
img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
|
|
return img
|
|
|
|
|
|
def reduce_sum(x, axis=None, keepdim=False):
|
|
if not axis:
|
|
axis = range(len(x.shape))
|
|
for i in sorted(axis, reverse=True):
|
|
x = torch.sum(x, dim=i, keepdim=keepdim)
|
|
return x
|
|
|
|
|
|
class ContextualAttention(nn.Module):
|
|
def __init__(self, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10,
|
|
fuse=False, use_cuda=False, device_ids=None):
|
|
super(ContextualAttention, self).__init__()
|
|
self.ksize = ksize
|
|
self.stride = stride
|
|
self.rate = rate
|
|
self.fuse_k = fuse_k
|
|
self.softmax_scale = softmax_scale
|
|
self.fuse = fuse
|
|
self.use_cuda = use_cuda
|
|
self.device_ids = device_ids
|
|
|
|
def forward(self, f, b, mask=None):
|
|
""" Contextual attention layer implementation.
|
|
Contextual attention is first introduced in publication:
|
|
Generative Image Inpainting with Contextual Attention, Yu et al.
|
|
Args:
|
|
f: Input feature to match (foreground).
|
|
b: Input feature for match (background).
|
|
mask: Input mask for b, indicating patches not available.
|
|
ksize: Kernel size for contextual attention.
|
|
stride: Stride for extracting patches from b.
|
|
rate: Dilation for matching.
|
|
softmax_scale: Scaled softmax for attention.
|
|
Returns:
|
|
torch.tensor: output
|
|
"""
|
|
# get shapes
|
|
raw_int_fs = list(f.size()) # b*c*h*w
|
|
raw_int_bs = list(b.size()) # b*c*h*w
|
|
|
|
# extract patches from background with stride and rate
|
|
kernel = 2 * self.rate
|
|
# raw_w is extracted for reconstruction
|
|
raw_w = extract_image_patches(b, ksizes=[kernel, kernel],
|
|
strides=[self.rate*self.stride,
|
|
self.rate*self.stride],
|
|
rates=[1, 1],
|
|
padding='same') # [N, C*k*k, L]
|
|
# raw_shape: [N, C, k, k, L]
|
|
raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
|
|
raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
|
|
raw_w_groups = torch.split(raw_w, 1, dim=0)
|
|
|
|
# downscaling foreground option: downscaling both foreground and
|
|
# background for matching and use original background for reconstruction.
|
|
f = F.interpolate(f, scale_factor=1./self.rate, mode='nearest')
|
|
b = F.interpolate(b, scale_factor=1./self.rate, mode='nearest')
|
|
int_fs = list(f.size()) # b*c*h*w
|
|
int_bs = list(b.size())
|
|
f_groups = torch.split(f, 1, dim=0) # split tensors along the batch dimension
|
|
# w shape: [N, C*k*k, L]
|
|
w = extract_image_patches(b, ksizes=[self.ksize, self.ksize],
|
|
strides=[self.stride, self.stride],
|
|
rates=[1, 1],
|
|
padding='same')
|
|
# w shape: [N, C, k, k, L]
|
|
w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
|
|
w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
|
|
w_groups = torch.split(w, 1, dim=0)
|
|
|
|
# process mask
|
|
if mask is None:
|
|
mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]])
|
|
if self.use_cuda:
|
|
mask = mask.cuda()
|
|
else:
|
|
mask = F.interpolate(mask, scale_factor=1./(4*self.rate), mode='nearest')
|
|
int_ms = list(mask.size())
|
|
# m shape: [N, C*k*k, L]
|
|
m = extract_image_patches(mask, ksizes=[self.ksize, self.ksize],
|
|
strides=[self.stride, self.stride],
|
|
rates=[1, 1],
|
|
padding='same')
|
|
# m shape: [N, C, k, k, L]
|
|
m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
|
|
m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
|
|
m = m[0] # m shape: [L, C, k, k]
|
|
# mm shape: [L, 1, 1, 1]
|
|
mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.).to(torch.float32)
|
|
mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
|
|
|
|
y = []
|
|
offsets = []
|
|
k = self.fuse_k
|
|
scale = self.softmax_scale # to fit the PyTorch tensor image value range
|
|
fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k
|
|
if self.use_cuda:
|
|
fuse_weight = fuse_weight.cuda()
|
|
|
|
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
|
|
'''
|
|
O => output channel as a conv filter
|
|
I => input channel as a conv filter
|
|
xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
|
|
wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
|
|
raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
|
|
'''
|
|
# conv for compare
|
|
escape_NaN = torch.FloatTensor([1e-4])
|
|
if self.use_cuda:
|
|
escape_NaN = escape_NaN.cuda()
|
|
wi = wi[0] # [L, C, k, k]
|
|
max_wi = torch.sqrt(reduce_sum(torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True))
|
|
wi_normed = wi / max_wi
|
|
# xi shape: [1, C, H, W], yi shape: [1, L, H, W]
|
|
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
|
|
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W]
|
|
# conv implementation for fuse scores to encourage large patches
|
|
if self.fuse:
|
|
# make all of depth to spatial resolution
|
|
yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3]) # (B=1, I=1, H=32*32, W=32*32)
|
|
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
|
|
yi = F.conv2d(yi, fuse_weight, stride=1) # (B=1, C=1, H=32*32, W=32*32)
|
|
yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]) # (B=1, 32, 32, 32, 32)
|
|
yi = yi.permute(0, 2, 1, 4, 3)
|
|
yi = yi.contiguous().view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])
|
|
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
|
|
yi = F.conv2d(yi, fuse_weight, stride=1)
|
|
yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2])
|
|
yi = yi.permute(0, 2, 1, 4, 3).contiguous()
|
|
yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]) # (B=1, C=32*32, H=32, W=32)
|
|
# softmax to match
|
|
yi = yi * mm
|
|
yi = F.softmax(yi*scale, dim=1)
|
|
yi = yi * mm # [1, L, H, W]
|
|
|
|
offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
|
|
|
|
if int_bs != int_fs:
|
|
# Normalize the offset value to match foreground dimension
|
|
times = float(int_fs[2] * int_fs[3]) / float(int_bs[2] * int_bs[3])
|
|
offset = ((offset + 1).float() * times - 1).to(torch.int64)
|
|
offset = torch.cat([offset//int_fs[3], offset%int_fs[3]], dim=1) # 1*2*H*W
|
|
|
|
# deconv for patch pasting
|
|
wi_center = raw_wi[0]
|
|
# yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding
|
|
yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64)
|
|
y.append(yi)
|
|
offsets.append(offset)
|
|
|
|
y = torch.cat(y, dim=0) # back to the mini-batch
|
|
y.contiguous().view(raw_int_fs)
|
|
|
|
offsets = torch.cat(offsets, dim=0)
|
|
offsets = offsets.view(int_fs[0], 2, *int_fs[2:])
|
|
|
|
# case1: visualize optical flow: minus current position
|
|
h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(int_fs[0], -1, -1, int_fs[3])
|
|
w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(int_fs[0], -1, int_fs[2], -1)
|
|
ref_coordinate = torch.cat([h_add, w_add], dim=1)
|
|
if self.use_cuda:
|
|
ref_coordinate = ref_coordinate.cuda()
|
|
|
|
offsets = offsets - ref_coordinate
|
|
# flow = pt_flow_to_image(offsets)
|
|
|
|
flow = torch.from_numpy(flow_to_image(offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255.
|
|
flow = flow.permute(0, 3, 1, 2)
|
|
if self.use_cuda:
|
|
flow = flow.cuda()
|
|
# case2: visualize which pixels are attended
|
|
# flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))
|
|
|
|
if self.rate != 1:
|
|
flow = F.interpolate(flow, scale_factor=self.rate*4, mode='nearest')
|
|
|
|
return y, flow
|
|
|