From 3b8729781a30ae34edf7c2cfc37456a1c96b53f6 Mon Sep 17 00:00:00 2001 From: pl63o9ejz <2318715650@qq.com> Date: Wed, 28 Apr 2021 16:17:12 +0800 Subject: [PATCH] trainer.py --- mnist-master/common/trainer.py | 78 ++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 mnist-master/common/trainer.py diff --git a/mnist-master/common/trainer.py b/mnist-master/common/trainer.py new file mode 100644 index 0000000..1878105 --- /dev/null +++ b/mnist-master/common/trainer.py @@ -0,0 +1,78 @@ +# coding: utf-8 +import sys, os +sys.path.append(os.pardir) # 为了导入父目录的文件而进行的设定 +import numpy as np +from common.optimizer import * + +class Trainer: + """进行神经网络的训练的类 + """ + def __init__(self, network, x_train, t_train, x_test, t_test, + epochs=20, mini_batch_size=100, + optimizer='SGD', optimizer_param={'lr':0.01}, + evaluate_sample_num_per_epoch=None, verbose=True): + self.network = network + self.verbose = verbose + self.x_train = x_train + self.t_train = t_train + self.x_test = x_test + self.t_test = t_test + self.epochs = epochs + self.batch_size = mini_batch_size + self.evaluate_sample_num_per_epoch = evaluate_sample_num_per_epoch + + # optimzer + optimizer_class_dict = {'sgd':SGD, 'momentum':Momentum, 'nesterov':Nesterov, + 'adagrad':AdaGrad, 'rmsprpo':RMSprop, 'adam':Adam} + self.optimizer = optimizer_class_dict[optimizer.lower()](**optimizer_param) + + self.train_size = x_train.shape[0] + self.iter_per_epoch = max(self.train_size / mini_batch_size, 1) + self.max_iter = int(epochs * self.iter_per_epoch) + self.current_iter = 0 + self.current_epoch = 0 + + self.train_loss_list = [] + self.train_acc_list = [] + self.test_acc_list = [] + + def train_step(self): + batch_mask = np.random.choice(self.train_size, self.batch_size) + x_batch = self.x_train[batch_mask] + t_batch = self.t_train[batch_mask] + + grads = self.network.gradient(x_batch, t_batch) + self.optimizer.update(self.network.params, grads) + + loss = self.network.loss(x_batch, t_batch) + self.train_loss_list.append(loss) + if self.verbose: print("train loss:" + str(loss)) + + if self.current_iter % self.iter_per_epoch == 0: + self.current_epoch += 1 + + x_train_sample, t_train_sample = self.x_train, self.t_train + x_test_sample, t_test_sample = self.x_test, self.t_test + if not self.evaluate_sample_num_per_epoch is None: + t = self.evaluate_sample_num_per_epoch + x_train_sample, t_train_sample = self.x_train[:t], self.t_train[:t] + x_test_sample, t_test_sample = self.x_test[:t], self.t_test[:t] + + train_acc = self.network.accuracy(x_train_sample, t_train_sample) + test_acc = self.network.accuracy(x_test_sample, t_test_sample) + self.train_acc_list.append(train_acc) + self.test_acc_list.append(test_acc) + + if self.verbose: print("=== epoch:" + str(self.current_epoch) + ", train acc:" + str(train_acc) + ", test acc:" + str(test_acc) + " ===") + self.current_iter += 1 + + def train(self): + for i in range(self.max_iter): + self.train_step() + + test_acc = self.network.accuracy(self.x_test, self.t_test) + + if self.verbose: + print("=============== Final Test Accuracy ===============") + print("test acc:" + str(test_acc)) +