diff --git a/test.py b/test.py deleted file mode 100644 index 81d2c02..0000000 --- a/test.py +++ /dev/null @@ -1,103 +0,0 @@ -from PIL import Image -import numpy as np -import tensorflow as tf -import matplotlib.pyplot as plt -import forward -from batchdealing import get_files - - -# 获取一张图片 -def get_one_image(train): - # 输入参数:train,训练图片的路径 - # 返回参数:image,从训练图片中随机抽取一张图片 - n = len(train) - ind = np.random.randint(0, n) - img_dir = train[ind] # 随机选择测试的图片 - - img = Image.open(img_dir) - - # 显示图片,在jupyter notebook下当然也可以不用plt.show() - plt.imshow(img) - plt.show(img) - imag = img.resize([64, 64]) # 由于图片在预处理阶段以及resize,因此该命令可略 - image = np.array(imag) - return image - - -# 测试图片 -def evaluate_one_image(image_array): - with tf.Graph().as_default(): - BATCH_SIZE = 1 - N_CLASSES = 4 - - image = tf.cast(image_array, tf.float32) - - # 线性缩放图像以具有零均值和单位范数。 - image = tf.image.per_image_standardization(image) - image = tf.reshape(image, [1, 64, 64, 3]) - - # 构建卷积神经网络 - logit = forward.inference(image, BATCH_SIZE, N_CLASSES) - - # softmax函数的作用就是归一化 - # 输入: 全连接层(往往是模型的最后一层)的值,一般代码中叫做logits。 - # 输出: 归一化的值,含义是属于该位置的概率,一般代码叫做probs。 - logit = tf.nn.softmax(logit) - - x = tf.compat.v1.placeholder(tf.float32, shape=[64, 64, 3]) - - # you need to change the directories to yours. /Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/ - logs_train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/' - - # tf.train.Saver() 保存和加载模型 - saver = tf.compat.v1.train.Saver() - - with tf.compat.v1.Session() as sess: - - print("Reading checkpoints...") - ckpt = tf.train.get_checkpoint_state(logs_train_dir) - if ckpt and ckpt.model_checkpoint_path: - global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] - saver.restore(sess, ckpt.model_checkpoint_path) - print('Loading success, global_step is %s' % global_step) - else: - print('No checkpoint file found') - - # feed_dict的作用是给使用placeholder创建出来的tensor赋值。 - # 其实,他的作用更加广泛:feed使用一个值临时替换一个op的输出结果。 - # 你可以提供feed数据作为run()调用的参数。 - # feed只在调用它的方法内有效,方法结束,feed就会消失。 - # 当我们构建完图后,需要在一个会话中启动图,启动的第一步是创建一个Session对象。 - # 为了取回(Fetch)操作的输出内容,可以在使用Session对象的run()调用执行图时, - # 传入一些tensor,这些tensor会帮助你取回结果。 - prediction = sess.run(logit, feed_dict={x: image_array}) - - # 取出prediction中元素最大值所对应的索引,也就是最大的可能 - max_index = np.argmax(prediction) - - if max_index == 0: - print('This could be the first category %.6f' % prediction[:, 0]) - elif max_index == 1: - print('This could be the second category %.6f' % prediction[:, 1]) - elif max_index == 2: - print('This could be the third category %.6f' % prediction[:, 2]) - elif max_index == 3: - print('This could be the fourth category %.6f' % prediction[:, 3]) - elif max_index == 4: - print('This could be the fifth category %.6f' % prediction[:, 4]) - elif max_index == 5: - print('This could be the sixth category %.6f' % prediction[:, 5]) - elif max_index == 6: - print('This could be the seventh category %.6f' % prediction[:, 6]) - elif max_index == 7: - print('This could be the eighth category %.6f' % prediction[:, 7]) - else: - print('This could be the ninth category %.6f' % prediction[:, 9]) - - -if __name__ == '__main__': - train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/' - train, train_label, val, val_label = get_files(train_dir, 0.3) - img = get_one_image(val) # 通过改变参数train or val,进而验证训练集或测试集 - evaluate_one_image(img) -