You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

48 lines
1.2 KiB

import tensorflow.compat.v1 as tf
import numpy as np
import model
import read_img
import utils
import random
def test():
x_test, y_test, raw_names = read_img.read_flowers(False)
idxs = [random.randint(0, x_test.shape[0] - 1) for _ in range(200)]
pics = []
labels = []
names = []
for i in idxs:
pics.append(x_test[i])
labels.append(y_test[i])
names.append(raw_names[i])
x, _, _, result = model.get_model(is_train=False, keep_prob=1)
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "./result/result.ckpt")
dists = result.eval(feed_dict={x: pics})
right_count = 0
for i in range(len(dists)):
print(i)
dist = dists[i]
pred_result = np.argmax(dist) == labels[i]
if pred_result:
right_count += 1
print("{}: {} is {}, result is {}".format(pred_result, names[i],
utils.get_traffic_name(labels[i]),
utils.get_traffic_name(np.argmax(dist))))
print("accuracy is {}".format(right_count / len(dists)))
if __name__ == '__main__':
test()