You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
4.5 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import forward
from batchdealing import get_files
# 获取一张图片
def get_one_image(train):
# 输入参数train,训练图片的路径
# 返回参数image从训练图片中随机抽取一张图片
n = len(train)
ind = np.random.randint(0, n)
img_dir = train[ind] # 随机选择测试的图片
img = Image.open(img_dir)
# 显示图片在jupyter notebook下当然也可以不用plt.show()
plt.imshow(img)
plt.show(img)
imag = img.resize([64, 64]) # 由于图片在预处理阶段以及resize因此该命令可略
image = np.array(imag)
return image
# 测试图片
def evaluate_one_image(image_array):
with tf.Graph().as_default():
BATCH_SIZE = 1
N_CLASSES = 4
image = tf.cast(image_array, tf.float32)
# 线性缩放图像以具有零均值和单位范数。
image = tf.image.per_image_standardization(image)
image = tf.reshape(image, [1, 64, 64, 3])
# 构建卷积神经网络
logit = forward.inference(image, BATCH_SIZE, N_CLASSES)
# softmax函数的作用就是归一化
# 输入: 全连接层往往是模型的最后一层的值一般代码中叫做logits。
# 输出: 归一化的值含义是属于该位置的概率一般代码叫做probs。
logit = tf.nn.softmax(logit)
x = tf.compat.v1.placeholder(tf.float32, shape=[64, 64, 3])
# you need to change the directories to yours. /Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/
logs_train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/'
# tf.train.Saver() 保存和加载模型
saver = tf.compat.v1.train.Saver()
with tf.compat.v1.Session() as sess:
print("Reading checkpoints...")
ckpt = tf.train.get_checkpoint_state(logs_train_dir)
if ckpt and ckpt.model_checkpoint_path:
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
saver.restore(sess, ckpt.model_checkpoint_path)
print('Loading success, global_step is %s' % global_step)
else:
print('No checkpoint file found')
# feed_dict的作用是给使用placeholder创建出来的tensor赋值。
# 其实他的作用更加广泛feed使用一个值临时替换一个op的输出结果。
# 你可以提供feed数据作为run()调用的参数。
# feed只在调用它的方法内有效,方法结束,feed就会消失。
# 当我们构建完图后需要在一个会话中启动图启动的第一步是创建一个Session对象。
# 为了取回Fetch操作的输出内容可以在使用Session对象的run()调用执行图时,
# 传入一些tensor这些tensor会帮助你取回结果。
prediction = sess.run(logit, feed_dict={x: image_array})
# 取出prediction中元素最大值所对应的索引也就是最大的可能
max_index = np.argmax(prediction)
if max_index == 0:
print('This could be the first category %.6f' % prediction[:, 0])
elif max_index == 1:
print('This could be the second category %.6f' % prediction[:, 1])
elif max_index == 2:
print('This could be the third category %.6f' % prediction[:, 2])
elif max_index == 3:
print('This could be the fourth category %.6f' % prediction[:, 3])
elif max_index == 4:
print('This could be the fifth category %.6f' % prediction[:, 4])
elif max_index == 5:
print('This could be the sixth category %.6f' % prediction[:, 5])
elif max_index == 6:
print('This could be the seventh category %.6f' % prediction[:, 6])
elif max_index == 7:
print('This could be the eighth category %.6f' % prediction[:, 7])
else:
print('This could be the ninth category %.6f' % prediction[:, 9])
if __name__ == '__main__':
train_dir = '/Users/leixinhong/PycharmProjects/classification/teethimg/Re_train/'
train, train_label, val, val_label = get_files(train_dir, 0.3)
img = get_one_image(val) # 通过改变参数train or val进而验证训练集或测试集
evaluate_one_image(img)