|
|
########MNIST手写识别########
|
|
|
########反向传播的神经网络########
|
|
|
|
|
|
#输入784,第一层30,第二层60,输出10的神经网络
|
|
|
import tensorflow as tf
|
|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
def sigmaprime(x):
|
|
|
"""
|
|
|
用Sigmoid函数的导数更新权重
|
|
|
:param x:
|
|
|
:return: 更新后的权重
|
|
|
"""
|
|
|
return tf.multiply(tf.sigmoid(x), tf.subtract(tf.constant(1.0), tf.sigmoid(x)))
|
|
|
|
|
|
|
|
|
#加载MNIST数据集
|
|
|
#MNIST数据集包含70000(60000+10000)个样本,其中有60000个训练样本和10000个测试样本,每个样本的像素大小为28*28
|
|
|
mnist = tf.keras.datasets.mnist
|
|
|
(train_x,train_y),(test_x,test_y) = mnist.load_data()
|
|
|
|
|
|
# 定义模型
|
|
|
|
|
|
# 常数
|
|
|
n_input = 784 # MNIST尺寸(28*28)
|
|
|
n_classes = 10 # MNIST类别(0-9)
|
|
|
|
|
|
# 超参数
|
|
|
max_epochs = 10000 # 最大迭代数
|
|
|
learning_rate = 0.5 # 学习率
|
|
|
batch_size = 10 # 每批训练批量大小
|
|
|
seed = 0 # 随机种子
|
|
|
n_hidden = 30 # 隐藏层的神经元数
|
|
|
|
|
|
# 占位符
|
|
|
#x_in = tf.placeholder(tf.float32, [None, n_input])
|
|
|
#y = tf.placeholder(tf.float32, [None, n_classes])
|
|
|
|
|
|
|
|
|
# 创建模型
|
|
|
def multilayer_perceptron(x, weights, biases):
|
|
|
h_layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['h1']) # 隐藏层使用ReLU激活函数
|
|
|
out_layer_1 = tf.sigmoid(h_layer_1)
|
|
|
|
|
|
h_out = tf.matmul(out_layer_1, weights['out']) + biases['out'] # 输出层使用线性激活函数
|
|
|
return tf.sigmoid(h_out), h_out, out_layer_1, h_layer_1
|
|
|
|
|
|
|
|
|
weights = { # 权重
|
|
|
'h1': tf.Variable(tf.random_normal([n_input, n_hidden], seed=seed)),
|
|
|
'out': tf.Variable(tf.random_normal([n_hidden, n_classes], seed=seed))}
|
|
|
|
|
|
biases = { # 偏置
|
|
|
'h1': tf.Variable(tf.random_normal([1, n_hidden], seed=seed)),
|
|
|
'out': tf.Variable(tf.random_normal([1, n_classes], seed=seed))}
|
|
|
|
|
|
# 正向传播
|
|
|
y_hat, h_2, o_1, h_1 = multilayer_perceptron(x_in, weights, biases)
|
|
|
|
|
|
# 损失函数
|
|
|
err = y - y_hat
|
|
|
loss = tf.reduce_mean(tf.square(err, name='loss'))
|
|
|
|
|
|
# 反向传播
|
|
|
delta_2 = tf.multiply(err, sigmaprime(h_2))
|
|
|
delta_w_2 = tf.matmul(tf.transpose(o_1), delta_2)
|
|
|
|
|
|
wtd_error = tf.matmul(delta_2, tf.transpose(weights['out']))
|
|
|
delta_1 = tf.multiply(wtd_error, sigmaprime(h_1))
|
|
|
delta_w_1 = tf.matmul(tf.transpose(x_in), delta_1)
|
|
|
|
|
|
eta = tf.constant(learning_rate)
|
|
|
|
|
|
# 更新权重
|
|
|
train = [
|
|
|
tf.assign(weights['h1'], tf.add(weights['h1'], tf.multiply(eta, delta_w_1)))
|
|
|
, tf.assign(biases['h1'], tf.add(biases['h1'], tf.multiply(eta, tf.reduce_mean(delta_1, axis=[0]))))
|
|
|
, tf.assign(weights['out'], tf.add(weights['out'], tf.multiply(eta, delta_w_2)))
|
|
|
, tf.assign(biases['out'], tf.add(biases['out'], tf.multiply(eta, tf.reduce_mean(delta_2, axis=[0]))))
|
|
|
]
|
|
|
|
|
|
# 定义精度
|
|
|
acct_mat = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
|
|
|
accuracy = tf.reduce_sum(tf.cast(acct_mat, tf.float32))
|
|
|
|
|
|
# 训练
|
|
|
init = tf.global_variables_initializer()
|
|
|
|
|
|
with tf.Session() as sess:
|
|
|
sess.run(init)
|
|
|
for epoch in range(max_epochs):
|
|
|
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
|
|
|
_, loss1 = sess.run([train, loss], feed_dict={x_in: batch_xs, y: batch_ys})
|
|
|
if epoch % 1000 == 0:
|
|
|
print('Epoch: {0} loss: {1}'.format(epoch, loss1))
|
|
|
|
|
|
acc_test = sess.run(accuracy, feed_dict={x_in: mnist.test.images, y: mnist.test.labels})
|
|
|
acc_train = sess.run(accuracy, feed_dict={x_in: mnist.train.images, y: mnist.train.labels})
|
|
|
# 评估
|
|
|
print('Accuracy Train%: {1} Accuracy Test%: {2}'
|
|
|
.format(epoch, acc_train / 600, (acc_test / 100)))
|
|
|
|
|
|
print('--------------')
|
|
|
|
|
|
#结果可视化
|
|
|
plt.figure(figsize=(10,3))
|
|
|
#使用模型
|
|
|
plt.figure()
|
|
|
for i in range(10):
|
|
|
num = np.random.randint(1,10000) #在MNIST数据集中随机一个数据
|
|
|
|
|
|
plt.subplot(2,5,i+1) #按2行5列的格式
|
|
|
plt.axis('off') #关闭坐标轴
|
|
|
plt.imshow(test_x[num],cmap='gray')
|
|
|
demo = tf.reshape(X_test[num],(1,28,28))
|
|
|
y_pred = np.argmax(model.predict(demo))
|
|
|
plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred)) #显示标签值和电脑预测值
|
|
|
|
|
|
plt.show() |