master
ptjlmfa8p 4 years ago
parent ea70ba1517
commit 8ce42bb1b2

@ -0,0 +1,102 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms
class Rnn(nn.Module):
def __init__(self, in_dim, hidden_dim, n_layer, n_classes): # (输入层维度,隐藏层维度,循环网络个数,分类个数)
super(Rnn, self).__init__()
# 存入self
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True) # LSTM输入层隐藏层网络个数批数量置于第一位
self.classifier = nn.Linear(hidden_dim, n_classes) # (输入大小,输出大小)
def forward(self, x):
out, (h_n, c_n) = self.lstm(x)
# 取得最后一层隐藏层h_t
x = h_n[-1, :, :]
# 进行分类,并返回结果
x = self.classifier(x)
return x
# 数据预处理模块
transform = transforms.Compose([
transforms.ToTensor(), # 将numpy转换为tensor
transforms.Normalize([0.5], [0.5]), # 将tensor正则化
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 设置数据集下载地址并下载训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True) # 载入训练集设置批大小为128
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) # 下载测试集
testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False) # 设置测试集批大小为512
# 对用户是否能够使用cuda进行判断
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
net = Rnn(28, 10, 2, 10) # 构造网络
net = net.to(device) # 将网络载入设备
criterion = nn.CrossEntropyLoss() # 使用CEL作为此次训练的损失函数
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9) # 构造SGD随机梯度下降优化器
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train() # 设置为train模式
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device) # 将数据集取出并载入设备
optimizer.zero_grad() # 将优化器导数清零
outputs = net(torch.squeeze(inputs, 1)) # 将数据集改造成网络输入要求形式
loss = criterion(outputs, targets) # 使用损失函数计算loss
loss.backward() # 反向传播
optimizer.step() # 优化
train_loss += loss.item() # 累加loss数值
_, predicted = outputs.max(1) # 取出当前网络在第0维度输出的最高值
total += targets.size(0) # 获取当前所有训练数据的label
correct += predicted.eq(targets).sum().item() # 计算错误
print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
def test(epoch):
global best_acc
net.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(torch.squeeze(inputs, 1))
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss / (batch_idx + 1), 100. * correct / total, correct, total))
print(net)
for epoch in range(100):
train(epoch)
test(epoch)
Loading…
Cancel
Save