|
|
|
@ -17,65 +17,71 @@ from ..._checkparam import Validator as validator
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..primitive import prim_attr_register, PrimitiveWithCheck
|
|
|
|
|
from .. import signature as sig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpdateCache(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Update the value fo input_x, similar to ScatterNdUpdate.
|
|
|
|
|
The difference is that UpdateCache will not update when indices < 0 or indices >= max_num.
|
|
|
|
|
|
|
|
|
|
更新 input_x 的值,类似于 ScatterNdUpdate。
|
|
|
|
|
不同之处在于,UpdateCache 当 indices < 0 或 indices >= max_num 时不会更新。
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Parameter) - Parameter which is going to be updated.
|
|
|
|
|
- **indices** (Tensor) - Update indices of input_x.
|
|
|
|
|
- **updates** (Tensor) - The update values.
|
|
|
|
|
|
|
|
|
|
- **input_x** (Parameter) - 将要更新的参数。
|
|
|
|
|
- **indices** (Tensor) - input_x 的更新索引。
|
|
|
|
|
- **updates** (Tensor) - 更新值。
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **out** (Tensor) - Returns a [1] Tensor, which is not useful.
|
|
|
|
|
- **out** (Tensor) - 返回一个 [1] 的张量,这个张量没有用处。
|
|
|
|
|
"""
|
|
|
|
|
# 定义函数签名,指定输入参数的类型和读写权限
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
# 定义输入参数input_x,类型为T,读写权限为写
|
|
|
|
|
sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
|
|
|
|
|
dtype=sig.sig_dtype.T),
|
|
|
|
|
# 定义输入参数indices,类型为T1
|
|
|
|
|
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
|
|
|
|
|
# 定义输入参数updates,类型为T
|
|
|
|
|
sig.make_sig('updates', dtype=sig.sig_dtype.T),
|
|
|
|
|
# 定义输入参数max_num,类型为T1
|
|
|
|
|
sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""init UpdateCache"""
|
|
|
|
|
|
|
|
|
|
"""初始化 UpdateCache"""
|
|
|
|
|
|
|
|
|
|
# 初始化输入和输出名称
|
|
|
|
|
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
|
|
|
|
|
outputs=['out'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
|
|
|
|
|
# 检查输入形状
|
|
|
|
|
return [1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
|
|
|
|
|
# 检查输入数据类型
|
|
|
|
|
validator.check_tensor_dtype_valid(
|
|
|
|
|
"indices", indices_dtype, mstype.int_type, self.name)
|
|
|
|
|
return input_x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SubAndFilter(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Dynamic kernel, sub an offset and
|
|
|
|
|
return the elements which in range [0, max_num).
|
|
|
|
|
|
|
|
|
|
动态内核,减去一个偏移量并返回在范围 [0, max_num) 内的元素。
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - Input tensor.
|
|
|
|
|
- **max_num** (Int) - The max value of element that after sub `offset`.
|
|
|
|
|
- **offset** (int) - Specifies the offset value of this `input_x`.
|
|
|
|
|
|
|
|
|
|
- **input_x** (Tensor) - 输入张量。
|
|
|
|
|
- **max_num** (Int) - 减去 `offset` 后元素的最大值。
|
|
|
|
|
- **offset** (int) - 指定此 `input_x` 的偏移值。
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx.
|
|
|
|
|
- **filter_res** (Tensor) - The result that `input_x` minus `offset`,
|
|
|
|
|
and return which in the range [0, max_num).
|
|
|
|
|
- **filter_idx** (Tensor) - A tensor containing indices of elements in the input
|
|
|
|
|
coressponding to the output tensor.
|
|
|
|
|
|
|
|
|
|
tuple(Tensor), 由 2 个张量组成的元组,filter_res 和 filter_idx。
|
|
|
|
|
- **filter_res** (Tensor) - `input_x` 减去 `offset` 的结果,
|
|
|
|
|
并返回在范围 [0, max_num) 内的值。
|
|
|
|
|
- **filter_idx** (Tensor) - 一个张量,包含与输出张量对应的输入元素的索引。
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
`CPU`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32)
|
|
|
|
|
>>> max_num = 10
|
|
|
|
@ -87,35 +93,38 @@ class SubAndFilter(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""init SubAndFilter"""
|
|
|
|
|
|
|
|
|
|
"""初始化 SubAndFilter"""
|
|
|
|
|
|
|
|
|
|
# 初始化输入和输出名称
|
|
|
|
|
self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'],
|
|
|
|
|
outputs=['sub_res', 'sub_idx'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_shape(self, input_x_shape, max_num_shape, offset_shape):
|
|
|
|
|
# 检查输入形状
|
|
|
|
|
return ((-1,), (-1,))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype):
|
|
|
|
|
# 检查输入数据类型
|
|
|
|
|
validator.check_tensor_dtype_valid(
|
|
|
|
|
"input_x", input_x_dtype, mstype.int_type, self.name)
|
|
|
|
|
return input_x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MapUniform(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`.
|
|
|
|
|
|
|
|
|
|
通过公式映射一个张量:value = key % `group_num` * `per_group_size` + key // `group_num`。
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input** (Tensor) - Input Tensor.
|
|
|
|
|
- **per_group_size** (int) - The size of each group.
|
|
|
|
|
- **group_num** (int) - The number of group.
|
|
|
|
|
|
|
|
|
|
- **input** (Tensor) - 输入张量。
|
|
|
|
|
- **per_group_size** (int) - 每个组的大小。
|
|
|
|
|
- **group_num** (int) - 组的数量。
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, has the same dtype and shape as the `input`.
|
|
|
|
|
|
|
|
|
|
Tensor,具有与 `input` 相同的 dtype 和形状。
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
`CPU`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]))
|
|
|
|
|
>>> per_group_size = 4
|
|
|
|
@ -125,33 +134,34 @@ class MapUniform(PrimitiveWithCheck):
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[0, 4, 1, 5, 2, 6, 3, 7]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""init MapUniform"""
|
|
|
|
|
"""初始化 MapUniform"""
|
|
|
|
|
self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'],
|
|
|
|
|
outputs=['output'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype):
|
|
|
|
|
"""检查输入数据类型"""
|
|
|
|
|
validator.check_tensor_dtype_valid(
|
|
|
|
|
"input", input_dtype, mstype.int_type, self.name)
|
|
|
|
|
validator.check_value_type(
|
|
|
|
|
'per_group_size', per_group_size_dtype, [mstype.Int], self.name)
|
|
|
|
|
validator.check_value_type(
|
|
|
|
|
'group_num', group_num_dtype, [mstype.Int], self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CacheSwapTable(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
|
|
|
|
|
|
|
|
|
|
删除一个哈希映射条目,并插入一个新键到哈希映射中,返回删除条目的键和值。
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **cache_table** (Parameter) - The cache table which is on device.
|
|
|
|
|
- **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped.
|
|
|
|
|
- **miss_value** (int) - The values which arg going to swap into cache table.
|
|
|
|
|
|
|
|
|
|
- **cache_table** (Parameter) - 在设备上的缓存表。
|
|
|
|
|
- **swap_cache_idx** (Tensor) - 需要交换的表索引,-1 被跳过。
|
|
|
|
|
- **miss_value** (int) - 将要交换到缓存表的值。
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **old_value** (Tensor) - The values which are swapped out.
|
|
|
|
|
- **old_value** (Tensor) - 被交换出去的值。
|
|
|
|
|
"""
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
|
|
|
|
@ -159,31 +169,35 @@ class CacheSwapTable(PrimitiveWithCheck):
|
|
|
|
|
sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1),
|
|
|
|
|
sig.make_sig('miss_value', dtype=sig.sig_dtype.T)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""init CacheSwapTable"""
|
|
|
|
|
|
|
|
|
|
"""初始化 CacheSwapTable"""
|
|
|
|
|
|
|
|
|
|
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
|
|
|
|
|
outputs=['old_value'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
|
|
|
|
|
# 检查cache_table_shape的长度是否为2,如果不是,则抛出ValueError异常
|
|
|
|
|
if len(cache_table_shape) != 2:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"cache table shape must be 2, but got %d" % len(cache_table_shape))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 返回miss_value_shape
|
|
|
|
|
return miss_value_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
|
|
|
|
|
# 检查swap_cache_idx_dtype是否为mstype.int_type,如果不是,则抛出ValueError异常
|
|
|
|
|
validator.check_tensor_dtype_valid(
|
|
|
|
|
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
|
|
|
|
|
# 返回miss_value_dtype
|
|
|
|
|
return miss_value_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MapCacheIdx(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
|
|
|
|
|
When input an indices tensor, it will output the cache indices which search in hashmap.
|
|
|
|
|
MapCacheIdx 将 SearchCacheIdx、CacheSwapHashmap 和 UpdateCache 合并在一起。
|
|
|
|
|
当输入一个索引张量时,它将输出在哈希映射中搜索的缓存索引。
|
|
|
|
|
"""
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
|
|
|
|
@ -193,56 +207,69 @@ class MapCacheIdx(PrimitiveWithCheck):
|
|
|
|
|
sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
|
|
|
|
|
sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""init MapCacheIdx"""
|
|
|
|
|
|
|
|
|
|
"""初始化 MapCacheIdx"""
|
|
|
|
|
|
|
|
|
|
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
|
|
|
|
|
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check__(self, hashmap, indices, step, emb_max_num, offset):
|
|
|
|
|
# 获取hashmap的形状
|
|
|
|
|
hashmap_shape = hashmap['shape']
|
|
|
|
|
# 如果hashmap的维度不是2,则抛出异常
|
|
|
|
|
if len(hashmap_shape) != 2:
|
|
|
|
|
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
|
|
|
|
|
"but got %d." % len(hashmap_shape))
|
|
|
|
|
# 设置输出的形状
|
|
|
|
|
out_shape = (indices['shape'], -1, -1, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取hashmap和indices的数据类型
|
|
|
|
|
hashmap_dtype = hashmap['dtype']
|
|
|
|
|
indices_dtype = indices['dtype']
|
|
|
|
|
# 将数据类型存入字典
|
|
|
|
|
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
|
|
|
|
|
# 检查数据类型是否相同且有效
|
|
|
|
|
validator.check_tensors_dtypes_same_and_valid(
|
|
|
|
|
args, mstype.int_type, self.name)
|
|
|
|
|
# 设置输出的数据类型
|
|
|
|
|
out_dtype = (hashmap_dtype, hashmap_dtype,
|
|
|
|
|
hashmap_dtype, hashmap_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 设置输出的字典
|
|
|
|
|
out = {'shape': out_shape,
|
|
|
|
|
'dtype': out_dtype,
|
|
|
|
|
'value': None}
|
|
|
|
|
# 如果indices中有max_shape,则设置输出的max_shape
|
|
|
|
|
if 'max_shape' in indices:
|
|
|
|
|
out['max_shape'] = (indices['max_shape'], indices['max_shape'],
|
|
|
|
|
indices['max_shape'], indices['max_shape'])
|
|
|
|
|
# 否则,设置输出的max_shape为indices的形状
|
|
|
|
|
else:
|
|
|
|
|
out['max_shape'] = (indices['shape'], indices['shape'],
|
|
|
|
|
indices['shape'], indices['shape'])
|
|
|
|
|
# 如果indices中有min_shape,则设置输出的min_shape
|
|
|
|
|
if 'min_shape' in indices:
|
|
|
|
|
out['min_shape'] = (indices['min_shape'], 0, 0, 0)
|
|
|
|
|
# 否则,设置输出的min_shape为(0, 0, 0, 0)
|
|
|
|
|
else:
|
|
|
|
|
out['min_shape'] = (0, 0, 0, 0)
|
|
|
|
|
# 返回输出的字典
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicAssign(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Assigns `Parameter` with a value, the `value` can have a dynamic shape.
|
|
|
|
|
|
|
|
|
|
将 `Parameter` 与值分配,`value` 可以具有动态形状。
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **variable** (Parameter) - The `Parameter`.
|
|
|
|
|
- **value** (Tensor) - The value to be assigned.
|
|
|
|
|
|
|
|
|
|
- **variable** (Parameter) - `Parameter`。
|
|
|
|
|
- **value** (Tensor) - 要分配的值。
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, has the same type as original `variable`.
|
|
|
|
|
|
|
|
|
|
Tensor,具有与原始 `variable` 相同的类型。
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
`CPU`
|
|
|
|
|
"""
|
|
|
|
@ -250,41 +277,42 @@ class DynamicAssign(PrimitiveWithCheck):
|
|
|
|
|
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
|
|
|
|
sig.make_sig('value', dtype=sig.sig_dtype.T)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_dtype(self, variable, value):
|
|
|
|
|
# 检查变量是否为mstype.type_refkey
|
|
|
|
|
if variable != mstype.type_refkey:
|
|
|
|
|
# 检查变量是否为mstype.number_type类型
|
|
|
|
|
validator.check_tensor_dtype_valid(
|
|
|
|
|
"variable", variable, mstype.number_type, self.name)
|
|
|
|
|
# 检查value是否为mstype.number_type类型
|
|
|
|
|
validator.check_scalar_or_tensor_types_same(
|
|
|
|
|
{"value": value}, mstype.number_type, self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PadAndShift(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Pad a tensor with -1, and shift with a length.
|
|
|
|
|
|
|
|
|
|
用 -1 填充张量,并按长度进行移位。
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **input_x** (Tensor) - The input Tensor, which will be copied
|
|
|
|
|
to `output`.
|
|
|
|
|
- **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is
|
|
|
|
|
the pad length of output tensor, cum_sum_arr[shift_idx] is
|
|
|
|
|
the start to shift, and cum_sum_arr[shift_idx+1] is the end.
|
|
|
|
|
- **shift_idx** (Int) - The idx of cum_sum_arr.
|
|
|
|
|
if use python, PadAndShift is:
|
|
|
|
|
- **input_x** (Tensor) - 输入张量,将被复制到 `output`。
|
|
|
|
|
- **cum_sum_arr** (Tensor) - cum_sum_arr 的最后一个值是输出张量的填充长度,
|
|
|
|
|
cum_sum_arr[shift_idx] 是开始移位,cum_sum_arr[shift_idx+1] 是结束。
|
|
|
|
|
- **shift_idx** (Int) - cum_sum_arr 的索引。
|
|
|
|
|
如果使用 Python,PadAndShift 为:
|
|
|
|
|
output = [-1] * cum_sum_arr[-1]
|
|
|
|
|
start = cum_sum_arr[shift_idx]
|
|
|
|
|
end = cum_sum_arr[shift_idx + 1]
|
|
|
|
|
output[start:end] = input_x[:(end-start)]
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, has the same type as original `variable`.
|
|
|
|
|
|
|
|
|
|
Tensor,具有与原始 `variable` 相同的类型。
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
`CPU`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32)
|
|
|
|
|
>>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32)
|
|
|
|
@ -296,11 +324,14 @@ class PadAndShift(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self):
|
|
|
|
|
# 初始化输入输出名称
|
|
|
|
|
self.init_prim_io_names(
|
|
|
|
|
inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape):
|
|
|
|
|
# 检查输入形状
|
|
|
|
|
return input_x_shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype):
|
|
|
|
|
return input_x_dtype
|
|
|
|
|
# 检查输入数据类型
|
|
|
|
|
return input_x_dtype
|