You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
2.8 KiB

from PIL import Image
import numpy as np
import tensorflow as tf
import mnist_cnn3,mnist_cnn,mnist_dense
def predeiction(str, model):
if model==1:
res = prediction_1(str)
elif model==2:
res = prediction_2(str)
elif model==3:
res = prediction_3(str)
return res
# preNum = int(input("input the number of test pictures:"))
def prediction_1(str):
model_save_path = "./checkpoint/mnist.ckpt"
model = mnist_dense.creat_model()
model.load_weights(model_save_path)
for i in range(1):
image_path = str
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
return pred.numpy()[0]
def prediction_2(str):
model_save_path = "./checkpoint/mnist_cnn1.ckpt"
model = mnist_cnn.creat_model()
model.load_weights(model_save_path)
for i in range(1):
image_path = str
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
x_predict = tf.expand_dims(x_predict, -1)
result = model.predict(x_predict)
print(result)
pred = tf.argmax(result, axis=1)
return pred.numpy()[0]
def prediction_3(str):
model_save_path = "./checkpoint/mnist_cnn3.ckpt"
model = mnist_cnn3.creat_model()
model.load_weights(model_save_path)
for i in range(1):
image_path = str
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
return pred.numpy()[0]