You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
6.5 KiB
147 lines
6.5 KiB
# coding: utf-8
|
|
from __future__ import print_function
|
|
from __future__ import division
|
|
import tensorflow as tf
|
|
from nets import nets_factory
|
|
from preprocessing import preprocessing_factory
|
|
import reader
|
|
import model
|
|
import time
|
|
import losses
|
|
import utils
|
|
import os
|
|
import argparse
|
|
|
|
slim = tf.contrib.slim
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('-c', '--conf', default='conf/mosaic.yml', help='the path to the conf file')
|
|
return parser.parse_args()
|
|
|
|
|
|
def main(FLAGS):
|
|
style_features_t = losses.get_style_features(FLAGS)
|
|
|
|
# Make sure the training path exists.
|
|
training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
|
|
if not(os.path.exists(training_path)):
|
|
os.makedirs(training_path)
|
|
|
|
with tf.Graph().as_default():
|
|
with tf.Session() as sess:
|
|
"""Build Network"""
|
|
network_fn = nets_factory.get_network_fn(
|
|
FLAGS.loss_model,
|
|
num_classes=1,
|
|
is_training=False)
|
|
|
|
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
|
|
FLAGS.loss_model,
|
|
is_training=False)
|
|
processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
|
|
'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
|
|
generated = model.net(processed_images, training=True)
|
|
processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
|
|
for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
|
|
]
|
|
processed_generated = tf.stack(processed_generated)
|
|
_, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)
|
|
|
|
# Log the structure of loss network
|
|
tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
|
|
for key in endpoints_dict:
|
|
tf.logging.info(key)
|
|
|
|
"""Build Losses"""
|
|
content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
|
|
style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
|
|
tv_loss = losses.total_variation_loss(generated) # use the unprocessed image
|
|
|
|
loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss
|
|
|
|
# Add Summary for visualization in tensorboard.
|
|
"""Add Summary"""
|
|
tf.summary.scalar('losses/content_loss', content_loss)
|
|
tf.summary.scalar('losses/style_loss', style_loss)
|
|
tf.summary.scalar('losses/regularizer_loss', tv_loss)
|
|
|
|
tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
|
|
tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
|
|
tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
|
|
tf.summary.scalar('total_loss', loss)
|
|
|
|
for layer in FLAGS.style_layers:
|
|
tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
|
|
tf.summary.image('generated', generated)
|
|
# tf.image_summary('processed_generated', processed_generated) # May be better?
|
|
tf.summary.image('origin', tf.stack([
|
|
image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
|
|
]))
|
|
summary = tf.summary.merge_all()
|
|
writer = tf.summary.FileWriter(training_path)
|
|
|
|
"""Prepare to Train"""
|
|
global_step = tf.Variable(0, name="global_step", trainable=False)
|
|
|
|
variable_to_train = []
|
|
for variable in tf.trainable_variables():
|
|
if not(variable.name.startswith(FLAGS.loss_model)):
|
|
variable_to_train.append(variable)
|
|
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)
|
|
|
|
variables_to_restore = []
|
|
for v in tf.global_variables():
|
|
if not(v.name.startswith(FLAGS.loss_model)):
|
|
variables_to_restore.append(v)
|
|
saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)
|
|
|
|
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
|
|
|
|
# Restore variables for loss network.
|
|
init_func = utils._get_init_fn(FLAGS)
|
|
init_func(sess)
|
|
|
|
# Restore variables for training model if the checkpoint file exists.
|
|
last_file = tf.train.latest_checkpoint(training_path)
|
|
if last_file:
|
|
tf.logging.info('Restoring model from {}'.format(last_file))
|
|
saver.restore(sess, last_file)
|
|
|
|
"""Start Training"""
|
|
coord = tf.train.Coordinator()
|
|
threads = tf.train.start_queue_runners(coord=coord)
|
|
start_time = time.time()
|
|
try:
|
|
while not coord.should_stop():
|
|
_, loss_t, step = sess.run([train_op, loss, global_step])
|
|
elapsed_time = time.time() - start_time
|
|
start_time = time.time()
|
|
"""logging"""
|
|
# print(step)
|
|
if step % 10 == 0:
|
|
tf.logging.info('step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
|
|
"""summary"""
|
|
if step % 25 == 0:
|
|
tf.logging.info('adding summary...')
|
|
summary_str = sess.run(summary)
|
|
writer.add_summary(summary_str, step)
|
|
writer.flush()
|
|
"""checkpoint"""
|
|
if step % 1000 == 0:
|
|
saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
|
|
except tf.errors.OutOfRangeError:
|
|
saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
|
|
tf.logging.info('Done training -- epoch limit reached')
|
|
finally:
|
|
coord.request_stop()
|
|
coord.join(threads)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.logging.set_verbosity(tf.logging.INFO)
|
|
args = parse_args()
|
|
FLAGS = utils.read_conf_file(args.conf)
|
|
main(FLAGS)
|