You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

111 lines
3.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import sys
import numpy as np
import torch
from PIL import Image, ImageQt
from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton, QLabel
from PyQt5.QtGui import QPainter, QPen, QImage, QPixmap, QColor
from PyQt5.QtCore import Qt, QPoint
from mm import CNN
class PaintBoard(QWidget):
def __init__(self):
super().__init__()
self.image = QImage(224, 224, QImage.Format_RGB32)
self.image.fill(Qt.black)
self.last_point = QPoint()
self.initUI()
def initUI(self):
self.setGeometry(0, 0, 224, 224)
self.setFixedSize(224, 224)
self.setStyleSheet("background-color: black;")
def clear(self):
self.image.fill(Qt.black)
self.update()
def getContentAsQImage(self):
return self.image
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self.last_point = event.pos()
def mouseMoveEvent(self, event):
if event.buttons() & Qt.LeftButton:
painter = QPainter(self.image)
pen = QPen(Qt.white, 20, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin)
painter.setPen(pen)
painter.drawLine(self.last_point, event.pos())
self.last_point = event.pos()
self.update()
def paintEvent(self, event):
canvas_painter = QPainter(self)
canvas_painter.drawImage(self.rect(), self.image, self.image.rect())
class MainWindow(QMainWindow):
def __init__(self, model):
super().__init__()
self.model = model
self.model.eval()
self.setWindowTitle("Handwritten Digit Recognition")
self.setGeometry(100, 100, 300, 400)
self.paint_board = PaintBoard()
self.result_label = QLabel("Prediction: ", self)
self.result_label.setGeometry(50, 250, 200, 50)
self.result_label.setStyleSheet("font-size: 18px;")
clear_button = QPushButton("Clear", self)
clear_button.setGeometry(50, 300, 100, 50)
clear_button.clicked.connect(self.paint_board.clear)
predict_button = QPushButton("Predict", self)
predict_button.setGeometry(150, 300, 100, 50)
predict_button.clicked.connect(self.predict)
#创建垂直布局管理器,把组件放到布局里展示出来
layout = QVBoxLayout()
layout.addWidget(self.paint_board)
layout.addWidget(self.result_label)
layout.addWidget(clear_button)
layout.addWidget(predict_button)
#创建一个QWidget对象作为主容器然后将垂直布局管理器放在这个主容器中
container = QWidget()
container.setLayout(layout)
self.setCentralWidget(container)
def predict(self):
try:
print("Starting prediction...")
image = self.paint_board.getContentAsQImage()
pil_image = ImageQt.fromqimage(image)
pil_image = pil_image.resize((28, 28), Image.ANTIALIAS)
pil_image = pil_image.convert('L')
img_array = np.array(pil_image).reshape(1, 1, 28, 28).astype(np.float32) / 255.0
img_tensor = torch.tensor(img_array).to(device)
with torch.no_grad():
output = self.model(img_tensor)
prediction = torch.argmax(output, dim=1).item()
self.result_label.setText(f"Prediction: {prediction}")
print(f"Prediction: {prediction}")
except Exception as e:
print(f"Error during prediction: {e}")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
model.load_state_dict(torch.load('model.pth', map_location=device))
app = QApplication(sys.argv)
main_window = MainWindow(model)
main_window.show()
sys.exit(app.exec_())