You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
def __init__(self, in_dim, hidden_dim, n_layer, n_classes): # (输入层维度,隐藏层维度,循环网络个数,分类个数)
super(Rnn, self).__init__()
# 存入self
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True) # LSTM输入层隐藏层网络个数批数量置于第一位
self.classifier = nn.Linear(hidden_dim, n_classes) # (输入大小,输出大小)
def forward(self, x):
out, (h_n, c_n) = self.lstm(x)
# 取得最后一层隐藏层h_t
x = h_n[-1, :, :]
# 进行分类,并返回结果
x = self.classifier(x)
return x
# 数据预处理模块
transform = transforms.Compose([
transforms.ToTensor(), # 将numpy转换为tensor
transforms.Normalize([0.5], [0.5]), # 将tensor正则化
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 设置数据集下载地址并下载训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True) # 载入训练集设置批大小为128
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 下载测试集
testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False) # 设置测试集批大小为512
# 对用户是否能够使用cuda进行判断
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
net = Rnn(28, 10, 2, 10) # 构造网络
net = net.to(device) # 将网络载入设备
criterion = nn.CrossEntropyLoss() # 使用CEL作为此次训练的损失函数
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9) # 构造SGD随机梯度下降优化器
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train() # 设置为train模式
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device) # 将数据集取出并载入设备
optimizer.zero_grad() # 将优化器导数清零
outputs = net(torch.squeeze(inputs, 1)) # 将数据集改造成网络输入要求形式
loss = criterion(outputs, targets) # 使用损失函数计算loss
loss.backward() # 反向传播
optimizer.step() # 优化
train_loss += loss.item() # 累加loss数值
_, predicted = outputs.max(1) # 取出当前网络在第0维度输出的最高值
total += targets.size(0) # 获取当前所有训练数据的label
correct += predicted.eq(targets).sum().item() # 计算错误
print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(torch.squeeze(inputs, 1))
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
print(net)
for epoch in range(100):
train(epoch)
test(epoch)