You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
3.1 KiB
79 lines
3.1 KiB
# coding: utf-8
|
|
import sys, os
|
|
sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定
|
|
import numpy as np
|
|
from common.optimizer import *
|
|
|
|
class Trainer:
|
|
"""进行神经网络的训练的类
|
|
"""
|
|
def __init__(self, network, x_train, t_train, x_test, t_test,
|
|
epochs=20, mini_batch_size=100,
|
|
optimizer='SGD', optimizer_param={'lr':0.01},
|
|
evaluate_sample_num_per_epoch=None, verbose=True):
|
|
self.network = network
|
|
self.verbose = verbose
|
|
self.x_train = x_train
|
|
self.t_train = t_train
|
|
self.x_test = x_test
|
|
self.t_test = t_test
|
|
self.epochs = epochs
|
|
self.batch_size = mini_batch_size
|
|
self.evaluate_sample_num_per_epoch = evaluate_sample_num_per_epoch
|
|
|
|
# optimzer
|
|
optimizer_class_dict = {'sgd':SGD, 'momentum':Momentum, 'nesterov':Nesterov,
|
|
'adagrad':AdaGrad, 'rmsprpo':RMSprop, 'adam':Adam}
|
|
self.optimizer = optimizer_class_dict[optimizer.lower()](**optimizer_param)
|
|
|
|
self.train_size = x_train.shape[0]
|
|
self.iter_per_epoch = max(self.train_size / mini_batch_size, 1)
|
|
self.max_iter = int(epochs * self.iter_per_epoch)
|
|
self.current_iter = 0
|
|
self.current_epoch = 0
|
|
|
|
self.train_loss_list = []
|
|
self.train_acc_list = []
|
|
self.test_acc_list = []
|
|
|
|
def train_step(self):
|
|
batch_mask = np.random.choice(self.train_size, self.batch_size)
|
|
x_batch = self.x_train[batch_mask]
|
|
t_batch = self.t_train[batch_mask]
|
|
|
|
grads = self.network.gradient(x_batch, t_batch)
|
|
self.optimizer.update(self.network.params, grads)
|
|
|
|
loss = self.network.loss(x_batch, t_batch)
|
|
self.train_loss_list.append(loss)
|
|
if self.verbose: print("train loss:" + str(loss))
|
|
|
|
if self.current_iter % self.iter_per_epoch == 0:
|
|
self.current_epoch += 1
|
|
|
|
x_train_sample, t_train_sample = self.x_train, self.t_train
|
|
x_test_sample, t_test_sample = self.x_test, self.t_test
|
|
if not self.evaluate_sample_num_per_epoch is None:
|
|
t = self.evaluate_sample_num_per_epoch
|
|
x_train_sample, t_train_sample = self.x_train[:t], self.t_train[:t]
|
|
x_test_sample, t_test_sample = self.x_test[:t], self.t_test[:t]
|
|
|
|
train_acc = self.network.accuracy(x_train_sample, t_train_sample)
|
|
test_acc = self.network.accuracy(x_test_sample, t_test_sample)
|
|
self.train_acc_list.append(train_acc)
|
|
self.test_acc_list.append(test_acc)
|
|
|
|
if self.verbose: print("=== epoch:" + str(self.current_epoch) + ", train acc:" + str(train_acc) + ", test acc:" + str(test_acc) + " ===")
|
|
self.current_iter += 1
|
|
|
|
def train(self):
|
|
for i in range(self.max_iter):
|
|
self.train_step()
|
|
|
|
test_acc = self.network.accuracy(self.x_test, self.t_test)
|
|
|
|
if self.verbose:
|
|
print("=============== Final Test Accuracy ===============")
|
|
print("test acc:" + str(test_acc))
|
|
|