You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
109 lines
4.1 KiB
109 lines
4.1 KiB
# coding: utf-8
|
|
from __future__ import print_function
|
|
import tensorflow as tf
|
|
from nets import nets_factory
|
|
from preprocessing import preprocessing_factory
|
|
import utils
|
|
import os
|
|
|
|
slim = tf.contrib.slim
|
|
|
|
|
|
def gram(layer):
|
|
shape = tf.shape(layer)
|
|
num_images = shape[0]
|
|
width = shape[1]
|
|
height = shape[2]
|
|
num_filters = shape[3]
|
|
filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters]))
|
|
grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters)
|
|
|
|
return grams
|
|
|
|
|
|
def get_style_features(FLAGS):
|
|
"""
|
|
For the "style_image", the preprocessing step is:
|
|
1. Resize the shorter side to FLAGS.image_size
|
|
2. Apply central crop
|
|
"""
|
|
with tf.Graph().as_default():
|
|
network_fn = nets_factory.get_network_fn(
|
|
FLAGS.loss_model,
|
|
num_classes=1,
|
|
is_training=False)
|
|
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
|
|
FLAGS.loss_model,
|
|
is_training=False)
|
|
|
|
# Get the style image data
|
|
size = FLAGS.image_size
|
|
img_bytes = tf.read_file(FLAGS.style_image)
|
|
if FLAGS.style_image.lower().endswith('png'):
|
|
image = tf.image.decode_png(img_bytes)
|
|
else:
|
|
image = tf.image.decode_jpeg(img_bytes)
|
|
# image = _aspect_preserving_resize(image, size)
|
|
|
|
# Add the batch dimension
|
|
images = tf.expand_dims(image_preprocessing_fn(image, size, size), 0)
|
|
# images = tf.stack([image_preprocessing_fn(image, size, size)])
|
|
|
|
_, endpoints_dict = network_fn(images, spatial_squeeze=False)
|
|
features = []
|
|
for layer in FLAGS.style_layers:
|
|
feature = endpoints_dict[layer]
|
|
feature = tf.squeeze(gram(feature), [0]) # remove the batch dimension
|
|
features.append(feature)
|
|
|
|
with tf.Session() as sess:
|
|
# Restore variables for loss network.
|
|
init_func = utils._get_init_fn(FLAGS)
|
|
init_func(sess)
|
|
|
|
# Make sure the 'generated' directory is exists.
|
|
if os.path.exists('generated') is False:
|
|
os.makedirs('generated')
|
|
# Indicate cropped style image path
|
|
save_file = 'generated/target_style_' + FLAGS.naming + '.jpg'
|
|
# Write preprocessed style image to indicated path
|
|
with open(save_file, 'wb') as f:
|
|
target_image = image_unprocessing_fn(images[0, :])
|
|
value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
|
|
f.write(sess.run(value))
|
|
tf.logging.info('Target style pattern is saved to: %s.' % save_file)
|
|
|
|
# Return the features those layers are use for measuring style loss.
|
|
return sess.run(features)
|
|
|
|
|
|
def style_loss(endpoints_dict, style_features_t, style_layers):
|
|
style_loss = 0
|
|
style_loss_summary = {}
|
|
for style_gram, layer in zip(style_features_t, style_layers):
|
|
generated_images, _ = tf.split(endpoints_dict[layer], 2, 0)
|
|
size = tf.size(generated_images)
|
|
layer_style_loss = tf.nn.l2_loss(gram(generated_images) - style_gram) * 2 / tf.to_float(size)
|
|
style_loss_summary[layer] = layer_style_loss
|
|
style_loss += layer_style_loss
|
|
return style_loss, style_loss_summary
|
|
|
|
|
|
def content_loss(endpoints_dict, content_layers):
|
|
content_loss = 0
|
|
for layer in content_layers:
|
|
generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0)
|
|
size = tf.size(generated_images)
|
|
content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper
|
|
return content_loss
|
|
|
|
|
|
def total_variation_loss(layer):
|
|
shape = tf.shape(layer)
|
|
height = shape[1]
|
|
width = shape[2]
|
|
y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
|
|
x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
|
|
loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
|
|
return loss
|