diff --git a/myModel/myModel.h5 b/myModel/myModel.h5 index 3353b3f..f9f081b 100644 Binary files a/myModel/myModel.h5 and b/myModel/myModel.h5 differ diff --git a/predict.py b/predict.py index 42e2afa..b2699e6 100644 --- a/predict.py +++ b/predict.py @@ -25,7 +25,7 @@ model.summary() model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels) # 可视化预测效果 -show_num = 170 +show_num = 200 testShow = test_labels[:show_num] pred = model.predict(test_images.reshape(-1, 28, 28, 1)) @@ -33,10 +33,25 @@ predict = [] for item in pred: predict.append(np.argmax(item)) +# 显示折线图 plt.figure() plt.title('Conv Predict') plt.ylabel('number') -plt.plot(range(testShow.size), predict[:show_num], marker='^', color='coral',label='predict') -plt.plot(range(testShow.size), testShow, marker='o',color='deepskyblue',label='result') +plt.plot(range(testShow.size), predict[:show_num], marker='^', color='coral', label='predict') +plt.plot(range(testShow.size), testShow, marker='o', color='deepskyblue', label='result') plt.legend() + +# 挑选出预测错误的图片,并显示预测值 +wrongImg = [] +wrongNum = [] +for i in range(testShow.size): + if (predict[i] != testShow[i]): + wrongImg.append(test_images[i]) + wrongNum.append(predict[i]) + +for i in range(len(wrongImg)): + plt.figure() + plt.title(str(wrongNum[i])) + plt.imshow(wrongImg[i]) + plt.show()