diff --git a/main.py b/main.py index 8a133b1..68fa9d0 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,8 @@ +import os import tensorflow as tf +import matplotlib.pyplot as plt from tensorflow import keras +import numpy as np config = tf.compat.v1.ConfigProto(gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)) sess = tf.compat.v1.Session(config=config) @@ -19,13 +22,39 @@ model.add(keras.layers.MaxPooling2D(2,2)) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(128, activation = tf.nn.relu)) -model.add(keras.layers.Dense(36, activation = tf.nn.softmax)) +model.add(keras.layers.Dense(10, activation = tf.nn.softmax)) + +checkpoint_path = "training_1/cp.ckpt" +checkpoint_dir = os.path.dirname(checkpoint_path) + +cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, + save_weights_only=True, + verbose=1) train_images_scaled = train_images/255 model.compile(optimizer = 'adam', loss = tf.losses.sparse_categorical_crossentropy, metrics = ['accuracy']) -history = model.fit(train_images_scaled.reshape(-1, 28, 28 ,1), train_labels, epochs = 8) +history = model.fit( + train_images_scaled.reshape(-1, 28, 28 ,1), + train_labels, + epochs = 8, + validation_data=(test_images.reshape(-1, 28, 28 ,1), test_labels), + callbacks=[cp_callback] + ) results = model.evaluate(test_images.reshape(-1, 28, 28 ,1), test_labels) +testShow = test_labels[:100] + +pred = model.predict(test_images.reshape(-1, 28, 28 ,1)) +predict = [] +for item in pred: + predict.append(np.argmax(item)) +plt.figure() +plt.title('Conv Predict') +plt.ylabel('number') +plt.plot( range( testShow.size ), predict[:100], label='predict') +plt.plot( range( testShow.size ), testShow, label='result') +plt.legend() +plt.show() diff --git a/training_1/cp.ckpt.data-00000-of-00001 b/training_1/cp.ckpt.data-00000-of-00001 index 77aa021..7dda239 100644 Binary files a/training_1/cp.ckpt.data-00000-of-00001 and b/training_1/cp.ckpt.data-00000-of-00001 differ diff --git a/training_1/cp.ckpt.index b/training_1/cp.ckpt.index index 53acb63..511c6a8 100644 Binary files a/training_1/cp.ckpt.index and b/training_1/cp.ckpt.index differ