You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
26 lines
932 B
26 lines
932 B
6 years ago
|
import tensorflow as tf
|
||
|
from tensorflow.examples.tutorials.mnist import input_data
|
||
|
|
||
|
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
|
||
|
|
||
|
xs = tf.placeholder(tf.float32,[None,784])
|
||
|
ys = tf.placeholder(tf.float32,[None,10])
|
||
|
|
||
|
Weight = tf.Variable(tf.zeros([784,10]))
|
||
|
biases = tf.Variable(tf.zeros([10]))
|
||
|
y = tf.nn.softmax(tf.matmul(xs,Weight)+biases)
|
||
|
|
||
|
loss = -tf.reduce_sum(ys*tf.log(y))
|
||
|
train = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
|
||
|
|
||
|
init = tf.initialize_all_variables()
|
||
|
sess = tf.Session()
|
||
|
sess.run(init)
|
||
|
|
||
|
for step in range(10000):
|
||
|
batch = mnist.train.next_batch(100)
|
||
|
sess.run(train,feed_dict={xs:batch[0],ys:batch[1]})
|
||
|
if step%50==0:
|
||
|
correct_prediction = tf.equal(tf.arg_max(ys,1),tf.arg_max(y,1))
|
||
|
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
|
||
|
print(sess.run(accuracy,feed_dict={xs:mnist.test.images,ys:mnist.test.labels}))
|