更新第二版readme

master
li.chengmeng 3 years ago
parent 2bbb508312
commit 32b55dbbba

@ -1,5 +1,7 @@
# num_mnist 手写数字识别
@ -94,7 +96,7 @@ $$
#### 实现过程
构建网络模型
```python
# 构建网络 开始
@ -116,6 +118,73 @@ model.compile(optimizer='adam', loss=tf.losses.sparse_categorical_crossentropy,
```
训练并评估
```python
# 训练
history = model.fit(
train_images_scaled.reshape(-1, 28, 28, 1),
train_labels,
epochs=8,
validation_data=(test_images.reshape(-1, 28, 28, 1), test_labels),
callbacks=[cp_callback]
)
# 评估
results = model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels)
```
将训练后的模型保存
```python
# 保存模型
model.save(model_path)
```
加载模型
```python
# 读取训练好的模型
model = load_model(model_path)
```
可视化预测效果
```
# 可视化预测效果
show_num = 200
testShow = test_labels[:show_num]
pred = model.predict(test_images.reshape(-1, 28, 28, 1))
predict = []
for item in pred:
predict.append(np.argmax(item))
# 显示折线图
plt.figure()
plt.title('Conv Predict')
plt.ylabel('true number')
plt.xlabel('img num')
plt.plot(range(testShow.size), predict[:show_num], marker='^', color='coral', label='predict')
plt.plot(range(testShow.size), testShow, marker='o', color='deepskyblue', label='result')
plt.legend()
# 挑选出预测错误的图片,并显示预测值
wrongImg = []
wrongNum = []
for i in range(testShow.size):
if (predict[i] != testShow[i]):
wrongImg.append(test_images[i])
wrongNum.append(predict[i])
for i in range(len(wrongImg)):
plt.figure()
plt.title(str(wrongNum[i]))
plt.imshow(wrongImg[i])
plt.show()
```
### 系统测试

@ -15,7 +15,7 @@ sess = tf.compat.v1.Session(config=config)
num_mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = num_mnist.load_data()
# 读取训练好模型
# 读取训练好模型
model = load_model(model_path)
# 打印网络结构

Loading…
Cancel
Save