parent
03572669b7
commit
32612d773c
@ -1,160 +0,0 @@
|
||||
# encoding: UTF-8
|
||||
# Copyright 2017 Google.com
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib import rnn # rnn stuff temporarily in contrib, moving back to code in TF 1.1
|
||||
|
||||
from code import my_txtutils as txt
|
||||
|
||||
tf.set_random_seed(0)
|
||||
|
||||
# Full comments in rnn_train.py
|
||||
# This file implements the exact same model but using the state_is_tuple=True
|
||||
# option in tf.nn.rnn_cell.MultiRNNCell. This option is enabled by default.
|
||||
# It produces faster code (by ~10%) but handling the state as a tuple is bit
|
||||
# more cumbersome. Search for comments containing "state_is_tuple=True" for
|
||||
# details.
|
||||
|
||||
SEQLEN = 30
|
||||
BATCHSIZE = 100
|
||||
ALPHASIZE = txt.ALPHASIZE
|
||||
INTERNALSIZE = 512
|
||||
NLAYERS = 3
|
||||
learning_rate = 0.001 # fixed learning rate
|
||||
|
||||
# load data, either shakespeare, or the Python source of Tensorflow itself
|
||||
shakedir = "shakespeare/*.txt"
|
||||
# shakedir = "../tensorflow/**/*.py"
|
||||
codetext, valitext, bookranges = txt.read_data_files(shakedir, validation=False)
|
||||
|
||||
# display some stats on the data
|
||||
epoch_size = len(codetext) // (BATCHSIZE * SEQLEN)
|
||||
txt.print_data_stats(len(codetext), len(valitext), epoch_size)
|
||||
|
||||
#
|
||||
# the model
|
||||
#
|
||||
lr = tf.placeholder(tf.float32, name='lr') # learning rate
|
||||
batchsize = tf.placeholder(tf.int32, name='batchsize')
|
||||
|
||||
# inputs
|
||||
X = tf.placeholder(tf.uint8, [None, None], name='X') # [ BATCHSIZE, SEQLEN ]
|
||||
Xo = tf.one_hot(X, ALPHASIZE, 1.0, 0.0) # [ BATCHSIZE, SEQLEN, ALPHASIZE ]
|
||||
# expected outputs = same sequence shifted by 1 since we are trying to predict the next character
|
||||
Y_ = tf.placeholder(tf.uint8, [None, None], name='Y_') # [ BATCHSIZE, SEQLEN ]
|
||||
Yo_ = tf.one_hot(Y_, ALPHASIZE, 1.0, 0.0) # [ BATCHSIZE, SEQLEN, ALPHASIZE ]
|
||||
|
||||
cells = [rnn.GRUCell(INTERNALSIZE) for _ in range(NLAYERS)]
|
||||
multicell = rnn.MultiRNNCell(cells, state_is_tuple=True)
|
||||
|
||||
# When using state_is_tuple=True, you must use multicell.zero_state
|
||||
# to create a tuple of placeholders for the input states (one state per layer).
|
||||
# When executed using session.run(zerostate), this also returns the correctly
|
||||
# shaped initial zero state to use when starting your training loop.
|
||||
zerostate = multicell.zero_state(BATCHSIZE, dtype=tf.float32)
|
||||
|
||||
Yr, H = tf.nn.dynamic_rnn(multicell, Xo, dtype=tf.float32, initial_state=zerostate)
|
||||
# Yr: [ BATCHSIZE, SEQLEN, INTERNALSIZE ]
|
||||
# H: [ BATCHSIZE, INTERNALSIZE*NLAYERS ] # this is the last state in the sequence
|
||||
|
||||
H = tf.identity(H, name='H') # just to give it a name
|
||||
|
||||
# Softmax layer implementation:
|
||||
# Flatten the first two dimension of the output [ BATCHSIZE, SEQLEN, ALPHASIZE ] => [ BATCHSIZE x SEQLEN, ALPHASIZE ]
|
||||
# then apply softmax readout layer. This way, the weights and biases are shared across unrolled time steps.
|
||||
# From the readout point of view, a value coming from a cell or a minibatch is the same thing
|
||||
|
||||
Yflat = tf.reshape(Yr, [-1, INTERNALSIZE]) # [ BATCHSIZE x SEQLEN, INTERNALSIZE ]
|
||||
Ylogits = layers.linear(Yflat, ALPHASIZE) # [ BATCHSIZE x SEQLEN, ALPHASIZE ]
|
||||
Yflat_ = tf.reshape(Yo_, [-1, ALPHASIZE]) # [ BATCHSIZE x SEQLEN, ALPHASIZE ]
|
||||
loss = tf.nn.softmax_cross_entropy_with_logits(logits=Ylogits, labels=Yflat_) # [ BATCHSIZE x SEQLEN ]
|
||||
loss = tf.reshape(loss, [batchsize, -1]) # [ BATCHSIZE, SEQLEN ]
|
||||
Yo = tf.nn.softmax(Ylogits, name='Yo') # [ BATCHSIZE x SEQLEN, ALPHASIZE ]
|
||||
Y = tf.argmax(Yo, 1) # [ BATCHSIZE x SEQLEN ]
|
||||
Y = tf.reshape(Y, [batchsize, -1], name="Y") # [ BATCHSIZE, SEQLEN ]
|
||||
train_step = tf.train.AdamOptimizer(lr).minimize(loss)
|
||||
|
||||
# stats for display
|
||||
seqloss = tf.reduce_mean(loss, 1)
|
||||
batchloss = tf.reduce_mean(seqloss)
|
||||
accuracy = tf.reduce_mean(tf.cast(tf.equal(Y_, tf.cast(Y, tf.uint8)), tf.float32))
|
||||
loss_summary = tf.summary.scalar("batch_loss", batchloss)
|
||||
acc_summary = tf.summary.scalar("batch_accuracy", accuracy)
|
||||
summaries = tf.summary.merge([loss_summary, acc_summary])
|
||||
|
||||
# Init Tensorboard stuff. This will save Tensorboard information into a different
|
||||
# folder at each run named 'log/<timestamp>/'.
|
||||
timestamp = str(math.trunc(time.time()))
|
||||
summary_writer = tf.summary.FileWriter("log/" + timestamp + "-training")
|
||||
|
||||
# Init for saving models. They will be saved into a directory named 'checkpoints'.
|
||||
# Only the last checkpoint is kept.
|
||||
if not os.path.exists("checkpoints"):
|
||||
os.mkdir("checkpoints")
|
||||
saver = tf.train.Saver(max_to_keep=1)
|
||||
|
||||
# for display: init the progress bar
|
||||
DISPLAY_FREQ = 50
|
||||
_50_BATCHES = DISPLAY_FREQ * BATCHSIZE * SEQLEN
|
||||
progress = txt.Progress(DISPLAY_FREQ, size=111+2, msg="Training on next "+str(DISPLAY_FREQ)+" batches")
|
||||
|
||||
# init
|
||||
init = tf.global_variables_initializer()
|
||||
sess = tf.Session()
|
||||
sess.run(init)
|
||||
step = 0
|
||||
|
||||
# training loop
|
||||
istate = sess.run(zerostate) # initial zero input state (a tuple)
|
||||
for x, y_, epoch in txt.rnn_minibatch_sequencer(codetext, BATCHSIZE, SEQLEN, nb_epochs=1000):
|
||||
|
||||
# train on one minibatch
|
||||
feed_dict = {X: x, Y_: y_, lr: learning_rate, batchsize: BATCHSIZE}
|
||||
# This is how you add the input state to feed dictionary when state_is_tuple=True.
|
||||
# zerostate is a tuple of the placeholders for the NLAYERS=3 input states of our
|
||||
# multi-layer RNN cell. Those placeholders must be used as keys in feed_dict.
|
||||
# istate is a tuple holding the actual values of the input states (one per layer).
|
||||
# Iterate on the input state placeholders and use them as keys in the dictionary
|
||||
# to add actual input state values.
|
||||
for i, v in enumerate(zerostate):
|
||||
feed_dict[v] = istate[i]
|
||||
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
|
||||
|
||||
# save training data for Tensorboard
|
||||
summary_writer.add_summary(smm, step)
|
||||
|
||||
# display a visual validation of progress (every 50 batches)
|
||||
if step % _50_BATCHES == 0:
|
||||
feed_dict = {X: x, Y_: y_, batchsize: BATCHSIZE} # no dropout for validation
|
||||
for i, v in enumerate(zerostate):
|
||||
feed_dict[v] = istate[i]
|
||||
y, l, bl, acc = sess.run([Y, seqloss, batchloss, accuracy], feed_dict=feed_dict)
|
||||
txt.print_learning_learned_comparison(x[:5], y, l, bookranges, bl, acc, epoch_size, step, epoch)
|
||||
|
||||
# save a checkpoint (every 500 batches)
|
||||
if step // 10 % _50_BATCHES == 0:
|
||||
saver.save(sess, 'checkpoints/rnn_train_' + timestamp, global_step=step)
|
||||
|
||||
# display progress bar
|
||||
progress.step(reset=step % _50_BATCHES == 0)
|
||||
|
||||
# loop state around
|
||||
istate = ostate
|
||||
step += BATCHSIZE * SEQLEN
|
Loading…
Reference in new issue