#!/usr/bin/python3 # -*- coding: utf-8 -*- import sys, os import numpy as np from dataset.mnist import load_mnist from PIL import Image, ImageQt from qt.layout import Ui_MainWindow from qt.paintboard import PaintBoard from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication from PyQt5.QtWidgets import QLabel, QMessageBox, QPushButton, QFrame from PyQt5.QtGui import QPainter, QPen, QPixmap, QColor, QImage from PyQt5.QtCore import Qt, QPoint, QSize from simple_convnet import SimpleConvNet from common.functions import softmax from deep_convnet import DeepConvNet MODE_MNIST = 1 # MNIST随机抽取 MODE_WRITE = 2 # 手写输入 Thresh = 0.5 # 识别结果置信度阈值 # 读取MNIST数据集 (_, _), (x_test, _) = load_mnist(normalize=True, flatten=False, one_hot_label=False) # 初始化网络 # 网络1:简单CNN """ conv - relu - pool - affine - relu - affine - softmax """ network = SimpleConvNet(input_dim=(1,28,28), conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1}, hidden_size=100, output_size=10, weight_init_std=0.01) network.load_params("params.pkl") # 网络2:深度CNN # network = DeepConvNet() # network.load_params("deep_convnet_params.pkl") class MainWindow(QMainWindow,Ui_MainWindow): def __init__(self): super(MainWindow,self).__init__() # 初始化参数 self.mode = MODE_MNIST self.result = [0, 0] # 初始化UI self.setupUi(self) self.center() # 初始化画板 self.paintBoard = PaintBoard(self, Size = QSize(224, 224), Fill = QColor(0,0,0,0)) self.paintBoard.setPenColor(QColor(0,0,0,0)) self.dArea_Layout.addWidget(self.paintBoard) self.clearDataArea() # 窗口居中 def center(self): # 获得窗口 framePos = self.frameGeometry() # 获得屏幕中心点 scPos = QDesktopWidget().availableGeometry().center() # 显示到屏幕中心 framePos.moveCenter(scPos) self.move(framePos.topLeft()) # 窗口关闭事件 def closeEvent(self, event): reply = QMessageBox.question(self, 'Message', "Are you sure to quit?", QMessageBox.Yes | QMessageBox.No, QMessageBox.Yes) if reply == QMessageBox.Yes: event.accept() else: event.ignore() # 清除数据待输入区 def clearDataArea(self): self.paintBoard.Clear() self.lbDataArea.clear() self.lbResult.clear() self.lbCofidence.clear() self.result = [0, 0] """ 回调函数 """ # 模式下拉列表回调 def cbBox_Mode_Callback(self, text): if text == '1:MINIST随机抽取': self.mode = MODE_MNIST self.clearDataArea() self.pbtGetMnist.setEnabled(True) self.paintBoard.setBoardFill(QColor(0,0,0,0)) self.paintBoard.setPenColor(QColor(0,0,0,0)) elif text == '2:鼠标手写输入': self.mode = MODE_WRITE self.clearDataArea() self.pbtGetMnist.setEnabled(False) # 更改背景 self.paintBoard.setBoardFill(QColor(0,0,0,255)) self.paintBoard.setPenColor(QColor(255,255,255,255)) # 数据清除 def pbtClear_Callback(self): self.clearDataArea() # 识别 def pbtPredict_Callback(self): __img, img_array =[],[] # 将图像统一从qimage->pil image -> np.array [1, 1, 28, 28] # 获取qimage格式图像 if self.mode == MODE_MNIST: __img = self.lbDataArea.pixmap() # label内若无图像返回None if __img == None: # 无图像则用纯黑代替 # __img = QImage(224, 224, QImage.Format_Grayscale8) __img = ImageQt.ImageQt(Image.fromarray(np.uint8(np.zeros([224,224])))) else: __img = __img.toImage() elif self.mode == MODE_WRITE: __img = self.paintBoard.getContentAsQImage() # 转换成pil image类型处理 pil_img = ImageQt.fromqimage(__img) pil_img = pil_img.resize((28, 28), Image.ANTIALIAS) # pil_img.save('test.png') img_array = np.array(pil_img.convert('L')).reshape(1,1,28, 28) / 255.0 # img_array = np.where(img_array>0.5, 1, 0) # reshape成网络输入类型 __result = network.predict(img_array) # shape:[1, 10] # print (__result) # 将预测结果使用softmax输出 __result = softmax(__result) self.result[0] = np.argmax(__result) # 预测的数字 self.result[1] = __result[0, self.result[0]] # 置信度 self.lbResult.setText("%d" % (self.result[0])) self.lbCofidence.setText("%.8f" % (self.result[1])) # 随机抽取 def pbtGetMnist_Callback(self): self.clearDataArea() # 随机抽取一张测试集图片,放大后显示 img = x_test[np.random.randint(0, 9999)] # shape:[1,28,28] img = img.reshape(28, 28) # shape:[28,28] img = img * 0xff # 恢复灰度值大小 pil_img = Image.fromarray(np.uint8(img)) pil_img = pil_img.resize((224, 224)) # 图像放大显示 # 将pil图像转换成qimage类型 qimage = ImageQt.ImageQt(pil_img) # 将qimage类型图像显示在label pix = QPixmap.fromImage(qimage) self.lbDataArea.setPixmap(pix) if __name__ == "__main__": app = QApplication(sys.argv) Gui = MainWindow() Gui.show() sys.exit(app.exec_())