You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1038 lines
38 KiB

import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
import torchvision.models as models
import copy
import numpy as np
# import lpips
####################################################################################################
# spectral normalization layer to decouple the magnitude of a weight tensor
####################################################################################################
class VGG16(nn.Module):
def __init__(self):
super(VGG16, self).__init__()
features = models.vgg16(pretrained=True).features
self.relu1_1 = torch.nn.Sequential()
self.relu1_2 = torch.nn.Sequential()
self.relu2_1 = torch.nn.Sequential()
self.relu2_2 = torch.nn.Sequential()
self.relu3_1 = torch.nn.Sequential()
self.relu3_2 = torch.nn.Sequential()
self.relu3_3 = torch.nn.Sequential()
self.max3 = torch.nn.Sequential()
self.relu4_1 = torch.nn.Sequential()
self.relu4_2 = torch.nn.Sequential()
self.relu4_3 = torch.nn.Sequential()
self.relu5_1 = torch.nn.Sequential()
self.relu5_2 = torch.nn.Sequential()
self.relu5_3 = torch.nn.Sequential()
for x in range(2):
self.relu1_1.add_module(str(x), features[x])
for x in range(2, 4):
self.relu1_2.add_module(str(x), features[x])
for x in range(4, 7):
self.relu2_1.add_module(str(x), features[x])
for x in range(7, 9):
self.relu2_2.add_module(str(x), features[x])
for x in range(9, 12):
self.relu3_1.add_module(str(x), features[x])
for x in range(12, 14):
self.relu3_2.add_module(str(x), features[x])
for x in range(14, 16):
self.relu3_3.add_module(str(x), features[x])
for x in range(16, 17):
self.max3.add_module(str(x), features[x])
for x in range(17, 19):
self.relu4_1.add_module(str(x), features[x])
for x in range(19, 21):
self.relu4_2.add_module(str(x), features[x])
for x in range(21, 23):
self.relu4_3.add_module(str(x), features[x])
for x in range(23, 26):
self.relu5_1.add_module(str(x), features[x])
for x in range(26, 28):
self.relu5_2.add_module(str(x), features[x])
for x in range(28, 30):
self.relu5_3.add_module(str(x), features[x])
# don't need the gradients, just want the features
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
relu1_1 = self.relu1_1(x)
relu1_2 = self.relu1_2(relu1_1)
relu2_1 = self.relu2_1(relu1_2)
relu2_2 = self.relu2_2(relu2_1)
relu3_1 = self.relu3_1(relu2_2)
relu3_2 = self.relu3_2(relu3_1)
relu3_3 = self.relu3_3(relu3_2)
max_3 = self.max3(relu3_3)
relu4_1 = self.relu4_1(max_3)
relu4_2 = self.relu4_2(relu4_1)
relu4_3 = self.relu4_3(relu4_2)
relu5_1 = self.relu5_1(relu4_3)
relu5_2 = self.relu5_1(relu5_1)
relu5_3 = self.relu5_1(relu5_2)
out = {
"relu1_1": relu1_1,
"relu1_2": relu1_2,
"relu2_1": relu2_1,
"relu2_2": relu2_2,
"relu3_1": relu3_1,
"relu3_2": relu3_2,
"relu3_3": relu3_3,
"max_3": max_3,
"relu4_1": relu4_1,
"relu4_2": relu4_2,
"relu4_3": relu4_3,
"relu5_1": relu5_1,
"relu5_2": relu5_2,
"relu5_3": relu5_3,
}
return out
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
"""
spectral normalization
code and idea originally from Takeru Miyato's work 'Spectral Normalization for GAN'
https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
"""
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)
####################################################################################################
# adversarial loss for different gan mode
####################################################################################################
class GANLoss(nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == 'hinge':
self.loss = nn.ReLU()
elif gan_mode == 'wgangp':
self.loss = None
elif gan_mode == 'wgandiv':
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def __call__(self, prediction, target_is_real, is_disc=False):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']:
labels = (self.real_label if target_is_real else self.fake_label).expand_as(prediction).type_as(prediction)
loss = self.loss(prediction, labels)
elif self.gan_mode in ['hinge', 'wgangp']:
if is_disc:
if target_is_real:
prediction = -prediction
if self.gan_mode == 'hinge':
loss = self.loss(1 + prediction).mean()
elif self.gan_mode == 'wgangp':
loss = prediction.mean()
else:
loss = -prediction.mean()
elif self.gan_mode in ['wgandiv']:
loss = prediction.mean()
return loss
class PD_Loss(nn.Module):
def __init__(self):
super(PD_Loss, self).__init__()
self.criterion = torch.nn.L1Loss()
def __call__(self, x, y):
# Compute features
pd_loss = 0.0
pd_loss = self.criterion(x, y)
return pd_loss
class TV_Loss(nn.Module):
def __init__(self):
super(TV_Loss, self).__init__()
def __call__(self, image, mask, method):
hole_mask = 1 - mask
b, ch, h, w = hole_mask.shape
dilation_conv = nn.Conv2d(ch, ch, 3, padding=1, bias=False).to(hole_mask)
torch.nn.init.constant_(dilation_conv.weight, 1.0)
with torch.no_grad():
output_mask = dilation_conv(hole_mask)
updated_holes = output_mask != 0
dilated_holes = updated_holes.float()
colomns_in_Pset = dilated_holes[:, :, :, 1:] * dilated_holes[:, :, :, :-1]
rows_in_Pset = dilated_holes[:, :, 1:, :] * dilated_holes[:, :, :-1:, :]
if method == "sum":
loss = torch.sum(
torch.abs(colomns_in_Pset * (image[:, :, :, 1:] - image[:, :, :, :-1]))
) + torch.sum(
torch.abs(rows_in_Pset * (image[:, :, :1, :] - image[:, :, -1:, :]))
)
else:
loss = torch.mean(
torch.abs(colomns_in_Pset * (image[:, :, :, 1:] - image[:, :, :, :-1]))
) + torch.mean(
torch.abs(rows_in_Pset * (image[:, :, :1, :] - image[:, :, -1:, :]))
)
return loss
def cal_gradient_penalty(netD, real_data, fake_data, type='mixed', constant=1.0, lambda_gp=10.0):
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
Arguments:
netD (network) -- discriminator network
real_data (tensor array) -- real images
fake_data (tensor array) -- generated images from the generator
type (str) -- if we mix real and fake data or not [real | fake | mixed].
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
lambda_gp (float) -- weight for this loss
Returns the gradient penalty loss
"""
if lambda_gp > 0.0:
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = torch.rand(real_data.shape[0], 1)
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
alpha = alpha.type_as(real_data)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.requires_grad_(True)
disc_interpolates = netD(interpolatesv)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data),
create_graph=True, retain_graph=True, only_inputs=True)
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
def cal_gradient_penalty_div(netD, real_data, fake_data, type='mixed', const_power=6.0, const_kappa=2.0):
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = torch.rand(real_data.shape[0], 1)
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(
*real_data.shape)
alpha = alpha.type_as(real_data)
interpolatesv = (1 - alpha) * real_data + alpha * fake_data
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.requires_grad_(True)
disc_interpolates = netD(interpolatesv)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data),
create_graph=True, retain_graph=True, only_inputs=True)
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
gradients_penalty_div = torch.pow((gradients + 1e-16).norm(2, dim=1), const_power).mean() * const_kappa
return gradients_penalty_div
####################################################################################################
# neural style transform loss from neural_style_tutorial of pytorch
####################################################################################################
def ContentLoss(input, target):
target = target.detach()
loss = F.l1_loss(input, target)
return loss
def GramMatrix(input):
s = input.size()
features = input.view(s[0], s[1], s[2]*s[3])
features_t = torch.transpose(features, 1, 2)
G = torch.bmm(features, features_t).div(s[1]*s[2]*s[3])
return G
def StyleLoss(input, target):
target = GramMatrix(target).detach()
input = GramMatrix(input)
loss = F.l1_loss(input, target)
return loss
def img_crop(input, size=224):
input_cropped = F.upsample(input, size=(size, size), mode='bilinear', align_corners=True)
return input_cropped
class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
self.mean = mean.view(-1, 1, 1)
self.std = std.view(-1, 1, 1)
def forward(self, input):
return (input-self.mean) / self.std
class get_features(nn.Module):
def __init__(self, cnn):
super(get_features, self).__init__()
vgg = copy.deepcopy(cnn)
self.conv1 = nn.Sequential(vgg[0], vgg[1], vgg[2], vgg[3], vgg[4])
self.conv2 = nn.Sequential(vgg[5], vgg[6], vgg[7], vgg[8], vgg[9])
self.conv3 = nn.Sequential(vgg[10], vgg[11], vgg[12], vgg[13], vgg[14], vgg[15], vgg[16])
self.conv4 = nn.Sequential(vgg[17], vgg[18], vgg[19], vgg[20], vgg[21], vgg[22], vgg[23])
self.conv5 = nn.Sequential(vgg[24], vgg[25], vgg[26], vgg[27], vgg[28], vgg[29], vgg[30])
def forward(self, input, layers):
input = img_crop(input)
output = []
for i in range(1, layers):
layer = getattr(self, 'conv'+str(i))
input = layer(input)
output.append(input)
return output
class GroupNorm(nn.Module):
def __init__(self, num_features, num_groups=32, eps=1e-5):
super(GroupNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
self.num_groups = num_groups
self.eps = eps
def forward(self, x):
N, C, H, W = x.size()
G = self.num_groups
# assert C % G == 0
x = x.view(N,G,-1)
mean = x.mean(-1, keepdim=True)
var = x.var(-1, keepdim=True)
x = (x-mean) / (var+self.eps).sqrt()
x = x.view(N, C, H, W)
return x * self.weight + self.bias
class FullAttention(nn.Module):
"""
Layer implements my version of the self-attention module
it is mostly same as self attention, but generalizes to
(k x k) convolutions instead of (1 x 1)
args:
in_channels: number of input channels
out_channels: number of output channels
activation: activation function to be applied (default: lrelu(0.2))
kernel_size: kernel size for convolution (default: (1 x 1))
transpose_conv: boolean denoting whether to use convolutions or transpose
convolutions
squeeze_factor: squeeze factor for query and keys (default: 8)
stride: stride for the convolutions (default: 1)
padding: padding for the applied convolutions (default: 1)
bias: whether to apply bias or not (default: True)
"""
def __init__(self, in_channels, out_channels,
activation=nn.LeakyReLU(0.2), kernel_size=(3, 3), transpose_conv=False,
use_spectral_norm=True, use_batch_norm=True,
squeeze_factor=8, stride=1, padding=1, bias=True):
""" constructor for the layer """
from torch.nn import Conv2d, Parameter, \
Softmax, ConvTranspose2d, BatchNorm2d, InstanceNorm2d
# base constructor call
super().__init__()
# state of the layer
self.activation = activation
self.gamma = Parameter(torch.zeros(1))
self.in_channels = in_channels
self.out_channels = out_channels
self.squeezed_channels = in_channels // squeeze_factor
self.use_batch_norm = use_batch_norm
# Modules required for computations
if transpose_conv:
self.query_conv = ConvTranspose2d( # query convolution
in_channels=in_channels,
out_channels=in_channels // squeeze_factor,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.key_conv = ConvTranspose2d(
in_channels=in_channels,
out_channels=in_channels // squeeze_factor,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.value_conv = ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.residual_conv = ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
) if not use_spectral_norm else SpectralNorm(
ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
)
else:
self.query_conv = Conv2d( # query convolution
in_channels=in_channels,
out_channels=in_channels // squeeze_factor,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.key_conv = Conv2d(
in_channels=in_channels,
out_channels=in_channels // squeeze_factor,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.value_conv = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
self.residual_conv = Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
) if not use_spectral_norm else SpectralNorm(
Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias
)
)
# softmax module for applying attention
self.softmax = Softmax(dim=-1)
# self.batch_norm = BatchNorm2d(out_channels)
self.batch_norm = InstanceNorm2d(out_channels)
def forward(self, x):
"""
forward computations of the layer
:param x: input feature maps (B x C x H x W)
:return:
out: self attention value + input feature (B x O x H x W)
attention: attention map (B x C x H x W)
"""
# extract the batch size of the input tensor
m_batchsize, _, _, _ = x.size()
# create the query projection
proj_query = self.query_conv(x).view(
m_batchsize, self.squeezed_channels, -1).permute(0, 2, 1) # B x (N) x C
# create the key projection
proj_key = self.key_conv(x).view(
m_batchsize, self.squeezed_channels, -1) # B x C x (N)
# calculate the attention maps
energy = torch.bmm(proj_query, proj_key) # energy
attention = self.softmax(energy) # attention (B x (N) x (N))
# create the value projection
proj_value = self.value_conv(x).view(
m_batchsize, self.out_channels, -1) # B X C X N
# calculate the output
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
# calculate the residual output
res_out = self.residual_conv(x)
out = out.view(m_batchsize, self.out_channels,
res_out.shape[-2], res_out.shape[-1])
attention = attention.view(m_batchsize, -1,
res_out.shape[-2], res_out.shape[-1])
if self.use_batch_norm:
res_out = self.batch_norm(res_out)
if self.activation is not None:
out = self.activation(out)
res_out = self.activation(res_out)
# apply the residual connections
out = (self.gamma * out) + ((1 - self.gamma) * res_out)
return out, attention
class Diversityloss(nn.Module):
def __init__(self):
super(Diversityloss, self).__init__()
self.vgg = VGG16().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0, 1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
diversity_loss = 0.0
diversity_loss += self.weights[4] * self.criterion(
x_vgg["relu4_1"], y_vgg["relu4_1"]
)
return diversity_loss
class PerceptualLoss(nn.Module):
r"""
Perceptual loss, VGG-based
https://arxiv.org/abs/1603.08155
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
"""
def __init__(self, weights=[1.0 / 2, 1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0]):
super(PerceptualLoss, self).__init__()
self.add_module("vgg", VGG16().cuda())
self.criterion = torch.nn.L1Loss()
self.weights = weights
def __call__(self, x, y):
# Compute features
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
content_loss = 0.0
content_loss += self.weights[0] * self.criterion(
x_vgg["relu1_1"], y_vgg["relu1_1"]
)
content_loss += self.weights[1] * self.criterion(
x_vgg["relu2_1"], y_vgg["relu2_1"]
)
content_loss += self.weights[2] * self.criterion(
x_vgg["relu3_1"], y_vgg["relu3_1"]
)
content_loss += self.weights[3] * self.criterion(
x_vgg["relu4_1"], y_vgg["relu4_1"]
)
content_loss += self.weights[4] * self.criterion(
x_vgg["relu5_1"], y_vgg["relu5_1"]
)
return content_loss
class StyleLoss(nn.Module):
r"""
Perceptual loss, VGG-based
https://arxiv.org/abs/1603.08155
https://github.com/dxyang/StyleTransfer/blob/master/utils.py
"""
def __init__(self):
super(StyleLoss, self).__init__()
self.add_module("vgg", VGG16().cuda())
self.criterion = torch.nn.L1Loss()
def compute_gram(self, x):
b, ch, h, w = x.size()
f = x.view(b, ch, w * h)
f_T = f.transpose(1, 2)
G = f.bmm(f_T) / (h * w * ch)
return G
def __call__(self, x, y):
# Compute features
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
# Compute loss
style_loss = 0.0
style_loss += self.criterion(
self.compute_gram(x_vgg["relu2_2"]), self.compute_gram(y_vgg["relu2_2"])
)
style_loss += self.criterion(
self.compute_gram(x_vgg["relu3_3"]), self.compute_gram(y_vgg["relu3_3"])
)
style_loss += self.criterion(
self.compute_gram(x_vgg["relu4_3"]), self.compute_gram(y_vgg["relu4_3"])
)
return style_loss
def reduce_mean(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.mean(x, dim=i, keepdim=keepdim)
return x
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def flow_to_image(flow):
"""Transfer flow map to image.
Part of code forked from flownet.
"""
out = []
maxu = -999.
maxv = -999.
minu = 999.
minv = 999.
maxrad = -1
for i in range(flow.shape[0]):
u = flow[i, :, :, 0]
v = flow[i, :, :, 1]
idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7)
u[idxunknow] = 0
v[idxunknow] = 0
maxu = max(maxu, np.max(u))
minu = min(minu, np.min(u))
maxv = max(maxv, np.max(v))
minv = min(minv, np.min(v))
rad = np.sqrt(u ** 2 + v ** 2)
maxrad = max(maxrad, np.max(rad))
u = u / (maxrad + np.finfo(float).eps)
v = v / (maxrad + np.finfo(float).eps)
img = compute_color(u, v)
out.append(img)
return np.float32(np.uint8(out))
def make_color_wheel():
RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6)
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG))
colorwheel[col:col + YG, 1] = 255
col += YG
# GC
colorwheel[col:col + GC, 1] = 255
colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB))
colorwheel[col:col + CB, 2] = 255
col += CB
# BM
colorwheel[col:col + BM, 2] = 255
colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col + MR, 0] = 255
return colorwheel
def compute_color(u, v):
h, w = u.shape
img = np.zeros([h, w, 3])
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
# colorwheel = COLORWHEEL
colorwheel = make_color_wheel()
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u ** 2 + v ** 2)
a = np.arctan2(-v, -u) / np.pi
fk = (a + 1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols + 1] = 1
f = fk - k0
for i in range(np.size(colorwheel, 1)):
tmp = colorwheel[:, i]
col0 = tmp[k0 - 1] / 255
col1 = tmp[k1 - 1] / 255
col = (1 - f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1 - rad[idx] * (1 - col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx)))
return img
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
class ContextualAttention(nn.Module):
def __init__(self, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10,
fuse=False, use_cuda=False, device_ids=None):
super(ContextualAttention, self).__init__()
self.ksize = ksize
self.stride = stride
self.rate = rate
self.fuse_k = fuse_k
self.softmax_scale = softmax_scale
self.fuse = fuse
self.use_cuda = use_cuda
self.device_ids = device_ids
def forward(self, f, b, mask=None):
""" Contextual attention layer implementation.
Contextual attention is first introduced in publication:
Generative Image Inpainting with Contextual Attention, Yu et al.
Args:
f: Input feature to match (foreground).
b: Input feature for match (background).
mask: Input mask for b, indicating patches not available.
ksize: Kernel size for contextual attention.
stride: Stride for extracting patches from b.
rate: Dilation for matching.
softmax_scale: Scaled softmax for attention.
Returns:
torch.tensor: output
"""
# get shapes
raw_int_fs = list(f.size()) # b*c*h*w
raw_int_bs = list(b.size()) # b*c*h*w
# extract patches from background with stride and rate
kernel = 2 * self.rate
# raw_w is extracted for reconstruction
raw_w = extract_image_patches(b, ksizes=[kernel, kernel],
strides=[self.rate*self.stride,
self.rate*self.stride],
rates=[1, 1],
padding='same') # [N, C*k*k, L]
# raw_shape: [N, C, k, k, L]
raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1)
raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k]
raw_w_groups = torch.split(raw_w, 1, dim=0)
# downscaling foreground option: downscaling both foreground and
# background for matching and use original background for reconstruction.
f = F.interpolate(f, scale_factor=1./self.rate, mode='nearest')
b = F.interpolate(b, scale_factor=1./self.rate, mode='nearest')
int_fs = list(f.size()) # b*c*h*w
int_bs = list(b.size())
f_groups = torch.split(f, 1, dim=0) # split tensors along the batch dimension
# w shape: [N, C*k*k, L]
w = extract_image_patches(b, ksizes=[self.ksize, self.ksize],
strides=[self.stride, self.stride],
rates=[1, 1],
padding='same')
# w shape: [N, C, k, k, L]
w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1)
w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k]
w_groups = torch.split(w, 1, dim=0)
# process mask
if mask is None:
mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]])
if self.use_cuda:
mask = mask.cuda()
else:
mask = F.interpolate(mask, scale_factor=1./(4*self.rate), mode='nearest')
int_ms = list(mask.size())
# m shape: [N, C*k*k, L]
m = extract_image_patches(mask, ksizes=[self.ksize, self.ksize],
strides=[self.stride, self.stride],
rates=[1, 1],
padding='same')
# m shape: [N, C, k, k, L]
m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1)
m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k]
m = m[0] # m shape: [L, C, k, k]
# mm shape: [L, 1, 1, 1]
mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.).to(torch.float32)
mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1]
y = []
offsets = []
k = self.fuse_k
scale = self.softmax_scale # to fit the PyTorch tensor image value range
fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k
if self.use_cuda:
fuse_weight = fuse_weight.cuda()
for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups):
'''
O => output channel as a conv filter
I => input channel as a conv filter
xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
'''
# conv for compare
escape_NaN = torch.FloatTensor([1e-4])
if self.use_cuda:
escape_NaN = escape_NaN.cuda()
wi = wi[0] # [L, C, k, k]
max_wi = torch.sqrt(reduce_sum(torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True))
wi_normed = wi / max_wi
# xi shape: [1, C, H, W], yi shape: [1, L, H, W]
xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W
yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W]
# conv implementation for fuse scores to encourage large patches
if self.fuse:
# make all of depth to spatial resolution
yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3]) # (B=1, I=1, H=32*32, W=32*32)
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
yi = F.conv2d(yi, fuse_weight, stride=1) # (B=1, C=1, H=32*32, W=32*32)
yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]) # (B=1, 32, 32, 32, 32)
yi = yi.permute(0, 2, 1, 4, 3)
yi = yi.contiguous().view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3])
yi = same_padding(yi, [k, k], [1, 1], [1, 1])
yi = F.conv2d(yi, fuse_weight, stride=1)
yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2])
yi = yi.permute(0, 2, 1, 4, 3).contiguous()
yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]) # (B=1, C=32*32, H=32, W=32)
# softmax to match
yi = yi * mm
yi = F.softmax(yi*scale, dim=1)
yi = yi * mm # [1, L, H, W]
offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W
if int_bs != int_fs:
# Normalize the offset value to match foreground dimension
times = float(int_fs[2] * int_fs[3]) / float(int_bs[2] * int_bs[3])
offset = ((offset + 1).float() * times - 1).to(torch.int64)
offset = torch.cat([offset//int_fs[3], offset%int_fs[3]], dim=1) # 1*2*H*W
# deconv for patch pasting
wi_center = raw_wi[0]
# yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding
yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64)
y.append(yi)
offsets.append(offset)
y = torch.cat(y, dim=0) # back to the mini-batch
y.contiguous().view(raw_int_fs)
offsets = torch.cat(offsets, dim=0)
offsets = offsets.view(int_fs[0], 2, *int_fs[2:])
# case1: visualize optical flow: minus current position
h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(int_fs[0], -1, -1, int_fs[3])
w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(int_fs[0], -1, int_fs[2], -1)
ref_coordinate = torch.cat([h_add, w_add], dim=1)
if self.use_cuda:
ref_coordinate = ref_coordinate.cuda()
offsets = offsets - ref_coordinate
# flow = pt_flow_to_image(offsets)
flow = torch.from_numpy(flow_to_image(offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255.
flow = flow.permute(0, 3, 1, 2)
if self.use_cuda:
flow = flow.cuda()
# case2: visualize which pixels are attended
# flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))
if self.rate != 1:
flow = F.interpolate(flow, scale_factor=self.rate*4, mode='nearest')
return y, flow