You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
2.8 KiB
96 lines
2.8 KiB
from PIL import Image
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import mnist_cnn3,mnist_cnn,mnist_dense
|
|
|
|
|
|
def predeiction(str, model):
|
|
if model==1:
|
|
|
|
res = prediction_1(str)
|
|
elif model==2:
|
|
res = prediction_2(str)
|
|
elif model==3:
|
|
res = prediction_3(str)
|
|
return res
|
|
|
|
|
|
# preNum = int(input("input the number of test pictures:"))
|
|
def prediction_1(str):
|
|
model_save_path = "./checkpoint/mnist.ckpt"
|
|
model = mnist_dense.creat_model()
|
|
|
|
model.load_weights(model_save_path)
|
|
for i in range(1):
|
|
image_path = str
|
|
img = Image.open(image_path)
|
|
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
|
img_arr = np.array(img.convert('L'))
|
|
#
|
|
# for i in range(28):
|
|
# for j in range(28):
|
|
# if img_arr[i][j] < 200:
|
|
# img_arr[i][j] = 255
|
|
# else:
|
|
# img_arr[i][j] = 0
|
|
|
|
img_arr = img_arr / 255.0
|
|
x_predict = img_arr[tf.newaxis, ...]
|
|
result = model.predict(x_predict)
|
|
pred = tf.argmax(result, axis=1)
|
|
return pred.numpy()[0]
|
|
|
|
|
|
|
|
def prediction_2(str):
|
|
model_save_path = "./checkpoint/mnist_cnn1.ckpt"
|
|
model = mnist_cnn.creat_model()
|
|
|
|
model.load_weights(model_save_path)
|
|
for i in range(1):
|
|
image_path = str
|
|
img = Image.open(image_path)
|
|
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
|
img_arr = np.array(img.convert('L'))
|
|
#
|
|
# for i in range(28):
|
|
# for j in range(28):
|
|
# if img_arr[i][j] < 200:
|
|
# img_arr[i][j] = 255
|
|
# else:
|
|
# img_arr[i][j] = 0
|
|
|
|
img_arr = img_arr / 255.0
|
|
x_predict = img_arr[tf.newaxis, ...]
|
|
x_predict = tf.expand_dims(x_predict, -1)
|
|
result = model.predict(x_predict)
|
|
print(result)
|
|
pred = tf.argmax(result, axis=1)
|
|
return pred.numpy()[0]
|
|
|
|
|
|
def prediction_3(str):
|
|
model_save_path = "./checkpoint/mnist_cnn3.ckpt"
|
|
model = mnist_cnn3.creat_model()
|
|
|
|
model.load_weights(model_save_path)
|
|
for i in range(1):
|
|
image_path = str
|
|
img = Image.open(image_path)
|
|
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
|
img_arr = np.array(img.convert('L'))
|
|
#
|
|
# for i in range(28):
|
|
# for j in range(28):
|
|
# if img_arr[i][j] < 200:
|
|
# img_arr[i][j] = 255
|
|
# else:
|
|
# img_arr[i][j] = 0
|
|
|
|
img_arr = img_arr / 255.0
|
|
x_predict = img_arr[tf.newaxis, ...]
|
|
result = model.predict(x_predict)
|
|
|
|
pred = tf.argmax(result, axis=1)
|
|
return pred.numpy()[0]
|