You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

147 lines
6.5 KiB

# coding: utf-8
from __future__ import print_function
from __future__ import division
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
import reader
import model
import time
import losses
import utils
import os
import argparse
slim = tf.contrib.slim
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--conf', default='conf/mosaic.yml', help='the path to the conf file')
return parser.parse_args()
def main(FLAGS):
style_features_t = losses.get_style_features(FLAGS)
# Make sure the training path exists.
training_path = os.path.join(FLAGS.model_path, FLAGS.naming)
if not(os.path.exists(training_path)):
os.makedirs(training_path)
with tf.Graph().as_default():
with tf.Session() as sess:
"""Build Network"""
network_fn = nets_factory.get_network_fn(
FLAGS.loss_model,
num_classes=1,
is_training=False)
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
processed_images = reader.image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size,
'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch)
generated = model.net(processed_images, training=True)
processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size)
]
processed_generated = tf.stack(processed_generated)
_, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False)
# Log the structure of loss network
tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):')
for key in endpoints_dict:
tf.logging.info(key)
"""Build Losses"""
content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)
style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)
tv_loss = losses.total_variation_loss(generated) # use the unprocessed image
loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss
# Add Summary for visualization in tensorboard.
"""Add Summary"""
tf.summary.scalar('losses/content_loss', content_loss)
tf.summary.scalar('losses/style_loss', style_loss)
tf.summary.scalar('losses/regularizer_loss', tv_loss)
tf.summary.scalar('weighted_losses/weighted_content_loss', content_loss * FLAGS.content_weight)
tf.summary.scalar('weighted_losses/weighted_style_loss', style_loss * FLAGS.style_weight)
tf.summary.scalar('weighted_losses/weighted_regularizer_loss', tv_loss * FLAGS.tv_weight)
tf.summary.scalar('total_loss', loss)
for layer in FLAGS.style_layers:
tf.summary.scalar('style_losses/' + layer, style_loss_summary[layer])
tf.summary.image('generated', generated)
# tf.image_summary('processed_generated', processed_generated) # May be better?
tf.summary.image('origin', tf.stack([
image_unprocessing_fn(image) for image in tf.unstack(processed_images, axis=0, num=FLAGS.batch_size)
]))
summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(training_path)
"""Prepare to Train"""
global_step = tf.Variable(0, name="global_step", trainable=False)
variable_to_train = []
for variable in tf.trainable_variables():
if not(variable.name.startswith(FLAGS.loss_model)):
variable_to_train.append(variable)
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train)
variables_to_restore = []
for v in tf.global_variables():
if not(v.name.startswith(FLAGS.loss_model)):
variables_to_restore.append(v)
saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# Restore variables for loss network.
init_func = utils._get_init_fn(FLAGS)
init_func(sess)
# Restore variables for training model if the checkpoint file exists.
last_file = tf.train.latest_checkpoint(training_path)
if last_file:
tf.logging.info('Restoring model from {}'.format(last_file))
saver.restore(sess, last_file)
"""Start Training"""
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
start_time = time.time()
try:
while not coord.should_stop():
_, loss_t, step = sess.run([train_op, loss, global_step])
elapsed_time = time.time() - start_time
start_time = time.time()
"""logging"""
# print(step)
if step % 10 == 0:
tf.logging.info('step: %d, total Loss %f, secs/step: %f' % (step, loss_t, elapsed_time))
"""summary"""
if step % 25 == 0:
tf.logging.info('adding summary...')
summary_str = sess.run(summary)
writer.add_summary(summary_str, step)
writer.flush()
"""checkpoint"""
if step % 1000 == 0:
saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt'), global_step=step)
except tf.errors.OutOfRangeError:
saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done'))
tf.logging.info('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
args = parse_args()
FLAGS = utils.read_conf_file(args.conf)
main(FLAGS)