You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

158 lines
6.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tensorflow.keras as keras
import numpy as np
import os
import tensorflow as tf
import matplotlib.pyplot as plt
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据归一化
x_train = x_train/255.0
x_test = x_test/255.0
# 为了卷积层输入格式,增加一个维度
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
# print("train shape:", x_train.shape)
# print("test shape:", x_test.shape)
# 使用此类进行图形增强
datagen = keras.preprocessing.image.ImageDataGenerator(
rotation_range=20, # 整数。随机旋转的度数范围。
width_shift_range=0.20, # 浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。
shear_range=15, # 浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。
zoom_range=0.10, # 浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用来进行随机的放大。
validation_split=0.15, # 浮点型。保留用于验证集的图像比例严格在0,1之间
horizontal_flip=False # 布尔值,随机水平翻转。
)
train_datagen = datagen.flow(
x_train,
y_train,
batch_size=256,
subset="training"
)
validation_genetor = datagen.flow(
x_train,
y_train,
batch_size=64,
subset="validation"
)
def creat_model():
"""
此函数用于构建模型。
调用keras.Sequential()建立模型的框架。
使用compile()配置神经网络训练方法
:return:
"""
model = keras.Sequential([
keras.layers.Reshape((28, 28, 1)),
# 卷积层定义
keras.layers.Conv2D(filters=32, kernel_size=(5, 5), activation="relu", padding="same",
input_shape=(28, 28, 1)),
# 池化层定义
keras.layers.MaxPool2D((2, 2)),
# 卷积层、池化层
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.MaxPool2D((2, 2)),
# 卷积层、池化层
keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.MaxPool2D((2, 2)),
# 分类器:多个全连接层
keras.layers.Flatten(),
keras.layers.Dense(512, activation="sigmoid"),
keras.layers.Dropout(0.25),
keras.layers.Dense(512, activation="sigmoid"),
keras.layers.Dropout(0.25),
keras.layers.Dense(256, activation="sigmoid"),
keras.layers.Dropout(0.1),
keras.layers.Dense(10, activation="sigmoid")
])
# 使用compile()配置神经网络训练方法
# model.compile(optimizer=优化器,
# loss=损失函数,
# metrics=["评测指标"])
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=["sparse_categorical_accuracy"])
return model
def model_fit(model, check_save_path):
# 判断是否有保存的模型,有就加载之
if os.path.exists(check_save_path+'.index'):
print("load modals...")
model.load_weights(check_save_path)
# 使用callbacks.ModelCheckpoint()保存模型
# tf.keras.callbacks.ModelCheckpoint(filepath=保存模型的路径,
# save_weights_only=若设置为True则仅保存模型权重,设置为False则保存整个模型
# save_best_only=若设置为True将只保存在验证集上性能最好的模型
# )
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
save_weights_only=True,
save_best_only=True)
# 用于动态调整学习率,参数如下:
# monitor监测的值可以是accuracyval_loss,val_accuracy
# factor缩放学习率的值学习率将以lr = lr*factor的形式被减少
# patience当patience个epoch过去而模型性能不提升时学习率减少的动作会被触发
# min_lr学习率最小值能缩小到的下限
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
factor=0.1,
patience=2,
min_lr=0.000001,
verbose=1)
# 早停机制
# monitor要监测的数量
# min_delta在被监测的数据中被认为是提升的最小变化即绝对变化小于min_delta将被视为没有提升。
# patience没有进步的训练轮数在这之后训练就会被停止。
earlystop_callback = keras.callbacks.EarlyStopping(
monitor='val_sparse_categorical_accuracy',
min_delta=0.0001,
patience=2)
history = model.fit(train_datagen, epochs=10, validation_data=validation_genetor, callbacks=[reduce_lr, cp_callback, earlystop_callback], verbose=1)
model.summary()
# 用于可视化
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# 此函数通过测试集检测准确率
def model_valtest(model, check_save_path):
model.load_weights(check_save_path)
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
if __name__ == "__main__":
check_save_path = "./checkpoint/mnist_cnn3.ckpt"
model = creat_model()
#model_fit(model, check_save_path)
model_valtest(model, check_save_path)