parent
b8dad55c70
commit
726785cf81
@ -0,0 +1,3 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
@ -0,0 +1 @@
|
||||
mnist_model1.py
|
@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.7 (tf2)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (tf2)" project-jdk-type="Python SDK" />
|
||||
</project>
|
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/MNIST.iml" filepath="$PROJECT_DIR$/.idea/MNIST.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
@ -1,162 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="8de3a41d-bd06-410a-83fb-215f3a5dfdde" name="默认变更列表" comment="" />
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||
</component>
|
||||
<component name="FileTemplateManagerImpl">
|
||||
<option name="RECENT_TEMPLATES">
|
||||
<list>
|
||||
<option value="Python Script" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
<component name="GitSEFilterConfiguration">
|
||||
<file-type-list>
|
||||
<filtered-out-file-type name="LOCAL_BRANCH" />
|
||||
<filtered-out-file-type name="REMOTE_BRANCH" />
|
||||
<filtered-out-file-type name="TAG" />
|
||||
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
|
||||
</file-type-list>
|
||||
</component>
|
||||
<component name="ProjectId" id="27IAlN7gLP6kkoWBytmc7WwZmCf" />
|
||||
<component name="ProjectViewState">
|
||||
<option name="hideEmptyMiddlePackages" value="true" />
|
||||
<option name="showLibraryContents" value="true" />
|
||||
</component>
|
||||
<component name="PropertiesComponent">
|
||||
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
|
||||
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
|
||||
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
|
||||
</component>
|
||||
<component name="RunManager" selected="Python.sc2">
|
||||
<configuration name="main" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="MNIST" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="mnist_cnn" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="MNIST" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/mnist_cnn.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="mnist_cnn3 (1)" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="MNIST" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/mnist_cnn3.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="mnist_cnn3" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="MNIST" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/venv" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/venv/mnist_cnn3.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="sc2" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="MNIST" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/sc2.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<recent_temporary>
|
||||
<list>
|
||||
<item itemvalue="Python.sc2" />
|
||||
<item itemvalue="Python.main" />
|
||||
<item itemvalue="Python.mnist_cnn" />
|
||||
<item itemvalue="Python.mnist_cnn3 (1)" />
|
||||
<item itemvalue="Python.mnist_cnn3" />
|
||||
</list>
|
||||
</recent_temporary>
|
||||
</component>
|
||||
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
|
||||
<component name="TaskManager">
|
||||
<task active="true" id="Default" summary="默认任务">
|
||||
<changelist id="8de3a41d-bd06-410a-83fb-215f3a5dfdde" name="默认变更列表" comment="" />
|
||||
<created>1648997377373</created>
|
||||
<option name="number" value="Default" />
|
||||
<option name="presentableId" value="Default" />
|
||||
<updated>1648997377373</updated>
|
||||
</task>
|
||||
<servers />
|
||||
</component>
|
||||
</project>
|
After Width: | Height: | Size: 6.9 KiB |
@ -0,0 +1,15 @@
|
||||
def isPrime(n):
|
||||
# 判断数字是否为素数
|
||||
# 请在此处添加代码 #
|
||||
# *************begin************#
|
||||
if n<2:
|
||||
return False
|
||||
if n==2:
|
||||
return True
|
||||
if n%2 == 0:
|
||||
return False
|
||||
for i in range(2,n):
|
||||
if n%i == 0:
|
||||
return False
|
||||
return True
|
||||
print(isPrime(10))
|
After Width: | Height: | Size: 1.6 KiB |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,2 @@
|
||||
model_checkpoint_path: "mnist.ckpt"
|
||||
all_model_checkpoint_paths: "mnist.ckpt"
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,116 @@
|
||||
from tkinter import *
|
||||
|
||||
import cv2
|
||||
from PIL import ImageGrab
|
||||
from tkinter import filedialog
|
||||
|
||||
import read_image
|
||||
|
||||
model=3
|
||||
def model_1():
|
||||
global model
|
||||
model = 1
|
||||
print(model)
|
||||
|
||||
def model_2():
|
||||
global model
|
||||
model = 2
|
||||
print(model)
|
||||
def model_3():
|
||||
global model
|
||||
model = 3
|
||||
print(model)
|
||||
|
||||
def paint(event):
|
||||
x1, y1 = (event.x - 20), (event.y - 20)
|
||||
x2, y2 = (event.x + 20), (event.y + 20)
|
||||
w.create_oval(x1, y1, x2, y2, fill="white", outline='white')
|
||||
|
||||
|
||||
def open_image():
|
||||
image_name = filedialog.askopenfilename(title='打开图片', filetypes=[('jpg,jpeg', '*.jpg')])
|
||||
image_show = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE)
|
||||
cv2.imshow("image", image_show)
|
||||
print(model)
|
||||
result = read_image.predeiction(image_name, model)
|
||||
text.set(str(result))
|
||||
|
||||
def screenshot(*args):
|
||||
a = root.winfo_x()
|
||||
b = root.winfo_y()
|
||||
a = a + 10
|
||||
b = b + 35
|
||||
bbox = (a, b, a + 395, b + 395)
|
||||
im = ImageGrab.grab(bbox)
|
||||
im.save('1.jpg')
|
||||
print("在使用%d模型"%model)
|
||||
result = read_image.predeiction('1.jpg', model)
|
||||
|
||||
text.set(str(result))
|
||||
|
||||
|
||||
def clear_canvas(event):
|
||||
x1, y1 = (event.x - 2800), (event.y - 2800)
|
||||
x2, y2 = (event.x + 2800), (event.y + 2800)
|
||||
w.create_oval(x1, y1, x2, y2, fill="black", outline='black')
|
||||
|
||||
|
||||
|
||||
def reset_canvas():
|
||||
a = root.winfo_x()
|
||||
b = root.winfo_y()
|
||||
x1, y1 = (a - 2800), (b - 2800)
|
||||
x2, y2 = (a + 2800), (b + 2800)
|
||||
w.create_oval(x1, y1, x2, y2, fill="black", outline='black')
|
||||
|
||||
|
||||
root = Tk()
|
||||
root.geometry('600x400') # 规定窗口大小600*400像素
|
||||
root.resizable(False, False) # 规定窗口不可缩放
|
||||
root.title('数字识别')
|
||||
|
||||
col_count, row_count = root.grid_size()
|
||||
|
||||
for col in range(col_count):
|
||||
root.grid_columnconfigure(col, minsize=10)
|
||||
|
||||
for row in range(row_count):
|
||||
root.grid_rowconfigure(row, minsize=20)
|
||||
|
||||
|
||||
text = StringVar()
|
||||
text.set('')
|
||||
w = Canvas(root, width=400, height=400, bg='black')
|
||||
w.grid(row=0, column=0, rowspan=6)
|
||||
label_1 = Label(root, text=' 识别的结果为:', font=('', 20))
|
||||
label_1.grid(row=0, column=1)
|
||||
result_label = Label(root, textvariable=text, font=('', 25), height=2, fg='red')
|
||||
result_label.grid(row=1, column=1)
|
||||
|
||||
try_button = Button(root, text='模型1', width=7, height=2, command=model_1)
|
||||
try_button.grid(row=2, column=1,sticky=W)
|
||||
try_button = Button(root, text='模型2', width=7, height=2, command=model_2)
|
||||
try_button.grid(row=2, column=1)
|
||||
try_button = Button(root, text='模型3', width=7, height=2, command=model_3)
|
||||
try_button.grid(row=2, column=1,sticky=E)
|
||||
|
||||
try_button = Button(root, text='开始识别', width=15, height=2, command=screenshot)
|
||||
try_button.grid(row=3, column=1)
|
||||
|
||||
clear_button = Button(root, text='清空画布', width=15, height=2, command=reset_canvas)
|
||||
clear_button.grid(row=4, column=1)
|
||||
|
||||
load_image_button = Button(root, text='来自图片', width=15, height=2, command=open_image)
|
||||
load_image_button.grid(row=5, column=1)
|
||||
|
||||
w.bind("<B1-Motion>", paint)
|
||||
w.bind("<Button-3>", screenshot)
|
||||
w.bind("<Double-Button-1>", clear_canvas)
|
||||
|
||||
mainloop()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,66 @@
|
||||
import tensorflow as tf
|
||||
import os
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
mnist = tf.keras.datasets.mnist
|
||||
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
x_train, x_test = x_train/255.0, x_test/255.0
|
||||
|
||||
def creat_model():
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dense(10, activation='softmax')
|
||||
])
|
||||
|
||||
model.compile(optimizer='adam',
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
||||
metrics=["sparse_categorical_accuracy"])
|
||||
return model
|
||||
def model_fit(model,check_save_path):
|
||||
|
||||
if os.path.exists(check_save_path+'.index'):
|
||||
print("load modals...")
|
||||
model.load_weights(check_save_path)
|
||||
|
||||
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
|
||||
save_weights_only=True,
|
||||
save_best_only=True)
|
||||
|
||||
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback])
|
||||
model.summary()
|
||||
|
||||
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
|
||||
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
|
||||
|
||||
acc = history.history['sparse_categorical_accuracy']
|
||||
val_acc = history.history['val_sparse_categorical_accuracy']
|
||||
loss = history.history['loss']
|
||||
val_loss = history.history['val_loss']
|
||||
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(acc, label='Training Accuracy')
|
||||
plt.plot(val_acc, label='Validation Accuracy')
|
||||
plt.title('Training and Validation Accuracy')
|
||||
plt.legend()
|
||||
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(loss, label='Training Loss')
|
||||
plt.plot(val_loss, label='Validation Loss')
|
||||
plt.title('Training and Validation Loss')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def eva_acc(str, model):
|
||||
model.load_weights(str)
|
||||
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
|
||||
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
|
||||
|
||||
if __name__=="__main__":
|
||||
check_save_path = "./checkpoint/mnist.ckpt"
|
||||
model = creat_model()
|
||||
model_fit(model, check_save_path)
|
||||
eva_acc(check_save_path, model)
|
@ -0,0 +1,66 @@
|
||||
import tensorflow as tf
|
||||
import os
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
mnist = tf.keras.datasets.mnist
|
||||
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
x_train, x_test = x_train/255.0, x_test/255.0
|
||||
|
||||
def creat_model():
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dense(10, activation='softmax')
|
||||
])
|
||||
|
||||
model.compile(optimizer='adam',
|
||||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
|
||||
metrics=["sparse_categorical_accuracy"])
|
||||
return model
|
||||
def model_fit(model,check_save_path):
|
||||
|
||||
if os.path.exists(check_save_path+'.index'):
|
||||
print("load modals...")
|
||||
model.load_weights(check_save_path)
|
||||
|
||||
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=check_save_path,
|
||||
save_weights_only=True,
|
||||
save_best_only=True)
|
||||
|
||||
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.15, validation_freq=1, callbacks=[cp_callback])
|
||||
model.summary()
|
||||
|
||||
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
|
||||
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
|
||||
|
||||
acc = history.history['sparse_categorical_accuracy']
|
||||
val_acc = history.history['val_sparse_categorical_accuracy']
|
||||
loss = history.history['loss']
|
||||
val_loss = history.history['val_loss']
|
||||
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(acc, label='Training Accuracy')
|
||||
plt.plot(val_acc, label='Validation Accuracy')
|
||||
plt.title('Training and Validation Accuracy')
|
||||
plt.legend()
|
||||
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(loss, label='Training Loss')
|
||||
plt.plot(val_loss, label='Validation Loss')
|
||||
plt.title('Training and Validation Loss')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def eva_acc(str, model):
|
||||
model.load_weights(str)
|
||||
final_loss, final_acc = model.evaluate(x_test, y_test, verbose=2)
|
||||
print("Model accuracy: ", final_acc, ", model loss: ", final_loss)
|
||||
|
||||
if __name__=="__main__":
|
||||
check_save_path = "./checkpoint/mnist.ckpt"
|
||||
model = creat_model()
|
||||
model_fit(model, check_save_path)
|
||||
eva_acc(check_save_path, model)
|
@ -0,0 +1,11 @@
|
||||
import tensorflow as tf
|
||||
import os
|
||||
mnist = tf.keras.datasets.mnist
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
(x_train, y_train), (x_test, y_test) = mnist.load_data()
|
||||
x_train, x_test = x_train/255.0, x_test/255.0
|
||||
print(y_test)
|
||||
|
||||
plt.imshow(x_train[0], cmap="gray")
|
||||
plt.show()
|
@ -0,0 +1,40 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
model_save_path = './checkpoint/mnist.ckpt'
|
||||
|
||||
model = tf.keras.models.Sequential([
|
||||
tf.keras.layers.Flatten(),
|
||||
tf.keras.layers.Dense(128, activation='relu'),
|
||||
tf.keras.layers.Dense(10, activation='softmax')])
|
||||
|
||||
model.load_weights(model_save_path)
|
||||
|
||||
preNum = int(input("input the number of test pictures:"))
|
||||
|
||||
for i in range(preNum):
|
||||
image_path = input("the path of test picture:")
|
||||
img = Image.open(image_path)
|
||||
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
||||
img_arr = np.array(img.convert('L'))
|
||||
#
|
||||
# for i in range(28):
|
||||
# for j in range(28):
|
||||
# if img_arr[i][j] < 200:
|
||||
# img_arr[i][j] = 255
|
||||
# else:
|
||||
# img_arr[i][j] = 0
|
||||
|
||||
img_arr = img_arr / 255.0
|
||||
x_predict = img_arr[tf.newaxis, ...]
|
||||
result = model.predict(x_predict)
|
||||
pred = tf.argmax(result, axis=1)
|
||||
if result[0][pred] <= 0.3:
|
||||
print(result)
|
||||
print("无法判断,请重新输入!")
|
||||
else:
|
||||
print(result)
|
||||
tf.print(pred[0])
|
||||
|
||||
|
@ -0,0 +1,95 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import mnist_cnn3,mnist_cnn,mnist_dense
|
||||
|
||||
|
||||
def predeiction(str, model):
|
||||
if model==1:
|
||||
|
||||
res = prediction_1(str)
|
||||
elif model==2:
|
||||
res = prediction_2(str)
|
||||
elif model==3:
|
||||
res = prediction_3(str)
|
||||
return res
|
||||
|
||||
|
||||
# preNum = int(input("input the number of test pictures:"))
|
||||
def prediction_1(str):
|
||||
model_save_path = "./checkpoint/mnist.ckpt"
|
||||
model = mnist_dense.creat_model()
|
||||
|
||||
model.load_weights(model_save_path)
|
||||
for i in range(1):
|
||||
image_path = str
|
||||
img = Image.open(image_path)
|
||||
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
||||
img_arr = np.array(img.convert('L'))
|
||||
#
|
||||
# for i in range(28):
|
||||
# for j in range(28):
|
||||
# if img_arr[i][j] < 200:
|
||||
# img_arr[i][j] = 255
|
||||
# else:
|
||||
# img_arr[i][j] = 0
|
||||
|
||||
img_arr = img_arr / 255.0
|
||||
x_predict = img_arr[tf.newaxis, ...]
|
||||
result = model.predict(x_predict)
|
||||
pred = tf.argmax(result, axis=1)
|
||||
return pred.numpy()[0]
|
||||
|
||||
|
||||
|
||||
def prediction_2(str):
|
||||
model_save_path = "./checkpoint/mnist_cnn1.ckpt"
|
||||
model = mnist_cnn.creat_model()
|
||||
|
||||
model.load_weights(model_save_path)
|
||||
for i in range(1):
|
||||
image_path = str
|
||||
img = Image.open(image_path)
|
||||
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
||||
img_arr = np.array(img.convert('L'))
|
||||
#
|
||||
# for i in range(28):
|
||||
# for j in range(28):
|
||||
# if img_arr[i][j] < 200:
|
||||
# img_arr[i][j] = 255
|
||||
# else:
|
||||
# img_arr[i][j] = 0
|
||||
|
||||
img_arr = img_arr / 255.0
|
||||
x_predict = img_arr[tf.newaxis, ...]
|
||||
x_predict = tf.expand_dims(x_predict, -1)
|
||||
result = model.predict(x_predict)
|
||||
print(result)
|
||||
pred = tf.argmax(result, axis=1)
|
||||
return pred.numpy()[0]
|
||||
|
||||
|
||||
def prediction_3(str):
|
||||
model_save_path = "./checkpoint/mnist_cnn3.ckpt"
|
||||
model = mnist_cnn3.creat_model()
|
||||
|
||||
model.load_weights(model_save_path)
|
||||
for i in range(1):
|
||||
image_path = str
|
||||
img = Image.open(image_path)
|
||||
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
||||
img_arr = np.array(img.convert('L'))
|
||||
#
|
||||
# for i in range(28):
|
||||
# for j in range(28):
|
||||
# if img_arr[i][j] < 200:
|
||||
# img_arr[i][j] = 255
|
||||
# else:
|
||||
# img_arr[i][j] = 0
|
||||
|
||||
img_arr = img_arr / 255.0
|
||||
x_predict = img_arr[tf.newaxis, ...]
|
||||
result = model.predict(x_predict)
|
||||
|
||||
pred = tf.argmax(result, axis=1)
|
||||
return pred.numpy()[0]
|
@ -0,0 +1,35 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import mnist_cnn3
|
||||
|
||||
|
||||
model_save_path = 'E:\\Python_touge\\MNIST\\venv\\checkpoint\\mnist_cnn3.ckpt'
|
||||
model = mnist_cnn3.creat_model()
|
||||
model.load_weights(model_save_path)
|
||||
|
||||
# preNum = int(input("input the number of test pictures:"))
|
||||
def prediction(str):
|
||||
for i in range(1):
|
||||
image_path = str
|
||||
img = Image.open(image_path)
|
||||
img = img.resize((28, 28), Image.Resampling.LANCZOS)
|
||||
img_arr = np.array(img.convert('L'))
|
||||
#
|
||||
# for i in range(28):
|
||||
# for j in range(28):
|
||||
# if img_arr[i][j] < 200:
|
||||
# img_arr[i][j] = 255
|
||||
# else:
|
||||
# img_arr[i][j] = 0
|
||||
|
||||
img_arr = img_arr / 255.0
|
||||
x_predict = img_arr[tf.newaxis, ...]
|
||||
result = model.predict(x_predict)
|
||||
pred = tf.argmax(result, axis=1)
|
||||
if result[(0, pred)] <= 0.8:
|
||||
str = "无法判断,请重新输入!"
|
||||
print(result)
|
||||
return str
|
||||
else:
|
||||
return pred.numpy()[0]
|
Loading…
Reference in new issue