董昊千代码注读

branch-donghaoqian
donghaoqian 7 months ago
parent 69ae5ccbe1
commit db2c8b9de0

@ -1,47 +1,117 @@
# Copyright 2020-2022 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# # 代码版权声明说明此代码由华为技术有限公司在2020-2022年间开发
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# 说明此代码使用Apache License 2.0版本的许可证
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# 说明除非遵守许可证,否则不得使用此文件
# You may obtain a copy of the License at # You may obtain a copy of the License at
# # 提供许可证的获取地址
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# # 许可证的具体地址
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# 除非适用法律要求或书面同意
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# 许可证在“现状”基础上进行分发
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 不附带任何形式的明示或暗示的担保或条件
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# 请参阅许可证了解特定的权限和
# limitations under the License. # limitations under the License.
# 限制条件
# ============================================================================ # ============================================================================
# 标准分割线,通常用于分隔许可证部分与代码部分
"""cell""" """cell"""
# 文档字符串模块的名称为cell
import gc import gc
# 导入垃圾回收模块,用于管理内存
import inspect import inspect
# 导入inspect模块用于获取活对象的信息
import os import os
# 导入os模块用于与操作系统进行交互
import time import time
# 导入time模块用于处理时间相关操作
from collections import OrderedDict from collections import OrderedDict
# 从collections模块导入OrderedDict类用于创建有序字典
from types import FunctionType, MethodType from types import FunctionType, MethodType
# 从types模块导入FunctionType和MethodType类用于类型检查
import numpy import numpy
# 导入numpy模块用于科学计算
from mindspore._checkparam import args_type_check from mindspore._checkparam import args_type_check
# 从mindspore._checkparam模块导入args_type_check函数用于检查函数参数的类型
from mindspore import log as logger from mindspore import log as logger
# 从mindspore模块导入log模块并命名为logger用于日志记录
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
# 从mindspore.common.parameter模块导入PARAMETER_NAME_DEFAULT常量用于默认参数名称
from mindspore.common.hook_handle import HookHandle from mindspore.common.hook_handle import HookHandle
# 从mindspore.common.hook_handle模块导入HookHandle类用于管理钩子处理
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
# 从mindspore.context模块导入ParallelMode类用于并行模式配置
from mindspore.ops.composite import Shard from mindspore.ops.composite import Shard
# 从mindspore.ops.composite模块导入Shard类用于分片操作
from .. import context from .. import context
# 导入相对路径的context模块用于上下文配置
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
# 从相对路径的_c_expression模块导入多个函数和类用于初始化管道、更新函数图超参数、Cell的基础类、函数图类、混合精度类型
from .._checkparam import Validator from .._checkparam import Validator
# 从相对路径的_checkparam模块导入Validator类用于参数验证
from ..common import dtype as mstype from ..common import dtype as mstype
# 从相对路径的common模块导入dtype模块并重命名为mstype用于数据类型定义
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
# 从相对路径的common.api模块导入多个函数和类用于单元图执行器、原生模式执行器、检查所有张量、编译缓存
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
# 从相对路径的common.parameter模块导入Parameter类和ParameterTuple类用于参数和参数元组
from ..common.variable import Variable from ..common.variable import Variable
# 从相对路径的common.variable模块导入Variable类用于变量表示
from ..common.tensor import Tensor, CSRTensor, COOTensor from ..common.tensor import Tensor, CSRTensor, COOTensor
# 从相对路径的common.tensor模块导入Tensor类、CSRTensor类和COOTensor类用于张量表示
from ..ops.operations import Cast from ..ops.operations import Cast
# 从相对路径的ops.operations模块导入Cast类用于类型转换操作
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
# 从相对路径的ops.primitive模块导入Primitive类用于基础操作
from ..ops.operations import _inner_ops as inner from ..ops.operations import _inner_ops as inner
from ..parallel._tensor import _load_tensor_by_layout # 从相对路径的ops.operations模块导入_inner_ops并重命名为inner用于内部操作
from ..parallel._tensor import _load_tensor_by_layout
# 从相对路径的parallel._tensor模块导入_load_tensor_by_layout函数用于按布局加载张量
class Cell(Cell_): class Cell(Cell_):
# 定义Cell类继承自Cell_类这是MindSpore中神经网络的基本构建单元
""" """
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
base class. base class.
@ -81,6 +151,8 @@ class Cell(Cell_):
... # the parameter's name will be 'net.weight'. ... # the parameter's name will be 'net.weight'.
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)] [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
""" """
# 类文档字符串解释Cell类的作用、继承关系、参数、支持平台及示例
class _CellGuard: class _CellGuard:
"""Detecting whether the cell is a top-level cell with the 'with statement'.""" """Detecting whether the cell is a top-level cell with the 'with statement'."""

@ -22,61 +22,68 @@ from ...common.api import ms_function
class _FirstGrad(Cell): class _FirstGrad(Cell):
# 计算第一个梯度的类
def __init__(self, fn): def __init__(self, fn):
super(_FirstGrad, self).__init__() super(_FirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True) self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
self.fn = fn self.fn = fn
def construct(self, u, first_grad_input): def construct(self, u, first_grad_input):
# 构造方法,用于计算梯度
return self.first_grad_op(self.fn)(*first_grad_input, u) return self.first_grad_op(self.fn)(*first_grad_input, u)
class _JvpFirstGrad(Cell): class _JvpFirstGrad(Cell):
# 计算Jacobian-Vector-Product的第一个梯度的类
def __init__(self): def __init__(self):
super(_JvpFirstGrad, self).__init__() super(_JvpFirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True) self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
def construct(self, u, fn, first_grad_input): def construct(self, u, fn, first_grad_input):
# 构造方法用于计算JVP的第一个梯度
return self.first_grad_op(fn)(*first_grad_input, u) return self.first_grad_op(fn)(*first_grad_input, u)
class _FirstGradSingleValue(Cell): class _FirstGradSingleValue(Cell):
# 计算单值梯度的类
def __init__(self, fn): def __init__(self, fn):
super(_FirstGradSingleValue, self).__init__() super(_FirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True) self.first_grad_single_value_op = C.GradOperation(sens_param=True)
self.fn = fn self.fn = fn
def construct(self, u, first_grad_single_value_input): def construct(self, u, first_grad_single_value_input):
# 构造方法,用于计算单值梯度
return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u) return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u)
class _JvpFirstGradSingleValue(Cell): class _JvpFirstGradSingleValue(Cell):
# 计算Jacobian-Vector-Product的单值梯度的类
def __init__(self): def __init__(self):
super(_JvpFirstGradSingleValue, self).__init__() super(_JvpFirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True) self.first_grad_single_value_op = C.GradOperation(sens_param=True)
def construct(self, u, fn, first_grad_single_value_input): def construct(self, u, fn, first_grad_single_value_input):
# 构造方法用于计算JVP的单值梯度
return self.first_grad_single_value_op(fn)(*first_grad_single_value_input, u) return self.first_grad_single_value_op(fn)(*first_grad_single_value_input, u)
class Jvp(Cell): class Jvp(Cell):
""" """
Compute the jacobian-vector-product of the given fn. Jvp is equivalent to forward mode autodiff. 计算给定fn的雅可比向量积Jvp等同于前向模式自动微分
Args: Args:
fn (Cell): The fn that takes Tensor inputs and returns a tuple of Tensors or a Tensor. fn (Cell): 接受Tensor输入并返回Tensor元组或Tensor的fn
Inputs: Inputs:
- **inputs** (Tensors) - The inputs to `fn`. - **inputs** (Tensors) - `fn`的输入
- **v** (Tensors or Tuple of Tensors) - The vector for which the Jacobian vector product is computed. - **v** (Tensors Tensor元组) - 用于计算雅可比向量积的向量
Must have the same size as the input of `fn`. 必须与`fn`的输入大小相同
Outputs: Outputs:
A tuple with 2 Tensors or Tuple of Tensors: 包含2个Tensors或Tensor元组的元组
- **net_output** (Tensors or Tuple of Tensors) - The output of `fn(inputs)`. - **net_output** (Tensors Tensor元组) - `fn(inputs)`的输出
- **jvp** (Tensors or Tuple of Tensors) - The result of the jacobian vector product. - **jvp** (Tensors Tensor元组) - 雅可比向量积的结果
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -113,6 +120,7 @@ class Jvp(Cell):
@ms_function @ms_function
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算JVP
jvp_input = args[0:-1] jvp_input = args[0:-1]
v = args[-1] v = args[-1]
output = self.fn(*jvp_input) output = self.fn(*jvp_input)
@ -135,8 +143,8 @@ class Jvp(Cell):
class _JvpInner(Cell): class _JvpInner(Cell):
""" """
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff. 计算给定网络的雅可比向量积Jvp等同于前向模式自动微分
This class implements the inner process of function jvp. 该类实现了JVP的内部过程
""" """
def __init__(self): def __init__(self):
super(_JvpInner, self).__init__() super(_JvpInner, self).__init__()
@ -152,6 +160,7 @@ class _JvpInner(Cell):
self.tuple_len = Primitive("tuple_len") self.tuple_len = Primitive("tuple_len")
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算内部JVP
fn = args[0] fn = args[0]
v = args[1] v = args[1]
jvp_input = args[2:] jvp_input = args[2:]
@ -175,22 +184,21 @@ class _JvpInner(Cell):
class Vjp(Cell): class Vjp(Cell):
""" """
Computes the dot product between a vector `v` and the Jacobian of the given fn at the point 计算给定向量`v`与给定fn在输入点处的雅可比的点积
given by the inputs.
Args: Args:
fn (Cell): The fn that takes Tensor inputs and returns a tuple of Tensors or a Tensor. fn (Cell): 接受Tensor输入并返回Tensor元组或Tensor的fn
Inputs: Inputs:
- **inputs** (Tensors) - The inputs to `fn`. Must be a tuple or a list. - **inputs** (Tensors) - `fn`的输入必须是元组或列表
- **v** (Tensors or Tuple of Tensors) - The vector for which the vector Jacobian product is computed. - **v** (Tensors Tensor元组) - 用于计算向量雅可比积的向量
Must have the same size as the output of `fn`. 必须与`fn`的输出大小相同
Outputs: Outputs:
A tuple with 2 Tensors or Tuple of Tensors: 包含2个Tensors或Tensor元组的元组
- **net_output** (Tensors or Tuple of Tensors) - The output of `fn(inputs)`. - **net_output** (Tensors Tensor元组) - `fn(inputs)`的输出
- **vjp** (Tensors or Tuple of Tensors) - The result of the dot product. - **vjp** (Tensors Tensor元组) - 点积的结果
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -226,6 +234,7 @@ class Vjp(Cell):
@ms_function @ms_function
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算VJP
front_input = args[0:-1] front_input = args[0:-1]
output = self.fn(*front_input) output = self.fn(*front_input)
if self.tuple_len(front_input) == 1: if self.tuple_len(front_input) == 1:
@ -237,8 +246,8 @@ class Vjp(Cell):
class _VjpInner(Cell): class _VjpInner(Cell):
""" """
Computes the dot product between a vector `v` and the Jacobian of the given network at the point 计算给定向量`v`与给定网络在输入点处的雅可比的点积
given by the inputs. This class implements the inner process of function vjp. 该类实现了VJP的内部过程
""" """
def __init__(self): def __init__(self):
@ -248,6 +257,7 @@ class _VjpInner(Cell):
self.tuple_len = Primitive("tuple_len") self.tuple_len = Primitive("tuple_len")
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算内部VJP
fn = args[0] fn = args[0]
front_input = args[1:-1] front_input = args[1:-1]
input_with_v = args[1:] input_with_v = args[1:]

File diff suppressed because it is too large Load Diff

@ -85,58 +85,84 @@ def array(obj, dtype=None, copy=True, ndmin=0):
>>> print(np.array([1,2,3])) >>> print(np.array([1,2,3]))
[1 2 3] [1 2 3]
""" """
if dtype is not None: if dtype is not None: # 如果用户指定了数据类型则检查并转换为mindspore的数据类型
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
res = asarray(obj, dtype) res = asarray(obj, dtype) # 将输入对象转换为tensor
if ndmin > res.ndim: if ndmin > res.ndim: # 如果用户指定的最小维度大于转换后的tensor维度则在tensor的前面添加维度
if res.size == 0: if res.size == 0: # 如果tensor为空抛出异常
_raise_value_error("Empty tensor cannot be expanded beyond the current dimension.") _raise_value_error("Empty tensor cannot be expanded beyond the current dimension.")
res = _expand(res, ndmin) res = _expand(res, ndmin) # 扩展tensor的维度
if copy and isinstance(obj, Tensor): if copy and isinstance(obj, Tensor): # 如果copy为True且输入对象已经是tensor则创建其副本
res = copy_(res) res = copy_(res)
elif dtype is not None and dtype != res.dtype: elif dtype is not None and dtype != res.dtype: # 如果用户指定了数据类型且与转换后的tensor数据类型不同则转换数据类型
res = res.astype(dtype) res = res.astype(dtype)
return res return res # 返回最终生成的tensor
@constexpr @constexpr
def asarray_const(a, dtype=None): def asarray_const(a, dtype=None):
# 标记此函数为constexpr意味着它是一个编译时常量函数
"""Converts the input to tensor. Note here `a` cannot be tensor itself.""" """Converts the input to tensor. Note here `a` cannot be tensor itself."""
# 文档字符串解释函数作用将输入转换为张量注意这里的a不能是张量本身
_check_input_for_asarray(a) _check_input_for_asarray(a)
# 检查输入a是否符合asarray函数的输入要求
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if isinstance(a, (float, int, bool)) and dtype is None: if isinstance(a, (float, int, bool)) and dtype is None:
# 如果a是float、int或bool类型并且dtype未指定
dtype = _get_dtype_from_scalar(a) dtype = _get_dtype_from_scalar(a)
# 从标量a中获取数据类型并赋值给dtype
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# 如果a是list或tuple类型
# Convert all tuple/nested tuples to lists # Convert all tuple/nested tuples to lists
a = _deep_list(a) a = _deep_list(a)
# 将所有tuple及其嵌套的tuple转换为list
# Convert all tensor sub-elements to numpy arrays # Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a) a = _deep_tensor_to_nparray(a)
# 将所有tensor子元素转换为numpy数组
a = onp.asarray(a) a = onp.asarray(a)
# 使用numpy的asarray函数将a转换为numpy数组
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果转换后的numpy数组的数据类型是object
raise ValueError('Input array must have the same size across all dimensions.') raise ValueError('Input array must have the same size across all dimensions.')
# 抛出ValueError表示输入数组在所有维度上必须具有相同的大小
# If dtype is not specified, we keep consistent with numpy decision # If dtype is not specified, we keep consistent with numpy decision
# only exceptions are: we use int/float32 # only exceptions are: we use int/float32
if dtype is None: if dtype is None:
# 如果dtype未指定
dtype = mstype.pytype_to_dtype(a.dtype) dtype = mstype.pytype_to_dtype(a.dtype)
# 将numpy数组的数据类型转换为mindspore的dtype
if dtype == mstype.float64: if dtype == mstype.float64:
# 如果dtype是float64
dtype = mstype.float32 dtype = mstype.float32
# 将dtype改为float32
elif dtype == mstype.int64: elif dtype == mstype.int64:
# 如果dtype是int64
dtype = mstype.int32 dtype = mstype.int32
# 将dtype改为int32
if isinstance(a, onp.ndarray) and dtype is None: if isinstance(a, onp.ndarray) and dtype is None:
# 如果a是numpy数组并且dtype未指定
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果numpy数组的数据类型是object
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
# 抛出TypeError表示输入数据包含不支持的元素
dtype = mstype.pytype_to_dtype(a.dtype) dtype = mstype.pytype_to_dtype(a.dtype)
# 将numpy数组的数据类型转换为mindspore的dtype
a = Tensor.from_numpy(a) a = Tensor.from_numpy(a)
# 将numpy数组转换为mindspore的Tensor
return Tensor(a, dtype=dtype) return Tensor(a, dtype=dtype)
# 返回一个具有指定dtype的Tensor
def asarray(a, dtype=None): def asarray(a, dtype=None):
@ -168,29 +194,46 @@ def asarray(a, dtype=None):
[1 2 3] [1 2 3]
""" """
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if isinstance(a, Tensor): if isinstance(a, Tensor):
# 如果a是Tensor类型
if dtype is None or dtype == a.dtype: if dtype is None or dtype == a.dtype:
# 如果dtype未指定或指定的数据类型与a的数据类型相同
return a return a
# 直接返回a
return a.astype(dtype) return a.astype(dtype)
# 如果指定的数据类型与a的数据类型不同将a的数据类型转换为指定的dtype并返回
return asarray_const(a, dtype) return asarray_const(a, dtype)
# 如果a不是Tensor类型调用asarray_const函数将其转换为Tensor并返回
@constexpr @constexpr
def asfarray_const(a, dtype=mstype.float32): def asfarray_const(a, dtype=mstype.float32):
"""Converts the input to tensor. Note here `a` cannot be tensor itself.""" """Converts the input to tensor. Note here `a` cannot be tensor itself."""
# 文档字符串解释函数作用将输入转换为张量注意这里的a不能是张量本身
_check_input_for_asarray(a) _check_input_for_asarray(a)
# 检查输入a是否符合asarray函数的输入要求
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# 如果a是list或tuple类型
# Convert all tuple/nested tuples to lists # Convert all tuple/nested tuples to lists
a = _deep_list(a) a = _deep_list(a)
# 将所有tuple及其嵌套的tuple转换为list
# Convert all tensor sub-elements to numpy arrays # Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a) a = _deep_tensor_to_nparray(a)
# 将所有tensor子元素转换为numpy数组
a = onp.asarray(a) a = onp.asarray(a)
# 使用numpy的asarray函数将a转换为numpy数组
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果转换后的numpy数组的数据类型是object
raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
# 抛出ValueError表示输入数组在所有维度上必须具有相同的大小
a = Tensor.from_numpy(a) a = Tensor.from_numpy(a)
# 将numpy数组转换为mindspore的Tensor
return Tensor(a, dtype) return Tensor(a, dtype)
# 返回一个具有指定dtype的Tensor
def asfarray(a, dtype=mstype.float32): def asfarray(a, dtype=mstype.float32):
@ -206,7 +249,6 @@ def asfarray(a, dtype=mstype.float32):
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`. of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`.
Returns: Returns:
Tensor, generated tensor with the specified float dtype. Tensor, generated tensor with the specified float dtype.
@ -223,16 +265,24 @@ def asfarray(a, dtype=mstype.float32):
[1. 2. 3.] [1. 2. 3.]
""" """
if dtype is None: if dtype is None:
# 如果dtype未指定
return asarray(a) return asarray(a)
# 调用asarray函数将a转换为Tensor并返回
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if dtype not in (mstype.float16, mstype.float32, mstype.float64): if dtype not in (mstype.float16, mstype.float32, mstype.float64):
# 如果dtype不是float16、float32或float64
dtype = mstype.float32 dtype = mstype.float32
# 将dtype改为float32
if isinstance(a, Tensor): if isinstance(a, Tensor):
# 如果a是Tensor类型
return a.astype(dtype) return a.astype(dtype)
# 将a的数据类型转换为指定的dtype并返回
return asfarray_const(a, dtype) return asfarray_const(a, dtype)
# 如果a不是Tensor类型调用asfarray_const函数将其转换为Tensor并返回
def copy_(a): def copy_(a):
@ -261,7 +311,9 @@ def copy_(a):
[1. 1.]] [1. 1.]]
""" """
a = asarray(a) a = asarray(a)
# 使用asarray函数将a转换为Tensor
return a.copy() return a.copy()
# 返回a的副本
def ones(shape, dtype=mstype.float32): def ones(shape, dtype=mstype.float32):
@ -290,11 +342,17 @@ def ones(shape, dtype=mstype.float32):
[1. 1.]] [1. 1.]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if _is_shape_empty(shape): if _is_shape_empty(shape):
# 如果shape表示的形状是空的
return full(shape, 1.0, dtype) return full(shape, 1.0, dtype)
# 使用full函数创建一个指定形状、数据类型并用1.0填充的Tensor
output = F.fill(dtype, shape, 1) output = F.fill(dtype, shape, 1)
# 使用F.fill函数创建一个指定形状、数据类型并用1填充的Tensor
return output return output
# 返回创建的Tensor
def zeros(shape, dtype=mstype.float32): def zeros(shape, dtype=mstype.float32):
@ -323,11 +381,17 @@ def zeros(shape, dtype=mstype.float32):
[0. 0.]] [0. 0.]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if _is_shape_empty(shape): if _is_shape_empty(shape):
# 如果shape表示的形状是空的
return full(shape, 0.0, dtype) return full(shape, 0.0, dtype)
# 使用full函数创建一个指定形状、数据类型并用0.0填充的Tensor
output = F.fill(dtype, shape, 0) output = F.fill(dtype, shape, 0)
# 使用F.fill函数创建一个指定形状、数据类型并用0填充的Tensor
return output return output
# 返回创建的Tensor
def full(shape, fill_value, dtype=None): def full(shape, fill_value, dtype=None):
@ -360,24 +424,42 @@ def full(shape, fill_value, dtype=None):
[True True]] [True True]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
if not isinstance(fill_value, ARRAY_TYPES): if not isinstance(fill_value, ARRAY_TYPES):
# 如果fill_value不是int、float、bool、list、tuple、Tensor类型
_raise_type_error("fill value should be int, float, bool, list, tuple, Tensor, but got", fill_value) _raise_type_error("fill value should be int, float, bool, list, tuple, Tensor, but got", fill_value)
# 抛出TypeError表示fill_value类型不支持
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
else: else:
# 如果dtype为None
if isinstance(fill_value, (int, float, bool)): if isinstance(fill_value, (int, float, bool)):
# 如果fill_value是int、float或bool类型
dtype = _get_dtype_from_scalar(fill_value) dtype = _get_dtype_from_scalar(fill_value)
# 从标量fill_value中获取数据类型并赋值给dtype
if isinstance(fill_value, Tensor): if isinstance(fill_value, Tensor):
# 如果fill_value是Tensor类型
dtype = fill_value.dtype dtype = fill_value.dtype
# 从Tensor fill_value中获取数据类型并赋值给dtype
if not _is_shape_empty(shape): if not _is_shape_empty(shape):
# 如果shape表示的形状不是空的
if isinstance(fill_value, (int, float, bool)): if isinstance(fill_value, (int, float, bool)):
# 如果fill_value是int、float或bool类型
return F.fill(dtype, shape, fill_value) return F.fill(dtype, shape, fill_value)
# 使用F.fill函数创建一个指定形状、数据类型并用fill_value填充的Tensor
if isinstance(fill_value, (list, tuple)): if isinstance(fill_value, (list, tuple)):
# 如果fill_value是list或tuple类型
fill_value = asarray_const(fill_value) fill_value = asarray_const(fill_value)
# 使用asarray_const函数将fill_value转换为Tensor
return broadcast_to(fill_value, shape) return broadcast_to(fill_value, shape)
# 使用broadcast_to函数将fill_value广播到指定的shape并返回结果
# if shape contains zero, use c.Tensor() # if shape contains zero, use c.Tensor()
return _convert_64_to_32(empty_compile(dtype, shape)) return _convert_64_to_32(empty_compile(dtype, shape))
# 如果shape包含零使用empty_compile函数创建一个空的Tensor并使用_convert_64_to_32函数将数据类型从float64转换为float32
@constexpr @constexpr

@ -22,185 +22,171 @@ from ..common import dtype as mstype
from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack
from ..ops.composite import GradOperation from ..ops.composite import GradOperation
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) # 定义一个求梯度操作,设置为只求第一个参数的梯度,不通过列表获取,且不使用敏感参数
_eps_net = ops.Eps() _eps_net = ops.Eps() # 定义一个计算数值精度的操作
def _convert_64_to_32(tensor): # 定义一个函数将输入的tensor从float64或int64类型转换为float32或int32类型
def _convert_64_to_32(tensor):
"""Convert Tensor with float64/int64 types to float32/int32.""" """Convert Tensor with float64/int64 types to float32/int32."""
if tensor.dtype == mstype.float64: if tensor.dtype == mstype.float64: # 如果tensor的数据类型是float64
return tensor.astype("float32") return tensor.astype("float32") # 将其转换为float32类型
if tensor.dtype == mstype.int64: if tensor.dtype == mstype.int64: # 如果tensor的数据类型是int64
return tensor.astype("int32") return tensor.astype("int32") # 将其转换为int32类型
return tensor return tensor # 如果不是以上两种类型则直接返回原tensor
def _to_tensor(*args, dtype=None): def _to_tensor(*args, dtype=None): # 定义一个函数将输入的参数转换为tensor
"""Returns each input as Tensor""" """Returns each input as Tensor"""
res = () res = () # 初始化一个空元组用于存储结果
for arg in args: for arg in args: # 遍历每一个输入参数
if isinstance(arg, (int, float, bool, list, tuple)): if isinstance(arg, (int, float, bool, list, tuple)): # 如果参数是整数、浮点数、布尔值、列表或元组
arg = _type_convert(Tensor, arg) arg = _type_convert(Tensor, arg) # 将其转换为Tensor类型
if dtype is None: if dtype is None: # 如果没有指定dtype
arg = _convert_64_to_32(arg) arg = _convert_64_to_32(arg) # 调用_convert_64_to_32函数进行类型转换
else: else: # 如果指定了dtype
arg = arg.astype(dtype) arg = arg.astype(dtype) # 将tensor转换为指定的dtype
elif not isinstance(arg, Tensor): elif not isinstance(arg, Tensor): # 如果参数不是Tensor类型
_raise_value_error("Expect input to be array like.") _raise_value_error("Expect input to be array like.") # 抛出错误,提示输入应为数组形式
res += (arg,) res += (arg,) # 将转换后的tensor添加到结果元组中
if len(res) == 1: if len(res) == 1: # 如果结果元组中只有一个元素
return res[0] return res[0] # 直接返回该元素
return res return res # 否则返回整个元组
def _to_scalar(arr): # 定义一个函数将输入的Tensor或ndarray转换为标量值
def _to_scalar(arr):
"""Convert a scalar Tensor or ndarray to a scalar.""" """Convert a scalar Tensor or ndarray to a scalar."""
if isinstance(arr, (int, float, bool)): if isinstance(arr, (int, float, bool)): # 如果输入参数是整数、浮点数或布尔值
return arr return arr # 直接返回该参数
if isinstance(arr, Tensor): if isinstance(arr, Tensor): # 如果输入参数是Tensor类型
if arr.shape: if arr.shape: # 如果tensor的形状不是空的即不是标量
return arr return arr # 返回整个tensor
return arr.asnumpy().item() return arr.asnumpy().item() # 如果是标量将其转换为numpy数组并返回标量值
raise ValueError("{} are not supported.".format(type(arr))) raise ValueError("{} are not supported.".format(type(arr))) # 如果输入参数不是以上两种类型,抛出错误,提示不支持该类型
def _eps(x): # 定义一个函数计算输入tensor的数值精度
def _eps(x): return _eps_net(x[(0,) * x.ndim]) # 使用_ops.Eps操作计算数值精度x[(0,) * x.ndim]确保输入的是一个标量
return _eps_net(x[(0,) * x.ndim])
def _safe_normalize(x, threshold=None): # 定义一个函数对输入的tensor进行归一化如果归一化结果非常小则设置为零
def _safe_normalize(x, threshold=None):
"""Normalize method that cast very small results to zero.""" """Normalize method that cast very small results to zero."""
x_sum2 = F.reduce_sum(F.pows(x, 2.0)) x_sum2 = F.reduce_sum(F.pows(x, 2.0)) # 计算tensor元素平方的和
norm = F.pows(x_sum2, 1. / 2.0) norm = F.pows(x_sum2, 1. / 2.0) # 计算上述和的平方根得到norm
if threshold is None: if threshold is None: # 如果没有指定threshold
if x.dtype in (mstype.float32, mstype.float64): if x.dtype in (mstype.float32, mstype.float64): # 如果tensor的dtype是float32或float64
# pick the first element of x to get the eps # pick the first element of x to get the eps # 获取eps来作为threshold
threshold = _eps(x) threshold = _eps(x)
else: else: # 如果tensor的dtype不是float32或float64
threshold = 0 threshold = 0 # 设置threshold为0
use_norm = greater(norm, threshold) use_norm = greater(norm, threshold) # 比较norm和threshold得到一个布尔mask
x_norm = x / norm x_norm = x / norm # 使用norm对tensor进行归一化
normalized_x = where(use_norm, x_norm, zeros_like(x)) normalized_x = where(use_norm, x_norm, zeros_like(x)) # 如果norm大于threshold则使用归一化后的tensor否则使用零
norm = where(use_norm, norm, zeros_like(norm)) norm = where(use_norm, norm, zeros_like(norm)) # 如果norm大于threshold则保留norm否则使用零
return normalized_x, norm return normalized_x, norm # 返回归一化后的tensor及其对应的norm
def sparse_dot(a, b): # 定义一个函数计算稀疏矩阵CSRTensor与向量generic Tensor的点积
def sparse_dot(a, b):
"""Returns the dot product of CSRTensor and generic Tensor(vector).""" """Returns the dot product of CSRTensor and generic Tensor(vector)."""
b_aligned = F.reshape(b, (b.shape[0], -1)) b_aligned = F.reshape(b, (b.shape[0], -1)) # 将向量b重塑为(b.shape[0], -1)的形状,使其可以与稀疏矩阵相乘
res = F.csr_mv(a, b_aligned) res = F.csr_mv(a, b_aligned) # 使用csr_mv操作计算稀疏矩阵a与向量b_aligned的点积
res = F.reshape(res, a.shape[:-1] + b.shape[1:]) res = F.reshape(res, a.shape[:-1] + b.shape[1:]) # 将计算结果重新塑形为a.shape[:-1] + b.shape[1:]的形状
return res return res # 返回结果
def _normalize_matvec(f): def _normalize_matvec(f): # 定义一个函数,对输入的矩阵或向量进行归一化处理
"""Normalize an argument for computing matrix-vector products.""" """Normalize an argument for computing matrix-vector products."""
if isinstance(f, Tensor): if isinstance(f, Tensor): # 如果输入参数是Tensor类型
return F.partial(dot, f) return F.partial(dot, f) # 返回一个带有矩阵参数f的dot函数的部分应用
if isinstance(f, CSRTensor):
return F.partial(sparse_dot, f)
return f
if isinstance(f, CSRTensor): # 如果输入参数是CSRTensor类型
return F.partial(sparse_dot, f) # 返回一个带有稀疏矩阵参数f的sparse_dot函数的部分应用
def _norm(x, ord_=None): return f # 如果输入参数不是上述两种类型,则直接返回原参数
if ord_ == mnp.inf:
res = mnp.max(mnp.abs(x))
else:
res = mnp.sqrt(mnp.sum(x ** 2))
return res
def _norm(x, ord_=None): # 定义一个函数计算输入tensor的范数
if ord_ == mnp.inf: # 如果ord_为无穷大实际为最大值
res = mnp.max(mnp.abs(x)) # 返回tensor绝对值的最大值
else: # 如果ord_不是无穷大
res = mnp.sqrt(mnp.sum(x ** 2)) # 返回tensor元素平方和的平方根即L2范数
return res # 返回结果
def _nd_transpose(a): def _nd_transpose(a): # 定义一个函数对输入的tensor进行转置最后一个维度与倒数第二个维度互换
dims = a.ndim dims = a.ndim # 获取tensor的维度数
if dims < 2: if dims < 2: # 如果tensor的维度小于2
_raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.") _raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.") # 抛出错误提示输入tensor的维度应大于等于2
axes = ops.make_range(0, dims) axes = ops.make_range(0, dims) # 生成一个从0到tensor维度数的序列
axes = axes[:-2] + (axes[-1],) + (axes[-2],) axes = axes[:-2] + (axes[-1],) + (axes[-2],) # 将序列中的倒数第二个和最后一个元素互换位置
return ops.transpose(a, axes) return ops.transpose(a, axes) # 使用transpose操作对tensor进行转置
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): # 定义一个函数,用于检查输入参数的值是否符合预期
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) # 调用_super_check函数进行检查
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): # 定义一个函数,用于检查输入参数的类型是否符合预期
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False) # 调用_super_check函数进行检查
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): # 定义一个函数用于检查输入参数的mstype是否符合预期
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype", # 调用_super_check函数进行检查
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype",
None, False) None, False)
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): # 定义一个函数,用于检查输入参数的数据类型是否符合预期
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", # 调用_super_check函数进行检查
None, False)
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): def _square_check(func_name, arg, arg_name='a'): # 定义一个函数,用于检查输入参数是否为方阵
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False) arg_shape = arg.shape # 获取输入参数的形状
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) # 检查输入参数的维度是否为2
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) # 检查输入参数的形状是否为方阵
def _square_check(func_name, arg, arg_name='a'): return arg # 返回检查后的参数
arg_shape = arg.shape
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): # 定义一个函数,用于在求解线性方程组时检查输入参数
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) # 获取第一个参数的形状和数据类型
return arg arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) # 获取第二个参数的形状和数据类型
_square_check(func_name, arg1, arg1_name) # 检查第一个参数是否为方阵
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) # 检查第二个参数的维度是否为1或2
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): _super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True) # 检查第一个参数和第二个参数的形状是否可以用于求解线性方程组
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) _super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False) # 检查第一个参数和第二个参数的数据类型是否匹配
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) return arg1, arg2 # 返回检查后的两个参数
_square_check(func_name, arg1, arg1_name)
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) def _sparse_check(func_name, a, m, b, x0): # 定义一个函数用于在稀疏求解器如cg, bicgstab和gmres中检查输入参数
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True)
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False)
return arg1, arg2
def _sparse_check(func_name, a, m, b, x0):
"""Used for cg, bicgstab and gmres method.""" """Used for cg, bicgstab and gmres method."""
def _check_right(arg, arg_name): def _check_right(arg, arg_name): # 定义一个内部函数用于检查右侧参数b或x0
if arg is None: if arg is None: # 如果参数为None
return mnp.zeros_like(b) # x0 same as b return mnp.zeros_like(b) # x0 same as b # 返回与b形状相同元素为零的tensor
# Type # Type
_mstype_check(func_name, arg, mstype.tensor_type, arg_name) _mstype_check(func_name, arg, mstype.tensor_type, arg_name) # 检查参数的mstype是否为tensor_type
# DType # DType
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
# Shape # Shape
if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1): if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1): # 检查参数的形状是否为(N,)或(N, 1)
_raise_value_error("For: '", func_name, "', the shape of '", arg_name, _raise_value_error("For: '", func_name, "', the shape of '", arg_name, # 如果不满足条件,抛出错误
"' should be like (N,) or (N, 1), bug got ", arg.shape, ".") "' should be like (N,) or (N, 1), bug got ", arg.shape, ".")
return arg return arg # 返回检查后的参数
b = _check_right(b, 'b') b = _check_right(b, 'b') # 检查参数b
x0 = _check_right(x0, 'x0') x0 = _check_right(x0, 'x0') # 检查参数x0
def _check_left(arg, arg_name): def _check_left(arg, arg_name): # 定义一个内部函数用于检查左侧参数a或m
if arg is None: if arg is None: # 如果参数为None
return lambda x: x # identity function return lambda x: x # identity function # 返回一个恒等函数
# Type # Type
_mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name) _mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name) # 检查参数的mstype是否为function_type, tensor_type或csr_tensor_type
if _callable_const(F.typeof(arg)): if _callable_const(F.typeof(arg)): # 如果参数是一个可调用的常量(即函数)
return arg return arg # 返回该参数
# DType # DType
if isinstance(arg, CSRTensor): if isinstance(arg, CSRTensor): # 如果参数是CSRTensor类型
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) _dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) # 检查CSRTensor的indptr数据类型是否为int32
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name) _dtype_check(func_name, arg.indices, [mstype.int32], arg_name) # 检查CSRTensor的indices数据类型是否为int32
_dtype_check(func_name, arg.values, [mstype.float32], arg_name) _dtype_check(func_name, arg.values, [mstype.float32], arg_name) # 检查CSRTensor的values数据类型是否为float32
else: else: # 如果参数不是CSRTensor类型
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
# Shape # Shape
_solve_check(func_name, arg, b, arg_name, 'b', True) _solve_check(func_name, arg, b, arg_name, 'b', True) # 检查参数a和b的形状是否可以用于求解线性方程组
_solve_check(func_name, arg, x0, arg_name, 'x0', True) _solve_check(func_name, arg, x0, arg_name, 'x0', True) # 检查参数a和x0的形状是否可以用于求解线性方程组
if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64): if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64): # 如果参数是Tensor类型且数据类型为int32或int64
arg = F.cast(arg, mstype.float64) arg = F.cast(arg, mstype.float64) # 将其转换为float64类型
return arg return arg # 返回检查后的参数
a = _check_left(a, 'A') a = _check_left(a, 'A') # 检查参数a
m = _check_left(m, 'M') m = _check_left(m, 'M') # 检查参数m
b = b.flatten() b = b.flatten() # 将参数b展平为一维的tensor
x0 = x0.flatten() x0 = x0.flatten() # 将参数x0展平为一维的tensor
if F.dtype(b) in (mstype.int32, mstype.int64): if F.dtype(b) in (mstype.int32, mstype.int64): # 如果参数b的数据类型为int32或int64
b = F.cast(b, mstype.float64) b = F.cast(b, mstype.float64) # 将其转换为float64类型
x0 = F.cast(x0, mstype.float64) x0 = F.cast(x0, mstype.float64) # 将其转换为float64类型
return a, m, b, x0 return a, m, b, x0 # 返回检查并转换后的参数

Loading…
Cancel
Save