You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
|
|
import os
|
|
|
|
|
import tensorflow as tf
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
from tensorflow.keras.models import load_model
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
# 模型路径
|
|
|
|
|
model_path = "./myModel/myModel.h5"
|
|
|
|
|
|
|
|
|
|
# 抑制tensorflow,以防显存占用过多报错
|
|
|
|
|
config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
|
|
|
|
|
sess = tf.compat.v1.Session(config=config)
|
|
|
|
|
|
|
|
|
|
# 读取手写数字数据集
|
|
|
|
|
num_mnist = tf.keras.datasets.mnist
|
|
|
|
|
(train_images, train_labels), (test_images, test_labels) = num_mnist.load_data()
|
|
|
|
|
|
|
|
|
|
# 读取训练好点模型
|
|
|
|
|
model = load_model(model_path)
|
|
|
|
|
|
|
|
|
|
# 打印网络结构
|
|
|
|
|
model.summary()
|
|
|
|
|
|
|
|
|
|
# 评估
|
|
|
|
|
model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels)
|
|
|
|
|
|
|
|
|
|
# 可视化预测效果
|
|
|
|
|
show_num = 300
|
|
|
|
|
testShow = test_labels[:show_num]
|
|
|
|
|
|
|
|
|
|
pred = model.predict(test_images.reshape(-1, 28, 28, 1))
|
|
|
|
|
predict = []
|
|
|
|
|
for item in pred:
|
|
|
|
|
predict.append(np.argmax(item))
|
|
|
|
|
|
|
|
|
|
plt.figure()
|
|
|
|
|
plt.title('Conv Predict')
|
|
|
|
|
plt.ylabel('number')
|
|
|
|
|
plt.plot(range(testShow.size), predict[:show_num], label='predict')
|
|
|
|
|
plt.plot(range(testShow.size), testShow, label='result')
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.show()
|