pull/2/head
xiangguo 7 months ago
parent ea602b4c9f
commit 7fa1f5ab19

@ -17,65 +17,71 @@ from ..._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithCheck from ..primitive import prim_attr_register, PrimitiveWithCheck
from .. import signature as sig from .. import signature as sig
class UpdateCache(PrimitiveWithCheck): class UpdateCache(PrimitiveWithCheck):
""" """
Update the value fo input_x, similar to ScatterNdUpdate. 更新 input_x 的值类似于 ScatterNdUpdate
The difference is that UpdateCache will not update when indices < 0 or indices >= max_num. 不同之处在于UpdateCache indices < 0 indices >= max_num 时不会更新
Inputs: Inputs:
- **input_x** (Parameter) - Parameter which is going to be updated. - **input_x** (Parameter) - 将要更新的参数
- **indices** (Tensor) - Update indices of input_x. - **indices** (Tensor) - input_x 的更新索引
- **updates** (Tensor) - The update values. - **updates** (Tensor) - 更新值
Outputs: Outputs:
- **out** (Tensor) - Returns a [1] Tensor, which is not useful. - **out** (Tensor) - 返回一个 [1] 的张量这个张量没有用处
""" """
# 定义函数签名,指定输入参数的类型和读写权限
__mindspore_signature__ = ( __mindspore_signature__ = (
# 定义输入参数input_x类型为T读写权限为写
sig.make_sig('input_x', sig.sig_rw.RW_WRITE, sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
dtype=sig.sig_dtype.T), dtype=sig.sig_dtype.T),
# 定义输入参数indices类型为T1
sig.make_sig('indices', dtype=sig.sig_dtype.T1), sig.make_sig('indices', dtype=sig.sig_dtype.T1),
# 定义输入参数updates类型为T
sig.make_sig('updates', dtype=sig.sig_dtype.T), sig.make_sig('updates', dtype=sig.sig_dtype.T),
# 定义输入参数max_num类型为T1
sig.make_sig('max_num', dtype=sig.sig_dtype.T1) sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init UpdateCache""" """初始化 UpdateCache"""
# 初始化输入和输出名称
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
outputs=['out']) outputs=['out'])
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
# 检查输入形状
return [1] return [1]
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
# 检查输入数据类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"indices", indices_dtype, mstype.int_type, self.name) "indices", indices_dtype, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype
class SubAndFilter(PrimitiveWithCheck): class SubAndFilter(PrimitiveWithCheck):
""" """
Dynamic kernel, sub an offset and 动态内核减去一个偏移量并返回在范围 [0, max_num) 内的元素
return the elements which in range [0, max_num).
Inputs: Inputs:
- **input_x** (Tensor) - Input tensor. - **input_x** (Tensor) - 输入张量
- **max_num** (Int) - The max value of element that after sub `offset`. - **max_num** (Int) - 减去 `offset` 后元素的最大值
- **offset** (int) - Specifies the offset value of this `input_x`. - **offset** (int) - 指定此 `input_x` 的偏移值
Outputs: Outputs:
tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx. tuple(Tensor), 2 个张量组成的元组filter_res filter_idx
- **filter_res** (Tensor) - The result that `input_x` minus `offset`, - **filter_res** (Tensor) - `input_x` 减去 `offset` 的结果
and return which in the range [0, max_num). 并返回在范围 [0, max_num) 内的值
- **filter_idx** (Tensor) - A tensor containing indices of elements in the input - **filter_idx** (Tensor) - 一个张量包含与输出张量对应的输入元素的索引
coressponding to the output tensor.
Supported Platforms: Supported Platforms:
`CPU` `CPU`
Examples: Examples:
>>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32) >>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32)
>>> max_num = 10 >>> max_num = 10
@ -87,35 +93,38 @@ class SubAndFilter(PrimitiveWithCheck):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init SubAndFilter""" """初始化 SubAndFilter"""
# 初始化输入和输出名称
self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'], self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'],
outputs=['sub_res', 'sub_idx']) outputs=['sub_res', 'sub_idx'])
def check_shape(self, input_x_shape, max_num_shape, offset_shape): def check_shape(self, input_x_shape, max_num_shape, offset_shape):
# 检查输入形状
return ((-1,), (-1,)) return ((-1,), (-1,))
def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype): def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype):
# 检查输入数据类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"input_x", input_x_dtype, mstype.int_type, self.name) "input_x", input_x_dtype, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype
class MapUniform(PrimitiveWithCheck): class MapUniform(PrimitiveWithCheck):
""" """
Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`. 通过公式映射一个张量value = key % `group_num` * `per_group_size` + key // `group_num`
Inputs: Inputs:
- **input** (Tensor) - Input Tensor. - **input** (Tensor) - 输入张量
- **per_group_size** (int) - The size of each group. - **per_group_size** (int) - 每个组的大小
- **group_num** (int) - The number of group. - **group_num** (int) - 组的数量
Outputs: Outputs:
Tensor, has the same dtype and shape as the `input`. Tensor具有与 `input` 相同的 dtype 和形状
Supported Platforms: Supported Platforms:
`CPU` `CPU`
Examples: Examples:
>>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7])) >>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]))
>>> per_group_size = 4 >>> per_group_size = 4
@ -125,33 +134,34 @@ class MapUniform(PrimitiveWithCheck):
>>> print(output) >>> print(output)
[0, 4, 1, 5, 2, 6, 3, 7] [0, 4, 1, 5, 2, 6, 3, 7]
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init MapUniform""" """初始化 MapUniform"""
self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'], self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'],
outputs=['output']) outputs=['output'])
def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype): def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype):
"""检查输入数据类型"""
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"input", input_dtype, mstype.int_type, self.name) "input", input_dtype, mstype.int_type, self.name)
validator.check_value_type( validator.check_value_type(
'per_group_size', per_group_size_dtype, [mstype.Int], self.name) 'per_group_size', per_group_size_dtype, [mstype.Int], self.name)
validator.check_value_type( validator.check_value_type(
'group_num', group_num_dtype, [mstype.Int], self.name) 'group_num', group_num_dtype, [mstype.Int], self.name)
class CacheSwapTable(PrimitiveWithCheck): class CacheSwapTable(PrimitiveWithCheck):
""" """
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. 删除一个哈希映射条目并插入一个新键到哈希映射中返回删除条目的键和值
Inputs: Inputs:
- **cache_table** (Parameter) - The cache table which is on device. - **cache_table** (Parameter) - 在设备上的缓存表
- **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped. - **swap_cache_idx** (Tensor) - 需要交换的表索引-1 被跳过
- **miss_value** (int) - The values which arg going to swap into cache table. - **miss_value** (int) - 将要交换到缓存表的值
Outputs: Outputs:
- **old_value** (Tensor) - The values which are swapped out. - **old_value** (Tensor) - 被交换出去的值
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('cache_table', sig.sig_rw.RW_WRITE, sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
@ -159,31 +169,35 @@ class CacheSwapTable(PrimitiveWithCheck):
sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1), sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1),
sig.make_sig('miss_value', dtype=sig.sig_dtype.T) sig.make_sig('miss_value', dtype=sig.sig_dtype.T)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init CacheSwapTable""" """初始化 CacheSwapTable"""
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
outputs=['old_value']) outputs=['old_value'])
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
# 检查cache_table_shape的长度是否为2如果不是则抛出ValueError异常
if len(cache_table_shape) != 2: if len(cache_table_shape) != 2:
raise ValueError( raise ValueError(
"cache table shape must be 2, but got %d" % len(cache_table_shape)) "cache table shape must be 2, but got %d" % len(cache_table_shape))
# 返回miss_value_shape
return miss_value_shape return miss_value_shape
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
# 检查swap_cache_idx_dtype是否为mstype.int_type如果不是则抛出ValueError异常
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
# 返回miss_value_dtype
return miss_value_dtype return miss_value_dtype
class MapCacheIdx(PrimitiveWithCheck): class MapCacheIdx(PrimitiveWithCheck):
""" """
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. MapCacheIdx SearchCacheIdxCacheSwapHashmap UpdateCache 合并在一起
When input an indices tensor, it will output the cache indices which search in hashmap. 当输入一个索引张量时它将输出在哈希映射中搜索的缓存索引
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('hashmap', sig.sig_rw.RW_WRITE, sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
@ -193,56 +207,69 @@ class MapCacheIdx(PrimitiveWithCheck):
sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T), sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T) sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init MapCacheIdx""" """初始化 MapCacheIdx"""
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
def __check__(self, hashmap, indices, step, emb_max_num, offset): def __check__(self, hashmap, indices, step, emb_max_num, offset):
# 获取hashmap的形状
hashmap_shape = hashmap['shape'] hashmap_shape = hashmap['shape']
# 如果hashmap的维度不是2则抛出异常
if len(hashmap_shape) != 2: if len(hashmap_shape) != 2:
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
"but got %d." % len(hashmap_shape)) "but got %d." % len(hashmap_shape))
# 设置输出的形状
out_shape = (indices['shape'], -1, -1, -1) out_shape = (indices['shape'], -1, -1, -1)
# 获取hashmap和indices的数据类型
hashmap_dtype = hashmap['dtype'] hashmap_dtype = hashmap['dtype']
indices_dtype = indices['dtype'] indices_dtype = indices['dtype']
# 将数据类型存入字典
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
# 检查数据类型是否相同且有效
validator.check_tensors_dtypes_same_and_valid( validator.check_tensors_dtypes_same_and_valid(
args, mstype.int_type, self.name) args, mstype.int_type, self.name)
# 设置输出的数据类型
out_dtype = (hashmap_dtype, hashmap_dtype, out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype) hashmap_dtype, hashmap_dtype)
# 设置输出的字典
out = {'shape': out_shape, out = {'shape': out_shape,
'dtype': out_dtype, 'dtype': out_dtype,
'value': None} 'value': None}
# 如果indices中有max_shape则设置输出的max_shape
if 'max_shape' in indices: if 'max_shape' in indices:
out['max_shape'] = (indices['max_shape'], indices['max_shape'], out['max_shape'] = (indices['max_shape'], indices['max_shape'],
indices['max_shape'], indices['max_shape']) indices['max_shape'], indices['max_shape'])
# 否则设置输出的max_shape为indices的形状
else: else:
out['max_shape'] = (indices['shape'], indices['shape'], out['max_shape'] = (indices['shape'], indices['shape'],
indices['shape'], indices['shape']) indices['shape'], indices['shape'])
# 如果indices中有min_shape则设置输出的min_shape
if 'min_shape' in indices: if 'min_shape' in indices:
out['min_shape'] = (indices['min_shape'], 0, 0, 0) out['min_shape'] = (indices['min_shape'], 0, 0, 0)
# 否则设置输出的min_shape为(0, 0, 0, 0)
else: else:
out['min_shape'] = (0, 0, 0, 0) out['min_shape'] = (0, 0, 0, 0)
# 返回输出的字典
return out return out
class DynamicAssign(PrimitiveWithCheck): class DynamicAssign(PrimitiveWithCheck):
""" """
Assigns `Parameter` with a value, the `value` can have a dynamic shape. `Parameter` 与值分配`value` 可以具有动态形状
Inputs: Inputs:
- **variable** (Parameter) - The `Parameter`. - **variable** (Parameter) - `Parameter`
- **value** (Tensor) - The value to be assigned. - **value** (Tensor) - 要分配的值
Outputs: Outputs:
Tensor, has the same type as original `variable`. Tensor具有与原始 `variable` 相同的类型
Supported Platforms: Supported Platforms:
`CPU` `CPU`
""" """
@ -250,41 +277,42 @@ class DynamicAssign(PrimitiveWithCheck):
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('value', dtype=sig.sig_dtype.T) sig.make_sig('value', dtype=sig.sig_dtype.T)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
def check_dtype(self, variable, value): def check_dtype(self, variable, value):
# 检查变量是否为mstype.type_refkey
if variable != mstype.type_refkey: if variable != mstype.type_refkey:
# 检查变量是否为mstype.number_type类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"variable", variable, mstype.number_type, self.name) "variable", variable, mstype.number_type, self.name)
# 检查value是否为mstype.number_type类型
validator.check_scalar_or_tensor_types_same( validator.check_scalar_or_tensor_types_same(
{"value": value}, mstype.number_type, self.name) {"value": value}, mstype.number_type, self.name)
class PadAndShift(PrimitiveWithCheck): class PadAndShift(PrimitiveWithCheck):
""" """
Pad a tensor with -1, and shift with a length. -1 填充张量并按长度进行移位
Inputs: Inputs:
- **input_x** (Tensor) - The input Tensor, which will be copied - **input_x** (Tensor) - 输入张量将被复制到 `output`
to `output`. - **cum_sum_arr** (Tensor) - cum_sum_arr 的最后一个值是输出张量的填充长度
- **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is cum_sum_arr[shift_idx] 是开始移位cum_sum_arr[shift_idx+1] 是结束
the pad length of output tensor, cum_sum_arr[shift_idx] is - **shift_idx** (Int) - cum_sum_arr 的索引
the start to shift, and cum_sum_arr[shift_idx+1] is the end. 如果使用 PythonPadAndShift
- **shift_idx** (Int) - The idx of cum_sum_arr.
if use python, PadAndShift is:
output = [-1] * cum_sum_arr[-1] output = [-1] * cum_sum_arr[-1]
start = cum_sum_arr[shift_idx] start = cum_sum_arr[shift_idx]
end = cum_sum_arr[shift_idx + 1] end = cum_sum_arr[shift_idx + 1]
output[start:end] = input_x[:(end-start)] output[start:end] = input_x[:(end-start)]
Outputs: Outputs:
Tensor, has the same type as original `variable`. Tensor具有与原始 `variable` 相同的类型
Supported Platforms: Supported Platforms:
`CPU` `CPU`
Examples: Examples:
>>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32) >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32)
>>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32) >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32)
@ -296,11 +324,14 @@ class PadAndShift(PrimitiveWithCheck):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
# 初始化输入输出名称
self.init_prim_io_names( self.init_prim_io_names(
inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output']) inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output'])
def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape): def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape):
# 检查输入形状
return input_x_shape return input_x_shape
def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype): def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype):
return input_x_dtype # 检查输入数据类型
return input_x_dtype

@ -12,39 +12,39 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Operators for TensorArray.""" """Operators for TensorArray."""
import mindspore as ms import mindspore as ms
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
class TensorArray(PrimitiveWithInfer): class TensorArray(PrimitiveWithInfer):
r""" r"""
TensorArrayCreate used to create a TensorArray and return an unique handle. TensorArrayCreate used to create a TensorArray and return an unique handle.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Args: Args:
dtype (mindspore.dtype): the data type in the TensorArray. dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray.
dynamic_size (bool): If true the TensorArray can increase the size. Default: True. dynamic_size (bool): If true the TensorArray can increase the size. Default: True.
size (int): The size of the TensorArray if dynamic_size = False. size (int): The size of the TensorArray if dynamic_size = False.
name (string): the name of this TensorArray. Default: "TA". name (string): the name of this TensorArray. Default: "TA".
Inputs: Inputs:
None. None.
Outputs: Outputs:
- **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray. - **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -55,6 +55,7 @@ class TensorArray(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"): def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
"""初始化TensorArray类设置参数和属性."""
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
validator.check_int(size, 0, Rel.GE, "size", self.name) validator.check_int(size, 0, Rel.GE, "size", self.name)
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
@ -63,32 +64,34 @@ class TensorArray(PrimitiveWithInfer):
self.add_prim_attr('size', size) self.add_prim_attr('size', size)
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
self.add_prim_attr('name', name) self.add_prim_attr('name', name)
def infer_shape(self): def infer_shape(self):
"""推断输出形状."""
return () return ()
def infer_dtype(self): def infer_dtype(self):
"""推断输出数据类型."""
return mstype.int64 return mstype.int64
class TensorArrayWrite(PrimitiveWithInfer): class TensorArrayWrite(PrimitiveWithInfer):
r""" r"""
TensorArrayWrite used to write tensor into a created TensorArray. TensorArrayWrite used to write tensor into a created TensorArray.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **index** (Tensor[int64]) - The position to write. - **index** (Tensor[int64]) - The position to write.
- **value** (Tensor) - The value to add into the TensorArray. - **value** (Tensor) - The value to add into the TensorArray.
- **handle** (Tensor[int64]) - The handle pointed to the TensorArray. - **handle** (Tensor[int64]) - The handle pointed to the TensorArray.
Outputs: Outputs:
None. None.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -99,39 +102,42 @@ class TensorArrayWrite(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayWrite类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape, index_shape, value_shape): def infer_shape(self, handle_shape, index_shape, value_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type, index_type, value_type): def infer_dtype(self, handle_type, index_type, value_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
validator.check_type_name("index", index_type, (int, ms.int64), self.name) validator.check_type_name("index", index_type, (int, ms.int64), self.name)
validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name)
return mstype.int64 return mstype.int64
class TensorArrayRead(PrimitiveWithInfer): class TensorArrayRead(PrimitiveWithInfer):
r""" r"""
TensorArrayRead used to read tensor from a created TensorArray by the given index. TensorArrayRead used to read tensor from a created TensorArray by the given index.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Args: Args:
dtype (mindspore.dtype): the data type in the TensorArray. dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs: Inputs:
- **index** (Tensor[int64]) - The position to read. - **index** (Tensor[int64]) - The position to read.
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
- **output** (Tensor) - the value in position index. - **output** (Tensor) - the value in position index.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -146,38 +152,41 @@ class TensorArrayRead(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape): def __init__(self, dtype, element_shape):
"""初始化TensorArrayRead类设置参数和属性."""
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
self.add_prim_attr('element_shape', element_shape) self.add_prim_attr('element_shape', element_shape)
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
self.dtype = dtype self.dtype = dtype
self.shape = element_shape self.shape = element_shape
def infer_shape(self, handle_shape, index_shape): def infer_shape(self, handle_shape, index_shape):
"""推断输出形状."""
return self.shape return self.shape
def infer_dtype(self, handle_type, index_type): def infer_dtype(self, handle_type, index_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
validator.check_type_name("index", index_type, (int, ms.int64), self.name) validator.check_type_name("index", index_type, (int, ms.int64), self.name)
return self.dtype return self.dtype
class TensorArrayClose(PrimitiveWithInfer): class TensorArrayClose(PrimitiveWithInfer):
r""" r"""
TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted. TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
None. None.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -188,32 +197,35 @@ class TensorArrayClose(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayClose类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
class TensorArrayClear(PrimitiveWithInfer): class TensorArrayClear(PrimitiveWithInfer):
r""" r"""
TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable. TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
None. None.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -224,36 +236,39 @@ class TensorArrayClear(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayClear类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
class TensorArrayStack(Primitive): class TensorArrayStack(Primitive):
r""" r"""
TensorArrayStack used to stack the tensors in a created TensorArray into one tensor. TensorArrayStack used to stack the tensors in a created TensorArray into one tensor.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Args: Args:
dtype (mindspore.dtype): the data type in the TensorArray. dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
- **output** (Tensor) - the stacked value from the TensorArray. - **output** (Tensor) - the stacked value from the TensorArray.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -269,31 +284,31 @@ class TensorArrayStack(Primitive):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape, dynamic_size, size): def __init__(self, dtype, element_shape, dynamic_size, size):
"""Initialize TensorArrayStack""" """初始化TensorArrayStack类设置参数和属性."""
self.init_prim_io_names(inputs=[''], outputs=['output']) self.init_prim_io_names(inputs=[''], outputs=['output'])
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
self.add_prim_attr('element_shape', element_shape) self.add_prim_attr('element_shape', element_shape)
self.add_prim_attr('is_dynamic_shape', dynamic_size) self.add_prim_attr('is_dynamic_shape', dynamic_size)
self.add_prim_attr('size', size) self.add_prim_attr('size', size)
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
class TensorArraySize(PrimitiveWithInfer): class TensorArraySize(PrimitiveWithInfer):
r""" r"""
TensorArraySize used to get the logical size of the created TensorArray. TensorArraySize used to get the logical size of the created TensorArray.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs: Outputs:
- **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray. - **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray.
Supported Platforms: Supported Platforms:
``GPU`` ``CPU`` ``GPU`` ``CPU``
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -304,34 +319,37 @@ class TensorArraySize(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArraySize类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
class TensorArrayGather(PrimitiveWithInfer): class TensorArrayGather(PrimitiveWithInfer):
r""" r"""
TensorArrayGather used to gather specified elements from the created TensorArray. TensorArrayGather used to gather specified elements from the created TensorArray.
.. warning:: .. warning::
This is an experimental prototype that is subject to change and/or deletion. This is an experimental prototype that is subject to change and/or deletion.
Args: Args:
dtype (mindspore.dtype): the data type in the TensorArray. dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs: Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **handle** (mindspore.int64) - The handle pointed to the TensorArray.
- **indices** (mindspore.int32) - The locations of the gathered elements. - **indices** (mindspore.int32) - The locations of the gathered elements.
Outputs: Outputs:
- **output** (Tensor) - The gathered value from the TensorArray. - **output** (Tensor) - The gathered value from the TensorArray.
Examples: Examples:
>>> import mindspore >>> import mindspore
>>> import mindspore.ops as ops >>> import mindspore.ops as ops
@ -344,17 +362,20 @@ class TensorArrayGather(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape): def __init__(self, dtype, element_shape):
"""初始化TensorArrayGather类设置参数和属性."""
self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value']) self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value'])
self.add_prim_attr("side_effect_mem", True) self.add_prim_attr("side_effect_mem", True)
self.dtype = dtype self.dtype = dtype
self.element_shape = element_shape self.element_shape = element_shape
def infer_shape(self, handle, indices): def infer_shape(self, handle, indices):
"""推断输出形状."""
if len(indices) != 1: if len(indices) != 1:
return ValueError("indices dimension should be equal to 1") return ValueError("indices dimension should be equal to 1")
return [indices[0]] + list(self.element_shape) return [indices[0]] + list(self.element_shape)
def infer_dtype(self, handle, indices): def infer_dtype(self, handle, indices):
"""推断输出数据类型."""
validator.check_type_name("handle", handle, (ms.int64), self.name) validator.check_type_name("handle", handle, (ms.int64), self.name)
validator.check_type_name("indices", indices, (ms.int32), self.name) validator.check_type_name("indices", indices, (ms.int32), self.name)
return self.dtype return self.dtype

@ -366,31 +366,46 @@ class ModelCheckpoint(Callback):
""" """
def __init__(self, prefix='CKP', directory=None, config=None): def __init__(self, prefix='CKP', directory=None, config=None):
# 初始化函数,设置前缀、目录、配置等参数
super(ModelCheckpoint, self).__init__() super(ModelCheckpoint, self).__init__()
# 调用父类的初始化函数
self._latest_ckpt_file_name = "" self._latest_ckpt_file_name = ""
# 初始化最新检查点文件名为空字符串
self._init_time = time.time() self._init_time = time.time()
# 初始化初始化时间为当前时间
self._last_time = time.time() self._last_time = time.time()
# 初始化最后时间时间为当前时间
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
# 初始化最后保存时间为当前时间
self._last_triggered_step = 0 self._last_triggered_step = 0
# 初始化最后触发的步数为0
# 检查前缀是否为字符串且不包含'/'
if not isinstance(prefix, str) or prefix.find('/') >= 0: if not isinstance(prefix, str) or prefix.find('/') >= 0:
raise ValueError("For 'ModelCheckpoint', the argument 'prefix' " raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
"for checkpoint file name is invalid, it must be " "for checkpoint file name is invalid, it must be "
"string and does not contain '/', but got {}.".format(prefix)) "string and does not contain '/', but got {}.".format(prefix))
self._prefix = prefix self._prefix = prefix
# 设置前缀
self._exception_prefix = prefix self._exception_prefix = prefix
# 设置异常前缀
# 如果目录不为空,则创建目录
if directory is not None: if directory is not None:
self._directory = _make_directory(directory) self._directory = _make_directory(directory)
else: else:
self._directory = _cur_dir self._directory = _cur_dir
# 否则,使用当前目录
# 如果启用了恢复上下文,则设置检查点路径
if _get_recovery_context("enable_recovery"): if _get_recovery_context("enable_recovery"):
_set_recovery_context(ckpt_path=self._directory) _set_recovery_context(ckpt_path=self._directory)
# 如果config为None则使用默认的CheckpointConfig
if config is None: if config is None:
self._config = CheckpointConfig() self._config = CheckpointConfig()
else: else:
# 如果config不是CheckpointConfig类型则抛出TypeError异常
if not isinstance(config, CheckpointConfig): if not isinstance(config, CheckpointConfig):
raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be " raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be "
"'CheckpointConfig', " "'CheckpointConfig', "
@ -398,11 +413,17 @@ class ModelCheckpoint(Callback):
self._config = config self._config = config
# get existing checkpoint files # get existing checkpoint files
# 创建CheckpointManager对象
self._manager = CheckpointManager() self._manager = CheckpointManager()
# 如果存在相同名称的文件,则更改文件名
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
# 获取配置中的append_dict参数如果没有则设置为空字典
self._append_dict = self._config.append_dict or {} self._append_dict = self._config.append_dict or {}
# 获取append_dict中的epoch_num参数如果没有则设置为0
self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0 self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0
# 获取append_dict中的step_num参数如果没有则设置为0
self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0 self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0
# 标记是否已经保存了图
self._graph_saved = False self._graph_saved = False
self._need_flush_from_cache = True self._need_flush_from_cache = True
@ -413,6 +434,7 @@ class ModelCheckpoint(Callback):
Args: Args:
run_context (RunContext): Context of the train running. run_context (RunContext): Context of the train running.
""" """
# If the role is PServer, add the role name and rank to the prefix
if _is_role_pserver(): if _is_role_pserver():
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
cb_params = run_context.original_args() cb_params = run_context.original_args()
@ -423,18 +445,23 @@ class ModelCheckpoint(Callback):
self._last_triggered_step = cb_params.last_save_ckpt_step self._last_triggered_step = cb_params.last_save_ckpt_step
cb_params.last_save_ckpt_step = None cb_params.last_save_ckpt_step = None
# Create the directory if it doesn't exist
_make_directory(self._directory) _make_directory(self._directory)
# save graph (only once) # save graph (only once)
if not self._graph_saved: if not self._graph_saved:
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
# If the graph file already exists and the mode is GRAPH_MODE, remove it
if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE: if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
os.remove(graph_file_name) os.remove(graph_file_name)
# Save the graph
_save_graph(cb_params.train_network, graph_file_name) _save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True self._graph_saved = True
# Wait for any asynchronous checkpoint saving threads to finish
thread_list = threading.enumerate() thread_list = threading.enumerate()
for thread in thread_list: for thread in thread_list:
if thread.getName() == "asyn_save_ckpt": if thread.getName() == "asyn_save_ckpt":
thread.join() thread.join()
# Save the checkpoint
self._save_ckpt(cb_params) self._save_ckpt(cb_params)
def end(self, run_context): def end(self, run_context):
@ -444,44 +471,63 @@ class ModelCheckpoint(Callback):
Args: Args:
run_context (RunContext): Context of the train running. run_context (RunContext): Context of the train running.
""" """
# 获取训练的参数
cb_params = run_context.original_args() cb_params = run_context.original_args()
# 设置保存最后一个checkpoint的标志为True
_to_save_last_ckpt = True _to_save_last_ckpt = True
# 保存最后一个checkpoint
self._save_ckpt(cb_params, _to_save_last_ckpt) self._save_ckpt(cb_params, _to_save_last_ckpt)
# 获取当前线程列表
thread_list = threading.enumerate() thread_list = threading.enumerate()
# 遍历线程列表
for thread in thread_list: for thread in thread_list:
# 如果线程名为"asyn_save_ckpt",则等待该线程结束
if thread.getName() == "asyn_save_ckpt": if thread.getName() == "asyn_save_ckpt":
thread.join() thread.join()
# 销毁所有gather cell
destroy_allgather_cell() destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save): def _check_save_ckpt(self, cb_params, force_to_save):
"""Check whether save checkpoint files or not.""" """Check whether save checkpoint files or not."""
# 如果配置了保存检查点步数且步数大于0
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
# 如果当前步数大于等于上次触发保存检查点步数加上保存检查点步数,或者强制保存检查点
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
or force_to_save is True: or force_to_save is True:
return True return True
# 如果配置了保存检查点秒数且秒数大于0
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
# 获取当前时间
self._cur_time = time.time() self._cur_time = time.time()
# 如果当前时间减去上次时间大于保存检查点秒数,或者强制保存检查点
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save: if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save:
# 更新上次时间
self._last_time = self._cur_time self._last_time = self._cur_time
return True return True
# 返回False
return False return False
def _save_ckpt(self, cb_params, force_to_save=False): def _save_ckpt(self, cb_params, force_to_save=False):
"""Save checkpoint files.""" """Save checkpoint files."""
# 如果当前步骤数等于最后触发的步骤数,则返回
if cb_params.cur_step_num == self._last_triggered_step: if cb_params.cur_step_num == self._last_triggered_step:
return return
# if param is cache enable, flush data from cache to host before save_ckpt # if param is cache enable, flush data from cache to host before save_ckpt
# 如果需要从缓存中刷新数据则调用_flush_from_cache方法
if self._need_flush_from_cache: if self._need_flush_from_cache:
self._flush_from_cache(cb_params) self._flush_from_cache(cb_params)
# 检查是否需要保存检查点如果force_to_save为True则强制保存
save_ckpt = self._check_save_ckpt(cb_params, force_to_save) save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
# 计算当前步数在epoch中的位置
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
# 如果需要保存检查点,则创建当前检查点的文件名
if save_ckpt: if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt" + str(step_num_in_epoch) + ".ckpt"
@ -489,43 +535,68 @@ class ModelCheckpoint(Callback):
self._manager.update_ckpoint_filelist(self._directory, self._prefix) self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number. # keep checkpoint files number equal max number.
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
# 如果keep_checkpoint_max配置存在且大于0且小于等于当前checkpoint文件数量则删除最旧的checkpoint文件
self._manager.remove_oldest_ckpoint_file() self._manager.remove_oldest_ckpoint_file()
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
# 如果keep_checkpoint_per_n_minutes配置存在且大于0则记录当前时间
self._cur_time_for_keep = time.time() self._cur_time_for_keep = time.time()
# 如果当前时间与上次记录的时间之差小于keep_checkpoint_per_n_minutes配置的分钟数乘以60则保留每个分钟的一个checkpoint文件
if (self._cur_time_for_keep - self._last_time_for_keep) \ if (self._cur_time_for_keep - self._last_time_for_keep) \
< self._config.keep_checkpoint_per_n_minutes * 60: < self._config.keep_checkpoint_per_n_minutes * 60:
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
self._cur_time_for_keep) self._cur_time_for_keep)
# generate the new checkpoint file and rename it. # generate the new checkpoint file and rename it.
# 定义全局变量_save_dir并将其赋值为self._directory
global _save_dir global _save_dir
_save_dir = self._directory _save_dir = self._directory
# 获取当前checkpoint文件的路径
cur_file = os.path.join(self._directory, cur_ckpoint_file) cur_file = os.path.join(self._directory, cur_ckpoint_file)
# 记录当前时间
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
# 记录当前触发步数
self._last_triggered_step = cb_params.cur_step_num self._last_triggered_step = cb_params.cur_step_num
# 如果启用了GEGraph Execution
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
# 设置当前网络
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
# 执行checkpoint图
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()
# 如果_append_dict中包含"epoch_num"
if "epoch_num" in self._append_dict: if "epoch_num" in self._append_dict:
# 将_append_epoch_num加上当前epoch数赋值给"epoch_num"
self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
# 如果_append_dict中包含"step_num"
if "step_num" in self._append_dict: if "step_num" in self._append_dict:
# 将_append_step_num加上当前step数赋值给"step_num"
self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
# 获取保存的网络如果self._config.saved_network不为None则使用self._config.saved_network否则使用cb_params.train_network
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
# 保存checkpoint
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
self._append_dict, self._config.enc_key, self._config.enc_mode) self._append_dict, self._config.enc_key, self._config.enc_mode)
# 记录最新的checkpoint文件名
self._latest_ckpt_file_name = cur_file self._latest_ckpt_file_name = cur_file
def _flush_from_cache(self, cb_params): def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable.""" """Flush cache data to host if tensor is cache enable."""
# 初始化has_cache_params为False
has_cache_params = False has_cache_params = False
# 获取训练网络中的参数
params = cb_params.train_network.get_parameters() params = cb_params.train_network.get_parameters()
# 遍历参数
for param in params: for param in params:
# 如果参数的cache_enable为True
if param.cache_enable: if param.cache_enable:
# 设置has_cache_params为True
has_cache_params = True has_cache_params = True
# 将参数的Tensor数据从缓存中刷新到主机
Tensor(param).flush_from_cache() Tensor(param).flush_from_cache()
# 如果没有参数的cache_enable为True
if not has_cache_params: if not has_cache_params:
# 设置_need_flush_from_cache为False
self._need_flush_from_cache = False self._need_flush_from_cache = False
@property @property
@ -535,63 +606,88 @@ class ModelCheckpoint(Callback):
class CheckpointManager: class CheckpointManager:
"""Manage checkpoint files according to train_config of checkpoint.""" """管理检查点文件,根据训练配置进行管理。"""
def __init__(self): def __init__(self):
"""初始化检查点管理器,创建空的检查点文件列表。"""
self._ckpoint_filelist = [] self._ckpoint_filelist = []
@property @property
def ckpoint_filelist(self): def ckpoint_filelist(self):
"""Get all the related checkpoint files managed here.""" """获取当前管理的所有检查点文件列表。"""
return self._ckpoint_filelist return self._ckpoint_filelist
@property @property
def ckpoint_num(self): def ckpoint_num(self):
"""Get the number of the related checkpoint files managed here.""" """获取当前管理的检查点文件数量。"""
return len(self._ckpoint_filelist) return len(self._ckpoint_filelist)
def update_ckpoint_filelist(self, directory, prefix): def update_ckpoint_filelist(self, directory, prefix):
"""Update the checkpoint file list.""" """更新检查点文件列表,根据目录和前缀筛选符合条件的检查点文件。"""
# 初始化一个空列表用于存储ckpt文件
self._ckpoint_filelist = [] self._ckpoint_filelist = []
# 获取指定目录下的所有文件
files = os.listdir(directory) files = os.listdir(directory)
# 遍历所有文件
for filename in files: for filename in files:
# 判断文件是否以指定前缀开头,并且以.ckpt结尾
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"): if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
# 获取文件名中间部分
mid_name = filename[len(prefix):-5] mid_name = filename[len(prefix):-5]
# 判断中间部分是否包含字母
flag = not (True in [char.isalpha() for char in mid_name]) flag = not (True in [char.isalpha() for char in mid_name])
# 如果不包含字母,则将文件路径添加到列表中
if flag: if flag:
self._ckpoint_filelist.append(os.path.join(directory, filename)) self._ckpoint_filelist.append(os.path.join(directory, filename))
def remove_ckpoint_file(self, file_name): def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" """从检查点管理器中移除指定的检查点文件,并从目录中删除该文件。"""
try: try:
# 修改文件权限为可写
os.chmod(file_name, stat.S_IWRITE) os.chmod(file_name, stat.S_IWRITE)
# 删除文件
os.remove(file_name) os.remove(file_name)
# 从ckpoint文件列表中移除该文件
self._ckpoint_filelist.remove(file_name) self._ckpoint_filelist.remove(file_name)
except OSError: except OSError:
# 捕获OSError异常并记录警告日志
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError: except ValueError:
# 捕获ValueError异常并记录警告日志
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def remove_oldest_ckpoint_file(self): def remove_oldest_ckpoint_file(self):
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" """移除检查点管理器中最早的检查点文件,并从目录中删除该文件。"""
# 获取所有checkpoint文件并按修改时间排序
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
# 删除最早修改的checkpoint文件
self.remove_ckpoint_file(ckpoint_files[0]) self.remove_ckpoint_file(ckpoint_files[0])
def keep_one_ckpoint_per_minutes(self, minutes, cur_time): def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" """保留每分钟生成的最新检查点文件,移除在指定时间范围内生成的其他文件。"""
# 定义一个空列表,用于存储需要删除的文件
del_list = [] del_list = []
# 定义一个空字符串,用于存储最旧的文件名
oldest_file = '' oldest_file = ''
# 定义一个变量,用于存储当前时间
oldest_time = cur_time oldest_time = cur_time
# 遍历_ckpoint_filelist中的文件
for ck_file in self._ckpoint_filelist: for ck_file in self._ckpoint_filelist:
# 获取文件的修改时间
modify_time = os.path.getmtime(ck_file) modify_time = os.path.getmtime(ck_file)
# 如果当前时间减去文件的修改时间小于60*minutes则将文件添加到del_list中
if cur_time - modify_time < 60 * minutes: if cur_time - modify_time < 60 * minutes:
del_list.append(ck_file) del_list.append(ck_file)
# 如果文件的修改时间小于oldest_time则更新oldest_time和oldest_file
if modify_time < oldest_time: if modify_time < oldest_time:
oldest_time = modify_time oldest_time = modify_time
oldest_file = ck_file oldest_file = ck_file
# 遍历del_list中的文件
for mv_file in del_list: for mv_file in del_list:
# 如果文件是最旧的文件,则跳过
if mv_file == oldest_file: if mv_file == oldest_file:
continue continue
self.remove_ckpoint_file(mv_file) # 调用remove_ckpoint_file方法删除文件
self.remove_ckpoint_file(mv_file)
Loading…
Cancel
Save