You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
4.3 KiB

import sys
import numpy as np
from PIL import Image, ImageQt
from PyQt5.QtCore import QSize
from PyQt5.QtGui import QPixmap, QColor
from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication
from PyQt5.QtWidgets import QMessageBox
from common.functions import softmax
from dataset.mnist import load_mnist
from qt.layout import Ui_MainWindow
from qt.paintboard import PaintBoard
from simple_convnet import SimpleConvNet
MODE_WRITE = 2 # 手写输入
Thresh = 0.5 # 识别结果置信度阈值
# 读取MNIST数据集
(_, _), (x_test, _) = load_mnist(normalize=True, flatten=False, one_hot_label=False)
# 初始化网络
# 简单CNN
network = SimpleConvNet(input_dim=(1,28,28),
conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
hidden_size=100, output_size=10, weight_init_std=0.01)
network.load_params("params.pkl")
class MainWindow(QMainWindow,Ui_MainWindow):
def __init__(self):
super(MainWindow,self).__init__()
# 初始化参数
self.result = [0, 0]
# 初始化UI
self.setupUi(self)
self.center()
# 初始化画板
self.paintBoard = PaintBoard(self, Size = QSize(224, 224), Fill = QColor(0,0,0,0))
self.paintBoard.setPenColor(QColor(0,0,0,0))
self.dArea_Layout.addWidget(self.paintBoard)
self.clearDataArea()
# 窗口居中
def center(self):
# 获得窗口
framePos = self.frameGeometry()
# 获得屏幕中心点
scPos = QDesktopWidget().availableGeometry().center()
# 显示到屏幕中心
framePos.moveCenter(scPos)
self.move(framePos.topLeft())
# 窗口关闭事件
def closeEvent(self, event):
reply = QMessageBox.question(self, '消息',
"确定退出吗?", QMessageBox.Yes |
QMessageBox.No, QMessageBox.Yes)
if reply == QMessageBox.Yes:
event.accept()
else:
event.ignore()
# 清除数据待输入区
def clearDataArea(self):
self.paintBoard.Clear()
self.lbDataArea.clear()
self.lbResult.clear()
self.lbCofidence.clear()
self.result = [0, 0]
"""
回调函数
"""
# 模式下拉列表回调
def cbBox_Mode_Callback(self, text):
self.mode = MODE_WRITE
self.clearDataArea()
# 更改背景
self.paintBoard.setBoardFill(QColor(0,0,0,255))
self.paintBoard.setPenColor(QColor(255,255,255,255))
# 数据清除
def pbtClear_Callback(self):
self.clearDataArea()
# 识别
def pbtPredict_Callback(self):
__img = self.paintBoard.getContentAsQImage()
# 转换成pil image类型处理
pil_img = ImageQt.fromqimage(__img)
pil_img = pil_img.resize((28, 28), Image.ANTIALIAS)
img_array = np.array(pil_img.convert('L')).reshape(1,1,28, 28) / 255.0
# img_array = np.where(img_array>0.5, 1, 0)
# reshape成网络输入类型
__result = network.predict(img_array) # shape:[1, 10]
# print (__result)
# 将预测结果使用softmax输出
__result = softmax(__result)
self.result[0] = np.argmax(__result) # 预测的数字
self.result[1] = __result[0, self.result[0]] # 置信度
self.lbResult.setText("%d" % (self.result[0]))
self.lbCofidence.setText("%.8f" % (self.result[1]))
# # 随机抽取
# def pbtGetMnist_Callback(self):
# self.clearDataArea()
#
# # 随机抽取一张测试集图片,放大后显示
# img = x_test[np.random.randint(0, 9999)] # shape:[1,28,28]
# img = img.reshape(28, 28) # shape:[28,28]
#
# img = img * 0xff # 恢复灰度值大小
# pil_img = Image.fromarray(np.uint8(img))
# pil_img = pil_img.resize((224, 224)) # 图像放大显示
#
# # 将pil图像转换成qimage类型
# qimage = ImageQt.ImageQt(pil_img)
#
# # 将qimage类型图像显示在label
# pix = QPixmap.fromImage(qimage)
# self.lbDataArea.setPixmap(pix)
#
if __name__ == "__main__":
app = QApplication(sys.argv)
Gui = MainWindow()
Gui.show()
sys.exit(app.exec_())