diff --git "a/C:\\Users\\zzb\\Desktop\\识别" "b/C:\\Users\\zzb\\Desktop\\识别" deleted file mode 100644 index e69de29..0000000 diff --git a/README.md b/README.md deleted file mode 100644 index 0116b1f..0000000 --- a/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# written_number_classfication - diff --git a/识别/.idea/.gitignore b/识别/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/识别/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/识别/.idea/inspectionProfiles/Project_Default.xml b/识别/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..63a980c --- /dev/null +++ b/识别/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,12 @@ + + + + \ No newline at end of file diff --git a/识别/.idea/inspectionProfiles/profiles_settings.xml b/识别/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/识别/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/识别/.idea/misc.xml b/识别/.idea/misc.xml new file mode 100644 index 0000000..33a23b2 --- /dev/null +++ b/识别/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/识别/.idea/modules.xml b/识别/.idea/modules.xml new file mode 100644 index 0000000..550b8d6 --- /dev/null +++ b/识别/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/识别/.idea/识别.iml b/识别/.idea/识别.iml new file mode 100644 index 0000000..140eb74 --- /dev/null +++ b/识别/.idea/识别.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/识别/__pycache__/mm.cpython-311.pyc b/识别/__pycache__/mm.cpython-311.pyc new file mode 100644 index 0000000..954b13e Binary files /dev/null and b/识别/__pycache__/mm.cpython-311.pyc differ diff --git a/识别/__pycache__/mm.cpython-312.pyc b/识别/__pycache__/mm.cpython-312.pyc new file mode 100644 index 0000000..5202aa1 Binary files /dev/null and b/识别/__pycache__/mm.cpython-312.pyc differ diff --git a/识别/another-main.py b/识别/another-main.py new file mode 100644 index 0000000..bba7893 --- /dev/null +++ b/识别/another-main.py @@ -0,0 +1,110 @@ +import sys +import numpy as np +import torch +from PIL import Image, ImageQt +from PyQt5.QtWidgets import QApplication, QMainWindow, QVBoxLayout, QWidget, QPushButton, QLabel +from PyQt5.QtGui import QPainter, QPen, QImage, QPixmap, QColor +from PyQt5.QtCore import Qt, QPoint +from mm import CNN + + +class PaintBoard(QWidget): + def __init__(self): + super().__init__() + self.image = QImage(224, 224, QImage.Format_RGB32) + self.image.fill(Qt.black) + self.last_point = QPoint() + self.initUI() + + def initUI(self): + self.setGeometry(0, 0, 224, 224) + self.setFixedSize(224, 224) + self.setStyleSheet("background-color: black;") + + def clear(self): + self.image.fill(Qt.black) + self.update() + + def getContentAsQImage(self): + return self.image + + def mousePressEvent(self, event): + if event.button() == Qt.LeftButton: + self.last_point = event.pos() + + def mouseMoveEvent(self, event): + if event.buttons() & Qt.LeftButton: + painter = QPainter(self.image) + pen = QPen(Qt.white, 20, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) + painter.setPen(pen) + painter.drawLine(self.last_point, event.pos()) + self.last_point = event.pos() + self.update() + + def paintEvent(self, event): + canvas_painter = QPainter(self) + canvas_painter.drawImage(self.rect(), self.image, self.image.rect()) + + +class MainWindow(QMainWindow): + def __init__(self, model): + super().__init__() + self.model = model + self.model.eval() + + self.setWindowTitle("Handwritten Digit Recognition") + self.setGeometry(100, 100, 300, 400) + self.paint_board = PaintBoard() + self.result_label = QLabel("Prediction: ", self) + self.result_label.setGeometry(50, 250, 200, 50) + self.result_label.setStyleSheet("font-size: 18px;") + + clear_button = QPushButton("Clear", self) + clear_button.setGeometry(50, 300, 100, 50) + clear_button.clicked.connect(self.paint_board.clear) + + predict_button = QPushButton("Predict", self) + predict_button.setGeometry(150, 300, 100, 50) + predict_button.clicked.connect(self.predict) + + #创建垂直布局管理器,把组件放到布局里展示出来 + layout = QVBoxLayout() + layout.addWidget(self.paint_board) + layout.addWidget(self.result_label) + layout.addWidget(clear_button) + layout.addWidget(predict_button) + #创建一个QWidget对象作为主容器,然后将垂直布局管理器放在这个主容器中 + container = QWidget() + container.setLayout(layout) + self.setCentralWidget(container) + + def predict(self): + try: + print("Starting prediction...") + image = self.paint_board.getContentAsQImage() + pil_image = ImageQt.fromqimage(image) + pil_image = pil_image.resize((28, 28), Image.ANTIALIAS) + pil_image = pil_image.convert('L') + + img_array = np.array(pil_image).reshape(1, 1, 28, 28).astype(np.float32) / 255.0 + img_tensor = torch.tensor(img_array).to(device) + + with torch.no_grad(): + output = self.model(img_tensor) + prediction = torch.argmax(output, dim=1).item() + + self.result_label.setText(f"Prediction: {prediction}") + print(f"Prediction: {prediction}") + except Exception as e: + print(f"Error during prediction: {e}") + + +if __name__ == "__main__": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = CNN().to(device) + model.load_state_dict(torch.load('model.pth', map_location=device)) + + app = QApplication(sys.argv) + main_window = MainWindow(model) + main_window.show() + sys.exit(app.exec_()) diff --git a/识别/basic.py b/识别/basic.py new file mode 100644 index 0000000..23c81d6 --- /dev/null +++ b/识别/basic.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision +import matplotlib.pyplot as plt +import numpy as np + +input_size = 28 +num_classes = 10 +num_epochs = 40 +batch_size = 64 + +# 确保 CUDA 可用并且使用它 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +train_dataset = torchvision.datasets.MNIST( + root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True +) +test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor()) + +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=batch_size, + shuffle=True) +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, + batch_size=batch_size, + shuffle=True) + +#定义了一个CNN类作为神经网络模型 +# in_channels:输入数据的通道数。在第一层卷积中,输入的MNIST图像是灰度图像,因此 in_channels=1。 +# out_channels:输出的特征图通道数,即卷积核的数量。每个卷积核生成一个特征图。 +# kernel_size:卷积核的大小。例如,kernel_size=5 表示卷积核大小为5x5。 +# stride:卷积核的步幅。在代码中设置为1,表示卷积核每次移动1个像素。 +# padding:填充的像素数。在代码中设置为2,表示在输入数据周围填充2个像素。 +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(1, 16, 5, 1, 2), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(16, 32, 5, 1, 2), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(32, 64, 5, 1, 2), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(64, 32, 5, 1, 2), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.out = nn.Linear(32, num_classes) # 确保输出类别数正确 + + def forward(self, x): + x = self.conv1(x).to(device) # 将数据迁移到GPU + x = self.conv2(x).to(device) + x=self.conv3(x).to(device) + x = self.conv4(x).to(device) + x = x.view(x.size(0), -1).to(device) + x = self.out(x).to(device) + return x + +def accuracy(prediction, label): + pred = torch.max(prediction.data, 1)[1] + rights = pred.eq(label.data.view_as(pred)).sum() + return rights, len(label) + +net = CNN().to(device) # 将模型迁移到GPU +criterion = nn.CrossEntropyLoss().to(device) # 将损失函数迁移到GPU +optimizer = optim.Adam(net.parameters(), lr=0.001) + +tr = [] +for epoch in range(40): + train_rights = [] + for batch_id, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) # 将数据迁移到GPU + optimizer.zero_grad() + output = net(data) + loss = criterion(output, target) + loss.backward() + optimizer.step() + right = accuracy(output, target) + train_rights.append(right) + tr.append(loss) + if batch_id % 100 == 0: + val_right = [] + for (data, target) in test_loader: + data, target = data.to(device), target.to(device) # 将数据迁移到GPU + output = net(data) + right = accuracy(output, target) + val_right.append(right) + train_r = (sum(t[0] for t in train_rights), sum(t[1] for t in train_rights)) + val_r = (sum(t[0] for t in val_right), sum(t[1] for t in val_right)) + print("train_acc {}, test_acc {}".format(train_r[0] / train_r[1], val_r[0] / val_r[1])) + +# 注意,绘图时不需要将数据迁移到GPU +plt.plot(list(range(len([tensor.detach().cpu().numpy() for tensor in tr]))), [tensor.detach().cpu().numpy() for tensor in tr]) +plt.show() +torch.save(net.state_dict(), 'model.pth') diff --git a/识别/data/MNIST/raw/t10k-images-idx3-ubyte b/识别/data/MNIST/raw/t10k-images-idx3-ubyte new file mode 100644 index 0000000..1170b2c Binary files /dev/null and b/识别/data/MNIST/raw/t10k-images-idx3-ubyte differ diff --git a/识别/data/MNIST/raw/t10k-images-idx3-ubyte.gz b/识别/data/MNIST/raw/t10k-images-idx3-ubyte.gz new file mode 100644 index 0000000..5ace8ea Binary files /dev/null and b/识别/data/MNIST/raw/t10k-images-idx3-ubyte.gz differ diff --git a/识别/data/MNIST/raw/t10k-labels-idx1-ubyte b/识别/data/MNIST/raw/t10k-labels-idx1-ubyte new file mode 100644 index 0000000..d1c3a97 Binary files /dev/null and b/识别/data/MNIST/raw/t10k-labels-idx1-ubyte differ diff --git a/识别/data/MNIST/raw/t10k-labels-idx1-ubyte.gz b/识别/data/MNIST/raw/t10k-labels-idx1-ubyte.gz new file mode 100644 index 0000000..a7e1415 Binary files /dev/null and b/识别/data/MNIST/raw/t10k-labels-idx1-ubyte.gz differ diff --git a/识别/data/MNIST/raw/train-images-idx3-ubyte b/识别/data/MNIST/raw/train-images-idx3-ubyte new file mode 100644 index 0000000..bbce276 Binary files /dev/null and b/识别/data/MNIST/raw/train-images-idx3-ubyte differ diff --git a/识别/data/MNIST/raw/train-images-idx3-ubyte.gz b/识别/data/MNIST/raw/train-images-idx3-ubyte.gz new file mode 100644 index 0000000..b50e4b6 Binary files /dev/null and b/识别/data/MNIST/raw/train-images-idx3-ubyte.gz differ diff --git a/识别/data/MNIST/raw/train-labels-idx1-ubyte b/识别/data/MNIST/raw/train-labels-idx1-ubyte new file mode 100644 index 0000000..d6b4c5d Binary files /dev/null and b/识别/data/MNIST/raw/train-labels-idx1-ubyte differ diff --git a/识别/data/MNIST/raw/train-labels-idx1-ubyte.gz b/识别/data/MNIST/raw/train-labels-idx1-ubyte.gz new file mode 100644 index 0000000..707a576 Binary files /dev/null and b/识别/data/MNIST/raw/train-labels-idx1-ubyte.gz differ diff --git a/识别/mm.py b/识别/mm.py new file mode 100644 index 0000000..bb75439 --- /dev/null +++ b/识别/mm.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torchvision +import matplotlib.pyplot as plt +import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +class CNN(nn.Module): + def __init__(self): + super(CNN, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(1, 16, 5, 1, 2), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(16, 32, 5, 1, 2), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(32, 64, 5, 1, 2), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(64, 32, 5, 1, 2), + nn.ReLU(), + nn.MaxPool2d(2) + ) + self.out = nn.Linear(32, 10) # 确保输出类别数正确 + + def forward(self, x): + x = self.conv1(x).to(device) # 将数据迁移到GPU + x = self.conv2(x).to(device) + x=self.conv3(x).to(device) + x = self.conv4(x).to(device) + x = x.view(x.size(0), -1).to(device) + x = self.out(x).to(device) + return x \ No newline at end of file diff --git a/识别/model.pth b/识别/model.pth new file mode 100644 index 0000000..44d47e2 Binary files /dev/null and b/识别/model.pth differ diff --git a/识别/requirements.txt b/识别/requirements.txt new file mode 100644 index 0000000..c01cfab --- /dev/null +++ b/识别/requirements.txt @@ -0,0 +1,28 @@ +certifi==2024.6.2 +charset-normalizer==3.3.2 +contourpy==1.2.1 +cycler==0.12.1 +filelock==3.15.1 +fonttools==4.53.0 +idna==3.7 +Jinja2==3.1.4 +kiwisolver==1.4.5 +MarkupSafe==2.1.5 +matplotlib==3.6.3 +mpmath==1.3.0 +networkx==3.2.1 +numpy==1.22.4 +packaging==24.1 +Pillow==9.1.0 +pyparsing==3.1.2 +PyQt5==5.15.6 +PyQt5-Qt5==5.15.2 +PyQt5-sip==12.13.0 +python-dateutil==2.9.0.post0 +requests==2.32.3 +six==1.16.0 +sympy==1.12.1 +torch==1.13.1 +torchvision==0.14.0 +typing_extensions==4.12.2 +urllib3==2.2.1