diff --git a/test.py b/test.py new file mode 100644 index 0000000..81d2c02 --- /dev/null +++ b/test.py @@ -0,0 +1,103 @@ +from PIL import Image +import numpy as np +import tensorflow as tf +import matplotlib.pyplot as plt +import forward +from batchdealing import get_files + + +# 获取一张图片 +def get_one_image(train): + # 输入参数:train,训练图片的路径 + # 返回参数:image,从训练图片中随机抽取一张图片 + n = len(train) + ind = np.random.randint(0, n) + img_dir = train[ind] # 随机选择测试的图片 + + img = Image.open(img_dir) + + # 显示图片,在jupyter notebook下当然也可以不用plt.show() + plt.imshow(img) + plt.show(img) + imag = img.resize([64, 64]) # 由于图片在预处理阶段以及resize,因此该命令可略 + image = np.array(imag) + return image + + +# 测试图片 +def evaluate_one_image(image_array): + with tf.Graph().as_default(): + BATCH_SIZE = 1 + N_CLASSES = 4 + + image = tf.cast(image_array, tf.float32) + + # 线性缩放图像以具有零均值和单位范数。 + image = tf.image.per_image_standardization(image) + image = tf.reshape(image, [1, 64, 64, 3]) + + # 构建卷积神经网络 + logit = forward.inference(image, BATCH_SIZE, N_CLASSES) + + # softmax函数的作用就是归一化 + # 输入: 全连接层(往往是模型的最后一层)的值,一般代码中叫做logits。 + # 输出: 归一化的值,含义是属于该位置的概率,一般代码叫做probs。 + logit = tf.nn.softmax(logit) + + x = tf.compat.v1.placeholder(tf.float32, shape=[64, 64, 3]) + + # you need to change the directories to yours. /Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/ + logs_train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/' + + # tf.train.Saver() 保存和加载模型 + saver = tf.compat.v1.train.Saver() + + with tf.compat.v1.Session() as sess: + + print("Reading checkpoints...") + ckpt = tf.train.get_checkpoint_state(logs_train_dir) + if ckpt and ckpt.model_checkpoint_path: + global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] + saver.restore(sess, ckpt.model_checkpoint_path) + print('Loading success, global_step is %s' % global_step) + else: + print('No checkpoint file found') + + # feed_dict的作用是给使用placeholder创建出来的tensor赋值。 + # 其实,他的作用更加广泛:feed使用一个值临时替换一个op的输出结果。 + # 你可以提供feed数据作为run()调用的参数。 + # feed只在调用它的方法内有效,方法结束,feed就会消失。 + # 当我们构建完图后,需要在一个会话中启动图,启动的第一步是创建一个Session对象。 + # 为了取回(Fetch)操作的输出内容,可以在使用Session对象的run()调用执行图时, + # 传入一些tensor,这些tensor会帮助你取回结果。 + prediction = sess.run(logit, feed_dict={x: image_array}) + + # 取出prediction中元素最大值所对应的索引,也就是最大的可能 + max_index = np.argmax(prediction) + + if max_index == 0: + print('This could be the first category %.6f' % prediction[:, 0]) + elif max_index == 1: + print('This could be the second category %.6f' % prediction[:, 1]) + elif max_index == 2: + print('This could be the third category %.6f' % prediction[:, 2]) + elif max_index == 3: + print('This could be the fourth category %.6f' % prediction[:, 3]) + elif max_index == 4: + print('This could be the fifth category %.6f' % prediction[:, 4]) + elif max_index == 5: + print('This could be the sixth category %.6f' % prediction[:, 5]) + elif max_index == 6: + print('This could be the seventh category %.6f' % prediction[:, 6]) + elif max_index == 7: + print('This could be the eighth category %.6f' % prediction[:, 7]) + else: + print('This could be the ninth category %.6f' % prediction[:, 9]) + + +if __name__ == '__main__': + train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/' + train, train_label, val, val_label = get_files(train_dir, 0.3) + img = get_one_image(val) # 通过改变参数train or val,进而验证训练集或测试集 + evaluate_one_image(img) +