You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
5.3 KiB
135 lines
5.3 KiB
import tensorflow.compat.v1 as tf
|
|
tf.disable_v2_behavior()
|
|
|
|
|
|
def conv2d(x, input_filters, output_filters, kernel, strides, mode='REFLECT'):
|
|
with tf.variable_scope('conv'):
|
|
|
|
shape = [kernel, kernel, input_filters, output_filters]
|
|
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
|
|
x_padded = tf.pad(x, [[0, 0], [int(kernel / 2), int(kernel / 2)], [int(kernel / 2), int(kernel / 2)], [0, 0]], mode=mode)
|
|
return tf.nn.conv2d(x_padded, weight, strides=[1, strides, strides, 1], padding='VALID', name='conv')
|
|
|
|
|
|
def conv2d_transpose(x, input_filters, output_filters, kernel, strides):
|
|
with tf.variable_scope('conv_transpose'):
|
|
|
|
shape = [kernel, kernel, output_filters, input_filters]
|
|
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
|
|
|
|
batch_size = tf.shape(x)[0]
|
|
height = tf.shape(x)[1] * strides
|
|
width = tf.shape(x)[2] * strides
|
|
output_shape = tf.stack([batch_size, height, width, output_filters])
|
|
return tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], name='conv_transpose')
|
|
|
|
|
|
def resize_conv2d(x, input_filters, output_filters, kernel, strides, training):
|
|
'''
|
|
An alternative to transposed convolution where we first resize, then convolve.
|
|
See http://distill.pub/2016/deconv-checkerboard/
|
|
|
|
For some reason the shape needs to be statically known for gradient propagation
|
|
through tf.image.resize_images, but we only know that for fixed image size, so we
|
|
plumb through a "training" argument
|
|
'''
|
|
with tf.variable_scope('conv_transpose'):
|
|
height = x.get_shape()[1].value if training else tf.shape(x)[1]
|
|
width = x.get_shape()[2].value if training else tf.shape(x)[2]
|
|
|
|
new_height = height * strides * 2
|
|
new_width = width * strides * 2
|
|
|
|
x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
|
|
|
# shape = [kernel, kernel, input_filters, output_filters]
|
|
# weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
|
|
return conv2d(x_resized, input_filters, output_filters, kernel, strides)
|
|
|
|
|
|
def instance_norm(x):
|
|
epsilon = 1e-9
|
|
|
|
mean, var = tf.nn.moments(x, [1, 2], keep_dims=True)
|
|
|
|
return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon)))
|
|
|
|
|
|
def batch_norm(x, size, training, decay=0.999):
|
|
beta = tf.Variable(tf.zeros([size]), name='beta')
|
|
scale = tf.Variable(tf.ones([size]), name='scale')
|
|
pop_mean = tf.Variable(tf.zeros([size]))
|
|
pop_var = tf.Variable(tf.ones([size]))
|
|
epsilon = 1e-3
|
|
|
|
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
|
|
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
|
|
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
|
|
|
|
def batch_statistics():
|
|
with tf.control_dependencies([train_mean, train_var]):
|
|
return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')
|
|
|
|
def population_statistics():
|
|
return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm')
|
|
|
|
return tf.cond(training, batch_statistics, population_statistics)
|
|
|
|
|
|
def relu(input):
|
|
relu = tf.nn.relu(input)
|
|
# convert nan to zero (nan != nan)
|
|
nan_to_zero = tf.where(tf.equal(relu, relu), relu, tf.zeros_like(relu))
|
|
return nan_to_zero
|
|
|
|
|
|
def residual(x, filters, kernel, strides):
|
|
with tf.variable_scope('residual'):
|
|
conv1 = conv2d(x, filters, filters, kernel, strides)
|
|
conv2 = conv2d(relu(conv1), filters, filters, kernel, strides)
|
|
|
|
residual = x + conv2
|
|
|
|
return residual
|
|
|
|
|
|
def net(image, training):
|
|
# Less border effects when padding a little before passing through ..
|
|
image = tf.pad(image, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT')
|
|
|
|
with tf.variable_scope('conv1'):
|
|
conv1 = relu(instance_norm(conv2d(image, 3, 32, 9, 1)))
|
|
with tf.variable_scope('conv2'):
|
|
conv2 = relu(instance_norm(conv2d(conv1, 32, 64, 3, 2)))
|
|
with tf.variable_scope('conv3'):
|
|
conv3 = relu(instance_norm(conv2d(conv2, 64, 128, 3, 2)))
|
|
with tf.variable_scope('res1'):
|
|
res1 = residual(conv3, 128, 3, 1)
|
|
with tf.variable_scope('res2'):
|
|
res2 = residual(res1, 128, 3, 1)
|
|
with tf.variable_scope('res3'):
|
|
res3 = residual(res2, 128, 3, 1)
|
|
with tf.variable_scope('res4'):
|
|
res4 = residual(res3, 128, 3, 1)
|
|
with tf.variable_scope('res5'):
|
|
res5 = residual(res4, 128, 3, 1)
|
|
# print(res5.get_shape())
|
|
with tf.variable_scope('deconv1'):
|
|
# deconv1 = relu(instance_norm(conv2d_transpose(res5, 128, 64, 3, 2)))
|
|
deconv1 = relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training)))
|
|
with tf.variable_scope('deconv2'):
|
|
# deconv2 = relu(instance_norm(conv2d_transpose(deconv1, 64, 32, 3, 2)))
|
|
deconv2 = relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training)))
|
|
with tf.variable_scope('deconv3'):
|
|
# deconv_test = relu(instance_norm(conv2d(deconv2, 32, 32, 2, 1)))
|
|
deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1)))
|
|
|
|
y = (deconv3 + 1) * 127.5
|
|
|
|
# Remove border effect reducing padding.
|
|
height = tf.shape(y)[1]
|
|
width = tf.shape(y)[2]
|
|
y = tf.slice(y, [0, 10, 10, 0], tf.stack([-1, height - 20, width - 20, -1]))
|
|
|
|
return y
|