Merge pull request 'xiangguo' (#2) from branch-xg into main

branch-donghaoqian
pptw92c8a 7 months ago
commit c8a0ad4d29

@ -20,8 +20,11 @@ function(find_submodule_lib module name path)
) )
endfunction() endfunction()
# protobuf
function(ge_protobuf_generate c_var h_var) function(ge_protobuf_generate c_var h_var)
# common_protobuf_generateprotobuf
common_protobuf_generate(${CMAKE_BINARY_DIR}/proto/ge/proto ${c_var} ${h_var} ${ARGN}) common_protobuf_generate(${CMAKE_BINARY_DIR}/proto/ge/proto ${c_var} ${h_var} ${ARGN})
# chc_varh_var
set(${c_var} ${${c_var}} PARENT_SCOPE) set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE) set(${h_var} ${${h_var}} PARENT_SCOPE)
endfunction() endfunction()

@ -452,11 +452,14 @@ class _GeneratorWorkerMp(multiprocessing.Process):
""" """
def __init__(self, dataset, eof, max_rowsize, queue_size, ppid): def __init__(self, dataset, eof, max_rowsize, queue_size, ppid):
# 初始化一个多进程队列,用于存储索引
self.idx_queue = multiprocessing.Queue(queue_size) self.idx_queue = multiprocessing.Queue(queue_size)
# 如果启用了共享内存,则初始化一个共享队列,否则初始化一个多进程队列
if get_enable_shared_mem(): if get_enable_shared_mem():
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize) self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
else: else:
self.res_queue = multiprocessing.Queue(queue_size) self.res_queue = multiprocessing.Queue(queue_size)
# 设置队列的_joincancelled属性为True表示在进程退出时队列不会阻塞
self.idx_queue._joincancelled = True # pylint: disable=W0212 self.idx_queue._joincancelled = True # pylint: disable=W0212
self.res_queue._joincancelled = True # pylint: disable=W0212 self.res_queue._joincancelled = True # pylint: disable=W0212
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid)) super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid))
@ -465,6 +468,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
""" """
Put function for worker index queue. Never block. Raise queue.Full on failure. Put function for worker index queue. Never block. Raise queue.Full on failure.
""" """
# 将item放入idx_queue队列中不阻塞如果失败则抛出queue.Full异常
self.idx_queue.put_nowait(item) self.idx_queue.put_nowait(item)
def get(self): def get(self):
@ -476,12 +480,19 @@ class _GeneratorWorkerMp(multiprocessing.Process):
return self.res_queue.get(timeout=30) return self.res_queue.get(timeout=30)
def queue_empty(self): def queue_empty(self):
# 检查idx_queue是否为空
if not self.idx_queue.empty(): if not self.idx_queue.empty():
# 如果不为空,记录警告日志
logger.warning("idx_queue is not empty.") logger.warning("idx_queue is not empty.")
# 返回False
return False return False
# 检查res_queue是否为空
if not self.res_queue.empty(): if not self.res_queue.empty():
# 如果不为空,记录警告日志
logger.warning("res_queue is not empty.") logger.warning("res_queue is not empty.")
# 返回False
return False return False
# 如果两个队列都为空返回True
return True return True
def __del__(self): def __del__(self):
@ -632,14 +643,17 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
python_multiprocessing=True, max_rowsize=6): python_multiprocessing=True, max_rowsize=6):
# 调用父类的初始化方法
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
# 如果source是zip类型则将其转换为列表
if isinstance(source, builtins.zip): if isinstance(source, builtins.zip):
# Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array. # Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array.
self.source = [item for item in source] self.source = [item for item in source]
else: else:
self.source = source self.source = source
self.prepared_source = None # source to be sent to C++ self.prepared_source = None # source to be sent to C++
# 如果self.operator_mixed属性为True则将num_parallel_workers设置为1
if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True: if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
self.num_parallel_workers = 1 self.num_parallel_workers = 1
logger.warning( logger.warning(
@ -650,56 +664,78 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
self.python_multiprocessing = python_multiprocessing self.python_multiprocessing = python_multiprocessing
# 将column_names转换为列表
self.column_names = to_list(column_names) self.column_names = to_list(column_names)
# 如果column_types不为空则将其转换为detypelist类型
if column_types is not None: if column_types is not None:
self.column_types = mstypelist_to_detypelist(column_types) self.column_types = mstypelist_to_detypelist(column_types)
else: else:
self.column_types = [] self.column_types = []
self.schema = schema self.schema = schema
# 如果schema不为空则将其转换为Schema类型
if schema is not None: if schema is not None:
# 如果schema不为空则将其赋值给self.schema
self.schema = schema self.schema = schema
# 如果schema不是Schema类型则将其转换为Schema类型
if not isinstance(schema, Schema): if not isinstance(schema, Schema):
self.schema = Schema(schema) self.schema = Schema(schema)
# Move get dataset_size by len from parse to here, because self.source will # Move get dataset_size by len from parse to here, because self.source will
# lose attribution of '__len__' after deepcopy. # lose attribution of '__len__' after deepcopy.
self.source_len = -1 # unknown self.source_len = -1 # unknown
# 如果self.source有__len__属性则获取self.source的长度
if hasattr(self.source, "__len__"): if hasattr(self.source, "__len__"):
self.source_len = len(self.source) self.source_len = len(self.source)
# 设置最大行大小
self.max_rowsize = max_rowsize self.max_rowsize = max_rowsize
# 设置采样函数为None
self.sample_fn = None self.sample_fn = None
def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
# 深度复制当前对象并传入一个字典memodict用于存储已经复制的对象
if id(self) in memodict: if id(self) in memodict:
# 如果当前对象的id已经在memodict中则直接返回该对象
return memodict[id(self)] return memodict[id(self)]
# 否则调用__safe_deepcopy__方法进行深度复制并传入memodict和exclude参数
new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__")) new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
sample_fn = None sample_fn = None
# 如果新对象的sampler属性不为空并且self.source对象具有__getitem__方法
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
# The reason why there is a try catch here is because when the new op is being constructed with shared # The reason why there is a try catch here is because when the new op is being constructed with shared
# memory enabled, there will be an exception thrown if there is not enough shared memory available # memory enabled, there will be an exception thrown if there is not enough shared memory available
# 如果self.source_len为-1则抛出RuntimeError异常因为尝试构造一个随机访问的数据集需要__len__方法
if self.source_len == -1: if self.source_len == -1:
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!") raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
try: try:
# 如果新对象的num_parallel_workers大于1则调用__validate_memory_usage方法进行内存使用验证
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
self.__validate_memory_usage() self.__validate_memory_usage()
# 创建一个SamplerFn对象用于并行采样
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing, sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
self.max_rowsize) self.max_rowsize)
# 将新对象的prepared_source属性设置为_cpp_sampler_fn_mp函数用于并行采样
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else: else:
# 否则将新对象的prepared_source属性设置为_cpp_sampler_fn函数用于单线程采样
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
# 将新对象的sample_fn属性设置为sample_fn
new_op.sample_fn = sample_fn new_op.sample_fn = sample_fn
except RuntimeError as e: except RuntimeError as e:
# 如果抛出RuntimeError异常则抛出Exception异常并传入异常信息
raise Exception(str(e)) raise Exception(str(e))
else: else:
try: try:
# 否则将新对象的sampler属性设置为Nonesample_fn属性设置为sample_fn
new_op.sampler = None new_op.sampler = None
new_op.sample_fn = sample_fn new_op.sample_fn = sample_fn
# 将新对象的source_len属性设置为min(new_op.source_len, new_op.num_samples)如果new_op.num_samples不为0否则设置为new_op.source_len
new_op.source_len = min(new_op.source_len, new_op.source_len = min(new_op.source_len,
new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len
# 遍历self.source对象
iter(self.source) iter(self.source)
except TypeError: except TypeError:
# Use generator function if input callable # Use generator function if input callable
@ -711,19 +747,26 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
return new_op return new_op
# 判断是否被洗牌
def is_shuffled(self): def is_shuffled(self):
return self.sampler.is_shuffled() return self.sampler.is_shuffled()
# 判断是否被分片
def is_sharded(self): def is_sharded(self):
return self.sampler.is_sharded() return self.sampler.is_sharded()
# 解析
def parse(self, children=None): def parse(self, children=None):
# 如果schema为空则返回GeneratorNode对象
if self.schema is None: if self.schema is None:
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len, return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
self.sampler, self.num_parallel_workers) self.sampler, self.num_parallel_workers)
# 获取schema
schema = self.schema schema = self.schema
# 如果schema是Schema类型则获取cpp_schema
if isinstance(schema, Schema): if isinstance(schema, Schema):
schema = self.schema.cpp_schema schema = self.schema.cpp_schema
# 返回GeneratorNode对象
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler, return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
self.num_parallel_workers) self.num_parallel_workers)
@ -735,24 +778,37 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
# if use num_parallel_workers is to large when python_multiprocessing=True which would cause # if use num_parallel_workers is to large when python_multiprocessing=True which would cause
# OOM error get the num_shards # OOM error get the num_shards
valid_num_shards = 1 valid_num_shards = 1
# 判断self.sampler是否为samplers.DistributedSampler类型
if isinstance(self.sampler, samplers.DistributedSampler): if isinstance(self.sampler, samplers.DistributedSampler):
# 如果是则将self.sampler的num_shards赋值给valid_num_shards
valid_num_shards = self.sampler.num_shards valid_num_shards = self.sampler.num_shards
# 否则判断self.num_shards是否为None
elif self.num_shards is not None: elif self.num_shards is not None:
# 如果不是则将self.num_shards赋值给valid_num_shards
valid_num_shards = self.num_shards valid_num_shards = self.num_shards
# get process memory usage # get process memory usage
# 获取当前进程
process = psutil.Process(os.getpid()) process = psutil.Process(os.getpid())
# 获取当前进程的内存信息
process_memory = process.memory_info().rss process_memory = process.memory_info().rss
# 获取系统内存的空闲量
sys_memory_free = psutil.virtual_memory().free sys_memory_free = psutil.virtual_memory().free
# 计算可能使用的总内存量
total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards
# 如果总内存可能使用的内存量除以系统可用内存大于0.85
if total_memory_maybe_used / sys_memory_free > 0.85: if total_memory_maybe_used / sys_memory_free > 0.85:
# 计算有效的worker数量即系统可用内存乘以0.85除以有效的shards数量再除以每个进程的内存
valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory) valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory)
# 如果有效的worker数量小于等于0则将其设置为1
valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker
# 构造警告信息提示用户num_parallel_workers设置过大可能会导致内存占用过高或OOM建议将其减小到valid_num_worker或更小
info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \ info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \
"occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \ "occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \
"to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers, "to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers,
valid_num_worker) valid_num_worker)
# 打印警告信息
logger.warning(info) logger.warning(info)
@ -764,37 +820,55 @@ class _NumpySlicesDataset:
def __init__(self, data, column_list=None): def __init__(self, data, column_list=None):
self.column_list = None self.column_list = None
# Convert dict data into tuple # Convert dict data into tuple
# 判断data是否为字典类型
if isinstance(data, dict): if isinstance(data, dict):
# 如果是字典类型则调用process_dict方法处理
data = self.process_dict(data) data = self.process_dict(data)
# 判断data是否为元组类型
if isinstance(data, tuple): if isinstance(data, tuple):
# 如果是元组类型则将self.data初始化为空元组
self.data = () self.data = ()
# 获取data的长度
data_len = len(data) data_len = len(data)
# 遍历data中的每个元素
for i in range(data_len): for i in range(data_len):
# 将data中的每个元素转换为numpy数组并添加到self.data中
self.data = self.data + (np.array(data[i]),) self.data = self.data + (np.array(data[i]),)
else: else:
# 如果data不是元组类型则将data转换为numpy数组并添加到self.data中
self.data = (np.array(data),) self.data = (np.array(data),)
# check whether the data length in each column is equal # check whether the data length in each column is equal
# 获取每个data_item的长度
data_len = [len(data_item) for data_item in self.data] data_len = [len(data_item) for data_item in self.data]
# 如果每个data_item的长度不相等则抛出ValueError异常
if data_len[1:] != data_len[:-1]: if data_len[1:] != data_len[:-1]:
raise ValueError("Data length in each column is not equal.") raise ValueError("Data length in each column is not equal.")
# Init column_name # Init column_name
# 如果column_list不为空则将self.column_list赋值为column_list
if column_list is not None: if column_list is not None:
self.column_list = column_list self.column_list = column_list
# 如果self.column_list为空则将self.column_list赋值为空列表
elif self.column_list is None: elif self.column_list is None:
self.column_list = [] self.column_list = []
# 获取data的列数
column_num = len(self.data) column_num = len(self.data)
# 遍历列数,将"column_" + str(i)添加到self.column_list中
for i in range(column_num): for i in range(column_num):
self.column_list.append("column_" + str(i)) self.column_list.append("column_" + str(i))
def __getitem__(self, index): def __getitem__(self, index):
# 获取指定索引的数据行
data_row = [d[index, ...] for d in self.data] data_row = [d[index, ...] for d in self.data]
# 将数据行转换为元组
data_res = tuple(data_row) data_res = tuple(data_row)
# 返回数据行
return data_res return data_res
def __len__(self): def __len__(self):
# 返回data的第一个元素的长度
return len(self.data[0]) return len(self.data[0])
def process_dict(self, input_data): def process_dict(self, input_data):
@ -802,24 +876,29 @@ class _NumpySlicesDataset:
Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
""" """
# Convert pandas like dict(has "values" column) into General dict # Convert pandas like dict(has "values" column) into General dict
# 将pandas样式的字典有"values"列)转换为通用字典
data_keys = list(input_data.keys()) data_keys = list(input_data.keys())
# 获取字典的第一个键对应的值
data_col = input_data[data_keys[0]] data_col = input_data[data_keys[0]]
# 如果值有values属性则将其转换为通用字典
if hasattr(data_col, "values"): if hasattr(data_col, "values"):
new_dict = {} new_dict = {}
for key in data_keys: for key in data_keys:
# 将字典中的键对应的值转换为列表
item1 = input_data.pop(key) item1 = input_data.pop(key)
new_dict[key] = item1.values new_dict[key] = item1.values
# 将转换后的字典赋值给input_data
input_data = new_dict input_data = new_dict
# Convert the data in dict into tuple # Convert the data in dict into tuple
data = () data = () # 初始化一个空元组
keys = list(input_data.keys()) keys = list(input_data.keys()) # 将输入数据的键转换为列表
self.column_list = keys self.column_list = keys # 将键列表赋值给实例变量column_list
for key in keys: for key in keys: # 遍历键列表
value = input_data[key] value = input_data[key] # 获取键对应的值
data = data + (list(value),) data = data + (list(value),) # 将值转换为列表,并添加到元组中
return data return data # 返回元组
class NumpySlicesDataset(GeneratorDataset): class NumpySlicesDataset(GeneratorDataset):
@ -909,7 +988,9 @@ class NumpySlicesDataset(GeneratorDataset):
@check_numpyslicesdataset @check_numpyslicesdataset
def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
num_shards=None, shard_id=None): num_shards=None, shard_id=None):
# 创建一个_NumpySlicesDataset对象传入data和column_names参数
dataset = _NumpySlicesDataset(data, column_names) dataset = _NumpySlicesDataset(data, column_names)
# 调用父类的__init__方法传入dataset、column_names、num_samples、num_parallel_workers、shuffle、sampler、num_shards和shard_id参数
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id) num_shards=num_shards, shard_id=shard_id)

@ -17,65 +17,71 @@ from ..._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithCheck from ..primitive import prim_attr_register, PrimitiveWithCheck
from .. import signature as sig from .. import signature as sig
class UpdateCache(PrimitiveWithCheck): class UpdateCache(PrimitiveWithCheck):
""" """
Update the value fo input_x, similar to ScatterNdUpdate. 更新 input_x 的值类似于 ScatterNdUpdate
The difference is that UpdateCache will not update when indices < 0 or indices >= max_num. 不同之处在于UpdateCache indices < 0 indices >= max_num 时不会更新
Inputs: Inputs:
- **input_x** (Parameter) - Parameter which is going to be updated. - **input_x** (Parameter) - 将要更新的参数
- **indices** (Tensor) - Update indices of input_x. - **indices** (Tensor) - input_x 的更新索引
- **updates** (Tensor) - The update values. - **updates** (Tensor) - 更新值
Outputs: Outputs:
- **out** (Tensor) - Returns a [1] Tensor, which is not useful. - **out** (Tensor) - 返回一个 [1] 的张量这个张量没有用处
""" """
# 定义函数签名,指定输入参数的类型和读写权限
__mindspore_signature__ = ( __mindspore_signature__ = (
# 定义输入参数input_x类型为T读写权限为写
sig.make_sig('input_x', sig.sig_rw.RW_WRITE, sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
dtype=sig.sig_dtype.T), dtype=sig.sig_dtype.T),
# 定义输入参数indices类型为T1
sig.make_sig('indices', dtype=sig.sig_dtype.T1), sig.make_sig('indices', dtype=sig.sig_dtype.T1),
# 定义输入参数updates类型为T
sig.make_sig('updates', dtype=sig.sig_dtype.T), sig.make_sig('updates', dtype=sig.sig_dtype.T),
# 定义输入参数max_num类型为T1
sig.make_sig('max_num', dtype=sig.sig_dtype.T1) sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init UpdateCache""" """初始化 UpdateCache"""
# 初始化输入和输出名称
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
outputs=['out']) outputs=['out'])
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
# 检查输入形状
return [1] return [1]
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
# 检查输入数据类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"indices", indices_dtype, mstype.int_type, self.name) "indices", indices_dtype, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype
class SubAndFilter(PrimitiveWithCheck): class SubAndFilter(PrimitiveWithCheck):
""" """
Dynamic kernel, sub an offset and 动态内核减去一个偏移量并返回在范围 [0, max_num) 内的元素
return the elements which in range [0, max_num).
Inputs: Inputs:
- **input_x** (Tensor) - Input tensor. - **input_x** (Tensor) - 输入张量
- **max_num** (Int) - The max value of element that after sub `offset`. - **max_num** (Int) - 减去 `offset` 后元素的最大值
- **offset** (int) - Specifies the offset value of this `input_x`. - **offset** (int) - 指定此 `input_x` 的偏移值
Outputs: Outputs:
tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx. tuple(Tensor), 2 个张量组成的元组filter_res filter_idx
- **filter_res** (Tensor) - The result that `input_x` minus `offset`, - **filter_res** (Tensor) - `input_x` 减去 `offset` 的结果
and return which in the range [0, max_num). 并返回在范围 [0, max_num) 内的值
- **filter_idx** (Tensor) - A tensor containing indices of elements in the input - **filter_idx** (Tensor) - 一个张量包含与输出张量对应的输入元素的索引
coressponding to the output tensor.
Supported Platforms: Supported Platforms:
`CPU` `CPU`
Examples: Examples:
>>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32) >>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32)
>>> max_num = 10 >>> max_num = 10
@ -87,35 +93,38 @@ class SubAndFilter(PrimitiveWithCheck):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init SubAndFilter""" """初始化 SubAndFilter"""
# 初始化输入和输出名称
self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'], self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'],
outputs=['sub_res', 'sub_idx']) outputs=['sub_res', 'sub_idx'])
def check_shape(self, input_x_shape, max_num_shape, offset_shape): def check_shape(self, input_x_shape, max_num_shape, offset_shape):
# 检查输入形状
return ((-1,), (-1,)) return ((-1,), (-1,))
def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype): def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype):
# 检查输入数据类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"input_x", input_x_dtype, mstype.int_type, self.name) "input_x", input_x_dtype, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype
class MapUniform(PrimitiveWithCheck): class MapUniform(PrimitiveWithCheck):
""" """
Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`. 通过公式映射一个张量value = key % `group_num` * `per_group_size` + key // `group_num`
Inputs: Inputs:
- **input** (Tensor) - Input Tensor. - **input** (Tensor) - 输入张量
- **per_group_size** (int) - The size of each group. - **per_group_size** (int) - 每个组的大小
- **group_num** (int) - The number of group. - **group_num** (int) - 组的数量
Outputs: Outputs:
Tensor, has the same dtype and shape as the `input`. Tensor具有与 `input` 相同的 dtype 和形状
Supported Platforms: Supported Platforms:
`CPU` `CPU`
Examples: Examples:
>>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7])) >>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]))
>>> per_group_size = 4 >>> per_group_size = 4
@ -125,33 +134,34 @@ class MapUniform(PrimitiveWithCheck):
>>> print(output) >>> print(output)
[0, 4, 1, 5, 2, 6, 3, 7] [0, 4, 1, 5, 2, 6, 3, 7]
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init MapUniform""" """初始化 MapUniform"""
self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'], self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'],
outputs=['output']) outputs=['output'])
def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype): def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype):
"""检查输入数据类型"""
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"input", input_dtype, mstype.int_type, self.name) "input", input_dtype, mstype.int_type, self.name)
validator.check_value_type( validator.check_value_type(
'per_group_size', per_group_size_dtype, [mstype.Int], self.name) 'per_group_size', per_group_size_dtype, [mstype.Int], self.name)
validator.check_value_type( validator.check_value_type(
'group_num', group_num_dtype, [mstype.Int], self.name) 'group_num', group_num_dtype, [mstype.Int], self.name)
class CacheSwapTable(PrimitiveWithCheck): class CacheSwapTable(PrimitiveWithCheck):
""" """
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. 删除一个哈希映射条目并插入一个新键到哈希映射中返回删除条目的键和值
Inputs: Inputs:
- **cache_table** (Parameter) - The cache table which is on device. - **cache_table** (Parameter) - 在设备上的缓存表
- **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped. - **swap_cache_idx** (Tensor) - 需要交换的表索引-1 被跳过
- **miss_value** (int) - The values which arg going to swap into cache table. - **miss_value** (int) - 将要交换到缓存表的值
Outputs: Outputs:
- **old_value** (Tensor) - The values which are swapped out. - **old_value** (Tensor) - 被交换出去的值
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('cache_table', sig.sig_rw.RW_WRITE, sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
@ -159,31 +169,35 @@ class CacheSwapTable(PrimitiveWithCheck):
sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1), sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1),
sig.make_sig('miss_value', dtype=sig.sig_dtype.T) sig.make_sig('miss_value', dtype=sig.sig_dtype.T)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init CacheSwapTable""" """初始化 CacheSwapTable"""
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
outputs=['old_value']) outputs=['old_value'])
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
# 检查cache_table_shape的长度是否为2如果不是则抛出ValueError异常
if len(cache_table_shape) != 2: if len(cache_table_shape) != 2:
raise ValueError( raise ValueError(
"cache table shape must be 2, but got %d" % len(cache_table_shape)) "cache table shape must be 2, but got %d" % len(cache_table_shape))
# 返回miss_value_shape
return miss_value_shape return miss_value_shape
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
# 检查swap_cache_idx_dtype是否为mstype.int_type如果不是则抛出ValueError异常
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
# 返回miss_value_dtype
return miss_value_dtype return miss_value_dtype
class MapCacheIdx(PrimitiveWithCheck): class MapCacheIdx(PrimitiveWithCheck):
""" """
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. MapCacheIdx SearchCacheIdxCacheSwapHashmap UpdateCache 合并在一起
When input an indices tensor, it will output the cache indices which search in hashmap. 当输入一个索引张量时它将输出在哈希映射中搜索的缓存索引
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('hashmap', sig.sig_rw.RW_WRITE, sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
@ -193,56 +207,69 @@ class MapCacheIdx(PrimitiveWithCheck):
sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T), sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T) sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init MapCacheIdx""" """初始化 MapCacheIdx"""
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
def __check__(self, hashmap, indices, step, emb_max_num, offset): def __check__(self, hashmap, indices, step, emb_max_num, offset):
# 获取hashmap的形状
hashmap_shape = hashmap['shape'] hashmap_shape = hashmap['shape']
# 如果hashmap的维度不是2则抛出异常
if len(hashmap_shape) != 2: if len(hashmap_shape) != 2:
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
"but got %d." % len(hashmap_shape)) "but got %d." % len(hashmap_shape))
# 设置输出的形状
out_shape = (indices['shape'], -1, -1, -1) out_shape = (indices['shape'], -1, -1, -1)
# 获取hashmap和indices的数据类型
hashmap_dtype = hashmap['dtype'] hashmap_dtype = hashmap['dtype']
indices_dtype = indices['dtype'] indices_dtype = indices['dtype']
# 将数据类型存入字典
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
# 检查数据类型是否相同且有效
validator.check_tensors_dtypes_same_and_valid( validator.check_tensors_dtypes_same_and_valid(
args, mstype.int_type, self.name) args, mstype.int_type, self.name)
# 设置输出的数据类型
out_dtype = (hashmap_dtype, hashmap_dtype, out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype) hashmap_dtype, hashmap_dtype)
# 设置输出的字典
out = {'shape': out_shape, out = {'shape': out_shape,
'dtype': out_dtype, 'dtype': out_dtype,
'value': None} 'value': None}
# 如果indices中有max_shape则设置输出的max_shape
if 'max_shape' in indices: if 'max_shape' in indices:
out['max_shape'] = (indices['max_shape'], indices['max_shape'], out['max_shape'] = (indices['max_shape'], indices['max_shape'],
indices['max_shape'], indices['max_shape']) indices['max_shape'], indices['max_shape'])
# 否则设置输出的max_shape为indices的形状
else: else:
out['max_shape'] = (indices['shape'], indices['shape'], out['max_shape'] = (indices['shape'], indices['shape'],
indices['shape'], indices['shape']) indices['shape'], indices['shape'])
# 如果indices中有min_shape则设置输出的min_shape
if 'min_shape' in indices: if 'min_shape' in indices:
out['min_shape'] = (indices['min_shape'], 0, 0, 0) out['min_shape'] = (indices['min_shape'], 0, 0, 0)
# 否则设置输出的min_shape为(0, 0, 0, 0)
else: else:
out['min_shape'] = (0, 0, 0, 0) out['min_shape'] = (0, 0, 0, 0)
# 返回输出的字典
return out return out
class DynamicAssign(PrimitiveWithCheck): class DynamicAssign(PrimitiveWithCheck):
""" """
Assigns `Parameter` with a value, the `value` can have a dynamic shape. `Parameter` 与值分配`value` 可以具有动态形状
Inputs: Inputs:
- **variable** (Parameter) - The `Parameter`. - **variable** (Parameter) - `Parameter`
- **value** (Tensor) - The value to be assigned. - **value** (Tensor) - 要分配的值
Outputs: Outputs:
Tensor, has the same type as original `variable`. Tensor具有与原始 `variable` 相同的类型
Supported Platforms: Supported Platforms:
`CPU` `CPU`
""" """
@ -250,41 +277,42 @@ class DynamicAssign(PrimitiveWithCheck):
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('value', dtype=sig.sig_dtype.T) sig.make_sig('value', dtype=sig.sig_dtype.T)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
def check_dtype(self, variable, value): def check_dtype(self, variable, value):
# 检查变量是否为mstype.type_refkey
if variable != mstype.type_refkey: if variable != mstype.type_refkey:
# 检查变量是否为mstype.number_type类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"variable", variable, mstype.number_type, self.name) "variable", variable, mstype.number_type, self.name)
# 检查value是否为mstype.number_type类型
validator.check_scalar_or_tensor_types_same( validator.check_scalar_or_tensor_types_same(
{"value": value}, mstype.number_type, self.name) {"value": value}, mstype.number_type, self.name)
class PadAndShift(PrimitiveWithCheck): class PadAndShift(PrimitiveWithCheck):
""" """
Pad a tensor with -1, and shift with a length. -1 填充张量并按长度进行移位
Inputs: Inputs:
- **input_x** (Tensor) - The input Tensor, which will be copied - **input_x** (Tensor) - 输入张量将被复制到 `output`
to `output`. - **cum_sum_arr** (Tensor) - cum_sum_arr 的最后一个值是输出张量的填充长度
- **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is cum_sum_arr[shift_idx] 是开始移位cum_sum_arr[shift_idx+1] 是结束
the pad length of output tensor, cum_sum_arr[shift_idx] is - **shift_idx** (Int) - cum_sum_arr 的索引
the start to shift, and cum_sum_arr[shift_idx+1] is the end. 如果使用 PythonPadAndShift
- **shift_idx** (Int) - The idx of cum_sum_arr.
if use python, PadAndShift is:
output = [-1] * cum_sum_arr[-1] output = [-1] * cum_sum_arr[-1]
start = cum_sum_arr[shift_idx] start = cum_sum_arr[shift_idx]
end = cum_sum_arr[shift_idx + 1] end = cum_sum_arr[shift_idx + 1]
output[start:end] = input_x[:(end-start)] output[start:end] = input_x[:(end-start)]
Outputs: Outputs:
Tensor, has the same type as original `variable`. Tensor具有与原始 `variable` 相同的类型
Supported Platforms: Supported Platforms:
`CPU` `CPU`
Examples: Examples:
>>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32) >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32)
>>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32) >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32)
@ -296,11 +324,14 @@ class PadAndShift(PrimitiveWithCheck):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
# 初始化输入输出名称
self.init_prim_io_names( self.init_prim_io_names(
inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output']) inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output'])
def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape): def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape):
# 检查输入形状
return input_x_shape return input_x_shape
def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype): def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype):
return input_x_dtype # 检查输入数据类型
return input_x_dtype

@ -12,39 +12,39 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Operators for TensorArray.""" """Operators for TensorArray."""
import mindspore as ms import mindspore as ms
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
class TensorArray(PrimitiveWithInfer): class TensorArray(PrimitiveWithInfer):
r""" r"""
TensorArrayCreate used to create a TensorArray and return an unique handle. TensorArrayCreate used to create a TensorArray and return an unique handle.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Args: Args:
dtype (mindspore.dtype): the data type in the TensorArray. dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray.
dynamic_size (bool): If true the TensorArray can increase the size. Default: True. dynamic_size (bool): If true the TensorArray can increase the size. Default: True.
size (int): The size of the TensorArray if dynamic_size = False. size (int): The size of the TensorArray if dynamic_size = False.
name (string): the name of this TensorArray. Default: "TA". name (string): the name of this TensorArray. Default: "TA".
Inputs: Inputs:
None. None.
Outputs: Outputs:
- **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray. - **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -55,6 +55,7 @@ class TensorArray(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"): def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
"""初始化TensorArray类设置参数和属性."""
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
validator.check_int(size, 0, Rel.GE, "size", self.name) validator.check_int(size, 0, Rel.GE, "size", self.name)
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
@ -63,32 +64,34 @@ class TensorArray(PrimitiveWithInfer):
self.add_prim_attr('size', size) self.add_prim_attr('size', size)
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
self.add_prim_attr('name', name) self.add_prim_attr('name', name)
def infer_shape(self): def infer_shape(self):
"""推断输出形状."""
return () return ()
def infer_dtype(self): def infer_dtype(self):
"""推断输出数据类型."""
return mstype.int64 return mstype.int64
class TensorArrayWrite(PrimitiveWithInfer): class TensorArrayWrite(PrimitiveWithInfer):
r""" r"""
TensorArrayWrite used to write tensor into a created TensorArray. TensorArrayWrite used to write tensor into a created TensorArray.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **index** (Tensor[int64]) - The position to write. - **index** (Tensor[int64]) - The position to write.
- **value** (Tensor) - The value to add into the TensorArray. - **value** (Tensor) - The value to add into the TensorArray.
- **handle** (Tensor[int64]) - The handle pointed to the TensorArray. - **handle** (Tensor[int64]) - The handle pointed to the TensorArray.
Outputs: Outputs:
None. None.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -99,39 +102,42 @@ class TensorArrayWrite(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayWrite类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape, index_shape, value_shape): def infer_shape(self, handle_shape, index_shape, value_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type, index_type, value_type): def infer_dtype(self, handle_type, index_type, value_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
validator.check_type_name("index", index_type, (int, ms.int64), self.name) validator.check_type_name("index", index_type, (int, ms.int64), self.name)
validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name)
return mstype.int64 return mstype.int64
class TensorArrayRead(PrimitiveWithInfer): class TensorArrayRead(PrimitiveWithInfer):
r""" r"""
TensorArrayRead used to read tensor from a created TensorArray by the given index. TensorArrayRead used to read tensor from a created TensorArray by the given index.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Args: Args:
dtype (mindspore.dtype): the data type in the TensorArray. dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs: Inputs:
- **index** (Tensor[int64]) - The position to read. - **index** (Tensor[int64]) - The position to read.
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
- **output** (Tensor) - the value in position index. - **output** (Tensor) - the value in position index.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -146,38 +152,41 @@ class TensorArrayRead(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape): def __init__(self, dtype, element_shape):
"""初始化TensorArrayRead类设置参数和属性."""
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
self.add_prim_attr('element_shape', element_shape) self.add_prim_attr('element_shape', element_shape)
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
self.dtype = dtype self.dtype = dtype
self.shape = element_shape self.shape = element_shape
def infer_shape(self, handle_shape, index_shape): def infer_shape(self, handle_shape, index_shape):
"""推断输出形状."""
return self.shape return self.shape
def infer_dtype(self, handle_type, index_type): def infer_dtype(self, handle_type, index_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
validator.check_type_name("index", index_type, (int, ms.int64), self.name) validator.check_type_name("index", index_type, (int, ms.int64), self.name)
return self.dtype return self.dtype
class TensorArrayClose(PrimitiveWithInfer): class TensorArrayClose(PrimitiveWithInfer):
r""" r"""
TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted. TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
None. None.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -188,32 +197,35 @@ class TensorArrayClose(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayClose类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
class TensorArrayClear(PrimitiveWithInfer): class TensorArrayClear(PrimitiveWithInfer):
r""" r"""
TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable. TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
None. None.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -224,36 +236,39 @@ class TensorArrayClear(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayClear类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
class TensorArrayStack(Primitive): class TensorArrayStack(Primitive):
r""" r"""
TensorArrayStack used to stack the tensors in a created TensorArray into one tensor. TensorArrayStack used to stack the tensors in a created TensorArray into one tensor.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Args: Args:
dtype (mindspore.dtype): the data type in the TensorArray. dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
- **output** (Tensor) - the stacked value from the TensorArray. - **output** (Tensor) - the stacked value from the TensorArray.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -269,31 +284,31 @@ class TensorArrayStack(Primitive):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape, dynamic_size, size): def __init__(self, dtype, element_shape, dynamic_size, size):
"""Initialize TensorArrayStack""" """初始化TensorArrayStack类设置参数和属性."""
self.init_prim_io_names(inputs=[''], outputs=['output']) self.init_prim_io_names(inputs=[''], outputs=['output'])
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
self.add_prim_attr('element_shape', element_shape) self.add_prim_attr('element_shape', element_shape)
self.add_prim_attr('is_dynamic_shape', dynamic_size) self.add_prim_attr('is_dynamic_shape', dynamic_size)
self.add_prim_attr('size', size) self.add_prim_attr('size', size)
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
class TensorArraySize(PrimitiveWithInfer): class TensorArraySize(PrimitiveWithInfer):
r""" r"""
TensorArraySize used to get the logical size of the created TensorArray. TensorArraySize used to get the logical size of the created TensorArray.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
- **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray. - **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -304,34 +319,37 @@ class TensorArraySize(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArraySize类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
class TensorArrayGather(PrimitiveWithInfer): class TensorArrayGather(PrimitiveWithInfer):
r""" r"""
TensorArrayGather used to gather specified elements from the created TensorArray. TensorArrayGather used to gather specified elements from the created TensorArray.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Args: Args:
dtype (mindspore.dtype): the data type in the TensorArray. dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
- **indices** (mindspore.int32) - The locations of the gathered elements. - **indices** (mindspore.int32) - The locations of the gathered elements.
Outputs: Outputs:
- **output** (Tensor) - The gathered value from the TensorArray. - **output** (Tensor) - The gathered value from the TensorArray.
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -344,17 +362,20 @@ class TensorArrayGather(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape): def __init__(self, dtype, element_shape):
"""初始化TensorArrayGather类设置参数和属性."""
self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value']) self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value'])
self.add_prim_attr("side_effect_mem", True) self.add_prim_attr("side_effect_mem", True)
self.dtype = dtype self.dtype = dtype
self.element_shape = element_shape self.element_shape = element_shape
def infer_shape(self, handle, indices): def infer_shape(self, handle, indices):
"""推断输出形状."""
if len(indices) != 1: if len(indices) != 1:
return ValueError("indices dimension should be equal to 1") return ValueError("indices dimension should be equal to 1")
return [indices[0]] + list(self.element_shape) return [indices[0]] + list(self.element_shape)
def infer_dtype(self, handle, indices): def infer_dtype(self, handle, indices):
"""推断输出数据类型."""
validator.check_type_name("handle", handle, (ms.int64), self.name) validator.check_type_name("handle", handle, (ms.int64), self.name)
validator.check_type_name("indices", indices, (ms.int32), self.name) validator.check_type_name("indices", indices, (ms.int32), self.name)
return self.dtype return self.dtype

@ -366,31 +366,46 @@ class ModelCheckpoint(Callback):
""" """
def __init__(self, prefix='CKP', directory=None, config=None): def __init__(self, prefix='CKP', directory=None, config=None):
# 初始化函数,设置前缀、目录、配置等参数
super(ModelCheckpoint, self).__init__() super(ModelCheckpoint, self).__init__()
# 调用父类的初始化函数
self._latest_ckpt_file_name = "" self._latest_ckpt_file_name = ""
# 初始化最新检查点文件名为空字符串
self._init_time = time.time() self._init_time = time.time()
# 初始化初始化时间为当前时间
self._last_time = time.time() self._last_time = time.time()
# 初始化最后时间时间为当前时间
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
# 初始化最后保存时间为当前时间
self._last_triggered_step = 0 self._last_triggered_step = 0
# 初始化最后触发的步数为0
# 检查前缀是否为字符串且不包含'/'
if not isinstance(prefix, str) or prefix.find('/') >= 0: if not isinstance(prefix, str) or prefix.find('/') >= 0:
raise ValueError("For 'ModelCheckpoint', the argument 'prefix' " raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
"for checkpoint file name is invalid, it must be " "for checkpoint file name is invalid, it must be "
"string and does not contain '/', but got {}.".format(prefix)) "string and does not contain '/', but got {}.".format(prefix))
self._prefix = prefix self._prefix = prefix
# 设置前缀
self._exception_prefix = prefix self._exception_prefix = prefix
# 设置异常前缀
# 如果目录不为空,则创建目录
if directory is not None: if directory is not None:
self._directory = _make_directory(directory) self._directory = _make_directory(directory)
else: else:
self._directory = _cur_dir self._directory = _cur_dir
# 否则,使用当前目录
# 如果启用了恢复上下文,则设置检查点路径
if _get_recovery_context("enable_recovery"): if _get_recovery_context("enable_recovery"):
_set_recovery_context(ckpt_path=self._directory) _set_recovery_context(ckpt_path=self._directory)
# 如果config为None则使用默认的CheckpointConfig
if config is None: if config is None:
self._config = CheckpointConfig() self._config = CheckpointConfig()
else: else:
# 如果config不是CheckpointConfig类型则抛出TypeError异常
if not isinstance(config, CheckpointConfig): if not isinstance(config, CheckpointConfig):
raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be " raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be "
"'CheckpointConfig', " "'CheckpointConfig', "
@ -398,11 +413,17 @@ class ModelCheckpoint(Callback):
self._config = config self._config = config
# get existing checkpoint files # get existing checkpoint files
# 创建CheckpointManager对象
self._manager = CheckpointManager() self._manager = CheckpointManager()
# 如果存在相同名称的文件,则更改文件名
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
# 获取配置中的append_dict参数如果没有则设置为空字典
self._append_dict = self._config.append_dict or {} self._append_dict = self._config.append_dict or {}
# 获取append_dict中的epoch_num参数如果没有则设置为0
self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0 self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0
# 获取append_dict中的step_num参数如果没有则设置为0
self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0 self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0
# 标记是否已经保存了图
self._graph_saved = False self._graph_saved = False
self._need_flush_from_cache = True self._need_flush_from_cache = True
@ -413,6 +434,7 @@ class ModelCheckpoint(Callback):
Args: Args:
run_context (RunContext): Context of the train running. run_context (RunContext): Context of the train running.
""" """
# If the role is PServer, add the role name and rank to the prefix
if _is_role_pserver(): if _is_role_pserver():
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
cb_params = run_context.original_args() cb_params = run_context.original_args()
@ -423,18 +445,23 @@ class ModelCheckpoint(Callback):
self._last_triggered_step = cb_params.last_save_ckpt_step self._last_triggered_step = cb_params.last_save_ckpt_step
cb_params.last_save_ckpt_step = None cb_params.last_save_ckpt_step = None
# Create the directory if it doesn't exist
_make_directory(self._directory) _make_directory(self._directory)
# save graph (only once) # save graph (only once)
if not self._graph_saved: if not self._graph_saved:
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
# If the graph file already exists and the mode is GRAPH_MODE, remove it
if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE: if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
os.remove(graph_file_name) os.remove(graph_file_name)
# Save the graph
_save_graph(cb_params.train_network, graph_file_name) _save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True self._graph_saved = True
# Wait for any asynchronous checkpoint saving threads to finish
thread_list = threading.enumerate() thread_list = threading.enumerate()
for thread in thread_list: for thread in thread_list:
if thread.getName() == "asyn_save_ckpt": if thread.getName() == "asyn_save_ckpt":
thread.join() thread.join()
# Save the checkpoint
self._save_ckpt(cb_params) self._save_ckpt(cb_params)
def end(self, run_context): def end(self, run_context):
@ -444,44 +471,63 @@ class ModelCheckpoint(Callback):
Args: Args:
run_context (RunContext): Context of the train running. run_context (RunContext): Context of the train running.
""" """
# 获取训练的参数
cb_params = run_context.original_args() cb_params = run_context.original_args()
# 设置保存最后一个checkpoint的标志为True
_to_save_last_ckpt = True _to_save_last_ckpt = True
# 保存最后一个checkpoint
self._save_ckpt(cb_params, _to_save_last_ckpt) self._save_ckpt(cb_params, _to_save_last_ckpt)
# 获取当前线程列表
thread_list = threading.enumerate() thread_list = threading.enumerate()
# 遍历线程列表
for thread in thread_list: for thread in thread_list:
# 如果线程名为"asyn_save_ckpt",则等待该线程结束
if thread.getName() == "asyn_save_ckpt": if thread.getName() == "asyn_save_ckpt":
thread.join() thread.join()
# 销毁所有gather cell
destroy_allgather_cell() destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save): def _check_save_ckpt(self, cb_params, force_to_save):
"""Check whether save checkpoint files or not.""" """Check whether save checkpoint files or not."""
# 如果配置了保存检查点步数且步数大于0
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
# 如果当前步数大于等于上次触发保存检查点步数加上保存检查点步数,或者强制保存检查点
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
or force_to_save is True: or force_to_save is True:
return True return True
# 如果配置了保存检查点秒数且秒数大于0
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
# 获取当前时间
self._cur_time = time.time() self._cur_time = time.time()
# 如果当前时间减去上次时间大于保存检查点秒数,或者强制保存检查点
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save: if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save:
# 更新上次时间
self._last_time = self._cur_time self._last_time = self._cur_time
return True return True
# 返回False
return False return False
def _save_ckpt(self, cb_params, force_to_save=False): def _save_ckpt(self, cb_params, force_to_save=False):
"""Save checkpoint files.""" """Save checkpoint files."""
# 如果当前步骤数等于最后触发的步骤数,则返回
if cb_params.cur_step_num == self._last_triggered_step: if cb_params.cur_step_num == self._last_triggered_step:
return return
# if param is cache enable, flush data from cache to host before save_ckpt # if param is cache enable, flush data from cache to host before save_ckpt
# 如果需要从缓存中刷新数据则调用_flush_from_cache方法
if self._need_flush_from_cache: if self._need_flush_from_cache:
self._flush_from_cache(cb_params) self._flush_from_cache(cb_params)
# 检查是否需要保存检查点如果force_to_save为True则强制保存
save_ckpt = self._check_save_ckpt(cb_params, force_to_save) save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
# 计算当前步数在epoch中的位置
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
# 如果需要保存检查点,则创建当前检查点的文件名
if save_ckpt: if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt" + str(step_num_in_epoch) + ".ckpt"
@ -489,43 +535,68 @@ class ModelCheckpoint(Callback):
self._manager.update_ckpoint_filelist(self._directory, self._prefix) self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number. # keep checkpoint files number equal max number.
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
# 如果keep_checkpoint_max配置存在且大于0且小于等于当前checkpoint文件数量则删除最旧的checkpoint文件
self._manager.remove_oldest_ckpoint_file() self._manager.remove_oldest_ckpoint_file()
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
# 如果keep_checkpoint_per_n_minutes配置存在且大于0则记录当前时间
self._cur_time_for_keep = time.time() self._cur_time_for_keep = time.time()
# 如果当前时间与上次记录的时间之差小于keep_checkpoint_per_n_minutes配置的分钟数乘以60则保留每个分钟的一个checkpoint文件
if (self._cur_time_for_keep - self._last_time_for_keep) \ if (self._cur_time_for_keep - self._last_time_for_keep) \
< self._config.keep_checkpoint_per_n_minutes * 60: < self._config.keep_checkpoint_per_n_minutes * 60:
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
self._cur_time_for_keep) self._cur_time_for_keep)
# generate the new checkpoint file and rename it. # generate the new checkpoint file and rename it.
# 定义全局变量_save_dir并将其赋值为self._directory
global _save_dir global _save_dir
_save_dir = self._directory _save_dir = self._directory
# 获取当前checkpoint文件的路径
cur_file = os.path.join(self._directory, cur_ckpoint_file) cur_file = os.path.join(self._directory, cur_ckpoint_file)
# 记录当前时间
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
# 记录当前触发步数
self._last_triggered_step = cb_params.cur_step_num self._last_triggered_step = cb_params.cur_step_num
# 如果启用了GEGraph Execution
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
# 设置当前网络
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
# 执行checkpoint图
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()
# 如果_append_dict中包含"epoch_num"
if "epoch_num" in self._append_dict: if "epoch_num" in self._append_dict:
# 将_append_epoch_num加上当前epoch数赋值给"epoch_num"
self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
# 如果_append_dict中包含"step_num"
if "step_num" in self._append_dict: if "step_num" in self._append_dict:
# 将_append_step_num加上当前step数赋值给"step_num"
self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
# 获取保存的网络如果self._config.saved_network不为None则使用self._config.saved_network否则使用cb_params.train_network
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
# 保存checkpoint
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
self._append_dict, self._config.enc_key, self._config.enc_mode) self._append_dict, self._config.enc_key, self._config.enc_mode)
# 记录最新的checkpoint文件名
self._latest_ckpt_file_name = cur_file self._latest_ckpt_file_name = cur_file
def _flush_from_cache(self, cb_params): def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable.""" """Flush cache data to host if tensor is cache enable."""
# 初始化has_cache_params为False
has_cache_params = False has_cache_params = False
# 获取训练网络中的参数
params = cb_params.train_network.get_parameters() params = cb_params.train_network.get_parameters()
# 遍历参数
for param in params: for param in params:
# 如果参数的cache_enable为True
if param.cache_enable: if param.cache_enable:
# 设置has_cache_params为True
has_cache_params = True has_cache_params = True
# 将参数的Tensor数据从缓存中刷新到主机
Tensor(param).flush_from_cache() Tensor(param).flush_from_cache()
# 如果没有参数的cache_enable为True
if not has_cache_params: if not has_cache_params:
# 设置_need_flush_from_cache为False
self._need_flush_from_cache = False self._need_flush_from_cache = False
@property @property
@ -535,63 +606,88 @@ class ModelCheckpoint(Callback):
class CheckpointManager: class CheckpointManager:
"""Manage checkpoint files according to train_config of checkpoint.""" """管理检查点文件,根据训练配置进行管理。"""
def __init__(self): def __init__(self):
"""初始化检查点管理器,创建空的检查点文件列表。"""
self._ckpoint_filelist = [] self._ckpoint_filelist = []
@property @property
def ckpoint_filelist(self): def ckpoint_filelist(self):
"""Get all the related checkpoint files managed here.""" """获取当前管理的所有检查点文件列表。"""
return self._ckpoint_filelist return self._ckpoint_filelist
@property @property
def ckpoint_num(self): def ckpoint_num(self):
"""Get the number of the related checkpoint files managed here.""" """获取当前管理的检查点文件数量。"""
return len(self._ckpoint_filelist) return len(self._ckpoint_filelist)
def update_ckpoint_filelist(self, directory, prefix): def update_ckpoint_filelist(self, directory, prefix):
"""Update the checkpoint file list.""" """更新检查点文件列表,根据目录和前缀筛选符合条件的检查点文件。"""
# 初始化一个空列表用于存储ckpt文件
self._ckpoint_filelist = [] self._ckpoint_filelist = []
# 获取指定目录下的所有文件
files = os.listdir(directory) files = os.listdir(directory)
# 遍历所有文件
for filename in files: for filename in files:
# 判断文件是否以指定前缀开头,并且以.ckpt结尾
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"): if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
# 获取文件名中间部分
mid_name = filename[len(prefix):-5] mid_name = filename[len(prefix):-5]
# 判断中间部分是否包含字母
flag = not (True in [char.isalpha() for char in mid_name]) flag = not (True in [char.isalpha() for char in mid_name])
# 如果不包含字母,则将文件路径添加到列表中
if flag: if flag:
self._ckpoint_filelist.append(os.path.join(directory, filename)) self._ckpoint_filelist.append(os.path.join(directory, filename))
def remove_ckpoint_file(self, file_name): def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" """从检查点管理器中移除指定的检查点文件,并从目录中删除该文件。"""
try: try:
# 修改文件权限为可写
os.chmod(file_name, stat.S_IWRITE) os.chmod(file_name, stat.S_IWRITE)
# 删除文件
os.remove(file_name) os.remove(file_name)
# 从ckpoint文件列表中移除该文件
self._ckpoint_filelist.remove(file_name) self._ckpoint_filelist.remove(file_name)
except OSError: except OSError:
# 捕获OSError异常并记录警告日志
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError: except ValueError:
# 捕获ValueError异常并记录警告日志
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def remove_oldest_ckpoint_file(self): def remove_oldest_ckpoint_file(self):
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" """移除检查点管理器中最早的检查点文件,并从目录中删除该文件。"""
# 获取所有checkpoint文件并按修改时间排序
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
# 删除最早修改的checkpoint文件
self.remove_ckpoint_file(ckpoint_files[0]) self.remove_ckpoint_file(ckpoint_files[0])
def keep_one_ckpoint_per_minutes(self, minutes, cur_time): def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" """保留每分钟生成的最新检查点文件,移除在指定时间范围内生成的其他文件。"""
# 定义一个空列表,用于存储需要删除的文件
del_list = [] del_list = []
# 定义一个空字符串,用于存储最旧的文件名
oldest_file = '' oldest_file = ''
# 定义一个变量,用于存储当前时间
oldest_time = cur_time oldest_time = cur_time
# 遍历_ckpoint_filelist中的文件
for ck_file in self._ckpoint_filelist: for ck_file in self._ckpoint_filelist:
# 获取文件的修改时间
modify_time = os.path.getmtime(ck_file) modify_time = os.path.getmtime(ck_file)
# 如果当前时间减去文件的修改时间小于60*minutes则将文件添加到del_list中
if cur_time - modify_time < 60 * minutes: if cur_time - modify_time < 60 * minutes:
del_list.append(ck_file) del_list.append(ck_file)
# 如果文件的修改时间小于oldest_time则更新oldest_time和oldest_file
if modify_time < oldest_time: if modify_time < oldest_time:
oldest_time = modify_time oldest_time = modify_time
oldest_file = ck_file oldest_file = ck_file
# 遍历del_list中的文件
for mv_file in del_list: for mv_file in del_list:
# 如果文件是最旧的文件,则跳过
if mv_file == oldest_file: if mv_file == oldest_file:
continue continue
self.remove_ckpoint_file(mv_file) # 调用remove_ckpoint_file方法删除文件
self.remove_ckpoint_file(mv_file)

@ -256,36 +256,53 @@ class DatasetHelper:
""" """
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
# 检查dataset_sink_mode是否为布尔值
dataset_sink_mode = Validator.check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
# 检查sink_size是否为整数
Validator.check_is_int(sink_size) Validator.check_is_int(sink_size)
# 如果sink_size小于-1或者等于0抛出异常
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size)) raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size))
# 如果sink_size等于-1则将其设置为dataset的dataset_size
if sink_size == -1: if sink_size == -1:
sink_size = dataset.get_dataset_size() sink_size = dataset.get_dataset_size()
# 如果dataset_sink_mode为True则根据不同的设备类型选择不同的迭代器
if dataset_sink_mode: if dataset_sink_mode:
# 如果启用了GE则使用GE的迭代器
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
iterclass = _DatasetIterGE iterclass = _DatasetIterGE
else: else:
# 如果当前模式为GRAPH_MODE则根据角色选择不同的迭代器
if context.get_context("mode") == context.GRAPH_MODE: if context.get_context("mode") == context.GRAPH_MODE:
# 如果当前角色为调度器或者参数服务器,则使用参数服务器的迭代器
if _is_role_sched() or _is_role_pserver(): if _is_role_sched() or _is_role_pserver():
iterclass = _DatasetIterPSServer iterclass = _DatasetIterPSServer
# 如果当前角色为工作节点并且是参数服务器模式,则使用参数服务器工作节点的迭代器
elif _is_role_worker() and _is_ps_mode(): elif _is_role_worker() and _is_ps_mode():
iterclass = _DatasetIterPSWork iterclass = _DatasetIterPSWork
# 如果当前设备类型为Ascend或者GPU则使用多线程循环的迭代器
elif (context.get_context("device_target") == "Ascend") or \ elif (context.get_context("device_target") == "Ascend") or \
(context.get_context("device_target") == "GPU"): (context.get_context("device_target") == "GPU"):
iterclass = _DatasetIterMSLoopSink iterclass = _DatasetIterMSLoopSink
# 如果当前设备类型为CPU则抛出异常因为CPU不支持数据集下沉模式
elif context.get_context("device_target") == "CPU": elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device " raise RuntimeError("Currently dataset sink mode is not supported when the device "
"target is CPU, please set dataset sink mode to False.") "target is CPU, please set dataset sink mode to False.")
# 如果当前模式不是GRAPH_MODE则使用PyNative的迭代器
else: else:
iterclass = _DatasetIterPyNative iterclass = _DatasetIterPyNative
# 创建迭代器
self.iter = iterclass(dataset, sink_size, epoch_num) self.iter = iterclass(dataset, sink_size, epoch_num)
# 如果dataset_sink_mode为False则使用普通的迭代器
else: else:
# 如果不是分布式训练则使用_DatasetIterNormal类
iterclass = _DatasetIterNormal iterclass = _DatasetIterNormal
# 初始化迭代器
self.iter = iterclass(dataset, epoch_num=epoch_num) self.iter = iterclass(dataset, epoch_num=epoch_num)
def __iter__(self): def __iter__(self):
# 返回self.iter的迭代器
return self.iter.__iter__() return self.iter.__iter__()
# A temp solution for loop sink. Delete later # A temp solution for loop sink. Delete later
@ -301,6 +318,7 @@ class DatasetHelper:
>>> >>>
>>> types, shapes = dataset_helper.types_shapes() >>> types, shapes = dataset_helper.types_shapes()
""" """
# 从当前配置的dataset中获取类型和形状
return self.iter.types_shapes() return self.iter.types_shapes()
def sink_size(self): def sink_size(self):
@ -316,18 +334,22 @@ class DatasetHelper:
>>> # if sink_size==-1, then will return the full size of source dataset. >>> # if sink_size==-1, then will return the full size of source dataset.
>>> sink_size = dataset_helper.sink_size() >>> sink_size = dataset_helper.sink_size()
""" """
# 返回迭代器的接收缓冲区大小
return self.iter.get_sink_size() return self.iter.get_sink_size()
def stop_send(self): def stop_send(self):
"""Stop send data about data sink.""" """Stop send data about data sink."""
# 停止发送关于数据接收器的数据
self.iter.stop_send() self.iter.stop_send()
def release(self): def release(self):
"""Free up resources about data sink.""" """Free up resources about data sink."""
# 释放数据接收器的资源
self.iter.release() self.iter.release()
def continue_send(self): def continue_send(self):
"""Continue to send data to device at the beginning of epoch.""" """Continue to send data to device at the beginning of epoch."""
# 在每个epoch的开始处继续向设备发送数据
self.iter.continue_send() self.iter.continue_send()
def _reset(self, step): def _reset(self, step):
@ -339,6 +361,7 @@ class DatasetHelper:
In sink mode, it returns the types and shapes of the current data. In sink mode, it returns the types and shapes of the current data.
Generally, it works in dynamic shape scenarios. Generally, it works in dynamic shape scenarios.
""" """
# 返回迭代器的数据信息
return self.iter.get_data_info() return self.iter.get_data_info()
def dynamic_min_max_shapes(self): def dynamic_min_max_shapes(self):
@ -355,6 +378,7 @@ class DatasetHelper:
>>> >>>
>>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes() >>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes()
""" """
# 返回self.iter的dynamic_min_max_shapes方法
return self.iter.dynamic_min_max_shapes() return self.iter.dynamic_min_max_shapes()
@ -362,20 +386,27 @@ class _DatasetIter:
"""Base iter for dataset helper""" """Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num): def __init__(self, dataset, sink_size, epoch_num):
# 初始化函数传入数据集、sink大小和epoch数量
self.dataset = dataset self.dataset = dataset
self.sink_size = sink_size self.sink_size = sink_size
self.sink_count = self.get_sink_count(dataset) self.sink_count = self.get_sink_count(dataset)
# 如果数据集没有__transfer_dataset__属性
if not hasattr(dataset, '__transfer_dataset__'): if not hasattr(dataset, '__transfer_dataset__'):
# 如果数据集有__loop_size__属性
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
# PS mode does not support loop sink and need get the real sink size. # PS mode does not support loop sink and need get the real sink size.
# 如果不是worker角色或者不是ps模式则设置sink_size为dataset的循环大小
if not (_is_role_worker() and _is_ps_mode()): if not (_is_role_worker() and _is_ps_mode()):
self.sink_size = dataset.__loop_size__ self.sink_size = dataset.__loop_size__
# 如果sink_size为1sink_count为1dataset的大小不为1并且设备目标为Ascend则创建数据信息队列
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1 create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1
and context.get_context("device_target") == "Ascend") and context.get_context("device_target") == "Ascend")
# 执行数据图并将sink_size和create_data_info_queue作为参数传入
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
create_data_info_queue=create_data_info_queue) create_data_info_queue=create_data_info_queue)
# 如果dataset没有__no_send__属性则发送数据
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num) _send_data(dataset, epoch_num)
else: else:
@ -384,33 +415,48 @@ class _DatasetIter:
_cell_graph_executor.set_queue_name(dataset.__transfer_dataset__.queue_name) _cell_graph_executor.set_queue_name(dataset.__transfer_dataset__.queue_name)
_send_data_no_flag(dataset, epoch_num) _send_data_no_flag(dataset, epoch_num)
# 获取dataset的stop_send方法
self.stop_send = dataset.__transfer_dataset__.stop_send self.stop_send = dataset.__transfer_dataset__.stop_send
# 获取dataset的release方法
self.release = dataset.__transfer_dataset__.release self.release = dataset.__transfer_dataset__.release
# 获取dataset的continue_send方法
self.continue_send = dataset.__transfer_dataset__.continue_send self.continue_send = dataset.__transfer_dataset__.continue_send
# 获取dataset的get_data_info方法
self.get_data_info = dataset.__transfer_dataset__.get_data_info self.get_data_info = dataset.__transfer_dataset__.get_data_info
# 获取dataset的dynamic_min_max_shapes属性
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
# 获取dataset的数据类型和数据形状
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
# 如果dataset的__transfer_dataset__属性中有_reset方法则获取该_reset方法
if hasattr(dataset.__transfer_dataset__, "_reset"): if hasattr(dataset.__transfer_dataset__, "_reset"):
self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212 self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212
def __iter__(self): def __iter__(self):
# 初始化索引为0
self.index = 0 self.index = 0
# 返回self
return self return self
# 迭代器的下一项
def __next__(self): def __next__(self):
# 如果索引大于等于sink_count抛出StopIteration异常
if self.index >= self.sink_count: if self.index >= self.sink_count:
raise StopIteration() raise StopIteration()
# 索引加1
self.index += 1 self.index += 1
# 返回op()的返回值
return self.op() return self.op()
def types_shapes(self): def types_shapes(self):
""" """
Return the types and shapes of the dataset. The type and shape of each data in the dataset 返回数据集的类型和形状数据集中每个数据的类型和形状应该是一致的
should be consistent.
""" """
return self.dataset_types, self.dataset_shapes return self.dataset_types, self.dataset_shapes
def get_sink_count(self, dataset): def get_sink_count(self, dataset):
"""
获取数据集的sink次数
:param dataset: 数据集对象
:return: sink次数
"""
sink_count = 1 sink_count = 1
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ loop_size = dataset.__loop_size__
@ -421,7 +467,10 @@ class _DatasetIter:
return sink_count return sink_count
def get_sink_size(self): def get_sink_size(self):
"""get sink_size to device""" """
获取设备的sink大小
:return: sink大小
"""
sink_size = 1 sink_size = 1
if hasattr(self.dataset, '__loop_size__'): if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__ sink_size = self.dataset.__loop_size__

@ -23,47 +23,59 @@ from setuptools import setup, find_packages
from setuptools.command.egg_info import egg_info from setuptools.command.egg_info import egg_info
from setuptools.command.build_py import build_py from setuptools.command.build_py import build_py
# 获取环境变量
backend_policy = os.getenv('BACKEND_POLICY') backend_policy = os.getenv('BACKEND_POLICY')
device_target = os.getenv('BACKEND_TARGET') device_target = os.getenv('BACKEND_TARGET')
commit_id = os.getenv('COMMIT_ID').replace("\n", "") commit_id = os.getenv('COMMIT_ID').replace("\n", "")
package_name = os.getenv('MS_PACKAGE_NAME').replace("\n", "") package_name = os.getenv('MS_PACKAGE_NAME').replace("\n", "")
build_path = os.getenv('BUILD_PATH') build_path = os.getenv('BUILD_PATH')
# 获取当前文件路径
pwd = os.path.dirname(os.path.realpath(__file__)) pwd = os.path.dirname(os.path.realpath(__file__))
# 获取包目录路径
pkg_dir = os.path.join(build_path, 'package') pkg_dir = os.path.join(build_path, 'package')
def _read_file(filename): def _read_file(filename):
"""读取文件内容"""
with open(os.path.join(pwd, filename), encoding='UTF-8') as f: with open(os.path.join(pwd, filename), encoding='UTF-8') as f:
return f.read() return f.read()
# 读取版本号
version = _read_file('version.txt').replace("\n", "") version = _read_file('version.txt').replace("\n", "")
# 读取README.md文件内容
readme = _read_file('README.md') readme = _read_file('README.md')
def _write_version(file): def _write_version(file):
"""写入版本号"""
file.write("__version__ = '{}'\n".format(version)) file.write("__version__ = '{}'\n".format(version))
def _write_config(file): def _write_config(file):
"""写入后端策略"""
file.write("__backend__ = '{}'\n".format(backend_policy)) file.write("__backend__ = '{}'\n".format(backend_policy))
def _write_commit_file(file): def _write_commit_file(file):
"""写入commit_id"""
file.write("__commit_id__ = '{}'\n".format(commit_id)) file.write("__commit_id__ = '{}'\n".format(commit_id))
def _write_package_name(file): def _write_package_name(file):
"""写入包名"""
file.write("__package_name__ = '{}'\n".format(package_name)) file.write("__package_name__ = '{}'\n".format(package_name))
def _write_device_target(file): def _write_device_target(file):
"""写入设备目标"""
file.write("__device_target__ = '{}'\n".format(device_target)) file.write("__device_target__ = '{}'\n".format(device_target))
def build_dependencies(): def build_dependencies():
"""generate python file""" """generate python file"""
# 生成version.py文件
version_file = os.path.join(pkg_dir, 'mindspore', 'version.py') version_file = os.path.join(pkg_dir, 'mindspore', 'version.py')
with open(version_file, 'w') as f: with open(version_file, 'w') as f:
_write_version(f) _write_version(f)
@ -72,6 +84,7 @@ def build_dependencies():
with open(version_file, 'w') as f: with open(version_file, 'w') as f:
_write_version(f) _write_version(f)
# 生成default_config.py文件
config_file = os.path.join(pkg_dir, 'mindspore', 'default_config.py') config_file = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(config_file, 'w') as f: with open(config_file, 'w') as f:
_write_config(f) _write_config(f)
@ -80,6 +93,7 @@ def build_dependencies():
with open(config_file, 'w') as f: with open(config_file, 'w') as f:
_write_config(f) _write_config(f)
# 向default_config.py文件中追加device_target
target = os.path.join(pkg_dir, 'mindspore', 'default_config.py') target = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(target, 'a') as f: with open(target, 'a') as f:
_write_device_target(f) _write_device_target(f)
@ -88,6 +102,7 @@ def build_dependencies():
with open(target, 'a') as f: with open(target, 'a') as f:
_write_device_target(f) _write_device_target(f)
# 向default_config.py文件中追加package_name
package_info = os.path.join(pkg_dir, 'mindspore', 'default_config.py') package_info = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(package_info, 'a') as f: with open(package_info, 'a') as f:
_write_package_name(f) _write_package_name(f)
@ -96,6 +111,7 @@ def build_dependencies():
with open(package_info, 'a') as f: with open(package_info, 'a') as f:
_write_package_name(f) _write_package_name(f)
# 生成.commit_id文件
commit_file = os.path.join(pkg_dir, 'mindspore', '.commit_id') commit_file = os.path.join(pkg_dir, 'mindspore', '.commit_id')
with open(commit_file, 'w') as f: with open(commit_file, 'w') as f:
_write_commit_file(f) _write_commit_file(f)
@ -145,16 +161,24 @@ def update_permissions(path):
Args: Args:
path (str): Target directory path. path (str): Target directory path.
""" """
# 判断操作系统是否为Windows
if platform.system() == "Windows": if platform.system() == "Windows":
return return
# 遍历目标目录下的所有文件和文件夹
for dirpath, dirnames, filenames in os.walk(path): for dirpath, dirnames, filenames in os.walk(path):
# 遍历文件夹
for dirname in dirnames: for dirname in dirnames:
# 获取文件夹的完整路径
dir_fullpath = os.path.join(dirpath, dirname) dir_fullpath = os.path.join(dirpath, dirname)
# 更新文件夹的权限
os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE | os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE |
stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP) stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP)
# 遍历文件
for filename in filenames: for filename in filenames:
# 获取文件的完整路径
file_fullpath = os.path.join(dirpath, filename) file_fullpath = os.path.join(dirpath, filename)
# 更新文件的权限
os.chmod(file_fullpath, stat.S_IREAD) os.chmod(file_fullpath, stat.S_IREAD)
@ -163,7 +187,9 @@ class EggInfo(egg_info):
def run(self): def run(self):
super().run() super().run()
# 获取egg-info目录的路径
egg_info_dir = os.path.join(pkg_dir, 'mindspore.egg-info') egg_info_dir = os.path.join(pkg_dir, 'mindspore.egg-info')
# 更新egg-info目录的权限
update_permissions(egg_info_dir) update_permissions(egg_info_dir)
@ -172,41 +198,64 @@ class BuildPy(build_py):
def run(self): def run(self):
super().run() super().run()
# 获取build目录下的lib/mindspore目录的路径
mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore') mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore')
# 更新lib/mindspore目录的权限
update_permissions(mindspore_dir) update_permissions(mindspore_dir)
# 获取build目录下的lib/mindspore/_akg目录的路径
mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore', '_akg') mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore', '_akg')
# 更新lib/mindspore/_akg目录的权限
update_permissions(mindspore_dir) update_permissions(mindspore_dir)
# 设置包的名称
setup( setup(
name=package_name, name=package_name,
# 设置包的版本
version=version, version=version,
# 设置包的作者
author='The MindSpore Authors', author='The MindSpore Authors',
# 设置包的作者邮箱
author_email='contact@mindspore.cn', author_email='contact@mindspore.cn',
# 设置包的网址
url='https://www.mindspore.cn', url='https://www.mindspore.cn',
# 设置包的下载网址
download_url='https://github.com/mindspore-ai/mindspore/tags', download_url='https://github.com/mindspore-ai/mindspore/tags',
# 设置包的源代码网址
project_urls={ project_urls={
'Sources': 'https://github.com/mindspore-ai/mindspore', 'Sources': 'https://github.com/mindspore-ai/mindspore',
# 设置包的问题追踪网址
'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues', 'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues',
}, },
# 设置包的描述
description='MindSpore is a new open source deep learning training/inference ' description='MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.', 'framework that could be used for mobile, edge and cloud scenarios.',
# 读取readme文件作为包的详细描述
long_description=readme, long_description=readme,
# 设置详细描述的格式
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
# 查找包中的所有模块
packages=find_packages(), packages=find_packages(),
# 设置包的数据
package_data=package_data, package_data=package_data,
# 包含包中的所有数据
include_package_data=True, include_package_data=True,
# 设置自定义的命令类
cmdclass={ cmdclass={
'egg_info': EggInfo, 'egg_info': EggInfo,
'build_py': BuildPy, 'build_py': BuildPy,
}, },
# 设置包的入口点
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'cache_admin=mindspore.dataset.engine.cache_admin:main', 'cache_admin=mindspore.dataset.engine.cache_admin:main',
], ],
}, },
# 设置包的Python版本要求
python_requires='>=3.7', python_requires='>=3.7',
# 设置包的依赖
install_requires=required_package, install_requires=required_package,
# 设置包的分类器
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
'Environment :: Console', 'Environment :: Console',
@ -223,6 +272,8 @@ setup(
'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: Software Development :: Libraries :: Python Modules',
], ],
# 设置包的许可证
license='Apache 2.0', license='Apache 2.0',
# 设置包的关键词
keywords='mindspore machine learning', keywords='mindspore machine learning',
) )

Loading…
Cancel
Save