You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

36 lines
1.1 KiB

from PIL import Image
import numpy as np
import tensorflow as tf
import mnist_cnn3
model_save_path = 'E:\\Python_touge\\MNIST\\venv\\checkpoint\\mnist_cnn3.ckpt'
model = mnist_cnn3.creat_model()
model.load_weights(model_save_path)
# preNum = int(input("input the number of test pictures:"))
def prediction(str):
for i in range(1):
image_path = str
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
if result[(0, pred)] <= 0.8:
str = "无法判断,请重新输入!"
print(result)
return str
else:
return pred.numpy()[0]