diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/builtin_operations.py b/src/mindspore2022/mindspore/python/mindspore/_extends/builtin_operations.py index 1b0e4e9a..d07dbf71 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/builtin_operations.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/builtin_operations.py @@ -22,31 +22,100 @@ from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype def ScalarAdd(x, y): + """ + 实现标量加法运算。 + + Args: + x (float): 第一个加数。 + y (float): 第二个加数。 + + Returns: + float: x 和 y 的和。 + + """ """Implement `scalar_add`.""" return x + y def ScalarMul(x, y): + """ + 标量乘法函数。 + + Args: + x (float): 第一个标量。 + y (float): 第二个标量。 + + Returns: + float: 两个标量的乘积。 + + """ """Implement `scalar_mul`.""" return x * y def ScalarMod(x, y): + """ + 对两个数进行模运算。 + + Args: + x (int or float): 被模数。 + y (int or float): 模数。 + + Returns: + int or float: x 对 y 取模的结果。 + + """ """Implement `scalar_mul`.""" return x % y def ScalarSub(x, y): + """ + 实现标量减法运算。 + + Args: + x (float): 第一个标量值。 + y (float): 第二个标量值。 + + Returns: + float: x 和 y 的差值。 + + """ """Implement `scalar_sub`.""" return x - y def ScalarUsub(x): + """ + 对给定的标量 x 进行取反操作。 + + Args: + x (float or int): 需要取反的标量。 + + Returns: + float or int: 取反后的标量。 + + """ """Implement `scalar_usub`.""" return -x def TupleGetItem(x, index): + """ + 从给定对象中获取指定索引处的元素。 + + Args: + x (Union[Tensor, dict]): 输入对象,可以是Tensor类型或字典类型。 + index (int): 要获取的元素的索引。 + + Returns: + Union[Tensor, Any]: 如果输入是Tensor类型,则返回Tensor类型的元素; + 如果输入是字典类型,则返回字典中对应索引的值; + 否则,返回输入对象中对应索引的元素。 + + Raises: + IndexError: 如果索引超出范围。 + """ """Implement `tuple_getitem`.""" if isinstance(x, Tensor): x = x.asnumpy() @@ -64,36 +133,111 @@ def TupleGetItem(x, index): def scalar_gt(x, y): + """ + 判断两个标量值x和y的大小关系。 + + Args: + x (float or int): 第一个标量值。 + y (float or int): 第二个标量值。 + + Returns: + bool: 如果x大于y,则返回True;否则返回False。 + + """ """Implement `scalar_gt`.""" return x > y def scalar_ne(x, y): + """ + 比较两个标量值是否不相等。 + + Args: + x (float): 第一个标量值。 + y (float): 第二个标量值。 + + Returns: + bool: 如果 x 不等于 y,则返回 True;否则返回 False。 + + """ """Implement `scalar_ne`.""" return x != y def scalar_eq(x, y): + """ + 判断两个标量值是否相等。 + + Args: + x (Any): 第一个标量值。 + y (Any): 第二个标量值。 + + Returns: + bool: 如果 x 和 y 相等,返回 True;否则返回 False。 + + """ """Implement `scalar_eq`.""" return x == y def scalar_le(x, y): + """ + 判断标量 x 是否小于等于标量 y。 + + Args: + x (float): 第一个标量值。 + y (float): 第二个标量值。 + + Returns: + bool: 如果 x 小于等于 y,则返回 True;否则返回 False。 + + """ """Implement `scalar_le`.""" return x <= y def scalar_lt(x, y): + """ + 判断两个标量值的大小关系。 + + Args: + x (float): 第一个标量值。 + y (float): 第二个标量值。 + + Returns: + bool: 如果 x 小于 y,则返回 True;否则返回 False。 + + """ """Implement `scalar_lt`.""" return x < y def identity(x): + """ + 返回输入参数本身。 + + Args: + x: 任何类型的输入参数。 + + Returns: + 返回输入参数本身。 + + """ """Implement `identity`.""" return x def zeros_like_tensor(x): + """ + 根据给定的张量x创建一个形状相同但所有元素为零的新张量。 + + Args: + x (Tensor): 输入的张量,用于确定新张量的形状。 + + Returns: + Tensor: 一个与输入张量x形状相同但所有元素为零的新张量。 + + """ """Implement `zeros_like_tensor`.""" x = x.asnumpy() value = Tensor(np.zeros(x.shape).astype(np.float32)) @@ -101,61 +245,201 @@ def zeros_like_tensor(x): def Switch(c, x, y): + """ + 实现 `switch` 功能。 + + Args: + c (bool): 条件值,如果为 True,则返回 x,否则返回 y。 + x (Any): 条件为 True 时返回的值。 + y (Any): 条件为 False 时返回的值。 + + Returns: + Any: 根据条件 c 返回 x 或 y。 + + """ """Implement `switch`.""" return x if c else y def list_getitem(data, item): + """ + 从列表中获取指定索引处的元素。 + + Args: + data (list): 待查询的列表。 + item (int): 要获取的元素的索引。 + + Returns: + 返回列表中索引为item的元素。 + + Raises: + IndexError: 如果索引超出列表范围。 + """ """Implement `list_getitem`.""" return data[item] def bool_not(x): + """ + 对输入值取反。 + + Args: + x (bool): 要取反的布尔值。 + + Returns: + bool: x 的逻辑非值。 + + """ """Implement `bool_not`.""" return not x def bool_and(x, y): + """ + 对两个布尔值进行逻辑与操作。 + + Args: + x (bool): 第一个布尔值。 + y (bool): 第二个布尔值。 + + Returns: + bool: 返回两个布尔值进行逻辑与操作后的结果。 + + """ """Implement `bool_and`.""" return x and y def bool_or(x, y): + """ + 实现布尔或运算。 + + Args: + x (bool): 第一个布尔值。 + y (bool): 第二个布尔值。 + + Returns: + bool: 如果 x 或 y 为 True,则返回 True,否则返回 False。 + + """ """Implement `bool_or`.""" return x or y def make_list(*xs): + """ + 将不定数量的参数转换为一个列表。 + + Args: + *xs: 不定数量的参数,可以是任意类型。 + + Returns: + list: 包含所有传入参数的列表。 + + Examples: + >>> make_list(1, 2, 3) + [1, 2, 3] + >>> make_list('a', 'b', 'c') + ['a', 'b', 'c'] + >>> make_list(1, 'a', [1, 2, 3]) + [1, 'a', [1, 2, 3]] + """ """Implement `make_list`.""" return list(xs) def list_len(x): + """ + 计算列表的长度。 + + Args: + x (list): 需要计算长度的列表。 + + Returns: + int: 列表的长度。 + + """ """Implement `list_len`.""" return len(x) def Depend(value, expr): + """ + 依赖函数,根据给定的表达式返回相应的值。 + + Args: + value (Any): 要返回的值。 + expr (Any): 表达式,该参数在当前实现中被忽略。 + + Returns: + Any: 返回与输入相同的值。 + + """ """Implement `Depend`.""" return value def UpdateState(monad, *exprs): + """ + 更新状态。 + + Args: + monad (Monad): 一个符合 Monad 类型的对象。 + *exprs (Any): 需要更新的表达式,可以为任意类型。 + + Returns: + Monad: 更新后的 Monad 对象。 + + """ """Implement `UpdateState`.""" return monad def Load(value, u=None): + """ + 加载指定的值。 + + Args: + value (Any): 要加载的值。 + u (Optional[Any], optional): 可选参数,默认为None。当前版本未使用,保留以便未来扩展。 + + Returns: + Any: 返回加载的值。 + + """ """Implement `Load`.""" return value # only used in PyNative mode def make_ref(key, value, ref): + """ + 创建一个引用对象。 + + Args: + key (str): 键名,用于标识引用的对象。 + value (Any): 引用对象的值。 + ref (Any): 引用对象,可以为任意类型。 + + Returns: + Any: 返回引用的值。 + + """ return value def scalar_cast(x, t): + """ + 将标量值x转换为指定的NumPy数据类型t。 + + Args: + x (float, int): 要转换的标量值。 + t (np.dtype): 目标NumPy数据类型。 + + Returns: + Any: 转换后的标量值,类型为t。 + + """ """Implement scalar_cast.""" np_type = dtype_to_nptype(t) value = np_type(x) @@ -164,16 +448,46 @@ def scalar_cast(x, t): def typeof(x): + """ + 实现 typeof 函数。 + + Args: + x (Any): 要获取类型的对象。 + + Returns: + str: 返回传入对象的Python类型名称。 + + """ """Implement typeof.""" return get_py_obj_dtype(x) def tuple_to_array(x): + """ + 将元组转换为数组。 + + Args: + x (tuple): 待转换的元组。 + + Returns: + Tensor: 转换后的数组。 + + """ """Implement `tuple_to_array`.""" return Tensor(np.array(x)) def stop_gradient(x): + """ + 停止梯度传播。 + + Args: + x (Tensor): 需要停止梯度传播的张量。 + + Returns: + Tensor: 停止梯度传播的张量。 + + """ """Implement `stop_gradient`.""" return x @@ -182,6 +496,20 @@ hyper_map = C.HyperMap() def mixed_precision_cast(dst_type, x): + """ + 实现混合精度转换函数。 + + Args: + dst_type (mstype.Type): 目标数据类型。 + x (Union[Tensor, list, tuple]): 需要进行类型转换的数据,可以是单个Tensor,也可以是一个包含Tensor的列表或元组。 + + Returns: + Union[Tensor, list, tuple]: 转换后的数据,类型与输入一致。 + + Raises: + TypeError: 如果输入数据类型不支持,将引发TypeError异常。 + + """ """Implement `mixed_precision_cast`.""" def cast_inner(data):