import tensorflow as tf import tensorflow.keras as keras import os from matplotlib import pyplot as plt from PIL import Image import numpy as np mnist = keras.datasets.mnist # 载入mnist数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data() # 归一化处理样本,把每个像素点的数据范围由[0-255]转为[0-1] x_train = x_train/255.0 x_test = x_test/255.0 # 增加一个维度 x_train = tf.expand_dims(x_train, -1) x_test = tf.expand_dims(x_test, -1) def creat_model(): model = keras.Sequential([ # 卷积层定义 keras.layers.Conv2D(32, 3, padding="SAME"), keras.layers.BatchNormalization(), keras.layers.Activation("relu"), keras.layers.MaxPool2D(pool_size=(2, 2), strides=2), keras.layers.Dropout(0.2), # 分类器定义 keras.layers.Flatten(), keras.layers.Dense(128, activation="relu"), keras.layers.Dense(10, activation="softmax") ]) model.compile(optimizer='adam', # 优化方法选用adam loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), # 指定每个批次训练误差的减小方法 metrics=["sparse_categorical_accuracy"]) # 评价函数 return model def eva_acc(str, model): # 加载测试集判断准确率 model.load_weights(str) final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) print("Model accuracy: ", final_acc, ", model loss: ", final_loss) def model_fit(model, check_save_path): if os.path.exists(check_save_path+'.index'): print("load modals...") model.load_weights(check_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path, save_weights_only=True, save_best_only=True) ES_callback = keras.callbacks.EarlyStopping( monitor="val_sparse_categorical_accuracy", # 数据的监视入口 min_delta=0.001, # 增大或减小的阈值,只有大于这个部分才算作improvement。 patience=1, # 能够容忍多少个epo # ch内都没有improvement。 mode='max' # ’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。 ) history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback, ES_callback]) # 输出参数计算 model.summary() acc = history.history['sparse_categorical_accuracy'] val_acc = history.history['val_sparse_categorical_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] # 图形化 plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show() if __name__=="__main__": check_save_path = "./checkpoint/mnist_cnn1.ckpt" model = creat_model() model_fit(model, check_save_path) eva_acc(check_save_path, model)