|
|
|
@ -1,5 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# num_mnist 手写数字识别
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -94,7 +96,7 @@ $$
|
|
|
|
|
|
|
|
|
|
#### 实现过程
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
构建网络模型
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
# 构建网络 开始
|
|
|
|
@ -116,6 +118,73 @@ model.compile(optimizer='adam', loss=tf.losses.sparse_categorical_crossentropy,
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
训练并评估
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
# 训练
|
|
|
|
|
history = model.fit(
|
|
|
|
|
train_images_scaled.reshape(-1, 28, 28, 1),
|
|
|
|
|
train_labels,
|
|
|
|
|
epochs=8,
|
|
|
|
|
validation_data=(test_images.reshape(-1, 28, 28, 1), test_labels),
|
|
|
|
|
callbacks=[cp_callback]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 评估
|
|
|
|
|
results = model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
将训练后的模型保存
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
# 保存模型
|
|
|
|
|
model.save(model_path)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
加载模型
|
|
|
|
|
|
|
|
|
|
```python
|
|
|
|
|
# 读取训练好的模型
|
|
|
|
|
model = load_model(model_path)
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
可视化预测效果
|
|
|
|
|
|
|
|
|
|
```
|
|
|
|
|
# 可视化预测效果
|
|
|
|
|
show_num = 200
|
|
|
|
|
testShow = test_labels[:show_num]
|
|
|
|
|
|
|
|
|
|
pred = model.predict(test_images.reshape(-1, 28, 28, 1))
|
|
|
|
|
predict = []
|
|
|
|
|
for item in pred:
|
|
|
|
|
predict.append(np.argmax(item))
|
|
|
|
|
|
|
|
|
|
# 显示折线图
|
|
|
|
|
plt.figure()
|
|
|
|
|
plt.title('Conv Predict')
|
|
|
|
|
plt.ylabel('true number')
|
|
|
|
|
plt.xlabel('img num')
|
|
|
|
|
plt.plot(range(testShow.size), predict[:show_num], marker='^', color='coral', label='predict')
|
|
|
|
|
plt.plot(range(testShow.size), testShow, marker='o', color='deepskyblue', label='result')
|
|
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
|
# 挑选出预测错误的图片,并显示预测值
|
|
|
|
|
wrongImg = []
|
|
|
|
|
wrongNum = []
|
|
|
|
|
for i in range(testShow.size):
|
|
|
|
|
if (predict[i] != testShow[i]):
|
|
|
|
|
wrongImg.append(test_images[i])
|
|
|
|
|
wrongNum.append(predict[i])
|
|
|
|
|
|
|
|
|
|
for i in range(len(wrongImg)):
|
|
|
|
|
plt.figure()
|
|
|
|
|
plt.title(str(wrongNum[i]))
|
|
|
|
|
plt.imshow(wrongImg[i])
|
|
|
|
|
|
|
|
|
|
plt.show()
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
### 系统测试
|
|
|
|
|