try 5 months ago
parent 3baa80ee49
commit c2719708ef

@ -1,2 +0,0 @@
# written_number_classfication

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

@ -0,0 +1,12 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredIdentifiers">
<list>
<option value="mybkauth.views.conn" />
</list>
</option>
</inspection_tool>
</profile>
</component>

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="C:\Users\28326\anaconda3" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="C:\Users\28326\anaconda3" project-jdk-type="Python SDK" />
</project>

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/识别.iml" filepath="$PROJECT_DIR$/.idea/识别.iml" />
</modules>
</component>
</project>

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="C:\Users\28326\anaconda3" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

@ -0,0 +1,110 @@
import sys
import numpy as np
import torch
from PIL import Image, ImageQt
from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton, QLabel
from PyQt5.QtGui import QPainter, QPen, QImage, QPixmap, QColor
from PyQt5.QtCore import Qt, QPoint
from mm import CNN
class PaintBoard(QWidget):
def __init__(self):
super().__init__()
self.image = QImage(224, 224, QImage.Format_RGB32)
self.image.fill(Qt.black)
self.last_point = QPoint()
self.initUI()
def initUI(self):
self.setGeometry(0, 0, 224, 224)
self.setFixedSize(224, 224)
self.setStyleSheet("background-color: black;")
def clear(self):
self.image.fill(Qt.black)
self.update()
def getContentAsQImage(self):
return self.image
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self.last_point = event.pos()
def mouseMoveEvent(self, event):
if event.buttons() & Qt.LeftButton:
painter = QPainter(self.image)
pen = QPen(Qt.white, 20, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin)
painter.setPen(pen)
painter.drawLine(self.last_point, event.pos())
self.last_point = event.pos()
self.update()
def paintEvent(self, event):
canvas_painter = QPainter(self)
canvas_painter.drawImage(self.rect(), self.image, self.image.rect())
class MainWindow(QMainWindow):
def __init__(self, model):
super().__init__()
self.model = model
self.model.eval()
self.setWindowTitle("Handwritten Digit Recognition")
self.setGeometry(100, 100, 300, 400)
self.paint_board = PaintBoard()
self.result_label = QLabel("Prediction: ", self)
self.result_label.setGeometry(50, 250, 200, 50)
self.result_label.setStyleSheet("font-size: 18px;")
clear_button = QPushButton("Clear", self)
clear_button.setGeometry(50, 300, 100, 50)
clear_button.clicked.connect(self.paint_board.clear)
predict_button = QPushButton("Predict", self)
predict_button.setGeometry(150, 300, 100, 50)
predict_button.clicked.connect(self.predict)
#创建垂直布局管理器,把组件放到布局里展示出来
layout = QVBoxLayout()
layout.addWidget(self.paint_board)
layout.addWidget(self.result_label)
layout.addWidget(clear_button)
layout.addWidget(predict_button)
#创建一个QWidget对象作为主容器然后将垂直布局管理器放在这个主容器中
container = QWidget()
container.setLayout(layout)
self.setCentralWidget(container)
def predict(self):
try:
print("Starting prediction...")
image = self.paint_board.getContentAsQImage()
pil_image = ImageQt.fromqimage(image)
pil_image = pil_image.resize((28, 28), Image.ANTIALIAS)
pil_image = pil_image.convert('L')
img_array = np.array(pil_image).reshape(1, 1, 28, 28).astype(np.float32) / 255.0
img_tensor = torch.tensor(img_array).to(device)
with torch.no_grad():
output = self.model(img_tensor)
prediction = torch.argmax(output, dim=1).item()
self.result_label.setText(f"Prediction: {prediction}")
print(f"Prediction: {prediction}")
except Exception as e:
print(f"Error during prediction: {e}")
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
model.load_state_dict(torch.load('model.pth', map_location=device))
app = QApplication(sys.argv)
main_window = MainWindow(model)
main_window.show()
sys.exit(app.exec_())

@ -0,0 +1,105 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import numpy as np
input_size = 28
num_classes = 10
num_epochs = 40
batch_size = 64
# 确保 CUDA 可用并且使用它
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True)
#定义了一个CNN类作为神经网络模型
# in_channels输入数据的通道数。在第一层卷积中输入的MNIST图像是灰度图像因此 in_channels=1。
# out_channels输出的特征图通道数即卷积核的数量。每个卷积核生成一个特征图。
# kernel_size卷积核的大小。例如kernel_size=5 表示卷积核大小为5x5。
# stride卷积核的步幅。在代码中设置为1表示卷积核每次移动1个像素。
# padding填充的像素数。在代码中设置为2表示在输入数据周围填充2个像素。
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv4 = nn.Sequential(
nn.Conv2d(64, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.out = nn.Linear(32, num_classes) # 确保输出类别数正确
def forward(self, x):
x = self.conv1(x).to(device) # 将数据迁移到GPU
x = self.conv2(x).to(device)
x=self.conv3(x).to(device)
x = self.conv4(x).to(device)
x = x.view(x.size(0), -1).to(device)
x = self.out(x).to(device)
return x
def accuracy(prediction, label):
pred = torch.max(prediction.data, 1)[1]
rights = pred.eq(label.data.view_as(pred)).sum()
return rights, len(label)
net = CNN().to(device) # 将模型迁移到GPU
criterion = nn.CrossEntropyLoss().to(device) # 将损失函数迁移到GPU
optimizer = optim.Adam(net.parameters(), lr=0.001)
tr = []
for epoch in range(40):
train_rights = []
for batch_id, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device) # 将数据迁移到GPU
optimizer.zero_grad()
output = net(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
right = accuracy(output, target)
train_rights.append(right)
tr.append(loss)
if batch_id % 100 == 0:
val_right = []
for (data, target) in test_loader:
data, target = data.to(device), target.to(device) # 将数据迁移到GPU
output = net(data)
right = accuracy(output, target)
val_right.append(right)
train_r = (sum(t[0] for t in train_rights), sum(t[1] for t in train_rights))
val_r = (sum(t[0] for t in val_right), sum(t[1] for t in val_right))
print("train_acc {}, test_acc {}".format(train_r[0] / train_r[1], val_r[0] / val_r[1]))
# 注意绘图时不需要将数据迁移到GPU
plt.plot(list(range(len([tensor.detach().cpu().numpy() for tensor in tr]))), [tensor.detach().cpu().numpy() for tensor in tr])
plt.show()
torch.save(net.state_dict(), 'model.pth')

@ -0,0 +1,42 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.conv4 = nn.Sequential(
nn.Conv2d(64, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.out = nn.Linear(32, 10) # 确保输出类别数正确
def forward(self, x):
x = self.conv1(x).to(device) # 将数据迁移到GPU
x = self.conv2(x).to(device)
x=self.conv3(x).to(device)
x = self.conv4(x).to(device)
x = x.view(x.size(0), -1).to(device)
x = self.out(x).to(device)
return x

Binary file not shown.

@ -0,0 +1,28 @@
certifi==2024.6.2
charset-normalizer==3.3.2
contourpy==1.2.1
cycler==0.12.1
filelock==3.15.1
fonttools==4.53.0
idna==3.7
Jinja2==3.1.4
kiwisolver==1.4.5
MarkupSafe==2.1.5
matplotlib==3.6.3
mpmath==1.3.0
networkx==3.2.1
numpy==1.22.4
packaging==24.1
Pillow==9.1.0
pyparsing==3.1.2
PyQt5==5.15.6
PyQt5-Qt5==5.15.2
PyQt5-sip==12.13.0
python-dateutil==2.9.0.post0
requests==2.32.3
six==1.16.0
sympy==1.12.1
torch==1.13.1
torchvision==0.14.0
typing_extensions==4.12.2
urllib3==2.2.1
Loading…
Cancel
Save