|
|
@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
|
|
|
import torchvision
|
|
|
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Rnn(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self, in_dim, hidden_dim, n_layer, n_classes): # (输入层维度,隐藏层维度,循环网络个数,分类个数)
|
|
|
|
|
|
|
|
super(Rnn, self).__init__()
|
|
|
|
|
|
|
|
# 存入self
|
|
|
|
|
|
|
|
self.n_layer = n_layer
|
|
|
|
|
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
|
|
|
|
self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True) # LSTM(输入层,隐藏层,网络个数,批数量置于第一位)
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(hidden_dim, n_classes) # (输入大小,输出大小)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
out, (h_n, c_n) = self.lstm(x)
|
|
|
|
|
|
|
|
# 取得最后一层隐藏层h_t
|
|
|
|
|
|
|
|
x = h_n[-1, :, :]
|
|
|
|
|
|
|
|
# 进行分类,并返回结果
|
|
|
|
|
|
|
|
x = self.classifier(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 数据预处理模块
|
|
|
|
|
|
|
|
transform = transforms.Compose([
|
|
|
|
|
|
|
|
transforms.ToTensor(), # 将numpy转换为tensor
|
|
|
|
|
|
|
|
transforms.Normalize([0.5], [0.5]), # 将tensor正则化
|
|
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 设置数据集下载地址并下载训练集
|
|
|
|
|
|
|
|
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True) # 载入训练集,设置批大小为128
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 下载测试集
|
|
|
|
|
|
|
|
testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False) # 设置测试集批大小为512
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 对用户是否能够使用cuda进行判断
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
|
|
|
device = torch.device('cuda:0')
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
device = torch.device('cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net = Rnn(28, 10, 2, 10) # 构造网络
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net = net.to(device) # 将网络载入设备
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() # 使用CEL作为此次训练的损失函数
|
|
|
|
|
|
|
|
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9) # 构造SGD随机梯度下降优化器
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Training
|
|
|
|
|
|
|
|
def train(epoch):
|
|
|
|
|
|
|
|
print('\nEpoch: %d' % epoch)
|
|
|
|
|
|
|
|
net.train() # 设置为train模式
|
|
|
|
|
|
|
|
train_loss = 0
|
|
|
|
|
|
|
|
correct = 0
|
|
|
|
|
|
|
|
total = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for batch_idx, (inputs, targets) in enumerate(trainloader):
|
|
|
|
|
|
|
|
inputs, targets = inputs.to(device), targets.to(device) # 将数据集取出并载入设备
|
|
|
|
|
|
|
|
optimizer.zero_grad() # 将优化器导数清零
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = net(torch.squeeze(inputs, 1)) # 将数据集改造成网络输入要求形式
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = criterion(outputs, targets) # 使用损失函数计算loss
|
|
|
|
|
|
|
|
loss.backward() # 反向传播
|
|
|
|
|
|
|
|
optimizer.step() # 优化
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_loss += loss.item() # 累加loss数值
|
|
|
|
|
|
|
|
_, predicted = outputs.max(1) # 取出当前网络在第0维度输出的最高值
|
|
|
|
|
|
|
|
total += targets.size(0) # 获取当前所有训练数据的label
|
|
|
|
|
|
|
|
correct += predicted.eq(targets).sum().item() # 计算错误
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
|
|
|
|
|
|
|
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test(epoch):
|
|
|
|
|
|
|
|
global best_acc
|
|
|
|
|
|
|
|
net.eval()
|
|
|
|
|
|
|
|
test_loss = 0
|
|
|
|
|
|
|
|
correct = 0
|
|
|
|
|
|
|
|
total = 0
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
for batch_idx, (inputs, targets) in enumerate(testloader):
|
|
|
|
|
|
|
|
inputs, targets = inputs.to(device), targets.to(device)
|
|
|
|
|
|
|
|
outputs = net(torch.squeeze(inputs, 1))
|
|
|
|
|
|
|
|
loss = criterion(outputs, targets)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test_loss += loss.item()
|
|
|
|
|
|
|
|
_, predicted = outputs.max(1)
|
|
|
|
|
|
|
|
total += targets.size(0)
|
|
|
|
|
|
|
|
correct += predicted.eq(targets).sum().item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
|
|
|
|
|
|
|
|
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(net)
|
|
|
|
|
|
|
|
for epoch in range(100):
|
|
|
|
|
|
|
|
train(epoch)
|
|
|
|
|
|
|
|
test(epoch)
|