diff --git a/predict.py b/predict.py index 994ece3..42e2afa 100644 --- a/predict.py +++ b/predict.py @@ -25,7 +25,7 @@ model.summary() model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels) # 可视化预测效果 -show_num = 300 +show_num = 170 testShow = test_labels[:show_num] pred = model.predict(test_images.reshape(-1, 28, 28, 1)) @@ -36,7 +36,7 @@ for item in pred: plt.figure() plt.title('Conv Predict') plt.ylabel('number') -plt.plot(range(testShow.size), predict[:show_num], label='predict') -plt.plot(range(testShow.size), testShow, label='result') +plt.plot(range(testShow.size), predict[:show_num], marker='^', color='coral',label='predict') +plt.plot(range(testShow.size), testShow, marker='o',color='deepskyblue',label='result') plt.legend() plt.show()