You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

48 lines
1.4 KiB

import tensorflow.compat.v1 as tf
import model
import utils
import read_img
def get_loss(result, y):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=result, labels=y)
return tf.reduce_mean(cross_entropy)
def get_optimizer(loss):
train_variables = tf.trainable_variables()
optimizer = tf.train.AdamOptimizer(learning_rate=utils.lr).minimize(loss, var_list=train_variables)
return optimizer
def train():
x_train, y_train, _ = read_img.read_flowers(True)
x_test, y_test, _ = read_img.read_flowers(False)
train_batches = x_train.shape[0]
x, y, one_hot, result = model.get_model(is_train=True)
loss = get_loss(result, one_hot)
optimizer = get_optimizer(loss)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(utils.epochs):
for batch in range(train_batches // utils.batch_size):
start = batch * utils.batch_size
next_x = x_train[start:start + utils.batch_size]
next_y = y_train[start:start + utils.batch_size]
sess.run(optimizer, feed_dict={x: next_x, y: next_y})
loss_result = sess.run(loss, feed_dict={x: x_test, y: y_test})
print("epoch: {}, loss: {}".format(epoch, loss_result))
saver.save(sess, "./result/result.ckpt")
if __name__ == '__main__':
train()