|
|
|
@ -22,185 +22,171 @@ from ..common import dtype as mstype
|
|
|
|
|
from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack
|
|
|
|
|
from ..ops.composite import GradOperation
|
|
|
|
|
|
|
|
|
|
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False)
|
|
|
|
|
_eps_net = ops.Eps()
|
|
|
|
|
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) # 定义一个求梯度操作,设置为只求第一个参数的梯度,不通过列表获取,且不使用敏感参数
|
|
|
|
|
_eps_net = ops.Eps() # 定义一个计算数值精度的操作
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_64_to_32(tensor):
|
|
|
|
|
def _convert_64_to_32(tensor): # 定义一个函数,将输入的tensor从float64或int64类型转换为float32或int32类型
|
|
|
|
|
"""Convert Tensor with float64/int64 types to float32/int32."""
|
|
|
|
|
if tensor.dtype == mstype.float64:
|
|
|
|
|
return tensor.astype("float32")
|
|
|
|
|
if tensor.dtype == mstype.int64:
|
|
|
|
|
return tensor.astype("int32")
|
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
if tensor.dtype == mstype.float64: # 如果tensor的数据类型是float64
|
|
|
|
|
return tensor.astype("float32") # 将其转换为float32类型
|
|
|
|
|
if tensor.dtype == mstype.int64: # 如果tensor的数据类型是int64
|
|
|
|
|
return tensor.astype("int32") # 将其转换为int32类型
|
|
|
|
|
return tensor # 如果不是以上两种类型,则直接返回原tensor
|
|
|
|
|
|
|
|
|
|
def _to_tensor(*args, dtype=None):
|
|
|
|
|
def _to_tensor(*args, dtype=None): # 定义一个函数,将输入的参数转换为tensor
|
|
|
|
|
"""Returns each input as Tensor"""
|
|
|
|
|
res = ()
|
|
|
|
|
for arg in args:
|
|
|
|
|
if isinstance(arg, (int, float, bool, list, tuple)):
|
|
|
|
|
arg = _type_convert(Tensor, arg)
|
|
|
|
|
if dtype is None:
|
|
|
|
|
arg = _convert_64_to_32(arg)
|
|
|
|
|
else:
|
|
|
|
|
arg = arg.astype(dtype)
|
|
|
|
|
elif not isinstance(arg, Tensor):
|
|
|
|
|
_raise_value_error("Expect input to be array like.")
|
|
|
|
|
res += (arg,)
|
|
|
|
|
if len(res) == 1:
|
|
|
|
|
return res[0]
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _to_scalar(arr):
|
|
|
|
|
res = () # 初始化一个空元组用于存储结果
|
|
|
|
|
for arg in args: # 遍历每一个输入参数
|
|
|
|
|
if isinstance(arg, (int, float, bool, list, tuple)): # 如果参数是整数、浮点数、布尔值、列表或元组
|
|
|
|
|
arg = _type_convert(Tensor, arg) # 将其转换为Tensor类型
|
|
|
|
|
if dtype is None: # 如果没有指定dtype
|
|
|
|
|
arg = _convert_64_to_32(arg) # 调用_convert_64_to_32函数进行类型转换
|
|
|
|
|
else: # 如果指定了dtype
|
|
|
|
|
arg = arg.astype(dtype) # 将tensor转换为指定的dtype
|
|
|
|
|
elif not isinstance(arg, Tensor): # 如果参数不是Tensor类型
|
|
|
|
|
_raise_value_error("Expect input to be array like.") # 抛出错误,提示输入应为数组形式
|
|
|
|
|
res += (arg,) # 将转换后的tensor添加到结果元组中
|
|
|
|
|
if len(res) == 1: # 如果结果元组中只有一个元素
|
|
|
|
|
return res[0] # 直接返回该元素
|
|
|
|
|
return res # 否则返回整个元组
|
|
|
|
|
|
|
|
|
|
def _to_scalar(arr): # 定义一个函数,将输入的Tensor或ndarray转换为标量值
|
|
|
|
|
"""Convert a scalar Tensor or ndarray to a scalar."""
|
|
|
|
|
if isinstance(arr, (int, float, bool)):
|
|
|
|
|
return arr
|
|
|
|
|
if isinstance(arr, Tensor):
|
|
|
|
|
if arr.shape:
|
|
|
|
|
return arr
|
|
|
|
|
return arr.asnumpy().item()
|
|
|
|
|
raise ValueError("{} are not supported.".format(type(arr)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _eps(x):
|
|
|
|
|
return _eps_net(x[(0,) * x.ndim])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _safe_normalize(x, threshold=None):
|
|
|
|
|
if isinstance(arr, (int, float, bool)): # 如果输入参数是整数、浮点数或布尔值
|
|
|
|
|
return arr # 直接返回该参数
|
|
|
|
|
if isinstance(arr, Tensor): # 如果输入参数是Tensor类型
|
|
|
|
|
if arr.shape: # 如果tensor的形状不是空的(即不是标量)
|
|
|
|
|
return arr # 返回整个tensor
|
|
|
|
|
return arr.asnumpy().item() # 如果是标量,将其转换为numpy数组并返回标量值
|
|
|
|
|
raise ValueError("{} are not supported.".format(type(arr))) # 如果输入参数不是以上两种类型,抛出错误,提示不支持该类型
|
|
|
|
|
|
|
|
|
|
def _eps(x): # 定义一个函数,计算输入tensor的数值精度
|
|
|
|
|
return _eps_net(x[(0,) * x.ndim]) # 使用_ops.Eps操作计算数值精度,x[(0,) * x.ndim]确保输入的是一个标量
|
|
|
|
|
|
|
|
|
|
def _safe_normalize(x, threshold=None): # 定义一个函数,对输入的tensor进行归一化,如果归一化结果非常小,则设置为零
|
|
|
|
|
"""Normalize method that cast very small results to zero."""
|
|
|
|
|
x_sum2 = F.reduce_sum(F.pows(x, 2.0))
|
|
|
|
|
norm = F.pows(x_sum2, 1. / 2.0)
|
|
|
|
|
if threshold is None:
|
|
|
|
|
if x.dtype in (mstype.float32, mstype.float64):
|
|
|
|
|
# pick the first element of x to get the eps
|
|
|
|
|
x_sum2 = F.reduce_sum(F.pows(x, 2.0)) # 计算tensor元素平方的和
|
|
|
|
|
norm = F.pows(x_sum2, 1. / 2.0) # 计算上述和的平方根,得到norm
|
|
|
|
|
if threshold is None: # 如果没有指定threshold
|
|
|
|
|
if x.dtype in (mstype.float32, mstype.float64): # 如果tensor的dtype是float32或float64
|
|
|
|
|
# pick the first element of x to get the eps # 获取eps来作为threshold
|
|
|
|
|
threshold = _eps(x)
|
|
|
|
|
else:
|
|
|
|
|
threshold = 0
|
|
|
|
|
use_norm = greater(norm, threshold)
|
|
|
|
|
x_norm = x / norm
|
|
|
|
|
normalized_x = where(use_norm, x_norm, zeros_like(x))
|
|
|
|
|
norm = where(use_norm, norm, zeros_like(norm))
|
|
|
|
|
return normalized_x, norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sparse_dot(a, b):
|
|
|
|
|
else: # 如果tensor的dtype不是float32或float64
|
|
|
|
|
threshold = 0 # 设置threshold为0
|
|
|
|
|
use_norm = greater(norm, threshold) # 比较norm和threshold,得到一个布尔mask
|
|
|
|
|
x_norm = x / norm # 使用norm对tensor进行归一化
|
|
|
|
|
normalized_x = where(use_norm, x_norm, zeros_like(x)) # 如果norm大于threshold,则使用归一化后的tensor,否则使用零
|
|
|
|
|
norm = where(use_norm, norm, zeros_like(norm)) # 如果norm大于threshold,则保留norm,否则使用零
|
|
|
|
|
return normalized_x, norm # 返回归一化后的tensor及其对应的norm
|
|
|
|
|
|
|
|
|
|
def sparse_dot(a, b): # 定义一个函数,计算稀疏矩阵(CSRTensor)与向量(generic Tensor)的点积
|
|
|
|
|
"""Returns the dot product of CSRTensor and generic Tensor(vector)."""
|
|
|
|
|
b_aligned = F.reshape(b, (b.shape[0], -1))
|
|
|
|
|
res = F.csr_mv(a, b_aligned)
|
|
|
|
|
res = F.reshape(res, a.shape[:-1] + b.shape[1:])
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
b_aligned = F.reshape(b, (b.shape[0], -1)) # 将向量b重塑为(b.shape[0], -1)的形状,使其可以与稀疏矩阵相乘
|
|
|
|
|
res = F.csr_mv(a, b_aligned) # 使用csr_mv操作计算稀疏矩阵a与向量b_aligned的点积
|
|
|
|
|
res = F.reshape(res, a.shape[:-1] + b.shape[1:]) # 将计算结果重新塑形为a.shape[:-1] + b.shape[1:]的形状
|
|
|
|
|
return res # 返回结果
|
|
|
|
|
|
|
|
|
|
def _normalize_matvec(f):
|
|
|
|
|
def _normalize_matvec(f): # 定义一个函数,对输入的矩阵或向量进行归一化处理
|
|
|
|
|
"""Normalize an argument for computing matrix-vector products."""
|
|
|
|
|
if isinstance(f, Tensor):
|
|
|
|
|
return F.partial(dot, f)
|
|
|
|
|
|
|
|
|
|
if isinstance(f, CSRTensor):
|
|
|
|
|
return F.partial(sparse_dot, f)
|
|
|
|
|
|
|
|
|
|
return f
|
|
|
|
|
if isinstance(f, Tensor): # 如果输入参数是Tensor类型
|
|
|
|
|
return F.partial(dot, f) # 返回一个带有矩阵参数f的dot函数的部分应用
|
|
|
|
|
|
|
|
|
|
if isinstance(f, CSRTensor): # 如果输入参数是CSRTensor类型
|
|
|
|
|
return F.partial(sparse_dot, f) # 返回一个带有稀疏矩阵参数f的sparse_dot函数的部分应用
|
|
|
|
|
|
|
|
|
|
def _norm(x, ord_=None):
|
|
|
|
|
if ord_ == mnp.inf:
|
|
|
|
|
res = mnp.max(mnp.abs(x))
|
|
|
|
|
else:
|
|
|
|
|
res = mnp.sqrt(mnp.sum(x ** 2))
|
|
|
|
|
return res
|
|
|
|
|
return f # 如果输入参数不是上述两种类型,则直接返回原参数
|
|
|
|
|
|
|
|
|
|
def _norm(x, ord_=None): # 定义一个函数,计算输入tensor的范数
|
|
|
|
|
if ord_ == mnp.inf: # 如果ord_为无穷大(实际为最大值)
|
|
|
|
|
res = mnp.max(mnp.abs(x)) # 返回tensor绝对值的最大值
|
|
|
|
|
else: # 如果ord_不是无穷大
|
|
|
|
|
res = mnp.sqrt(mnp.sum(x ** 2)) # 返回tensor元素平方和的平方根,即L2范数
|
|
|
|
|
return res # 返回结果
|
|
|
|
|
|
|
|
|
|
def _nd_transpose(a):
|
|
|
|
|
dims = a.ndim
|
|
|
|
|
if dims < 2:
|
|
|
|
|
_raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.")
|
|
|
|
|
axes = ops.make_range(0, dims)
|
|
|
|
|
axes = axes[:-2] + (axes[-1],) + (axes[-2],)
|
|
|
|
|
return ops.transpose(a, axes)
|
|
|
|
|
def _nd_transpose(a): # 定义一个函数,对输入的tensor进行转置,最后一个维度与倒数第二个维度互换
|
|
|
|
|
dims = a.ndim # 获取tensor的维度数
|
|
|
|
|
if dims < 2: # 如果tensor的维度小于2
|
|
|
|
|
_raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.") # 抛出错误,提示输入tensor的维度应大于等于2
|
|
|
|
|
axes = ops.make_range(0, dims) # 生成一个从0到tensor维度数的序列
|
|
|
|
|
axes = axes[:-2] + (axes[-1],) + (axes[-2],) # 将序列中的倒数第二个和最后一个元素互换位置
|
|
|
|
|
return ops.transpose(a, axes) # 使用transpose操作对tensor进行转置
|
|
|
|
|
|
|
|
|
|
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): # 定义一个函数,用于检查输入参数的值是否符合预期
|
|
|
|
|
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) # 调用_super_check函数进行检查
|
|
|
|
|
|
|
|
|
|
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None):
|
|
|
|
|
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True)
|
|
|
|
|
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): # 定义一个函数,用于检查输入参数的类型是否符合预期
|
|
|
|
|
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False) # 调用_super_check函数进行检查
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None):
|
|
|
|
|
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
|
|
|
|
|
return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype",
|
|
|
|
|
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): # 定义一个函数,用于检查输入参数的mstype是否符合预期
|
|
|
|
|
return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype", # 调用_super_check函数进行检查
|
|
|
|
|
None, False)
|
|
|
|
|
|
|
|
|
|
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): # 定义一个函数,用于检查输入参数的数据类型是否符合预期
|
|
|
|
|
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", # 调用_super_check函数进行检查
|
|
|
|
|
None, False)
|
|
|
|
|
|
|
|
|
|
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'):
|
|
|
|
|
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _square_check(func_name, arg, arg_name='a'):
|
|
|
|
|
arg_shape = arg.shape
|
|
|
|
|
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True)
|
|
|
|
|
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True)
|
|
|
|
|
return arg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False):
|
|
|
|
|
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1)
|
|
|
|
|
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2)
|
|
|
|
|
_square_check(func_name, arg1, arg1_name)
|
|
|
|
|
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True)
|
|
|
|
|
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True)
|
|
|
|
|
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False)
|
|
|
|
|
return arg1, arg2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sparse_check(func_name, a, m, b, x0):
|
|
|
|
|
def _square_check(func_name, arg, arg_name='a'): # 定义一个函数,用于检查输入参数是否为方阵
|
|
|
|
|
arg_shape = arg.shape # 获取输入参数的形状
|
|
|
|
|
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) # 检查输入参数的维度是否为2
|
|
|
|
|
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) # 检查输入参数的形状是否为方阵
|
|
|
|
|
return arg # 返回检查后的参数
|
|
|
|
|
|
|
|
|
|
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): # 定义一个函数,用于在求解线性方程组时检查输入参数
|
|
|
|
|
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) # 获取第一个参数的形状和数据类型
|
|
|
|
|
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) # 获取第二个参数的形状和数据类型
|
|
|
|
|
_square_check(func_name, arg1, arg1_name) # 检查第一个参数是否为方阵
|
|
|
|
|
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) # 检查第二个参数的维度是否为1或2
|
|
|
|
|
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True) # 检查第一个参数和第二个参数的形状是否可以用于求解线性方程组
|
|
|
|
|
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False) # 检查第一个参数和第二个参数的数据类型是否匹配
|
|
|
|
|
return arg1, arg2 # 返回检查后的两个参数
|
|
|
|
|
|
|
|
|
|
def _sparse_check(func_name, a, m, b, x0): # 定义一个函数,用于在稀疏求解器(如cg, bicgstab和gmres)中检查输入参数
|
|
|
|
|
"""Used for cg, bicgstab and gmres method."""
|
|
|
|
|
|
|
|
|
|
def _check_right(arg, arg_name):
|
|
|
|
|
if arg is None:
|
|
|
|
|
return mnp.zeros_like(b) # x0 same as b
|
|
|
|
|
def _check_right(arg, arg_name): # 定义一个内部函数,用于检查右侧参数(b或x0)
|
|
|
|
|
if arg is None: # 如果参数为None
|
|
|
|
|
return mnp.zeros_like(b) # x0 same as b # 返回与b形状相同,元素为零的tensor
|
|
|
|
|
# Type
|
|
|
|
|
_mstype_check(func_name, arg, mstype.tensor_type, arg_name)
|
|
|
|
|
_mstype_check(func_name, arg, mstype.tensor_type, arg_name) # 检查参数的mstype是否为tensor_type
|
|
|
|
|
# DType
|
|
|
|
|
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name)
|
|
|
|
|
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
|
|
|
|
|
# Shape
|
|
|
|
|
if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1):
|
|
|
|
|
_raise_value_error("For: '", func_name, "', the shape of '", arg_name,
|
|
|
|
|
if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1): # 检查参数的形状是否为(N,)或(N, 1)
|
|
|
|
|
_raise_value_error("For: '", func_name, "', the shape of '", arg_name, # 如果不满足条件,抛出错误
|
|
|
|
|
"' should be like (N,) or (N, 1), bug got ", arg.shape, ".")
|
|
|
|
|
return arg
|
|
|
|
|
return arg # 返回检查后的参数
|
|
|
|
|
|
|
|
|
|
b = _check_right(b, 'b')
|
|
|
|
|
x0 = _check_right(x0, 'x0')
|
|
|
|
|
b = _check_right(b, 'b') # 检查参数b
|
|
|
|
|
x0 = _check_right(x0, 'x0') # 检查参数x0
|
|
|
|
|
|
|
|
|
|
def _check_left(arg, arg_name):
|
|
|
|
|
if arg is None:
|
|
|
|
|
return lambda x: x # identity function
|
|
|
|
|
def _check_left(arg, arg_name): # 定义一个内部函数,用于检查左侧参数(a或m)
|
|
|
|
|
if arg is None: # 如果参数为None
|
|
|
|
|
return lambda x: x # identity function # 返回一个恒等函数
|
|
|
|
|
# Type
|
|
|
|
|
_mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name)
|
|
|
|
|
if _callable_const(F.typeof(arg)):
|
|
|
|
|
return arg
|
|
|
|
|
_mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name) # 检查参数的mstype是否为function_type, tensor_type或csr_tensor_type
|
|
|
|
|
if _callable_const(F.typeof(arg)): # 如果参数是一个可调用的常量(即函数)
|
|
|
|
|
return arg # 返回该参数
|
|
|
|
|
# DType
|
|
|
|
|
if isinstance(arg, CSRTensor):
|
|
|
|
|
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name)
|
|
|
|
|
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name)
|
|
|
|
|
_dtype_check(func_name, arg.values, [mstype.float32], arg_name)
|
|
|
|
|
else:
|
|
|
|
|
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name)
|
|
|
|
|
if isinstance(arg, CSRTensor): # 如果参数是CSRTensor类型
|
|
|
|
|
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) # 检查CSRTensor的indptr数据类型是否为int32
|
|
|
|
|
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name) # 检查CSRTensor的indices数据类型是否为int32
|
|
|
|
|
_dtype_check(func_name, arg.values, [mstype.float32], arg_name) # 检查CSRTensor的values数据类型是否为float32
|
|
|
|
|
else: # 如果参数不是CSRTensor类型
|
|
|
|
|
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
|
|
|
|
|
# Shape
|
|
|
|
|
_solve_check(func_name, arg, b, arg_name, 'b', True)
|
|
|
|
|
_solve_check(func_name, arg, x0, arg_name, 'x0', True)
|
|
|
|
|
if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64):
|
|
|
|
|
arg = F.cast(arg, mstype.float64)
|
|
|
|
|
return arg
|
|
|
|
|
|
|
|
|
|
a = _check_left(a, 'A')
|
|
|
|
|
m = _check_left(m, 'M')
|
|
|
|
|
|
|
|
|
|
b = b.flatten()
|
|
|
|
|
x0 = x0.flatten()
|
|
|
|
|
if F.dtype(b) in (mstype.int32, mstype.int64):
|
|
|
|
|
b = F.cast(b, mstype.float64)
|
|
|
|
|
x0 = F.cast(x0, mstype.float64)
|
|
|
|
|
return a, m, b, x0
|
|
|
|
|
_solve_check(func_name, arg, b, arg_name, 'b', True) # 检查参数a和b的形状是否可以用于求解线性方程组
|
|
|
|
|
_solve_check(func_name, arg, x0, arg_name, 'x0', True) # 检查参数a和x0的形状是否可以用于求解线性方程组
|
|
|
|
|
if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64): # 如果参数是Tensor类型且数据类型为int32或int64
|
|
|
|
|
arg = F.cast(arg, mstype.float64) # 将其转换为float64类型
|
|
|
|
|
return arg # 返回检查后的参数
|
|
|
|
|
|
|
|
|
|
a = _check_left(a, 'A') # 检查参数a
|
|
|
|
|
m = _check_left(m, 'M') # 检查参数m
|
|
|
|
|
|
|
|
|
|
b = b.flatten() # 将参数b展平为一维的tensor
|
|
|
|
|
x0 = x0.flatten() # 将参数x0展平为一维的tensor
|
|
|
|
|
if F.dtype(b) in (mstype.int32, mstype.int64): # 如果参数b的数据类型为int32或int64
|
|
|
|
|
b = F.cast(b, mstype.float64) # 将其转换为float64类型
|
|
|
|
|
x0 = F.cast(x0, mstype.float64) # 将其转换为float64类型
|
|
|
|
|
return a, m, b, x0 # 返回检查并转换后的参数
|
|
|
|
|
|
|
|
|
|