diff --git "a/C:\\Users\\zzb\\Desktop\\识别" "b/C:\\Users\\zzb\\Desktop\\识别"
deleted file mode 100644
index e69de29..0000000
diff --git a/README.md b/README.md
deleted file mode 100644
index 0116b1f..0000000
--- a/README.md
+++ /dev/null
@@ -1,2 +0,0 @@
-# written_number_classfication
-
diff --git a/识别/.idea/.gitignore b/识别/.idea/.gitignore
new file mode 100644
index 0000000..35410ca
--- /dev/null
+++ b/识别/.idea/.gitignore
@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/识别/.idea/inspectionProfiles/Project_Default.xml b/识别/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..63a980c
--- /dev/null
+++ b/识别/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/识别/.idea/inspectionProfiles/profiles_settings.xml b/识别/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/识别/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/识别/.idea/misc.xml b/识别/.idea/misc.xml
new file mode 100644
index 0000000..33a23b2
--- /dev/null
+++ b/识别/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/识别/.idea/modules.xml b/识别/.idea/modules.xml
new file mode 100644
index 0000000..550b8d6
--- /dev/null
+++ b/识别/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/识别/.idea/识别.iml b/识别/.idea/识别.iml
new file mode 100644
index 0000000..140eb74
--- /dev/null
+++ b/识别/.idea/识别.iml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/识别/__pycache__/mm.cpython-311.pyc b/识别/__pycache__/mm.cpython-311.pyc
new file mode 100644
index 0000000..954b13e
Binary files /dev/null and b/识别/__pycache__/mm.cpython-311.pyc differ
diff --git a/识别/__pycache__/mm.cpython-312.pyc b/识别/__pycache__/mm.cpython-312.pyc
new file mode 100644
index 0000000..5202aa1
Binary files /dev/null and b/识别/__pycache__/mm.cpython-312.pyc differ
diff --git a/识别/another-main.py b/识别/another-main.py
new file mode 100644
index 0000000..bba7893
--- /dev/null
+++ b/识别/another-main.py
@@ -0,0 +1,110 @@
+import sys
+import numpy as np
+import torch
+from PIL import Image, ImageQt
+from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton, QLabel
+from PyQt5.QtGui import QPainter, QPen, QImage, QPixmap, QColor
+from PyQt5.QtCore import Qt, QPoint
+from mm import CNN
+
+
+class PaintBoard(QWidget):
+ def __init__(self):
+ super().__init__()
+ self.image = QImage(224, 224, QImage.Format_RGB32)
+ self.image.fill(Qt.black)
+ self.last_point = QPoint()
+ self.initUI()
+
+ def initUI(self):
+ self.setGeometry(0, 0, 224, 224)
+ self.setFixedSize(224, 224)
+ self.setStyleSheet("background-color: black;")
+
+ def clear(self):
+ self.image.fill(Qt.black)
+ self.update()
+
+ def getContentAsQImage(self):
+ return self.image
+
+ def mousePressEvent(self, event):
+ if event.button() == Qt.LeftButton:
+ self.last_point = event.pos()
+
+ def mouseMoveEvent(self, event):
+ if event.buttons() & Qt.LeftButton:
+ painter = QPainter(self.image)
+ pen = QPen(Qt.white, 20, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin)
+ painter.setPen(pen)
+ painter.drawLine(self.last_point, event.pos())
+ self.last_point = event.pos()
+ self.update()
+
+ def paintEvent(self, event):
+ canvas_painter = QPainter(self)
+ canvas_painter.drawImage(self.rect(), self.image, self.image.rect())
+
+
+class MainWindow(QMainWindow):
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+ self.model.eval()
+
+ self.setWindowTitle("Handwritten Digit Recognition")
+ self.setGeometry(100, 100, 300, 400)
+ self.paint_board = PaintBoard()
+ self.result_label = QLabel("Prediction: ", self)
+ self.result_label.setGeometry(50, 250, 200, 50)
+ self.result_label.setStyleSheet("font-size: 18px;")
+
+ clear_button = QPushButton("Clear", self)
+ clear_button.setGeometry(50, 300, 100, 50)
+ clear_button.clicked.connect(self.paint_board.clear)
+
+ predict_button = QPushButton("Predict", self)
+ predict_button.setGeometry(150, 300, 100, 50)
+ predict_button.clicked.connect(self.predict)
+
+ #创建垂直布局管理器,把组件放到布局里展示出来
+ layout = QVBoxLayout()
+ layout.addWidget(self.paint_board)
+ layout.addWidget(self.result_label)
+ layout.addWidget(clear_button)
+ layout.addWidget(predict_button)
+ #创建一个QWidget对象作为主容器,然后将垂直布局管理器放在这个主容器中
+ container = QWidget()
+ container.setLayout(layout)
+ self.setCentralWidget(container)
+
+ def predict(self):
+ try:
+ print("Starting prediction...")
+ image = self.paint_board.getContentAsQImage()
+ pil_image = ImageQt.fromqimage(image)
+ pil_image = pil_image.resize((28, 28), Image.ANTIALIAS)
+ pil_image = pil_image.convert('L')
+
+ img_array = np.array(pil_image).reshape(1, 1, 28, 28).astype(np.float32) / 255.0
+ img_tensor = torch.tensor(img_array).to(device)
+
+ with torch.no_grad():
+ output = self.model(img_tensor)
+ prediction = torch.argmax(output, dim=1).item()
+
+ self.result_label.setText(f"Prediction: {prediction}")
+ print(f"Prediction: {prediction}")
+ except Exception as e:
+ print(f"Error during prediction: {e}")
+
+
+if __name__ == "__main__":
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = CNN().to(device)
+ model.load_state_dict(torch.load('model.pth', map_location=device))
+
+ app = QApplication(sys.argv)
+ main_window = MainWindow(model)
+ main_window.show()
+ sys.exit(app.exec_())
diff --git a/识别/basic.py b/识别/basic.py
new file mode 100644
index 0000000..23c81d6
--- /dev/null
+++ b/识别/basic.py
@@ -0,0 +1,105 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+import torchvision
+import matplotlib.pyplot as plt
+import numpy as np
+
+input_size = 28
+num_classes = 10
+num_epochs = 40
+batch_size = 64
+
+# 确保 CUDA 可用并且使用它
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+train_dataset = torchvision.datasets.MNIST(
+ root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True
+)
+test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor())
+
+train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
+ batch_size=batch_size,
+ shuffle=True)
+test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
+ batch_size=batch_size,
+ shuffle=True)
+
+#定义了一个CNN类作为神经网络模型
+# in_channels:输入数据的通道数。在第一层卷积中,输入的MNIST图像是灰度图像,因此 in_channels=1。
+# out_channels:输出的特征图通道数,即卷积核的数量。每个卷积核生成一个特征图。
+# kernel_size:卷积核的大小。例如,kernel_size=5 表示卷积核大小为5x5。
+# stride:卷积核的步幅。在代码中设置为1,表示卷积核每次移动1个像素。
+# padding:填充的像素数。在代码中设置为2,表示在输入数据周围填充2个像素。
+class CNN(nn.Module):
+ def __init__(self):
+ super(CNN, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(1, 16, 5, 1, 2),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(16, 32, 5, 1, 2),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(32, 64, 5, 1, 2),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(64, 32, 5, 1, 2),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+ self.out = nn.Linear(32, num_classes) # 确保输出类别数正确
+
+ def forward(self, x):
+ x = self.conv1(x).to(device) # 将数据迁移到GPU
+ x = self.conv2(x).to(device)
+ x=self.conv3(x).to(device)
+ x = self.conv4(x).to(device)
+ x = x.view(x.size(0), -1).to(device)
+ x = self.out(x).to(device)
+ return x
+
+def accuracy(prediction, label):
+ pred = torch.max(prediction.data, 1)[1]
+ rights = pred.eq(label.data.view_as(pred)).sum()
+ return rights, len(label)
+
+net = CNN().to(device) # 将模型迁移到GPU
+criterion = nn.CrossEntropyLoss().to(device) # 将损失函数迁移到GPU
+optimizer = optim.Adam(net.parameters(), lr=0.001)
+
+tr = []
+for epoch in range(40):
+ train_rights = []
+ for batch_id, (data, target) in enumerate(train_loader):
+ data, target = data.to(device), target.to(device) # 将数据迁移到GPU
+ optimizer.zero_grad()
+ output = net(data)
+ loss = criterion(output, target)
+ loss.backward()
+ optimizer.step()
+ right = accuracy(output, target)
+ train_rights.append(right)
+ tr.append(loss)
+ if batch_id % 100 == 0:
+ val_right = []
+ for (data, target) in test_loader:
+ data, target = data.to(device), target.to(device) # 将数据迁移到GPU
+ output = net(data)
+ right = accuracy(output, target)
+ val_right.append(right)
+ train_r = (sum(t[0] for t in train_rights), sum(t[1] for t in train_rights))
+ val_r = (sum(t[0] for t in val_right), sum(t[1] for t in val_right))
+ print("train_acc {}, test_acc {}".format(train_r[0] / train_r[1], val_r[0] / val_r[1]))
+
+# 注意,绘图时不需要将数据迁移到GPU
+plt.plot(list(range(len([tensor.detach().cpu().numpy() for tensor in tr]))), [tensor.detach().cpu().numpy() for tensor in tr])
+plt.show()
+torch.save(net.state_dict(), 'model.pth')
diff --git a/识别/data/MNIST/raw/t10k-images-idx3-ubyte b/识别/data/MNIST/raw/t10k-images-idx3-ubyte
new file mode 100644
index 0000000..1170b2c
Binary files /dev/null and b/识别/data/MNIST/raw/t10k-images-idx3-ubyte differ
diff --git a/识别/data/MNIST/raw/t10k-images-idx3-ubyte.gz b/识别/data/MNIST/raw/t10k-images-idx3-ubyte.gz
new file mode 100644
index 0000000..5ace8ea
Binary files /dev/null and b/识别/data/MNIST/raw/t10k-images-idx3-ubyte.gz differ
diff --git a/识别/data/MNIST/raw/t10k-labels-idx1-ubyte b/识别/data/MNIST/raw/t10k-labels-idx1-ubyte
new file mode 100644
index 0000000..d1c3a97
Binary files /dev/null and b/识别/data/MNIST/raw/t10k-labels-idx1-ubyte differ
diff --git a/识别/data/MNIST/raw/t10k-labels-idx1-ubyte.gz b/识别/data/MNIST/raw/t10k-labels-idx1-ubyte.gz
new file mode 100644
index 0000000..a7e1415
Binary files /dev/null and b/识别/data/MNIST/raw/t10k-labels-idx1-ubyte.gz differ
diff --git a/识别/data/MNIST/raw/train-images-idx3-ubyte b/识别/data/MNIST/raw/train-images-idx3-ubyte
new file mode 100644
index 0000000..bbce276
Binary files /dev/null and b/识别/data/MNIST/raw/train-images-idx3-ubyte differ
diff --git a/识别/data/MNIST/raw/train-images-idx3-ubyte.gz b/识别/data/MNIST/raw/train-images-idx3-ubyte.gz
new file mode 100644
index 0000000..b50e4b6
Binary files /dev/null and b/识别/data/MNIST/raw/train-images-idx3-ubyte.gz differ
diff --git a/识别/data/MNIST/raw/train-labels-idx1-ubyte b/识别/data/MNIST/raw/train-labels-idx1-ubyte
new file mode 100644
index 0000000..d6b4c5d
Binary files /dev/null and b/识别/data/MNIST/raw/train-labels-idx1-ubyte differ
diff --git a/识别/data/MNIST/raw/train-labels-idx1-ubyte.gz b/识别/data/MNIST/raw/train-labels-idx1-ubyte.gz
new file mode 100644
index 0000000..707a576
Binary files /dev/null and b/识别/data/MNIST/raw/train-labels-idx1-ubyte.gz differ
diff --git a/识别/mm.py b/识别/mm.py
new file mode 100644
index 0000000..bb75439
--- /dev/null
+++ b/识别/mm.py
@@ -0,0 +1,42 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+import torchvision
+import matplotlib.pyplot as plt
+import numpy as np
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+class CNN(nn.Module):
+ def __init__(self):
+ super(CNN, self).__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(1, 16, 5, 1, 2),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv2d(16, 32, 5, 1, 2),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+ self.conv3 = nn.Sequential(
+ nn.Conv2d(32, 64, 5, 1, 2),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+ self.conv4 = nn.Sequential(
+ nn.Conv2d(64, 32, 5, 1, 2),
+ nn.ReLU(),
+ nn.MaxPool2d(2)
+ )
+ self.out = nn.Linear(32, 10) # 确保输出类别数正确
+
+ def forward(self, x):
+ x = self.conv1(x).to(device) # 将数据迁移到GPU
+ x = self.conv2(x).to(device)
+ x=self.conv3(x).to(device)
+ x = self.conv4(x).to(device)
+ x = x.view(x.size(0), -1).to(device)
+ x = self.out(x).to(device)
+ return x
\ No newline at end of file
diff --git a/识别/model.pth b/识别/model.pth
new file mode 100644
index 0000000..44d47e2
Binary files /dev/null and b/识别/model.pth differ
diff --git a/识别/requirements.txt b/识别/requirements.txt
new file mode 100644
index 0000000..c01cfab
--- /dev/null
+++ b/识别/requirements.txt
@@ -0,0 +1,28 @@
+certifi==2024.6.2
+charset-normalizer==3.3.2
+contourpy==1.2.1
+cycler==0.12.1
+filelock==3.15.1
+fonttools==4.53.0
+idna==3.7
+Jinja2==3.1.4
+kiwisolver==1.4.5
+MarkupSafe==2.1.5
+matplotlib==3.6.3
+mpmath==1.3.0
+networkx==3.2.1
+numpy==1.22.4
+packaging==24.1
+Pillow==9.1.0
+pyparsing==3.1.2
+PyQt5==5.15.6
+PyQt5-Qt5==5.15.2
+PyQt5-sip==12.13.0
+python-dateutil==2.9.0.post0
+requests==2.32.3
+six==1.16.0
+sympy==1.12.1
+torch==1.13.1
+torchvision==0.14.0
+typing_extensions==4.12.2
+urllib3==2.2.1