From 28326bb560755b09fb1aba1ac8d9ec56eee186ef Mon Sep 17 00:00:00 2001 From: p4w2aybsf <2363061197@qq.com> Date: Thu, 29 Apr 2021 17:10:30 +0800 Subject: [PATCH] Add 'train_deepnet.py' --- train_deepnet.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 train_deepnet.py diff --git a/train_deepnet.py b/train_deepnet.py new file mode 100644 index 0000000..9cdf3fb --- /dev/null +++ b/train_deepnet.py @@ -0,0 +1,21 @@ +# coding: utf-8 +import sys, os +sys.path.append(os.pardir) # 为了导入父目录而进行的设定 +import numpy as np +import matplotlib.pyplot as plt +from dataset.mnist import load_mnist +from deep_convnet import DeepConvNet +from common.trainer import Trainer + +(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False) + +network = DeepConvNet() +trainer = Trainer(network, x_train, t_train, x_test, t_test, + epochs=20, mini_batch_size=100, + optimizer='Adam', optimizer_param={'lr':0.001}, + evaluate_sample_num_per_epoch=1000) +trainer.train() + +# 保存参数 +network.save_params("deep_convnet_params.pkl") +print("Saved Network Parameters!")