|
|
|
|
@ -108,17 +108,3 @@ model.evaluate(input_fn)
|
|
|
|
|
|
|
|
|
|
# 预测单个图像
|
|
|
|
|
n_images = 6
|
|
|
|
|
# 从数据集得到测试图像
|
|
|
|
|
test_images = mnist.test.images[:n_images]
|
|
|
|
|
# 准备输入数据
|
|
|
|
|
input_fn = tf.estimator.inputs.numpy_input_fn(
|
|
|
|
|
x={'images': test_images}, shuffle=False)
|
|
|
|
|
# 用训练好的模型预测图片类别
|
|
|
|
|
preds = list(model.predict(input_fn))
|
|
|
|
|
|
|
|
|
|
# 可视化显示
|
|
|
|
|
for i in range(n_images):
|
|
|
|
|
plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray')
|
|
|
|
|
plt.show()
|
|
|
|
|
print("Model prediction:", preds[i])
|
|
|
|
|
|
|
|
|
|
|