|
|
|
@ -22,31 +22,100 @@ from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ScalarAdd(x, y):
|
|
|
|
|
"""
|
|
|
|
|
实现标量加法运算。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float): 第一个加数。
|
|
|
|
|
y (float): 第二个加数。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
float: x 和 y 的和。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_add`."""
|
|
|
|
|
return x + y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ScalarMul(x, y):
|
|
|
|
|
"""
|
|
|
|
|
标量乘法函数。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float): 第一个标量。
|
|
|
|
|
y (float): 第二个标量。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
float: 两个标量的乘积。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_mul`."""
|
|
|
|
|
return x * y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ScalarMod(x, y):
|
|
|
|
|
"""
|
|
|
|
|
对两个数进行模运算。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (int or float): 被模数。
|
|
|
|
|
y (int or float): 模数。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int or float: x 对 y 取模的结果。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_mul`."""
|
|
|
|
|
return x % y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ScalarSub(x, y):
|
|
|
|
|
"""
|
|
|
|
|
实现标量减法运算。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float): 第一个标量值。
|
|
|
|
|
y (float): 第二个标量值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
float: x 和 y 的差值。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_sub`."""
|
|
|
|
|
return x - y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ScalarUsub(x):
|
|
|
|
|
"""
|
|
|
|
|
对给定的标量 x 进行取反操作。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float or int): 需要取反的标量。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
float or int: 取反后的标量。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_usub`."""
|
|
|
|
|
return -x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def TupleGetItem(x, index):
|
|
|
|
|
"""
|
|
|
|
|
从给定对象中获取指定索引处的元素。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (Union[Tensor, dict]): 输入对象,可以是Tensor类型或字典类型。
|
|
|
|
|
index (int): 要获取的元素的索引。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Union[Tensor, Any]: 如果输入是Tensor类型,则返回Tensor类型的元素;
|
|
|
|
|
如果输入是字典类型,则返回字典中对应索引的值;
|
|
|
|
|
否则,返回输入对象中对应索引的元素。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
IndexError: 如果索引超出范围。
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `tuple_getitem`."""
|
|
|
|
|
if isinstance(x, Tensor):
|
|
|
|
|
x = x.asnumpy()
|
|
|
|
@ -64,36 +133,111 @@ def TupleGetItem(x, index):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_gt(x, y):
|
|
|
|
|
"""
|
|
|
|
|
判断两个标量值x和y的大小关系。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float or int): 第一个标量值。
|
|
|
|
|
y (float or int): 第二个标量值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 如果x大于y,则返回True;否则返回False。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_gt`."""
|
|
|
|
|
return x > y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_ne(x, y):
|
|
|
|
|
"""
|
|
|
|
|
比较两个标量值是否不相等。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float): 第一个标量值。
|
|
|
|
|
y (float): 第二个标量值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 如果 x 不等于 y,则返回 True;否则返回 False。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_ne`."""
|
|
|
|
|
return x != y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_eq(x, y):
|
|
|
|
|
"""
|
|
|
|
|
判断两个标量值是否相等。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (Any): 第一个标量值。
|
|
|
|
|
y (Any): 第二个标量值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 如果 x 和 y 相等,返回 True;否则返回 False。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_eq`."""
|
|
|
|
|
return x == y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_le(x, y):
|
|
|
|
|
"""
|
|
|
|
|
判断标量 x 是否小于等于标量 y。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float): 第一个标量值。
|
|
|
|
|
y (float): 第二个标量值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 如果 x 小于等于 y,则返回 True;否则返回 False。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_le`."""
|
|
|
|
|
return x <= y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_lt(x, y):
|
|
|
|
|
"""
|
|
|
|
|
判断两个标量值的大小关系。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float): 第一个标量值。
|
|
|
|
|
y (float): 第二个标量值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 如果 x 小于 y,则返回 True;否则返回 False。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `scalar_lt`."""
|
|
|
|
|
return x < y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def identity(x):
|
|
|
|
|
"""
|
|
|
|
|
返回输入参数本身。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x: 任何类型的输入参数。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
返回输入参数本身。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `identity`."""
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def zeros_like_tensor(x):
|
|
|
|
|
"""
|
|
|
|
|
根据给定的张量x创建一个形状相同但所有元素为零的新张量。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (Tensor): 输入的张量,用于确定新张量的形状。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 一个与输入张量x形状相同但所有元素为零的新张量。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `zeros_like_tensor`."""
|
|
|
|
|
x = x.asnumpy()
|
|
|
|
|
value = Tensor(np.zeros(x.shape).astype(np.float32))
|
|
|
|
@ -101,61 +245,201 @@ def zeros_like_tensor(x):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Switch(c, x, y):
|
|
|
|
|
"""
|
|
|
|
|
实现 `switch` 功能。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
c (bool): 条件值,如果为 True,则返回 x,否则返回 y。
|
|
|
|
|
x (Any): 条件为 True 时返回的值。
|
|
|
|
|
y (Any): 条件为 False 时返回的值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Any: 根据条件 c 返回 x 或 y。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `switch`."""
|
|
|
|
|
return x if c else y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_getitem(data, item):
|
|
|
|
|
"""
|
|
|
|
|
从列表中获取指定索引处的元素。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
data (list): 待查询的列表。
|
|
|
|
|
item (int): 要获取的元素的索引。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
返回列表中索引为item的元素。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
IndexError: 如果索引超出列表范围。
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `list_getitem`."""
|
|
|
|
|
return data[item]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bool_not(x):
|
|
|
|
|
"""
|
|
|
|
|
对输入值取反。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (bool): 要取反的布尔值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: x 的逻辑非值。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `bool_not`."""
|
|
|
|
|
return not x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bool_and(x, y):
|
|
|
|
|
"""
|
|
|
|
|
对两个布尔值进行逻辑与操作。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (bool): 第一个布尔值。
|
|
|
|
|
y (bool): 第二个布尔值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 返回两个布尔值进行逻辑与操作后的结果。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `bool_and`."""
|
|
|
|
|
return x and y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bool_or(x, y):
|
|
|
|
|
"""
|
|
|
|
|
实现布尔或运算。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (bool): 第一个布尔值。
|
|
|
|
|
y (bool): 第二个布尔值。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
bool: 如果 x 或 y 为 True,则返回 True,否则返回 False。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `bool_or`."""
|
|
|
|
|
return x or y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_list(*xs):
|
|
|
|
|
"""
|
|
|
|
|
将不定数量的参数转换为一个列表。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
*xs: 不定数量的参数,可以是任意类型。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: 包含所有传入参数的列表。
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> make_list(1, 2, 3)
|
|
|
|
|
[1, 2, 3]
|
|
|
|
|
>>> make_list('a', 'b', 'c')
|
|
|
|
|
['a', 'b', 'c']
|
|
|
|
|
>>> make_list(1, 'a', [1, 2, 3])
|
|
|
|
|
[1, 'a', [1, 2, 3]]
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `make_list`."""
|
|
|
|
|
return list(xs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def list_len(x):
|
|
|
|
|
"""
|
|
|
|
|
计算列表的长度。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (list): 需要计算长度的列表。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int: 列表的长度。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `list_len`."""
|
|
|
|
|
return len(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Depend(value, expr):
|
|
|
|
|
"""
|
|
|
|
|
依赖函数,根据给定的表达式返回相应的值。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
value (Any): 要返回的值。
|
|
|
|
|
expr (Any): 表达式,该参数在当前实现中被忽略。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Any: 返回与输入相同的值。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `Depend`."""
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def UpdateState(monad, *exprs):
|
|
|
|
|
"""
|
|
|
|
|
更新状态。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
monad (Monad): 一个符合 Monad 类型的对象。
|
|
|
|
|
*exprs (Any): 需要更新的表达式,可以为任意类型。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Monad: 更新后的 Monad 对象。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `UpdateState`."""
|
|
|
|
|
return monad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Load(value, u=None):
|
|
|
|
|
"""
|
|
|
|
|
加载指定的值。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
value (Any): 要加载的值。
|
|
|
|
|
u (Optional[Any], optional): 可选参数,默认为None。当前版本未使用,保留以便未来扩展。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Any: 返回加载的值。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `Load`."""
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# only used in PyNative mode
|
|
|
|
|
def make_ref(key, value, ref):
|
|
|
|
|
"""
|
|
|
|
|
创建一个引用对象。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
key (str): 键名,用于标识引用的对象。
|
|
|
|
|
value (Any): 引用对象的值。
|
|
|
|
|
ref (Any): 引用对象,可以为任意类型。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Any: 返回引用的值。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scalar_cast(x, t):
|
|
|
|
|
"""
|
|
|
|
|
将标量值x转换为指定的NumPy数据类型t。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (float, int): 要转换的标量值。
|
|
|
|
|
t (np.dtype): 目标NumPy数据类型。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Any: 转换后的标量值,类型为t。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement scalar_cast."""
|
|
|
|
|
np_type = dtype_to_nptype(t)
|
|
|
|
|
value = np_type(x)
|
|
|
|
@ -164,16 +448,46 @@ def scalar_cast(x, t):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def typeof(x):
|
|
|
|
|
"""
|
|
|
|
|
实现 typeof 函数。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (Any): 要获取类型的对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: 返回传入对象的Python类型名称。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement typeof."""
|
|
|
|
|
return get_py_obj_dtype(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def tuple_to_array(x):
|
|
|
|
|
"""
|
|
|
|
|
将元组转换为数组。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (tuple): 待转换的元组。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 转换后的数组。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `tuple_to_array`."""
|
|
|
|
|
return Tensor(np.array(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def stop_gradient(x):
|
|
|
|
|
"""
|
|
|
|
|
停止梯度传播。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (Tensor): 需要停止梯度传播的张量。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor: 停止梯度传播的张量。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `stop_gradient`."""
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -182,6 +496,20 @@ hyper_map = C.HyperMap()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mixed_precision_cast(dst_type, x):
|
|
|
|
|
"""
|
|
|
|
|
实现混合精度转换函数。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
dst_type (mstype.Type): 目标数据类型。
|
|
|
|
|
x (Union[Tensor, list, tuple]): 需要进行类型转换的数据,可以是单个Tensor,也可以是一个包含Tensor的列表或元组。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Union[Tensor, list, tuple]: 转换后的数据,类型与输入一致。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: 如果输入数据类型不支持,将引发TypeError异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Implement `mixed_precision_cast`."""
|
|
|
|
|
|
|
|
|
|
def cast_inner(data):
|
|
|
|
|