|
|
import os
|
|
|
import tensorflow as tf
|
|
|
import matplotlib.pyplot as plt
|
|
|
from tensorflow import keras
|
|
|
import numpy as np
|
|
|
import shutil
|
|
|
|
|
|
# 检查点路径与模型保存路径
|
|
|
checkpoint_path = "./checkpoint/cp.ckpt"
|
|
|
model_path = "./myModel/myModel.h5"
|
|
|
checkpoint_dir = os.path.dirname(checkpoint_path)
|
|
|
model_dir = os.path.dirname(model_path)
|
|
|
|
|
|
# 抑制tensorflow,以防显存占用过多报错
|
|
|
config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True))
|
|
|
sess = tf.compat.v1.Session(config=config)
|
|
|
|
|
|
# 读取手写数字数据集
|
|
|
num_mnist = keras.datasets.mnist
|
|
|
(train_images, train_labels), (test_images, test_labels) = num_mnist.load_data()
|
|
|
# 图片归一化处理
|
|
|
train_images_scaled = train_images / 255
|
|
|
|
|
|
# 构建网络 开始
|
|
|
'''
|
|
|
卷积 池化 卷积 池化 全连接 全连接
|
|
|
'''
|
|
|
model = keras.Sequential()
|
|
|
model.add(keras.layers.Conv2D(8, (3, 3), activation='relu', input_shape=(28, 28, 1)))
|
|
|
model.add(keras.layers.MaxPooling2D(2, 2))
|
|
|
model.add(keras.layers.Conv2D(8, (3, 3), activation='relu'))
|
|
|
model.add(keras.layers.MaxPooling2D(2, 2))
|
|
|
|
|
|
model.add(keras.layers.Flatten()) # 扁平化处理,实现从卷积到全连接的过渡
|
|
|
model.add(keras.layers.Dense(128, activation=tf.nn.relu))
|
|
|
model.add(keras.layers.Dense(10, activation=tf.nn.softmax))
|
|
|
|
|
|
model.compile(optimizer='adam', loss=tf.losses.sparse_categorical_crossentropy, metrics=['accuracy'])
|
|
|
# 构建网络 结束
|
|
|
|
|
|
# 打印网络结构
|
|
|
model.summary()
|
|
|
|
|
|
# 回调函数,用于训练中保存模型
|
|
|
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
|
|
|
save_weights_only=True,
|
|
|
verbose=1)
|
|
|
|
|
|
# 发现检查点中存在保存的模型权值,则认为上次训练被中断,读取权值继续训练
|
|
|
if (os.path.exists(checkpoint_dir)):
|
|
|
print('检测到未完成的训练,已读取检查点权值继续训练')
|
|
|
model.load_weights(checkpoint_path)
|
|
|
|
|
|
# 训练
|
|
|
history = model.fit(
|
|
|
train_images_scaled.reshape(-1, 28, 28, 1),
|
|
|
train_labels,
|
|
|
epochs=8,
|
|
|
validation_data=(test_images.reshape(-1, 28, 28, 1), test_labels),
|
|
|
callbacks=[cp_callback]
|
|
|
)
|
|
|
|
|
|
# 保存模型
|
|
|
model.save(model_path)
|
|
|
|
|
|
# 删除检查点文件
|
|
|
shutil.rmtree(checkpoint_dir)
|
|
|
print("已删除检查点文件")
|
|
|
|
|
|
# 评估
|
|
|
results = model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels)
|