You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
2.7 KiB
92 lines
2.7 KiB
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch import optim
|
|
|
|
import torchvision
|
|
from matplotlib import pyplot as plt
|
|
import utils
|
|
|
|
def plot_curve(data):
|
|
fig = plt.figure()
|
|
plt.plot(range(len(data)), data, color='blue')
|
|
plt.legend(['value'], loc='upper right')
|
|
plt.xlabel('step')
|
|
plt.ylabel('value')
|
|
plt.show()
|
|
|
|
|
|
def plot_image(img, label, name):
|
|
|
|
fig = plt.figure()
|
|
for i in range(6):
|
|
plt.subplot(2, 3, i+1)
|
|
plt.tight_layout()
|
|
plt.imshow(img[i, 0]*0.3081+0.1307, cmap='gray', interpolation='none')
|
|
plt.title("{}: {}".format(name,label[i].item()))
|
|
plt.xticks([])
|
|
plt.yticks([])
|
|
plt.show()
|
|
|
|
def one_hot(label, depth=10):
|
|
out = torch.zeros(label.size(0), depth)
|
|
idx = torch.LongTensor(label).view(-1, 1)
|
|
out.scatter_(dim=1, index=idx, value=1)
|
|
return out
|
|
|
|
batch_size=512
|
|
train_loader=torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size=batch_size,shuffle=True)
|
|
|
|
test_loader=torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size=batch_size,shuffle=True)
|
|
|
|
x,y=next(iter(train_loader))
|
|
print(x.shape,y.shape,x.min(),x.max())
|
|
plot_image(x,y,'image_sample')
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super(Net,self).__init__()
|
|
self.fc1=nn.Linear(28*28,256)
|
|
self.fc2=nn.Linear(256,64)
|
|
self.fc3=nn.Linear(64,10)
|
|
|
|
def forward(self,x):
|
|
x=F.relu(self.fc1(x))
|
|
x=F.relu(self.fc2(x))
|
|
x=self.fc3(x)
|
|
return x
|
|
|
|
net=Net()
|
|
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
|
|
train_loss=[]
|
|
for epoch in range(3):
|
|
for batch_idx,(x,y) in enumerate(train_loader):
|
|
x=x.view(x.size(0),28*28)
|
|
out=net(x)
|
|
y_onehot=one_hot(y)
|
|
loss=F.mse_loss(out,y_onehot)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
train_loss.append(loss.item())
|
|
|
|
if batch_idx %10==0:
|
|
print(epoch,batch_idx,loss.item())
|
|
|
|
plot_curve(train_loss)
|
|
|
|
total_correct=0
|
|
for x,y in test_loader:
|
|
x=x.view(x.size(0),28*28)
|
|
out=net(x)
|
|
pred=out.argmax(dim=1)
|
|
correct=pred.eq(y).sum().float().item()
|
|
total_correct+=correct
|
|
total_num=len(test_loader.dataset)
|
|
acc=total_correct/total_num
|
|
print("test acc:",acc)
|
|
|
|
x,y=next(iter(test_loader))
|
|
out=net(x.view(x.size(0),28*28))
|
|
pred=out.argmax(dim=1)
|
|
plot_image(x,pred,'test') |