from PIL import Image import numpy as np import tensorflow as tf import mnist_cnn3,mnist_cnn,mnist_dense def predeiction(str, model): if model==1: res = prediction_1(str) elif model==2: res = prediction_2(str) elif model==3: res = prediction_3(str) return res # preNum = int(input("input the number of test pictures:")) def prediction_1(str): model_save_path = "./checkpoint/mnist.ckpt" model = mnist_dense.creat_model() model.load_weights(model_save_path) for i in range(1): image_path = str img = Image.open(image_path) img = img.resize((28, 28), Image.Resampling.LANCZOS) img_arr = np.array(img.convert('L')) # # for i in range(28): # for j in range(28): # if img_arr[i][j] < 200: # img_arr[i][j] = 255 # else: # img_arr[i][j] = 0 img_arr = img_arr / 255.0 x_predict = img_arr[tf.newaxis, ...] result = model.predict(x_predict) pred = tf.argmax(result, axis=1) return pred.numpy()[0] def prediction_2(str): model_save_path = "./checkpoint/mnist_cnn1.ckpt" model = mnist_cnn.creat_model() model.load_weights(model_save_path) for i in range(1): image_path = str img = Image.open(image_path) img = img.resize((28, 28), Image.Resampling.LANCZOS) img_arr = np.array(img.convert('L')) # # for i in range(28): # for j in range(28): # if img_arr[i][j] < 200: # img_arr[i][j] = 255 # else: # img_arr[i][j] = 0 img_arr = img_arr / 255.0 x_predict = img_arr[tf.newaxis, ...] x_predict = tf.expand_dims(x_predict, -1) result = model.predict(x_predict) print(result) pred = tf.argmax(result, axis=1) return pred.numpy()[0] def prediction_3(str): model_save_path = "./checkpoint/mnist_cnn3.ckpt" model = mnist_cnn3.creat_model() model.load_weights(model_save_path) for i in range(1): image_path = str img = Image.open(image_path) img = img.resize((28, 28), Image.Resampling.LANCZOS) img_arr = np.array(img.convert('L')) # # for i in range(28): # for j in range(28): # if img_arr[i][j] < 200: # img_arr[i][j] = 255 # else: # img_arr[i][j] = 0 img_arr = img_arr / 255.0 x_predict = img_arr[tf.newaxis, ...] result = model.predict(x_predict) pred = tf.argmax(result, axis=1) return pred.numpy()[0]