import os import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.keras.models import load_model import numpy as np # 模型路径 model_path = "./myModel/myModel.h5" # 抑制tensorflow,以防显存占用过多报错 config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) sess = tf.compat.v1.Session(config=config) # 读取手写数字数据集 num_mnist = tf.keras.datasets.mnist (train_images, train_labels), (test_images, test_labels) = num_mnist.load_data() # 读取训练好点模型 model = load_model(model_path) # 打印网络结构 model.summary() # 评估 model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels) # 可视化预测效果 show_num = 170 testShow = test_labels[:show_num] pred = model.predict(test_images.reshape(-1, 28, 28, 1)) predict = [] for item in pred: predict.append(np.argmax(item)) plt.figure() plt.title('Conv Predict') plt.ylabel('number') plt.plot(range(testShow.size), predict[:show_num], marker='^', color='coral',label='predict') plt.plot(range(testShow.size), testShow, marker='o',color='deepskyblue',label='result') plt.legend() plt.show()