import torch from torch import nn from torch.nn import Parameter import torch.nn.functional as F import torchvision.models as models import copy import numpy as np # import lpips #################################################################################################### # spectral normalization layer to decouple the magnitude of a weight tensor #################################################################################################### class VGG16(nn.Module): def __init__(self): super(VGG16, self).__init__() features = models.vgg16(pretrained=True).features self.relu1_1 = torch.nn.Sequential() self.relu1_2 = torch.nn.Sequential() self.relu2_1 = torch.nn.Sequential() self.relu2_2 = torch.nn.Sequential() self.relu3_1 = torch.nn.Sequential() self.relu3_2 = torch.nn.Sequential() self.relu3_3 = torch.nn.Sequential() self.max3 = torch.nn.Sequential() self.relu4_1 = torch.nn.Sequential() self.relu4_2 = torch.nn.Sequential() self.relu4_3 = torch.nn.Sequential() self.relu5_1 = torch.nn.Sequential() self.relu5_2 = torch.nn.Sequential() self.relu5_3 = torch.nn.Sequential() for x in range(2): self.relu1_1.add_module(str(x), features[x]) for x in range(2, 4): self.relu1_2.add_module(str(x), features[x]) for x in range(4, 7): self.relu2_1.add_module(str(x), features[x]) for x in range(7, 9): self.relu2_2.add_module(str(x), features[x]) for x in range(9, 12): self.relu3_1.add_module(str(x), features[x]) for x in range(12, 14): self.relu3_2.add_module(str(x), features[x]) for x in range(14, 16): self.relu3_3.add_module(str(x), features[x]) for x in range(16, 17): self.max3.add_module(str(x), features[x]) for x in range(17, 19): self.relu4_1.add_module(str(x), features[x]) for x in range(19, 21): self.relu4_2.add_module(str(x), features[x]) for x in range(21, 23): self.relu4_3.add_module(str(x), features[x]) for x in range(23, 26): self.relu5_1.add_module(str(x), features[x]) for x in range(26, 28): self.relu5_2.add_module(str(x), features[x]) for x in range(28, 30): self.relu5_3.add_module(str(x), features[x]) # don't need the gradients, just want the features for param in self.parameters(): param.requires_grad = False def forward(self, x): relu1_1 = self.relu1_1(x) relu1_2 = self.relu1_2(relu1_1) relu2_1 = self.relu2_1(relu1_2) relu2_2 = self.relu2_2(relu2_1) relu3_1 = self.relu3_1(relu2_2) relu3_2 = self.relu3_2(relu3_1) relu3_3 = self.relu3_3(relu3_2) max_3 = self.max3(relu3_3) relu4_1 = self.relu4_1(max_3) relu4_2 = self.relu4_2(relu4_1) relu4_3 = self.relu4_3(relu4_2) relu5_1 = self.relu5_1(relu4_3) relu5_2 = self.relu5_1(relu5_1) relu5_3 = self.relu5_1(relu5_2) out = { "relu1_1": relu1_1, "relu1_2": relu1_2, "relu2_1": relu2_1, "relu2_2": relu2_2, "relu3_1": relu3_1, "relu3_2": relu3_2, "relu3_3": relu3_3, "max_3": max_3, "relu4_1": relu4_1, "relu4_2": relu4_2, "relu4_3": relu4_3, "relu5_1": relu5_1, "relu5_2": relu5_2, "relu5_3": relu5_3, } return out def l2normalize(v, eps=1e-12): return v / (v.norm() + eps) class SpectralNorm(nn.Module): """ spectral normalization code and idea originally from Takeru Miyato's work 'Spectral Normalization for GAN' https://github.com/christiancosgrove/pytorch-spectral-normalization-gan """ def __init__(self, module, name='weight', power_iterations=1): super(SpectralNorm, self).__init__() self.module = module self.name = name self.power_iterations = power_iterations if not self._made_params(): self._make_params() def _update_u_v(self): u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") height = w.data.shape[0] for _ in range(self.power_iterations): v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) sigma = u.dot(w.view(height, -1).mv(v)) setattr(self.module, self.name, w / sigma.expand_as(w)) def _made_params(self): try: u = getattr(self.module, self.name + "_u") v = getattr(self.module, self.name + "_v") w = getattr(self.module, self.name + "_bar") return True except AttributeError: return False def _make_params(self): w = getattr(self.module, self.name) height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) u.data = l2normalize(u.data) v.data = l2normalize(v.data) w_bar = Parameter(w.data) del self.module._parameters[self.name] self.module.register_parameter(self.name + "_u", u) self.module.register_parameter(self.name + "_v", v) self.module.register_parameter(self.name + "_bar", w_bar) def forward(self, *args): self._update_u_v() return self.module.forward(*args) #################################################################################################### # adversarial loss for different gan mode #################################################################################################### class GANLoss(nn.Module): """Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input. """ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): """ Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) self.gan_mode = gan_mode if gan_mode == 'lsgan': self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif gan_mode == 'hinge': self.loss = nn.ReLU() elif gan_mode == 'wgangp': self.loss = None elif gan_mode == 'wgandiv': self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % gan_mode) def __call__(self, prediction, target_is_real, is_disc=False): """Calculate loss given Discriminator's output and grount truth labels. Parameters: prediction (tensor) - - tpyically the prediction output from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images Returns: the calculated loss. """ if self.gan_mode in ['lsgan', 'vanilla']: labels = (self.real_label if target_is_real else self.fake_label).expand_as(prediction).type_as(prediction) loss = self.loss(prediction, labels) elif self.gan_mode in ['hinge', 'wgangp']: if is_disc: if target_is_real: prediction = -prediction if self.gan_mode == 'hinge': loss = self.loss(1 + prediction).mean() elif self.gan_mode == 'wgangp': loss = prediction.mean() else: loss = -prediction.mean() elif self.gan_mode in ['wgandiv']: loss = prediction.mean() return loss class PD_Loss(nn.Module): def __init__(self): super(PD_Loss, self).__init__() self.criterion = torch.nn.L1Loss() def __call__(self, x, y): # Compute features pd_loss = 0.0 pd_loss = self.criterion(x, y) return pd_loss class TV_Loss(nn.Module): def __init__(self): super(TV_Loss, self).__init__() def __call__(self, image, mask, method): hole_mask = 1 - mask b, ch, h, w = hole_mask.shape dilation_conv = nn.Conv2d(ch, ch, 3, padding=1, bias=False).to(hole_mask) torch.nn.init.constant_(dilation_conv.weight, 1.0) with torch.no_grad(): output_mask = dilation_conv(hole_mask) updated_holes = output_mask != 0 dilated_holes = updated_holes.float() colomns_in_Pset = dilated_holes[:, :, :, 1:] * dilated_holes[:, :, :, :-1] rows_in_Pset = dilated_holes[:, :, 1:, :] * dilated_holes[:, :, :-1:, :] if method == "sum": loss = torch.sum( torch.abs(colomns_in_Pset * (image[:, :, :, 1:] - image[:, :, :, :-1])) ) + torch.sum( torch.abs(rows_in_Pset * (image[:, :, :1, :] - image[:, :, -1:, :])) ) else: loss = torch.mean( torch.abs(colomns_in_Pset * (image[:, :, :, 1:] - image[:, :, :, :-1])) ) + torch.mean( torch.abs(rows_in_Pset * (image[:, :, :1, :] - image[:, :, -1:, :])) ) return loss def cal_gradient_penalty(netD, real_data, fake_data, type='mixed', constant=1.0, lambda_gp=10.0): """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 Arguments: netD (network) -- discriminator network real_data (tensor array) -- real images fake_data (tensor array) -- generated images from the generator type (str) -- if we mix real and fake data or not [real | fake | mixed]. constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 lambda_gp (float) -- weight for this loss Returns the gradient penalty loss """ if lambda_gp > 0.0: if type == 'real': # either use real images, fake images, or a linear interpolation of two. interpolatesv = real_data elif type == 'fake': interpolatesv = fake_data elif type == 'mixed': alpha = torch.rand(real_data.shape[0], 1) alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) alpha = alpha.type_as(real_data) interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) else: raise NotImplementedError('{} not implemented'.format(type)) interpolatesv.requires_grad_(True) disc_interpolates = netD(interpolatesv) gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data), create_graph=True, retain_graph=True, only_inputs=True) gradients = gradients[0].view(real_data.size(0), -1) # flat the data gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps return gradient_penalty, gradients else: return 0.0, None def cal_gradient_penalty_div(netD, real_data, fake_data, type='mixed', const_power=6.0, const_kappa=2.0): if type == 'real': # either use real images, fake images, or a linear interpolation of two. interpolatesv = real_data elif type == 'fake': interpolatesv = fake_data elif type == 'mixed': alpha = torch.rand(real_data.shape[0], 1) alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view( *real_data.shape) alpha = alpha.type_as(real_data) interpolatesv = (1 - alpha) * real_data + alpha * fake_data else: raise NotImplementedError('{} not implemented'.format(type)) interpolatesv.requires_grad_(True) disc_interpolates = netD(interpolatesv) gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data), create_graph=True, retain_graph=True, only_inputs=True) gradients = gradients[0].view(real_data.size(0), -1) # flat the data gradients_penalty_div = torch.pow((gradients + 1e-16).norm(2, dim=1), const_power).mean() * const_kappa return gradients_penalty_div #################################################################################################### # neural style transform loss from neural_style_tutorial of pytorch #################################################################################################### def ContentLoss(input, target): target = target.detach() loss = F.l1_loss(input, target) return loss def GramMatrix(input): s = input.size() features = input.view(s[0], s[1], s[2]*s[3]) features_t = torch.transpose(features, 1, 2) G = torch.bmm(features, features_t).div(s[1]*s[2]*s[3]) return G def StyleLoss(input, target): target = GramMatrix(target).detach() input = GramMatrix(input) loss = F.l1_loss(input, target) return loss def img_crop(input, size=224): input_cropped = F.upsample(input, size=(size, size), mode='bilinear', align_corners=True) return input_cropped class Normalization(nn.Module): def __init__(self, mean, std): super(Normalization, self).__init__() self.mean = mean.view(-1, 1, 1) self.std = std.view(-1, 1, 1) def forward(self, input): return (input-self.mean) / self.std class get_features(nn.Module): def __init__(self, cnn): super(get_features, self).__init__() vgg = copy.deepcopy(cnn) self.conv1 = nn.Sequential(vgg[0], vgg[1], vgg[2], vgg[3], vgg[4]) self.conv2 = nn.Sequential(vgg[5], vgg[6], vgg[7], vgg[8], vgg[9]) self.conv3 = nn.Sequential(vgg[10], vgg[11], vgg[12], vgg[13], vgg[14], vgg[15], vgg[16]) self.conv4 = nn.Sequential(vgg[17], vgg[18], vgg[19], vgg[20], vgg[21], vgg[22], vgg[23]) self.conv5 = nn.Sequential(vgg[24], vgg[25], vgg[26], vgg[27], vgg[28], vgg[29], vgg[30]) def forward(self, input, layers): input = img_crop(input) output = [] for i in range(1, layers): layer = getattr(self, 'conv'+str(i)) input = layer(input) output.append(input) return output class GroupNorm(nn.Module): def __init__(self, num_features, num_groups=32, eps=1e-5): super(GroupNorm, self).__init__() self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1)) self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1)) self.num_groups = num_groups self.eps = eps def forward(self, x): N, C, H, W = x.size() G = self.num_groups # assert C % G == 0 x = x.view(N,G,-1) mean = x.mean(-1, keepdim=True) var = x.var(-1, keepdim=True) x = (x-mean) / (var+self.eps).sqrt() x = x.view(N, C, H, W) return x * self.weight + self.bias class FullAttention(nn.Module): """ Layer implements my version of the self-attention module it is mostly same as self attention, but generalizes to (k x k) convolutions instead of (1 x 1) args: in_channels: number of input channels out_channels: number of output channels activation: activation function to be applied (default: lrelu(0.2)) kernel_size: kernel size for convolution (default: (1 x 1)) transpose_conv: boolean denoting whether to use convolutions or transpose convolutions squeeze_factor: squeeze factor for query and keys (default: 8) stride: stride for the convolutions (default: 1) padding: padding for the applied convolutions (default: 1) bias: whether to apply bias or not (default: True) """ def __init__(self, in_channels, out_channels, activation=nn.LeakyReLU(0.2), kernel_size=(3, 3), transpose_conv=False, use_spectral_norm=True, use_batch_norm=True, squeeze_factor=8, stride=1, padding=1, bias=True): """ constructor for the layer """ from torch.nn import Conv2d, Parameter, \ Softmax, ConvTranspose2d, BatchNorm2d, InstanceNorm2d # base constructor call super().__init__() # state of the layer self.activation = activation self.gamma = Parameter(torch.zeros(1)) self.in_channels = in_channels self.out_channels = out_channels self.squeezed_channels = in_channels // squeeze_factor self.use_batch_norm = use_batch_norm # Modules required for computations if transpose_conv: self.query_conv = ConvTranspose2d( # query convolution in_channels=in_channels, out_channels=in_channels // squeeze_factor, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.key_conv = ConvTranspose2d( in_channels=in_channels, out_channels=in_channels // squeeze_factor, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.value_conv = ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.residual_conv = ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) if not use_spectral_norm else SpectralNorm( ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) ) else: self.query_conv = Conv2d( # query convolution in_channels=in_channels, out_channels=in_channels // squeeze_factor, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.key_conv = Conv2d( in_channels=in_channels, out_channels=in_channels // squeeze_factor, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.value_conv = Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) self.residual_conv = Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) if not use_spectral_norm else SpectralNorm( Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) ) # softmax module for applying attention self.softmax = Softmax(dim=-1) # self.batch_norm = BatchNorm2d(out_channels) self.batch_norm = InstanceNorm2d(out_channels) def forward(self, x): """ forward computations of the layer :param x: input feature maps (B x C x H x W) :return: out: self attention value + input feature (B x O x H x W) attention: attention map (B x C x H x W) """ # extract the batch size of the input tensor m_batchsize, _, _, _ = x.size() # create the query projection proj_query = self.query_conv(x).view( m_batchsize, self.squeezed_channels, -1).permute(0, 2, 1) # B x (N) x C # create the key projection proj_key = self.key_conv(x).view( m_batchsize, self.squeezed_channels, -1) # B x C x (N) # calculate the attention maps energy = torch.bmm(proj_query, proj_key) # energy attention = self.softmax(energy) # attention (B x (N) x (N)) # create the value projection proj_value = self.value_conv(x).view( m_batchsize, self.out_channels, -1) # B X C X N # calculate the output out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # calculate the residual output res_out = self.residual_conv(x) out = out.view(m_batchsize, self.out_channels, res_out.shape[-2], res_out.shape[-1]) attention = attention.view(m_batchsize, -1, res_out.shape[-2], res_out.shape[-1]) if self.use_batch_norm: res_out = self.batch_norm(res_out) if self.activation is not None: out = self.activation(out) res_out = self.activation(res_out) # apply the residual connections out = (self.gamma * out) + ((1 - self.gamma) * res_out) return out, attention class Diversityloss(nn.Module): def __init__(self): super(Diversityloss, self).__init__() self.vgg = VGG16().cuda() self.criterion = nn.L1Loss() self.weights = [1.0, 1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0] def forward(self, x, y): x_vgg, y_vgg = self.vgg(x), self.vgg(y) diversity_loss = 0.0 diversity_loss += self.weights[4] * self.criterion( x_vgg["relu4_1"], y_vgg["relu4_1"] ) return diversity_loss class PerceptualLoss(nn.Module): r""" Perceptual loss, VGG-based https://arxiv.org/abs/1603.08155 https://github.com/dxyang/StyleTransfer/blob/master/utils.py """ def __init__(self, weights=[1.0 / 2, 1.0 / 4, 1.0 / 8, 1.0 / 16, 1.0]): super(PerceptualLoss, self).__init__() self.add_module("vgg", VGG16().cuda()) self.criterion = torch.nn.L1Loss() self.weights = weights def __call__(self, x, y): # Compute features x_vgg, y_vgg = self.vgg(x), self.vgg(y) content_loss = 0.0 content_loss += self.weights[0] * self.criterion( x_vgg["relu1_1"], y_vgg["relu1_1"] ) content_loss += self.weights[1] * self.criterion( x_vgg["relu2_1"], y_vgg["relu2_1"] ) content_loss += self.weights[2] * self.criterion( x_vgg["relu3_1"], y_vgg["relu3_1"] ) content_loss += self.weights[3] * self.criterion( x_vgg["relu4_1"], y_vgg["relu4_1"] ) content_loss += self.weights[4] * self.criterion( x_vgg["relu5_1"], y_vgg["relu5_1"] ) return content_loss class StyleLoss(nn.Module): r""" Perceptual loss, VGG-based https://arxiv.org/abs/1603.08155 https://github.com/dxyang/StyleTransfer/blob/master/utils.py """ def __init__(self): super(StyleLoss, self).__init__() self.add_module("vgg", VGG16().cuda()) self.criterion = torch.nn.L1Loss() def compute_gram(self, x): b, ch, h, w = x.size() f = x.view(b, ch, w * h) f_T = f.transpose(1, 2) G = f.bmm(f_T) / (h * w * ch) return G def __call__(self, x, y): # Compute features x_vgg, y_vgg = self.vgg(x), self.vgg(y) # Compute loss style_loss = 0.0 style_loss += self.criterion( self.compute_gram(x_vgg["relu2_2"]), self.compute_gram(y_vgg["relu2_2"]) ) style_loss += self.criterion( self.compute_gram(x_vgg["relu3_3"]), self.compute_gram(y_vgg["relu3_3"]) ) style_loss += self.criterion( self.compute_gram(x_vgg["relu4_3"]), self.compute_gram(y_vgg["relu4_3"]) ) return style_loss def reduce_mean(x, axis=None, keepdim=False): if not axis: axis = range(len(x.shape)) for i in sorted(axis, reverse=True): x = torch.mean(x, dim=i, keepdim=keepdim) return x def same_padding(images, ksizes, strides, rates): assert len(images.size()) == 4 batch_size, channel, rows, cols = images.size() out_rows = (rows + strides[0] - 1) // strides[0] out_cols = (cols + strides[1] - 1) // strides[1] effective_k_row = (ksizes[0] - 1) * rates[0] + 1 effective_k_col = (ksizes[1] - 1) * rates[1] + 1 padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) # Pad the input padding_top = int(padding_rows / 2.) padding_left = int(padding_cols / 2.) padding_bottom = padding_rows - padding_top padding_right = padding_cols - padding_left paddings = (padding_left, padding_right, padding_top, padding_bottom) images = torch.nn.ZeroPad2d(paddings)(images) return images def extract_image_patches(images, ksizes, strides, rates, padding='same'): """ Extract patches from images and put them in the C output dimension. :param padding: :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for each dimension of images :param strides: [stride_rows, stride_cols] :param rates: [dilation_rows, dilation_cols] :return: A Tensor """ assert len(images.size()) == 4 assert padding in ['same', 'valid'] batch_size, channel, height, width = images.size() if padding == 'same': images = same_padding(images, ksizes, strides, rates) elif padding == 'valid': pass else: raise NotImplementedError('Unsupported padding type: {}.\ Only "same" or "valid" are supported.'.format(padding)) unfold = torch.nn.Unfold(kernel_size=ksizes, dilation=rates, padding=0, stride=strides) patches = unfold(images) return patches # [N, C*k*k, L], L is the total number of such blocks def flow_to_image(flow): """Transfer flow map to image. Part of code forked from flownet. """ out = [] maxu = -999. maxv = -999. minu = 999. minv = 999. maxrad = -1 for i in range(flow.shape[0]): u = flow[i, :, :, 0] v = flow[i, :, :, 1] idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7) u[idxunknow] = 0 v[idxunknow] = 0 maxu = max(maxu, np.max(u)) minu = min(minu, np.min(u)) maxv = max(maxv, np.max(v)) minv = min(minv, np.min(v)) rad = np.sqrt(u ** 2 + v ** 2) maxrad = max(maxrad, np.max(rad)) u = u / (maxrad + np.finfo(float).eps) v = v / (maxrad + np.finfo(float).eps) img = compute_color(u, v) out.append(img) return np.float32(np.uint8(out)) def make_color_wheel(): RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6) ncols = RY + YG + GC + CB + BM + MR colorwheel = np.zeros([ncols, 3]) col = 0 # RY colorwheel[0:RY, 0] = 255 colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) col += RY # YG colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) colorwheel[col:col + YG, 1] = 255 col += YG # GC colorwheel[col:col + GC, 1] = 255 colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) col += GC # CB colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) colorwheel[col:col + CB, 2] = 255 col += CB # BM colorwheel[col:col + BM, 2] = 255 colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) col += + BM # MR colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) colorwheel[col:col + MR, 0] = 255 return colorwheel def compute_color(u, v): h, w = u.shape img = np.zeros([h, w, 3]) nanIdx = np.isnan(u) | np.isnan(v) u[nanIdx] = 0 v[nanIdx] = 0 # colorwheel = COLORWHEEL colorwheel = make_color_wheel() ncols = np.size(colorwheel, 0) rad = np.sqrt(u ** 2 + v ** 2) a = np.arctan2(-v, -u) / np.pi fk = (a + 1) / 2 * (ncols - 1) + 1 k0 = np.floor(fk).astype(int) k1 = k0 + 1 k1[k1 == ncols + 1] = 1 f = fk - k0 for i in range(np.size(colorwheel, 1)): tmp = colorwheel[:, i] col0 = tmp[k0 - 1] / 255 col1 = tmp[k1 - 1] / 255 col = (1 - f) * col0 + f * col1 idx = rad <= 1 col[idx] = 1 - rad[idx] * (1 - col[idx]) notidx = np.logical_not(idx) col[notidx] *= 0.75 img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) return img def reduce_sum(x, axis=None, keepdim=False): if not axis: axis = range(len(x.shape)) for i in sorted(axis, reverse=True): x = torch.sum(x, dim=i, keepdim=keepdim) return x class ContextualAttention(nn.Module): def __init__(self, ksize=3, stride=1, rate=1, fuse_k=3, softmax_scale=10, fuse=False, use_cuda=False, device_ids=None): super(ContextualAttention, self).__init__() self.ksize = ksize self.stride = stride self.rate = rate self.fuse_k = fuse_k self.softmax_scale = softmax_scale self.fuse = fuse self.use_cuda = use_cuda self.device_ids = device_ids def forward(self, f, b, mask=None): """ Contextual attention layer implementation. Contextual attention is first introduced in publication: Generative Image Inpainting with Contextual Attention, Yu et al. Args: f: Input feature to match (foreground). b: Input feature for match (background). mask: Input mask for b, indicating patches not available. ksize: Kernel size for contextual attention. stride: Stride for extracting patches from b. rate: Dilation for matching. softmax_scale: Scaled softmax for attention. Returns: torch.tensor: output """ # get shapes raw_int_fs = list(f.size()) # b*c*h*w raw_int_bs = list(b.size()) # b*c*h*w # extract patches from background with stride and rate kernel = 2 * self.rate # raw_w is extracted for reconstruction raw_w = extract_image_patches(b, ksizes=[kernel, kernel], strides=[self.rate*self.stride, self.rate*self.stride], rates=[1, 1], padding='same') # [N, C*k*k, L] # raw_shape: [N, C, k, k, L] raw_w = raw_w.view(raw_int_bs[0], raw_int_bs[1], kernel, kernel, -1) raw_w = raw_w.permute(0, 4, 1, 2, 3) # raw_shape: [N, L, C, k, k] raw_w_groups = torch.split(raw_w, 1, dim=0) # downscaling foreground option: downscaling both foreground and # background for matching and use original background for reconstruction. f = F.interpolate(f, scale_factor=1./self.rate, mode='nearest') b = F.interpolate(b, scale_factor=1./self.rate, mode='nearest') int_fs = list(f.size()) # b*c*h*w int_bs = list(b.size()) f_groups = torch.split(f, 1, dim=0) # split tensors along the batch dimension # w shape: [N, C*k*k, L] w = extract_image_patches(b, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride], rates=[1, 1], padding='same') # w shape: [N, C, k, k, L] w = w.view(int_bs[0], int_bs[1], self.ksize, self.ksize, -1) w = w.permute(0, 4, 1, 2, 3) # w shape: [N, L, C, k, k] w_groups = torch.split(w, 1, dim=0) # process mask if mask is None: mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]]) if self.use_cuda: mask = mask.cuda() else: mask = F.interpolate(mask, scale_factor=1./(4*self.rate), mode='nearest') int_ms = list(mask.size()) # m shape: [N, C*k*k, L] m = extract_image_patches(mask, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride], rates=[1, 1], padding='same') # m shape: [N, C, k, k, L] m = m.view(int_ms[0], int_ms[1], self.ksize, self.ksize, -1) m = m.permute(0, 4, 1, 2, 3) # m shape: [N, L, C, k, k] m = m[0] # m shape: [L, C, k, k] # mm shape: [L, 1, 1, 1] mm = (reduce_mean(m, axis=[1, 2, 3], keepdim=True) == 0.).to(torch.float32) mm = mm.permute(1, 0, 2, 3) # mm shape: [1, L, 1, 1] y = [] offsets = [] k = self.fuse_k scale = self.softmax_scale # to fit the PyTorch tensor image value range fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k if self.use_cuda: fuse_weight = fuse_weight.cuda() for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): ''' O => output channel as a conv filter I => input channel as a conv filter xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32) wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3) raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4) ''' # conv for compare escape_NaN = torch.FloatTensor([1e-4]) if self.use_cuda: escape_NaN = escape_NaN.cuda() wi = wi[0] # [L, C, k, k] max_wi = torch.sqrt(reduce_sum(torch.pow(wi, 2) + escape_NaN, axis=[1, 2, 3], keepdim=True)) wi_normed = wi / max_wi # xi shape: [1, C, H, W], yi shape: [1, L, H, W] xi = same_padding(xi, [self.ksize, self.ksize], [1, 1], [1, 1]) # xi: 1*c*H*W yi = F.conv2d(xi, wi_normed, stride=1) # [1, L, H, W] # conv implementation for fuse scores to encourage large patches if self.fuse: # make all of depth to spatial resolution yi = yi.view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3]) # (B=1, I=1, H=32*32, W=32*32) yi = same_padding(yi, [k, k], [1, 1], [1, 1]) yi = F.conv2d(yi, fuse_weight, stride=1) # (B=1, C=1, H=32*32, W=32*32) yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]) # (B=1, 32, 32, 32, 32) yi = yi.permute(0, 2, 1, 4, 3) yi = yi.contiguous().view(1, 1, int_bs[2]*int_bs[3], int_fs[2]*int_fs[3]) yi = same_padding(yi, [k, k], [1, 1], [1, 1]) yi = F.conv2d(yi, fuse_weight, stride=1) yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]) yi = yi.permute(0, 2, 1, 4, 3).contiguous() yi = yi.view(1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]) # (B=1, C=32*32, H=32, W=32) # softmax to match yi = yi * mm yi = F.softmax(yi*scale, dim=1) yi = yi * mm # [1, L, H, W] offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W if int_bs != int_fs: # Normalize the offset value to match foreground dimension times = float(int_fs[2] * int_fs[3]) / float(int_bs[2] * int_bs[3]) offset = ((offset + 1).float() * times - 1).to(torch.int64) offset = torch.cat([offset//int_fs[3], offset%int_fs[3]], dim=1) # 1*2*H*W # deconv for patch pasting wi_center = raw_wi[0] # yi = F.pad(yi, [0, 1, 0, 1]) # here may need conv_transpose same padding yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64) y.append(yi) offsets.append(offset) y = torch.cat(y, dim=0) # back to the mini-batch y.contiguous().view(raw_int_fs) offsets = torch.cat(offsets, dim=0) offsets = offsets.view(int_fs[0], 2, *int_fs[2:]) # case1: visualize optical flow: minus current position h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(int_fs[0], -1, -1, int_fs[3]) w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(int_fs[0], -1, int_fs[2], -1) ref_coordinate = torch.cat([h_add, w_add], dim=1) if self.use_cuda: ref_coordinate = ref_coordinate.cuda() offsets = offsets - ref_coordinate # flow = pt_flow_to_image(offsets) flow = torch.from_numpy(flow_to_image(offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255. flow = flow.permute(0, 3, 1, 2) if self.use_cuda: flow = flow.cuda() # case2: visualize which pixels are attended # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy())) if self.rate != 1: flow = F.interpolate(flow, scale_factor=self.rate*4, mode='nearest') return y, flow