import sys import numpy as np from PIL import Image, ImageQt from PyQt5.QtCore import QSize from PyQt5.QtGui import QPixmap, QColor from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication from PyQt5.QtWidgets import QMessageBox from common.functions import softmax from dataset.mnist import load_mnist from qt.layout import Ui_MainWindow from qt.paintboard import PaintBoard from simple_convnet import SimpleConvNet MODE_WRITE = 2 # 手写输入 Thresh = 0.5 # 识别结果置信度阈值 # 读取MNIST数据集 (_, _), (x_test, _) = load_mnist(normalize=True, flatten=False, one_hot_label=False) # 初始化网络 # 简单CNN network = SimpleConvNet(input_dim=(1,28,28), conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1}, hidden_size=100, output_size=10, weight_init_std=0.01) network.load_params("params.pkl") class MainWindow(QMainWindow,Ui_MainWindow): def __init__(self): super(MainWindow,self).__init__() # 初始化参数 self.result = [0, 0] # 初始化UI self.setupUi(self) self.center() # 初始化画板 self.paintBoard = PaintBoard(self, Size = QSize(224, 224), Fill = QColor(0,0,0,0)) self.paintBoard.setPenColor(QColor(0,0,0,0)) self.dArea_Layout.addWidget(self.paintBoard) self.clearDataArea() # 窗口居中 def center(self): # 获得窗口 framePos = self.frameGeometry() # 获得屏幕中心点 scPos = QDesktopWidget().availableGeometry().center() # 显示到屏幕中心 framePos.moveCenter(scPos) self.move(framePos.topLeft()) # 窗口关闭事件 def closeEvent(self, event): reply = QMessageBox.question(self, '消息', "确定退出吗?", QMessageBox.Yes | QMessageBox.No, QMessageBox.Yes) if reply == QMessageBox.Yes: event.accept() else: event.ignore() # 清除数据待输入区 def clearDataArea(self): self.paintBoard.Clear() self.lbDataArea.clear() self.lbResult.clear() self.lbCofidence.clear() self.result = [0, 0] """ 回调函数 """ # 模式下拉列表回调 def cbBox_Mode_Callback(self, text): self.mode = MODE_WRITE self.clearDataArea() # 更改背景 self.paintBoard.setBoardFill(QColor(0,0,0,255)) self.paintBoard.setPenColor(QColor(255,255,255,255)) # 数据清除 def pbtClear_Callback(self): self.clearDataArea() # 识别 def pbtPredict_Callback(self): __img = self.paintBoard.getContentAsQImage() # 转换成pil image类型处理 pil_img = ImageQt.fromqimage(__img) pil_img = pil_img.resize((28, 28), Image.ANTIALIAS) img_array = np.array(pil_img.convert('L')).reshape(1,1,28, 28) / 255.0 # img_array = np.where(img_array>0.5, 1, 0) # reshape成网络输入类型 __result = network.predict(img_array) # shape:[1, 10] # print (__result) # 将预测结果使用softmax输出 __result = softmax(__result) self.result[0] = np.argmax(__result) # 预测的数字 self.result[1] = __result[0, self.result[0]] # 置信度 self.lbResult.setText("%d" % (self.result[0])) self.lbCofidence.setText("%.8f" % (self.result[1])) # # 随机抽取 # def pbtGetMnist_Callback(self): # self.clearDataArea() # # # 随机抽取一张测试集图片,放大后显示 # img = x_test[np.random.randint(0, 9999)] # shape:[1,28,28] # img = img.reshape(28, 28) # shape:[28,28] # # img = img * 0xff # 恢复灰度值大小 # pil_img = Image.fromarray(np.uint8(img)) # pil_img = pil_img.resize((224, 224)) # 图像放大显示 # # # 将pil图像转换成qimage类型 # qimage = ImageQt.ImageQt(pil_img) # # # 将qimage类型图像显示在label # pix = QPixmap.fromImage(qimage) # self.lbDataArea.setPixmap(pix) # if __name__ == "__main__": app = QApplication(sys.argv) Gui = MainWindow() Gui.show() sys.exit(app.exec_())