You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
3.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tensorflow as tf
import tensorflow.keras as keras
import os
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
mnist = keras.datasets.mnist
# 载入mnist数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 归一化处理样本,把每个像素点的数据范围由[0-255]转为[0-1]
x_train = x_train/255.0
x_test = x_test/255.0
# 增加一个维度
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
def creat_model():
model = keras.Sequential([
# 卷积层定义
keras.layers.Conv2D(32, 3, padding="SAME"),
keras.layers.BatchNormalization(),
keras.layers.Activation("relu"),
keras.layers.MaxPool2D(pool_size=(2, 2), strides=2),
keras.layers.Dropout(0.2),
# 分类器定义
keras.layers.Flatten(),
keras.layers.Dense(128, activation="relu"),
keras.layers.Dense(10, activation="softmax")
])
model.compile(optimizer='adam', # 优化方法选用adam
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
# 指定每个批次训练误差的减小方法
metrics=["sparse_categorical_accuracy"])
# 评价函数
return model
def eva_acc(str, model):
# 加载测试集判断准确率
model.load_weights(str)
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
def model_fit(model, check_save_path):
if os.path.exists(check_save_path+'.index'):
print("load modals...")
model.load_weights(check_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
save_weights_only=True,
save_best_only=True)
ES_callback = keras.callbacks.EarlyStopping(
monitor="val_sparse_categorical_accuracy", # 数据的监视入口
min_delta=0.001, # 增大或减小的阈值只有大于这个部分才算作improvement。
patience=1, # 能够容忍多少个epo
# ch内都没有improvement。
mode='max' # auto, min, ,max三个可能。如果知道是要上升还是下降建议设置一下。
)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback, ES_callback])
# 输出参数计算
model.summary()
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
# 图形化
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
if __name__=="__main__":
check_save_path = "./checkpoint/mnist_cnn1.ckpt"
model = creat_model()
model_fit(model, check_save_path)
eva_acc(check_save_path, model)