You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
3.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import matplotlib
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
matplotlib.use('TkAgg')
# 神经网络类 主体
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
# 四个全连接层
self.fc1 = torch.nn.Linear(28 * 28, 64) # 输入28*28像素的图像 64个节点
self.fc2 = torch.nn.Linear(64, 64)
self.fc3 = torch.nn.Linear(64, 64)
self.fc4 = torch.nn.Linear(64, 10) # 输出十个数字类别
def forward(self, x): # x图像输入
# 每层首先做全连接线性计算 self.fc1 再套上激活函数torch.nn.functional.relu
x = torch.nn.functional.relu(self.fc1(x))
x = torch.nn.functional.relu(self.fc2(x))
x = torch.nn.functional.relu(self.fc3(x))
# 输出层通过softmax进行归一化
x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
return x
# 下载MNIST数据集 导入数据
def get_data_loader(is_train):
# tensor 多维数组 张量
to_tensor = transforms.Compose([transforms.ToTensor()])
"""
第一个参数:''标识数据集下载位置 空表示下载到当前目录
is_train:指定导入训练集
"""
data_set = MNIST('', is_train, transform=to_tensor, download=True)
# batch_size一个批次包含15张图片 shuffle=True数据随机打乱
return DataLoader(data_set, batch_size=15, shuffle=True) # 返回数据加载器DataLoader
# 评估神经网络识别正确率
def evaluate(test_data, net):
# 正确预测数量
n_correct = 0
# 预测总数
n_total = 0
with torch.no_grad():
# 从测试集依次取出数据
for (x, y) in test_data:
# 计算神经网络预测值
outputs = net.forward(x.view(-1, 28 * 28))
# 对批次结果进行比较
for i, output in enumerate(outputs):
# 累加正确预测的数量 argmax 计算数列中最大结果,即预测的手写数字结果
if torch.argmax(output) == y[i]:
n_correct += 1
n_total += 1
# 返回预测正确率
return n_correct / n_total
def main():
# 导入训练集
train_data = get_data_loader(is_train=True)
# 导入测试集
test_data = get_data_loader(is_train=False)
# 初始化神经网络
net = Net()
# 打印最开始的时候的正确率 约为0.1左右
print('初始神经网络正确率:', evaluate(test_data, net))
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
# epoch轮次 反复使用train_data进行训练神经网络提高训练集利用率
for epoch in range(2):
for (x, y) in train_data:
net.zero_grad() # 初始化
output = net.forward(x.view(-1, 28 * 28)) # 正向传播
# nll_lose 对数损失函数 与前的softmax对数计算进行对应
loss = torch.nn.functional.nll_loss(output, y) # 计算差值
loss.backward() # 反向误差传播
optimizer.step() # 优化网络参数
print('', epoch+1, '次训练后准确率:', evaluate(test_data, net))
# 随机抽取三张图象显示网络预测结果
for (n, (x, _)) in enumerate(test_data):
if n > 2:
break
predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
plt.figure(n)
plt.imshow(x[0].view(28, 28))
plt.title('prediction:' + str(int(predict)))
plt.show()
if __name__ == '__main__':
main()