You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
578 lines
24 KiB
578 lines
24 KiB
import math
|
|
|
|
from .base_function import *
|
|
from .external_function import SpectralNorm, FullAttention
|
|
import torch.nn.functional as F
|
|
|
|
|
|
##############################################################################################################
|
|
# Network function
|
|
##############################################################################################################
|
|
def define_e(input_nc=3, ngf=64, z_nc=512, img_f=512, L=6, layers=5, norm='none', activation='ReLU', use_spect=True,
|
|
use_coord=False, init_type='orthogonal', gpu_ids=[]):
|
|
|
|
net = ResEncoder(input_nc, ngf, z_nc, img_f, L, layers, norm, activation, use_spect, use_coord)
|
|
|
|
return init_net(net, init_type, activation, gpu_ids)
|
|
|
|
|
|
def define_g(output_nc=3, ngf=64, z_nc=512, img_f=512, L=1, layers=5, norm='instance', activation='ReLU', output_scale=1,
|
|
use_spect=True, use_coord=False, use_attn=True, init_type='orthogonal', gpu_ids=[]):
|
|
|
|
net = ResGenerator(output_nc, ngf, z_nc, img_f, L, layers, norm, activation, output_scale, use_spect, use_coord, use_attn)
|
|
|
|
return init_net(net, init_type, activation, gpu_ids)
|
|
|
|
|
|
def define_pd_g(output_nc=3, ngf=64, z_nc=512, img_f=512, L=1, layers=5, norm='instance', activation='ReLU', output_scale=1,
|
|
use_spect=True, use_coord=False, use_attn=True, use_gated=False, init_type='orthogonal', gpu_ids=[]):
|
|
|
|
net = PD_Generator(output_nc, ngf, z_nc, img_f, L, layers, norm, activation, output_scale, use_spect, use_coord, use_attn=use_attn, use_gated=use_gated)
|
|
|
|
return init_net(net, init_type, activation, gpu_ids)
|
|
|
|
|
|
def define_d(input_nc=3, ndf=64, img_f=512, layers=6, norm='none', activation='LeakyReLU', use_spect=True, use_coord=False,
|
|
use_attn=True, model_type='ResDis', init_type='orthogonal', gpu_ids=[]):
|
|
|
|
if model_type == 'ResDis':
|
|
net = ResDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord, use_attn)
|
|
elif model_type == 'PatchDis':
|
|
net = PatchDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord, use_attn)
|
|
|
|
return init_net(net, init_type, activation, gpu_ids)
|
|
|
|
|
|
def define_pd_d(input_nc=3, ndf=64, img_f=512, layers=6, norm='none', activation='LeakyReLU', use_spect=True, use_coord=False,
|
|
use_attn=True, use_gated=False, model_type='ResDis', init_type='orthogonal', gpu_ids=[]):
|
|
|
|
if model_type == 'ResDis':
|
|
net = PD_Discriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord, use_attn, use_gated=use_gated)
|
|
elif model_type == 'PatchDis':
|
|
net = PatchDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord)
|
|
else:
|
|
net = MultiPatchCoordDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord)
|
|
|
|
return init_net(net, init_type, activation, gpu_ids)
|
|
|
|
#############################################################################################################
|
|
# Network structure
|
|
#############################################################################################################
|
|
class ResEncoder(nn.Module):
|
|
"""
|
|
ResNet Encoder Network
|
|
:param input_nc: number of channels in input
|
|
:param ngf: base filter channel
|
|
:param z_nc: latent channels
|
|
:param img_f: the largest feature channels
|
|
:param L: Number of refinements of density
|
|
:param layers: down and up sample layers
|
|
:param norm: normalization function 'instance, batch, group'
|
|
:param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
|
|
"""
|
|
def __init__(self, input_nc=3, ngf=64, z_nc=128, img_f=1024, L=6, layers=6, norm='none', activation='ReLU',
|
|
use_spect=True, use_coord=False):
|
|
super(ResEncoder, self).__init__()
|
|
|
|
self.layers = layers
|
|
self.z_nc = z_nc
|
|
self.L = L
|
|
|
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
nonlinearity = get_nonlinearity_layer(activation_type=activation)
|
|
# encoder part
|
|
self.block0 = ResBlockEncoderOptimized(input_nc, ngf, norm_layer, nonlinearity, use_spect, use_coord)
|
|
|
|
mult = 1
|
|
for i in range(layers-1):
|
|
mult_prev = mult
|
|
mult = min(2 ** (i + 1), img_f // ngf)
|
|
block = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
|
|
setattr(self, 'encoder' + str(i), block)
|
|
|
|
# inference part
|
|
for i in range(self.L):
|
|
block = ResBlock(ngf * mult, ngf * mult, ngf *mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
|
|
setattr(self, 'infer_prior' + str(i), block)
|
|
|
|
self.posterior = ResBlock(ngf * mult, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
|
|
self.prior = ResBlock(ngf * mult, 2*z_nc, ngf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
|
|
|
|
def forward(self, img_m, img_c=None):
|
|
"""
|
|
:param img_m: image with mask regions I_m
|
|
:param img_c: complement of I_m, the mask regions
|
|
:return distribution: distribution of mask regions, for training we have two paths, testing one path
|
|
:return feature: the conditional feature f_m, and the previous f_pre for auto context attention
|
|
"""
|
|
|
|
if type(img_c) != type(None):
|
|
img = torch.cat([img_m, img_c], dim=0)
|
|
else:
|
|
img = img_m
|
|
|
|
# encoder part
|
|
out = self.block0(img)
|
|
feature = [out]
|
|
for i in range(self.layers-1):
|
|
model = getattr(self, 'encoder' + str(i))
|
|
out = model(out)
|
|
feature.append(out)
|
|
|
|
# infer part
|
|
# during the training, we have two paths, during the testing, we only have one paths
|
|
if type(img_c) != type(None):
|
|
distribution = self.two_paths(out)
|
|
return distribution, feature
|
|
else:
|
|
distribution = self.one_path(out)
|
|
return distribution, feature
|
|
|
|
def one_path(self, f_in):
|
|
"""one path for baseline training or testing"""
|
|
f_m = f_in
|
|
distribution = []
|
|
|
|
# infer state
|
|
for i in range(self.L):
|
|
infer_prior = getattr(self, 'infer_prior' + str(i))
|
|
f_m = infer_prior(f_m)
|
|
|
|
# get distribution
|
|
o = self.prior(f_m)
|
|
q_mu, q_std = torch.split(o, self.z_nc, dim=1)
|
|
distribution.append([q_mu, F.softplus(q_std)])
|
|
|
|
return distribution
|
|
|
|
def two_paths(self, f_in):
|
|
"""two paths for the training"""
|
|
f_m, f_c = f_in.chunk(2)
|
|
distributions = []
|
|
|
|
# get distribution
|
|
o = self.posterior(f_c)
|
|
p_mu, p_std = torch.split(o, self.z_nc, dim=1)
|
|
distribution = self.one_path(f_m)
|
|
distributions.append([p_mu, F.softplus(p_std), distribution[0][0], distribution[0][1]])
|
|
|
|
return distributions
|
|
|
|
|
|
class ResGenerator(nn.Module):
|
|
"""
|
|
ResNet Generator Network
|
|
:param output_nc: number of channels in output
|
|
:param ngf: base filter channel
|
|
:param z_nc: latent channels
|
|
:param img_f: the largest feature channels
|
|
:param L: Number of refinements of density
|
|
:param layers: down and up sample layers
|
|
:param norm: normalization function 'instance, batch, group'
|
|
:param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
|
|
:param output_scale: Different output scales
|
|
"""
|
|
def __init__(self, output_nc=3, ngf=64, z_nc=128, img_f=1024, L=1, layers=6, norm='batch', activation='ReLU',
|
|
output_scale=1, use_spect=True, use_coord=False, use_attn=True):
|
|
super(ResGenerator, self).__init__()
|
|
|
|
self.layers = layers
|
|
self.L = L
|
|
self.output_scale = output_scale
|
|
self.use_attn = use_attn
|
|
|
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
nonlinearity = get_nonlinearity_layer(activation_type=activation)
|
|
# latent z to feature
|
|
mult = min(2 ** (layers-1), img_f // ngf)
|
|
self.generator = ResBlock(z_nc, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
|
|
|
|
# transform
|
|
for i in range(self.L):
|
|
block = ResBlock(ngf * mult, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord)
|
|
setattr(self, 'generator' + str(i), block)
|
|
|
|
# decoder part
|
|
for i in range(layers):
|
|
mult_prev = mult
|
|
mult = min(2 ** (layers - i - 1), img_f // ngf)
|
|
if i > layers - output_scale:
|
|
# upconv = ResBlock(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
|
|
upconv = ResBlockDecoder(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
|
|
else:
|
|
# upconv = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True)
|
|
upconv = ResBlockDecoder(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
|
|
setattr(self, 'decoder' + str(i), upconv)
|
|
# output part
|
|
if i > layers - output_scale - 1:
|
|
outconv = Output(ngf * mult, output_nc, 3, None, nonlinearity, use_spect, use_coord)
|
|
setattr(self, 'out' + str(i), outconv)
|
|
# short+long term attention part
|
|
if i == 1 and use_attn:
|
|
# simam = SimAM()
|
|
# setattr(self, 'simam' + str(i), simam)
|
|
attn = Auto_Attn(ngf*mult, None)
|
|
setattr(self, 'attn' + str(i), attn)
|
|
|
|
def forward(self, z, f_m=None, f_e=None, mask=None):
|
|
"""
|
|
ResNet Generator Network
|
|
:param z: latent vector
|
|
:param f_m: feature of valid regions for conditional VAG-GAN
|
|
:param f_e: previous encoder feature for short+long term attention layer
|
|
:return results: different scale generation outputs
|
|
"""
|
|
|
|
f = self.generator(z)
|
|
for i in range(self.L):
|
|
generator = getattr(self, 'generator' + str(i))
|
|
f = generator(f)
|
|
|
|
# the features come from mask regions and valid regions, we directly add them together
|
|
out = f_m + f
|
|
results= []
|
|
attn = 0
|
|
for i in range(self.layers):
|
|
model = getattr(self, 'decoder' + str(i))
|
|
out = model(out)
|
|
if i == 1 and self.use_attn:
|
|
# auto attention
|
|
# model = getattr(self, 'simam' + str(i))
|
|
# out = model(out)
|
|
model = getattr(self, 'attn' + str(i))
|
|
out, attn = model(out, f_e, mask)
|
|
if i > self.layers - self.output_scale - 1:
|
|
model = getattr(self, 'out' + str(i))
|
|
output = model(out)
|
|
results.append(output)
|
|
out = torch.cat([out, output], dim=1)
|
|
|
|
return results, attn
|
|
|
|
|
|
class ResDiscriminator(nn.Module):
|
|
"""
|
|
ResNet Discriminator Network
|
|
:param input_nc: number of channels in input
|
|
:param ndf: base filter channel
|
|
:param layers: down and up sample layers
|
|
:param img_f: the largest feature channels
|
|
:param norm: normalization function 'instance, batch, group'
|
|
:param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
|
|
"""
|
|
def __init__(self, input_nc=3, ndf=64, img_f=1024, layers=6, norm='none', activation='LeakyReLU', use_spect=True,
|
|
use_coord=False, use_attn=True):
|
|
super(ResDiscriminator, self).__init__()
|
|
|
|
self.layers = layers
|
|
self.use_attn = use_attn
|
|
|
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
nonlinearity = get_nonlinearity_layer(activation_type=activation)
|
|
self.nonlinearity = nonlinearity
|
|
|
|
# encoder part
|
|
self.block0 = ResBlockEncoderOptimized(input_nc, ndf,norm_layer, nonlinearity, use_spect, use_coord)
|
|
|
|
mult = 1
|
|
for i in range(layers - 1):
|
|
mult_prev = mult
|
|
mult = min(2 ** (i + 1), img_f // ndf)
|
|
# self-attention
|
|
if i == 2 and use_attn:
|
|
# simam = SimAM()
|
|
# setattr(self, 'simam' + str(i), simam)
|
|
attn = Auto_Attn(ndf * mult_prev, norm_layer)
|
|
setattr(self, 'attn' + str(i), attn)
|
|
block = ResBlock(ndf * mult_prev, ndf * mult, ndf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord)
|
|
setattr(self, 'encoder' + str(i), block)
|
|
|
|
self.block1 = ResBlock(ndf * mult, ndf * mult, ndf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord)
|
|
self.conv = SpectralNorm(nn.Conv2d(ndf * mult, 1, 3))
|
|
|
|
def forward(self, x):
|
|
out = self.block0(x)
|
|
for i in range(self.layers - 1):
|
|
if i == 2 and self.use_attn:
|
|
# simam = getattr(self, 'simam' + str(i))
|
|
# out = simam(out)
|
|
attn = getattr(self, 'attn' + str(i))
|
|
out, attention = attn(out)
|
|
model = getattr(self, 'encoder' + str(i))
|
|
out = model(out)
|
|
out = self.block1(out)
|
|
out = self.conv(self.nonlinearity(out))
|
|
return out
|
|
|
|
|
|
class PatchDiscriminator(nn.Module):
|
|
"""
|
|
Patch Discriminator Network for Local 70*70 fake/real
|
|
:param input_nc: number of channels in input
|
|
:param ndf: base filter channel
|
|
:param img_f: the largest channel for the model
|
|
:param layers: down sample layers
|
|
:param norm: normalization function 'instance, batch, group'
|
|
:param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
|
|
:param use_spect: use spectral normalization or not
|
|
:param use_coord: use CoordConv or nor
|
|
:param use_attn: use short+long attention or not
|
|
"""
|
|
def __init__(self, input_nc=3, ndf=64, img_f=512, layers=3, norm='batch', activation='LeakyReLU', use_spect=True,
|
|
use_coord=False):
|
|
super(PatchDiscriminator, self).__init__()
|
|
|
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
nonlinearity = get_nonlinearity_layer(activation_type=activation)
|
|
|
|
kwargs = {'kernel_size': 4, 'stride': 2, 'padding': 1, 'bias': False}
|
|
sequence = [
|
|
coord_conv(input_nc, ndf, use_spect, use_coord, **kwargs),
|
|
nonlinearity,
|
|
]
|
|
|
|
mult = 1
|
|
for i in range(1, layers):
|
|
mult_prev = mult
|
|
mult = min(2 ** i, img_f // ndf)
|
|
sequence +=[
|
|
coord_conv(ndf * mult_prev, ndf * mult, use_spect, use_coord, **kwargs),
|
|
nonlinearity,
|
|
]
|
|
|
|
mult_prev = mult
|
|
mult = min(2 ** i, img_f // ndf)
|
|
kwargs = {'kernel_size': 4, 'stride': 1, 'padding': 1, 'bias': False}
|
|
sequence += [
|
|
coord_conv(ndf * mult_prev, ndf * mult, use_spect, use_coord, **kwargs),
|
|
nonlinearity,
|
|
coord_conv(ndf * mult, 1, use_spect, use_coord, **kwargs),
|
|
]
|
|
|
|
self.model = nn.Sequential(*sequence)
|
|
|
|
def forward(self, x):
|
|
# print(x.shape)
|
|
out = self.model(x)
|
|
# print(out.shape)
|
|
return out
|
|
|
|
|
|
class PD_Generator(nn.Module):
|
|
"""
|
|
PD Generator Network
|
|
:param output_nc: number of channels in output
|
|
:param ngf: base filter channel
|
|
:param z_nc: latent channels
|
|
:param img_f: the largest feature channels
|
|
:param L: Number of refinements of density
|
|
:param layers: down and up sample layers
|
|
:param norm: normalization function 'instance, batch, group'
|
|
:param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
|
|
:param output_scale: Different output scales
|
|
"""
|
|
def __init__(self, output_nc=3, ngf=64, z_nc=128, img_f=1024, L=1, layers=6, norm='batch', activation='ReLU',
|
|
output_scale=1, k=4, n=2, use_spect=True, use_coord=False, use_attn=True, use_gated=False):
|
|
super(PD_Generator, self).__init__()
|
|
|
|
self.layers = layers
|
|
self.L = L
|
|
self.output_scale = output_scale
|
|
self.use_attn = use_attn
|
|
|
|
self.n = n
|
|
self.k = k
|
|
|
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
nonlinearity = get_nonlinearity_layer(activation_type=activation)
|
|
# latent z to feature
|
|
mult = min(2 ** (layers-1), img_f // ngf)
|
|
self.generator = ResBlock(z_nc, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord, use_gated=use_gated)
|
|
|
|
# transform
|
|
for i in range(self.L):
|
|
block = ResBlock(ngf * mult, ngf * mult, ngf * mult, None, nonlinearity, 'none', use_spect, use_coord, use_gated=use_gated)
|
|
setattr(self, 'generator' + str(i), block)
|
|
|
|
# decoder part
|
|
for i in range(layers):
|
|
if i > 1:
|
|
self.n = 4
|
|
mult_prev = mult # 4
|
|
mult = min(2 ** (layers - i - 1), img_f // ngf)
|
|
if i > layers - output_scale:
|
|
upconv = ResBlock(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True, use_gated=use_gated)
|
|
# upconv = ResBlockDecoder(ngf * mult_prev + output_nc, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
|
|
else:
|
|
upconv = ResBlock(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, 'up', True, use_gated=use_gated)
|
|
# upconv = ResBlockDecoder(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer, nonlinearity, use_spect, use_coord)
|
|
# block_s = ResBlockSPDNorm(ngf * mult, 3, int(math.log2(256 / 2 ** (i + 4))), self.n, self.k)
|
|
if i <= 2:
|
|
block_s = SPDNormResnetBlock(ngf * mult, ngf * mult, 2, 3)
|
|
else:
|
|
block_s = SPDNormResnetBlock(ngf * mult, ngf * mult, 4, 3)
|
|
setattr(self, 'decoder' + str(i), upconv)
|
|
setattr(self, 'SPDNorm' + str(i), block_s)
|
|
# output part
|
|
if i > layers - output_scale - 1:
|
|
outconv = Output(ngf * mult, output_nc, 3, None, nonlinearity, use_spect, use_coord, use_gated=use_gated)
|
|
setattr(self, 'out' + str(i), outconv)
|
|
# short+long term attention part
|
|
if i == 1 and use_attn:
|
|
# simam = SimAM()
|
|
# setattr(self, 'simam' + str(i), simam)
|
|
# attn = Auto_Attn(ngf*mult, None)
|
|
# attn = FullAttention(ngf*mult, ngf*mult)
|
|
attn = MyAttention(ngf*mult, ngf*mult)
|
|
setattr(self, 'attn' + str(i), attn)
|
|
|
|
def forward(self, z, mask=None, img_p=None):
|
|
"""
|
|
PD Generator Network
|
|
:param z: latent vector
|
|
:param f_m: feature of valid regions for conditional VAG-GAN
|
|
:param f_e: previous encoder feature for short+long term attention layer
|
|
:return results: different scale generation outputs
|
|
"""
|
|
|
|
f = self.generator(z)
|
|
for i in range(self.L):
|
|
generator = getattr(self, 'generator' + str(i))
|
|
f = generator(f)
|
|
|
|
# the features come from mask regions and valid regions, we directly add them together
|
|
out = f
|
|
# [8, 128, 8, 8]
|
|
results = []
|
|
attn = 0
|
|
for i in range(self.layers):
|
|
model = getattr(self, 'decoder' + str(i))
|
|
out_d = model(out)
|
|
# print("decoder" + str(i) + str(out.shape))
|
|
model = getattr(self, 'SPDNorm' + str(i))
|
|
out, featrues = model(out_d, img_p, mask)
|
|
# print("SPDNorm" + str(i) + str(out.shape))
|
|
if i == 1 and self.use_attn:
|
|
# auto attention
|
|
# model = getattr(self, 'simam' + str(i))
|
|
# out = model(out)
|
|
model = getattr(self, 'attn' + str(i))
|
|
# out, attn = model(out, mask=mask)
|
|
out, attn = model(out, featrues[-1])
|
|
# out, attn = model(out)
|
|
# print("attn" + str(i) + str(out.shape))
|
|
if i > self.layers - self.output_scale - 1:
|
|
model = getattr(self, 'out' + str(i))
|
|
output = model(out)
|
|
# print("output" + str(output.shape))
|
|
results.append(output)
|
|
out = torch.cat([out, output], dim=1)
|
|
# print("result", results)
|
|
return results, attn
|
|
|
|
|
|
class PD_Discriminator(nn.Module):
|
|
"""
|
|
ResNet Discriminator Network
|
|
:param input_nc: number of channels in input
|
|
:param ndf: base filter channel
|
|
:param layers: down and up sample layers
|
|
:param img_f: the largest feature channels
|
|
:param norm: normalization function 'instance, batch, group'
|
|
:param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
|
|
"""
|
|
def __init__(self, input_nc=3, ndf=64, img_f=1024, layers=6, norm='none', activation='LeakyReLU', use_spect=True,
|
|
use_coord=False, use_attn=True, use_gated=False):
|
|
super(PD_Discriminator, self).__init__()
|
|
|
|
self.layers = layers
|
|
self.use_attn = use_attn
|
|
|
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
nonlinearity = get_nonlinearity_layer(activation_type=activation)
|
|
self.nonlinearity = nonlinearity
|
|
|
|
# encoder part
|
|
self.block0 = ResBlockEncoderOptimized(input_nc, ndf, norm_layer, nonlinearity, use_spect, use_coord, use_gated=use_gated)
|
|
|
|
mult = 1
|
|
for i in range(layers - 1):
|
|
mult_prev = mult
|
|
mult = min(2 ** (i + 1), img_f // ndf)
|
|
# self-attention
|
|
if i == 2 and use_attn:
|
|
# simam = SimAM()
|
|
# setattr(self, 'simam' + str(i), simam)
|
|
# attn = Auto_Attn(ndf * mult_prev, norm_layer)
|
|
attn = FullAttention(ndf * mult_prev, ndf * mult_prev)
|
|
setattr(self, 'attn' + str(i), attn)
|
|
block = ResBlock(ndf * mult_prev, ndf * mult, ndf * mult_prev, norm_layer, nonlinearity, 'down', use_spect, use_coord, use_gated=use_gated)
|
|
setattr(self, 'encoder' + str(i), block)
|
|
|
|
self.block1 = ResBlock(ndf * mult, ndf * mult, ndf * mult, norm_layer, nonlinearity, 'none', use_spect, use_coord, use_gated=use_gated)
|
|
self.conv = SpectralNorm(nn.Conv2d(ndf * mult, 1, 3))
|
|
|
|
def forward(self, x):
|
|
out = self.block0(x)
|
|
# print("block0" + str(out.shape))
|
|
for i in range(self.layers - 1):
|
|
if i == 2 and self.use_attn:
|
|
# simam = getattr(self, 'simam' + str(i))
|
|
# out = simam(out)
|
|
attn = getattr(self, 'attn' + str(i))
|
|
# out, attention = attn(out)
|
|
out, attn = attn(out)
|
|
# print("attn" + str(i) + str(out.shape))
|
|
model = getattr(self, 'encoder' + str(i))
|
|
out = model(out)
|
|
# print("encoder" + str(i) + str(out.shape))
|
|
out = self.block1(out)
|
|
out = self.conv(self.nonlinearity(out))
|
|
return out
|
|
|
|
|
|
class MultiPatchCoordDiscriminator(nn.Module):
|
|
def __init__(self, input_nc=3, ndf=64, img_f=1024, layers=6, norm='none', activation='LeakyReLU', use_spect=True,
|
|
use_coord=False):
|
|
super(MultiPatchCoordDiscriminator, self).__init__()
|
|
|
|
self.module_d1 = PatchDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord)
|
|
self.module_d2 = PatchDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord)
|
|
self.module_d3 = PatchDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect, use_coord)
|
|
|
|
def downsample(self, input):
|
|
return F.avg_pool2d(
|
|
input, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False
|
|
)
|
|
|
|
# Returns list of lists of discriminator outputs.
|
|
# The final result is of size cfg.num_D x cfg.n_layers_D
|
|
def forward(self, input, mask):
|
|
result = []
|
|
local = input * mask
|
|
size = local.size()
|
|
x = int(size[2] / 4)
|
|
y = int(size[3] / 4)
|
|
range_x = int(size[2] * 3 / 4)
|
|
range_y = int(size[3] * 3 / 4)
|
|
# print("local" + str(local.shape))
|
|
local_re = local[:, :, x:range_x, y:range_y]
|
|
# print("re" + str(local_re.shape))
|
|
out_g1 = self.module_d1(input)
|
|
# print("out_g1" + str(out_g1.shape))
|
|
input_g = self.downsample(input)
|
|
out_g2 = self.module_d2(input_g)
|
|
|
|
# print(local_re.shape)
|
|
out_l1 = self.module_d3(local_re)
|
|
# print(out_g1.shape)
|
|
# print(out_g2.shape)
|
|
# print(out_l1.shape)
|
|
# print(out_l2.shape)
|
|
|
|
result.append(out_g1)
|
|
result.append(out_g2)
|
|
result.append(out_l1)
|
|
|
|
return result
|