import tensorflow as tf import os from matplotlib import pyplot as plt from PIL import Image import numpy as np mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train/255.0, x_test/255.0 def creat_model(): model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=["sparse_categorical_accuracy"]) return model def model_fit(model,check_save_path): if os.path.exists(check_save_path+'.index'): print("load modals...") model.load_weights(check_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path, save_weights_only=True, save_best_only=True) history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback]) model.summary() final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) print("Model accuracy: ", final_acc, ", model loss: ", final_loss) acc = history.history['sparse_categorical_accuracy'] val_acc = history.history['val_sparse_categorical_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show() def eva_acc(str, model): model.load_weights(str) final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) print("Model accuracy: ", final_acc, ", model loss: ", final_loss) if __name__=="__main__": check_save_path = "./checkpoint/mnist.ckpt" model = creat_model() model_fit(model, check_save_path) eva_acc(check_save_path, model)