|
|
import sys
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from PIL import Image, ImageQt
|
|
|
from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton, QLabel
|
|
|
from PyQt5.QtGui import QPainter, QPen, QImage, QPixmap, QColor
|
|
|
from PyQt5.QtCore import Qt, QPoint
|
|
|
from mm import CNN
|
|
|
|
|
|
|
|
|
class PaintBoard(QWidget):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.image = QImage(224, 224, QImage.Format_RGB32)
|
|
|
self.image.fill(Qt.black)
|
|
|
self.last_point = QPoint()
|
|
|
self.initUI()
|
|
|
|
|
|
def initUI(self):
|
|
|
self.setGeometry(0, 0, 224, 224)
|
|
|
self.setFixedSize(224, 224)
|
|
|
self.setStyleSheet("background-color: black;")
|
|
|
|
|
|
def clear(self):
|
|
|
self.image.fill(Qt.black)
|
|
|
self.update()
|
|
|
|
|
|
def getContentAsQImage(self):
|
|
|
return self.image
|
|
|
|
|
|
def mousePressEvent(self, event):
|
|
|
if event.button() == Qt.LeftButton:
|
|
|
self.last_point = event.pos()
|
|
|
|
|
|
def mouseMoveEvent(self, event):
|
|
|
if event.buttons() & Qt.LeftButton:
|
|
|
painter = QPainter(self.image)
|
|
|
pen = QPen(Qt.white, 20, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin)
|
|
|
painter.setPen(pen)
|
|
|
painter.drawLine(self.last_point, event.pos())
|
|
|
self.last_point = event.pos()
|
|
|
self.update()
|
|
|
|
|
|
def paintEvent(self, event):
|
|
|
canvas_painter = QPainter(self)
|
|
|
canvas_painter.drawImage(self.rect(), self.image, self.image.rect())
|
|
|
|
|
|
|
|
|
class MainWindow(QMainWindow):
|
|
|
def __init__(self, model):
|
|
|
super().__init__()
|
|
|
self.model = model
|
|
|
self.model.eval()
|
|
|
|
|
|
self.setWindowTitle("Handwritten Digit Recognition")
|
|
|
self.setGeometry(100, 100, 300, 400)
|
|
|
self.paint_board = PaintBoard()
|
|
|
self.result_label = QLabel("Prediction: ", self)
|
|
|
self.result_label.setGeometry(50, 250, 200, 50)
|
|
|
self.result_label.setStyleSheet("font-size: 18px;")
|
|
|
|
|
|
clear_button = QPushButton("Clear", self)
|
|
|
clear_button.setGeometry(50, 300, 100, 50)
|
|
|
clear_button.clicked.connect(self.paint_board.clear)
|
|
|
|
|
|
predict_button = QPushButton("Predict", self)
|
|
|
predict_button.setGeometry(150, 300, 100, 50)
|
|
|
predict_button.clicked.connect(self.predict)
|
|
|
|
|
|
#创建垂直布局管理器,把组件放到布局里展示出来
|
|
|
layout = QVBoxLayout()
|
|
|
layout.addWidget(self.paint_board)
|
|
|
layout.addWidget(self.result_label)
|
|
|
layout.addWidget(clear_button)
|
|
|
layout.addWidget(predict_button)
|
|
|
#创建一个QWidget对象作为主容器,然后将垂直布局管理器放在这个主容器中
|
|
|
container = QWidget()
|
|
|
container.setLayout(layout)
|
|
|
self.setCentralWidget(container)
|
|
|
|
|
|
def predict(self):
|
|
|
try:
|
|
|
print("Starting prediction...")
|
|
|
image = self.paint_board.getContentAsQImage()
|
|
|
pil_image = ImageQt.fromqimage(image)
|
|
|
pil_image = pil_image.resize((28, 28), Image.ANTIALIAS)
|
|
|
pil_image = pil_image.convert('L')
|
|
|
|
|
|
img_array = np.array(pil_image).reshape(1, 1, 28, 28).astype(np.float32) / 255.0
|
|
|
img_tensor = torch.tensor(img_array).to(device)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output = self.model(img_tensor)
|
|
|
prediction = torch.argmax(output, dim=1).item()
|
|
|
|
|
|
self.result_label.setText(f"Prediction: {prediction}")
|
|
|
print(f"Prediction: {prediction}")
|
|
|
except Exception as e:
|
|
|
print(f"Error during prediction: {e}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
model = CNN().to(device)
|
|
|
model.load_state_dict(torch.load('model.pth', map_location=device))
|
|
|
|
|
|
app = QApplication(sys.argv)
|
|
|
main_window = MainWindow(model)
|
|
|
main_window.show()
|
|
|
sys.exit(app.exec_())
|