import matplotlib import torch from torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import MNIST import matplotlib.pyplot as plt matplotlib.use('TkAgg') # 神经网络类 主体 class Net(torch.nn.Module): def __init__(self): super().__init__() # 四个全连接层 self.fc1 = torch.nn.Linear(28 * 28, 64) # 输入28*28像素的图像 64个节点 self.fc2 = torch.nn.Linear(64, 64) self.fc3 = torch.nn.Linear(64, 64) self.fc4 = torch.nn.Linear(64, 10) # 输出十个数字类别 def forward(self, x): # x:图像输入 # 每层首先做全连接线性计算 self.fc1 再套上激活函数torch.nn.functional.relu x = torch.nn.functional.relu(self.fc1(x)) x = torch.nn.functional.relu(self.fc2(x)) x = torch.nn.functional.relu(self.fc3(x)) # 输出层通过softmax进行归一化 x = torch.nn.functional.log_softmax(self.fc4(x), dim=1) return x # 下载MNIST数据集 导入数据 def get_data_loader(is_train): # tensor 多维数组 张量 to_tensor = transforms.Compose([transforms.ToTensor()]) """ 第一个参数:''标识数据集下载位置 空表示下载到当前目录 is_train:指定导入训练集 """ data_set = MNIST('', is_train, transform=to_tensor, download=True) # batch_size:一个批次包含15张图片 shuffle=True:数据随机打乱 return DataLoader(data_set, batch_size=15, shuffle=True) # 返回数据加载器DataLoader() # 评估神经网络识别正确率 def evaluate(test_data, net): # 正确预测数量 n_correct = 0 # 预测总数 n_total = 0 with torch.no_grad(): # 从测试集依次取出数据 for (x, y) in test_data: # 计算神经网络预测值 outputs = net.forward(x.view(-1, 28 * 28)) # 对批次结果进行比较 for i, output in enumerate(outputs): # 累加正确预测的数量 argmax 计算数列中最大结果,即预测的手写数字结果 if torch.argmax(output) == y[i]: n_correct += 1 n_total += 1 # 返回预测正确率 return n_correct / n_total def main(): # 导入训练集 train_data = get_data_loader(is_train=True) # 导入测试集 test_data = get_data_loader(is_train=False) # 初始化神经网络 net = Net() # 打印最开始的时候的正确率 约为0.1左右 print('初始神经网络正确率:', evaluate(test_data, net)) optimizer = torch.optim.Adam(net.parameters(), lr=0.001) # epoch:轮次 反复使用train_data进行训练神经网络,提高训练集利用率 for epoch in range(2): for (x, y) in train_data: net.zero_grad() # 初始化 output = net.forward(x.view(-1, 28 * 28)) # 正向传播 # nll_lose 对数损失函数 与前的softmax对数计算进行对应 loss = torch.nn.functional.nll_loss(output, y) # 计算差值 loss.backward() # 反向误差传播 optimizer.step() # 优化网络参数 print('第', epoch+1, '次训练后准确率:', evaluate(test_data, net)) # 随机抽取三张图象显示网络预测结果 for (n, (x, _)) in enumerate(test_data): if n > 2: break predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28))) plt.figure(n) plt.imshow(x[0].view(28, 28)) plt.title('prediction:' + str(int(predict))) plt.show() if __name__ == '__main__': main()