You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
4.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import tensorflow as tf
import os
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
# 数据准备
mnist = tf.keras.datasets.mnist# 导入MNIST数据集
# 指定要投入网络的训练集和测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
def creat_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(), # 将输入层的数据压成一维数据
# 全连接层
# tf.keras.layers.Densen(神经元个数,activation='激活函数')
tf.keras.layers.Dense(128, activation='relu'), # 128个神经元激活函数为'relu'
tf.keras.layers.Dense(10, activation='softmax') # 10个神经元激活函数为'softmax'
])
# 使用compile()配置神经网络训练方法
# model.compile(optimizer=优化器,
# loss=损失函数,
# metrics=["评测指标"]
# )
model.compile(optimizer='adam', # 优化器为'adam'
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
# 损失函数为'sparse_categorical_crossentropy'
metrics=["sparse_categorical_accuracy"]) # y_以数值形式给出y以独热码形式给出
return model
def model_fit(model,check_save_path):
if os.path.exists(check_save_path+'.index'):
print("load modals...")
model.load_weights(check_save_path)
# 使用callbacks.ModelCheckpoint()保存模型
# tf.keras.callbacks.ModelCheckpoint(filepath=保存模型的路径,
# save_weights_only=若设置为True则仅保存模型权重,设置为False则保存整个模型
# save_best_only=若设置为True将只保存在验证集上性能最好的模型
# )
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path, # 保存路径
save_weights_only=True, # 只保存模型权重
save_best_only=True) # 只保存在验证集上性能最好的模型
# 模型训练
# 在fit()中执行训练
# model.fit(训练集的输入特征,训练集的标签,
# batch_size=,epochs=,
# validation_split=从训练集中划分多少比例给测试集,
# validation_freq=多少次epoch测试一次
# )
history = model.fit(x_train, y_train,
batch_size=32, epochs=5, # 每一批batch大小为32,迭代次数epochs为5
validation_split=0.15, # 从训练集中划分15%给验证集
validation_freq=1, # 测试的间隔次数为1
callbacks=[cp_callback]) # 回调参数
model.summary() # 打印网络结构和参数统计
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
def eva_acc(str, model):
model.load_weights(str)
# 准确率测试
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) # 输入数据和标签,输出损失和精确度,verbose = 2输出一行记录
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
if __name__=="__main__":
check_save_path = "./checkpoint/mnist.ckpt"
model = creat_model()
model_fit(model, check_save_path)
eva_acc(check_save_path, model)