|
|
import tensorflow as tf
|
|
|
import tensorflow.keras as keras
|
|
|
import os
|
|
|
from matplotlib import pyplot as plt
|
|
|
from PIL import Image
|
|
|
import numpy as np
|
|
|
|
|
|
mnist = keras.datasets.mnist
|
|
|
# 载入mnist数据集
|
|
|
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
|
|
|
|
|
# 归一化处理样本,把每个像素点的数据范围由[0-255]转为[0-1]
|
|
|
x_train = x_train/255.0
|
|
|
x_test = x_test/255.0
|
|
|
|
|
|
# 增加一个维度
|
|
|
x_train = tf.expand_dims(x_train, -1)
|
|
|
x_test = tf.expand_dims(x_test, -1)
|
|
|
|
|
|
def creat_model():
|
|
|
|
|
|
model = keras.Sequential([
|
|
|
# 卷积层定义
|
|
|
keras.layers.Conv2D(32, 3, padding="SAME"),
|
|
|
keras.layers.BatchNormalization(),
|
|
|
keras.layers.Activation("relu"),
|
|
|
keras.layers.MaxPool2D(pool_size=(2, 2), strides=2),
|
|
|
keras.layers.Dropout(0.2),
|
|
|
# 分类器定义
|
|
|
keras.layers.Flatten(),
|
|
|
keras.layers.Dense(128, activation="relu"),
|
|
|
keras.layers.Dense(10, activation="softmax")
|
|
|
])
|
|
|
|
|
|
model.compile(optimizer='adam', # 优化方法选用adam
|
|
|
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
|
|
# 指定每个批次训练误差的减小方法
|
|
|
metrics=["sparse_categorical_accuracy"])
|
|
|
# 评价函数
|
|
|
return model
|
|
|
|
|
|
|
|
|
def eva_acc(str, model):
|
|
|
# 加载测试集判断准确率
|
|
|
model.load_weights(str)
|
|
|
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
|
|
|
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
|
|
|
|
|
|
|
|
|
|
|
|
def model_fit(model, check_save_path):
|
|
|
if os.path.exists(check_save_path+'.index'):
|
|
|
print("load modals...")
|
|
|
model.load_weights(check_save_path)
|
|
|
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
|
|
|
save_weights_only=True,
|
|
|
save_best_only=True)
|
|
|
ES_callback = keras.callbacks.EarlyStopping(
|
|
|
monitor="val_sparse_categorical_accuracy", # 数据的监视入口
|
|
|
min_delta=0.001, # 增大或减小的阈值,只有大于这个部分才算作improvement。
|
|
|
patience=1, # 能够容忍多少个epo
|
|
|
# ch内都没有improvement。
|
|
|
mode='max' # ’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。
|
|
|
)
|
|
|
|
|
|
|
|
|
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback, ES_callback])
|
|
|
# 输出参数计算
|
|
|
model.summary()
|
|
|
acc = history.history['sparse_categorical_accuracy']
|
|
|
val_acc = history.history['val_sparse_categorical_accuracy']
|
|
|
loss = history.history['loss']
|
|
|
val_loss = history.history['val_loss']
|
|
|
# 图形化
|
|
|
plt.subplot(1, 2, 1)
|
|
|
plt.plot(acc, label='Training Accuracy')
|
|
|
plt.plot(val_acc, label='Validation Accuracy')
|
|
|
plt.title('Training and Validation Accuracy')
|
|
|
plt.legend()
|
|
|
plt.subplot(1, 2, 2)
|
|
|
plt.plot(loss, label='Training Loss')
|
|
|
plt.plot(val_loss, label='Validation Loss')
|
|
|
plt.title('Training and Validation Loss')
|
|
|
plt.legend()
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
if __name__=="__main__":
|
|
|
check_save_path = "./checkpoint/mnist_cnn1.ckpt"
|
|
|
model = creat_model()
|
|
|
model_fit(model, check_save_path)
|
|
|
eva_acc(check_save_path, model)
|