diff --git a/code/.idea/.gitignore b/code/.idea/.gitignore
new file mode 100644
index 0000000..50d9d22
--- /dev/null
+++ b/code/.idea/.gitignore
@@ -0,0 +1,3 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
diff --git a/code/.idea/.name b/code/.idea/.name
new file mode 100644
index 0000000..546302d
--- /dev/null
+++ b/code/.idea/.name
@@ -0,0 +1 @@
+mnist_model1.py
\ No newline at end of file
diff --git a/code/.idea/MNIST.iml b/code/.idea/MNIST.iml
new file mode 100644
index 0000000..32048ae
--- /dev/null
+++ b/code/.idea/MNIST.iml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/code/.idea/inspectionProfiles/profiles_settings.xml b/code/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/code/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/code/.idea/misc.xml b/code/.idea/misc.xml
new file mode 100644
index 0000000..628f4b0
--- /dev/null
+++ b/code/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/code/.idea/modules.xml b/code/.idea/modules.xml
new file mode 100644
index 0000000..e38aea9
--- /dev/null
+++ b/code/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/code/.idea/workspace.xml b/code/.idea/workspace.xml
deleted file mode 100644
index 0ab5a71..0000000
--- a/code/.idea/workspace.xml
+++ /dev/null
@@ -1,162 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 1648997377373
-
-
- 1648997377373
-
-
-
-
\ No newline at end of file
diff --git a/code/1.jpg b/code/1.jpg
new file mode 100644
index 0000000..71fe1bb
Binary files /dev/null and b/code/1.jpg differ
diff --git a/code/1.py b/code/1.py
new file mode 100644
index 0000000..d7db9e7
--- /dev/null
+++ b/code/1.py
@@ -0,0 +1,15 @@
+def isPrime(n):
+ # 判断数字是否为素数
+ # 请在此处添加代码 #
+ # *************begin************#
+ if n<2:
+ return False
+ if n==2:
+ return True
+ if n%2 == 0:
+ return False
+ for i in range(2,n):
+ if n%i == 0:
+ return False
+ return True
+print(isPrime(10))
\ No newline at end of file
diff --git a/code/2.png b/code/2.png
new file mode 100644
index 0000000..cc4713e
Binary files /dev/null and b/code/2.png differ
diff --git a/code/__pycache__/mnist_cnn.cpython-37.pyc b/code/__pycache__/mnist_cnn.cpython-37.pyc
new file mode 100644
index 0000000..e81c16b
Binary files /dev/null and b/code/__pycache__/mnist_cnn.cpython-37.pyc differ
diff --git a/code/__pycache__/mnist_cnn3.cpython-37.pyc b/code/__pycache__/mnist_cnn3.cpython-37.pyc
new file mode 100644
index 0000000..7c3dc0d
Binary files /dev/null and b/code/__pycache__/mnist_cnn3.cpython-37.pyc differ
diff --git a/code/__pycache__/mnist_dense.cpython-37.pyc b/code/__pycache__/mnist_dense.cpython-37.pyc
new file mode 100644
index 0000000..7c73fa5
Binary files /dev/null and b/code/__pycache__/mnist_dense.cpython-37.pyc differ
diff --git a/code/__pycache__/read_image.cpython-37.pyc b/code/__pycache__/read_image.cpython-37.pyc
new file mode 100644
index 0000000..9fa8fbc
Binary files /dev/null and b/code/__pycache__/read_image.cpython-37.pyc differ
diff --git a/code/__pycache__/test1.cpython-37.pyc b/code/__pycache__/test1.cpython-37.pyc
new file mode 100644
index 0000000..017d96d
Binary files /dev/null and b/code/__pycache__/test1.cpython-37.pyc differ
diff --git a/code/__pycache__/test_cnn3.cpython-37.pyc b/code/__pycache__/test_cnn3.cpython-37.pyc
new file mode 100644
index 0000000..7a50e8d
Binary files /dev/null and b/code/__pycache__/test_cnn3.cpython-37.pyc differ
diff --git a/code/checkpoint/checkpoint b/code/checkpoint/checkpoint
new file mode 100644
index 0000000..9beb43e
--- /dev/null
+++ b/code/checkpoint/checkpoint
@@ -0,0 +1,2 @@
+model_checkpoint_path: "mnist.ckpt"
+all_model_checkpoint_paths: "mnist.ckpt"
diff --git a/code/checkpoint/mnist.ckpt.data-00000-of-00001 b/code/checkpoint/mnist.ckpt.data-00000-of-00001
new file mode 100644
index 0000000..23d9b55
Binary files /dev/null and b/code/checkpoint/mnist.ckpt.data-00000-of-00001 differ
diff --git a/code/checkpoint/mnist.ckpt.index b/code/checkpoint/mnist.ckpt.index
new file mode 100644
index 0000000..2d65419
Binary files /dev/null and b/code/checkpoint/mnist.ckpt.index differ
diff --git a/code/checkpoint/mnist_cnn1.ckpt.data-00000-of-00001 b/code/checkpoint/mnist_cnn1.ckpt.data-00000-of-00001
new file mode 100644
index 0000000..bdb4112
Binary files /dev/null and b/code/checkpoint/mnist_cnn1.ckpt.data-00000-of-00001 differ
diff --git a/code/checkpoint/mnist_cnn1.ckpt.index b/code/checkpoint/mnist_cnn1.ckpt.index
new file mode 100644
index 0000000..bbc3924
Binary files /dev/null and b/code/checkpoint/mnist_cnn1.ckpt.index differ
diff --git a/code/checkpoint/mnist_cnn3.ckpt.data-00000-of-00001 b/code/checkpoint/mnist_cnn3.ckpt.data-00000-of-00001
new file mode 100644
index 0000000..32d49d1
Binary files /dev/null and b/code/checkpoint/mnist_cnn3.ckpt.data-00000-of-00001 differ
diff --git a/code/checkpoint/mnist_cnn3.ckpt.index b/code/checkpoint/mnist_cnn3.ckpt.index
new file mode 100644
index 0000000..6baf7dc
Binary files /dev/null and b/code/checkpoint/mnist_cnn3.ckpt.index differ
diff --git a/code/main.py b/code/main.py
new file mode 100644
index 0000000..e172123
--- /dev/null
+++ b/code/main.py
@@ -0,0 +1,116 @@
+from tkinter import *
+
+import cv2
+from PIL import ImageGrab
+from tkinter import filedialog
+
+import read_image
+
+model=3
+def model_1():
+ global model
+ model = 1
+ print(model)
+
+def model_2():
+ global model
+ model = 2
+ print(model)
+def model_3():
+ global model
+ model = 3
+ print(model)
+
+def paint(event):
+ x1, y1 = (event.x - 20), (event.y - 20)
+ x2, y2 = (event.x + 20), (event.y + 20)
+ w.create_oval(x1, y1, x2, y2, fill="white", outline='white')
+
+
+def open_image():
+ image_name = filedialog.askopenfilename(title='打开图片', filetypes=[('jpg,jpeg', '*.jpg')])
+ image_show = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE)
+ cv2.imshow("image", image_show)
+ print(model)
+ result = read_image.predeiction(image_name, model)
+ text.set(str(result))
+
+def screenshot(*args):
+ a = root.winfo_x()
+ b = root.winfo_y()
+ a = a + 10
+ b = b + 35
+ bbox = (a, b, a + 395, b + 395)
+ im = ImageGrab.grab(bbox)
+ im.save('1.jpg')
+ print("在使用%d模型"%model)
+ result = read_image.predeiction('1.jpg', model)
+
+ text.set(str(result))
+
+
+def clear_canvas(event):
+ x1, y1 = (event.x - 2800), (event.y - 2800)
+ x2, y2 = (event.x + 2800), (event.y + 2800)
+ w.create_oval(x1, y1, x2, y2, fill="black", outline='black')
+
+
+
+def reset_canvas():
+ a = root.winfo_x()
+ b = root.winfo_y()
+ x1, y1 = (a - 2800), (b - 2800)
+ x2, y2 = (a + 2800), (b + 2800)
+ w.create_oval(x1, y1, x2, y2, fill="black", outline='black')
+
+
+root = Tk()
+root.geometry('600x400') # 规定窗口大小600*400像素
+root.resizable(False, False) # 规定窗口不可缩放
+root.title('数字识别')
+
+col_count, row_count = root.grid_size()
+
+for col in range(col_count):
+ root.grid_columnconfigure(col, minsize=10)
+
+for row in range(row_count):
+ root.grid_rowconfigure(row, minsize=20)
+
+
+text = StringVar()
+text.set('')
+w = Canvas(root, width=400, height=400, bg='black')
+w.grid(row=0, column=0, rowspan=6)
+label_1 = Label(root, text=' 识别的结果为:', font=('', 20))
+label_1.grid(row=0, column=1)
+result_label = Label(root, textvariable=text, font=('', 25), height=2, fg='red')
+result_label.grid(row=1, column=1)
+
+try_button = Button(root, text='模型1', width=7, height=2, command=model_1)
+try_button.grid(row=2, column=1,sticky=W)
+try_button = Button(root, text='模型2', width=7, height=2, command=model_2)
+try_button.grid(row=2, column=1)
+try_button = Button(root, text='模型3', width=7, height=2, command=model_3)
+try_button.grid(row=2, column=1,sticky=E)
+
+try_button = Button(root, text='开始识别', width=15, height=2, command=screenshot)
+try_button.grid(row=3, column=1)
+
+clear_button = Button(root, text='清空画布', width=15, height=2, command=reset_canvas)
+clear_button.grid(row=4, column=1)
+
+load_image_button = Button(root, text='来自图片', width=15, height=2, command=open_image)
+load_image_button.grid(row=5, column=1)
+
+w.bind("", paint)
+w.bind("", screenshot)
+w.bind("", clear_canvas)
+
+mainloop()
+
+
+
+
+
+
diff --git a/code/mnist_cnn.py b/code/mnist_cnn.py
new file mode 100644
index 0000000..4a944ac
--- /dev/null
+++ b/code/mnist_cnn.py
@@ -0,0 +1,66 @@
+import tensorflow as tf
+import os
+from matplotlib import pyplot as plt
+from PIL import Image
+import numpy as np
+
+mnist = tf.keras.datasets.mnist
+
+(x_train, y_train), (x_test, y_test) = mnist.load_data()
+x_train, x_test = x_train/255.0, x_test/255.0
+
+def creat_model():
+ model = tf.keras.models.Sequential([
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(128, activation='relu'),
+ tf.keras.layers.Dense(10, activation='softmax')
+ ])
+
+ model.compile(optimizer='adam',
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
+ metrics=["sparse_categorical_accuracy"])
+ return model
+def model_fit(model,check_save_path):
+
+ if os.path.exists(check_save_path+'.index'):
+ print("load modals...")
+ model.load_weights(check_save_path)
+
+ cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
+ save_weights_only=True,
+ save_best_only=True)
+
+ history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback])
+ model.summary()
+
+ final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
+ print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
+
+ acc = history.history['sparse_categorical_accuracy']
+ val_acc = history.history['val_sparse_categorical_accuracy']
+ loss = history.history['loss']
+ val_loss = history.history['val_loss']
+
+ plt.subplot(1, 2, 1)
+ plt.plot(acc, label='Training Accuracy')
+ plt.plot(val_acc, label='Validation Accuracy')
+ plt.title('Training and Validation Accuracy')
+ plt.legend()
+
+ plt.subplot(1, 2, 2)
+ plt.plot(loss, label='Training Loss')
+ plt.plot(val_loss, label='Validation Loss')
+ plt.title('Training and Validation Loss')
+ plt.legend()
+ plt.show()
+
+def eva_acc(str, model):
+ model.load_weights(str)
+ final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
+ print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
+
+if __name__=="__main__":
+ check_save_path = "./checkpoint/mnist.ckpt"
+ model = creat_model()
+ model_fit(model, check_save_path)
+ eva_acc(check_save_path, model)
diff --git a/code/mnist_cnn2.py b/code/mnist_cnn2.py
new file mode 100644
index 0000000..e69de29
diff --git a/code/mnist_cnn3.py b/code/mnist_cnn3.py
new file mode 100644
index 0000000..cc10dfc
--- /dev/null
+++ b/code/mnist_cnn3.py
@@ -0,0 +1,120 @@
+import tensorflow.keras as keras
+import numpy as np
+import os
+import tensorflow as tf
+import matplotlib.pyplot as plt
+mnist = keras.datasets.mnist
+(x_train, y_train), (x_test, y_test) = mnist.load_data()
+x_train = x_train/255.0
+x_test = x_test/255.0
+
+x_train = tf.expand_dims(x_train, -1)
+x_test = tf.expand_dims(x_test, -1)
+
+print("train shape:", x_train.shape)
+print("test shape:", x_test.shape)
+
+# 使用此类进行图形增强
+datagen = keras.preprocessing.image.ImageDataGenerator(
+ rotation_range=20, # 整数。随机旋转的度数范围。
+ width_shift_range=0.20, # 浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。
+ shear_range=15, # 浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。
+ zoom_range=0.10, # 浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用来进行随机的放大。
+ validation_split=0.15, # 浮点型。保留用于验证集的图像比例(严格在0,1之间)
+ horizontal_flip=False # 布尔值,随机水平翻转。
+)
+
+train_datagen = datagen.flow(
+ x_train,
+ y_train,
+ batch_size=256,
+ subset="training"
+)
+
+validation_genetor = datagen.flow(
+ x_train,
+ y_train,
+ batch_size=64,
+ subset="validation"
+)
+def creat_model():
+ model = keras.Sequential([
+ keras.layers.Reshape((28, 28, 1)),
+ keras.layers.Conv2D(filters=32, kernel_size=(5, 5), activation="relu", padding="same",
+ input_shape=(28, 28, 1)),
+ keras.layers.MaxPool2D((2, 2)),
+
+ keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
+ keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
+ keras.layers.MaxPool2D((2, 2)),
+
+ keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
+ keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
+ keras.layers.MaxPool2D((2, 2)),
+
+ keras.layers.Flatten(),
+ keras.layers.Dense(512, activation="sigmoid"),
+ keras.layers.Dropout(0.25),
+
+ keras.layers.Dense(512, activation="sigmoid"),
+ keras.layers.Dropout(0.25),
+
+ keras.layers.Dense(256, activation="sigmoid"),
+ keras.layers.Dropout(0.1),
+
+ keras.layers.Dense(10, activation="sigmoid")
+ ])
+
+ model.compile(optimizer='adam',
+ loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
+ metrics=["sparse_categorical_accuracy"])
+ return model
+
+def model_fit(model, check_save_path):
+ if os.path.exists(check_save_path+'.index'):
+ print("load modals...")
+ model.load_weights(check_save_path)
+
+ cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
+ save_weights_only=True,
+ save_best_only=True)
+ reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
+ factor=0.1,
+ patience=5,
+ min_lr=0.000001,
+ verbose=1)
+
+
+ history = model.fit(train_datagen, epochs=1, validation_data=validation_genetor, callbacks=[reduce_lr,cp_callback],verbose=1)
+ model.summary()
+
+ acc = history.history['sparse_categorical_accuracy']
+ val_acc = history.history['val_sparse_categorical_accuracy']
+ loss = history.history['loss']
+ val_loss = history.history['val_loss']
+
+ plt.subplot(1, 2, 1)
+ plt.plot(acc, label='Training Accuracy')
+ plt.plot(val_acc, label='Validation Accuracy')
+ plt.title('Training and Validation Accuracy')
+ plt.legend()
+
+ plt.subplot(1, 2, 2)
+ plt.plot(loss, label='Training Loss')
+ plt.plot(val_loss, label='Validation Loss')
+ plt.title('Training and Validation Loss')
+ plt.legend()
+ plt.show()
+
+
+def model_valtest(model, check_save_path):
+ model.load_weights(check_save_path)
+ final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
+ print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
+
+if __name__ == "__main__":
+
+ check_save_path = "./checkpoint/mnist_cnn3.ckpt"
+ model = creat_model()
+ # model_fit(model, check_save_path)
+ model_valtest(model, check_save_path)
\ No newline at end of file
diff --git a/code/mnist_dense.py b/code/mnist_dense.py
new file mode 100644
index 0000000..4a944ac
--- /dev/null
+++ b/code/mnist_dense.py
@@ -0,0 +1,66 @@
+import tensorflow as tf
+import os
+from matplotlib import pyplot as plt
+from PIL import Image
+import numpy as np
+
+mnist = tf.keras.datasets.mnist
+
+(x_train, y_train), (x_test, y_test) = mnist.load_data()
+x_train, x_test = x_train/255.0, x_test/255.0
+
+def creat_model():
+ model = tf.keras.models.Sequential([
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(128, activation='relu'),
+ tf.keras.layers.Dense(10, activation='softmax')
+ ])
+
+ model.compile(optimizer='adam',
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
+ metrics=["sparse_categorical_accuracy"])
+ return model
+def model_fit(model,check_save_path):
+
+ if os.path.exists(check_save_path+'.index'):
+ print("load modals...")
+ model.load_weights(check_save_path)
+
+ cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
+ save_weights_only=True,
+ save_best_only=True)
+
+ history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback])
+ model.summary()
+
+ final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
+ print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
+
+ acc = history.history['sparse_categorical_accuracy']
+ val_acc = history.history['val_sparse_categorical_accuracy']
+ loss = history.history['loss']
+ val_loss = history.history['val_loss']
+
+ plt.subplot(1, 2, 1)
+ plt.plot(acc, label='Training Accuracy')
+ plt.plot(val_acc, label='Validation Accuracy')
+ plt.title('Training and Validation Accuracy')
+ plt.legend()
+
+ plt.subplot(1, 2, 2)
+ plt.plot(loss, label='Training Loss')
+ plt.plot(val_loss, label='Validation Loss')
+ plt.title('Training and Validation Loss')
+ plt.legend()
+ plt.show()
+
+def eva_acc(str, model):
+ model.load_weights(str)
+ final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
+ print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
+
+if __name__=="__main__":
+ check_save_path = "./checkpoint/mnist.ckpt"
+ model = creat_model()
+ model_fit(model, check_save_path)
+ eva_acc(check_save_path, model)
diff --git a/code/mnist_show.py b/code/mnist_show.py
new file mode 100644
index 0000000..dee4aad
--- /dev/null
+++ b/code/mnist_show.py
@@ -0,0 +1,11 @@
+import tensorflow as tf
+import os
+mnist = tf.keras.datasets.mnist
+from matplotlib import pyplot as plt
+
+(x_train, y_train), (x_test, y_test) = mnist.load_data()
+x_train, x_test = x_train/255.0, x_test/255.0
+print(y_test)
+
+plt.imshow(x_train[0], cmap="gray")
+plt.show()
\ No newline at end of file
diff --git a/code/mnist_test.py b/code/mnist_test.py
new file mode 100644
index 0000000..9e5643b
--- /dev/null
+++ b/code/mnist_test.py
@@ -0,0 +1,40 @@
+from PIL import Image
+import numpy as np
+import tensorflow as tf
+
+model_save_path = './checkpoint/mnist.ckpt'
+
+model = tf.keras.models.Sequential([
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(128, activation='relu'),
+ tf.keras.layers.Dense(10, activation='softmax')])
+
+model.load_weights(model_save_path)
+
+preNum = int(input("input the number of test pictures:"))
+
+for i in range(preNum):
+ image_path = input("the path of test picture:")
+ img = Image.open(image_path)
+ img = img.resize((28, 28), Image.Resampling.LANCZOS)
+ img_arr = np.array(img.convert('L'))
+ #
+ # for i in range(28):
+ # for j in range(28):
+ # if img_arr[i][j] < 200:
+ # img_arr[i][j] = 255
+ # else:
+ # img_arr[i][j] = 0
+
+ img_arr = img_arr / 255.0
+ x_predict = img_arr[tf.newaxis, ...]
+ result = model.predict(x_predict)
+ pred = tf.argmax(result, axis=1)
+ if result[0][pred] <= 0.3:
+ print(result)
+ print("无法判断,请重新输入!")
+ else:
+ print(result)
+ tf.print(pred[0])
+
+
diff --git a/code/read_image.py b/code/read_image.py
new file mode 100644
index 0000000..f0c5a4d
--- /dev/null
+++ b/code/read_image.py
@@ -0,0 +1,95 @@
+from PIL import Image
+import numpy as np
+import tensorflow as tf
+import mnist_cnn3,mnist_cnn,mnist_dense
+
+
+def predeiction(str, model):
+ if model==1:
+
+ res = prediction_1(str)
+ elif model==2:
+ res = prediction_2(str)
+ elif model==3:
+ res = prediction_3(str)
+ return res
+
+
+# preNum = int(input("input the number of test pictures:"))
+def prediction_1(str):
+ model_save_path = "./checkpoint/mnist.ckpt"
+ model = mnist_dense.creat_model()
+
+ model.load_weights(model_save_path)
+ for i in range(1):
+ image_path = str
+ img = Image.open(image_path)
+ img = img.resize((28, 28), Image.Resampling.LANCZOS)
+ img_arr = np.array(img.convert('L'))
+ #
+ # for i in range(28):
+ # for j in range(28):
+ # if img_arr[i][j] < 200:
+ # img_arr[i][j] = 255
+ # else:
+ # img_arr[i][j] = 0
+
+ img_arr = img_arr / 255.0
+ x_predict = img_arr[tf.newaxis, ...]
+ result = model.predict(x_predict)
+ pred = tf.argmax(result, axis=1)
+ return pred.numpy()[0]
+
+
+
+def prediction_2(str):
+ model_save_path = "./checkpoint/mnist_cnn1.ckpt"
+ model = mnist_cnn.creat_model()
+
+ model.load_weights(model_save_path)
+ for i in range(1):
+ image_path = str
+ img = Image.open(image_path)
+ img = img.resize((28, 28), Image.Resampling.LANCZOS)
+ img_arr = np.array(img.convert('L'))
+ #
+ # for i in range(28):
+ # for j in range(28):
+ # if img_arr[i][j] < 200:
+ # img_arr[i][j] = 255
+ # else:
+ # img_arr[i][j] = 0
+
+ img_arr = img_arr / 255.0
+ x_predict = img_arr[tf.newaxis, ...]
+ x_predict = tf.expand_dims(x_predict, -1)
+ result = model.predict(x_predict)
+ print(result)
+ pred = tf.argmax(result, axis=1)
+ return pred.numpy()[0]
+
+
+def prediction_3(str):
+ model_save_path = "./checkpoint/mnist_cnn3.ckpt"
+ model = mnist_cnn3.creat_model()
+
+ model.load_weights(model_save_path)
+ for i in range(1):
+ image_path = str
+ img = Image.open(image_path)
+ img = img.resize((28, 28), Image.Resampling.LANCZOS)
+ img_arr = np.array(img.convert('L'))
+ #
+ # for i in range(28):
+ # for j in range(28):
+ # if img_arr[i][j] < 200:
+ # img_arr[i][j] = 255
+ # else:
+ # img_arr[i][j] = 0
+
+ img_arr = img_arr / 255.0
+ x_predict = img_arr[tf.newaxis, ...]
+ result = model.predict(x_predict)
+
+ pred = tf.argmax(result, axis=1)
+ return pred.numpy()[0]
diff --git a/code/test_cnn3.py b/code/test_cnn3.py
new file mode 100644
index 0000000..c3a291c
--- /dev/null
+++ b/code/test_cnn3.py
@@ -0,0 +1,35 @@
+from PIL import Image
+import numpy as np
+import tensorflow as tf
+import mnist_cnn3
+
+
+model_save_path = 'E:\\Python_touge\\MNIST\\venv\\checkpoint\\mnist_cnn3.ckpt'
+model = mnist_cnn3.creat_model()
+model.load_weights(model_save_path)
+
+# preNum = int(input("input the number of test pictures:"))
+def prediction(str):
+ for i in range(1):
+ image_path = str
+ img = Image.open(image_path)
+ img = img.resize((28, 28), Image.Resampling.LANCZOS)
+ img_arr = np.array(img.convert('L'))
+ #
+ # for i in range(28):
+ # for j in range(28):
+ # if img_arr[i][j] < 200:
+ # img_arr[i][j] = 255
+ # else:
+ # img_arr[i][j] = 0
+
+ img_arr = img_arr / 255.0
+ x_predict = img_arr[tf.newaxis, ...]
+ result = model.predict(x_predict)
+ pred = tf.argmax(result, axis=1)
+ if result[(0, pred)] <= 0.8:
+ str = "无法判断,请重新输入!"
+ print(result)
+ return str
+ else:
+ return pred.numpy()[0]