|
|
import keras
|
|
|
from keras.datasets import mnist
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
import matplotlib.pyplot as plt
|
|
|
from keras.models import Sequential
|
|
|
from keras.layers import Dense,Conv2D,MaxPooling2D,Flatten
|
|
|
from keras.losses import categorical_crossentropy
|
|
|
from keras.utils import np_utils
|
|
|
import matplotlib
|
|
|
|
|
|
import os
|
|
|
os.environ['KMP_DUPLICATE_LIB_OK']='True'
|
|
|
#加这个就不报错
|
|
|
|
|
|
matplotlib.use('TkAgg')#使用这个交互方式
|
|
|
batch_size=32#训练样本数
|
|
|
num_classes=10#数字类型
|
|
|
(train_images,train_labels),(test_images,test_labels) = mnist.load_data()#导入 mnist数据集
|
|
|
|
|
|
print(train_images.shape,train_labels.shape)
|
|
|
print(test_images.shape,test_labels.shape)
|
|
|
|
|
|
def show_mnist(train_image,train_labels):
|
|
|
n = 3#行数
|
|
|
m = 3#列数
|
|
|
fig = plt.figure()#作用就是生成一个图框
|
|
|
for i in range(n):
|
|
|
for j in range(m):
|
|
|
plt.subplot(n,m,i*n+j+1)#直接指定划分方式和位置进行绘图
|
|
|
#plt.subplots_adjust(wspace=0.2, hspace=0.8)
|
|
|
index = i * n + j #当前图片的标号
|
|
|
img_array = train_image[index]
|
|
|
plt.title(np.argmax(train_labels[index]))
|
|
|
plt.imshow(img_array,cmap='Greys')
|
|
|
plt.axis('off')
|
|
|
plt.show()
|
|
|
img_row,img_col,channel = 28,28,1
|
|
|
|
|
|
mnist_input_shape = (img_row,img_col,1)
|
|
|
#将数据维度进行处理
|
|
|
train_images = train_images.reshape(train_images.shape[0],img_row,img_col,channel)
|
|
|
test_images = test_images.reshape(test_images.shape[0],img_row,img_col,channel)
|
|
|
'''对训练集和测试集的数据进行处理,将数据维度由原来的 `(n, 784)` 转换为 `(n, 28, 28, 1)`,
|
|
|
其中 `n` 表示样本数。这里使用了 `reshape()` 函数来进行。'''
|
|
|
|
|
|
train_images = train_images.astype("float32")
|
|
|
test_images = test_images.astype("float32")
|
|
|
|
|
|
## 进行归一化处理,将数据范围缩放到 `[0,1]`,让训练的时间变少
|
|
|
train_images /= 255
|
|
|
test_images /= 255
|
|
|
|
|
|
# 将类向量,转化为类矩阵
|
|
|
# 从 5 转换为 0 0 0 0 1 0 0 0 0 0 矩阵
|
|
|
train_labels = np_utils.to_categorical(train_labels,num_classes)
|
|
|
test_labels = np_utils.to_categorical(test_labels,num_classes)
|
|
|
"""
|
|
|
构造网络结构
|
|
|
"""
|
|
|
model = Sequential()
|
|
|
model.add(Conv2D(32,kernel_size=(3,3),
|
|
|
activation="relu",
|
|
|
input_shape=mnist_input_shape))
|
|
|
# kernalsize = 3*3 并没有改变数据维度
|
|
|
model.add(Conv2D(16,kernel_size=(3,3),
|
|
|
activation="relu"
|
|
|
))
|
|
|
model.add(MaxPooling2D(pool_size=(2,2)))
|
|
|
# 进行数据降维操作
|
|
|
model.add(Flatten())#Flatten层用来将输入“压平”,即把多维的输入一维化,
|
|
|
#常用在从卷积层到全连接层的过渡。Flatten不影响batch的大小。
|
|
|
model.add(Dense(32,activation="relu"))
|
|
|
#全连接层
|
|
|
model.add(Dense(num_classes,activation='softmax'))
|
|
|
|
|
|
"""
|
|
|
编译网络模型,添加一些超参数
|
|
|
"""
|
|
|
|
|
|
model.compile(loss=categorical_crossentropy,#损失函数
|
|
|
optimizer="adadelta",#优化器
|
|
|
metrics=['accuracy'])#评价指标
|
|
|
|
|
|
model.fit(train_images,
|
|
|
train_labels,
|
|
|
batch_size=batch_size,#样本参数
|
|
|
epochs=500,#迭代次数
|
|
|
verbose=1,#打印方式
|
|
|
validation_data=(test_images,test_labels),#验证集数据
|
|
|
shuffle=True,#每个epoch开始时打乱训练集数据
|
|
|
use_multiprocessing=True,#使用多进程训练模型
|
|
|
workers=4#使用的进程数
|
|
|
)
|
|
|
|
|
|
score = model.evaluate(test_images,test_labels,verbose=1)
|
|
|
|
|
|
print('test loss:',score[0])
|
|
|
print('test accuracy:',score[1])
|
|
|
"""
|
|
|
保存模型
|
|
|
"""
|
|
|
model.save("mnist.h5")
|
|
|
"""
|
|
|
加载模型
|
|
|
"""
|
|
|
from keras.models import load_model
|
|
|
model = load_model("mnist.h5")
|
|
|
"""
|
|
|
预测
|
|
|
"""
|
|
|
predict = model.predict(test_images)#对测试集数据进行预测
|
|
|
# print(predict.shape)#输出 predict 的形状,即预测结果的维度
|
|
|
# print(predict[0])#输出 predict 的第一行,即第一个测试数据的预测结果。
|
|
|
# print(np.argmax(predict[0]))#输出 predict 的第一行中概率最大的位置,即预测结果。
|
|
|
# print(test_labels[0])#输出测试数据的真实标签,即该测试数据对应的数字
|
|
|
|
|
|
def show_predict(test_images,test_labels,predict):
|
|
|
n = 3
|
|
|
m = 3
|
|
|
fig = plt.figure(figsize=(10,10))
|
|
|
for i in range(n):
|
|
|
for j in range(m):
|
|
|
plt.subplot(n,m,i*n+j+1)
|
|
|
index = i * n + j #当前图片的标号
|
|
|
img_array = test_images[index]
|
|
|
plt.title("predict:{}".format(np.argmax(predict[index])),fontsize=20)
|
|
|
plt.imshow(img_array,cmap='Greys')
|
|
|
plt.axis('off')
|
|
|
plt.show()
|
|
|
show_mnist(test_images, test_labels)
|
|
|
show_predict(test_images,test_labels,predict) |