import os import tensorflow as tf import matplotlib.pyplot as plt from tensorflow import keras import numpy as np import shutil # 检查点路径与模型保存路径 checkpoint_path = "./checkpoint/cp.ckpt" model_path = "./myModel/myModel.h5" checkpoint_dir = os.path.dirname(checkpoint_path) model_dir = os.path.dirname(model_path) # 抑制tensorflow,以防显存占用过多报错 config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) sess = tf.compat.v1.Session(config=config) # 读取手写数字数据集 num_mnist = keras.datasets.mnist (train_images, train_labels), (test_images, test_labels) = num_mnist.load_data() # 图片归一化处理 train_images_scaled = train_images / 255 # 构建网络 开始 ''' 卷积 池化 卷积 池化 全连接 全连接 ''' model = keras.Sequential() model.add(keras.layers.Conv2D(8, (3, 3), activation='relu', input_shape=(28, 28, 1))) model.add(keras.layers.MaxPooling2D(2, 2)) model.add(keras.layers.Conv2D(8, (3, 3), activation='relu')) model.add(keras.layers.MaxPooling2D(2, 2)) model.add(keras.layers.Flatten()) # 扁平化处理,实现从卷积到全连接的过渡 model.add(keras.layers.Dense(128, activation=tf.nn.relu)) model.add(keras.layers.Dense(10, activation=tf.nn.softmax)) model.compile(optimizer='adam', loss=tf.losses.sparse_categorical_crossentropy, metrics=['accuracy']) # 构建网络 结束 # 打印网络结构 model.summary() # 回调函数,用于训练中保存模型 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1) # 发现检查点中存在保存的模型权值,则认为上次训练被中断,读取权值继续训练 if (os.path.exists(checkpoint_dir)): print('检测到未完成的训练,已读取检查点权值继续训练') model.load_weights(checkpoint_path) # 训练 history = model.fit( train_images_scaled.reshape(-1, 28, 28, 1), train_labels, epochs=8, validation_data=(test_images.reshape(-1, 28, 28, 1), test_labels), callbacks=[cp_callback] ) # 保存模型 model.save(model_path) # 删除检查点文件 shutil.rmtree(checkpoint_dir) print("已删除检查点文件") # 评估 results = model.evaluate(test_images.reshape(-1, 28, 28, 1), test_labels)