parent
b88983bd58
commit
057eb31d26
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch import optim
|
||||
|
||||
import torchvision
|
||||
from matplotlib import pyplot as plt
|
||||
import utils
|
||||
|
||||
def plot_curve(data):
|
||||
fig = plt.figure()
|
||||
plt.plot(range(len(data)), data, color='blue')
|
||||
plt.legend(['value'], loc='upper right')
|
||||
plt.xlabel('step')
|
||||
plt.ylabel('value')
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_image(img, label, name):
|
||||
|
||||
fig = plt.figure()
|
||||
for i in range(6):
|
||||
plt.subplot(2, 3, i+1)
|
||||
plt.tight_layout()
|
||||
plt.imshow(img[i, 0]*0.3081+0.1307, cmap='gray', interpolation='none')
|
||||
plt.title("{}: {}".format(name,label[i].item()))
|
||||
plt.xticks([])
|
||||
plt.yticks([])
|
||||
plt.show()
|
||||
|
||||
def one_hot(label, depth=10):
|
||||
out = torch.zeros(label.size(0), depth)
|
||||
idx = torch.LongTensor(label).view(-1, 1)
|
||||
out.scatter_(dim=1, index=idx, value=1)
|
||||
return out
|
||||
|
||||
batch_size=512
|
||||
train_loader=torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size=batch_size,shuffle=True)
|
||||
|
||||
test_loader=torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data/',train=False,download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))])),batch_size=batch_size,shuffle=True)
|
||||
|
||||
x,y=next(iter(train_loader))
|
||||
print(x.shape,y.shape,x.min(),x.max())
|
||||
plot_image(x,y,'image_sample')
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net,self).__init__()
|
||||
self.fc1=nn.Linear(28*28,256)
|
||||
self.fc2=nn.Linear(256,64)
|
||||
self.fc3=nn.Linear(64,10)
|
||||
|
||||
def forward(self,x):
|
||||
x=F.relu(self.fc1(x))
|
||||
x=F.relu(self.fc2(x))
|
||||
x=self.fc3(x)
|
||||
return x
|
||||
|
||||
net=Net()
|
||||
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
|
||||
train_loss=[]
|
||||
for epoch in range(3):
|
||||
for batch_idx,(x,y) in enumerate(train_loader):
|
||||
x=x.view(x.size(0),28*28)
|
||||
out=net(x)
|
||||
y_onehot=one_hot(y)
|
||||
loss=F.mse_loss(out,y_onehot)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
train_loss.append(loss.item())
|
||||
|
||||
if batch_idx %10==0:
|
||||
print(epoch,batch_idx,loss.item())
|
||||
|
||||
plot_curve(train_loss)
|
||||
|
||||
total_correct=0
|
||||
for x,y in test_loader:
|
||||
x=x.view(x.size(0),28*28)
|
||||
out=net(x)
|
||||
pred=out.argmax(dim=1)
|
||||
correct=pred.eq(y).sum().float().item()
|
||||
total_correct+=correct
|
||||
total_num=len(test_loader.dataset)
|
||||
acc=total_correct/total_num
|
||||
print("test acc:",acc)
|
||||
|
||||
x,y=next(iter(test_loader))
|
||||
out=net(x.view(x.size(0),28*28))
|
||||
pred=out.argmax(dim=1)
|
||||
plot_image(x,pred,'test')
|
||||
Loading…
Reference in new issue