董昊千代码注读

branch-donghaoqian
donghaoqian 2 months ago
parent 69ae5ccbe1
commit db2c8b9de0

@ -1,47 +1,117 @@
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# 代码版权声明说明此代码由华为技术有限公司在2020-2022年间开发
# Licensed under the Apache License, Version 2.0 (the "License");
# 说明此代码使用Apache License 2.0版本的许可证
# you may not use this file except in compliance with the License.
# 说明除非遵守许可证,否则不得使用此文件
# You may obtain a copy of the License at
#
# 提供许可证的获取地址
# http://www.apache.org/licenses/LICENSE-2.0
#
# 许可证的具体地址
# Unless required by applicable law or agreed to in writing, software
# 除非适用法律要求或书面同意
# distributed under the License is distributed on an "AS IS" BASIS,
# 许可证在“现状”基础上进行分发
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 不附带任何形式的明示或暗示的担保或条件
# See the License for the specific language governing permissions and
# 请参阅许可证了解特定的权限和
# limitations under the License.
# 限制条件
# ============================================================================
# 标准分割线,通常用于分隔许可证部分与代码部分
"""cell"""
# 文档字符串模块的名称为cell
import gc
# 导入垃圾回收模块,用于管理内存
import inspect
# 导入inspect模块用于获取活对象的信息
import os
# 导入os模块用于与操作系统进行交互
import time
# 导入time模块用于处理时间相关操作
from collections import OrderedDict
# 从collections模块导入OrderedDict类用于创建有序字典
from types import FunctionType, MethodType
# 从types模块导入FunctionType和MethodType类用于类型检查
import numpy
# 导入numpy模块用于科学计算
from mindspore._checkparam import args_type_check
# 从mindspore._checkparam模块导入args_type_check函数用于检查函数参数的类型
from mindspore import log as logger
# 从mindspore模块导入log模块并命名为logger用于日志记录
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
# 从mindspore.common.parameter模块导入PARAMETER_NAME_DEFAULT常量用于默认参数名称
from mindspore.common.hook_handle import HookHandle
# 从mindspore.common.hook_handle模块导入HookHandle类用于管理钩子处理
from mindspore.context import ParallelMode
# 从mindspore.context模块导入ParallelMode类用于并行模式配置
from mindspore.ops.composite import Shard
# 从mindspore.ops.composite模块导入Shard类用于分片操作
from .. import context
# 导入相对路径的context模块用于上下文配置
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
# 从相对路径的_c_expression模块导入多个函数和类用于初始化管道、更新函数图超参数、Cell的基础类、函数图类、混合精度类型
from .._checkparam import Validator
# 从相对路径的_checkparam模块导入Validator类用于参数验证
from ..common import dtype as mstype
# 从相对路径的common模块导入dtype模块并重命名为mstype用于数据类型定义
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
# 从相对路径的common.api模块导入多个函数和类用于单元图执行器、原生模式执行器、检查所有张量、编译缓存
from ..common.parameter import Parameter, ParameterTuple
# 从相对路径的common.parameter模块导入Parameter类和ParameterTuple类用于参数和参数元组
from ..common.variable import Variable
# 从相对路径的common.variable模块导入Variable类用于变量表示
from ..common.tensor import Tensor, CSRTensor, COOTensor
# 从相对路径的common.tensor模块导入Tensor类、CSRTensor类和COOTensor类用于张量表示
from ..ops.operations import Cast
# 从相对路径的ops.operations模块导入Cast类用于类型转换操作
from ..ops.primitive import Primitive
# 从相对路径的ops.primitive模块导入Primitive类用于基础操作
from ..ops.operations import _inner_ops as inner
from ..parallel._tensor import _load_tensor_by_layout
# 从相对路径的ops.operations模块导入_inner_ops并重命名为inner用于内部操作
from ..parallel._tensor import _load_tensor_by_layout
# 从相对路径的parallel._tensor模块导入_load_tensor_by_layout函数用于按布局加载张量
class Cell(Cell_):
# 定义Cell类继承自Cell_类这是MindSpore中神经网络的基本构建单元
"""
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
base class.
@ -81,6 +151,8 @@ class Cell(Cell_):
... # the parameter's name will be 'net.weight'.
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
"""
# 类文档字符串解释Cell类的作用、继承关系、参数、支持平台及示例
class _CellGuard:
"""Detecting whether the cell is a top-level cell with the 'with statement'."""

@ -22,61 +22,68 @@ from ...common.api import ms_function
class _FirstGrad(Cell):
# 计算第一个梯度的类
def __init__(self, fn):
super(_FirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
self.fn = fn
def construct(self, u, first_grad_input):
# 构造方法,用于计算梯度
return self.first_grad_op(self.fn)(*first_grad_input, u)
class _JvpFirstGrad(Cell):
# 计算Jacobian-Vector-Product的第一个梯度的类
def __init__(self):
super(_JvpFirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
def construct(self, u, fn, first_grad_input):
# 构造方法用于计算JVP的第一个梯度
return self.first_grad_op(fn)(*first_grad_input, u)
class _FirstGradSingleValue(Cell):
# 计算单值梯度的类
def __init__(self, fn):
super(_FirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True)
self.fn = fn
def construct(self, u, first_grad_single_value_input):
# 构造方法,用于计算单值梯度
return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u)
class _JvpFirstGradSingleValue(Cell):
# 计算Jacobian-Vector-Product的单值梯度的类
def __init__(self):
super(_JvpFirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True)
def construct(self, u, fn, first_grad_single_value_input):
# 构造方法用于计算JVP的单值梯度
return self.first_grad_single_value_op(fn)(*first_grad_single_value_input, u)
class Jvp(Cell):
"""
Compute the jacobian-vector-product of the given fn. Jvp is equivalent to forward mode autodiff.
计算给定fn的雅可比向量积Jvp等同于前向模式自动微分
Args:
fn (Cell): The fn that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
fn (Cell): 接受Tensor输入并返回Tensor元组或Tensor的fn
Inputs:
- **inputs** (Tensors) - The inputs to `fn`.
- **v** (Tensors or Tuple of Tensors) - The vector for which the Jacobian vector product is computed.
Must have the same size as the input of `fn`.
- **inputs** (Tensors) - `fn`的输入
- **v** (Tensors Tensor元组) - 用于计算雅可比向量积的向量
必须与`fn`的输入大小相同
Outputs:
A tuple with 2 Tensors or Tuple of Tensors:
包含2个Tensors或Tensor元组的元组
- **net_output** (Tensors or Tuple of Tensors) - The output of `fn(inputs)`.
- **jvp** (Tensors or Tuple of Tensors) - The result of the jacobian vector product.
- **net_output** (Tensors Tensor元组) - `fn(inputs)`的输出
- **jvp** (Tensors Tensor元组) - 雅可比向量积的结果
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -113,6 +120,7 @@ class Jvp(Cell):
@ms_function
def construct(self, *args):
# 构造方法用于计算JVP
jvp_input = args[0:-1]
v = args[-1]
output = self.fn(*jvp_input)
@ -135,8 +143,8 @@ class Jvp(Cell):
class _JvpInner(Cell):
"""
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff.
This class implements the inner process of function jvp.
计算给定网络的雅可比向量积Jvp等同于前向模式自动微分
该类实现了JVP的内部过程
"""
def __init__(self):
super(_JvpInner, self).__init__()
@ -152,6 +160,7 @@ class _JvpInner(Cell):
self.tuple_len = Primitive("tuple_len")
def construct(self, *args):
# 构造方法用于计算内部JVP
fn = args[0]
v = args[1]
jvp_input = args[2:]
@ -175,22 +184,21 @@ class _JvpInner(Cell):
class Vjp(Cell):
"""
Computes the dot product between a vector `v` and the Jacobian of the given fn at the point
given by the inputs.
计算给定向量`v`与给定fn在输入点处的雅可比的点积
Args:
fn (Cell): The fn that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
fn (Cell): 接受Tensor输入并返回Tensor元组或Tensor的fn
Inputs:
- **inputs** (Tensors) - The inputs to `fn`. Must be a tuple or a list.
- **v** (Tensors or Tuple of Tensors) - The vector for which the vector Jacobian product is computed.
Must have the same size as the output of `fn`.
- **inputs** (Tensors) - `fn`的输入必须是元组或列表
- **v** (Tensors Tensor元组) - 用于计算向量雅可比积的向量
必须与`fn`的输出大小相同
Outputs:
A tuple with 2 Tensors or Tuple of Tensors:
包含2个Tensors或Tensor元组的元组
- **net_output** (Tensors or Tuple of Tensors) - The output of `fn(inputs)`.
- **vjp** (Tensors or Tuple of Tensors) - The result of the dot product.
- **net_output** (Tensors Tensor元组) - `fn(inputs)`的输出
- **vjp** (Tensors Tensor元组) - 点积的结果
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -226,6 +234,7 @@ class Vjp(Cell):
@ms_function
def construct(self, *args):
# 构造方法用于计算VJP
front_input = args[0:-1]
output = self.fn(*front_input)
if self.tuple_len(front_input) == 1:
@ -237,8 +246,8 @@ class Vjp(Cell):
class _VjpInner(Cell):
"""
Computes the dot product between a vector `v` and the Jacobian of the given network at the point
given by the inputs. This class implements the inner process of function vjp.
计算给定向量`v`与给定网络在输入点处的雅可比的点积
该类实现了VJP的内部过程
"""
def __init__(self):
@ -248,6 +257,7 @@ class _VjpInner(Cell):
self.tuple_len = Primitive("tuple_len")
def construct(self, *args):
# 构造方法用于计算内部VJP
fn = args[0]
front_input = args[1:-1]
input_with_v = args[1:]

File diff suppressed because it is too large Load Diff

@ -85,58 +85,84 @@ def array(obj, dtype=None, copy=True, ndmin=0):
>>> print(np.array([1,2,3]))
[1 2 3]
"""
if dtype is not None:
if dtype is not None: # 如果用户指定了数据类型则检查并转换为mindspore的数据类型
dtype = _check_dtype(dtype)
res = asarray(obj, dtype)
res = asarray(obj, dtype) # 将输入对象转换为tensor
if ndmin > res.ndim:
if res.size == 0:
if ndmin > res.ndim: # 如果用户指定的最小维度大于转换后的tensor维度则在tensor的前面添加维度
if res.size == 0: # 如果tensor为空抛出异常
_raise_value_error("Empty tensor cannot be expanded beyond the current dimension.")
res = _expand(res, ndmin)
res = _expand(res, ndmin) # 扩展tensor的维度
if copy and isinstance(obj, Tensor):
if copy and isinstance(obj, Tensor): # 如果copy为True且输入对象已经是tensor则创建其副本
res = copy_(res)
elif dtype is not None and dtype != res.dtype:
elif dtype is not None and dtype != res.dtype: # 如果用户指定了数据类型且与转换后的tensor数据类型不同则转换数据类型
res = res.astype(dtype)
return res
return res # 返回最终生成的tensor
@constexpr
def asarray_const(a, dtype=None):
# 标记此函数为constexpr意味着它是一个编译时常量函数
"""Converts the input to tensor. Note here `a` cannot be tensor itself."""
# 文档字符串解释函数作用将输入转换为张量注意这里的a不能是张量本身
_check_input_for_asarray(a)
# 检查输入a是否符合asarray函数的输入要求
if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if isinstance(a, (float, int, bool)) and dtype is None:
# 如果a是float、int或bool类型并且dtype未指定
dtype = _get_dtype_from_scalar(a)
# 从标量a中获取数据类型并赋值给dtype
if isinstance(a, (list, tuple)):
# 如果a是list或tuple类型
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# 将所有tuple及其嵌套的tuple转换为list
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
# 将所有tensor子元素转换为numpy数组
a = onp.asarray(a)
# 使用numpy的asarray函数将a转换为numpy数组
if a.dtype is onp.dtype('object'):
# 如果转换后的numpy数组的数据类型是object
raise ValueError('Input array must have the same size across all dimensions.')
# 抛出ValueError表示输入数组在所有维度上必须具有相同的大小
# If dtype is not specified, we keep consistent with numpy decision
# only exceptions are: we use int/float32
if dtype is None:
# 如果dtype未指定
dtype = mstype.pytype_to_dtype(a.dtype)
# 将numpy数组的数据类型转换为mindspore的dtype
if dtype == mstype.float64:
# 如果dtype是float64
dtype = mstype.float32
# 将dtype改为float32
elif dtype == mstype.int64:
# 如果dtype是int64
dtype = mstype.int32
# 将dtype改为int32
if isinstance(a, onp.ndarray) and dtype is None:
# 如果a是numpy数组并且dtype未指定
if a.dtype is onp.dtype('object'):
# 如果numpy数组的数据类型是object
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
# 抛出TypeError表示输入数据包含不支持的元素
dtype = mstype.pytype_to_dtype(a.dtype)
# 将numpy数组的数据类型转换为mindspore的dtype
a = Tensor.from_numpy(a)
# 将numpy数组转换为mindspore的Tensor
return Tensor(a, dtype=dtype)
# 返回一个具有指定dtype的Tensor
def asarray(a, dtype=None):
@ -168,29 +194,46 @@ def asarray(a, dtype=None):
[1 2 3]
"""
if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if isinstance(a, Tensor):
# 如果a是Tensor类型
if dtype is None or dtype == a.dtype:
# 如果dtype未指定或指定的数据类型与a的数据类型相同
return a
# 直接返回a
return a.astype(dtype)
# 如果指定的数据类型与a的数据类型不同将a的数据类型转换为指定的dtype并返回
return asarray_const(a, dtype)
# 如果a不是Tensor类型调用asarray_const函数将其转换为Tensor并返回
@constexpr
def asfarray_const(a, dtype=mstype.float32):
"""Converts the input to tensor. Note here `a` cannot be tensor itself."""
# 文档字符串解释函数作用将输入转换为张量注意这里的a不能是张量本身
_check_input_for_asarray(a)
# 检查输入a是否符合asarray函数的输入要求
if isinstance(a, (list, tuple)):
# 如果a是list或tuple类型
# Convert all tuple/nested tuples to lists
a = _deep_list(a)
# 将所有tuple及其嵌套的tuple转换为list
# Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a)
# 将所有tensor子元素转换为numpy数组
a = onp.asarray(a)
# 使用numpy的asarray函数将a转换为numpy数组
if a.dtype is onp.dtype('object'):
# 如果转换后的numpy数组的数据类型是object
raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
# 抛出ValueError表示输入数组在所有维度上必须具有相同的大小
a = Tensor.from_numpy(a)
# 将numpy数组转换为mindspore的Tensor
return Tensor(a, dtype)
# 返回一个具有指定dtype的Tensor
def asfarray(a, dtype=mstype.float32):
@ -206,7 +249,6 @@ def asfarray(a, dtype=mstype.float32):
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`.
Returns:
Tensor, generated tensor with the specified float dtype.
@ -223,16 +265,24 @@ def asfarray(a, dtype=mstype.float32):
[1. 2. 3.]
"""
if dtype is None:
# 如果dtype未指定
return asarray(a)
# 调用asarray函数将a转换为Tensor并返回
dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if dtype not in (mstype.float16, mstype.float32, mstype.float64):
# 如果dtype不是float16、float32或float64
dtype = mstype.float32
# 将dtype改为float32
if isinstance(a, Tensor):
# 如果a是Tensor类型
return a.astype(dtype)
# 将a的数据类型转换为指定的dtype并返回
return asfarray_const(a, dtype)
# 如果a不是Tensor类型调用asfarray_const函数将其转换为Tensor并返回
def copy_(a):
@ -261,7 +311,9 @@ def copy_(a):
[1. 1.]]
"""
a = asarray(a)
# 使用asarray函数将a转换为Tensor
return a.copy()
# 返回a的副本
def ones(shape, dtype=mstype.float32):
@ -290,11 +342,17 @@ def ones(shape, dtype=mstype.float32):
[1. 1.]]
"""
shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if _is_shape_empty(shape):
# 如果shape表示的形状是空的
return full(shape, 1.0, dtype)
# 使用full函数创建一个指定形状、数据类型并用1.0填充的Tensor
output = F.fill(dtype, shape, 1)
# 使用F.fill函数创建一个指定形状、数据类型并用1填充的Tensor
return output
# 返回创建的Tensor
def zeros(shape, dtype=mstype.float32):
@ -323,11 +381,17 @@ def zeros(shape, dtype=mstype.float32):
[0. 0.]]
"""
shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if _is_shape_empty(shape):
# 如果shape表示的形状是空的
return full(shape, 0.0, dtype)
# 使用full函数创建一个指定形状、数据类型并用0.0填充的Tensor
output = F.fill(dtype, shape, 0)
# 使用F.fill函数创建一个指定形状、数据类型并用0填充的Tensor
return output
# 返回创建的Tensor
def full(shape, fill_value, dtype=None):
@ -360,24 +424,42 @@ def full(shape, fill_value, dtype=None):
[True True]]
"""
shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
if not isinstance(fill_value, ARRAY_TYPES):
# 如果fill_value不是int、float、bool、list、tuple、Tensor类型
_raise_type_error("fill value should be int, float, bool, list, tuple, Tensor, but got", fill_value)
# 抛出TypeError表示fill_value类型不支持
if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
else:
# 如果dtype为None
if isinstance(fill_value, (int, float, bool)):
# 如果fill_value是int、float或bool类型
dtype = _get_dtype_from_scalar(fill_value)
# 从标量fill_value中获取数据类型并赋值给dtype
if isinstance(fill_value, Tensor):
# 如果fill_value是Tensor类型
dtype = fill_value.dtype
# 从Tensor fill_value中获取数据类型并赋值给dtype
if not _is_shape_empty(shape):
# 如果shape表示的形状不是空的
if isinstance(fill_value, (int, float, bool)):
# 如果fill_value是int、float或bool类型
return F.fill(dtype, shape, fill_value)
# 使用F.fill函数创建一个指定形状、数据类型并用fill_value填充的Tensor
if isinstance(fill_value, (list, tuple)):
# 如果fill_value是list或tuple类型
fill_value = asarray_const(fill_value)
# 使用asarray_const函数将fill_value转换为Tensor
return broadcast_to(fill_value, shape)
# 使用broadcast_to函数将fill_value广播到指定的shape并返回结果
# if shape contains zero, use c.Tensor()
return _convert_64_to_32(empty_compile(dtype, shape))
# 如果shape包含零使用empty_compile函数创建一个空的Tensor并使用_convert_64_to_32函数将数据类型从float64转换为float32
@constexpr

@ -22,185 +22,171 @@ from ..common import dtype as mstype
from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack
from ..ops.composite import GradOperation
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False)
_eps_net = ops.Eps()
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) # 定义一个求梯度操作,设置为只求第一个参数的梯度,不通过列表获取,且不使用敏感参数
_eps_net = ops.Eps() # 定义一个计算数值精度的操作
def _convert_64_to_32(tensor):
def _convert_64_to_32(tensor): # 定义一个函数将输入的tensor从float64或int64类型转换为float32或int32类型
"""Convert Tensor with float64/int64 types to float32/int32."""
if tensor.dtype == mstype.float64:
return tensor.astype("float32")
if tensor.dtype == mstype.int64:
return tensor.astype("int32")
return tensor
if tensor.dtype == mstype.float64: # 如果tensor的数据类型是float64
return tensor.astype("float32") # 将其转换为float32类型
if tensor.dtype == mstype.int64: # 如果tensor的数据类型是int64
return tensor.astype("int32") # 将其转换为int32类型
return tensor # 如果不是以上两种类型则直接返回原tensor
def _to_tensor(*args, dtype=None):
def _to_tensor(*args, dtype=None): # 定义一个函数将输入的参数转换为tensor
"""Returns each input as Tensor"""
res = ()
for arg in args:
if isinstance(arg, (int, float, bool, list, tuple)):
arg = _type_convert(Tensor, arg)
if dtype is None:
arg = _convert_64_to_32(arg)
else:
arg = arg.astype(dtype)
elif not isinstance(arg, Tensor):
_raise_value_error("Expect input to be array like.")
res += (arg,)
if len(res) == 1:
return res[0]
return res
def _to_scalar(arr):
res = () # 初始化一个空元组用于存储结果
for arg in args: # 遍历每一个输入参数
if isinstance(arg, (int, float, bool, list, tuple)): # 如果参数是整数、浮点数、布尔值、列表或元组
arg = _type_convert(Tensor, arg) # 将其转换为Tensor类型
if dtype is None: # 如果没有指定dtype
arg = _convert_64_to_32(arg) # 调用_convert_64_to_32函数进行类型转换
else: # 如果指定了dtype
arg = arg.astype(dtype) # 将tensor转换为指定的dtype
elif not isinstance(arg, Tensor): # 如果参数不是Tensor类型
_raise_value_error("Expect input to be array like.") # 抛出错误,提示输入应为数组形式
res += (arg,) # 将转换后的tensor添加到结果元组中
if len(res) == 1: # 如果结果元组中只有一个元素
return res[0] # 直接返回该元素
return res # 否则返回整个元组
def _to_scalar(arr): # 定义一个函数将输入的Tensor或ndarray转换为标量值
"""Convert a scalar Tensor or ndarray to a scalar."""
if isinstance(arr, (int, float, bool)):
return arr
if isinstance(arr, Tensor):
if arr.shape:
return arr
return arr.asnumpy().item()
raise ValueError("{} are not supported.".format(type(arr)))
def _eps(x):
return _eps_net(x[(0,) * x.ndim])
def _safe_normalize(x, threshold=None):
if isinstance(arr, (int, float, bool)): # 如果输入参数是整数、浮点数或布尔值
return arr # 直接返回该参数
if isinstance(arr, Tensor): # 如果输入参数是Tensor类型
if arr.shape: # 如果tensor的形状不是空的即不是标量
return arr # 返回整个tensor
return arr.asnumpy().item() # 如果是标量将其转换为numpy数组并返回标量值
raise ValueError("{} are not supported.".format(type(arr))) # 如果输入参数不是以上两种类型,抛出错误,提示不支持该类型
def _eps(x): # 定义一个函数计算输入tensor的数值精度
return _eps_net(x[(0,) * x.ndim]) # 使用_ops.Eps操作计算数值精度x[(0,) * x.ndim]确保输入的是一个标量
def _safe_normalize(x, threshold=None): # 定义一个函数对输入的tensor进行归一化如果归一化结果非常小则设置为零
"""Normalize method that cast very small results to zero."""
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
norm = F.pows(x_sum2, 1. / 2.0)
if threshold is None:
if x.dtype in (mstype.float32, mstype.float64):
# pick the first element of x to get the eps
x_sum2 = F.reduce_sum(F.pows(x, 2.0)) # 计算tensor元素平方的和
norm = F.pows(x_sum2, 1. / 2.0) # 计算上述和的平方根得到norm
if threshold is None: # 如果没有指定threshold
if x.dtype in (mstype.float32, mstype.float64): # 如果tensor的dtype是float32或float64
# pick the first element of x to get the eps # 获取eps来作为threshold
threshold = _eps(x)
else:
threshold = 0
use_norm = greater(norm, threshold)
x_norm = x / norm
normalized_x = where(use_norm, x_norm, zeros_like(x))
norm = where(use_norm, norm, zeros_like(norm))
return normalized_x, norm
def sparse_dot(a, b):
else: # 如果tensor的dtype不是float32或float64
threshold = 0 # 设置threshold为0
use_norm = greater(norm, threshold) # 比较norm和threshold得到一个布尔mask
x_norm = x / norm # 使用norm对tensor进行归一化
normalized_x = where(use_norm, x_norm, zeros_like(x)) # 如果norm大于threshold则使用归一化后的tensor否则使用零
norm = where(use_norm, norm, zeros_like(norm)) # 如果norm大于threshold则保留norm否则使用零
return normalized_x, norm # 返回归一化后的tensor及其对应的norm
def sparse_dot(a, b): # 定义一个函数计算稀疏矩阵CSRTensor与向量generic Tensor的点积
"""Returns the dot product of CSRTensor and generic Tensor(vector)."""
b_aligned = F.reshape(b, (b.shape[0], -1))
res = F.csr_mv(a, b_aligned)
res = F.reshape(res, a.shape[:-1] + b.shape[1:])
return res
b_aligned = F.reshape(b, (b.shape[0], -1)) # 将向量b重塑为(b.shape[0], -1)的形状,使其可以与稀疏矩阵相乘
res = F.csr_mv(a, b_aligned) # 使用csr_mv操作计算稀疏矩阵a与向量b_aligned的点积
res = F.reshape(res, a.shape[:-1] + b.shape[1:]) # 将计算结果重新塑形为a.shape[:-1] + b.shape[1:]的形状
return res # 返回结果
def _normalize_matvec(f):
def _normalize_matvec(f): # 定义一个函数,对输入的矩阵或向量进行归一化处理
"""Normalize an argument for computing matrix-vector products."""
if isinstance(f, Tensor):
return F.partial(dot, f)
if isinstance(f, CSRTensor):
return F.partial(sparse_dot, f)
return f
if isinstance(f, Tensor): # 如果输入参数是Tensor类型
return F.partial(dot, f) # 返回一个带有矩阵参数f的dot函数的部分应用
if isinstance(f, CSRTensor): # 如果输入参数是CSRTensor类型
return F.partial(sparse_dot, f) # 返回一个带有稀疏矩阵参数f的sparse_dot函数的部分应用
def _norm(x, ord_=None):
if ord_ == mnp.inf:
res = mnp.max(mnp.abs(x))
else:
res = mnp.sqrt(mnp.sum(x ** 2))
return res
return f # 如果输入参数不是上述两种类型,则直接返回原参数
def _norm(x, ord_=None): # 定义一个函数计算输入tensor的范数
if ord_ == mnp.inf: # 如果ord_为无穷大实际为最大值
res = mnp.max(mnp.abs(x)) # 返回tensor绝对值的最大值
else: # 如果ord_不是无穷大
res = mnp.sqrt(mnp.sum(x ** 2)) # 返回tensor元素平方和的平方根即L2范数
return res # 返回结果
def _nd_transpose(a):
dims = a.ndim
if dims < 2:
_raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.")
axes = ops.make_range(0, dims)
axes = axes[:-2] + (axes[-1],) + (axes[-2],)
return ops.transpose(a, axes)
def _nd_transpose(a): # 定义一个函数对输入的tensor进行转置最后一个维度与倒数第二个维度互换
dims = a.ndim # 获取tensor的维度数
if dims < 2: # 如果tensor的维度小于2
_raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.") # 抛出错误提示输入tensor的维度应大于等于2
axes = ops.make_range(0, dims) # 生成一个从0到tensor维度数的序列
axes = axes[:-2] + (axes[-1],) + (axes[-2],) # 将序列中的倒数第二个和最后一个元素互换位置
return ops.transpose(a, axes) # 使用transpose操作对tensor进行转置
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): # 定义一个函数,用于检查输入参数的值是否符合预期
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) # 调用_super_check函数进行检查
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): # 定义一个函数,用于检查输入参数的类型是否符合预期
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False) # 调用_super_check函数进行检查
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None):
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype",
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): # 定义一个函数用于检查输入参数的mstype是否符合预期
return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype", # 调用_super_check函数进行检查
None, False)
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): # 定义一个函数,用于检查输入参数的数据类型是否符合预期
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", # 调用_super_check函数进行检查
None, False)
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'):
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False)
def _square_check(func_name, arg, arg_name='a'):
arg_shape = arg.shape
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True)
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True)
return arg
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False):
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1)
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2)
_square_check(func_name, arg1, arg1_name)
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True)
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True)
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False)
return arg1, arg2
def _sparse_check(func_name, a, m, b, x0):
def _square_check(func_name, arg, arg_name='a'): # 定义一个函数,用于检查输入参数是否为方阵
arg_shape = arg.shape # 获取输入参数的形状
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) # 检查输入参数的维度是否为2
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) # 检查输入参数的形状是否为方阵
return arg # 返回检查后的参数
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): # 定义一个函数,用于在求解线性方程组时检查输入参数
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) # 获取第一个参数的形状和数据类型
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) # 获取第二个参数的形状和数据类型
_square_check(func_name, arg1, arg1_name) # 检查第一个参数是否为方阵
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) # 检查第二个参数的维度是否为1或2
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True) # 检查第一个参数和第二个参数的形状是否可以用于求解线性方程组
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False) # 检查第一个参数和第二个参数的数据类型是否匹配
return arg1, arg2 # 返回检查后的两个参数
def _sparse_check(func_name, a, m, b, x0): # 定义一个函数用于在稀疏求解器如cg, bicgstab和gmres中检查输入参数
"""Used for cg, bicgstab and gmres method."""
def _check_right(arg, arg_name):
if arg is None:
return mnp.zeros_like(b) # x0 same as b
def _check_right(arg, arg_name): # 定义一个内部函数用于检查右侧参数b或x0
if arg is None: # 如果参数为None
return mnp.zeros_like(b) # x0 same as b # 返回与b形状相同元素为零的tensor
# Type
_mstype_check(func_name, arg, mstype.tensor_type, arg_name)
_mstype_check(func_name, arg, mstype.tensor_type, arg_name) # 检查参数的mstype是否为tensor_type
# DType
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name)
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
# Shape
if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1):
_raise_value_error("For: '", func_name, "', the shape of '", arg_name,
if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1): # 检查参数的形状是否为(N,)或(N, 1)
_raise_value_error("For: '", func_name, "', the shape of '", arg_name, # 如果不满足条件,抛出错误
"' should be like (N,) or (N, 1), bug got ", arg.shape, ".")
return arg
return arg # 返回检查后的参数
b = _check_right(b, 'b')
x0 = _check_right(x0, 'x0')
b = _check_right(b, 'b') # 检查参数b
x0 = _check_right(x0, 'x0') # 检查参数x0
def _check_left(arg, arg_name):
if arg is None:
return lambda x: x # identity function
def _check_left(arg, arg_name): # 定义一个内部函数用于检查左侧参数a或m
if arg is None: # 如果参数为None
return lambda x: x # identity function # 返回一个恒等函数
# Type
_mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name)
if _callable_const(F.typeof(arg)):
return arg
_mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name) # 检查参数的mstype是否为function_type, tensor_type或csr_tensor_type
if _callable_const(F.typeof(arg)): # 如果参数是一个可调用的常量(即函数)
return arg # 返回该参数
# DType
if isinstance(arg, CSRTensor):
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name)
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name)
_dtype_check(func_name, arg.values, [mstype.float32], arg_name)
else:
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name)
if isinstance(arg, CSRTensor): # 如果参数是CSRTensor类型
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) # 检查CSRTensor的indptr数据类型是否为int32
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name) # 检查CSRTensor的indices数据类型是否为int32
_dtype_check(func_name, arg.values, [mstype.float32], arg_name) # 检查CSRTensor的values数据类型是否为float32
else: # 如果参数不是CSRTensor类型
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
# Shape
_solve_check(func_name, arg, b, arg_name, 'b', True)
_solve_check(func_name, arg, x0, arg_name, 'x0', True)
if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64):
arg = F.cast(arg, mstype.float64)
return arg
a = _check_left(a, 'A')
m = _check_left(m, 'M')
b = b.flatten()
x0 = x0.flatten()
if F.dtype(b) in (mstype.int32, mstype.int64):
b = F.cast(b, mstype.float64)
x0 = F.cast(x0, mstype.float64)
return a, m, b, x0
_solve_check(func_name, arg, b, arg_name, 'b', True) # 检查参数a和b的形状是否可以用于求解线性方程组
_solve_check(func_name, arg, x0, arg_name, 'x0', True) # 检查参数a和x0的形状是否可以用于求解线性方程组
if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64): # 如果参数是Tensor类型且数据类型为int32或int64
arg = F.cast(arg, mstype.float64) # 将其转换为float64类型
return arg # 返回检查后的参数
a = _check_left(a, 'A') # 检查参数a
m = _check_left(m, 'M') # 检查参数m
b = b.flatten() # 将参数b展平为一维的tensor
x0 = x0.flatten() # 将参数x0展平为一维的tensor
if F.dtype(b) in (mstype.int32, mstype.int64): # 如果参数b的数据类型为int32或int64
b = F.cast(b, mstype.float64) # 将其转换为float64类型
x0 = F.cast(x0, mstype.float64) # 将其转换为float64类型
return a, m, b, x0 # 返回检查并转换后的参数

Loading…
Cancel
Save