You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
4.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tensorflow.keras as keras
import numpy as np
import os
import tensorflow as tf
import matplotlib.pyplot as plt
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train/255.0
x_test = x_test/255.0
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
print("train shape:", x_train.shape)
print("test shape:", x_test.shape)
# 使用此类进行图形增强
datagen = keras.preprocessing.image.ImageDataGenerator(
rotation_range=20, # 整数。随机旋转的度数范围。
width_shift_range=0.20, # 浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。
shear_range=15, # 浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。
zoom_range=0.10, # 浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用来进行随机的放大。
validation_split=0.15, # 浮点型。保留用于验证集的图像比例严格在0,1之间
horizontal_flip=False # 布尔值,随机水平翻转。
)
train_datagen = datagen.flow(
x_train,
y_train,
batch_size=256,
subset="training"
)
validation_genetor = datagen.flow(
x_train,
y_train,
batch_size=64,
subset="validation"
)
def creat_model():
model = keras.Sequential([
keras.layers.Reshape((28, 28, 1)),
keras.layers.Conv2D(filters=32, kernel_size=(5, 5), activation="relu", padding="same",
input_shape=(28, 28, 1)),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(512, activation="sigmoid"),
keras.layers.Dropout(0.25),
keras.layers.Dense(512, activation="sigmoid"),
keras.layers.Dropout(0.25),
keras.layers.Dense(256, activation="sigmoid"),
keras.layers.Dropout(0.1),
keras.layers.Dense(10, activation="sigmoid")
])
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=["sparse_categorical_accuracy"])
return model
def model_fit(model, check_save_path):
if os.path.exists(check_save_path+'.index'):
print("load modals...")
model.load_weights(check_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
save_weights_only=True,
save_best_only=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
factor=0.1,
patience=5,
min_lr=0.000001,
verbose=1)
history = model.fit(train_datagen, epochs=1, validation_data=validation_genetor, callbacks=[reduce_lr,cp_callback],verbose=1)
model.summary()
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
def model_valtest(model, check_save_path):
model.load_weights(check_save_path)
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
if __name__ == "__main__":
check_save_path = "./checkpoint/mnist_cnn3.ckpt"
model = creat_model()
# model_fit(model, check_save_path)
model_valtest(model, check_save_path)