You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

127 lines
6.2 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import argparse
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore import context, Model, load_checkpoint, load_param_into_net
from mindspore.common.initializer import Normal
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.nn.metrics import Accuracy
from mindspore import dtype as mstype
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from utils.dataset import download_dataset
def create_dataset(data_path, batch_size=32, repeat_size=1,num_parallel_workers=1):
""" 创建用于训练或测试的数据集"""
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
# 定义操作参数
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# 定义映射操作
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # 将图像调整为3232
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # 对图像进行归一化、标准化操作,提升训练效率
rescale_op = CV.Rescale(rescale, shift) # 重新缩放、移位图像
hwc2chw_op = CV.HWC2CHW() # 对图像数据张量进行变换,张量形式由 高x宽x通道(HWC) 变为 通道x高x宽(CHW) ,方便进行数据训练。
type_cast_op = C.TypeCast(mstype.int32) # 将数据类型更改为int32来适合网络
# 在图像上应用映射操作
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
# 其他增强操作
buffer_size = 10000 # 混洗程度为 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 随机将数据存放在可容纳10000张图片地址的内存中进行混洗
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) # 从混洗的10000张图片地址中抽取32张图片组成一个batch参数batch. size表示每组包含的数据个数现设置每组包含32个数据
mnist_ds = mnist_ds.repeat(repeat_size) # 将batch数据进行复制增强参数repeat_size 表示数据集复制的数量
#先进行shuffle、batch操作再进行repeat操作这样能保证1个epoch内数据不重复。
return mnist_ds
class LeNet5(nn.Cell):
"""Lenet网络结构"""
# 定义所需的运算
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
# 使用前面的运算来构建网络
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def train_net(network_model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练方法"""
print("============== Starting Training ==============")
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
# 进行训练
network_model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode)
def test_net(network, network_model, data_path):
"""定义评估方法"""
print("============== Starting Testing ==============")
# 加载保存的模型进行评估
param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
# 将参数加载到网络
load_param_into_net(network, param_dict)
# 加载测试数据集
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = network_model.eval(ds_eval, dataset_sink_mode=False)
print("============== Accuracy:{} ==============".format(acc))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
dataset_sink_mode = not args.device_target == "CPU"
# 下载mnist数据集
#download_dataset()
# 学习率设定
lr = 0.01
momentum = 0.9
dataset_size = 1
mnist_path = "./MNIST_Data"
# 定义损失函数
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
train_epoch = 1
# 建立网络
net = LeNet5()
# 定义优化
net_opt = nn.Momentum(net.trainable_params(), lr, momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
# 保存网络模型和参数以进行子序列微调
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
# 通过训练和评估功能将图层分组为一个对象
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
train_net(model, train_epoch, mnist_path, dataset_size, ckpoint, dataset_sink_mode)
test_net(net, model, mnist_path)