|
|
import tensorflow.keras as keras
|
|
|
import numpy as np
|
|
|
import os
|
|
|
import tensorflow as tf
|
|
|
import matplotlib.pyplot as plt
|
|
|
mnist = keras.datasets.mnist
|
|
|
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
|
|
x_train = x_train/255.0
|
|
|
x_test = x_test/255.0
|
|
|
|
|
|
x_train = tf.expand_dims(x_train, -1)
|
|
|
x_test = tf.expand_dims(x_test, -1)
|
|
|
|
|
|
print("train shape:", x_train.shape)
|
|
|
print("test shape:", x_test.shape)
|
|
|
|
|
|
# 使用此类进行图形增强
|
|
|
datagen = keras.preprocessing.image.ImageDataGenerator(
|
|
|
rotation_range=20, # 整数。随机旋转的度数范围。
|
|
|
width_shift_range=0.20, # 浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。
|
|
|
shear_range=15, # 浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。
|
|
|
zoom_range=0.10, # 浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用来进行随机的放大。
|
|
|
validation_split=0.15, # 浮点型。保留用于验证集的图像比例(严格在0,1之间)
|
|
|
horizontal_flip=False # 布尔值,随机水平翻转。
|
|
|
)
|
|
|
|
|
|
train_datagen = datagen.flow(
|
|
|
x_train,
|
|
|
y_train,
|
|
|
batch_size=256,
|
|
|
subset="training"
|
|
|
)
|
|
|
|
|
|
validation_genetor = datagen.flow(
|
|
|
x_train,
|
|
|
y_train,
|
|
|
batch_size=64,
|
|
|
subset="validation"
|
|
|
)
|
|
|
def creat_model():
|
|
|
model = keras.Sequential([
|
|
|
keras.layers.Reshape((28, 28, 1)),
|
|
|
keras.layers.Conv2D(filters=32, kernel_size=(5, 5), activation="relu", padding="same",
|
|
|
input_shape=(28, 28, 1)),
|
|
|
keras.layers.MaxPool2D((2, 2)),
|
|
|
|
|
|
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
|
|
|
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
|
|
|
keras.layers.MaxPool2D((2, 2)),
|
|
|
|
|
|
keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
|
|
|
keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
|
|
|
keras.layers.MaxPool2D((2, 2)),
|
|
|
|
|
|
keras.layers.Flatten(),
|
|
|
keras.layers.Dense(512, activation="sigmoid"),
|
|
|
keras.layers.Dropout(0.25),
|
|
|
|
|
|
keras.layers.Dense(512, activation="sigmoid"),
|
|
|
keras.layers.Dropout(0.25),
|
|
|
|
|
|
keras.layers.Dense(256, activation="sigmoid"),
|
|
|
keras.layers.Dropout(0.1),
|
|
|
|
|
|
keras.layers.Dense(10, activation="sigmoid")
|
|
|
])
|
|
|
|
|
|
model.compile(optimizer='adam',
|
|
|
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
|
|
metrics=["sparse_categorical_accuracy"])
|
|
|
return model
|
|
|
|
|
|
def model_fit(model, check_save_path):
|
|
|
if os.path.exists(check_save_path+'.index'):
|
|
|
print("load modals...")
|
|
|
model.load_weights(check_save_path)
|
|
|
|
|
|
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
|
|
|
save_weights_only=True,
|
|
|
save_best_only=True)
|
|
|
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
|
|
|
factor=0.1,
|
|
|
patience=5,
|
|
|
min_lr=0.000001,
|
|
|
verbose=1)
|
|
|
|
|
|
|
|
|
history = model.fit(train_datagen, epochs=1, validation_data=validation_genetor, callbacks=[reduce_lr,cp_callback],verbose=1)
|
|
|
model.summary()
|
|
|
|
|
|
acc = history.history['sparse_categorical_accuracy']
|
|
|
val_acc = history.history['val_sparse_categorical_accuracy']
|
|
|
loss = history.history['loss']
|
|
|
val_loss = history.history['val_loss']
|
|
|
|
|
|
plt.subplot(1, 2, 1)
|
|
|
plt.plot(acc, label='Training Accuracy')
|
|
|
plt.plot(val_acc, label='Validation Accuracy')
|
|
|
plt.title('Training and Validation Accuracy')
|
|
|
plt.legend()
|
|
|
|
|
|
plt.subplot(1, 2, 2)
|
|
|
plt.plot(loss, label='Training Loss')
|
|
|
plt.plot(val_loss, label='Validation Loss')
|
|
|
plt.title('Training and Validation Loss')
|
|
|
plt.legend()
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
def model_valtest(model, check_save_path):
|
|
|
model.load_weights(check_save_path)
|
|
|
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
|
|
|
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
check_save_path = "./checkpoint/mnist_cnn3.ckpt"
|
|
|
model = creat_model()
|
|
|
# model_fit(model, check_save_path)
|
|
|
model_valtest(model, check_save_path) |