parent
235db2d4fb
commit
34a1947052
@ -0,0 +1,159 @@
|
||||
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageQt
|
||||
from PyQt5.QtCore import QSize
|
||||
from PyQt5.QtGui import QPixmap, QColor
|
||||
from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication
|
||||
from PyQt5.QtWidgets import QMessageBox
|
||||
|
||||
from common.functions import softmax
|
||||
from dataset.mnist import load_mnist
|
||||
from qt.layout import Ui_MainWindow
|
||||
from qt.paintboard import PaintBoard
|
||||
from simple_convnet import SimpleConvNet
|
||||
|
||||
MODE_WRITE = 2 # 手写输入
|
||||
|
||||
Thresh = 0.5 # 识别结果置信度阈值
|
||||
|
||||
|
||||
|
||||
# 读取MNIST数据集
|
||||
(_, _), (x_test, _) = load_mnist(normalize=True, flatten=False, one_hot_label=False)
|
||||
|
||||
|
||||
# 初始化网络
|
||||
|
||||
# 简单CNN
|
||||
|
||||
network = SimpleConvNet(input_dim=(1,28,28),
|
||||
conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
|
||||
hidden_size=100, output_size=10, weight_init_std=0.01)
|
||||
network.load_params("params.pkl")
|
||||
|
||||
|
||||
|
||||
|
||||
class MainWindow(QMainWindow,Ui_MainWindow):
|
||||
def __init__(self):
|
||||
super(MainWindow,self).__init__()
|
||||
|
||||
# 初始化参数
|
||||
self.result = [0, 0]
|
||||
|
||||
# 初始化UI
|
||||
self.setupUi(self)
|
||||
self.center()
|
||||
|
||||
# 初始化画板
|
||||
self.paintBoard = PaintBoard(self, Size = QSize(224, 224), Fill = QColor(0,0,0,0))
|
||||
self.paintBoard.setPenColor(QColor(0,0,0,0))
|
||||
self.dArea_Layout.addWidget(self.paintBoard)
|
||||
|
||||
self.clearDataArea()
|
||||
|
||||
# 窗口居中
|
||||
def center(self):
|
||||
# 获得窗口
|
||||
framePos = self.frameGeometry()
|
||||
# 获得屏幕中心点
|
||||
scPos = QDesktopWidget().availableGeometry().center()
|
||||
# 显示到屏幕中心
|
||||
framePos.moveCenter(scPos)
|
||||
self.move(framePos.topLeft())
|
||||
|
||||
|
||||
# 窗口关闭事件
|
||||
def closeEvent(self, event):
|
||||
reply = QMessageBox.question(self, '消息',
|
||||
"确定退出吗?", QMessageBox.Yes |
|
||||
QMessageBox.No, QMessageBox.Yes)
|
||||
|
||||
if reply == QMessageBox.Yes:
|
||||
event.accept()
|
||||
else:
|
||||
event.ignore()
|
||||
|
||||
# 清除数据待输入区
|
||||
def clearDataArea(self):
|
||||
self.paintBoard.Clear()
|
||||
self.lbDataArea.clear()
|
||||
self.lbResult.clear()
|
||||
self.lbCofidence.clear()
|
||||
self.result = [0, 0]
|
||||
|
||||
"""
|
||||
回调函数
|
||||
"""
|
||||
# 模式下拉列表回调
|
||||
def cbBox_Mode_Callback(self, text):
|
||||
self.mode = MODE_WRITE
|
||||
self.clearDataArea()
|
||||
|
||||
# 更改背景
|
||||
self.paintBoard.setBoardFill(QColor(0,0,0,255))
|
||||
self.paintBoard.setPenColor(QColor(255,255,255,255))
|
||||
|
||||
|
||||
# 数据清除
|
||||
def pbtClear_Callback(self):
|
||||
self.clearDataArea()
|
||||
|
||||
|
||||
# 识别
|
||||
def pbtPredict_Callback(self):
|
||||
|
||||
__img = self.paintBoard.getContentAsQImage()
|
||||
|
||||
# 转换成pil image类型处理
|
||||
pil_img = ImageQt.fromqimage(__img)
|
||||
pil_img = pil_img.resize((28, 28), Image.ANTIALIAS)
|
||||
|
||||
|
||||
img_array = np.array(pil_img.convert('L')).reshape(1,1,28, 28) / 255.0
|
||||
# img_array = np.where(img_array>0.5, 1, 0)
|
||||
|
||||
# reshape成网络输入类型
|
||||
__result = network.predict(img_array) # shape:[1, 10]
|
||||
|
||||
# print (__result)
|
||||
|
||||
# 将预测结果使用softmax输出
|
||||
__result = softmax(__result)
|
||||
|
||||
self.result[0] = np.argmax(__result) # 预测的数字
|
||||
self.result[1] = __result[0, self.result[0]] # 置信度
|
||||
|
||||
self.lbResult.setText("%d" % (self.result[0]))
|
||||
self.lbCofidence.setText("%.8f" % (self.result[1]))
|
||||
|
||||
|
||||
# # 随机抽取
|
||||
# def pbtGetMnist_Callback(self):
|
||||
# self.clearDataArea()
|
||||
#
|
||||
# # 随机抽取一张测试集图片,放大后显示
|
||||
# img = x_test[np.random.randint(0, 9999)] # shape:[1,28,28]
|
||||
# img = img.reshape(28, 28) # shape:[28,28]
|
||||
#
|
||||
# img = img * 0xff # 恢复灰度值大小
|
||||
# pil_img = Image.fromarray(np.uint8(img))
|
||||
# pil_img = pil_img.resize((224, 224)) # 图像放大显示
|
||||
#
|
||||
# # 将pil图像转换成qimage类型
|
||||
# qimage = ImageQt.ImageQt(pil_img)
|
||||
#
|
||||
# # 将qimage类型图像显示在label
|
||||
# pix = QPixmap.fromImage(qimage)
|
||||
# self.lbDataArea.setPixmap(pix)
|
||||
#
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
Gui = MainWindow()
|
||||
Gui.show()
|
||||
|
||||
sys.exit(app.exec_())
|
Loading…
Reference in new issue