from PIL import Image import numpy as np import tensorflow as tf import mnist_cnn3 model_save_path = 'E:\\Python_touge\\MNIST\\venv\\checkpoint\\mnist_cnn3.ckpt' model = mnist_cnn3.creat_model() model.load_weights(model_save_path) # preNum = int(input("input the number of test pictures:")) def prediction(str): for i in range(1): image_path = str img = Image.open(image_path) img = img.resize((28, 28), Image.Resampling.LANCZOS) img_arr = np.array(img.convert('L')) # # for i in range(28): # for j in range(28): # if img_arr[i][j] < 200: # img_arr[i][j] = 255 # else: # img_arr[i][j] = 0 img_arr = img_arr / 255.0 x_predict = img_arr[tf.newaxis, ...] result = model.predict(x_predict) pred = tf.argmax(result, axis=1) if result[(0, pred)] <= 0.8: str = "无法判断,请重新输入!" print(result) return str else: return pred.numpy()[0]