master
qidangge 3 years ago
parent b8dad55c70
commit 726785cf81

@ -0,0 +1,3 @@
# 默认忽略的文件
/shelf/
/workspace.xml

@ -0,0 +1 @@
mnist_model1.py

@ -0,0 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.7 (tf2)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (tf2)" project-jdk-type="Python SDK" />
</project>

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/MNIST.iml" filepath="$PROJECT_DIR$/.idea/MNIST.iml" />
</modules>
</component>
</project>

@ -1,162 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ChangeListManager">
<list default="true" id="8de3a41d-bd06-410a-83fb-215f3a5dfdde" name="默认变更列表" comment="" />
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
<option value="Python Script" />
</list>
</option>
</component>
<component name="GitSEFilterConfiguration">
<file-type-list>
<filtered-out-file-type name="LOCAL_BRANCH" />
<filtered-out-file-type name="REMOTE_BRANCH" />
<filtered-out-file-type name="TAG" />
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
</file-type-list>
</component>
<component name="ProjectId" id="27IAlN7gLP6kkoWBytmc7WwZmCf" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent">
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
</component>
<component name="RunManager" selected="Python.sc2">
<configuration name="main" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="MNIST" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="mnist_cnn" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="MNIST" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/mnist_cnn.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="mnist_cnn3 (1)" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="MNIST" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/mnist_cnn3.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="mnist_cnn3" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="MNIST" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/venv" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/venv/mnist_cnn3.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="sc2" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="MNIST" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/sc2.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<recent_temporary>
<list>
<item itemvalue="Python.sc2" />
<item itemvalue="Python.main" />
<item itemvalue="Python.mnist_cnn" />
<item itemvalue="Python.mnist_cnn3 (1)" />
<item itemvalue="Python.mnist_cnn3" />
</list>
</recent_temporary>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="默认任务">
<changelist id="8de3a41d-bd06-410a-83fb-215f3a5dfdde" name="默认变更列表" comment="" />
<created>1648997377373</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1648997377373</updated>
</task>
<servers />
</component>
</project>

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.9 KiB

@ -0,0 +1,15 @@
def isPrime(n):
# 判断数字是否为素数
# 请在此处添加代码 #
# *************begin************#
if n<2:
return False
if n==2:
return True
if n%2 == 0:
return False
for i in range(2,n):
if n%i == 0:
return False
return True
print(isPrime(10))

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 KiB

@ -0,0 +1,2 @@
model_checkpoint_path: "mnist.ckpt"
all_model_checkpoint_paths: "mnist.ckpt"

Binary file not shown.

@ -0,0 +1,116 @@
from tkinter import *
import cv2
from PIL import ImageGrab
from tkinter import filedialog
import read_image
model=3
def model_1():
global model
model = 1
print(model)
def model_2():
global model
model = 2
print(model)
def model_3():
global model
model = 3
print(model)
def paint(event):
x1, y1 = (event.x - 20), (event.y - 20)
x2, y2 = (event.x + 20), (event.y + 20)
w.create_oval(x1, y1, x2, y2, fill="white", outline='white')
def open_image():
image_name = filedialog.askopenfilename(title='打开图片', filetypes=[('jpg,jpeg', '*.jpg')])
image_show = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE)
cv2.imshow("image", image_show)
print(model)
result = read_image.predeiction(image_name, model)
text.set(str(result))
def screenshot(*args):
a = root.winfo_x()
b = root.winfo_y()
a = a + 10
b = b + 35
bbox = (a, b, a + 395, b + 395)
im = ImageGrab.grab(bbox)
im.save('1.jpg')
print("在使用%d模型"%model)
result = read_image.predeiction('1.jpg', model)
text.set(str(result))
def clear_canvas(event):
x1, y1 = (event.x - 2800), (event.y - 2800)
x2, y2 = (event.x + 2800), (event.y + 2800)
w.create_oval(x1, y1, x2, y2, fill="black", outline='black')
def reset_canvas():
a = root.winfo_x()
b = root.winfo_y()
x1, y1 = (a - 2800), (b - 2800)
x2, y2 = (a + 2800), (b + 2800)
w.create_oval(x1, y1, x2, y2, fill="black", outline='black')
root = Tk()
root.geometry('600x400') # 规定窗口大小600*400像素
root.resizable(False, False) # 规定窗口不可缩放
root.title('数字识别')
col_count, row_count = root.grid_size()
for col in range(col_count):
root.grid_columnconfigure(col, minsize=10)
for row in range(row_count):
root.grid_rowconfigure(row, minsize=20)
text = StringVar()
text.set('')
w = Canvas(root, width=400, height=400, bg='black')
w.grid(row=0, column=0, rowspan=6)
label_1 = Label(root, text=' 识别的结果为:', font=('', 20))
label_1.grid(row=0, column=1)
result_label = Label(root, textvariable=text, font=('', 25), height=2, fg='red')
result_label.grid(row=1, column=1)
try_button = Button(root, text='模型1', width=7, height=2, command=model_1)
try_button.grid(row=2, column=1,sticky=W)
try_button = Button(root, text='模型2', width=7, height=2, command=model_2)
try_button.grid(row=2, column=1)
try_button = Button(root, text='模型3', width=7, height=2, command=model_3)
try_button.grid(row=2, column=1,sticky=E)
try_button = Button(root, text='开始识别', width=15, height=2, command=screenshot)
try_button.grid(row=3, column=1)
clear_button = Button(root, text='清空画布', width=15, height=2, command=reset_canvas)
clear_button.grid(row=4, column=1)
load_image_button = Button(root, text='来自图片', width=15, height=2, command=open_image)
load_image_button.grid(row=5, column=1)
w.bind("<B1-Motion>", paint)
w.bind("<Button-3>", screenshot)
w.bind("<Double-Button-1>", clear_canvas)
mainloop()

@ -0,0 +1,66 @@
import tensorflow as tf
import os
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
def creat_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=["sparse_categorical_accuracy"])
return model
def model_fit(model,check_save_path):
if os.path.exists(check_save_path+'.index'):
print("load modals...")
model.load_weights(check_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback])
model.summary()
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
def eva_acc(str, model):
model.load_weights(str)
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
if __name__=="__main__":
check_save_path = "./checkpoint/mnist.ckpt"
model = creat_model()
model_fit(model, check_save_path)
eva_acc(check_save_path, model)

@ -0,0 +1,120 @@
import tensorflow.keras as keras
import numpy as np
import os
import tensorflow as tf
import matplotlib.pyplot as plt
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train/255.0
x_test = x_test/255.0
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
print("train shape:", x_train.shape)
print("test shape:", x_test.shape)
# 使用此类进行图形增强
datagen = keras.preprocessing.image.ImageDataGenerator(
rotation_range=20, # 整数。随机旋转的度数范围。
width_shift_range=0.20, # 浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。
shear_range=15, # 浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。
zoom_range=0.10, # 浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用来进行随机的放大。
validation_split=0.15, # 浮点型。保留用于验证集的图像比例严格在0,1之间
horizontal_flip=False # 布尔值,随机水平翻转。
)
train_datagen = datagen.flow(
x_train,
y_train,
batch_size=256,
subset="training"
)
validation_genetor = datagen.flow(
x_train,
y_train,
batch_size=64,
subset="validation"
)
def creat_model():
model = keras.Sequential([
keras.layers.Reshape((28, 28, 1)),
keras.layers.Conv2D(filters=32, kernel_size=(5, 5), activation="relu", padding="same",
input_shape=(28, 28, 1)),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.Conv2D(filters=128, kernel_size=(3, 3), activation="relu", padding="same"),
keras.layers.MaxPool2D((2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(512, activation="sigmoid"),
keras.layers.Dropout(0.25),
keras.layers.Dense(512, activation="sigmoid"),
keras.layers.Dropout(0.25),
keras.layers.Dense(256, activation="sigmoid"),
keras.layers.Dropout(0.1),
keras.layers.Dense(10, activation="sigmoid")
])
model.compile(optimizer='adam',
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=["sparse_categorical_accuracy"])
return model
def model_fit(model, check_save_path):
if os.path.exists(check_save_path+'.index'):
print("load modals...")
model.load_weights(check_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
save_weights_only=True,
save_best_only=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
factor=0.1,
patience=5,
min_lr=0.000001,
verbose=1)
history = model.fit(train_datagen, epochs=1, validation_data=validation_genetor, callbacks=[reduce_lr,cp_callback],verbose=1)
model.summary()
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
def model_valtest(model, check_save_path):
model.load_weights(check_save_path)
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
if __name__ == "__main__":
check_save_path = "./checkpoint/mnist_cnn3.ckpt"
model = creat_model()
# model_fit(model, check_save_path)
model_valtest(model, check_save_path)

@ -0,0 +1,66 @@
import tensorflow as tf
import os
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
def creat_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=["sparse_categorical_accuracy"])
return model
def model_fit(model,check_save_path):
if os.path.exists(check_save_path+'.index'):
print("load modals...")
model.load_weights(check_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback])
model.summary()
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
def eva_acc(str, model):
model.load_weights(str)
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
if __name__=="__main__":
check_save_path = "./checkpoint/mnist.ckpt"
model = creat_model()
model_fit(model, check_save_path)
eva_acc(check_save_path, model)

@ -0,0 +1,11 @@
import tensorflow as tf
import os
mnist = tf.keras.datasets.mnist
from matplotlib import pyplot as plt
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
print(y_test)
plt.imshow(x_train[0], cmap="gray")
plt.show()

@ -0,0 +1,40 @@
from PIL import Image
import numpy as np
import tensorflow as tf
model_save_path = './checkpoint/mnist.ckpt'
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')])
model.load_weights(model_save_path)
preNum = int(input("input the number of test pictures:"))
for i in range(preNum):
image_path = input("the path of test picture:")
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
if result[0][pred] <= 0.3:
print(result)
print("无法判断,请重新输入!")
else:
print(result)
tf.print(pred[0])

@ -0,0 +1,95 @@
from PIL import Image
import numpy as np
import tensorflow as tf
import mnist_cnn3,mnist_cnn,mnist_dense
def predeiction(str, model):
if model==1:
res = prediction_1(str)
elif model==2:
res = prediction_2(str)
elif model==3:
res = prediction_3(str)
return res
# preNum = int(input("input the number of test pictures:"))
def prediction_1(str):
model_save_path = "./checkpoint/mnist.ckpt"
model = mnist_dense.creat_model()
model.load_weights(model_save_path)
for i in range(1):
image_path = str
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
return pred.numpy()[0]
def prediction_2(str):
model_save_path = "./checkpoint/mnist_cnn1.ckpt"
model = mnist_cnn.creat_model()
model.load_weights(model_save_path)
for i in range(1):
image_path = str
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
x_predict = tf.expand_dims(x_predict, -1)
result = model.predict(x_predict)
print(result)
pred = tf.argmax(result, axis=1)
return pred.numpy()[0]
def prediction_3(str):
model_save_path = "./checkpoint/mnist_cnn3.ckpt"
model = mnist_cnn3.creat_model()
model.load_weights(model_save_path)
for i in range(1):
image_path = str
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
return pred.numpy()[0]

@ -0,0 +1,35 @@
from PIL import Image
import numpy as np
import tensorflow as tf
import mnist_cnn3
model_save_path = 'E:\\Python_touge\\MNIST\\venv\\checkpoint\\mnist_cnn3.ckpt'
model = mnist_cnn3.creat_model()
model.load_weights(model_save_path)
# preNum = int(input("input the number of test pictures:"))
def prediction(str):
for i in range(1):
image_path = str
img = Image.open(image_path)
img = img.resize((28, 28), Image.Resampling.LANCZOS)
img_arr = np.array(img.convert('L'))
#
# for i in range(28):
# for j in range(28):
# if img_arr[i][j] < 200:
# img_arr[i][j] = 255
# else:
# img_arr[i][j] = 0
img_arr = img_arr / 255.0
x_predict = img_arr[tf.newaxis, ...]
result = model.predict(x_predict)
pred = tf.argmax(result, axis=1)
if result[(0, pred)] <= 0.8:
str = "无法判断,请重新输入!"
print(result)
return str
else:
return pred.numpy()[0]
Loading…
Cancel
Save