|
|
import os
|
|
|
import argparse
|
|
|
import mindspore.dataset as ds
|
|
|
import mindspore.nn as nn
|
|
|
from mindspore import context, Model, load_checkpoint, load_param_into_net
|
|
|
from mindspore.common.initializer import Normal
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
|
|
import mindspore.dataset.vision.c_transforms as CV
|
|
|
import mindspore.dataset.transforms.c_transforms as C
|
|
|
from mindspore.dataset.vision import Inter
|
|
|
from mindspore.nn.metrics import Accuracy
|
|
|
from mindspore import dtype as mstype
|
|
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
|
|
from utils.dataset import download_dataset
|
|
|
|
|
|
def create_dataset(data_path, batch_size=32, repeat_size=1,num_parallel_workers=1):
|
|
|
""" 创建用于训练或测试的数据集"""
|
|
|
# 定义数据集
|
|
|
mnist_ds = ds.MnistDataset(data_path)
|
|
|
|
|
|
# 定义操作参数
|
|
|
resize_height, resize_width = 32, 32
|
|
|
rescale = 1.0 / 255.0
|
|
|
shift = 0.0
|
|
|
rescale_nml = 1 / 0.3081
|
|
|
shift_nml = -1 * 0.1307 / 0.3081
|
|
|
|
|
|
# 定义映射操作
|
|
|
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # 将图像调整为(32,32)
|
|
|
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # 对图像进行归一化、标准化操作,提升训练效率
|
|
|
rescale_op = CV.Rescale(rescale, shift) # 重新缩放、移位图像
|
|
|
hwc2chw_op = CV.HWC2CHW() # 对图像数据张量进行变换,张量形式由 高x宽x通道(HWC) 变为 通道x高x宽(CHW) ,方便进行数据训练。
|
|
|
type_cast_op = C.TypeCast(mstype.int32) # 将数据类型更改为int32来适合网络
|
|
|
|
|
|
# 在图像上应用映射操作
|
|
|
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
|
|
|
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
|
|
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
|
|
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
|
|
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
|
|
|
|
|
# 其他增强操作
|
|
|
buffer_size = 10000 # 混洗程度为 10000
|
|
|
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 随机将数据存放在可容纳10000张图片地址的内存中进行混洗
|
|
|
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) # 从混洗的10000张图片地址中抽取32张图片组成一个batch,参数batch. size表示每组包含的数据个数,现设置每组包含32个数据
|
|
|
mnist_ds = mnist_ds.repeat(repeat_size) # 将batch数据进行复制增强,参数repeat_size 表示数据集复制的数量
|
|
|
#先进行shuffle、batch操作,再进行repeat操作,这样能保证1个epoch内数据不重复。
|
|
|
return mnist_ds
|
|
|
|
|
|
|
|
|
class LeNet5(nn.Cell):
|
|
|
"""Lenet网络结构"""
|
|
|
# 定义所需的运算
|
|
|
def __init__(self, num_class=10, num_channel=1):
|
|
|
super(LeNet5, self).__init__()
|
|
|
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
|
|
|
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
|
|
|
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
|
|
|
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
|
|
|
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
|
|
|
self.relu = nn.ReLU()
|
|
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
|
self.flatten = nn.Flatten()
|
|
|
|
|
|
# 使用前面的运算来构建网络
|
|
|
def construct(self, x):
|
|
|
x = self.max_pool2d(self.relu(self.conv1(x)))
|
|
|
x = self.max_pool2d(self.relu(self.conv2(x)))
|
|
|
x = self.flatten(x)
|
|
|
x = self.relu(self.fc1(x))
|
|
|
x = self.relu(self.fc2(x))
|
|
|
x = self.fc3(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def train_net(network_model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
|
|
|
"""定义训练方法"""
|
|
|
print("============== Starting Training ==============")
|
|
|
# 加载训练数据集
|
|
|
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
|
|
|
# 进行训练
|
|
|
network_model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode)
|
|
|
|
|
|
|
|
|
def test_net(network, network_model, data_path):
|
|
|
"""定义评估方法"""
|
|
|
print("============== Starting Testing ==============")
|
|
|
# 加载保存的模型进行评估
|
|
|
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
|
|
|
# 将参数加载到网络
|
|
|
load_param_into_net(network, param_dict)
|
|
|
# 加载测试数据集
|
|
|
ds_eval = create_dataset(os.path.join(data_path, "test"))
|
|
|
acc = network_model.eval(ds_eval, dataset_sink_mode=False)
|
|
|
print("============== Accuracy:{} ==============".format(acc))
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
|
|
|
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
|
|
|
help='device where the code will be implemented (default: CPU)')
|
|
|
args = parser.parse_args()
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
|
|
dataset_sink_mode = not args.device_target == "CPU"
|
|
|
# 下载mnist数据集
|
|
|
#download_dataset()
|
|
|
# 学习率设定
|
|
|
lr = 0.01
|
|
|
momentum = 0.9
|
|
|
dataset_size = 1
|
|
|
mnist_path = "./MNIST_Data"
|
|
|
# 定义损失函数
|
|
|
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
|
|
train_epoch = 1
|
|
|
# 建立网络
|
|
|
net = LeNet5()
|
|
|
# 定义优化
|
|
|
net_opt = nn.Momentum(net.trainable_params(), lr, momentum)
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
|
|
|
# 保存网络模型和参数以进行子序列微调
|
|
|
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
|
|
|
# 通过训练和评估功能将图层分组为一个对象
|
|
|
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
|
|
|
|
train_net(model, train_epoch, mnist_path, dataset_size, ckpoint, dataset_sink_mode)
|
|
|
test_net(net, model, mnist_path)
|