From dbf175a62e8361123341cdba46e8e913204dba31 Mon Sep 17 00:00:00 2001
From: pl63o9ejz <2318715650@qq.com>
Date: Wed, 28 Apr 2021 16:12:38 +0800
Subject: [PATCH] mnist_cnn_gui_main.py

---
 mnist-master/mnist_cnn_gui_main.py | 192 +++++++++++++++++++++++++++++
 1 file changed, 192 insertions(+)
 create mode 100644 mnist-master/mnist_cnn_gui_main.py

diff --git a/mnist-master/mnist_cnn_gui_main.py b/mnist-master/mnist_cnn_gui_main.py
new file mode 100644
index 0000000..b3a4cc2
--- /dev/null
+++ b/mnist-master/mnist_cnn_gui_main.py
@@ -0,0 +1,192 @@
+#!/usr/bin/python3
+# -*- coding: utf-8 -*-
+
+import sys, os
+import numpy as np
+from dataset.mnist import load_mnist
+from PIL import Image, ImageQt
+
+
+from qt.layout import Ui_MainWindow
+from qt.paintboard import PaintBoard
+
+from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication
+from PyQt5.QtWidgets import QLabel, QMessageBox, QPushButton, QFrame
+from PyQt5.QtGui import QPainter, QPen, QPixmap, QColor, QImage
+from PyQt5.QtCore import Qt, QPoint, QSize
+
+from simple_convnet import SimpleConvNet
+from common.functions import softmax
+from deep_convnet import DeepConvNet
+
+
+
+MODE_MNIST = 1    # MNIST随机抽取
+MODE_WRITE = 2    # 手写输入
+
+Thresh = 0.5      # 识别结果置信度阈值
+
+
+
+# 读取MNIST数据集
+(_, _), (x_test, _) = load_mnist(normalize=True, flatten=False, one_hot_label=False)
+
+
+# 初始化网络
+
+# 网络1:简单CNN
+"""
+conv - relu - pool - affine - relu - affine - softmax
+"""
+network = SimpleConvNet(input_dim=(1,28,28), 
+                        conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
+                        hidden_size=100, output_size=10, weight_init_std=0.01)
+network.load_params("params.pkl")
+
+# 网络2:深度CNN
+# network = DeepConvNet()
+# network.load_params("deep_convnet_params.pkl")
+
+
+class MainWindow(QMainWindow,Ui_MainWindow):
+    def __init__(self):
+        super(MainWindow,self).__init__()
+    
+        # 初始化参数
+        self.mode = MODE_MNIST
+        self.result = [0, 0]
+
+        # 初始化UI
+        self.setupUi(self)
+        self.center()
+
+        # 初始化画板
+        self.paintBoard = PaintBoard(self, Size = QSize(224, 224), Fill = QColor(0,0,0,0))
+        self.paintBoard.setPenColor(QColor(0,0,0,0))
+        self.dArea_Layout.addWidget(self.paintBoard)
+
+        self.clearDataArea()
+
+    # 窗口居中
+    def center(self):
+        # 获得窗口
+        framePos = self.frameGeometry()
+        # 获得屏幕中心点
+        scPos = QDesktopWidget().availableGeometry().center()
+        # 显示到屏幕中心
+        framePos.moveCenter(scPos)
+        self.move(framePos.topLeft())
+    
+    
+    # 窗口关闭事件
+    def closeEvent(self, event):
+        reply = QMessageBox.question(self, 'Message',
+            "Are you sure to quit?", QMessageBox.Yes | 
+            QMessageBox.No, QMessageBox.Yes)
+
+        if reply == QMessageBox.Yes:
+            event.accept()
+        else:
+            event.ignore()   
+    
+    # 清除数据待输入区
+    def clearDataArea(self):
+        self.paintBoard.Clear()
+        self.lbDataArea.clear()
+        self.lbResult.clear()
+        self.lbCofidence.clear()
+        self.result = [0, 0]
+
+    """
+    回调函数
+    """
+    # 模式下拉列表回调
+    def cbBox_Mode_Callback(self, text):
+        if text == '1:MINIST随机抽取':
+            self.mode = MODE_MNIST
+            self.clearDataArea()
+            self.pbtGetMnist.setEnabled(True)
+
+            self.paintBoard.setBoardFill(QColor(0,0,0,0))
+            self.paintBoard.setPenColor(QColor(0,0,0,0))
+
+        elif text == '2:鼠标手写输入':
+            self.mode = MODE_WRITE
+            self.clearDataArea()
+            self.pbtGetMnist.setEnabled(False)
+
+            # 更改背景
+            self.paintBoard.setBoardFill(QColor(0,0,0,255))
+            self.paintBoard.setPenColor(QColor(255,255,255,255))
+
+
+    # 数据清除
+    def pbtClear_Callback(self):
+        self.clearDataArea()
+ 
+
+    # 识别
+    def pbtPredict_Callback(self):
+        __img, img_array =[],[]      # 将图像统一从qimage->pil image -> np.array [1, 1, 28, 28]
+        
+        # 获取qimage格式图像
+        if self.mode == MODE_MNIST:
+            __img = self.lbDataArea.pixmap()  # label内若无图像返回None
+            if __img == None:   # 无图像则用纯黑代替
+                # __img = QImage(224, 224, QImage.Format_Grayscale8)
+                __img = ImageQt.ImageQt(Image.fromarray(np.uint8(np.zeros([224,224]))))
+            else: __img = __img.toImage()
+        elif self.mode == MODE_WRITE:
+            __img = self.paintBoard.getContentAsQImage()
+
+        # 转换成pil image类型处理
+        pil_img = ImageQt.fromqimage(__img)
+        pil_img = pil_img.resize((28, 28), Image.ANTIALIAS)
+        
+        # pil_img.save('test.png')
+
+        img_array = np.array(pil_img.convert('L')).reshape(1,1,28, 28) / 255.0
+        # img_array = np.where(img_array>0.5, 1, 0)
+    
+        # reshape成网络输入类型 
+        __result = network.predict(img_array)      # shape:[1, 10]
+
+        # print (__result)
+
+        # 将预测结果使用softmax输出
+        __result = softmax(__result)
+       
+        self.result[0] = np.argmax(__result)          # 预测的数字
+        self.result[1] = __result[0, self.result[0]]     # 置信度
+
+        self.lbResult.setText("%d" % (self.result[0]))
+        self.lbCofidence.setText("%.8f" % (self.result[1]))
+
+
+    # 随机抽取
+    def pbtGetMnist_Callback(self):
+        self.clearDataArea()
+        
+        # 随机抽取一张测试集图片,放大后显示
+        img = x_test[np.random.randint(0, 9999)]    # shape:[1,28,28] 
+        img = img.reshape(28, 28)                   # shape:[28,28]  
+
+        img = img * 0xff      # 恢复灰度值大小 
+        pil_img = Image.fromarray(np.uint8(img))
+        pil_img = pil_img.resize((224, 224))        # 图像放大显示
+
+        # 将pil图像转换成qimage类型
+        qimage = ImageQt.ImageQt(pil_img)
+        
+        # 将qimage类型图像显示在label
+        pix = QPixmap.fromImage(qimage)
+        self.lbDataArea.setPixmap(pix)
+
+
+
+if __name__ == "__main__":
+    app = QApplication(sys.argv)
+    Gui = MainWindow()
+    Gui.show()
+
+    sys.exit(app.exec_())
\ No newline at end of file