You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.1 KiB
36 lines
1.1 KiB
from PIL import Image
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import mnist_cnn3
|
|
|
|
|
|
model_save_path = 'E:\\Python_touge\\MNIST\\venv\\checkpoint\\mnist_cnn3.ckpt'
|
|
model = mnist_cnn3.creat_model()
|
|
model.load_weights(model_save_path)
|
|
|
|
# preNum = int(input("input the number of test pictures:"))
|
|
def prediction(str):
|
|
for i in range(1):
|
|
image_path = str
|
|
img = Image.open(image_path)
|
|
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
|
img_arr = np.array(img.convert('L'))
|
|
#
|
|
# for i in range(28):
|
|
# for j in range(28):
|
|
# if img_arr[i][j] < 200:
|
|
# img_arr[i][j] = 255
|
|
# else:
|
|
# img_arr[i][j] = 0
|
|
|
|
img_arr = img_arr / 255.0
|
|
x_predict = img_arr[tf.newaxis, ...]
|
|
result = model.predict(x_predict)
|
|
pred = tf.argmax(result, axis=1)
|
|
if result[(0, pred)] <= 0.8:
|
|
str = "无法判断,请重新输入!"
|
|
print(result)
|
|
return str
|
|
else:
|
|
return pred.numpy()[0]
|