import os import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.keras.models import load_model import numpy as np # 模型路径 model_path = "./myModel/myModel.h5" # 抑制tensorflow,以防显存占用过多报错 config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) sess = tf.compat.v1.Session(config=config) # 读取手写数字数据集 num_mnist = tf.keras.datasets.mnist (train_images, train_labels), (test_images, test_labels) = num_mnist.load_data() # 读取训练好点模型 model = load_model(model_path) # 打印网络结构 model.summary() # 评估 model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels) # 可视化预测效果 show_num = 200 testShow = test_labels[:show_num] pred = model.predict(test_images.reshape(-1, 28, 28, 1)) predict = [] for item in pred: predict.append(np.argmax(item)) # 显示折线图 plt.figure() plt.title('Conv Predict') plt.ylabel('true number') plt.xlabel('img num') plt.plot(range(testShow.size), predict[:show_num], marker='^', color='coral', label='predict') plt.plot(range(testShow.size), testShow, marker='o', color='deepskyblue', label='result') plt.legend() # 挑选出预测错误的图片,并显示预测值 wrongImg = [] wrongNum = [] for i in range(testShow.size): if (predict[i] != testShow[i]): wrongImg.append(test_images[i]) wrongNum.append(predict[i]) for i in range(len(wrongImg)): plt.figure() plt.title(str(wrongNum[i])) plt.imshow(wrongImg[i]) plt.show()