from PIL import Image import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import forward from batchdealing import get_files # 获取一张图片 def get_one_image(train): # 输入参数:train,训练图片的路径 # 返回参数:image,从训练图片中随机抽取一张图片 n = len(train) ind = np.random.randint(0, n) img_dir = train[ind] # 随机选择测试的图片 img = Image.open(img_dir) # 显示图片,在jupyter notebook下当然也可以不用plt.show() plt.imshow(img) plt.show(img) imag = img.resize([64, 64]) # 由于图片在预处理阶段以及resize,因此该命令可略 image = np.array(imag) return image # 测试图片 def evaluate_one_image(image_array): with tf.Graph().as_default(): BATCH_SIZE = 1 N_CLASSES = 4 image = tf.cast(image_array, tf.float32) # 线性缩放图像以具有零均值和单位范数。 image = tf.image.per_image_standardization(image) image = tf.reshape(image, [1, 64, 64, 3]) # 构建卷积神经网络 logit = forward.inference(image, BATCH_SIZE, N_CLASSES) # softmax函数的作用就是归一化 # 输入: 全连接层(往往是模型的最后一层)的值,一般代码中叫做logits。 # 输出: 归一化的值,含义是属于该位置的概率,一般代码叫做probs。 logit = tf.nn.softmax(logit) x = tf.compat.v1.placeholder(tf.float32, shape=[64, 64, 3]) # you need to change the directories to yours. /Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/ logs_train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/' # tf.train.Saver() 保存和加载模型 saver = tf.compat.v1.train.Saver() with tf.compat.v1.Session() as sess: print("Reading checkpoints...") ckpt = tf.train.get_checkpoint_state(logs_train_dir) if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] saver.restore(sess, ckpt.model_checkpoint_path) print('Loading success, global_step is %s' % global_step) else: print('No checkpoint file found') # feed_dict的作用是给使用placeholder创建出来的tensor赋值。 # 其实,他的作用更加广泛:feed使用一个值临时替换一个op的输出结果。 # 你可以提供feed数据作为run()调用的参数。 # feed只在调用它的方法内有效,方法结束,feed就会消失。 # 当我们构建完图后,需要在一个会话中启动图,启动的第一步是创建一个Session对象。 # 为了取回(Fetch)操作的输出内容,可以在使用Session对象的run()调用执行图时, # 传入一些tensor,这些tensor会帮助你取回结果。 prediction = sess.run(logit, feed_dict={x: image_array}) # 取出prediction中元素最大值所对应的索引,也就是最大的可能 max_index = np.argmax(prediction) if max_index == 0: print('This could be the first category %.6f' % prediction[:, 0]) elif max_index == 1: print('This could be the second category %.6f' % prediction[:, 1]) elif max_index == 2: print('This could be the third category %.6f' % prediction[:, 2]) elif max_index == 3: print('This could be the fourth category %.6f' % prediction[:, 3]) elif max_index == 4: print('This could be the fifth category %.6f' % prediction[:, 4]) elif max_index == 5: print('This could be the sixth category %.6f' % prediction[:, 5]) elif max_index == 6: print('This could be the seventh category %.6f' % prediction[:, 6]) elif max_index == 7: print('This could be the eighth category %.6f' % prediction[:, 7]) else: print('This could be the ninth category %.6f' % prediction[:, 9]) if __name__ == '__main__': train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/' train, train_label, val, val_label = get_files(train_dir, 0.3) img = get_one_image(val) # 通过改变参数train or val,进而验证训练集或测试集 evaluate_one_image(img)