You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 lines
1.4 KiB
48 lines
1.4 KiB
3 years ago
|
import tensorflow.compat.v1 as tf
|
||
|
import model
|
||
|
import utils
|
||
|
import read_img
|
||
|
|
||
|
|
||
|
def get_loss(result, y):
|
||
|
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=result, labels=y)
|
||
|
return tf.reduce_mean(cross_entropy)
|
||
|
|
||
|
|
||
|
def get_optimizer(loss):
|
||
|
train_variables = tf.trainable_variables()
|
||
|
optimizer = tf.train.AdamOptimizer(learning_rate=utils.lr).minimize(loss, var_list=train_variables)
|
||
|
return optimizer
|
||
|
|
||
|
|
||
|
def train():
|
||
|
x_train, y_train, _ = read_img.read_flowers(True)
|
||
|
x_test, y_test, _ = read_img.read_flowers(False)
|
||
|
|
||
|
train_batches = x_train.shape[0]
|
||
|
|
||
|
x, y, one_hot, result = model.get_model(is_train=True)
|
||
|
loss = get_loss(result, one_hot)
|
||
|
optimizer = get_optimizer(loss)
|
||
|
|
||
|
saver = tf.train.Saver()
|
||
|
with tf.Session() as sess:
|
||
|
sess.run(tf.global_variables_initializer())
|
||
|
|
||
|
for epoch in range(utils.epochs):
|
||
|
for batch in range(train_batches // utils.batch_size):
|
||
|
start = batch * utils.batch_size
|
||
|
next_x = x_train[start:start + utils.batch_size]
|
||
|
next_y = y_train[start:start + utils.batch_size]
|
||
|
|
||
|
sess.run(optimizer, feed_dict={x: next_x, y: next_y})
|
||
|
|
||
|
loss_result = sess.run(loss, feed_dict={x: x_test, y: y_test})
|
||
|
print("epoch: {}, loss: {}".format(epoch, loss_result))
|
||
|
|
||
|
saver.save(sess, "./result/result.ckpt")
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
train()
|