parent
3fc3678987
commit
aefddf00a5
@ -0,0 +1,48 @@
|
|||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import model
|
||||||
|
import read_img
|
||||||
|
import utils
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def test():
|
||||||
|
x_test, y_test, raw_names = read_img.read_flowers(False)
|
||||||
|
idxs = [random.randint(0, x_test.shape[0] - 1) for _ in range(200)]
|
||||||
|
|
||||||
|
pics = []
|
||||||
|
labels = []
|
||||||
|
names = []
|
||||||
|
|
||||||
|
for i in idxs:
|
||||||
|
pics.append(x_test[i])
|
||||||
|
labels.append(y_test[i])
|
||||||
|
names.append(raw_names[i])
|
||||||
|
|
||||||
|
x, _, _, result = model.get_model(is_train=False, keep_prob=1)
|
||||||
|
|
||||||
|
with tf.Session() as sess:
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
saver.restore(sess, "./result/result.ckpt")
|
||||||
|
|
||||||
|
dists = result.eval(feed_dict={x: pics})
|
||||||
|
|
||||||
|
right_count = 0
|
||||||
|
for i in range(len(dists)):
|
||||||
|
print(i)
|
||||||
|
dist = dists[i]
|
||||||
|
pred_result = np.argmax(dist) == labels[i]
|
||||||
|
if pred_result:
|
||||||
|
right_count += 1
|
||||||
|
|
||||||
|
print("{}: {} is {}, result is {}".format(pred_result, names[i],
|
||||||
|
utils.get_traffic_name(labels[i]),
|
||||||
|
utils.get_traffic_name(np.argmax(dist))))
|
||||||
|
|
||||||
|
print("accuracy is {}".format(right_count / len(dists)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test()
|
Loading…
Reference in new issue