parent
aefddf00a5
commit
2b7a1a178b
@ -0,0 +1,47 @@
|
||||
import tensorflow.compat.v1 as tf
|
||||
import model
|
||||
import utils
|
||||
import read_img
|
||||
|
||||
|
||||
def get_loss(result, y):
|
||||
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=result, labels=y)
|
||||
return tf.reduce_mean(cross_entropy)
|
||||
|
||||
|
||||
def get_optimizer(loss):
|
||||
train_variables = tf.trainable_variables()
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=utils.lr).minimize(loss, var_list=train_variables)
|
||||
return optimizer
|
||||
|
||||
|
||||
def train():
|
||||
x_train, y_train, _ = read_img.read_flowers(True)
|
||||
x_test, y_test, _ = read_img.read_flowers(False)
|
||||
|
||||
train_batches = x_train.shape[0]
|
||||
|
||||
x, y, one_hot, result = model.get_model(is_train=True)
|
||||
loss = get_loss(result, one_hot)
|
||||
optimizer = get_optimizer(loss)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
for epoch in range(utils.epochs):
|
||||
for batch in range(train_batches // utils.batch_size):
|
||||
start = batch * utils.batch_size
|
||||
next_x = x_train[start:start + utils.batch_size]
|
||||
next_y = y_train[start:start + utils.batch_size]
|
||||
|
||||
sess.run(optimizer, feed_dict={x: next_x, y: next_y})
|
||||
|
||||
loss_result = sess.run(loss, feed_dict={x: x_test, y: y_test})
|
||||
print("epoch: {}, loss: {}".format(epoch, loss_result))
|
||||
|
||||
saver.save(sess, "./result/result.ckpt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
train()
|
Loading…
Reference in new issue