You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.3 KiB

import tensorflow as tf
import os
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
def creat_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=["sparse_categorical_accuracy"])
return model
def model_fit(model,check_save_path):
if os.path.exists(check_save_path+'.index'):
print("load modals...")
model.load_weights(check_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback])
model.summary()
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
def eva_acc(str, model):
model.load_weights(str)
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
if __name__=="__main__":
check_save_path = "./checkpoint/mnist.ckpt"
model = creat_model()
model_fit(model, check_save_path)
eva_acc(check_save_path, model)