diff --git a/code/.idea/.gitignore b/code/.idea/.gitignore new file mode 100644 index 0000000..50d9d22 --- /dev/null +++ b/code/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/code/.idea/.name b/code/.idea/.name new file mode 100644 index 0000000..546302d --- /dev/null +++ b/code/.idea/.name @@ -0,0 +1 @@ +mnist_model1.py \ No newline at end of file diff --git a/code/.idea/MNIST.iml b/code/.idea/MNIST.iml new file mode 100644 index 0000000..32048ae --- /dev/null +++ b/code/.idea/MNIST.iml @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/code/.idea/inspectionProfiles/profiles_settings.xml b/code/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/code/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/code/.idea/misc.xml b/code/.idea/misc.xml new file mode 100644 index 0000000..628f4b0 --- /dev/null +++ b/code/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/code/.idea/modules.xml b/code/.idea/modules.xml new file mode 100644 index 0000000..e38aea9 --- /dev/null +++ b/code/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/code/.idea/workspace.xml b/code/.idea/workspace.xml deleted file mode 100644 index 0ab5a71..0000000 --- a/code/.idea/workspace.xml +++ /dev/null @@ -1,162 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 1648997377373 - - - - \ No newline at end of file diff --git a/code/1.jpg b/code/1.jpg new file mode 100644 index 0000000..71fe1bb Binary files /dev/null and b/code/1.jpg differ diff --git a/code/1.py b/code/1.py new file mode 100644 index 0000000..d7db9e7 --- /dev/null +++ b/code/1.py @@ -0,0 +1,15 @@ +def isPrime(n): + # 判断数字是否为素数 + # 请在此处添加代码 # + # *************begin************# + if n<2: + return False + if n==2: + return True + if n%2 == 0: + return False + for i in range(2,n): + if n%i == 0: + return False + return True +print(isPrime(10)) \ No newline at end of file diff --git a/code/2.png b/code/2.png new file mode 100644 index 0000000..cc4713e Binary files /dev/null and b/code/2.png differ diff --git a/code/__pycache__/mnist_cnn.cpython-37.pyc b/code/__pycache__/mnist_cnn.cpython-37.pyc new file mode 100644 index 0000000..e81c16b Binary files /dev/null and b/code/__pycache__/mnist_cnn.cpython-37.pyc differ diff --git a/code/__pycache__/mnist_cnn3.cpython-37.pyc b/code/__pycache__/mnist_cnn3.cpython-37.pyc new file mode 100644 index 0000000..7c3dc0d Binary files /dev/null and b/code/__pycache__/mnist_cnn3.cpython-37.pyc differ diff --git a/code/__pycache__/mnist_dense.cpython-37.pyc b/code/__pycache__/mnist_dense.cpython-37.pyc new file mode 100644 index 0000000..7c73fa5 Binary files /dev/null and b/code/__pycache__/mnist_dense.cpython-37.pyc differ diff --git a/code/__pycache__/read_image.cpython-37.pyc b/code/__pycache__/read_image.cpython-37.pyc new file mode 100644 index 0000000..9fa8fbc Binary files /dev/null and b/code/__pycache__/read_image.cpython-37.pyc differ diff --git a/code/__pycache__/test1.cpython-37.pyc b/code/__pycache__/test1.cpython-37.pyc new file mode 100644 index 0000000..017d96d Binary files /dev/null and b/code/__pycache__/test1.cpython-37.pyc differ diff --git a/code/__pycache__/test_cnn3.cpython-37.pyc b/code/__pycache__/test_cnn3.cpython-37.pyc new file mode 100644 index 0000000..7a50e8d Binary files /dev/null and b/code/__pycache__/test_cnn3.cpython-37.pyc differ diff --git a/code/checkpoint/checkpoint b/code/checkpoint/checkpoint new file mode 100644 index 0000000..9beb43e --- /dev/null +++ b/code/checkpoint/checkpoint @@ -0,0 +1,2 @@ +model_checkpoint_path: "mnist.ckpt" +all_model_checkpoint_paths: "mnist.ckpt" diff --git a/code/checkpoint/mnist.ckpt.data-00000-of-00001 b/code/checkpoint/mnist.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..23d9b55 Binary files /dev/null and b/code/checkpoint/mnist.ckpt.data-00000-of-00001 differ diff --git a/code/checkpoint/mnist.ckpt.index b/code/checkpoint/mnist.ckpt.index new file mode 100644 index 0000000..2d65419 Binary files /dev/null and b/code/checkpoint/mnist.ckpt.index differ diff --git a/code/checkpoint/mnist_cnn1.ckpt.data-00000-of-00001 b/code/checkpoint/mnist_cnn1.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..bdb4112 Binary files /dev/null and b/code/checkpoint/mnist_cnn1.ckpt.data-00000-of-00001 differ diff --git a/code/checkpoint/mnist_cnn1.ckpt.index b/code/checkpoint/mnist_cnn1.ckpt.index new file mode 100644 index 0000000..bbc3924 Binary files /dev/null and b/code/checkpoint/mnist_cnn1.ckpt.index differ diff --git a/code/checkpoint/mnist_cnn3.ckpt.data-00000-of-00001 b/code/checkpoint/mnist_cnn3.ckpt.data-00000-of-00001 new file mode 100644 index 0000000..32d49d1 Binary files /dev/null and b/code/checkpoint/mnist_cnn3.ckpt.data-00000-of-00001 differ diff --git a/code/checkpoint/mnist_cnn3.ckpt.index b/code/checkpoint/mnist_cnn3.ckpt.index new file mode 100644 index 0000000..6baf7dc Binary files /dev/null and b/code/checkpoint/mnist_cnn3.ckpt.index differ diff --git a/code/main.py b/code/main.py new file mode 100644 index 0000000..e172123 --- /dev/null +++ b/code/main.py @@ -0,0 +1,116 @@ +from tkinter import * + +import cv2 +from PIL import ImageGrab +from tkinter import filedialog + +import read_image + +model=3 +def model_1(): + global model + model = 1 + print(model) + +def model_2(): + global model + model = 2 + print(model) +def model_3(): + global model + model = 3 + print(model) + +def paint(event): + x1, y1 = (event.x - 20), (event.y - 20) + x2, y2 = (event.x + 20), (event.y + 20) + w.create_oval(x1, y1, x2, y2, fill="white", outline='white') + + +def open_image(): + image_name = filedialog.askopenfilename(title='打开图片', filetypes=[('jpg,jpeg', '*.jpg')]) + image_show = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE) + cv2.imshow("image", image_show) + print(model) + result = read_image.predeiction(image_name, model) + text.set(str(result)) + +def screenshot(*args): + a = root.winfo_x() + b = root.winfo_y() + a = a + 10 + b = b + 35 + bbox = (a, b, a + 395, b + 395) + im = ImageGrab.grab(bbox) + im.save('1.jpg') + print("在使用%d模型"%model) + result = read_image.predeiction('1.jpg', model) + + text.set(str(result)) + + +def clear_canvas(event): + x1, y1 = (event.x - 2800), (event.y - 2800) + x2, y2 = (event.x + 2800), (event.y + 2800) + w.create_oval(x1, y1, x2, y2, fill="black", outline='black') + + + +def reset_canvas(): + a = root.winfo_x() + b = root.winfo_y() + x1, y1 = (a - 2800), (b - 2800) + x2, y2 = (a + 2800), (b + 2800) + w.create_oval(x1, y1, x2, y2, fill="black", outline='black') + + +root = Tk() +root.geometry('600x400') # 规定窗口大小600*400像素 +root.resizable(False, False) # 规定窗口不可缩放 +root.title('数字识别') + +col_count, row_count = root.grid_size() + +for col in range(col_count): + root.grid_columnconfigure(col, minsize=10) + +for row in range(row_count): + root.grid_rowconfigure(row, minsize=20) + + +text = StringVar() +text.set('') +w = Canvas(root, width=400, height=400, bg='black') +w.grid(row=0, column=0, rowspan=6) +label_1 = Label(root, text=' 识别的结果为:', font=('', 20)) +label_1.grid(row=0, column=1) +result_label = Label(root, textvariable=text, font=('', 25), height=2, fg='red') +result_label.grid(row=1, column=1) + +try_button = Button(root, text='模型1', width=7, height=2, command=model_1) +try_button.grid(row=2, column=1,sticky=W) +try_button = Button(root, text='模型2', width=7, height=2, command=model_2) +try_button.grid(row=2, column=1) +try_button = Button(root, text='模型3', width=7, height=2, command=model_3) +try_button.grid(row=2, column=1,sticky=E) + +try_button = Button(root, text='开始识别', width=15, height=2, command=screenshot) +try_button.grid(row=3, column=1) + +clear_button = Button(root, text='清空画布', width=15, height=2, command=reset_canvas) +clear_button.grid(row=4, column=1) + +load_image_button = Button(root, text='来自图片', width=15, height=2, command=open_image) +load_image_button.grid(row=5, column=1) + +w.bind("", paint) +w.bind("", screenshot) +w.bind("", clear_canvas) + +mainloop() + + + + + + diff --git a/code/mnist_cnn.py b/code/mnist_cnn.py new file mode 100644 index 0000000..4a944ac --- /dev/null +++ b/code/mnist_cnn.py @@ -0,0 +1,66 @@ +import tensorflow as tf +import os +from matplotlib import pyplot as plt +from PIL import Image +import numpy as np + +mnist = tf.keras.datasets.mnist + +(x_train, y_train), (x_test, y_test) = mnist.load_data() +x_train, x_test = x_train/255.0, x_test/255.0 + +def creat_model(): + model = tf.keras.models.Sequential([ + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') + ]) + + model.compile(optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), + metrics=["sparse_categorical_accuracy"]) + return model +def model_fit(model,check_save_path): + + if os.path.exists(check_save_path+'.index'): + print("load modals...") + model.load_weights(check_save_path) + + cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path, + save_weights_only=True, + save_best_only=True) + + history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback]) + model.summary() + + final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) + print("Model accuracy: ", final_acc, ", model loss: ", final_loss) + + acc = history.history['sparse_categorical_accuracy'] + val_acc = history.history['val_sparse_categorical_accuracy'] + loss = history.history['loss'] + val_loss = history.history['val_loss'] + + plt.subplot(1, 2, 1) + plt.plot(acc, label='Training Accuracy') + plt.plot(val_acc, label='Validation Accuracy') + plt.title('Training and Validation Accuracy') + plt.legend() + + plt.subplot(1, 2, 2) + plt.plot(loss, label='Training Loss') + plt.plot(val_loss, label='Validation Loss') + plt.title('Training and Validation Loss') + plt.legend() + plt.show() + +def eva_acc(str, model): + model.load_weights(str) + final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) + print("Model accuracy: ", final_acc, ", model loss: ", final_loss) + +if __name__=="__main__": + check_save_path = "./checkpoint/mnist.ckpt" + model = creat_model() + model_fit(model, check_save_path) + eva_acc(check_save_path, model) diff --git a/code/mnist_cnn2.py b/code/mnist_cnn2.py new file mode 100644 index 0000000..e69de29 diff --git a/code/mnist_cnn3.py b/code/mnist_cnn3.py new file mode 100644 index 0000000..cc10dfc --- /dev/null +++ b/code/mnist_cnn3.py @@ -0,0 +1,120 @@ +import tensorflow.keras as keras +import numpy as np +import os +import tensorflow as tf +import matplotlib.pyplot as plt +mnist = keras.datasets.mnist +(x_train, y_train), (x_test, y_test) = mnist.load_data() +x_train = x_train/255.0 +x_test = x_test/255.0 + +x_train = tf.expand_dims(x_train, -1) +x_test = tf.expand_dims(x_test, -1) + +print("train shape:", x_train.shape) +print("test shape:", x_test.shape) + +# 使用此类进行图形增强 +datagen = keras.preprocessing.image.ImageDataGenerator( + rotation_range=20, # 整数。随机旋转的度数范围。 + width_shift_range=0.20, # 浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。 + shear_range=15, # 浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。 + zoom_range=0.10, # 浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用来进行随机的放大。 + validation_split=0.15, # 浮点型。保留用于验证集的图像比例(严格在0,1之间) + horizontal_flip=False # 布尔值,随机水平翻转。 +) + +train_datagen = datagen.flow( + x_train, + y_train, + batch_size=256, + subset="training" +) + +validation_genetor = datagen.flow( + x_train, + y_train, + batch_size=64, + subset="validation" +) +def creat_model(): + model = keras.Sequential([ + keras.layers.Reshape((28, 28, 1)), + keras.layers.Conv2D(filters=32, kernel_size=(5, 5), activation="relu", padding="same", + input_shape=(28, 28, 1)), + keras.layers.MaxPool2D((2, 2)), + + keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"), + keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"), + keras.layers.MaxPool2D((2, 2)), + + keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"), + keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"), + keras.layers.MaxPool2D((2, 2)), + + keras.layers.Flatten(), + keras.layers.Dense(512, activation="sigmoid"), + keras.layers.Dropout(0.25), + + keras.layers.Dense(512, activation="sigmoid"), + keras.layers.Dropout(0.25), + + keras.layers.Dense(256, activation="sigmoid"), + keras.layers.Dropout(0.1), + + keras.layers.Dense(10, activation="sigmoid") + ]) + + model.compile(optimizer='adam', + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), + metrics=["sparse_categorical_accuracy"]) + return model + +def model_fit(model, check_save_path): + if os.path.exists(check_save_path+'.index'): + print("load modals...") + model.load_weights(check_save_path) + + cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path, + save_weights_only=True, + save_best_only=True) + reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', + factor=0.1, + patience=5, + min_lr=0.000001, + verbose=1) + + + history = model.fit(train_datagen, epochs=1, validation_data=validation_genetor, callbacks=[reduce_lr,cp_callback],verbose=1) + model.summary() + + acc = history.history['sparse_categorical_accuracy'] + val_acc = history.history['val_sparse_categorical_accuracy'] + loss = history.history['loss'] + val_loss = history.history['val_loss'] + + plt.subplot(1, 2, 1) + plt.plot(acc, label='Training Accuracy') + plt.plot(val_acc, label='Validation Accuracy') + plt.title('Training and Validation Accuracy') + plt.legend() + + plt.subplot(1, 2, 2) + plt.plot(loss, label='Training Loss') + plt.plot(val_loss, label='Validation Loss') + plt.title('Training and Validation Loss') + plt.legend() + plt.show() + + +def model_valtest(model, check_save_path): + model.load_weights(check_save_path) + final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) + print("Model accuracy: ", final_acc, ", model loss: ", final_loss) + +if __name__ == "__main__": + + check_save_path = "./checkpoint/mnist_cnn3.ckpt" + model = creat_model() + # model_fit(model, check_save_path) + model_valtest(model, check_save_path) \ No newline at end of file diff --git a/code/mnist_dense.py b/code/mnist_dense.py new file mode 100644 index 0000000..4a944ac --- /dev/null +++ b/code/mnist_dense.py @@ -0,0 +1,66 @@ +import tensorflow as tf +import os +from matplotlib import pyplot as plt +from PIL import Image +import numpy as np + +mnist = tf.keras.datasets.mnist + +(x_train, y_train), (x_test, y_test) = mnist.load_data() +x_train, x_test = x_train/255.0, x_test/255.0 + +def creat_model(): + model = tf.keras.models.Sequential([ + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax') + ]) + + model.compile(optimizer='adam', + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), + metrics=["sparse_categorical_accuracy"]) + return model +def model_fit(model,check_save_path): + + if os.path.exists(check_save_path+'.index'): + print("load modals...") + model.load_weights(check_save_path) + + cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path, + save_weights_only=True, + save_best_only=True) + + history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback]) + model.summary() + + final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) + print("Model accuracy: ", final_acc, ", model loss: ", final_loss) + + acc = history.history['sparse_categorical_accuracy'] + val_acc = history.history['val_sparse_categorical_accuracy'] + loss = history.history['loss'] + val_loss = history.history['val_loss'] + + plt.subplot(1, 2, 1) + plt.plot(acc, label='Training Accuracy') + plt.plot(val_acc, label='Validation Accuracy') + plt.title('Training and Validation Accuracy') + plt.legend() + + plt.subplot(1, 2, 2) + plt.plot(loss, label='Training Loss') + plt.plot(val_loss, label='Validation Loss') + plt.title('Training and Validation Loss') + plt.legend() + plt.show() + +def eva_acc(str, model): + model.load_weights(str) + final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2) + print("Model accuracy: ", final_acc, ", model loss: ", final_loss) + +if __name__=="__main__": + check_save_path = "./checkpoint/mnist.ckpt" + model = creat_model() + model_fit(model, check_save_path) + eva_acc(check_save_path, model) diff --git a/code/mnist_show.py b/code/mnist_show.py new file mode 100644 index 0000000..dee4aad --- /dev/null +++ b/code/mnist_show.py @@ -0,0 +1,11 @@ +import tensorflow as tf +import os +mnist = tf.keras.datasets.mnist +from matplotlib import pyplot as plt + +(x_train, y_train), (x_test, y_test) = mnist.load_data() +x_train, x_test = x_train/255.0, x_test/255.0 +print(y_test) + +plt.imshow(x_train[0], cmap="gray") +plt.show() \ No newline at end of file diff --git a/code/mnist_test.py b/code/mnist_test.py new file mode 100644 index 0000000..9e5643b --- /dev/null +++ b/code/mnist_test.py @@ -0,0 +1,40 @@ +from PIL import Image +import numpy as np +import tensorflow as tf + +model_save_path = './checkpoint/mnist.ckpt' + +model = tf.keras.models.Sequential([ + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation='relu'), + tf.keras.layers.Dense(10, activation='softmax')]) + +model.load_weights(model_save_path) + +preNum = int(input("input the number of test pictures:")) + +for i in range(preNum): + image_path = input("the path of test picture:") + img = Image.open(image_path) + img = img.resize((28, 28), Image.Resampling.LANCZOS) + img_arr = np.array(img.convert('L')) + # + # for i in range(28): + # for j in range(28): + # if img_arr[i][j] < 200: + # img_arr[i][j] = 255 + # else: + # img_arr[i][j] = 0 + + img_arr = img_arr / 255.0 + x_predict = img_arr[tf.newaxis, ...] + result = model.predict(x_predict) + pred = tf.argmax(result, axis=1) + if result[0][pred] <= 0.3: + print(result) + print("无法判断,请重新输入!") + else: + print(result) + tf.print(pred[0]) + + diff --git a/code/read_image.py b/code/read_image.py new file mode 100644 index 0000000..f0c5a4d --- /dev/null +++ b/code/read_image.py @@ -0,0 +1,95 @@ +from PIL import Image +import numpy as np +import tensorflow as tf +import mnist_cnn3,mnist_cnn,mnist_dense + + +def predeiction(str, model): + if model==1: + + res = prediction_1(str) + elif model==2: + res = prediction_2(str) + elif model==3: + res = prediction_3(str) + return res + + +# preNum = int(input("input the number of test pictures:")) +def prediction_1(str): + model_save_path = "./checkpoint/mnist.ckpt" + model = mnist_dense.creat_model() + + model.load_weights(model_save_path) + for i in range(1): + image_path = str + img = Image.open(image_path) + img = img.resize((28, 28), Image.Resampling.LANCZOS) + img_arr = np.array(img.convert('L')) + # + # for i in range(28): + # for j in range(28): + # if img_arr[i][j] < 200: + # img_arr[i][j] = 255 + # else: + # img_arr[i][j] = 0 + + img_arr = img_arr / 255.0 + x_predict = img_arr[tf.newaxis, ...] + result = model.predict(x_predict) + pred = tf.argmax(result, axis=1) + return pred.numpy()[0] + + + +def prediction_2(str): + model_save_path = "./checkpoint/mnist_cnn1.ckpt" + model = mnist_cnn.creat_model() + + model.load_weights(model_save_path) + for i in range(1): + image_path = str + img = Image.open(image_path) + img = img.resize((28, 28), Image.Resampling.LANCZOS) + img_arr = np.array(img.convert('L')) + # + # for i in range(28): + # for j in range(28): + # if img_arr[i][j] < 200: + # img_arr[i][j] = 255 + # else: + # img_arr[i][j] = 0 + + img_arr = img_arr / 255.0 + x_predict = img_arr[tf.newaxis, ...] + x_predict = tf.expand_dims(x_predict, -1) + result = model.predict(x_predict) + print(result) + pred = tf.argmax(result, axis=1) + return pred.numpy()[0] + + +def prediction_3(str): + model_save_path = "./checkpoint/mnist_cnn3.ckpt" + model = mnist_cnn3.creat_model() + + model.load_weights(model_save_path) + for i in range(1): + image_path = str + img = Image.open(image_path) + img = img.resize((28, 28), Image.Resampling.LANCZOS) + img_arr = np.array(img.convert('L')) + # + # for i in range(28): + # for j in range(28): + # if img_arr[i][j] < 200: + # img_arr[i][j] = 255 + # else: + # img_arr[i][j] = 0 + + img_arr = img_arr / 255.0 + x_predict = img_arr[tf.newaxis, ...] + result = model.predict(x_predict) + + pred = tf.argmax(result, axis=1) + return pred.numpy()[0] diff --git a/code/test_cnn3.py b/code/test_cnn3.py new file mode 100644 index 0000000..c3a291c --- /dev/null +++ b/code/test_cnn3.py @@ -0,0 +1,35 @@ +from PIL import Image +import numpy as np +import tensorflow as tf +import mnist_cnn3 + + +model_save_path = 'E:\\Python_touge\\MNIST\\venv\\checkpoint\\mnist_cnn3.ckpt' +model = mnist_cnn3.creat_model() +model.load_weights(model_save_path) + +# preNum = int(input("input the number of test pictures:")) +def prediction(str): + for i in range(1): + image_path = str + img = Image.open(image_path) + img = img.resize((28, 28), Image.Resampling.LANCZOS) + img_arr = np.array(img.convert('L')) + # + # for i in range(28): + # for j in range(28): + # if img_arr[i][j] < 200: + # img_arr[i][j] = 255 + # else: + # img_arr[i][j] = 0 + + img_arr = img_arr / 255.0 + x_predict = img_arr[tf.newaxis, ...] + result = model.predict(x_predict) + pred = tf.argmax(result, axis=1) + if result[(0, pred)] <= 0.8: + str = "无法判断,请重新输入!" + print(result) + return str + else: + return pred.numpy()[0]