diff --git a/README.md b/README.md index 699d33a..5cc53c3 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ + + # num_mnist 手写数字识别 @@ -94,7 +96,7 @@ $$ #### 实现过程 - +构建网络模型 ```python # 构建网络 开始 @@ -116,6 +118,73 @@ model.compile(optimizer='adam', loss=tf.losses.sparse_categorical_crossentropy, ``` +训练并评估 + +```python +# 训练 +history = model.fit( + train_images_scaled.reshape(-1, 28, 28, 1), + train_labels, + epochs=8, + validation_data=(test_images.reshape(-1, 28, 28, 1), test_labels), + callbacks=[cp_callback] +) + +# 评估 +results = model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels) +``` + +将训练后的模型保存 + +```python +# 保存模型 +model.save(model_path) +``` + +加载模型 + +```python +# 读取训练好的模型 +model = load_model(model_path) +``` + +可视化预测效果 + +``` +# 可视化预测效果 +show_num = 200 +testShow = test_labels[:show_num] + +pred = model.predict(test_images.reshape(-1, 28, 28, 1)) +predict = [] +for item in pred: + predict.append(np.argmax(item)) + +# 显示折线图 +plt.figure() +plt.title('Conv Predict') +plt.ylabel('true number') +plt.xlabel('img num') +plt.plot(range(testShow.size), predict[:show_num], marker='^', color='coral', label='predict') +plt.plot(range(testShow.size), testShow, marker='o', color='deepskyblue', label='result') +plt.legend() + +# 挑选出预测错误的图片,并显示预测值 +wrongImg = [] +wrongNum = [] +for i in range(testShow.size): + if (predict[i] != testShow[i]): + wrongImg.append(test_images[i]) + wrongNum.append(predict[i]) + +for i in range(len(wrongImg)): + plt.figure() + plt.title(str(wrongNum[i])) + plt.imshow(wrongImg[i]) + +plt.show() +``` + ### 系统测试 diff --git a/predict.py b/predict.py index 347be77..e52c8e3 100644 --- a/predict.py +++ b/predict.py @@ -15,7 +15,7 @@ sess = tf.compat.v1.Session(config=config) num_mnist = tf.keras.datasets.mnist (train_images, train_labels), (test_images, test_labels) = num_mnist.load_data() -# 读取训练好点模型 +# 读取训练好的模型 model = load_model(model_path) # 打印网络结构