diff --git a/README.md b/README.md index c4dcd66..45f0655 100644 --- a/README.md +++ b/README.md @@ -108,3 +108,45 @@ model.evaluate(input_fn) # 预测单个图像 n_images = 6 +# 从数据集得到测试图像 +test_images = mnist.test.images[:n_images] +# 准备输入数据 +input_fn = tf.estimator.inputs.numpy_input_fn( + x={'images': test_images}, shuffle=False) +# 用训练好的模型预测图片类别 +preds = list(model.predict(input_fn)) + +# 可视化显示 +for i in range(n_images): + plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray') + plt.show() + print("Model prediction:", preds[i]) + + + # 从数据集得到测试图像 +test_images = mnist.test.images[:n_images] +# 准备输入数据 +input_fn = tf.estimator.inputs.numpy_input_fn( + x={'images': test_images}, shuffle=False) +# 用训练好的模型预测图片类别 +preds = list(model.predict(input_fn)) + +# 可视化显示 +for i in range(n_images): + plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray') + plt.show() + print("Model prediction:", preds[i]) + +# 从数据集得到测试图像 +test_images = mnist.test.images[:n_images] +# 准备输入数据 +input_fn = tf.estimator.inputs.numpy_input_fn( + x={'images': test_images}, shuffle=False) +# 用训练好的模型预测图片类别 +preds = list(model.predict(input_fn)) + +# 可视化显示 +for i in range(n_images): + plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray') + plt.show() + print("Model prediction:", preds[i])