You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

33 lines
1.4 KiB

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
import matplotlib.pyplot as plt
num_samples_to_visualize=10
(X_train,y_train),(X_test,y_test)=tf.keras.datasets.mnist.load_data()
X_train = X_train.reshape(-1,28,28,1)/255.0
X_test=X_test.reshape(-1,28,28,1)/255.0
model=models.Sequential()
model.add(layers.Conv2D(6,(5,5),activation='relu',input_shape=(28,28,1)))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(16,(5,5),activation='relu'))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Flatten())
model.add(layers.Dense(120,activation='relu'))
model.add(layers.Dense(84,activation='relu'))
model.add(layers.Dense(10,activation='softmax'))
model.compile(optimizer='SGD',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.fit(X_train,y_train,epochs=5,batch_size=64)
test_loss,test_acc = model.evaluate(X_test,y_test)
print(f'Test accuarcy: {test_acc}')
predictions =model.predict(X_test)
predicted_labels = tf.argmax(predictions,axis=1)
random_indices = tf.random.uniform((num_samples_to_visualize ,),maxval=len(X_test),dtype=tf.int32)
for i,idx in enumerate(random_indices):
image=X_test[idx].squeeze()
true_label =y_test[idx]
predicted_label=predicted_labels[idx]
plt.subplot(2,5,i+1)
plt.imshow(image,cmap='gray')
plt.title(f'Predict:{predicted_label}')
plt.axis('off')
plt.show()