Compare commits

..

64 Commits

Author SHA1 Message Date
donghaoqian 8ff195e3f3 泛读报告格式调整后再次提交
2 months ago
donghaoqian 40500495da 提交修改后的泛读报告
2 months ago
donghaoqian 1319480deb 删除无关文件
2 months ago
donghaoqian 79516a3321 Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2
2 months ago
donghaoqian eeea480af9 更新泛读报告
2 months ago
pupb4l8rm ab1f7b5097 Merge pull request 'ruiqin' (#6) from Branch_ruiqin into main
2 months ago
zhang 759ad9d09c communication_ruiqin2
2 months ago
liuwenhao bdf240c986 Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-lwh
2 months ago
pupb4l8rm b12adc5753 Merge pull request 'ruiqin' (#5) from Branch_ruiqin into main
2 months ago
liuwenhao ae1d86521d Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-lwh
2 months ago
zhang 542726fc1a run_check,ruiqin
2 months ago
zhang ddf9ac7477 graph_utils,ruiqin
2 months ago
zhang 61450586a7 communication,ruiqin
2 months ago
pstluih63 6d63bbbd8c Merge pull request '提交mindspore泛读报告' (#4) from branch-donghaoqian into main
2 months ago
donghaoqian b47832ea8c 提交mindspore泛读报告
2 months ago
pbp2cnvu4 93f0af2f73 Merge pull request 'merge zwt' (#3) from branch_zwt into main
2 months ago
zouwentao 3668e69ed5 注释
2 months ago
zouwentao a672592d99 comment
2 months ago
donghaoqian 74aff73449 Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-donghaoqian
2 months ago
donghaoqian db2c8b9de0 董昊千代码注读
2 months ago
pptw92c8a c8a0ad4d29 Merge pull request 'xiangguo' (#2) from branch-xg into main
2 months ago
donghaoqian 69ae5ccbe1 董昊千注读/nn/loss/loss.py
2 months ago
liuwenhao 5880e5be32 Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-lwh
2 months ago
pc8nkf3eg 0c93bc5640 Merge pull request '_extend文件夹泛读注释' (#1) from branch-yixin into main
2 months ago
yixin 510df7cf31 feat(boost): 增加对新模块的导入及中文注释和文档翻译
2 months ago
yixin c7222ba353 (add comments and docstrings to resources.py for clarity)
2 months ago
yixin eff398104d (add comments and docstrings to namespace.py for clarity)
2 months ago
yixin b4d46a2542 _extends\parse\__init__.py
2 months ago
yixin 569be2fa4e _extends\parallel_compile\tbe_compiler
2 months ago
yixin 0012f23abf _extends\parallel_compile\akg_compiler
2 months ago
yixin cd3f01ab90 _extends\graph_kernel\model\graph_parallel.py
2 months ago
yixin ffdf6162c7 _extends\graph_kernel\model
2 months ago
yixin f8389e877f add comments for _extends\graph_kernel\expanders\complex
2 months ago
yixin f1e159c4b4 add comments for expanders
2 months ago
yixin 95f50332ac _extends\graph_kernel\expanders\gelu.py
2 months ago
yixin ace1dccc7a _extends\graph_kernel\expanders\gelu_grad.py
2 months ago
yixin 07b3545276 _extends\graph_kernel\expanders\gather.py
2 months ago
yixin efae72cc05 _extends\graph_kernel\expanders\fused_mul_add.py
2 months ago
yixin 17d5bbbe72 _extends\graph_kernel\expanders\fused_adam.py
2 months ago
yixin 3e1df14f3b _extends\graph_kernel\expanders\fused_adam_weight_decay.py
2 months ago
yixin 60f231dc1c _extends\graph_kernel\expanders\expand_dims.py
2 months ago
yixin 979f67f6fa _extends\graph_kernel\expanders\erfc.py
2 months ago
yixin 0015c083a7 _extends\graph_kernel\expanders\equal_count.py
2 months ago
yixin c64739f456 _extends\graph_kernel\expanders\dropout_grad.py
2 months ago
yixin e13c287655 _extends\graph_kernel\expanders\conv2d.py
2 months ago
yixin 8388b30ee2 add comments for _extends\graph_kernel\expanders\clip_by_norm_no_div_sum.py
2 months ago
yixin 2c4a524a6a add comments for _extends\graph_kernel\expanders\bias_add_grad.py
2 months ago
yixin b0c7662155 add comments for _extends\graph_kernel\expanders\batchnorm.py
2 months ago
yixin eef69d070e add comments for _extends\graph_kernel\expanders\batchnorm_grad.py
2 months ago
yixin e414d6025d add comments for _extends\graph_kernel\expanders\addn.py
2 months ago
yixin 1531f33582 add comments for _extends\graph_kernel\expanders\_utils.py
2 months ago
yixin 500f5d3348 add comments for _extends\graph_kernel\expanders\__init__.py
2 months ago
yixin ffeb853318 add comments for _extends\graph_kernel\utils.py
2 months ago
yixin e5d2d82400 add comments for _extends\graph_kernel\splitter.py
2 months ago
yixin 5437cea5c8 add comments for _extends/graph_kernel/parrellel_estimate.py
2 months ago
yixin bfa789008e add comments for _extends/graph_kernel/expander.py
2 months ago
yixin afeb43c388 add comments for _extends/graph_kernel/_init_.py
2 months ago
yixin 318e60dcb7 add comments for utils.py
2 months ago
yixin 4e2d1b2b99 add comments for builtin_operations.py
2 months ago
zhang 4e49598ac9 ruiqn
2 months ago
donghaoqian ade7c451d5 创建branch-donghaoqian
2 months ago
liuwenhao 65ca9afacc 注释
2 months ago
yixin 9618cd0672 command mindspore_test.py
2 months ago
yixin dfcf88b2c6 command
2 months ago

@ -1,2 +0,0 @@
# mindspore_group_2

@ -22,31 +22,100 @@ from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
def ScalarAdd(x, y): def ScalarAdd(x, y):
"""
实现标量加法运算
Args:
x (float): 第一个加数
y (float): 第二个加数
Returns:
float: x y 的和
"""
"""Implement `scalar_add`.""" """Implement `scalar_add`."""
return x + y return x + y
def ScalarMul(x, y): def ScalarMul(x, y):
"""
标量乘法函数
Args:
x (float): 第一个标量
y (float): 第二个标量
Returns:
float: 两个标量的乘积
"""
"""Implement `scalar_mul`.""" """Implement `scalar_mul`."""
return x * y return x * y
def ScalarMod(x, y): def ScalarMod(x, y):
"""
对两个数进行模运算
Args:
x (int or float): 被模数
y (int or float): 模数
Returns:
int or float: x y 取模的结果
"""
"""Implement `scalar_mul`.""" """Implement `scalar_mul`."""
return x % y return x % y
def ScalarSub(x, y): def ScalarSub(x, y):
"""
实现标量减法运算
Args:
x (float): 第一个标量值
y (float): 第二个标量值
Returns:
float: x y 的差值
"""
"""Implement `scalar_sub`.""" """Implement `scalar_sub`."""
return x - y return x - y
def ScalarUsub(x): def ScalarUsub(x):
"""
对给定的标量 x 进行取反操作
Args:
x (float or int): 需要取反的标量
Returns:
float or int: 取反后的标量
"""
"""Implement `scalar_usub`.""" """Implement `scalar_usub`."""
return -x return -x
def TupleGetItem(x, index): def TupleGetItem(x, index):
"""
从给定对象中获取指定索引处的元素
Args:
x (Union[Tensor, dict]): 输入对象可以是Tensor类型或字典类型
index (int): 要获取的元素的索引
Returns:
Union[Tensor, Any]: 如果输入是Tensor类型则返回Tensor类型的元素
如果输入是字典类型则返回字典中对应索引的值
否则返回输入对象中对应索引的元素
Raises:
IndexError: 如果索引超出范围
"""
"""Implement `tuple_getitem`.""" """Implement `tuple_getitem`."""
if isinstance(x, Tensor): if isinstance(x, Tensor):
x = x.asnumpy() x = x.asnumpy()
@ -64,36 +133,111 @@ def TupleGetItem(x, index):
def scalar_gt(x, y): def scalar_gt(x, y):
"""
判断两个标量值x和y的大小关系
Args:
x (float or int): 第一个标量值
y (float or int): 第二个标量值
Returns:
bool: 如果x大于y则返回True否则返回False
"""
"""Implement `scalar_gt`.""" """Implement `scalar_gt`."""
return x > y return x > y
def scalar_ne(x, y): def scalar_ne(x, y):
"""
比较两个标量值是否不相等
Args:
x (float): 第一个标量值
y (float): 第二个标量值
Returns:
bool: 如果 x 不等于 y则返回 True否则返回 False
"""
"""Implement `scalar_ne`.""" """Implement `scalar_ne`."""
return x != y return x != y
def scalar_eq(x, y): def scalar_eq(x, y):
"""
判断两个标量值是否相等
Args:
x (Any): 第一个标量值
y (Any): 第二个标量值
Returns:
bool: 如果 x y 相等返回 True否则返回 False
"""
"""Implement `scalar_eq`.""" """Implement `scalar_eq`."""
return x == y return x == y
def scalar_le(x, y): def scalar_le(x, y):
"""
判断标量 x 是否小于等于标量 y
Args:
x (float): 第一个标量值
y (float): 第二个标量值
Returns:
bool: 如果 x 小于等于 y则返回 True否则返回 False
"""
"""Implement `scalar_le`.""" """Implement `scalar_le`."""
return x <= y return x <= y
def scalar_lt(x, y): def scalar_lt(x, y):
"""
判断两个标量值的大小关系
Args:
x (float): 第一个标量值
y (float): 第二个标量值
Returns:
bool: 如果 x 小于 y则返回 True否则返回 False
"""
"""Implement `scalar_lt`.""" """Implement `scalar_lt`."""
return x < y return x < y
def identity(x): def identity(x):
"""
返回输入参数本身
Args:
x: 任何类型的输入参数
Returns:
返回输入参数本身
"""
"""Implement `identity`.""" """Implement `identity`."""
return x return x
def zeros_like_tensor(x): def zeros_like_tensor(x):
"""
根据给定的张量x创建一个形状相同但所有元素为零的新张量
Args:
x (Tensor): 输入的张量用于确定新张量的形状
Returns:
Tensor: 一个与输入张量x形状相同但所有元素为零的新张量
"""
"""Implement `zeros_like_tensor`.""" """Implement `zeros_like_tensor`."""
x = x.asnumpy() x = x.asnumpy()
value = Tensor(np.zeros(x.shape).astype(np.float32)) value = Tensor(np.zeros(x.shape).astype(np.float32))
@ -101,61 +245,201 @@ def zeros_like_tensor(x):
def Switch(c, x, y): def Switch(c, x, y):
"""
实现 `switch` 功能
Args:
c (bool): 条件值如果为 True则返回 x否则返回 y
x (Any): 条件为 True 时返回的值
y (Any): 条件为 False 时返回的值
Returns:
Any: 根据条件 c 返回 x y
"""
"""Implement `switch`.""" """Implement `switch`."""
return x if c else y return x if c else y
def list_getitem(data, item): def list_getitem(data, item):
"""
从列表中获取指定索引处的元素
Args:
data (list): 待查询的列表
item (int): 要获取的元素的索引
Returns:
返回列表中索引为item的元素
Raises:
IndexError: 如果索引超出列表范围
"""
"""Implement `list_getitem`.""" """Implement `list_getitem`."""
return data[item] return data[item]
def bool_not(x): def bool_not(x):
"""
对输入值取反
Args:
x (bool): 要取反的布尔值
Returns:
bool: x 的逻辑非值
"""
"""Implement `bool_not`.""" """Implement `bool_not`."""
return not x return not x
def bool_and(x, y): def bool_and(x, y):
"""
对两个布尔值进行逻辑与操作
Args:
x (bool): 第一个布尔值
y (bool): 第二个布尔值
Returns:
bool: 返回两个布尔值进行逻辑与操作后的结果
"""
"""Implement `bool_and`.""" """Implement `bool_and`."""
return x and y return x and y
def bool_or(x, y): def bool_or(x, y):
"""
实现布尔或运算
Args:
x (bool): 第一个布尔值
y (bool): 第二个布尔值
Returns:
bool: 如果 x y True则返回 True否则返回 False
"""
"""Implement `bool_or`.""" """Implement `bool_or`."""
return x or y return x or y
def make_list(*xs): def make_list(*xs):
"""
将不定数量的参数转换为一个列表
Args:
*xs: 不定数量的参数可以是任意类型
Returns:
list: 包含所有传入参数的列表
Examples:
>>> make_list(1, 2, 3)
[1, 2, 3]
>>> make_list('a', 'b', 'c')
['a', 'b', 'c']
>>> make_list(1, 'a', [1, 2, 3])
[1, 'a', [1, 2, 3]]
"""
"""Implement `make_list`.""" """Implement `make_list`."""
return list(xs) return list(xs)
def list_len(x): def list_len(x):
"""
计算列表的长度
Args:
x (list): 需要计算长度的列表
Returns:
int: 列表的长度
"""
"""Implement `list_len`.""" """Implement `list_len`."""
return len(x) return len(x)
def Depend(value, expr): def Depend(value, expr):
"""
依赖函数根据给定的表达式返回相应的值
Args:
value (Any): 要返回的值
expr (Any): 表达式该参数在当前实现中被忽略
Returns:
Any: 返回与输入相同的值
"""
"""Implement `Depend`.""" """Implement `Depend`."""
return value return value
def UpdateState(monad, *exprs): def UpdateState(monad, *exprs):
"""
更新状态
Args:
monad (Monad): 一个符合 Monad 类型的对象
*exprs (Any): 需要更新的表达式可以为任意类型
Returns:
Monad: 更新后的 Monad 对象
"""
"""Implement `UpdateState`.""" """Implement `UpdateState`."""
return monad return monad
def Load(value, u=None): def Load(value, u=None):
"""
加载指定的值
Args:
value (Any): 要加载的值
u (Optional[Any], optional): 可选参数默认为None当前版本未使用保留以便未来扩展
Returns:
Any: 返回加载的值
"""
"""Implement `Load`.""" """Implement `Load`."""
return value return value
# only used in PyNative mode # only used in PyNative mode
def make_ref(key, value, ref): def make_ref(key, value, ref):
"""
创建一个引用对象
Args:
key (str): 键名用于标识引用的对象
value (Any): 引用对象的值
ref (Any): 引用对象可以为任意类型
Returns:
Any: 返回引用的值
"""
return value return value
def scalar_cast(x, t): def scalar_cast(x, t):
"""
将标量值x转换为指定的NumPy数据类型t
Args:
x (float, int): 要转换的标量值
t (np.dtype): 目标NumPy数据类型
Returns:
Any: 转换后的标量值类型为t
"""
"""Implement scalar_cast.""" """Implement scalar_cast."""
np_type = dtype_to_nptype(t) np_type = dtype_to_nptype(t)
value = np_type(x) value = np_type(x)
@ -164,16 +448,46 @@ def scalar_cast(x, t):
def typeof(x): def typeof(x):
"""
实现 typeof 函数
Args:
x (Any): 要获取类型的对象
Returns:
str: 返回传入对象的Python类型名称
"""
"""Implement typeof.""" """Implement typeof."""
return get_py_obj_dtype(x) return get_py_obj_dtype(x)
def tuple_to_array(x): def tuple_to_array(x):
"""
将元组转换为数组
Args:
x (tuple): 待转换的元组
Returns:
Tensor: 转换后的数组
"""
"""Implement `tuple_to_array`.""" """Implement `tuple_to_array`."""
return Tensor(np.array(x)) return Tensor(np.array(x))
def stop_gradient(x): def stop_gradient(x):
"""
停止梯度传播
Args:
x (Tensor): 需要停止梯度传播的张量
Returns:
Tensor: 停止梯度传播的张量
"""
"""Implement `stop_gradient`.""" """Implement `stop_gradient`."""
return x return x
@ -182,6 +496,20 @@ hyper_map = C.HyperMap()
def mixed_precision_cast(dst_type, x): def mixed_precision_cast(dst_type, x):
"""
实现混合精度转换函数
Args:
dst_type (mstype.Type): 目标数据类型
x (Union[Tensor, list, tuple]): 需要进行类型转换的数据可以是单个Tensor也可以是一个包含Tensor的列表或元组
Returns:
Union[Tensor, list, tuple]: 转换后的数据类型与输入一致
Raises:
TypeError: 如果输入数据类型不支持将引发TypeError异常
"""
"""Implement `mixed_precision_cast`.""" """Implement `mixed_precision_cast`."""
def cast_inner(data): def cast_inner(data):

@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""init""" """init"""
# 从splitter模块中导入split_with_json函数
from .splitter import split_with_json from .splitter import split_with_json
# 从expander模块中导入get_op_expander函数
from .expander import get_op_expander from .expander import get_op_expander
# 从parallel_estimate模块中导入estimate_calculation_amount和estimate_ops函数
from .parallel_estimate import estimate_calculation_amount, estimate_ops from .parallel_estimate import estimate_calculation_amount, estimate_ops

@ -22,8 +22,32 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx
def create_expander(expand_info): def create_expander(expand_info):
"""
根据操作符名称创建一个扩展器
Args:
expand_info (dict): 包含操作符名称及其他相关信息的字典
Returns:
Any: 调用指定操作符名称的扩展器后返回的结果
Raises:
GraphKernelUnsupportedException: 如果指定的操作符名称在扩展器模块中不存在则抛出此异常
"""
"""Create an expander according to op name""" """Create an expander according to op name"""
def call_func(func, arg): def call_func(func, arg):
"""
调用给定的函数并返回其结果
Args:
func (callable): 要调用的函数
arg: 要传递给函数的参数
Returns:
调用给定函数后的返回值
"""
return func(arg) return func(arg)
op_name = str(expand_info['name']) op_name = str(expand_info['name'])
if not hasattr(expanders, op_name): if not hasattr(expanders, op_name):
@ -33,6 +57,21 @@ def create_expander(expand_info):
def extract_expand_info(kernel_info): def extract_expand_info(kernel_info):
"""
将json格式的kernel信息转换为更友好的格式
Args:
kernel_info (dict): 包含kernel信息的字典
Returns:
dict: 转换后的kernel信息字典包含以下键
- name (str): kernel的名称
- input_desc (list): 输入描述列表
- output_desc (list): 输出描述列表
- attr (dict): 属性字典键为属性名值为属性值
- process (str): 处理过程的描述
"""
"""Convert the json into a more friendly format""" """Convert the json into a more friendly format"""
input_desc = [] input_desc = []
if 'input_desc' in kernel_info and kernel_info['input_desc']: if 'input_desc' in kernel_info and kernel_info['input_desc']:
@ -53,6 +92,20 @@ def extract_expand_info(kernel_info):
def get_op_expander(json_str: str): def get_op_expander(json_str: str):
"""
通过json信息获取操作扩展器
Args:
json_str (str): 包含操作扩展器信息的json字符串
Returns:
str: 返回扩展后的操作图的json描述
Raises:
jd.JSONDecodeError: 如果输入的json字符串解码失败
GraphKernelUnsupportedException: 如果操作图不支持的操作类型
"""
"""get op expander by json info""" """get op expander by json info"""
try: try:
kernel_info = json.loads(json_str) kernel_info = json.loads(json_str)

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""expanders init. Deprecated, please add the new operators in the c++ file""" """expanders init. Deprecated, please add the new operators in the c++ file"""
"""扩展器初始化。已弃用请在新运算符中添加C++文件"""
from .addn import AddN from .addn import AddN

@ -27,6 +27,19 @@ class Expander:
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
def __init__(self, expand_info): def __init__(self, expand_info):
"""
初始化方法
Args:
expand_info (dict): 包含模型信息的字典包括模型名称输入描述输出描述属性处理函数等
Attributes:
name (str): 模型名称
inputs (list): 输入描述列表
outputs (list): 输出描述列表
attrs (dict): 模型属性字典
processor (callable): 处理函数
"""
self.name = expand_info["name"] self.name = expand_info["name"]
self.inputs = expand_info["input_desc"] self.inputs = expand_info["input_desc"]
self.outputs = expand_info["output_desc"] self.outputs = expand_info["output_desc"]
@ -34,6 +47,19 @@ class Expander:
self.processor = expand_info["process"] self.processor = expand_info["process"]
def run(self): def run(self):
"""
将操作扩展为图
Args:
Returns:
返回扩展后的图对象
Raises:
GraphKernelUnsupportedException: 如果检查失败则引发此异常
"""
""" """
Expand the operator to a graph. Expand the operator to a graph.
@ -58,9 +84,31 @@ class Expander:
return graph return graph
def _check(self): def _check(self):
"""
检查输入
Args:
Returns:
Raises:
ValueError: 如果输入不符合要求则引发此异常
"""
"""Check inputs""" """Check inputs"""
def _check_output_same(self, outputs): def _check_output_same(self, outputs):
"""
检查输出是否与预期一致
Args:
outputs (list): 实际输出值的列表
Raises:
GKException: 如果实际输出值与预期不一致则抛出异常
"""
for index, value in enumerate(self.outputs): for index, value in enumerate(self.outputs):
if list(outputs[index].shape) != list(value['shape']): if list(outputs[index].shape) != list(value['shape']):
raise GKException("{} 's output shape {} is wrong. Expected:{}".format( raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
@ -74,6 +122,18 @@ class Expander:
@abstractmethod @abstractmethod
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
Expand 操作符此函数应在子类中重写
Args:
graph_builder (GraphBuilder): 图构建器对象
Raises:
Exception: 如果子类未重写此方法则抛出异常提示 "_expand() is not implemented in {}".
Returns:
None
"""
"""Expand operator, this function should be overridden in subclass""" """Expand operator, this function should be overridden in subclass"""
raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__)) raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
@ -82,10 +142,34 @@ class ExpanderInfoValidator:
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
def __init__(self): def __init__(self):
"""
初始化方法
Args:
Returns:
"""
"""Init""" """Init"""
@staticmethod @staticmethod
def _add_check_function(kls, func): def _add_check_function(kls, func):
"""
向类 Expander 中的 `_check` 函数添加新的检查函数 `func`
Args:
kls (type): 需要被修改的类对象应继承自 Expander
func (callable): 要添加的新检查函数该函数接受一个对象作为参数
Returns:
None
Raises:
AttributeError: 如果 kls 类中不存在 `_check` 方法
"""
""" """
Rewrite the function `_check` in class Expander Rewrite the function `_check` in class Expander
to append the new `func` after the original checks. to append the new `func` after the original checks.
@ -93,6 +177,21 @@ class ExpanderInfoValidator:
old_check = getattr(kls, "_check") old_check = getattr(kls, "_check")
def new_check(obj): def new_check(obj):
"""
执行新的检查函数
Args:
obj (Any): 需要检查的对象
Returns:
None
Raises:
None
这个函数首先调用旧版本的检查函数 `old_check` 对传入的对象 `obj` 进行检查
然后调用自定义的函数 `func` 对该对象进行处理
"""
old_check(obj) old_check(obj)
func(obj) func(obj)
@ -103,6 +202,34 @@ class ExpanderInfoValidator:
""" """
Add new supported format for the operator Add new supported format for the operator
Args:
*input_format: A variable number of arguments representing the new supported formats.
Returns:
A wrapper function that adds the specified formats to the operator's supported formats list.
Raises:
GKException: Raised if the length of the registered format list does not match the length of the input formats,
or if the input formats do not match any registered format.
Exception: Raised if the wrapped class is not a subclass of Expander.
Description:
This function adds a list `__supported_formats` to the expander,
which contains the whitelist of formats supported by the operator.
It also rewrites the `_check` function to check the formats.
Example:
python
@add_format("text", "image")
class MyOperator(Expander):
pass
```
In this example, `MyOperator` will support the "text" and "image" formats.
"""
"""
Add new supported format for the operator
this function will add a list `__supported_formats` into the expander, this function will add a list `__supported_formats` into the expander,
saving the whitelist of formats that this op supports. saving the whitelist of formats that this op supports.
it also rewrites the `_check` function to check the formats. it also rewrites the `_check` function to check the formats.
@ -110,6 +237,19 @@ class ExpanderInfoValidator:
format_list_name = "__supported_formats" format_list_name = "__supported_formats"
def _check_format(obj): def _check_format(obj):
"""
检查对象的输入格式是否与已注册的格式匹配
Args:
obj (object): 需要检查的对象
Raises:
GKException: 如果输入格式与已注册的格式不匹配则引发异常
Returns:
None
"""
inp_formats = [inp['format'] for inp in obj.inputs] inp_formats = [inp['format'] for inp in obj.inputs]
for formats in getattr(obj, format_list_name): for formats in getattr(obj, format_list_name):
if len(formats) != len(inp_formats): if len(formats) != len(inp_formats):
@ -120,6 +260,18 @@ class ExpanderInfoValidator:
raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name)) raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
def wrapper(cls): def wrapper(cls):
"""
为给定的类添加包装功能
Args:
cls: 需要被包装的类必须继承自 Expander
Returns:
返回包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander): if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__)) raise Exception("{} should be subclass of Expander.".format(cls.__name__))
if not hasattr(cls, format_list_name): if not hasattr(cls, format_list_name):
@ -132,11 +284,49 @@ class ExpanderInfoValidator:
@staticmethod @staticmethod
def check_all_formats_same(kls): def check_all_formats_same(kls):
"""
检查所有格式是否相同
Args:
kls: 待检查的类
Returns:
返回传入的类 kls并在类上注册一个检查函数用于验证该类所有输入格式是否一致
Raises:
Exception: 如果传入的类 kls 不是 Expander 的子类则抛出异常
GKException: 如果 kls 类中的输入格式不一致则抛出异常并显示不匹配格式的信息
"""
"""Check that all formats are the same""" """Check that all formats are the same"""
# Ensure no args case can return a class # Ensure no args case can return a class
def _check(*args): def _check(*args):
"""
检查操作输入格式是否一致的装饰器
Args:
*args: 可变参数装饰器可以接收任意数量的参数
Returns:
wrapper: 返回一个装饰器函数用于包装类
Raises:
GKException: 如果所有输入的格式不一致抛出GKException异常
Exception: 如果被装饰的类不是Expander的子类抛出异常
"""
def _check_format(obj): def _check_format(obj):
"""
检查输入格式是否一致
Args:
obj (Any): 包含输入信息的对象
Raises:
GKException: 如果所有输入格式不一致则抛出异常并包含不匹配格式的具体信息
"""
inp_formats = [inp['format'] for inp in obj.inputs] inp_formats = [inp['format'] for inp in obj.inputs]
if all((fmt == inp_formats[0] for fmt in inp_formats[1:])): if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
return return
@ -144,6 +334,19 @@ class ExpanderInfoValidator:
','.join(inp_formats), obj.name)) ','.join(inp_formats), obj.name))
def wrapper(cls): def wrapper(cls):
"""
将给定类包装为 Expander 的子类并进行格式检查
Args:
cls (class): 需要包装的类
Returns:
class: 包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander): if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__)) raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_format) ExpanderInfoValidator._add_check_function(cls, _check_format)
@ -155,14 +358,53 @@ class ExpanderInfoValidator:
@staticmethod @staticmethod
def check_attrs(*args): def check_attrs(*args):
"""
检查属性是否存在
Args:
*args: 一个或多个属性名用于检查对象是否具有这些属性
Returns:
一个装饰器函数该装饰器函数用于验证类是否具有指定的属性
Raises:
GKException: 如果对象不具有指定的属性则抛出该异常
Exception: 如果被装饰的类不是 Expander 的子类则抛出该异常
"""
"""Check the attrs exist""" """Check the attrs exist"""
def _check_attr(obj): def _check_attr(obj):
"""
检查对象是否具有指定的属性
Args:
obj (object): 要检查的对象
Raises:
GKException: 如果对象不具有指定的属性则抛出异常
Returns:
None
"""
for a in args: for a in args:
if a not in obj.attrs: if a not in obj.attrs:
raise GKException("attr '{}' does not exist.".format(a)) raise GKException("attr '{}' does not exist.".format(a))
def wrapper(cls): def wrapper(cls):
"""
对类进行包装确保该类是 Expander 的子类并添加属性检查功能
Args:
cls (class): 需要包装的类
Returns:
class: 包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander): if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__)) raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_attr) ExpanderInfoValidator._add_check_function(cls, _check_attr)
@ -172,6 +414,21 @@ class ExpanderInfoValidator:
def to_frac_z_axis(ori_shape, ori_axis): def to_frac_z_axis(ori_shape, ori_axis):
"""
判断是否为分形NZ格式
Args:
----
ori_shape: list or tuple
输入的原始形状
ori_axis: list or tuple
操作的原始形状的轴
Returns:
-------
output: list
分形Nz形状的轴
"""
""" """
judge the format is fractal NZ judge the format is fractal NZ
Parameters Parameters
@ -208,6 +465,16 @@ def to_frac_z_axis(ori_shape, ori_axis):
def infer_shape_from_fractalnz(fractal): def infer_shape_from_fractalnz(fractal):
"""
从fractalnz形状推断原始形状
Args:
fractal (list): fractalnz形状一个包含形状的列表
Returns:
list: 推断出的原始形状
"""
"get original shape from fractalnz shape" "get original shape from fractalnz shape"
shape = [] shape = []
dims = len(fractal) dims = len(fractal)
@ -222,6 +489,17 @@ def infer_shape_from_fractalnz(fractal):
def get_reduced_ori_shape(shape, axis): def get_reduced_ori_shape(shape, axis):
"""
获取基于原始形状的降维后的形状
Args:
shape (List[int]): 原始形状是一个整数列表
axis (List[int]): 需要降维的轴索引列表
Returns:
List[int]: 降维后的形状是一个整数列表
"""
"get shape after reduced which is based on original shape" "get shape after reduced which is based on original shape"
reduced_ori_shape = [] reduced_ori_shape = []
for i, value in enumerate(shape): for i, value in enumerate(shape):
@ -233,6 +511,25 @@ def get_reduced_ori_shape(shape, axis):
def get_reduce_axis_shape(shape, data_format, axis): def get_reduce_axis_shape(shape, data_format, axis):
"""
根据给定的输入形状数据格式和轴获取在指定格式下的归约轴和原始的归约形状
Args:
-----
shape: list or tuple
输入的形状
data_format: str
输入的数据格式
axis: None, int, list or tuple
在原始形状下的归约轴
Returns:
--------
reduce_axis: list
在指定数据格式下的归约轴
ori_reduced_shape: list
原始的归约形状
"""
""" """
Get the reduce axis under format `data_format` and original reduced shape. Get the reduce axis under format `data_format` and original reduced shape.
Parameters Parameters

@ -13,20 +13,47 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for addn""" """generate json desc for addn"""
# 导入GraphKernelUnsupportedException异常类
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
# 使用VLD.check_all_formats_same装饰器确保所有输入格式相同
@VLD.check_all_formats_same @VLD.check_all_formats_same
class AddN(Expander): class AddN(Expander):
"""Expand AddN to multiple Adds""" """Expand AddN to multiple Adds"""
# 检查输入数量是否大于1
def _check(self): def _check(self):
"""
检查输入的数量是否满足要求
Args:
Returns:
Raises:
GKException: 如果输入的数量小于2则抛出GKException异常
"""
if len(self.inputs) < 2: if len(self.inputs) < 2:
raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}" raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
.format(len(self.inputs))) .format(len(self.inputs)))
# 将AddN展开为多个Add操作
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入张量进行逐元素加法运算
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成图节点
Returns:
Tensor: 逐元素加法运算后的结果张量
"""
result = self.inputs[0] result = self.inputs[0]
for inp in self.inputs[1:]: for inp in self.inputs[1:]:
result = graph_builder.emit('Add', [result, inp]) result = graph_builder.emit('Add', [result, inp])

@ -36,15 +36,19 @@ class BatchNorm(Expander):
input_x_ori_type = input_x.dtype input_x_ori_type = input_x.dtype
input_x_new_type = input_x.dtype input_x_new_type = input_x.dtype
# 如果输入数据的类型为float16而scale、offset、mean、variance的类型为float32则将输入数据类型转换为float32
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \ if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
input_mean.dtype == "float32" and input_variance.dtype == "float32": input_mean.dtype == "float32" and input_variance.dtype == "float32":
input_x_new_type = "float32" input_x_new_type = "float32"
# 如果输入数据类型与原始类型不同,则进行类型转换
if input_x_new_type != input_x_ori_type: if input_x_new_type != input_x_ori_type:
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type}) input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
# 如果是训练模式
if self.attrs['is_training']: if self.attrs['is_training']:
self.inputs[0] = input_x self.inputs[0] = input_x
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder) res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
# 如果输入数据类型与原始类型不同,则将输出数据类型转换为原始类型
if input_x_new_type != input_x_ori_type: if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type}) res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
@ -70,21 +74,42 @@ class BatchNorm(Expander):
return res_y, var_add, var_add, var_add, var_add return res_y, var_add, var_add, var_add, var_add
def _bn_train(self, graph_builder): def _bn_train(self, graph_builder):
"""
在训练模式下扩展BatchNorm
Args:
graph_builder (GraphBuilder): 图构建器实例
Returns:
tuple: 包含以下内容的元组:
- res_y (Tensor): 归一化后的输出
- mean_res (Tensor): 更新后的移动均值
- variance_res (Tensor): 更新后的移动方差
- mean_muls (Tensor): 输入数据的均值
- y_sqrt_rec (Tensor): 1 / sqrt(方差 + epsilon)用于反向传播
"""
"""expand BatchNorm for training mode""" """expand BatchNorm for training mode"""
# 获取输入数据
input_x = self.inputs[0] input_x = self.inputs[0]
input_scale = self.inputs[1] input_scale = self.inputs[1]
input_offset = self.inputs[2] input_offset = self.inputs[2]
input_mean = self.inputs[3] input_mean = self.inputs[3]
input_variance = self.inputs[4] input_variance = self.inputs[4]
# 获取epsilon值
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon']) epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
# 获取reduce轴
reduce_axis = () reduce_axis = ()
# 获取输入数据的形状
shape_x = input_x.shape shape_x = input_x.shape
# 根据输入数据的格式设置reduce轴和num值
if input_x.data_format == DF.NHWC: if input_x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2) reduce_axis = (0, 1, 2)
num = shape_x[0] * shape_x[1] * shape_x[2] num = shape_x[0] * shape_x[1] * shape_x[2]
else: else:
reduce_axis = (0, 2, 3) reduce_axis = (0, 2, 3)
num = shape_x[0] * shape_x[2] * shape_x[3] num = shape_x[0] * shape_x[2] * shape_x[3]
# 计算num的倒数
num_rec = 1.0 / num num_rec = 1.0 / num
num_rec_v = graph_builder.value(input_scale.dtype, num_rec) num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
@ -112,41 +137,67 @@ class BatchNorm(Expander):
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt]) y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
# compute res_y # compute res_y
# 计算输入x和mean_muls_expand的差值
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand]) tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对y_sqrt_rec进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW): if input_x.data_format in (DF.DEFAULT, DF.NCHW):
y_sqrt_rec_expand = graph_builder.emit( y_sqrt_rec_expand = graph_builder.emit(
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])}) 'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
# 否则y_sqrt_rec保持不变
else: else:
y_sqrt_rec_expand = y_sqrt_rec y_sqrt_rec_expand = y_sqrt_rec
# 计算tmp_sub和y_sqrt_rec_expand的乘积
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand]) y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对input_scale进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW): if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_scale_expand = graph_builder.emit( input_scale_expand = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])}) 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
# 否则input_scale保持不变
else: else:
input_scale_expand = input_scale input_scale_expand = input_scale
# 计算input_scale_expand和y_norm的乘积
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm]) res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对input_offset进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW): if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_offset_expand = graph_builder.emit( input_offset_expand = graph_builder.emit(
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])}) 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
# 否则input_offset保持不变
else: else:
input_offset_expand = input_offset input_offset_expand = input_offset
# 计算res_y_mul和input_offset_expand的和
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand]) res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
# compute mean_res # compute mean_res
# 计算动量减去1的值
momentum_sub = scalar_one - self.attrs['momentum'] momentum_sub = scalar_one - self.attrs['momentum']
# 将动量减去1的值转换为输入数据的类型
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub) momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
# 计算新的移动平均值的临时值
# 计算新的running_mean_tmp
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean]) new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
# 计算momentum_v
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum']) momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
# 计算current_mean_tmp
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls]) current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
# 计算updated_moving_mean
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp]) updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
# 将updated_moving_mean赋值给input_mean
mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean]) mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1) # variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
# 计算方差
var_num = float(num) / (num - 1) var_num = float(num) / (num - 1)
# 将方差转换为输入数据的类型
var_num_v = graph_builder.value(input_scale.dtype, var_num) var_num_v = graph_builder.value(input_scale.dtype, var_num)
# 计算方差乘积
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul]) var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
# 计算新的移动方差
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance]) new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
# 计算当前移动方差
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update]) current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
# 更新移动方差
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp]) updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
# 将更新后的移动方差赋值给输入方差
variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance]) variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
# 返回结果
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec

@ -11,46 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# =========================================================================== # ========================================================================
"""generate json desc for BatchNormGrad""" # ===
# 版权声明
# 根据Apache License 2.0授权
# 除非遵守许可,否则不得使用此文件
"""
为BatchNormGrad生成json描述BatchNormGrad是用于计算Batch Normalization层梯度的类
"""
# 导入必要的模块和类
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
from .expand_dims import ExpandDims from .expand_dims import ExpandDims
# 定义BatchNormGrad类继承自Expander
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('is_training', 'epsilon') @VLD.check_attrs('is_training', 'epsilon')
class BatchNormGrad(Expander): class BatchNormGrad(Expander):
"""BatchNormGrad expander""" """BatchNormGrad扩展器用于计算Batch Normalization层的梯度"""
# 定义扩展方法该方法将被调用来执行BatchNormGrad的计算
def _expand(self, graph_builder): def _expand(self, graph_builder):
# get op info # 获取操作信息,包括梯度、输入数据、尺度、保存的均值和倒数方差
input_dy = self.inputs[0] input_dy = self.inputs[0] # 输入数据的梯度
input_x = self.inputs[1] input_x = self.inputs[1] # 输入数据
input_scale = self.inputs[2] input_scale = self.inputs[2] # 输入数据的尺度
input_save_mean = self.inputs[3] input_save_mean = self.inputs[3] # 保存的均值
input_save_inv_variance = self.inputs[4] input_save_inv_variance = self.inputs[4] # 保存的倒数方差
# 根据输入数据的格式计算reduce_axis用于后续的ReduceSum操作
reduce_axis = () reduce_axis = ()
shape_x = input_x.shape shape_x = input_x.shape
if input_x.data_format == DF.NHWC: if input_x.data_format == DF.NHWC: # 如果数据格式为NHWC
reduce_axis = (0, 1, 2) reduce_axis = (0, 1, 2) # 指定ReduceSum的轴
num = shape_x[0] * shape_x[1] * shape_x[2] num = shape_x[0] * shape_x[1] * shape_x[2] # 计算元素总数
else: else: # 否则假设数据格式为NCHW
reduce_axis = (0, 2, 3) reduce_axis = (0, 2, 3) # 指定ReduceSum的轴
num = shape_x[0] * shape_x[2] * shape_x[3] num = shape_x[0] * shape_x[2] * shape_x[3] # 计算元素总数
ori_type = input_x.dtype ori_type = input_x.dtype # 原始数据类型
# 如果原始数据类型为float16则转换为float32进行计算以避免精度损失
if ori_type == 'float16': if ori_type == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
if input_dy.dtype == 'float16': if input_dy.dtype == 'float16':
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
num_rec = -1.0 / num num_rec = -1.0 / num # 计算倒数
num_rec_v = graph_builder.value(input_scale.dtype, num_rec) num_rec_v = graph_builder.value(input_scale.dtype, num_rec) # 创建倒数的值
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dbeta即beta的梯度
# in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass # 根据是否在训练中计算inv_variance倒数方差
if self.attrs['is_training']: if self.attrs['is_training']:
inv_variance = input_save_inv_variance inv_variance = input_save_inv_variance
else: else:
@ -61,7 +74,7 @@ class BatchNormGrad(Expander):
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one) scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps]) inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
# compute dgamma # 计算dgammagamma的梯度
if input_x.data_format in (DF.DEFAULT, DF.NCHW): if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_save_mean = graph_builder.emit( input_save_mean = graph_builder.emit(
'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])}) 'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
@ -69,13 +82,13 @@ class BatchNormGrad(Expander):
'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])}) 'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
input_scale = graph_builder.emit( input_scale = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])}) 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) # 计算x减去均值
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) # 计算x除以倒数方差
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) # 计算dgamma参数
dgamma = graph_builder.emit( dgamma = graph_builder.emit(
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dgamma
# compute dx # 计算dxx的梯度
if self.attrs['is_training']: if self.attrs['is_training']:
tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta]) tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
if input_x.data_format in (DF.DEFAULT, DF.NCHW): if input_x.data_format in (DF.DEFAULT, DF.NCHW):
@ -95,11 +108,12 @@ class BatchNormGrad(Expander):
y_scale = graph_builder.emit('Mul', [input_scale, input_dy]) y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
dx = graph_builder.emit('Mul', [inv_variance, y_scale]) dx = graph_builder.emit('Mul', [inv_variance, y_scale])
if ori_type == 'float16': if ori_type == 'float16':
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) # 如果原始数据类型为float16则转换回float16
# set output tensors' data_format # 设置输出张量的数据格式
dx.data_format = self.outputs[0]['format'] dx.data_format = self.outputs[0]['format']
dgamma.data_format = self.outputs[1]['format'] dgamma.data_format = self.outputs[1]['format']
dbeta.data_format = self.outputs[2]['format'] dbeta.data_format = self.outputs[2]['format']
# 返回计算结果
return dx, dgamma, dbeta return dx, dgamma, dbeta

@ -13,37 +13,64 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for bias_add""" """generate json desc for bias_add"""
# 导入MindSpore的DataFormat类
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
# 为BiasAddGrad类添加DF.DEFAULT、DF.NHWC、DF.NCHW、DF.FRAC_NZ格式的验证
@VLD.add_format(DF.DEFAULT) @VLD.add_format(DF.DEFAULT)
@VLD.add_format(DF.NHWC) @VLD.add_format(DF.NHWC)
@VLD.add_format(DF.NCHW) @VLD.add_format(DF.NCHW)
@VLD.add_format(DF.FRAC_NZ) @VLD.add_format(DF.FRAC_NZ)
# 定义BiasAddGrad类继承自Expander类
class BiasAddGrad(Expander): class BiasAddGrad(Expander):
"""BiasAddGrad expander""" """BiasAddGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
内部方法用于扩展输入张量的维度
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成图操作
Returns:
Tensor: 扩展维度后的张量
"""
# 获取输入张量
x = self.inputs[0] x = self.inputs[0]
# 定义reduce_axis用于指定求和的维度
reduce_axis = () reduce_axis = ()
# 如果输入张量的数据格式为NHWC则reduce_axis为(0, 1, 2)
if x.data_format == DF.NHWC: if x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2) reduce_axis = (0, 1, 2)
# 如果输入张量的数据格式为NCHW则reduce_axis为(0, 2, 3)
elif x.data_format == DF.NCHW: elif x.data_format == DF.NCHW:
reduce_axis = (0, 2, 3) reduce_axis = (0, 2, 3)
# 如果输入张量的数据格式为FRAC_NZ则reduce_axis为(-2, -3)
elif x.data_format == DF.FRAC_NZ: elif x.data_format == DF.FRAC_NZ:
reduce_axis = (-2, -3) reduce_axis = (-2, -3)
# 如果输入张量的数据格式为DefaultFormat则根据shape的长度确定reduce_axis
else: else:
# DefaultFormat shape's length should be from 2 to 4 # DefaultFormat shape's length should be from 2 to 4
# 如果x的维度为2则reduce_axis为(0,)
if len(x.shape) == 2: if len(x.shape) == 2:
reduce_axis = (0,) reduce_axis = (0,)
# 如果x的维度为3则reduce_axis为(0, 1)
elif len(x.shape) == 3: elif len(x.shape) == 3:
reduce_axis = (0, 1) reduce_axis = (0, 1)
# 否则reduce_axis为(0, 2, 3)
else: else:
reduce_axis = (0, 2, 3) reduce_axis = (0, 2, 3)
# 发射ReduceSum操作计算x的reduce_sumreduce_axis为reduce_axiskeep_dims为False
result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
# 如果x的数据格式为DF.FRAC_NZ则将result的shape改为x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
if x.data_format == DF.FRAC_NZ: if x.data_format == DF.FRAC_NZ:
out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]] out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
# 发射Reshape操作将result的shape改为out_shape
result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape}) result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape})
# 返回result
return result return result

@ -21,13 +21,28 @@ class ClipByNormNoDivSum(Expander):
"""ClipByNormNoDivSum expander""" """ClipByNormNoDivSum expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入的张量进行计算返回计算结果
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: 计算结果张量
"""
input_x0, input_x1, input_x2, input_x3 = self.inputs input_x0, input_x1, input_x2, input_x3 = self.inputs
# cal result # cal result
# 计算大于结果
greater_res = graph_builder.emit('Greater', [input_x0, input_x1]) greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
# 根据大于结果选择input_x0或input_x2
select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2]) select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
# 计算select_res0的平方根
sqrt_res = graph_builder.emit('Sqrt', [select_res0]) sqrt_res = graph_builder.emit('Sqrt', [select_res0])
# 根据大于结果选择sqrt_res或input_x0
select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0]) select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
# 计算select_res1和input_x3的最大值
result = graph_builder.emit('Maximum', [select_res1, input_x3]) result = graph_builder.emit('Maximum', [select_res1, input_x3])
return result return result

@ -14,8 +14,13 @@
# ============================================================================ # ============================================================================
"""complex expanders init""" """complex expanders init"""
# 从当前目录下的abs模块中导入CAbs类
from .abs import CAbs from .abs import CAbs
# 从当前目录下的add模块中导入CAdd类
from .add import CAdd from .add import CAdd
# 从当前目录下的div模块中导入CDiv类
from .div import CDiv from .div import CDiv
# 从当前目录下的mul模块中导入CMul类
from .mul import CMul from .mul import CMul
# 从当前目录下的sub模块中导入CSub类
from .sub import CSub from .sub import CSub

@ -20,11 +20,23 @@ class CAbs(Expander):
"""CAbs expander""" """CAbs expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
# 获取输入的第一个元素
input_x = self.inputs[0] input_x = self.inputs[0]
# 发射指令CReal将输入x的实部提取出来
x_real = graph_builder.emit('CReal', [input_x]) x_real = graph_builder.emit('CReal', [input_x])
# 发射指令CImag将输入x的虚部提取出来
x_imag = graph_builder.emit('CImag', [input_x]) x_imag = graph_builder.emit('CImag', [input_x])
# 发射指令Mul计算x的实部的平方
squre_x_real = graph_builder.emit('Mul', [x_real, x_real]) squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
# 发射指令Mul计算x的虚部的平方
squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag]) squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
# 发射指令Add计算实部和虚部的平方和
squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag]) squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
# 发射指令Sqrt计算平方和的平方根
result = graph_builder.emit('Sqrt', [squre_sum]) result = graph_builder.emit('Sqrt', [squre_sum])
return result return result

@ -22,12 +22,35 @@ class CAdd(Expander):
"""CAdd expander""" """CAdd expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
将两个复数相加
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Node: 相加后的复数结果
"""
# 获取输入参数
input_x, input_y = self.inputs input_x, input_y = self.inputs
# 将输入参数x转换为实数部分
x_real = graph_builder.emit('CReal', [input_x]) x_real = graph_builder.emit('CReal', [input_x])
# 将输入参数y转换为实数部分
y_real = graph_builder.emit('CReal', [input_y]) y_real = graph_builder.emit('CReal', [input_y])
# 将输入参数x转换为虚数部分
x_imag = graph_builder.emit('CImag', [input_x]) x_imag = graph_builder.emit('CImag', [input_x])
# 将输入参数y转换为虚数部分
y_imag = graph_builder.emit('CImag', [input_y]) y_imag = graph_builder.emit('CImag', [input_y])
# 将x和y的实数部分相加
result_real = graph_builder.emit('Add', [x_real, y_real]) result_real = graph_builder.emit('Add', [x_real, y_real])
# 将x和y的虚数部分相加
result_imag = graph_builder.emit('Add', [x_imag, y_imag]) result_imag = graph_builder.emit('Add', [x_imag, y_imag])
# 将相加后的实数部分和虚数部分组合为复数
result = graph_builder.emit('Complex', [result_real, result_imag]) result = graph_builder.emit('Complex', [result_real, result_imag])
return result return result

@ -22,22 +22,43 @@ class CDiv(Expander):
"""CDiv expander""" """CDiv expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
CDiv Implementation
Args:
graph_builder: 图构建器对象用于构建计算图
Returns:
返回复数除法结果
实现复数除法CDiv操作
获取两个输入的复数分别计算它们的实部和虚部
然后计算分母和分子的实部和虚部并进行除法运算
最后将得到的商的实部和虚部合并为复数结果返回
"""
"""CDiv Implementation""" """CDiv Implementation"""
# 获取输入的两个复数
input_x, input_y = self.inputs input_x, input_y = self.inputs
x_real = graph_builder.emit('CReal', [input_x]) # 获取输入复数的实部
y_real = graph_builder.emit('CReal', [input_y]) x_real = graph_builder.emit('CReal', [input_x]) # 发射 CReal 操作获取 input_x 的实部
x_imag = graph_builder.emit('CImag', [input_x]) y_real = graph_builder.emit('CReal', [input_y]) # 发射 CReal 操作获取 input_y 的实部
y_imag = graph_builder.emit('CImag', [input_y]) # 获取输入复数的虚部
squre_y_real = graph_builder.emit('Mul', [y_real, y_real]) x_imag = graph_builder.emit('CImag', [input_x]) # 发射 CImag 操作获取 input_x 的虚部
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag]) y_imag = graph_builder.emit('CImag', [input_y]) # 发射 CImag 操作获取 input_y 的虚部
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag]) # 计算分母
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) squre_y_real = graph_builder.emit('Mul', [y_real, y_real]) # 计算 y_real 的平方
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag]) # 计算 y_imag 的平方
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag]) # 计算分母
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) # 计算分子
final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag]) x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) # 计算 x_real 和 y_real 的乘积
final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag]) x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) # 计算 x_imag 和 y_imag 的乘积
result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator]) x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) # 计算 x_real 和 y_imag 的乘积
result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator]) x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) # 计算 x_imag 和 y_real 的乘积
result = graph_builder.emit('Complex', [result_real, result_imag]) final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag]) # 计算分子的实部
final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag]) # 计算分子的虚部
# 计算商
result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator]) # 计算商的实部
result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator]) # 计算商的虚部
# 将商合并为复数结果
result = graph_builder.emit('Complex', [result_real, result_imag]) # 将实部和虚部合并为复数结果
return result return result

@ -22,17 +22,45 @@ class CMul(Expander):
"""CMul expander""" """CMul expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算两个复数的乘积
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中生成操作节点
Returns:
Result: 计算得到的复数乘积结果
"""
"""CMul Implementation""" """CMul Implementation"""
# 获取输入的两个复数
input_x, input_y = self.inputs input_x, input_y = self.inputs
x_real = graph_builder.emit('CReal', [input_x])
y_real = graph_builder.emit('CReal', [input_y]) # 获取输入复数的实部
x_imag = graph_builder.emit('CImag', [input_x]) x_real = graph_builder.emit('CReal', [input_x]) # 发射指令获取input_x的实部
y_imag = graph_builder.emit('CImag', [input_y]) y_real = graph_builder.emit('CReal', [input_y]) # 发射指令获取input_y的实部
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) # 获取输入复数的虚部
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) x_imag = graph_builder.emit('CImag', [input_x]) # 发射指令获取input_x的虚部
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) y_imag = graph_builder.emit('CImag', [input_y]) # 发射指令获取input_y的虚部
result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag])
result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real]) # 计算实部与实部的乘积
result = graph_builder.emit('Complex', [result_real, result_imag]) x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) # 发射指令计算x_real与y_real的乘积
# 计算虚部与虚部的乘积
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) # 发射指令计算x_imag与y_imag的乘积
# 计算实部与虚部的乘积
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) # 发射指令计算x_real与y_imag的乘积
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) # 发射指令计算x_imag与y_real的乘积
# 计算复数的实部
result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag]) # 发射指令计算实部结果
# 计算复数的虚部
result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real]) # 发射指令计算虚部结果
# 构造复数结果
result = graph_builder.emit('Complex', [result_real, result_imag]) # 发射指令构造复数结果
return result return result

@ -22,12 +22,38 @@ class CSub(Expander):
"""CSub expander""" """CSub expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算两个复数的差
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: 计算得到的复数差的结果
"""
# 获取输入
input_x, input_y = self.inputs input_x, input_y = self.inputs
# 提取输入x的实部
x_real = graph_builder.emit('CReal', [input_x]) x_real = graph_builder.emit('CReal', [input_x])
# 提取输入y的实部
y_real = graph_builder.emit('CReal', [input_y]) y_real = graph_builder.emit('CReal', [input_y])
# 提取输入x的虚部
x_imag = graph_builder.emit('CImag', [input_x]) x_imag = graph_builder.emit('CImag', [input_x])
# 提取输入y的虚部
y_imag = graph_builder.emit('CImag', [input_y]) y_imag = graph_builder.emit('CImag', [input_y])
# 计算实部之差
result_real = graph_builder.emit('Sub', [x_real, y_real]) result_real = graph_builder.emit('Sub', [x_real, y_real])
# 计算虚部之差
result_imag = graph_builder.emit('Sub', [x_imag, y_imag]) result_imag = graph_builder.emit('Sub', [x_imag, y_imag])
# 将实部和虚部组合成复数结果
result = graph_builder.emit('Complex', [result_real, result_imag]) result = graph_builder.emit('Complex', [result_real, result_imag])
return result return result

@ -18,6 +18,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
# 定义常量
M_ALIGN = 32 M_ALIGN = 32
N_ALIGN = 32 N_ALIGN = 32
K_ALIGN = 16 K_ALIGN = 16
@ -29,6 +30,7 @@ C_CHANNEL_ALIGN = 16
OUT_NHW_ALIGN = 128 OUT_NHW_ALIGN = 128
# 添加格式验证
@VLD.add_format(DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NHWC, DF.NHWC) @VLD.add_format(DF.NHWC, DF.NHWC)
@VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation') @VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
@ -47,6 +49,16 @@ class Conv2D(Expander):
""" """
def __init__(self, expand_info): def __init__(self, expand_info):
"""
类的构造函数
Args:
expand_info (dict): 扩展信息字典包含一些扩展的配置参数
Returns:
None
"""
super().__init__(expand_info) super().__init__(expand_info)
self.dst_type = self.outputs[0]['data_type'] self.dst_type = self.outputs[0]['data_type']
self.dst_format = self.outputs[0]['format'] self.dst_format = self.outputs[0]['format']
@ -59,6 +71,19 @@ class Conv2D(Expander):
self.k = 0 self.k = 0
def _optimize_to_matmul(self): def _optimize_to_matmul(self):
"""
检查是否可以将Conv2D优化为MatMul
Args:
Returns:
bool: 如果可以将Conv2D优化为MatMul则返回True否则返回False
"""
"""
Check if the Conv2D can be optimized to MatMul.
"""
stride = self.attrs['stride'] stride = self.attrs['stride']
dilation = self.attrs['dilation'] dilation = self.attrs['dilation']
_, h, w, _ = self.inputs[1]['shape'] _, h, w, _ = self.inputs[1]['shape']
@ -68,6 +93,18 @@ class Conv2D(Expander):
return False return False
def _common_check(self): def _common_check(self):
"""
对输入和属性的通用检查
Args:
Returns:
Raises:
GKException: 如果输入数据类型不是 float16或者输入格式不是 NHWC或者属性 groups group 不是 1或者属性 dilation 不是 [1, 1, 1, 1]抛出异常
"""
"""common check for inputs and attrs""" """common check for inputs and attrs"""
type_0 = self.inputs[0]['data_type'] type_0 = self.inputs[0]['data_type']
type_1 = self.inputs[1]['data_type'] type_1 = self.inputs[1]['data_type']
@ -91,26 +128,52 @@ class Conv2D(Expander):
.format(dilation)) .format(dilation))
def _check(self): def _check(self):
"""
检查卷积2D操作的参数和输入是否合法
Args:
Raises:
GKException: 当输入参数或输入维度不满足要求时抛出异常
Returns:
"""
# 调用_common_check()方法
self._common_check() self._common_check()
# 获取pad_list
pad_list = self.attrs['pad_list'] pad_list = self.attrs['pad_list']
# 检查pad_list的维度是否为4
check_nd(pad_list, 4) check_nd(pad_list, 4)
# 调用conv_had_pad()方法判断是否有pad
self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode']) self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode'])
# 获取输入的shape
shape_0 = self.inputs[0]['shape'] shape_0 = self.inputs[0]['shape']
shape_1 = self.inputs[1]['shape'] shape_1 = self.inputs[1]['shape']
# 获取stride
stride = self.attrs['stride'] stride = self.attrs['stride']
# 检查shape_0的维度是否为4
check_nd(shape_0, 4) check_nd(shape_0, 4)
# 检查shape_1的维度是否为4
check_nd(shape_1, 4) check_nd(shape_1, 4)
# 检查stride的维度是否为4
check_nd(stride, 4) check_nd(stride, 4)
# 获取shape_0的各个维度
n0, h0, w0, c0 = shape_0 n0, h0, w0, c0 = shape_0
# 获取shape_1的各个维度
n1, h1, w1, c1 = shape_1 n1, h1, w1, c1 = shape_1
# 检查n0是否为N0_CHANNEL_ALIGN的倍数
if (n0 % N0_CHANNEL_ALIGN) != 0: if (n0 % N0_CHANNEL_ALIGN) != 0:
raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}" raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}"
.format(N0_CHANNEL_ALIGN, n0)) .format(N0_CHANNEL_ALIGN, n0))
# 检查n1是否为N1_CHANNEL_ALIGN的倍数
if (n1 % N1_CHANNEL_ALIGN) != 0: if (n1 % N1_CHANNEL_ALIGN) != 0:
raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}" raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}"
.format(N1_CHANNEL_ALIGN, n1)) .format(N1_CHANNEL_ALIGN, n1))
# 检查c0和c1是否相等并且是否为C_CHANNEL_ALIGN的倍数
if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0: if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0:
raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got " raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got "
"{} and {}".format(C_CHANNEL_ALIGN, c0, c1)) "{} and {}".format(C_CHANNEL_ALIGN, c0, c1))
@ -130,68 +193,106 @@ class Conv2D(Expander):
# check if can optimize to matmul # check if can optimize to matmul
self.m, self.n, self.k = n0 * h0 * w0, n1, c1 self.m, self.n, self.k = n0 * h0 * w0, n1, c1
# 调用_optimize_to_matmul()方法判断是否可以优化为matmul
self.can_optimize_to_matmul = self._optimize_to_matmul() self.can_optimize_to_matmul = self._optimize_to_matmul()
# requirements # requirements
if self.can_optimize_to_matmul: if self.can_optimize_to_matmul:
# 如果可以优化为matmul检查k是否大于K_LIMIT
if self.k > K_LIMIT: if self.k > K_LIMIT:
raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got " raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got "
"{}".format(K_LIMIT, self.k)) "{}".format(K_LIMIT, self.k))
# 如果可以优化为matmul检查m*n*k的总大小是否大于MNK_LIMIT
if self.m * self.n * self.k >= MNK_LIMIT: if self.m * self.n * self.k >= MNK_LIMIT:
raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than " raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than "
"{}, but got {}".format(MNK_LIMIT, self.m * self.n * self.k)) "{}, but got {}".format(MNK_LIMIT, self.m * self.n * self.k))
else: else:
# 如果不能优化为matmul计算输出的大小
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1 out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
# 检查n0*out_h*out_w是否为OUT_NHW_ALIGN的倍数
if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0: if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0:
raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}" raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}"
.format(n0, out_h, out_w, OUT_NHW_ALIGN)) .format(n0, out_h, out_w, OUT_NHW_ALIGN))
# 检查stride是否为[1, 1, 2, 2]
if stride != [1, 1, 2, 2]: if stride != [1, 1, 2, 2]:
raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}" raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}"
.format(stride)) .format(stride))
# 保存pad后的shape
self.shape_0_pad = [n0, h0, w0, c0] self.shape_0_pad = [n0, h0, w0, c0]
self.shape_1_pad = [n1, h1, w1, c1] self.shape_1_pad = [n1, h1, w1, c1]
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入进行扩展处理
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展处理后的结果
"""
# 获取输入0
input_0 = self.inputs[0] input_0 = self.inputs[0]
# 获取输入1
input_1 = self.inputs[1] input_1 = self.inputs[1]
# 获取输入0的形状
n0, _, _, c0 = input_0.shape n0, _, _, c0 = input_0.shape
# 获取输入1的形状
n1, _, _, c1 = input_1.shape n1, _, _, c1 = input_1.shape
# 获取输入0的填充形状
n0_p, h0_p, w0_p, c0_p = self.shape_0_pad n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
# 获取输入1的填充形状
n1_p, _, _, c1_p = self.shape_1_pad n1_p, _, _, c1_p = self.shape_1_pad
pad_value = 0 pad_value = 0
# input0 pad # input0 pad
# 初始化输入0的填充前后的值
input_0_pad_before = [0, 0, 0, 0] input_0_pad_before = [0, 0, 0, 0]
input_0_pad_after = [0, 0, 0, 0] input_0_pad_after = [0, 0, 0, 0]
# 如果有填充,则获取填充列表
if self.has_pad: if self.has_pad:
pad_list = self.attrs['pad_list'] pad_list = self.attrs['pad_list']
# 设置输入0的填充前后的值
input_0_pad_before = [0, pad_list[0], pad_list[2], 0] input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
input_0_pad_after = [0, pad_list[1], pad_list[3], 0] input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
# 设置输入0的填充后的值
input_0_pad_after[0] = n0_p - n0 input_0_pad_after[0] = n0_p - n0
input_0_pad_after[3] = c0_p - c0 input_0_pad_after[3] = c0_p - c0
# 如果输入0的填充前后的值不为默认值则进行填充操作
if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]: if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
# 发射填充操作
input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before, input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
'tail': input_0_pad_after, 'tail': input_0_pad_after,
'pad_val': pad_value}) 'pad_val': pad_value})
# input1 pad # input1 pad
# 计算input_1的pad值
input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1] input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
# 如果input_1的pad值不为0则进行pad操作
if input_1_pad_after != [0, 0, 0, 0]: if input_1_pad_after != [0, 0, 0, 0]:
input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0], input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
'tail': input_1_pad_after, 'tail': input_1_pad_after,
'pad_val': pad_value}) 'pad_val': pad_value})
# 如果可以优化为matmul操作则进行matmul操作
if self.can_optimize_to_matmul: if self.can_optimize_to_matmul:
# 将input_0和input_1进行reshape操作
a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]}) a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]}) b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
# 进行matmul操作
c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False, c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
'transpose_b': True, 'transpose_b': True,
'dst_type': self.dst_type}) 'dst_type': self.dst_type})
# 将结果进行reshape操作
result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p], result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
'format': self.dst_format}) 'format': self.dst_format})
# 否则进行Conv2D操作
else: else:
# 设置Conv2D操作的属性
attrs = self.attrs attrs = self.attrs
attrs['pad_list'] = [0, 0, 0, 0] attrs['pad_list'] = [0, 0, 0, 0]
attrs['dst_type'] = self.dst_type attrs['dst_type'] = self.dst_type
# 进行Conv2D操作
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
# unpad # unpad
unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]] unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]

@ -13,18 +13,36 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for DropoutGrad""" """generate json desc for DropoutGrad"""
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
# 定义DropoutGrad类继承自Expander类
@VLD.check_all_formats_same @VLD.check_all_formats_same
@VLD.check_attrs('keep_prob') @VLD.check_attrs('keep_prob')
class DropoutGrad(Expander): class DropoutGrad(Expander):
"""DropoutGrad expander""" """DropoutGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入数据进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的输入数据
"""
# 获取输入数据和掩码
input_dy, input_mask = self.inputs input_dy, input_mask = self.inputs
# 获取保持概率
keep_prob = self.attrs['keep_prob'] keep_prob = self.attrs['keep_prob']
# 计算保持概率的倒数
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
# 计算输入数据和保持概率的乘积
result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
# 计算乘积和掩码的乘积
result = graph_builder.emit('Mul', [result, input_mask]) result = graph_builder.emit('Mul', [result, input_mask])
# 返回结果
return result return result

@ -17,34 +17,84 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
# @VLD.check_all_formats_same检查所有格式的相同性
@VLD.check_all_formats_same @VLD.check_all_formats_same
class EqualCount(Expander): class EqualCount(Expander):
"""EqualCount expander""" """EqualCount expander"""
def __init__(self, expand_info): def __init__(self, expand_info):
"""
初始化方法
Args:
expand_info (dict): 扩展信息字典
Returns:
None
"""
# 调用父类的初始化方法
super().__init__(expand_info) super().__init__(expand_info)
# 获取输入x的形状
self.shape_x = self.inputs[0]['shape'] self.shape_x = self.inputs[0]['shape']
# 获取输入y的形状
self.shape_y = self.inputs[1]['shape'] self.shape_y = self.inputs[1]['shape']
# 获取输入x的数据类型
self.dtype_x = self.inputs[0]['data_type'] self.dtype_x = self.inputs[0]['data_type']
# 获取输入y的数据类型
self.dtype_y = self.inputs[1]['data_type'] self.dtype_y = self.inputs[1]['data_type']
def _check(self): def _check(self):
"""
检查输入的两个张量是否具有相同的形状和数据类型
Args:
Returns:
Raises:
GKException: 如果两个张量的形状不同则引发异常异常信息中包含两个张量的形状
GKException: 如果两个张量的数据类型不同则引发异常异常信息中包含两个张量的数据类型
"""
# 判断输入的形状是否相同
if self.shape_x != self.shape_y: if self.shape_x != self.shape_y:
# 如果不相同,抛出异常
raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}" raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}"
.format(self.shape_x, self.shape_y)) .format(self.shape_x, self.shape_y))
# 判断输入的数据类型是否相同
if self.dtype_x != self.dtype_y: if self.dtype_x != self.dtype_y:
# 如果不相同,抛出异常
raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}" raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}"
.format(self.dtype_x, self.dtype_y)) .format(self.dtype_x, self.dtype_y))
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
扩展输入维度的方法
Args:
graph_builder: 图构建器对象用于生成计算图
Returns:
扩展后的张量
"""
# 获取输入张量
input_x = self.inputs[0] input_x = self.inputs[0]
input_y = self.inputs[1] input_y = self.inputs[1]
# 比较输入张量是否相等
eql_val = graph_builder.emit('Equal', [input_x, input_y]) eql_val = graph_builder.emit('Equal', [input_x, input_y])
# 将比较结果转换为float32类型
cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'}) cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
# 获取输入张量的维度
axis = list(range(len(input_x.shape))) axis = list(range(len(input_x.shape)))
# 对比较结果进行求和
result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': True}) result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': True})
# 如果求和结果的数据类型与输入张量的数据类型不同,则将求和结果转换为输入张量的数据类型
if result.dtype != input_x.dtype: if result.dtype != input_x.dtype:
result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype}) result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
# 返回求和结果
return result return result

@ -16,20 +16,44 @@
from ._utils import Expander from ._utils import Expander
# 定义一个Erfc类继承自Expander类
class Erfc(Expander): class Erfc(Expander):
"""Erfc expander""" """Erfc expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入数据进行扩展处理
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展处理后的结果
"""
# 获取输入数据
input_x = self.inputs[0] input_x = self.inputs[0]
# 初始化结果
result = None result = None
# 如果输入数据的类型是float16
if input_x.dtype == "float16": if input_x.dtype == "float16":
# 创建一个float32类型的常量1
const_one = graph_builder.value("float32", 1) const_one = graph_builder.value("float32", 1)
# 将输入数据转换为float32类型
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"}) input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
# 计算输入数据的erf值
erf_result = graph_builder.emit('Erf', [input_x]) erf_result = graph_builder.emit('Erf', [input_x])
# 计算结果
result = graph_builder.emit('Sub', [const_one, erf_result]) result = graph_builder.emit('Sub', [const_one, erf_result])
# 将结果转换为float16类型
result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"}) result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
# 返回结果
return result return result
# 创建一个与输入数据类型相同的常量1
const_one = graph_builder.value(input_x.dtype, 1) const_one = graph_builder.value(input_x.dtype, 1)
# 计算输入数据的erf值
erf_result = graph_builder.emit('Erf', [input_x]) erf_result = graph_builder.emit('Erf', [input_x])
# 计算结果
result = graph_builder.emit('Sub', [const_one, erf_result]) result = graph_builder.emit('Sub', [const_one, erf_result])
# 返回结果
return result return result

@ -21,6 +21,16 @@ class ExpandDims(Expander):
"""ExpandDims expander""" """ExpandDims expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入数据进行维度扩展
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的数据
"""
input_x = self.inputs[0] input_x = self.inputs[0]
shape = self.infer_shape(input_x.shape, self.attrs['axis']) shape = self.infer_shape(input_x.shape, self.attrs['axis'])
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape}) result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape})
@ -29,8 +39,35 @@ class ExpandDims(Expander):
@staticmethod @staticmethod
def infer_shape(shape, axis): def infer_shape(shape, axis):
"""
根据给定的轴位置推断新的形状
Args:
shape (list or tuple): 原始形状表示一个多维数组的尺寸
axis (int, list or tuple): 指定要插入新维度的轴位置如果为整数表示在指定位置插入一个维度如果为列表或元组则按顺序在指定位置插入多个维度
Returns:
list: 插入新维度后的新形状
Raises:
ValueError: 如果axis的值或类型不符合要求时抛出
"""
"""infer shape for expand_dims""" """infer shape for expand_dims"""
def insert_axis(shape, axis): def insert_axis(shape, axis):
"""
在指定轴上插入一个新的维度
Args:
shape (list): 原始数组的形状类型为列表
axis (int): 要插入新维度的轴的位置
Returns:
list: 插入新维度后的数组形状
Raises:
ValueError: 如果axis的类型不是int或者axis的值不在合法范围内将抛出异常
"""
if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1: if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1:
raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, " raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, "
"{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis))) "{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis)))

@ -21,24 +21,51 @@ class FusedAdam(Expander):
"""FusedAdam expander""" """FusedAdam expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
使用图构建器对模型参数进行更新
Args:
graph_builder (GraphBuilder): 图构建器实例用于生成计算图
Returns:
Tensor: 更新后的参数结果
"""
# 获取输入参数
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
# 计算beta_1乘以m
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
# 计算one_sub_beta_1乘以gradient
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
# 计算next_m
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
# 计算beta_2乘以v
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
# 计算gradient的平方
grad_square = graph_builder.emit('Mul', [gradient, gradient]) grad_square = graph_builder.emit('Mul', [gradient, gradient])
# 计算one_sub_beta_2乘以gradient的平方
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
# 计算next_v
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
# 计算next_v的平方根
sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
# 计算sqrt_next_v加上eps
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps]) sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
# 计算更新值
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
# 计算更新值乘以lr
update_with_lr = graph_builder.emit('Mul', [lr, update]) update_with_lr = graph_builder.emit('Mul', [lr, update])
# 计算next_para
next_para = graph_builder.emit('Sub', [param, update_with_lr]) next_para = graph_builder.emit('Sub', [param, update_with_lr])
# 将next_para赋值给param
param_result = graph_builder.emit( param_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
# 将next_m赋值给m
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True}) param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
# 将next_v赋值给v
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True}) param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
# 返回param_result
return param_result return param_result

@ -21,27 +21,54 @@ class FusedAdamWeightDecay(Expander):
"""FusedAdamWeightDecay expander""" """FusedAdamWeightDecay expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入参数进行梯度下降更新
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中添加节点
Returns:
ParaResult: 更新后的参数结果节点
"""
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs
# compute result # compute result
# 计算beta_1和m的乘积
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
# 计算one_sub_beta_1和gradient的乘积
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
# 计算next_m
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
# 计算beta_2和v的乘积
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
# 计算gradient的平方
grad_square = graph_builder.emit('Mul', [gradient, gradient]) grad_square = graph_builder.emit('Mul', [gradient, gradient])
# 计算one_sub_beta_2和grad_square的乘积
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
# 计算next_v
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
# 计算sqrt_next_v
sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
# 计算sqrt_next_v和eps的和
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps]) sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
# 计算update
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
# 计算param_with_weight_decay
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
# 计算update和param_with_weight_decay的和
update = graph_builder.emit('Add', [update, param_with_weight_decay]) update = graph_builder.emit('Add', [update, param_with_weight_decay])
# 计算update_with_lr
update_with_lr = graph_builder.emit('Mul', [lr, update]) update_with_lr = graph_builder.emit('Mul', [lr, update])
# 计算next_para
next_para = graph_builder.emit('Sub', [param, update_with_lr]) next_para = graph_builder.emit('Sub', [param, update_with_lr])
# 将next_para赋值给param
para_result = graph_builder.emit( para_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
# 将next_m赋值给m
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
# 将next_v赋值给v
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
return para_result return para_result

@ -20,9 +20,23 @@ class FusedMulAdd(Expander):
"""FusedMulAdd expander""" """FusedMulAdd expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
执行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 执行加法操作后的结果
"""
# 获取输入
input_x, input_y, input_z = self.inputs input_x, input_y, input_z = self.inputs
# 发射乘法操作
mul_res = graph_builder.emit('Mul', [input_x, input_y]) mul_res = graph_builder.emit('Mul', [input_x, input_y])
# 发射加法操作
result = graph_builder.emit('Add', [mul_res, input_z]) result = graph_builder.emit('Add', [mul_res, input_z])
# 返回结果
return result return result

@ -22,22 +22,47 @@ class Gather(Expander):
"""Expand Gather""" """Expand Gather"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入张量进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的张量
"""
# 获取输入和索引
inputs, indices = self.inputs inputs, indices = self.inputs
# 获取轴
axis = self.attrs['axis'] axis = self.attrs['axis']
# 如果轴小于0则将其转换为正数
if axis < 0: if axis < 0:
axis += len(inputs.shape) axis += len(inputs.shape)
# 如果索引的维度为1则直接进行Gather操作
if len(indices.shape) == 1: if len(indices.shape) == 1:
result = graph_builder.emit('Gather', [inputs, indices], attrs={'axis': axis}) result = graph_builder.emit('Gather', [inputs, indices], attrs={'axis': axis})
# 否则对索引进行Reshape操作然后进行Gather操作最后再进行Reshape操作
else: else:
# 获取原始索引的形状
ori_indices_shape = indices.shape ori_indices_shape = indices.shape
# 计算索引的形状的乘积
indices_shape_one_dim = 1 indices_shape_one_dim = 1
for dim in ori_indices_shape: for dim in ori_indices_shape:
indices_shape_one_dim *= dim indices_shape_one_dim *= dim
# 构造新的索引形状
new_indices_shape = [indices_shape_one_dim] new_indices_shape = [indices_shape_one_dim]
# 对索引进行Reshape操作
reshape_indices = graph_builder.emit('Reshape', [indices], attrs={'shape': new_indices_shape}) reshape_indices = graph_builder.emit('Reshape', [indices], attrs={'shape': new_indices_shape})
# 对输入和Reshape后的索引进行Gather操作
tmp_result = graph_builder.emit('Gather', [inputs, reshape_indices], attrs={'axis': axis}) tmp_result = graph_builder.emit('Gather', [inputs, reshape_indices], attrs={'axis': axis})
# 获取输出的形状
output_shape = list(inputs.shape) output_shape = list(inputs.shape)
# 将索引的形状插入到输出的形状中
output_shape[axis:axis] = ori_indices_shape output_shape[axis:axis] = ori_indices_shape
# 删除输出的形状中多余的维度
del output_shape[axis + len(ori_indices_shape)] del output_shape[axis + len(ori_indices_shape)]
# 对Gather操作的结果进行Reshape操作
result = graph_builder.emit('Reshape', [tmp_result], attrs={'shape': output_shape}) result = graph_builder.emit('Reshape', [tmp_result], attrs={'shape': output_shape})
# 返回结果
return result return result

@ -22,6 +22,16 @@ class GeLU(Expander):
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算输入张量的GELU激活函数值
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: 输入张量的GELU激活函数值
"""
# cal formula are: # cal formula are:
# gelu of x is 0.5 * x * (1.0 + tanh(y)) # gelu of x is 0.5 * x * (1.0 + tanh(y))
# y is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)' # y is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)'
@ -29,20 +39,33 @@ class GeLU(Expander):
input_x = self.inputs[0] input_x = self.inputs[0]
# cal y # cal y
# 计算 input_x 的平方
mul_0 = graph_builder.emit('Mul', [input_x, input_x]) mul_0 = graph_builder.emit('Mul', [input_x, input_x])
# 计算 input_x 的立方
pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
# 创建一个 CSVALUE 常量
const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE) const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE)
# 计算 pow_0 和 CSVALUE 的乘积
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
# 计算 input_x 和 mul_1 的和
tanh_res = graph_builder.emit('Add', [input_x, mul_1]) tanh_res = graph_builder.emit('Add', [input_x, mul_1])
# 创建一个 CSVALUE_SQRT_TWO_DIV_PI 常量
const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
# 计算 tanh_res 和 CSVALUE_SQRT_TWO_DIV_PI 的乘积
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
# cal gelu(x) # cal gelu(x)
# 计算 y 的 tanh 值
tanh_y = graph_builder.emit('Tanh', [y]) tanh_y = graph_builder.emit('Tanh', [y])
# 创建一个 1 常量
const_one = graph_builder.value(tanh_y.dtype, 1) const_one = graph_builder.value(tanh_y.dtype, 1)
# 创建一个 0.5 常量
const_half = graph_builder.value(tanh_y.dtype, 0.5) const_half = graph_builder.value(tanh_y.dtype, 0.5)
# 计算 tanh_y 和 1 的和
tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
# 计算 input_x 和 tanh_y_add_one 的乘积
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
# 计算 const_half 和 mul_x 的乘积
result = graph_builder.emit('Mul', [const_half, mul_x]) result = graph_builder.emit('Mul', [const_half, mul_x])
return result return result

@ -19,11 +19,29 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same @VLD.check_all_formats_same
class GeLUGrad(Expander): class GeLUGrad(Expander):
"""GeLUGrad expander""" """GeLUGrad expander"""
CSVALUE = 0.044715
# CSVALUE = 0.044715
CSVALUE = 0.044715 # CSVALUE的值为0.044715
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
CSVALUE_TRI = 0.134141 # CSVALUE * 3 CSVALUE_TRI = 0.134141 # CSVALUE * 3
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算GELU函数的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: GELU函数的梯度
计算公式如下
GELU的梯度dy和x是dy * y'
y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
"""
# cal formula are: # cal formula are:
# gelu_grad of dy and x is dy * y' # gelu_grad of dy and x is dy * y'
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
@ -33,21 +51,33 @@ class GeLUGrad(Expander):
input_dy, input_x, _ = self.inputs input_dy, input_x, _ = self.inputs
# create some const var # create some const var
# 创建一个常量值为self.CSVALUE数据类型为input_dy.dtype
const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE) const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE)
# 创建一个常量值为self.CSVALUE_SQRT_TWO_DIV_PI数据类型为input_dy.dtype
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
# 创建一个常量值为self.CSVALUE_TRI数据类型为input_dy.dtype
const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI) const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI)
# 创建一个常量值为1数据类型为input_dy.dtype
const_one = graph_builder.value(input_dy.dtype, 1) const_one = graph_builder.value(input_dy.dtype, 1)
# 创建一个常量值为0.5数据类型为input_dy.dtype
const_half = graph_builder.value(input_dy.dtype, 0.5) const_half = graph_builder.value(input_dy.dtype, 0.5)
# cal mul_right # cal mul_right
# 计算input_x的平方
mul_double = graph_builder.emit('Mul', [input_x, input_x]) mul_double = graph_builder.emit('Mul', [input_x, input_x])
# 将const_csvalue_tri与mul_double相乘
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double]) mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
# 将const_one与mul_double_mul_tri相加
mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri]) mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
# 将const_csvalue_sqrt_two_div_pi与mul_add_one相乘
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one]) mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
# cal tanh_para # cal tanh_para
# 计算input_x和mul_double的乘积
mul_triple = graph_builder.emit('Mul', [input_x, mul_double]) mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
# 计算const_csvalue和mul_triple的乘积
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple]) mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
# 计算input_x和mul_triple_mul_csvalue的和
mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue]) mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x]) tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])

@ -22,12 +22,27 @@ class GkDropout(Expander):
"""GkDropout expander""" """GkDropout expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入数据进行dropout操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
tuple: 包含两个元素第一个是执行dropout操作后的结果第二个是生成的掩码
"""
# 获取输入数据和掩码
input_x, input_mask = self.inputs input_x, input_mask = self.inputs
# 获取保持概率
keep_prob = self.attrs['keep_prob'] keep_prob = self.attrs['keep_prob']
# 计算保持概率的倒数
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob)
# 计算保持概率
keep_prob = graph_builder.value(input_x.dtype, keep_prob) keep_prob = graph_builder.value(input_x.dtype, keep_prob)
# 如果掩码的数据类型与输入数据类型不同,则进行类型转换
if input_mask.dtype != input_x.dtype: if input_mask.dtype != input_x.dtype:
input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype})
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type

@ -20,6 +20,16 @@ class Identity(Expander):
"""Identity expander""" """Identity expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
input_x = self.inputs[0] """
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape}) 对输入数据进行重塑操作
return result
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
Tensor: 重塑后的输入数据
"""
input_x = self.inputs[0] # 获取输入数据
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape}) # 使用图构建器对象构建计算图,对输入数据进行重塑操作
return result # 返回重塑后的输入数据

@ -25,67 +25,107 @@ class LayerNorm(Expander):
"""LayerNorm expander""" """LayerNorm expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入进行扩展处理包括批量归一化操作
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
tuple: 包含三个元素的元组分别是处理后的输入均值和方差
- res (Tensor): 处理后的输入张量
- mean (Tensor): 输入的均值张量
- variance (Tensor): 输入的方差张量
"""
# 获取输入数据
input_x, input_gamma, input_beta = self.inputs input_x, input_gamma, input_beta = self.inputs
# 获取处理器类型
processor = self.processor processor = self.processor
# 获取归一化开始轴
begin_norm_axis = self.attrs['begin_norm_axis'] begin_norm_axis = self.attrs['begin_norm_axis']
# 获取epsilon值
epsilon = self.attrs['epsilon'] epsilon = self.attrs['epsilon']
# 获取输入数据的原始数据类型
ori_dtype = input_x.dtype ori_dtype = input_x.dtype
# 如果处理器类型为aicore且输入数据类型为float16则将输入数据类型转换为float32
if processor == 'aicore' and ori_dtype == 'float16': if processor == 'aicore' and ori_dtype == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'}) input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'}) input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
# 获取输入数据的原始形状
ori_shape_x = input_x.shape ori_shape_x = input_x.shape
# 如果输入数据的格式为FRAC_NZ则根据FRAC_NZ格式获取输入数据的形状
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
ori_shape_x = infer_shape_from_fractalnz(input_x.shape) ori_shape_x = infer_shape_from_fractalnz(input_x.shape)
# Calculate the scaling ratio of the average # Calculate the scaling ratio of the average
# 如果begin_norm_axis小于0则将其加上ori_shape_x的长度
if begin_norm_axis < 0: if begin_norm_axis < 0:
begin_norm_axis += len(ori_shape_x) begin_norm_axis += len(ori_shape_x)
# 定义reduce_axis用于存储需要归一化的维度
reduce_axis = () reduce_axis = ()
# 遍历ori_shape_x如果维度大于begin_norm_axis或者等于begin_norm_axis则将其加入reduce_axis
for i, _ in enumerate(ori_shape_x): for i, _ in enumerate(ori_shape_x):
if i > begin_norm_axis or i == begin_norm_axis: if i > begin_norm_axis or i == begin_norm_axis:
reduce_axis = reduce_axis + (i,) reduce_axis = reduce_axis + (i,)
# 计算reduce_elts即需要归一化的维度上的元素个数
reduce_elts = 1.0 reduce_elts = 1.0
for i in reduce_axis: for i in reduce_axis:
reduce_elts *= ori_shape_x[i] reduce_elts *= ori_shape_x[i]
# after reduced # after reduced
# 获取归一化后的ori_shape_x
ori_reduced_shape_x = get_reduced_ori_shape(ori_shape_x, reduce_axis) ori_reduced_shape_x = get_reduced_ori_shape(ori_shape_x, reduce_axis)
# 定义axis用于存储归一化的维度
axis = reduce_axis axis = reduce_axis
# 如果input_x的数据格式为DF.FRAC_NZ则将axis转换为frac_z轴
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
axis = to_frac_z_axis(ori_shape_x, reduce_axis) axis = to_frac_z_axis(ori_shape_x, reduce_axis)
# 计算mean_cof_v即归一化系数
mean_cof_v = graph_builder.value(input_x.dtype, 1.0 / reduce_elts) mean_cof_v = graph_builder.value(input_x.dtype, 1.0 / reduce_elts)
# Calculate mean # Calculate mean
# 计算输入张量的均值
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
# 将均值乘以系数
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
# 如果输入张量的数据格式为DF.FRAC_NZ则对均值进行重整
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x}) mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x})
# Calculate variance # Calculate variance
variance_sub = graph_builder.emit('Sub', [input_x, mean]) # 计算方差
variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub]) variance_sub = graph_builder.emit('Sub', [input_x, mean]) # 计算输入与均值的差值
variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': axis, 'keep_dims': True}) variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub]) # 计算差值的平方
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': axis, 'keep_dims': True}) # 对差值的平方求和
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) # 计算方差
# 如果输入数据的格式为DF.FRAC_NZ则对方差进行reshape操作
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x}) variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x})
# Calculate normalize # Calculate normalize
# 计算输入x与均值之间的差值
normalize_sub = graph_builder.emit('Sub', [input_x, mean]) normalize_sub = graph_builder.emit('Sub', [input_x, mean])
# 创建一个epsilon值用于防止除零错误
epsilon_v = graph_builder.value(input_x.dtype, epsilon) epsilon_v = graph_builder.value(input_x.dtype, epsilon)
# 计算方差加上epsilon的值
normalize_add = graph_builder.emit('Add', [variance, epsilon_v]) normalize_add = graph_builder.emit('Add', [variance, epsilon_v])
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
# Calculate scale and translate # Calculate scale and translate
# 计算归一化后的乘积
scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma]) scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
# 计算最终结果
res = graph_builder.emit('Add', [scale_mul, input_beta]) res = graph_builder.emit('Add', [scale_mul, input_beta])
# 如果处理器为aicore且原始数据类型为float16则将结果、均值和方差转换为float16
if processor == 'aicore' and ori_dtype == 'float16': if processor == 'aicore' and ori_dtype == 'float16':
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'}) res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'}) mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})

@ -23,13 +23,33 @@ class LayerNormGrad(Expander):
"""LayerNormGrad expander""" """LayerNormGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
tuple: 包含dx, dg, db的元组
dx (Tensor): 梯度相对于输入x的导数
dg (Tensor): 梯度相对于gamma的导数
db (Tensor): 梯度相对于beta的导数
"""
# 获取输入参数
x, dy, variance, mean, gamma = self.inputs x, dy, variance, mean, gamma = self.inputs
# 获取处理器类型
processor = self.processor processor = self.processor
# 获取归一化轴的起始位置
begin_norm_axis = self.attrs['begin_norm_axis'] begin_norm_axis = self.attrs['begin_norm_axis']
# 获取参数轴的起始位置
begin_params_axis = self.attrs['begin_params_axis'] begin_params_axis = self.attrs['begin_params_axis']
# 获取epsilon值默认为1e-12
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12 epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
# 获取输入数据的原始数据类型
ori_dtype = x.dtype ori_dtype = x.dtype
# 如果处理器类型为aicore且数据类型为float16则将输入数据转换为float32
if processor == 'aicore' and ori_dtype == 'float16': if processor == 'aicore' and ori_dtype == 'float16':
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'}) x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'}) dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
@ -37,77 +57,121 @@ class LayerNormGrad(Expander):
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'}) mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'}) gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
# 如果归一化轴的起始位置小于0则将其转换为正数
if begin_norm_axis < 0: if begin_norm_axis < 0:
begin_norm_axis += len(x.shape) begin_norm_axis += len(x.shape)
# 如果参数轴的起始位置小于0则将其转换为正数
if begin_params_axis < 0: if begin_params_axis < 0:
begin_params_axis += len(x.shape) begin_params_axis += len(x.shape)
# 获取归一化轴和参数轴的范围
norm_axis = tuple(range(begin_norm_axis, len(x.shape))) norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
param_axis = tuple(range(0, begin_params_axis)) param_axis = tuple(range(0, begin_params_axis))
# 计算归一化轴的维度乘积
reduce_size = 1.0 reduce_size = 1.0
for i in norm_axis: for i in norm_axis:
reduce_size *= x.shape[i] reduce_size *= x.shape[i]
# set some constant val. # set some constant val.
# 计算epsilon的值
eps = graph_builder.value(x.dtype, epsilon) eps = graph_builder.value(x.dtype, epsilon)
# 计算-0.5的值
const_neg_half = graph_builder.value(x.dtype, -0.5) const_neg_half = graph_builder.value(x.dtype, -0.5)
# 计算-2.0的值
const_neg_two = graph_builder.value(x.dtype, -2.0) const_neg_two = graph_builder.value(x.dtype, -2.0)
# 计算2.0的值
const_two = graph_builder.value(x.dtype, 2.0) const_two = graph_builder.value(x.dtype, 2.0)
# 计算-1.0的值
const_neg_one = graph_builder.value(x.dtype, -1.0) const_neg_one = graph_builder.value(x.dtype, -1.0)
# 计算mean_cof的值
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size)) mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size))
# cal dg db # cal dg db
# 计算方差和eps的和
var_eps = graph_builder.emit('Add', [variance, eps]) var_eps = graph_builder.emit('Add', [variance, eps])
# 计算方差和eps的和的对数
var_eps_log = graph_builder.emit('Log', [var_eps]) var_eps_log = graph_builder.emit('Log', [var_eps])
# 计算方差和eps的对数乘以-0.5
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half]) var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
# 计算方差和eps的对数乘以-0.5的指数
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul]) rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
# 计算x和mean的差
# 计算输入x减去均值
x_sub_mean = graph_builder.emit('Sub', [x, mean]) x_sub_mean = graph_builder.emit('Sub', [x, mean])
# 计算x减去均值乘以rsqrt_var_eps
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean]) x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
# 计算dy乘以x减去均值乘以rsqrt_var_eps
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps]) dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
# 计算dg对dg_mul进行求和reduce_axis为param_axiskeep_dims为False
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False}) dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# 计算db对dy进行求和reduce_axis为param_axiskeep_dims为False
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False}) db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# pd_var # pd_var
# 计算tmp_var_eps
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps]) tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
# 计算r_tmp_var_eps
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps]) r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
# 计算dy_mul_gamma
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma]) dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
# 计算tmp_mul
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean]) tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
# 计算padvar_mul1
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# 计算padvar_mul3
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps]) padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
# 计算pd_var
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half]) pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
# pd_mean # pd_mean
# 计算pdmean1_sum使用ReduceSum函数输入为dy_mul_gamma归约轴为norm_axis保持维度为True
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma], pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
attrs={'reduce_axis': norm_axis, 'keep_dims': True}) attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# 计算neg_rsqrt_var_eps使用Mul函数输入为rsqrt_var_eps和const_neg_one
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one]) neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
# 计算pd_mean_1使用Mul函数输入为neg_rsqrt_var_eps和pdmean1_sum
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum]) pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
# 计算pdmean2_mul1使用Mul函数输入为const_neg_two和x_sub_mean
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean]) pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
# 计算pdmean2_sum使用ReduceSum函数输入为pdmean2_mul1归约轴为norm_axis保持维度为True
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1], pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
attrs={'reduce_axis': norm_axis, 'keep_dims': True}) attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# 计算pdmean2_mul3使用Mul函数输入为pdmean2_sum和mean_cof
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof]) pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
# 计算pd_mean_2使用Mul函数输入为pdmean2_mul3和pd_var
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var]) pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
# 计算pd_mean使用Add函数输入为pd_mean_1和pd_mean_2
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2]) pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
# cal dx # cal dx
# 计算pd_x_1
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps]) pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
# 计算pdx2_mul
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean]) pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
# 计算pdx2_mul_two
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two]) pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
# 计算pd_x_2
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof]) pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
# 计算pd_x_3
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof]) pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
# 计算dx_tmp
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2]) dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
# 计算dx
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3]) dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
# 如果处理器为aicore且原始数据类型为float16则将dx、dg、db转换为float16
if processor == 'aicore' and ori_dtype == 'float16': if processor == 'aicore' and ori_dtype == 'float16':
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'}) dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'}) db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
# 返回dx、dg、db
return dx, dg, db return dx, dg, db

@ -23,24 +23,49 @@ class LogSoftmax(Expander):
"""LogSoftmax expander""" """LogSoftmax expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入数据进行Softmax归一化
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
Tensor: Softmax归一化后的结果
"""
# 获取输入数据
input_x = self.inputs[0] input_x = self.inputs[0]
# 获取轴参数
axis = self.attrs['axis'] axis = self.attrs['axis']
# 获取处理器类型
processor = self.processor processor = self.processor
# 如果轴参数是整数,则将其转换为元组
if isinstance(axis, int): if isinstance(axis, int):
axis = (axis,) axis = (axis,)
# 获取输入数据的原始数据类型
ori_dtype = input_x.dtype ori_dtype = input_x.dtype
# 如果原始数据类型不是float16且处理器类型是aicore则将输入数据转换为float16
if ori_dtype != "float16" and processor == "aicore": if ori_dtype != "float16" and processor == "aicore":
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'}) input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
# 对转换后的数据进行ReduceMax操作
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True}) max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
# 将ReduceMax操作的结果转换回原始数据类型
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype}) max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
else: else:
# 对输入数据进行ReduceMax操作
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
# 计算输入数据与ReduceMax操作结果的差值
data_sub = graph_builder.emit('Sub', [input_x, max_x]) data_sub = graph_builder.emit('Sub', [input_x, max_x])
# 计算差值的指数
data_exp = graph_builder.emit('Exp', [data_sub]) data_exp = graph_builder.emit('Exp', [data_sub])
# 对指数结果进行ReduceSum操作
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
# 计算ReduceSum结果的log
log_expsum = graph_builder.emit('Log', [data_expsum]) log_expsum = graph_builder.emit('Log', [data_expsum])
# 计算差值与log的差值
result = graph_builder.emit('Sub', [data_sub, log_expsum]) result = graph_builder.emit('Sub', [data_sub, log_expsum])
# 返回结果
return result return result

@ -23,14 +23,32 @@ class LogSoftmaxGrad(Expander):
"""LogSoftmaxGrad expander""" """LogSoftmaxGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展操作的结果
"""
# 获取输入的logits和dy
input_logits, input_dy = self.inputs input_logits, input_dy = self.inputs
# 获取axis参数
axis = self.attrs['axis'] axis = self.attrs['axis']
# 如果axis是整数则将其转换为元组
if isinstance(axis, int): if isinstance(axis, int):
axis = (axis,) axis = (axis,)
# 计算softmax
softmax = graph_builder.emit('Exp', [input_logits]) softmax = graph_builder.emit('Exp', [input_logits])
# 计算dy的sum
dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True}) dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True})
# 计算softmax和dy_sum的乘积
mul_result = graph_builder.emit('Mul', [softmax, dy_sum]) mul_result = graph_builder.emit('Mul', [softmax, dy_sum])
# 计算input_dy和mul_result的差
result = graph_builder.emit('Sub', [input_dy, mul_result]) result = graph_builder.emit('Sub', [input_dy, mul_result])
# 返回结果
return result return result

@ -25,48 +25,139 @@ class MatMul(Expander):
""" """
def __init__(self, expand_info): def __init__(self, expand_info):
"""
初始化MatMul类实例
Args:
expand_info (dict): 扩展信息字典包含操作所需的额外信息
Attributes:
transpose_a (bool): 是否对矩阵A进行转置
transpose_b (bool): 是否对矩阵B进行转置
left_format (str): 矩阵A的数据格式
right_format (str): 矩阵B的数据格式
shape_a (tuple): 矩阵A的形状
shape_b (tuple): 矩阵B的形状
"""
# 调用父类的初始化方法
super(MatMul, self).__init__(expand_info) super(MatMul, self).__init__(expand_info)
# 获取transpose_a属性
self.transpose_a = self.attrs['transpose_a'] self.transpose_a = self.attrs['transpose_a']
# 获取transpose_b属性
self.transpose_b = self.attrs['transpose_b'] self.transpose_b = self.attrs['transpose_b']
# 获取left_format属性
self.left_format = self.attrs['left_format'] self.left_format = self.attrs['left_format']
# 获取right_format属性
self.right_format = self.attrs['right_format'] self.right_format = self.attrs['right_format']
# 获取输入A的shape
self.shape_a = self.inputs[0]['shape'] self.shape_a = self.inputs[0]['shape']
# 获取输入B的shape
self.shape_b = self.inputs[1]['shape'] self.shape_b = self.inputs[1]['shape']
def _optimize_to_mul(self): def _optimize_to_mul(self):
"""
检查是否可以用乘法mul替换矩阵乘法matmul
Args:
Returns:
bool: 如果可以用乘法替换矩阵乘法返回True否则返回False
"""
"""check if matmul can be replace by mul""" """check if matmul can be replace by mul"""
# 如果处理器不是'aicore'或者左格式或右格式不是默认格式则返回False
if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT: if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT:
return False return False
# 如果transpose_a为True则k_a为shape_a的倒数第二个维度否则为shape_a的倒数第一个维度
k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1] k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1]
# 如果transpose_b为True则k_b为shape_b的倒数第一个维度否则为shape_b的倒数第二个维度
k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2] k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2]
# 如果k_a或k_b不等于1则返回False
if k_a != 1 or k_b != 1: if k_a != 1 or k_b != 1:
return False return False
# 否则返回True
return True return True
def _check(self): def _check(self):
"""
检查输入个数是否满足矩阵乘法的要求
Args:
Returns:
Raises:
GKException: 如果输入的个数小于2则抛出GKException异常提示信息为 "For 'MatMul', inputs number should bigger than 1, but got {}."其中{}为输入的个数
"""
# 获取输入的个数
input_num = len(self.inputs) input_num = len(self.inputs)
# 如果输入的个数小于2抛出异常
if input_num < 2: if input_num < 2:
raise GKException("For 'MatMul', inputs number should bigger than 1, but got {}.".format(input_num)) raise GKException("For 'MatMul', inputs number should bigger than 1, but got {}.".format(input_num))
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
将MatMul或BatchMatMul操作替换为Mul操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Node: Mul操作的结果节点
Raises:
GKException: 如果不需要将MatMul/BatchMatMul替换为Mul操作则引发异常
"""
# 定义一个函数用于转置shape
def transpose(shape): def transpose(shape):
"""
将给定的shape进行转置操作
Args:
shape (tuple): 输入的shape为一个元组表示多维数组的形状
Returns:
list: 转置后的shape以列表形式返回
"""
# 将shape转换为列表
trans_shape = list(shape) trans_shape = list(shape)
# 将shape的倒数第二个元素和倒数第一个元素交换位置
trans_shape[-2] = shape[-1] trans_shape[-2] = shape[-1]
trans_shape[-1] = shape[-2] trans_shape[-1] = shape[-2]
# 返回转置后的shape
return trans_shape return trans_shape
# 如果不需要优化为乘法,则抛出异常
if not self._optimize_to_mul(): if not self._optimize_to_mul():
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul") raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
# Matmul is replaced by Mul([b m k], [b k n]) when k==1 # Matmul is replaced by Mul([b m k], [b k n]) when k==1
# 获取输入a
input_a = self.inputs[0] input_a = self.inputs[0]
# 获取输入b
input_b = self.inputs[1] input_b = self.inputs[1]
# 如果transpose_a为True则对输入a进行转置
if self.transpose_a: if self.transpose_a:
# 获取输入a的转置形状
shape_a_trans = transpose(self.shape_a) shape_a_trans = transpose(self.shape_a)
# 对输入a进行转置
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans}) input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
# 如果transpose_b为True则对输入b进行转置
if self.transpose_b: if self.transpose_b:
# 获取输入b的转置形状
shape_b_trans = transpose(self.shape_b) shape_b_trans = transpose(self.shape_b)
# 对输入b进行转置
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans}) input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
# 对输入a和输入b进行乘法运算
result = graph_builder.emit('Mul', [input_a, input_b]) result = graph_builder.emit('Mul', [input_a, input_b])
# 如果dst_type在attrs中并且输入a的数据类型与dst_type不同则对结果进行类型转换
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']: if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
# 对结果进行类型转换
result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']}) result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
return result return result

@ -23,35 +23,76 @@ class MaximumGrad(Expander):
"""MaximumGrad expander""" """MaximumGrad expander"""
def _check(self): def _check(self):
"""
检查MaximumGrad的属性是否符合要求
Args:
Returns:
返回父类的检查结果
Raises:
GKException: 'grad_x' 'grad_y' 的值都为 False 时抛出异常
"""
# 如果attr 'grad_x'和'grad_y'的值都为False则抛出异常
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
raise GKException("For 'MaximumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and " raise GKException("For 'MaximumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
"{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y'))) "{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
# 调用父类的方法
return super()._check() return super()._check()
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
根据输入计算梯度并返回两个梯度结果
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
tuple: 包含两个梯度结果的元组第一个元素为对输入x的梯度第二个元素为对输入y的梯度
"""
# 获取输入的x、y和dout
input_x, input_y, input_dout = self.inputs input_x, input_y, input_dout = self.inputs
# 比较x和y的大小返回一个布尔值
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y]) ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
# 将布尔值转换为与x相同的类型
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
# 计算dx即x的梯度
dx = graph_builder.emit('Mul', [ge_result, input_dout]) dx = graph_builder.emit('Mul', [ge_result, input_dout])
# 计算dy即y的梯度
dy = graph_builder.emit('Sub', [input_dout, dx]) dy = graph_builder.emit('Sub', [input_dout, dx])
# 获取dx和dy的reduce轴
reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape) reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape)
reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape) reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape)
# 如果dx有reduce轴
if reduce_axis_x: if reduce_axis_x:
# 对dx进行求和
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False}) dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
# 如果dx_reduce的形状与input_x的形状不同则进行reshape
if dx_reduce.shape != input_x.shape: if dx_reduce.shape != input_x.shape:
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape}) dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
# 否则dx_out等于dx_reduce
else: else:
dx_out = dx_reduce dx_out = dx_reduce
# 否则dx_out等于dx
else: else:
dx_out = dx dx_out = dx
# 如果dy有reduce轴
if reduce_axis_y: if reduce_axis_y:
# 对dy进行求和
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False}) dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
# 如果dy_reduce的形状与input_y的形状不同则进行reshape
if dy_reduce.shape != input_y.shape: if dy_reduce.shape != input_y.shape:
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape}) dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
# 否则dy_out等于dy_reduce
else: else:
dy_out = dy_reduce dy_out = dy_reduce
# 否则dy_out等于dy
else: else:
dy_out = dy dy_out = dy

@ -22,59 +22,117 @@ class MinimumGrad(Expander):
"""MinimumGrad expander""" """MinimumGrad expander"""
def _check(self): def _check(self):
"""
检查MinimumGrad类的属性是否满足要求
Args:
Returns:
bool: 如果属性符合要求则返回True否则抛出GKException异常
Raises:
GKException: 如果MinimumGrad类的属性'grad_x''grad_y'均为False则抛出此异常
"""
# 如果attr 'grad_x'和'grad_y'的值都为False则抛出异常
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
raise GKException("For 'MinimumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and " raise GKException("For 'MinimumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
"{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y'))) "{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
# 调用父类的方法
return super(MinimumGrad, self)._check() return super(MinimumGrad, self)._check()
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算两个输入的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中执行操作
Returns:
tuple: 包含两个梯度结果的元组
"""
# 输入参数
input_x, input_y, input_dout = self.inputs input_x, input_y, input_dout = self.inputs
le_result = graph_builder.emit('LessEqual', [input_x, input_y]) # 执行 LessEqual 操作
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype}) le_result = graph_builder.emit('LessEqual', [input_x, input_y]) # 执行 LessEqual 操作
dx = graph_builder.emit('Mul', [le_result, input_dout]) # 将结果转换为与 input_x 相同的数据类型
dy = graph_builder.emit('Sub', [input_dout, dx]) le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype}) # 将结果转换为与 input_x 相同的数据类型
# 执行 Mul 操作,将 le_result 和 input_dout 相乘
dx = graph_builder.emit('Mul', [le_result, input_dout]) # 执行 Mul 操作,将 le_result 和 input_dout 相乘
# 执行 Sub 操作,用 input_dout 减去 dx
dy = graph_builder.emit('Sub', [input_dout, dx]) # 执行 Sub 操作,用 input_dout 减去 dx
# 对于 minimumgrad 操作,输出形状应与输入形状相同,
# 但某些元素级操作可能会广播输入形状,
# 导致输出形状不等于原始输入形状,因此需要减少输出来使它们相等
# for minimumgrad op, output_shape should be equal to input_shape, # for minimumgrad op, output_shape should be equal to input_shape,
# but some elementwise operating may broadcast input_shape # but some elementwise operating may broadcast input_shape
# then output_shape not equal to original input_shape, so need to reduce output to let them equal # then output_shape not equal to original input_shape, so need to reduce output to let them equal
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape) reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape) # 获取 x 的减少轴
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape) reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape) # 获取 y 的减少轴
if reduce_axis_x: if reduce_axis_x:
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False}) # 如果存在减少轴,执行 ReduceSum 操作
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False}) # 执行 ReduceSum 操作
if dx_reduce.shape != input_x.shape: if dx_reduce.shape != input_x.shape:
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape}) # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape}) # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作
else: else:
dx_out = dx_reduce dx_out = dx_reduce # 否则直接使用减少后的结果
else: else:
dx_out = dx dx_out = dx # 如果没有减少轴,则直接使用 dx
if reduce_axis_y: if reduce_axis_y:
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False}) # 如果存在减少轴,执行 ReduceSum 操作
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False}) # 执行 ReduceSum 操作
if dy_reduce.shape != input_y.shape: if dy_reduce.shape != input_y.shape:
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape}) # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape}) # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作
else: else:
dy_out = dy_reduce dy_out = dy_reduce # 否则直接使用减少后的结果
else: else:
dy_out = dy dy_out = dy # 如果没有减少轴,则直接使用 dy
# output two results, regardless of grad_x and grad_y # 输出两个结果,
return dx_out, dy_out return dx_out, dy_out
@staticmethod @staticmethod
def get_reduce_axis(original_shape, broadcast_shape): def get_reduce_axis(original_shape, broadcast_shape):
"""
计算最终输出形状的归约轴
Args:
original_shape (tuple of int): 原始形状一个包含整数的元组
broadcast_shape (tuple of int): 广播形状一个包含整数的元组
Returns:
list of int: 归约轴列表表示在最终输出形状中需要归约的轴索引
Raises:
ValueError: 如果original_shape的长度大于broadcast_shape的长度或者original_shape和broadcast_shape无法广播
"""
"""compute reduce axis for final output_shape""" """compute reduce axis for final output_shape"""
# 如果original_shape的长度大于broadcast_shape的长度
if len(original_shape) > len(broadcast_shape): if len(original_shape) > len(broadcast_shape):
raise ValueError("For 'MinimumGrad', the length of original_shape should be less than or equal to the " raise ValueError("For 'MinimumGrad', the length of original_shape should be less than or equal to the "
"length of broadcast_shape, but got {} and {}".format(original_shape, broadcast_shape)) "length of broadcast_shape, but got {} and {}".format(original_shape, broadcast_shape))
# 创建一个tmp_shape列表长度为broadcast_shape的长度前面填充1后面填充original_shape的元素
tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape
reduce_axis = [] reduce_axis = []
# 遍历tmp_shape中的每个元素
for idx, _ in enumerate(tmp_shape): for idx, _ in enumerate(tmp_shape):
# 如果tmp_shape中的元素与broadcast_shape中的对应元素不相等
if tmp_shape[idx] != broadcast_shape[idx]: if tmp_shape[idx] != broadcast_shape[idx]:
# 如果tmp_shape中的元素为1
if tmp_shape[idx] == 1: if tmp_shape[idx] == 1:
# 将当前索引添加到reduce_axis列表中
reduce_axis.append(idx) reduce_axis.append(idx)
else: else:
# 抛出异常表示original_shape和broadcast_shape无法广播
raise ValueError("For 'MinimumGrad', original_shape {} and broadcast_shape {} can not broadcast." raise ValueError("For 'MinimumGrad', original_shape {} and broadcast_shape {} can not broadcast."
.format(original_shape, broadcast_shape)) .format(original_shape, broadcast_shape))
return reduce_axis return reduce_axis

@ -20,7 +20,24 @@ class OnesLike(Expander):
"""OnesLike expander""" """OnesLike expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
将输入张量扩展至指定形状
Args:
graph_builder: 图构建器对象用于构建图结构
Returns:
扩展后的张量
"""
# 获取输入张量
input_x = self.inputs[0] input_x = self.inputs[0]
# 创建一个值为1的常量数据类型与输入张量相同
const_one = graph_builder.value(input_x.dtype, 1) const_one = graph_builder.value(input_x.dtype, 1)
# 使用BroadcastTo操作将常量扩展至输入张量的形状
result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape}) result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape})
# 返回扩展后的张量
return result return result

@ -23,21 +23,40 @@ class ReduceMean(Expander):
"""ReduceMean expander""" """ReduceMean expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入张量进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展操作后的张量
"""
# 获取输入张量
x = self.inputs[0] x = self.inputs[0]
# 获取扩展操作的轴
axis = self.attrs['axis'] axis = self.attrs['axis']
# 获取是否保持维度
keep_dims = self.attrs['keep_dims'] keep_dims = self.attrs['keep_dims']
# 如果轴不是元组或列表,则将其转换为元组
if not isinstance(axis, (tuple, list)): if not isinstance(axis, (tuple, list)):
axis = (axis,) axis = (axis,)
# 如果轴为空,则将其设置为张量的所有维度
elif not axis: elif not axis:
axis = list(range(len(x.shape))) axis = list(range(len(x.shape)))
# 计算缩减的大小
reduce_size = 1.0 reduce_size = 1.0
for idx in axis: for idx in axis:
reduce_size *= x.shape[idx] reduce_size *= x.shape[idx]
# 创建一个与输入张量相同数据类型的值,值为缩减的大小
reduce_size_value = graph_builder.value(x.dtype, reduce_size) reduce_size_value = graph_builder.value(x.dtype, reduce_size)
# 沿指定轴对输入张量进行求和操作
sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims}) sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
# 将求和结果除以缩减的大小,得到扩展后的张量
result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value]) result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value])
return result return result

@ -21,12 +21,28 @@ class ReluGrad(Expander):
"""ReLU expander""" """ReLU expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
在指定的图构建器中扩展当前节点
Args:
graph_builder (GraphBuilder): 图构建器实例用于在图中生成新的节点
Returns:
Tensor: 返回计算后的结果张量
"""
# 获取输入张量
input_x = self.inputs[0] input_x = self.inputs[0]
input_y = self.inputs[1] input_y = self.inputs[1]
# 生成一个与input_y相同数据类型的0值张量
# 生成一个与input_y相同数据类型的0值张量
const_zero = graph_builder.value(input_y.dtype, 0) const_zero = graph_builder.value(input_y.dtype, 0)
# 判断input_y是否大于0生成布尔张量
ge_result = graph_builder.emit('Greater', [input_y, const_zero]) ge_result = graph_builder.emit('Greater', [input_y, const_zero])
# 将布尔张量转换为与input_x相同数据类型的张量
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
# 将转换后的张量与input_x相乘
result = graph_builder.emit('Mul', [ge_result, input_x]) result = graph_builder.emit('Mul', [ge_result, input_x])
return result return result

@ -21,21 +21,46 @@ class SigmoidCrossEntropyWithLogits(Expander):
"""SigmoidCrossEntropyWithLogits expander""" """SigmoidCrossEntropyWithLogits expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算sigmoid交叉熵损失
Args:
graph_builder: 图构建器对象用于构建计算图
Returns:
计算得到的sigmoid交叉熵损失值
"""
logits, labels = self.inputs logits, labels = self.inputs
# 计算 logits 和 labels 的 sigmoid_cross_entropy_with_logits
# Calculate sigmoid_cross_entropy_with_logits(logits, labels) # Calculate sigmoid_cross_entropy_with_logits(logits, labels)
# formula of sigmoid_cross_entropy_with_logits is: # sigmoid_cross_entropy_with_logits 的公式为:
# -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits))) # -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits)))
# To ensure stability and avoid overflow, the formula equal to : # 为了确保稳定性并避免溢出,该公式等价于:
# max(logits, 0) - logits * labels + log(1 + exp(-abs(logits))) # max(logits, 0) - logits * labels + log(1 + exp(-abs(logits)))
# 创建一个值为 1.0 的常量
const_one = graph_builder.value(logits.dtype, 1.0) const_one = graph_builder.value(logits.dtype, 1.0)
# 创建一个值为 0.0 的常量
const_zero = graph_builder.value(logits.dtype, 0.0) const_zero = graph_builder.value(logits.dtype, 0.0)
# 计算 logits 和 0 的最大值
max_logits = graph_builder.emit('Maximum', [logits, const_zero]) max_logits = graph_builder.emit('Maximum', [logits, const_zero])
# 计算 logits 和 labels 的乘积
logits_mul_labels = graph_builder.emit('Mul', [logits, labels]) logits_mul_labels = graph_builder.emit('Mul', [logits, labels])
# 计算 logits 的绝对值
abs_logits = graph_builder.emit('Abs', [logits]) abs_logits = graph_builder.emit('Abs', [logits])
# 计算 logits 的负值
neg_abs_logits = graph_builder.emit('Neg', [abs_logits]) neg_abs_logits = graph_builder.emit('Neg', [abs_logits])
# 计算 exp(-abs(logits))
exp_neg_abs_logits = graph_builder.emit('Exp', [neg_abs_logits]) exp_neg_abs_logits = graph_builder.emit('Exp', [neg_abs_logits])
# 计算 1 + exp(-abs(logits))
one_add_exp_neg_abs_logits = graph_builder.emit('Add', [const_one, exp_neg_abs_logits]) one_add_exp_neg_abs_logits = graph_builder.emit('Add', [const_one, exp_neg_abs_logits])
# 计算 log(1 + exp(-abs(logits)))
log_one_add_exp_neg_abs_logits = graph_builder.emit('Log', [one_add_exp_neg_abs_logits]) log_one_add_exp_neg_abs_logits = graph_builder.emit('Log', [one_add_exp_neg_abs_logits])
# 计算 max(logits, 0) - logits * labels
res_tmp = graph_builder.emit('Sub', [max_logits, logits_mul_labels]) res_tmp = graph_builder.emit('Sub', [max_logits, logits_mul_labels])
# 计算最终结果
res = graph_builder.emit('Add', [res_tmp, log_one_add_exp_neg_abs_logits]) res = graph_builder.emit('Add', [res_tmp, log_one_add_exp_neg_abs_logits])
return res return res

@ -21,15 +21,29 @@ class SigmoidCrossEntropyWithLogitsGrad(Expander):
"""SigmoidCrossEntropyWithLogitsGrad expander""" """SigmoidCrossEntropyWithLogitsGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算sigmoid交叉熵损失的梯度
Args:
graph_builder: 图构建器对象用于构建计算图
Returns:
计算得到的梯度值
"""
logits, label, dout = self.inputs logits, label, dout = self.inputs
# Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout) # 计算sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
# formula of sigmoid_cross_entropy_with_logits_grad is : # sigmoid_cross_entropy_with_logits_grad的公式为:
# (sigmoid(logits) - label) * dout # (sigmoid(logits) - label) * dout
# 计算sigmoid(logits)
# Calculate sigmoid(logits)
const_one = graph_builder.value(logits.dtype, 1.0) const_one = graph_builder.value(logits.dtype, 1.0)
neg_x = graph_builder.emit('Neg', [logits]) neg_x = graph_builder.emit('Neg', [logits]) # 计算-logits
exp_neg_x = graph_builder.emit('Exp', [neg_x]) exp_neg_x = graph_builder.emit('Exp', [neg_x]) # 计算e^(-logits)
add_exp = graph_builder.emit('Add', [const_one, exp_neg_x]) add_exp = graph_builder.emit('Add', [const_one, exp_neg_x]) # 计算1 + e^(-logits)
sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp]) sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp]) # 计算1 / (1 + e^(-logits))即sigmoid(logits)
# 计算(sigmoid(logits) - label)
sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label]) sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label])
res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout]) # 计算最终结果
res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout]) # 计算(sigmoid(logits) - label) * dout
return res return res

@ -16,16 +16,31 @@
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same @VLD.check_all_formats_same # 定义一个SigmoidGrad类继承自Expander类
class SigmoidGrad(Expander): class SigmoidGrad(Expander):
"""SigmoidGrad expander""" """SigmoidGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算 sigmoid 函数的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中添加操作
Returns:
Tensor: 计算得到的 sigmoid 梯度
"""
input_y, dy = self.inputs input_y, dy = self.inputs
# 计算 sigmoid_grad(y, dy)
# sigmoid_grad 的公式是: (1 - y) * y * dy
# Calculate sigmoid_grad(y, dy) # Calculate sigmoid_grad(y, dy)
# formula of sigmoid_grad is : (1 - y) * y * dy # formula of sigmoid_grad is : (1 - y) * y * dy
const_one = graph_builder.value(input_y.dtype, 1.0) const_one = graph_builder.value(input_y.dtype, 1.0)
# 1 - y
one_mins_y = graph_builder.emit('Sub', [const_one, input_y]) one_mins_y = graph_builder.emit('Sub', [const_one, input_y])
# y * dy
y_mul_dy = graph_builder.emit('Mul', [input_y, dy]) y_mul_dy = graph_builder.emit('Mul', [input_y, dy])
# (1 - y) * (y * dy)
res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy]) res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy])
return res return res

@ -21,15 +21,41 @@ class Slice(Expander):
"""Slice expander""" """Slice expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
在图中扩展输入张量
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的输出张量
"""
# 获取输入张量
input_x = self.inputs[0] input_x = self.inputs[0]
# 获取开始索引
begin = self.attrs['begin'] begin = self.attrs['begin']
# 获取切片大小
size = self.attrs['size'] size = self.attrs['size']
# 初始化结束索引列表
end = [] end = []
# 初始化步长列表
strides = [] strides = []
# 遍历每个维度,计算结束索引和步长
for i, begin_idx in enumerate(begin): for i, begin_idx in enumerate(begin):
# 步长设置为1
strides.append(1) strides.append(1)
# 计算结束索引
end.append(begin_idx + size[i]) end.append(begin_idx + size[i])
# 创建一个新的张量作为输出
output = graph_builder.tensor(size, input_x.dtype, input_x.data_format) output = graph_builder.tensor(size, input_x.dtype, input_x.data_format)
# 执行StridedSlice操作对输入张量进行切片
graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides}) graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides})
# 返回输出张量
return output return output

@ -25,45 +25,75 @@ class Softmax(Expander):
"""Softmax expander""" """Softmax expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算Softmax函数值
Args:
graph_builder: 图构建器对象
Returns:
Softmax函数的计算结果
"""
# 获取输入数据
input_x = self.inputs[0] input_x = self.inputs[0]
# 获取处理器
processor = self.processor processor = self.processor
# 获取轴信息
axis = self.attrs['axis'] axis = self.attrs['axis']
# 获取输入数据的原始形状
ori_shape = input_x.shape ori_shape = input_x.shape
# 如果输入数据格式为FRAC_NZ则推断其形状
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
ori_shape = infer_shape_from_fractalnz(input_x.shape) ori_shape = infer_shape_from_fractalnz(input_x.shape)
# 遍历轴信息,处理负数轴索引
for i, _ in enumerate(list(axis)): for i, _ in enumerate(list(axis)):
if axis[i] < 0: if axis[i] < 0:
axis[i] += len(ori_shape) axis[i] += len(ori_shape)
# 获取减少维度后的原始形状
ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis) ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
# 获取减少的轴
reduce_axis = axis reduce_axis = axis
# 如果输入数据格式为FRAC_NZ则转换轴
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
reduce_axis = to_frac_z_axis(ori_shape, axis) reduce_axis = to_frac_z_axis(ori_shape, axis)
# 获取输入数据的原始数据类型
ori_dtype = input_x.dtype ori_dtype = input_x.dtype
# 如果原始数据类型不是float16且处理器为aicore则进行类型转换
if ori_dtype != "float16" and processor == "aicore": if ori_dtype != "float16" and processor == "aicore":
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'}) input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': reduce_axis, max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': reduce_axis,
'keep_dims': True}) 'keep_dims': True})
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype}) max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
else: else:
# 计算最大值
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
# 如果原始数据类型为float16且处理器为aicore则进行类型转换
if ori_dtype == "float16" and processor == "aicore": if ori_dtype == "float16" and processor == "aicore":
max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': "float32"}) max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': "float32"})
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"}) input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
# 如果输入数据格式为FRAC_NZ则重新调整最大值的形状
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
max_x = graph_builder.emit('Reshape', [max_x], attrs={'shape': ori_reduced_shape}) max_x = graph_builder.emit('Reshape', [max_x], attrs={'shape': ori_reduced_shape})
# 计算输入数据减去最大值的差值
data_sub = graph_builder.emit('Sub', [input_x, max_x]) data_sub = graph_builder.emit('Sub', [input_x, max_x])
# 计算差值的指数
data_exp = graph_builder.emit('Exp', [data_sub]) data_exp = graph_builder.emit('Exp', [data_sub])
# 计算指数的和
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
# 如果输入数据格式为FRAC_NZ则重新调整指数和的形状
if input_x.data_format == DF.FRAC_NZ: if input_x.data_format == DF.FRAC_NZ:
data_expsum = graph_builder.emit('Reshape', [data_expsum], attrs={'shape': ori_reduced_shape}) data_expsum = graph_builder.emit('Reshape', [data_expsum], attrs={'shape': ori_reduced_shape})
# 计算Softmax值
result = graph_builder.emit('RealDiv', [data_exp, data_expsum]) result = graph_builder.emit('RealDiv', [data_exp, data_expsum])
# 如果原始数据类型为float16且处理器为aicore则进行类型转换
if ori_dtype == "float16" and processor == "aicore": if ori_dtype == "float16" and processor == "aicore":
result = graph_builder.emit('Cast', [result], attrs={'dst_type': ori_dtype}) result = graph_builder.emit('Cast', [result], attrs={'dst_type': ori_dtype})

@ -22,21 +22,45 @@ class SoftmaxCrossEntropyWithLogits(Expander):
"""SoftmaxCrossEntropyWithLogits expander""" """SoftmaxCrossEntropyWithLogits expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算损失值和 logits 的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tuple[Tensor, Tensor]: 损失值和 logits 的梯度
"""
logits, label = self.inputs logits, label = self.inputs
# 计算 softmax_cross_entropy_with_logits(logits, label)
# softmax_cross_entropy_with_logits 的公式是: -reduce_sum(label * log(softmax(logits)))
# Calculate softmax_cross_entropy_with_logits(logits, label) # Calculate softmax_cross_entropy_with_logits(logits, label)
# formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits))) # formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits)))
axis = (-1,) axis = (-1,)
max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True}) max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True})
# 计算 logits 的最大值
data_sub = graph_builder.emit('Sub', [logits, max_x]) data_sub = graph_builder.emit('Sub', [logits, max_x])
# logits 减去最大值
data_exp = graph_builder.emit('Exp', [data_sub]) data_exp = graph_builder.emit('Exp', [data_sub])
# 对上一步结果进行指数运算
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
# 对指数运算结果求和
data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum]) data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum])
# 计算 softmax
const_eps = graph_builder.value(logits.dtype, 0.000001) const_eps = graph_builder.value(logits.dtype, 0.000001)
# 定义一个极小的常数,用于防止除以零的错误
data_softmax_safety = graph_builder.emit("Maximum", [data_softmax, const_eps]) data_softmax_safety = graph_builder.emit("Maximum", [data_softmax, const_eps])
# 确保 softmax 的值不为零
softmax_log = graph_builder.emit('Log', [data_softmax_safety]) softmax_log = graph_builder.emit('Log', [data_softmax_safety])
# 对 softmax 结果取对数
label_mul_log = graph_builder.emit('Mul', [label, softmax_log]) label_mul_log = graph_builder.emit('Mul', [label, softmax_log])
# 将 label 与 softmax 的对数相乘
tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={ tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={
'reduce_axis': axis, 'keep_dims': False}) 'reduce_axis': axis, 'keep_dims': False})
# 对上一步结果进行求和
loss = graph_builder.emit('Neg', [tmp_res]) loss = graph_builder.emit('Neg', [tmp_res])
# 计算损失值,即上一步结果的负值
dlogits = graph_builder.emit('Sub', [data_softmax, label]) dlogits = graph_builder.emit('Sub', [data_softmax, label])
# 计算 logits 的梯度
return loss, dlogits return loss, dlogits

@ -13,29 +13,48 @@
# limitations under the License. # limitations under the License.
# =========================================================================== # ===========================================================================
"""generate json desc for SoftmaxGradExt""" """generate json desc for SoftmaxGradExt"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import DataFormat as DF # 导入DataFormat类
from ._utils import Expander, ExpanderInfoValidator as VLD from ._utils import Expander, ExpanderInfoValidator as VLD # 导入Expander和ExpanderInfoValidator类
from ._utils import get_reduce_axis_shape from ._utils import get_reduce_axis_shape # 导入get_reduce_axis_shape函数
@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT) @VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT) # 使用ExpanderInfoValidator类添加FRAC_NZ格式
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('axis') @VLD.check_attrs('axis')
class SoftmaxGradExt(Expander): class SoftmaxGradExt(Expander):
"""SoftmaxGradExt expander""" """SoftmaxGradExt expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入数据进行扩展处理
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 处理后的数据
"""
# 获取输入参数
x, y, z = self.inputs x, y, z = self.inputs
# 获取指定的轴
axis = self.attrs['axis'] axis = self.attrs['axis']
# 获取需要减少的轴和原始减少的形状
reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis) reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
# 将x和y相乘
data_mul = graph_builder.emit('Mul', [x, y]) data_mul = graph_builder.emit('Mul', [x, y])
# 对乘积进行求和,并保留维度
data_sum = graph_builder.emit('ReduceSum', [data_mul], data_sum = graph_builder.emit('ReduceSum', [data_mul],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True, 'reduce_output_fuse': True}) attrs={'reduce_axis': reduce_axis, 'keep_dims': True, 'reduce_output_fuse': True})
# 如果x的数据格式为FRAC_NZ则对求和结果进行重塑
if x.data_format == DF.FRAC_NZ: if x.data_format == DF.FRAC_NZ:
data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape}) data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape})
# 从x中减去求和结果
data_sub = graph_builder.emit('Sub', [x, data_sum]) data_sub = graph_builder.emit('Sub', [x, data_sum])
# 将减法结果与y相乘
data_mul2 = graph_builder.emit('Mul', [data_sub, y]) data_mul2 = graph_builder.emit('Mul', [data_sub, y])
# 将结果与z相乘得到最终结果
result = graph_builder.emit('Mul', [data_mul2, z]) result = graph_builder.emit('Mul', [data_mul2, z])
return result return result

@ -21,9 +21,24 @@ class SqrtGrad(Expander):
"""SqrtGrad expander""" """SqrtGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算并返回给定输入 x 的平方根的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
Tensor: 返回给定输入 x 的平方根的梯度
"""
# 获取输入 x 和梯度 dout
# formula of sqrt_grad is dout / (2 * x) # formula of sqrt_grad is dout / (2 * x)
x, dout = self.inputs x, dout = self.inputs
# 创建一个常数值 2
const_two = graph_builder.value(x.dtype, 2) const_two = graph_builder.value(x.dtype, 2)
# 计算 2 * x
dividend = graph_builder.emit('Mul', [x, const_two]) dividend = graph_builder.emit('Mul', [x, const_two])
# 计算梯度dout / (2 * x)
result = graph_builder.emit('RealDiv', [dout, dividend]) result = graph_builder.emit('RealDiv', [dout, dividend])
# 返回计算结果
return result return result

@ -21,24 +21,57 @@ class SquareSumAll(Expander):
"""SquareSumAll expander""" """SquareSumAll expander"""
def _check(self): def _check(self):
"""
检查输入是否合法
Args:
Returns:
Raises:
GKException: 如果输入的数量不等于2则抛出GKException异常
"""
"""check inputs""" """check inputs"""
# 获取输入的数量
input_num = len(self.inputs) input_num = len(self.inputs)
if input_num != 2: if input_num != 2:
# 如果输入的数量不等于2则抛出异常
raise GKException("For 'SquareSumAll', the inputs number should be 2, but got {}.".format(input_num)) raise GKException("For 'SquareSumAll', the inputs number should be 2, but got {}.".format(input_num))
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
对输入的两个变量进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图构建过程中发射操作
Returns:
tuple: 包含两个元素的元组每个元素为扩展操作的结果
"""
"""do expand""" """do expand"""
# 获取输入的两个变量
x0 = self.inputs[0] x0 = self.inputs[0]
x1 = self.inputs[1] x1 = self.inputs[1]
# 获取x0的形状
ori_shape = x0.shape ori_shape = x0.shape
# 初始化一个空列表,用于存储维度索引
axis = [] axis = []
# 遍历ori_shape将每个维度的索引添加到axis列表中
for i, _ in enumerate(ori_shape): for i, _ in enumerate(ori_shape):
axis.append(i) axis.append(i)
# 对x0进行平方运算
square_res0 = graph_builder.emit('Mul', [x0, x0]) square_res0 = graph_builder.emit('Mul', [x0, x0])
# 对x1进行平方运算
square_res1 = graph_builder.emit('Mul', [x1, x1]) square_res1 = graph_builder.emit('Mul', [x1, x1])
# 对square_res0进行求和运算求和的维度为axis并保持维度不变
result0 = graph_builder.emit('ReduceSum', [square_res0], attrs={'reduce_axis': axis, 'keep_dims': False}) result0 = graph_builder.emit('ReduceSum', [square_res0], attrs={'reduce_axis': axis, 'keep_dims': False})
# 对square_res1进行求和运算求和的维度为axis并保持维度不变
result1 = graph_builder.emit('ReduceSum', [square_res1], attrs={'reduce_axis': axis, 'keep_dims': False}) result1 = graph_builder.emit('ReduceSum', [square_res1], attrs={'reduce_axis': axis, 'keep_dims': False})
return result0, result1 return result0, result1

@ -25,13 +25,30 @@ class SquareSumV1(Expander):
"""Square expander""" """Square expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算输入张量的平方并沿指定轴进行求和
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 计算得到的张量
"""
# 获取输入的第一个元素
x = self.inputs[0] x = self.inputs[0]
# 获取属性中的axis值
axis = self.attrs['axis'] axis = self.attrs['axis']
# 获取需要reduce的axis和原始的reduced shape
reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis) reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
# 计算x的平方
square_res = graph_builder.emit('Mul', [x, x]) square_res = graph_builder.emit('Mul', [x, x])
# 对平方结果进行ReduceSum操作
result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
# 如果数据格式为FRAC_NZ则对结果进行Reshape操作
if x.data_format == DF.FRAC_NZ: if x.data_format == DF.FRAC_NZ:
result = graph_builder.emit('Reshape', [result], attrs={'shape': ori_reduced_shape}) result = graph_builder.emit('Reshape', [result], attrs={'shape': ori_reduced_shape})
# 返回最终结果
return result return result

@ -21,10 +21,24 @@ class SquaredDifference(Expander):
"""SquaredDifference expander""" """SquaredDifference expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
根据输入的两个输入值计算并返回它们的平方差的计算结果
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中生成节点和边
Returns:
Node: 计算结果节点
"""
# 获取输入的第一个值
input_x = self.inputs[0] input_x = self.inputs[0]
# 获取输入的第二个值
input_y = self.inputs[1] input_y = self.inputs[1]
# 使用图构建器计算输入值的差值
sub_val = graph_builder.emit('Sub', [input_x, input_y]) sub_val = graph_builder.emit('Sub', [input_x, input_y])
# 使用图构建器计算差值的平方
result = graph_builder.emit('Mul', [sub_val, sub_val]) result = graph_builder.emit('Mul', [sub_val, sub_val])
return result return result

@ -21,27 +21,67 @@ class Squeeze(Expander):
"""Squeeze expander""" """Squeeze expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
扩展输入的维度
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建图结构
Returns:
Tensor: 扩展维度后的输入
"""
# 获取输入的第一个元素
input_x = self.inputs[0] input_x = self.inputs[0]
# 根据输入的shape和axis属性推断输出shape
out_shape = self.infer_shape(input_x.shape, self.attrs['axis']) out_shape = self.infer_shape(input_x.shape, self.attrs['axis'])
# 使用graph_builder发射Reshape操作并设置shape属性为out_shape
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': out_shape}) result = graph_builder.emit('Reshape', [input_x], attrs={'shape': out_shape})
# 返回结果
return result return result
@staticmethod @staticmethod
def infer_shape(shape, axis): def infer_shape(shape, axis):
"""
根据指定的axis推断squeeze后的shape
Args:
shape (list, tuple): 原始数据的shape
axis (int, list, tuple): 需要被squeeze的维度如果为int则只squeeze该维度
如果为list或tuple则squeeze列表或元组中的每个维度如果为空则squeeze所有维度为1的维度
Returns:
list: squeeze后的shape
Raises:
ValueError: 如果输入的axis类型不符合要求抛出异常
"""
"""infer shape for squeeze""" """infer shape for squeeze"""
def squeeze_axis(shape, axis): def squeeze_axis(shape, axis):
# 如果axis为空移除shape中所有值为1的维度
if not axis: if not axis:
out_shape = list(d for d in shape if d != 1) out_shape = list(d for d in shape if d != 1)
else: else:
# 获取shape的维度数量
ndim = len(shape) ndim = len(shape)
# 移除shape中指定的axis维度
out_shape = list(shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)) out_shape = list(shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis))
# 如果out_shape为空则将其设置为[1]
if not out_shape: if not out_shape:
out_shape = [1] out_shape = [1]
return out_shape return out_shape
# 如果shape是列表或元组类型
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
# 如果axis是整数类型则将其转换为列表
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis] axis = [axis]
# 如果axis是列表或元组类型则调用squeeze_axis函数处理
if isinstance(axis, (list, tuple)): if isinstance(axis, (list, tuple)):
return squeeze_axis(shape, axis) return squeeze_axis(shape, axis)
# 如果输入不符合要求,则抛出异常
raise ValueError("Invalid axis for Squeeze.") raise ValueError("Invalid axis for Squeeze.")

@ -21,11 +21,29 @@ class TanhGrad(Expander):
"""TanhGrad expander""" """TanhGrad expander"""
def _expand(self, graph_builder): def _expand(self, graph_builder):
"""
计算1减去输入值的平方后再与输入的导数相乘
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
Tensor: 计算结果类型为Tensor
"""
# 获取输入值
input_y, input_dy = self.inputs input_y, input_dy = self.inputs
# 创建一个值为1的常量数据类型与input_y相同
const_one = graph_builder.value(input_y.dtype, 1) const_one = graph_builder.value(input_y.dtype, 1)
# 计算input_y的平方
double_y = graph_builder.emit('Mul', [input_y, input_y]) double_y = graph_builder.emit('Mul', [input_y, input_y])
# 计算1减去input_y的平方
one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y]) one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y])
# 计算input_dy与1减去input_y的平方的乘积
result = graph_builder.emit('Mul', [input_dy, one_sub_double_y]) result = graph_builder.emit('Mul', [input_dy, one_sub_double_y])
return result return result

@ -25,30 +25,48 @@ class Tile(Expander):
def _get_output_shape(self): def _get_output_shape(self):
"""Get output shape""" """Get output shape"""
# 获取输入形状的列表
shape = list(self.inputs[0].shape) shape = list(self.inputs[0].shape)
# 获取属性"multiples"的列表
multiples = list(self.attrs["multiples"]) multiples = list(self.attrs["multiples"])
# 计算"multiples"和输入形状的长度差
diff_len = len(multiples) - len(shape) diff_len = len(multiples) - len(shape)
# 如果长度差小于0抛出异常
if diff_len < 0: if diff_len < 0:
raise GKException("For 'Tile', dimensions of attr 'multiples' should be greater than or equal to " raise GKException("For 'Tile', dimensions of attr 'multiples' should be greater than or equal to "
"dimensions of input shape, but got {} and {}".format(multiples, shape)) "dimensions of input shape, but got {} and {}".format(multiples, shape))
# 如果长度差大于0则扩展输入形状的列表
if diff_len > 0: if diff_len > 0:
for _ in range(diff_len): for _ in range(diff_len):
shape.insert(0, 1) shape.insert(0, 1)
# 初始化输出形状的列表
output_shape = [] output_shape = []
# 遍历输入形状和multiples的元组
for sh, mul in list(zip(shape, multiples)): for sh, mul in list(zip(shape, multiples)):
# 如果输入形状和multiples的值都不为1则抛出异常
if sh != 1 and mul != 1: if sh != 1 and mul != 1:
raise GKException("For 'Tile', input shape{} and attr 'multiples'{} can not broadcast." raise GKException("For 'Tile', input shape{} and attr 'multiples'{} can not broadcast."
.format(self.inputs[0].shape, multiples)) .format(self.inputs[0].shape, multiples))
# 计算维度
dim = sh * mul dim = sh * mul
# 将计算得到的维度添加到输出形状的列表中
output_shape.append(dim) output_shape.append(dim)
# 返回输出形状的列表
return output_shape return output_shape
def _expand(self, graph_builder): def _expand(self, graph_builder):
# 获取输入的第一个元素
input_x = self.inputs[0] input_x = self.inputs[0]
# 获取输出形状
output_shape = self._get_output_shape() output_shape = self._get_output_shape()
# 使用graph_builder的emit方法生成BroadcastTo操作
# 参数为[input_x]和输出形状
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
# 返回结果
return result return result

@ -14,6 +14,9 @@
# =========================================================================== # ===========================================================================
"""GraphKernel cost model init""" """GraphKernel cost model init"""
# 导入split模块
from .graph_split import split from .graph_split import split
# 导入GraphBuilder和load_composite模块
from .model_builder import GraphBuilder, load_composite from .model_builder import GraphBuilder, load_composite
# 导入parallel_estimate模块
from .graph_parallel import parallel_estimate from .graph_parallel import parallel_estimate

@ -20,128 +20,299 @@ class ParalGain:
"""Paral Gain""" """Paral Gain"""
def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info): def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info):
"""
类的构造函数
Args:
fusion_type (str): 融合类型
bottleneck (int): 瓶颈层的大小
gain (float): 增益值
block_assign (list): 块分配列表
type_info (dict): 类型信息字典
Returns:
None
"""
# 初始化融合类型
self.fusion_type = fusion_type self.fusion_type = fusion_type
# 初始化瓶颈层
self.bottleneck = bottleneck self.bottleneck = bottleneck
# 初始化增益
self.gain = gain self.gain = gain
# 初始化块分配
self.block_assign = block_assign self.block_assign = block_assign
# 初始化类型信息
self.type_info = type_info self.type_info = type_info
class ScheduleAnalyzer: class ScheduleAnalyzer:
"""schedule analyzer""" """schedule analyzer"""
# 定义一个常量表示wrap的大小
WRAP_SIZE = 32 WRAP_SIZE = 32
# 定义一个常量表示最大SM数量
MAX_SM = 80 # Volta MAX_SM = 80 # Volta
# 定义一个常量,表示最大线程数量
MAX_NUM_THREADS = 1024 MAX_NUM_THREADS = 1024
# 定义一个常量表示最大block数量
MAX_BLOCK = 256 MAX_BLOCK = 256
# 定义一个常量,表示流水线操作的阈值
PIPELINE_OP_THREADHOLD = 5 PIPELINE_OP_THREADHOLD = 5
def __init__(self, graph): def __init__(self, graph):
"""
初始化图处理类
Args:
graph (Graph): 图对象用于存储图的结构和参数
Attributes:
graph (Graph): 图对象存储图的结构和参数
block_num (int): 块的数量初始值为0
block_weight (float): 块的权重初始值为0
ops (List[Operation]): 图的操作列表
dom_op (List[Operation]): 输出的每个操作对应的操作列表
"""
# 将传入的图对象赋值给实例变量graph
self.graph = graph self.graph = graph
# 初始化block数量为0
self.block_num = 0 self.block_num = 0
# 初始化block权重为0
self.block_weight = 0 self.block_weight = 0
# 通过图对象的deduce_parameters方法获取参数并赋值给outputs变量
_, outputs = graph.deduce_parameters() _, outputs = graph.deduce_parameters()
# 将图对象的操作列表赋值给实例变量ops
self.ops = graph.ops self.ops = graph.ops
# 将outputs中的每个输出对应的操作收集到一个列表中并赋值给实例变量dom_op
self.dom_op = list(out.op for out in outputs) self.dom_op = list(out.op for out in outputs)
@staticmethod @staticmethod
def prod(shape): def prod(shape):
"""
计算形状乘积
Args:
shape (list): 一个包含整数的列表表示形状
Returns:
int: 形状乘积的结果
"""
"""Compute shape product""" """Compute shape product"""
# 初始化结果变量为shape的第一个元素
res = shape[0] res = shape[0]
# 遍历shape列表从第二个元素开始
for i in range(1, len(shape)): for i in range(1, len(shape)):
# 将当前结果与shape的下一个元素相乘
res = res * shape[i] res = res * shape[i]
# 返回计算后的结果
return res return res
def _cal_weight(self, ops): def _cal_weight(self, ops):
"""
计算给定操作列表的总权重
Args:
ops (list): 包含多个操作对象的列表
Returns:
int: 所有操作的权重总和
"""
weight = 0 weight = 0
for op in ops: for op in ops:
# 遍历每个操作
weight += self.prod(op.output.shape) * \ weight += self.prod(op.output.shape) * \
PrimLib.dtype_bytes(op.output.dtype) PrimLib.dtype_bytes(op.output.dtype) # 计算op的输出数据类型的字节数
return weight return weight
def injective_analyze(self): def injective_analyze(self):
"""
分析单射情况
Args:
Returns:
"""
"""analyze injective case""" """analyze injective case"""
# 计算常量大小
const_size = max((self.prod(op.output.shape) for op in self.dom_op)) const_size = max((self.prod(op.output.shape) for op in self.dom_op))
# 调整常量大小确保是MAX_NUM_THREADS的倍数
const_size = (const_size + self.MAX_NUM_THREADS - const_size = (const_size + self.MAX_NUM_THREADS -
1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS 1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS
# 计算总权重
total_weight = self._cal_weight(self.ops) total_weight = self._cal_weight(self.ops)
# 计算总块数
total_block = (const_size + self.MAX_NUM_THREADS - total_block = (const_size + self.MAX_NUM_THREADS -
1) // self.MAX_NUM_THREADS 1) // self.MAX_NUM_THREADS
# 判断是否需要分割块
need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS
if need_block_split: if need_block_split:
# 如果需要分割块设置块数为MAX_BLOCK
self.block_num = self.MAX_BLOCK self.block_num = self.MAX_BLOCK
# 计算波数
waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK
# 计算块权重
self.block_weight = total_weight // total_block * waves self.block_weight = total_weight // total_block * waves
else: else:
# 如果不需要分割块,设置块数为总块数
self.block_num = total_block self.block_num = total_block
# 计算块权重
self.block_weight = total_weight // self.block_num self.block_weight = total_weight // self.block_num
def reduce_analyze(self): def reduce_analyze(self):
"""
分析reduce操作
Args:
Returns:
Raises:
RuntimeError: 如果并行融合不支持多个reduce操作或者没有找到reduce操作
"""
"""analyze reduce case""" """analyze reduce case"""
# 定义线程数
thread_x, thread_y = 32, 32 thread_x, thread_y = 32, 32
reduce_op = None reduce_op = None
for op in self.ops: for op in self.ops:
# 判断操作类型是否为reduce
if PrimLib.iter_type(op) == PrimLib.REDUCE: if PrimLib.iter_type(op) == PrimLib.REDUCE:
# 如果已经存在reduce操作则抛出异常
if reduce_op: if reduce_op:
raise RuntimeError("Parallel fusion does not support multiple reduce op now.") raise RuntimeError("Parallel fusion does not support multiple reduce op now.")
reduce_op = op reduce_op = op
# 如果没有找到reduce操作则抛出异常
if not reduce_op: if not reduce_op:
raise RuntimeError("Parallel fusion does not find a reduce op.") raise RuntimeError("Parallel fusion does not find a reduce op.")
# 获取reduce操作的输入形状
shape = reduce_op.inputs[0].shape shape = reduce_op.inputs[0].shape
# 获取reduce操作的reduce轴
reduce_axis = reduce_op.attrs['reduce_axis'] reduce_axis = reduce_op.attrs['reduce_axis']
# 计算总空间
total_space = self.prod(shape) total_space = self.prod(shape)
# 计算reduce空间
red_space = shape[reduce_axis[0]] red_space = shape[reduce_axis[0]]
for i in range(1, len(reduce_axis)): for i in range(1, len(reduce_axis)):
red_space *= shape[reduce_axis[i]] red_space *= shape[reduce_axis[i]]
# 获取数据类型大小
dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype) dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype)
# 计算权重
weight = self._cal_weight(self.ops) # reduce + injective weight = self._cal_weight(self.ops) # reduce + injective
# 计算block_x
block_x = (total_space // red_space + thread_y - 1) // thread_y block_x = (total_space // red_space + thread_y - 1) // thread_y
# 计算block_w
block_w = (weight + block_x - 1) // block_x block_w = (weight + block_x - 1) // block_x
# 计算waves
waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK
# 设置block_num
self.block_num = min(self.MAX_BLOCK, block_x) self.block_num = min(self.MAX_BLOCK, block_x)
# 定义all_reduce
all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write
# 计算block_weight
self.block_weight = (block_w + all_reduce * self.block_weight = (block_w + all_reduce *
dtype_size * thread_x * thread_y) * waves dtype_size * thread_x * thread_y) * waves
def default_analyze(self): def default_analyze(self):
"""
默认分析函数
Args:
Returns:
Raises:
"""
"""analyze default case""" """analyze default case"""
# 定义一个内部函数,用于计算默认空间
def _cal_default_space(op): def _cal_default_space(op):
# 计算op的输出空间
space = self.prod(op.output.shape) space = self.prod(op.output.shape)
# 遍历op的所有输入
for t in op.inputs: for t in op.inputs:
# 计算输入的空间
size = self.prod(t.shape) size = self.prod(t.shape)
# 如果输入空间大于当前空间,则更新空间
if size > space: if size > space:
space = size space = size
# 返回计算出的空间
return space return space
# 计算所有操作中的最大空间
space = max((_cal_default_space(op) for op in self.dom_op)) space = max((_cal_default_space(op) for op in self.dom_op))
# each sm least 4 wrap # 每个sm至少包含4个wrap
# 计算所需的block数量
block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4) block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4)
# 将block数量限制在最大block数量之内
self.block_num = min(self.MAX_BLOCK, block) self.block_num = min(self.MAX_BLOCK, block)
# 计算每个block的权重
self.block_weight = self._cal_weight(self.ops) // self.block_num self.block_weight = self._cal_weight(self.ops) // self.block_num
def analyze(self): def analyze(self):
"""analyze ops""" """analyze ops"""
def _ops_type(ops, dom_op): def _ops_type(ops, dom_op):
"""
判断操作列表中是否包含reduce操作
Args:
ops (list): 操作列表
dom_op (list): 操作列表
Returns:
bool: 如果操作列表中包含reduce操作则返回True否则返回False
"""
# 检查ops列表中是否有reduce操作
have_reduce = any( have_reduce = any(
# 如果op的类型是PrimLib.REDUCE则返回True
(PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops)) (PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops))
if have_reduce: if have_reduce:
# 如果有reduce操作返回True
return True return True
# 否则返回dom_op[0]的类型
return PrimLib.iter_type(dom_op[0]) return PrimLib.iter_type(dom_op[0])
# 调用_ops_type函数获取dom_op的类型
dom_type = _ops_type(self.ops, self.dom_op) dom_type = _ops_type(self.ops, self.dom_op)
# 如果dom_type是PrimLib.ELEMWISE或PrimLib.BROADCAST类型
if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST): if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
# 调用injective_analyze方法
self.injective_analyze() self.injective_analyze()
# 如果dom_type是PrimLib.REDUCE类型
elif dom_type == PrimLib.REDUCE: elif dom_type == PrimLib.REDUCE:
# 调用reduce_analyze方法
self.reduce_analyze() self.reduce_analyze()
# 如果dom_type是其他类型
else: else:
# 调用default_analyze方法
self.default_analyze() self.default_analyze()
def suitable_to_pipeline(self): def suitable_to_pipeline(self):
"""judge whether is suitable to be pipeline optimized""" """judge whether is suitable to be pipeline optimized"""
# 判断是否适合进行流水线优化
# Reduce操作不适合
# Reduce is not suitable # Reduce is not suitable
def _contain_reduce(ops): def _contain_reduce(ops):
for op in ops: for op in ops:
# Reduce操作可能导致分片效果差
# Reduce may make the tiling bad. # Reduce may make the tiling bad.
if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE: if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE:
return True return True
@ -149,6 +320,7 @@ class ScheduleAnalyzer:
suitable = True suitable = True
if _contain_reduce(self.ops): if _contain_reduce(self.ops):
# 如果包含Reduce操作则不适合进行流水线优化
suitable = False suitable = False
return suitable return suitable
@ -166,13 +338,16 @@ class ScheduleAnalyzer:
classes (list[list[int]]): The list of clusters. Each cluster is a list of indices. classes (list[list[int]]): The list of clusters. Each cluster is a list of indices.
""" """
def _cal_mean(classes): def _cal_mean(classes):
# 计算每个聚类的均值
class_datas = list(list(data[cid] for cid in cls) for cls in classes) class_datas = list(list(data[cid] for cid in cls) for cls in classes)
return list(sum(cls) / len(cls) if cls else float('inf') for cls in class_datas) return list(sum(cls) / len(cls) if cls else float('inf') for cls in class_datas)
def _cal_distance(a, b): def _cal_distance(a, b):
# 计算两个元素之间的距离
return abs(a - b) return abs(a - b)
def _check_different(old_classes, new_classes): def _check_different(old_classes, new_classes):
# 检查新旧聚类是否不同
for o, n in zip(old_classes, new_classes): for o, n in zip(old_classes, new_classes):
if o != n: if o != n:
return True return True
@ -201,31 +376,39 @@ class ScheduleAnalyzer:
min_idx = i if min_dis > cur_dis else min_idx min_idx = i if min_dis > cur_dis else min_idx
min_dis = cur_dis if min_dis > cur_dis else min_dis min_dis = cur_dis if min_dis > cur_dis else min_dis
new_classes[min_idx].append(idx) new_classes[min_idx].append(idx)
# 检查聚类是否发生变化
changed = _check_different(classes, new_classes) changed = _check_different(classes, new_classes)
# 更新聚类
classes = new_classes classes = new_classes
return classes return classes
@staticmethod @staticmethod
def pipeline_fusion_analyze(blocks, op_sizes, exclude_id): def pipeline_fusion_analyze(blocks, op_sizes, exclude_id):
"""analyze whether the segments can be pipeline optimized""" """analyze whether the segments can be pipeline optimized"""
# op size first, block second. # op size first, block second。
# 操作大小在前,块在后
def _simple_factor(block, op_size): def _simple_factor(block, op_size):
return block + 5 * op_size return block + 5 * op_size
def _take_second(elem): def _take_second(elem):
return elem[1] return elem[1]
# 计算每个块的简单因子
simple_indicators = list(_simple_factor(b, s) simple_indicators = list(_simple_factor(b, s)
for b, s in zip(blocks, op_sizes)) for b, s in zip(blocks, op_sizes))
# 2 classes, one heavy, the other light # 2 classes, one heavy, the other light
# 两类,一类重,一类轻
classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id) classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id)
if not classes: if not classes:
return [] return []
# 计算每类的均值
means = list(sum([simple_indicators[idx] for idx in cls]) / means = list(sum([simple_indicators[idx] for idx in cls]) /
len(cls) if cls else float('inf') for cls in classes) len(cls) if cls else float('inf') for cls in classes)
# The target two clusters should be a heavy one and a light one. # The target two clusters should be a heavy one and a light one.
# 目标两类应该是一类重的和一类轻的
# The light one maybe suitable to run with pipeline optimized. # The light one maybe suitable to run with pipeline optimized.
# 轻的一类可能适合进行流水线优化
classes_infos = list([cls, m] for cls, m in zip(classes, means)) classes_infos = list([cls, m] for cls, m in zip(classes, means))
classes_infos.sort(key=_take_second) classes_infos.sort(key=_take_second)
pipeline_target = None pipeline_target = None
@ -234,6 +417,7 @@ class ScheduleAnalyzer:
pipeline_target = ci pipeline_target = ci
break break
pipeline_gids, pipeline_mean = pipeline_target pipeline_gids, pipeline_mean = pipeline_target
# 如果轻的一类的均值大于某个阈值,则返回空列表
if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks), if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks),
ScheduleAnalyzer.PIPELINE_OP_THREADHOLD): ScheduleAnalyzer.PIPELINE_OP_THREADHOLD):
return [] return []
@ -241,6 +425,7 @@ class ScheduleAnalyzer:
pipeline_blocks = [] pipeline_blocks = []
pipeline_weight = len(pipeline_gids) pipeline_weight = len(pipeline_gids)
# Try to make two paralleled at least. # Try to make two paralleled at least.
# 至少尝试两个并行
if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2: if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2:
if len(pipeline_gids[:pipeline_weight // 2]) > 1: if len(pipeline_gids[:pipeline_weight // 2]) > 1:
pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2]) pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2])
@ -252,49 +437,114 @@ class ScheduleAnalyzer:
@staticmethod @staticmethod
def fusion_consult(blocks, op_sizes, exclude_gid): def fusion_consult(blocks, op_sizes, exclude_gid):
"""
获取并行融合的建议
Args:
blocks (list): 包含多个计算块的列表
op_sizes (list): 每个操作的尺寸列表
exclude_gid (int): 需要排除的组ID
Returns:
tuple: 包含融合类型和类型信息的元组
Raises:
"""
"""get a recommendation for parallel fusion""" """get a recommendation for parallel fusion"""
# 默认是块融合
# Default is block fusion # Default is block fusion
fusion_type = "block_fusion" fusion_type = "block_fusion"
type_info = None type_info = None
# 禁用管道优化
activate_pipeline_optimization = False # Disable pipeline optimization for now. activate_pipeline_optimization = False # Disable pipeline optimization for now.
# 如果启用管道优化
if activate_pipeline_optimization: if activate_pipeline_optimization:
# 对块、操作大小和排除组ID进行管道融合分析
pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze( pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze(
blocks, op_sizes, exclude_gid) blocks, op_sizes, exclude_gid)
# 如果存在管道信息
if pipeline_info: if pipeline_info:
# 融合类型为块管道融合
fusion_type = "block_pipeline_fusion" fusion_type = "block_pipeline_fusion"
# 设置类型信息为管道信息
type_info = pipeline_info type_info = pipeline_info
return fusion_type, type_info return fusion_type, type_info
def block_parallel_estimate(graphs): def block_parallel_estimate(graphs):
"""
估计块并行增益
Args:
graphs (list): 图集合每个元素是一个图对象
Returns:
ParalGain: 包含块并行增益信息的ParalGain对象
"""
"""estimate block parallel gain""" """estimate block parallel gain"""
# 初始化变量
sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], [] sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], []
# 遍历图集合
for gid, g in enumerate(graphs): for gid, g in enumerate(graphs):
# 创建ScheduleAnalyzer对象
s = ScheduleAnalyzer(g) s = ScheduleAnalyzer(g)
# 分析图
s.analyze() s.analyze()
# 累加块的数量
sum_block += s.block_num sum_block += s.block_num
# 更新最大权重
if s.block_weight > max_weight: if s.block_weight > max_weight:
max_weight = s.block_weight max_weight = s.block_weight
# 累加权重
sum_weight += s.block_weight sum_weight += s.block_weight
# 添加块的数量到blocks列表
blocks.append(s.block_num) blocks.append(s.block_num)
# 添加操作数量到op_sizes列表
op_sizes.append(len(s.ops)) op_sizes.append(len(s.ops))
# 如果不适合流水线处理将gid添加到exclude_gid列表
if not s.suitable_to_pipeline(): if not s.suitable_to_pipeline():
exclude_gid.append(gid) exclude_gid.append(gid)
# 如果块的数量大于ScheduleAnalyzer.MAX_SM * 32返回"none"
if sum_block > ScheduleAnalyzer.MAX_SM * 32: if sum_block > ScheduleAnalyzer.MAX_SM * 32:
return ParalGain("none", sum_weight, 0, list(0 for _ in graphs), None) return ParalGain("none", sum_weight, 0, list(0 for _ in graphs), None)
# 获取融合类型和类型信息
fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid)) fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid))
# 返回ParalGain对象
return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info) return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info)
def parallel_estimate(graphs, target): def parallel_estimate(graphs, target):
"""
并行估计函数
Args:
graphs (list): 图结构列表
target (str): 目标类型例如"aicore"
Returns:
ParalGain: 并行增益对象
"""
"""Estimate parallel gain""" """Estimate parallel gain"""
# 如果目标是"aicore"
if target == "aicore": if target == "aicore":
# 融合类型为"block_fusion"
fusion_type = "block_fusion" fusion_type = "block_fusion"
# 类型信息为空
type_info = None type_info = None
# 假设估计值为1000
fake_estimate = 1000 fake_estimate = 1000
# 生成一个与graphs长度相同的列表每个元素都是1
fake_blocks = list(1 for g in graphs) fake_blocks = list(1 for g in graphs)
# 返回ParalGain对象
return ParalGain(fusion_type, fake_estimate, fake_estimate, fake_blocks, type_info) return ParalGain(fusion_type, fake_estimate, fake_estimate, fake_blocks, type_info)
# 调用block_parallel_estimate函数进行并行估计
return block_parallel_estimate(graphs) return block_parallel_estimate(graphs)

@ -24,39 +24,61 @@ class Utils:
@staticmethod @staticmethod
def get_attr_type(attr): def get_attr_type(attr):
"""Get attr type""" """Get attr type"""
# 判断attr是否为bool类型
if isinstance(attr, bool): if isinstance(attr, bool):
return 'bool' return 'bool'
# 判断attr是否为str类型
if isinstance(attr, str): if isinstance(attr, str):
return 'str' return 'str'
# 判断attr是否为int类型
if isinstance(attr, int): if isinstance(attr, int):
return 'int' return 'int'
# 判断attr是否为float类型
if isinstance(attr, float): if isinstance(attr, float):
return 'float' return 'float'
# 判断attr是否为list或tuple类型
if isinstance(attr, (list, tuple)): if isinstance(attr, (list, tuple)):
# 判断attr是否为空
if not attr: if not attr:
raise ValueError("attr is invalid: the length of attr is 0") raise ValueError("attr is invalid: the length of attr is 0")
# 判断attr的第一个元素是否为int类型
if isinstance(attr[0], int): if isinstance(attr[0], int):
return 'listInt' return 'listInt'
# 判断attr的第一个元素是否为str类型
if isinstance(attr[0], str): if isinstance(attr[0], str):
return 'listStr' return 'listStr'
# 如果attr的类型不在支持的列表中则抛出异常
raise ValueError("attr {} type {} is not in supported list ['bool', 'str', 'int', 'float', 'int' list, " raise ValueError("attr {} type {} is not in supported list ['bool', 'str', 'int', 'float', 'int' list, "
"'str' list]".format(attr, type(attr))) "'str' list]".format(attr, type(attr)))
class DataFormat: class DataFormat:
"""DataFormat""" """DataFormat"""
# 默认格式
DEFAULT = "DefaultFormat" DEFAULT = "DefaultFormat"
# NC1KHKWHWC0格式
NC1KHKWHWC0 = "NC1KHKWHWC0" NC1KHKWHWC0 = "NC1KHKWHWC0"
# ND格式
ND = "ND" ND = "ND"
# NCHW格式
NCHW = "NCHW" NCHW = "NCHW"
# NHWC格式
NHWC = "NHWC" NHWC = "NHWC"
# HWCN格式
HWCN = "HWCN" HWCN = "HWCN"
# NC1HWC0格式
NC1HWC0 = "NC1HWC0" NC1HWC0 = "NC1HWC0"
# FRAC_Z格式
FRAC_Z = "FracZ" FRAC_Z = "FracZ"
# FRAC_NZ格式
FRAC_NZ = "FRACTAL_NZ" FRAC_NZ = "FRACTAL_NZ"
# C1HWNCOC0格式
C1HWNCOC0 = "C1HWNCoC0" C1HWNCOC0 = "C1HWNCoC0"
# NC1HWC0_C04格式
NC1HWC0_C04 = "NC1HWC0_C04" NC1HWC0_C04 = "NC1HWC0_C04"
# FRACTAL_Z_C04格式
FRACTAL_Z_C04 = "FRACTAL_Z_C04" FRACTAL_Z_C04 = "FRACTAL_Z_C04"
# NDHWC格式
NDHWC = "NDHWC" NDHWC = "NDHWC"
def __init__(self): def __init__(self):
@ -65,29 +87,47 @@ class DataFormat:
class DataType: class DataType:
"""Data Type""" """Data Type"""
# 浮点型
FLOAT = "float" FLOAT = "float"
# 半精度浮点型
FLOAT16 = "float16" FLOAT16 = "float16"
# 单精度浮点型
FLOAT32 = "float32" FLOAT32 = "float32"
# 双精度浮点型
FLOAT64 = "float64" FLOAT64 = "float64"
# 整型
INT = "int" INT = "int"
# 8位整型
INT8 = "int8" INT8 = "int8"
# 16位整型
INT16 = "int16" INT16 = "int16"
# 32位整型
INT32 = "int32" INT32 = "int32"
# 64位整型
INT64 = "int64" INT64 = "int64"
# 无符号整型
UINT = "uint" UINT = "uint"
# 8位无符号整型
UINT8 = "uint8" UINT8 = "uint8"
# 16位无符号整型
UINT16 = "uint16" UINT16 = "uint16"
# 32位无符号整型
UINT32 = "uint32" UINT32 = "uint32"
# 64位无符号整型
UINT64 = "uint64" UINT64 = "uint64"
# 布尔型
BOOL = "bool" BOOL = "bool"
# 初始化函数
def __init__(self): def __init__(self):
# 无需执行任何操作
pass pass
class PrimLib: class PrimLib:
"""Prim lib""" """Prim lib"""
# 定义PrimLib类中的常量
UNKNOWN = 0 UNKNOWN = 0
RESHAPE = 1 RESHAPE = 1
ELEMWISE = 2 ELEMWISE = 2
@ -102,53 +142,73 @@ class PrimLib:
"""Prim""" """Prim"""
def __init__(self, iter_type, calibrate=1, relation_func=None): def __init__(self, iter_type, calibrate=1, relation_func=None):
# 初始化Prim类设置iter_type、calibrate和relation_func属性
self.iter_type = iter_type self.iter_type = iter_type
self.calibrate = calibrate self.calibrate = calibrate
self.relation_func = relation_func self.relation_func = relation_func
if relation_func is None: if relation_func is None:
# 如果relation_func为None则设置默认的relation_func
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x) self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
def default_reshape_relation(self, op, input_idx): def default_reshape_relation(self, op, input_idx):
"""Process reshape relation""" """Process reshape relation"""
# 处理reshape关系
axis_relation, elem_relation = self.unknown_relation(op, input_idx) axis_relation, elem_relation = self.unknown_relation(op, input_idx)
# 将elem_relation设置为PrimLib.RESHAPE
elem_relation = [PrimLib.RESHAPE] * len(elem_relation) elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
return axis_relation, elem_relation return axis_relation, elem_relation
def default_elemwise_broadcast_relation(self, op, input_idx): def default_elemwise_broadcast_relation(self, op, input_idx):
"""Process elemwise and broadcast relation""" """Process elemwise and broadcast relation"""
# 处理elemwise和broadcast关系
out_shape = op.output.shape out_shape = op.output.shape
in_shape = op.inputs[input_idx].shape in_shape = op.inputs[input_idx].shape
# 如果输出形状的长度小于输入形状的长度,则抛出异常
if len(out_shape) < len(in_shape): if len(out_shape) < len(in_shape):
raise ValueError("For '{}', the input/output size is abnormal, as the length of output shape{} is less " raise ValueError("For '{}', the input/output size is abnormal, as the length of output shape{} is less "
"than the length of input shape{}".format(op.prim, out_shape, in_shape)) "than the length of input shape{}".format(op.prim, out_shape, in_shape))
axis_relation, elem_relation = [], [] axis_relation, elem_relation = [], []
# 计算输出形状和输入形状的长度差
delta = len(out_shape) - len(in_shape) delta = len(out_shape) - len(in_shape)
if delta > 0: if delta > 0:
# 如果输出形状的长度大于输入形状的长度则在axis_relation和elem_relation中添加None
for i in range(0, delta): for i in range(0, delta):
axis_relation.append(None) axis_relation.append(None)
elem_relation.append(None) elem_relation.append(None)
# 遍历输入形状的每个元素
for i, _ in enumerate(in_shape): for i, _ in enumerate(in_shape):
# 在axis_relation中添加当前元素的索引
axis_relation.append(i) axis_relation.append(i)
# 如果输出形状的对应元素等于输入形状的对应元素则elem_relation添加PrimLib.ELEMWISE否则添加PrimLib.BROADCAST
elem_relation.append( elem_relation.append(
PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST) PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
return axis_relation, elem_relation return axis_relation, elem_relation
def default_reduce_relation(self, op, input_idx): def default_reduce_relation(self, op, input_idx):
"""Process reduce relation""" """Process reduce relation"""
# 处理reduce关系
axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx) axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
# 遍历reduce_axis中的每个元素
for i in op.attrs['reduce_axis']: for i in op.attrs['reduce_axis']:
# 将elem_relation中对应元素的值设置为PrimLib.REDUCE
elem_relation[i] = PrimLib.REDUCE elem_relation[i] = PrimLib.REDUCE
return axis_relation, elem_relation return axis_relation, elem_relation
def unknown_relation(self, op, input_idx): def unknown_relation(self, op, input_idx):
"""Process unknown relation""" """Process unknown relation"""
# 获取输出和输入的形状
out_shape = op.output.shape out_shape = op.output.shape
in_shape = op.inputs[input_idx].shape in_shape = op.inputs[input_idx].shape
# 获取所有可能的轴关系
all_relation = list(range(len(in_shape))) all_relation = list(range(len(in_shape)))
# 初始化轴关系列表
axis_relation = [all_relation for i in range(0, len(out_shape))] axis_relation = [all_relation for i in range(0, len(out_shape))]
# 初始化元素关系列表
elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))] elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
# 返回轴关系和元素关系
return axis_relation, elem_relation return axis_relation, elem_relation
# 默认的关系函数列表
default_relation_func = [ default_relation_func = [
unknown_relation, unknown_relation,
default_reshape_relation, default_reshape_relation,
@ -158,6 +218,7 @@ class PrimLib:
unknown_relation, unknown_relation,
] ]
# 定义基本操作
primtives = { primtives = {
'Add': Prim(ELEMWISE), 'Add': Prim(ELEMWISE),
'Abs': Prim(ELEMWISE), 'Abs': Prim(ELEMWISE),
@ -239,7 +300,9 @@ class PrimLib:
@classmethod @classmethod
def get_prim(cls, op): def get_prim(cls, op):
"""Get op primtive""" """Get op primtive"""
# 从cls.primtives中获取op.prim对应的prim
prim = cls.primtives.get(op.prim, None) prim = cls.primtives.get(op.prim, None)
# 如果prim为None则打印警告信息并返回cls.default_primtive
if prim is None: if prim is None:
print('[WARN] primtive is not registered: ' + op.prim) print('[WARN] primtive is not registered: ' + op.prim)
prim = cls.default_primtive prim = cls.default_primtive
@ -248,50 +311,65 @@ class PrimLib:
@classmethod @classmethod
def input_relation(cls, op, input_idx): def input_relation(cls, op, input_idx):
"""Get op's input_relation according to input_idx""" """Get op's input_relation according to input_idx"""
# 调用cls.get_prim(op)获取op对应的prim然后调用prim的relation_func方法获取op的input_relation
return cls.get_prim(op).relation_func(op, input_idx) return cls.get_prim(op).relation_func(op, input_idx)
@classmethod @classmethod
def iter_type(cls, op): def iter_type(cls, op):
"""Get op's iter type""" """Get op's iter type"""
# 调用cls.get_prim(op)获取op对应的prim然后返回prim的iter_type
return cls.get_prim(op).iter_type return cls.get_prim(op).iter_type
@classmethod @classmethod
def is_reduce(cls, op): def is_reduce(cls, op):
"""Check whether op's iter type is reduce""" """Check whether op's iter type is reduce"""
# 调用cls.get_prim(op)获取op对应的prim然后判断prim的iter_type是否为cls.REDUCE
return cls.get_prim(op).iter_type == cls.REDUCE return cls.get_prim(op).iter_type == cls.REDUCE
@classmethod @classmethod
def calibrate_iter_size(cls, op, iter_size): def calibrate_iter_size(cls, op, iter_size):
"""Get calibrate_iter_size""" """Get calibrate_iter_size"""
# 调用cls.get_prim(op)获取op对应的prim然后返回prim的calibrate乘以iter_size
return cls.get_prim(op).calibrate * iter_size return cls.get_prim(op).calibrate * iter_size
@classmethod @classmethod
def dtype_bytes(cls, dtype): def dtype_bytes(cls, dtype):
"""Get dtype bytes""" """Get dtype bytes"""
# 初始化bits和unit为1
bits, unit = 1, 1 bits, unit = 1, 1
# 从dtype的最后一个字符开始向前遍历
for i in range(len(dtype) - 1, 0, -1): for i in range(len(dtype) - 1, 0, -1):
# 如果当前字符是数字则将bits加上当前字符对应的数字乘以unit并将unit乘以10
if dtype[i].isdecimal(): if dtype[i].isdecimal():
bits += int(dtype[i]) * unit bits += int(dtype[i]) * unit
unit *= 10 unit *= 10
# 如果当前字符不是数字,则跳出循环
else: else:
break break
# 返回bits除以8的结果
return bits // 8 return bits // 8
@classmethod @classmethod
def inplace_reuse(cls, op, input_idx, start_axis=0): def inplace_reuse(cls, op, input_idx, start_axis=0):
"""Check whether op is inplace reuse""" """Check whether op is inplace reuse"""
# 如果op.output.dtype的字节数大于op.inputs[input_idx].dtype的字节数则返回False
if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype): if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
return False return False
# 调用cls.get_prim(op)获取op对应的prim然后调用prim的relation_func方法获取op的input_relation
_, elem_relation = cls.get_prim(op).relation_func(op, input_idx) _, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
# 从start_axis开始遍历elem_relation
for i in range(start_axis, len(elem_relation)): for i in range(start_axis, len(elem_relation)):
# 如果elem_relation中的元素不等于cls.ELEMWISE则返回False
if elem_relation[i] != cls.ELEMWISE: if elem_relation[i] != cls.ELEMWISE:
return False return False
# 如果以上条件都不满足则返回True
return True return True
class Tensor: class Tensor:
"""Tensor""" """Tensor"""
# 参数类型常量
PARA_NONE = 0 PARA_NONE = 0
PARA_INPUT = 1 PARA_INPUT = 1
PARA_OUTPUT = 2 PARA_OUTPUT = 2
@ -303,6 +381,7 @@ class Tensor:
self.members = [leader] self.members = [leader]
def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0): def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
# 初始化Tensor对象
self.name = name self.name = name
self.shape = shape self.shape = shape
self.dtype = dtype self.dtype = dtype
@ -313,13 +392,16 @@ class Tensor:
self.buddy = None self.buddy = None
def __str__(self): def __str__(self):
# 返回Tensor对象的字符串表示
return self.name + str(list(self.shape)) return self.name + str(list(self.shape))
def __repr__(self): def __repr__(self):
# 返回Tensor对象的字符串表示
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
def get_size(self): def get_size(self):
"""Get size""" """Get size"""
# 获取Tensor对象的大小
size = PrimLib.dtype_bytes(self.dtype) size = PrimLib.dtype_bytes(self.dtype)
for i in self.shape: for i in self.shape:
size *= i size *= i
@ -327,6 +409,7 @@ class Tensor:
def add_buddy(self, tensor): def add_buddy(self, tensor):
"""Add buddy""" """Add buddy"""
# 添加buddy
if self.buddy is None: if self.buddy is None:
self.buddy = self.Buddy(self) self.buddy = self.Buddy(self)
self.buddy.members.append(tensor) self.buddy.members.append(tensor)
@ -337,6 +420,7 @@ class Value:
"""Value""" """Value"""
def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT): def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
# 初始化Value对象
self.name = name self.name = name
self.shape = [1] self.shape = [1]
self.dtype = dtype self.dtype = dtype
@ -344,14 +428,17 @@ class Value:
self.data_format = data_format self.data_format = data_format
def __str__(self): def __str__(self):
# 返回Value对象的字符串表示
return self.name + str(list(self.shape)) return self.name + str(list(self.shape))
def __repr__(self): def __repr__(self):
# 返回Value对象的字符串表示
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
@staticmethod @staticmethod
def get_size(): def get_size():
"""Get size""" """Get size"""
# 获取Value对象的大小
return 1 return 1
@ -359,23 +446,31 @@ class Operator:
"""Operator""" """Operator"""
def __init__(self, primtive, inputs, output, attrs): def __init__(self, primtive, inputs, output, attrs):
# 初始化Operator对象
self.prim = primtive self.prim = primtive
self.inputs = inputs self.inputs = inputs
self.output = output self.output = output
self.attrs = attrs self.attrs = attrs
# 将当前Operator对象添加到每个输入的to_ops列表中
for t in inputs: for t in inputs:
t.to_ops.append(self) t.to_ops.append(self)
# 如果输出的op属性为None则将当前Operator对象赋值给输出的op属性
if output.op is None: if output.op is None:
output.op = self output.op = self
# 初始化all_inputs列表用于存储Tensor输入和Value输入
self.all_inputs = [] # include Tensor inputs and Value inputs. self.all_inputs = [] # include Tensor inputs and Value inputs.
def __str__(self): def __str__(self):
# 将self.all_inputs中的元素转换为字符串并用逗号连接起来
args = ', '.join((str(t) for t in self.all_inputs)) args = ', '.join((str(t) for t in self.all_inputs))
# 构造表达式字符串
expr = "%s = %s.%s(%s) id:%s" % ( expr = "%s = %s.%s(%s) id:%s" % (
str(self.output), self.prim, self.output.dtype, args, id(self)) str(self.output), self.prim, self.output.dtype, args, id(self))
# 如果self.attrs不为空则返回表达式字符串和self.attrs的字符串连接
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs)) return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
def __repr__(self): def __repr__(self):
# 返回self的字符串表示
return str(self) return str(self)
@ -383,6 +478,7 @@ class Graph:
"""Graph""" """Graph"""
def __init__(self, name, ops, stitch_info=None, recompute_ops=None): def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
# 初始化Graph对象
self.name = name self.name = name
self.ops = ops # in topo order, can not use set self.ops = ops # in topo order, can not use set
self.inputs = [] self.inputs = []
@ -393,10 +489,12 @@ class Graph:
def set_processor(self, processor): def set_processor(self, processor):
"""Set processor""" """Set processor"""
# 设置处理器
self.processor = processor self.processor = processor
def add(self, ops): def add(self, ops):
"""Add ops""" """Add ops"""
# 添加操作
if isinstance(ops, Operator): if isinstance(ops, Operator):
self.ops.append(ops) self.ops.append(ops)
else: else:
@ -404,101 +502,148 @@ class Graph:
def extract_subgraph(self, graph_name, tensor_names, difference=False): def extract_subgraph(self, graph_name, tensor_names, difference=False):
"""Extract subgraph from this graph""" """Extract subgraph from this graph"""
# 从当前图中提取子图
graph = Graph(graph_name, []) graph = Graph(graph_name, [])
outputs = set(tensor_names) outputs = set(tensor_names)
if difference: if difference:
# 如果difference为True则提取不在outputs中的操作
for op in self.ops: for op in self.ops:
if op.output.name not in outputs: if op.output.name not in outputs:
graph.add(op) graph.add(op)
else: else:
# 如果difference为False则提取在outputs中的操作
for op in self.ops: for op in self.ops:
if op.output.name in outputs: if op.output.name in outputs:
graph.add(op) graph.add(op)
outputs.remove(op.output.name) outputs.remove(op.output.name)
# 如果outputs中还有元素则抛出异常
for name in outputs: for name in outputs:
raise ValueError("Invalid input tensor : {}, can not find it in graph".format(name)) raise ValueError("Invalid input tensor : {}, can not find it in graph".format(name))
return graph return graph
def deduce_parameters(self): def deduce_parameters(self):
"""Deduce parameters""" """Deduce parameters"""
# 初始化输入和输出列表
inputs, outputs = [], [] inputs, outputs = [], []
# 遍历所有操作
for op in self.ops: for op in self.ops:
# 遍历操作的所有输入
for t in op.inputs: for t in op.inputs:
# 如果输入不在输入列表中,且输入的操作不在操作列表中,则将输入添加到输入列表中
if t not in inputs and t.op not in self.ops: if t not in inputs and t.op not in self.ops:
inputs.append(t) inputs.append(t)
# 如果操作输出已经在输出列表中,则跳过
if op.output in outputs: if op.output in outputs:
continue continue
# 如果操作输出是输出参数类型,或者操作输出没有后续操作,则将操作输出添加到输出列表中
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops: if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
outputs.append(op.output) outputs.append(op.output)
continue continue
# 如果操作输出的后续操作不在操作列表中,则将操作输出添加到输出列表中
if any((succ not in self.ops for succ in op.output.to_ops)): if any((succ not in self.ops for succ in op.output.to_ops)):
outputs.append(op.output) outputs.append(op.output)
# 如果有指定的输入,则将指定的输入赋值给输入列表
if self.inputs: if self.inputs:
inputs = self.inputs inputs = self.inputs
# 如果有指定的输出,则将指定的输出赋值给输出列表
if self.outputs: if self.outputs:
outputs = self.outputs outputs = self.outputs
# 返回输入和输出列表
return inputs, outputs return inputs, outputs
def __str__(self): def __str__(self):
# 调用deduce_parameters方法获取输入和输出列表
inputs, outputs = self.deduce_parameters() inputs, outputs = self.deduce_parameters()
# 将输入列表转换为字符串
para_str = ', '.join((repr(t) for t in inputs)) para_str = ', '.join((repr(t) for t in inputs))
# 将输出列表转换为字符串
out_str = ', '.join((repr(t) for t in outputs)) out_str = ', '.join((repr(t) for t in outputs))
# 初始化行列表
lines = [] lines = []
# 添加操作名称、输入和输出到行列表中
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str)) lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
# 如果有拼接信息,则添加拼接操作和拼接原子操作到行列表中
if self.stitch_info: if self.stitch_info:
if self.stitch_info.stitch_ops: if self.stitch_info.stitch_ops:
lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops)) lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
if self.stitch_info.stitch_atomic_ops: if self.stitch_info.stitch_atomic_ops:
lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops)) lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
# 遍历所有操作,将操作添加到行列表中
for op in self.ops: for op in self.ops:
lines.append(' ' + str(op)) lines.append(' ' + str(op))
# 添加结束符号到行列表中
lines.append('}') lines.append('}')
# 将行列表转换为字符串并返回
return '\n'.join(lines) return '\n'.join(lines)
def __repr__(self): def __repr__(self):
# 返回对象的字符串表示
return str(self) return str(self)
def dump(self): def dump(self):
"""Dump Graph to json""" """Dump Graph to json"""
# 将Graph转换为json格式
attr_name = {'reduce_axis': 'axis'} attr_name = {'reduce_axis': 'axis'}
# 获取Graph的输入和输出参数
inputs, outputs = self.deduce_parameters() inputs, outputs = self.deduce_parameters()
input_desc, output_desc, op_desc = [], [], [] input_desc, output_desc, op_desc = [], [], []
# 遍历输入参数
for t in inputs: for t in inputs:
# 将输入参数转换为字典格式
input_desc.append([{'data_type': t.dtype, 'shape': t.shape, input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}]) 'tensor_name': t.name, 'format': t.data_format}])
# 遍历输出参数
for t in outputs: for t in outputs:
# 将输出参数转换为字典格式
output_desc.append({'data_type': t.dtype, 'shape': t.shape, output_desc.append({'data_type': t.dtype, 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}) 'tensor_name': t.name, 'format': t.data_format})
# 遍历Graph中的操作
for op in self.ops: for op in self.ops:
attrs, in_desc = [], [] attrs, in_desc = [], []
# 遍历操作中的属性
for a in op.attrs: for a in op.attrs:
# 获取属性名
name = attr_name.get(a, a) name = attr_name.get(a, a)
# 将属性转换为字典格式
attrs.append( attrs.append(
{'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])}) {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
# 遍历操作中的输入
for t in op.all_inputs: for t in op.all_inputs:
# 如果输入是Tensor类型
if isinstance(t, Tensor): if isinstance(t, Tensor):
# 将输入转换为字典格式
in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape, in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}]) 'tensor_name': t.name, 'format': t.data_format}])
else: else:
# 将输入转换为字典格式
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape, in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}]) 'tensor_name': t.name, 'format': t.data_format}])
# 将操作输出转换为字典格式
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape, out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
'tensor_name': op.output.name, 'format': op.output.data_format}] 'tensor_name': op.output.name, 'format': op.output.data_format}]
# 将操作转换为字典格式
op_desc.append({'attr': attrs, 'impl_path': '', op_desc.append({'attr': attrs, 'impl_path': '',
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc}) 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
# 将Graph转换为字典格式
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0, graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc, 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
'platform': 'AKG', 'process': self.processor} 'platform': 'AKG', 'process': self.processor}
# 如果Graph中有stitch信息
if self.stitch_info and self.stitch_info.stitch_ops: if self.stitch_info and self.stitch_info.stitch_ops:
# 将stitch信息转换为字典格式
buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)} buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
# 如果有stitch_atomic_ops
if self.stitch_info.stitch_atomic_ops: if self.stitch_info.stitch_atomic_ops:
# 将stitch_atomic_ops转换为字典格式
buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops) buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
# 将stitch信息添加到Graph字典中
graph_desc['buffer_stitch'] = buffer_stitch graph_desc['buffer_stitch'] = buffer_stitch
# 返回Graph字典
return graph_desc return graph_desc
@ -506,13 +651,16 @@ class GraphVisitor:
"""Graph visitor""" """Graph visitor"""
def __init__(self, forward=True): def __init__(self, forward=True):
# 初始化forward参数默认为True
self.forward = forward self.forward = forward
def visit_graph(self, graph): def visit_graph(self, graph):
"""Visit graph""" """Visit graph"""
# 如果forward为True则按照顺序遍历graph中的ops
if self.forward: if self.forward:
for op in graph.ops: for op in graph.ops:
self.visit(op) self.visit(op)
# 如果forward为False则按照逆序遍历graph中的ops
else: else:
for i in range(len(graph.ops)-1, -1, -1): for i in range(len(graph.ops)-1, -1, -1):
self.visit(graph.ops[i]) self.visit(graph.ops[i])
@ -528,12 +676,18 @@ class AlignShape(GraphVisitor):
def visit(op): def visit(op):
"""Visit op node""" """Visit op node"""
prim = PrimLib.get_prim(op) prim = PrimLib.get_prim(op)
# 如果op的迭代类型是ELEMWISE、BROADCAST或REDUCE则需要进行形状对齐
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE): if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
# 获取op的输出维度
out_dim = len(op.output.shape) out_dim = len(op.output.shape)
# 初始化对齐维度为输出维度
align_dim = out_dim align_dim = out_dim
# 遍历op的输入
for t in op.inputs: for t in op.inputs:
# 如果输入的维度大于对齐维度,则更新对齐维度
if len(t.shape) > align_dim: if len(t.shape) > align_dim:
align_dim = len(t.shape) align_dim = len(t.shape)
# 如果对齐维度大于输出维度则对op的输出形状进行对齐
if align_dim > out_dim: if align_dim > out_dim:
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape op.output.shape = [1] * (align_dim - out_dim) + op.output.shape

@ -25,90 +25,140 @@ class GraphBuilder:
"""Graph wrapper""" """Graph wrapper"""
def __init__(self, name): def __init__(self, name):
"""
初始化类实例
Args:
name (str): 图的名称
Attributes:
self.graph (Graph): 图的实例使用传入的名称初始化
"""
self.graph = Graph(name, []) self.graph = Graph(name, [])
def set_input(self, *para): def set_input(self, *para):
"""set input to graph inputs""" """set input to graph inputs"""
# 遍历传入的参数
for t in para: for t in para:
# 设置参数类型为输入参数
t.para_type = Tensor.PARA_INPUT t.para_type = Tensor.PARA_INPUT
# 将参数添加到图的输入列表中
self.graph.inputs.append(t) self.graph.inputs.append(t)
def set_output(self, *para): def set_output(self, *para):
"""set output to graph inputs""" """set output to graph inputs"""
# 遍历传入的参数
for t in para: for t in para:
# 设置参数类型为输出参数
t.para_type = Tensor.PARA_OUTPUT t.para_type = Tensor.PARA_OUTPUT
# 将参数添加到图的输出列表中
self.graph.outputs.append(t) self.graph.outputs.append(t)
def __init__(self): def __init__(self):
# 初始化图列表
self.graphs = [] self.graphs = []
# 当前图设置为None
self.current = None self.current = None
# 初始化名称ID
self.name_id = 0 self.name_id = 0
def _alloc_tensor_name(self): def _alloc_tensor_name(self):
# 获取当前名称ID
tid = self.name_id tid = self.name_id
# 名称ID加1
self.name_id += 1 self.name_id += 1
# 格式化字符串,生成张量名称
return "t%d" % (tid) return "t%d" % (tid)
def graph_scope(self, name): def graph_scope(self, name):
"""The graph scope to be processed""" """The graph scope to be processed"""
# 定义GraphScope类
class GraphScope: class GraphScope:
"""Graph Scope""" """Graph Scope"""
def __init__(self, gb): def __init__(self, gb):
# 初始化GraphScope对象接收一个GraphBuilder对象
self.gb = gb self.gb = gb
def __enter__(self): def __enter__(self):
# 当使用with语句进入GraphScope上下文时调用
return self.gb.current return self.gb.current
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
# 当离开GraphScope上下文时调用
self.gb.graphs.append(self.gb.current.graph) self.gb.graphs.append(self.gb.current.graph)
self.gb.current = None self.gb.current = None
# 检查self.current是否不为None
if self.current is not None: if self.current is not None:
raise ValueError("self.current is not None!") raise ValueError("self.current is not None!")
# 创建GraphWrapper对象并赋值给self.current
self.current = self.GraphWrapper(name) self.current = self.GraphWrapper(name)
# 返回GraphScope对象
return GraphScope(self) return GraphScope(self)
def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE): def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE):
"""Create a new Tensor""" """创建一个新的张量"""
# 如果名称为空或None则分配一个新的张量名称
if name in (None, ''): if name in (None, ''):
# 分配一个新的张量名称
name = self._alloc_tensor_name() name = self._alloc_tensor_name()
# 如果shape为空则默认设置为[1]
if not shape: if not shape:
shape = [1] shape = [1]
# 返回创建好的张量对象
return Tensor(name, shape, dtype, data_format, para_type=para_type) return Tensor(name, shape, dtype, data_format, para_type=para_type)
def value(self, dtype, value, name=None): def value(self, dtype, value, name=None):
"""Create a new Value""" """Create a new Value"""
# 如果name为None或空字符串
if name in (None, ''): if name in (None, ''):
# 分配一个新的tensor名称
name = self._alloc_tensor_name() name = self._alloc_tensor_name()
# 创建一个新的Value对象
v = Value(name, dtype, value) v = Value(name, dtype, value)
# 返回创建的Value对象
return v return v
def op(self, prim, output, inputs, attrs=None): def op(self, prim, output, inputs, attrs=None):
"""Insert an operator into graph""" """Insert an operator into graph"""
# 如果 attrs 为 None则将其设置为空字典
if attrs is None: if attrs is None:
attrs = {} attrs = {}
# 如果 inputs 是 Tensor 类型,则将其转换为列表
if isinstance(inputs, Tensor): if isinstance(inputs, Tensor):
inputs = [inputs] inputs = [inputs]
# 过滤出 inputs 中 Tensor 类型的元素
tensor_inputs = [t for t in inputs if isinstance(t, Tensor)] tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
# 创建一个 Operator 对象
node = Operator(prim, tensor_inputs, output, attrs) node = Operator(prim, tensor_inputs, output, attrs)
# 将所有输入保存到 node 的 all_inputs 属性中
node.all_inputs = inputs node.all_inputs = inputs
# 将 node 添加到当前图的节点列表中
self.current.graph.add(node) self.current.graph.add(node)
def emit(self, prim, inputs, name=None, attrs=None): def emit(self, prim, inputs, name=None, attrs=None):
"""Emit a new operation""" """Emit a new operation"""
# 如果attrs为None则初始化为空字典
if attrs is None: if attrs is None:
attrs = {} attrs = {}
# 如果inputs是Tensor或Value的实例则将其转换为列表
if isinstance(inputs, (Tensor, Value)): if isinstance(inputs, (Tensor, Value)):
inputs = [inputs] inputs = [inputs]
# 过滤出inputs中的Tensor和Value实例
tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))] tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
# 调用op_infer.infer函数进行形状、数据类型和格式的推断
out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs) out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
# 创建一个新的Tensor实例作为输出
output = self.tensor(out_shape, out_dtype, out_format, name) output = self.tensor(out_shape, out_dtype, out_format, name)
# 执行操作并将结果存储在output中
self.op(prim, output, inputs, attrs) self.op(prim, output, inputs, attrs)
# 返回操作的结果
return output return output
def get(self): def get(self):
"""Get graphs""" """Get graphs"""
# 返回self.graphs
return self.graphs return self.graphs
@ -116,16 +166,21 @@ class CompositeGraph:
"""Composite Graph""" """Composite Graph"""
def __init__(self): def __init__(self):
# 初始化图对象默认为None
self.graph = None self.graph = None
# 初始化描述信息默认为None
self.desc = None self.desc = None
self.tensors = {} # name : Tensor # 初始化张量字典,默认为空字典
self.tensors = {}
def refine(self): def refine(self):
"""Refine Graph""" """Refine Graph"""
# 对图进行形状对齐操作
AlignShape().visit_graph(self.graph) AlignShape().visit_graph(self.graph)
def load(self, desc): def load(self, desc):
"""Load Graph from json""" """Load Graph from json"""
# 定义一个内部函数,用于处理操作属性
def _attr_of(op): def _attr_of(op):
if not op['attr']: if not op['attr']:
return dict() return dict()
@ -134,21 +189,29 @@ class CompositeGraph:
if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin', 'Argmax', 'Argmin'): if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin', 'Argmax', 'Argmin'):
attr['reduce_axis'] = a['value'] attr['reduce_axis'] = a['value']
else: else:
# 将属性添加到字典中
attr[a['name']] = a['value'] attr[a['name']] = a['value']
return attr return attr
# 创建GraphBuilder对象
builder = GraphBuilder() builder = GraphBuilder()
# 在描述的操作范围内构建图
with builder.graph_scope(desc['op']): with builder.graph_scope(desc['op']):
# 遍历输入描述并构建输入张量
for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []: for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []:
name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[ name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[
0]['shape'], in_desc[0]['data_type'], in_desc[0]['format'] 0]['shape'], in_desc[0]['data_type'], in_desc[0]['format']
# 将输入张量添加到tensors字典中
self.tensors[name] = builder.tensor( self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT) shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT)
# 遍历输出描述并构建输出张量
for out_desc in desc['output_desc']: for out_desc in desc['output_desc']:
name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[ name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[
'shape'], out_desc['data_type'], out_desc['format'] 'shape'], out_desc['data_type'], out_desc['format']
# 将输出张量添加到tensors字典中
self.tensors[name] = builder.tensor( self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
# 遍历操作描述并构建操作
for op in desc['op_desc']: for op in desc['op_desc']:
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d] inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
out_desc = op['output_desc'] out_desc = op['output_desc']
@ -164,52 +227,77 @@ class CompositeGraph:
if not output: if not output:
output = builder.tensor(shape, dtype, data_format, name=name) output = builder.tensor(shape, dtype, data_format, name=name)
self.tensors[name] = output self.tensors[name] = output
# 构建操作并添加到图中
builder.op(op['name'], output, inputs, attrs=_attr_of(op)) builder.op(op['name'], output, inputs, attrs=_attr_of(op))
# 获取构建好的图
self.graph = builder.get()[0] self.graph = builder.get()[0]
self.desc = desc self.desc = desc
def add_stitch_info(self, subgraph, desc): def add_stitch_info(self, subgraph, desc):
"""add stitch info to desc""" """add stitch info to desc"""
# 如果subgraph包含stitch信息且stitch_ops不为空
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops: if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
# 创建一个字典用于存储stitch操作信息
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)} buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
# 如果subgraph包含stitch_atomic_ops信息
if subgraph.stitch_info.stitch_atomic_ops: if subgraph.stitch_info.stitch_atomic_ops:
# 将stitch_atomic_ops信息添加到buffer_stitch字典中
buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops) buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops)
# 将buffer_stitch信息添加到desc字典中
desc['buffer_stitch'] = buffer_stitch desc['buffer_stitch'] = buffer_stitch
return desc return desc
def add_recompute_ops(self, subgraph, desc): def add_recompute_ops(self, subgraph, desc):
"""add recompute ops to desc""" """add recompute ops to desc"""
# 如果subgraph中包含需要重新计算的操作
if subgraph.recompute_ops: if subgraph.recompute_ops:
# 将需要重新计算的操作的输出名称添加到desc中
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops] desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
return desc return desc
def _pre_dump(self, outputs): def _pre_dump(self, outputs):
"""restore name to before load""" """restore name to before load"""
# 创建一个空字典用于存储inplace赋值操作
inplace_assign = {} # y_name, output_name inplace_assign = {} # y_name, output_name
inplace_assign_z = None inplace_assign_z = None
# 遍历self.desc['op_desc']中的操作
for op in self.desc['op_desc']: for op in self.desc['op_desc']:
# 如果操作名称为'InplaceAssign'
if op['name'] == 'InplaceAssign': if op['name'] == 'InplaceAssign':
# 将inplace赋值操作的输入tensor名作为键输出tensor名作为值存入inplace_assign字典
inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name'] inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name']
# 如果inplace_assign字典不为空
if inplace_assign: if inplace_assign:
# 遍历outputs中的tensor
for t in outputs: for t in outputs:
# 如果当前tensor的名称不在inplace_assign字典中
if t.name not in inplace_assign: if t.name not in inplace_assign:
# 将当前tensor赋值给inplace_assign_z
inplace_assign_z = t inplace_assign_z = t
# 返回inplace_assign和inplace_assign_z
return inplace_assign, inplace_assign_z return inplace_assign, inplace_assign_z
def dump(self, subgraph): def dump(self, subgraph):
"""Dump Graph to json""" """Dump Graph to json"""
desc = {} desc = {}
# 获取输入和输出参数
inputs, outputs = subgraph.deduce_parameters() inputs, outputs = subgraph.deduce_parameters()
# 获取图中的所有操作
graph_ops = set(subgraph.ops) graph_ops = set(subgraph.ops)
# 预处理输出参数
inplace_assign, inplace_assign_z = self._pre_dump(outputs) inplace_assign, inplace_assign_z = self._pre_dump(outputs)
def dump_output(t): def dump_output(t):
# 如果输出参数是原地赋值操作的结果
if t.name in inplace_assign: if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name] z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
# 返回包含数据类型、形状和张量名称的字典
return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)} return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)}
# 返回包含数据类型、形状和张量名称的字典
return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name} return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
def dump_op_desc(d): def dump_op_desc(d):
# 如果操作是原地赋值操作
if d['name'] == 'InplaceAssign': if d['name'] == 'InplaceAssign':
y = d['input_desc'][1][0]['tensor_name'] y = d['input_desc'][1][0]['tensor_name']
if self.tensors[y].op in graph_ops: if self.tensors[y].op in graph_ops:
@ -222,33 +310,50 @@ class CompositeGraph:
z_desc['tensor_name'] = z.name z_desc['tensor_name'] = z.name
out_desc['shape'] = z.shape out_desc['shape'] = z.shape
out_desc['data_type'] = z.dtype out_desc['data_type'] = z.dtype
# 返回处理后的原地赋值操作描述
return inplace_desc return inplace_desc
# 获取操作对应的张量
op = self.tensors[d['output_desc'][0]['tensor_name']].op op = self.tensors[d['output_desc'][0]['tensor_name']].op
# 如果操作在图操作集或重新计算操作集中
if op in graph_ops or op in subgraph.recompute_ops: if op in graph_ops or op in subgraph.recompute_ops:
# 返回操作描述
return d return d
# 返回None
return None return None
for key in self.desc.keys(): for key in self.desc.keys():
if key == 'input_desc': if key == 'input_desc':
# 处理输入描述
desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs] desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
elif key == 'output_desc': elif key == 'output_desc':
# 处理输出描述
desc[key] = list(map(dump_output, outputs)) desc[key] = list(map(dump_output, outputs))
elif key == 'op_desc': elif key == 'op_desc':
# 处理操作描述
op_desc = map(dump_op_desc, self.desc[key]) op_desc = map(dump_op_desc, self.desc[key])
desc[key] = [d for d in op_desc if d is not None] desc[key] = [d for d in op_desc if d is not None]
elif key == 'op': elif key == 'op':
# 处理操作名称
desc[key] = subgraph.name desc[key] = subgraph.name
else: else:
# 处理其他描述
desc[key] = self.desc[key] desc[key] = self.desc[key]
# 添加缝合信息
desc = self.add_stitch_info(subgraph, desc) desc = self.add_stitch_info(subgraph, desc)
# 添加重新计算操作信息
desc = self.add_recompute_ops(subgraph, desc) desc = self.add_recompute_ops(subgraph, desc)
# 返回最终描述
return desc return desc
def load_composite(desc): def load_composite(desc):
"""Load composite kernel""" """Load composite kernel"""
# 创建一个CompositeGraph对象
composite = CompositeGraph() composite = CompositeGraph()
# 加载描述信息
composite.load(desc) composite.load(desc)
# 对加载的CompositeGraph进行细化
composite.refine() composite.refine()
# 返回处理后的CompositeGraph对象
return composite return composite

@ -25,24 +25,32 @@ def infer(op_name, inputs, attrs):
"""infer shape dtype and format""" """infer shape dtype and format"""
def _create_opinfer(): def _create_opinfer():
# 获取当前模块
self_module = sys.modules.get(__name__, None) self_module = sys.modules.get(__name__, None)
# 如果当前模块为空,则抛出异常
if self_module is None: if self_module is None:
raise GKException("OpInfo does not support op {}".format(op_name)) raise GKException("OpInfo does not support op {}".format(op_name))
# 如果当前模块有op_name属性则获取该属性
if hasattr(self_module, op_name): if hasattr(self_module, op_name):
op_cls = getattr(self_module, op_name) op_cls = getattr(self_module, op_name)
return op_cls(op_name, inputs, attrs) return op_cls(op_name, inputs, attrs)
# common infer # common infer
# 定义一个字典将PrimLib中的iter_type映射到对应的类名
class_name_map = { class_name_map = {
PrimLib.ELEMWISE: "_Elemwise", PrimLib.ELEMWISE: "_Elemwise",
PrimLib.REDUCE: "_Reduce", PrimLib.REDUCE: "_Reduce",
} }
# 获取op_name对应的iter_type
cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None) cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None)
# 如果没有对应的iter_type则抛出异常
if not cls_name: if not cls_name:
raise GKException("OpInfo does not support op {}".format(op_name)) raise GKException("OpInfo does not support op {}".format(op_name))
# 获取对应的类
op_cls = getattr(self_module, cls_name) op_cls = getattr(self_module, cls_name)
return op_cls(op_name, inputs, attrs) return op_cls(op_name, inputs, attrs)
# 返回infer方法
return _create_opinfer().infer() return _create_opinfer().infer()
@ -55,49 +63,65 @@ class OpInfer:
""" """
def __init__(self, name, inputs, attrs): def __init__(self, name, inputs, attrs):
# 初始化函数传入参数name、inputs、attrs
self.name = name self.name = name
self.inputs = inputs self.inputs = inputs
self.attrs = attrs self.attrs = attrs
def infer(self): def infer(self):
"""Infer shape, type and format by op inputs""" """Infer shape, type and format by op inputs"""
# 根据op的输入推断shape、type和format
self._check() self._check()
return self._infer_shape(), self._infer_type(), self._infer_format() return self._infer_shape(), self._infer_type(), self._infer_format()
def _infer_shape(self): def _infer_shape(self):
# 根据op的输入推断shape
return self.inputs[0].shape return self.inputs[0].shape
def _infer_type(self): def _infer_type(self):
# 根据op的输入推断type
return self.inputs[0].dtype return self.inputs[0].dtype
def _infer_format(self): def _infer_format(self):
# 根据op的输入推断format
return self.inputs[0].data_format return self.inputs[0].data_format
def _check(self): def _check(self):
# 检查shape、type和format
self._check_shape() self._check_shape()
self._check_type() self._check_type()
self._check_format() self._check_format()
def _check_shape(self): def _check_shape(self):
# 检查shape
pass pass
def _check_type(self): def _check_type(self):
"""check all dtypes are same""" """check all dtypes are same"""
# 获取第一个输入的dtype
dtype = self.inputs[0].dtype dtype = self.inputs[0].dtype
# 遍历剩下的输入
for i, t in enumerate(self.inputs[1:]): for i, t in enumerate(self.inputs[1:]):
# 如果当前输入的dtype与第一个输入的dtype不同则抛出异常
if t.dtype != dtype: if t.dtype != dtype:
raise GKException( raise GKException(
"Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype)) "Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype))
def _check_format(self): def _check_format(self):
"""check formats are compatible. only DefaultFormat is compatible with others""" """check formats are compatible. only DefaultFormat is compatible with others"""
# 获取第一个输入的data_format
result = self.inputs[0].data_format result = self.inputs[0].data_format
# 初始化i为0
i = 0 i = 0
# 遍历剩下的输入
for j, t in enumerate(self.inputs[1:]): for j, t in enumerate(self.inputs[1:]):
# 如果当前输入的data_format与第一个输入的data_format不同则进行判断
if t.data_format != result: if t.data_format != result:
# 如果第一个输入的data_format和当前输入的data_format都不是DefaultFormat则抛出异常
if DF.DEFAULT not in (result, t.data_format): if DF.DEFAULT not in (result, t.data_format):
raise GKException("Incompatible format between input {}({}) and {}({})".format( raise GKException("Incompatible format between input {}({}) and {}({})".format(
i, result, j + 1, t.data_format)) i, result, j + 1, t.data_format))
# 如果第一个输入的data_format是DefaultFormat则将result设置为当前输入的data_format并将i设置为j+1
if result == DF.DEFAULT: if result == DF.DEFAULT:
result = t.data_format result = t.data_format
i = j + 1 i = j + 1
@ -109,17 +133,26 @@ class _Elemwise(OpInfer):
@staticmethod @staticmethod
def broadcast_shape(shapes): def broadcast_shape(shapes):
"""deduce broadcast shape using same rules as numpy""" """deduce broadcast shape using same rules as numpy"""
# 计算所有shape的最大维度
dim_size = max(len(shape) for shape in shapes) dim_size = max(len(shape) for shape in shapes)
# 将所有shape扩展到最大维度不足的部分用1填充
align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes] align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
# 初始化输出shape为全1
out_shape = [1] * dim_size out_shape = [1] * dim_size
# 遍历每个维度
for i in range(dim_size): for i in range(dim_size):
# 遍历每个shape
for align_shape in align_shapes: for align_shape in align_shapes:
# 如果当前维度为1则跳过
if align_shape[i] == 1: if align_shape[i] == 1:
continue continue
# 如果输出shape当前维度为1则将输出shape当前维度设置为当前shape当前维度的值
if out_shape[i] == 1: if out_shape[i] == 1:
out_shape[i] = align_shape[i] out_shape[i] = align_shape[i]
# 如果输出shape当前维度和当前shape当前维度不相等则抛出异常
elif out_shape[i] != align_shape[i]: elif out_shape[i] != align_shape[i]:
raise GKException("Input shapes {} can not broadcast.".format(shapes)) raise GKException("Input shapes {} can not broadcast.".format(shapes))
# 返回输出shape
return out_shape return out_shape
@staticmethod @staticmethod
@ -174,9 +207,12 @@ class _Elemwise(OpInfer):
.format(inputs_format)) .format(inputs_format))
def _infer_format(self): def _infer_format(self):
# 遍历输入张量
for tensor in self.inputs: for tensor in self.inputs:
# 如果张量的数据格式不是默认格式,则返回该数据格式
if tensor.data_format != DF.DEFAULT: if tensor.data_format != DF.DEFAULT:
return tensor.data_format return tensor.data_format
# 如果所有输入张量的数据格式都是默认格式,则返回默认格式
return DF.DEFAULT return DF.DEFAULT
@ -184,6 +220,7 @@ class _Reduce(OpInfer):
"""Common infer for reduction operators""" """Common infer for reduction operators"""
def _check(self): def _check(self):
# 调用父类的方法
super(_Reduce, self)._check() super(_Reduce, self)._check()
# check reduce axis in the range [-len, len) # check reduce axis in the range [-len, len)
shape_len = len(self.inputs[0].shape) shape_len = len(self.inputs[0].shape)
@ -195,21 +232,29 @@ class _Reduce(OpInfer):
"Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis)) "Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis))
def _infer_shape(self): def _infer_shape(self):
# 深度拷贝输入的形状
shape = copy.deepcopy(self.inputs[0].shape) shape = copy.deepcopy(self.inputs[0].shape)
# 获取reduce_axis属性
axis = self.attrs['reduce_axis'] axis = self.attrs['reduce_axis']
# 如果axis是整数则将其转换为列表
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis] axis = [axis]
# 如果axis中的元素小于0则将其转换为非负数
if any(i < 0 for i in axis): if any(i < 0 for i in axis):
# change the axis to non-negative number. # change the axis to non-negative number.
# 将axis中的负数转换为正数
axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis)) axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis))
# 将axis排序
self.attrs['reduce_axis'] = sorted(axis) self.attrs['reduce_axis'] = sorted(axis)
# 如果keep_dims为True则将axis中的维度设置为1
if self.attrs['keep_dims']: if self.attrs['keep_dims']:
for i in axis: for i in axis:
shape[i] = 1 shape[i] = 1
return shape return shape
# 如果keep_dims为False则将axis中的维度从shape中移除
real_shape = [] real_shape = []
for i, s in enumerate(shape): for i, s in enumerate(shape):
if i not in axis: if i not in axis:
@ -223,10 +268,14 @@ class _Reduce(OpInfer):
class _Reshape(OpInfer): class _Reshape(OpInfer):
"""Common infer for reshape operators, should not be instantiated""" """Common infer for reshape operators, should not be instantiated"""
# 定义一个函数,用于推断形状
def _infer_shape(self): def _infer_shape(self):
# 抛出一个异常,提示子类需要实现这个函数
raise GKException("_infer_shape should be implemented by subclass") raise GKException("_infer_shape should be implemented by subclass")
# 定义一个函数,用于推断格式
def _infer_format(self): def _infer_format(self):
# 如果attrs中不存在"format"这个属性则返回DF.DEFAULT否则返回attrs中"format"的值
return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"] return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"]
@ -234,14 +283,20 @@ class Reshape(_Reshape):
"""Reshape op infer""" """Reshape op infer"""
def _check_shape(self): def _check_shape(self):
# 获取输入形状
input_shape = self.inputs[0].shape input_shape = self.inputs[0].shape
# 获取输出形状
output_shape = self.attrs["shape"] output_shape = self.attrs["shape"]
# 计算输入形状的乘积
size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape) size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape)
# 计算输出形状的乘积
size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape) size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape)
# 如果输入形状的乘积不等于输出形状的乘积,则抛出异常
if size_before_reshape != size_after_reshape: if size_before_reshape != size_after_reshape:
raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape)) raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape))
def _infer_shape(self): def _infer_shape(self):
# 返回输出形状
return self.attrs["shape"] return self.attrs["shape"]
@ -249,6 +304,7 @@ class Cast(_Elemwise):
"""Cast op infer""" """Cast op infer"""
def _infer_type(self): def _infer_type(self):
# 返回dst_type属性
return self.attrs["dst_type"] return self.attrs["dst_type"]
@ -256,29 +312,38 @@ class InplaceAssign(_Elemwise):
"""InplaceAssign op infer""" """InplaceAssign op infer"""
def _infer_shape(self): def _infer_shape(self):
# 返回第3个输入的shape属性
return self.inputs[2].shape return self.inputs[2].shape
def _infer_type(self): def _infer_type(self):
# 返回第3个输入的dtype属性
return self.inputs[2].dtype return self.inputs[2].dtype
def _infer_format(self): def _infer_format(self):
# 返回第3个输入的data_format属性
return self.inputs[2].data_format return self.inputs[2].data_format
class BroadcastTo(OpInfer): class BroadcastTo(OpInfer):
"""BroadcastTo op infer""" """BroadcastTo op infer"""
# 定义一个函数,用于推断形状
def _infer_shape(self): def _infer_shape(self):
# 返回self.attrs字典中的"shape"键对应的值
return self.attrs["shape"] return self.attrs["shape"]
# 定义一个函数,用于推断格式
def _infer_format(self): def _infer_format(self):
# 返回self.inputs列表中第一个元素的data_format属性
return self.inputs[0].data_format return self.inputs[0].data_format
class _CompareOp(_Elemwise): class _CompareOp(_Elemwise):
"""Compare operators""" """Compare operators"""
# 定义一个函数,用于推断类型
def _infer_type(self): def _infer_type(self):
# 返回类型为bool
return "bool" return "bool"
@ -286,11 +351,14 @@ class CImag(OpInfer):
"""CImag op infer""" """CImag op infer"""
def _check_type(self): def _check_type(self):
# 检查输入数据的类型是否为complex64
if self.inputs[0].dtype != "complex64": if self.inputs[0].dtype != "complex64":
# 如果不是,则抛出异常
raise GKException( raise GKException(
"For 'CImag', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype)) "For 'CImag', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype))
def _infer_type(self): def _infer_type(self):
# 返回数据类型为float32
return "float32" return "float32"
@ -298,11 +366,14 @@ class CReal(OpInfer):
"""CReal op infer""" """CReal op infer"""
def _check_type(self): def _check_type(self):
# 检查输入数据的类型是否为complex64
if self.inputs[0].dtype != "complex64": if self.inputs[0].dtype != "complex64":
# 如果不是,则抛出异常
raise GKException( raise GKException(
"For 'CReal', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype)) "For 'CReal', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype))
def _infer_type(self): def _infer_type(self):
# 返回数据类型为float32
return "float32" return "float32"
@ -310,14 +381,17 @@ class Complex(OpInfer):
"""Complex op infer""" """Complex op infer"""
def _check_type(self): def _check_type(self):
# 检查输入数据的类型是否为float32
if self.inputs[0].dtype != "float32": if self.inputs[0].dtype != "float32":
raise GKException( raise GKException(
"For 'Complex', input[0] should be of type float32, but got {}".format(self.inputs[0].dtype)) "For 'Complex', input[0] should be of type float32, but got {}".format(self.inputs[0].dtype))
# 检查输入数据的类型是否一致
if self.inputs[0].dtype != self.inputs[1].dtype: if self.inputs[0].dtype != self.inputs[1].dtype:
raise GKException("For 'Complex', inputs data type mismatch ({} vs {})" raise GKException("For 'Complex', inputs data type mismatch ({} vs {})"
.format(self.inputs[0].dtype, self.inputs[1].dtype)) .format(self.inputs[0].dtype, self.inputs[1].dtype))
def _infer_type(self): def _infer_type(self):
# 返回复数类型
return "complex64" return "complex64"
@ -345,40 +419,53 @@ class Select(_Elemwise):
"""Select op infer""" """Select op infer"""
def _check_type(self): def _check_type(self):
# 检查输入数据的类型
if self.inputs[0].dtype != "bool": if self.inputs[0].dtype != "bool":
# 如果输入数据的类型不是bool则抛出异常
raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype)) raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype))
if self.inputs[1].dtype != self.inputs[2].dtype: if self.inputs[1].dtype != self.inputs[2].dtype:
# 如果输入数据的类型不一致,则抛出异常
raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})" raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})"
.format(self.inputs[1].dtype, self.inputs[2].dtype)) .format(self.inputs[1].dtype, self.inputs[2].dtype))
def _infer_type(self): def _infer_type(self):
# 推断输入数据的类型
return self.inputs[1].dtype return self.inputs[1].dtype
def check_format_any(formats, checked_format): def check_format_any(formats, checked_format):
"""Check whether input format in formats list""" """Check whether input format in formats list"""
# 检查输入格式是否在formats列表中
if not isinstance(formats, (list, tuple)): if not isinstance(formats, (list, tuple)):
# 如果formats不是list或tuple类型则抛出异常
raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats))) raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats)))
if checked_format not in formats: if checked_format not in formats:
# 如果checked_format不在formats列表中则抛出异常
raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats)) raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats))
def check_nd(data, nd): def check_nd(data, nd):
"""Check whether data are nd format""" """Check whether data are nd format"""
# 检查数据是否为nd格式
if not isinstance(data, (list, tuple)) or len(data) != nd: if not isinstance(data, (list, tuple)) or len(data) != nd:
# 如果数据不是list或tuple类型或者数据的维度不等于nd则抛出异常
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data)) raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
def conv_had_pad(pad_list, pad_mode): def conv_had_pad(pad_list, pad_mode):
"""Check whether conv need to add pad""" """Check whether conv need to add pad"""
# 检查pad_list是否为4D list或tuple如果不是则抛出异常
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4: if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list)) raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
# 如果pad_list的前两个元素不相等或后两个元素不相等则返回True
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]: if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
return True return True
# 如果pad_mode不是"VALID"或"valid"则遍历pad_list如果有元素不为0则返回True
if pad_mode not in ["VALID", "valid"]: if pad_mode not in ["VALID", "valid"]:
for _, pad in enumerate(pad_list): for _, pad in enumerate(pad_list):
if pad != 0: if pad != 0:
return True return True
# 否则返回False
return False return False
@ -386,38 +473,50 @@ class Conv2D(OpInfer):
"""Conv2D infer""" """Conv2D infer"""
def _infer_type(self): def _infer_type(self):
# 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值否则返回输入的第一个元素的dtype
if isinstance(self.attrs, dict) and "dst_type" in self.attrs: if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"] return self.attrs["dst_type"]
return self.inputs[0].dtype return self.inputs[0].dtype
def _infer_shape(self): def _infer_shape(self):
# 将输入的第一个和第二个元素的shape转换为list
shape_0 = list(self.inputs[0].shape) shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape) shape_1 = list(self.inputs[1].shape)
# 检查shape_0和shape_1的维度是否为4
check_nd(shape_0, 4) check_nd(shape_0, 4)
check_nd(shape_1, 4) check_nd(shape_1, 4)
# 检查输入的data_format是否为NHWC
formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]] formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]]
check_format_any(formats, DF.NHWC) check_format_any(formats, DF.NHWC)
# 获取输入的n、h、w和out_channel
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0] n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
# 获取pad_list和pad_mode
pad_list = self.attrs["pad_list"] pad_list = self.attrs["pad_list"]
pad_mode = self.attrs["pad_mode"] pad_mode = self.attrs["pad_mode"]
# 获取kernel_size、stride和dilation
kernel_size = self.attrs["kernel_size"] kernel_size = self.attrs["kernel_size"]
stride = self.attrs["stride"] stride = self.attrs["stride"]
dilation = self.attrs["dilation"] dilation = self.attrs["dilation"]
# 检查pad_list、kernel_size、stride和dilation的维度是否为4、2、4和4
check_nd(pad_list, 4) check_nd(pad_list, 4)
check_nd(kernel_size, 2) check_nd(kernel_size, 2)
check_nd(stride, 4) check_nd(stride, 4)
check_nd(dilation, 4) check_nd(dilation, 4)
# 调用conv_had_pad函数判断是否需要pad
has_pad = conv_had_pad(pad_list, pad_mode) has_pad = conv_had_pad(pad_list, pad_mode)
# 如果不需要pad则将pad_list设置为[0, 0, 0, 0]
if not has_pad: if not has_pad:
pad_list = [0, 0, 0, 0] pad_list = [0, 0, 0, 0]
# 计算输出的h和w
k_h = (kernel_size[0] - 1) * dilation[-2] + 1 k_h = (kernel_size[0] - 1) * dilation[-2] + 1
k_w = (kernel_size[1] - 1) * dilation[-1] + 1 k_w = (kernel_size[1] - 1) * dilation[-1] + 1
out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1 out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1 out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
# 返回输出的shape
return [n, out_h, out_w, out_channel] return [n, out_h, out_w, out_channel]
@ -425,23 +524,31 @@ class MatMul(OpInfer):
"""MatMul infer""" """MatMul infer"""
def _infer_type(self): def _infer_type(self):
# 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值否则返回输入的第一个元素的dtype
if isinstance(self.attrs, dict) and "dst_type" in self.attrs: if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"] return self.attrs["dst_type"]
return self.inputs[0].dtype return self.inputs[0].dtype
def _infer_shape(self): def _infer_shape(self):
# 将输入的第一个和第二个元素的shape转换为list
shape_0 = list(self.inputs[0].shape) shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape) shape_1 = list(self.inputs[1].shape)
# 检查shape_0和shape_1的维度是否为2
if len(shape_0) != 2 or len(shape_1) != 2: if len(shape_0) != 2 or len(shape_1) != 2:
raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}" raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}"
.format(shape_0, shape_1)) .format(shape_0, shape_1))
# 获取transpose_a和transpose_b
transpose_a = self.attrs["transpose_a"] transpose_a = self.attrs["transpose_a"]
transpose_b = self.attrs["transpose_b"] transpose_b = self.attrs["transpose_b"]
# 根据transpose_a和transpose_b获取m、k1、k2和n
m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1]) m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1])
k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1]) k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1])
# 如果k1和k2不相等则抛出异常
if k1 != k2: if k1 != k2:
raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2)) raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2))
# 计算输出的shape
output_shape = [m, n] output_shape = [m, n]
# 返回输出的shape
return output_shape return output_shape
@ -449,14 +556,20 @@ class PadAkg(OpInfer):
"""PadAkg infer""" """PadAkg infer"""
def _infer_shape(self): def _infer_shape(self):
# 将输入的第一个元素的shape转换为list
shape = list(self.inputs[0].shape) shape = list(self.inputs[0].shape)
# 获取输入的维度
n = len(shape) n = len(shape)
# 获取pad_before和pad_after
pad_before = list(self.attrs["head"]) pad_before = list(self.attrs["head"])
pad_after = list(self.attrs["tail"]) pad_after = list(self.attrs["tail"])
# 检查pad_before和pad_after的维度是否与输入的维度相等
if len(pad_before) != n or len(pad_after) != n: if len(pad_before) != n or len(pad_after) != n:
raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d" raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d"
.format(n, len(pad_before), len(pad_after))) .format(n, len(pad_before), len(pad_after)))
# 计算输出的shape
out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)] out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
# 返回输出的shape
return out_shape return out_shape
@ -464,13 +577,19 @@ class UnPadAkg(OpInfer):
"""UnPadAkg infer""" """UnPadAkg infer"""
def _infer_shape(self): def _infer_shape(self):
# 将输入的第一个元素的shape转换为list
shape = list(self.inputs[0].shape) shape = list(self.inputs[0].shape)
# 获取输入的维度
n = len(shape) n = len(shape)
# 获取unpad_after
unpad_after = list(self.attrs["tail"]) unpad_after = list(self.attrs["tail"])
# 检查unpad_after的维度是否与输入的维度相等
if len(unpad_after) != n: if len(unpad_after) != n:
raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d" raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d"
.format(n, len(unpad_after))) .format(n, len(unpad_after)))
# 计算输出的shape
out_shape = [shape[i] - unpad_after[i] for i in range(n)] out_shape = [shape[i] - unpad_after[i] for i in range(n)]
# 返回输出的shape
return out_shape return out_shape
@ -478,23 +597,32 @@ class Gather(OpInfer):
"""Gather infer""" """Gather infer"""
def _infer_shape(self): def _infer_shape(self):
# 获取输入的第一个和第二个元素的shape
input_shape = self.inputs[0].shape input_shape = self.inputs[0].shape
indices_shape = self.inputs[1].shape indices_shape = self.inputs[1].shape
# 获取axis
axis = self.attrs['axis'] axis = self.attrs['axis']
# 将输出的shape设置为输入的第一个元素的shape
output_shape = input_shape output_shape = input_shape
# 计算indices_shape的维度
indices_shape_one_dim = 1 indices_shape_one_dim = 1
for dim in indices_shape: for dim in indices_shape:
indices_shape_one_dim *= dim indices_shape_one_dim *= dim
# 将输出的shape的axis维度设置为indices_shape的维度
output_shape[axis] = indices_shape_one_dim output_shape[axis] = indices_shape_one_dim
# 返回输出的shape
return output_shape return output_shape
def _infer_type(self): def _infer_type(self):
# 返回输入的第一个元素的dtype
return self.inputs[0].dtype return self.inputs[0].dtype
def _infer_format(self): def _infer_format(self):
# 返回输入的第一个元素的data_format
return self.inputs[0].data_format return self.inputs[0].data_format
def _check_type(self): def _check_type(self):
# 检查输入的第二个元素的dtype是否为int32如果不是则抛出异常
if self.inputs[1].dtype != "int32": if self.inputs[1].dtype != "int32":
raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}" raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}"
.format(self.inputs[1].dtype)) .format(self.inputs[1].dtype))

@ -21,24 +21,48 @@ from . import model
def estimate_ops(json_str): def estimate_ops(json_str):
"""
估计操作数
Args:
json_str (str): 包含图描述的json字符串
Returns:
tuple: 包含估计结果的元组包括块分配增益融合类型和类型信息的元组
Raises:
JSONDecodeError: 如果输入的json字符串无法解码将引发此异常
"""
"""Call cost model to estimate ops.""" """Call cost model to estimate ops."""
try: try:
# 将json字符串转换为json对象
json_obj = json.loads(json_str) json_obj = json.loads(json_str)
# 获取json对象中的graph_desc
graph_descs = json_obj["graph_desc"] graph_descs = json_obj["graph_desc"]
# 初始化graphs和target
graphs = [] graphs = []
target = None target = None
# 遍历graph_descs
for gd in graph_descs: for gd in graph_descs:
# 如果target为空则将gd['process']赋值给target
if target is None: if target is None:
target = gd['process'] target = gd['process']
# 如果target不为空且gd['process']与target不同则输出错误信息
elif target != gd['process']: elif target != gd['process']:
logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process'])) logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process']))
return None return None
# 将model.load_composite(gd).graph添加到graphs中
graphs.append(model.load_composite(gd).graph) graphs.append(model.load_composite(gd).graph)
# 调用model.parallel_estimate函数传入graphs和target获取estimation
estimation = model.parallel_estimate(graphs, target) estimation = model.parallel_estimate(graphs, target)
# 将estimation的block_assign、gain、fusion_type和type_info赋值给res
res = (estimation.block_assign, estimation.gain, res = (estimation.block_assign, estimation.gain,
estimation.fusion_type, estimation.type_info) estimation.fusion_type, estimation.type_info)
# 返回res
return res return res
except jd.JSONDecodeError: except jd.JSONDecodeError:
# 如果出现JSONDecodeError则输出错误信息
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return None return None
finally: finally:
@ -46,14 +70,33 @@ def estimate_ops(json_str):
def estimate_calculation_amount(json_str): def estimate_calculation_amount(json_str):
"""
估计操作计算量的函数
Args:
json_str (str): 包含操作描述的JSON字符串
Returns:
int: 计算量的估计值如果解析JSON字符串失败则返回-1
Raises:
"""
"""Call cost model to estimate calculation amount of op.""" """Call cost model to estimate calculation amount of op."""
try: try:
# 将json字符串转换为json对象
graph_desc = json.loads(json_str) graph_desc = json.loads(json_str)
# 获取json对象中的process
target = graph_desc['process'] target = graph_desc['process']
# 调用model.load_composite函数传入graph_desc获取comp
comp = model.load_composite(graph_desc) comp = model.load_composite(graph_desc)
# 调用model.parallel_estimate函数传入comp.graph和target获取estimation
estimation = model.parallel_estimate([comp.graph], target) estimation = model.parallel_estimate([comp.graph], target)
# 返回estimation的bottleneck
return estimation.bottleneck return estimation.bottleneck
except jd.JSONDecodeError: except jd.JSONDecodeError:
# 如果出现JSONDecodeError则输出错误信息
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return -1 return -1
finally: finally:

@ -24,6 +24,20 @@ from . import utils
def split_with_json(json_str, flags_str): def split_with_json(json_str, flags_str):
"""
根据JSON字符串分割GraphKernel
Args:
json_str (str): 包含GraphKernel描述的JSON字符串
flags_str (str): 包含分割标志的JSON字符串
Returns:
str: 包含分割结果的JSON字符串
Raises:
jd.JSONDecodeError: 如果json_str或flags_str无法被解析为JSON格式将引发此异常
"""
"""Call cost model to split GraphKernel""" """Call cost model to split GraphKernel"""
try: try:
graph_desc = json.loads(json_str) graph_desc = json.loads(json_str)
@ -45,6 +59,21 @@ def split_with_json(json_str, flags_str):
def _reset_graphmode_for_inplaceassign(graph_list, graph_mode): def _reset_graphmode_for_inplaceassign(graph_list, graph_mode):
"""
重置具有 InplaceAssign 操作符的图模式
Args:
graph_list (list): 包含图的列表每个图都是一个包含操作描述的字典
graph_mode (list): 图模式列表每个元素表示对应图的模式
Returns:
None
Notes:
具有 InplaceAssign 操作符的操作应始终为复合操作
对于包含 InplaceAssign 操作符的图将其模式设置为 'composite'
"""
"""Operator with InplaceAssign should always be composite op""" """Operator with InplaceAssign should always be composite op"""
for i, g in enumerate(graph_list): for i, g in enumerate(graph_list):
if any((op['name'] == 'InplaceAssign' for op in g['op_desc'])): if any((op['name'] == 'InplaceAssign' for op in g['op_desc'])):
@ -52,6 +81,20 @@ def _reset_graphmode_for_inplaceassign(graph_list, graph_mode):
def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode): def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
"""
将分割信息以文本形式输出
Args:
flags (dict): 包含配置信息的字典
graph_json (str): 图结构的JSON字符串
graph_desc (object): 图描述对象
subgraphs (list): 子图列表
graph_mode (list): 图模式列表
Returns:
None
"""
"""Dump split info as text""" """Dump split info as text"""
if not flags.get("dump_as_text", False): if not flags.get("dump_as_text", False):
return return

@ -19,6 +19,19 @@ GRAPH_KERNEL_DUMP_PATH = "graph_kernel_dump"
def create_dir(pathname): def create_dir(pathname):
"""
尝试创建目录
Args:
pathname (str): 要创建的目录的路径
Returns:
None
Raises:
不显式抛出异常
"""
"""Try to create directory""" """Try to create directory"""
if os.path.exists(pathname): if os.path.exists(pathname):
return return

@ -16,4 +16,8 @@
Extension functions. Extension functions.
Python functions that will be called in the c++ parts of MindSpore. Python functions that will be called in the c++ parts of MindSpore.
扩展函数
这些Python函数将在MindSpore的C++部分中被调用
""" """

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""akg process""" """akg process"""
import os import os
import json import json
import subprocess import subprocess
@ -24,10 +26,10 @@ from mindspore._extends.parallel_compile.akg_compiler.get_file_path import get_a
def _compile_akg_task_default(json_strs, attrs): def _compile_akg_task_default(json_strs, attrs):
""" """
compile func called in single process 编译函数在单个进程中调用
Parameters: 参数
json_strs: list. List contains multiple kernel infos, suitable for json compile api. json_strs列表包含多个内核信息的列表适用于json编译API
""" """
sys.path.insert(0, get_akg_path()) sys.path.insert(0, get_akg_path())
@ -37,15 +39,15 @@ def _compile_akg_task_default(json_strs, attrs):
for json_str in json_strs: for json_str in json_strs:
res = func(json_str, attrs) res = func(json_str, attrs)
if not res: if not res:
raise ValueError("Compile error, args: {}! build attrs: {}".format(json_str, attrs)) raise ValueError("编译错误,参数:{}!构建属性:{}".format(json_str, attrs))
def _compile_akg_task_ascend(json_strs, attrs): def _compile_akg_task_ascend(json_strs, attrs):
""" """
compile func called in single process 编译函数在单个进程中调用
Parameters: 参数
json_strs: list. List contains multiple kernel infos, suitable for json compile api. json_strs列表包含多个内核信息的列表适用于json编译API
""" """
if attrs is None: if attrs is None:
attrs = "{}" attrs = "{}"
@ -56,35 +58,33 @@ def _compile_akg_task_ascend(json_strs, attrs):
if compile_result.returncode: if compile_result.returncode:
json_dict = json.loads(json_str) json_dict = json.loads(json_str)
if not json_dict.get("composite"): if not json_dict.get("composite"):
raise ValueError("Compile error, json str: {}! build attrs: {}".format(json_str, attrs)) raise ValueError("编译错误json字符串{}!构建属性:{}".format(json_str, attrs))
logger.debug("Will try to split, json str: {}! build attrs: {}".format(json_str, attrs)) logger.debug("将尝试拆分json字符串{}!构建属性:{}".format(json_str, attrs))
def create_akg_parallel_process(process_num, wait_time, platform): def create_akg_parallel_process(process_num, wait_time, platform):
""" """
create AkgParallelCompiler object 创建AkgParallelCompiler对象
Returns: 返回
AkgParallelCompiler AkgParallelCompiler
""" """
return AkgProcess(process_num, wait_time, platform) return AkgProcess(process_num, wait_time, platform)
class AkgProcess: class AkgProcess:
"""akg kernel parallel process""" """akg内核并行进程"""
def __init__(self, process_num, wait_time, platform): def __init__(self, process_num, wait_time, platform):
""" """
Args: 参数
process_num: int. processes number process_numint进程数量
wait_time: int. max time the function blocked wait_timeint函数阻塞的最大时间
""" """
if not isinstance(process_num, int): if not isinstance(process_num, int):
raise ValueError("AKG kernel compiling process number must be of type int, but got {} with type {}" raise ValueError("AKG内核编译进程数量必须是int类型但得到的是{},类型为{}".format(process_num, type(wait_time)))
.format(process_num, type(wait_time)))
if not isinstance(wait_time, int): if not isinstance(wait_time, int):
raise ValueError("AKG kernel compiling wait time must be of type int, but got {} with type {}" raise ValueError("AKG内核编译等待时间必须是int类型但得到的是{},类型为{}".format(wait_time, type(wait_time)))
.format(wait_time, type(wait_time)))
if process_num == 0: if process_num == 0:
process_num = 1 process_num = 1
max_proc_num = 16 max_proc_num = 16
@ -96,13 +96,12 @@ class AkgProcess:
def compile(self, attrs=None): def compile(self, attrs=None):
""" """
compile kernel by multi processes 多进程编译内核
Return: 返回
True for all compile success, False for some failed. 所有编译成功返回True部分失败返回False
""" """
if self.argc == 0: if self.argc == 0:
raise ValueError("In AKG kernel compiling, the number of kernel json that need to be compiled can " raise ValueError("在AKG内核编译中需要编译的内核json数量不能为零。")
"not be zero.")
args = list((arg, attrs) for arg in self.args) args = list((arg, attrs) for arg in self.args)
if self.platform == "ASCEND": if self.platform == "ASCEND":
with Pool(processes=self.process_num) as pool: with Pool(processes=self.process_num) as pool:
@ -116,12 +115,11 @@ class AkgProcess:
def accept_json(self, json_str): def accept_json(self, json_str):
""" """
accept json data before compile 在编译前接受内核的json数据
Args: 参数
json_str: str. kernel info. json_strstr内核信息
""" """
if not isinstance(json_str, str): if not isinstance(json_str, str):
raise ValueError("In AKG kernel compiling, the kernel json must be of type str, but got {} with type {}" raise ValueError("在AKG内核编译中内核json必须是str类型但得到的是{},类型为{}".format(json_str, type(json_str)))
.format(json, type(json)))
self.args[self.argc % self.process_num].append(json_str) self.args[self.argc % self.process_num].append(json_str)
self.argc += 1 self.argc += 1

@ -28,16 +28,23 @@ def run_compiler(op_json, attrs=None):
None None
""" """
from get_file_path import get_akg_path from get_file_path import get_akg_path
# 将akg路径添加到sys.path中
sys.path.insert(0, get_akg_path()) sys.path.insert(0, get_akg_path())
# 导入akg模块
p = __import__("akg", globals(), locals(), ['ms'], 0) p = __import__("akg", globals(), locals(), ['ms'], 0)
# 获取akg.ms.compilewithjson函数
func = getattr(p.ms, "compilewithjson") func = getattr(p.ms, "compilewithjson")
# 调用akg.ms.compilewithjson函数进行编译
res = func(op_json, attrs) res = func(op_json, attrs)
# 如果编译失败,抛出异常
if not res: if not res:
raise ValueError("Compile error") raise ValueError("Compile error")
if __name__ == "__main__": if __name__ == "__main__":
# 如果命令行参数大于2则调用run_compiler函数传入op_json和attrs
if len(sys.argv) > 2: if len(sys.argv) > 2:
run_compiler(sys.argv[1], sys.argv[2]) run_compiler(sys.argv[1], sys.argv[2])
# 否则只传入op_json
else: else:
run_compiler(sys.argv[1]) run_compiler(sys.argv[1])

@ -19,18 +19,27 @@ import os
def get_akg_path(): def get_akg_path():
"""get akg directory base path""" """get akg directory base path"""
# 提示信息如果找不到mindspore模块请检查1MindSpore是否成功编译。2MindSpore是否成功安装使用pip install安装或设置环境变量PYTHONPATH为${mindspore_build_dir}/package
hint = "Please check: 1) whether MindSpore is compiled successfully. " \ hint = "Please check: 1) whether MindSpore is compiled successfully. " \
"2) Whether MindSpore is installed successfully with pip install or " \ "2) Whether MindSpore is installed successfully with pip install or " \
"the path ${mindspore_build_dir}/package is set in env PYTHONPATH." "the path ${mindspore_build_dir}/package is set in env PYTHONPATH."
# 查找mindspore模块
search_res = importlib.util.find_spec("mindspore") search_res = importlib.util.find_spec("mindspore")
if search_res is None: if search_res is None:
# 如果找不到mindspore模块抛出异常
raise RuntimeError("Cannot find mindspore module! {}".format(hint)) raise RuntimeError("Cannot find mindspore module! {}".format(hint))
# 获取mindspore模块的路径
res_path = search_res.origin res_path = search_res.origin
# 在路径中查找__init__.py文件
find_pos = res_path.find("__init__.py") find_pos = res_path.find("__init__.py")
if find_pos == -1: if find_pos == -1:
# 如果找不到__init__.py文件抛出异常
raise RuntimeError("Find module mindspore origin file failed! {}".format(hint)) raise RuntimeError("Find module mindspore origin file failed! {}".format(hint))
# 获取akg路径
akg_path = "{}_akg".format(res_path[:find_pos]) akg_path = "{}_akg".format(res_path[:find_pos])
# 如果akg路径不存在抛出异常
if not os.path.isdir(akg_path): if not os.path.isdir(akg_path):
raise RuntimeError("Cannot find akg from mindspore module! {}".format(hint)) raise RuntimeError("Cannot find akg from mindspore module! {}".format(hint))
# 返回akg路径
return akg_path return akg_path

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""tbe adapter to adapt te/topi/auto-tune python api """ """tbe adapter to adapt te/topi/auto-tune python api """
# 导入必要的库和模块
import json import json
import os import os
import shutil import shutil
@ -20,33 +21,62 @@ import sys
import traceback import traceback
from datetime import datetime from datetime import datetime
# 导入TBE相关的库和模块
from tbe.common.rl_bank.bank_manager import set_current_op_name from tbe.common.rl_bank.bank_manager import set_current_op_name
from tbe.common.repository_manager.interface import cann_kb_unload, cann_kb_load from tbe.common.repository_manager.interface import cann_kb_unload, cann_kb_load
from tbe.common.rl_bank.bank_cfg import LocalLock from tbe.common.rl_bank.bank_cfg import LocalLock
from te.platform.cce_conf import te_set_version from te.platform.cce_conf import te_set_version
from te.platform.cce_policy import set_L1_info from te.platform.cce_policy import set_L1_info
from te_fusion.compile_task_manager import dispatch_prebuild_task, dispatch_single_op_compile_task, import_py_module, \ from te_fusion.compile_task_manager import (
dispatch_fusion_op_compile_task, dispatch_autotune_task, sync_op_tune_params dispatch_prebuild_task,
from te_fusion.compile_task_manager import sync_syspath dispatch_single_op_compile_task,
from te_fusion.fusion_manager import call_op_func, clear_fusion_params, check_op_impl_mode, \ import_py_module,
save_op_params, build_single_op_from_c, op_params_to_json dispatch_fusion_op_compile_task,
dispatch_autotune_task,
sync_op_tune_params,
sync_syspath
)
from te_fusion.fusion_manager import (
call_op_func,
clear_fusion_params,
check_op_impl_mode,
save_op_params,
build_single_op_from_c,
op_params_to_json
)
from te_fusion.fusion_util import dump_fusion_json from te_fusion.fusion_util import dump_fusion_json
from te_fusion.parallel_compilation import init_multi_process_env, start_ga_multi_process, deinit_multi_process_env, \ from te_fusion.parallel_compilation import (
init_multi_process_env,
start_ga_multi_process,
deinit_multi_process_env,
get_finished_compilation_task get_finished_compilation_task
)
from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \ from .tbe_helper import (
adjust_custom_op_info, pack_op_args, get_module_name, get_real_op_debug_level get_soc_info,
assemble_op_args,
get_compute_op_list,
get_options_info,
get_fuzz_build_info,
adjust_custom_op_info,
pack_op_args,
get_module_name,
get_real_op_debug_level
)
from .tbe_job import TbeJob, JobStatus from .tbe_job import TbeJob, JobStatus
PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"] # 定义支持的平台标志
PLATFORM_FLAG = [
"Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"
]
# 定义Tune初始化函数
def _tune_init(job: TbeJob): def _tune_init(job: TbeJob):
""" """
Tune Initialize Tune初始化
:param job: :param job: TbeJob对象包含任务信息
:return: :return: 初始化是否成功
""" """
# 提取Soc信息和Tune信息
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"] auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
offline_tune = job.content["SocInfo"]["offlineTune"] offline_tune = job.content["SocInfo"]["offlineTune"]
op_bank_update = job.content["SocInfo"]["op_bank_update"] op_bank_update = job.content["SocInfo"]["op_bank_update"]
@ -54,11 +84,14 @@ def _tune_init(job: TbeJob):
tune_bank_path = job.content["TuneInfo"]["tune_bank_path"] tune_bank_path = job.content["TuneInfo"]["tune_bank_path"]
need_ga = bool("GA" in auto_tiling_mode) need_ga = bool("GA" in auto_tiling_mode)
need_rl = bool("RL" in auto_tiling_mode) need_rl = bool("RL" in auto_tiling_mode)
# 设置环境变量
if offline_tune: if offline_tune:
os.environ["ENABLE_TUNE_DUMP"] = "TRUE" os.environ["ENABLE_TUNE_DUMP"] = "TRUE"
if op_bank_update: if op_bank_update:
sync_op_tune_params("tbe.common.tiling.tiling_api", "reset_repository", False, "") sync_op_tune_params("tbe.common.tiling.tiling_api", "reset_repository", False, "")
# 初始化Tune环境
if need_ga or need_rl or offline_tune: if need_ga or need_rl or offline_tune:
res = __init_tune_env(job, need_ga) res = __init_tune_env(job, need_ga)
if not res: if not res:
@ -66,6 +99,7 @@ def _tune_init(job: TbeJob):
else: else:
return True return True
# 设置Tune路径
if tune_dump_path: if tune_dump_path:
os.environ["TUNE_DUMP_PATH"] = str(tune_dump_path) os.environ["TUNE_DUMP_PATH"] = str(tune_dump_path)
if tune_bank_path: if tune_bank_path:
@ -73,12 +107,12 @@ def _tune_init(job: TbeJob):
res = _creating_custom_path(job) res = _creating_custom_path(job)
return res return res
# 定义CANN知识库加载函数
def _cann_kb_load(job: TbeJob): def _cann_kb_load(job: TbeJob):
""" """
database load 加载CANN知识库
:param job: :param job: TbeJob对象包含任务信息
:return: :return: 加载是否成功
""" """
soc_version = job.soc_version soc_version = job.soc_version
core_num = job.core_num core_num = job.core_num
@ -87,12 +121,12 @@ def _cann_kb_load(job: TbeJob):
res = cann_kb_load(soc_version, core_num, op_bank_path, kb_type) res = cann_kb_load(soc_version, core_num, op_bank_path, kb_type)
return res return res
# 定义CANN知识库卸载函数
def _cann_kb_unload(job: TbeJob): def _cann_kb_unload(job: TbeJob):
""" """
database unload 卸载CANN知识库
:param job: :param job: TbeJob对象包含任务信息
:return: :return: 卸载是否成功
""" """
if job is None: if job is None:
return 0 return 0
@ -102,12 +136,12 @@ def _cann_kb_unload(job: TbeJob):
res = cann_kb_unload(soc_version, core_num, kb_type) res = cann_kb_unload(soc_version, core_num, kb_type)
return res return res
# 定义移除缓存文件函数
def _remove_cache(job: TbeJob): def _remove_cache(job: TbeJob):
""" """
:param job: remove cache file:[*.json, *.o, *.info, *.cce] when "op_debug_level" is "0" 移除缓存文件
op_debug_level: representation the env MS_COMPILER_OP_LEVEL :param job: TbeJob对象包含任务信息
:return: :return:
""" """
op_debug_level = job.content["SocInfo"]["op_debug_level"] op_debug_level = job.content["SocInfo"]["op_debug_level"]
op_debug_dir = job.content["SocInfo"]["op_debug_dir"] op_debug_dir = job.content["SocInfo"]["op_debug_dir"]
@ -118,24 +152,30 @@ def _remove_cache(job: TbeJob):
real_path = os.path.join(root_path, "kernel_meta/") real_path = os.path.join(root_path, "kernel_meta/")
shutil.rmtree(real_path) shutil.rmtree(real_path)
# 定义创建目录函数
def __directory_creation(path, concat_path): def __directory_creation(path, concat_path):
""" """
Create directory 创建目录
:param path: 基础路径
:param concat_path: 需要连接的路径
:return: 创建后的完整路径
""" """
path = os.path.join(path, concat_path) path = os.path.join(path, concat_path)
if not os.path.isdir(path): if not os.path.isdir(path):
os.makedirs(path, 0o750) os.makedirs(path, 0o750)
return path return path
# 定义初始化Tune环境函数
def __init_tune_env(job, need_ga): def __init_tune_env(job, need_ga):
""" """
Initialize tune env 初始化Tune环境
:param job: TbeJob对象包含任务信息
:param need_ga: 是否需要GA
:return: 初始化是否成功
""" """
try: try:
import auto_tune.auto_tune_main as at_atm import auto_tune.auto_tune_main as at_atm
from schedule_search.rl_online_tune import rl_tune_init # pylint: disable=unused-import from schedule_search.rl_online_tune import rl_tune_init
if need_ga: if need_ga:
res = at_atm.ga_tune_init() res = at_atm.ga_tune_init()
if not res: if not res:
@ -157,10 +197,13 @@ def __init_tune_env(job, need_ga):
finally: finally:
pass pass
# 定义创建默认自定义路径函数
def __creating_default_custom_path(auto_tiling_mode, base_custom_path): def __creating_default_custom_path(auto_tiling_mode, base_custom_path):
""" """
Create default custom path 创建默认自定义路径
:param auto_tiling_mode: 自动平铺模式
:param base_custom_path: 基础自定义路径
:return:
""" """
base_custom_path = __directory_creation(base_custom_path, "data") base_custom_path = __directory_creation(base_custom_path, "data")
tune_flag = [] tune_flag = []
@ -179,27 +222,40 @@ def __creating_default_custom_path(auto_tiling_mode, base_custom_path):
def _creating_custom_path(job): def _creating_custom_path(job):
""" """
Create custom path 创建自定义路径用于存储和检索自定义算子的调优参数
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 自定义路径创建是否成功
""" """
# 获取自动平铺模式
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"] auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
# 如果模式中包含"NO_TUNE",则不需要创建自定义路径
if "NO_TUNE" in auto_tiling_mode: if "NO_TUNE" in auto_tiling_mode:
return True return True
# 获取调优参数的基础路径
base_custom_path = job.content["TuneInfo"]["tune_bank_path"] base_custom_path = job.content["TuneInfo"]["tune_bank_path"]
tune_bank_flag = True tune_bank_flag = True
# 如果基础路径不存在则尝试从auto_tune模块获取
if not base_custom_path: if not base_custom_path:
import auto_tune import auto_tune
base_custom_path = os.path.dirname(os.path.realpath(auto_tune.__file__)) base_custom_path = os.path.dirname(os.path.realpath(auto_tune.__file__))
base_custom_path = os.path.realpath(os.path.join(base_custom_path, "../../../")) base_custom_path = os.path.realpath(os.path.join(base_custom_path, "../../../"))
tune_bank_flag = False tune_bank_flag = False
# 检查基础路径是否存在
if not os.path.isdir(base_custom_path): if not os.path.isdir(base_custom_path):
job.error("Check whether the tuning path [{}] exists.".format(base_custom_path)) job.error("Check whether the tuning path [{}] exists.".format(base_custom_path))
return False return False
# 检查基础路径的权限
if not os.access(base_custom_path, os.R_OK | os.W_OK | os.X_OK): if not os.access(base_custom_path, os.R_OK | os.W_OK | os.X_OK):
job.error("Check whether the permission on the tuning path [{}] is correct.".format(base_custom_path)) job.error("Check whether the permission on the tuning path [{}] is correct.".format(base_custom_path))
return False return False
# 如果不需要创建调优参数库,则直接返回成功
if not tune_bank_flag: if not tune_bank_flag:
return __creating_default_custom_path(auto_tiling_mode, base_custom_path) return __creating_default_custom_path(auto_tiling_mode, base_custom_path)
return True return True
@ -207,22 +263,34 @@ def _creating_custom_path(job):
def _parallel_compilation_init(initialize: TbeJob): def _parallel_compilation_init(initialize: TbeJob):
""" """
Tbe parallel compilation initialize 初始化TBE并行编译环境
:param initialize:
:return: Args:
initialize (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 并行编译环境初始化是否成功
""" """
# 设置并行编译器的环境变量
os.environ["TE_PARALLEL_COMPILER"] = str(initialize.content["process_num"]) os.environ["TE_PARALLEL_COMPILER"] = str(initialize.content["process_num"])
# 获取SoC信息
soc_info = get_soc_info(initialize.content) soc_info = get_soc_info(initialize.content)
# 获取实际的调试级别
real_debug_level = get_real_op_debug_level(initialize.content) real_debug_level = get_real_op_debug_level(initialize.content)
# 获取自动平铺模式
auto_tiling_mode = initialize.content["SocInfo"]["autoTilingMode"] auto_tiling_mode = initialize.content["SocInfo"]["autoTilingMode"]
# 获取是否需要离线调优
offline_tune = initialize.content["SocInfo"]["offlineTune"] offline_tune = initialize.content["SocInfo"]["offlineTune"]
# 生成进程ID和时间戳的组合字符串
pid_ts = "{}_pid{}".format(datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3], os.getpid()) pid_ts = "{}_pid{}".format(datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3], os.getpid())
# 初始化多进程环境
ret = init_multi_process_env(False, soc_info, auto_tiling_mode, real_debug_level, ret = init_multi_process_env(False, soc_info, auto_tiling_mode, real_debug_level,
None, 1, pid_ts) None, 1, pid_ts)
if ret is None: if ret is None:
initialize.error("Init multiprocess env failed") initialize.error("Init multiprocess env failed")
return False return False
initialize.info("Init multiprocess env success with {} process".format(ret[0])) initialize.info("Init multiprocess env success with {} process".format(ret[0]))
# 如果需要RL或离线调优则初始化RL环境
if "RL" in auto_tiling_mode or offline_tune: if "RL" in auto_tiling_mode or offline_tune:
res_queue = ret[1] res_queue = ret[1]
live_checker = ret[2] live_checker = ret[2]
@ -234,6 +302,7 @@ def _parallel_compilation_init(initialize: TbeJob):
initialize.error("RL env init failed!") initialize.error("RL env init failed!")
return False return False
initialize.info("RL Tune init success.") initialize.info("RL Tune init success.")
# 如果需要GA则启动GA多进程
if "GA" in auto_tiling_mode: if "GA" in auto_tiling_mode:
start_ga_multi_process(auto_tiling_mode) start_ga_multi_process(auto_tiling_mode)
initialize.info("GA Tune init success.") initialize.info("GA Tune init success.")
@ -242,31 +311,44 @@ def _parallel_compilation_init(initialize: TbeJob):
def tbe_initialize(job: TbeJob): def tbe_initialize(job: TbeJob):
""" """
Tbe Initialize 初始化TBE环境
:param job:
:return: Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: TBE环境初始化是否成功
""" """
# 设置上下文模型编译环境变量
os.environ["CONTEXT_MODELCOMPILING"] = "TRUE" os.environ["CONTEXT_MODELCOMPILING"] = "TRUE"
# 获取SoC信息
soc_info = get_soc_info(job.content) soc_info = get_soc_info(job.content)
# 设置版本
res = te_set_version(*soc_info) res = te_set_version(*soc_info)
if not res: if not res:
job.error("Set version failed") job.error("Set version failed")
# 初始化调优环境
res = _tune_init(job) res = _tune_init(job)
if not res: if not res:
job.error("Tune init failed") job.error("Tune init failed")
# 创建锁文件
lock_file = os.path.join(job.content["SocInfo"]["op_debug_dir"], "kernel_meta", "file.lock") lock_file = os.path.join(job.content["SocInfo"]["op_debug_dir"], "kernel_meta", "file.lock")
local_lock = LocalLock(lock_file) local_lock = LocalLock(lock_file)
try: try:
# 加锁
local_lock.lock() local_lock.lock()
# 加载CANN知识库
res = _cann_kb_load(job) res = _cann_kb_load(job)
if res == 1: if res == 1:
job.error("Cann kb load failed") job.error("Cann kb load failed")
# 初始化并行编译
res = _parallel_compilation_init(job) res = _parallel_compilation_init(job)
if not res: if not res:
job.error("Parallel compilation failed") job.error("Parallel compilation failed")
except RuntimeError: except RuntimeError:
job.error("Initialize failed with RuntimeError") job.error("Initialize failed with RuntimeError")
finally: finally:
# 解锁
local_lock.unlock() local_lock.unlock()
job.result = "Success" job.result = "Success"
return res return res
@ -274,9 +356,13 @@ def tbe_initialize(job: TbeJob):
def get_auto_tune_support_op_list(job: TbeJob): def get_auto_tune_support_op_list(job: TbeJob):
""" """
Get GA tune supported op list 获取支持自动调优的算子列表
:param job:
:return: Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
list: 支持自动调优的算子列表
""" """
from auto_tune_main import enable_auto_tune_support from auto_tune_main import enable_auto_tune_support
auto_tune_op_list = enable_auto_tune_support() auto_tune_op_list = enable_auto_tune_support()
@ -286,10 +372,14 @@ def get_auto_tune_support_op_list(job: TbeJob):
def _normalize_module_name(module_name, py_module_path): def _normalize_module_name(module_name, py_module_path):
""" """
Normalize module name 规范化模块名称
:param module_name:
:param py_module_path: Args:
:return: module_name (str): 模块名称
py_module_path (str): Python模块路径
Returns:
None
""" """
if py_module_path not in sys.path: if py_module_path not in sys.path:
sys.path.insert(0, py_module_path) sys.path.insert(0, py_module_path)
@ -298,9 +388,13 @@ def _normalize_module_name(module_name, py_module_path):
def check_support(job: TbeJob): def check_support(job: TbeJob):
""" """
Check support 检查算子是否受支持
:param job:
:return: Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 算子是否受支持
""" """
op_compute_info_list = get_compute_op_list(job.content) op_compute_info_list = get_compute_op_list(job.content)
if len(op_compute_info_list) != 1: if len(op_compute_info_list) != 1:
@ -341,21 +435,37 @@ def check_support(job: TbeJob):
def select_op_format(job: TbeJob): def select_op_format(job: TbeJob):
""" """
Select op format Select op format
:param job: 根据计算操作信息选择操作的格式
:return:
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 操作格式选择是否成功
""" """
# 获取计算操作列表
compute_op_info_list = get_compute_op_list(job.content) compute_op_info_list = get_compute_op_list(job.content)
# 检查计算操作数量是否为1
if len(compute_op_info_list) != 1: if len(compute_op_info_list) != 1:
job.error("Invalid op compute num ({}) in check_support".format(len(compute_op_info_list))) job.error("Invalid op compute num ({}) in check_support".format(len(compute_op_info_list)))
return False return False
# 获取第一个计算操作信息
compute_op_info = compute_op_info_list[0] compute_op_info = compute_op_info_list[0]
# 调整自定义操作信息
adjust_custom_op_info(compute_op_info) adjust_custom_op_info(compute_op_info)
# 组装操作参数
inputs, outputs, attrs = assemble_op_args(compute_op_info) inputs, outputs, attrs = assemble_op_args(compute_op_info)
# 获取操作模块名称
op_module_name = get_module_name(compute_op_info) op_module_name = get_module_name(compute_op_info)
# 获取Python模块路径
py_module_path = compute_op_info["py_module_path"] py_module_path = compute_op_info["py_module_path"]
# 规范化模块名称
_normalize_module_name(op_module_name, py_module_path) _normalize_module_name(op_module_name, py_module_path)
# 设置操作选择格式的函数名称
op_func_name = "op_select_format" op_func_name = "op_select_format"
# 调用操作函数选择格式
res = call_op_func((inputs, outputs, attrs), op_module_name, op_func_name) res = call_op_func((inputs, outputs, attrs), op_module_name, op_func_name)
# 设置操作格式选择结果
job.result = str(res) job.result = str(res)
return True return True
@ -363,15 +473,25 @@ def select_op_format(job: TbeJob):
def parallel_pre_compile_op(job: TbeJob): def parallel_pre_compile_op(job: TbeJob):
""" """
Parallel pre compile op Parallel pre compile op
:param job: 并行预编译操作
:return:
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 预编译操作是否成功
""" """
# 获取计算操作列表
compute_op_info_list = get_compute_op_list(job.content) compute_op_info_list = get_compute_op_list(job.content)
# 检查计算操作数量是否为1
if len(compute_op_info_list) != 1: if len(compute_op_info_list) != 1:
job.error("Invalid op compute num ({}) in pre compile op".format(len(compute_op_info_list))) job.error("Invalid op compute num ({}) in pre compile op".format(len(compute_op_info_list)))
return False return False
# 获取第一个计算操作信息
compute_op_info = compute_op_info_list[0] compute_op_info = compute_op_info_list[0]
# 调整自定义操作信息
adjust_custom_op_info(compute_op_info) adjust_custom_op_info(compute_op_info)
# 预构建计算操作信息
_pre_build_compute_op_info(compute_op_info, job) _pre_build_compute_op_info(compute_op_info, job)
return True return True
@ -379,35 +499,60 @@ def parallel_pre_compile_op(job: TbeJob):
def _pre_build_compute_op_info(compute_op, job): def _pre_build_compute_op_info(compute_op, job):
""" """
Prebuild by compute op info Prebuild by compute op info
:param compute_op: 根据计算操作信息预构建操作
:param job:
:return: Args:
compute_op (dict): 计算操作信息
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
None
""" """
# 获取L1缓存大小
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
# 如果L1缓存大小不为-1则设置L1缓存信息
if l1_size != -1: if l1_size != -1:
set_L1_info("op_L1_space", -1) set_L1_info("op_L1_space", -1)
# 组装操作参数
inputs, outputs, attrs = assemble_op_args(compute_op, is_single_op_build=True) inputs, outputs, attrs = assemble_op_args(compute_op, is_single_op_build=True)
# 获取操作模块名称
op_module_name = get_module_name(compute_op) op_module_name = get_module_name(compute_op)
# 获取Python模块路径
py_module_path = compute_op["py_module_path"] py_module_path = compute_op["py_module_path"]
# 获取操作函数名称
op_func_name = compute_op["func_name"] op_func_name = compute_op["func_name"]
# 获取操作类型
op_type = compute_op["type"] op_type = compute_op["type"]
# 获取操作名称
op_name = compute_op["op_name"] op_name = compute_op["op_name"]
# 保存操作参数
save_op_params(op_name, "prebuild", (outputs, attrs)) save_op_params(op_name, "prebuild", (outputs, attrs))
l1_size = job.content["l1_size"] # 设置L1缓存信息
set_L1_info("op_L1_space", l1_size) set_L1_info("op_L1_space", l1_size)
# 规范化模块名称
_normalize_module_name(op_module_name, py_module_path) _normalize_module_name(op_module_name, py_module_path)
# 获取未知形状信息
unknown_shape = compute_op["unknown_shape"] unknown_shape = compute_op["unknown_shape"]
# 获取int64模式信息
int64_mode = compute_op["int64mode"] int64_mode = compute_op["int64mode"]
# 检查操作实现模式
res = check_op_impl_mode(op_module_name, op_func_name) res = check_op_impl_mode(op_module_name, op_func_name)
# 获取操作实现模式
op_impl_mode = job.content["SocInfo"]["op_impl_mode"] op_impl_mode = job.content["SocInfo"]["op_impl_mode"]
# 获取操作实现模式列表
op_impl_mode_list = job.content["SocInfo"]["op_impl_mode_list"] op_impl_mode_list = job.content["SocInfo"]["op_impl_mode_list"]
# 获取完整操作名称
op_full_name = job.content["full_name"] op_full_name = job.content["full_name"]
# 如果操作不支持实现模式,则发出警告
if not res: if not res:
if op_impl_mode_list: if op_impl_mode_list:
job.warning("The op {} do NOT support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode)) job.warning("The op {} do NOT support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
else: else:
# 否则,记录操作支持实现模式的信息
job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode)) job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
# 获取选项信息
options = get_options_info(job.content) options = get_options_info(job.content)
# 分派预构建任务
dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_full_name, dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_full_name,
op_type, op_func_name, unknown_shape, op_type, op_func_name, unknown_shape,
(inputs, outputs, attrs, options), int64_mode, unknown_shape, (inputs, outputs, attrs, options), int64_mode, unknown_shape,
@ -416,13 +561,22 @@ def _pre_build_compute_op_info(compute_op, job):
def get_prebuild_output(op_name): def get_prebuild_output(op_name):
""" """
get prebuild output Get prebuild output
:param op_name: 获取预构建输出
Args:
op_name (str): 操作名称
Returns:
dict: 预构建输出
""" """
# 将操作参数转换为JSON字符串
params_str = op_params_to_json(op_name) params_str = op_params_to_json(op_name)
try: try:
# 尝试解析JSON字符串
res = json.loads(params_str) res = json.loads(params_str)
except ValueError: except ValueError:
# 如果解析失败,则返回空字典
res = {} res = {}
finally: finally:
pass pass
@ -432,9 +586,15 @@ def get_prebuild_output(op_name):
def do_fuzz_build_tbe_op(job: TbeJob): def do_fuzz_build_tbe_op(job: TbeJob):
""" """
Fuzzy build op Fuzzy build op
:param job: 模糊构建操作
:return:
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 模糊构建操作是否成功
""" """
# 设置操作结果为"NOT_CHANGED"
job.result = "NOT_CHANGED" job.result = "NOT_CHANGED"
return True return True
@ -442,9 +602,15 @@ def do_fuzz_build_tbe_op(job: TbeJob):
def _dump_fusion_op_info_to_json_file(job: TbeJob): def _dump_fusion_op_info_to_json_file(job: TbeJob):
""" """
Dump fusion op info to json file Dump fusion op info to json file
:param job: 将融合操作信息转储到JSON文件
:return:
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
None
""" """
# 如果系统参数调试路径不为空,则转储融合操作信息
if not job.sys_para_debug_path or job.sys_para_debug_path == "\0": if not job.sys_para_debug_path or job.sys_para_debug_path == "\0":
return return
dump_fusion_json(json.dumps(job.content), job.sys_para_debug_path) dump_fusion_json(json.dumps(job.content), job.sys_para_debug_path)
@ -453,30 +619,55 @@ def _dump_fusion_op_info_to_json_file(job: TbeJob):
def build_single_pre_op(job: TbeJob): def build_single_pre_op(job: TbeJob):
""" """
Build single op Build single op
:param job: 构建单个操作的预处理过程
:return:
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 构建过程是否成功
""" """
# 执行构建前的处理工作
before_build_process(job) before_build_process(job)
# 获取计算操作列表
compute_op_info_list = get_compute_op_list(job.content) compute_op_info_list = get_compute_op_list(job.content)
# 确保只有一个计算操作
if len(compute_op_info_list) != 1: if len(compute_op_info_list) != 1:
job.error("Invalid op compute num ({}) in build single op".format(len(compute_op_info_list))) job.error("Invalid op compute num ({}) in build single op".format(len(compute_op_info_list)))
return False return False
# 获取单个计算操作信息
compute_op_info = compute_op_info_list[0] compute_op_info = compute_op_info_list[0]
# 调整自定义操作信息
adjust_custom_op_info(compute_op_info) adjust_custom_op_info(compute_op_info)
# 组装操作的输入、输出和属性
inputs, outputs, attrs = assemble_op_args(compute_op_info, is_single_op_build=True) inputs, outputs, attrs = assemble_op_args(compute_op_info, is_single_op_build=True)
# 获取操作类型
op_type = compute_op_info["type"] op_type = compute_op_info["type"]
# 获取L1缓存大小
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
# 获取操作模块名称
op_module_name = get_module_name(compute_op_info) op_module_name = get_module_name(compute_op_info)
# 获取操作内核名称
op_kernel_name = compute_op_info["op_name"] op_kernel_name = compute_op_info["op_name"]
# 获取Python模块路径
py_module_path = compute_op_info["py_module_path"] py_module_path = compute_op_info["py_module_path"]
# 获取完整操作名称
op_name = job.content["full_name"] op_name = job.content["full_name"]
# 获取操作函数名称
op_func_name = compute_op_info["func_name"] op_func_name = compute_op_info["func_name"]
# 规范化模块名称
_normalize_module_name(op_module_name, py_module_path) _normalize_module_name(op_module_name, py_module_path)
# 获取未知形状信息
unknown_shape = compute_op_info["unknown_shape"] unknown_shape = compute_op_info["unknown_shape"]
# 获取int64模式信息
int64_mode = compute_op_info["int64mode"] int64_mode = compute_op_info["int64mode"]
# 获取操作模式
op_pattern = compute_op_info["pattern"] op_pattern = compute_op_info["pattern"]
# 获取选项信息
options = get_options_info(job.content) options = get_options_info(job.content)
# 获取模糊构建信息
fuzz_build_info = get_fuzz_build_info(job.content) fuzz_build_info = get_fuzz_build_info(job.content)
# 分派单个操作编译任务
dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_name, op_type, op_func_name, dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_name, op_type, op_func_name,
op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode, op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode,
None, None, unknown_shape, op_pattern, None, None, unknown_shape, op_pattern,
@ -487,13 +678,22 @@ def build_single_pre_op(job: TbeJob):
def before_build_process(job: TbeJob): def before_build_process(job: TbeJob):
""" """
Processing before build Processing before build
:param job: 在构建前进行处理
:return:
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
None
""" """
# 获取L1缓存大小并设置
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
set_L1_info("op_L1_space", l1_size) set_L1_info("op_L1_space", l1_size)
# 将融合操作信息转储到JSON文件
_dump_fusion_op_info_to_json_file(job) _dump_fusion_op_info_to_json_file(job)
# 获取是否需要离线调优
offline_tune = job.sys_offline_tune offline_tune = job.sys_offline_tune
# 如果需要离线调优则将融合操作信息转储到JSON文件
if offline_tune: if offline_tune:
dump_fusion_json(json.dumps(job.content), job.sys_tune_dump_path) dump_fusion_json(json.dumps(job.content), job.sys_tune_dump_path)
@ -501,20 +701,29 @@ def before_build_process(job: TbeJob):
def sync_fusion_env(fusion_need_sync, module_list): def sync_fusion_env(fusion_need_sync, module_list):
""" """
Sync fusion env Sync fusion env
:param fusion_need_sync: 同步融合环境
:param module_list:
:return: Args:
fusion_need_sync (int): 是否需要同步融合环境
module_list (dict): 模块列表
Returns:
bool: 同步是否成功
""" """
# 如果不需要同步,则直接返回成功
if fusion_need_sync == 0: if fusion_need_sync == 0:
return True return True
# 准备使用的模块列表
module_using = [] module_using = []
for key, value in module_list.items(): for key, value in module_list.items():
if value > 0: if value > 0:
module_using.append(str(key)) module_using.append(str(key))
module_list[key] = 0 module_list[key] = 0
# 将使用的模块列表转换为字符串
module_str = ",".join(module_using) module_str = ",".join(module_using)
# 导入使用的模块
import_py_module(module_str) import_py_module(module_str)
return True return True
@ -522,13 +731,23 @@ def sync_fusion_env(fusion_need_sync, module_list):
def parallel_compile_fusion_op(job: TbeJob): def parallel_compile_fusion_op(job: TbeJob):
""" """
Compile fusion op in parallel compiler Compile fusion op in parallel compiler
:param job: 在并行编译器中编译融合操作
:return:
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 编译过程是否成功
""" """
# 获取L1缓存大小
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
# 获取选项信息
options = get_options_info(job.content) options = get_options_info(job.content)
# 获取融合操作内核名称
op_kernel_name = job.content["fusion_op_name"] op_kernel_name = job.content["fusion_op_name"]
# 获取完整操作名称
op_name = job.content["full_name"] op_name = job.content["full_name"]
# 分派融合操作编译任务
dispatch_fusion_op_compile_task(job.source_id, job.id, l1_size, json.dumps(job.content), op_kernel_name, None, None, dispatch_fusion_op_compile_task(job.source_id, job.id, l1_size, json.dumps(job.content), op_kernel_name, None, None,
options, None, job.pass_list, op_name) options, None, job.pass_list, op_name)
return True return True
@ -537,112 +756,185 @@ def parallel_compile_fusion_op(job: TbeJob):
def ga_tune(job: TbeJob): def ga_tune(job: TbeJob):
""" """
GA tune GA tune
:param job: 使用遗传算法进行调优
:return:
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 调优过程是否成功
""" """
# 获取L1缓存大小
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
# 获取融合操作内核名称
op_kernel_name = job.content["fusion_op_name"] op_kernel_name = job.content["fusion_op_name"]
# 获取完整操作名称
op_name = job.content["full_name"] op_name = job.content["full_name"]
# 分派自动调优任务
dispatch_autotune_task(job.source_id, job.id, l1_size, json.dumps(job.content), {}, op_kernel_name, op_name) dispatch_autotune_task(job.source_id, job.id, l1_size, json.dumps(job.content), {}, op_kernel_name, op_name)
# 设置任务状态为运行中
job.status = JobStatus.JOB_RUNNING job.status = JobStatus.JOB_RUNNING
return True return True
def rl_tune_single_op(job: TbeJob): def rl_tune_single_op(job: TbeJob):
""" """
RL tune single op Perform RL (Reinforcement Learning) tuning for a single operation.
:param job:
:return: This function is responsible for tuning a single operation using RL techniques.
It retrieves the operation's information, performs the tuning, and handles any exceptions that may occur during the process.
Args:
job (TbeJob): An object containing job information, including the operation to be tuned.
Returns:
bool: True if the RL tuning is successful, False otherwise.
""" """
# Retrieve the list of compute operations from the job content
compute_op_info_list = get_compute_op_list(job.content) compute_op_info_list = get_compute_op_list(job.content)
# Check if there is exactly one compute operation
if len(compute_op_info_list) != 1: if len(compute_op_info_list) != 1:
job.error("Invalid op compute num ({}) in rl tune single op".format(len(compute_op_info_list))) job.error("Invalid op compute num ({}) in rl tune single op".format(len(compute_op_info_list)))
return False return False
# Get the first (and only) compute operation info
compute_op_info = compute_op_info_list[0] compute_op_info = compute_op_info_list[0]
# Assemble the operation's input, output, and attributes
inputs, outputs, attrs = assemble_op_args(compute_op_info) inputs, outputs, attrs = assemble_op_args(compute_op_info)
# Get the operation type
op_type = compute_op_info["type"] op_type = compute_op_info["type"]
# Get the L1 size from the job content
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
# Get the operation module name
op_module_name = get_module_name(compute_op_info) op_module_name = get_module_name(compute_op_info)
# Get the operation kernel name
op_kernel_name = compute_op_info["op_name"] op_kernel_name = compute_op_info["op_name"]
# Get the full name of the operation
full_name = compute_op_info["name"] full_name = compute_op_info["name"]
# Get the Python module path
py_module_path = compute_op_info["py_module_path"] py_module_path = compute_op_info["py_module_path"]
# Get the operation function name
op_func_name = compute_op_info["func_name"] op_func_name = compute_op_info["func_name"]
# Normalize the module name
_normalize_module_name(op_module_name, py_module_path) _normalize_module_name(op_module_name, py_module_path)
# Set the current operation name
set_current_op_name(op_kernel_name) set_current_op_name(op_kernel_name)
# Get the unknown shape information
unknown_shape = compute_op_info["unknown_shape"] unknown_shape = compute_op_info["unknown_shape"]
# Get the int64 mode information
int64_mode = compute_op_info["int64mode"] int64_mode = compute_op_info["int64mode"]
# Get the operation pattern
op_pattern = compute_op_info["pattern"] op_pattern = compute_op_info["pattern"]
# Get the fuzz build information
fuzz_build_info = get_fuzz_build_info(job.content) fuzz_build_info = get_fuzz_build_info(job.content)
# Get the auto tiling mode
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"] auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
# Get the device ID
device_id = job.content["SocInfo"]["deviceId"] device_id = job.content["SocInfo"]["deviceId"]
# Get the options information
options = get_options_info(job.content) options = get_options_info(job.content)
try: try:
# Build the single operation from C code
build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape, build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape,
(inputs, outputs, attrs), int64_mode, unknown_shape, options, (inputs, outputs, attrs), int64_mode, unknown_shape, options,
op_pattern, auto_tiling_mode, device_id, json.dumps(fuzz_build_info)) op_pattern, auto_tiling_mode, device_id, json.dumps(fuzz_build_info))
# pylint: disable=broad-except
except Exception: except Exception:
# If an exception occurs, log the error and return False
job.error( job.error(
"Single op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string)) "Single op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string))
exc_type, exc_value, _ = sys.exc_info() exc_type, exc_value, _ = sys.exc_info()
job.error( job.error(
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc())) "exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
return False return False
finally: # Prepare the tuning operation module name
pass
tune_op_module_name = op_module_name + "@" + py_module_path tune_op_module_name = op_module_name + "@" + py_module_path
# Get the base kernel path
base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o" base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o"
# Dispatch the single tune task
from schedule_search.rl_online_tune import dispatch_single_tune_task from schedule_search.rl_online_tune import dispatch_single_tune_task
pack_args = pack_op_args(inputs, outputs, attrs) pack_args = pack_op_args(inputs, outputs, attrs)
res = dispatch_single_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, full_name, res = dispatch_single_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, full_name,
tune_op_module_name, op_func_name, op_type, pack_args) tune_op_module_name, op_func_name, op_type, pack_args)
# Process the RL tune result
return _process_rl_tune_result(job, op_type, res) return _process_rl_tune_result(job, op_type, res)
def rl_tune_fusion_op(job: TbeJob): def rl_tune_fusion_op(job: TbeJob):
""" """
rl tune fusion op Perform RL tuning for a fusion operation.
:param job:
:return: This function is responsible for tuning a fusion operation using RL techniques.
It compiles the operation using multiprocessing and handles any exceptions that may occur during the process.
Args:
job (TbeJob): An object containing job information, including the fusion operation to be tuned.
Returns:
bool: True if the RL tuning is successful, False otherwise.
""" """
# Get the fusion operation kernel name
op_kernel_name = job.content["fusion_op_name"] op_kernel_name = job.content["fusion_op_name"]
# Set the current operation name
set_current_op_name(op_kernel_name) set_current_op_name(op_kernel_name)
try: try:
# Compile the operation using multiprocessing
from schedule_search.rl_online_tune import compile_op_by_mp from schedule_search.rl_online_tune import compile_op_by_mp
compile_op_by_mp(json.dumps(job.content)) compile_op_by_mp(json.dumps(job.content))
# pylint: disable=broad-except # pylint: disable=broad-except
except Exception: except Exception:
# If an exception occurs, log the error and return False
job.error( job.error(
"Fusion op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string)) "Fusion op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string))
exc_type, exc_value, _ = sys.exc_info() exc_type, exc_value, _ = sys.exc_info()
job.error( job.error(
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc())) "exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
return False return False
finally: # Get the L1 size
pass
l1_size = job.content["l1_size"] l1_size = job.content["l1_size"]
# Get the base kernel path
base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o" base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o"
# Get the list of compute operations
compute_op_list = get_compute_op_list(job.content) compute_op_list = get_compute_op_list(job.content)
# Prepare the operation module names string
op_module_names_str = "" op_module_names_str = ""
op_type_set = set() op_type_set = set()
for op in compute_op_list: for op in compute_op_list:
op_module_names_str = ','.join([op_module_names_str, get_module_name(op)]) op_module_names_str = ','.join([op_module_names_str, get_module_name(op)])
op_type_set.add(op["type"]) op_type_set.add(op["type"])
# Remove the leading comma from the operation module names string
op_module_names_str = op_module_names_str[1:] op_module_names_str = op_module_names_str[1:]
# Join the operation types with double underscore
op_type = "__".join(list(op_type_set)) op_type = "__".join(list(op_type_set))
# Dispatch the fusion tune task
from schedule_search.rl_online_tune import dispatch_fusion_tune_task from schedule_search.rl_online_tune import dispatch_fusion_tune_task
res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str, res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str,
json.dumps(job.content)) json.dumps(job.content))
# Process the RL tune result
return _process_rl_tune_result(job, op_type, res) return _process_rl_tune_result(job, op_type, res)
def _process_rl_tune_result(job, op_type, res): def _process_rl_tune_result(job, op_type, res):
"""
Process the result of RL tuning.
If the tuning result is False, it checks if the operation type is in the black list or if the job is set to offline tune.
If the tuning result is True, it sets the job status to running.
Args:
job (TbeJob): An object containing job information.
op_type (str): The type of the operation.
res (bool): The result of RL tuning.
Returns:
bool: The processed result of RL tuning.
"""
if not res: if not res:
# Check if the operation type is in the black list or if the job is set to offline tune
from schedule_search.tune_util import filter_black_op_type from schedule_search.tune_util import filter_black_op_type
res = bool(job.sys_offline_tune or os.getenv("REPEAT_TUNE", "False").lower() != "true" or filter_black_op_type( res = bool(job.sys_offline_tune or os.getenv("REPEAT_TUNE", "False").lower() != "true" or filter_black_op_type(
op_type)) op_type))
else: else:
# Set the job status to running
job.status = JobStatus.JOB_RUNNING job.status = JobStatus.JOB_RUNNING
res = True res = True
return res return res
@ -650,8 +942,13 @@ def _process_rl_tune_result(job, op_type, res):
def get_finish_tasks(source_id): def get_finish_tasks(source_id):
""" """
Get finish task from parallel compilation framework Get the list of finished tasks from the parallel compilation framework.
:return task info list
Args:
source_id (int): The source ID of the tasks.
Returns:
list: A list of finished task information.
""" """
return get_finished_compilation_task(source_id) return get_finished_compilation_task(source_id)
@ -664,14 +961,21 @@ def tbe_finalize(auto_tiling_mode, offline_tune, job: TbeJob):
:param job: TbeJob :param job: TbeJob
:return: None :return: None
""" """
# 释放多进程环境
deinit_multi_process_env() deinit_multi_process_env()
# 如果自动切分模式为RL或者离线调优则释放RL调优
if "RL" in auto_tiling_mode or offline_tune: if "RL" in auto_tiling_mode or offline_tune:
from schedule_search.rl_online_tune import rl_tune_deinit from schedule_search.rl_online_tune import rl_tune_deinit
rl_tune_deinit() rl_tune_deinit()
# 卸载Cann kb
res = _cann_kb_unload(job) res = _cann_kb_unload(job)
# 如果卸载失败则返回False
if res == 1: if res == 1:
job.error("Cann kb unload failed") job.error("Cann kb unload failed")
return False return False
# 清除融合参数
clear_fusion_params() clear_fusion_params()
# 删除缓存
_remove_cache(job) _remove_cache(job)
# 返回True
return True return True

@ -26,6 +26,7 @@ class BuildType(Enum):
ACCURATELY = "accurately" ACCURATELY = "accurately"
# 获取JobType枚举类中的所有值
job_type_list = [job_type.value for _, job_type in JobType.__members__.items()] job_type_list = [job_type.value for _, job_type in JobType.__members__.items()]
@ -35,14 +36,19 @@ def check_job_json(job_info):
:param job_info:tne compilation job json :param job_info:tne compilation job json
:return: raise value error if wrong :return: raise value error if wrong
""" """
# 检查job_info中是否包含source_id
if 'source_id' not in job_info: if 'source_id' not in job_info:
raise ValueError("Json string Errors, key:source_id not found.") raise ValueError("Json string Errors, key:source_id not found.")
# 检查job_info中是否包含job_id
if 'job_id' not in job_info: if 'job_id' not in job_info:
raise ValueError("Json string Errors, key:job_id not found.") raise ValueError("Json string Errors, key:job_id not found.")
# 检查job_info中是否包含job_type
if 'job_type' not in job_info or not job_info['job_type']: if 'job_type' not in job_info or not job_info['job_type']:
raise ValueError("Json string Errors, key:job_type not found.") raise ValueError("Json string Errors, key:job_type not found.")
# 检查job_info中job_type是否在job_type_list中
if job_info['job_type'] not in job_type_list: if job_info['job_type'] not in job_type_list:
raise ValueError("Invalid job type: {}.".format(job_info['job_type'])) raise ValueError("Invalid job type: {}.".format(job_info['job_type']))
# 检查job_info中是否包含job_content
if 'job_content' not in job_info: if 'job_content' not in job_info:
raise ValueError("Json string Errors, key:job_content not found.") raise ValueError("Json string Errors, key:job_content not found.")
@ -52,6 +58,7 @@ def reset_op_debug_level_in_soc_info(level):
:param level: op_debug_level, if level is 3 or 4, replace it with 0 :param level: op_debug_level, if level is 3 or 4, replace it with 0
:return: op_debug_level :return: op_debug_level
""" """
# 如果level为3或4则将其替换为0
if level in ("3", "4"): if level in ("3", "4"):
level = "0" level = "0"
return level return level
@ -62,6 +69,7 @@ def get_real_op_debug_level(initialize_job_info):
:param initialize_job_info: initialize_job_info :param initialize_job_info: initialize_job_info
:return: origin op_debug_level for init_multi_process_env :return: origin op_debug_level for init_multi_process_env
""" """
# 返回initialize_job_info中op_debug_level的值
return initialize_job_info["SocInfo"]["op_debug_level"] return initialize_job_info["SocInfo"]["op_debug_level"]
@ -72,21 +80,35 @@ def get_soc_info(initialize_job_info):
:return: soc info :return: soc info
""" """
soc_param = dict() soc_param = dict()
# 获取soc_info中的op_impl_mode
soc_param["op_impl_mode"] = initialize_job_info["SocInfo"]["op_impl_mode"] soc_param["op_impl_mode"] = initialize_job_info["SocInfo"]["op_impl_mode"]
# 获取soc_info中的op_debug_level并调用reset_op_debug_level_in_soc_info函数进行处理
soc_param["op_debug_level"] = reset_op_debug_level_in_soc_info(initialize_job_info["SocInfo"]["op_debug_level"]) soc_param["op_debug_level"] = reset_op_debug_level_in_soc_info(initialize_job_info["SocInfo"]["op_debug_level"])
# 获取soc_info中的op_impl_mode_list
soc_param["op_impl_mode_list"] = initialize_job_info["SocInfo"]["op_impl_mode_list"] soc_param["op_impl_mode_list"] = initialize_job_info["SocInfo"]["op_impl_mode_list"]
# 获取soc_info中的op_debug_dir
soc_param["op_debug_dir"] = initialize_job_info["SocInfo"]["op_debug_dir"] soc_param["op_debug_dir"] = initialize_job_info["SocInfo"]["op_debug_dir"]
# 获取soc_info中的vector_fp_ceiling
soc_param["vector_fp_ceiling"] = initialize_job_info["SocInfo"]["vector_fp_ceiling"] soc_param["vector_fp_ceiling"] = initialize_job_info["SocInfo"]["vector_fp_ceiling"]
# 获取soc_info中的mdl_bank_path
soc_param['mdl_bank_path'] = initialize_job_info["SocInfo"]["mdl_bank_path"] soc_param['mdl_bank_path'] = initialize_job_info["SocInfo"]["mdl_bank_path"]
# 获取soc_info中的op_bank_path
soc_param['op_bank_path'] = initialize_job_info["SocInfo"]["op_bank_path"] soc_param['op_bank_path'] = initialize_job_info["SocInfo"]["op_bank_path"]
soc_info = list() soc_info = list()
# 获取soc_info中的socVersion
soc_info.append(initialize_job_info["SocInfo"]["socVersion"]) soc_info.append(initialize_job_info["SocInfo"]["socVersion"])
# 获取soc_info中的coreType
soc_info.append(initialize_job_info["SocInfo"]["coreType"]) soc_info.append(initialize_job_info["SocInfo"]["coreType"])
# 获取soc_info中的coreNum
soc_info.append(initialize_job_info["SocInfo"]["coreNum"]) soc_info.append(initialize_job_info["SocInfo"]["coreNum"])
# 获取soc_info中的l1Fusion
soc_info.append(initialize_job_info["SocInfo"]["l1Fusion"]) soc_info.append(initialize_job_info["SocInfo"]["l1Fusion"])
# 获取soc_info中的l2Mode
soc_info.append(initialize_job_info["SocInfo"]["l2Mode"]) soc_info.append(initialize_job_info["SocInfo"]["l2Mode"])
# 获取soc_info中的l2Fusion
soc_info.append(initialize_job_info["SocInfo"]["l2Fusion"]) soc_info.append(initialize_job_info["SocInfo"]["l2Fusion"])
# 将soc_param添加到soc_info中
soc_info.append(soc_param) soc_info.append(soc_param)
return soc_info return soc_info
@ -98,16 +120,22 @@ def check_arg_info(io_info):
:param io_info:A dict, to be checked. :param io_info:A dict, to be checked.
:return: Exception: If specific keyword is not found. :return: Exception: If specific keyword is not found.
""" """
# 检查io_info中是否包含shape
if 'shape' not in io_info: if 'shape' not in io_info:
raise ValueError("Json string Errors, key:shape not found.") raise ValueError("Json string Errors, key:shape not found.")
# 检查io_info中是否包含ori_shape
if 'ori_shape' not in io_info: if 'ori_shape' not in io_info:
raise ValueError("Json string Errors, key:ori_shape not found.") raise ValueError("Json string Errors, key:ori_shape not found.")
# 检查io_info中是否包含format
if 'format' not in io_info or not io_info['format']: if 'format' not in io_info or not io_info['format']:
raise ValueError("Json string Errors, key:format not found.") raise ValueError("Json string Errors, key:format not found.")
# 检查io_info中是否包含ori_format
if 'ori_format' not in io_info or not io_info['ori_format']: if 'ori_format' not in io_info or not io_info['ori_format']:
raise ValueError("Json string Errors, key:ori_format not found.") raise ValueError("Json string Errors, key:ori_format not found.")
# 检查io_info中是否包含dtype
if 'dtype' not in io_info or not io_info['dtype']: if 'dtype' not in io_info or not io_info['dtype']:
raise ValueError("Json string Errors, key:dtype not found.") raise ValueError("Json string Errors, key:dtype not found.")
# 检查io_info中是否包含param_type
if 'param_type' not in io_info or not io_info['param_type']: if 'param_type' not in io_info or not io_info['param_type']:
raise ValueError("Json string Errors, key:param_type not found.") raise ValueError("Json string Errors, key:param_type not found.")
@ -119,18 +147,28 @@ def get_input_output_args(io_info):
:return:input/output args :return:input/output args
""" """
args = [] args = []
# 如果io_info为空则返回空列表
if io_info is None: if io_info is None:
return args return args
# 遍历io_info中的每个元素
for item in io_info: for item in io_info:
# 如果元素是字典类型
if isinstance(item, dict): if isinstance(item, dict):
# 调用get_single_io_arg函数获取单个输入/输出参数
arg = get_single_io_arg(item) arg = get_single_io_arg(item)
args.append(arg) args.append(arg)
elif isinstance(item, list): elif isinstance(item, list):
# 如果元素是列表类型
dyn_arg = [] dyn_arg = []
# 创建一个空列表dyn_arg
for info in item: for info in item:
# 遍历列表中的每个元素
arg = get_single_io_arg(info) arg = get_single_io_arg(info)
# 调用get_single_io_arg函数获取单个输入/输出参数
dyn_arg.append(arg) dyn_arg.append(arg)
# 将参数添加到dyn_arg列表中
args.append(tuple(dyn_arg)) args.append(tuple(dyn_arg))
# 将dyn_arg列表添加到args列表中
return args return args
@ -142,19 +180,30 @@ def get_single_io_arg(info):
""" """
if 'valid' not in info: if 'valid' not in info:
raise ValueError("Json string Errors, key:valid not found.") raise ValueError("Json string Errors, key:valid not found.")
# 检查info中是否包含valid
if info['valid']: if info['valid']:
check_arg_info(info) check_arg_info(info)
# 如果valid为True
del info['valid'] del info['valid']
# 调用check_arg_info函数检查参数的有效性
del info['name'] del info['name']
# 删除info中的valid和name键值对
if 'range' in info: if 'range' in info:
for i in range(len(info['range'])): for i in range(len(info['range'])):
# 如果info中包含range
if info['range'][i][1] == -1: if info['range'][i][1] == -1:
# 遍历range中的每个元素
info['range'][i][1] = None info['range'][i][1] = None
# 如果range中的元素值为-1则将其替换为None
res = info res = info
else: else:
# 将info赋值给res
res = None res = None
# 如果valid为False
return res return res
# 将res赋值为None
# 返回res
def assemble_op_args(compute_op_info, is_single_op_build=False): def assemble_op_args(compute_op_info, is_single_op_build=False):
""" """
@ -165,20 +214,32 @@ def assemble_op_args(compute_op_info, is_single_op_build=False):
""" """
inputs_info = compute_op_info["input_desc"] if "input_desc" in compute_op_info.keys() else None inputs_info = compute_op_info["input_desc"] if "input_desc" in compute_op_info.keys() else None
outputs_info = compute_op_info["output_desc"] if "output_desc" in compute_op_info.keys() else None outputs_info = compute_op_info["output_desc"] if "output_desc" in compute_op_info.keys() else None
# 如果compute_op_info中包含input_desc则将其赋值给inputs_info
if is_single_op_build: if is_single_op_build:
# 如果compute_op_info中包含output_desc则将其赋值给outputs_info
attrs = [] attrs = []
# 如果is_single_op_build为True
attrs_info = compute_op_info["attrs"] if "attrs" in compute_op_info.keys() else [] attrs_info = compute_op_info["attrs"] if "attrs" in compute_op_info.keys() else []
# 创建一个空列表attrs
for item in attrs_info: for item in attrs_info:
# 如果compute_op_info中包含attrs则将其赋值给attrs_info
if item["valid"] and item["name"] != "isRef": if item["valid"] and item["name"] != "isRef":
# 遍历attrs_info中的每个元素
attrs.append(item) attrs.append(item)
# 如果元素的valid为True且name不为isRef则将其添加到attrs列表中
else: else:
attrs = compute_op_info["attr_desc"] if "attr_desc" in compute_op_info.keys() else [] attrs = compute_op_info["attr_desc"] if "attr_desc" in compute_op_info.keys() else []
inputs = get_input_output_args(inputs_info) inputs = get_input_output_args(inputs_info)
outputs = get_input_output_args(outputs_info) outputs = get_input_output_args(outputs_info)
# 如果compute_op_info中包含attr_desc则将其赋值给attrs
attrs.append(compute_op_info["op_name"]) attrs.append(compute_op_info["op_name"])
# 调用get_output_args函数获取输入参数
return inputs, outputs, attrs return inputs, outputs, attrs
# 调用get_input_output_args函数获取输出参数
# 将compute_op_info中的op_name添加到attrs列表中
# 返回inputs、outputs、attrs
def get_compute_op_list(job_content): def get_compute_op_list(job_content):
""" """
Get compute op info list from job content info Get compute op info list from job content info
@ -188,12 +249,16 @@ def get_compute_op_list(job_content):
op_list = job_content["op_list"] op_list = job_content["op_list"]
op_compute_list = [] op_compute_list = []
for op in op_list: for op in op_list:
# 获取job_content中的op_list
if op["type"] != "Data": if op["type"] != "Data":
# 创建一个空列表op_compute_list
op_compute_list.append(op) op_compute_list.append(op)
return op_compute_list return op_compute_list
# 如果元素的typeData则将其添加到op_compute_list列表中
def get_options_info(job_content): def get_options_info(job_content):
# 返回op_compute_list列表
""" """
Get options info Get options info
:param job_content: :param job_content:
@ -203,17 +268,29 @@ def get_options_info(job_content):
options["socVersion"] = job_content["SocInfo"]["socVersion"] options["socVersion"] = job_content["SocInfo"]["socVersion"]
options["coreType"] = job_content["SocInfo"]["coreType"] options["coreType"] = job_content["SocInfo"]["coreType"]
options["coreNum"] = job_content["SocInfo"]["coreNum"] options["coreNum"] = job_content["SocInfo"]["coreNum"]
# 创建一个空字典options
options["l1Fusion"] = job_content["SocInfo"]["l1Fusion"] options["l1Fusion"] = job_content["SocInfo"]["l1Fusion"]
# 获取job_content中的socVersion
options["l2Fusion"] = job_content["SocInfo"]["l2Fusion"] options["l2Fusion"] = job_content["SocInfo"]["l2Fusion"]
# 获取job_content中的coreType
options["l2Mode"] = job_content["SocInfo"]["l2Mode"] options["l2Mode"] = job_content["SocInfo"]["l2Mode"]
# 获取job_content中的coreNum
options["op_debug_level"] = reset_op_debug_level_in_soc_info(job_content["SocInfo"]["op_debug_level"]) options["op_debug_level"] = reset_op_debug_level_in_soc_info(job_content["SocInfo"]["op_debug_level"])
# 获取job_content中的l1Fusion
options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"] options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"]
# 获取job_content中的l2Fusion
options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"] options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"]
# 获取job_content中的l2Mode
options["mdl_bank_path"] = job_content["SocInfo"]["mdl_bank_path"] options["mdl_bank_path"] = job_content["SocInfo"]["mdl_bank_path"]
# 获取job_content中的op_debug_level并调用reset_op_debug_level_in_soc_info函数进行处理
options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"] options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"]
# 获取job_content中的op_impl_mode
options["deviceId"] = job_content["SocInfo"]["deviceId"] options["deviceId"] = job_content["SocInfo"]["deviceId"]
# 从job_content中获取deviceId并将其赋值给options字典的deviceId键
options["autoTilingMode"] = job_content["SocInfo"]["autoTilingMode"] options["autoTilingMode"] = job_content["SocInfo"]["autoTilingMode"]
# 从job_content中获取autoTilingMode并将其赋值给options字典的autoTilingMode键
options["op_impl_mode_list"] = job_content["SocInfo"]["op_impl_mode_list"] options["op_impl_mode_list"] = job_content["SocInfo"]["op_impl_mode_list"]
# 从job_content中获取op_impl_mode_list并将其赋值给options字典的op_impl_mode_list键
return options return options
@ -223,15 +300,22 @@ def get_fuzz_build_info(job_content):
:param job_content: job content info :param job_content: job content info
:return: fuzz build info :return: fuzz build info
""" """
# 从job_content中获取计算操作列表
op_compute_info = get_compute_op_list(job_content)[0] op_compute_info = get_compute_op_list(job_content)[0]
# 初始化fuzz_build_info字典
fuzz_build_info = dict() fuzz_build_info = dict()
# 根据op_compute_info中的build_type判断编译类型
fuzz_build_info["compile_type"] = "fuzzily_build" if op_compute_info["build_type"] == BuildType.FUZZILY.value \ fuzz_build_info["compile_type"] = "fuzzily_build" if op_compute_info["build_type"] == BuildType.FUZZILY.value \
else "accurately_build" else "accurately_build"
# 获取miss_support_info
fuzz_build_info["miss_support_info"] = op_compute_info["miss_support_info"] fuzz_build_info["miss_support_info"] = op_compute_info["miss_support_info"]
# 获取max_kernel_id
fuzz_build_info["max_kernel_id"] = op_compute_info["max_kernel_id"] fuzz_build_info["max_kernel_id"] = op_compute_info["max_kernel_id"]
# 如果build_type为FUZZILY则获取incremental_link
fuzz_build_info["incremental_link"] = os.path.realpath( fuzz_build_info["incremental_link"] = os.path.realpath(
job_content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_compute_info["name"] + ".json") if \ job_content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_compute_info["name"] + ".json") if \
op_compute_info["build_type"] == BuildType.FUZZILY.value else "" op_compute_info["build_type"] == BuildType.FUZZILY.value else ""
# 返回fuzz_build_info
return fuzz_build_info return fuzz_build_info
@ -241,10 +325,14 @@ def get_func_names(job_content):
:param job_content: job content info :param job_content: job content info
:return: function names :return: function names
""" """
# 初始化func_names列表
func_names = [] func_names = []
# 遍历job_content中的op_list
for op in job_content["op_list"]: for op in job_content["op_list"]:
# 如果op中包含func_name则将其添加到func_names列表中
if "func_name" in op: if "func_name" in op:
func_names.append(op["func_name"]) func_names.append(op["func_name"])
# 返回func_names
return func_names return func_names
@ -254,12 +342,16 @@ def get_module_name(compute_op_info):
:param compute_op_info: :param compute_op_info:
:return: :return:
""" """
# 获取compute_op_info中的dynamic_compile_static和unknown_shape
dynamic_compile_static = compute_op_info["dynamic_compile_static"] dynamic_compile_static = compute_op_info["dynamic_compile_static"]
unknown_shape = compute_op_info["unknown_shape"] unknown_shape = compute_op_info["unknown_shape"]
# 获取compute_op_info中的module_name
op_module_name = compute_op_info["module_name"] op_module_name = compute_op_info["module_name"]
# 如果dynamic_compile_static或unknown_shape为True则将module_name中的第一个和最后一个"."之间的字符串替换为".dynamic."
if dynamic_compile_static or unknown_shape: if dynamic_compile_static or unknown_shape:
d = ".dynamic." d = ".dynamic."
op_module_name = d.join((op_module_name.split(".")[0], op_module_name.split(".")[-1])) op_module_name = d.join((op_module_name.split(".")[0], op_module_name.split(".")[-1]))
# 返回替换后的module_name
return op_module_name return op_module_name
@ -269,10 +361,14 @@ def adjust_custom_op_info(compute_op_info):
:param compute_op_info: :param compute_op_info:
:return: :return:
""" """
# 获取compute_op_info中的py_module_path
py_module_path = compute_op_info["py_module_path"] py_module_path = compute_op_info["py_module_path"]
# 如果py_module_path是一个文件则获取其路径和文件名
if os.path.isfile(py_module_path): if os.path.isfile(py_module_path):
py_module_path, file_name = os.path.split(py_module_path) py_module_path, file_name = os.path.split(py_module_path)
# 获取文件名中的模块名
module_name, _ = os.path.splitext(file_name) module_name, _ = os.path.splitext(file_name)
# 将py_module_path和module_name更新到compute_op_info中
compute_op_info["py_module_path"] = py_module_path compute_op_info["py_module_path"] = py_module_path
compute_op_info["module_name"] = module_name compute_op_info["module_name"] = module_name
@ -281,5 +377,6 @@ def pack_op_args(inputs, outputs, attrs):
""" """
flatten inputs outputs attrs flatten inputs outputs attrs
""" """
# 将inputs、outputs、attrs展开为一个列表
op_args = (inputs, outputs, attrs) op_args = (inputs, outputs, attrs)
return [item for arg in op_args for item in arg] return [item for arg in op_args for item in arg]

@ -20,14 +20,23 @@ from enum import Enum
class JobType(Enum): class JobType(Enum):
""" Job Type """ """ Job Type """
# 初始化任务
INITIALIZE_JOB = 'Initialize' INITIALIZE_JOB = 'Initialize'
# 结束任务
FINALIZE_JOB = 'Finalize' FINALIZE_JOB = 'Finalize'
# 检查支持任务
CHECK_JOB = 'CheckSupport' CHECK_JOB = 'CheckSupport'
# 选择格式任务
SELECT_JOB = 'SelectFormat' SELECT_JOB = 'SelectFormat'
# 预编译任务
PRECOMPILE_JOB = 'PreCompile' PRECOMPILE_JOB = 'PreCompile'
# 编译任务
COMPILE_JOB = 'Compile' COMPILE_JOB = 'Compile'
# 融合编译任务
FUSION_COMPILE_JOB = 'FusionOpCompile' FUSION_COMPILE_JOB = 'FusionOpCompile'
# 调优任务
TUNE_JOB = 'Tune' TUNE_JOB = 'Tune'
# 查询任务
QUERY_JOB = 'Query' QUERY_JOB = 'Query'
@ -51,9 +60,13 @@ class JobStatus(Enum):
class LogMessage: class LogMessage:
""" Log message """ """ Log message """
# 初始化函数,用于创建一个对象
def __init__(self, index, level, info): def __init__(self, index, level, info):
# 将传入的index参数赋值给对象的index属性
self.index = index self.index = index
# 将传入的level参数赋值给对象的level属性
self.level = level self.level = level
# 将传入的info参数赋值给对象的info属性
self.info = info self.info = info
@ -74,29 +87,50 @@ class TbeJob:
""" Tbe compilation job """ """ Tbe compilation job """
def __init__(self, source_id, job_id, job_type, content, fusion_op_name, json_str, sys_info): def __init__(self, source_id, job_id, job_type, content, fusion_op_name, json_str, sys_info):
# 初始化函数用于创建一个Job对象
self.source_id = source_id self.source_id = source_id
# 源ID
self.id = job_id self.id = job_id
# 任务ID
self.type = JobType(job_type) self.type = JobType(job_type)
# 任务类型
self.status = JobStatus.JOB_INITIAL self.status = JobStatus.JOB_INITIAL
# 任务状态
self.content = content self.content = content
# 任务内容
self.fusion_op_name = fusion_op_name self.fusion_op_name = fusion_op_name
# 融合操作名称
self.result = "" self.result = ""
# 任务结果
self.process_info = [] self.process_info = []
# 任务处理信息
self.json_string = json_str self.json_string = json_str
# JSON字符串
self._sys_logger = sys_info["logger"] self._sys_logger = sys_info["logger"]
# 系统日志
self.sys_offline_tune = sys_info["offline_tune"] self.sys_offline_tune = sys_info["offline_tune"]
# 离线调优
self.sys_tune_dump_path = sys_info["tune_dump_path"] self.sys_tune_dump_path = sys_info["tune_dump_path"]
# 调优转储路径
self.sys_para_debug_path = sys_info["para_debug_path"] self.sys_para_debug_path = sys_info["para_debug_path"]
# 参数调试路径
# license info # license info
self.rl_tune_switch = sys_info["rl_tune_switch"] self.rl_tune_switch = sys_info["rl_tune_switch"]
# 强化学习调优开关
self.rl_tune_list = sys_info["rl_tune_list"] self.rl_tune_list = sys_info["rl_tune_list"]
# 强化学习调优列表
self.op_tune_switch = sys_info["op_tune_switch"] self.op_tune_switch = sys_info["op_tune_switch"]
# 操作调优开关
self.op_tune_list = sys_info["op_tune_list"] self.op_tune_list = sys_info["op_tune_list"]
# 操作调优列表
self.pass_list = sys_info["pass_list"] self.pass_list = sys_info["pass_list"]
# 通过列表
# soc info # soc info
self.soc_version = sys_info["socVersion"] self.soc_version = sys_info["socVersion"]
# SoC版本
self.core_num = sys_info["coreNum"] self.core_num = sys_info["coreNum"]
# 核心数量
self.op_bank_path = sys_info["op_bank_path"] self.op_bank_path = sys_info["op_bank_path"]
def debug(self, msg, *args, **kwargs): def debug(self, msg, *args, **kwargs):
@ -106,9 +140,13 @@ class TbeJob:
:param args: :param args:
:return: :return:
""" """
# 获取处理后的消息
processed_msg = _get_message(msg, args) processed_msg = _get_message(msg, args)
# 创建日志消息对象
message = LogMessage(len(self.process_info), LogLevel.DEBUG, processed_msg) message = LogMessage(len(self.process_info), LogLevel.DEBUG, processed_msg)
# 将日志消息对象添加到process_info列表中
self.process_info.append(message) self.process_info.append(message)
# 使用系统日志记录器记录日志
self._sys_logger.debug(msg, *args, **kwargs) self._sys_logger.debug(msg, *args, **kwargs)
def info(self, msg, *args, **kwargs): def info(self, msg, *args, **kwargs):
@ -118,9 +156,13 @@ class TbeJob:
:param args: :param args:
:return: :return:
""" """
# 获取处理后的消息
processed_msg = _get_message(msg, args) processed_msg = _get_message(msg, args)
# 创建日志消息对象
message = LogMessage(len(self.process_info), LogLevel.INFO, processed_msg) message = LogMessage(len(self.process_info), LogLevel.INFO, processed_msg)
# 将日志消息对象添加到process_info列表中
self.process_info.append(message) self.process_info.append(message)
# 使用系统日志记录器记录日志
self._sys_logger.info(msg, *args, **kwargs) self._sys_logger.info(msg, *args, **kwargs)
def warning(self, msg, *args, **kwargs): def warning(self, msg, *args, **kwargs):
@ -130,9 +172,13 @@ class TbeJob:
:param args: :param args:
:return: :return:
""" """
# 获取处理后的消息
processed_msg = _get_message(msg, args) processed_msg = _get_message(msg, args)
# 创建日志消息对象
message = LogMessage(len(self.process_info), LogLevel.WARNING, processed_msg) message = LogMessage(len(self.process_info), LogLevel.WARNING, processed_msg)
# 将日志消息对象添加到process_info列表中
self.process_info.append(message) self.process_info.append(message)
# 使用系统日志记录器记录警告信息
self._sys_logger.warning(msg, *args, **kwargs) self._sys_logger.warning(msg, *args, **kwargs)
def error(self, msg, *args, **kwargs): def error(self, msg, *args, **kwargs):
@ -142,9 +188,13 @@ class TbeJob:
:param args: :param args:
:return: :return:
""" """
# 获取处理后的消息
processed_msg = _get_message(msg, args) processed_msg = _get_message(msg, args)
# 创建一个LogMessage对象包含消息的长度、日志级别和消息内容
message = LogMessage(len(self.process_info), LogLevel.ERROR, processed_msg) message = LogMessage(len(self.process_info), LogLevel.ERROR, processed_msg)
# 将LogMessage对象添加到process_info列表中
self.process_info.append(message) self.process_info.append(message)
# 使用_sys_logger记录错误日志msg为原始消息args和kwargs为参数
self._sys_logger.error(msg, *args, **kwargs) self._sys_logger.error(msg, *args, **kwargs)
def error_manager(self, msg, *args, **kwargs): def error_manager(self, msg, *args, **kwargs):
@ -154,30 +204,50 @@ class TbeJob:
:param args: :param args:
:return: :return:
""" """
# 如果msg为空则输出警告信息并返回
if not msg: if not msg:
self.warning("Get empty error manager message, op_name: {}".format(self.fusion_op_name)) self.warning("Get empty error manager message, op_name: {}".format(self.fusion_op_name))
return return
# 初始化异常信息为None
exception_info = None exception_info = None
# 获取融合操作名称
op_name = self.fusion_op_name op_name = self.fusion_op_name
# 如果msg是Exception类型
if isinstance(msg, Exception): if isinstance(msg, Exception):
# 遍历msg的参数
for arg in msg.args: for arg in msg.args:
# 如果参数是字典类型且包含"errCode"键
if isinstance(arg, dict) and "errCode" in arg: if isinstance(arg, dict) and "errCode" in arg:
# 将异常信息赋值给exception_info
exception_info = arg exception_info = arg
break break
# 如果没有找到异常信息
if not exception_info: if not exception_info:
# 输出错误信息
self.error("Exception message:{}".format(msg)) self.error("Exception message:{}".format(msg))
return return
# 如果msg不是Exception类型
else: else:
# 将msg的第一个元素赋值给异常信息
exception_info = msg[0] exception_info = msg[0]
# 如果msg的长度大于等于2
if len(msg) >= 2: if len(msg) >= 2:
# 将msg的第二个元素赋值给融合操作名称
op_name = msg[1] op_name = msg[1]
# 如果异常信息不是字典类型或为空
if not isinstance(exception_info, dict) or not exception_info: if not isinstance(exception_info, dict) or not exception_info:
# 输出警告信息
self.warning("Get illegal error manager message, op_name: {}".format(self.fusion_op_name)) self.warning("Get illegal error manager message, op_name: {}".format(self.fusion_op_name))
return return
# 将异常信息中的op_name字段赋值为融合操作名称
exception_info["op_name"] = op_name exception_info["op_name"] = op_name
# 将异常信息转换为JSON格式
processed_msg = json.dumps(exception_info) processed_msg = json.dumps(exception_info)
# 创建LogMessage对象
message = LogMessage(len(self.process_info), LogLevel.ERROR_MANAGER, processed_msg) message = LogMessage(len(self.process_info), LogLevel.ERROR_MANAGER, processed_msg)
# 将LogMessage对象添加到process_info列表中
self.process_info.append(message) self.process_info.append(message)
# 输出异常信息
self._sys_logger.exception(msg, *args, **kwargs) self._sys_logger.exception(msg, *args, **kwargs)
def get_result(self): def get_result(self):
@ -186,15 +256,26 @@ class TbeJob:
:return: job process result string :return: job process result string
""" """
result = dict() result = dict()
# 获取任务状态
result["status"] = self.status.value result["status"] = self.status.value
# 获取任务源ID
result["source_id"] = self.source_id result["source_id"] = self.source_id
# 获取任务ID
result["job_id"] = self.id result["job_id"] = self.id
# 获取任务类型
result["job_type"] = self.type.value result["job_type"] = self.type.value
# 获取融合操作名称
result["fusion_op_name"] = self.fusion_op_name result["fusion_op_name"] = self.fusion_op_name
# 获取任务结果
result["result"] = self.result result["result"] = self.result
process_info = [] process_info = []
# 遍历任务处理信息
for info in self.process_info: for info in self.process_info:
# 构造消息字典
msg = {"index": info.index, "level": info.level.value, "message": info.info} msg = {"index": info.index, "level": info.level.value, "message": info.info}
# 将消息字典添加到处理信息列表中
process_info.append(msg) process_info.append(msg)
# 将处理信息列表添加到结果字典中
result["process_info"] = process_info result["process_info"] = process_info
# 将结果字典转换为JSON字符串并返回
return json.dumps(result) return json.dumps(result)

@ -29,6 +29,7 @@ class TbeJobManager:
""" TBE compiler job manager """ """ TBE compiler job manager """
def __init__(self): def __init__(self):
# 定义一个字典,用于存储不同类型的任务及其对应的处理函数
self.job_handlers = { self.job_handlers = {
JobType.INITIALIZE_JOB: self.initialize_handler, JobType.INITIALIZE_JOB: self.initialize_handler,
JobType.FINALIZE_JOB: self.finalize_handler, JobType.FINALIZE_JOB: self.finalize_handler,
@ -41,24 +42,43 @@ class TbeJobManager:
JobType.QUERY_JOB: self.query_handler JobType.QUERY_JOB: self.query_handler
} }
# 定义一个字典,用于存储所有任务
self._all_jobs = {} self._all_jobs = {}
# 定义一个字典,用于存储已完成任务
self._finished_jobs = {} self._finished_jobs = {}
# 定义一个字典,用于存储正在运行的任务
self._running_jobs = {} self._running_jobs = {}
# 定义一个字典,用于存储原始完成任务
self._raw_finish_jobs = {} self._raw_finish_jobs = {}
# 定义一个布尔值用于判断TBE是否初始化
self.tbe_initialize = False self.tbe_initialize = False
# 定义一个变量,用于存储初始化缓存
self.init_cache = None self.init_cache = None
# 定义一个字符串,用于存储参数调试路径
self.para_debug_path = "" self.para_debug_path = ""
# 定义一个字符串,用于存储自动调优模式
self.auto_tiling_mode = "" self.auto_tiling_mode = ""
# 定义一个布尔值,用于判断是否离线调优
self.offline_tune = False self.offline_tune = False
# 定义一个列表,用于存储调优操作
self.tune_op_list = [] self.tune_op_list = []
# 定义一个字符串,用于存储调优输出路径
self.tune_dump_path = "" self.tune_dump_path = ""
# 定义一个字符串,用于存储调优库路径
self.tune_bank_path = "" self.tune_bank_path = ""
# 定义一个列表,用于存储自动调优操作
self.auto_tune_op_list = [] self.auto_tune_op_list = []
# 定义一个字典,用于存储预编译操作
self.pre_build_ops = {} self.pre_build_ops = {}
# 定义一个整数,用于存储融合编译需要同步的次数
self.fusion_need_sync = 0 self.fusion_need_sync = 0
# 定义一个字典,用于存储导入的模块
self.imported_module = {} self.imported_module = {}
# 定义一个字符串用于存储SoC版本
self.soc_version = "" self.soc_version = ""
# 定义一个整数,用于存储核心数量
self.core_num = 0 self.core_num = 0
# 定义一个字符串,用于存储操作库路径
self.op_bank_path = "" self.op_bank_path = ""
# license info # license info
self.rl_tune_switch = "" self.rl_tune_switch = ""
@ -68,6 +88,7 @@ class TbeJobManager:
self.pass_list = "" self.pass_list = ""
def __del__(self): def __del__(self):
# 删除对象时调用reset方法
self.reset() self.reset()
def reset(self): def reset(self):
@ -75,22 +96,38 @@ class TbeJobManager:
Reset the job manager Reset the job manager
:return: None :return: None
""" """
# 重置所有任务
self._all_jobs = {} self._all_jobs = {}
# 重置已完成任务
self._finished_jobs = {} self._finished_jobs = {}
# 重置正在运行的任务
self._running_jobs = {} self._running_jobs = {}
# 重置原始已完成任务
self._raw_finish_jobs = {} self._raw_finish_jobs = {}
# 重置调试路径
self.para_debug_path = "" self.para_debug_path = ""
# 重置自动切分模式
self.auto_tiling_mode = "" self.auto_tiling_mode = ""
# 重置离线调优
self.offline_tune = False self.offline_tune = False
# 重置调优操作列表
self.tune_op_list = [] self.tune_op_list = []
# 重置调优导出路径
self.tune_dump_path = "" self.tune_dump_path = ""
# 重置调优银行路径
self.tune_bank_path = "" self.tune_bank_path = ""
# 重置自动调优操作列表
self.auto_tune_op_list = [] self.auto_tune_op_list = []
# 重置预构建操作
self.pre_build_ops = [] self.pre_build_ops = []
# 重置融合需要同步
self.fusion_need_sync = 0 self.fusion_need_sync = 0
# 重置导入模块
self.imported_module = {} self.imported_module = {}
# 如果tbe_initialize为True则调用tbe_finalize方法
if self.tbe_initialize: if self.tbe_initialize:
tbe_finalize(self.auto_tiling_mode, self.offline_tune, self.init_cache) tbe_finalize(self.auto_tiling_mode, self.offline_tune, self.init_cache)
# 重置tbe_initialize
self.tbe_initialize = False self.tbe_initialize = False
self.init_cache = None self.init_cache = None
self.soc_version = "" self.soc_version = ""
@ -105,11 +142,17 @@ class TbeJobManager:
""" """
job = None job = None
try: try:
# 将job_str转换为json格式
job_json = json.loads(job_str) job_json = json.loads(job_str)
# 检查job_json的合法性
check_job_json(job_json) check_job_json(job_json)
# 获取job_id
job_id = job_json["job_id"] job_id = job_json["job_id"]
# 获取source_id
source_id = job_json["source_id"] source_id = job_json["source_id"]
# 获取job_type
job_type = job_json["job_type"] job_type = job_json["job_type"]
# 获取系统信息
sys_info = self._get_job_sys_info() sys_info = self._get_job_sys_info()
fusion_op_name = "NA" if "fusion_op_name" not in job_json["job_content"] else job_json["job_content"][ fusion_op_name = "NA" if "fusion_op_name" not in job_json["job_content"] else job_json["job_content"][
"fusion_op_name"] "fusion_op_name"]
@ -140,173 +183,260 @@ class TbeJobManager:
def initialize_handler(self, job: TbeJob): def initialize_handler(self, job: TbeJob):
""" Initialize job handler """ """ Initialize job handler """
# 初始化系统信息
self._init_sys_info(job) self._init_sys_info(job)
# 调用tbe_initialize函数初始化job
res = tbe_initialize(job) res = tbe_initialize(job)
# 如果初始化失败记录错误信息并将job状态设置为JOB_FAILED
if not res: if not res:
job.error("Process Initialize Job failed, job json string:{}".format(job.json_string)) job.error("Process Initialize Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 如果auto_tiling_mode中包含"GA",则获取自动调优支持的操作列表
if "GA" in self.auto_tiling_mode: if "GA" in self.auto_tiling_mode:
self.auto_tune_op_list = get_auto_tune_support_op_list(job) self.auto_tune_op_list = get_auto_tune_support_op_list(job)
# 设置tbe_initialize为True
self.tbe_initialize = True self.tbe_initialize = True
# 将job保存到init_cache中
self.init_cache = job self.init_cache = job
# 将job状态设置为JOB_SUCCESS
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def finalize_handler(self, job: TbeJob): def finalize_handler(self, job: TbeJob):
""" Finalize job handler """ """ Finalize job handler """
# 如果tbe_initialize为False则直接将job状态设置为JOB_SUCCESS
if not self.tbe_initialize: if not self.tbe_initialize:
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
# 调用tbe_finalize函数传入auto_tiling_mode和offline_tune参数
res = tbe_finalize(self.auto_tiling_mode, self.offline_tune, job) res = tbe_finalize(self.auto_tiling_mode, self.offline_tune, job)
# 如果finalize失败记录错误信息并将job状态设置为JOB_FAILED
if not res: if not res:
job.error("Process Finalize Job failed, job json string:{}".format(job.json_string)) job.error("Process Finalize Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job状态设置为JOB_SUCCESS
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def check_support_handler(self, job: TbeJob): def check_support_handler(self, job: TbeJob):
""" Check Support job handler """ """ Check Support job handler """
# 调用check_support函数检查job是否支持
res = check_support(job) res = check_support(job)
# 如果不支持记录错误信息并将job状态设置为JOB_FAILED
if not res: if not res:
job.error("Process CheckSupport Job failed, job json string:{}".format(job.json_string)) job.error("Process CheckSupport Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 更新导入的操作模块
self._update_imported_op_module(job) self._update_imported_op_module(job)
# 将job状态设置为JOB_SUCCESS
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def select_format_handler(self, job: TbeJob): def select_format_handler(self, job: TbeJob):
""" Select Format job handler """ """ Select Format job handler """
# 调用select_op_format函数选择操作格式
res = select_op_format(job) res = select_op_format(job)
# 如果选择失败记录错误信息并将job状态设置为JOB_FAILED
if not res: if not res:
job.error("Process SelectFormat Job failed, job json string:{}".format(job.json_string)) job.error("Process SelectFormat Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job状态设置为JOB_SUCCESS
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def pre_compile_handler(self, job: TbeJob): def pre_compile_handler(self, job: TbeJob):
""" Pre Compile job handler """ """ Pre Compile job handler """
# 调用parallel_pre_compile_op函数对job进行预处理
res = parallel_pre_compile_op(job) res = parallel_pre_compile_op(job)
# 如果预处理失败则记录错误信息并将job状态设置为JOB_FAILED
if not res: if not res:
job.error("Process PreCompile Job failed, job json string:{}".format(job.json_string)) job.error("Process PreCompile Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job添加到pre_build_ops字典中以fusion_op_name为键
self.pre_build_ops[job.content["fusion_op_name"]] = job self.pre_build_ops[job.content["fusion_op_name"]] = job
# 将job状态设置为JOB_RUNNING
return self.add_to_running_jobs(job) return self.add_to_running_jobs(job)
def compile_handler(self, job: TbeJob): def compile_handler(self, job: TbeJob):
""" Compile job handler """ """ Compile job handler """
# 获取job中的compute_op_list
compute_op_list = get_compute_op_list(job.content) compute_op_list = get_compute_op_list(job.content)
# 如果compute_op_list只有一个元素则调用single_op_compile函数进行编译
if len(compute_op_list) == 1: # pylint: disable=no-else-return if len(compute_op_list) == 1: # pylint: disable=no-else-return
return self.single_op_compile(job) return self.single_op_compile(job)
else: else:
# 调用before_build_process函数对job进行预处理
before_build_process(job) before_build_process(job)
# 如果需要同步fusion则调用sync_fusion_env函数进行同步
if self.fusion_need_sync: if self.fusion_need_sync:
sync_fusion_env(self.fusion_need_sync, self.imported_module) sync_fusion_env(self.fusion_need_sync, self.imported_module)
self.fusion_need_sync = 0 self.fusion_need_sync = 0
# 调用parallel_compile_fusion_op函数对job进行编译
res = parallel_compile_fusion_op(job) res = parallel_compile_fusion_op(job)
# 如果编译失败则记录错误信息并将job状态设置为JOB_FAILED
if not res: if not res:
job.error("Parallel_compile_fusion_op Job failed, job json string:{}".format(job.json_string)) job.error("Parallel_compile_fusion_op Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job状态设置为JOB_RUNNING
return self.add_to_running_jobs(job) return self.add_to_running_jobs(job)
def single_op_compile(self, job: TbeJob): def single_op_compile(self, job: TbeJob):
"""Single operator compile""" """Single operator compile"""
# 调用do_fuzz_build_tbe_op函数对job进行编译
res = do_fuzz_build_tbe_op(job) res = do_fuzz_build_tbe_op(job)
# 如果编译失败则记录错误信息并将job状态设置为JOB_FAILED
if not res: if not res:
job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string)) job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 如果job.result为"NOT_CHANGED"则调用before_build_process函数进行预处理并调用build_single_pre_op函数进行编译
if job.result == "NOT_CHANGED": if job.result == "NOT_CHANGED":
job.result = "" job.result = ""
before_build_process(job) before_build_process(job)
res = build_single_pre_op(job) res = build_single_pre_op(job)
# 如果编译失败则记录错误信息并将job状态设置为JOB_FAILED
if not res: if not res:
job.error("Process build single pre op failed, job json string:{}".format(job.json_string)) job.error("Process build single pre op failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job状态设置为JOB_RUNNING
return self.add_to_running_jobs(job) return self.add_to_running_jobs(job)
# 如果job.result为"SUCCESS"则将job状态设置为JOB_SUCCESS
if job.result == "SUCCESS": if job.result == "SUCCESS":
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
# 如果编译失败则记录错误信息并将job状态设置为JOB_FAILED
job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string)) job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
def tune_handler(self, job: TbeJob): def tune_handler(self, job: TbeJob):
""" Tune job handler """ """ Tune job handler """
before_build_process(job) before_build_process(job)
# 选择调优模式
tune_mode = self._select_tune_mode(job) tune_mode = self._select_tune_mode(job)
# 如果调优模式为不调优,则直接调用编译处理函数
if tune_mode == TuneMode.NO_TUNE: if tune_mode == TuneMode.NO_TUNE:
return self.compile_handler(job) return self.compile_handler(job)
# 获取计算操作列表
compute_op_list = get_compute_op_list(job.content) compute_op_list = get_compute_op_list(job.content)
# 如果计算操作列表只有一个,则调用单操作调优函数
if len(compute_op_list) == 1: if len(compute_op_list) == 1:
return self.single_op_tune(job) return self.single_op_tune(job)
# 否则调用融合操作调优函数
return self.fusion_op_tune(job) return self.fusion_op_tune(job)
def single_op_tune(self, job: TbeJob): def single_op_tune(self, job: TbeJob):
"""Single operator tune""" """Single operator tune"""
# 选择调优模式
tune_mode = self._select_tune_mode(job) tune_mode = self._select_tune_mode(job)
# 如果调优模式为强化学习调优
if tune_mode == TuneMode.RL_TUNE: if tune_mode == TuneMode.RL_TUNE:
# 调用强化学习单操作调优函数
res = rl_tune_single_op(job) res = rl_tune_single_op(job)
# 如果调优失败,则记录错误信息,并将任务状态设置为失败
if not res: if not res:
job.error( job.error(
"Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string)) "Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 否则,如果需要同步融合环境,则调用同步融合环境函数
else: else:
if self.fusion_need_sync: if self.fusion_need_sync:
sync_fusion_env(self.fusion_need_sync, self.imported_module) sync_fusion_env(self.fusion_need_sync, self.imported_module)
self.fusion_need_sync = 0 self.fusion_need_sync = 0
# 调用遗传算法调优函数
res = ga_tune(job) res = ga_tune(job)
# 如果调优失败,则记录错误信息,并调用编译处理函数
if not res: if not res:
job.error("ga tune Job failed, job json string:{}".format(job.json_string)) job.error("ga tune Job failed, job json string:{}".format(job.json_string))
return self.compile_handler(job) return self.compile_handler(job)
# 如果任务状态为运行中
if job.status == JobStatus.JOB_RUNNING: if job.status == JobStatus.JOB_RUNNING:
# 如果调优模式为强化学习调优,则更新导入的操作模块
if tune_mode == TuneMode.RL_TUNE: if tune_mode == TuneMode.RL_TUNE:
self._update_imported_op_module(job) self._update_imported_op_module(job)
# 将任务添加到运行中任务列表
return self.add_to_running_jobs(job) return self.add_to_running_jobs(job)
# 否则将任务添加到已完成任务列表,并设置任务状态为成功
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def fusion_op_tune(self, job: TbeJob): def fusion_op_tune(self, job: TbeJob):
"""Fusion operator tune""" """Fusion operator tune"""
# 选择调优模式
tune_mode = self._select_tune_mode(job) tune_mode = self._select_tune_mode(job)
# 如果需要同步融合环境,则调用同步融合环境函数
if self.fusion_need_sync: if self.fusion_need_sync:
sync_fusion_env(self.fusion_need_sync, self.imported_module) sync_fusion_env(self.fusion_need_sync, self.imported_module)
self.fusion_need_sync = 0 self.fusion_need_sync = 0
# 如果调优模式为强化学习调优,则调用强化学习融合操作调优函数
if tune_mode == TuneMode.RL_TUNE: if tune_mode == TuneMode.RL_TUNE:
res = rl_tune_fusion_op(job) res = rl_tune_fusion_op(job)
# 否则调用遗传算法调优函数
else: else:
res = ga_tune(job) res = ga_tune(job)
# 如果调优失败,则记录错误信息,并将任务状态设置为失败
if not res: if not res:
job.error( job.error(
"Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string)) "Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 如果任务状态为运行中,则将任务添加到运行中任务列表
if job.status == JobStatus.JOB_RUNNING: if job.status == JobStatus.JOB_RUNNING:
return self.add_to_running_jobs(job) return self.add_to_running_jobs(job)
# 否则将任务添加到已完成任务列表,并设置任务状态为成功
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def query_handler(self, query_job: TbeJob): def query_handler(self, query_job: TbeJob):
""" Query job handler """ """ Query job handler """
# 获取查询任务的source_id和job_id
target_source_id = query_job.content["source_id"] target_source_id = query_job.content["source_id"]
target_job_id = query_job.content["job_id"] target_job_id = query_job.content["job_id"]
# 根据source_id和job_id获取已完成的任务
target_job = get_job(self._finished_jobs, target_source_id, target_job_id) target_job = get_job(self._finished_jobs, target_source_id, target_job_id)
# 如果找到了已完成的任务
if target_job: if target_job:
# 记录警告信息
query_job.warning("Query a finished job: {}".format(query_job.content)) query_job.warning("Query a finished job: {}".format(query_job.content))
# 将查询任务的结果设置为已完成任务的结果
query_job.result = target_job.get_result() query_job.result = target_job.get_result()
# 将查询任务添加到已完成任务列表中,并返回成功状态
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
# 根据source_id和job_id获取未完成的任务
target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id) target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id)
# 如果未找到未完成的任务
if not target_job: if not target_job:
# 更新未完成的任务列表
self.update_raw_finished_jobs(query_job) self.update_raw_finished_jobs(query_job)
# 再次根据source_id和job_id获取未完成的任务
target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id) target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id)
# 如果找到了未完成的任务
if target_job: if target_job:
# 记录调试信息
query_job.debug("Found job in raw finished jobs, source_id:{}, job_id:{}".format(target_source_id, query_job.debug("Found job in raw finished jobs, source_id:{}, job_id:{}".format(target_source_id,
target_job_id)) target_job_id))
# 将查询任务的结果设置为未完成任务的结果
query_job.result = target_job.get_result() query_job.result = target_job.get_result()
# 从未完成任务列表中删除该任务
del_job(self._raw_finish_jobs, target_job.source_id, target_job.id) del_job(self._raw_finish_jobs, target_job.source_id, target_job.id)
# 将未完成任务添加到已完成任务列表中,并返回成功状态
self.add_to_finished_jobs(target_job, target_job.status) self.add_to_finished_jobs(target_job, target_job.status)
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
# 根据source_id和job_id获取正在运行的任务
target_job = get_job(self._running_jobs, target_source_id, target_job_id) target_job = get_job(self._running_jobs, target_source_id, target_job_id)
# 如果找到了正在运行的任务
if target_job: if target_job:
# 将查询任务的结果设置为正在运行任务的结果
query_job.result = target_job.get_result() query_job.result = target_job.get_result()
# 将查询任务添加到已完成任务列表中,并返回成功状态
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
# 根据source_id和job_id获取所有任务
target_job = get_job(self._all_jobs, target_source_id, target_job_id) target_job = get_job(self._all_jobs, target_source_id, target_job_id)
# 如果找到了所有任务
if target_job: if target_job:
# 记录调试信息
query_job.debug("Found job in all jobs, source_id:{}, job_id:{}".format(target_source_id, query_job.debug("Found job in all jobs, source_id:{}, job_id:{}".format(target_source_id,
target_job_id)) target_job_id))
# 记录调试信息
target_job.debug("Be Queried") target_job.debug("Be Queried")
# 将查询任务的结果设置为所有任务的结果
query_job.result = target_job.get_result() query_job.result = target_job.get_result()
# 将查询任务添加到已完成任务列表中,并返回成功状态
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
# 如果没有找到任何任务,记录错误信息
query_job.error("Can't find job in finished/raw_finished/running jobs, source_id: {}".format(target_source_id)) query_job.error("Can't find job in finished/raw_finished/running jobs, source_id: {}".format(target_source_id))
# 将查询任务的结果设置为空
query_job.result = "" query_job.result = ""
# 将查询任务添加到已完成任务列表中,并返回失败状态
return self.add_to_finished_jobs(query_job, JobStatus.JOB_FAILED) return self.add_to_finished_jobs(query_job, JobStatus.JOB_FAILED)
def _get_job_sys_info(self): def _get_job_sys_info(self):
@ -314,10 +444,15 @@ class TbeJobManager:
Get job manager system info Get job manager system info
:return: system info :return: system info
""" """
# 创建一个字典,用于存储系统信息
sys_info = dict() sys_info = dict()
# 将DummyLogger添加到系统信息中
sys_info["logger"] = DummyLogger sys_info["logger"] = DummyLogger
# 将para_debug_path添加到系统信息中
sys_info["para_debug_path"] = self.para_debug_path sys_info["para_debug_path"] = self.para_debug_path
# 将tune_dump_path添加到系统信息中
sys_info["tune_dump_path"] = self.tune_dump_path sys_info["tune_dump_path"] = self.tune_dump_path
# 将offline_tune添加到系统信息中
sys_info["offline_tune"] = self.offline_tune sys_info["offline_tune"] = self.offline_tune
# license info # license info
sys_info["rl_tune_switch"] = self.rl_tune_switch sys_info["rl_tune_switch"] = self.rl_tune_switch
@ -362,12 +497,17 @@ class TbeJobManager:
:param job: :param job:
:return: :return:
""" """
# 获取计算操作列表
compute_op_info = get_compute_op_list(job.content)[0] compute_op_info = get_compute_op_list(job.content)[0]
# 获取操作模块名称
op_module_name = compute_op_info["module_name"] op_module_name = compute_op_info["module_name"]
# 如果操作模块名称在已导入模块中,则增加引用次数
if op_module_name in self.imported_module.keys(): if op_module_name in self.imported_module.keys():
self.imported_module[op_module_name] = self.imported_module[op_module_name] + 1 self.imported_module[op_module_name] = self.imported_module[op_module_name] + 1
# 否则将操作模块名称添加到已导入模块中并设置引用次数为1
else: else:
self.imported_module[op_module_name] = 1 self.imported_module[op_module_name] = 1
# 增加融合需要同步的次数
self.fusion_need_sync = self.fusion_need_sync + 1 self.fusion_need_sync = self.fusion_need_sync + 1
def _select_tune_mode(self, job): def _select_tune_mode(self, job):
@ -376,18 +516,25 @@ class TbeJobManager:
:param job: tbe tune job :param job: tbe tune job
:return: NO_TUNE RL_TUNE or GA_TUNE :return: NO_TUNE RL_TUNE or GA_TUNE
""" """
# 获取job的SocInfo中的autoTilingMode和offlineTune
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"] auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
offline_tune = job.content["SocInfo"]["offlineTune"] offline_tune = job.content["SocInfo"]["offlineTune"]
# 获取job的full_name
full_name = job.content["full_name"] full_name = job.content["full_name"]
# 获取job的func_names
func_names = get_func_names(job.content) func_names = get_func_names(job.content)
# 如果self.tune_op_list不为空且full_name不在self.tune_op_list中则返回TuneMode.NO_TUNE
if self.tune_op_list and full_name not in self.tune_op_list: if self.tune_op_list and full_name not in self.tune_op_list:
return TuneMode.NO_TUNE return TuneMode.NO_TUNE
# 如果offline_tune为True则返回TuneMode.RL_TUNE
if offline_tune: if offline_tune:
return TuneMode.RL_TUNE return TuneMode.RL_TUNE
# 如果auto_tiling_mode中包含TuneMode.GA_TUNE.value则遍历func_names如果func_name.lower()在self.auto_tune_op_list中则返回TuneMode.GA_TUNE
if TuneMode.GA_TUNE.value in auto_tiling_mode: if TuneMode.GA_TUNE.value in auto_tiling_mode:
for func_name in func_names: for func_name in func_names:
if func_name.lower() in self.auto_tune_op_list: if func_name.lower() in self.auto_tune_op_list:
return TuneMode.GA_TUNE return TuneMode.GA_TUNE
# 如果auto_tiling_mode中包含TuneMode.RL_TUNE.value则返回TuneMode.RL_TUNE
if TuneMode.RL_TUNE.value in auto_tiling_mode: if TuneMode.RL_TUNE.value in auto_tiling_mode:
return TuneMode.RL_TUNE return TuneMode.RL_TUNE
return TuneMode.NO_TUNE return TuneMode.NO_TUNE
@ -398,15 +545,22 @@ class TbeJobManager:
:param query_job: query job :param query_job: query job
:return: Node :return: Node
""" """
# 获取已完成任务
new_finished_jobs = get_finish_tasks(query_job.source_id) new_finished_jobs = get_finish_tasks(query_job.source_id)
# 遍历已完成任务
for new_job in new_finished_jobs: for new_job in new_finished_jobs:
# 获取任务ID
source_id = new_job["graph_id"] source_id = new_job["graph_id"]
job_id = new_job["task_id"] job_id = new_job["task_id"]
# 获取任务
target_job = get_job(self._running_jobs, source_id, job_id) target_job = get_job(self._running_jobs, source_id, job_id)
# 如果任务不存在,则报错
if not target_job: if not target_job:
query_job.error("Can't get job, source id:{}, job id:{}".format(source_id, job_id)) query_job.error("Can't get job, source id:{}, job id:{}".format(source_id, job_id))
continue continue
# 设置任务结果
target_job.result = new_job["op_res"] if "op_res" in new_job else new_job["result"] target_job.result = new_job["op_res"] if "op_res" in new_job else new_job["result"]
# 如果任务类型为预编译任务,则进行预编译
if target_job.type == JobType.PRECOMPILE_JOB: if target_job.type == JobType.PRECOMPILE_JOB:
op_name = target_job.content["fusion_op_name"] op_name = target_job.content["fusion_op_name"]
op_params = get_prebuild_output(op_name) op_params = get_prebuild_output(op_name)
@ -415,13 +569,17 @@ class TbeJobManager:
pre_compile_result["op_params"] = op_params pre_compile_result["op_params"] = op_params
pre_compile_result["core_type"] = new_job["core_type"] if "core_type" in new_job else "" pre_compile_result["core_type"] = new_job["core_type"] if "core_type" in new_job else ""
target_job.result = json.dumps(pre_compile_result) target_job.result = json.dumps(pre_compile_result)
# 输出任务结果
target_job.info("Query result:{}".format(new_job["result"])) target_job.info("Query result:{}".format(new_job["result"]))
# 如果任务状态码为0则任务成功
if new_job["status_code"] == 0: if new_job["status_code"] == 0:
target_job.status = JobStatus.JOB_SUCCESS target_job.status = JobStatus.JOB_SUCCESS
target_job.info("Query info_msg:{}".format(new_job["info_msg"])) target_job.info("Query info_msg:{}".format(new_job["info_msg"]))
# 否则任务失败
else: else:
target_job.status = JobStatus.JOB_FAILED target_job.status = JobStatus.JOB_FAILED
target_job.error("Query info_msg:{}".format(new_job["info_msg"])) target_job.error("Query info_msg:{}".format(new_job["info_msg"]))
# 输出错误信息
if "err_args" in new_job: if "err_args" in new_job:
target_job.error("Query err_args:{}".format(new_job["err_args"])) target_job.error("Query err_args:{}".format(new_job["err_args"]))
if "except_msg" in new_job: if "except_msg" in new_job:
@ -429,7 +587,9 @@ class TbeJobManager:
if "except_tuple_msg" in new_job: if "except_tuple_msg" in new_job:
target_job.error_manager(new_job["except_tuple_msg"]) target_job.error_manager(new_job["except_tuple_msg"])
target_job.error("\nOriginal compile json: \n {}\n".format(target_job.json_string)) target_job.error("\nOriginal compile json: \n {}\n".format(target_job.json_string))
# 将任务添加到已完成任务列表
post_job(self._raw_finish_jobs, target_job) post_job(self._raw_finish_jobs, target_job)
# 从运行中任务列表中删除任务
del_job(self._running_jobs, target_job.source_id, target_job.id) del_job(self._running_jobs, target_job.source_id, target_job.id)
def add_to_finished_jobs(self, job, status): def add_to_finished_jobs(self, job, status):
@ -456,8 +616,11 @@ class TbeJobManager:
class TuneMode(Enum): class TuneMode(Enum):
"""Class of tune mode: NO_TUNE, GA, RL""" """Class of tune mode: NO_TUNE, GA, RL"""
# 不调优模式
NO_TUNE = "NO_TUNE" NO_TUNE = "NO_TUNE"
# 遗传算法调优模式
GA_TUNE = "GA" GA_TUNE = "GA"
# 强化学习调优模式
RL_TUNE = "RL" RL_TUNE = "RL"
@ -469,18 +632,22 @@ class DummyLogger:
@staticmethod @staticmethod
def debug(msg, *args, **kwargs): def debug(msg, *args, **kwargs):
"""Debug级别日志"""
pass pass
@staticmethod @staticmethod
def info(msg, *args, **kwargs): def info(msg, *args, **kwargs):
"""Info级别日志"""
pass pass
@staticmethod @staticmethod
def warning(msg, *args, **kwargs): def warning(msg, *args, **kwargs):
"""Warning级别日志"""
pass pass
@staticmethod @staticmethod
def error(msg, *args, **kwargs): def error(msg, *args, **kwargs):
"""Error级别日志"""
pass pass
@staticmethod @staticmethod
@ -497,10 +664,13 @@ def get_job(jobs, source_id, job_id):
:return: job instance if found in job list :return: job instance if found in job list
None if not found in job list None if not found in job list
""" """
# 如果source_id不在jobs的键中返回None
if source_id not in jobs.keys(): if source_id not in jobs.keys():
return None return None
# 如果job_id不在jobs[source_id]的键中返回None
if job_id not in jobs[source_id].keys(): if job_id not in jobs[source_id].keys():
return None return None
# 返回jobs[source_id][job_id]
return jobs[source_id][job_id] return jobs[source_id][job_id]
@ -526,9 +696,15 @@ def del_job(jobs, source_id, job_id):
:param job_id: target job's job_id :param job_id: target job's job_id
:return: bool True or False :return: bool True or False
""" """
# 判断source_id是否在jobs字典中
if source_id not in jobs.keys(): if source_id not in jobs.keys():
# 如果不在返回False
return False return False
# 判断job_id是否在jobs[source_id]字典中
if job_id not in jobs[source_id].keys(): if job_id not in jobs[source_id].keys():
# 如果不在返回False
return False return False
# 删除jobs[source_id]字典中的job_id键值对
del jobs[source_id][job_id] del jobs[source_id][job_id]
# 返回True
return True return True

@ -26,6 +26,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
get_ms_class_attr) get_ms_class_attr)
# 导入parser模块中的所有函数和类
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type', 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol', 'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol',

@ -16,131 +16,136 @@
# ============================================================================ # ============================================================================
"""Define the namespace of parse.""" """Define the namespace of parse."""
import builtins import builtins # 导入内置模块builtins
from mindspore import log as logger
from mindspore import log as logger # 从mindspore库导入log模块并重命名为logger
class Namespace: class Namespace:
""" """
Base class of namespace for resolve variables. 基类用于解析变量命名空间
Args: Args:
name (str): The namespace's name. name (str): 命名空间的名称
dicts (dict): A list of dict containing the namespace's variable. dicts (dict): 包含命名空间变量的字典列表
""" """
def __init__(self, name, *dicts): def __init__(self, name, *dicts):
self.name = name self.name = name # 初始化命名空间名称
self.dicts = dicts self.dicts = dicts # 初始化包含变量的字典列表
def __contains__(self, name): def __contains__(self, name):
# 检查命名空间中是否包含指定名称的变量
for d in self.dicts: for d in self.dicts:
if name in d: if name in d:
return True return True
return False return False
def __getitem__(self, name): def __getitem__(self, name):
# 获取命名空间中指定名称的变量
for d in self.dicts: for d in self.dicts:
if name in d: if name in d:
return d[name] return d[name]
raise NameError(name) raise NameError(name) # 如果未找到抛出NameError
def __repr__(self): def __repr__(self):
# 返回命名空间的字符串表示
return f'Namespace:{self.name}' return f'Namespace:{self.name}'
class CellNamespace(Namespace): class CellNamespace(Namespace):
""" """
Namespace for Cell object. Cell对象的命名空间
Args: Args:
name (str): Valid module name, it can be imported. name (str): 可导入的有效模块名称
""" """
def __init__(self, name): def __init__(self, name):
mod_dict = vars(__import__(name, fromlist=['_'])) mod_dict = vars(__import__(name, fromlist=['_'])) # 导入模块并获取其变量字典
builtins_dict = vars(builtins) builtins_dict = vars(builtins) # 获取内置模块的变量字典
super().__init__(name, mod_dict, builtins_dict) super().__init__(name, mod_dict, builtins_dict) # 调用父类初始化
def __getstate__(self): def __getstate__(self):
# 获取对象的状态,用于序列化
return (self.name,) return (self.name,)
def __setstate__(self, state): def __setstate__(self, state):
# 设置对象的状态,用于反序列化
name, = state name, = state
mod_dict = vars(__import__(name, fromlist=['_'])) mod_dict = vars(__import__(name, fromlist=['_'])) # 重新导入模块
builtins_dict = vars(builtins) builtins_dict = vars(builtins) # 重新获取内置模块字典
super().__init__(name, mod_dict, builtins_dict) super().__init__(name, mod_dict, builtins_dict) # 重新初始化父类
class ClosureNamespace(Namespace): class ClosureNamespace(Namespace):
""" """
Namespace for function closure. 函数闭包的命名空间
Args: Args:
fn (Function): A python function. fn (Function): 一个Python函数
""" """
def __init__(self, fn): def __init__(self, fn):
name = f'{fn.__module__}..<{fn.__name__}>' name = f'{fn.__module__}..<{fn.__name__}>' # 构造命名空间名称
names = fn.__code__.co_freevars names = fn.__code__.co_freevars # 获取函数的自由变量名称
cells = fn.__closure__ cells = fn.__closure__ # 获取函数的闭包
ns = dict(zip(names, cells or ())) ns = dict(zip(names, cells or ())) # 构造命名空间字典
super().__init__(name, ns) super().__init__(name, ns) # 调用父类初始化
def __getitem__(self, name): def __getitem__(self, name):
# 获取命名空间中指定名称的变量
d, = self.dicts d, = self.dicts
try: try:
return d[name].cell_contents return d[name].cell_contents # 返回闭包内容
except ValueError: except ValueError:
raise UnboundLocalError(name) raise UnboundLocalError(name) # 如果未找到抛出UnboundLocalError
class ClassMemberNamespace(Namespace): class ClassMemberNamespace(Namespace):
""" """
Namespace of a class's closure. 类闭包的命名空间
Args: Args:
obj (Object): A python class object. obj (Object): 一个Python类对象
""" """
def __init__(self, obj): def __init__(self, obj):
self.__class_member_namespace__ = True self.__class_member_namespace__ = True # 标记为类成员命名空间
label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' # 构造命名空间标签
super().__init__(label, obj) super().__init__(label, obj) # 调用父类初始化
def __getitem__(self, name): def __getitem__(self, name):
# 获取命名空间中指定名称的变量
d, = self.dicts d, = self.dicts
if name == "self": if name == "self":
return d return d # 如果名称是self返回对象本身
if name == "namespace": if name == "namespace":
return self return self # 如果名称是namespace返回命名空间对象
try: try:
if hasattr(d, name): if hasattr(d, name):
return getattr(d, name) return getattr(d, name) # 如果对象有该属性,返回属性值
return d.__dict__[name] return d.__dict__[name] # 否则从对象字典中获取
except ValueError: except ValueError:
raise UnboundLocalError(name) raise UnboundLocalError(name) # 如果未找到抛出UnboundLocalError
except KeyError: except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name) raise AttributeError(name) # 如果未找到属性记录日志并抛出AttributeError
class ClassAttrNamespace(Namespace): class ClassAttrNamespace(Namespace):
""" """
Namespace of a class. 类的命名空间
Args: Args:
obj (Object): A python class object. obj (Object): 一个Python类对象
""" """
def __init__(self, obj): def __init__(self, obj):
name = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' name = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' # 构造命名空间名称
super().__init__(name, obj) super().__init__(name, obj) # 调用父类初始化
def __getattr__(self, name): def __getattr__(self, name):
# 获取命名空间中指定名称的属性
d, = self.dicts d, = self.dicts
try: try:
if hasattr(d, name): if hasattr(d, name):
return getattr(d, name) return getattr(d, name) # 如果对象有该属性,返回属性值
return d.__dict__[name] return d.__dict__[name] # 否则从对象字典中获取
except ValueError: except ValueError:
raise UnboundLocalError(name) raise UnboundLocalError(name) # 如果未找到抛出UnboundLocalError
except KeyError: except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name) raise AttributeError(name) # 如果未找到属性记录日志并抛出AttributeError

@ -34,7 +34,7 @@ trope_ns = CellNamespace('mindspore._extends.parse.trope')
NO_IMPLEMENT = None # not implemented NO_IMPLEMENT = None # not implemented
SYMBOL_UNDEFINE = 0xFF # Undefined var and function SYMBOL_UNDEFINE = 0xFF # Undefined var and function
# Some space set aside for readability of code # 一些空间设置以提高代码可读性
parse_object_map = { parse_object_map = {
# ast grammar # ast grammar
ast.Add: (trope_ns, 'add'), ast.Add: (trope_ns, 'add'),
@ -93,8 +93,8 @@ ops_symbol_map = {
SYMBOL_UNDEFINE: '', SYMBOL_UNDEFINE: '',
} }
# Escape an object to another object, eg: system function(len,xxx) # 将一个对象转为另一个对象,例如:系统函数(len,xxx)
# Some space set aside for readability of code # 一些空间设置以提高代码可读性
convert_object_map = { convert_object_map = {
T.add: multitype_ops.add, T.add: multitype_ops.add,
T.sub: multitype_ops.sub, T.sub: multitype_ops.sub,
@ -162,5 +162,6 @@ convert_object_map = {
CSRTensor: F.make_csr_tensor CSRTensor: F.make_csr_tensor
} }
# 如果不启用安全性,则将 T.print 映射到 F.print_
if not security.enable_security(): if not security.enable_security():
convert_object_map[T.print] = F.print_ convert_object_map[T.print] = F.print_

@ -50,55 +50,45 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
def MakeTuple(*elts): # pragma: no cover def MakeTuple(*elts): # pragma: no cover
"""Tuple builder.""" """Tuple builder.""" # 创建元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_dict(key, value): # pragma: no cover def make_dict(key, value): # pragma: no cover
"""Dict builder.""" """Dict builder.""" # 创建字典的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_list(*elts): # pragma: no cover def make_list(*elts): # pragma: no cover
"""List builder.""" """List builder.""" # 创建列表的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_slice(*elts): # pragma: no cover def make_slice(*elts): # pragma: no cover
"""Slice builder.""" """Slice builder.""" # 创建切片的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_range(*elts): # pragma: no cover def make_range(*elts): # pragma: no cover
"""Range tuple builder.""" """Range tuple builder.""" # 创建范围元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def switch(cond, tb, fb): # pragma: no cover def switch(cond, tb, fb): # pragma: no cover
"""Switch statement, returns one of the two values.""" """Switch statement, returns one of the two values.""" # 返回两个值中的一个的开关语句
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def hasnext(it): # pragma: no cover def hasnext(it): # pragma: no cover
"""Hasnext function.""" """Hasnext function.""" # 判断是否有下一个元素的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def to_array(x): def to_array(x):
"""The to_array function.""" """The to_array function.""" # 将输入转换为数组的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def not_contains(x): # pragma: no cover def not_contains(x): # pragma: no cover
"""Not in function.""" """Not in function.""" # 判断元素是否不在集合中的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def while_cond(x): # pragma: no cover def while_cond(x): # pragma: no cover
"""Not in function.""" """Not in function.""" # 判断条件是否成立的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def bool_(x): # pragma: no cover def bool_(x): # pragma: no cover
"""judge true function.""" """judge true function.""" # 判断一个值是否为真值的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')

@ -23,6 +23,13 @@ class Messager:
'''Messager''' '''Messager'''
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
"""
初始化 Messager
Args:
fdin: 输入文件描述符
fdout: 输出文件描述符
"""
self.fdin = fdin self.fdin = fdin
self.fdout = fdout self.fdout = fdout
self.fin = os.fdopen(fdin, "r") self.fin = os.fdopen(fdin, "r")
@ -30,12 +37,15 @@ class Messager:
self.message = '' self.message = ''
def __del__(self): def __del__(self):
"""
删除 Messager 实例时关闭文件描述符
"""
os.close(self.fdin) os.close(self.fdin)
os.close(self.fdout) os.close(self.fdout)
def get_message(self): def get_message(self):
""" """
Get message from remote 从远程获取消息
Returns: Returns:
message message
@ -61,10 +71,10 @@ class Messager:
def send_res(self, res, keep_format=True): def send_res(self, res, keep_format=True):
""" """
Send result to remote 发送结果到远程
Args: Args:
keep_format: True or False keep_format: True False
""" """
logger.debug(f"[OUT] {str(res)}") logger.debug(f"[OUT] {str(res)}")
if keep_format: if keep_format:
@ -85,10 +95,10 @@ class Messager:
def send_ack(self, success=True): def send_ack(self, success=True):
""" """
Send ack to remote 发送确认消息到远程
Args: Args:
success: True or False success: True False
""" """
if success: if success:
self.send_res('ACK') self.send_res('ACK')
@ -97,29 +107,30 @@ class Messager:
def loop(self): def loop(self):
""" """
Messaging loop 消息循环
""" """
while True: while True:
self.handle() self.handle()
def run(self): def run(self):
"""运行消息循环"""
self.loop() self.loop()
def handle(self): def handle(self):
""" """
A interface communicates with remote. 与远程通信的接口
Note: Note:
All subclasses should override this interface. 所有子类应该重写此接口
""" """
raise NotImplementedError raise NotImplementedError
def exit(self): def exit(self):
""" """
A interface handles the procedure before exit. 处理退出之前的程序
Note: Note:
All subclasses should override this interface. 所有子类应该重写此接口
""" """
raise NotImplementedError raise NotImplementedError
@ -128,23 +139,29 @@ class AkgBuilder():
"""Akg building wrapper""" """Akg building wrapper"""
def __init__(self, platform): def __init__(self, platform):
"""
初始化 AkgBuilder
Args:
platform: 平台标识
"""
self.platform = platform self.platform = platform
self.attrs = None self.attrs = None
def create(self, process_num, waitime): def create(self, process_num, waitime):
""" Create akg processor""" """ 创建 akg 处理器"""
self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform) self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform)
def accept_json(self, json): def accept_json(self, json):
""" Accept json""" """ 接受 json 数据"""
return self.akg_processor.accept_json(json) return self.akg_processor.accept_json(json)
def compile(self): def compile(self):
"""Compile""" """编译"""
return self.akg_processor.compile(self.attrs) return self.akg_processor.compile(self.attrs)
def handle(self, messager, arg): def handle(self, messager, arg):
"""Handle message about akg""" """处理关于 akg 的消息"""
if arg == 'AKG/START': if arg == 'AKG/START':
messager.send_ack() messager.send_ack()
process_num_str = messager.get_message() process_num_str = messager.get_message()
@ -175,4 +192,5 @@ class AkgBuilder():
def get_logger(): def get_logger():
"""获取日志记录器"""
return logger return logger

@ -20,19 +20,24 @@ from mindspore._extends.remote.kernel_build_server import Messager, get_logger,
class AkgMessager(Messager): class AkgMessager(Messager):
''' '''
Default Messager for akg kernels. 默认的 akg 内核消息处理器
It works as a server, communicating with c++ client. 它作为一个服务器 C++ 客户端进行通信
''' '''
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
"""
初始化 AkgMessager 实例
:param fdin: 输入文件描述符
:param fdout: 输出文件描述符
"""
super().__init__(fdin, fdout) super().__init__(fdin, fdout)
get_logger().info("[TRACE] Akg Messager init...") get_logger().info("[TRACE] Akg Messager init...")
self.akg_builder = AkgBuilder("default") self.akg_builder = AkgBuilder("default")
def handle(self): def handle(self):
""" """
Communicate with remote client. 与远程客户端进行通信
Reference protocol between them at PR#4063 它们之间的参考协议见 PR#4063。
""" """
arg = self.get_message() arg = self.get_message()
if "AKG" in arg: if "AKG" in arg:
@ -42,11 +47,18 @@ class AkgMessager(Messager):
self.exit() self.exit()
def exit(self): def exit(self):
"""
退出 AkgMessager
"""
get_logger().info("[TRACE] Akg Messager Exit...") get_logger().info("[TRACE] Akg Messager Exit...")
exit() exit()
if __name__ == '__main__': if __name__ == '__main__':
"""
程序入口
检查命令行参数并初始化 AkgMessager 实例
"""
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
if len(sys.argv) != 3: if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv)) raise Exception('Incorrect argv: {}'.format(sys.argv))

@ -26,13 +26,14 @@ class AscendMessager(Messager):
Ascend Messager Ascend Messager
It works as a server, communicating with c++ client. It works as a server, communicating with c++ client.
""" """
# 初始化方法
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
super().__init__(fdin, fdout) super().__init__(fdin, fdout)
get_logger().info("[TRACE] Ascend Messager init...") get_logger().info("[TRACE] Ascend Messager init...")
self.tbe_builder = TbeJobManager() self.tbe_builder = TbeJobManager()
self.akg_builder = AkgBuilder("ASCEND") self.akg_builder = AkgBuilder("ASCEND")
# 处理与远程客户端的通信
def handle(self): def handle(self):
""" """
Communicate with remote client. Communicate with remote client.
@ -60,6 +61,7 @@ class AscendMessager(Messager):
self.send_ack(False) self.send_ack(False)
self.exit() self.exit()
# 退出方法
def exit(self): def exit(self):
self.tbe_builder.reset() self.tbe_builder.reset()
get_logger().info("[TRACE] Ascend Messager Exit...") get_logger().info("[TRACE] Ascend Messager Exit...")

@ -22,6 +22,21 @@ def cell_attr_register(fn=None, attrs=None):
""" """
Cell init attributes register. Cell init attributes register.
Args:
fn (function, optional): The __init__ function of the cell. Defaults to None.
attrs (list(string) | string, optional): A list of attributes to register.
Can be a list of strings or a single string. Defaults to None.
Returns:
function: The original function wrapped with attribute registration.
该函数用于注册cell类的初始化属性
通过装饰器模式将cell类的__init__函数的参数保存为operator的属性
如果未提供fn参数则返回装饰器函数wrap_cell否则返回包装后的__init__函数
"""
"""
Cell init attributes register.
Registering the decorator of the built-in operator cell __init__ Registering the decorator of the built-in operator cell __init__
function will add save all the parameters of __init__ as operator attributes. function will add save all the parameters of __init__ as operator attributes.
@ -34,8 +49,38 @@ def cell_attr_register(fn=None, attrs=None):
""" """
def wrap_cell(fn): def wrap_cell(fn):
"""
装饰器函数用于记录类的初始化参数
Args:
fn (function): 需要被装饰的函数
Returns:
function: 返回一个新的函数该函数在调用时会记录传递给fn函数的参数
"""
@wraps(fn) @wraps(fn)
def deco(self, *args, **kwargs): def deco(self, *args, **kwargs):
"""
这是一个装饰器函数用于记录类的初始化参数
Args:
self: 类实例对象
*args: 传递给被装饰函数的可变位置参数
**kwargs: 传递给被装饰函数的可变关键字参数
attrs: 可选参数指定要记录的属性可以是字符串或字符串列表
Returns:
None
Raises:
ValueError: 如果attrs不是字符串或字符串列表或者attrs中的元素不是字符串时抛出
该函数的主要作用是在类实例初始化时记录传递给__init__方法的参数
如果attrs为None则记录所有传递给__init__方法的参数不包括self
如果attrs为字符串或字符串列表则只记录指定的属性
记录的参数将被保存为实例的cell_init_args属性格式为"类名+参数列表"
"""
arguments = [] arguments = []
if attrs is None: if attrs is None:
bound_args = inspect.signature(fn).bind(self, *args, **kwargs) bound_args = inspect.signature(fn).bind(self, *args, **kwargs)

@ -19,16 +19,25 @@ accumulation and so on.
Note: Note:
This feature is a beta feature, and we are still improving its functionality. This feature is a beta feature, and we are still improving its functionality.
""" """
# 从当前包的boost模块导入AutoBoost类
from .boost import AutoBoost from .boost import AutoBoost
# 从当前包的base模块导入OptimizerProcess和ParameterProcess类
from .base import OptimizerProcess, ParameterProcess from .base import OptimizerProcess, ParameterProcess
# 从当前包的boost_cell_wrapper模块导入BoostTrainOneStepCell和BoostTrainOneStepWithLossScaleCell类
from .boost_cell_wrapper import BoostTrainOneStepCell, BoostTrainOneStepWithLossScaleCell from .boost_cell_wrapper import BoostTrainOneStepCell, BoostTrainOneStepWithLossScaleCell
# 从当前包的less_batch_normalization模块导入LessBN类
from .less_batch_normalization import LessBN from .less_batch_normalization import LessBN
# 从当前包的grad_freeze模块导入GradientFreeze, FreezeOpt和freeze_cell类或函数
from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell
# 从当前包的grad_accumulation模块导入GradientAccumulation类
from .grad_accumulation import GradientAccumulation from .grad_accumulation import GradientAccumulation
# 从当前包的adasum模块导入AdaSum类
from .adasum import AdaSum from .adasum import AdaSum
# 从当前包的dim_reduce模块导入DimReduce类
from .dim_reduce import DimReduce from .dim_reduce import DimReduce
# 定义一个列表,包含所有要公开的模块成员
__all__ = ['AutoBoost', __all__ = ['AutoBoost',
'OptimizerProcess', 'ParameterProcess', 'OptimizerProcess', 'ParameterProcess',
'BoostTrainOneStepCell', 'BoostTrainOneStepWithLossScaleCell', 'BoostTrainOneStepCell', 'BoostTrainOneStepWithLossScaleCell',

@ -35,6 +35,7 @@ _update_parameters = C.MultitypeFuncGraph("update_parameters")
@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor") @_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor")
def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter): def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter):
"""更新参数的函数在广播后应用delta_weight来更新参数."""
shape = F.shape(delta_weight) shape = F.shape(delta_weight)
update_delta_weight = P.Reshape()(update_delta_weight, shape) update_delta_weight = P.Reshape()(update_delta_weight, shape)
new_parameter = old_parameter - update_delta_weight new_parameter = old_parameter - update_delta_weight
@ -42,18 +43,20 @@ def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parame
def _send_before_receive(send_part, send, recv): def _send_before_receive(send_part, send, recv):
"""在接收之前发送数据的辅助函数."""
send_ok = send(send_part) send_ok = send(send_part)
return recv(send_ok) return recv(send_ok)
def _receive_before_send(send_part, send, recv): def _receive_before_send(send_part, send, recv):
"""在发送之前接收数据的辅助函数."""
receive_ok = recv(send_part) receive_ok = recv(send_part)
send_part = F.depend(send_part, receive_ok) send_part = F.depend(send_part, receive_ok)
return F.depend(receive_ok, send(send_part)) return F.depend(receive_ok, send(send_part))
def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num): def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
"""send result and receive result.""" """发送结果并接收结果的辅助函数."""
if parameter_divisibility: if parameter_divisibility:
recv_part = P.Squeeze()(recv_part) recv_part = P.Squeeze()(recv_part)
local_part = F.depend(local_part, recv_part) local_part = F.depend(local_part, recv_part)
@ -83,7 +86,7 @@ _adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor") @_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w): def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
"""adasum optimizer process.""" """adaSum优化器的前向过程处理函数."""
if parameter_divisibility: if parameter_divisibility:
delta_w = P.Squeeze()(delta_w) delta_w = P.Squeeze()(delta_w)
ori_len = F.shape(delta_w)[0] ori_len = F.shape(delta_w)[0]
@ -117,7 +120,7 @@ _adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function") @_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv): def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
"""adasum optimizer rollback process.""" """adaSum优化器的回滚处理函数."""
if parameter_divisibility: if parameter_divisibility:
if left_send: if left_send:
recv_part = _send_before_receive(delta_w, send, recv) recv_part = _send_before_receive(delta_w, send, recv)
@ -139,24 +142,24 @@ def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, sen
class AdaSum(Cell): class AdaSum(Cell):
r""" r"""
The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data 自适应加法AdaSum是一种新算法用于改善深度学习模型的分布式数据并行训练
parallel training of Deep Learning models.
Args: Args:
rank (int): Rank number. rank (int): 排名编号
device_number (int): Device number. device_number (int): 设备数量
group_number (int): Group number. group_number (int): 组数量
parameter_tuple (Tuple(Parameter)): Tuple of parameters. parameter_tuple (Tuple(Parameter)): 参数元组
Inputs: Inputs:
- **delta_weights** (Tuple(Tensor)) - Tuple of gradients. - **delta_weights** (Tuple(Tensor)) - 梯度的元组
- **parameters** (Tuple(Parameter)) - Tuple of current parameters. - **parameters** (Tuple(Parameter)) - 当前参数的元组
- **old_parameters** (Tuple(Parameter)) - Tuple of last parameters. - **old_parameters** (Tuple(Parameter)) - 上一参数的元组
Outputs: Outputs:
- **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process. - **adasum_parameters** (Tuple(Tensor)) - 经过adasum处理后的参数元组
""" """
def __init__(self, rank, device_number, group_number, parameter_tuple): def __init__(self, rank, device_number, group_number, parameter_tuple):
"""AdaSum类的初始化函数."""
super(AdaSum, self).__init__() super(AdaSum, self).__init__()
self.rank = rank self.rank = rank
self.device_number = device_number self.device_number = device_number
@ -166,7 +169,7 @@ class AdaSum(Cell):
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
def _generate_communication_op(self): def _generate_communication_op(self):
"""generate communication op.""" """生成通信操作的私有方法."""
self.calc_times = int(math.log(self.group_number, 2)) self.calc_times = int(math.log(self.group_number, 2))
self.send_node = [] self.send_node = []
self.send_list_forward = [] self.send_list_forward = []
@ -267,7 +270,7 @@ class AdaSum(Cell):
self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name) self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name)
def _get_delta_weights_info(self, last_delta_weights): def _get_delta_weights_info(self, last_delta_weights):
"""get delta weights info.""" """获取delta权重信息的私有方法."""
half_delta_weights = [] half_delta_weights = []
if last_delta_weights: if last_delta_weights:
half_delta_weights = last_delta_weights half_delta_weights = last_delta_weights
@ -294,12 +297,14 @@ class AdaSum(Cell):
return left_delta_weights, right_delta_weights, delta_weights_divisibility return left_delta_weights, right_delta_weights, delta_weights_divisibility
def _hash(self, step, target, weights_index): def _hash(self, step, target, weights_index):
"""计算哈希值的私有方法."""
target = "tag" + str(step) + str(target) + str(weights_index) target = "tag" + str(step) + str(target) + str(weights_index)
target_hash = hashlib.sha1(target.encode()).hexdigest() target_hash = hashlib.sha1(target.encode()).hexdigest()
hash_res = int(int(target_hash, 16) % MAX_NUM_HASH) hash_res = int(int(target_hash, 16) % MAX_NUM_HASH)
return hash_res return hash_res
def construct(self, delta_weights, parameters, old_parameters): def construct(self, delta_weights, parameters, old_parameters):
"""构建方法用于执行adaSum优化过程."""
forward_weights = [delta_weights] forward_weights = [delta_weights]
for i in range(self.calc_times): for i in range(self.calc_times):
process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]), process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),

@ -19,23 +19,30 @@ from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib from ._hccl_management import load_lib as hccl_load_lib
from .._c_expression import get_rank_id, get_rank_size from .._c_expression import get_rank_id, get_rank_size
# 检查HCCL是否可用
_HCCL_AVAILABLE = False _HCCL_AVAILABLE = False
# 检查HCCL测试是否可用
_HCCL_TEST_AVAILABLE = False _HCCL_TEST_AVAILABLE = False
# 检查NCCL是否可用
_NCCL_AVAILABLE = False _NCCL_AVAILABLE = False
# 检查MPI是否可用
_MPI_AVAILABLE = False _MPI_AVAILABLE = False
try: try:
# 尝试导入mindspore._ms_mpi如果成功则NCCL可用
import mindspore._ms_mpi as mpi import mindspore._ms_mpi as mpi
_NCCL_AVAILABLE = True _NCCL_AVAILABLE = True
except ImportError: except ImportError:
# 如果导入失败则NCCL不可用
_NCCL_AVAILABLE = False _NCCL_AVAILABLE = False
# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True否则捕获 RuntimeError 并设置为 False
try: try:
hccl_load_lib() hccl_load_lib()
_HCCL_AVAILABLE = True _HCCL_AVAILABLE = True
except RuntimeError: except RuntimeError:
_HCCL_AVAILABLE = False _HCCL_AVAILABLE = False
# 如果 HCCL 可用,则导入 _hccl_management 并尝试导入 mindspore._ascend_mpi如果成功则设置 _MPI_AVAILABLE 为 True否则捕获 ImportError 并设置为 False
if _HCCL_AVAILABLE: if _HCCL_AVAILABLE:
from . import _hccl_management as hccl from . import _hccl_management as hccl
try: try:
@ -43,6 +50,7 @@ if _HCCL_AVAILABLE:
_MPI_AVAILABLE = True _MPI_AVAILABLE = True
except ImportError: except ImportError:
_MPI_AVAILABLE = False _MPI_AVAILABLE = False
# 如果 HCCL 不可用,则尝试导入 hccl_test.manage.api如果成功则设置 _HCCL_AVAILABLE 和 _HCCL_TEST_AVAILABLE 为 True否则捕获 ImportError 并设置 _HCCL_AVAILABLE 为 False
else: else:
try: try:
import hccl_test.manage.api as hccl import hccl_test.manage.api as hccl
@ -51,11 +59,10 @@ else:
except ImportError: except ImportError:
_HCCL_AVAILABLE = False _HCCL_AVAILABLE = False
# 定义 HCCL 和 NCCL 的通信组名称常量
HCCL_WORLD_COMM_GROUP = "hccl_world_group" HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group" NCCL_WORLD_COMM_GROUP = "nccl_world_group"
class Backend: class Backend:
""" """
Class for available backends. Class for available backends.
@ -82,13 +89,17 @@ class Backend:
def __new__(cls, name): def __new__(cls, name):
"""Create instance object of Backend.""" """Create instance object of Backend."""
# 检查传入的name是否为字符串类型
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError("For 'Backend', the class variable 'name' must be a string, " raise TypeError("For 'Backend', the class variable 'name' must be a string, "
"but got the type : {}".format(type(name))) "but got the type : {}".format(type(name)))
# 获取对应name的大写形式的Backend类属性值如果不存在则返回Backend.UNDEFINED
value = getattr(Backend, name.upper(), Backend.UNDEFINED) value = getattr(Backend, name.upper(), Backend.UNDEFINED)
# 如果获取到的值是Backend.UNDEFINED说明传入的name不被支持
if value == Backend.UNDEFINED: if value == Backend.UNDEFINED:
raise ValueError("For 'Backend', the class variable 'name' {} is not supported, " raise ValueError("For 'Backend', the class variable 'name' {} is not supported, "
"please use hccl or nccl.".format(name)) "please use hccl or nccl.".format(name))
# 返回获取到的Backend类属性值
return value return value
DEFAULT_BACKEND = Backend("hccl") DEFAULT_BACKEND = Backend("hccl")
@ -103,36 +114,35 @@ class GlobalComm:
""" """
BACKEND = DEFAULT_BACKEND BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False INITED = False # 标记全局通信是否已初始化
CHECK_ENVS = True CHECK_ENVS = True # 标记是否需要检查通信环境变量
class _ExistingGroup: class _ExistingGroup:
""" """
The communication groups which exist in the progress. 用于表示在程序运行过程中存在的通信组
""" """
ITEMS = {} ITEMS = {} # 存储通信组的字典,键为通信组的标识符,值为通信组对象
def is_hccl_available(): def is_hccl_available():
""" """
Check HCCL api is available. 检查HCCL API是否可用
Returns: Returns:
Boolean. Return whether HCCL is available or not. Boolean: 返回HCCL是否可用
""" """
return _HCCL_AVAILABLE return _HCCL_AVAILABLE # 返回一个布尔值指示HCCL是否可用
def is_mpi_available(): def is_mpi_available():
""" """
Check HCCL & MPI api is available. 检查HCCL和MPI API是否可用
Returns: Returns:
Boolean. Return whether HCCL & MPI is available or not. Boolean: 返回HCCL和MPI是否同时可用
""" """
return _MPI_AVAILABLE return _MPI_AVAILABLE # 返回一个布尔值指示HCCL和MPI是否同时可用
def is_nccl_available(): def is_nccl_available():
""" """
@ -158,19 +168,25 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error. Wrapper. If not available, raise Error.
""" """
def wrapper(*args, **kargs): def wrapper(*args, **kargs):
# 如果当前角色是参数服务器或者调度器,直接调用被装饰的函数
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs) return func(*args, **kargs)
# 检查分布式通信是否已经初始化,未初始化则抛出异常
if not GlobalComm.INITED: if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited") raise RuntimeError("Distributed Communication has not been inited")
# 获取参数组默认为None
group = None group = None
# 检查关键字参数中是否包含"group",并进行类型检查
if "group" in kargs.keys(): if "group" in kargs.keys():
group = kargs.get("group") group = kargs.get("group")
if group is not None and not isinstance(group, str): if group is not None and not isinstance(group, str):
raise TypeError("The parameter 'group' should be str or None, " raise TypeError("The parameter 'group' should be str or None, "
"but got the type : {}".format(type(group))) "but got the type : {}".format(type(group)))
# 获取后端默认为None
if "backend" in kargs.keys(): if "backend" in kargs.keys():
backend = kargs.get("backend") backend = kargs.get("backend")
# 检查后端是否可用,不可用则抛出异常
if backend is Backend.HCCL and not is_hccl_available(): if backend is Backend.HCCL and not is_hccl_available():
raise RuntimeError("Distributed Communication doesn't have HCCL built in") raise RuntimeError("Distributed Communication doesn't have HCCL built in")
if backend is Backend.HCCL_MPI and not is_mpi_available(): if backend is Backend.HCCL_MPI and not is_mpi_available():
@ -178,15 +194,16 @@ def check_parameter_available(func):
if backend is Backend.NCCL and not is_nccl_available(): if backend is Backend.NCCL and not is_nccl_available():
raise RuntimeError("Distributed Communication doesn't have NCCL built in") raise RuntimeError("Distributed Communication doesn't have NCCL built in")
# 如果未指定group根据backend设置默认的group
if group is None: if group is None:
if backend is Backend.HCCL or Backend.HCCL_MPI: if backend is Backend.HCCL or backend is Backend.HCCL_MPI:
group = HCCL_WORLD_COMM_GROUP group = HCCL_WORLD_COMM_GROUP
elif backend is Backend.NCCL: elif backend is Backend.NCCL:
group = NCCL_WORLD_COMM_GROUP group = NCCL_WORLD_COMM_GROUP
# 调用被装饰的函数
return func(*args, **kargs) return func(*args, **kargs)
return wrapper return wrapper
@check_parameter_available @check_parameter_available
def _get_rank_helper(group, backend): def _get_rank_helper(group, backend):
""" """
@ -202,10 +219,13 @@ def _get_rank_helper(group, backend):
Returns: Returns:
Integer. The local rank id of the calling process. Integer. The local rank id of the calling process.
""" """
# 辅助函数,用于根据不同的后端和组获取 rank_id
# 获取当前角色的rank_id如果是参数服务器或调度器角色rank_id设为0并返回
rank_id = None rank_id = None
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
rank_id = 0 rank_id = 0
return rank_id return rank_id
# 根据不同的后端获取rank_id
if backend == Backend.HCCL_MPI: if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group) rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL: elif backend == Backend.HCCL:
@ -216,6 +236,7 @@ def _get_rank_helper(group, backend):
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
rank_id = get_rank_id(group) rank_id = get_rank_id(group)
else: else:
# 如果后端不被支持抛出ValueError异常
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, " raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi, hccl or nccl.".format(backend)) "please use hccl_mpi, hccl or nccl.".format(backend))
return rank_id return rank_id
@ -236,22 +257,30 @@ def _get_local_rank_helper(group, backend):
Returns: Returns:
Integer. The local rank id of the calling process. Integer. The local rank id of the calling process.
""" """
# 获取当前进程的rank id根据不同的后端和组进行处理
rank_id = None rank_id = None
# 根据不同的后端选择获取rank_id的方法
if backend == Backend.HCCL_MPI: if backend == Backend.HCCL_MPI:
# 使用HCCL MPI后端时通过mpi.get_rank_id获取rank_id
rank_id = mpi.get_rank_id(group) rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL: elif backend == Backend.HCCL:
# 使用HCCL后端时根据group的不同选择获取rank_id的方法
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
# 如果group是HCCL_WORLD_COMM_GROUP则使用hccl.get_local_rank_id获取rank_id
rank_id = hccl.get_local_rank_id() rank_id = hccl.get_local_rank_id()
else: else:
# 对于其他group同样使用hccl.get_local_rank_id获取rank_id
rank_id = hccl.get_local_rank_id(group) rank_id = hccl.get_local_rank_id(group)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
# 如果使用NCCL后端当前不支持get_local_rank_id方法抛出异常
raise RuntimeError("Nccl doesn't support get_local_rank_id now.") raise RuntimeError("Nccl doesn't support get_local_rank_id now.")
else: else:
# 如果backend既不是HCCL_MPI也不是HCCL抛出异常表示不支持的backend
raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, " raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi or hccl.".format(backend)) "please use hccl_mpi or hccl.".format(backend))
# 返回获取到的rank_id
return rank_id return rank_id
@check_parameter_available @check_parameter_available
def _get_size_helper(group, backend): def _get_size_helper(group, backend):
""" """
@ -268,9 +297,12 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group. Integer. The rank size of specified group.
""" """
size = None size = None
# 如果当前角色是参数服务器或调度器则将size设为1并返回
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
size = 1 size = 1
return size return size
# 根据不同的后端设置size的值
# 根据不同的后端获取组的大小
if backend == Backend.HCCL_MPI: if backend == Backend.HCCL_MPI:
size = mpi.get_rank_size(group) size = mpi.get_rank_size(group)
elif backend == Backend.HCCL: elif backend == Backend.HCCL:
@ -302,19 +334,23 @@ def _get_local_size_helper(group, backend):
Integer. The local rank size where the calling process is being within specified group. Integer. The local rank size where the calling process is being within specified group.
""" """
size = None size = None
# 根据不同的后端获取本地进程组的大小
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 如果组是全局通信组,则获取全局通信组的本地排名大小
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_local_rank_size() size = hccl.get_local_rank_size()
# 否则获取指定组的本地排名大小
else: else:
size = hccl.get_local_rank_size(group) size = hccl.get_local_rank_size(group)
# NCCL后端不支持获取本地排名大小抛出异常
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_local_rank_size now.") raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
# 对于不支持的后端,抛出异常
else: else:
raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, " raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, "
"please use hccl.".format(backend)) "please use hccl.".format(backend))
return size return size
@check_parameter_available @check_parameter_available
def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend): def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
""" """
@ -333,21 +369,26 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
Integer. A rank id in world communication group. Integer. A rank id in world communication group.
""" """
world_rank_id = None world_rank_id = None
# 检查 group_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(group_rank_id, int): if not isinstance(group_rank_id, int):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be" raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id))) " type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
# 根据不同的后端选择不同的逻辑处理方式
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 如果在 GPU 上使用 HCCL但 group 参数为 HCCL_WORLD_COMM_GROUP则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' " raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.") "should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl.get_world_rank_from_group_rank 方法获取 world_rank_id
world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id) world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
# 如果使用 NCCL则抛出 RuntimeError 表示不支持该操作
raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.") raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
else: else:
# 如果 backend 参数不支持,则抛出 ValueError 表示不支持该后端
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend)) raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取的 world_rank_id
return world_rank_id return world_rank_id
@check_parameter_available @check_parameter_available
def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend): def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
""" """
@ -366,21 +407,27 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
Integer. A rank id in user communication group. Integer. A rank id in user communication group.
""" """
group_rank_id = None group_rank_id = None
# 检查 world_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(world_rank_id, int): if not isinstance(world_rank_id, int):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, " raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
"but got 'world_rank_id' type : {}.".format(type(world_rank_id))) "but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
# 根据不同的后端处理获取 group_rank_id 的逻辑
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 检查 GPU 后端的 group 参数是否正确,如果不正确则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' " raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.") "should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl 模块的函数获取 group_rank_id
group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group) group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
# NCCL 后端不支持此操作,抛出 RuntimeError
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.") raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
# 如果后端不是 HCCL 或 NCCL则抛出 ValueError 表示不支持的后端
else: else:
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend)) raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取到的 group_rank_id
return group_rank_id return group_rank_id
@check_parameter_available @check_parameter_available
def _create_group_helper(group, rank_ids, backend): def _create_group_helper(group, rank_ids, backend):
""" """
@ -395,34 +442,46 @@ def _create_group_helper(group, rank_ids, backend):
TypeError: If rank_ids is not a list. TypeError: If rank_ids is not a list.
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
""" """
# 检查组是否已存在
if group in _ExistingGroup.ITEMS.keys(): if group in _ExistingGroup.ITEMS.keys():
# 如果组已存在且提供的rank_ids与存储的不一致抛出异常
if rank_ids != _ExistingGroup.ITEMS[group]: if rank_ids != _ExistingGroup.ITEMS[group]:
raise ValueError("The group {} has been created, the rank_list is {}, " raise ValueError("The group {} has been created, the rank_list is {}, "
"but current rank_list for the group is {}". "but current rank_list for the group is {}".
format(group, _ExistingGroup.ITEMS[group], rank_ids)) format(group, _ExistingGroup.ITEMS[group], rank_ids))
# 记录警告信息,提示组已存在
logger.warning("%r group has existed.", group) logger.warning("%r group has existed.", group)
return return
# 根据不同的后端创建组
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 检查rank_ids是否为列表类型
if not isinstance(rank_ids, list): if not isinstance(rank_ids, list):
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids))) "but got 'rank_ids' type : {}.".format(type(rank_ids)))
# 检查rank_ids的长度是否大于1
rank_size = len(rank_ids) rank_size = len(rank_ids)
if rank_size < 1: if rank_size < 1:
raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, " raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, "
"but got 'rank_ids' size : {}.".format(len(rank_ids))) "but got 'rank_ids' size : {}.".format(len(rank_ids)))
# 检查rank_ids中是否有重复的元素
if len(rank_ids) - len(list(set(rank_ids))) > 0: if len(rank_ids) - len(list(set(rank_ids))) > 0:
raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
# 使用HCCL创建组
hccl.create_group(group, rank_size, rank_ids) hccl.create_group(group, rank_size, rank_ids)
elif backend == Backend.HCCL_MPI: elif backend == Backend.HCCL_MPI:
# 使用HCCL_MPI创建组
mpi.create_group(group, rank_ids) mpi.create_group(group, rank_ids)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
# NCCL暂不支持创建组抛出异常
raise RuntimeError("Nccl doesn't support create_group now.") raise RuntimeError("Nccl doesn't support create_group now.")
else: else:
# 如果后端不支持,抛出异常
raise ValueError("The context configuration parameter 'backend' {} is not supported, " raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend)) "please use hccl.".format(backend))
_ExistingGroup.ITEMS[group] = rank_ids
# 将新创建的组及其rank_ids添加到_existingGroup中
_ExistingGroup.ITEMS[group] = rank_ids
@check_parameter_available @check_parameter_available
def _destroy_group_helper(group, backend): def _destroy_group_helper(group, backend):
""" """
@ -435,12 +494,17 @@ def _destroy_group_helper(group, backend):
Raises: Raises:
ValueError: If group is "hccl_world_group" or backend is invalid. ValueError: If group is "hccl_world_group" or backend is invalid.
""" """
# 根据后端类型销毁通信组
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 检查是否为 HCCL 的全局通信组
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("The hccl_world_group does not support destruction.") raise ValueError("The hccl_world_group does not support destruction.")
# 销毁指定的 HCCL 通信组
hccl.destroy_group(group) hccl.destroy_group(group)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
# 当前 NCCL 后端不支持销毁通信组
raise RuntimeError("Nccl doesn't support destroy_group now.") raise RuntimeError("Nccl doesn't support destroy_group now.")
else: else:
# 抛出错误,表示不支持的后端类型
raise ValueError("The context configuration parameter 'backend' {} is not supported, " raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend)) "please use hccl.".format(backend))

@ -24,7 +24,7 @@ MAX_RANK_NUM = 4096
HCCL_LIB = 'libhccl_plugin.so' HCCL_LIB = 'libhccl_plugin.so'
HCCL_LIB_CTYPES = "" HCCL_LIB_CTYPES = ""
# 检查集体通信组的名称是否合法
def check_group(group): def check_group(group):
""" """
A function that check if a collection communication group is legal. A function that check if a collection communication group is legal.
@ -41,23 +41,31 @@ def check_group(group):
raise TypeError("The type of communication group name must be type of string, " raise TypeError("The type of communication group name must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 检查集体通信中的排名编号是否合法
def check_rank_num(rank_num): def check_rank_num(rank_num):
""" """
A function that check if a collection communication rank number is legal.If not raise error. 检查通信集合中的排名编号是否合法如果不合法则抛出错误
参数:
rank_num: 需要检查的排名编号预期为整数类型
Returns: Returns:
None None: 该函数不返回任何值但可能会抛出异常
""" """
if isinstance(rank_num, (int)): # 检查 rank_num 是否为整数类型
if isinstance(rank_num, int):
# 检查 rank_num 是否在合法范围内大于0且小于等于 MAX_RANK_NUM
if rank_num > MAX_RANK_NUM or rank_num <= 0: if rank_num > MAX_RANK_NUM or rank_num <= 0:
raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and" # 如果 rank_num 不在合法范围内,抛出 ValueError 异常,并提供详细的错误信息
"less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num)) raise ValueError("对于 'create_group' 函数,参数 'rank_ids' 的大小必须大于0且"
"小于等于 {},但得到的 'rank_ids' 的大小为: {}".format(MAX_RANK_NUM, rank_num))
else: else:
raise TypeError("The argument 'rank_num' must be type of int, " # 如果 rank_num 不是整数类型,抛出 TypeError 异常,并提供详细的错误信息
"but got 'rank_num' type : {}.".format(type(rank_num))) raise TypeError("参数 'rank_num' 必须为整数类型,"
"但得到的 'rank_num' 类型为: {}".format(type(rank_num)))
#检查集体通信中的排名标识(rank id)是否合法
def check_rank_id(rank_id): def check_rank_id(rank_id):
""" """
A function that check if a collection communication rank id is legal.If not raise error. A function that check if a collection communication rank id is legal.If not raise error.
@ -65,30 +73,38 @@ def check_rank_id(rank_id):
Returns: Returns:
None None
""" """
# 检查rank_id是否为整数类型
if isinstance(rank_id, (int)): if isinstance(rank_id, (int)):
# 检查rank_id是否在有效范围内
if rank_id >= MAX_RANK_NUM or rank_id < 0: if rank_id >= MAX_RANK_NUM or rank_id < 0:
raise ValueError("The rand id in the communication group must be greater or equal 0 and " raise ValueError("The rand id in the communication group must be greater or equal 0 and "
"less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id)) "less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id))
else: else:
# 如果rank_id不是整数类型抛出类型错误
raise TypeError("The rand id in the communication group must be must be type of int, " raise TypeError("The rand id in the communication group must be must be type of int, "
"but got type value : {}.".format(type(rank_id))) "but got type value : {}.".format(type(rank_id)))
# 加载 HCCLHuawei Cloud Communication Library
def load_lib(): def load_lib():
"""load hccl lib""" """load hccl lib"""
try: try:
# 获取当前文件所在的目录
base_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.dirname(os.path.realpath(__file__))
# 构建库文件的路径
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB) lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
# 加载库文件
hccl_lib = ctypes.CDLL(lib_path) hccl_lib = ctypes.CDLL(lib_path)
except Exception: except Exception:
# 如果加载失败则抛出运行时错误
raise RuntimeError('Get hccl lib error.') raise RuntimeError('Get hccl lib error.')
# 将加载的库文件设置为全局变量
global HCCL_LIB_CTYPES global HCCL_LIB_CTYPES
HCCL_LIB_CTYPES = hccl_lib HCCL_LIB_CTYPES = hccl_lib
def c_str(string): def c_str(string):
"""Convert a python string to C string.""" """Convert a python string to C string."""
# 将字符串转换为C风格字符串
if not isinstance(string, str): if not isinstance(string, str):
string = string.decode('ascii') string = string.decode('ascii')
return ctypes.c_char_p(string.encode('utf-8')) return ctypes.c_char_p(string.encode('utf-8'))
@ -96,9 +112,9 @@ def c_str(string):
def c_array(ctype, values): def c_array(ctype, values):
"""Create ctypes array from a python array.""" """Create ctypes array from a python array."""
# 从Python数组创建ctypes数组
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
#用于创建包含指定数量和ID的HCCL通信组但不能创建世界组。
def create_group(group, rank_num, rank_ids): def create_group(group, rank_num, rank_ids):
""" """
Create group. Create group.
@ -112,28 +128,38 @@ def create_group(group, rank_num, rank_ids):
Returns: Returns:
None None
""" """
# 检查组的有效性
check_group(group) check_group(group)
# 检查排名数量的有效性
check_rank_num(rank_num) check_rank_num(rank_num)
# 检查rank_ids是否为列表类型
if isinstance(rank_ids, (list)): if isinstance(rank_ids, (list)):
# 确保rank_num与rank_ids的长度一致
if rank_num != len(rank_ids): if rank_num != len(rank_ids):
raise ValueError("The argument 'rank_num' number should be equal to the length " raise ValueError("The argument 'rank_num' number should be equal to the length "
"of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}." "of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}."
.format(rank_num, rank_ids)) .format(rank_num, rank_ids))
# 检查rank_ids中的每个元素是否为非负整数
for rank_id in rank_ids: for rank_id in rank_ids:
if not isinstance(rank_id, (int)) or rank_id < 0: if not isinstance(rank_id, (int)) or rank_id < 0:
raise ValueError("The elements of argument 'rank_ids' must be " raise ValueError("The elements of argument 'rank_ids' must be "
"unsigned integer, but got the type : {}".format(type(rank_id))) "unsigned integer, but got the type : {}".format(type(rank_id)))
# 将rank_ids转换为C类型的数组
c_array_rank_ids = c_array(ctypes.c_uint, rank_ids) c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
# 将rank_num转换为C类型的无符号整数
c_rank_num = ctypes.c_uint(rank_num) c_rank_num = ctypes.c_uint(rank_num)
# 将group转换为C类型的字符串
c_group = c_str(group) c_group = c_str(group)
# 调用HCCL库创建组
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
# 检查组创建是否成功
if ret != 0: if ret != 0:
raise RuntimeError('Create group error, the error code is {}.'.format(ret)) raise RuntimeError('Create group error, the error code is {}.'.format(ret))
else: else:
# 如果rank_ids不是列表类型抛出类型错误
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids))) "but got 'rank_ids' type : {}.".format(type(rank_ids)))
#用于销毁用户创建的HCCL组
def destroy_group(group): def destroy_group(group):
""" """
A function that destroy the group which created by user. A function that destroy the group which created by user.
@ -144,13 +170,18 @@ def destroy_group(group):
Returns: Returns:
None None
""" """
# 检查传入的组是否有效
check_group(group) check_group(group)
# 将组名转换为C风格的字符串
c_group = c_str(group) c_group = c_str(group)
# 调用HCCL库中的函数销毁指定的组
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
# 如果返回值不为0说明销毁组时发生了错误抛出异常
if ret != 0: if ret != 0:
raise RuntimeError('Destroy group error.') raise RuntimeError('Destroy group error.')
def get_rank_size(group="hccl_world_group"): def get_rank_size(group="hccl_world_group"):
""" """
A function that returns the number of ranks within the given collection communication group. A function that returns the number of ranks within the given collection communication group.
@ -162,16 +193,23 @@ def get_rank_size(group="hccl_world_group"):
An integer scalar with the num of ranks. An integer scalar with the num of ranks.
""" """
# 根据上下文的模式判断是否为PYNATIVE_MODE模式若是则直接返回HCCL的rank size
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_size() return get_hccl_rank_size()
# 检查给定的组是否有效
check_group(group) check_group(group)
# 将组名转换为C字符串格式
c_group = c_str(group) c_group = c_str(group)
# 定义一个C类型的无符号整数用于存储rank size
c_rank_size = ctypes.c_uint() c_rank_size = ctypes.c_uint()
# 调用HCCL库的HcomGetRankSize函数获取组内的rank size
ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size)) ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
# 如果返回值不为0表示获取rank size失败抛出运行时错误
if ret != 0: if ret != 0:
raise RuntimeError('Get rank size error.') raise RuntimeError('Get rank size error.')
# 返回获取到的rank size值
return c_rank_size.value return c_rank_size.value
@ -186,17 +224,22 @@ def get_rank_id(group="hccl_world_group"):
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_id() return get_hccl_rank_id()
# 检查组的有效性
check_group(group) check_group(group)
# 将组转换为 C 字符串格式
c_group = c_str(group) c_group = c_str(group)
# 定义一个用于存储 rank id 的 ctypes 无符号整数
c_rank_id = ctypes.c_uint() c_rank_id = ctypes.c_uint()
# 调用 HCCL 库获取当前进程的 rank id
ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id)) ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
# 如果返回值不为 0表示获取 rank id 出错,抛出 RuntimeError 异常
if ret != 0: if ret != 0:
raise RuntimeError('Get rank id error.') raise RuntimeError('Get rank id error.')
# 返回获取到的 rank id 值
return c_rank_id.value return c_rank_id.value
def get_local_rank_size(group="hccl_world_group"): def get_local_rank_size(group="hccl_world_group"):
""" """
A function that returns the number of local ranks within the given collection communication group. A function that returns the number of local ranks within the given collection communication group.
@ -207,19 +250,25 @@ def get_local_rank_size(group="hccl_world_group"):
Returns: Returns:
An integer scalar with the num of local ranks. An integer scalar with the num of local ranks.
""" """
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE: if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, " raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, "
"'get_local_rank_size' only support GRAPH_MODE") "'get_local_rank_size' only support GRAPH_MODE")
# 验证传入的组是否有效
check_group(group) check_group(group)
# 将组名称转换为C字符串格式
c_group = c_str(group) c_group = c_str(group)
# 定义一个ctypes的无符号整数变量用于存储本地排名大小
c_local_rank_size = ctypes.c_uint() c_local_rank_size = ctypes.c_uint()
# 调用HCCL库中的HcomGetLocalRankSize函数获取本地排名大小
ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size)) ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
# 如果返回值不为0说明获取本地排名大小时出错抛出异常
if ret != 0: if ret != 0:
raise RuntimeError('Get local rank size error.') raise RuntimeError('Get local rank size error.')
# 返回获取到的本地排名大小
return c_local_rank_size.value return c_local_rank_size.value
def get_local_rank_id(group="hccl_world_group"): def get_local_rank_id(group="hccl_world_group"):
""" """
Get local rank id. Get local rank id.
@ -230,16 +279,23 @@ def get_local_rank_id(group="hccl_world_group"):
An integer scalar with the local rank id of the calling process. An integer scalar with the local rank id of the calling process.
""" """
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE: if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, " raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, "
"'get_local_rank_id' only support GRAPH_MODE") "'get_local_rank_id' only support GRAPH_MODE")
# 验证群组的有效性
check_group(group) check_group(group)
# 将群组名称转换为C字符串格式
c_group = c_str(group) c_group = c_str(group)
# 定义一个无符号整型的C类型变量来存储本地排名ID
c_local_rank_id = ctypes.c_uint() c_local_rank_id = ctypes.c_uint()
# 调用HCCL库的HcomGetLocalRankId函数获取本地排名ID
ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id)) ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
# 如果返回值不为0表示获取本地排名ID失败抛出异常
if ret != 0: if ret != 0:
raise RuntimeError('Get local rank id error.') raise RuntimeError('Get local rank id error.')
# 返回获取到的本地排名ID值
return c_local_rank_id.value return c_local_rank_id.value
@ -256,18 +312,25 @@ def get_world_rank_from_group_rank(group, group_rank_id):
if context.get_context("mode") is context.PYNATIVE_MODE: if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, " raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, "
"'get_world_rank_from_group_rank' only support GRAPH_MODE") "'get_world_rank_from_group_rank' only support GRAPH_MODE")
# 检查组名是否有效
check_group(group) check_group(group)
# 检查组内rank ID是否有效
check_rank_id(group_rank_id) check_rank_id(group_rank_id)
# 将组名转换为C字符串格式
c_group = c_str(group) c_group = c_str(group)
# 将组内rank ID转换为C的无符号整数类型
c_group_rank_id = ctypes.c_uint(group_rank_id) c_group_rank_id = ctypes.c_uint(group_rank_id)
# 声明一个用于存储世界rank ID的C的无符号整数类型变量
c_world_rank_id = ctypes.c_uint() c_world_rank_id = ctypes.c_uint()
# 调用HCCL库中的HcomGetWorldRankFromGroupRank函数根据组名和组内rank ID获取对应的世界rank ID
ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id)) ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
# 如果返回值不为0说明函数调用出错抛出RuntimeError异常
if ret != 0: if ret != 0:
raise RuntimeError('Get world rank from group rank error.') raise RuntimeError('根据组内rank ID获取世界rank ID时出错。')
# 返回获取到的世界rank ID的值
return c_world_rank_id.value return c_world_rank_id.value
def get_group_rank_from_world_rank(world_rank_id, group): def get_group_rank_from_world_rank(world_rank_id, group):
""" """
Get group rank from world rank. Get group rank from world rank.
@ -281,13 +344,21 @@ def get_group_rank_from_world_rank(world_rank_id, group):
if context.get_context("mode") is context.PYNATIVE_MODE: if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, " raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, "
"'get_group_rank_from_world_rank' only support GRAPH_MODE") "'get_group_rank_from_world_rank' only support GRAPH_MODE")
# 检查组的有效性
check_group(group) check_group(group)
# 检查世界排名ID的有效性
check_rank_id(world_rank_id) check_rank_id(world_rank_id)
# 将组转换为C字符串
c_group = c_str(group) c_group = c_str(group)
# 将世界排名ID转换为C无符号整数
c_world_rank_id = ctypes.c_uint(world_rank_id) c_world_rank_id = ctypes.c_uint(world_rank_id)
# 创建一个用于存储组排名ID的C无符号整数
c_group_rank_id = ctypes.c_uint() c_group_rank_id = ctypes.c_uint()
# 从世界排名ID获取组排名ID
ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id)) ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
# 如果返回值不为0抛出运行时错误
if ret != 0: if ret != 0:
raise RuntimeError('Get group rank from world rank error.') raise RuntimeError('Get group rank from world rank error.')
# 返回获取到的组排名ID的值
return c_group_rank_id.value return c_group_rank_id.value

@ -20,40 +20,51 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
_get_local_rank_helper, _get_local_size_helper, GlobalComm _get_local_rank_helper, _get_local_size_helper, GlobalComm
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
# 导入mindspore上下文模块
# 导入mindspore的进程服务器和调度器角色判断函数
# 导入通信辅助模块包括后端类型、获取和设置rank及size的辅助函数、创建和销毁组的辅助函数、以及世界通信组和本地通信组的相关标识
# 导入C语言表达式模块包含HCCL和NCCL的初始化和终结化函数以及GPU集体通信的初始化函数
__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"get_local_rank_size", "get_world_rank_from_group_rank", "get_local_rank_size", "get_world_rank_from_group_rank",
"get_group_rank_from_world_rank", "create_group", "destroy_group", "get_group_rank_from_world_rank", "create_group", "destroy_group",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
# 默认的世界通信组
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
def _get_group(group): def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" """根据输入的`group`参数返回相应的通信组。
如果`group`等于`DEFAULT_WORLD_COMM_GROUP`则返回全局通信组`GlobalComm.WORLD_COMM_GROUP`
否则返回传入的`group`参数
"""
if group == DEFAULT_WORLD_COMM_GROUP: if group == DEFAULT_WORLD_COMM_GROUP:
return GlobalComm.WORLD_COMM_GROUP return GlobalComm.WORLD_COMM_GROUP # 返回默认的世界通信组
return group return group # 返回传入的通信组参数
def _check_task_sink_envs(): def _check_task_sink_envs():
""" """检查任务接收器task_sink相关的环境变量是否已导出。
Check whether task_sink environment variables have been exported or not.
return True if task_sink environment variables have been exported, False otherwise. 该函数通过检查环境变量`GRAPH_OP_RUN`来判断任务接收器的环境变量是否已导出
返回值
- 如果环境变量`GRAPH_OP_RUN`已导出且其值可以转换为整数1则返回False表示环境变量未正确设置为启用状态
- 如果环境变量`GRAPH_OP_RUN`未导出或其值不能转换为整数1则返回True表示环境变量未导出或设置有误
""" """
import os import os
task_sink = os.getenv("GRAPH_OP_RUN") task_sink = os.getenv("GRAPH_OP_RUN") # 获取名为"GRAPH_OP_RUN"的环境变量
if task_sink: if task_sink:
try: try:
if int(task_sink) == 1: if int(task_sink) == 1: # 尝试将环境变量的值转换为整数并检查是否等于1
return False return False # 如果等于1返回False表示环境变量已导出但设置为禁用状态非预期情况
except ValueError: except ValueError:
return True return True # 如果转换为整数失败返回True表示环境变量设置有误
finally: finally:
pass pass # finally块中的代码在这里是空操作通常用于清理操作
return True return True # 如果环境变量未导出返回True
def _check_parallel_envs(): def _check_parallel_envs():
@ -63,25 +74,31 @@ def _check_parallel_envs():
Raises: Raises:
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values. RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
""" """
# 检查是否需要进行环境验证,如果不进行则直接返回
if not GlobalComm.CHECK_ENVS: if not GlobalComm.CHECK_ENVS:
return return
import os import os
# 获取环境变量RANK_ID的值
rank_id_str = os.getenv("RANK_ID") rank_id_str = os.getenv("RANK_ID")
# 如果RANK_ID未设置抛出运行时错误
if not rank_id_str: if not rank_id_str:
raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.") raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.")
try: try:
# 尝试将RANK_ID转换为整数
int(rank_id_str) int(rank_id_str)
except ValueError: except ValueError:
# 如果转换失败,打印错误信息
print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str))) print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str)))
finally: finally:
# 无论是否发生异常,此块为空操作
pass pass
# 获取环境变量MINDSPORE_HCCL_CONFIG_PATH和RANK_TABLE_FILE的值
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH") rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE") rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
# 如果两个环境变量都未设置,抛出运行时错误
if not rank_table_file_str and not rank_table_file_str_old: if not rank_table_file_str and not rank_table_file_str_old:
raise RuntimeError("Get hccl rank_table_file failed, " raise RuntimeError("Get hccl rank_table_file failed, "
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.") "please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
def init(backend_name=None): def init(backend_name=None):
""" """
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service. Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
@ -109,15 +126,20 @@ def init(backend_name=None):
>>> from mindspore.communication import init >>> from mindspore.communication import init
>>> init() >>> init()
""" """
# 检查当前角色是否为参数服务器或调度器,如果是则直接返回
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
return return
# 检查任务接收环境变量,获取设备目标和模式
task_sink = _check_task_sink_envs() task_sink = _check_task_sink_envs()
device_target = context.get_context("device_target") device_target = context.get_context("device_target")
mode = context.get_context("mode") mode = context.get_context("mode")
mpi_init = False mpi_init = False
# 如果没有任务接收且模式为图模式设置mpi_init为True
if not task_sink and mode == context.GRAPH_MODE: if not task_sink and mode == context.GRAPH_MODE:
mpi_init = True mpi_init = True
# 根据设备目标选择后端名称,如果不支持则抛出异常
# 根据设备目标设置默认的后端名称
if backend_name is None: if backend_name is None:
if device_target == "Ascend": if device_target == "Ascend":
backend_name = "hccl" backend_name = "hccl"
@ -126,28 +148,34 @@ def init(backend_name=None):
else: else:
raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in " raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in "
"parallel initialization, please use Ascend or GPU.".format(device_target)) "parallel initialization, please use Ascend or GPU.".format(device_target))
# 检查后端名称是否为字符串,如果不是则抛出异常
if not isinstance(backend_name, str): if not isinstance(backend_name, str):
raise TypeError("For 'init', the argument 'backend_name' must be a string, " raise TypeError("For 'init', the argument 'backend_name' must be a string, "
"but got the type : {}".format(type(backend_name))) "but got the type : {}".format(type(backend_name)))
# 根据后端名称初始化通信环境
if backend_name == "hccl": if backend_name == "hccl":
# 如果设备目标不是Ascend抛出异常
if device_target != "Ascend": if device_target != "Ascend":
raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, " raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, "
"but got {}".format(device_target)) "but got {}".format(device_target))
# 如果不需要MPI初始化检查并行环境并设置后端名称
if not mpi_init: if not mpi_init:
_check_parallel_envs() _check_parallel_envs()
GlobalComm.BACKEND = Backend("hccl") GlobalComm.BACKEND = Backend("hccl")
else: else:
GlobalComm.BACKEND = Backend("hccl_mpi") GlobalComm.BACKEND = Backend("hccl_mpi")
# 初始化HCCL并设置全局通信组和初始化状态
init_hccl() init_hccl()
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True GlobalComm.INITED = True
elif backend_name == "nccl": elif backend_name == "nccl":
# 初始化GPU集体通信并设置后端名称、全局通信组和初始化状态
init_gpu_collective() init_gpu_collective()
GlobalComm.BACKEND = Backend("nccl") GlobalComm.BACKEND = Backend("nccl")
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True GlobalComm.INITED = True
else: else:
# 如果后端名称不支持,抛出异常
raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, " raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, "
"but got the 'backend_name' : hccl.") "but got the 'backend_name' : hccl.")
@ -165,10 +193,11 @@ def release():
Examples: Examples:
>>> from mindspore.communication import init, release >>> from mindspore.communication import init, release
>>> init() >>> init() # 初始化分布式通信环境
>>> release() >>> release() # 释放分布式通信资源
""" """
finalize_hccl() finalize_hccl()# 结束 HCCL 的使用,释放相关资源
def get_rank(group=GlobalComm.WORLD_COMM_GROUP): def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
@ -197,12 +226,18 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
>>> print(rank_id) >>> print(rank_id)
>>> # the result is the rank_id in world_group >>> # the result is the rank_id in world_group
""" """
# 检查传入的group参数是否为字符串类型
if not isinstance(group, str): if not isinstance(group, str):
# 如果group参数不是字符串类型则抛出TypeError异常
raise TypeError("For 'get_rank', the argument 'group' must be type of string, " raise TypeError("For 'get_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用_get_rank_helper函数获取指定group的rank ID
# _get_group函数用于解析group参数
# GlobalComm.BACKEND变量存储了当前使用的通信后端
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
""" """
Gets local rank ID for current device in specified collective communication group. Gets local rank ID for current device in specified collective communication group.
@ -234,9 +269,11 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank)) >>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank))
local_rank is: 1, world_rank is 9 local_rank is: 1, world_rank is 9
""" """
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, " raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _get_local_rank_helper 来获取本地排名,传入的参数为解析后的组和全局通信后端
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -270,9 +307,11 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("group_size is: ", group_size) >>> print("group_size is: ", group_size)
group_size is: 8 group_size is: 8
""" """
# 检查传入的参数 'group' 是否为字符串类型
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, " raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 返回指定组的大小,使用辅助函数 _get_size_helper 和全局通信后端
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -306,9 +345,11 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("local_rank_size is: ", local_rank_size) >>> print("local_rank_size is: ", local_rank_size)
local_rank_size is: 8 local_rank_size is: 8
""" """
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, " raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数获取本地组的大小,使用 _get_group 函数获取组,并使用 GlobalComm.BACKEND 作为后端
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -347,12 +388,12 @@ def get_world_rank_from_group_rank(group, group_rank_id):
>>> print("world_rank_id is: ", world_rank_id) >>> print("world_rank_id is: ", world_rank_id)
world_rank_id is: 4 world_rank_id is: 4
""" """
# 检查传入的 group 参数是否为字符串类型
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, " raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 返回根据组和组排名获取的世界排名的帮助函数的结果
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND) return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
def get_group_rank_from_world_rank(world_rank_id, group): def get_group_rank_from_world_rank(world_rank_id, group):
""" """
Get the rank ID in the specified user communication group corresponding to Get the rank ID in the specified user communication group corresponding to
@ -389,12 +430,13 @@ def get_group_rank_from_world_rank(world_rank_id, group):
>>> print("group_rank_id is: ", group_rank_id) >>> print("group_rank_id is: ", group_rank_id)
group_rank_id is: 1 group_rank_id is: 1
""" """
# 检查输入参数 'group' 是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, " raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _get_group_rank_from_world_rank_helper 来获取组的排名,并返回结果
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND) return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
#创建一个用户集体通信组通过传入通信组名称和设备ID列表来实现。
def create_group(group, rank_ids): def create_group(group, rank_ids):
""" """
Create a user collective communication group. Create a user collective communication group.
@ -427,9 +469,11 @@ def create_group(group, rank_ids):
>>> create_group(group, rank_ids) >>> create_group(group, rank_ids)
>>> allreduce = ops.AllReduce(group) >>> allreduce = ops.AllReduce(group)
""" """
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'create_group', the argument 'group' must be type of string, " raise TypeError("For 'create_group', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _create_group_helper 来创建组,使用指定的 rank_ids 和后端
_create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND) _create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND)
@ -450,7 +494,9 @@ def destroy_group(group):
ValueError: If group is "hccl_world_group" or backend is invalid. ValueError: If group is "hccl_world_group" or backend is invalid.
RuntimeError: If HCCL is not available or MindSpore is GPU version. RuntimeError: If HCCL is not available or MindSpore is GPU version.
""" """
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'destroy_group', the argument 'group' must be type of string, " raise TypeError("For 'destroy_group', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _destroy_group_helper 来销毁指定的组,并使用全局通信后端
_destroy_group_helper(group, backend=GlobalComm.BACKEND) _destroy_group_helper(group, backend=GlobalComm.BACKEND)

@ -45,14 +45,15 @@ class OneOf(OneOf_):
TypeError: raise type error for invalid inputs. TypeError: raise type error for invalid inputs.
""" """
self.patterns = patterns self.patterns = patterns
# 检查 patterns 是否是 Pattern 类的实例
if isinstance(patterns, Pattern): if isinstance(patterns, Pattern):
OneOf_.__init__(self, [patterns]) OneOf_.__init__(self, [patterns])
# 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
OneOf_.__init__(self, patterns) OneOf_.__init__(self, patterns)
# 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常
else: else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class Prim(Prim_): class Prim(Prim_):
r""" r"""
Express a pattern of certain primitive type(s). Express a pattern of certain primitive type(s).
@ -76,25 +77,33 @@ class Prim(Prim_):
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
# 检查name是否为字符串类型如果不是则抛出TypeError
if name is not None and not isinstance(name, str): if name is not None and not isinstance(name, str):
raise TypeError(f"Expect string, got : {name}") raise TypeError(f"Expect string, got : {name}")
self.name = name self.name = name
# 如果types是字符串类型则将其按'|'分割成列表
if isinstance(types, str): if isinstance(types, str):
if self.name is None: if self.name is None:
self.name = types self.name = types
self.types = types.split('|') self.types = types.split('|')
# 如果types是Primitive类型则直接将其放入列表中
elif isinstance(types, Primitive): elif isinstance(types, Primitive):
if self.name is None: if self.name is None:
self.name = types.name self.name = types.name
self.types = [types] self.types = [types]
# 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型
elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
# 如果 self.name 为 None则初始化为空字符串并拼接所有 Primitive 的 name
if self.name is None: if self.name is None:
self.name = "" self.name = ""
for prim in types: for prim in types:
self.name += prim.name self.name += prim.name
# 设置 self.types 为传入的 types
self.types = types self.types = types
# 如果 types 不符合预期类型,抛出 TypeError
else: else:
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
# 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name
Prim_.__init__(self, self.types, self.name) Prim_.__init__(self, self.types, self.name)
@ -115,16 +124,22 @@ class Call(Call_):
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
# 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError
if not isinstance(prim_pattern, (Pattern, str, Primitive)): if not isinstance(prim_pattern, (Pattern, str, Primitive)):
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
self.prim_pattern = prim_pattern self.prim_pattern = prim_pattern
# 初始化 inputs 列表
self.inputs = [] self.inputs = []
# 如果 inputs 为 None则不做任何操作
if inputs is None: if inputs is None:
pass pass
# 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs
elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
self.inputs = inputs self.inputs = inputs
# 如果 inputs 不符合上述条件,则抛出 TypeError
else: else:
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
# 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs
Call_.__init__(self, self.prim_pattern, self.inputs) Call_.__init__(self, self.prim_pattern, self.inputs)
@ -145,6 +160,7 @@ class NoneOf(NoneOf_):
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
self.patterns = patterns self.patterns = patterns
# 根据 patterns 的类型初始化 NoneOf_ 类
if patterns is None: if patterns is None:
NoneOf_.__init__(self, ()) NoneOf_.__init__(self, ())
elif isinstance(patterns, Pattern): elif isinstance(patterns, Pattern):
@ -154,7 +170,6 @@ class NoneOf(NoneOf_):
else: else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
class NewTensor(NewTensor_): class NewTensor(NewTensor_):
r""" r"""
New Tensor to be used in the target. New Tensor to be used in the target.
@ -167,13 +182,16 @@ class NewTensor(NewTensor_):
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
# 初始化输入张量
self.input_tensor = input_tensor self.input_tensor = input_tensor
# 检查输入是否为 Tensor 类型
if isinstance(input_tensor, Tensor): if isinstance(input_tensor, Tensor):
# 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法
NewTensor_.__init__(self, input_tensor) NewTensor_.__init__(self, input_tensor)
else: else:
# 如果不是 Tensor 类型,则抛出 TypeError 异常
raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}") raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}")
class NewParameter(NewParameter_): class NewParameter(NewParameter_):
r""" r"""
New Parameter to be used in the target. New Parameter to be used in the target.
@ -193,11 +211,14 @@ class NewParameter(NewParameter_):
self.default_tensor = default_tensor self.default_tensor = default_tensor
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel self.layerwise_parallel = layerwise_parallel
# 检查参数类型是否符合预期
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool): isinstance(layerwise_parallel, bool):
# 初始化 NewParameter_ 类
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel) self.layerwise_parallel)
else: else:
# 如果参数类型不符合预期,抛出 TypeError
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) got : {para_name}, {default_tensor}, \ layerwise_parallel(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}") {requires_grad}, {layerwise_parallel}")

@ -39,6 +39,7 @@ class PyPassManager(PyPassManager_):
TypeError: If argument has invalid type. TypeError: If argument has invalid type.
""" """
def __init__(self, requires_grad=True, run_only_once=False): def __init__(self, requires_grad=True, run_only_once=False):
# 初始化方法,接收两个布尔参数,设置实例的属性并调用父类的初始化方法
if not isinstance(requires_grad, bool): if not isinstance(requires_grad, bool):
raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}") raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
if not isinstance(run_only_once, bool): if not isinstance(run_only_once, bool):
@ -48,17 +49,20 @@ class PyPassManager(PyPassManager_):
PyPassManager_.__init__(self) PyPassManager_.__init__(self)
def register(self, py_pass): def register(self, py_pass):
# 注册一个Python pass检查其是否为函数类型并获取其模式和目标
if not isfunction(py_pass): if not isfunction(py_pass):
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass() pattern, target = py_pass()
pass_name = py_pass.__name__ pass_name = py_pass.__name__
# 检查模式和目标是否为Pattern类型
if not isinstance(pattern, Pattern): if not isinstance(pattern, Pattern):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern): if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
# 调用父类的register方法注册pass及其相关信息
super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_) super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
def unregister(self, py_pass): def unregister(self, py_pass):
# 从注册表中移除指定的Python传递对象可以是字符串形式的名称或函数对象
if isinstance(py_pass, str): if isinstance(py_pass, str):
super().unregister(py_pass) super().unregister(py_pass)
return return
@ -68,25 +72,28 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass): def __call__(self, py_pass):
# 将Python传递对象注册到注册表中并返回该对象
self.register(py_pass) self.register(py_pass)
return py_pass return py_pass
def gen_new_parameter(self, pattern): def gen_new_parameter(self, pattern):
# 根据给定的模式生成新的参数模式必须是NewParameter类型
if not isinstance(pattern, NewParameter): if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
super().gen_new_parameter(pattern) super().gen_new_parameter(pattern)
def set_renorm(self, should_renorm): def set_renorm(self, should_renorm):
# 设置是否进行重归一化操作,参数必须是布尔值
if not isinstance(should_renorm, bool): if not isinstance(should_renorm, bool):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm) super().set_renorm(should_renorm)
def set_reopt(self, do_reopt): def set_reopt(self, do_reopt):
# 设置是否进行重新优化操作,参数必须是布尔值
if not isinstance(do_reopt, bool): if not isinstance(do_reopt, bool):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt) super().set_reopt(do_reopt)
def register_pass(requires_grad=True, run_only_once=False): def register_pass(requires_grad=True, run_only_once=False):
""" """
Register python pass to specified pipeline phase which would be used in compilation. Register python pass to specified pipeline phase which would be used in compilation.
@ -165,12 +172,13 @@ def cancel_new_parameter(pattern):
>>> # some compilations >>> # some compilations
>>> cancel_new_parameter(abc) >>> cancel_new_parameter(abc)
""" """
# 检查传入的pattern是否为NewParameter的实例
if not isinstance(pattern, NewParameter): if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
# 创建一个PyPassManager对象
ppm = PyPassManager() ppm = PyPassManager()
# 从PyPassManager中注销指定名称的参数
ppm.unregister(pattern.para_name) ppm.unregister(pattern.para_name)
def set_renorm(should_renorm): def set_renorm(should_renorm):
""" """
Set whether or not to do renormalization after modified graph in python pass(es). Set whether or not to do renormalization after modified graph in python pass(es).

@ -1,47 +1,117 @@
# Copyright 2020-2022 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# # 代码版权声明说明此代码由华为技术有限公司在2020-2022年间开发
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# 说明此代码使用Apache License 2.0版本的许可证
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# 说明除非遵守许可证,否则不得使用此文件
# You may obtain a copy of the License at # You may obtain a copy of the License at
# # 提供许可证的获取地址
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# # 许可证的具体地址
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# 除非适用法律要求或书面同意
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# 许可证在“现状”基础上进行分发
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 不附带任何形式的明示或暗示的担保或条件
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# 请参阅许可证了解特定的权限和
# limitations under the License. # limitations under the License.
# 限制条件
# ============================================================================ # ============================================================================
# 标准分割线,通常用于分隔许可证部分与代码部分
"""cell""" """cell"""
# 文档字符串模块的名称为cell
import gc import gc
# 导入垃圾回收模块,用于管理内存
import inspect import inspect
# 导入inspect模块用于获取活对象的信息
import os import os
# 导入os模块用于与操作系统进行交互
import time import time
# 导入time模块用于处理时间相关操作
from collections import OrderedDict from collections import OrderedDict
# 从collections模块导入OrderedDict类用于创建有序字典
from types import FunctionType, MethodType from types import FunctionType, MethodType
# 从types模块导入FunctionType和MethodType类用于类型检查
import numpy import numpy
# 导入numpy模块用于科学计算
from mindspore._checkparam import args_type_check from mindspore._checkparam import args_type_check
# 从mindspore._checkparam模块导入args_type_check函数用于检查函数参数的类型
from mindspore import log as logger from mindspore import log as logger
# 从mindspore模块导入log模块并命名为logger用于日志记录
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
# 从mindspore.common.parameter模块导入PARAMETER_NAME_DEFAULT常量用于默认参数名称
from mindspore.common.hook_handle import HookHandle from mindspore.common.hook_handle import HookHandle
# 从mindspore.common.hook_handle模块导入HookHandle类用于管理钩子处理
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
# 从mindspore.context模块导入ParallelMode类用于并行模式配置
from mindspore.ops.composite import Shard from mindspore.ops.composite import Shard
# 从mindspore.ops.composite模块导入Shard类用于分片操作
from .. import context from .. import context
# 导入相对路径的context模块用于上下文配置
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
# 从相对路径的_c_expression模块导入多个函数和类用于初始化管道、更新函数图超参数、Cell的基础类、函数图类、混合精度类型
from .._checkparam import Validator from .._checkparam import Validator
# 从相对路径的_checkparam模块导入Validator类用于参数验证
from ..common import dtype as mstype from ..common import dtype as mstype
# 从相对路径的common模块导入dtype模块并重命名为mstype用于数据类型定义
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
# 从相对路径的common.api模块导入多个函数和类用于单元图执行器、原生模式执行器、检查所有张量、编译缓存
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
# 从相对路径的common.parameter模块导入Parameter类和ParameterTuple类用于参数和参数元组
from ..common.variable import Variable from ..common.variable import Variable
# 从相对路径的common.variable模块导入Variable类用于变量表示
from ..common.tensor import Tensor, CSRTensor, COOTensor from ..common.tensor import Tensor, CSRTensor, COOTensor
# 从相对路径的common.tensor模块导入Tensor类、CSRTensor类和COOTensor类用于张量表示
from ..ops.operations import Cast from ..ops.operations import Cast
# 从相对路径的ops.operations模块导入Cast类用于类型转换操作
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
# 从相对路径的ops.primitive模块导入Primitive类用于基础操作
from ..ops.operations import _inner_ops as inner from ..ops.operations import _inner_ops as inner
from ..parallel._tensor import _load_tensor_by_layout # 从相对路径的ops.operations模块导入_inner_ops并重命名为inner用于内部操作
from ..parallel._tensor import _load_tensor_by_layout
# 从相对路径的parallel._tensor模块导入_load_tensor_by_layout函数用于按布局加载张量
class Cell(Cell_): class Cell(Cell_):
# 定义Cell类继承自Cell_类这是MindSpore中神经网络的基本构建单元
""" """
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
base class. base class.
@ -81,6 +151,8 @@ class Cell(Cell_):
... # the parameter's name will be 'net.weight'. ... # the parameter's name will be 'net.weight'.
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)] [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
""" """
# 类文档字符串解释Cell类的作用、继承关系、参数、支持平台及示例
class _CellGuard: class _CellGuard:
"""Detecting whether the cell is a top-level cell with the 'with statement'.""" """Detecting whether the cell is a top-level cell with the 'with statement'."""

@ -22,61 +22,68 @@ from ...common.api import ms_function
class _FirstGrad(Cell): class _FirstGrad(Cell):
# 计算第一个梯度的类
def __init__(self, fn): def __init__(self, fn):
super(_FirstGrad, self).__init__() super(_FirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True) self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
self.fn = fn self.fn = fn
def construct(self, u, first_grad_input): def construct(self, u, first_grad_input):
# 构造方法,用于计算梯度
return self.first_grad_op(self.fn)(*first_grad_input, u) return self.first_grad_op(self.fn)(*first_grad_input, u)
class _JvpFirstGrad(Cell): class _JvpFirstGrad(Cell):
# 计算Jacobian-Vector-Product的第一个梯度的类
def __init__(self): def __init__(self):
super(_JvpFirstGrad, self).__init__() super(_JvpFirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True) self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
def construct(self, u, fn, first_grad_input): def construct(self, u, fn, first_grad_input):
# 构造方法用于计算JVP的第一个梯度
return self.first_grad_op(fn)(*first_grad_input, u) return self.first_grad_op(fn)(*first_grad_input, u)
class _FirstGradSingleValue(Cell): class _FirstGradSingleValue(Cell):
# 计算单值梯度的类
def __init__(self, fn): def __init__(self, fn):
super(_FirstGradSingleValue, self).__init__() super(_FirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True) self.first_grad_single_value_op = C.GradOperation(sens_param=True)
self.fn = fn self.fn = fn
def construct(self, u, first_grad_single_value_input): def construct(self, u, first_grad_single_value_input):
# 构造方法,用于计算单值梯度
return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u) return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u)
class _JvpFirstGradSingleValue(Cell): class _JvpFirstGradSingleValue(Cell):
# 计算Jacobian-Vector-Product的单值梯度的类
def __init__(self): def __init__(self):
super(_JvpFirstGradSingleValue, self).__init__() super(_JvpFirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True) self.first_grad_single_value_op = C.GradOperation(sens_param=True)
def construct(self, u, fn, first_grad_single_value_input): def construct(self, u, fn, first_grad_single_value_input):
# 构造方法用于计算JVP的单值梯度
return self.first_grad_single_value_op(fn)(*first_grad_single_value_input, u) return self.first_grad_single_value_op(fn)(*first_grad_single_value_input, u)
class Jvp(Cell): class Jvp(Cell):
""" """
Compute the jacobian-vector-product of the given fn. Jvp is equivalent to forward mode autodiff. 计算给定fn的雅可比向量积Jvp等同于前向模式自动微分
Args: Args:
fn (Cell): The fn that takes Tensor inputs and returns a tuple of Tensors or a Tensor. fn (Cell): 接受Tensor输入并返回Tensor元组或Tensor的fn
Inputs: Inputs:
- **inputs** (Tensors) - The inputs to `fn`. - **inputs** (Tensors) - `fn`的输入
- **v** (Tensors or Tuple of Tensors) - The vector for which the Jacobian vector product is computed. - **v** (Tensors Tensor元组) - 用于计算雅可比向量积的向量
Must have the same size as the input of `fn`. 必须与`fn`的输入大小相同
Outputs: Outputs:
A tuple with 2 Tensors or Tuple of Tensors: 包含2个Tensors或Tensor元组的元组
- **net_output** (Tensors or Tuple of Tensors) - The output of `fn(inputs)`. - **net_output** (Tensors Tensor元组) - `fn(inputs)`的输出
- **jvp** (Tensors or Tuple of Tensors) - The result of the jacobian vector product. - **jvp** (Tensors Tensor元组) - 雅可比向量积的结果
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -113,6 +120,7 @@ class Jvp(Cell):
@ms_function @ms_function
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算JVP
jvp_input = args[0:-1] jvp_input = args[0:-1]
v = args[-1] v = args[-1]
output = self.fn(*jvp_input) output = self.fn(*jvp_input)
@ -135,8 +143,8 @@ class Jvp(Cell):
class _JvpInner(Cell): class _JvpInner(Cell):
""" """
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff. 计算给定网络的雅可比向量积Jvp等同于前向模式自动微分
This class implements the inner process of function jvp. 该类实现了JVP的内部过程
""" """
def __init__(self): def __init__(self):
super(_JvpInner, self).__init__() super(_JvpInner, self).__init__()
@ -152,6 +160,7 @@ class _JvpInner(Cell):
self.tuple_len = Primitive("tuple_len") self.tuple_len = Primitive("tuple_len")
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算内部JVP
fn = args[0] fn = args[0]
v = args[1] v = args[1]
jvp_input = args[2:] jvp_input = args[2:]
@ -175,22 +184,21 @@ class _JvpInner(Cell):
class Vjp(Cell): class Vjp(Cell):
""" """
Computes the dot product between a vector `v` and the Jacobian of the given fn at the point 计算给定向量`v`与给定fn在输入点处的雅可比的点积
given by the inputs.
Args: Args:
fn (Cell): The fn that takes Tensor inputs and returns a tuple of Tensors or a Tensor. fn (Cell): 接受Tensor输入并返回Tensor元组或Tensor的fn
Inputs: Inputs:
- **inputs** (Tensors) - The inputs to `fn`. Must be a tuple or a list. - **inputs** (Tensors) - `fn`的输入必须是元组或列表
- **v** (Tensors or Tuple of Tensors) - The vector for which the vector Jacobian product is computed. - **v** (Tensors Tensor元组) - 用于计算向量雅可比积的向量
Must have the same size as the output of `fn`. 必须与`fn`的输出大小相同
Outputs: Outputs:
A tuple with 2 Tensors or Tuple of Tensors: 包含2个Tensors或Tensor元组的元组
- **net_output** (Tensors or Tuple of Tensors) - The output of `fn(inputs)`. - **net_output** (Tensors Tensor元组) - `fn(inputs)`的输出
- **vjp** (Tensors or Tuple of Tensors) - The result of the dot product. - **vjp** (Tensors Tensor元组) - 点积的结果
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -226,6 +234,7 @@ class Vjp(Cell):
@ms_function @ms_function
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算VJP
front_input = args[0:-1] front_input = args[0:-1]
output = self.fn(*front_input) output = self.fn(*front_input)
if self.tuple_len(front_input) == 1: if self.tuple_len(front_input) == 1:
@ -237,8 +246,8 @@ class Vjp(Cell):
class _VjpInner(Cell): class _VjpInner(Cell):
""" """
Computes the dot product between a vector `v` and the Jacobian of the given network at the point 计算给定向量`v`与给定网络在输入点处的雅可比的点积
given by the inputs. This class implements the inner process of function vjp. 该类实现了VJP的内部过程
""" """
def __init__(self): def __init__(self):
@ -248,6 +257,7 @@ class _VjpInner(Cell):
self.tuple_len = Primitive("tuple_len") self.tuple_len = Primitive("tuple_len")
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算内部VJP
fn = args[0] fn = args[0]
front_input = args[1:-1] front_input = args[1:-1]
input_with_v = args[1:] input_with_v = args[1:]

@ -48,24 +48,23 @@ class LossBase(Cell):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize Loss.""" """Initialize Loss.""" # 初始化LossBase类接收一个参数reduction默认值为'mean'
super(LossBase, self).__init__() super(LossBase, self).__init__() # 调用父类Cell的初始化方法
if reduction not in ('mean', 'sum', 'none'): if reduction not in ('mean', 'sum', 'none'): # 检查reduction参数是否为'mean', 'sum', 'none'中的一个
raise ValueError(f"For '{self.cls_name}', the 'reduction' should be in ['mean', 'sum', 'none'], " raise ValueError(f"For '{self.cls_name}', the 'reduction' should be in ['mean', 'sum', 'none'], " # 如果参数不在允许的范围内抛出ValueError
f"but got {reduction}.") f"but got {reduction}.")
self.average = True # 设置average属性为True默认进行平均
self.average = True self.reduce = True # 设置reduce属性为True默认进行降维
self.reduce = True if reduction == 'sum': # 如果reduction参数为'sum'
if reduction == 'sum': self.average = False # 设置average属性为False不进行平均
self.average = False if reduction == 'none': # 如果reduction参数为'none'
if reduction == 'none': self.reduce = False # 设置reduce属性为False不进行降维
self.reduce = False
self.reduce_mean = P.ReduceMean() # 定义reduce_mean操作用于计算平均损失
self.reduce_mean = P.ReduceMean() self.reduce_sum = P.ReduceSum() # 定义reduce_sum操作用于计算总损失
self.reduce_sum = P.ReduceSum() self.mul = P.Mul() # 定义mul操作用于权重乘法
self.mul = P.Mul() self.cast = P.Cast() # 定义cast操作用于数据类型转换
self.cast = P.Cast()
def get_axis(self, x): def get_axis(self, x):
""" """
@ -98,10 +97,10 @@ class LossBase(Cell):
>>> print(output) >>> print(output)
(0, 1) (0, 1)
""" """
shape = F.shape(x) shape = F.shape(x) # 获取输入张量x的形状
length = F.tuple_len(shape) length = F.tuple_len(shape) # 获取形状的长度(即维度数量)
perm = F.make_range(0, length) perm = F.make_range(0, length) # 生成一个从0到length-1的元组表示所有轴
return perm return perm # 返回这个元组
def get_loss(self, x, weights=1.0): def get_loss(self, x, weights=1.0):
""" """
@ -141,20 +140,19 @@ class LossBase(Cell):
>>> print(output) >>> print(output)
0.11111111 0.11111111
""" """
input_dtype = x.dtype input_dtype = x.dtype # 获取输入张量x的数据类型
x = self.cast(x, mstype.float32) x = self.cast(x, mstype.float32) # 将输入张量x的数据类型转换为float32
weights = self.cast(weights, mstype.float32) weights = self.cast(weights, mstype.float32) # 将权重weights的数据类型转换为float32
x = self.mul(weights, x) x = self.mul(weights, x) # 将权重weights与输入张量x相乘
if self.reduce and self.average: if self.reduce and self.average: # 如果需要降维且进行平均
x = self.reduce_mean(x, self.get_axis(x)) x = self.reduce_mean(x, self.get_axis(x)) # 计算平均损失
if self.reduce and not self.average: if self.reduce and not self.average: # 如果需要降维但不进行平均
x = self.reduce_sum(x, self.get_axis(x)) x = self.reduce_sum(x, self.get_axis(x)) # 计算总损失
x = self.cast(x, input_dtype) x = self.cast(x, input_dtype) # 将损失x的数据类型转换回输入张量x的原始数据类型
return x return x # 返回计算得到的损失
def construct(self, logits, labels): def construct(self, logits, labels):
raise NotImplementedError raise NotImplementedError # 这是一个抽象方法,需要在子类中实现
class _Loss(LossBase): class _Loss(LossBase):
""" """
@ -162,23 +160,21 @@ class _Loss(LossBase):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize _Loss.""" """Initialize _Loss.""" # 初始化_Loss类接收一个参数reduction默认值为'mean'
log.warning("'_Loss' is deprecated from version 1.3 and " log.warning("'_Loss' is deprecated from version 1.3 and " # 输出警告信息提示_Loss类已过时
"will be removed in a future version, use 'LossBase' instead.") "will be removed in a future version, use 'LossBase' instead.")
super(_Loss, self).__init__(reduction) super(_Loss, self).__init__(reduction) # 调用父类LossBase的初始化方法
def construct(self, logits, labels): def construct(self, logits, labels):
raise NotImplementedError raise NotImplementedError # 这是一个抽象方法,需要在子类中实现
@constexpr @constexpr
def _check_is_tensor(param_name, input_data, cls_name): def _check_is_tensor(param_name, input_data, cls_name):
"""Internal function, used to check whether the input data is Tensor.""" """Internal function, used to check whether the input data is Tensor.""" # 定义一个内部函数用于检查输入数据是否为Tensor
if input_data is not None and not isinstance(F.typeof(input_data), mstype.tensor_type): if input_data is not None and not isinstance(F.typeof(input_data), mstype.tensor_type): # 如果输入数据不为None且类型不是Tensor
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', " raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', " # 抛出TypeError
f"but got '{F.typeof(input_data)}'") f"but got '{F.typeof(input_data)}'")
class L1Loss(LossBase): class L1Loss(LossBase):
r""" r"""
L1Loss is used to calculate the mean absolute error between the predicted value and the target value. L1Loss is used to calculate the mean absolute error between the predicted value and the target value.
@ -238,16 +234,15 @@ class L1Loss(LossBase):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize L1Loss.""" """Initialize L1Loss.""" # 初始化L1Loss类接收一个参数reduction默认值为'mean'
super(L1Loss, self).__init__(reduction) super(L1Loss, self).__init__(reduction) # 调用父类LossBase的初始化方法
self.abs = P.Abs() self.abs = P.Abs() # 定义abs操作用于计算绝对值
def construct(self, logits, labels): def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) # 检查labels是否为Tensor
x = self.abs(logits - labels) x = self.abs(logits - labels) # 计算logits与labels的差的绝对值
return self.get_loss(x) return self.get_loss(x) # 使用self.get_loss方法计算加权损失并返回
class MSELoss(LossBase): class MSELoss(LossBase):
r""" r"""
@ -308,10 +303,10 @@ class MSELoss(LossBase):
""" """
def construct(self, logits, labels): def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) # 检查labels是否为Tensor
x = F.square(logits - labels) x = F.square(logits - labels) # 计算logits与labels的差的平方
return self.get_loss(x) return self.get_loss(x) # 使用self.get_loss方法计算加权损失并返回
class RMSELoss(LossBase): class RMSELoss(LossBase):
@ -356,15 +351,14 @@ class RMSELoss(LossBase):
""" """
def __init__(self): def __init__(self):
"""Initialize RMSELoss.""" """Initialize RMSELoss.""" # 初始化RMSELoss类
super(RMSELoss, self).__init__() super(RMSELoss, self).__init__() # 调用父类LossBase的初始化方法
self.MSELoss = MSELoss() self.MSELoss = MSELoss() # 初始化MSELoss对象用于计算均方误差
def construct(self, logits, label): def construct(self, logits, label):
rmse_loss = F.sqrt(self.MSELoss(logits, label)) rmse_loss = F.sqrt(self.MSELoss(logits, label)) # 计算均方误差损失然后取平方根得到RMSE损失
return rmse_loss
return rmse_loss # 返回计算得到的RMSE损失
class MAELoss(LossBase): class MAELoss(LossBase):
r""" r"""
@ -426,16 +420,15 @@ class MAELoss(LossBase):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize MAELoss.""" """Initialize MAELoss.""" # 初始化MAELoss类接收一个参数reduction默认值为'mean'
super(MAELoss, self).__init__(reduction) super(MAELoss, self).__init__(reduction) # 调用父类LossBase的初始化方法
self.abs = P.Abs() self.abs = P.Abs() # 定义abs操作用于计算绝对值
def construct(self, logits, label): def construct(self, logits, label):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', label, self.cls_name) _check_is_tensor('labels', label, self.cls_name) # 检查labels是否为Tensor
x = self.abs(logits - label) x = self.abs(logits - label) # 计算logits与labels的差的绝对值
return self.get_loss(x) return self.get_loss(x) # 使用self.get_loss方法计算加权损失并返回
class SmoothL1Loss(LossBase): class SmoothL1Loss(LossBase):
r""" r"""
@ -491,16 +484,15 @@ class SmoothL1Loss(LossBase):
""" """
def __init__(self, beta=1.0): def __init__(self, beta=1.0):
"""Initialize SmoothL1Loss.""" """Initialize SmoothL1Loss.""" # 初始化SmoothL1Loss类接收一个参数beta默认值为1.0
super(SmoothL1Loss, self).__init__() super(SmoothL1Loss, self).__init__() # 调用父类LossBase的初始化方法
self.beta = beta self.beta = beta # 设置beta属性表示平滑阈值
self.smooth_l1_loss = P.SmoothL1Loss(self.beta) self.smooth_l1_loss = P.SmoothL1Loss(self.beta) # 定义smooth_l1_loss操作用于计算平滑L1损失
def construct(self, logits, labels): def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) # 检查labels是否为Tensor
return self.smooth_l1_loss(logits, labels) return self.smooth_l1_loss(logits, labels) # 使用self.smooth_l1_loss计算平滑L1损失并返回
class SoftMarginLoss(LossBase): class SoftMarginLoss(LossBase):
r""" r"""
@ -545,12 +537,11 @@ class SoftMarginLoss(LossBase):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
super(SoftMarginLoss, self).__init__() super(SoftMarginLoss, self).__init__() # 调用父类LossBase的初始化方法
self.soft_margin_loss = P.SoftMarginLoss(reduction) self.soft_margin_loss = P.SoftMarginLoss(reduction) # 定义soft_margin_loss操作用于计算SoftMargin损失
def construct(self, logits, labels): def construct(self, logits, labels):
return self.soft_margin_loss(logits, labels) return self.soft_margin_loss(logits, labels) # 使用self.soft_margin_loss计算SoftMargin损失并返回
class SoftmaxCrossEntropyWithLogits(LossBase): class SoftmaxCrossEntropyWithLogits(LossBase):
r""" r"""
@ -619,27 +610,28 @@ class SoftmaxCrossEntropyWithLogits(LossBase):
def __init__(self, def __init__(self,
sparse=False, sparse=False,
reduction='none'): reduction='none'):
"""Initialize SoftmaxCrossEntropyWithLogits.""" """Initialize SoftmaxCrossEntropyWithLogits.""" # 初始化SoftmaxCrossEntropyWithLogits类接收sparse和reduction两个参数
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) # 调用父类LossBase的初始化方法传入reduction参数
self.sparse = validator.check_bool(sparse, "sparse", self.cls_name) self.sparse = validator.check_bool(sparse, "sparse", self.cls_name) # 检查sparse参数是否为布尔值如果是则赋值给self.sparse
self.reduction = reduction self.reduction = reduction # 设置reduction属性表示减少类型
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() # 定义softmax_cross_entropy操作用于计算Softmax交叉熵损失
self.one_hot = P.OneHot() self.one_hot = P.OneHot() # 定义one_hot操作用于将标签转换为OneHot编码
self.on_value = Tensor(1.0, mstype.float32) self.on_value = Tensor(1.0, mstype.float32) # 定义on_value属性表示OneHot编码中正类的值
self.off_value = Tensor(0., mstype.float32) self.off_value = Tensor(0., mstype.float32) # 定义off_value属性表示OneHot编码中负类的值
self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] # 定义is_cpugpu属性表示是否在CPU或GPU上运行
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits() self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits() # 定义sparse_softmax_cross_entropy操作用于计算稀疏Softmax交叉熵损失
def construct(self, logits, labels): def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) # 检查labels是否为Tensor
if self.sparse: if self.sparse: # 如果使用稀疏标签格式
if self.reduction == 'mean': if self.reduction == 'mean': # 如果reduction为'mean'
x = self.sparse_softmax_cross_entropy(logits, labels) x = self.sparse_softmax_cross_entropy(logits, labels) # 使用稀疏Softmax交叉熵损失计算损失
return x return x # 返回计算得到的损失
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value) labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value) # 将labels转换为OneHot编码
x = self.softmax_cross_entropy(logits, labels)[0] x = self.softmax_cross_entropy(logits, labels)[0] # 计算Softmax交叉熵损失取第一个返回值
return self.get_loss(x) return self.get_loss(x) # 使用self.get_loss方法计算加权损失并返回
@constexpr @constexpr

@ -85,58 +85,84 @@ def array(obj, dtype=None, copy=True, ndmin=0):
>>> print(np.array([1,2,3])) >>> print(np.array([1,2,3]))
[1 2 3] [1 2 3]
""" """
if dtype is not None: if dtype is not None: # 如果用户指定了数据类型则检查并转换为mindspore的数据类型
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
res = asarray(obj, dtype) res = asarray(obj, dtype) # 将输入对象转换为tensor
if ndmin > res.ndim: if ndmin > res.ndim: # 如果用户指定的最小维度大于转换后的tensor维度则在tensor的前面添加维度
if res.size == 0: if res.size == 0: # 如果tensor为空抛出异常
_raise_value_error("Empty tensor cannot be expanded beyond the current dimension.") _raise_value_error("Empty tensor cannot be expanded beyond the current dimension.")
res = _expand(res, ndmin) res = _expand(res, ndmin) # 扩展tensor的维度
if copy and isinstance(obj, Tensor): if copy and isinstance(obj, Tensor): # 如果copy为True且输入对象已经是tensor则创建其副本
res = copy_(res) res = copy_(res)
elif dtype is not None and dtype != res.dtype: elif dtype is not None and dtype != res.dtype: # 如果用户指定了数据类型且与转换后的tensor数据类型不同则转换数据类型
res = res.astype(dtype) res = res.astype(dtype)
return res return res # 返回最终生成的tensor
@constexpr @constexpr
def asarray_const(a, dtype=None): def asarray_const(a, dtype=None):
# 标记此函数为constexpr意味着它是一个编译时常量函数
"""Converts the input to tensor. Note here `a` cannot be tensor itself.""" """Converts the input to tensor. Note here `a` cannot be tensor itself."""
# 文档字符串解释函数作用将输入转换为张量注意这里的a不能是张量本身
_check_input_for_asarray(a) _check_input_for_asarray(a)
# 检查输入a是否符合asarray函数的输入要求
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if isinstance(a, (float, int, bool)) and dtype is None: if isinstance(a, (float, int, bool)) and dtype is None:
# 如果a是float、int或bool类型并且dtype未指定
dtype = _get_dtype_from_scalar(a) dtype = _get_dtype_from_scalar(a)
# 从标量a中获取数据类型并赋值给dtype
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# 如果a是list或tuple类型
# Convert all tuple/nested tuples to lists # Convert all tuple/nested tuples to lists
a = _deep_list(a) a = _deep_list(a)
# 将所有tuple及其嵌套的tuple转换为list
# Convert all tensor sub-elements to numpy arrays # Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a) a = _deep_tensor_to_nparray(a)
# 将所有tensor子元素转换为numpy数组
a = onp.asarray(a) a = onp.asarray(a)
# 使用numpy的asarray函数将a转换为numpy数组
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果转换后的numpy数组的数据类型是object
raise ValueError('Input array must have the same size across all dimensions.') raise ValueError('Input array must have the same size across all dimensions.')
# 抛出ValueError表示输入数组在所有维度上必须具有相同的大小
# If dtype is not specified, we keep consistent with numpy decision # If dtype is not specified, we keep consistent with numpy decision
# only exceptions are: we use int/float32 # only exceptions are: we use int/float32
if dtype is None: if dtype is None:
# 如果dtype未指定
dtype = mstype.pytype_to_dtype(a.dtype) dtype = mstype.pytype_to_dtype(a.dtype)
# 将numpy数组的数据类型转换为mindspore的dtype
if dtype == mstype.float64: if dtype == mstype.float64:
# 如果dtype是float64
dtype = mstype.float32 dtype = mstype.float32
# 将dtype改为float32
elif dtype == mstype.int64: elif dtype == mstype.int64:
# 如果dtype是int64
dtype = mstype.int32 dtype = mstype.int32
# 将dtype改为int32
if isinstance(a, onp.ndarray) and dtype is None: if isinstance(a, onp.ndarray) and dtype is None:
# 如果a是numpy数组并且dtype未指定
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果numpy数组的数据类型是object
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
# 抛出TypeError表示输入数据包含不支持的元素
dtype = mstype.pytype_to_dtype(a.dtype) dtype = mstype.pytype_to_dtype(a.dtype)
# 将numpy数组的数据类型转换为mindspore的dtype
a = Tensor.from_numpy(a) a = Tensor.from_numpy(a)
# 将numpy数组转换为mindspore的Tensor
return Tensor(a, dtype=dtype) return Tensor(a, dtype=dtype)
# 返回一个具有指定dtype的Tensor
def asarray(a, dtype=None): def asarray(a, dtype=None):
@ -168,29 +194,46 @@ def asarray(a, dtype=None):
[1 2 3] [1 2 3]
""" """
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if isinstance(a, Tensor): if isinstance(a, Tensor):
# 如果a是Tensor类型
if dtype is None or dtype == a.dtype: if dtype is None or dtype == a.dtype:
# 如果dtype未指定或指定的数据类型与a的数据类型相同
return a return a
# 直接返回a
return a.astype(dtype) return a.astype(dtype)
# 如果指定的数据类型与a的数据类型不同将a的数据类型转换为指定的dtype并返回
return asarray_const(a, dtype) return asarray_const(a, dtype)
# 如果a不是Tensor类型调用asarray_const函数将其转换为Tensor并返回
@constexpr @constexpr
def asfarray_const(a, dtype=mstype.float32): def asfarray_const(a, dtype=mstype.float32):
"""Converts the input to tensor. Note here `a` cannot be tensor itself.""" """Converts the input to tensor. Note here `a` cannot be tensor itself."""
# 文档字符串解释函数作用将输入转换为张量注意这里的a不能是张量本身
_check_input_for_asarray(a) _check_input_for_asarray(a)
# 检查输入a是否符合asarray函数的输入要求
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# 如果a是list或tuple类型
# Convert all tuple/nested tuples to lists # Convert all tuple/nested tuples to lists
a = _deep_list(a) a = _deep_list(a)
# 将所有tuple及其嵌套的tuple转换为list
# Convert all tensor sub-elements to numpy arrays # Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a) a = _deep_tensor_to_nparray(a)
# 将所有tensor子元素转换为numpy数组
a = onp.asarray(a) a = onp.asarray(a)
# 使用numpy的asarray函数将a转换为numpy数组
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果转换后的numpy数组的数据类型是object
raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
# 抛出ValueError表示输入数组在所有维度上必须具有相同的大小
a = Tensor.from_numpy(a) a = Tensor.from_numpy(a)
# 将numpy数组转换为mindspore的Tensor
return Tensor(a, dtype) return Tensor(a, dtype)
# 返回一个具有指定dtype的Tensor
def asfarray(a, dtype=mstype.float32): def asfarray(a, dtype=mstype.float32):
@ -206,7 +249,6 @@ def asfarray(a, dtype=mstype.float32):
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`. of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`.
Returns: Returns:
Tensor, generated tensor with the specified float dtype. Tensor, generated tensor with the specified float dtype.
@ -223,16 +265,24 @@ def asfarray(a, dtype=mstype.float32):
[1. 2. 3.] [1. 2. 3.]
""" """
if dtype is None: if dtype is None:
# 如果dtype未指定
return asarray(a) return asarray(a)
# 调用asarray函数将a转换为Tensor并返回
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if dtype not in (mstype.float16, mstype.float32, mstype.float64): if dtype not in (mstype.float16, mstype.float32, mstype.float64):
# 如果dtype不是float16、float32或float64
dtype = mstype.float32 dtype = mstype.float32
# 将dtype改为float32
if isinstance(a, Tensor): if isinstance(a, Tensor):
# 如果a是Tensor类型
return a.astype(dtype) return a.astype(dtype)
# 将a的数据类型转换为指定的dtype并返回
return asfarray_const(a, dtype) return asfarray_const(a, dtype)
# 如果a不是Tensor类型调用asfarray_const函数将其转换为Tensor并返回
def copy_(a): def copy_(a):
@ -261,7 +311,9 @@ def copy_(a):
[1. 1.]] [1. 1.]]
""" """
a = asarray(a) a = asarray(a)
# 使用asarray函数将a转换为Tensor
return a.copy() return a.copy()
# 返回a的副本
def ones(shape, dtype=mstype.float32): def ones(shape, dtype=mstype.float32):
@ -290,11 +342,17 @@ def ones(shape, dtype=mstype.float32):
[1. 1.]] [1. 1.]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if _is_shape_empty(shape): if _is_shape_empty(shape):
# 如果shape表示的形状是空的
return full(shape, 1.0, dtype) return full(shape, 1.0, dtype)
# 使用full函数创建一个指定形状、数据类型并用1.0填充的Tensor
output = F.fill(dtype, shape, 1) output = F.fill(dtype, shape, 1)
# 使用F.fill函数创建一个指定形状、数据类型并用1填充的Tensor
return output return output
# 返回创建的Tensor
def zeros(shape, dtype=mstype.float32): def zeros(shape, dtype=mstype.float32):
@ -323,11 +381,17 @@ def zeros(shape, dtype=mstype.float32):
[0. 0.]] [0. 0.]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if _is_shape_empty(shape): if _is_shape_empty(shape):
# 如果shape表示的形状是空的
return full(shape, 0.0, dtype) return full(shape, 0.0, dtype)
# 使用full函数创建一个指定形状、数据类型并用0.0填充的Tensor
output = F.fill(dtype, shape, 0) output = F.fill(dtype, shape, 0)
# 使用F.fill函数创建一个指定形状、数据类型并用0填充的Tensor
return output return output
# 返回创建的Tensor
def full(shape, fill_value, dtype=None): def full(shape, fill_value, dtype=None):
@ -360,24 +424,42 @@ def full(shape, fill_value, dtype=None):
[True True]] [True True]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
if not isinstance(fill_value, ARRAY_TYPES): if not isinstance(fill_value, ARRAY_TYPES):
# 如果fill_value不是int、float、bool、list、tuple、Tensor类型
_raise_type_error("fill value should be int, float, bool, list, tuple, Tensor, but got", fill_value) _raise_type_error("fill value should be int, float, bool, list, tuple, Tensor, but got", fill_value)
# 抛出TypeError表示fill_value类型不支持
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
else: else:
# 如果dtype为None
if isinstance(fill_value, (int, float, bool)): if isinstance(fill_value, (int, float, bool)):
# 如果fill_value是int、float或bool类型
dtype = _get_dtype_from_scalar(fill_value) dtype = _get_dtype_from_scalar(fill_value)
# 从标量fill_value中获取数据类型并赋值给dtype
if isinstance(fill_value, Tensor): if isinstance(fill_value, Tensor):
# 如果fill_value是Tensor类型
dtype = fill_value.dtype dtype = fill_value.dtype
# 从Tensor fill_value中获取数据类型并赋值给dtype
if not _is_shape_empty(shape): if not _is_shape_empty(shape):
# 如果shape表示的形状不是空的
if isinstance(fill_value, (int, float, bool)): if isinstance(fill_value, (int, float, bool)):
# 如果fill_value是int、float或bool类型
return F.fill(dtype, shape, fill_value) return F.fill(dtype, shape, fill_value)
# 使用F.fill函数创建一个指定形状、数据类型并用fill_value填充的Tensor
if isinstance(fill_value, (list, tuple)): if isinstance(fill_value, (list, tuple)):
# 如果fill_value是list或tuple类型
fill_value = asarray_const(fill_value) fill_value = asarray_const(fill_value)
# 使用asarray_const函数将fill_value转换为Tensor
return broadcast_to(fill_value, shape) return broadcast_to(fill_value, shape)
# 使用broadcast_to函数将fill_value广播到指定的shape并返回结果
# if shape contains zero, use c.Tensor() # if shape contains zero, use c.Tensor()
return _convert_64_to_32(empty_compile(dtype, shape)) return _convert_64_to_32(empty_compile(dtype, shape))
# 如果shape包含零使用empty_compile函数创建一个空的Tensor并使用_convert_64_to_32函数将数据类型从float64转换为float32
@constexpr @constexpr

@ -30,10 +30,12 @@ class AllGatherCell(Cell):
def __init__(self, group): def __init__(self, group):
super(AllGatherCell, self).__init__(auto_prefix=False) super(AllGatherCell, self).__init__(auto_prefix=False)
# 创建AllGather操作对象
self.allgather = AllGather(group) self.allgather = AllGather(group)
@ms_function() @ms_function()
def construct(self, x): def construct(self, x):
# 执行AllGather操作
x = self.allgather(x) x = self.allgather(x)
return x return x
@ -50,10 +52,12 @@ class SaveOptShardCkptCell(Cell):
""" """
def __init__(self, group): def __init__(self, group):
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False) super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
# 创建AllGather操作对象
self.allgather1 = AllGather(group) self.allgather1 = AllGather(group)
self.allgather2 = AllGather() self.allgather2 = AllGather()
def construct(self, x): def construct(self, x):
# 执行AllGather操作
x = self.allgather1(x) x = self.allgather1(x)
x = self.allgather2(x) x = self.allgather2(x)
@ -64,11 +68,14 @@ def get_allgather_cell(group, need_merge_twice=False):
"""Get AllGatherCell object.""" """Get AllGatherCell object."""
global _allgather_cell global _allgather_cell
if need_merge_twice: if need_merge_twice:
# 如果需要两次合并则创建SaveOptShardCkptCell对象
_allgather_cell = SaveOptShardCkptCell(group) _allgather_cell = SaveOptShardCkptCell(group)
else: else:
if group: if group:
# 如果有指定的设备组则创建AllGatherCell对象
_allgather_cell = AllGatherCell(group) _allgather_cell = AllGatherCell(group)
else: else:
# 否则创建AllGatherCell对象使用全局通信组
_allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP) _allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
return _allgather_cell return _allgather_cell
@ -77,4 +84,5 @@ def destroy_allgather_cell():
"""Destroy AllGatherCell object.""" """Destroy AllGatherCell object."""
global _allgather_cell global _allgather_cell
if _allgather_cell: if _allgather_cell:
# 销毁AllGatherCell对象
_allgather_cell = None _allgather_cell = None

@ -28,13 +28,16 @@ def parse_args():
Examples: Examples:
>>> parse_args() >>> parse_args()
""" """
# 创建一个ArgumentParser对象用于解析命令行参数描述信息为"MindSpore dependency packages version checker."
parser = ArgumentParser(description="MindSpore dependency packages version checker.") parser = ArgumentParser(description="MindSpore dependency packages version checker.")
# 添加一个命令行参数--mindspore_version类型为字符串帮助信息为"MindSpore version."
parser.add_argument("--mindspore_version", type=str, help="MindSpore version.") parser.add_argument("--mindspore_version", type=str, help="MindSpore version.")
# 添加一个命令行参数--supported_version类型为字符串可以多次指定帮助信息为"Supported environment version."
parser.add_argument("--supported_version", type=str, action='append', help="Supported environment version.") parser.add_argument("--supported_version", type=str, action='append', help="Supported environment version.")
# 解析命令行参数并返回结果
args = parser.parse_args() args = parser.parse_args()
return args return args
def check_deps_version(mindspore_version, supported_version): def check_deps_version(mindspore_version, supported_version):
""" """
check te/hccl/topi version check te/hccl/topi version
@ -46,6 +49,7 @@ def check_deps_version(mindspore_version, supported_version):
Returns: Returns:
void void
""" """
# 尝试导入并检查 hccl、te 和 topi 包的版本是否与支持的版本匹配
try: try:
from hccl import sys_version as hccl_version from hccl import sys_version as hccl_version
v = '.'.join(hccl_version.__sys_version__.split('.')[0:2]) v = '.'.join(hccl_version.__sys_version__.split('.')[0:2])
@ -63,18 +67,20 @@ def check_deps_version(mindspore_version, supported_version):
print(f"MindSpore version {mindspore_version} and \"topi\" wheel package version {v} does not " print(f"MindSpore version {mindspore_version} and \"topi\" wheel package version {v} does not "
"match, reference to the match info on: https://www.mindspore.cn/install") "match, reference to the match info on: https://www.mindspore.cn/install")
# 捕获导入错误并打印相应的检查失败信息
except ImportError as e: except ImportError as e:
print("CheckFailed: ", e.args) print("CheckFailed: ", e.args)
print("MindSpore relies on the 3 whl packages of \"te\", \"topi\" and \"hccl\" in the \"fwkacllib\" " print("MindSpore relies on the 3 whl packages of \"te\", \"topi\" and \"hccl\" in the \"fwkacllib\" "
"folder of the Ascend AI software package (Ascend Data Center Solution), please check whether they are " "folder of the Ascend AI software package (Ascend Data Center Solution), please check whether they are "
"installed correctly or not, reference to the match info on: https://www.mindspore.cn/install") "installed correctly or not, reference to the match info on: https://www.mindspore.cn/install")
def main(): def main():
# 解析命令行参数
args = parse_args() args = parse_args()
# 检查 mindspore 的版本是否在支持的版本范围内
check_deps_version(args.mindspore_version, args.supported_version) check_deps_version(args.mindspore_version, args.supported_version)
if __name__ == "__main__": if __name__ == "__main__":
sys.path = sys.path[1:] # avoid the impact of relative path env, only affect this process # 避免相对路径环境的影响,仅影响当前进程
sys.path = sys.path[1:]
main() main()

@ -31,47 +31,58 @@ class EnvChecker(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def check_env(self, e): def check_env(self, e):
pass """检查环境是否符合要求"""
@abstractmethod @abstractmethod
def set_env(self): def set_env(self):
pass """设置环境"""
@abstractmethod @abstractmethod
def check_version(self): def check_version(self):
pass """检查版本是否符合要求"""
class GPUEnvChecker(EnvChecker): class GPUEnvChecker(EnvChecker):
"""GPU environment check.""" """GPU environment check."""
def __init__(self): def __init__(self):
# 初始化版本列表
self.version = ["10.1", "11.1"] self.version = ["10.1", "11.1"]
# 初始化库键到库名的映射字典
self.lib_key_to_lib_name = {'libcu': 'libcuda.so'} self.lib_key_to_lib_name = {'libcu': 'libcuda.so'}
# env # env
# 获取系统环境变量 PATH 的值
self.path = os.getenv("PATH") self.path = os.getenv("PATH")
# 获取系统环境变量 LD_LIBRARY_PATH 的值
self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") self.ld_lib_path = os.getenv("LD_LIBRARY_PATH")
# check # check
# 初始化版本号为 "0"
self.v = "0" self.v = "0"
# 获取 CUDA 库的路径
self.cuda_lib_path = self._get_lib_path("libcu") self.cuda_lib_path = self._get_lib_path("libcu")
# 获取 CUDA 可执行文件的路径
self.cuda_bin_path = self._get_bin_path("cuda") self.cuda_bin_path = self._get_bin_path("cuda")
# 获取 cuDNN 库的路径
self.cudnn_lib_path = self._get_lib_path("libcudnn") self.cudnn_lib_path = self._get_lib_path("libcudnn")
def check_env(self, e): def check_env(self, e):
# 抛出传入的异常 e
raise e raise e
def set_env(self): def set_env(self):
# 设置环境变量,当前实现为空
return return
def _get_bin_path(self, bin_name): def _get_bin_path(self, bin_name):
"""Get bin path by bin name.""" """Get bin path by bin name."""
# 如果二进制名称为 "cuda",则调用获取 CUDA 二进制路径的方法
if bin_name == "cuda": if bin_name == "cuda":
return self._get_cuda_bin_path() return self._get_cuda_bin_path()
# 否则返回空列表
return [] return []
def _get_cuda_bin_path(self): def _get_cuda_bin_path(self):
"""Get cuda bin path by lib path.""" # Get cuda bin path by lib path.
path_list = [] path_list = []
for path in self.cuda_lib_path: for path in self.cuda_lib_path:
path = os.path.abspath(path.strip()+"/bin/") path = os.path.abspath(path.strip()+"/bin/")
@ -81,56 +92,87 @@ class GPUEnvChecker(EnvChecker):
def _get_nvcc_version(self, is_set_env): def _get_nvcc_version(self, is_set_env):
"""Get cuda version by nvcc command.""" """Get cuda version by nvcc command."""
# 运行 nvcc 命令获取 CUDA 版本信息
nvcc_result = subprocess.run(["nvcc", "--version | grep release"], nvcc_result = subprocess.run(["nvcc", "--version | grep release"],
timeout=3, text=True, capture_output=True, check=False) timeout=3, text=True, capture_output=True, check=False)
# 如果命令返回非零值,表示命令执行失败
if nvcc_result.returncode: if nvcc_result.returncode:
# 如果尚未设置环境变量
if not is_set_env: if not is_set_env:
# 遍历预设的 CUDA 二进制路径
for path in self.cuda_bin_path: for path in self.cuda_bin_path:
# 检查路径中是否存在 nvcc 文件
if Path(path + "/nvcc").is_file(): if Path(path + "/nvcc").is_file():
# 将路径添加到环境变量 PATH 中
os.environ['PATH'] = path + ":" + os.environ['PATH'] os.environ['PATH'] = path + ":" + os.environ['PATH']
# 递归调用以重新尝试获取版本信息
return self._get_nvcc_version(True) return self._get_nvcc_version(True)
# 如果命令执行失败且未找到 nvcc 文件,返回空字符串
return "" return ""
# 获取命令输出结果
result = nvcc_result.stdout result = nvcc_result.stdout
# 遍历输出结果的每一行
for line in result.split('\n'): for line in result.split('\n'):
if line: if line:
# 提取并返回 CUDA 版本号
return line.strip().split("release")[1].split(",")[0].strip() return line.strip().split("release")[1].split(",")[0].strip()
# 如果未找到版本信息,返回空字符串
return "" return ""
def _get_cudnn_version(self): def _get_cudnn_version(self):
"""Get cudnn version by libcudnn.so.""" """Get cudnn version by libcudnn.so."""
# 初始化cudnn版本列表为空
cudnn_version = [] cudnn_version = []
# 遍历cudnn库路径
for path in self.cudnn_lib_path: for path in self.cudnn_lib_path:
# 查找路径下所有的libcudnn.so文件
real_path = glob.glob(path + "/lib*/libcudnn.so.*.*") real_path = glob.glob(path + "/lib*/libcudnn.so.*.*")
# 如果没有找到对应的文件,继续下一个路径
if real_path == []: if real_path == []:
continue continue
# 使用ls命令获取文件信息
ls_cudnn = subprocess.run(["ls", real_path[0]], timeout=10, text=True, ls_cudnn = subprocess.run(["ls", real_path[0]], timeout=10, text=True,
capture_output=True, check=False) capture_output=True, check=False)
# 如果ls命令执行成功解析输出以获取版本号
if ls_cudnn.returncode == 0: if ls_cudnn.returncode == 0:
cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.') cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.')
# 如果版本号只有两个部分,添加一个'.0'作为第三部分
if len(cudnn_version) == 2: if len(cudnn_version) == 2:
cudnn_version.append('0') cudnn_version.append('0')
# 找到版本号后跳出循环
break break
# 将版本号列表转换为字符串
version_str = ''.join([n for n in cudnn_version]) version_str = ''.join([n for n in cudnn_version])
# 返回版本号的前三位
return version_str[0:3] return version_str[0:3]
def _get_cudart_version(self): def _get_cudart_version(self):
"""Get cuda runtime version by libcudart.so.""" """Get cuda runtime version by libcudart.so."""
# 遍历可能的 CUDA 库路径
for path in self.cuda_lib_path: for path in self.cuda_lib_path:
# 查找路径下所有可能的 libcudart.so 文件
real_path = glob.glob(path + "/lib*/libcudart.so.*.*.*") real_path = glob.glob(path + "/lib*/libcudart.so.*.*.*")
# 如果没有找到任何文件,则跳过当前路径
if real_path == []: if real_path == []:
continue continue
# 获取文件名信息以确定 CUDA 版本
ls_cudart = subprocess.run(["ls", real_path[0]], timeout=10, text=True, ls_cudart = subprocess.run(["ls", real_path[0]], timeout=10, text=True,
capture_output=True, check=False) capture_output=True, check=False)
# 如果命令成功执行,则解析输出以提取版本号
if ls_cudart.returncode == 0: if ls_cudart.returncode == 0:
self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip() self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip()
# 找到版本号后跳出循环
break break
# 返回找到的 CUDA 版本号
return self.v return self.v
def check_version(self): def check_version(self):
"""Check cuda version.""" """Check cuda version."""
version_match = False version_match = False
# 调用私有方法检查版本是否匹配并根据结果设置version_match标志
if self._check_version(): if self._check_version():
version_match = True version_match = True
# 如果版本不匹配根据CUDA版本号输出不同的警告信息
if not version_match: if not version_match:
if self.v == "0": if self.v == "0":
logger.warning("Can not found cuda libs, please confirm that the correct " logger.warning("Can not found cuda libs, please confirm that the correct "
@ -140,17 +182,20 @@ class GPUEnvChecker(EnvChecker):
logger.warning(f"MindSpore version {__version__} and cuda version {self.v} does not match, " logger.warning(f"MindSpore version {__version__} and cuda version {self.v} does not match, "
"please refer to the installation guide for version matching " "please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install") "information: https://www.mindspore.cn/install")
# 获取nvcc版本号并检查是否与MindSpore支持的版本匹配
nvcc_version = self._get_nvcc_version(False) nvcc_version = self._get_nvcc_version(False)
if nvcc_version and (nvcc_version not in self.version): if nvcc_version and (nvcc_version not in self.version):
logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} " logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} "
"does not match, please refer to the installation guide for version matching " "does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install") "information: https://www.mindspore.cn/install")
# 获取cudnn版本号并检查是否符合最低要求
cudnn_version = self._get_cudnn_version() cudnn_version = self._get_cudnn_version()
if cudnn_version and int(cudnn_version) < 760: if cudnn_version and int(cudnn_version) < 760:
logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} " logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching " "does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install. The recommended version is " "information: https://www.mindspore.cn/install. The recommended version is "
"CUDA10.1 with cuDNN7.6.x and CUDA11.1 with cuDNN8.0.x") "CUDA10.1 with cuDNN7.6.x and CUDA11.1 with cuDNN8.0.x")
# 检查cudnn版本号与CUDA版本号的兼容性对于CUDA 11.0以上版本cudnn版本需要至少为8.0
if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10: if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10:
logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} " logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching " "does not match, please refer to the installation guide for version matching "
@ -159,45 +204,58 @@ class GPUEnvChecker(EnvChecker):
def _check_version(self): def _check_version(self):
"""Check cuda version""" """Check cuda version"""
# 获取 CUDA 运行时版本
v = self._get_cudart_version() v = self._get_cudart_version()
# 解析版本字符串为版本对象
v = version.parse(v) v = version.parse(v)
# 构造版本号字符串,格式为 "主版本.次版本"
v_str = str(v.major) + "." + str(v.minor) v_str = str(v.major) + "." + str(v.minor)
# 检查构造的版本号字符串是否在预定义的版本列表中
if v_str not in self.version: if v_str not in self.version:
return False return False
# 版本号匹配,返回 True
return True return True
def _get_lib_path(self, lib_name): def _get_lib_path(self, lib_name):
"""Get gpu lib path by ldd command.""" """通过ldd命令获取gpu库路径。"""
path_list = [] path_list = [] # 初始化一个空列表用于存储路径
current_path = os.path.split(os.path.realpath(__file__))[0] current_path = os.path.split(os.path.realpath(__file__))[0] # 获取当前文件的绝对路径并分割以获取目录部分
mindspore_path = os.path.join(current_path, "../") mindspore_path = os.path.join(current_path, "../") # 构建mindspore路径通常是当前文件的上一级目录
try: try:
# 使用glob模块查找mindspore_path目录下所有以_c_expression.so开头的文件路径
real_path = glob.glob(mindspore_path + "/_c_expression*.so*") real_path = glob.glob(mindspore_path + "/_c_expression*.so*")
if real_path == []: if real_path == []: # 如果没有找到任何文件
logger.error(f"{self.lib_key_to_lib_name[lib_name]} (need by mindspore-gpu) is not found, please " # 记录错误日志提示用户确认_c_expression.so文件是否存在以及是否安装了正确的cuda版本
f"confirm that _c_expression.so is in directory:{mindspore_path} and the correct cuda " logger.error(f"{self.lib_key_to_lib_name[lib_name]} (mindspore-gpu所需的库) 未找到,请确认 "
"version has been installed, you can refer to the installation " f"_c_expression.so是否位于目录:{mindspore_path}并且已安装正确的cuda版本"
"guidelines: https://www.mindspore.cn/install") "您可以参考安装指南https://www.mindspore.cn/install")
return path_list return path_list # 返回空路径列表
# 使用subprocess.Popen执行ldd命令以获取依赖库的信息
ldd_r = subprocess.Popen(['ldd', real_path[0]], stdout=subprocess.PIPE) ldd_r = subprocess.Popen(['ldd', real_path[0]], stdout=subprocess.PIPE)
# 使用subprocess.Popen的stdin参数从ldd_r.stdout接收输出并执行grep命令以过滤出包含指定库名的信息
ldd_result = subprocess.Popen(['grep', lib_name], stdin=ldd_r.stdout, stdout=subprocess.PIPE) ldd_result = subprocess.Popen(['grep', lib_name], stdin=ldd_r.stdout, stdout=subprocess.PIPE)
# 获取grep命令的输出结果并解码为字符串
result = ldd_result.communicate()[0].decode() result = ldd_result.communicate()[0].decode()
for i in result.split('\n'): for i in result.split('\n'): # 按行分割结果字符串
# 使用partition方法从每一行中提取出库文件的路径
path = i.partition("=>")[2] path = i.partition("=>")[2]
if path.lower().find("not found") > 0: if path.lower().find("not found") > 0: # 如果路径中包含"not found"
logger.warning(f"Cuda {self.version} version(need by mindspore-gpu) is not found, please confirm " # 记录警告日志提示用户确认cuda路径是否已添加到环境变量LD_LIBRARY_PATH中
"that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the " logger.warning(f"Cuda {self.version}版本(由mindspore-gpu要求的) 未找到请确认cuda路径已设置到环境变量LD_LIBRARY_PATH中"
"installation guidelines: https://www.mindspore.cn/install") "您可以参考安装指南https://www.mindspore.cn/install")
continue continue # 继续下一次循环
# 从路径中去除库名部分
path = path.partition(lib_name)[0] path = path.partition(lib_name)[0]
if path: if path: # 如果路径非空
# 将路径的绝对路径并去除末尾斜杠后添加到path_list中
path_list.append(os.path.abspath(path.strip() + "../")) path_list.append(os.path.abspath(path.strip() + "../"))
# 返回path_list中唯一的路径
return np.unique(path_list) return np.unique(path_list)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired: # 捕获subprocess.TimeoutExpired异常
logger.warning("Failed to check cuda version due to the ldd command timeout, please confirm that " # 记录警告日志提示用户确认cuda版本是否正确安装因为ldd命令超时
"the correct cuda version has been installed, you can refer to the " logger.warning("由于ldd命令超时无法检查cuda版本请确认已安装正确的cuda版本"
"installation guidelines: https://www.mindspore.cn/install") "您可以参考安装指南:https://www.mindspore.cn/install")
return path_list return path_list # 返回空路径列表
def _read_version(self, file_path): def _read_version(self, file_path):
"""Get gpu version info in version.txt.""" """Get gpu version info in version.txt."""
@ -211,70 +269,80 @@ class GPUEnvChecker(EnvChecker):
class AscendEnvChecker(EnvChecker): class AscendEnvChecker(EnvChecker):
"""ascend environment check""" """Ascend 环境检查类"""
def __init__(self): def __init__(self):
# 初始化 Ascend 环境检查器的版本列表
self.version = ["1.81"] self.version = ["1.81"]
# 定义不同路径下的 version.info 文件位置
atlas_nnae_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info" atlas_nnae_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info"
atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/version.info" atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/version.info"
hisi_fwk_version = "/usr/local/Ascend/latest/fwkacllib/version.info" hisi_fwk_version = "/usr/local/Ascend/latest/fwkacllib/version.info"
# 检查 Atlas NNAE 环境是否存在
if os.path.exists(atlas_nnae_version): if os.path.exists(atlas_nnae_version):
# atlas default path # 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = atlas_nnae_version self.fwk_version = atlas_nnae_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/nnae/latest/opp" self.op_path = "/usr/local/Ascend/nnae/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/nnae/latest" self.aicpu_path = "/usr/local/Ascend/nnae/latest" # AI CPU 路径
# 检查 Atlas Toolkit 环境是否存在
elif os.path.exists(atlas_toolkit_version): elif os.path.exists(atlas_toolkit_version):
# atlas default path # 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = atlas_toolkit_version self.fwk_version = atlas_toolkit_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" # AI CPU 路径
# 检查 Hisi 环境是否存在
elif os.path.exists(hisi_fwk_version): elif os.path.exists(hisi_fwk_version):
# hisi default path # 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = hisi_fwk_version self.fwk_version = hisi_fwk_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/latest/opp" self.op_path = "/usr/local/Ascend/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/latest" self.aicpu_path = "/usr/local/Ascend/latest" # AI CPU 路径
else:
# custom or unknown environment
self.fwk_path = ""
self.op_impl_path = ""
self.tbe_path = ""
self.cce_path = ""
self.fwk_version = ""
self.op_path = ""
self.aicpu_path = ""
# env else:
# 如果以上环境都不存在,设置为空路径
self.fwk_path = "" # Framework 路径
self.op_impl_path = "" # Operator 实现路径
self.tbe_path = "" # TBE 库路径
self.cce_path = "" # CCE 编译器路径
self.fwk_version = "" # Framework 版本文件路径
self.op_path = "" # Operator 路径
self.aicpu_path = "" # AI CPU 路径
# 初始化环境变量
self.path = os.getenv("PATH") self.path = os.getenv("PATH")
self.python_path = os.getenv("PYTHONPATH") self.python_path = os.getenv("PYTHONPATH")
self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") self.ld_lib_path = os.getenv("LD_LIBRARY_PATH")
self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH") self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH")
self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH") self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH")
# check content # 设置需要检查的路径内容
self.path_check = "/fwkacllib/ccec_compiler/bin" self.path_check = "/fwkacllib/ccec_compiler/bin"
self.python_path_check = "opp/op_impl/built-in/ai_core/tbe" self.python_path_check = "opp/op_impl/built-in/ai_core/tbe"
self.ld_lib_path_check_fwk = "/fwkacllib/lib64" self.ld_lib_path_check_fwk = "/fwkacllib/lib64"
self.ld_lib_path_check_addons = "/add-ons" self.ld_lib_path_check_addons = "/add-ons"
self.ascend_opp_path_check = "/op" self.ascend_opp_path_check = "/op"
self.v = "" self.v = ""
def check_env(self, e): def check_env(self, e):
self._check_env() self._check_env()
raise e raise e
def check_version(self): def check_version(self):
# 检查指定路径的版本文件是否存在,如果不存在则跳过版本检查
if not Path(self.fwk_version).is_file(): if not Path(self.fwk_version).is_file():
logger.warning("Using custom Ascend AI software package (Ascend Data Center Solution) path, package " logger.warning("Using custom Ascend AI software package (Ascend Data Center Solution) path, package "
"version checking is skipped, please make sure Ascend AI software package (Ascend Data " "version checking is skipped, please make sure Ascend AI software package (Ascend Data "
@ -282,36 +350,43 @@ class AscendEnvChecker(EnvChecker):
"https://www.mindspore.cn/install") "https://www.mindspore.cn/install")
return return
# 读取版本文件中的版本信息
v = self._read_version(self.fwk_version) v = self._read_version(self.fwk_version)
# 如果读取的版本不在支持的版本列表中,则记录警告信息
if v not in self.version: if v not in self.version:
v_list = str([x for x in self.version]) v_list = str([x for x in self.version])
logger.warning(f"MindSpore version {__version__} and Ascend AI software package (Ascend Data Center " logger.warning(f"MindSpore version {__version__} and Ascend AI software package (Ascend Data Center "
f"Solution)version {v} does not match, the version of software package expect one of " f"Solution)version {v} does not match, the version of software package expect one of "
f"{v_list}, please reference to the match info on: https://www.mindspore.cn/install") f"{v_list}, please reference to the match info on: https://www.mindspore.cn/install")
def check_deps_version(self): def check_deps_version(self):
""" """
te, topi, hccl wheel package version check te, topi, hccl wheel package version check
in order to update the change of 'LD_LIBRARY_PATH' env, run a sub process in order to update the change of 'LD_LIBRARY_PATH' env, run a sub process
""" """
# 构建输入参数列表包含mindspore版本和受支持的版本列表
input_args = ["--mindspore_version=" + __version__] input_args = ["--mindspore_version=" + __version__]
for v in self.version: for v in self.version:
input_args.append("--supported_version=" + v) input_args.append("--supported_version=" + v)
# 获取依赖版本检查脚本的路径
deps_version_checker = os.path.join(os.path.split(os.path.realpath(__file__))[0], deps_version_checker = os.path.join(os.path.split(os.path.realpath(__file__))[0],
"_check_deps_version.py") "_check_deps_version.py")
# 构建调用命令包括python解释器路径、脚本路径和输入参数
call_cmd = [sys.executable, deps_version_checker] + input_args call_cmd = [sys.executable, deps_version_checker] + input_args
try: try:
# 运行子进程进行版本检查设置超时时间为3秒并捕获输出
process = subprocess.run(call_cmd, timeout=3, text=True, capture_output=True, check=False) process = subprocess.run(call_cmd, timeout=3, text=True, capture_output=True, check=False)
# 如果子进程的输出不为空,则记录警告信息并进行倒计时提醒
if process.stdout.strip() != "": if process.stdout.strip() != "":
logger.warning(process.stdout.strip()) logger.warning(process.stdout.strip())
warning_countdown = 3 warning_countdown = 3
for i in range(warning_countdown, 0, -1): for i in range(warning_countdown, 0, -1):
logger.warning(f"Please pay attention to the above warning, countdown: {i}") logger.warning(f"Please pay attention to the above warning, countdown: {i}")
time.sleep(1) time.sleep(1)
# 如果版本检查超时,则记录信息并跳过
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.info("Package te, topi, hccl version check timed out, skip.") logger.info("Package te, topi, hccl version check timed out, skip.")
def set_env(self): def set_env(self):
# 设置Ascend环境变量
if not self.tbe_path: if not self.tbe_path:
self._check_env() self._check_env()
return return
@ -330,24 +405,26 @@ class AscendEnvChecker(EnvChecker):
f"No such directory: {self.tbe_path}, Please check if Ascend AI software package (Ascend Data " f"No such directory: {self.tbe_path}, Please check if Ascend AI software package (Ascend Data "
"Center Solution) is installed correctly.") "Center Solution) is installed correctly.")
# check te version after set te env # 检查te版本
self.check_deps_version() self.check_deps_version()
# 设置op实现路径环境变量
if Path(self.op_impl_path).is_dir(): if Path(self.op_impl_path).is_dir():
# python path for sub process # python路径用于子进程
if os.getenv('PYTHONPATH'): if os.getenv('PYTHONPATH'):
os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH'] os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH']
else: else:
os.environ['PYTHONPATH'] = self.op_impl_path os.environ['PYTHONPATH'] = self.op_impl_path
# sys path for this process # sys路径用于当前进程
sys.path.append(self.op_impl_path) sys.path.append(self.op_impl_path)
os.environ['TBE_IMPL_PATH'] = self.op_impl_path os.environ['TBE_IMPL_PATH'] = self.op_impl_path
else: else:
raise EnvironmentError( raise EnvironmentError(
f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data " f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data Center "
"Center Solution) is installed correctly.") "Solution) is installed correctly.")
# 设置CCE路径环境变量
if Path(self.cce_path).is_dir(): if Path(self.cce_path).is_dir():
os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH'] os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH']
else: else:
@ -355,6 +432,7 @@ class AscendEnvChecker(EnvChecker):
f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center " f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.") "Solution) is installed correctly.")
# 设置OP路径环境变量
if self.op_path is None: if self.op_path is None:
pass pass
elif Path(self.op_path).is_dir(): elif Path(self.op_path).is_dir():
@ -364,6 +442,7 @@ class AscendEnvChecker(EnvChecker):
f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center " f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.") "Solution) is installed correctly.")
# 设置AICPU路径环境变量
if self.aicpu_path is None: if self.aicpu_path is None:
pass pass
elif Path(self.aicpu_path).is_dir(): elif Path(self.aicpu_path).is_dir():
@ -375,41 +454,51 @@ class AscendEnvChecker(EnvChecker):
def _check_env(self): def _check_env(self):
"""ascend dependence path check""" """ascend dependence path check"""
# 检查是否设置正确的PATH环境变量
if self.path is None or self.path_check not in self.path: if self.path is None or self.path_check not in self.path:
logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env " logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env "
"PATH, you can reference to the installation guidelines https://www.mindspore.cn/install") "PATH, you can reference to the installation guidelines https://www.mindspore.cn/install")
# 检查是否设置正确的PYTHONPATH环境变量
if self.python_path is None or self.python_path_check not in self.python_path: if self.python_path is None or self.python_path_check not in self.python_path:
logger.warning( logger.warning(
"Can not find tbe op implement(need by mindspore-ascend), please check if you have set env " "Can not find tbe op implement(need by mindspore-ascend), please check if you have set env "
"PYTHONPATH, you can reference to the installation guidelines " "PYTHONPATH, you can reference to the installation guidelines "
"https://www.mindspore.cn/install") "https://www.mindspore.cn/install")
# 检查是否设置正确的LD_LIBRARY_PATH环境变量
if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and
self.ld_lib_path_check_addons in self.ld_lib_path): self.ld_lib_path_check_addons in self.ld_lib_path):
logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env " logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env "
"LD_LIBRARY_PATH, you can reference to the installation guidelines " "LD_LIBRARY_PATH, you can reference to the installation guidelines "
"https://www.mindspore.cn/install") "https://www.mindspore.cn/install")
# 检查是否设置正确的ASCEND_OPP_PATH环境变量
if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path: if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path:
logger.warning( logger.warning(
"Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, " "Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, "
"you can reference to the installation guidelines https://www.mindspore.cn/install") "you can reference to the installation guidelines https://www.mindspore.cn/install")
def _read_version(self, file_path): def _read_version(self, file_path):
"""get ascend version info""" """get ascend version info"""
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
all_info = f.readlines() all_info = f.readlines()
# 遍历文件中的每一行
for line in all_info: for line in all_info:
# 检查行是否以 "Version=" 开头
if line.startswith("Version="): if line.startswith("Version="):
# 去除行末的换行符并按 "=" 分割, 获取版本号
full_version = line.strip().split("=")[1] full_version = line.strip().split("=")[1]
# 提取主版本号和次版本号, 并用 "." 连接
self.v = '.'.join(full_version.split('.')[0:2]) self.v = '.'.join(full_version.split('.')[0:2])
# 返回版本号
return self.v return self.v
# 如果未找到版本信息, 返回 None 或默认值
return self.v return self.v
def check_version_and_env_config(): def check_version_and_env_config():
"""check version and env config""" """检查版本和环境配置"""
# 检查包名以确定使用哪种环境检查器
if __package_name__.lower() == "mindspore-ascend": if __package_name__.lower() == "mindspore-ascend":
env_checker = AscendEnvChecker() env_checker = AscendEnvChecker()
# Note: pre-load libgomp.so to solve error like "cannot allocate memory in statis TLS block" # Note: pre-load libgomp.so to solve error like "cannot allocate memory in statis TLS block"
@ -425,19 +514,21 @@ def check_version_and_env_config():
else: else:
logger.info(f"Package version {__package_name__} does not need to check any environment variable, skipping.") logger.info(f"Package version {__package_name__} does not need to check any environment variable, skipping.")
return return
# 检查是否关闭版本检查,如果已关闭则直接返回
if os.getenv("MS_DEV_CLOSE_VERSION_CHECK") == "ON": if os.getenv("MS_DEV_CLOSE_VERSION_CHECK") == "ON":
return return
# 设置环境变量以关闭版本检查
os.environ["MS_DEV_CLOSE_VERSION_CHECK"] = "ON" os.environ["MS_DEV_CLOSE_VERSION_CHECK"] = "ON"
try: try:
# check version of ascend site or cuda # 检查 ascend site 或 cuda 的版本
env_checker.check_version() env_checker.check_version()
from .. import _c_expression # pylint: disable=unused-import from .. import _c_expression # pylint: disable=unused-import
# 设置环境
env_checker.set_env() env_checker.set_env()
except ImportError as e: except ImportError as e:
# 处理导入错误,检查环境
env_checker.check_env(e) env_checker.check_env(e)
def _set_pb_env(): def _set_pb_env():
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" """Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
@ -450,6 +541,8 @@ def _set_pb_env():
"during save or load checkpoint file.") "during save or load checkpoint file.")
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# 检查版本和环境配置
check_version_and_env_config() check_version_and_env_config()
# 设置协议缓冲区的环境变量, 防止内存溢出
_set_pb_env() _set_pb_env()

@ -34,15 +34,18 @@ def _check_mul():
finally: finally:
pass pass
# 打印MindSpore版本信息
print(f"MindSpore version: ", ms.__version__) print(f"MindSpore version: ", ms.__version__)
# 创建两个MindSpore张量分别包含数组[1.0, 2.0, 3.0]和[4.0, 5.0, 6.0]
input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32) input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32) input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32)
# 创建一个乘法操作对象
mul = ms.ops.Mul() mul = ms.ops.Mul()
# 执行乘法操作
mul(input_x, input_y) mul(input_x, input_y)
# 打印乘法计算结果正确MindSpore安装成功的信息
print(f"The result of multiplication calculation is correct, MindSpore has been installed successfully!") print(f"The result of multiplication calculation is correct, MindSpore has been installed successfully!")
def run_check(): def run_check():
""" """
Provide a convenient API to check if the installation is successful or failed. Provide a convenient API to check if the installation is successful or failed.
@ -55,10 +58,13 @@ def run_check():
The result of multiplication calculation is correct, MindSpore has been installed successfully! The result of multiplication calculation is correct, MindSpore has been installed successfully!
""" """
try: try:
# 尝试执行检查乘法操作的函数
_check_mul() _check_mul()
# pylint: disable=broad-except # pylint: disable=broad-except
# 捕获所有异常并打印错误信息
except Exception as e: except Exception as e:
print("MindSpore running check failed.") print("MindSpore running check failed.")
print(str(e)) print(str(e))
finally: finally:
pass # 无论是否发生异常,都会执行此部分代码
pass # 执行乘法检查的函数,并处理可能的异常情况。如果检查失败,打印错误信息。

@ -22,185 +22,171 @@ from ..common import dtype as mstype
from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack
from ..ops.composite import GradOperation from ..ops.composite import GradOperation
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) # 定义一个求梯度操作,设置为只求第一个参数的梯度,不通过列表获取,且不使用敏感参数
_eps_net = ops.Eps() _eps_net = ops.Eps() # 定义一个计算数值精度的操作
def _convert_64_to_32(tensor): # 定义一个函数将输入的tensor从float64或int64类型转换为float32或int32类型
def _convert_64_to_32(tensor):
"""Convert Tensor with float64/int64 types to float32/int32.""" """Convert Tensor with float64/int64 types to float32/int32."""
if tensor.dtype == mstype.float64: if tensor.dtype == mstype.float64: # 如果tensor的数据类型是float64
return tensor.astype("float32") return tensor.astype("float32") # 将其转换为float32类型
if tensor.dtype == mstype.int64: if tensor.dtype == mstype.int64: # 如果tensor的数据类型是int64
return tensor.astype("int32") return tensor.astype("int32") # 将其转换为int32类型
return tensor return tensor # 如果不是以上两种类型则直接返回原tensor
def _to_tensor(*args, dtype=None): def _to_tensor(*args, dtype=None): # 定义一个函数将输入的参数转换为tensor
"""Returns each input as Tensor""" """Returns each input as Tensor"""
res = () res = () # 初始化一个空元组用于存储结果
for arg in args: for arg in args: # 遍历每一个输入参数
if isinstance(arg, (int, float, bool, list, tuple)): if isinstance(arg, (int, float, bool, list, tuple)): # 如果参数是整数、浮点数、布尔值、列表或元组
arg = _type_convert(Tensor, arg) arg = _type_convert(Tensor, arg) # 将其转换为Tensor类型
if dtype is None: if dtype is None: # 如果没有指定dtype
arg = _convert_64_to_32(arg) arg = _convert_64_to_32(arg) # 调用_convert_64_to_32函数进行类型转换
else: else: # 如果指定了dtype
arg = arg.astype(dtype) arg = arg.astype(dtype) # 将tensor转换为指定的dtype
elif not isinstance(arg, Tensor): elif not isinstance(arg, Tensor): # 如果参数不是Tensor类型
_raise_value_error("Expect input to be array like.") _raise_value_error("Expect input to be array like.") # 抛出错误,提示输入应为数组形式
res += (arg,) res += (arg,) # 将转换后的tensor添加到结果元组中
if len(res) == 1: if len(res) == 1: # 如果结果元组中只有一个元素
return res[0] return res[0] # 直接返回该元素
return res return res # 否则返回整个元组
def _to_scalar(arr): # 定义一个函数将输入的Tensor或ndarray转换为标量值
def _to_scalar(arr):
"""Convert a scalar Tensor or ndarray to a scalar.""" """Convert a scalar Tensor or ndarray to a scalar."""
if isinstance(arr, (int, float, bool)): if isinstance(arr, (int, float, bool)): # 如果输入参数是整数、浮点数或布尔值
return arr return arr # 直接返回该参数
if isinstance(arr, Tensor): if isinstance(arr, Tensor): # 如果输入参数是Tensor类型
if arr.shape: if arr.shape: # 如果tensor的形状不是空的即不是标量
return arr return arr # 返回整个tensor
return arr.asnumpy().item() return arr.asnumpy().item() # 如果是标量将其转换为numpy数组并返回标量值
raise ValueError("{} are not supported.".format(type(arr))) raise ValueError("{} are not supported.".format(type(arr))) # 如果输入参数不是以上两种类型,抛出错误,提示不支持该类型
def _eps(x): # 定义一个函数计算输入tensor的数值精度
def _eps(x): return _eps_net(x[(0,) * x.ndim]) # 使用_ops.Eps操作计算数值精度x[(0,) * x.ndim]确保输入的是一个标量
return _eps_net(x[(0,) * x.ndim])
def _safe_normalize(x, threshold=None): # 定义一个函数对输入的tensor进行归一化如果归一化结果非常小则设置为零
def _safe_normalize(x, threshold=None):
"""Normalize method that cast very small results to zero.""" """Normalize method that cast very small results to zero."""
x_sum2 = F.reduce_sum(F.pows(x, 2.0)) x_sum2 = F.reduce_sum(F.pows(x, 2.0)) # 计算tensor元素平方的和
norm = F.pows(x_sum2, 1. / 2.0) norm = F.pows(x_sum2, 1. / 2.0) # 计算上述和的平方根得到norm
if threshold is None: if threshold is None: # 如果没有指定threshold
if x.dtype in (mstype.float32, mstype.float64): if x.dtype in (mstype.float32, mstype.float64): # 如果tensor的dtype是float32或float64
# pick the first element of x to get the eps # pick the first element of x to get the eps # 获取eps来作为threshold
threshold = _eps(x) threshold = _eps(x)
else: else: # 如果tensor的dtype不是float32或float64
threshold = 0 threshold = 0 # 设置threshold为0
use_norm = greater(norm, threshold) use_norm = greater(norm, threshold) # 比较norm和threshold得到一个布尔mask
x_norm = x / norm x_norm = x / norm # 使用norm对tensor进行归一化
normalized_x = where(use_norm, x_norm, zeros_like(x)) normalized_x = where(use_norm, x_norm, zeros_like(x)) # 如果norm大于threshold则使用归一化后的tensor否则使用零
norm = where(use_norm, norm, zeros_like(norm)) norm = where(use_norm, norm, zeros_like(norm)) # 如果norm大于threshold则保留norm否则使用零
return normalized_x, norm return normalized_x, norm # 返回归一化后的tensor及其对应的norm
def sparse_dot(a, b): # 定义一个函数计算稀疏矩阵CSRTensor与向量generic Tensor的点积
def sparse_dot(a, b):
"""Returns the dot product of CSRTensor and generic Tensor(vector).""" """Returns the dot product of CSRTensor and generic Tensor(vector)."""
b_aligned = F.reshape(b, (b.shape[0], -1)) b_aligned = F.reshape(b, (b.shape[0], -1)) # 将向量b重塑为(b.shape[0], -1)的形状,使其可以与稀疏矩阵相乘
res = F.csr_mv(a, b_aligned) res = F.csr_mv(a, b_aligned) # 使用csr_mv操作计算稀疏矩阵a与向量b_aligned的点积
res = F.reshape(res, a.shape[:-1] + b.shape[1:]) res = F.reshape(res, a.shape[:-1] + b.shape[1:]) # 将计算结果重新塑形为a.shape[:-1] + b.shape[1:]的形状
return res return res # 返回结果
def _normalize_matvec(f): def _normalize_matvec(f): # 定义一个函数,对输入的矩阵或向量进行归一化处理
"""Normalize an argument for computing matrix-vector products.""" """Normalize an argument for computing matrix-vector products."""
if isinstance(f, Tensor): if isinstance(f, Tensor): # 如果输入参数是Tensor类型
return F.partial(dot, f) return F.partial(dot, f) # 返回一个带有矩阵参数f的dot函数的部分应用
if isinstance(f, CSRTensor):
return F.partial(sparse_dot, f)
return f
if isinstance(f, CSRTensor): # 如果输入参数是CSRTensor类型
return F.partial(sparse_dot, f) # 返回一个带有稀疏矩阵参数f的sparse_dot函数的部分应用
def _norm(x, ord_=None): return f # 如果输入参数不是上述两种类型,则直接返回原参数
if ord_ == mnp.inf:
res = mnp.max(mnp.abs(x))
else:
res = mnp.sqrt(mnp.sum(x ** 2))
return res
def _norm(x, ord_=None): # 定义一个函数计算输入tensor的范数
if ord_ == mnp.inf: # 如果ord_为无穷大实际为最大值
res = mnp.max(mnp.abs(x)) # 返回tensor绝对值的最大值
else: # 如果ord_不是无穷大
res = mnp.sqrt(mnp.sum(x ** 2)) # 返回tensor元素平方和的平方根即L2范数
return res # 返回结果
def _nd_transpose(a): def _nd_transpose(a): # 定义一个函数对输入的tensor进行转置最后一个维度与倒数第二个维度互换
dims = a.ndim dims = a.ndim # 获取tensor的维度数
if dims < 2: if dims < 2: # 如果tensor的维度小于2
_raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.") _raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.") # 抛出错误提示输入tensor的维度应大于等于2
axes = ops.make_range(0, dims) axes = ops.make_range(0, dims) # 生成一个从0到tensor维度数的序列
axes = axes[:-2] + (axes[-1],) + (axes[-2],) axes = axes[:-2] + (axes[-1],) + (axes[-2],) # 将序列中的倒数第二个和最后一个元素互换位置
return ops.transpose(a, axes) return ops.transpose(a, axes) # 使用transpose操作对tensor进行转置
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): # 定义一个函数,用于检查输入参数的值是否符合预期
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) # 调用_super_check函数进行检查
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): # 定义一个函数,用于检查输入参数的类型是否符合预期
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False) # 调用_super_check函数进行检查
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): # 定义一个函数用于检查输入参数的mstype是否符合预期
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype", # 调用_super_check函数进行检查
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype",
None, False) None, False)
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): # 定义一个函数,用于检查输入参数的数据类型是否符合预期
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", # 调用_super_check函数进行检查
None, False)
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): def _square_check(func_name, arg, arg_name='a'): # 定义一个函数,用于检查输入参数是否为方阵
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False) arg_shape = arg.shape # 获取输入参数的形状
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) # 检查输入参数的维度是否为2
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) # 检查输入参数的形状是否为方阵
def _square_check(func_name, arg, arg_name='a'): return arg # 返回检查后的参数
arg_shape = arg.shape
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): # 定义一个函数,用于在求解线性方程组时检查输入参数
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) # 获取第一个参数的形状和数据类型
return arg arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) # 获取第二个参数的形状和数据类型
_square_check(func_name, arg1, arg1_name) # 检查第一个参数是否为方阵
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) # 检查第二个参数的维度是否为1或2
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): _super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True) # 检查第一个参数和第二个参数的形状是否可以用于求解线性方程组
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) _super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False) # 检查第一个参数和第二个参数的数据类型是否匹配
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) return arg1, arg2 # 返回检查后的两个参数
_square_check(func_name, arg1, arg1_name)
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) def _sparse_check(func_name, a, m, b, x0): # 定义一个函数用于在稀疏求解器如cg, bicgstab和gmres中检查输入参数
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True)
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False)
return arg1, arg2
def _sparse_check(func_name, a, m, b, x0):
"""Used for cg, bicgstab and gmres method.""" """Used for cg, bicgstab and gmres method."""
def _check_right(arg, arg_name): def _check_right(arg, arg_name): # 定义一个内部函数用于检查右侧参数b或x0
if arg is None: if arg is None: # 如果参数为None
return mnp.zeros_like(b) # x0 same as b return mnp.zeros_like(b) # x0 same as b # 返回与b形状相同元素为零的tensor
# Type # Type
_mstype_check(func_name, arg, mstype.tensor_type, arg_name) _mstype_check(func_name, arg, mstype.tensor_type, arg_name) # 检查参数的mstype是否为tensor_type
# DType # DType
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
# Shape # Shape
if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1): if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1): # 检查参数的形状是否为(N,)或(N, 1)
_raise_value_error("For: '", func_name, "', the shape of '", arg_name, _raise_value_error("For: '", func_name, "', the shape of '", arg_name, # 如果不满足条件,抛出错误
"' should be like (N,) or (N, 1), bug got ", arg.shape, ".") "' should be like (N,) or (N, 1), bug got ", arg.shape, ".")
return arg return arg # 返回检查后的参数
b = _check_right(b, 'b') b = _check_right(b, 'b') # 检查参数b
x0 = _check_right(x0, 'x0') x0 = _check_right(x0, 'x0') # 检查参数x0
def _check_left(arg, arg_name): def _check_left(arg, arg_name): # 定义一个内部函数用于检查左侧参数a或m
if arg is None: if arg is None: # 如果参数为None
return lambda x: x # identity function return lambda x: x # identity function # 返回一个恒等函数
# Type # Type
_mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name) _mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name) # 检查参数的mstype是否为function_type, tensor_type或csr_tensor_type
if _callable_const(F.typeof(arg)): if _callable_const(F.typeof(arg)): # 如果参数是一个可调用的常量(即函数)
return arg return arg # 返回该参数
# DType # DType
if isinstance(arg, CSRTensor): if isinstance(arg, CSRTensor): # 如果参数是CSRTensor类型
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) _dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) # 检查CSRTensor的indptr数据类型是否为int32
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name) _dtype_check(func_name, arg.indices, [mstype.int32], arg_name) # 检查CSRTensor的indices数据类型是否为int32
_dtype_check(func_name, arg.values, [mstype.float32], arg_name) _dtype_check(func_name, arg.values, [mstype.float32], arg_name) # 检查CSRTensor的values数据类型是否为float32
else: else: # 如果参数不是CSRTensor类型
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
# Shape # Shape
_solve_check(func_name, arg, b, arg_name, 'b', True) _solve_check(func_name, arg, b, arg_name, 'b', True) # 检查参数a和b的形状是否可以用于求解线性方程组
_solve_check(func_name, arg, x0, arg_name, 'x0', True) _solve_check(func_name, arg, x0, arg_name, 'x0', True) # 检查参数a和x0的形状是否可以用于求解线性方程组
if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64): if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64): # 如果参数是Tensor类型且数据类型为int32或int64
arg = F.cast(arg, mstype.float64) arg = F.cast(arg, mstype.float64) # 将其转换为float64类型
return arg return arg # 返回检查后的参数
a = _check_left(a, 'A') a = _check_left(a, 'A') # 检查参数a
m = _check_left(m, 'M') m = _check_left(m, 'M') # 检查参数m
b = b.flatten() b = b.flatten() # 将参数b展平为一维的tensor
x0 = x0.flatten() x0 = x0.flatten() # 将参数x0展平为一维的tensor
if F.dtype(b) in (mstype.int32, mstype.int64): if F.dtype(b) in (mstype.int32, mstype.int64): # 如果参数b的数据类型为int32或int64
b = F.cast(b, mstype.float64) b = F.cast(b, mstype.float64) # 将其转换为float64类型
x0 = F.cast(x0, mstype.float64) x0 = F.cast(x0, mstype.float64) # 将其转换为float64类型
return a, m, b, x0 return a, m, b, x0 # 返回检查并转换后的参数

@ -1,3 +1,6 @@
这段代码是一个Python类的实现名为`MindData`它是一个用于模拟MindSpore框架中数据集处理的桩Stub下面是对这段代码的逐行注释
```python
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -12,77 +15,94 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
'''Remove after MindData merge to MindSpore ''' '''Remove after MindData merge to MindSpore '''
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor
class MindData: class MindData:
""" Stub for MindData """ """ Stub for MindData """
# 构造函数初始化MindData类的实例
def __init__(self, size=1, batch_size=None, repeat_count=1, def __init__(self, size=1, batch_size=None, repeat_count=1,
np_types=None, output_shapes=None, input_indexs=()): np_types=None, output_shapes=None, input_indexs=()):
self._size = size self._size = size # 数据集的大小
self._batch_size = batch_size self._batch_size = batch_size # 批处理大小
self._repeat_count = repeat_count self._repeat_count = repeat_count # 重复次数
self._np_types = np_types self._np_types = np_types # NumPy数据类型
self._output_shapes = output_shapes self._output_shapes = output_shapes # 输出形状
self._input_indexs = input_indexs self._input_indexs = input_indexs # 输入索引
self._iter_num = 0 self._iter_num = 0 # 迭代次数计数器
self.dynamic_setting = [False, None] self.dynamic_setting = [False, None] # 动态设置标志和值
# 获取数据集大小
def get_dataset_size(self): def get_dataset_size(self):
return self._size return self._size
# 获取重复次数
def get_repeat_count(self): def get_repeat_count(self):
return self._repeat_count return self._repeat_count
# 获取批处理大小
def get_batch_size(self): def get_batch_size(self):
return self._batch_size return self._batch_size
# 获取输出数据类型
def output_types(self): def output_types(self):
return self._np_types return self._np_types
# 获取输出形状
def output_shapes(self): def output_shapes(self):
return self._output_shapes return self._output_shapes
# 输入索引属性
@property @property
def input_indexs(self): def input_indexs(self):
return self._input_indexs return self._input_indexs
# 设备队列设置
def device_que(self, send_epoch_end=True, create_data_info_queue=False): def device_que(self, send_epoch_end=True, create_data_info_queue=False):
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736' self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
self.send_epoch_end = send_epoch_end self.send_epoch_end = send_epoch_end
return self return self
# 创建元组迭代器
def create_tuple_iterator(self, num_epochs=-1, do_copy=True): def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self.__iter__() return self.__iter__()
# 发送数据
def send(self, num_epochs=-1): def send(self, num_epochs=-1):
pass pass
# 停止发送数据
def stop_send(self): def stop_send(self):
pass pass
# 释放资源
def release(self): def release(self):
pass pass
# 继续发送数据
def continue_send(self): def continue_send(self):
pass pass
# 获取数据信息
def get_data_info(self): def get_data_info(self):
pass pass
# 动态最小最大形状
def dynamic_min_max_shapes(self): def dynamic_min_max_shapes(self):
pass pass
# 获取长度
def __len__(self): def __len__(self):
return self._size return self._size
# 迭代器
def __iter__(self): def __iter__(self):
return self return self
# 获取下一个元素
def __next__(self): def __next__(self):
if self._size < self._iter_num: if self._size < self._iter_num:
raise StopIteration raise StopIteration
@ -90,11 +110,13 @@ class MindData:
next_value = [] next_value = []
for shape, typ in zip(self._output_shapes, self._np_types): for shape, typ in zip(self._output_shapes, self._np_types):
next_value.append(Tensor(np.ndarray(shape, typ))) next_value.append(Tensor(np.ndarray(shape, typ)))
return tuple(next_value) return tuple(next_value)
# 下一个元素
def next(self): def next(self):
return self.__next__() return self.__next__()
# 重置迭代器
def reset(self): def reset(self):
self._iter_num = 0 self._iter_num = 0

@ -17,5 +17,7 @@ import mindspore.context as context
def setup_module(module): def setup_module(module):
# 禁用pylint对未使用参数的警告
# pylint: disable=unused-argument # pylint: disable=unused-argument
# 设置上下文模式为图模式
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)

@ -26,25 +26,28 @@ from .utils import keyword
def mindspore_test(verification_pipeline): def mindspore_test(verification_pipeline):
""" """
Run verification pipeline. 运行验证流水线
Args: Args:
verification_pipeline (list): Pipeline designed to do verification. verification_pipeline (list): 设计的验证流水线
Returns: Returns:
""" """
def decorate(get_verification_set): def decorate(get_verification_set):
# 获取验证集
verification_set = get_verification_set() verification_set = get_verification_set()
facade_components = [] # 初始化组件列表
data_components = [] facade_components = [] # 外观组件列表
builder_components = [] data_components = [] # 数据组件列表
executor_components = [] builder_components = [] # 构建组件列表
verifier_components = [] executor_components = [] # 执行组件列表
fi_policy_components = [] verifier_components = [] # 验证组件列表
er_policy_components = [] fi_policy_components = [] # FI策略组件列表
er_policy_components = [] # ER策略组件列表
for component in verification_pipeline: for component in verification_pipeline:
# 判断组件类型并添加到对应列表
if issubclass(component, IFacadeComponent): if issubclass(component, IFacadeComponent):
facade_components.append(component) facade_components.append(component)
elif issubclass(component, IDataComponent): elif issubclass(component, IDataComponent):
@ -62,68 +65,90 @@ def mindspore_test(verification_pipeline):
else: else:
raise Exception(f'{component} is not an instance of {IComponent}') raise Exception(f'{component} is not an instance of {IComponent}')
# 依次处理外观组件
for component in facade_components: for component in facade_components:
fc = component(verification_set) fc = component(verification_set)
verification_set = fc() verification_set = fc()
# 初始化输入列表
inputs = [] inputs = []
# 依次处理数据组件
for component in data_components: for component in data_components:
dc = component(verification_set) dc = component(verification_set)
item = dc() item = dc()
inputs.extend(item) inputs.extend(item)
# 如果输入列表为空,记录警告
if not inputs: if not inputs:
logging.warning("Inputs set is empty.") logging.warning("Inputs set is empty.")
# 初始化函数列表
functions = [] functions = []
# 依次处理构建组件
for component in builder_components: for component in builder_components:
bc = component(verification_set) bc = component(verification_set)
f = bc() f = bc()
functions.extend(f) functions.extend(f)
# 如果函数列表为空,记录警告
if not functions: if not functions:
logging.warning("Function set is empty.") logging.warning("Function set is empty.")
# 初始化函数输入对列表
fis = [] fis = []
# 依次处理FI策略组件
for component in fi_policy_components: for component in fi_policy_components:
fipc = component(verification_set, functions, inputs) fipc = component(verification_set, functions, inputs)
result = fipc() result = fipc()
fis.extend(result) fis.extend(result)
# 如果函数输入对列表为空,记录警告
if not fis: if not fis:
logging.warning("Function inputs pair set is empty.") logging.warning("Function inputs pair set is empty.")
# 定义测试用例函数
def test_case(args): def test_case(args):
# 提取系统待测和输入参数
sut, inputs = args sut, inputs = args
# 初始化结果列表
results = [] results = []
# 依次处理执行组件
for component in executor_components: for component in executor_components:
ec = component(verification_set, sut, inputs) ec = component(verification_set, sut, inputs)
result = ec() result = ec()
results.append(result) results.append(result)
# 如果结果列表为空,记录警告
if not results: if not results:
logging.warning("Result set is empty.") logging.warning("Result set is empty.")
# 初始化期望实际结果对列表
expect_actuals = [] expect_actuals = []
# 依次处理ER策略组件
for component in er_policy_components: for component in er_policy_components:
erpc = component(verification_set, verification_set['expect'], results) erpc = component(verification_set, verification_set['expect'], results)
result = erpc() result = erpc()
expect_actuals.extend(result) expect_actuals.extend(result)
# 如果期望实际结果对列表为空,记录警告
if not expect_actuals: if not expect_actuals:
logging.warning("Expect Result pair set is empty.") logging.warning("Expect Result pair set is empty.")
# 依次处理验证组件
for ea in expect_actuals: for ea in expect_actuals:
for component in verifier_components: for component in verifier_components:
vc = component(verification_set, *ea) vc = component(verification_set, *ea)
vc() vc()
# 定义测试用例名称生成函数
def get_tc_name(f, inputs): def get_tc_name(f, inputs):
# 拼接测试用例ID和组名
tc_id = f[keyword.id] + '-' + inputs[keyword.id] tc_id = f[keyword.id] + '-' + inputs[keyword.id]
group = f[keyword.group] + '-' + inputs[keyword.group] group = f[keyword.group] + '-' + inputs[keyword.group]
return 'Group_' + group + '-' + 'Id_' + tc_id return 'Group_' + group + '-' + 'Id_' + tc_id
# 如果存在函数输入对,则生成测试用例
if fis: if fis:
m = pytest.mark.parametrize('args', fis, ids=lambda fi: get_tc_name(*fi))(test_case) m = pytest.mark.parametrize('args', fis, ids=lambda fi: get_tc_name(*fi))(test_case)
m.__orig__ = get_verification_set m.__orig__ = get_verification_set

Loading…
Cancel
Save