From 66a5007298d9951cb4cda875559344ce4d786260 Mon Sep 17 00:00:00 2001 From: p4w2aybsf <2363061197@qq.com> Date: Thu, 29 Apr 2021 16:57:55 +0800 Subject: [PATCH] =?UTF-8?q?=E6=97=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mnist_cnn_gui_main.py | 192 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 mnist_cnn_gui_main.py diff --git a/mnist_cnn_gui_main.py b/mnist_cnn_gui_main.py new file mode 100644 index 0000000..b3a4cc2 --- /dev/null +++ b/mnist_cnn_gui_main.py @@ -0,0 +1,192 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- + +import sys, os +import numpy as np +from dataset.mnist import load_mnist +from PIL import Image, ImageQt + + +from qt.layout import Ui_MainWindow +from qt.paintboard import PaintBoard + +from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication +from PyQt5.QtWidgets import QLabel, QMessageBox, QPushButton, QFrame +from PyQt5.QtGui import QPainter, QPen, QPixmap, QColor, QImage +from PyQt5.QtCore import Qt, QPoint, QSize + +from simple_convnet import SimpleConvNet +from common.functions import softmax +from deep_convnet import DeepConvNet + + + +MODE_MNIST = 1 # MNIST随机抽取 +MODE_WRITE = 2 # 手写输入 + +Thresh = 0.5 # 识别结果置信度阈值 + + + +# 读取MNIST数据集 +(_, _), (x_test, _) = load_mnist(normalize=True, flatten=False, one_hot_label=False) + + +# 初始化网络 + +# 网络1:简单CNN +""" +conv - relu - pool - affine - relu - affine - softmax +""" +network = SimpleConvNet(input_dim=(1,28,28), + conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1}, + hidden_size=100, output_size=10, weight_init_std=0.01) +network.load_params("params.pkl") + +# 网络2:深度CNN +# network = DeepConvNet() +# network.load_params("deep_convnet_params.pkl") + + +class MainWindow(QMainWindow,Ui_MainWindow): + def __init__(self): + super(MainWindow,self).__init__() + + # 初始化参数 + self.mode = MODE_MNIST + self.result = [0, 0] + + # 初始化UI + self.setupUi(self) + self.center() + + # 初始化画板 + self.paintBoard = PaintBoard(self, Size = QSize(224, 224), Fill = QColor(0,0,0,0)) + self.paintBoard.setPenColor(QColor(0,0,0,0)) + self.dArea_Layout.addWidget(self.paintBoard) + + self.clearDataArea() + + # 窗口居中 + def center(self): + # 获得窗口 + framePos = self.frameGeometry() + # 获得屏幕中心点 + scPos = QDesktopWidget().availableGeometry().center() + # 显示到屏幕中心 + framePos.moveCenter(scPos) + self.move(framePos.topLeft()) + + + # 窗口关闭事件 + def closeEvent(self, event): + reply = QMessageBox.question(self, 'Message', + "Are you sure to quit?", QMessageBox.Yes | + QMessageBox.No, QMessageBox.Yes) + + if reply == QMessageBox.Yes: + event.accept() + else: + event.ignore() + + # 清除数据待输入区 + def clearDataArea(self): + self.paintBoard.Clear() + self.lbDataArea.clear() + self.lbResult.clear() + self.lbCofidence.clear() + self.result = [0, 0] + + """ + 回调函数 + """ + # 模式下拉列表回调 + def cbBox_Mode_Callback(self, text): + if text == '1:MINIST随机抽取': + self.mode = MODE_MNIST + self.clearDataArea() + self.pbtGetMnist.setEnabled(True) + + self.paintBoard.setBoardFill(QColor(0,0,0,0)) + self.paintBoard.setPenColor(QColor(0,0,0,0)) + + elif text == '2:鼠标手写输入': + self.mode = MODE_WRITE + self.clearDataArea() + self.pbtGetMnist.setEnabled(False) + + # 更改背景 + self.paintBoard.setBoardFill(QColor(0,0,0,255)) + self.paintBoard.setPenColor(QColor(255,255,255,255)) + + + # 数据清除 + def pbtClear_Callback(self): + self.clearDataArea() + + + # 识别 + def pbtPredict_Callback(self): + __img, img_array =[],[] # 将图像统一从qimage->pil image -> np.array [1, 1, 28, 28] + + # 获取qimage格式图像 + if self.mode == MODE_MNIST: + __img = self.lbDataArea.pixmap() # label内若无图像返回None + if __img == None: # 无图像则用纯黑代替 + # __img = QImage(224, 224, QImage.Format_Grayscale8) + __img = ImageQt.ImageQt(Image.fromarray(np.uint8(np.zeros([224,224])))) + else: __img = __img.toImage() + elif self.mode == MODE_WRITE: + __img = self.paintBoard.getContentAsQImage() + + # 转换成pil image类型处理 + pil_img = ImageQt.fromqimage(__img) + pil_img = pil_img.resize((28, 28), Image.ANTIALIAS) + + # pil_img.save('test.png') + + img_array = np.array(pil_img.convert('L')).reshape(1,1,28, 28) / 255.0 + # img_array = np.where(img_array>0.5, 1, 0) + + # reshape成网络输入类型 + __result = network.predict(img_array) # shape:[1, 10] + + # print (__result) + + # 将预测结果使用softmax输出 + __result = softmax(__result) + + self.result[0] = np.argmax(__result) # 预测的数字 + self.result[1] = __result[0, self.result[0]] # 置信度 + + self.lbResult.setText("%d" % (self.result[0])) + self.lbCofidence.setText("%.8f" % (self.result[1])) + + + # 随机抽取 + def pbtGetMnist_Callback(self): + self.clearDataArea() + + # 随机抽取一张测试集图片,放大后显示 + img = x_test[np.random.randint(0, 9999)] # shape:[1,28,28] + img = img.reshape(28, 28) # shape:[28,28] + + img = img * 0xff # 恢复灰度值大小 + pil_img = Image.fromarray(np.uint8(img)) + pil_img = pil_img.resize((224, 224)) # 图像放大显示 + + # 将pil图像转换成qimage类型 + qimage = ImageQt.ImageQt(pil_img) + + # 将qimage类型图像显示在label + pix = QPixmap.fromImage(qimage) + self.lbDataArea.setPixmap(pix) + + + +if __name__ == "__main__": + app = QApplication(sys.argv) + Gui = MainWindow() + Gui.show() + + sys.exit(app.exec_()) \ No newline at end of file