You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
1.1 KiB
41 lines
1.1 KiB
from PIL import Image
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
model_save_path = './checkpoint/mnist.ckpt'
|
|
|
|
model = tf.keras.models.Sequential([
|
|
tf.keras.layers.Flatten(),
|
|
tf.keras.layers.Dense(128, activation='relu'),
|
|
tf.keras.layers.Dense(10, activation='softmax')])
|
|
|
|
model.load_weights(model_save_path)
|
|
|
|
preNum = int(input("input the number of test pictures:"))
|
|
|
|
for i in range(preNum):
|
|
image_path = input("the path of test picture:")
|
|
img = Image.open(image_path)
|
|
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
|
img_arr = np.array(img.convert('L'))
|
|
#
|
|
# for i in range(28):
|
|
# for j in range(28):
|
|
# if img_arr[i][j] < 200:
|
|
# img_arr[i][j] = 255
|
|
# else:
|
|
# img_arr[i][j] = 0
|
|
|
|
img_arr = img_arr / 255.0
|
|
x_predict = img_arr[tf.newaxis, ...]
|
|
result = model.predict(x_predict)
|
|
pred = tf.argmax(result, axis=1)
|
|
if result[0][pred] <= 0.3:
|
|
print(result)
|
|
print("无法判断,请重新输入!")
|
|
else:
|
|
print(result)
|
|
tf.print(pred[0])
|
|
|
|
|