diff --git a/mnist_cnn_gui_main.py b/mnist_cnn_gui_main.py new file mode 100644 index 0000000..b19673d --- /dev/null +++ b/mnist_cnn_gui_main.py @@ -0,0 +1,159 @@ + +import sys + +import numpy as np +from PIL import Image, ImageQt +from PyQt5.QtCore import QSize +from PyQt5.QtGui import QPixmap, QColor +from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication +from PyQt5.QtWidgets import QMessageBox + +from common.functions import softmax +from dataset.mnist import load_mnist +from qt.layout import Ui_MainWindow +from qt.paintboard import PaintBoard +from simple_convnet import SimpleConvNet + +MODE_WRITE = 2 # 手写输入 + +Thresh = 0.5 # 识别结果置信度阈值 + + + +# 读取MNIST数据集 +(_, _), (x_test, _) = load_mnist(normalize=True, flatten=False, one_hot_label=False) + + +# 初始化网络 + +# 简单CNN + +network = SimpleConvNet(input_dim=(1,28,28), + conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1}, + hidden_size=100, output_size=10, weight_init_std=0.01) +network.load_params("params.pkl") + + + + +class MainWindow(QMainWindow,Ui_MainWindow): + def __init__(self): + super(MainWindow,self).__init__() + + # 初始化参数 + self.result = [0, 0] + + # 初始化UI + self.setupUi(self) + self.center() + + # 初始化画板 + self.paintBoard = PaintBoard(self, Size = QSize(224, 224), Fill = QColor(0,0,0,0)) + self.paintBoard.setPenColor(QColor(0,0,0,0)) + self.dArea_Layout.addWidget(self.paintBoard) + + self.clearDataArea() + + # 窗口居中 + def center(self): + # 获得窗口 + framePos = self.frameGeometry() + # 获得屏幕中心点 + scPos = QDesktopWidget().availableGeometry().center() + # 显示到屏幕中心 + framePos.moveCenter(scPos) + self.move(framePos.topLeft()) + + + # 窗口关闭事件 + def closeEvent(self, event): + reply = QMessageBox.question(self, '消息', + "确定退出吗?", QMessageBox.Yes | + QMessageBox.No, QMessageBox.Yes) + + if reply == QMessageBox.Yes: + event.accept() + else: + event.ignore() + + # 清除数据待输入区 + def clearDataArea(self): + self.paintBoard.Clear() + self.lbDataArea.clear() + self.lbResult.clear() + self.lbCofidence.clear() + self.result = [0, 0] + + """ + 回调函数 + """ + # 模式下拉列表回调 + def cbBox_Mode_Callback(self, text): + self.mode = MODE_WRITE + self.clearDataArea() + + # 更改背景 + self.paintBoard.setBoardFill(QColor(0,0,0,255)) + self.paintBoard.setPenColor(QColor(255,255,255,255)) + + + # 数据清除 + def pbtClear_Callback(self): + self.clearDataArea() + + + # 识别 + def pbtPredict_Callback(self): + + __img = self.paintBoard.getContentAsQImage() + + # 转换成pil image类型处理 + pil_img = ImageQt.fromqimage(__img) + pil_img = pil_img.resize((28, 28), Image.ANTIALIAS) + + + img_array = np.array(pil_img.convert('L')).reshape(1,1,28, 28) / 255.0 + # img_array = np.where(img_array>0.5, 1, 0) + + # reshape成网络输入类型 + __result = network.predict(img_array) # shape:[1, 10] + + # print (__result) + + # 将预测结果使用softmax输出 + __result = softmax(__result) + + self.result[0] = np.argmax(__result) # 预测的数字 + self.result[1] = __result[0, self.result[0]] # 置信度 + + self.lbResult.setText("%d" % (self.result[0])) + self.lbCofidence.setText("%.8f" % (self.result[1])) + + + # # 随机抽取 + # def pbtGetMnist_Callback(self): + # self.clearDataArea() + # + # # 随机抽取一张测试集图片,放大后显示 + # img = x_test[np.random.randint(0, 9999)] # shape:[1,28,28] + # img = img.reshape(28, 28) # shape:[28,28] + # + # img = img * 0xff # 恢复灰度值大小 + # pil_img = Image.fromarray(np.uint8(img)) + # pil_img = pil_img.resize((224, 224)) # 图像放大显示 + # + # # 将pil图像转换成qimage类型 + # qimage = ImageQt.ImageQt(pil_img) + # + # # 将qimage类型图像显示在label + # pix = QPixmap.fromImage(qimage) + # self.lbDataArea.setPixmap(pix) + # + + +if __name__ == "__main__": + app = QApplication(sys.argv) + Gui = MainWindow() + Gui.show() + + sys.exit(app.exec_()) \ No newline at end of file