import sys import numpy as np import torch from PIL import Image, ImageQt from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton, QLabel from PyQt5.QtGui import QPainter, QPen, QImage, QPixmap, QColor from PyQt5.QtCore import Qt, QPoint from mm import CNN class PaintBoard(QWidget): def __init__(self): super().__init__() self.image = QImage(224, 224, QImage.Format_RGB32) self.image.fill(Qt.black) self.last_point = QPoint() self.initUI() def initUI(self): self.setGeometry(0, 0, 224, 224) self.setFixedSize(224, 224) self.setStyleSheet("background-color: black;") def clear(self): self.image.fill(Qt.black) self.update() def getContentAsQImage(self): return self.image def mousePressEvent(self, event): if event.button() == Qt.LeftButton: self.last_point = event.pos() def mouseMoveEvent(self, event): if event.buttons() & Qt.LeftButton: painter = QPainter(self.image) pen = QPen(Qt.white, 20, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) painter.setPen(pen) painter.drawLine(self.last_point, event.pos()) self.last_point = event.pos() self.update() def paintEvent(self, event): canvas_painter = QPainter(self) canvas_painter.drawImage(self.rect(), self.image, self.image.rect()) class MainWindow(QMainWindow): def __init__(self, model): super().__init__() self.model = model self.model.eval() self.setWindowTitle("Handwritten Digit Recognition") self.setGeometry(100, 100, 300, 400) self.paint_board = PaintBoard() self.result_label = QLabel("Prediction: ", self) self.result_label.setGeometry(50, 250, 200, 50) self.result_label.setStyleSheet("font-size: 18px;") clear_button = QPushButton("Clear", self) clear_button.setGeometry(50, 300, 100, 50) clear_button.clicked.connect(self.paint_board.clear) predict_button = QPushButton("Predict", self) predict_button.setGeometry(150, 300, 100, 50) predict_button.clicked.connect(self.predict) #创建垂直布局管理器,把组件放到布局里展示出来 layout = QVBoxLayout() layout.addWidget(self.paint_board) layout.addWidget(self.result_label) layout.addWidget(clear_button) layout.addWidget(predict_button) #创建一个QWidget对象作为主容器,然后将垂直布局管理器放在这个主容器中 container = QWidget() container.setLayout(layout) self.setCentralWidget(container) def predict(self): try: print("Starting prediction...") image = self.paint_board.getContentAsQImage() pil_image = ImageQt.fromqimage(image) pil_image = pil_image.resize((28, 28), Image.ANTIALIAS) pil_image = pil_image.convert('L') img_array = np.array(pil_image).reshape(1, 1, 28, 28).astype(np.float32) / 255.0 img_tensor = torch.tensor(img_array).to(device) with torch.no_grad(): output = self.model(img_tensor) prediction = torch.argmax(output, dim=1).item() self.result_label.setText(f"Prediction: {prediction}") print(f"Prediction: {prediction}") except Exception as e: print(f"Error during prediction: {e}") if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = CNN().to(device) model.load_state_dict(torch.load('model.pth', map_location=device)) app = QApplication(sys.argv) main_window = MainWindow(model) main_window.show() sys.exit(app.exec_())