Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-donghaoqian

branch-donghaoqian
donghaoqian 7 months ago
commit 74aff73449

@ -20,8 +20,11 @@ function(find_submodule_lib module name path)
)
endfunction()
# protobuf
function(ge_protobuf_generate c_var h_var)
# common_protobuf_generateprotobuf
common_protobuf_generate(${CMAKE_BINARY_DIR}/proto/ge/proto ${c_var} ${h_var} ${ARGN})
# chc_varh_var
set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE)
endfunction()

@ -22,31 +22,100 @@ from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
def ScalarAdd(x, y):
"""
实现标量加法运算
Args:
x (float): 第一个加数
y (float): 第二个加数
Returns:
float: x y 的和
"""
"""Implement `scalar_add`."""
return x + y
def ScalarMul(x, y):
"""
标量乘法函数
Args:
x (float): 第一个标量
y (float): 第二个标量
Returns:
float: 两个标量的乘积
"""
"""Implement `scalar_mul`."""
return x * y
def ScalarMod(x, y):
"""
对两个数进行模运算
Args:
x (int or float): 被模数
y (int or float): 模数
Returns:
int or float: x y 取模的结果
"""
"""Implement `scalar_mul`."""
return x % y
def ScalarSub(x, y):
"""
实现标量减法运算
Args:
x (float): 第一个标量值
y (float): 第二个标量值
Returns:
float: x y 的差值
"""
"""Implement `scalar_sub`."""
return x - y
def ScalarUsub(x):
"""
对给定的标量 x 进行取反操作
Args:
x (float or int): 需要取反的标量
Returns:
float or int: 取反后的标量
"""
"""Implement `scalar_usub`."""
return -x
def TupleGetItem(x, index):
"""
从给定对象中获取指定索引处的元素
Args:
x (Union[Tensor, dict]): 输入对象可以是Tensor类型或字典类型
index (int): 要获取的元素的索引
Returns:
Union[Tensor, Any]: 如果输入是Tensor类型则返回Tensor类型的元素
如果输入是字典类型则返回字典中对应索引的值
否则返回输入对象中对应索引的元素
Raises:
IndexError: 如果索引超出范围
"""
"""Implement `tuple_getitem`."""
if isinstance(x, Tensor):
x = x.asnumpy()
@ -64,36 +133,111 @@ def TupleGetItem(x, index):
def scalar_gt(x, y):
"""
判断两个标量值x和y的大小关系
Args:
x (float or int): 第一个标量值
y (float or int): 第二个标量值
Returns:
bool: 如果x大于y则返回True否则返回False
"""
"""Implement `scalar_gt`."""
return x > y
def scalar_ne(x, y):
"""
比较两个标量值是否不相等
Args:
x (float): 第一个标量值
y (float): 第二个标量值
Returns:
bool: 如果 x 不等于 y则返回 True否则返回 False
"""
"""Implement `scalar_ne`."""
return x != y
def scalar_eq(x, y):
"""
判断两个标量值是否相等
Args:
x (Any): 第一个标量值
y (Any): 第二个标量值
Returns:
bool: 如果 x y 相等返回 True否则返回 False
"""
"""Implement `scalar_eq`."""
return x == y
def scalar_le(x, y):
"""
判断标量 x 是否小于等于标量 y
Args:
x (float): 第一个标量值
y (float): 第二个标量值
Returns:
bool: 如果 x 小于等于 y则返回 True否则返回 False
"""
"""Implement `scalar_le`."""
return x <= y
def scalar_lt(x, y):
"""
判断两个标量值的大小关系
Args:
x (float): 第一个标量值
y (float): 第二个标量值
Returns:
bool: 如果 x 小于 y则返回 True否则返回 False
"""
"""Implement `scalar_lt`."""
return x < y
def identity(x):
"""
返回输入参数本身
Args:
x: 任何类型的输入参数
Returns:
返回输入参数本身
"""
"""Implement `identity`."""
return x
def zeros_like_tensor(x):
"""
根据给定的张量x创建一个形状相同但所有元素为零的新张量
Args:
x (Tensor): 输入的张量用于确定新张量的形状
Returns:
Tensor: 一个与输入张量x形状相同但所有元素为零的新张量
"""
"""Implement `zeros_like_tensor`."""
x = x.asnumpy()
value = Tensor(np.zeros(x.shape).astype(np.float32))
@ -101,61 +245,201 @@ def zeros_like_tensor(x):
def Switch(c, x, y):
"""
实现 `switch` 功能
Args:
c (bool): 条件值如果为 True则返回 x否则返回 y
x (Any): 条件为 True 时返回的值
y (Any): 条件为 False 时返回的值
Returns:
Any: 根据条件 c 返回 x y
"""
"""Implement `switch`."""
return x if c else y
def list_getitem(data, item):
"""
从列表中获取指定索引处的元素
Args:
data (list): 待查询的列表
item (int): 要获取的元素的索引
Returns:
返回列表中索引为item的元素
Raises:
IndexError: 如果索引超出列表范围
"""
"""Implement `list_getitem`."""
return data[item]
def bool_not(x):
"""
对输入值取反
Args:
x (bool): 要取反的布尔值
Returns:
bool: x 的逻辑非值
"""
"""Implement `bool_not`."""
return not x
def bool_and(x, y):
"""
对两个布尔值进行逻辑与操作
Args:
x (bool): 第一个布尔值
y (bool): 第二个布尔值
Returns:
bool: 返回两个布尔值进行逻辑与操作后的结果
"""
"""Implement `bool_and`."""
return x and y
def bool_or(x, y):
"""
实现布尔或运算
Args:
x (bool): 第一个布尔值
y (bool): 第二个布尔值
Returns:
bool: 如果 x y True则返回 True否则返回 False
"""
"""Implement `bool_or`."""
return x or y
def make_list(*xs):
"""
将不定数量的参数转换为一个列表
Args:
*xs: 不定数量的参数可以是任意类型
Returns:
list: 包含所有传入参数的列表
Examples:
>>> make_list(1, 2, 3)
[1, 2, 3]
>>> make_list('a', 'b', 'c')
['a', 'b', 'c']
>>> make_list(1, 'a', [1, 2, 3])
[1, 'a', [1, 2, 3]]
"""
"""Implement `make_list`."""
return list(xs)
def list_len(x):
"""
计算列表的长度
Args:
x (list): 需要计算长度的列表
Returns:
int: 列表的长度
"""
"""Implement `list_len`."""
return len(x)
def Depend(value, expr):
"""
依赖函数根据给定的表达式返回相应的值
Args:
value (Any): 要返回的值
expr (Any): 表达式该参数在当前实现中被忽略
Returns:
Any: 返回与输入相同的值
"""
"""Implement `Depend`."""
return value
def UpdateState(monad, *exprs):
"""
更新状态
Args:
monad (Monad): 一个符合 Monad 类型的对象
*exprs (Any): 需要更新的表达式可以为任意类型
Returns:
Monad: 更新后的 Monad 对象
"""
"""Implement `UpdateState`."""
return monad
def Load(value, u=None):
"""
加载指定的值
Args:
value (Any): 要加载的值
u (Optional[Any], optional): 可选参数默认为None当前版本未使用保留以便未来扩展
Returns:
Any: 返回加载的值
"""
"""Implement `Load`."""
return value
# only used in PyNative mode
def make_ref(key, value, ref):
"""
创建一个引用对象
Args:
key (str): 键名用于标识引用的对象
value (Any): 引用对象的值
ref (Any): 引用对象可以为任意类型
Returns:
Any: 返回引用的值
"""
return value
def scalar_cast(x, t):
"""
将标量值x转换为指定的NumPy数据类型t
Args:
x (float, int): 要转换的标量值
t (np.dtype): 目标NumPy数据类型
Returns:
Any: 转换后的标量值类型为t
"""
"""Implement scalar_cast."""
np_type = dtype_to_nptype(t)
value = np_type(x)
@ -164,16 +448,46 @@ def scalar_cast(x, t):
def typeof(x):
"""
实现 typeof 函数
Args:
x (Any): 要获取类型的对象
Returns:
str: 返回传入对象的Python类型名称
"""
"""Implement typeof."""
return get_py_obj_dtype(x)
def tuple_to_array(x):
"""
将元组转换为数组
Args:
x (tuple): 待转换的元组
Returns:
Tensor: 转换后的数组
"""
"""Implement `tuple_to_array`."""
return Tensor(np.array(x))
def stop_gradient(x):
"""
停止梯度传播
Args:
x (Tensor): 需要停止梯度传播的张量
Returns:
Tensor: 停止梯度传播的张量
"""
"""Implement `stop_gradient`."""
return x
@ -182,6 +496,20 @@ hyper_map = C.HyperMap()
def mixed_precision_cast(dst_type, x):
"""
实现混合精度转换函数
Args:
dst_type (mstype.Type): 目标数据类型
x (Union[Tensor, list, tuple]): 需要进行类型转换的数据可以是单个Tensor也可以是一个包含Tensor的列表或元组
Returns:
Union[Tensor, list, tuple]: 转换后的数据类型与输入一致
Raises:
TypeError: 如果输入数据类型不支持将引发TypeError异常
"""
"""Implement `mixed_precision_cast`."""
def cast_inner(data):

@ -13,6 +13,9 @@
# limitations under the License.
# ============================================================================
"""init"""
# 从splitter模块中导入split_with_json函数
from .splitter import split_with_json
# 从expander模块中导入get_op_expander函数
from .expander import get_op_expander
# 从parallel_estimate模块中导入estimate_calculation_amount和estimate_ops函数
from .parallel_estimate import estimate_calculation_amount, estimate_ops

@ -22,8 +22,32 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx
def create_expander(expand_info):
"""
根据操作符名称创建一个扩展器
Args:
expand_info (dict): 包含操作符名称及其他相关信息的字典
Returns:
Any: 调用指定操作符名称的扩展器后返回的结果
Raises:
GraphKernelUnsupportedException: 如果指定的操作符名称在扩展器模块中不存在则抛出此异常
"""
"""Create an expander according to op name"""
def call_func(func, arg):
"""
调用给定的函数并返回其结果
Args:
func (callable): 要调用的函数
arg: 要传递给函数的参数
Returns:
调用给定函数后的返回值
"""
return func(arg)
op_name = str(expand_info['name'])
if not hasattr(expanders, op_name):
@ -33,6 +57,21 @@ def create_expander(expand_info):
def extract_expand_info(kernel_info):
"""
将json格式的kernel信息转换为更友好的格式
Args:
kernel_info (dict): 包含kernel信息的字典
Returns:
dict: 转换后的kernel信息字典包含以下键
- name (str): kernel的名称
- input_desc (list): 输入描述列表
- output_desc (list): 输出描述列表
- attr (dict): 属性字典键为属性名值为属性值
- process (str): 处理过程的描述
"""
"""Convert the json into a more friendly format"""
input_desc = []
if 'input_desc' in kernel_info and kernel_info['input_desc']:
@ -53,6 +92,20 @@ def extract_expand_info(kernel_info):
def get_op_expander(json_str: str):
"""
通过json信息获取操作扩展器
Args:
json_str (str): 包含操作扩展器信息的json字符串
Returns:
str: 返回扩展后的操作图的json描述
Raises:
jd.JSONDecodeError: 如果输入的json字符串解码失败
GraphKernelUnsupportedException: 如果操作图不支持的操作类型
"""
"""get op expander by json info"""
try:
kernel_info = json.loads(json_str)

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""expanders init. Deprecated, please add the new operators in the c++ file"""
"""扩展器初始化。已弃用请在新运算符中添加C++文件"""
from .addn import AddN

@ -27,6 +27,19 @@ class Expander:
__metaclass__ = ABCMeta
def __init__(self, expand_info):
"""
初始化方法
Args:
expand_info (dict): 包含模型信息的字典包括模型名称输入描述输出描述属性处理函数等
Attributes:
name (str): 模型名称
inputs (list): 输入描述列表
outputs (list): 输出描述列表
attrs (dict): 模型属性字典
processor (callable): 处理函数
"""
self.name = expand_info["name"]
self.inputs = expand_info["input_desc"]
self.outputs = expand_info["output_desc"]
@ -34,6 +47,19 @@ class Expander:
self.processor = expand_info["process"]
def run(self):
"""
将操作扩展为图
Args:
Returns:
返回扩展后的图对象
Raises:
GraphKernelUnsupportedException: 如果检查失败则引发此异常
"""
"""
Expand the operator to a graph.
@ -58,9 +84,31 @@ class Expander:
return graph
def _check(self):
"""
检查输入
Args:
Returns:
Raises:
ValueError: 如果输入不符合要求则引发此异常
"""
"""Check inputs"""
def _check_output_same(self, outputs):
"""
检查输出是否与预期一致
Args:
outputs (list): 实际输出值的列表
Raises:
GKException: 如果实际输出值与预期不一致则抛出异常
"""
for index, value in enumerate(self.outputs):
if list(outputs[index].shape) != list(value['shape']):
raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
@ -74,6 +122,18 @@ class Expander:
@abstractmethod
def _expand(self, graph_builder):
"""
Expand 操作符此函数应在子类中重写
Args:
graph_builder (GraphBuilder): 图构建器对象
Raises:
Exception: 如果子类未重写此方法则抛出异常提示 "_expand() is not implemented in {}".
Returns:
None
"""
"""Expand operator, this function should be overridden in subclass"""
raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
@ -82,10 +142,34 @@ class ExpanderInfoValidator:
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
def __init__(self):
"""
初始化方法
Args:
Returns:
"""
"""Init"""
@staticmethod
def _add_check_function(kls, func):
"""
向类 Expander 中的 `_check` 函数添加新的检查函数 `func`
Args:
kls (type): 需要被修改的类对象应继承自 Expander
func (callable): 要添加的新检查函数该函数接受一个对象作为参数
Returns:
None
Raises:
AttributeError: 如果 kls 类中不存在 `_check` 方法
"""
"""
Rewrite the function `_check` in class Expander
to append the new `func` after the original checks.
@ -93,6 +177,21 @@ class ExpanderInfoValidator:
old_check = getattr(kls, "_check")
def new_check(obj):
"""
执行新的检查函数
Args:
obj (Any): 需要检查的对象
Returns:
None
Raises:
None
这个函数首先调用旧版本的检查函数 `old_check` 对传入的对象 `obj` 进行检查
然后调用自定义的函数 `func` 对该对象进行处理
"""
old_check(obj)
func(obj)
@ -103,6 +202,34 @@ class ExpanderInfoValidator:
"""
Add new supported format for the operator
Args:
*input_format: A variable number of arguments representing the new supported formats.
Returns:
A wrapper function that adds the specified formats to the operator's supported formats list.
Raises:
GKException: Raised if the length of the registered format list does not match the length of the input formats,
or if the input formats do not match any registered format.
Exception: Raised if the wrapped class is not a subclass of Expander.
Description:
This function adds a list `__supported_formats` to the expander,
which contains the whitelist of formats supported by the operator.
It also rewrites the `_check` function to check the formats.
Example:
python
@add_format("text", "image")
class MyOperator(Expander):
pass
```
In this example, `MyOperator` will support the "text" and "image" formats.
"""
"""
Add new supported format for the operator
this function will add a list `__supported_formats` into the expander,
saving the whitelist of formats that this op supports.
it also rewrites the `_check` function to check the formats.
@ -110,6 +237,19 @@ class ExpanderInfoValidator:
format_list_name = "__supported_formats"
def _check_format(obj):
"""
检查对象的输入格式是否与已注册的格式匹配
Args:
obj (object): 需要检查的对象
Raises:
GKException: 如果输入格式与已注册的格式不匹配则引发异常
Returns:
None
"""
inp_formats = [inp['format'] for inp in obj.inputs]
for formats in getattr(obj, format_list_name):
if len(formats) != len(inp_formats):
@ -120,6 +260,18 @@ class ExpanderInfoValidator:
raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
def wrapper(cls):
"""
为给定的类添加包装功能
Args:
cls: 需要被包装的类必须继承自 Expander
Returns:
返回包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
if not hasattr(cls, format_list_name):
@ -132,11 +284,49 @@ class ExpanderInfoValidator:
@staticmethod
def check_all_formats_same(kls):
"""
检查所有格式是否相同
Args:
kls: 待检查的类
Returns:
返回传入的类 kls并在类上注册一个检查函数用于验证该类所有输入格式是否一致
Raises:
Exception: 如果传入的类 kls 不是 Expander 的子类则抛出异常
GKException: 如果 kls 类中的输入格式不一致则抛出异常并显示不匹配格式的信息
"""
"""Check that all formats are the same"""
# Ensure no args case can return a class
def _check(*args):
"""
检查操作输入格式是否一致的装饰器
Args:
*args: 可变参数装饰器可以接收任意数量的参数
Returns:
wrapper: 返回一个装饰器函数用于包装类
Raises:
GKException: 如果所有输入的格式不一致抛出GKException异常
Exception: 如果被装饰的类不是Expander的子类抛出异常
"""
def _check_format(obj):
"""
检查输入格式是否一致
Args:
obj (Any): 包含输入信息的对象
Raises:
GKException: 如果所有输入格式不一致则抛出异常并包含不匹配格式的具体信息
"""
inp_formats = [inp['format'] for inp in obj.inputs]
if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
return
@ -144,6 +334,19 @@ class ExpanderInfoValidator:
','.join(inp_formats), obj.name))
def wrapper(cls):
"""
将给定类包装为 Expander 的子类并进行格式检查
Args:
cls (class): 需要包装的类
Returns:
class: 包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_format)
@ -155,14 +358,53 @@ class ExpanderInfoValidator:
@staticmethod
def check_attrs(*args):
"""
检查属性是否存在
Args:
*args: 一个或多个属性名用于检查对象是否具有这些属性
Returns:
一个装饰器函数该装饰器函数用于验证类是否具有指定的属性
Raises:
GKException: 如果对象不具有指定的属性则抛出该异常
Exception: 如果被装饰的类不是 Expander 的子类则抛出该异常
"""
"""Check the attrs exist"""
def _check_attr(obj):
"""
检查对象是否具有指定的属性
Args:
obj (object): 要检查的对象
Raises:
GKException: 如果对象不具有指定的属性则抛出异常
Returns:
None
"""
for a in args:
if a not in obj.attrs:
raise GKException("attr '{}' does not exist.".format(a))
def wrapper(cls):
"""
对类进行包装确保该类是 Expander 的子类并添加属性检查功能
Args:
cls (class): 需要包装的类
Returns:
class: 包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_attr)
@ -172,6 +414,21 @@ class ExpanderInfoValidator:
def to_frac_z_axis(ori_shape, ori_axis):
"""
判断是否为分形NZ格式
Args:
----
ori_shape: list or tuple
输入的原始形状
ori_axis: list or tuple
操作的原始形状的轴
Returns:
-------
output: list
分形Nz形状的轴
"""
"""
judge the format is fractal NZ
Parameters
@ -208,6 +465,16 @@ def to_frac_z_axis(ori_shape, ori_axis):
def infer_shape_from_fractalnz(fractal):
"""
从fractalnz形状推断原始形状
Args:
fractal (list): fractalnz形状一个包含形状的列表
Returns:
list: 推断出的原始形状
"""
"get original shape from fractalnz shape"
shape = []
dims = len(fractal)
@ -222,6 +489,17 @@ def infer_shape_from_fractalnz(fractal):
def get_reduced_ori_shape(shape, axis):
"""
获取基于原始形状的降维后的形状
Args:
shape (List[int]): 原始形状是一个整数列表
axis (List[int]): 需要降维的轴索引列表
Returns:
List[int]: 降维后的形状是一个整数列表
"""
"get shape after reduced which is based on original shape"
reduced_ori_shape = []
for i, value in enumerate(shape):
@ -233,6 +511,25 @@ def get_reduced_ori_shape(shape, axis):
def get_reduce_axis_shape(shape, data_format, axis):
"""
根据给定的输入形状数据格式和轴获取在指定格式下的归约轴和原始的归约形状
Args:
-----
shape: list or tuple
输入的形状
data_format: str
输入的数据格式
axis: None, int, list or tuple
在原始形状下的归约轴
Returns:
--------
reduce_axis: list
在指定数据格式下的归约轴
ori_reduced_shape: list
原始的归约形状
"""
"""
Get the reduce axis under format `data_format` and original reduced shape.
Parameters

@ -13,20 +13,47 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for addn"""
# 导入GraphKernelUnsupportedException异常类
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD
# 使用VLD.check_all_formats_same装饰器确保所有输入格式相同
@VLD.check_all_formats_same
class AddN(Expander):
"""Expand AddN to multiple Adds"""
# 检查输入数量是否大于1
def _check(self):
"""
检查输入的数量是否满足要求
Args:
Returns:
Raises:
GKException: 如果输入的数量小于2则抛出GKException异常
"""
if len(self.inputs) < 2:
raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
.format(len(self.inputs)))
# 将AddN展开为多个Add操作
def _expand(self, graph_builder):
"""
对输入张量进行逐元素加法运算
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成图节点
Returns:
Tensor: 逐元素加法运算后的结果张量
"""
result = self.inputs[0]
for inp in self.inputs[1:]:
result = graph_builder.emit('Add', [result, inp])

@ -36,15 +36,19 @@ class BatchNorm(Expander):
input_x_ori_type = input_x.dtype
input_x_new_type = input_x.dtype
# 如果输入数据的类型为float16而scale、offset、mean、variance的类型为float32则将输入数据类型转换为float32
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
input_mean.dtype == "float32" and input_variance.dtype == "float32":
input_x_new_type = "float32"
# 如果输入数据类型与原始类型不同,则进行类型转换
if input_x_new_type != input_x_ori_type:
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
# 如果是训练模式
if self.attrs['is_training']:
self.inputs[0] = input_x
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
# 如果输入数据类型与原始类型不同,则将输出数据类型转换为原始类型
if input_x_new_type != input_x_ori_type:
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
@ -70,21 +74,42 @@ class BatchNorm(Expander):
return res_y, var_add, var_add, var_add, var_add
def _bn_train(self, graph_builder):
"""
在训练模式下扩展BatchNorm
Args:
graph_builder (GraphBuilder): 图构建器实例
Returns:
tuple: 包含以下内容的元组:
- res_y (Tensor): 归一化后的输出
- mean_res (Tensor): 更新后的移动均值
- variance_res (Tensor): 更新后的移动方差
- mean_muls (Tensor): 输入数据的均值
- y_sqrt_rec (Tensor): 1 / sqrt(方差 + epsilon)用于反向传播
"""
"""expand BatchNorm for training mode"""
# 获取输入数据
input_x = self.inputs[0]
input_scale = self.inputs[1]
input_offset = self.inputs[2]
input_mean = self.inputs[3]
input_variance = self.inputs[4]
# 获取epsilon值
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
# 获取reduce轴
reduce_axis = ()
# 获取输入数据的形状
shape_x = input_x.shape
# 根据输入数据的格式设置reduce轴和num值
if input_x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
num = shape_x[0] * shape_x[1] * shape_x[2]
else:
reduce_axis = (0, 2, 3)
num = shape_x[0] * shape_x[2] * shape_x[3]
# 计算num的倒数
num_rec = 1.0 / num
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
@ -112,41 +137,67 @@ class BatchNorm(Expander):
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
# compute res_y
# 计算输入x和mean_muls_expand的差值
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对y_sqrt_rec进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
y_sqrt_rec_expand = graph_builder.emit(
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
# 否则y_sqrt_rec保持不变
else:
y_sqrt_rec_expand = y_sqrt_rec
# 计算tmp_sub和y_sqrt_rec_expand的乘积
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对input_scale进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_scale_expand = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
# 否则input_scale保持不变
else:
input_scale_expand = input_scale
# 计算input_scale_expand和y_norm的乘积
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
# 如果输入x的数据格式为DF.DEFAULT或DF.NCHW则对input_offset进行reshape操作
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_offset_expand = graph_builder.emit(
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
# 否则input_offset保持不变
else:
input_offset_expand = input_offset
# 计算res_y_mul和input_offset_expand的和
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
# compute mean_res
# 计算动量减去1的值
momentum_sub = scalar_one - self.attrs['momentum']
# 将动量减去1的值转换为输入数据的类型
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
# 计算新的移动平均值的临时值
# 计算新的running_mean_tmp
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
# 计算momentum_v
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
# 计算current_mean_tmp
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
# 计算updated_moving_mean
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
# 将updated_moving_mean赋值给input_mean
mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
# 计算方差
var_num = float(num) / (num - 1)
# 将方差转换为输入数据的类型
var_num_v = graph_builder.value(input_scale.dtype, var_num)
# 计算方差乘积
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
# 计算新的移动方差
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
# 计算当前移动方差
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
# 更新移动方差
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
# 将更新后的移动方差赋值给输入方差
variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
# 返回结果
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec

@ -11,46 +11,59 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for BatchNormGrad"""
# ========================================================================
# ===
# 版权声明
# 根据Apache License 2.0授权
# 除非遵守许可,否则不得使用此文件
"""
为BatchNormGrad生成json描述BatchNormGrad是用于计算Batch Normalization层梯度的类
"""
# 导入必要的模块和类
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
from .expand_dims import ExpandDims
# 定义BatchNormGrad类继承自Expander
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('is_training', 'epsilon')
class BatchNormGrad(Expander):
"""BatchNormGrad expander"""
"""BatchNormGrad扩展器用于计算Batch Normalization层的梯度"""
# 定义扩展方法该方法将被调用来执行BatchNormGrad的计算
def _expand(self, graph_builder):
# get op info
input_dy = self.inputs[0]
input_x = self.inputs[1]
input_scale = self.inputs[2]
input_save_mean = self.inputs[3]
input_save_inv_variance = self.inputs[4]
# 获取操作信息,包括梯度、输入数据、尺度、保存的均值和倒数方差
input_dy = self.inputs[0] # 输入数据的梯度
input_x = self.inputs[1] # 输入数据
input_scale = self.inputs[2] # 输入数据的尺度
input_save_mean = self.inputs[3] # 保存的均值
input_save_inv_variance = self.inputs[4] # 保存的倒数方差
# 根据输入数据的格式计算reduce_axis用于后续的ReduceSum操作
reduce_axis = ()
shape_x = input_x.shape
if input_x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
num = shape_x[0] * shape_x[1] * shape_x[2]
else:
reduce_axis = (0, 2, 3)
num = shape_x[0] * shape_x[2] * shape_x[3]
ori_type = input_x.dtype
if input_x.data_format == DF.NHWC: # 如果数据格式为NHWC
reduce_axis = (0, 1, 2) # 指定ReduceSum的轴
num = shape_x[0] * shape_x[1] * shape_x[2] # 计算元素总数
else: # 否则假设数据格式为NCHW
reduce_axis = (0, 2, 3) # 指定ReduceSum的轴
num = shape_x[0] * shape_x[2] * shape_x[3] # 计算元素总数
ori_type = input_x.dtype # 原始数据类型
# 如果原始数据类型为float16则转换为float32进行计算以避免精度损失
if ori_type == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
if input_dy.dtype == 'float16':
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
num_rec = -1.0 / num
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
num_rec = -1.0 / num # 计算倒数
num_rec_v = graph_builder.value(input_scale.dtype, num_rec) # 创建倒数的值
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dbeta即beta的梯度
# in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass
# 根据是否在训练中计算inv_variance倒数方差
if self.attrs['is_training']:
inv_variance = input_save_inv_variance
else:
@ -61,7 +74,7 @@ class BatchNormGrad(Expander):
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
# compute dgamma
# 计算dgammagamma的梯度
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
input_save_mean = graph_builder.emit(
'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
@ -69,13 +82,13 @@ class BatchNormGrad(Expander):
'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
input_scale = graph_builder.emit(
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean])
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance])
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div])
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) # 计算x减去均值
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) # 计算x除以倒数方差
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) # 计算dgamma参数
dgamma = graph_builder.emit(
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dgamma
# compute dx
# 计算dxx的梯度
if self.attrs['is_training']:
tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
@ -95,11 +108,12 @@ class BatchNormGrad(Expander):
y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
dx = graph_builder.emit('Mul', [inv_variance, y_scale])
if ori_type == 'float16':
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) # 如果原始数据类型为float16则转换回float16
# set output tensors' data_format
# 设置输出张量的数据格式
dx.data_format = self.outputs[0]['format']
dgamma.data_format = self.outputs[1]['format']
dbeta.data_format = self.outputs[2]['format']
return dx, dgamma, dbeta
# 返回计算结果
return dx, dgamma, dbeta

@ -13,37 +13,64 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for bias_add"""
# 导入MindSpore的DataFormat类
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD
# 为BiasAddGrad类添加DF.DEFAULT、DF.NHWC、DF.NCHW、DF.FRAC_NZ格式的验证
@VLD.add_format(DF.DEFAULT)
@VLD.add_format(DF.NHWC)
@VLD.add_format(DF.NCHW)
@VLD.add_format(DF.FRAC_NZ)
# 定义BiasAddGrad类继承自Expander类
class BiasAddGrad(Expander):
"""BiasAddGrad expander"""
def _expand(self, graph_builder):
"""
内部方法用于扩展输入张量的维度
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成图操作
Returns:
Tensor: 扩展维度后的张量
"""
# 获取输入张量
x = self.inputs[0]
# 定义reduce_axis用于指定求和的维度
reduce_axis = ()
# 如果输入张量的数据格式为NHWC则reduce_axis为(0, 1, 2)
if x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
# 如果输入张量的数据格式为NCHW则reduce_axis为(0, 2, 3)
elif x.data_format == DF.NCHW:
reduce_axis = (0, 2, 3)
# 如果输入张量的数据格式为FRAC_NZ则reduce_axis为(-2, -3)
elif x.data_format == DF.FRAC_NZ:
reduce_axis = (-2, -3)
# 如果输入张量的数据格式为DefaultFormat则根据shape的长度确定reduce_axis
else:
# DefaultFormat shape's length should be from 2 to 4
# 如果x的维度为2则reduce_axis为(0,)
if len(x.shape) == 2:
reduce_axis = (0,)
# 如果x的维度为3则reduce_axis为(0, 1)
elif len(x.shape) == 3:
reduce_axis = (0, 1)
# 否则reduce_axis为(0, 2, 3)
else:
reduce_axis = (0, 2, 3)
# 发射ReduceSum操作计算x的reduce_sumreduce_axis为reduce_axiskeep_dims为False
result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
# 如果x的数据格式为DF.FRAC_NZ则将result的shape改为x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
if x.data_format == DF.FRAC_NZ:
out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
# 发射Reshape操作将result的shape改为out_shape
result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape})
# 返回result
return result

@ -21,13 +21,28 @@ class ClipByNormNoDivSum(Expander):
"""ClipByNormNoDivSum expander"""
def _expand(self, graph_builder):
"""
对输入的张量进行计算返回计算结果
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: 计算结果张量
"""
input_x0, input_x1, input_x2, input_x3 = self.inputs
# cal result
# 计算大于结果
greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
# 根据大于结果选择input_x0或input_x2
select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
# 计算select_res0的平方根
sqrt_res = graph_builder.emit('Sqrt', [select_res0])
# 根据大于结果选择sqrt_res或input_x0
select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
# 计算select_res1和input_x3的最大值
result = graph_builder.emit('Maximum', [select_res1, input_x3])
return result

@ -14,8 +14,13 @@
# ============================================================================
"""complex expanders init"""
# 从当前目录下的abs模块中导入CAbs类
from .abs import CAbs
# 从当前目录下的add模块中导入CAdd类
from .add import CAdd
# 从当前目录下的div模块中导入CDiv类
from .div import CDiv
# 从当前目录下的mul模块中导入CMul类
from .mul import CMul
# 从当前目录下的sub模块中导入CSub类
from .sub import CSub

@ -20,11 +20,23 @@ class CAbs(Expander):
"""CAbs expander"""
def _expand(self, graph_builder):
# 获取输入的第一个元素
input_x = self.inputs[0]
# 发射指令CReal将输入x的实部提取出来
x_real = graph_builder.emit('CReal', [input_x])
# 发射指令CImag将输入x的虚部提取出来
x_imag = graph_builder.emit('CImag', [input_x])
# 发射指令Mul计算x的实部的平方
squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
# 发射指令Mul计算x的虚部的平方
squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
# 发射指令Add计算实部和虚部的平方和
squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
# 发射指令Sqrt计算平方和的平方根
result = graph_builder.emit('Sqrt', [squre_sum])
return result

@ -22,12 +22,35 @@ class CAdd(Expander):
"""CAdd expander"""
def _expand(self, graph_builder):
"""
将两个复数相加
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Node: 相加后的复数结果
"""
# 获取输入参数
input_x, input_y = self.inputs
# 将输入参数x转换为实数部分
x_real = graph_builder.emit('CReal', [input_x])
# 将输入参数y转换为实数部分
y_real = graph_builder.emit('CReal', [input_y])
# 将输入参数x转换为虚数部分
x_imag = graph_builder.emit('CImag', [input_x])
# 将输入参数y转换为虚数部分
y_imag = graph_builder.emit('CImag', [input_y])
# 将x和y的实数部分相加
result_real = graph_builder.emit('Add', [x_real, y_real])
# 将x和y的虚数部分相加
result_imag = graph_builder.emit('Add', [x_imag, y_imag])
# 将相加后的实数部分和虚数部分组合为复数
result = graph_builder.emit('Complex', [result_real, result_imag])
return result

@ -22,22 +22,43 @@ class CDiv(Expander):
"""CDiv expander"""
def _expand(self, graph_builder):
"""
CDiv Implementation
Args:
graph_builder: 图构建器对象用于构建计算图
Returns:
返回复数除法结果
实现复数除法CDiv操作
获取两个输入的复数分别计算它们的实部和虚部
然后计算分母和分子的实部和虚部并进行除法运算
最后将得到的商的实部和虚部合并为复数结果返回
"""
"""CDiv Implementation"""
# 获取输入的两个复数
input_x, input_y = self.inputs
x_real = graph_builder.emit('CReal', [input_x])
y_real = graph_builder.emit('CReal', [input_y])
x_imag = graph_builder.emit('CImag', [input_x])
y_imag = graph_builder.emit('CImag', [input_y])
squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
result = graph_builder.emit('Complex', [result_real, result_imag])
# 获取输入复数的实部
x_real = graph_builder.emit('CReal', [input_x]) # 发射 CReal 操作获取 input_x 的实部
y_real = graph_builder.emit('CReal', [input_y]) # 发射 CReal 操作获取 input_y 的实部
# 获取输入复数的虚部
x_imag = graph_builder.emit('CImag', [input_x]) # 发射 CImag 操作获取 input_x 的虚部
y_imag = graph_builder.emit('CImag', [input_y]) # 发射 CImag 操作获取 input_y 的虚部
# 计算分母
squre_y_real = graph_builder.emit('Mul', [y_real, y_real]) # 计算 y_real 的平方
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag]) # 计算 y_imag 的平方
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag]) # 计算分母
# 计算分子
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) # 计算 x_real 和 y_real 的乘积
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) # 计算 x_imag 和 y_imag 的乘积
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) # 计算 x_real 和 y_imag 的乘积
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) # 计算 x_imag 和 y_real 的乘积
final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag]) # 计算分子的实部
final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag]) # 计算分子的虚部
# 计算商
result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator]) # 计算商的实部
result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator]) # 计算商的虚部
# 将商合并为复数结果
result = graph_builder.emit('Complex', [result_real, result_imag]) # 将实部和虚部合并为复数结果
return result

@ -22,17 +22,45 @@ class CMul(Expander):
"""CMul expander"""
def _expand(self, graph_builder):
"""
计算两个复数的乘积
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中生成操作节点
Returns:
Result: 计算得到的复数乘积结果
"""
"""CMul Implementation"""
# 获取输入的两个复数
input_x, input_y = self.inputs
x_real = graph_builder.emit('CReal', [input_x])
y_real = graph_builder.emit('CReal', [input_y])
x_imag = graph_builder.emit('CImag', [input_x])
y_imag = graph_builder.emit('CImag', [input_y])
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag])
result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real])
result = graph_builder.emit('Complex', [result_real, result_imag])
# 获取输入复数的实部
x_real = graph_builder.emit('CReal', [input_x]) # 发射指令获取input_x的实部
y_real = graph_builder.emit('CReal', [input_y]) # 发射指令获取input_y的实部
# 获取输入复数的虚部
x_imag = graph_builder.emit('CImag', [input_x]) # 发射指令获取input_x的虚部
y_imag = graph_builder.emit('CImag', [input_y]) # 发射指令获取input_y的虚部
# 计算实部与实部的乘积
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) # 发射指令计算x_real与y_real的乘积
# 计算虚部与虚部的乘积
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) # 发射指令计算x_imag与y_imag的乘积
# 计算实部与虚部的乘积
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) # 发射指令计算x_real与y_imag的乘积
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) # 发射指令计算x_imag与y_real的乘积
# 计算复数的实部
result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag]) # 发射指令计算实部结果
# 计算复数的虚部
result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real]) # 发射指令计算虚部结果
# 构造复数结果
result = graph_builder.emit('Complex', [result_real, result_imag]) # 发射指令构造复数结果
return result

@ -22,12 +22,38 @@ class CSub(Expander):
"""CSub expander"""
def _expand(self, graph_builder):
"""
计算两个复数的差
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: 计算得到的复数差的结果
"""
# 获取输入
input_x, input_y = self.inputs
# 提取输入x的实部
x_real = graph_builder.emit('CReal', [input_x])
# 提取输入y的实部
y_real = graph_builder.emit('CReal', [input_y])
# 提取输入x的虚部
x_imag = graph_builder.emit('CImag', [input_x])
# 提取输入y的虚部
y_imag = graph_builder.emit('CImag', [input_y])
# 计算实部之差
result_real = graph_builder.emit('Sub', [x_real, y_real])
# 计算虚部之差
result_imag = graph_builder.emit('Sub', [x_imag, y_imag])
# 将实部和虚部组合成复数结果
result = graph_builder.emit('Complex', [result_real, result_imag])
return result

@ -18,6 +18,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
from ._utils import Expander, ExpanderInfoValidator as VLD
# 定义常量
M_ALIGN = 32
N_ALIGN = 32
K_ALIGN = 16
@ -29,6 +30,7 @@ C_CHANNEL_ALIGN = 16
OUT_NHW_ALIGN = 128
# 添加格式验证
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
@VLD.add_format(DF.NHWC, DF.NHWC)
@VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
@ -47,6 +49,16 @@ class Conv2D(Expander):
"""
def __init__(self, expand_info):
"""
类的构造函数
Args:
expand_info (dict): 扩展信息字典包含一些扩展的配置参数
Returns:
None
"""
super().__init__(expand_info)
self.dst_type = self.outputs[0]['data_type']
self.dst_format = self.outputs[0]['format']
@ -59,6 +71,19 @@ class Conv2D(Expander):
self.k = 0
def _optimize_to_matmul(self):
"""
检查是否可以将Conv2D优化为MatMul
Args:
Returns:
bool: 如果可以将Conv2D优化为MatMul则返回True否则返回False
"""
"""
Check if the Conv2D can be optimized to MatMul.
"""
stride = self.attrs['stride']
dilation = self.attrs['dilation']
_, h, w, _ = self.inputs[1]['shape']
@ -68,6 +93,18 @@ class Conv2D(Expander):
return False
def _common_check(self):
"""
对输入和属性的通用检查
Args:
Returns:
Raises:
GKException: 如果输入数据类型不是 float16或者输入格式不是 NHWC或者属性 groups group 不是 1或者属性 dilation 不是 [1, 1, 1, 1]抛出异常
"""
"""common check for inputs and attrs"""
type_0 = self.inputs[0]['data_type']
type_1 = self.inputs[1]['data_type']
@ -91,26 +128,52 @@ class Conv2D(Expander):
.format(dilation))
def _check(self):
"""
检查卷积2D操作的参数和输入是否合法
Args:
Raises:
GKException: 当输入参数或输入维度不满足要求时抛出异常
Returns:
"""
# 调用_common_check()方法
self._common_check()
# 获取pad_list
pad_list = self.attrs['pad_list']
# 检查pad_list的维度是否为4
check_nd(pad_list, 4)
# 调用conv_had_pad()方法判断是否有pad
self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode'])
# 获取输入的shape
shape_0 = self.inputs[0]['shape']
shape_1 = self.inputs[1]['shape']
# 获取stride
stride = self.attrs['stride']
# 检查shape_0的维度是否为4
check_nd(shape_0, 4)
# 检查shape_1的维度是否为4
check_nd(shape_1, 4)
# 检查stride的维度是否为4
check_nd(stride, 4)
# 获取shape_0的各个维度
n0, h0, w0, c0 = shape_0
# 获取shape_1的各个维度
n1, h1, w1, c1 = shape_1
# 检查n0是否为N0_CHANNEL_ALIGN的倍数
if (n0 % N0_CHANNEL_ALIGN) != 0:
raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}"
.format(N0_CHANNEL_ALIGN, n0))
# 检查n1是否为N1_CHANNEL_ALIGN的倍数
if (n1 % N1_CHANNEL_ALIGN) != 0:
raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}"
.format(N1_CHANNEL_ALIGN, n1))
# 检查c0和c1是否相等并且是否为C_CHANNEL_ALIGN的倍数
if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0:
raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got "
"{} and {}".format(C_CHANNEL_ALIGN, c0, c1))
@ -130,68 +193,106 @@ class Conv2D(Expander):
# check if can optimize to matmul
self.m, self.n, self.k = n0 * h0 * w0, n1, c1
# 调用_optimize_to_matmul()方法判断是否可以优化为matmul
self.can_optimize_to_matmul = self._optimize_to_matmul()
# requirements
if self.can_optimize_to_matmul:
# 如果可以优化为matmul检查k是否大于K_LIMIT
if self.k > K_LIMIT:
raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got "
"{}".format(K_LIMIT, self.k))
# 如果可以优化为matmul检查m*n*k的总大小是否大于MNK_LIMIT
if self.m * self.n * self.k >= MNK_LIMIT:
raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than "
"{}, but got {}".format(MNK_LIMIT, self.m * self.n * self.k))
else:
# 如果不能优化为matmul计算输出的大小
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
# 检查n0*out_h*out_w是否为OUT_NHW_ALIGN的倍数
if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0:
raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}"
.format(n0, out_h, out_w, OUT_NHW_ALIGN))
# 检查stride是否为[1, 1, 2, 2]
if stride != [1, 1, 2, 2]:
raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}"
.format(stride))
# 保存pad后的shape
self.shape_0_pad = [n0, h0, w0, c0]
self.shape_1_pad = [n1, h1, w1, c1]
def _expand(self, graph_builder):
def _expand(self, graph_builder):
"""
对输入进行扩展处理
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展处理后的结果
"""
# 获取输入0
input_0 = self.inputs[0]
# 获取输入1
input_1 = self.inputs[1]
# 获取输入0的形状
n0, _, _, c0 = input_0.shape
# 获取输入1的形状
n1, _, _, c1 = input_1.shape
# 获取输入0的填充形状
n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
# 获取输入1的填充形状
n1_p, _, _, c1_p = self.shape_1_pad
pad_value = 0
# input0 pad
# 初始化输入0的填充前后的值
input_0_pad_before = [0, 0, 0, 0]
input_0_pad_after = [0, 0, 0, 0]
# 如果有填充,则获取填充列表
if self.has_pad:
pad_list = self.attrs['pad_list']
# 设置输入0的填充前后的值
input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
# 设置输入0的填充后的值
input_0_pad_after[0] = n0_p - n0
input_0_pad_after[3] = c0_p - c0
# 如果输入0的填充前后的值不为默认值则进行填充操作
if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
# 发射填充操作
input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
'tail': input_0_pad_after,
'pad_val': pad_value})
# input1 pad
# 计算input_1的pad值
input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
# 如果input_1的pad值不为0则进行pad操作
if input_1_pad_after != [0, 0, 0, 0]:
input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
'tail': input_1_pad_after,
'pad_val': pad_value})
# 如果可以优化为matmul操作则进行matmul操作
if self.can_optimize_to_matmul:
# 将input_0和input_1进行reshape操作
a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
# 进行matmul操作
c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
'transpose_b': True,
'dst_type': self.dst_type})
# 将结果进行reshape操作
result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
'format': self.dst_format})
# 否则进行Conv2D操作
else:
# 设置Conv2D操作的属性
attrs = self.attrs
attrs['pad_list'] = [0, 0, 0, 0]
attrs['dst_type'] = self.dst_type
# 进行Conv2D操作
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
# unpad
unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]

@ -13,18 +13,36 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for DropoutGrad"""
# 导入Expander和ExpanderInfoValidator类
from ._utils import Expander, ExpanderInfoValidator as VLD
# 定义DropoutGrad类继承自Expander类
@VLD.check_all_formats_same
@VLD.check_attrs('keep_prob')
class DropoutGrad(Expander):
"""DropoutGrad expander"""
def _expand(self, graph_builder):
"""
对输入数据进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的输入数据
"""
# 获取输入数据和掩码
input_dy, input_mask = self.inputs
# 获取保持概率
keep_prob = self.attrs['keep_prob']
# 计算保持概率的倒数
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
# 计算输入数据和保持概率的乘积
result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
# 计算乘积和掩码的乘积
result = graph_builder.emit('Mul', [result, input_mask])
# 返回结果
return result

@ -17,34 +17,84 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx
from ._utils import Expander, ExpanderInfoValidator as VLD
# @VLD.check_all_formats_same检查所有格式的相同性
@VLD.check_all_formats_same
class EqualCount(Expander):
"""EqualCount expander"""
def __init__(self, expand_info):
"""
初始化方法
Args:
expand_info (dict): 扩展信息字典
Returns:
None
"""
# 调用父类的初始化方法
super().__init__(expand_info)
# 获取输入x的形状
self.shape_x = self.inputs[0]['shape']
# 获取输入y的形状
self.shape_y = self.inputs[1]['shape']
# 获取输入x的数据类型
self.dtype_x = self.inputs[0]['data_type']
# 获取输入y的数据类型
self.dtype_y = self.inputs[1]['data_type']
def _check(self):
"""
检查输入的两个张量是否具有相同的形状和数据类型
Args:
Returns:
Raises:
GKException: 如果两个张量的形状不同则引发异常异常信息中包含两个张量的形状
GKException: 如果两个张量的数据类型不同则引发异常异常信息中包含两个张量的数据类型
"""
# 判断输入的形状是否相同
if self.shape_x != self.shape_y:
# 如果不相同,抛出异常
raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}"
.format(self.shape_x, self.shape_y))
# 判断输入的数据类型是否相同
if self.dtype_x != self.dtype_y:
# 如果不相同,抛出异常
raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}"
.format(self.dtype_x, self.dtype_y))
def _expand(self, graph_builder):
"""
扩展输入维度的方法
Args:
graph_builder: 图构建器对象用于生成计算图
Returns:
扩展后的张量
"""
# 获取输入张量
input_x = self.inputs[0]
input_y = self.inputs[1]
# 比较输入张量是否相等
eql_val = graph_builder.emit('Equal', [input_x, input_y])
# 将比较结果转换为float32类型
cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
# 获取输入张量的维度
axis = list(range(len(input_x.shape)))
# 对比较结果进行求和
result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': True})
# 如果求和结果的数据类型与输入张量的数据类型不同,则将求和结果转换为输入张量的数据类型
if result.dtype != input_x.dtype:
result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
# 返回求和结果
return result

@ -16,20 +16,44 @@
from ._utils import Expander
# 定义一个Erfc类继承自Expander类
class Erfc(Expander):
"""Erfc expander"""
def _expand(self, graph_builder):
"""
对输入数据进行扩展处理
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展处理后的结果
"""
# 获取输入数据
input_x = self.inputs[0]
# 初始化结果
result = None
# 如果输入数据的类型是float16
if input_x.dtype == "float16":
# 创建一个float32类型的常量1
const_one = graph_builder.value("float32", 1)
# 将输入数据转换为float32类型
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
# 计算输入数据的erf值
erf_result = graph_builder.emit('Erf', [input_x])
# 计算结果
result = graph_builder.emit('Sub', [const_one, erf_result])
# 将结果转换为float16类型
result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
# 返回结果
return result
# 创建一个与输入数据类型相同的常量1
const_one = graph_builder.value(input_x.dtype, 1)
# 计算输入数据的erf值
erf_result = graph_builder.emit('Erf', [input_x])
# 计算结果
result = graph_builder.emit('Sub', [const_one, erf_result])
# 返回结果
return result

@ -21,6 +21,16 @@ class ExpandDims(Expander):
"""ExpandDims expander"""
def _expand(self, graph_builder):
"""
对输入数据进行维度扩展
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的数据
"""
input_x = self.inputs[0]
shape = self.infer_shape(input_x.shape, self.attrs['axis'])
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape})
@ -29,8 +39,35 @@ class ExpandDims(Expander):
@staticmethod
def infer_shape(shape, axis):
"""
根据给定的轴位置推断新的形状
Args:
shape (list or tuple): 原始形状表示一个多维数组的尺寸
axis (int, list or tuple): 指定要插入新维度的轴位置如果为整数表示在指定位置插入一个维度如果为列表或元组则按顺序在指定位置插入多个维度
Returns:
list: 插入新维度后的新形状
Raises:
ValueError: 如果axis的值或类型不符合要求时抛出
"""
"""infer shape for expand_dims"""
def insert_axis(shape, axis):
"""
在指定轴上插入一个新的维度
Args:
shape (list): 原始数组的形状类型为列表
axis (int): 要插入新维度的轴的位置
Returns:
list: 插入新维度后的数组形状
Raises:
ValueError: 如果axis的类型不是int或者axis的值不在合法范围内将抛出异常
"""
if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1:
raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, "
"{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis)))

@ -21,24 +21,51 @@ class FusedAdam(Expander):
"""FusedAdam expander"""
def _expand(self, graph_builder):
"""
使用图构建器对模型参数进行更新
Args:
graph_builder (GraphBuilder): 图构建器实例用于生成计算图
Returns:
Tensor: 更新后的参数结果
"""
# 获取输入参数
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
# 计算beta_1乘以m
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
# 计算one_sub_beta_1乘以gradient
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
# 计算next_m
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
# 计算beta_2乘以v
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
# 计算gradient的平方
grad_square = graph_builder.emit('Mul', [gradient, gradient])
# 计算one_sub_beta_2乘以gradient的平方
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
# 计算next_v
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
# 计算next_v的平方根
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
# 计算sqrt_next_v加上eps
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
# 计算更新值
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
# 计算更新值乘以lr
update_with_lr = graph_builder.emit('Mul', [lr, update])
# 计算next_para
next_para = graph_builder.emit('Sub', [param, update_with_lr])
# 将next_para赋值给param
param_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
# 将next_m赋值给m
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
# 将next_v赋值给v
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
# 返回param_result
return param_result

@ -21,27 +21,54 @@ class FusedAdamWeightDecay(Expander):
"""FusedAdamWeightDecay expander"""
def _expand(self, graph_builder):
"""
对输入参数进行梯度下降更新
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中添加节点
Returns:
ParaResult: 更新后的参数结果节点
"""
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs
# compute result
# 计算beta_1和m的乘积
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
# 计算one_sub_beta_1和gradient的乘积
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
# 计算next_m
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
# 计算beta_2和v的乘积
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
# 计算gradient的平方
grad_square = graph_builder.emit('Mul', [gradient, gradient])
# 计算one_sub_beta_2和grad_square的乘积
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
# 计算next_v
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
# 计算sqrt_next_v
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
# 计算sqrt_next_v和eps的和
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
# 计算update
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
# 计算param_with_weight_decay
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
# 计算update和param_with_weight_decay的和
update = graph_builder.emit('Add', [update, param_with_weight_decay])
# 计算update_with_lr
update_with_lr = graph_builder.emit('Mul', [lr, update])
# 计算next_para
next_para = graph_builder.emit('Sub', [param, update_with_lr])
# 将next_para赋值给param
para_result = graph_builder.emit(
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
# 将next_m赋值给m
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
# 将next_v赋值给v
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
return para_result

@ -20,9 +20,23 @@ class FusedMulAdd(Expander):
"""FusedMulAdd expander"""
def _expand(self, graph_builder):
"""
执行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 执行加法操作后的结果
"""
# 获取输入
input_x, input_y, input_z = self.inputs
# 发射乘法操作
mul_res = graph_builder.emit('Mul', [input_x, input_y])
# 发射加法操作
result = graph_builder.emit('Add', [mul_res, input_z])
# 返回结果
return result

@ -22,22 +22,47 @@ class Gather(Expander):
"""Expand Gather"""
def _expand(self, graph_builder):
"""
对输入张量进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的张量
"""
# 获取输入和索引
inputs, indices = self.inputs
# 获取轴
axis = self.attrs['axis']
# 如果轴小于0则将其转换为正数
if axis < 0:
axis += len(inputs.shape)
# 如果索引的维度为1则直接进行Gather操作
if len(indices.shape) == 1:
result = graph_builder.emit('Gather', [inputs, indices], attrs={'axis': axis})
# 否则对索引进行Reshape操作然后进行Gather操作最后再进行Reshape操作
else:
# 获取原始索引的形状
ori_indices_shape = indices.shape
# 计算索引的形状的乘积
indices_shape_one_dim = 1
for dim in ori_indices_shape:
indices_shape_one_dim *= dim
# 构造新的索引形状
new_indices_shape = [indices_shape_one_dim]
# 对索引进行Reshape操作
reshape_indices = graph_builder.emit('Reshape', [indices], attrs={'shape': new_indices_shape})
# 对输入和Reshape后的索引进行Gather操作
tmp_result = graph_builder.emit('Gather', [inputs, reshape_indices], attrs={'axis': axis})
# 获取输出的形状
output_shape = list(inputs.shape)
# 将索引的形状插入到输出的形状中
output_shape[axis:axis] = ori_indices_shape
# 删除输出的形状中多余的维度
del output_shape[axis + len(ori_indices_shape)]
# 对Gather操作的结果进行Reshape操作
result = graph_builder.emit('Reshape', [tmp_result], attrs={'shape': output_shape})
# 返回结果
return result

@ -22,6 +22,16 @@ class GeLU(Expander):
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
def _expand(self, graph_builder):
"""
计算输入张量的GELU激活函数值
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: 输入张量的GELU激活函数值
"""
# cal formula are:
# gelu of x is 0.5 * x * (1.0 + tanh(y))
# y is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)'
@ -29,20 +39,33 @@ class GeLU(Expander):
input_x = self.inputs[0]
# cal y
# 计算 input_x 的平方
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
# 计算 input_x 的立方
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
# 创建一个 CSVALUE 常量
const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE)
# 计算 pow_0 和 CSVALUE 的乘积
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
# 计算 input_x 和 mul_1 的和
tanh_res = graph_builder.emit('Add', [input_x, mul_1])
# 创建一个 CSVALUE_SQRT_TWO_DIV_PI 常量
const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
# 计算 tanh_res 和 CSVALUE_SQRT_TWO_DIV_PI 的乘积
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
# cal gelu(x)
# 计算 y 的 tanh 值
tanh_y = graph_builder.emit('Tanh', [y])
# 创建一个 1 常量
const_one = graph_builder.value(tanh_y.dtype, 1)
# 创建一个 0.5 常量
const_half = graph_builder.value(tanh_y.dtype, 0.5)
# 计算 tanh_y 和 1 的和
tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one])
# 计算 input_x 和 tanh_y_add_one 的乘积
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
# 计算 const_half 和 mul_x 的乘积
result = graph_builder.emit('Mul', [const_half, mul_x])
return result

@ -19,11 +19,29 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
class GeLUGrad(Expander):
"""GeLUGrad expander"""
CSVALUE = 0.044715
# CSVALUE = 0.044715
CSVALUE = 0.044715 # CSVALUE的值为0.044715
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
CSVALUE_TRI = 0.134141 # CSVALUE * 3
def _expand(self, graph_builder):
"""
计算GELU函数的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于生成计算图
Returns:
Tensor: GELU函数的梯度
计算公式如下
GELU的梯度dy和x是dy * y'
y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
"""
# cal formula are:
# gelu_grad of dy and x is dy * y'
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
@ -33,21 +51,33 @@ class GeLUGrad(Expander):
input_dy, input_x, _ = self.inputs
# create some const var
# 创建一个常量值为self.CSVALUE数据类型为input_dy.dtype
const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE)
# 创建一个常量值为self.CSVALUE_SQRT_TWO_DIV_PI数据类型为input_dy.dtype
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
# 创建一个常量值为self.CSVALUE_TRI数据类型为input_dy.dtype
const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI)
# 创建一个常量值为1数据类型为input_dy.dtype
const_one = graph_builder.value(input_dy.dtype, 1)
# 创建一个常量值为0.5数据类型为input_dy.dtype
const_half = graph_builder.value(input_dy.dtype, 0.5)
# cal mul_right
# 计算input_x的平方
mul_double = graph_builder.emit('Mul', [input_x, input_x])
# 将const_csvalue_tri与mul_double相乘
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
# 将const_one与mul_double_mul_tri相加
mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
# 将const_csvalue_sqrt_two_div_pi与mul_add_one相乘
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
# cal tanh_para
# 计算input_x和mul_double的乘积
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
# 计算const_csvalue和mul_triple的乘积
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
# 计算input_x和mul_triple_mul_csvalue的和
mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])

@ -22,12 +22,27 @@ class GkDropout(Expander):
"""GkDropout expander"""
def _expand(self, graph_builder):
"""
对输入数据进行dropout操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
tuple: 包含两个元素第一个是执行dropout操作后的结果第二个是生成的掩码
"""
# 获取输入数据和掩码
input_x, input_mask = self.inputs
# 获取保持概率
keep_prob = self.attrs['keep_prob']
# 计算保持概率的倒数
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob)
# 计算保持概率
keep_prob = graph_builder.value(input_x.dtype, keep_prob)
# 如果掩码的数据类型与输入数据类型不同,则进行类型转换
if input_mask.dtype != input_x.dtype:
input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype})
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type

@ -20,6 +20,16 @@ class Identity(Expander):
"""Identity expander"""
def _expand(self, graph_builder):
input_x = self.inputs[0]
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape})
return result
"""
对输入数据进行重塑操作
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
Tensor: 重塑后的输入数据
"""
input_x = self.inputs[0] # 获取输入数据
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape}) # 使用图构建器对象构建计算图,对输入数据进行重塑操作
return result # 返回重塑后的输入数据

@ -25,67 +25,107 @@ class LayerNorm(Expander):
"""LayerNorm expander"""
def _expand(self, graph_builder):
"""
对输入进行扩展处理包括批量归一化操作
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
tuple: 包含三个元素的元组分别是处理后的输入均值和方差
- res (Tensor): 处理后的输入张量
- mean (Tensor): 输入的均值张量
- variance (Tensor): 输入的方差张量
"""
# 获取输入数据
input_x, input_gamma, input_beta = self.inputs
# 获取处理器类型
processor = self.processor
# 获取归一化开始轴
begin_norm_axis = self.attrs['begin_norm_axis']
# 获取epsilon值
epsilon = self.attrs['epsilon']
# 获取输入数据的原始数据类型
ori_dtype = input_x.dtype
# 如果处理器类型为aicore且输入数据类型为float16则将输入数据类型转换为float32
if processor == 'aicore' and ori_dtype == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
# 获取输入数据的原始形状
ori_shape_x = input_x.shape
# 如果输入数据的格式为FRAC_NZ则根据FRAC_NZ格式获取输入数据的形状
if input_x.data_format == DF.FRAC_NZ:
ori_shape_x = infer_shape_from_fractalnz(input_x.shape)
# Calculate the scaling ratio of the average
# 如果begin_norm_axis小于0则将其加上ori_shape_x的长度
if begin_norm_axis < 0:
begin_norm_axis += len(ori_shape_x)
# 定义reduce_axis用于存储需要归一化的维度
reduce_axis = ()
# 遍历ori_shape_x如果维度大于begin_norm_axis或者等于begin_norm_axis则将其加入reduce_axis
for i, _ in enumerate(ori_shape_x):
if i > begin_norm_axis or i == begin_norm_axis:
reduce_axis = reduce_axis + (i,)
# 计算reduce_elts即需要归一化的维度上的元素个数
reduce_elts = 1.0
for i in reduce_axis:
reduce_elts *= ori_shape_x[i]
# after reduced
# 获取归一化后的ori_shape_x
ori_reduced_shape_x = get_reduced_ori_shape(ori_shape_x, reduce_axis)
# 定义axis用于存储归一化的维度
axis = reduce_axis
# 如果input_x的数据格式为DF.FRAC_NZ则将axis转换为frac_z轴
if input_x.data_format == DF.FRAC_NZ:
axis = to_frac_z_axis(ori_shape_x, reduce_axis)
# 计算mean_cof_v即归一化系数
mean_cof_v = graph_builder.value(input_x.dtype, 1.0 / reduce_elts)
# Calculate mean
# 计算输入张量的均值
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
# 将均值乘以系数
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
# 如果输入张量的数据格式为DF.FRAC_NZ则对均值进行重整
if input_x.data_format == DF.FRAC_NZ:
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x})
# Calculate variance
variance_sub = graph_builder.emit('Sub', [input_x, mean])
variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub])
variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': axis, 'keep_dims': True})
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
# 计算方差
variance_sub = graph_builder.emit('Sub', [input_x, mean]) # 计算输入与均值的差值
variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub]) # 计算差值的平方
variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': axis, 'keep_dims': True}) # 对差值的平方求和
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) # 计算方差
# 如果输入数据的格式为DF.FRAC_NZ则对方差进行reshape操作
if input_x.data_format == DF.FRAC_NZ:
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x})
# Calculate normalize
# 计算输入x与均值之间的差值
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
# 创建一个epsilon值用于防止除零错误
epsilon_v = graph_builder.value(input_x.dtype, epsilon)
# 计算方差加上epsilon的值
normalize_add = graph_builder.emit('Add', [variance, epsilon_v])
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
# Calculate scale and translate
# 计算归一化后的乘积
scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
# 计算最终结果
res = graph_builder.emit('Add', [scale_mul, input_beta])
# 如果处理器为aicore且原始数据类型为float16则将结果、均值和方差转换为float16
if processor == 'aicore' and ori_dtype == 'float16':
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})

@ -23,13 +23,33 @@ class LayerNormGrad(Expander):
"""LayerNormGrad expander"""
def _expand(self, graph_builder):
"""
对输入进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
tuple: 包含dx, dg, db的元组
dx (Tensor): 梯度相对于输入x的导数
dg (Tensor): 梯度相对于gamma的导数
db (Tensor): 梯度相对于beta的导数
"""
# 获取输入参数
x, dy, variance, mean, gamma = self.inputs
# 获取处理器类型
processor = self.processor
# 获取归一化轴的起始位置
begin_norm_axis = self.attrs['begin_norm_axis']
# 获取参数轴的起始位置
begin_params_axis = self.attrs['begin_params_axis']
# 获取epsilon值默认为1e-12
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
# 获取输入数据的原始数据类型
ori_dtype = x.dtype
# 如果处理器类型为aicore且数据类型为float16则将输入数据转换为float32
if processor == 'aicore' and ori_dtype == 'float16':
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
@ -37,77 +57,121 @@ class LayerNormGrad(Expander):
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
# 如果归一化轴的起始位置小于0则将其转换为正数
if begin_norm_axis < 0:
begin_norm_axis += len(x.shape)
# 如果参数轴的起始位置小于0则将其转换为正数
if begin_params_axis < 0:
begin_params_axis += len(x.shape)
# 获取归一化轴和参数轴的范围
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
param_axis = tuple(range(0, begin_params_axis))
# 计算归一化轴的维度乘积
reduce_size = 1.0
for i in norm_axis:
reduce_size *= x.shape[i]
# set some constant val.
# 计算epsilon的值
eps = graph_builder.value(x.dtype, epsilon)
# 计算-0.5的值
const_neg_half = graph_builder.value(x.dtype, -0.5)
# 计算-2.0的值
const_neg_two = graph_builder.value(x.dtype, -2.0)
# 计算2.0的值
const_two = graph_builder.value(x.dtype, 2.0)
# 计算-1.0的值
const_neg_one = graph_builder.value(x.dtype, -1.0)
# 计算mean_cof的值
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size))
# cal dg db
# 计算方差和eps的和
var_eps = graph_builder.emit('Add', [variance, eps])
# 计算方差和eps的和的对数
var_eps_log = graph_builder.emit('Log', [var_eps])
# 计算方差和eps的对数乘以-0.5
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
# 计算方差和eps的对数乘以-0.5的指数
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
# 计算x和mean的差
# 计算输入x减去均值
x_sub_mean = graph_builder.emit('Sub', [x, mean])
# 计算x减去均值乘以rsqrt_var_eps
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
# 计算dy乘以x减去均值乘以rsqrt_var_eps
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
# 计算dg对dg_mul进行求和reduce_axis为param_axiskeep_dims为False
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# 计算db对dy进行求和reduce_axis为param_axiskeep_dims为False
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
# pd_var
# 计算tmp_var_eps
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
# 计算r_tmp_var_eps
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
# 计算dy_mul_gamma
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
# 计算tmp_mul
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
# 计算padvar_mul1
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# 计算padvar_mul3
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
# 计算pd_var
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
# pd_mean
# 计算pdmean1_sum使用ReduceSum函数输入为dy_mul_gamma归约轴为norm_axis保持维度为True
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# 计算neg_rsqrt_var_eps使用Mul函数输入为rsqrt_var_eps和const_neg_one
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
# 计算pd_mean_1使用Mul函数输入为neg_rsqrt_var_eps和pdmean1_sum
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
# 计算pdmean2_mul1使用Mul函数输入为const_neg_two和x_sub_mean
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
# 计算pdmean2_sum使用ReduceSum函数输入为pdmean2_mul1归约轴为norm_axis保持维度为True
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
# 计算pdmean2_mul3使用Mul函数输入为pdmean2_sum和mean_cof
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
# 计算pd_mean_2使用Mul函数输入为pdmean2_mul3和pd_var
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
# 计算pd_mean使用Add函数输入为pd_mean_1和pd_mean_2
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
# cal dx
# 计算pd_x_1
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
# 计算pdx2_mul
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
# 计算pdx2_mul_two
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
# 计算pd_x_2
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
# 计算pd_x_3
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
# 计算dx_tmp
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
# 计算dx
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
# 如果处理器为aicore且原始数据类型为float16则将dx、dg、db转换为float16
if processor == 'aicore' and ori_dtype == 'float16':
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
# 返回dx、dg、db
return dx, dg, db

@ -23,24 +23,49 @@ class LogSoftmax(Expander):
"""LogSoftmax expander"""
def _expand(self, graph_builder):
"""
对输入数据进行Softmax归一化
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
Tensor: Softmax归一化后的结果
"""
# 获取输入数据
input_x = self.inputs[0]
# 获取轴参数
axis = self.attrs['axis']
# 获取处理器类型
processor = self.processor
# 如果轴参数是整数,则将其转换为元组
if isinstance(axis, int):
axis = (axis,)
# 获取输入数据的原始数据类型
ori_dtype = input_x.dtype
# 如果原始数据类型不是float16且处理器类型是aicore则将输入数据转换为float16
if ori_dtype != "float16" and processor == "aicore":
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
# 对转换后的数据进行ReduceMax操作
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
# 将ReduceMax操作的结果转换回原始数据类型
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
else:
# 对输入数据进行ReduceMax操作
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
# 计算输入数据与ReduceMax操作结果的差值
data_sub = graph_builder.emit('Sub', [input_x, max_x])
# 计算差值的指数
data_exp = graph_builder.emit('Exp', [data_sub])
# 对指数结果进行ReduceSum操作
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
# 计算ReduceSum结果的log
log_expsum = graph_builder.emit('Log', [data_expsum])
# 计算差值与log的差值
result = graph_builder.emit('Sub', [data_sub, log_expsum])
# 返回结果
return result

@ -23,14 +23,32 @@ class LogSoftmaxGrad(Expander):
"""LogSoftmaxGrad expander"""
def _expand(self, graph_builder):
"""
对输入进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展操作的结果
"""
# 获取输入的logits和dy
input_logits, input_dy = self.inputs
# 获取axis参数
axis = self.attrs['axis']
# 如果axis是整数则将其转换为元组
if isinstance(axis, int):
axis = (axis,)
# 计算softmax
softmax = graph_builder.emit('Exp', [input_logits])
# 计算dy的sum
dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True})
# 计算softmax和dy_sum的乘积
mul_result = graph_builder.emit('Mul', [softmax, dy_sum])
# 计算input_dy和mul_result的差
result = graph_builder.emit('Sub', [input_dy, mul_result])
# 返回结果
return result

@ -25,48 +25,139 @@ class MatMul(Expander):
"""
def __init__(self, expand_info):
"""
初始化MatMul类实例
Args:
expand_info (dict): 扩展信息字典包含操作所需的额外信息
Attributes:
transpose_a (bool): 是否对矩阵A进行转置
transpose_b (bool): 是否对矩阵B进行转置
left_format (str): 矩阵A的数据格式
right_format (str): 矩阵B的数据格式
shape_a (tuple): 矩阵A的形状
shape_b (tuple): 矩阵B的形状
"""
# 调用父类的初始化方法
super(MatMul, self).__init__(expand_info)
# 获取transpose_a属性
self.transpose_a = self.attrs['transpose_a']
# 获取transpose_b属性
self.transpose_b = self.attrs['transpose_b']
# 获取left_format属性
self.left_format = self.attrs['left_format']
# 获取right_format属性
self.right_format = self.attrs['right_format']
# 获取输入A的shape
self.shape_a = self.inputs[0]['shape']
# 获取输入B的shape
self.shape_b = self.inputs[1]['shape']
def _optimize_to_mul(self):
"""
检查是否可以用乘法mul替换矩阵乘法matmul
Args:
Returns:
bool: 如果可以用乘法替换矩阵乘法返回True否则返回False
"""
"""check if matmul can be replace by mul"""
# 如果处理器不是'aicore'或者左格式或右格式不是默认格式则返回False
if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT:
return False
# 如果transpose_a为True则k_a为shape_a的倒数第二个维度否则为shape_a的倒数第一个维度
k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1]
# 如果transpose_b为True则k_b为shape_b的倒数第一个维度否则为shape_b的倒数第二个维度
k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2]
# 如果k_a或k_b不等于1则返回False
if k_a != 1 or k_b != 1:
return False
# 否则返回True
return True
def _check(self):
"""
检查输入个数是否满足矩阵乘法的要求
Args:
Returns:
Raises:
GKException: 如果输入的个数小于2则抛出GKException异常提示信息为 "For 'MatMul', inputs number should bigger than 1, but got {}."其中{}为输入的个数
"""
# 获取输入的个数
input_num = len(self.inputs)
# 如果输入的个数小于2抛出异常
if input_num < 2:
raise GKException("For 'MatMul', inputs number should bigger than 1, but got {}.".format(input_num))
def _expand(self, graph_builder):
"""
将MatMul或BatchMatMul操作替换为Mul操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Node: Mul操作的结果节点
Raises:
GKException: 如果不需要将MatMul/BatchMatMul替换为Mul操作则引发异常
"""
# 定义一个函数用于转置shape
def transpose(shape):
"""
将给定的shape进行转置操作
Args:
shape (tuple): 输入的shape为一个元组表示多维数组的形状
Returns:
list: 转置后的shape以列表形式返回
"""
# 将shape转换为列表
trans_shape = list(shape)
# 将shape的倒数第二个元素和倒数第一个元素交换位置
trans_shape[-2] = shape[-1]
trans_shape[-1] = shape[-2]
# 返回转置后的shape
return trans_shape
# 如果不需要优化为乘法,则抛出异常
if not self._optimize_to_mul():
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
# Matmul is replaced by Mul([b m k], [b k n]) when k==1
# 获取输入a
input_a = self.inputs[0]
# 获取输入b
input_b = self.inputs[1]
# 如果transpose_a为True则对输入a进行转置
if self.transpose_a:
# 获取输入a的转置形状
shape_a_trans = transpose(self.shape_a)
# 对输入a进行转置
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
# 如果transpose_b为True则对输入b进行转置
if self.transpose_b:
# 获取输入b的转置形状
shape_b_trans = transpose(self.shape_b)
# 对输入b进行转置
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
# 对输入a和输入b进行乘法运算
result = graph_builder.emit('Mul', [input_a, input_b])
# 如果dst_type在attrs中并且输入a的数据类型与dst_type不同则对结果进行类型转换
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
# 对结果进行类型转换
result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
return result

@ -23,35 +23,76 @@ class MaximumGrad(Expander):
"""MaximumGrad expander"""
def _check(self):
"""
检查MaximumGrad的属性是否符合要求
Args:
Returns:
返回父类的检查结果
Raises:
GKException: 'grad_x' 'grad_y' 的值都为 False 时抛出异常
"""
# 如果attr 'grad_x'和'grad_y'的值都为False则抛出异常
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
raise GKException("For 'MaximumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
"{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
# 调用父类的方法
return super()._check()
def _expand(self, graph_builder):
"""
根据输入计算梯度并返回两个梯度结果
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
tuple: 包含两个梯度结果的元组第一个元素为对输入x的梯度第二个元素为对输入y的梯度
"""
# 获取输入的x、y和dout
input_x, input_y, input_dout = self.inputs
# 比较x和y的大小返回一个布尔值
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
# 将布尔值转换为与x相同的类型
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
# 计算dx即x的梯度
dx = graph_builder.emit('Mul', [ge_result, input_dout])
# 计算dy即y的梯度
dy = graph_builder.emit('Sub', [input_dout, dx])
# 获取dx和dy的reduce轴
reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape)
reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape)
# 如果dx有reduce轴
if reduce_axis_x:
# 对dx进行求和
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
# 如果dx_reduce的形状与input_x的形状不同则进行reshape
if dx_reduce.shape != input_x.shape:
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
# 否则dx_out等于dx_reduce
else:
dx_out = dx_reduce
# 否则dx_out等于dx
else:
dx_out = dx
# 如果dy有reduce轴
if reduce_axis_y:
# 对dy进行求和
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
# 如果dy_reduce的形状与input_y的形状不同则进行reshape
if dy_reduce.shape != input_y.shape:
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
# 否则dy_out等于dy_reduce
else:
dy_out = dy_reduce
# 否则dy_out等于dy
else:
dy_out = dy

@ -22,59 +22,117 @@ class MinimumGrad(Expander):
"""MinimumGrad expander"""
def _check(self):
"""
检查MinimumGrad类的属性是否满足要求
Args:
Returns:
bool: 如果属性符合要求则返回True否则抛出GKException异常
Raises:
GKException: 如果MinimumGrad类的属性'grad_x''grad_y'均为False则抛出此异常
"""
# 如果attr 'grad_x'和'grad_y'的值都为False则抛出异常
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
raise GKException("For 'MinimumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
"{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
# 调用父类的方法
return super(MinimumGrad, self)._check()
def _expand(self, graph_builder):
"""
计算两个输入的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中执行操作
Returns:
tuple: 包含两个梯度结果的元组
"""
# 输入参数
input_x, input_y, input_dout = self.inputs
le_result = graph_builder.emit('LessEqual', [input_x, input_y])
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype})
dx = graph_builder.emit('Mul', [le_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx])
# 执行 LessEqual 操作
le_result = graph_builder.emit('LessEqual', [input_x, input_y]) # 执行 LessEqual 操作
# 将结果转换为与 input_x 相同的数据类型
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype}) # 将结果转换为与 input_x 相同的数据类型
# 执行 Mul 操作,将 le_result 和 input_dout 相乘
dx = graph_builder.emit('Mul', [le_result, input_dout]) # 执行 Mul 操作,将 le_result 和 input_dout 相乘
# 执行 Sub 操作,用 input_dout 减去 dx
dy = graph_builder.emit('Sub', [input_dout, dx]) # 执行 Sub 操作,用 input_dout 减去 dx
# 对于 minimumgrad 操作,输出形状应与输入形状相同,
# 但某些元素级操作可能会广播输入形状,
# 导致输出形状不等于原始输入形状,因此需要减少输出来使它们相等
# for minimumgrad op, output_shape should be equal to input_shape,
# but some elementwise operating may broadcast input_shape
# then output_shape not equal to original input_shape, so need to reduce output to let them equal
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape) # 获取 x 的减少轴
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape) # 获取 y 的减少轴
if reduce_axis_x:
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
# 如果存在减少轴,执行 ReduceSum 操作
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False}) # 执行 ReduceSum 操作
if dx_reduce.shape != input_x.shape:
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
# 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape}) # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作
else:
dx_out = dx_reduce
dx_out = dx_reduce # 否则直接使用减少后的结果
else:
dx_out = dx
dx_out = dx # 如果没有减少轴,则直接使用 dx
if reduce_axis_y:
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
# 如果存在减少轴,执行 ReduceSum 操作
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False}) # 执行 ReduceSum 操作
if dy_reduce.shape != input_y.shape:
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
# 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape}) # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作
else:
dy_out = dy_reduce
dy_out = dy_reduce # 否则直接使用减少后的结果
else:
dy_out = dy
dy_out = dy # 如果没有减少轴,则直接使用 dy
# output two results, regardless of grad_x and grad_y
# 输出两个结果,
return dx_out, dy_out
@staticmethod
def get_reduce_axis(original_shape, broadcast_shape):
"""
计算最终输出形状的归约轴
Args:
original_shape (tuple of int): 原始形状一个包含整数的元组
broadcast_shape (tuple of int): 广播形状一个包含整数的元组
Returns:
list of int: 归约轴列表表示在最终输出形状中需要归约的轴索引
Raises:
ValueError: 如果original_shape的长度大于broadcast_shape的长度或者original_shape和broadcast_shape无法广播
"""
"""compute reduce axis for final output_shape"""
# 如果original_shape的长度大于broadcast_shape的长度
if len(original_shape) > len(broadcast_shape):
raise ValueError("For 'MinimumGrad', the length of original_shape should be less than or equal to the "
"length of broadcast_shape, but got {} and {}".format(original_shape, broadcast_shape))
# 创建一个tmp_shape列表长度为broadcast_shape的长度前面填充1后面填充original_shape的元素
tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape
reduce_axis = []
# 遍历tmp_shape中的每个元素
for idx, _ in enumerate(tmp_shape):
# 如果tmp_shape中的元素与broadcast_shape中的对应元素不相等
if tmp_shape[idx] != broadcast_shape[idx]:
# 如果tmp_shape中的元素为1
if tmp_shape[idx] == 1:
# 将当前索引添加到reduce_axis列表中
reduce_axis.append(idx)
else:
# 抛出异常表示original_shape和broadcast_shape无法广播
raise ValueError("For 'MinimumGrad', original_shape {} and broadcast_shape {} can not broadcast."
.format(original_shape, broadcast_shape))
return reduce_axis

@ -20,7 +20,24 @@ class OnesLike(Expander):
"""OnesLike expander"""
def _expand(self, graph_builder):
"""
将输入张量扩展至指定形状
Args:
graph_builder: 图构建器对象用于构建图结构
Returns:
扩展后的张量
"""
# 获取输入张量
input_x = self.inputs[0]
# 创建一个值为1的常量数据类型与输入张量相同
const_one = graph_builder.value(input_x.dtype, 1)
# 使用BroadcastTo操作将常量扩展至输入张量的形状
result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape})
# 返回扩展后的张量
return result

@ -22,22 +22,41 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
class ReduceMean(Expander):
"""ReduceMean expander"""
def _expand(self, graph_builder):
def _expand(self, graph_builder):
"""
对输入张量进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展操作后的张量
"""
# 获取输入张量
x = self.inputs[0]
# 获取扩展操作的轴
axis = self.attrs['axis']
# 获取是否保持维度
keep_dims = self.attrs['keep_dims']
# 如果轴不是元组或列表,则将其转换为元组
if not isinstance(axis, (tuple, list)):
axis = (axis,)
# 如果轴为空,则将其设置为张量的所有维度
elif not axis:
axis = list(range(len(x.shape)))
# 计算缩减的大小
reduce_size = 1.0
for idx in axis:
reduce_size *= x.shape[idx]
# 创建一个与输入张量相同数据类型的值,值为缩减的大小
reduce_size_value = graph_builder.value(x.dtype, reduce_size)
# 沿指定轴对输入张量进行求和操作
sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
# 将求和结果除以缩减的大小,得到扩展后的张量
result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value])
return result

@ -21,12 +21,28 @@ class ReluGrad(Expander):
"""ReLU expander"""
def _expand(self, graph_builder):
"""
在指定的图构建器中扩展当前节点
Args:
graph_builder (GraphBuilder): 图构建器实例用于在图中生成新的节点
Returns:
Tensor: 返回计算后的结果张量
"""
# 获取输入张量
input_x = self.inputs[0]
input_y = self.inputs[1]
# 生成一个与input_y相同数据类型的0值张量
# 生成一个与input_y相同数据类型的0值张量
const_zero = graph_builder.value(input_y.dtype, 0)
# 判断input_y是否大于0生成布尔张量
ge_result = graph_builder.emit('Greater', [input_y, const_zero])
# 将布尔张量转换为与input_x相同数据类型的张量
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
# 将转换后的张量与input_x相乘
result = graph_builder.emit('Mul', [ge_result, input_x])
return result

@ -21,21 +21,46 @@ class SigmoidCrossEntropyWithLogits(Expander):
"""SigmoidCrossEntropyWithLogits expander"""
def _expand(self, graph_builder):
"""
计算sigmoid交叉熵损失
Args:
graph_builder: 图构建器对象用于构建计算图
Returns:
计算得到的sigmoid交叉熵损失值
"""
logits, labels = self.inputs
# 计算 logits 和 labels 的 sigmoid_cross_entropy_with_logits
# Calculate sigmoid_cross_entropy_with_logits(logits, labels)
# formula of sigmoid_cross_entropy_with_logits is:
# -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits)))
# To ensure stability and avoid overflow, the formula equal to :
# max(logits, 0) - logits * labels + log(1 + exp(-abs(logits)))
# sigmoid_cross_entropy_with_logits 的公式为:
# -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits)))
# 为了确保稳定性并避免溢出,该公式等价于:
# max(logits, 0) - logits * labels + log(1 + exp(-abs(logits)))
# 创建一个值为 1.0 的常量
const_one = graph_builder.value(logits.dtype, 1.0)
# 创建一个值为 0.0 的常量
const_zero = graph_builder.value(logits.dtype, 0.0)
# 计算 logits 和 0 的最大值
max_logits = graph_builder.emit('Maximum', [logits, const_zero])
# 计算 logits 和 labels 的乘积
logits_mul_labels = graph_builder.emit('Mul', [logits, labels])
# 计算 logits 的绝对值
abs_logits = graph_builder.emit('Abs', [logits])
# 计算 logits 的负值
neg_abs_logits = graph_builder.emit('Neg', [abs_logits])
# 计算 exp(-abs(logits))
exp_neg_abs_logits = graph_builder.emit('Exp', [neg_abs_logits])
# 计算 1 + exp(-abs(logits))
one_add_exp_neg_abs_logits = graph_builder.emit('Add', [const_one, exp_neg_abs_logits])
# 计算 log(1 + exp(-abs(logits)))
log_one_add_exp_neg_abs_logits = graph_builder.emit('Log', [one_add_exp_neg_abs_logits])
# 计算 max(logits, 0) - logits * labels
res_tmp = graph_builder.emit('Sub', [max_logits, logits_mul_labels])
# 计算最终结果
res = graph_builder.emit('Add', [res_tmp, log_one_add_exp_neg_abs_logits])
return res

@ -21,15 +21,29 @@ class SigmoidCrossEntropyWithLogitsGrad(Expander):
"""SigmoidCrossEntropyWithLogitsGrad expander"""
def _expand(self, graph_builder):
"""
计算sigmoid交叉熵损失的梯度
Args:
graph_builder: 图构建器对象用于构建计算图
Returns:
计算得到的梯度值
"""
logits, label, dout = self.inputs
# Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
# formula of sigmoid_cross_entropy_with_logits_grad is :
# 计算sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
# sigmoid_cross_entropy_with_logits_grad的公式为:
# (sigmoid(logits) - label) * dout
# 计算sigmoid(logits)
# Calculate sigmoid(logits)
const_one = graph_builder.value(logits.dtype, 1.0)
neg_x = graph_builder.emit('Neg', [logits])
exp_neg_x = graph_builder.emit('Exp', [neg_x])
add_exp = graph_builder.emit('Add', [const_one, exp_neg_x])
sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp])
neg_x = graph_builder.emit('Neg', [logits]) # 计算-logits
exp_neg_x = graph_builder.emit('Exp', [neg_x]) # 计算e^(-logits)
add_exp = graph_builder.emit('Add', [const_one, exp_neg_x]) # 计算1 + e^(-logits)
sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp]) # 计算1 / (1 + e^(-logits))即sigmoid(logits)
# 计算(sigmoid(logits) - label)
sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label])
res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout])
# 计算最终结果
res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout]) # 计算(sigmoid(logits) - label) * dout
return res

@ -16,16 +16,31 @@
from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.check_all_formats_same
@VLD.check_all_formats_same # 定义一个SigmoidGrad类继承自Expander类
class SigmoidGrad(Expander):
"""SigmoidGrad expander"""
def _expand(self, graph_builder):
"""
计算 sigmoid 函数的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中添加操作
Returns:
Tensor: 计算得到的 sigmoid 梯度
"""
input_y, dy = self.inputs
# 计算 sigmoid_grad(y, dy)
# sigmoid_grad 的公式是: (1 - y) * y * dy
# Calculate sigmoid_grad(y, dy)
# formula of sigmoid_grad is : (1 - y) * y * dy
const_one = graph_builder.value(input_y.dtype, 1.0)
# 1 - y
one_mins_y = graph_builder.emit('Sub', [const_one, input_y])
# y * dy
y_mul_dy = graph_builder.emit('Mul', [input_y, dy])
# (1 - y) * (y * dy)
res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy])
return res

@ -21,15 +21,41 @@ class Slice(Expander):
"""Slice expander"""
def _expand(self, graph_builder):
"""
在图中扩展输入张量
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 扩展后的输出张量
"""
# 获取输入张量
input_x = self.inputs[0]
# 获取开始索引
begin = self.attrs['begin']
# 获取切片大小
size = self.attrs['size']
# 初始化结束索引列表
end = []
# 初始化步长列表
strides = []
# 遍历每个维度,计算结束索引和步长
for i, begin_idx in enumerate(begin):
# 步长设置为1
strides.append(1)
# 计算结束索引
end.append(begin_idx + size[i])
# 创建一个新的张量作为输出
output = graph_builder.tensor(size, input_x.dtype, input_x.data_format)
# 执行StridedSlice操作对输入张量进行切片
graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides})
# 返回输出张量
return output

@ -25,45 +25,75 @@ class Softmax(Expander):
"""Softmax expander"""
def _expand(self, graph_builder):
"""
计算Softmax函数值
Args:
graph_builder: 图构建器对象
Returns:
Softmax函数的计算结果
"""
# 获取输入数据
input_x = self.inputs[0]
# 获取处理器
processor = self.processor
# 获取轴信息
axis = self.attrs['axis']
# 获取输入数据的原始形状
ori_shape = input_x.shape
# 如果输入数据格式为FRAC_NZ则推断其形状
if input_x.data_format == DF.FRAC_NZ:
ori_shape = infer_shape_from_fractalnz(input_x.shape)
# 遍历轴信息,处理负数轴索引
for i, _ in enumerate(list(axis)):
if axis[i] < 0:
axis[i] += len(ori_shape)
# 获取减少维度后的原始形状
ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
# 获取减少的轴
reduce_axis = axis
# 如果输入数据格式为FRAC_NZ则转换轴
if input_x.data_format == DF.FRAC_NZ:
reduce_axis = to_frac_z_axis(ori_shape, axis)
# 获取输入数据的原始数据类型
ori_dtype = input_x.dtype
# 如果原始数据类型不是float16且处理器为aicore则进行类型转换
if ori_dtype != "float16" and processor == "aicore":
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': reduce_axis,
'keep_dims': True})
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
else:
# 计算最大值
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
# 如果原始数据类型为float16且处理器为aicore则进行类型转换
if ori_dtype == "float16" and processor == "aicore":
max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': "float32"})
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
# 如果输入数据格式为FRAC_NZ则重新调整最大值的形状
if input_x.data_format == DF.FRAC_NZ:
max_x = graph_builder.emit('Reshape', [max_x], attrs={'shape': ori_reduced_shape})
# 计算输入数据减去最大值的差值
data_sub = graph_builder.emit('Sub', [input_x, max_x])
# 计算差值的指数
data_exp = graph_builder.emit('Exp', [data_sub])
# 计算指数的和
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': reduce_axis, 'keep_dims': True})
# 如果输入数据格式为FRAC_NZ则重新调整指数和的形状
if input_x.data_format == DF.FRAC_NZ:
data_expsum = graph_builder.emit('Reshape', [data_expsum], attrs={'shape': ori_reduced_shape})
# 计算Softmax值
result = graph_builder.emit('RealDiv', [data_exp, data_expsum])
# 如果原始数据类型为float16且处理器为aicore则进行类型转换
if ori_dtype == "float16" and processor == "aicore":
result = graph_builder.emit('Cast', [result], attrs={'dst_type': ori_dtype})

@ -22,21 +22,45 @@ class SoftmaxCrossEntropyWithLogits(Expander):
"""SoftmaxCrossEntropyWithLogits expander"""
def _expand(self, graph_builder):
"""
计算损失值和 logits 的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tuple[Tensor, Tensor]: 损失值和 logits 的梯度
"""
logits, label = self.inputs
# 计算 softmax_cross_entropy_with_logits(logits, label)
# softmax_cross_entropy_with_logits 的公式是: -reduce_sum(label * log(softmax(logits)))
# Calculate softmax_cross_entropy_with_logits(logits, label)
# formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits)))
axis = (-1,)
max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True})
# 计算 logits 的最大值
data_sub = graph_builder.emit('Sub', [logits, max_x])
# logits 减去最大值
data_exp = graph_builder.emit('Exp', [data_sub])
# 对上一步结果进行指数运算
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
# 对指数运算结果求和
data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum])
# 计算 softmax
const_eps = graph_builder.value(logits.dtype, 0.000001)
# 定义一个极小的常数,用于防止除以零的错误
data_softmax_safety = graph_builder.emit("Maximum", [data_softmax, const_eps])
# 确保 softmax 的值不为零
softmax_log = graph_builder.emit('Log', [data_softmax_safety])
# 对 softmax 结果取对数
label_mul_log = graph_builder.emit('Mul', [label, softmax_log])
# 将 label 与 softmax 的对数相乘
tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={
'reduce_axis': axis, 'keep_dims': False})
# 对上一步结果进行求和
loss = graph_builder.emit('Neg', [tmp_res])
# 计算损失值,即上一步结果的负值
dlogits = graph_builder.emit('Sub', [data_softmax, label])
# 计算 logits 的梯度
return loss, dlogits

@ -13,29 +13,48 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for SoftmaxGradExt"""
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
from ._utils import get_reduce_axis_shape
from mindspore._extends.graph_kernel.model.model import DataFormat as DF # 导入DataFormat类
from ._utils import Expander, ExpanderInfoValidator as VLD # 导入Expander和ExpanderInfoValidator类
from ._utils import get_reduce_axis_shape # 导入get_reduce_axis_shape函数
@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT) # 使用ExpanderInfoValidator类添加FRAC_NZ格式
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
@VLD.check_attrs('axis')
class SoftmaxGradExt(Expander):
"""SoftmaxGradExt expander"""
def _expand(self, graph_builder):
"""
对输入数据进行扩展处理
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 处理后的数据
"""
# 获取输入参数
x, y, z = self.inputs
# 获取指定的轴
axis = self.attrs['axis']
# 获取需要减少的轴和原始减少的形状
reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
# 将x和y相乘
data_mul = graph_builder.emit('Mul', [x, y])
# 对乘积进行求和,并保留维度
data_sum = graph_builder.emit('ReduceSum', [data_mul],
attrs={'reduce_axis': reduce_axis, 'keep_dims': True, 'reduce_output_fuse': True})
# 如果x的数据格式为FRAC_NZ则对求和结果进行重塑
if x.data_format == DF.FRAC_NZ:
data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape})
# 从x中减去求和结果
data_sub = graph_builder.emit('Sub', [x, data_sum])
# 将减法结果与y相乘
data_mul2 = graph_builder.emit('Mul', [data_sub, y])
# 将结果与z相乘得到最终结果
result = graph_builder.emit('Mul', [data_mul2, z])
return result

@ -21,9 +21,24 @@ class SqrtGrad(Expander):
"""SqrtGrad expander"""
def _expand(self, graph_builder):
"""
计算并返回给定输入 x 的平方根的梯度
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
Tensor: 返回给定输入 x 的平方根的梯度
"""
# 获取输入 x 和梯度 dout
# formula of sqrt_grad is dout / (2 * x)
x, dout = self.inputs
# 创建一个常数值 2
const_two = graph_builder.value(x.dtype, 2)
# 计算 2 * x
dividend = graph_builder.emit('Mul', [x, const_two])
# 计算梯度dout / (2 * x)
result = graph_builder.emit('RealDiv', [dout, dividend])
# 返回计算结果
return result

@ -21,24 +21,57 @@ class SquareSumAll(Expander):
"""SquareSumAll expander"""
def _check(self):
"""
检查输入是否合法
Args:
Returns:
Raises:
GKException: 如果输入的数量不等于2则抛出GKException异常
"""
"""check inputs"""
# 获取输入的数量
input_num = len(self.inputs)
if input_num != 2:
# 如果输入的数量不等于2则抛出异常
raise GKException("For 'SquareSumAll', the inputs number should be 2, but got {}.".format(input_num))
def _expand(self, graph_builder):
"""
对输入的两个变量进行扩展操作
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图构建过程中发射操作
Returns:
tuple: 包含两个元素的元组每个元素为扩展操作的结果
"""
"""do expand"""
# 获取输入的两个变量
x0 = self.inputs[0]
x1 = self.inputs[1]
# 获取x0的形状
ori_shape = x0.shape
# 初始化一个空列表,用于存储维度索引
axis = []
# 遍历ori_shape将每个维度的索引添加到axis列表中
for i, _ in enumerate(ori_shape):
axis.append(i)
# 对x0进行平方运算
square_res0 = graph_builder.emit('Mul', [x0, x0])
# 对x1进行平方运算
square_res1 = graph_builder.emit('Mul', [x1, x1])
# 对square_res0进行求和运算求和的维度为axis并保持维度不变
result0 = graph_builder.emit('ReduceSum', [square_res0], attrs={'reduce_axis': axis, 'keep_dims': False})
# 对square_res1进行求和运算求和的维度为axis并保持维度不变
result1 = graph_builder.emit('ReduceSum', [square_res1], attrs={'reduce_axis': axis, 'keep_dims': False})
return result0, result1

@ -25,13 +25,30 @@ class SquareSumV1(Expander):
"""Square expander"""
def _expand(self, graph_builder):
"""
计算输入张量的平方并沿指定轴进行求和
Args:
graph_builder (GraphBuilder): 图构建器对象
Returns:
Tensor: 计算得到的张量
"""
# 获取输入的第一个元素
x = self.inputs[0]
# 获取属性中的axis值
axis = self.attrs['axis']
# 获取需要reduce的axis和原始的reduced shape
reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
# 计算x的平方
square_res = graph_builder.emit('Mul', [x, x])
# 对平方结果进行ReduceSum操作
result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
# 如果数据格式为FRAC_NZ则对结果进行Reshape操作
if x.data_format == DF.FRAC_NZ:
result = graph_builder.emit('Reshape', [result], attrs={'shape': ori_reduced_shape})
# 返回最终结果
return result

@ -21,10 +21,24 @@ class SquaredDifference(Expander):
"""SquaredDifference expander"""
def _expand(self, graph_builder):
"""
根据输入的两个输入值计算并返回它们的平方差的计算结果
Args:
graph_builder (GraphBuilder): 图构建器对象用于在图中生成节点和边
Returns:
Node: 计算结果节点
"""
# 获取输入的第一个值
input_x = self.inputs[0]
# 获取输入的第二个值
input_y = self.inputs[1]
# 使用图构建器计算输入值的差值
sub_val = graph_builder.emit('Sub', [input_x, input_y])
# 使用图构建器计算差值的平方
result = graph_builder.emit('Mul', [sub_val, sub_val])
return result

@ -21,27 +21,67 @@ class Squeeze(Expander):
"""Squeeze expander"""
def _expand(self, graph_builder):
"""
扩展输入的维度
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建图结构
Returns:
Tensor: 扩展维度后的输入
"""
# 获取输入的第一个元素
input_x = self.inputs[0]
# 根据输入的shape和axis属性推断输出shape
out_shape = self.infer_shape(input_x.shape, self.attrs['axis'])
# 使用graph_builder发射Reshape操作并设置shape属性为out_shape
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': out_shape})
# 返回结果
return result
@staticmethod
def infer_shape(shape, axis):
"""
根据指定的axis推断squeeze后的shape
Args:
shape (list, tuple): 原始数据的shape
axis (int, list, tuple): 需要被squeeze的维度如果为int则只squeeze该维度
如果为list或tuple则squeeze列表或元组中的每个维度如果为空则squeeze所有维度为1的维度
Returns:
list: squeeze后的shape
Raises:
ValueError: 如果输入的axis类型不符合要求抛出异常
"""
"""infer shape for squeeze"""
def squeeze_axis(shape, axis):
# 如果axis为空移除shape中所有值为1的维度
if not axis:
out_shape = list(d for d in shape if d != 1)
else:
# 获取shape的维度数量
ndim = len(shape)
# 移除shape中指定的axis维度
out_shape = list(shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis))
# 如果out_shape为空则将其设置为[1]
if not out_shape:
out_shape = [1]
return out_shape
# 如果shape是列表或元组类型
if isinstance(shape, (list, tuple)):
# 如果axis是整数类型则将其转换为列表
if isinstance(axis, int):
axis = [axis]
# 如果axis是列表或元组类型则调用squeeze_axis函数处理
if isinstance(axis, (list, tuple)):
return squeeze_axis(shape, axis)
# 如果输入不符合要求,则抛出异常
raise ValueError("Invalid axis for Squeeze.")

@ -21,11 +21,29 @@ class TanhGrad(Expander):
"""TanhGrad expander"""
def _expand(self, graph_builder):
"""
计算1减去输入值的平方后再与输入的导数相乘
Args:
graph_builder (GraphBuilder): 图构建器对象用于构建计算图
Returns:
Tensor: 计算结果类型为Tensor
"""
# 获取输入值
input_y, input_dy = self.inputs
# 创建一个值为1的常量数据类型与input_y相同
const_one = graph_builder.value(input_y.dtype, 1)
# 计算input_y的平方
double_y = graph_builder.emit('Mul', [input_y, input_y])
# 计算1减去input_y的平方
one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y])
# 计算input_dy与1减去input_y的平方的乘积
result = graph_builder.emit('Mul', [input_dy, one_sub_double_y])
return result

@ -25,30 +25,48 @@ class Tile(Expander):
def _get_output_shape(self):
"""Get output shape"""
# 获取输入形状的列表
shape = list(self.inputs[0].shape)
# 获取属性"multiples"的列表
multiples = list(self.attrs["multiples"])
# 计算"multiples"和输入形状的长度差
diff_len = len(multiples) - len(shape)
# 如果长度差小于0抛出异常
if diff_len < 0:
raise GKException("For 'Tile', dimensions of attr 'multiples' should be greater than or equal to "
"dimensions of input shape, but got {} and {}".format(multiples, shape))
# 如果长度差大于0则扩展输入形状的列表
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)
# 初始化输出形状的列表
output_shape = []
# 遍历输入形状和multiples的元组
for sh, mul in list(zip(shape, multiples)):
# 如果输入形状和multiples的值都不为1则抛出异常
if sh != 1 and mul != 1:
raise GKException("For 'Tile', input shape{} and attr 'multiples'{} can not broadcast."
.format(self.inputs[0].shape, multiples))
# 计算维度
dim = sh * mul
# 将计算得到的维度添加到输出形状的列表中
output_shape.append(dim)
# 返回输出形状的列表
return output_shape
def _expand(self, graph_builder):
# 获取输入的第一个元素
input_x = self.inputs[0]
# 获取输出形状
output_shape = self._get_output_shape()
# 使用graph_builder的emit方法生成BroadcastTo操作
# 参数为[input_x]和输出形状
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
# 返回结果
return result

@ -14,6 +14,9 @@
# ===========================================================================
"""GraphKernel cost model init"""
# 导入split模块
from .graph_split import split
# 导入GraphBuilder和load_composite模块
from .model_builder import GraphBuilder, load_composite
# 导入parallel_estimate模块
from .graph_parallel import parallel_estimate

@ -20,128 +20,299 @@ class ParalGain:
"""Paral Gain"""
def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info):
"""
类的构造函数
Args:
fusion_type (str): 融合类型
bottleneck (int): 瓶颈层的大小
gain (float): 增益值
block_assign (list): 块分配列表
type_info (dict): 类型信息字典
Returns:
None
"""
# 初始化融合类型
self.fusion_type = fusion_type
# 初始化瓶颈层
self.bottleneck = bottleneck
# 初始化增益
self.gain = gain
# 初始化块分配
self.block_assign = block_assign
# 初始化类型信息
self.type_info = type_info
class ScheduleAnalyzer:
"""schedule analyzer"""
# 定义一个常量表示wrap的大小
WRAP_SIZE = 32
# 定义一个常量表示最大SM数量
MAX_SM = 80 # Volta
# 定义一个常量,表示最大线程数量
MAX_NUM_THREADS = 1024
# 定义一个常量表示最大block数量
MAX_BLOCK = 256
# 定义一个常量,表示流水线操作的阈值
PIPELINE_OP_THREADHOLD = 5
def __init__(self, graph):
"""
初始化图处理类
Args:
graph (Graph): 图对象用于存储图的结构和参数
Attributes:
graph (Graph): 图对象存储图的结构和参数
block_num (int): 块的数量初始值为0
block_weight (float): 块的权重初始值为0
ops (List[Operation]): 图的操作列表
dom_op (List[Operation]): 输出的每个操作对应的操作列表
"""
# 将传入的图对象赋值给实例变量graph
self.graph = graph
# 初始化block数量为0
self.block_num = 0
# 初始化block权重为0
self.block_weight = 0
# 通过图对象的deduce_parameters方法获取参数并赋值给outputs变量
_, outputs = graph.deduce_parameters()
# 将图对象的操作列表赋值给实例变量ops
self.ops = graph.ops
# 将outputs中的每个输出对应的操作收集到一个列表中并赋值给实例变量dom_op
self.dom_op = list(out.op for out in outputs)
@staticmethod
def prod(shape):
"""
计算形状乘积
Args:
shape (list): 一个包含整数的列表表示形状
Returns:
int: 形状乘积的结果
"""
"""Compute shape product"""
# 初始化结果变量为shape的第一个元素
res = shape[0]
# 遍历shape列表从第二个元素开始
for i in range(1, len(shape)):
# 将当前结果与shape的下一个元素相乘
res = res * shape[i]
# 返回计算后的结果
return res
def _cal_weight(self, ops):
"""
计算给定操作列表的总权重
Args:
ops (list): 包含多个操作对象的列表
Returns:
int: 所有操作的权重总和
"""
weight = 0
for op in ops:
# 遍历每个操作
weight += self.prod(op.output.shape) * \
PrimLib.dtype_bytes(op.output.dtype)
PrimLib.dtype_bytes(op.output.dtype) # 计算op的输出数据类型的字节数
return weight
def injective_analyze(self):
"""
分析单射情况
Args:
Returns:
"""
"""analyze injective case"""
# 计算常量大小
const_size = max((self.prod(op.output.shape) for op in self.dom_op))
# 调整常量大小确保是MAX_NUM_THREADS的倍数
const_size = (const_size + self.MAX_NUM_THREADS -
1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS
# 计算总权重
total_weight = self._cal_weight(self.ops)
# 计算总块数
total_block = (const_size + self.MAX_NUM_THREADS -
1) // self.MAX_NUM_THREADS
# 判断是否需要分割块
need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS
if need_block_split:
# 如果需要分割块设置块数为MAX_BLOCK
self.block_num = self.MAX_BLOCK
# 计算波数
waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK
# 计算块权重
self.block_weight = total_weight // total_block * waves
else:
# 如果不需要分割块,设置块数为总块数
self.block_num = total_block
# 计算块权重
self.block_weight = total_weight // self.block_num
def reduce_analyze(self):
"""
分析reduce操作
Args:
Returns:
Raises:
RuntimeError: 如果并行融合不支持多个reduce操作或者没有找到reduce操作
"""
"""analyze reduce case"""
# 定义线程数
thread_x, thread_y = 32, 32
reduce_op = None
for op in self.ops:
# 判断操作类型是否为reduce
if PrimLib.iter_type(op) == PrimLib.REDUCE:
# 如果已经存在reduce操作则抛出异常
if reduce_op:
raise RuntimeError("Parallel fusion does not support multiple reduce op now.")
reduce_op = op
# 如果没有找到reduce操作则抛出异常
if not reduce_op:
raise RuntimeError("Parallel fusion does not find a reduce op.")
# 获取reduce操作的输入形状
shape = reduce_op.inputs[0].shape
# 获取reduce操作的reduce轴
reduce_axis = reduce_op.attrs['reduce_axis']
# 计算总空间
total_space = self.prod(shape)
# 计算reduce空间
red_space = shape[reduce_axis[0]]
for i in range(1, len(reduce_axis)):
red_space *= shape[reduce_axis[i]]
# 获取数据类型大小
dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype)
# 计算权重
weight = self._cal_weight(self.ops) # reduce + injective
# 计算block_x
block_x = (total_space // red_space + thread_y - 1) // thread_y
# 计算block_w
block_w = (weight + block_x - 1) // block_x
# 计算waves
waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK
# 设置block_num
self.block_num = min(self.MAX_BLOCK, block_x)
# 定义all_reduce
all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write
# 计算block_weight
self.block_weight = (block_w + all_reduce *
dtype_size * thread_x * thread_y) * waves
def default_analyze(self):
"""
默认分析函数
Args:
Returns:
Raises:
"""
"""analyze default case"""
# 定义一个内部函数,用于计算默认空间
def _cal_default_space(op):
# 计算op的输出空间
space = self.prod(op.output.shape)
# 遍历op的所有输入
for t in op.inputs:
# 计算输入的空间
size = self.prod(t.shape)
# 如果输入空间大于当前空间,则更新空间
if size > space:
space = size
# 返回计算出的空间
return space
# 计算所有操作中的最大空间
space = max((_cal_default_space(op) for op in self.dom_op))
# each sm least 4 wrap
# 每个sm至少包含4个wrap
# 计算所需的block数量
block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4)
# 将block数量限制在最大block数量之内
self.block_num = min(self.MAX_BLOCK, block)
# 计算每个block的权重
self.block_weight = self._cal_weight(self.ops) // self.block_num
def analyze(self):
"""analyze ops"""
def _ops_type(ops, dom_op):
"""
判断操作列表中是否包含reduce操作
Args:
ops (list): 操作列表
dom_op (list): 操作列表
Returns:
bool: 如果操作列表中包含reduce操作则返回True否则返回False
"""
# 检查ops列表中是否有reduce操作
have_reduce = any(
# 如果op的类型是PrimLib.REDUCE则返回True
(PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops))
if have_reduce:
# 如果有reduce操作返回True
return True
# 否则返回dom_op[0]的类型
return PrimLib.iter_type(dom_op[0])
# 调用_ops_type函数获取dom_op的类型
dom_type = _ops_type(self.ops, self.dom_op)
# 如果dom_type是PrimLib.ELEMWISE或PrimLib.BROADCAST类型
if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
# 调用injective_analyze方法
self.injective_analyze()
# 如果dom_type是PrimLib.REDUCE类型
elif dom_type == PrimLib.REDUCE:
# 调用reduce_analyze方法
self.reduce_analyze()
# 如果dom_type是其他类型
else:
# 调用default_analyze方法
self.default_analyze()
def suitable_to_pipeline(self):
"""judge whether is suitable to be pipeline optimized"""
# 判断是否适合进行流水线优化
# Reduce操作不适合
# Reduce is not suitable
def _contain_reduce(ops):
for op in ops:
# Reduce操作可能导致分片效果差
# Reduce may make the tiling bad.
if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE:
return True
@ -149,6 +320,7 @@ class ScheduleAnalyzer:
suitable = True
if _contain_reduce(self.ops):
# 如果包含Reduce操作则不适合进行流水线优化
suitable = False
return suitable
@ -166,13 +338,16 @@ class ScheduleAnalyzer:
classes (list[list[int]]): The list of clusters. Each cluster is a list of indices.
"""
def _cal_mean(classes):
# 计算每个聚类的均值
class_datas = list(list(data[cid] for cid in cls) for cls in classes)
return list(sum(cls) / len(cls) if cls else float('inf') for cls in class_datas)
def _cal_distance(a, b):
# 计算两个元素之间的距离
return abs(a - b)
def _check_different(old_classes, new_classes):
# 检查新旧聚类是否不同
for o, n in zip(old_classes, new_classes):
if o != n:
return True
@ -201,31 +376,39 @@ class ScheduleAnalyzer:
min_idx = i if min_dis > cur_dis else min_idx
min_dis = cur_dis if min_dis > cur_dis else min_dis
new_classes[min_idx].append(idx)
# 检查聚类是否发生变化
changed = _check_different(classes, new_classes)
# 更新聚类
classes = new_classes
return classes
@staticmethod
def pipeline_fusion_analyze(blocks, op_sizes, exclude_id):
"""analyze whether the segments can be pipeline optimized"""
# op size first, block second.
# op size first, block second。
# 操作大小在前,块在后
def _simple_factor(block, op_size):
return block + 5 * op_size
def _take_second(elem):
return elem[1]
# 计算每个块的简单因子
simple_indicators = list(_simple_factor(b, s)
for b, s in zip(blocks, op_sizes))
# 2 classes, one heavy, the other light
# 两类,一类重,一类轻
classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id)
if not classes:
return []
# 计算每类的均值
means = list(sum([simple_indicators[idx] for idx in cls]) /
len(cls) if cls else float('inf') for cls in classes)
# The target two clusters should be a heavy one and a light one.
# 目标两类应该是一类重的和一类轻的
# The light one maybe suitable to run with pipeline optimized.
# 轻的一类可能适合进行流水线优化
classes_infos = list([cls, m] for cls, m in zip(classes, means))
classes_infos.sort(key=_take_second)
pipeline_target = None
@ -234,6 +417,7 @@ class ScheduleAnalyzer:
pipeline_target = ci
break
pipeline_gids, pipeline_mean = pipeline_target
# 如果轻的一类的均值大于某个阈值,则返回空列表
if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks),
ScheduleAnalyzer.PIPELINE_OP_THREADHOLD):
return []
@ -241,6 +425,7 @@ class ScheduleAnalyzer:
pipeline_blocks = []
pipeline_weight = len(pipeline_gids)
# Try to make two paralleled at least.
# 至少尝试两个并行
if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2:
if len(pipeline_gids[:pipeline_weight // 2]) > 1:
pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2])
@ -252,49 +437,114 @@ class ScheduleAnalyzer:
@staticmethod
def fusion_consult(blocks, op_sizes, exclude_gid):
"""
获取并行融合的建议
Args:
blocks (list): 包含多个计算块的列表
op_sizes (list): 每个操作的尺寸列表
exclude_gid (int): 需要排除的组ID
Returns:
tuple: 包含融合类型和类型信息的元组
Raises:
"""
"""get a recommendation for parallel fusion"""
# 默认是块融合
# Default is block fusion
fusion_type = "block_fusion"
type_info = None
# 禁用管道优化
activate_pipeline_optimization = False # Disable pipeline optimization for now.
# 如果启用管道优化
if activate_pipeline_optimization:
# 对块、操作大小和排除组ID进行管道融合分析
pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze(
blocks, op_sizes, exclude_gid)
# 如果存在管道信息
if pipeline_info:
# 融合类型为块管道融合
fusion_type = "block_pipeline_fusion"
# 设置类型信息为管道信息
type_info = pipeline_info
return fusion_type, type_info
def block_parallel_estimate(graphs):
"""
估计块并行增益
Args:
graphs (list): 图集合每个元素是一个图对象
Returns:
ParalGain: 包含块并行增益信息的ParalGain对象
"""
"""estimate block parallel gain"""
# 初始化变量
sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], []
# 遍历图集合
for gid, g in enumerate(graphs):
# 创建ScheduleAnalyzer对象
s = ScheduleAnalyzer(g)
# 分析图
s.analyze()
# 累加块的数量
sum_block += s.block_num
# 更新最大权重
if s.block_weight > max_weight:
max_weight = s.block_weight
# 累加权重
sum_weight += s.block_weight
# 添加块的数量到blocks列表
blocks.append(s.block_num)
# 添加操作数量到op_sizes列表
op_sizes.append(len(s.ops))
# 如果不适合流水线处理将gid添加到exclude_gid列表
if not s.suitable_to_pipeline():
exclude_gid.append(gid)
# 如果块的数量大于ScheduleAnalyzer.MAX_SM * 32返回"none"
if sum_block > ScheduleAnalyzer.MAX_SM * 32:
return ParalGain("none", sum_weight, 0, list(0 for _ in graphs), None)
# 获取融合类型和类型信息
fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid))
# 返回ParalGain对象
return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info)
def parallel_estimate(graphs, target):
"""
并行估计函数
Args:
graphs (list): 图结构列表
target (str): 目标类型例如"aicore"
Returns:
ParalGain: 并行增益对象
"""
"""Estimate parallel gain"""
# 如果目标是"aicore"
if target == "aicore":
# 融合类型为"block_fusion"
fusion_type = "block_fusion"
# 类型信息为空
type_info = None
# 假设估计值为1000
fake_estimate = 1000
# 生成一个与graphs长度相同的列表每个元素都是1
fake_blocks = list(1 for g in graphs)
# 返回ParalGain对象
return ParalGain(fusion_type, fake_estimate, fake_estimate, fake_blocks, type_info)
# 调用block_parallel_estimate函数进行并行估计
return block_parallel_estimate(graphs)

@ -24,39 +24,61 @@ class Utils:
@staticmethod
def get_attr_type(attr):
"""Get attr type"""
# 判断attr是否为bool类型
if isinstance(attr, bool):
return 'bool'
# 判断attr是否为str类型
if isinstance(attr, str):
return 'str'
# 判断attr是否为int类型
if isinstance(attr, int):
return 'int'
# 判断attr是否为float类型
if isinstance(attr, float):
return 'float'
# 判断attr是否为list或tuple类型
if isinstance(attr, (list, tuple)):
# 判断attr是否为空
if not attr:
raise ValueError("attr is invalid: the length of attr is 0")
# 判断attr的第一个元素是否为int类型
if isinstance(attr[0], int):
return 'listInt'
# 判断attr的第一个元素是否为str类型
if isinstance(attr[0], str):
return 'listStr'
# 如果attr的类型不在支持的列表中则抛出异常
raise ValueError("attr {} type {} is not in supported list ['bool', 'str', 'int', 'float', 'int' list, "
"'str' list]".format(attr, type(attr)))
class DataFormat:
"""DataFormat"""
# 默认格式
DEFAULT = "DefaultFormat"
# NC1KHKWHWC0格式
NC1KHKWHWC0 = "NC1KHKWHWC0"
# ND格式
ND = "ND"
# NCHW格式
NCHW = "NCHW"
# NHWC格式
NHWC = "NHWC"
# HWCN格式
HWCN = "HWCN"
# NC1HWC0格式
NC1HWC0 = "NC1HWC0"
# FRAC_Z格式
FRAC_Z = "FracZ"
# FRAC_NZ格式
FRAC_NZ = "FRACTAL_NZ"
# C1HWNCOC0格式
C1HWNCOC0 = "C1HWNCoC0"
# NC1HWC0_C04格式
NC1HWC0_C04 = "NC1HWC0_C04"
# FRACTAL_Z_C04格式
FRACTAL_Z_C04 = "FRACTAL_Z_C04"
# NDHWC格式
NDHWC = "NDHWC"
def __init__(self):
@ -65,29 +87,47 @@ class DataFormat:
class DataType:
"""Data Type"""
# 浮点型
FLOAT = "float"
# 半精度浮点型
FLOAT16 = "float16"
# 单精度浮点型
FLOAT32 = "float32"
# 双精度浮点型
FLOAT64 = "float64"
# 整型
INT = "int"
# 8位整型
INT8 = "int8"
# 16位整型
INT16 = "int16"
# 32位整型
INT32 = "int32"
# 64位整型
INT64 = "int64"
# 无符号整型
UINT = "uint"
# 8位无符号整型
UINT8 = "uint8"
# 16位无符号整型
UINT16 = "uint16"
# 32位无符号整型
UINT32 = "uint32"
# 64位无符号整型
UINT64 = "uint64"
# 布尔型
BOOL = "bool"
# 初始化函数
def __init__(self):
# 无需执行任何操作
pass
class PrimLib:
"""Prim lib"""
# 定义PrimLib类中的常量
UNKNOWN = 0
RESHAPE = 1
ELEMWISE = 2
@ -102,53 +142,73 @@ class PrimLib:
"""Prim"""
def __init__(self, iter_type, calibrate=1, relation_func=None):
# 初始化Prim类设置iter_type、calibrate和relation_func属性
self.iter_type = iter_type
self.calibrate = calibrate
self.relation_func = relation_func
if relation_func is None:
# 如果relation_func为None则设置默认的relation_func
self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
def default_reshape_relation(self, op, input_idx):
"""Process reshape relation"""
# 处理reshape关系
axis_relation, elem_relation = self.unknown_relation(op, input_idx)
# 将elem_relation设置为PrimLib.RESHAPE
elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
return axis_relation, elem_relation
def default_elemwise_broadcast_relation(self, op, input_idx):
"""Process elemwise and broadcast relation"""
# 处理elemwise和broadcast关系
out_shape = op.output.shape
in_shape = op.inputs[input_idx].shape
# 如果输出形状的长度小于输入形状的长度,则抛出异常
if len(out_shape) < len(in_shape):
raise ValueError("For '{}', the input/output size is abnormal, as the length of output shape{} is less "
"than the length of input shape{}".format(op.prim, out_shape, in_shape))
axis_relation, elem_relation = [], []
# 计算输出形状和输入形状的长度差
delta = len(out_shape) - len(in_shape)
if delta > 0:
# 如果输出形状的长度大于输入形状的长度则在axis_relation和elem_relation中添加None
for i in range(0, delta):
axis_relation.append(None)
elem_relation.append(None)
# 遍历输入形状的每个元素
for i, _ in enumerate(in_shape):
# 在axis_relation中添加当前元素的索引
axis_relation.append(i)
# 如果输出形状的对应元素等于输入形状的对应元素则elem_relation添加PrimLib.ELEMWISE否则添加PrimLib.BROADCAST
elem_relation.append(
PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
return axis_relation, elem_relation
def default_reduce_relation(self, op, input_idx):
"""Process reduce relation"""
# 处理reduce关系
axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
# 遍历reduce_axis中的每个元素
for i in op.attrs['reduce_axis']:
# 将elem_relation中对应元素的值设置为PrimLib.REDUCE
elem_relation[i] = PrimLib.REDUCE
return axis_relation, elem_relation
def unknown_relation(self, op, input_idx):
"""Process unknown relation"""
# 获取输出和输入的形状
out_shape = op.output.shape
in_shape = op.inputs[input_idx].shape
# 获取所有可能的轴关系
all_relation = list(range(len(in_shape)))
# 初始化轴关系列表
axis_relation = [all_relation for i in range(0, len(out_shape))]
# 初始化元素关系列表
elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
# 返回轴关系和元素关系
return axis_relation, elem_relation
# 默认的关系函数列表
default_relation_func = [
unknown_relation,
default_reshape_relation,
@ -158,6 +218,7 @@ class PrimLib:
unknown_relation,
]
# 定义基本操作
primtives = {
'Add': Prim(ELEMWISE),
'Abs': Prim(ELEMWISE),
@ -239,7 +300,9 @@ class PrimLib:
@classmethod
def get_prim(cls, op):
"""Get op primtive"""
# 从cls.primtives中获取op.prim对应的prim
prim = cls.primtives.get(op.prim, None)
# 如果prim为None则打印警告信息并返回cls.default_primtive
if prim is None:
print('[WARN] primtive is not registered: ' + op.prim)
prim = cls.default_primtive
@ -248,50 +311,65 @@ class PrimLib:
@classmethod
def input_relation(cls, op, input_idx):
"""Get op's input_relation according to input_idx"""
# 调用cls.get_prim(op)获取op对应的prim然后调用prim的relation_func方法获取op的input_relation
return cls.get_prim(op).relation_func(op, input_idx)
@classmethod
def iter_type(cls, op):
"""Get op's iter type"""
# 调用cls.get_prim(op)获取op对应的prim然后返回prim的iter_type
return cls.get_prim(op).iter_type
@classmethod
def is_reduce(cls, op):
"""Check whether op's iter type is reduce"""
# 调用cls.get_prim(op)获取op对应的prim然后判断prim的iter_type是否为cls.REDUCE
return cls.get_prim(op).iter_type == cls.REDUCE
@classmethod
def calibrate_iter_size(cls, op, iter_size):
"""Get calibrate_iter_size"""
# 调用cls.get_prim(op)获取op对应的prim然后返回prim的calibrate乘以iter_size
return cls.get_prim(op).calibrate * iter_size
@classmethod
def dtype_bytes(cls, dtype):
"""Get dtype bytes"""
# 初始化bits和unit为1
bits, unit = 1, 1
# 从dtype的最后一个字符开始向前遍历
for i in range(len(dtype) - 1, 0, -1):
# 如果当前字符是数字则将bits加上当前字符对应的数字乘以unit并将unit乘以10
if dtype[i].isdecimal():
bits += int(dtype[i]) * unit
unit *= 10
# 如果当前字符不是数字,则跳出循环
else:
break
# 返回bits除以8的结果
return bits // 8
@classmethod
def inplace_reuse(cls, op, input_idx, start_axis=0):
"""Check whether op is inplace reuse"""
# 如果op.output.dtype的字节数大于op.inputs[input_idx].dtype的字节数则返回False
if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
return False
# 调用cls.get_prim(op)获取op对应的prim然后调用prim的relation_func方法获取op的input_relation
_, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
# 从start_axis开始遍历elem_relation
for i in range(start_axis, len(elem_relation)):
# 如果elem_relation中的元素不等于cls.ELEMWISE则返回False
if elem_relation[i] != cls.ELEMWISE:
return False
# 如果以上条件都不满足则返回True
return True
class Tensor:
"""Tensor"""
# 参数类型常量
PARA_NONE = 0
PARA_INPUT = 1
PARA_OUTPUT = 2
@ -303,6 +381,7 @@ class Tensor:
self.members = [leader]
def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
# 初始化Tensor对象
self.name = name
self.shape = shape
self.dtype = dtype
@ -313,13 +392,16 @@ class Tensor:
self.buddy = None
def __str__(self):
# 返回Tensor对象的字符串表示
return self.name + str(list(self.shape))
def __repr__(self):
# 返回Tensor对象的字符串表示
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
def get_size(self):
"""Get size"""
# 获取Tensor对象的大小
size = PrimLib.dtype_bytes(self.dtype)
for i in self.shape:
size *= i
@ -327,6 +409,7 @@ class Tensor:
def add_buddy(self, tensor):
"""Add buddy"""
# 添加buddy
if self.buddy is None:
self.buddy = self.Buddy(self)
self.buddy.members.append(tensor)
@ -337,6 +420,7 @@ class Value:
"""Value"""
def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
# 初始化Value对象
self.name = name
self.shape = [1]
self.dtype = dtype
@ -344,14 +428,17 @@ class Value:
self.data_format = data_format
def __str__(self):
# 返回Value对象的字符串表示
return self.name + str(list(self.shape))
def __repr__(self):
# 返回Value对象的字符串表示
return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
@staticmethod
def get_size():
"""Get size"""
# 获取Value对象的大小
return 1
@ -359,23 +446,31 @@ class Operator:
"""Operator"""
def __init__(self, primtive, inputs, output, attrs):
# 初始化Operator对象
self.prim = primtive
self.inputs = inputs
self.output = output
self.attrs = attrs
# 将当前Operator对象添加到每个输入的to_ops列表中
for t in inputs:
t.to_ops.append(self)
# 如果输出的op属性为None则将当前Operator对象赋值给输出的op属性
if output.op is None:
output.op = self
# 初始化all_inputs列表用于存储Tensor输入和Value输入
self.all_inputs = [] # include Tensor inputs and Value inputs.
def __str__(self):
# 将self.all_inputs中的元素转换为字符串并用逗号连接起来
args = ', '.join((str(t) for t in self.all_inputs))
# 构造表达式字符串
expr = "%s = %s.%s(%s) id:%s" % (
str(self.output), self.prim, self.output.dtype, args, id(self))
# 如果self.attrs不为空则返回表达式字符串和self.attrs的字符串连接
return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
def __repr__(self):
# 返回self的字符串表示
return str(self)
@ -383,6 +478,7 @@ class Graph:
"""Graph"""
def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
# 初始化Graph对象
self.name = name
self.ops = ops # in topo order, can not use set
self.inputs = []
@ -393,10 +489,12 @@ class Graph:
def set_processor(self, processor):
"""Set processor"""
# 设置处理器
self.processor = processor
def add(self, ops):
"""Add ops"""
# 添加操作
if isinstance(ops, Operator):
self.ops.append(ops)
else:
@ -404,101 +502,148 @@ class Graph:
def extract_subgraph(self, graph_name, tensor_names, difference=False):
"""Extract subgraph from this graph"""
# 从当前图中提取子图
graph = Graph(graph_name, [])
outputs = set(tensor_names)
if difference:
# 如果difference为True则提取不在outputs中的操作
for op in self.ops:
if op.output.name not in outputs:
graph.add(op)
else:
# 如果difference为False则提取在outputs中的操作
for op in self.ops:
if op.output.name in outputs:
graph.add(op)
outputs.remove(op.output.name)
# 如果outputs中还有元素则抛出异常
for name in outputs:
raise ValueError("Invalid input tensor : {}, can not find it in graph".format(name))
return graph
def deduce_parameters(self):
"""Deduce parameters"""
# 初始化输入和输出列表
inputs, outputs = [], []
# 遍历所有操作
for op in self.ops:
# 遍历操作的所有输入
for t in op.inputs:
# 如果输入不在输入列表中,且输入的操作不在操作列表中,则将输入添加到输入列表中
if t not in inputs and t.op not in self.ops:
inputs.append(t)
# 如果操作输出已经在输出列表中,则跳过
if op.output in outputs:
continue
# 如果操作输出是输出参数类型,或者操作输出没有后续操作,则将操作输出添加到输出列表中
if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
outputs.append(op.output)
continue
# 如果操作输出的后续操作不在操作列表中,则将操作输出添加到输出列表中
if any((succ not in self.ops for succ in op.output.to_ops)):
outputs.append(op.output)
# 如果有指定的输入,则将指定的输入赋值给输入列表
if self.inputs:
inputs = self.inputs
# 如果有指定的输出,则将指定的输出赋值给输出列表
if self.outputs:
outputs = self.outputs
# 返回输入和输出列表
return inputs, outputs
def __str__(self):
# 调用deduce_parameters方法获取输入和输出列表
inputs, outputs = self.deduce_parameters()
# 将输入列表转换为字符串
para_str = ', '.join((repr(t) for t in inputs))
# 将输出列表转换为字符串
out_str = ', '.join((repr(t) for t in outputs))
# 初始化行列表
lines = []
# 添加操作名称、输入和输出到行列表中
lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
# 如果有拼接信息,则添加拼接操作和拼接原子操作到行列表中
if self.stitch_info:
if self.stitch_info.stitch_ops:
lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops))
if self.stitch_info.stitch_atomic_ops:
lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
# 遍历所有操作,将操作添加到行列表中
for op in self.ops:
lines.append(' ' + str(op))
# 添加结束符号到行列表中
lines.append('}')
# 将行列表转换为字符串并返回
return '\n'.join(lines)
def __repr__(self):
# 返回对象的字符串表示
return str(self)
def dump(self):
"""Dump Graph to json"""
# 将Graph转换为json格式
attr_name = {'reduce_axis': 'axis'}
# 获取Graph的输入和输出参数
inputs, outputs = self.deduce_parameters()
input_desc, output_desc, op_desc = [], [], []
# 遍历输入参数
for t in inputs:
# 将输入参数转换为字典格式
input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
# 遍历输出参数
for t in outputs:
# 将输出参数转换为字典格式
output_desc.append({'data_type': t.dtype, 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format})
# 遍历Graph中的操作
for op in self.ops:
attrs, in_desc = [], []
# 遍历操作中的属性
for a in op.attrs:
# 获取属性名
name = attr_name.get(a, a)
# 将属性转换为字典格式
attrs.append(
{'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
# 遍历操作中的输入
for t in op.all_inputs:
# 如果输入是Tensor类型
if isinstance(t, Tensor):
# 将输入转换为字典格式
in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
else:
# 将输入转换为字典格式
in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
'tensor_name': t.name, 'format': t.data_format}])
# 将操作输出转换为字典格式
out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
'tensor_name': op.output.name, 'format': op.output.data_format}]
# 将操作转换为字典格式
op_desc.append({'attr': attrs, 'impl_path': '',
'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
# 将Graph转换为字典格式
graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
'platform': 'AKG', 'process': self.processor}
# 如果Graph中有stitch信息
if self.stitch_info and self.stitch_info.stitch_ops:
# 将stitch信息转换为字典格式
buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
# 如果有stitch_atomic_ops
if self.stitch_info.stitch_atomic_ops:
# 将stitch_atomic_ops转换为字典格式
buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
# 将stitch信息添加到Graph字典中
graph_desc['buffer_stitch'] = buffer_stitch
# 返回Graph字典
return graph_desc
@ -506,13 +651,16 @@ class GraphVisitor:
"""Graph visitor"""
def __init__(self, forward=True):
# 初始化forward参数默认为True
self.forward = forward
def visit_graph(self, graph):
"""Visit graph"""
# 如果forward为True则按照顺序遍历graph中的ops
if self.forward:
for op in graph.ops:
self.visit(op)
# 如果forward为False则按照逆序遍历graph中的ops
else:
for i in range(len(graph.ops)-1, -1, -1):
self.visit(graph.ops[i])
@ -528,12 +676,18 @@ class AlignShape(GraphVisitor):
def visit(op):
"""Visit op node"""
prim = PrimLib.get_prim(op)
# 如果op的迭代类型是ELEMWISE、BROADCAST或REDUCE则需要进行形状对齐
if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
# 获取op的输出维度
out_dim = len(op.output.shape)
# 初始化对齐维度为输出维度
align_dim = out_dim
# 遍历op的输入
for t in op.inputs:
# 如果输入的维度大于对齐维度,则更新对齐维度
if len(t.shape) > align_dim:
align_dim = len(t.shape)
# 如果对齐维度大于输出维度则对op的输出形状进行对齐
if align_dim > out_dim:
op.output.shape = [1] * (align_dim - out_dim) + op.output.shape

@ -25,90 +25,140 @@ class GraphBuilder:
"""Graph wrapper"""
def __init__(self, name):
"""
初始化类实例
Args:
name (str): 图的名称
Attributes:
self.graph (Graph): 图的实例使用传入的名称初始化
"""
self.graph = Graph(name, [])
def set_input(self, *para):
"""set input to graph inputs"""
# 遍历传入的参数
for t in para:
# 设置参数类型为输入参数
t.para_type = Tensor.PARA_INPUT
# 将参数添加到图的输入列表中
self.graph.inputs.append(t)
def set_output(self, *para):
"""set output to graph inputs"""
# 遍历传入的参数
for t in para:
# 设置参数类型为输出参数
t.para_type = Tensor.PARA_OUTPUT
# 将参数添加到图的输出列表中
self.graph.outputs.append(t)
def __init__(self):
# 初始化图列表
self.graphs = []
# 当前图设置为None
self.current = None
# 初始化名称ID
self.name_id = 0
def _alloc_tensor_name(self):
# 获取当前名称ID
tid = self.name_id
# 名称ID加1
self.name_id += 1
# 格式化字符串,生成张量名称
return "t%d" % (tid)
def graph_scope(self, name):
"""The graph scope to be processed"""
# 定义GraphScope类
class GraphScope:
"""Graph Scope"""
def __init__(self, gb):
# 初始化GraphScope对象接收一个GraphBuilder对象
self.gb = gb
def __enter__(self):
# 当使用with语句进入GraphScope上下文时调用
return self.gb.current
def __exit__(self, ptype, value, trace):
# 当离开GraphScope上下文时调用
self.gb.graphs.append(self.gb.current.graph)
self.gb.current = None
# 检查self.current是否不为None
if self.current is not None:
raise ValueError("self.current is not None!")
# 创建GraphWrapper对象并赋值给self.current
self.current = self.GraphWrapper(name)
# 返回GraphScope对象
return GraphScope(self)
def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE):
"""Create a new Tensor"""
"""创建一个新的张量"""
# 如果名称为空或None则分配一个新的张量名称
if name in (None, ''):
# 分配一个新的张量名称
name = self._alloc_tensor_name()
# 如果shape为空则默认设置为[1]
if not shape:
shape = [1]
# 返回创建好的张量对象
return Tensor(name, shape, dtype, data_format, para_type=para_type)
def value(self, dtype, value, name=None):
"""Create a new Value"""
# 如果name为None或空字符串
if name in (None, ''):
# 分配一个新的tensor名称
name = self._alloc_tensor_name()
# 创建一个新的Value对象
v = Value(name, dtype, value)
# 返回创建的Value对象
return v
def op(self, prim, output, inputs, attrs=None):
"""Insert an operator into graph"""
# 如果 attrs 为 None则将其设置为空字典
if attrs is None:
attrs = {}
# 如果 inputs 是 Tensor 类型,则将其转换为列表
if isinstance(inputs, Tensor):
inputs = [inputs]
# 过滤出 inputs 中 Tensor 类型的元素
tensor_inputs = [t for t in inputs if isinstance(t, Tensor)]
# 创建一个 Operator 对象
node = Operator(prim, tensor_inputs, output, attrs)
# 将所有输入保存到 node 的 all_inputs 属性中
node.all_inputs = inputs
# 将 node 添加到当前图的节点列表中
self.current.graph.add(node)
def emit(self, prim, inputs, name=None, attrs=None):
"""Emit a new operation"""
# 如果attrs为None则初始化为空字典
if attrs is None:
attrs = {}
# 如果inputs是Tensor或Value的实例则将其转换为列表
if isinstance(inputs, (Tensor, Value)):
inputs = [inputs]
# 过滤出inputs中的Tensor和Value实例
tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
# 调用op_infer.infer函数进行形状、数据类型和格式的推断
out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
# 创建一个新的Tensor实例作为输出
output = self.tensor(out_shape, out_dtype, out_format, name)
# 执行操作并将结果存储在output中
self.op(prim, output, inputs, attrs)
# 返回操作的结果
return output
def get(self):
"""Get graphs"""
# 返回self.graphs
return self.graphs
@ -116,16 +166,21 @@ class CompositeGraph:
"""Composite Graph"""
def __init__(self):
# 初始化图对象默认为None
self.graph = None
# 初始化描述信息默认为None
self.desc = None
self.tensors = {} # name : Tensor
# 初始化张量字典,默认为空字典
self.tensors = {}
def refine(self):
"""Refine Graph"""
# 对图进行形状对齐操作
AlignShape().visit_graph(self.graph)
def load(self, desc):
"""Load Graph from json"""
# 定义一个内部函数,用于处理操作属性
def _attr_of(op):
if not op['attr']:
return dict()
@ -134,21 +189,29 @@ class CompositeGraph:
if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin', 'Argmax', 'Argmin'):
attr['reduce_axis'] = a['value']
else:
# 将属性添加到字典中
attr[a['name']] = a['value']
return attr
# 创建GraphBuilder对象
builder = GraphBuilder()
# 在描述的操作范围内构建图
with builder.graph_scope(desc['op']):
# 遍历输入描述并构建输入张量
for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []:
name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[
0]['shape'], in_desc[0]['data_type'], in_desc[0]['format']
# 将输入张量添加到tensors字典中
self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT)
# 遍历输出描述并构建输出张量
for out_desc in desc['output_desc']:
name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[
'shape'], out_desc['data_type'], out_desc['format']
# 将输出张量添加到tensors字典中
self.tensors[name] = builder.tensor(
shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT)
# 遍历操作描述并构建操作
for op in desc['op_desc']:
inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d]
out_desc = op['output_desc']
@ -164,52 +227,77 @@ class CompositeGraph:
if not output:
output = builder.tensor(shape, dtype, data_format, name=name)
self.tensors[name] = output
# 构建操作并添加到图中
builder.op(op['name'], output, inputs, attrs=_attr_of(op))
# 获取构建好的图
self.graph = builder.get()[0]
self.desc = desc
def add_stitch_info(self, subgraph, desc):
"""add stitch info to desc"""
# 如果subgraph包含stitch信息且stitch_ops不为空
if subgraph.stitch_info and subgraph.stitch_info.stitch_ops:
# 创建一个字典用于存储stitch操作信息
buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)}
# 如果subgraph包含stitch_atomic_ops信息
if subgraph.stitch_info.stitch_atomic_ops:
# 将stitch_atomic_ops信息添加到buffer_stitch字典中
buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops)
# 将buffer_stitch信息添加到desc字典中
desc['buffer_stitch'] = buffer_stitch
return desc
def add_recompute_ops(self, subgraph, desc):
"""add recompute ops to desc"""
# 如果subgraph中包含需要重新计算的操作
if subgraph.recompute_ops:
# 将需要重新计算的操作的输出名称添加到desc中
desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops]
return desc
def _pre_dump(self, outputs):
"""restore name to before load"""
# 创建一个空字典用于存储inplace赋值操作
inplace_assign = {} # y_name, output_name
inplace_assign_z = None
# 遍历self.desc['op_desc']中的操作
for op in self.desc['op_desc']:
# 如果操作名称为'InplaceAssign'
if op['name'] == 'InplaceAssign':
# 将inplace赋值操作的输入tensor名作为键输出tensor名作为值存入inplace_assign字典
inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name']
# 如果inplace_assign字典不为空
if inplace_assign:
# 遍历outputs中的tensor
for t in outputs:
# 如果当前tensor的名称不在inplace_assign字典中
if t.name not in inplace_assign:
# 将当前tensor赋值给inplace_assign_z
inplace_assign_z = t
# 返回inplace_assign和inplace_assign_z
return inplace_assign, inplace_assign_z
def dump(self, subgraph):
"""Dump Graph to json"""
desc = {}
# 获取输入和输出参数
inputs, outputs = subgraph.deduce_parameters()
# 获取图中的所有操作
graph_ops = set(subgraph.ops)
# 预处理输出参数
inplace_assign, inplace_assign_z = self._pre_dump(outputs)
def dump_output(t):
# 如果输出参数是原地赋值操作的结果
if t.name in inplace_assign:
z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name]
# 返回包含数据类型、形状和张量名称的字典
return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)}
# 返回包含数据类型、形状和张量名称的字典
return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}
def dump_op_desc(d):
# 如果操作是原地赋值操作
if d['name'] == 'InplaceAssign':
y = d['input_desc'][1][0]['tensor_name']
if self.tensors[y].op in graph_ops:
@ -222,33 +310,50 @@ class CompositeGraph:
z_desc['tensor_name'] = z.name
out_desc['shape'] = z.shape
out_desc['data_type'] = z.dtype
# 返回处理后的原地赋值操作描述
return inplace_desc
# 获取操作对应的张量
op = self.tensors[d['output_desc'][0]['tensor_name']].op
# 如果操作在图操作集或重新计算操作集中
if op in graph_ops or op in subgraph.recompute_ops:
# 返回操作描述
return d
# 返回None
return None
for key in self.desc.keys():
if key == 'input_desc':
# 处理输入描述
desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs]
elif key == 'output_desc':
# 处理输出描述
desc[key] = list(map(dump_output, outputs))
elif key == 'op_desc':
# 处理操作描述
op_desc = map(dump_op_desc, self.desc[key])
desc[key] = [d for d in op_desc if d is not None]
elif key == 'op':
# 处理操作名称
desc[key] = subgraph.name
else:
# 处理其他描述
desc[key] = self.desc[key]
# 添加缝合信息
desc = self.add_stitch_info(subgraph, desc)
# 添加重新计算操作信息
desc = self.add_recompute_ops(subgraph, desc)
# 返回最终描述
return desc
def load_composite(desc):
"""Load composite kernel"""
# 创建一个CompositeGraph对象
composite = CompositeGraph()
# 加载描述信息
composite.load(desc)
# 对加载的CompositeGraph进行细化
composite.refine()
# 返回处理后的CompositeGraph对象
return composite

@ -25,24 +25,32 @@ def infer(op_name, inputs, attrs):
"""infer shape dtype and format"""
def _create_opinfer():
# 获取当前模块
self_module = sys.modules.get(__name__, None)
# 如果当前模块为空,则抛出异常
if self_module is None:
raise GKException("OpInfo does not support op {}".format(op_name))
# 如果当前模块有op_name属性则获取该属性
if hasattr(self_module, op_name):
op_cls = getattr(self_module, op_name)
return op_cls(op_name, inputs, attrs)
# common infer
# 定义一个字典将PrimLib中的iter_type映射到对应的类名
class_name_map = {
PrimLib.ELEMWISE: "_Elemwise",
PrimLib.REDUCE: "_Reduce",
}
# 获取op_name对应的iter_type
cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None)
# 如果没有对应的iter_type则抛出异常
if not cls_name:
raise GKException("OpInfo does not support op {}".format(op_name))
# 获取对应的类
op_cls = getattr(self_module, cls_name)
return op_cls(op_name, inputs, attrs)
# 返回infer方法
return _create_opinfer().infer()
@ -55,49 +63,65 @@ class OpInfer:
"""
def __init__(self, name, inputs, attrs):
# 初始化函数传入参数name、inputs、attrs
self.name = name
self.inputs = inputs
self.attrs = attrs
def infer(self):
"""Infer shape, type and format by op inputs"""
# 根据op的输入推断shape、type和format
self._check()
return self._infer_shape(), self._infer_type(), self._infer_format()
def _infer_shape(self):
# 根据op的输入推断shape
return self.inputs[0].shape
def _infer_type(self):
# 根据op的输入推断type
return self.inputs[0].dtype
def _infer_format(self):
# 根据op的输入推断format
return self.inputs[0].data_format
def _check(self):
# 检查shape、type和format
self._check_shape()
self._check_type()
self._check_format()
def _check_shape(self):
# 检查shape
pass
def _check_type(self):
"""check all dtypes are same"""
# 获取第一个输入的dtype
dtype = self.inputs[0].dtype
# 遍历剩下的输入
for i, t in enumerate(self.inputs[1:]):
# 如果当前输入的dtype与第一个输入的dtype不同则抛出异常
if t.dtype != dtype:
raise GKException(
"Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype))
def _check_format(self):
"""check formats are compatible. only DefaultFormat is compatible with others"""
# 获取第一个输入的data_format
result = self.inputs[0].data_format
# 初始化i为0
i = 0
# 遍历剩下的输入
for j, t in enumerate(self.inputs[1:]):
# 如果当前输入的data_format与第一个输入的data_format不同则进行判断
if t.data_format != result:
# 如果第一个输入的data_format和当前输入的data_format都不是DefaultFormat则抛出异常
if DF.DEFAULT not in (result, t.data_format):
raise GKException("Incompatible format between input {}({}) and {}({})".format(
i, result, j + 1, t.data_format))
# 如果第一个输入的data_format是DefaultFormat则将result设置为当前输入的data_format并将i设置为j+1
if result == DF.DEFAULT:
result = t.data_format
i = j + 1
@ -109,17 +133,26 @@ class _Elemwise(OpInfer):
@staticmethod
def broadcast_shape(shapes):
"""deduce broadcast shape using same rules as numpy"""
# 计算所有shape的最大维度
dim_size = max(len(shape) for shape in shapes)
# 将所有shape扩展到最大维度不足的部分用1填充
align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
# 初始化输出shape为全1
out_shape = [1] * dim_size
# 遍历每个维度
for i in range(dim_size):
# 遍历每个shape
for align_shape in align_shapes:
# 如果当前维度为1则跳过
if align_shape[i] == 1:
continue
# 如果输出shape当前维度为1则将输出shape当前维度设置为当前shape当前维度的值
if out_shape[i] == 1:
out_shape[i] = align_shape[i]
# 如果输出shape当前维度和当前shape当前维度不相等则抛出异常
elif out_shape[i] != align_shape[i]:
raise GKException("Input shapes {} can not broadcast.".format(shapes))
# 返回输出shape
return out_shape
@staticmethod
@ -174,9 +207,12 @@ class _Elemwise(OpInfer):
.format(inputs_format))
def _infer_format(self):
# 遍历输入张量
for tensor in self.inputs:
# 如果张量的数据格式不是默认格式,则返回该数据格式
if tensor.data_format != DF.DEFAULT:
return tensor.data_format
# 如果所有输入张量的数据格式都是默认格式,则返回默认格式
return DF.DEFAULT
@ -184,6 +220,7 @@ class _Reduce(OpInfer):
"""Common infer for reduction operators"""
def _check(self):
# 调用父类的方法
super(_Reduce, self)._check()
# check reduce axis in the range [-len, len)
shape_len = len(self.inputs[0].shape)
@ -195,21 +232,29 @@ class _Reduce(OpInfer):
"Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis))
def _infer_shape(self):
# 深度拷贝输入的形状
shape = copy.deepcopy(self.inputs[0].shape)
# 获取reduce_axis属性
axis = self.attrs['reduce_axis']
# 如果axis是整数则将其转换为列表
if isinstance(axis, int):
axis = [axis]
# 如果axis中的元素小于0则将其转换为非负数
if any(i < 0 for i in axis):
# change the axis to non-negative number.
# 将axis中的负数转换为正数
axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis))
# 将axis排序
self.attrs['reduce_axis'] = sorted(axis)
# 如果keep_dims为True则将axis中的维度设置为1
if self.attrs['keep_dims']:
for i in axis:
shape[i] = 1
return shape
# 如果keep_dims为False则将axis中的维度从shape中移除
real_shape = []
for i, s in enumerate(shape):
if i not in axis:
@ -223,10 +268,14 @@ class _Reduce(OpInfer):
class _Reshape(OpInfer):
"""Common infer for reshape operators, should not be instantiated"""
# 定义一个函数,用于推断形状
def _infer_shape(self):
# 抛出一个异常,提示子类需要实现这个函数
raise GKException("_infer_shape should be implemented by subclass")
# 定义一个函数,用于推断格式
def _infer_format(self):
# 如果attrs中不存在"format"这个属性则返回DF.DEFAULT否则返回attrs中"format"的值
return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"]
@ -234,14 +283,20 @@ class Reshape(_Reshape):
"""Reshape op infer"""
def _check_shape(self):
# 获取输入形状
input_shape = self.inputs[0].shape
# 获取输出形状
output_shape = self.attrs["shape"]
# 计算输入形状的乘积
size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape)
# 计算输出形状的乘积
size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape)
# 如果输入形状的乘积不等于输出形状的乘积,则抛出异常
if size_before_reshape != size_after_reshape:
raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape))
def _infer_shape(self):
# 返回输出形状
return self.attrs["shape"]
@ -249,6 +304,7 @@ class Cast(_Elemwise):
"""Cast op infer"""
def _infer_type(self):
# 返回dst_type属性
return self.attrs["dst_type"]
@ -256,29 +312,38 @@ class InplaceAssign(_Elemwise):
"""InplaceAssign op infer"""
def _infer_shape(self):
# 返回第3个输入的shape属性
return self.inputs[2].shape
def _infer_type(self):
# 返回第3个输入的dtype属性
return self.inputs[2].dtype
def _infer_format(self):
# 返回第3个输入的data_format属性
return self.inputs[2].data_format
class BroadcastTo(OpInfer):
"""BroadcastTo op infer"""
# 定义一个函数,用于推断形状
def _infer_shape(self):
# 返回self.attrs字典中的"shape"键对应的值
return self.attrs["shape"]
# 定义一个函数,用于推断格式
def _infer_format(self):
# 返回self.inputs列表中第一个元素的data_format属性
return self.inputs[0].data_format
class _CompareOp(_Elemwise):
"""Compare operators"""
# 定义一个函数,用于推断类型
def _infer_type(self):
# 返回类型为bool
return "bool"
@ -286,11 +351,14 @@ class CImag(OpInfer):
"""CImag op infer"""
def _check_type(self):
# 检查输入数据的类型是否为complex64
if self.inputs[0].dtype != "complex64":
# 如果不是,则抛出异常
raise GKException(
"For 'CImag', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype))
def _infer_type(self):
# 返回数据类型为float32
return "float32"
@ -298,11 +366,14 @@ class CReal(OpInfer):
"""CReal op infer"""
def _check_type(self):
# 检查输入数据的类型是否为complex64
if self.inputs[0].dtype != "complex64":
# 如果不是,则抛出异常
raise GKException(
"For 'CReal', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype))
def _infer_type(self):
# 返回数据类型为float32
return "float32"
@ -310,14 +381,17 @@ class Complex(OpInfer):
"""Complex op infer"""
def _check_type(self):
# 检查输入数据的类型是否为float32
if self.inputs[0].dtype != "float32":
raise GKException(
"For 'Complex', input[0] should be of type float32, but got {}".format(self.inputs[0].dtype))
# 检查输入数据的类型是否一致
if self.inputs[0].dtype != self.inputs[1].dtype:
raise GKException("For 'Complex', inputs data type mismatch ({} vs {})"
.format(self.inputs[0].dtype, self.inputs[1].dtype))
def _infer_type(self):
# 返回复数类型
return "complex64"
@ -345,40 +419,53 @@ class Select(_Elemwise):
"""Select op infer"""
def _check_type(self):
# 检查输入数据的类型
if self.inputs[0].dtype != "bool":
# 如果输入数据的类型不是bool则抛出异常
raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype))
if self.inputs[1].dtype != self.inputs[2].dtype:
# 如果输入数据的类型不一致,则抛出异常
raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})"
.format(self.inputs[1].dtype, self.inputs[2].dtype))
def _infer_type(self):
# 推断输入数据的类型
return self.inputs[1].dtype
def check_format_any(formats, checked_format):
"""Check whether input format in formats list"""
# 检查输入格式是否在formats列表中
if not isinstance(formats, (list, tuple)):
# 如果formats不是list或tuple类型则抛出异常
raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats)))
if checked_format not in formats:
# 如果checked_format不在formats列表中则抛出异常
raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats))
def check_nd(data, nd):
"""Check whether data are nd format"""
# 检查数据是否为nd格式
if not isinstance(data, (list, tuple)) or len(data) != nd:
# 如果数据不是list或tuple类型或者数据的维度不等于nd则抛出异常
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
def conv_had_pad(pad_list, pad_mode):
"""Check whether conv need to add pad"""
# 检查pad_list是否为4D list或tuple如果不是则抛出异常
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
# 如果pad_list的前两个元素不相等或后两个元素不相等则返回True
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
return True
# 如果pad_mode不是"VALID"或"valid"则遍历pad_list如果有元素不为0则返回True
if pad_mode not in ["VALID", "valid"]:
for _, pad in enumerate(pad_list):
if pad != 0:
return True
# 否则返回False
return False
@ -386,38 +473,50 @@ class Conv2D(OpInfer):
"""Conv2D infer"""
def _infer_type(self):
# 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值否则返回输入的第一个元素的dtype
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"]
return self.inputs[0].dtype
def _infer_shape(self):
# 将输入的第一个和第二个元素的shape转换为list
shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape)
# 检查shape_0和shape_1的维度是否为4
check_nd(shape_0, 4)
check_nd(shape_1, 4)
# 检查输入的data_format是否为NHWC
formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]]
check_format_any(formats, DF.NHWC)
# 获取输入的n、h、w和out_channel
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
# 获取pad_list和pad_mode
pad_list = self.attrs["pad_list"]
pad_mode = self.attrs["pad_mode"]
# 获取kernel_size、stride和dilation
kernel_size = self.attrs["kernel_size"]
stride = self.attrs["stride"]
dilation = self.attrs["dilation"]
# 检查pad_list、kernel_size、stride和dilation的维度是否为4、2、4和4
check_nd(pad_list, 4)
check_nd(kernel_size, 2)
check_nd(stride, 4)
check_nd(dilation, 4)
# 调用conv_had_pad函数判断是否需要pad
has_pad = conv_had_pad(pad_list, pad_mode)
# 如果不需要pad则将pad_list设置为[0, 0, 0, 0]
if not has_pad:
pad_list = [0, 0, 0, 0]
# 计算输出的h和w
k_h = (kernel_size[0] - 1) * dilation[-2] + 1
k_w = (kernel_size[1] - 1) * dilation[-1] + 1
out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
# 返回输出的shape
return [n, out_h, out_w, out_channel]
@ -425,23 +524,31 @@ class MatMul(OpInfer):
"""MatMul infer"""
def _infer_type(self):
# 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值否则返回输入的第一个元素的dtype
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
return self.attrs["dst_type"]
return self.inputs[0].dtype
def _infer_shape(self):
# 将输入的第一个和第二个元素的shape转换为list
shape_0 = list(self.inputs[0].shape)
shape_1 = list(self.inputs[1].shape)
# 检查shape_0和shape_1的维度是否为2
if len(shape_0) != 2 or len(shape_1) != 2:
raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}"
.format(shape_0, shape_1))
# 获取transpose_a和transpose_b
transpose_a = self.attrs["transpose_a"]
transpose_b = self.attrs["transpose_b"]
# 根据transpose_a和transpose_b获取m、k1、k2和n
m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1])
k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1])
# 如果k1和k2不相等则抛出异常
if k1 != k2:
raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2))
# 计算输出的shape
output_shape = [m, n]
# 返回输出的shape
return output_shape
@ -449,14 +556,20 @@ class PadAkg(OpInfer):
"""PadAkg infer"""
def _infer_shape(self):
# 将输入的第一个元素的shape转换为list
shape = list(self.inputs[0].shape)
# 获取输入的维度
n = len(shape)
# 获取pad_before和pad_after
pad_before = list(self.attrs["head"])
pad_after = list(self.attrs["tail"])
# 检查pad_before和pad_after的维度是否与输入的维度相等
if len(pad_before) != n or len(pad_after) != n:
raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d"
.format(n, len(pad_before), len(pad_after)))
# 计算输出的shape
out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
# 返回输出的shape
return out_shape
@ -464,13 +577,19 @@ class UnPadAkg(OpInfer):
"""UnPadAkg infer"""
def _infer_shape(self):
# 将输入的第一个元素的shape转换为list
shape = list(self.inputs[0].shape)
# 获取输入的维度
n = len(shape)
# 获取unpad_after
unpad_after = list(self.attrs["tail"])
# 检查unpad_after的维度是否与输入的维度相等
if len(unpad_after) != n:
raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d"
.format(n, len(unpad_after)))
# 计算输出的shape
out_shape = [shape[i] - unpad_after[i] for i in range(n)]
# 返回输出的shape
return out_shape
@ -478,23 +597,32 @@ class Gather(OpInfer):
"""Gather infer"""
def _infer_shape(self):
# 获取输入的第一个和第二个元素的shape
input_shape = self.inputs[0].shape
indices_shape = self.inputs[1].shape
# 获取axis
axis = self.attrs['axis']
# 将输出的shape设置为输入的第一个元素的shape
output_shape = input_shape
# 计算indices_shape的维度
indices_shape_one_dim = 1
for dim in indices_shape:
indices_shape_one_dim *= dim
# 将输出的shape的axis维度设置为indices_shape的维度
output_shape[axis] = indices_shape_one_dim
# 返回输出的shape
return output_shape
def _infer_type(self):
# 返回输入的第一个元素的dtype
return self.inputs[0].dtype
def _infer_format(self):
# 返回输入的第一个元素的data_format
return self.inputs[0].data_format
def _check_type(self):
# 检查输入的第二个元素的dtype是否为int32如果不是则抛出异常
if self.inputs[1].dtype != "int32":
raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}"
.format(self.inputs[1].dtype))

@ -21,24 +21,48 @@ from . import model
def estimate_ops(json_str):
"""
估计操作数
Args:
json_str (str): 包含图描述的json字符串
Returns:
tuple: 包含估计结果的元组包括块分配增益融合类型和类型信息的元组
Raises:
JSONDecodeError: 如果输入的json字符串无法解码将引发此异常
"""
"""Call cost model to estimate ops."""
try:
# 将json字符串转换为json对象
json_obj = json.loads(json_str)
# 获取json对象中的graph_desc
graph_descs = json_obj["graph_desc"]
# 初始化graphs和target
graphs = []
target = None
# 遍历graph_descs
for gd in graph_descs:
# 如果target为空则将gd['process']赋值给target
if target is None:
target = gd['process']
# 如果target不为空且gd['process']与target不同则输出错误信息
elif target != gd['process']:
logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process']))
return None
# 将model.load_composite(gd).graph添加到graphs中
graphs.append(model.load_composite(gd).graph)
# 调用model.parallel_estimate函数传入graphs和target获取estimation
estimation = model.parallel_estimate(graphs, target)
# 将estimation的block_assign、gain、fusion_type和type_info赋值给res
res = (estimation.block_assign, estimation.gain,
estimation.fusion_type, estimation.type_info)
# 返回res
return res
except jd.JSONDecodeError:
# 如果出现JSONDecodeError则输出错误信息
logger.error(traceback.format_exc())
return None
finally:
@ -46,14 +70,33 @@ def estimate_ops(json_str):
def estimate_calculation_amount(json_str):
"""
估计操作计算量的函数
Args:
json_str (str): 包含操作描述的JSON字符串
Returns:
int: 计算量的估计值如果解析JSON字符串失败则返回-1
Raises:
"""
"""Call cost model to estimate calculation amount of op."""
try:
# 将json字符串转换为json对象
graph_desc = json.loads(json_str)
# 获取json对象中的process
target = graph_desc['process']
# 调用model.load_composite函数传入graph_desc获取comp
comp = model.load_composite(graph_desc)
# 调用model.parallel_estimate函数传入comp.graph和target获取estimation
estimation = model.parallel_estimate([comp.graph], target)
# 返回estimation的bottleneck
return estimation.bottleneck
except jd.JSONDecodeError:
# 如果出现JSONDecodeError则输出错误信息
logger.error(traceback.format_exc())
return -1
finally:

@ -24,6 +24,20 @@ from . import utils
def split_with_json(json_str, flags_str):
"""
根据JSON字符串分割GraphKernel
Args:
json_str (str): 包含GraphKernel描述的JSON字符串
flags_str (str): 包含分割标志的JSON字符串
Returns:
str: 包含分割结果的JSON字符串
Raises:
jd.JSONDecodeError: 如果json_str或flags_str无法被解析为JSON格式将引发此异常
"""
"""Call cost model to split GraphKernel"""
try:
graph_desc = json.loads(json_str)
@ -45,6 +59,21 @@ def split_with_json(json_str, flags_str):
def _reset_graphmode_for_inplaceassign(graph_list, graph_mode):
"""
重置具有 InplaceAssign 操作符的图模式
Args:
graph_list (list): 包含图的列表每个图都是一个包含操作描述的字典
graph_mode (list): 图模式列表每个元素表示对应图的模式
Returns:
None
Notes:
具有 InplaceAssign 操作符的操作应始终为复合操作
对于包含 InplaceAssign 操作符的图将其模式设置为 'composite'
"""
"""Operator with InplaceAssign should always be composite op"""
for i, g in enumerate(graph_list):
if any((op['name'] == 'InplaceAssign' for op in g['op_desc'])):
@ -52,6 +81,20 @@ def _reset_graphmode_for_inplaceassign(graph_list, graph_mode):
def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode):
"""
将分割信息以文本形式输出
Args:
flags (dict): 包含配置信息的字典
graph_json (str): 图结构的JSON字符串
graph_desc (object): 图描述对象
subgraphs (list): 子图列表
graph_mode (list): 图模式列表
Returns:
None
"""
"""Dump split info as text"""
if not flags.get("dump_as_text", False):
return

@ -19,6 +19,19 @@ GRAPH_KERNEL_DUMP_PATH = "graph_kernel_dump"
def create_dir(pathname):
"""
尝试创建目录
Args:
pathname (str): 要创建的目录的路径
Returns:
None
Raises:
不显式抛出异常
"""
"""Try to create directory"""
if os.path.exists(pathname):
return

@ -16,4 +16,8 @@
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
扩展函数
这些Python函数将在MindSpore的C++部分中被调用
"""

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""akg process"""
import os
import json
import subprocess
@ -24,10 +26,10 @@ from mindspore._extends.parallel_compile.akg_compiler.get_file_path import get_a
def _compile_akg_task_default(json_strs, attrs):
"""
compile func called in single process
编译函数在单个进程中调用
Parameters:
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
参数
json_strs列表包含多个内核信息的列表适用于json编译API
"""
sys.path.insert(0, get_akg_path())
@ -37,15 +39,15 @@ def _compile_akg_task_default(json_strs, attrs):
for json_str in json_strs:
res = func(json_str, attrs)
if not res:
raise ValueError("Compile error, args: {}! build attrs: {}".format(json_str, attrs))
raise ValueError("编译错误,参数:{}!构建属性:{}".format(json_str, attrs))
def _compile_akg_task_ascend(json_strs, attrs):
"""
compile func called in single process
编译函数在单个进程中调用
Parameters:
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
参数
json_strs列表包含多个内核信息的列表适用于json编译API
"""
if attrs is None:
attrs = "{}"
@ -56,35 +58,33 @@ def _compile_akg_task_ascend(json_strs, attrs):
if compile_result.returncode:
json_dict = json.loads(json_str)
if not json_dict.get("composite"):
raise ValueError("Compile error, json str: {}! build attrs: {}".format(json_str, attrs))
logger.debug("Will try to split, json str: {}! build attrs: {}".format(json_str, attrs))
raise ValueError("编译错误json字符串{}!构建属性:{}".format(json_str, attrs))
logger.debug("将尝试拆分json字符串{}!构建属性:{}".format(json_str, attrs))
def create_akg_parallel_process(process_num, wait_time, platform):
"""
create AkgParallelCompiler object
创建AkgParallelCompiler对象
Returns:
返回
AkgParallelCompiler
"""
return AkgProcess(process_num, wait_time, platform)
class AkgProcess:
"""akg kernel parallel process"""
"""akg内核并行进程"""
def __init__(self, process_num, wait_time, platform):
"""
Args:
process_num: int. processes number
wait_time: int. max time the function blocked
参数
process_numint进程数量
wait_timeint函数阻塞的最大时间
"""
if not isinstance(process_num, int):
raise ValueError("AKG kernel compiling process number must be of type int, but got {} with type {}"
.format(process_num, type(wait_time)))
raise ValueError("AKG内核编译进程数量必须是int类型但得到的是{},类型为{}".format(process_num, type(wait_time)))
if not isinstance(wait_time, int):
raise ValueError("AKG kernel compiling wait time must be of type int, but got {} with type {}"
.format(wait_time, type(wait_time)))
raise ValueError("AKG内核编译等待时间必须是int类型但得到的是{},类型为{}".format(wait_time, type(wait_time)))
if process_num == 0:
process_num = 1
max_proc_num = 16
@ -96,13 +96,12 @@ class AkgProcess:
def compile(self, attrs=None):
"""
compile kernel by multi processes
Return:
True for all compile success, False for some failed.
多进程编译内核
返回
所有编译成功返回True部分失败返回False
"""
if self.argc == 0:
raise ValueError("In AKG kernel compiling, the number of kernel json that need to be compiled can "
"not be zero.")
raise ValueError("在AKG内核编译中需要编译的内核json数量不能为零。")
args = list((arg, attrs) for arg in self.args)
if self.platform == "ASCEND":
with Pool(processes=self.process_num) as pool:
@ -116,12 +115,11 @@ class AkgProcess:
def accept_json(self, json_str):
"""
accept json data before compile
Args:
json_str: str. kernel info.
在编译前接受内核的json数据
参数
json_strstr内核信息
"""
if not isinstance(json_str, str):
raise ValueError("In AKG kernel compiling, the kernel json must be of type str, but got {} with type {}"
.format(json, type(json)))
raise ValueError("在AKG内核编译中内核json必须是str类型但得到的是{},类型为{}".format(json_str, type(json_str)))
self.args[self.argc % self.process_num].append(json_str)
self.argc += 1

@ -28,16 +28,23 @@ def run_compiler(op_json, attrs=None):
None
"""
from get_file_path import get_akg_path
# 将akg路径添加到sys.path中
sys.path.insert(0, get_akg_path())
# 导入akg模块
p = __import__("akg", globals(), locals(), ['ms'], 0)
# 获取akg.ms.compilewithjson函数
func = getattr(p.ms, "compilewithjson")
# 调用akg.ms.compilewithjson函数进行编译
res = func(op_json, attrs)
# 如果编译失败,抛出异常
if not res:
raise ValueError("Compile error")
if __name__ == "__main__":
# 如果命令行参数大于2则调用run_compiler函数传入op_json和attrs
if len(sys.argv) > 2:
run_compiler(sys.argv[1], sys.argv[2])
# 否则只传入op_json
else:
run_compiler(sys.argv[1])

@ -19,18 +19,27 @@ import os
def get_akg_path():
"""get akg directory base path"""
# 提示信息如果找不到mindspore模块请检查1MindSpore是否成功编译。2MindSpore是否成功安装使用pip install安装或设置环境变量PYTHONPATH为${mindspore_build_dir}/package
hint = "Please check: 1) whether MindSpore is compiled successfully. " \
"2) Whether MindSpore is installed successfully with pip install or " \
"the path ${mindspore_build_dir}/package is set in env PYTHONPATH."
# 查找mindspore模块
search_res = importlib.util.find_spec("mindspore")
if search_res is None:
# 如果找不到mindspore模块抛出异常
raise RuntimeError("Cannot find mindspore module! {}".format(hint))
# 获取mindspore模块的路径
res_path = search_res.origin
# 在路径中查找__init__.py文件
find_pos = res_path.find("__init__.py")
if find_pos == -1:
# 如果找不到__init__.py文件抛出异常
raise RuntimeError("Find module mindspore origin file failed! {}".format(hint))
# 获取akg路径
akg_path = "{}_akg".format(res_path[:find_pos])
# 如果akg路径不存在抛出异常
if not os.path.isdir(akg_path):
raise RuntimeError("Cannot find akg from mindspore module! {}".format(hint))
# 返回akg路径
return akg_path

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""tbe adapter to adapt te/topi/auto-tune python api """
# 导入必要的库和模块
import json
import os
import shutil
@ -20,33 +21,62 @@ import sys
import traceback
from datetime import datetime
# 导入TBE相关的库和模块
from tbe.common.rl_bank.bank_manager import set_current_op_name
from tbe.common.repository_manager.interface import cann_kb_unload, cann_kb_load
from tbe.common.rl_bank.bank_cfg import LocalLock
from te.platform.cce_conf import te_set_version
from te.platform.cce_policy import set_L1_info
from te_fusion.compile_task_manager import dispatch_prebuild_task, dispatch_single_op_compile_task, import_py_module, \
dispatch_fusion_op_compile_task, dispatch_autotune_task, sync_op_tune_params
from te_fusion.compile_task_manager import sync_syspath
from te_fusion.fusion_manager import call_op_func, clear_fusion_params, check_op_impl_mode, \
save_op_params, build_single_op_from_c, op_params_to_json
from te_fusion.compile_task_manager import (
dispatch_prebuild_task,
dispatch_single_op_compile_task,
import_py_module,
dispatch_fusion_op_compile_task,
dispatch_autotune_task,
sync_op_tune_params,
sync_syspath
)
from te_fusion.fusion_manager import (
call_op_func,
clear_fusion_params,
check_op_impl_mode,
save_op_params,
build_single_op_from_c,
op_params_to_json
)
from te_fusion.fusion_util import dump_fusion_json
from te_fusion.parallel_compilation import init_multi_process_env, start_ga_multi_process, deinit_multi_process_env, \
from te_fusion.parallel_compilation import (
init_multi_process_env,
start_ga_multi_process,
deinit_multi_process_env,
get_finished_compilation_task
from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \
adjust_custom_op_info, pack_op_args, get_module_name, get_real_op_debug_level
)
from .tbe_helper import (
get_soc_info,
assemble_op_args,
get_compute_op_list,
get_options_info,
get_fuzz_build_info,
adjust_custom_op_info,
pack_op_args,
get_module_name,
get_real_op_debug_level
)
from .tbe_job import TbeJob, JobStatus
PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"]
# 定义支持的平台标志
PLATFORM_FLAG = [
"Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"
]
# 定义Tune初始化函数
def _tune_init(job: TbeJob):
"""
Tune Initialize
:param job:
:return:
Tune初始化
:param job: TbeJob对象包含任务信息
:return: 初始化是否成功
"""
# 提取Soc信息和Tune信息
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
offline_tune = job.content["SocInfo"]["offlineTune"]
op_bank_update = job.content["SocInfo"]["op_bank_update"]
@ -54,11 +84,14 @@ def _tune_init(job: TbeJob):
tune_bank_path = job.content["TuneInfo"]["tune_bank_path"]
need_ga = bool("GA" in auto_tiling_mode)
need_rl = bool("RL" in auto_tiling_mode)
# 设置环境变量
if offline_tune:
os.environ["ENABLE_TUNE_DUMP"] = "TRUE"
if op_bank_update:
sync_op_tune_params("tbe.common.tiling.tiling_api", "reset_repository", False, "")
# 初始化Tune环境
if need_ga or need_rl or offline_tune:
res = __init_tune_env(job, need_ga)
if not res:
@ -66,6 +99,7 @@ def _tune_init(job: TbeJob):
else:
return True
# 设置Tune路径
if tune_dump_path:
os.environ["TUNE_DUMP_PATH"] = str(tune_dump_path)
if tune_bank_path:
@ -73,12 +107,12 @@ def _tune_init(job: TbeJob):
res = _creating_custom_path(job)
return res
# 定义CANN知识库加载函数
def _cann_kb_load(job: TbeJob):
"""
database load
:param job:
:return:
加载CANN知识库
:param job: TbeJob对象包含任务信息
:return: 加载是否成功
"""
soc_version = job.soc_version
core_num = job.core_num
@ -87,12 +121,12 @@ def _cann_kb_load(job: TbeJob):
res = cann_kb_load(soc_version, core_num, op_bank_path, kb_type)
return res
# 定义CANN知识库卸载函数
def _cann_kb_unload(job: TbeJob):
"""
database unload
:param job:
:return:
卸载CANN知识库
:param job: TbeJob对象包含任务信息
:return: 卸载是否成功
"""
if job is None:
return 0
@ -102,12 +136,12 @@ def _cann_kb_unload(job: TbeJob):
res = cann_kb_unload(soc_version, core_num, kb_type)
return res
# 定义移除缓存文件函数
def _remove_cache(job: TbeJob):
"""
:param job: remove cache file:[*.json, *.o, *.info, *.cce] when "op_debug_level" is "0"
op_debug_level: representation the env MS_COMPILER_OP_LEVEL
:return:
移除缓存文件
:param job: TbeJob对象包含任务信息
:return:
"""
op_debug_level = job.content["SocInfo"]["op_debug_level"]
op_debug_dir = job.content["SocInfo"]["op_debug_dir"]
@ -118,24 +152,30 @@ def _remove_cache(job: TbeJob):
real_path = os.path.join(root_path, "kernel_meta/")
shutil.rmtree(real_path)
# 定义创建目录函数
def __directory_creation(path, concat_path):
"""
Create directory
创建目录
:param path: 基础路径
:param concat_path: 需要连接的路径
:return: 创建后的完整路径
"""
path = os.path.join(path, concat_path)
if not os.path.isdir(path):
os.makedirs(path, 0o750)
return path
# 定义初始化Tune环境函数
def __init_tune_env(job, need_ga):
"""
Initialize tune env
初始化Tune环境
:param job: TbeJob对象包含任务信息
:param need_ga: 是否需要GA
:return: 初始化是否成功
"""
try:
import auto_tune.auto_tune_main as at_atm
from schedule_search.rl_online_tune import rl_tune_init # pylint: disable=unused-import
from schedule_search.rl_online_tune import rl_tune_init
if need_ga:
res = at_atm.ga_tune_init()
if not res:
@ -157,10 +197,13 @@ def __init_tune_env(job, need_ga):
finally:
pass
# 定义创建默认自定义路径函数
def __creating_default_custom_path(auto_tiling_mode, base_custom_path):
"""
Create default custom path
创建默认自定义路径
:param auto_tiling_mode: 自动平铺模式
:param base_custom_path: 基础自定义路径
:return:
"""
base_custom_path = __directory_creation(base_custom_path, "data")
tune_flag = []
@ -179,27 +222,40 @@ def __creating_default_custom_path(auto_tiling_mode, base_custom_path):
def _creating_custom_path(job):
"""
Create custom path
创建自定义路径用于存储和检索自定义算子的调优参数
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 自定义路径创建是否成功
"""
# 获取自动平铺模式
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
# 如果模式中包含"NO_TUNE",则不需要创建自定义路径
if "NO_TUNE" in auto_tiling_mode:
return True
# 获取调优参数的基础路径
base_custom_path = job.content["TuneInfo"]["tune_bank_path"]
tune_bank_flag = True
# 如果基础路径不存在则尝试从auto_tune模块获取
if not base_custom_path:
import auto_tune
base_custom_path = os.path.dirname(os.path.realpath(auto_tune.__file__))
base_custom_path = os.path.realpath(os.path.join(base_custom_path, "../../../"))
tune_bank_flag = False
# 检查基础路径是否存在
if not os.path.isdir(base_custom_path):
job.error("Check whether the tuning path [{}] exists.".format(base_custom_path))
return False
# 检查基础路径的权限
if not os.access(base_custom_path, os.R_OK | os.W_OK | os.X_OK):
job.error("Check whether the permission on the tuning path [{}] is correct.".format(base_custom_path))
return False
# 如果不需要创建调优参数库,则直接返回成功
if not tune_bank_flag:
return __creating_default_custom_path(auto_tiling_mode, base_custom_path)
return True
@ -207,22 +263,34 @@ def _creating_custom_path(job):
def _parallel_compilation_init(initialize: TbeJob):
"""
Tbe parallel compilation initialize
:param initialize:
:return:
初始化TBE并行编译环境
Args:
initialize (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 并行编译环境初始化是否成功
"""
# 设置并行编译器的环境变量
os.environ["TE_PARALLEL_COMPILER"] = str(initialize.content["process_num"])
# 获取SoC信息
soc_info = get_soc_info(initialize.content)
# 获取实际的调试级别
real_debug_level = get_real_op_debug_level(initialize.content)
# 获取自动平铺模式
auto_tiling_mode = initialize.content["SocInfo"]["autoTilingMode"]
# 获取是否需要离线调优
offline_tune = initialize.content["SocInfo"]["offlineTune"]
# 生成进程ID和时间戳的组合字符串
pid_ts = "{}_pid{}".format(datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3], os.getpid())
# 初始化多进程环境
ret = init_multi_process_env(False, soc_info, auto_tiling_mode, real_debug_level,
None, 1, pid_ts)
if ret is None:
initialize.error("Init multiprocess env failed")
return False
initialize.info("Init multiprocess env success with {} process".format(ret[0]))
# 如果需要RL或离线调优则初始化RL环境
if "RL" in auto_tiling_mode or offline_tune:
res_queue = ret[1]
live_checker = ret[2]
@ -234,6 +302,7 @@ def _parallel_compilation_init(initialize: TbeJob):
initialize.error("RL env init failed!")
return False
initialize.info("RL Tune init success.")
# 如果需要GA则启动GA多进程
if "GA" in auto_tiling_mode:
start_ga_multi_process(auto_tiling_mode)
initialize.info("GA Tune init success.")
@ -242,31 +311,44 @@ def _parallel_compilation_init(initialize: TbeJob):
def tbe_initialize(job: TbeJob):
"""
Tbe Initialize
:param job:
:return:
初始化TBE环境
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: TBE环境初始化是否成功
"""
# 设置上下文模型编译环境变量
os.environ["CONTEXT_MODELCOMPILING"] = "TRUE"
# 获取SoC信息
soc_info = get_soc_info(job.content)
# 设置版本
res = te_set_version(*soc_info)
if not res:
job.error("Set version failed")
# 初始化调优环境
res = _tune_init(job)
if not res:
job.error("Tune init failed")
# 创建锁文件
lock_file = os.path.join(job.content["SocInfo"]["op_debug_dir"], "kernel_meta", "file.lock")
local_lock = LocalLock(lock_file)
try:
# 加锁
local_lock.lock()
# 加载CANN知识库
res = _cann_kb_load(job)
if res == 1:
job.error("Cann kb load failed")
# 初始化并行编译
res = _parallel_compilation_init(job)
if not res:
job.error("Parallel compilation failed")
except RuntimeError:
job.error("Initialize failed with RuntimeError")
finally:
# 解锁
local_lock.unlock()
job.result = "Success"
return res
@ -274,9 +356,13 @@ def tbe_initialize(job: TbeJob):
def get_auto_tune_support_op_list(job: TbeJob):
"""
Get GA tune supported op list
:param job:
:return:
获取支持自动调优的算子列表
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
list: 支持自动调优的算子列表
"""
from auto_tune_main import enable_auto_tune_support
auto_tune_op_list = enable_auto_tune_support()
@ -286,10 +372,14 @@ def get_auto_tune_support_op_list(job: TbeJob):
def _normalize_module_name(module_name, py_module_path):
"""
Normalize module name
:param module_name:
:param py_module_path:
:return:
规范化模块名称
Args:
module_name (str): 模块名称
py_module_path (str): Python模块路径
Returns:
None
"""
if py_module_path not in sys.path:
sys.path.insert(0, py_module_path)
@ -298,9 +388,13 @@ def _normalize_module_name(module_name, py_module_path):
def check_support(job: TbeJob):
"""
Check support
:param job:
:return:
检查算子是否受支持
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 算子是否受支持
"""
op_compute_info_list = get_compute_op_list(job.content)
if len(op_compute_info_list) != 1:
@ -341,21 +435,37 @@ def check_support(job: TbeJob):
def select_op_format(job: TbeJob):
"""
Select op format
:param job:
:return:
根据计算操作信息选择操作的格式
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 操作格式选择是否成功
"""
# 获取计算操作列表
compute_op_info_list = get_compute_op_list(job.content)
# 检查计算操作数量是否为1
if len(compute_op_info_list) != 1:
job.error("Invalid op compute num ({}) in check_support".format(len(compute_op_info_list)))
return False
# 获取第一个计算操作信息
compute_op_info = compute_op_info_list[0]
# 调整自定义操作信息
adjust_custom_op_info(compute_op_info)
# 组装操作参数
inputs, outputs, attrs = assemble_op_args(compute_op_info)
# 获取操作模块名称
op_module_name = get_module_name(compute_op_info)
# 获取Python模块路径
py_module_path = compute_op_info["py_module_path"]
# 规范化模块名称
_normalize_module_name(op_module_name, py_module_path)
# 设置操作选择格式的函数名称
op_func_name = "op_select_format"
# 调用操作函数选择格式
res = call_op_func((inputs, outputs, attrs), op_module_name, op_func_name)
# 设置操作格式选择结果
job.result = str(res)
return True
@ -363,15 +473,25 @@ def select_op_format(job: TbeJob):
def parallel_pre_compile_op(job: TbeJob):
"""
Parallel pre compile op
:param job:
:return:
并行预编译操作
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 预编译操作是否成功
"""
# 获取计算操作列表
compute_op_info_list = get_compute_op_list(job.content)
# 检查计算操作数量是否为1
if len(compute_op_info_list) != 1:
job.error("Invalid op compute num ({}) in pre compile op".format(len(compute_op_info_list)))
return False
# 获取第一个计算操作信息
compute_op_info = compute_op_info_list[0]
# 调整自定义操作信息
adjust_custom_op_info(compute_op_info)
# 预构建计算操作信息
_pre_build_compute_op_info(compute_op_info, job)
return True
@ -379,35 +499,60 @@ def parallel_pre_compile_op(job: TbeJob):
def _pre_build_compute_op_info(compute_op, job):
"""
Prebuild by compute op info
:param compute_op:
:param job:
:return:
根据计算操作信息预构建操作
Args:
compute_op (dict): 计算操作信息
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
None
"""
# 获取L1缓存大小
l1_size = job.content["l1_size"]
# 如果L1缓存大小不为-1则设置L1缓存信息
if l1_size != -1:
set_L1_info("op_L1_space", -1)
# 组装操作参数
inputs, outputs, attrs = assemble_op_args(compute_op, is_single_op_build=True)
# 获取操作模块名称
op_module_name = get_module_name(compute_op)
# 获取Python模块路径
py_module_path = compute_op["py_module_path"]
# 获取操作函数名称
op_func_name = compute_op["func_name"]
# 获取操作类型
op_type = compute_op["type"]
# 获取操作名称
op_name = compute_op["op_name"]
# 保存操作参数
save_op_params(op_name, "prebuild", (outputs, attrs))
l1_size = job.content["l1_size"]
# 设置L1缓存信息
set_L1_info("op_L1_space", l1_size)
# 规范化模块名称
_normalize_module_name(op_module_name, py_module_path)
# 获取未知形状信息
unknown_shape = compute_op["unknown_shape"]
# 获取int64模式信息
int64_mode = compute_op["int64mode"]
# 检查操作实现模式
res = check_op_impl_mode(op_module_name, op_func_name)
# 获取操作实现模式
op_impl_mode = job.content["SocInfo"]["op_impl_mode"]
# 获取操作实现模式列表
op_impl_mode_list = job.content["SocInfo"]["op_impl_mode_list"]
# 获取完整操作名称
op_full_name = job.content["full_name"]
# 如果操作不支持实现模式,则发出警告
if not res:
if op_impl_mode_list:
job.warning("The op {} do NOT support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
else:
# 否则,记录操作支持实现模式的信息
job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode))
# 获取选项信息
options = get_options_info(job.content)
# 分派预构建任务
dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_full_name,
op_type, op_func_name, unknown_shape,
(inputs, outputs, attrs, options), int64_mode, unknown_shape,
@ -416,13 +561,22 @@ def _pre_build_compute_op_info(compute_op, job):
def get_prebuild_output(op_name):
"""
get prebuild output
:param op_name:
Get prebuild output
获取预构建输出
Args:
op_name (str): 操作名称
Returns:
dict: 预构建输出
"""
# 将操作参数转换为JSON字符串
params_str = op_params_to_json(op_name)
try:
# 尝试解析JSON字符串
res = json.loads(params_str)
except ValueError:
# 如果解析失败,则返回空字典
res = {}
finally:
pass
@ -432,9 +586,15 @@ def get_prebuild_output(op_name):
def do_fuzz_build_tbe_op(job: TbeJob):
"""
Fuzzy build op
:param job:
:return:
模糊构建操作
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 模糊构建操作是否成功
"""
# 设置操作结果为"NOT_CHANGED"
job.result = "NOT_CHANGED"
return True
@ -442,9 +602,15 @@ def do_fuzz_build_tbe_op(job: TbeJob):
def _dump_fusion_op_info_to_json_file(job: TbeJob):
"""
Dump fusion op info to json file
:param job:
:return:
将融合操作信息转储到JSON文件
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
None
"""
# 如果系统参数调试路径不为空,则转储融合操作信息
if not job.sys_para_debug_path or job.sys_para_debug_path == "\0":
return
dump_fusion_json(json.dumps(job.content), job.sys_para_debug_path)
@ -453,30 +619,55 @@ def _dump_fusion_op_info_to_json_file(job: TbeJob):
def build_single_pre_op(job: TbeJob):
"""
Build single op
:param job:
:return:
构建单个操作的预处理过程
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 构建过程是否成功
"""
# 执行构建前的处理工作
before_build_process(job)
# 获取计算操作列表
compute_op_info_list = get_compute_op_list(job.content)
# 确保只有一个计算操作
if len(compute_op_info_list) != 1:
job.error("Invalid op compute num ({}) in build single op".format(len(compute_op_info_list)))
return False
# 获取单个计算操作信息
compute_op_info = compute_op_info_list[0]
# 调整自定义操作信息
adjust_custom_op_info(compute_op_info)
# 组装操作的输入、输出和属性
inputs, outputs, attrs = assemble_op_args(compute_op_info, is_single_op_build=True)
# 获取操作类型
op_type = compute_op_info["type"]
# 获取L1缓存大小
l1_size = job.content["l1_size"]
# 获取操作模块名称
op_module_name = get_module_name(compute_op_info)
# 获取操作内核名称
op_kernel_name = compute_op_info["op_name"]
# 获取Python模块路径
py_module_path = compute_op_info["py_module_path"]
# 获取完整操作名称
op_name = job.content["full_name"]
# 获取操作函数名称
op_func_name = compute_op_info["func_name"]
# 规范化模块名称
_normalize_module_name(op_module_name, py_module_path)
# 获取未知形状信息
unknown_shape = compute_op_info["unknown_shape"]
# 获取int64模式信息
int64_mode = compute_op_info["int64mode"]
# 获取操作模式
op_pattern = compute_op_info["pattern"]
# 获取选项信息
options = get_options_info(job.content)
# 获取模糊构建信息
fuzz_build_info = get_fuzz_build_info(job.content)
# 分派单个操作编译任务
dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_name, op_type, op_func_name,
op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode,
None, None, unknown_shape, op_pattern,
@ -487,13 +678,22 @@ def build_single_pre_op(job: TbeJob):
def before_build_process(job: TbeJob):
"""
Processing before build
:param job:
:return:
在构建前进行处理
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
None
"""
# 获取L1缓存大小并设置
l1_size = job.content["l1_size"]
set_L1_info("op_L1_space", l1_size)
# 将融合操作信息转储到JSON文件
_dump_fusion_op_info_to_json_file(job)
# 获取是否需要离线调优
offline_tune = job.sys_offline_tune
# 如果需要离线调优则将融合操作信息转储到JSON文件
if offline_tune:
dump_fusion_json(json.dumps(job.content), job.sys_tune_dump_path)
@ -501,20 +701,29 @@ def before_build_process(job: TbeJob):
def sync_fusion_env(fusion_need_sync, module_list):
"""
Sync fusion env
:param fusion_need_sync:
:param module_list:
:return:
同步融合环境
Args:
fusion_need_sync (int): 是否需要同步融合环境
module_list (dict): 模块列表
Returns:
bool: 同步是否成功
"""
# 如果不需要同步,则直接返回成功
if fusion_need_sync == 0:
return True
# 准备使用的模块列表
module_using = []
for key, value in module_list.items():
if value > 0:
module_using.append(str(key))
module_list[key] = 0
# 将使用的模块列表转换为字符串
module_str = ",".join(module_using)
# 导入使用的模块
import_py_module(module_str)
return True
@ -522,13 +731,23 @@ def sync_fusion_env(fusion_need_sync, module_list):
def parallel_compile_fusion_op(job: TbeJob):
"""
Compile fusion op in parallel compiler
:param job:
:return:
在并行编译器中编译融合操作
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 编译过程是否成功
"""
# 获取L1缓存大小
l1_size = job.content["l1_size"]
# 获取选项信息
options = get_options_info(job.content)
# 获取融合操作内核名称
op_kernel_name = job.content["fusion_op_name"]
# 获取完整操作名称
op_name = job.content["full_name"]
# 分派融合操作编译任务
dispatch_fusion_op_compile_task(job.source_id, job.id, l1_size, json.dumps(job.content), op_kernel_name, None, None,
options, None, job.pass_list, op_name)
return True
@ -537,112 +756,185 @@ def parallel_compile_fusion_op(job: TbeJob):
def ga_tune(job: TbeJob):
"""
GA tune
:param job:
:return:
使用遗传算法进行调优
Args:
job (TbeJob): 包含任务信息的TbeJob对象
Returns:
bool: 调优过程是否成功
"""
# 获取L1缓存大小
l1_size = job.content["l1_size"]
# 获取融合操作内核名称
op_kernel_name = job.content["fusion_op_name"]
# 获取完整操作名称
op_name = job.content["full_name"]
# 分派自动调优任务
dispatch_autotune_task(job.source_id, job.id, l1_size, json.dumps(job.content), {}, op_kernel_name, op_name)
# 设置任务状态为运行中
job.status = JobStatus.JOB_RUNNING
return True
def rl_tune_single_op(job: TbeJob):
"""
RL tune single op
:param job:
:return:
Perform RL (Reinforcement Learning) tuning for a single operation.
This function is responsible for tuning a single operation using RL techniques.
It retrieves the operation's information, performs the tuning, and handles any exceptions that may occur during the process.
Args:
job (TbeJob): An object containing job information, including the operation to be tuned.
Returns:
bool: True if the RL tuning is successful, False otherwise.
"""
# Retrieve the list of compute operations from the job content
compute_op_info_list = get_compute_op_list(job.content)
# Check if there is exactly one compute operation
if len(compute_op_info_list) != 1:
job.error("Invalid op compute num ({}) in rl tune single op".format(len(compute_op_info_list)))
return False
# Get the first (and only) compute operation info
compute_op_info = compute_op_info_list[0]
# Assemble the operation's input, output, and attributes
inputs, outputs, attrs = assemble_op_args(compute_op_info)
# Get the operation type
op_type = compute_op_info["type"]
# Get the L1 size from the job content
l1_size = job.content["l1_size"]
# Get the operation module name
op_module_name = get_module_name(compute_op_info)
# Get the operation kernel name
op_kernel_name = compute_op_info["op_name"]
# Get the full name of the operation
full_name = compute_op_info["name"]
# Get the Python module path
py_module_path = compute_op_info["py_module_path"]
# Get the operation function name
op_func_name = compute_op_info["func_name"]
# Normalize the module name
_normalize_module_name(op_module_name, py_module_path)
# Set the current operation name
set_current_op_name(op_kernel_name)
# Get the unknown shape information
unknown_shape = compute_op_info["unknown_shape"]
# Get the int64 mode information
int64_mode = compute_op_info["int64mode"]
# Get the operation pattern
op_pattern = compute_op_info["pattern"]
# Get the fuzz build information
fuzz_build_info = get_fuzz_build_info(job.content)
# Get the auto tiling mode
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
# Get the device ID
device_id = job.content["SocInfo"]["deviceId"]
# Get the options information
options = get_options_info(job.content)
try:
# Build the single operation from C code
build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape,
(inputs, outputs, attrs), int64_mode, unknown_shape, options,
op_pattern, auto_tiling_mode, device_id, json.dumps(fuzz_build_info))
# pylint: disable=broad-except
except Exception:
# If an exception occurs, log the error and return False
job.error(
"Single op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string))
exc_type, exc_value, _ = sys.exc_info()
job.error(
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
return False
finally:
pass
# Prepare the tuning operation module name
tune_op_module_name = op_module_name + "@" + py_module_path
# Get the base kernel path
base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o"
# Dispatch the single tune task
from schedule_search.rl_online_tune import dispatch_single_tune_task
pack_args = pack_op_args(inputs, outputs, attrs)
res = dispatch_single_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, full_name,
tune_op_module_name, op_func_name, op_type, pack_args)
# Process the RL tune result
return _process_rl_tune_result(job, op_type, res)
def rl_tune_fusion_op(job: TbeJob):
"""
rl tune fusion op
:param job:
:return:
Perform RL tuning for a fusion operation.
This function is responsible for tuning a fusion operation using RL techniques.
It compiles the operation using multiprocessing and handles any exceptions that may occur during the process.
Args:
job (TbeJob): An object containing job information, including the fusion operation to be tuned.
Returns:
bool: True if the RL tuning is successful, False otherwise.
"""
# Get the fusion operation kernel name
op_kernel_name = job.content["fusion_op_name"]
# Set the current operation name
set_current_op_name(op_kernel_name)
try:
# Compile the operation using multiprocessing
from schedule_search.rl_online_tune import compile_op_by_mp
compile_op_by_mp(json.dumps(job.content))
# pylint: disable=broad-except
except Exception:
# If an exception occurs, log the error and return False
job.error(
"Fusion op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string))
exc_type, exc_value, _ = sys.exc_info()
job.error(
"exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc()))
return False
finally:
pass
# Get the L1 size
l1_size = job.content["l1_size"]
# Get the base kernel path
base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o"
# Get the list of compute operations
compute_op_list = get_compute_op_list(job.content)
# Prepare the operation module names string
op_module_names_str = ""
op_type_set = set()
for op in compute_op_list:
op_module_names_str = ','.join([op_module_names_str, get_module_name(op)])
op_type_set.add(op["type"])
# Remove the leading comma from the operation module names string
op_module_names_str = op_module_names_str[1:]
# Join the operation types with double underscore
op_type = "__".join(list(op_type_set))
# Dispatch the fusion tune task
from schedule_search.rl_online_tune import dispatch_fusion_tune_task
res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str,
json.dumps(job.content))
# Process the RL tune result
return _process_rl_tune_result(job, op_type, res)
def _process_rl_tune_result(job, op_type, res):
"""
Process the result of RL tuning.
If the tuning result is False, it checks if the operation type is in the black list or if the job is set to offline tune.
If the tuning result is True, it sets the job status to running.
Args:
job (TbeJob): An object containing job information.
op_type (str): The type of the operation.
res (bool): The result of RL tuning.
Returns:
bool: The processed result of RL tuning.
"""
if not res:
# Check if the operation type is in the black list or if the job is set to offline tune
from schedule_search.tune_util import filter_black_op_type
res = bool(job.sys_offline_tune or os.getenv("REPEAT_TUNE", "False").lower() != "true" or filter_black_op_type(
op_type))
else:
# Set the job status to running
job.status = JobStatus.JOB_RUNNING
res = True
return res
@ -650,8 +942,13 @@ def _process_rl_tune_result(job, op_type, res):
def get_finish_tasks(source_id):
"""
Get finish task from parallel compilation framework
:return task info list
Get the list of finished tasks from the parallel compilation framework.
Args:
source_id (int): The source ID of the tasks.
Returns:
list: A list of finished task information.
"""
return get_finished_compilation_task(source_id)
@ -664,14 +961,21 @@ def tbe_finalize(auto_tiling_mode, offline_tune, job: TbeJob):
:param job: TbeJob
:return: None
"""
# 释放多进程环境
deinit_multi_process_env()
# 如果自动切分模式为RL或者离线调优则释放RL调优
if "RL" in auto_tiling_mode or offline_tune:
from schedule_search.rl_online_tune import rl_tune_deinit
rl_tune_deinit()
# 卸载Cann kb
res = _cann_kb_unload(job)
# 如果卸载失败则返回False
if res == 1:
job.error("Cann kb unload failed")
return False
# 清除融合参数
clear_fusion_params()
# 删除缓存
_remove_cache(job)
# 返回True
return True

@ -26,6 +26,7 @@ class BuildType(Enum):
ACCURATELY = "accurately"
# 获取JobType枚举类中的所有值
job_type_list = [job_type.value for _, job_type in JobType.__members__.items()]
@ -35,14 +36,19 @@ def check_job_json(job_info):
:param job_info:tne compilation job json
:return: raise value error if wrong
"""
# 检查job_info中是否包含source_id
if 'source_id' not in job_info:
raise ValueError("Json string Errors, key:source_id not found.")
# 检查job_info中是否包含job_id
if 'job_id' not in job_info:
raise ValueError("Json string Errors, key:job_id not found.")
# 检查job_info中是否包含job_type
if 'job_type' not in job_info or not job_info['job_type']:
raise ValueError("Json string Errors, key:job_type not found.")
# 检查job_info中job_type是否在job_type_list中
if job_info['job_type'] not in job_type_list:
raise ValueError("Invalid job type: {}.".format(job_info['job_type']))
# 检查job_info中是否包含job_content
if 'job_content' not in job_info:
raise ValueError("Json string Errors, key:job_content not found.")
@ -52,6 +58,7 @@ def reset_op_debug_level_in_soc_info(level):
:param level: op_debug_level, if level is 3 or 4, replace it with 0
:return: op_debug_level
"""
# 如果level为3或4则将其替换为0
if level in ("3", "4"):
level = "0"
return level
@ -62,6 +69,7 @@ def get_real_op_debug_level(initialize_job_info):
:param initialize_job_info: initialize_job_info
:return: origin op_debug_level for init_multi_process_env
"""
# 返回initialize_job_info中op_debug_level的值
return initialize_job_info["SocInfo"]["op_debug_level"]
@ -72,21 +80,35 @@ def get_soc_info(initialize_job_info):
:return: soc info
"""
soc_param = dict()
# 获取soc_info中的op_impl_mode
soc_param["op_impl_mode"] = initialize_job_info["SocInfo"]["op_impl_mode"]
# 获取soc_info中的op_debug_level并调用reset_op_debug_level_in_soc_info函数进行处理
soc_param["op_debug_level"] = reset_op_debug_level_in_soc_info(initialize_job_info["SocInfo"]["op_debug_level"])
# 获取soc_info中的op_impl_mode_list
soc_param["op_impl_mode_list"] = initialize_job_info["SocInfo"]["op_impl_mode_list"]
# 获取soc_info中的op_debug_dir
soc_param["op_debug_dir"] = initialize_job_info["SocInfo"]["op_debug_dir"]
# 获取soc_info中的vector_fp_ceiling
soc_param["vector_fp_ceiling"] = initialize_job_info["SocInfo"]["vector_fp_ceiling"]
# 获取soc_info中的mdl_bank_path
soc_param['mdl_bank_path'] = initialize_job_info["SocInfo"]["mdl_bank_path"]
# 获取soc_info中的op_bank_path
soc_param['op_bank_path'] = initialize_job_info["SocInfo"]["op_bank_path"]
soc_info = list()
# 获取soc_info中的socVersion
soc_info.append(initialize_job_info["SocInfo"]["socVersion"])
# 获取soc_info中的coreType
soc_info.append(initialize_job_info["SocInfo"]["coreType"])
# 获取soc_info中的coreNum
soc_info.append(initialize_job_info["SocInfo"]["coreNum"])
# 获取soc_info中的l1Fusion
soc_info.append(initialize_job_info["SocInfo"]["l1Fusion"])
# 获取soc_info中的l2Mode
soc_info.append(initialize_job_info["SocInfo"]["l2Mode"])
# 获取soc_info中的l2Fusion
soc_info.append(initialize_job_info["SocInfo"]["l2Fusion"])
# 将soc_param添加到soc_info中
soc_info.append(soc_param)
return soc_info
@ -98,16 +120,22 @@ def check_arg_info(io_info):
:param io_info:A dict, to be checked.
:return: Exception: If specific keyword is not found.
"""
# 检查io_info中是否包含shape
if 'shape' not in io_info:
raise ValueError("Json string Errors, key:shape not found.")
# 检查io_info中是否包含ori_shape
if 'ori_shape' not in io_info:
raise ValueError("Json string Errors, key:ori_shape not found.")
# 检查io_info中是否包含format
if 'format' not in io_info or not io_info['format']:
raise ValueError("Json string Errors, key:format not found.")
# 检查io_info中是否包含ori_format
if 'ori_format' not in io_info or not io_info['ori_format']:
raise ValueError("Json string Errors, key:ori_format not found.")
# 检查io_info中是否包含dtype
if 'dtype' not in io_info or not io_info['dtype']:
raise ValueError("Json string Errors, key:dtype not found.")
# 检查io_info中是否包含param_type
if 'param_type' not in io_info or not io_info['param_type']:
raise ValueError("Json string Errors, key:param_type not found.")
@ -119,18 +147,28 @@ def get_input_output_args(io_info):
:return:input/output args
"""
args = []
# 如果io_info为空则返回空列表
if io_info is None:
return args
# 遍历io_info中的每个元素
for item in io_info:
# 如果元素是字典类型
if isinstance(item, dict):
# 调用get_single_io_arg函数获取单个输入/输出参数
arg = get_single_io_arg(item)
args.append(arg)
elif isinstance(item, list):
# 如果元素是列表类型
dyn_arg = []
# 创建一个空列表dyn_arg
for info in item:
# 遍历列表中的每个元素
arg = get_single_io_arg(info)
# 调用get_single_io_arg函数获取单个输入/输出参数
dyn_arg.append(arg)
# 将参数添加到dyn_arg列表中
args.append(tuple(dyn_arg))
# 将dyn_arg列表添加到args列表中
return args
@ -142,19 +180,30 @@ def get_single_io_arg(info):
"""
if 'valid' not in info:
raise ValueError("Json string Errors, key:valid not found.")
# 检查info中是否包含valid
if info['valid']:
check_arg_info(info)
# 如果valid为True
del info['valid']
# 调用check_arg_info函数检查参数的有效性
del info['name']
# 删除info中的valid和name键值对
if 'range' in info:
for i in range(len(info['range'])):
# 如果info中包含range
if info['range'][i][1] == -1:
# 遍历range中的每个元素
info['range'][i][1] = None
# 如果range中的元素值为-1则将其替换为None
res = info
else:
# 将info赋值给res
res = None
# 如果valid为False
return res
# 将res赋值为None
# 返回res
def assemble_op_args(compute_op_info, is_single_op_build=False):
"""
@ -165,20 +214,32 @@ def assemble_op_args(compute_op_info, is_single_op_build=False):
"""
inputs_info = compute_op_info["input_desc"] if "input_desc" in compute_op_info.keys() else None
outputs_info = compute_op_info["output_desc"] if "output_desc" in compute_op_info.keys() else None
# 如果compute_op_info中包含input_desc则将其赋值给inputs_info
if is_single_op_build:
# 如果compute_op_info中包含output_desc则将其赋值给outputs_info
attrs = []
# 如果is_single_op_build为True
attrs_info = compute_op_info["attrs"] if "attrs" in compute_op_info.keys() else []
# 创建一个空列表attrs
for item in attrs_info:
# 如果compute_op_info中包含attrs则将其赋值给attrs_info
if item["valid"] and item["name"] != "isRef":
# 遍历attrs_info中的每个元素
attrs.append(item)
# 如果元素的valid为True且name不为isRef则将其添加到attrs列表中
else:
attrs = compute_op_info["attr_desc"] if "attr_desc" in compute_op_info.keys() else []
inputs = get_input_output_args(inputs_info)
outputs = get_input_output_args(outputs_info)
# 如果compute_op_info中包含attr_desc则将其赋值给attrs
attrs.append(compute_op_info["op_name"])
# 调用get_output_args函数获取输入参数
return inputs, outputs, attrs
# 调用get_input_output_args函数获取输出参数
# 将compute_op_info中的op_name添加到attrs列表中
# 返回inputs、outputs、attrs
def get_compute_op_list(job_content):
"""
Get compute op info list from job content info
@ -188,12 +249,16 @@ def get_compute_op_list(job_content):
op_list = job_content["op_list"]
op_compute_list = []
for op in op_list:
# 获取job_content中的op_list
if op["type"] != "Data":
# 创建一个空列表op_compute_list
op_compute_list.append(op)
return op_compute_list
# 如果元素的typeData则将其添加到op_compute_list列表中
def get_options_info(job_content):
# 返回op_compute_list列表
"""
Get options info
:param job_content:
@ -203,17 +268,29 @@ def get_options_info(job_content):
options["socVersion"] = job_content["SocInfo"]["socVersion"]
options["coreType"] = job_content["SocInfo"]["coreType"]
options["coreNum"] = job_content["SocInfo"]["coreNum"]
# 创建一个空字典options
options["l1Fusion"] = job_content["SocInfo"]["l1Fusion"]
# 获取job_content中的socVersion
options["l2Fusion"] = job_content["SocInfo"]["l2Fusion"]
# 获取job_content中的coreType
options["l2Mode"] = job_content["SocInfo"]["l2Mode"]
# 获取job_content中的coreNum
options["op_debug_level"] = reset_op_debug_level_in_soc_info(job_content["SocInfo"]["op_debug_level"])
# 获取job_content中的l1Fusion
options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"]
# 获取job_content中的l2Fusion
options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"]
# 获取job_content中的l2Mode
options["mdl_bank_path"] = job_content["SocInfo"]["mdl_bank_path"]
# 获取job_content中的op_debug_level并调用reset_op_debug_level_in_soc_info函数进行处理
options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"]
# 获取job_content中的op_impl_mode
options["deviceId"] = job_content["SocInfo"]["deviceId"]
# 从job_content中获取deviceId并将其赋值给options字典的deviceId键
options["autoTilingMode"] = job_content["SocInfo"]["autoTilingMode"]
# 从job_content中获取autoTilingMode并将其赋值给options字典的autoTilingMode键
options["op_impl_mode_list"] = job_content["SocInfo"]["op_impl_mode_list"]
# 从job_content中获取op_impl_mode_list并将其赋值给options字典的op_impl_mode_list键
return options
@ -223,15 +300,22 @@ def get_fuzz_build_info(job_content):
:param job_content: job content info
:return: fuzz build info
"""
# 从job_content中获取计算操作列表
op_compute_info = get_compute_op_list(job_content)[0]
# 初始化fuzz_build_info字典
fuzz_build_info = dict()
# 根据op_compute_info中的build_type判断编译类型
fuzz_build_info["compile_type"] = "fuzzily_build" if op_compute_info["build_type"] == BuildType.FUZZILY.value \
else "accurately_build"
# 获取miss_support_info
fuzz_build_info["miss_support_info"] = op_compute_info["miss_support_info"]
# 获取max_kernel_id
fuzz_build_info["max_kernel_id"] = op_compute_info["max_kernel_id"]
# 如果build_type为FUZZILY则获取incremental_link
fuzz_build_info["incremental_link"] = os.path.realpath(
job_content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_compute_info["name"] + ".json") if \
op_compute_info["build_type"] == BuildType.FUZZILY.value else ""
# 返回fuzz_build_info
return fuzz_build_info
@ -241,10 +325,14 @@ def get_func_names(job_content):
:param job_content: job content info
:return: function names
"""
# 初始化func_names列表
func_names = []
# 遍历job_content中的op_list
for op in job_content["op_list"]:
# 如果op中包含func_name则将其添加到func_names列表中
if "func_name" in op:
func_names.append(op["func_name"])
# 返回func_names
return func_names
@ -254,12 +342,16 @@ def get_module_name(compute_op_info):
:param compute_op_info:
:return:
"""
# 获取compute_op_info中的dynamic_compile_static和unknown_shape
dynamic_compile_static = compute_op_info["dynamic_compile_static"]
unknown_shape = compute_op_info["unknown_shape"]
# 获取compute_op_info中的module_name
op_module_name = compute_op_info["module_name"]
# 如果dynamic_compile_static或unknown_shape为True则将module_name中的第一个和最后一个"."之间的字符串替换为".dynamic."
if dynamic_compile_static or unknown_shape:
d = ".dynamic."
op_module_name = d.join((op_module_name.split(".")[0], op_module_name.split(".")[-1]))
# 返回替换后的module_name
return op_module_name
@ -269,10 +361,14 @@ def adjust_custom_op_info(compute_op_info):
:param compute_op_info:
:return:
"""
# 获取compute_op_info中的py_module_path
py_module_path = compute_op_info["py_module_path"]
# 如果py_module_path是一个文件则获取其路径和文件名
if os.path.isfile(py_module_path):
py_module_path, file_name = os.path.split(py_module_path)
# 获取文件名中的模块名
module_name, _ = os.path.splitext(file_name)
# 将py_module_path和module_name更新到compute_op_info中
compute_op_info["py_module_path"] = py_module_path
compute_op_info["module_name"] = module_name
@ -281,5 +377,6 @@ def pack_op_args(inputs, outputs, attrs):
"""
flatten inputs outputs attrs
"""
# 将inputs、outputs、attrs展开为一个列表
op_args = (inputs, outputs, attrs)
return [item for arg in op_args for item in arg]

@ -20,14 +20,23 @@ from enum import Enum
class JobType(Enum):
""" Job Type """
# 初始化任务
INITIALIZE_JOB = 'Initialize'
# 结束任务
FINALIZE_JOB = 'Finalize'
# 检查支持任务
CHECK_JOB = 'CheckSupport'
# 选择格式任务
SELECT_JOB = 'SelectFormat'
# 预编译任务
PRECOMPILE_JOB = 'PreCompile'
# 编译任务
COMPILE_JOB = 'Compile'
# 融合编译任务
FUSION_COMPILE_JOB = 'FusionOpCompile'
# 调优任务
TUNE_JOB = 'Tune'
# 查询任务
QUERY_JOB = 'Query'
@ -51,9 +60,13 @@ class JobStatus(Enum):
class LogMessage:
""" Log message """
# 初始化函数,用于创建一个对象
def __init__(self, index, level, info):
# 将传入的index参数赋值给对象的index属性
self.index = index
# 将传入的level参数赋值给对象的level属性
self.level = level
# 将传入的info参数赋值给对象的info属性
self.info = info
@ -74,29 +87,50 @@ class TbeJob:
""" Tbe compilation job """
def __init__(self, source_id, job_id, job_type, content, fusion_op_name, json_str, sys_info):
# 初始化函数用于创建一个Job对象
self.source_id = source_id
# 源ID
self.id = job_id
# 任务ID
self.type = JobType(job_type)
# 任务类型
self.status = JobStatus.JOB_INITIAL
# 任务状态
self.content = content
# 任务内容
self.fusion_op_name = fusion_op_name
# 融合操作名称
self.result = ""
# 任务结果
self.process_info = []
# 任务处理信息
self.json_string = json_str
# JSON字符串
self._sys_logger = sys_info["logger"]
# 系统日志
self.sys_offline_tune = sys_info["offline_tune"]
# 离线调优
self.sys_tune_dump_path = sys_info["tune_dump_path"]
# 调优转储路径
self.sys_para_debug_path = sys_info["para_debug_path"]
# 参数调试路径
# license info
self.rl_tune_switch = sys_info["rl_tune_switch"]
# 强化学习调优开关
self.rl_tune_list = sys_info["rl_tune_list"]
# 强化学习调优列表
self.op_tune_switch = sys_info["op_tune_switch"]
# 操作调优开关
self.op_tune_list = sys_info["op_tune_list"]
# 操作调优列表
self.pass_list = sys_info["pass_list"]
# 通过列表
# soc info
self.soc_version = sys_info["socVersion"]
# SoC版本
self.core_num = sys_info["coreNum"]
# 核心数量
self.op_bank_path = sys_info["op_bank_path"]
def debug(self, msg, *args, **kwargs):
@ -106,9 +140,13 @@ class TbeJob:
:param args:
:return:
"""
# 获取处理后的消息
processed_msg = _get_message(msg, args)
# 创建日志消息对象
message = LogMessage(len(self.process_info), LogLevel.DEBUG, processed_msg)
# 将日志消息对象添加到process_info列表中
self.process_info.append(message)
# 使用系统日志记录器记录日志
self._sys_logger.debug(msg, *args, **kwargs)
def info(self, msg, *args, **kwargs):
@ -118,9 +156,13 @@ class TbeJob:
:param args:
:return:
"""
# 获取处理后的消息
processed_msg = _get_message(msg, args)
# 创建日志消息对象
message = LogMessage(len(self.process_info), LogLevel.INFO, processed_msg)
# 将日志消息对象添加到process_info列表中
self.process_info.append(message)
# 使用系统日志记录器记录日志
self._sys_logger.info(msg, *args, **kwargs)
def warning(self, msg, *args, **kwargs):
@ -130,9 +172,13 @@ class TbeJob:
:param args:
:return:
"""
# 获取处理后的消息
processed_msg = _get_message(msg, args)
# 创建日志消息对象
message = LogMessage(len(self.process_info), LogLevel.WARNING, processed_msg)
# 将日志消息对象添加到process_info列表中
self.process_info.append(message)
# 使用系统日志记录器记录警告信息
self._sys_logger.warning(msg, *args, **kwargs)
def error(self, msg, *args, **kwargs):
@ -142,9 +188,13 @@ class TbeJob:
:param args:
:return:
"""
# 获取处理后的消息
processed_msg = _get_message(msg, args)
# 创建一个LogMessage对象包含消息的长度、日志级别和消息内容
message = LogMessage(len(self.process_info), LogLevel.ERROR, processed_msg)
# 将LogMessage对象添加到process_info列表中
self.process_info.append(message)
# 使用_sys_logger记录错误日志msg为原始消息args和kwargs为参数
self._sys_logger.error(msg, *args, **kwargs)
def error_manager(self, msg, *args, **kwargs):
@ -154,30 +204,50 @@ class TbeJob:
:param args:
:return:
"""
# 如果msg为空则输出警告信息并返回
if not msg:
self.warning("Get empty error manager message, op_name: {}".format(self.fusion_op_name))
return
# 初始化异常信息为None
exception_info = None
# 获取融合操作名称
op_name = self.fusion_op_name
# 如果msg是Exception类型
if isinstance(msg, Exception):
# 遍历msg的参数
for arg in msg.args:
# 如果参数是字典类型且包含"errCode"键
if isinstance(arg, dict) and "errCode" in arg:
# 将异常信息赋值给exception_info
exception_info = arg
break
# 如果没有找到异常信息
if not exception_info:
# 输出错误信息
self.error("Exception message:{}".format(msg))
return
# 如果msg不是Exception类型
else:
# 将msg的第一个元素赋值给异常信息
exception_info = msg[0]
# 如果msg的长度大于等于2
if len(msg) >= 2:
# 将msg的第二个元素赋值给融合操作名称
op_name = msg[1]
# 如果异常信息不是字典类型或为空
if not isinstance(exception_info, dict) or not exception_info:
# 输出警告信息
self.warning("Get illegal error manager message, op_name: {}".format(self.fusion_op_name))
return
# 将异常信息中的op_name字段赋值为融合操作名称
exception_info["op_name"] = op_name
# 将异常信息转换为JSON格式
processed_msg = json.dumps(exception_info)
# 创建LogMessage对象
message = LogMessage(len(self.process_info), LogLevel.ERROR_MANAGER, processed_msg)
# 将LogMessage对象添加到process_info列表中
self.process_info.append(message)
# 输出异常信息
self._sys_logger.exception(msg, *args, **kwargs)
def get_result(self):
@ -186,15 +256,26 @@ class TbeJob:
:return: job process result string
"""
result = dict()
# 获取任务状态
result["status"] = self.status.value
# 获取任务源ID
result["source_id"] = self.source_id
# 获取任务ID
result["job_id"] = self.id
# 获取任务类型
result["job_type"] = self.type.value
# 获取融合操作名称
result["fusion_op_name"] = self.fusion_op_name
# 获取任务结果
result["result"] = self.result
process_info = []
# 遍历任务处理信息
for info in self.process_info:
# 构造消息字典
msg = {"index": info.index, "level": info.level.value, "message": info.info}
# 将消息字典添加到处理信息列表中
process_info.append(msg)
# 将处理信息列表添加到结果字典中
result["process_info"] = process_info
# 将结果字典转换为JSON字符串并返回
return json.dumps(result)

@ -29,6 +29,7 @@ class TbeJobManager:
""" TBE compiler job manager """
def __init__(self):
# 定义一个字典,用于存储不同类型的任务及其对应的处理函数
self.job_handlers = {
JobType.INITIALIZE_JOB: self.initialize_handler,
JobType.FINALIZE_JOB: self.finalize_handler,
@ -41,24 +42,43 @@ class TbeJobManager:
JobType.QUERY_JOB: self.query_handler
}
# 定义一个字典,用于存储所有任务
self._all_jobs = {}
# 定义一个字典,用于存储已完成任务
self._finished_jobs = {}
# 定义一个字典,用于存储正在运行的任务
self._running_jobs = {}
# 定义一个字典,用于存储原始完成任务
self._raw_finish_jobs = {}
# 定义一个布尔值用于判断TBE是否初始化
self.tbe_initialize = False
# 定义一个变量,用于存储初始化缓存
self.init_cache = None
# 定义一个字符串,用于存储参数调试路径
self.para_debug_path = ""
# 定义一个字符串,用于存储自动调优模式
self.auto_tiling_mode = ""
# 定义一个布尔值,用于判断是否离线调优
self.offline_tune = False
# 定义一个列表,用于存储调优操作
self.tune_op_list = []
# 定义一个字符串,用于存储调优输出路径
self.tune_dump_path = ""
# 定义一个字符串,用于存储调优库路径
self.tune_bank_path = ""
# 定义一个列表,用于存储自动调优操作
self.auto_tune_op_list = []
# 定义一个字典,用于存储预编译操作
self.pre_build_ops = {}
# 定义一个整数,用于存储融合编译需要同步的次数
self.fusion_need_sync = 0
# 定义一个字典,用于存储导入的模块
self.imported_module = {}
# 定义一个字符串用于存储SoC版本
self.soc_version = ""
# 定义一个整数,用于存储核心数量
self.core_num = 0
# 定义一个字符串,用于存储操作库路径
self.op_bank_path = ""
# license info
self.rl_tune_switch = ""
@ -68,6 +88,7 @@ class TbeJobManager:
self.pass_list = ""
def __del__(self):
# 删除对象时调用reset方法
self.reset()
def reset(self):
@ -75,22 +96,38 @@ class TbeJobManager:
Reset the job manager
:return: None
"""
# 重置所有任务
self._all_jobs = {}
# 重置已完成任务
self._finished_jobs = {}
# 重置正在运行的任务
self._running_jobs = {}
# 重置原始已完成任务
self._raw_finish_jobs = {}
# 重置调试路径
self.para_debug_path = ""
# 重置自动切分模式
self.auto_tiling_mode = ""
# 重置离线调优
self.offline_tune = False
# 重置调优操作列表
self.tune_op_list = []
# 重置调优导出路径
self.tune_dump_path = ""
# 重置调优银行路径
self.tune_bank_path = ""
# 重置自动调优操作列表
self.auto_tune_op_list = []
# 重置预构建操作
self.pre_build_ops = []
# 重置融合需要同步
self.fusion_need_sync = 0
# 重置导入模块
self.imported_module = {}
# 如果tbe_initialize为True则调用tbe_finalize方法
if self.tbe_initialize:
tbe_finalize(self.auto_tiling_mode, self.offline_tune, self.init_cache)
# 重置tbe_initialize
self.tbe_initialize = False
self.init_cache = None
self.soc_version = ""
@ -105,11 +142,17 @@ class TbeJobManager:
"""
job = None
try:
# 将job_str转换为json格式
job_json = json.loads(job_str)
# 检查job_json的合法性
check_job_json(job_json)
# 获取job_id
job_id = job_json["job_id"]
# 获取source_id
source_id = job_json["source_id"]
# 获取job_type
job_type = job_json["job_type"]
# 获取系统信息
sys_info = self._get_job_sys_info()
fusion_op_name = "NA" if "fusion_op_name" not in job_json["job_content"] else job_json["job_content"][
"fusion_op_name"]
@ -140,173 +183,260 @@ class TbeJobManager:
def initialize_handler(self, job: TbeJob):
""" Initialize job handler """
# 初始化系统信息
self._init_sys_info(job)
# 调用tbe_initialize函数初始化job
res = tbe_initialize(job)
# 如果初始化失败记录错误信息并将job状态设置为JOB_FAILED
if not res:
job.error("Process Initialize Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 如果auto_tiling_mode中包含"GA",则获取自动调优支持的操作列表
if "GA" in self.auto_tiling_mode:
self.auto_tune_op_list = get_auto_tune_support_op_list(job)
# 设置tbe_initialize为True
self.tbe_initialize = True
# 将job保存到init_cache中
self.init_cache = job
# 将job状态设置为JOB_SUCCESS
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def finalize_handler(self, job: TbeJob):
""" Finalize job handler """
# 如果tbe_initialize为False则直接将job状态设置为JOB_SUCCESS
if not self.tbe_initialize:
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
# 调用tbe_finalize函数传入auto_tiling_mode和offline_tune参数
res = tbe_finalize(self.auto_tiling_mode, self.offline_tune, job)
# 如果finalize失败记录错误信息并将job状态设置为JOB_FAILED
if not res:
job.error("Process Finalize Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job状态设置为JOB_SUCCESS
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def check_support_handler(self, job: TbeJob):
""" Check Support job handler """
# 调用check_support函数检查job是否支持
res = check_support(job)
# 如果不支持记录错误信息并将job状态设置为JOB_FAILED
if not res:
job.error("Process CheckSupport Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 更新导入的操作模块
self._update_imported_op_module(job)
# 将job状态设置为JOB_SUCCESS
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def select_format_handler(self, job: TbeJob):
""" Select Format job handler """
# 调用select_op_format函数选择操作格式
res = select_op_format(job)
# 如果选择失败记录错误信息并将job状态设置为JOB_FAILED
if not res:
job.error("Process SelectFormat Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job状态设置为JOB_SUCCESS
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def pre_compile_handler(self, job: TbeJob):
""" Pre Compile job handler """
# 调用parallel_pre_compile_op函数对job进行预处理
res = parallel_pre_compile_op(job)
# 如果预处理失败则记录错误信息并将job状态设置为JOB_FAILED
if not res:
job.error("Process PreCompile Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job添加到pre_build_ops字典中以fusion_op_name为键
self.pre_build_ops[job.content["fusion_op_name"]] = job
# 将job状态设置为JOB_RUNNING
return self.add_to_running_jobs(job)
def compile_handler(self, job: TbeJob):
""" Compile job handler """
# 获取job中的compute_op_list
compute_op_list = get_compute_op_list(job.content)
# 如果compute_op_list只有一个元素则调用single_op_compile函数进行编译
if len(compute_op_list) == 1: # pylint: disable=no-else-return
return self.single_op_compile(job)
else:
# 调用before_build_process函数对job进行预处理
before_build_process(job)
# 如果需要同步fusion则调用sync_fusion_env函数进行同步
if self.fusion_need_sync:
sync_fusion_env(self.fusion_need_sync, self.imported_module)
self.fusion_need_sync = 0
# 调用parallel_compile_fusion_op函数对job进行编译
res = parallel_compile_fusion_op(job)
# 如果编译失败则记录错误信息并将job状态设置为JOB_FAILED
if not res:
job.error("Parallel_compile_fusion_op Job failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job状态设置为JOB_RUNNING
return self.add_to_running_jobs(job)
def single_op_compile(self, job: TbeJob):
"""Single operator compile"""
# 调用do_fuzz_build_tbe_op函数对job进行编译
res = do_fuzz_build_tbe_op(job)
# 如果编译失败则记录错误信息并将job状态设置为JOB_FAILED
if not res:
job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 如果job.result为"NOT_CHANGED"则调用before_build_process函数进行预处理并调用build_single_pre_op函数进行编译
if job.result == "NOT_CHANGED":
job.result = ""
before_build_process(job)
res = build_single_pre_op(job)
# 如果编译失败则记录错误信息并将job状态设置为JOB_FAILED
if not res:
job.error("Process build single pre op failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 将job状态设置为JOB_RUNNING
return self.add_to_running_jobs(job)
# 如果job.result为"SUCCESS"则将job状态设置为JOB_SUCCESS
if job.result == "SUCCESS":
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
# 如果编译失败则记录错误信息并将job状态设置为JOB_FAILED
job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
def tune_handler(self, job: TbeJob):
""" Tune job handler """
before_build_process(job)
# 选择调优模式
tune_mode = self._select_tune_mode(job)
# 如果调优模式为不调优,则直接调用编译处理函数
if tune_mode == TuneMode.NO_TUNE:
return self.compile_handler(job)
# 获取计算操作列表
compute_op_list = get_compute_op_list(job.content)
# 如果计算操作列表只有一个,则调用单操作调优函数
if len(compute_op_list) == 1:
return self.single_op_tune(job)
# 否则调用融合操作调优函数
return self.fusion_op_tune(job)
def single_op_tune(self, job: TbeJob):
"""Single operator tune"""
# 选择调优模式
tune_mode = self._select_tune_mode(job)
# 如果调优模式为强化学习调优
if tune_mode == TuneMode.RL_TUNE:
# 调用强化学习单操作调优函数
res = rl_tune_single_op(job)
# 如果调优失败,则记录错误信息,并将任务状态设置为失败
if not res:
job.error(
"Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 否则,如果需要同步融合环境,则调用同步融合环境函数
else:
if self.fusion_need_sync:
sync_fusion_env(self.fusion_need_sync, self.imported_module)
self.fusion_need_sync = 0
# 调用遗传算法调优函数
res = ga_tune(job)
# 如果调优失败,则记录错误信息,并调用编译处理函数
if not res:
job.error("ga tune Job failed, job json string:{}".format(job.json_string))
return self.compile_handler(job)
# 如果任务状态为运行中
if job.status == JobStatus.JOB_RUNNING:
# 如果调优模式为强化学习调优,则更新导入的操作模块
if tune_mode == TuneMode.RL_TUNE:
self._update_imported_op_module(job)
# 将任务添加到运行中任务列表
return self.add_to_running_jobs(job)
# 否则将任务添加到已完成任务列表,并设置任务状态为成功
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def fusion_op_tune(self, job: TbeJob):
"""Fusion operator tune"""
# 选择调优模式
tune_mode = self._select_tune_mode(job)
# 如果需要同步融合环境,则调用同步融合环境函数
if self.fusion_need_sync:
sync_fusion_env(self.fusion_need_sync, self.imported_module)
self.fusion_need_sync = 0
# 如果调优模式为强化学习调优,则调用强化学习融合操作调优函数
if tune_mode == TuneMode.RL_TUNE:
res = rl_tune_fusion_op(job)
# 否则调用遗传算法调优函数
else:
res = ga_tune(job)
# 如果调优失败,则记录错误信息,并将任务状态设置为失败
if not res:
job.error(
"Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string))
return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED)
# 如果任务状态为运行中,则将任务添加到运行中任务列表
if job.status == JobStatus.JOB_RUNNING:
return self.add_to_running_jobs(job)
# 否则将任务添加到已完成任务列表,并设置任务状态为成功
return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS)
def query_handler(self, query_job: TbeJob):
""" Query job handler """
# 获取查询任务的source_id和job_id
target_source_id = query_job.content["source_id"]
target_job_id = query_job.content["job_id"]
# 根据source_id和job_id获取已完成的任务
target_job = get_job(self._finished_jobs, target_source_id, target_job_id)
# 如果找到了已完成的任务
if target_job:
# 记录警告信息
query_job.warning("Query a finished job: {}".format(query_job.content))
# 将查询任务的结果设置为已完成任务的结果
query_job.result = target_job.get_result()
# 将查询任务添加到已完成任务列表中,并返回成功状态
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
# 根据source_id和job_id获取未完成的任务
target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id)
# 如果未找到未完成的任务
if not target_job:
# 更新未完成的任务列表
self.update_raw_finished_jobs(query_job)
# 再次根据source_id和job_id获取未完成的任务
target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id)
# 如果找到了未完成的任务
if target_job:
# 记录调试信息
query_job.debug("Found job in raw finished jobs, source_id:{}, job_id:{}".format(target_source_id,
target_job_id))
# 将查询任务的结果设置为未完成任务的结果
query_job.result = target_job.get_result()
# 从未完成任务列表中删除该任务
del_job(self._raw_finish_jobs, target_job.source_id, target_job.id)
# 将未完成任务添加到已完成任务列表中,并返回成功状态
self.add_to_finished_jobs(target_job, target_job.status)
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
# 根据source_id和job_id获取正在运行的任务
target_job = get_job(self._running_jobs, target_source_id, target_job_id)
# 如果找到了正在运行的任务
if target_job:
# 将查询任务的结果设置为正在运行任务的结果
query_job.result = target_job.get_result()
# 将查询任务添加到已完成任务列表中,并返回成功状态
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
# 根据source_id和job_id获取所有任务
target_job = get_job(self._all_jobs, target_source_id, target_job_id)
# 如果找到了所有任务
if target_job:
# 记录调试信息
query_job.debug("Found job in all jobs, source_id:{}, job_id:{}".format(target_source_id,
target_job_id))
# 记录调试信息
target_job.debug("Be Queried")
# 将查询任务的结果设置为所有任务的结果
query_job.result = target_job.get_result()
# 将查询任务添加到已完成任务列表中,并返回成功状态
return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS)
# 如果没有找到任何任务,记录错误信息
query_job.error("Can't find job in finished/raw_finished/running jobs, source_id: {}".format(target_source_id))
# 将查询任务的结果设置为空
query_job.result = ""
# 将查询任务添加到已完成任务列表中,并返回失败状态
return self.add_to_finished_jobs(query_job, JobStatus.JOB_FAILED)
def _get_job_sys_info(self):
@ -314,10 +444,15 @@ class TbeJobManager:
Get job manager system info
:return: system info
"""
# 创建一个字典,用于存储系统信息
sys_info = dict()
# 将DummyLogger添加到系统信息中
sys_info["logger"] = DummyLogger
# 将para_debug_path添加到系统信息中
sys_info["para_debug_path"] = self.para_debug_path
# 将tune_dump_path添加到系统信息中
sys_info["tune_dump_path"] = self.tune_dump_path
# 将offline_tune添加到系统信息中
sys_info["offline_tune"] = self.offline_tune
# license info
sys_info["rl_tune_switch"] = self.rl_tune_switch
@ -362,12 +497,17 @@ class TbeJobManager:
:param job:
:return:
"""
# 获取计算操作列表
compute_op_info = get_compute_op_list(job.content)[0]
# 获取操作模块名称
op_module_name = compute_op_info["module_name"]
# 如果操作模块名称在已导入模块中,则增加引用次数
if op_module_name in self.imported_module.keys():
self.imported_module[op_module_name] = self.imported_module[op_module_name] + 1
# 否则将操作模块名称添加到已导入模块中并设置引用次数为1
else:
self.imported_module[op_module_name] = 1
# 增加融合需要同步的次数
self.fusion_need_sync = self.fusion_need_sync + 1
def _select_tune_mode(self, job):
@ -376,18 +516,25 @@ class TbeJobManager:
:param job: tbe tune job
:return: NO_TUNE RL_TUNE or GA_TUNE
"""
# 获取job的SocInfo中的autoTilingMode和offlineTune
auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"]
offline_tune = job.content["SocInfo"]["offlineTune"]
# 获取job的full_name
full_name = job.content["full_name"]
# 获取job的func_names
func_names = get_func_names(job.content)
# 如果self.tune_op_list不为空且full_name不在self.tune_op_list中则返回TuneMode.NO_TUNE
if self.tune_op_list and full_name not in self.tune_op_list:
return TuneMode.NO_TUNE
# 如果offline_tune为True则返回TuneMode.RL_TUNE
if offline_tune:
return TuneMode.RL_TUNE
# 如果auto_tiling_mode中包含TuneMode.GA_TUNE.value则遍历func_names如果func_name.lower()在self.auto_tune_op_list中则返回TuneMode.GA_TUNE
if TuneMode.GA_TUNE.value in auto_tiling_mode:
for func_name in func_names:
if func_name.lower() in self.auto_tune_op_list:
return TuneMode.GA_TUNE
# 如果auto_tiling_mode中包含TuneMode.RL_TUNE.value则返回TuneMode.RL_TUNE
if TuneMode.RL_TUNE.value in auto_tiling_mode:
return TuneMode.RL_TUNE
return TuneMode.NO_TUNE
@ -398,15 +545,22 @@ class TbeJobManager:
:param query_job: query job
:return: Node
"""
# 获取已完成任务
new_finished_jobs = get_finish_tasks(query_job.source_id)
# 遍历已完成任务
for new_job in new_finished_jobs:
# 获取任务ID
source_id = new_job["graph_id"]
job_id = new_job["task_id"]
# 获取任务
target_job = get_job(self._running_jobs, source_id, job_id)
# 如果任务不存在,则报错
if not target_job:
query_job.error("Can't get job, source id:{}, job id:{}".format(source_id, job_id))
continue
# 设置任务结果
target_job.result = new_job["op_res"] if "op_res" in new_job else new_job["result"]
# 如果任务类型为预编译任务,则进行预编译
if target_job.type == JobType.PRECOMPILE_JOB:
op_name = target_job.content["fusion_op_name"]
op_params = get_prebuild_output(op_name)
@ -415,13 +569,17 @@ class TbeJobManager:
pre_compile_result["op_params"] = op_params
pre_compile_result["core_type"] = new_job["core_type"] if "core_type" in new_job else ""
target_job.result = json.dumps(pre_compile_result)
# 输出任务结果
target_job.info("Query result:{}".format(new_job["result"]))
# 如果任务状态码为0则任务成功
if new_job["status_code"] == 0:
target_job.status = JobStatus.JOB_SUCCESS
target_job.info("Query info_msg:{}".format(new_job["info_msg"]))
# 否则任务失败
else:
target_job.status = JobStatus.JOB_FAILED
target_job.error("Query info_msg:{}".format(new_job["info_msg"]))
# 输出错误信息
if "err_args" in new_job:
target_job.error("Query err_args:{}".format(new_job["err_args"]))
if "except_msg" in new_job:
@ -429,7 +587,9 @@ class TbeJobManager:
if "except_tuple_msg" in new_job:
target_job.error_manager(new_job["except_tuple_msg"])
target_job.error("\nOriginal compile json: \n {}\n".format(target_job.json_string))
# 将任务添加到已完成任务列表
post_job(self._raw_finish_jobs, target_job)
# 从运行中任务列表中删除任务
del_job(self._running_jobs, target_job.source_id, target_job.id)
def add_to_finished_jobs(self, job, status):
@ -456,8 +616,11 @@ class TbeJobManager:
class TuneMode(Enum):
"""Class of tune mode: NO_TUNE, GA, RL"""
# 不调优模式
NO_TUNE = "NO_TUNE"
# 遗传算法调优模式
GA_TUNE = "GA"
# 强化学习调优模式
RL_TUNE = "RL"
@ -469,18 +632,22 @@ class DummyLogger:
@staticmethod
def debug(msg, *args, **kwargs):
"""Debug级别日志"""
pass
@staticmethod
def info(msg, *args, **kwargs):
"""Info级别日志"""
pass
@staticmethod
def warning(msg, *args, **kwargs):
"""Warning级别日志"""
pass
@staticmethod
def error(msg, *args, **kwargs):
"""Error级别日志"""
pass
@staticmethod
@ -497,10 +664,13 @@ def get_job(jobs, source_id, job_id):
:return: job instance if found in job list
None if not found in job list
"""
# 如果source_id不在jobs的键中返回None
if source_id not in jobs.keys():
return None
# 如果job_id不在jobs[source_id]的键中返回None
if job_id not in jobs[source_id].keys():
return None
# 返回jobs[source_id][job_id]
return jobs[source_id][job_id]
@ -526,9 +696,15 @@ def del_job(jobs, source_id, job_id):
:param job_id: target job's job_id
:return: bool True or False
"""
# 判断source_id是否在jobs字典中
if source_id not in jobs.keys():
# 如果不在返回False
return False
# 判断job_id是否在jobs[source_id]字典中
if job_id not in jobs[source_id].keys():
# 如果不在返回False
return False
# 删除jobs[source_id]字典中的job_id键值对
del jobs[source_id][job_id]
# 返回True
return True

@ -26,6 +26,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
get_ms_class_attr)
# 导入parser模块中的所有函数和类
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type',
'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol',

@ -16,131 +16,136 @@
# ============================================================================
"""Define the namespace of parse."""
import builtins
from mindspore import log as logger
import builtins # 导入内置模块builtins
from mindspore import log as logger # 从mindspore库导入log模块并重命名为logger
class Namespace:
"""
Base class of namespace for resolve variables.
基类用于解析变量命名空间
Args:
name (str): The namespace's name.
dicts (dict): A list of dict containing the namespace's variable.
name (str): 命名空间的名称
dicts (dict): 包含命名空间变量的字典列表
"""
def __init__(self, name, *dicts):
self.name = name
self.dicts = dicts
self.name = name # 初始化命名空间名称
self.dicts = dicts # 初始化包含变量的字典列表
def __contains__(self, name):
# 检查命名空间中是否包含指定名称的变量
for d in self.dicts:
if name in d:
return True
return False
def __getitem__(self, name):
# 获取命名空间中指定名称的变量
for d in self.dicts:
if name in d:
return d[name]
raise NameError(name)
raise NameError(name) # 如果未找到抛出NameError
def __repr__(self):
# 返回命名空间的字符串表示
return f'Namespace:{self.name}'
class CellNamespace(Namespace):
"""
Namespace for Cell object.
Cell对象的命名空间
Args:
name (str): Valid module name, it can be imported.
name (str): 可导入的有效模块名称
"""
def __init__(self, name):
mod_dict = vars(__import__(name, fromlist=['_']))
builtins_dict = vars(builtins)
super().__init__(name, mod_dict, builtins_dict)
mod_dict = vars(__import__(name, fromlist=['_'])) # 导入模块并获取其变量字典
builtins_dict = vars(builtins) # 获取内置模块的变量字典
super().__init__(name, mod_dict, builtins_dict) # 调用父类初始化
def __getstate__(self):
# 获取对象的状态,用于序列化
return (self.name,)
def __setstate__(self, state):
# 设置对象的状态,用于反序列化
name, = state
mod_dict = vars(__import__(name, fromlist=['_']))
builtins_dict = vars(builtins)
super().__init__(name, mod_dict, builtins_dict)
mod_dict = vars(__import__(name, fromlist=['_'])) # 重新导入模块
builtins_dict = vars(builtins) # 重新获取内置模块字典
super().__init__(name, mod_dict, builtins_dict) # 重新初始化父类
class ClosureNamespace(Namespace):
"""
Namespace for function closure.
函数闭包的命名空间
Args:
fn (Function): A python function.
fn (Function): 一个Python函数
"""
def __init__(self, fn):
name = f'{fn.__module__}..<{fn.__name__}>'
names = fn.__code__.co_freevars
cells = fn.__closure__
ns = dict(zip(names, cells or ()))
super().__init__(name, ns)
name = f'{fn.__module__}..<{fn.__name__}>' # 构造命名空间名称
names = fn.__code__.co_freevars # 获取函数的自由变量名称
cells = fn.__closure__ # 获取函数的闭包
ns = dict(zip(names, cells or ())) # 构造命名空间字典
super().__init__(name, ns) # 调用父类初始化
def __getitem__(self, name):
# 获取命名空间中指定名称的变量
d, = self.dicts
try:
return d[name].cell_contents
return d[name].cell_contents # 返回闭包内容
except ValueError:
raise UnboundLocalError(name)
raise UnboundLocalError(name) # 如果未找到抛出UnboundLocalError
class ClassMemberNamespace(Namespace):
"""
Namespace of a class's closure.
类闭包的命名空间
Args:
obj (Object): A python class object.
obj (Object): 一个Python类对象
"""
def __init__(self, obj):
self.__class_member_namespace__ = True
label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>'
super().__init__(label, obj)
self.__class_member_namespace__ = True # 标记为类成员命名空间
label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' # 构造命名空间标签
super().__init__(label, obj) # 调用父类初始化
def __getitem__(self, name):
# 获取命名空间中指定名称的变量
d, = self.dicts
if name == "self":
return d
return d # 如果名称是self返回对象本身
if name == "namespace":
return self
return self # 如果名称是namespace返回命名空间对象
try:
if hasattr(d, name):
return getattr(d, name)
return d.__dict__[name]
return getattr(d, name) # 如果对象有该属性,返回属性值
return d.__dict__[name] # 否则从对象字典中获取
except ValueError:
raise UnboundLocalError(name)
raise UnboundLocalError(name) # 如果未找到抛出UnboundLocalError
except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name)
raise AttributeError(name) # 如果未找到属性记录日志并抛出AttributeError
class ClassAttrNamespace(Namespace):
"""
Namespace of a class.
类的命名空间
Args:
obj (Object): A python class object.
obj (Object): 一个Python类对象
"""
def __init__(self, obj):
name = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>'
super().__init__(name, obj)
name = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' # 构造命名空间名称
super().__init__(name, obj) # 调用父类初始化
def __getattr__(self, name):
# 获取命名空间中指定名称的属性
d, = self.dicts
try:
if hasattr(d, name):
return getattr(d, name)
return d.__dict__[name]
return getattr(d, name) # 如果对象有该属性,返回属性值
return d.__dict__[name] # 否则从对象字典中获取
except ValueError:
raise UnboundLocalError(name)
raise UnboundLocalError(name) # 如果未找到抛出UnboundLocalError
except KeyError:
logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.")
raise AttributeError(name)
raise AttributeError(name) # 如果未找到属性记录日志并抛出AttributeError

@ -17,7 +17,7 @@
"""Resources for ast tree parse."""
import ast
import math
from mindspore import RowTensor, SparseTensor, COOTensor, CSRTensor
from mindspore.ops import functional as F, composite as C
from mindspore.ops.composite import multitype_ops
@ -25,16 +25,16 @@ from mindspore._c_expression import security
from . import standard_method as M
from . import trope as T
from .namespace import CellNamespace
# namespace define
functional_ns = CellNamespace('mindspore.ops.functional')
composite_ns = CellNamespace('mindspore.ops.composite')
trope_ns = CellNamespace('mindspore._extends.parse.trope')
NO_IMPLEMENT = None # not implemented
SYMBOL_UNDEFINE = 0xFF # Undefined var and function
# Some space set aside for readability of code
# 一些空间设置以提高代码可读性
parse_object_map = {
# ast grammar
ast.Add: (trope_ns, 'add'),
@ -64,17 +64,17 @@ parse_object_map = {
ast.IsNot: (trope_ns, 'is_not'),
ast.In: (trope_ns, 'contains'),
ast.NotIn: (trope_ns, 'not_contains'),
# operation symbol type
'getitem': (composite_ns, 'getitem'),
'ms_iter': (composite_ns, 'ms_iter'),
'ms_next': (composite_ns, 'ms_next'),
'hasnext': (composite_ns, 'hasnext'),
# undefined type
SYMBOL_UNDEFINE: (None, 'undefine'),
}
# Operation symbols corresponding to ast grammar
ops_symbol_map = {
# ast grammar
@ -88,13 +88,13 @@ ops_symbol_map = {
ast.LShift: '<<',
ast.RShift: '>>',
ast.BitXor: '^',
# undefined type
SYMBOL_UNDEFINE: '',
}
# Escape an object to another object, eg: system function(len,xxx)
# Some space set aside for readability of code
# 将一个对象转为另一个对象,例如:系统函数(len,xxx)
# 一些空间设置以提高代码可读性
convert_object_map = {
T.add: multitype_ops.add,
T.sub: multitype_ops.sub,
@ -124,7 +124,7 @@ convert_object_map = {
T.is_not: F.is_not,
T.contains: multitype_ops.in_,
T.not_contains: multitype_ops.not_in_,
# system function
T.len: M.ms_len,
T.bool_: M.bool_,
@ -134,7 +134,7 @@ convert_object_map = {
T.zip: C.zip_operation,
T.enumerate: M.enumerate_,
T.isinstance: M.isinstance_,
# custom define operation
T.iter: M.ms_iter,
T.next: M.ms_next,
@ -145,7 +145,7 @@ convert_object_map = {
T.make_slice: F.make_slice,
T.range: F.make_range,
T.while_cond: M.while_cond,
# lib function
math.floor: NO_IMPLEMENT,
math.trunc: NO_IMPLEMENT,
@ -154,13 +154,14 @@ convert_object_map = {
math.sin: NO_IMPLEMENT,
math.cos: NO_IMPLEMENT,
math.tan: NO_IMPLEMENT,
# user defined
RowTensor: F.make_row_tensor,
SparseTensor: F.make_sparse_tensor,
COOTensor: F.make_coo_tensor,
CSRTensor: F.make_csr_tensor
}
# 如果不启用安全性,则将 T.print 映射到 F.print_
if not security.enable_security():
convert_object_map[T.print] = F.print_
convert_object_map[T.print] = F.print_

@ -50,55 +50,45 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
def MakeTuple(*elts): # pragma: no cover
"""Tuple builder."""
"""Tuple builder.""" # 创建元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def make_dict(key, value): # pragma: no cover
"""Dict builder."""
"""Dict builder.""" # 创建字典的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def make_list(*elts): # pragma: no cover
"""List builder."""
"""List builder.""" # 创建列表的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def make_slice(*elts): # pragma: no cover
"""Slice builder."""
"""Slice builder.""" # 创建切片的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def make_range(*elts): # pragma: no cover
"""Range tuple builder."""
"""Range tuple builder.""" # 创建范围元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def switch(cond, tb, fb): # pragma: no cover
"""Switch statement, returns one of the two values."""
"""Switch statement, returns one of the two values.""" # 返回两个值中的一个的开关语句
raise RuntimeError('This operation is not meant to be called directly.')
def hasnext(it): # pragma: no cover
"""Hasnext function."""
"""Hasnext function.""" # 判断是否有下一个元素的函数
raise RuntimeError('This operation is not meant to be called directly.')
def to_array(x):
"""The to_array function."""
"""The to_array function.""" # 将输入转换为数组的函数
raise RuntimeError('This operation is not meant to be called directly.')
def not_contains(x): # pragma: no cover
"""Not in function."""
"""Not in function.""" # 判断元素是否不在集合中的函数
raise RuntimeError('This operation is not meant to be called directly.')
def while_cond(x): # pragma: no cover
"""Not in function."""
"""Not in function.""" # 判断条件是否成立的函数
raise RuntimeError('This operation is not meant to be called directly.')
def bool_(x): # pragma: no cover
"""judge true function."""
"""judge true function.""" # 判断一个值是否为真值的函数
raise RuntimeError('This operation is not meant to be called directly.')

@ -16,27 +16,37 @@
import os
from mindspore import log as logger
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
class Messager:
'''Messager'''
def __init__(self, fdin, fdout):
"""
初始化 Messager
Args:
fdin: 输入文件描述符
fdout: 输出文件描述符
"""
self.fdin = fdin
self.fdout = fdout
self.fin = os.fdopen(fdin, "r")
self.fout = os.fdopen(fdout, "w")
self.message = ''
def __del__(self):
"""
删除 Messager 实例时关闭文件描述符
"""
os.close(self.fdin)
os.close(self.fdout)
def get_message(self):
"""
Get message from remote
从远程获取消息
Returns:
message
"""
@ -58,13 +68,13 @@ class Messager:
self.send_ack()
self.exit()
return self.message
def send_res(self, res, keep_format=True):
"""
Send result to remote
发送结果到远程
Args:
keep_format: True or False
keep_format: True False
"""
logger.debug(f"[OUT] {str(res)}")
if keep_format:
@ -72,7 +82,7 @@ class Messager:
else:
res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '')
tag = '[~]' # The same as client kTAG
# Not write by print(tag + res_str, flush=True) any more
try:
self.fout.write(tag + res_str + "\n")
@ -82,69 +92,76 @@ class Messager:
self.exit()
finally:
pass
def send_ack(self, success=True):
"""
Send ack to remote
发送确认消息到远程
Args:
success: True or False
success: True False
"""
if success:
self.send_res('ACK')
else:
self.send_res('ERR')
def loop(self):
"""
Messaging loop
消息循环
"""
while True:
self.handle()
def run(self):
"""运行消息循环"""
self.loop()
def handle(self):
"""
A interface communicates with remote.
与远程通信的接口
Note:
All subclasses should override this interface.
所有子类应该重写此接口
"""
raise NotImplementedError
def exit(self):
"""
A interface handles the procedure before exit.
处理退出之前的程序
Note:
All subclasses should override this interface.
所有子类应该重写此接口
"""
raise NotImplementedError
class AkgBuilder():
"""Akg building wrapper"""
def __init__(self, platform):
"""
初始化 AkgBuilder
Args:
platform: 平台标识
"""
self.platform = platform
self.attrs = None
def create(self, process_num, waitime):
""" Create akg processor"""
""" 创建 akg 处理器"""
self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform)
def accept_json(self, json):
""" Accept json"""
""" 接受 json 数据"""
return self.akg_processor.accept_json(json)
def compile(self):
"""Compile"""
"""编译"""
return self.akg_processor.compile(self.attrs)
def handle(self, messager, arg):
"""Handle message about akg"""
"""处理关于 akg 的消息"""
if arg == 'AKG/START':
messager.send_ack()
process_num_str = messager.get_message()
@ -172,7 +189,8 @@ class AkgBuilder():
break
else:
raise RuntimeError("Unknown message type: %s" % arg)
def get_logger():
return logger
"""获取日志记录器"""
return logger

@ -20,19 +20,24 @@ from mindspore._extends.remote.kernel_build_server import Messager, get_logger,
class AkgMessager(Messager):
'''
Default Messager for akg kernels.
It works as a server, communicating with c++ client.
默认的 akg 内核消息处理器
它作为一个服务器 C++ 客户端进行通信
'''
def __init__(self, fdin, fdout):
"""
初始化 AkgMessager 实例
:param fdin: 输入文件描述符
:param fdout: 输出文件描述符
"""
super().__init__(fdin, fdout)
get_logger().info("[TRACE] Akg Messager init...")
self.akg_builder = AkgBuilder("default")
def handle(self):
"""
Communicate with remote client.
Reference protocol between them at PR#4063
与远程客户端进行通信
它们之间的参考协议见 PR#4063。
"""
arg = self.get_message()
if "AKG" in arg:
@ -42,11 +47,18 @@ class AkgMessager(Messager):
self.exit()
def exit(self):
"""
退出 AkgMessager
"""
get_logger().info("[TRACE] Akg Messager Exit...")
exit()
if __name__ == '__main__':
"""
程序入口
检查命令行参数并初始化 AkgMessager 实例
"""
warnings.simplefilter("ignore")
if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv))

@ -16,23 +16,24 @@
import sys
import warnings
import json
from mindspore._extends.parallel_compile.tbe_compiler.tbe_job_manager import TbeJobManager
from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder
class AscendMessager(Messager):
"""
Ascend Messager
It works as a server, communicating with c++ client.
"""
# 初始化方法
def __init__(self, fdin, fdout):
super().__init__(fdin, fdout)
get_logger().info("[TRACE] Ascend Messager init...")
self.tbe_builder = TbeJobManager()
self.akg_builder = AkgBuilder("ASCEND")
# 处理与远程客户端的通信
def handle(self):
"""
Communicate with remote client.
@ -51,7 +52,7 @@ class AscendMessager(Messager):
self.exit()
finally:
pass
if "job_type" in job_json:
res = self.tbe_builder.job_handler(arg)
self.send_res(res)
@ -59,17 +60,18 @@ class AscendMessager(Messager):
get_logger().error("[TRACE] Request is not a TBE Job message: {}".format(arg))
self.send_ack(False)
self.exit()
# 退出方法
def exit(self):
self.tbe_builder.reset()
get_logger().info("[TRACE] Ascend Messager Exit...")
exit()
if __name__ == '__main__':
warnings.simplefilter("ignore")
if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv))
get_logger().debug(f"[TRACE] argv: {str(sys.argv)}")
messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2]))
messager.run()
messager.run()

@ -22,6 +22,21 @@ def cell_attr_register(fn=None, attrs=None):
"""
Cell init attributes register.
Args:
fn (function, optional): The __init__ function of the cell. Defaults to None.
attrs (list(string) | string, optional): A list of attributes to register.
Can be a list of strings or a single string. Defaults to None.
Returns:
function: The original function wrapped with attribute registration.
该函数用于注册cell类的初始化属性
通过装饰器模式将cell类的__init__函数的参数保存为operator的属性
如果未提供fn参数则返回装饰器函数wrap_cell否则返回包装后的__init__函数
"""
"""
Cell init attributes register.
Registering the decorator of the built-in operator cell __init__
function will add save all the parameters of __init__ as operator attributes.
@ -34,8 +49,38 @@ def cell_attr_register(fn=None, attrs=None):
"""
def wrap_cell(fn):
"""
装饰器函数用于记录类的初始化参数
Args:
fn (function): 需要被装饰的函数
Returns:
function: 返回一个新的函数该函数在调用时会记录传递给fn函数的参数
"""
@wraps(fn)
def deco(self, *args, **kwargs):
"""
这是一个装饰器函数用于记录类的初始化参数
Args:
self: 类实例对象
*args: 传递给被装饰函数的可变位置参数
**kwargs: 传递给被装饰函数的可变关键字参数
attrs: 可选参数指定要记录的属性可以是字符串或字符串列表
Returns:
None
Raises:
ValueError: 如果attrs不是字符串或字符串列表或者attrs中的元素不是字符串时抛出
该函数的主要作用是在类实例初始化时记录传递给__init__方法的参数
如果attrs为None则记录所有传递给__init__方法的参数不包括self
如果attrs为字符串或字符串列表则只记录指定的属性
记录的参数将被保存为实例的cell_init_args属性格式为"类名+参数列表"
"""
arguments = []
if attrs is None:
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)

@ -19,16 +19,25 @@ accumulation and so on.
Note:
This feature is a beta feature, and we are still improving its functionality.
"""
# 从当前包的boost模块导入AutoBoost类
from .boost import AutoBoost
# 从当前包的base模块导入OptimizerProcess和ParameterProcess类
from .base import OptimizerProcess, ParameterProcess
# 从当前包的boost_cell_wrapper模块导入BoostTrainOneStepCell和BoostTrainOneStepWithLossScaleCell类
from .boost_cell_wrapper import BoostTrainOneStepCell, BoostTrainOneStepWithLossScaleCell
# 从当前包的less_batch_normalization模块导入LessBN类
from .less_batch_normalization import LessBN
# 从当前包的grad_freeze模块导入GradientFreeze, FreezeOpt和freeze_cell类或函数
from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell
# 从当前包的grad_accumulation模块导入GradientAccumulation类
from .grad_accumulation import GradientAccumulation
# 从当前包的adasum模块导入AdaSum类
from .adasum import AdaSum
# 从当前包的dim_reduce模块导入DimReduce类
from .dim_reduce import DimReduce
# 定义一个列表,包含所有要公开的模块成员
__all__ = ['AutoBoost',
'OptimizerProcess', 'ParameterProcess',
'BoostTrainOneStepCell', 'BoostTrainOneStepWithLossScaleCell',

@ -22,38 +22,41 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import Send, Receive
__all__ = ["AdaSum"]
MAX_NUM_HASH = 2 ** 31
_update_parameters = C.MultitypeFuncGraph("update_parameters")
@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor")
def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter):
"""更新参数的函数在广播后应用delta_weight来更新参数."""
shape = F.shape(delta_weight)
update_delta_weight = P.Reshape()(update_delta_weight, shape)
new_parameter = old_parameter - update_delta_weight
return P.Assign()(parameter, new_parameter)
def _send_before_receive(send_part, send, recv):
"""在接收之前发送数据的辅助函数."""
send_ok = send(send_part)
return recv(send_ok)
def _receive_before_send(send_part, send, recv):
"""在发送之前接收数据的辅助函数."""
receive_ok = recv(send_part)
send_part = F.depend(send_part, receive_ok)
return F.depend(receive_ok, send(send_part))
def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
"""send result and receive result."""
"""发送结果并接收结果的辅助函数."""
if parameter_divisibility:
recv_part = P.Squeeze()(recv_part)
local_part = F.depend(local_part, recv_part)
@ -76,14 +79,14 @@ def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisi
res = allreduce(local_part)
res /= allreduce_node_num
return res
_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
"""adasum optimizer process."""
"""adaSum优化器的前向过程处理函数."""
if parameter_divisibility:
delta_w = P.Squeeze()(delta_w)
ori_len = F.shape(delta_w)[0]
@ -93,7 +96,7 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al
else:
left_part = delta_w
right_part = delta_w
if left_send:
if parameter_divisibility:
recv_part = _send_before_receive(left_part, send, recv)
@ -108,26 +111,26 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al
recv_part = left_part
update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility,
allreduce_node_num)
return update_delta_w
_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
"""adasum optimizer rollback process."""
"""adaSum优化器的回滚处理函数."""
if parameter_divisibility:
if left_send:
recv_part = _send_before_receive(delta_w, send, recv)
else:
recv_part = _receive_before_send(delta_w, send, recv)
recv_part = P.Squeeze()(recv_part)
recv_part = P.Reshape()(recv_part, (-1,))
delta_w = P.Reshape()(delta_w, (-1,))
if left_send:
res = P.Concat()((recv_part, delta_w))
else:
@ -135,28 +138,28 @@ def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, sen
else:
res = delta_w
return res
class AdaSum(Cell):
r"""
The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data
parallel training of Deep Learning models.
自适应加法AdaSum是一种新算法用于改善深度学习模型的分布式数据并行训练
Args:
rank (int): Rank number.
device_number (int): Device number.
group_number (int): Group number.
parameter_tuple (Tuple(Parameter)): Tuple of parameters.
rank (int): 排名编号
device_number (int): 设备数量
group_number (int): 组数量
parameter_tuple (Tuple(Parameter)): 参数元组
Inputs:
- **delta_weights** (Tuple(Tensor)) - Tuple of gradients.
- **parameters** (Tuple(Parameter)) - Tuple of current parameters.
- **old_parameters** (Tuple(Parameter)) - Tuple of last parameters.
- **delta_weights** (Tuple(Tensor)) - 梯度的元组
- **parameters** (Tuple(Parameter)) - 当前参数的元组
- **old_parameters** (Tuple(Parameter)) - 上一参数的元组
Outputs:
- **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process.
- **adasum_parameters** (Tuple(Tensor)) - 经过adasum处理后的参数元组
"""
def __init__(self, rank, device_number, group_number, parameter_tuple):
"""AdaSum类的初始化函数."""
super(AdaSum, self).__init__()
self.rank = rank
self.device_number = device_number
@ -164,9 +167,9 @@ class AdaSum(Cell):
self.parameter_tuple = parameter_tuple
self._generate_communication_op()
self.hyper_map = C.HyperMap()
def _generate_communication_op(self):
"""generate communication op."""
"""生成通信操作的私有方法."""
self.calc_times = int(math.log(self.group_number, 2))
self.send_node = []
self.send_list_forward = []
@ -179,7 +182,7 @@ class AdaSum(Cell):
self.allreduce_node_num_list = []
last_delta_weights = []
group_start_rank = (self.rank // self.device_number) * self.device_number
for step in range(self.calc_times):
current_group = self.device_number * (2 ** step)
sr_target = self.rank
@ -189,7 +192,7 @@ class AdaSum(Cell):
else:
dest_target = sr_target - current_group
self.send_node.append(False)
neighbor_ids = []
group_name_last = 0
for index in range(2 ** (step + 1)):
@ -201,7 +204,7 @@ class AdaSum(Cell):
group_name_last += neighbor_id
group_name = "adasum_" + str(step) + "_" + str(group_name_last)
create_group(group_name, neighbor_ids)
send_left = []
send_right = []
recv_left = []
@ -234,7 +237,7 @@ class AdaSum(Cell):
send_right.append(send)
recv_right.append(recv)
weights_index += 1
if self.send_node and self.send_node[-1]:
self.send_list_forward.append(send_left)
self.send_list_rollback.append(send_right)
@ -247,27 +250,27 @@ class AdaSum(Cell):
self.recv_list_forward.append(recv_left)
self.recv_list_rollback.append(recv_right)
last_delta_weights = left_delta_weights
server_all_reduce = P.AllReduce("sum", group_name)
server_all_reduce.add_prim_attr("fusion", fusion_id + 2)
self.allreduce_list.append(server_all_reduce)
for param_divisibility in delta_weights_divisibility:
if param_divisibility:
allreduce_node_num += (0,)
else:
allreduce_node_num += (2 ** (step + 1),)
self.allreduce_node_num_list.append(allreduce_node_num)
broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)]
broadcast_group_name = "broadcast_group_" + str(group_start_rank)
create_group(broadcast_group_name, broadcast_group)
for b_rank in range(len(broadcast_group)):
self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name))
self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name)
def _get_delta_weights_info(self, last_delta_weights):
"""get delta weights info."""
"""获取delta权重信息的私有方法."""
half_delta_weights = []
if last_delta_weights:
half_delta_weights = last_delta_weights
@ -292,14 +295,16 @@ class AdaSum(Cell):
right_delta_weights.append((right_shape, dtype))
delta_weights_divisibility += (divisibility_flag,)
return left_delta_weights, right_delta_weights, delta_weights_divisibility
def _hash(self, step, target, weights_index):
"""计算哈希值的私有方法."""
target = "tag" + str(step) + str(target) + str(weights_index)
target_hash = hashlib.sha1(target.encode()).hexdigest()
hash_res = int(int(target_hash, 16) % MAX_NUM_HASH)
return hash_res
def construct(self, delta_weights, parameters, old_parameters):
"""构建方法用于执行adaSum优化过程."""
forward_weights = [delta_weights]
for i in range(self.calc_times):
process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),
@ -314,4 +319,4 @@ class AdaSum(Cell):
forward_weights[j] = process_weights
adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],
parameters, old_parameters)
return adasum_parameters
return adasum_parameters

@ -452,11 +452,14 @@ class _GeneratorWorkerMp(multiprocessing.Process):
"""
def __init__(self, dataset, eof, max_rowsize, queue_size, ppid):
# 初始化一个多进程队列,用于存储索引
self.idx_queue = multiprocessing.Queue(queue_size)
# 如果启用了共享内存,则初始化一个共享队列,否则初始化一个多进程队列
if get_enable_shared_mem():
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
else:
self.res_queue = multiprocessing.Queue(queue_size)
# 设置队列的_joincancelled属性为True表示在进程退出时队列不会阻塞
self.idx_queue._joincancelled = True # pylint: disable=W0212
self.res_queue._joincancelled = True # pylint: disable=W0212
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid))
@ -465,6 +468,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
"""
Put function for worker index queue. Never block. Raise queue.Full on failure.
"""
# 将item放入idx_queue队列中不阻塞如果失败则抛出queue.Full异常
self.idx_queue.put_nowait(item)
def get(self):
@ -476,12 +480,19 @@ class _GeneratorWorkerMp(multiprocessing.Process):
return self.res_queue.get(timeout=30)
def queue_empty(self):
# 检查idx_queue是否为空
if not self.idx_queue.empty():
# 如果不为空,记录警告日志
logger.warning("idx_queue is not empty.")
# 返回False
return False
# 检查res_queue是否为空
if not self.res_queue.empty():
# 如果不为空,记录警告日志
logger.warning("res_queue is not empty.")
# 返回False
return False
# 如果两个队列都为空返回True
return True
def __del__(self):
@ -632,14 +643,17 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
python_multiprocessing=True, max_rowsize=6):
# 调用父类的初始化方法
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
# 如果source是zip类型则将其转换为列表
if isinstance(source, builtins.zip):
# Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array.
self.source = [item for item in source]
else:
self.source = source
self.prepared_source = None # source to be sent to C++
# 如果self.operator_mixed属性为True则将num_parallel_workers设置为1
if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
self.num_parallel_workers = 1
logger.warning(
@ -650,56 +664,78 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
self.python_multiprocessing = python_multiprocessing
# 将column_names转换为列表
self.column_names = to_list(column_names)
# 如果column_types不为空则将其转换为detypelist类型
if column_types is not None:
self.column_types = mstypelist_to_detypelist(column_types)
else:
self.column_types = []
self.schema = schema
# 如果schema不为空则将其转换为Schema类型
if schema is not None:
# 如果schema不为空则将其赋值给self.schema
self.schema = schema
# 如果schema不是Schema类型则将其转换为Schema类型
if not isinstance(schema, Schema):
self.schema = Schema(schema)
# Move get dataset_size by len from parse to here, because self.source will
# lose attribution of '__len__' after deepcopy.
self.source_len = -1 # unknown
# 如果self.source有__len__属性则获取self.source的长度
if hasattr(self.source, "__len__"):
self.source_len = len(self.source)
# 设置最大行大小
self.max_rowsize = max_rowsize
# 设置采样函数为None
self.sample_fn = None
def __deepcopy__(self, memodict):
# 深度复制当前对象并传入一个字典memodict用于存储已经复制的对象
if id(self) in memodict:
# 如果当前对象的id已经在memodict中则直接返回该对象
return memodict[id(self)]
# 否则调用__safe_deepcopy__方法进行深度复制并传入memodict和exclude参数
new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
sample_fn = None
# 如果新对象的sampler属性不为空并且self.source对象具有__getitem__方法
if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
# The reason why there is a try catch here is because when the new op is being constructed with shared
# memory enabled, there will be an exception thrown if there is not enough shared memory available
# 如果self.source_len为-1则抛出RuntimeError异常因为尝试构造一个随机访问的数据集需要__len__方法
if self.source_len == -1:
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
try:
# 如果新对象的num_parallel_workers大于1则调用__validate_memory_usage方法进行内存使用验证
if new_op.num_parallel_workers > 1:
self.__validate_memory_usage()
# 创建一个SamplerFn对象用于并行采样
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
self.max_rowsize)
# 将新对象的prepared_source属性设置为_cpp_sampler_fn_mp函数用于并行采样
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else:
# 否则将新对象的prepared_source属性设置为_cpp_sampler_fn函数用于单线程采样
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
# 将新对象的sample_fn属性设置为sample_fn
new_op.sample_fn = sample_fn
except RuntimeError as e:
# 如果抛出RuntimeError异常则抛出Exception异常并传入异常信息
raise Exception(str(e))
else:
try:
# 否则将新对象的sampler属性设置为Nonesample_fn属性设置为sample_fn
new_op.sampler = None
new_op.sample_fn = sample_fn
# 将新对象的source_len属性设置为min(new_op.source_len, new_op.num_samples)如果new_op.num_samples不为0否则设置为new_op.source_len
new_op.source_len = min(new_op.source_len,
new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len
# 遍历self.source对象
iter(self.source)
except TypeError:
# Use generator function if input callable
@ -711,19 +747,26 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
return new_op
# 判断是否被洗牌
def is_shuffled(self):
return self.sampler.is_shuffled()
# 判断是否被分片
def is_sharded(self):
return self.sampler.is_sharded()
# 解析
def parse(self, children=None):
# 如果schema为空则返回GeneratorNode对象
if self.schema is None:
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
self.sampler, self.num_parallel_workers)
# 获取schema
schema = self.schema
# 如果schema是Schema类型则获取cpp_schema
if isinstance(schema, Schema):
schema = self.schema.cpp_schema
# 返回GeneratorNode对象
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
self.num_parallel_workers)
@ -735,24 +778,37 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
# if use num_parallel_workers is to large when python_multiprocessing=True which would cause
# OOM error get the num_shards
valid_num_shards = 1
# 判断self.sampler是否为samplers.DistributedSampler类型
if isinstance(self.sampler, samplers.DistributedSampler):
# 如果是则将self.sampler的num_shards赋值给valid_num_shards
valid_num_shards = self.sampler.num_shards
# 否则判断self.num_shards是否为None
elif self.num_shards is not None:
# 如果不是则将self.num_shards赋值给valid_num_shards
valid_num_shards = self.num_shards
# get process memory usage
# 获取当前进程
process = psutil.Process(os.getpid())
# 获取当前进程的内存信息
process_memory = process.memory_info().rss
# 获取系统内存的空闲量
sys_memory_free = psutil.virtual_memory().free
# 计算可能使用的总内存量
total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards
# 如果总内存可能使用的内存量除以系统可用内存大于0.85
if total_memory_maybe_used / sys_memory_free > 0.85:
# 计算有效的worker数量即系统可用内存乘以0.85除以有效的shards数量再除以每个进程的内存
valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory)
# 如果有效的worker数量小于等于0则将其设置为1
valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker
# 构造警告信息提示用户num_parallel_workers设置过大可能会导致内存占用过高或OOM建议将其减小到valid_num_worker或更小
info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \
"occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \
"to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers,
valid_num_worker)
# 打印警告信息
logger.warning(info)
@ -764,37 +820,55 @@ class _NumpySlicesDataset:
def __init__(self, data, column_list=None):
self.column_list = None
# Convert dict data into tuple
# 判断data是否为字典类型
if isinstance(data, dict):
# 如果是字典类型则调用process_dict方法处理
data = self.process_dict(data)
# 判断data是否为元组类型
if isinstance(data, tuple):
# 如果是元组类型则将self.data初始化为空元组
self.data = ()
# 获取data的长度
data_len = len(data)
# 遍历data中的每个元素
for i in range(data_len):
# 将data中的每个元素转换为numpy数组并添加到self.data中
self.data = self.data + (np.array(data[i]),)
else:
# 如果data不是元组类型则将data转换为numpy数组并添加到self.data中
self.data = (np.array(data),)
# check whether the data length in each column is equal
# 获取每个data_item的长度
data_len = [len(data_item) for data_item in self.data]
# 如果每个data_item的长度不相等则抛出ValueError异常
if data_len[1:] != data_len[:-1]:
raise ValueError("Data length in each column is not equal.")
# Init column_name
# 如果column_list不为空则将self.column_list赋值为column_list
if column_list is not None:
self.column_list = column_list
# 如果self.column_list为空则将self.column_list赋值为空列表
elif self.column_list is None:
self.column_list = []
# 获取data的列数
column_num = len(self.data)
# 遍历列数,将"column_" + str(i)添加到self.column_list中
for i in range(column_num):
self.column_list.append("column_" + str(i))
def __getitem__(self, index):
# 获取指定索引的数据行
data_row = [d[index, ...] for d in self.data]
# 将数据行转换为元组
data_res = tuple(data_row)
# 返回数据行
return data_res
def __len__(self):
# 返回data的第一个元素的长度
return len(self.data[0])
def process_dict(self, input_data):
@ -802,24 +876,29 @@ class _NumpySlicesDataset:
Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
"""
# Convert pandas like dict(has "values" column) into General dict
# 将pandas样式的字典有"values"列)转换为通用字典
data_keys = list(input_data.keys())
# 获取字典的第一个键对应的值
data_col = input_data[data_keys[0]]
# 如果值有values属性则将其转换为通用字典
if hasattr(data_col, "values"):
new_dict = {}
for key in data_keys:
# 将字典中的键对应的值转换为列表
item1 = input_data.pop(key)
new_dict[key] = item1.values
# 将转换后的字典赋值给input_data
input_data = new_dict
# Convert the data in dict into tuple
data = ()
keys = list(input_data.keys())
self.column_list = keys
for key in keys:
value = input_data[key]
data = data + (list(value),)
data = () # 初始化一个空元组
keys = list(input_data.keys()) # 将输入数据的键转换为列表
self.column_list = keys # 将键列表赋值给实例变量column_list
for key in keys: # 遍历键列表
value = input_data[key] # 获取键对应的值
data = data + (list(value),) # 将值转换为列表,并添加到元组中
return data
return data # 返回元组
class NumpySlicesDataset(GeneratorDataset):
@ -909,7 +988,9 @@ class NumpySlicesDataset(GeneratorDataset):
@check_numpyslicesdataset
def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
num_shards=None, shard_id=None):
# 创建一个_NumpySlicesDataset对象传入data和column_names参数
dataset = _NumpySlicesDataset(data, column_names)
# 调用父类的__init__方法传入dataset、column_names、num_samples、num_parallel_workers、shuffle、sampler、num_shards和shard_id参数
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id)

@ -17,65 +17,71 @@ from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithCheck
from .. import signature as sig
class UpdateCache(PrimitiveWithCheck):
"""
Update the value fo input_x, similar to ScatterNdUpdate.
The difference is that UpdateCache will not update when indices < 0 or indices >= max_num.
更新 input_x 的值类似于 ScatterNdUpdate
不同之处在于UpdateCache indices < 0 indices >= max_num 时不会更新
Inputs:
- **input_x** (Parameter) - Parameter which is going to be updated.
- **indices** (Tensor) - Update indices of input_x.
- **updates** (Tensor) - The update values.
- **input_x** (Parameter) - 将要更新的参数
- **indices** (Tensor) - input_x 的更新索引
- **updates** (Tensor) - 更新值
Outputs:
- **out** (Tensor) - Returns a [1] Tensor, which is not useful.
- **out** (Tensor) - 返回一个 [1] 的张量这个张量没有用处
"""
# 定义函数签名,指定输入参数的类型和读写权限
__mindspore_signature__ = (
# 定义输入参数input_x类型为T读写权限为写
sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
dtype=sig.sig_dtype.T),
# 定义输入参数indices类型为T1
sig.make_sig('indices', dtype=sig.sig_dtype.T1),
# 定义输入参数updates类型为T
sig.make_sig('updates', dtype=sig.sig_dtype.T),
# 定义输入参数max_num类型为T1
sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
)
@prim_attr_register
def __init__(self):
"""init UpdateCache"""
"""初始化 UpdateCache"""
# 初始化输入和输出名称
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
outputs=['out'])
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
# 检查输入形状
return [1]
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
# 检查输入数据类型
validator.check_tensor_dtype_valid(
"indices", indices_dtype, mstype.int_type, self.name)
return input_x_dtype
class SubAndFilter(PrimitiveWithCheck):
"""
Dynamic kernel, sub an offset and
return the elements which in range [0, max_num).
动态内核减去一个偏移量并返回在范围 [0, max_num) 内的元素
Inputs:
- **input_x** (Tensor) - Input tensor.
- **max_num** (Int) - The max value of element that after sub `offset`.
- **offset** (int) - Specifies the offset value of this `input_x`.
- **input_x** (Tensor) - 输入张量
- **max_num** (Int) - 减去 `offset` 后元素的最大值
- **offset** (int) - 指定此 `input_x` 的偏移值
Outputs:
tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx.
- **filter_res** (Tensor) - The result that `input_x` minus `offset`,
and return which in the range [0, max_num).
- **filter_idx** (Tensor) - A tensor containing indices of elements in the input
coressponding to the output tensor.
tuple(Tensor), 2 个张量组成的元组filter_res filter_idx
- **filter_res** (Tensor) - `input_x` 减去 `offset` 的结果
并返回在范围 [0, max_num) 内的值
- **filter_idx** (Tensor) - 一个张量包含与输出张量对应的输入元素的索引
Supported Platforms:
`CPU`
Examples:
>>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32)
>>> max_num = 10
@ -87,35 +93,38 @@ class SubAndFilter(PrimitiveWithCheck):
"""
@prim_attr_register
def __init__(self):
"""init SubAndFilter"""
"""初始化 SubAndFilter"""
# 初始化输入和输出名称
self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'],
outputs=['sub_res', 'sub_idx'])
def check_shape(self, input_x_shape, max_num_shape, offset_shape):
# 检查输入形状
return ((-1,), (-1,))
def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype):
# 检查输入数据类型
validator.check_tensor_dtype_valid(
"input_x", input_x_dtype, mstype.int_type, self.name)
return input_x_dtype
class MapUniform(PrimitiveWithCheck):
"""
Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`.
通过公式映射一个张量value = key % `group_num` * `per_group_size` + key // `group_num`
Inputs:
- **input** (Tensor) - Input Tensor.
- **per_group_size** (int) - The size of each group.
- **group_num** (int) - The number of group.
- **input** (Tensor) - 输入张量
- **per_group_size** (int) - 每个组的大小
- **group_num** (int) - 组的数量
Outputs:
Tensor, has the same dtype and shape as the `input`.
Tensor具有与 `input` 相同的 dtype 和形状
Supported Platforms:
`CPU`
Examples:
>>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7]))
>>> per_group_size = 4
@ -125,33 +134,34 @@ class MapUniform(PrimitiveWithCheck):
>>> print(output)
[0, 4, 1, 5, 2, 6, 3, 7]
"""
@prim_attr_register
def __init__(self):
"""init MapUniform"""
"""初始化 MapUniform"""
self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'],
outputs=['output'])
def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype):
"""检查输入数据类型"""
validator.check_tensor_dtype_valid(
"input", input_dtype, mstype.int_type, self.name)
validator.check_value_type(
'per_group_size', per_group_size_dtype, [mstype.Int], self.name)
validator.check_value_type(
'group_num', group_num_dtype, [mstype.Int], self.name)
class CacheSwapTable(PrimitiveWithCheck):
"""
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
删除一个哈希映射条目并插入一个新键到哈希映射中返回删除条目的键和值
Inputs:
- **cache_table** (Parameter) - The cache table which is on device.
- **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped.
- **miss_value** (int) - The values which arg going to swap into cache table.
- **cache_table** (Parameter) - 在设备上的缓存表
- **swap_cache_idx** (Tensor) - 需要交换的表索引-1 被跳过
- **miss_value** (int) - 将要交换到缓存表的值
Outputs:
- **old_value** (Tensor) - The values which are swapped out.
- **old_value** (Tensor) - 被交换出去的值
"""
__mindspore_signature__ = (
sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
@ -159,31 +169,35 @@ class CacheSwapTable(PrimitiveWithCheck):
sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1),
sig.make_sig('miss_value', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self):
"""init CacheSwapTable"""
"""初始化 CacheSwapTable"""
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
outputs=['old_value'])
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
# 检查cache_table_shape的长度是否为2如果不是则抛出ValueError异常
if len(cache_table_shape) != 2:
raise ValueError(
"cache table shape must be 2, but got %d" % len(cache_table_shape))
# 返回miss_value_shape
return miss_value_shape
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
# 检查swap_cache_idx_dtype是否为mstype.int_type如果不是则抛出ValueError异常
validator.check_tensor_dtype_valid(
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
# 返回miss_value_dtype
return miss_value_dtype
class MapCacheIdx(PrimitiveWithCheck):
"""
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
When input an indices tensor, it will output the cache indices which search in hashmap.
MapCacheIdx SearchCacheIdxCacheSwapHashmap UpdateCache 合并在一起
当输入一个索引张量时它将输出在哈希映射中搜索的缓存索引
"""
__mindspore_signature__ = (
sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
@ -193,56 +207,69 @@ class MapCacheIdx(PrimitiveWithCheck):
sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T),
sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self):
"""init MapCacheIdx"""
"""初始化 MapCacheIdx"""
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
def __check__(self, hashmap, indices, step, emb_max_num, offset):
# 获取hashmap的形状
hashmap_shape = hashmap['shape']
# 如果hashmap的维度不是2则抛出异常
if len(hashmap_shape) != 2:
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
"but got %d." % len(hashmap_shape))
# 设置输出的形状
out_shape = (indices['shape'], -1, -1, -1)
# 获取hashmap和indices的数据类型
hashmap_dtype = hashmap['dtype']
indices_dtype = indices['dtype']
# 将数据类型存入字典
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
# 检查数据类型是否相同且有效
validator.check_tensors_dtypes_same_and_valid(
args, mstype.int_type, self.name)
# 设置输出的数据类型
out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype)
# 设置输出的字典
out = {'shape': out_shape,
'dtype': out_dtype,
'value': None}
# 如果indices中有max_shape则设置输出的max_shape
if 'max_shape' in indices:
out['max_shape'] = (indices['max_shape'], indices['max_shape'],
indices['max_shape'], indices['max_shape'])
# 否则设置输出的max_shape为indices的形状
else:
out['max_shape'] = (indices['shape'], indices['shape'],
indices['shape'], indices['shape'])
# 如果indices中有min_shape则设置输出的min_shape
if 'min_shape' in indices:
out['min_shape'] = (indices['min_shape'], 0, 0, 0)
# 否则设置输出的min_shape为(0, 0, 0, 0)
else:
out['min_shape'] = (0, 0, 0, 0)
# 返回输出的字典
return out
class DynamicAssign(PrimitiveWithCheck):
"""
Assigns `Parameter` with a value, the `value` can have a dynamic shape.
`Parameter` 与值分配`value` 可以具有动态形状
Inputs:
- **variable** (Parameter) - The `Parameter`.
- **value** (Tensor) - The value to be assigned.
- **variable** (Parameter) - `Parameter`
- **value** (Tensor) - 要分配的值
Outputs:
Tensor, has the same type as original `variable`.
Tensor具有与原始 `variable` 相同的类型
Supported Platforms:
`CPU`
"""
@ -250,41 +277,42 @@ class DynamicAssign(PrimitiveWithCheck):
sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('value', dtype=sig.sig_dtype.T)
)
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
def check_dtype(self, variable, value):
# 检查变量是否为mstype.type_refkey
if variable != mstype.type_refkey:
# 检查变量是否为mstype.number_type类型
validator.check_tensor_dtype_valid(
"variable", variable, mstype.number_type, self.name)
# 检查value是否为mstype.number_type类型
validator.check_scalar_or_tensor_types_same(
{"value": value}, mstype.number_type, self.name)
class PadAndShift(PrimitiveWithCheck):
"""
Pad a tensor with -1, and shift with a length.
-1 填充张量并按长度进行移位
Inputs:
- **input_x** (Tensor) - The input Tensor, which will be copied
to `output`.
- **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is
the pad length of output tensor, cum_sum_arr[shift_idx] is
the start to shift, and cum_sum_arr[shift_idx+1] is the end.
- **shift_idx** (Int) - The idx of cum_sum_arr.
if use python, PadAndShift is:
- **input_x** (Tensor) - 输入张量将被复制到 `output`
- **cum_sum_arr** (Tensor) - cum_sum_arr 的最后一个值是输出张量的填充长度
cum_sum_arr[shift_idx] 是开始移位cum_sum_arr[shift_idx+1] 是结束
- **shift_idx** (Int) - cum_sum_arr 的索引
如果使用 PythonPadAndShift
output = [-1] * cum_sum_arr[-1]
start = cum_sum_arr[shift_idx]
end = cum_sum_arr[shift_idx + 1]
output[start:end] = input_x[:(end-start)]
Outputs:
Tensor, has the same type as original `variable`.
Tensor具有与原始 `variable` 相同的类型
Supported Platforms:
`CPU`
Examples:
>>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32)
>>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32)
@ -296,11 +324,14 @@ class PadAndShift(PrimitiveWithCheck):
"""
@prim_attr_register
def __init__(self):
# 初始化输入输出名称
self.init_prim_io_names(
inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output'])
def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape):
# 检查输入形状
return input_x_shape
def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype):
return input_x_dtype
# 检查输入数据类型
return input_x_dtype

@ -12,39 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operators for TensorArray."""
import mindspore as ms
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive
class TensorArray(PrimitiveWithInfer):
r"""
TensorArrayCreate used to create a TensorArray and return an unique handle.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
dynamic_size (bool): If true the TensorArray can increase the size. Default: True.
size (int): The size of the TensorArray if dynamic_size = False.
name (string): the name of this TensorArray. Default: "TA".
Inputs:
None.
Outputs:
- **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import mindspore.ops as ops
@ -55,6 +55,7 @@ class TensorArray(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
"""初始化TensorArray类设置参数和属性."""
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
validator.check_int(size, 0, Rel.GE, "size", self.name)
self.add_prim_attr('dtype', dtype)
@ -63,32 +64,34 @@ class TensorArray(PrimitiveWithInfer):
self.add_prim_attr('size', size)
self.add_prim_attr('side_effect_mem', True)
self.add_prim_attr('name', name)
def infer_shape(self):
"""推断输出形状."""
return ()
def infer_dtype(self):
"""推断输出数据类型."""
return mstype.int64
class TensorArrayWrite(PrimitiveWithInfer):
r"""
TensorArrayWrite used to write tensor into a created TensorArray.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **index** (Tensor[int64]) - The position to write.
- **value** (Tensor) - The value to add into the TensorArray.
- **handle** (Tensor[int64]) - The handle pointed to the TensorArray.
Outputs:
None.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import mindspore.ops as ops
@ -99,39 +102,42 @@ class TensorArrayWrite(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
"""初始化TensorArrayWrite类."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape, index_shape, value_shape):
"""推断输出形状."""
return ()
def infer_dtype(self, handle_type, index_type, value_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name)
validator.check_type_name("index", index_type, (int, ms.int64), self.name)
validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name)
return mstype.int64
class TensorArrayRead(PrimitiveWithInfer):
r"""
TensorArrayRead used to read tensor from a created TensorArray by the given index.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs:
- **index** (Tensor[int64]) - The position to read.
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs:
- **output** (Tensor) - the value in position index.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import mindspore.ops as ops
@ -146,38 +152,41 @@ class TensorArrayRead(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, dtype, element_shape):
"""初始化TensorArrayRead类设置参数和属性."""
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('dtype', dtype)
self.add_prim_attr('element_shape', element_shape)
self.add_prim_attr('side_effect_mem', True)
self.dtype = dtype
self.shape = element_shape
def infer_shape(self, handle_shape, index_shape):
"""推断输出形状."""
return self.shape
def infer_dtype(self, handle_type, index_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name)
validator.check_type_name("index", index_type, (int, ms.int64), self.name)
return self.dtype
class TensorArrayClose(PrimitiveWithInfer):
r"""
TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs:
None.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import mindspore.ops as ops
@ -188,32 +197,35 @@ class TensorArrayClose(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
"""初始化TensorArrayClose类."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape):
"""推断输出形状."""
return ()
def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64
class TensorArrayClear(PrimitiveWithInfer):
r"""
TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs:
None.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import mindspore.ops as ops
@ -224,36 +236,39 @@ class TensorArrayClear(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
"""初始化TensorArrayClear类."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape):
"""推断输出形状."""
return ()
def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64
class TensorArrayStack(Primitive):
r"""
TensorArrayStack used to stack the tensors in a created TensorArray into one tensor.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs:
- **output** (Tensor) - the stacked value from the TensorArray.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import mindspore.ops as ops
@ -269,31 +284,31 @@ class TensorArrayStack(Primitive):
"""
@prim_attr_register
def __init__(self, dtype, element_shape, dynamic_size, size):
"""Initialize TensorArrayStack"""
"""初始化TensorArrayStack类设置参数和属性."""
self.init_prim_io_names(inputs=[''], outputs=['output'])
self.add_prim_attr('dtype', dtype)
self.add_prim_attr('element_shape', element_shape)
self.add_prim_attr('is_dynamic_shape', dynamic_size)
self.add_prim_attr('size', size)
self.add_prim_attr('side_effect_mem', True)
class TensorArraySize(PrimitiveWithInfer):
r"""
TensorArraySize used to get the logical size of the created TensorArray.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.
Outputs:
- **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray.
Supported Platforms:
``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import mindspore.ops as ops
@ -304,34 +319,37 @@ class TensorArraySize(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
"""初始化TensorArraySize类."""
self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape):
"""推断输出形状."""
return ()
def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64
class TensorArrayGather(PrimitiveWithInfer):
r"""
TensorArrayGather used to gather specified elements from the created TensorArray.
.. warning::
This is an experimental prototype that is subject to change and/or deletion.
Args:
dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.
- **indices** (mindspore.int32) - The locations of the gathered elements.
Outputs:
- **output** (Tensor) - The gathered value from the TensorArray.
Examples:
>>> import mindspore
>>> import mindspore.ops as ops
@ -344,17 +362,20 @@ class TensorArrayGather(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, dtype, element_shape):
"""初始化TensorArrayGather类设置参数和属性."""
self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value'])
self.add_prim_attr("side_effect_mem", True)
self.dtype = dtype
self.element_shape = element_shape
def infer_shape(self, handle, indices):
"""推断输出形状."""
if len(indices) != 1:
return ValueError("indices dimension should be equal to 1")
return [indices[0]] + list(self.element_shape)
def infer_dtype(self, handle, indices):
"""推断输出数据类型."""
validator.check_type_name("handle", handle, (ms.int64), self.name)
validator.check_type_name("indices", indices, (ms.int32), self.name)
return self.dtype
return self.dtype

@ -366,31 +366,46 @@ class ModelCheckpoint(Callback):
"""
def __init__(self, prefix='CKP', directory=None, config=None):
# 初始化函数,设置前缀、目录、配置等参数
super(ModelCheckpoint, self).__init__()
# 调用父类的初始化函数
self._latest_ckpt_file_name = ""
# 初始化最新检查点文件名为空字符串
self._init_time = time.time()
# 初始化初始化时间为当前时间
self._last_time = time.time()
# 初始化最后时间时间为当前时间
self._last_time_for_keep = time.time()
# 初始化最后保存时间为当前时间
self._last_triggered_step = 0
# 初始化最后触发的步数为0
# 检查前缀是否为字符串且不包含'/'
if not isinstance(prefix, str) or prefix.find('/') >= 0:
raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
"for checkpoint file name is invalid, it must be "
"string and does not contain '/', but got {}.".format(prefix))
self._prefix = prefix
# 设置前缀
self._exception_prefix = prefix
# 设置异常前缀
# 如果目录不为空,则创建目录
if directory is not None:
self._directory = _make_directory(directory)
else:
self._directory = _cur_dir
# 否则,使用当前目录
# 如果启用了恢复上下文,则设置检查点路径
if _get_recovery_context("enable_recovery"):
_set_recovery_context(ckpt_path=self._directory)
# 如果config为None则使用默认的CheckpointConfig
if config is None:
self._config = CheckpointConfig()
else:
# 如果config不是CheckpointConfig类型则抛出TypeError异常
if not isinstance(config, CheckpointConfig):
raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be "
"'CheckpointConfig', "
@ -398,11 +413,17 @@ class ModelCheckpoint(Callback):
self._config = config
# get existing checkpoint files
# 创建CheckpointManager对象
self._manager = CheckpointManager()
# 如果存在相同名称的文件,则更改文件名
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
# 获取配置中的append_dict参数如果没有则设置为空字典
self._append_dict = self._config.append_dict or {}
# 获取append_dict中的epoch_num参数如果没有则设置为0
self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0
# 获取append_dict中的step_num参数如果没有则设置为0
self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0
# 标记是否已经保存了图
self._graph_saved = False
self._need_flush_from_cache = True
@ -413,6 +434,7 @@ class ModelCheckpoint(Callback):
Args:
run_context (RunContext): Context of the train running.
"""
# If the role is PServer, add the role name and rank to the prefix
if _is_role_pserver():
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
cb_params = run_context.original_args()
@ -423,18 +445,23 @@ class ModelCheckpoint(Callback):
self._last_triggered_step = cb_params.last_save_ckpt_step
cb_params.last_save_ckpt_step = None
# Create the directory if it doesn't exist
_make_directory(self._directory)
# save graph (only once)
if not self._graph_saved:
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
# If the graph file already exists and the mode is GRAPH_MODE, remove it
if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
os.remove(graph_file_name)
# Save the graph
_save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True
# Wait for any asynchronous checkpoint saving threads to finish
thread_list = threading.enumerate()
for thread in thread_list:
if thread.getName() == "asyn_save_ckpt":
thread.join()
# Save the checkpoint
self._save_ckpt(cb_params)
def end(self, run_context):
@ -444,44 +471,63 @@ class ModelCheckpoint(Callback):
Args:
run_context (RunContext): Context of the train running.
"""
# 获取训练的参数
cb_params = run_context.original_args()
# 设置保存最后一个checkpoint的标志为True
_to_save_last_ckpt = True
# 保存最后一个checkpoint
self._save_ckpt(cb_params, _to_save_last_ckpt)
# 获取当前线程列表
thread_list = threading.enumerate()
# 遍历线程列表
for thread in thread_list:
# 如果线程名为"asyn_save_ckpt",则等待该线程结束
if thread.getName() == "asyn_save_ckpt":
thread.join()
# 销毁所有gather cell
destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save):
"""Check whether save checkpoint files or not."""
# 如果配置了保存检查点步数且步数大于0
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
# 如果当前步数大于等于上次触发保存检查点步数加上保存检查点步数,或者强制保存检查点
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
or force_to_save is True:
return True
# 如果配置了保存检查点秒数且秒数大于0
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
# 获取当前时间
self._cur_time = time.time()
# 如果当前时间减去上次时间大于保存检查点秒数,或者强制保存检查点
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save:
# 更新上次时间
self._last_time = self._cur_time
return True
# 返回False
return False
def _save_ckpt(self, cb_params, force_to_save=False):
"""Save checkpoint files."""
# 如果当前步骤数等于最后触发的步骤数,则返回
if cb_params.cur_step_num == self._last_triggered_step:
return
# if param is cache enable, flush data from cache to host before save_ckpt
# 如果需要从缓存中刷新数据则调用_flush_from_cache方法
if self._need_flush_from_cache:
self._flush_from_cache(cb_params)
# 检查是否需要保存检查点如果force_to_save为True则强制保存
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
# 计算当前步数在epoch中的位置
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
# 如果需要保存检查点,则创建当前检查点的文件名
if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt"
@ -489,43 +535,68 @@ class ModelCheckpoint(Callback):
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number.
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
# 如果keep_checkpoint_max配置存在且大于0且小于等于当前checkpoint文件数量则删除最旧的checkpoint文件
self._manager.remove_oldest_ckpoint_file()
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
# 如果keep_checkpoint_per_n_minutes配置存在且大于0则记录当前时间
self._cur_time_for_keep = time.time()
# 如果当前时间与上次记录的时间之差小于keep_checkpoint_per_n_minutes配置的分钟数乘以60则保留每个分钟的一个checkpoint文件
if (self._cur_time_for_keep - self._last_time_for_keep) \
< self._config.keep_checkpoint_per_n_minutes * 60:
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
self._cur_time_for_keep)
# generate the new checkpoint file and rename it.
# 定义全局变量_save_dir并将其赋值为self._directory
global _save_dir
_save_dir = self._directory
# 获取当前checkpoint文件的路径
cur_file = os.path.join(self._directory, cur_ckpoint_file)
# 记录当前时间
self._last_time_for_keep = time.time()
# 记录当前触发步数
self._last_triggered_step = cb_params.cur_step_num
# 如果启用了GEGraph Execution
if context.get_context("enable_ge"):
# 设置当前网络
set_cur_net(cb_params.train_network)
# 执行checkpoint图
cb_params.train_network.exec_checkpoint_graph()
# 如果_append_dict中包含"epoch_num"
if "epoch_num" in self._append_dict:
# 将_append_epoch_num加上当前epoch数赋值给"epoch_num"
self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
# 如果_append_dict中包含"step_num"
if "step_num" in self._append_dict:
# 将_append_step_num加上当前step数赋值给"step_num"
self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
# 获取保存的网络如果self._config.saved_network不为None则使用self._config.saved_network否则使用cb_params.train_network
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
# 保存checkpoint
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
self._append_dict, self._config.enc_key, self._config.enc_mode)
# 记录最新的checkpoint文件名
self._latest_ckpt_file_name = cur_file
def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable."""
# 初始化has_cache_params为False
has_cache_params = False
# 获取训练网络中的参数
params = cb_params.train_network.get_parameters()
# 遍历参数
for param in params:
# 如果参数的cache_enable为True
if param.cache_enable:
# 设置has_cache_params为True
has_cache_params = True
# 将参数的Tensor数据从缓存中刷新到主机
Tensor(param).flush_from_cache()
# 如果没有参数的cache_enable为True
if not has_cache_params:
# 设置_need_flush_from_cache为False
self._need_flush_from_cache = False
@property
@ -535,63 +606,88 @@ class ModelCheckpoint(Callback):
class CheckpointManager:
"""Manage checkpoint files according to train_config of checkpoint."""
"""管理检查点文件,根据训练配置进行管理。"""
def __init__(self):
"""初始化检查点管理器,创建空的检查点文件列表。"""
self._ckpoint_filelist = []
@property
def ckpoint_filelist(self):
"""Get all the related checkpoint files managed here."""
"""获取当前管理的所有检查点文件列表。"""
return self._ckpoint_filelist
@property
def ckpoint_num(self):
"""Get the number of the related checkpoint files managed here."""
"""获取当前管理的检查点文件数量。"""
return len(self._ckpoint_filelist)
def update_ckpoint_filelist(self, directory, prefix):
"""Update the checkpoint file list."""
"""更新检查点文件列表,根据目录和前缀筛选符合条件的检查点文件。"""
# 初始化一个空列表用于存储ckpt文件
self._ckpoint_filelist = []
# 获取指定目录下的所有文件
files = os.listdir(directory)
# 遍历所有文件
for filename in files:
# 判断文件是否以指定前缀开头,并且以.ckpt结尾
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
# 获取文件名中间部分
mid_name = filename[len(prefix):-5]
# 判断中间部分是否包含字母
flag = not (True in [char.isalpha() for char in mid_name])
# 如果不包含字母,则将文件路径添加到列表中
if flag:
self._ckpoint_filelist.append(os.path.join(directory, filename))
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
"""从检查点管理器中移除指定的检查点文件,并从目录中删除该文件。"""
try:
# 修改文件权限为可写
os.chmod(file_name, stat.S_IWRITE)
# 删除文件
os.remove(file_name)
# 从ckpoint文件列表中移除该文件
self._ckpoint_filelist.remove(file_name)
except OSError:
# 捕获OSError异常并记录警告日志
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
# 捕获ValueError异常并记录警告日志
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def remove_oldest_ckpoint_file(self):
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory."""
"""移除检查点管理器中最早的检查点文件,并从目录中删除该文件。"""
# 获取所有checkpoint文件并按修改时间排序
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
# 删除最早修改的checkpoint文件
self.remove_ckpoint_file(ckpoint_files[0])
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
"""保留每分钟生成的最新检查点文件,移除在指定时间范围内生成的其他文件。"""
# 定义一个空列表,用于存储需要删除的文件
del_list = []
# 定义一个空字符串,用于存储最旧的文件名
oldest_file = ''
# 定义一个变量,用于存储当前时间
oldest_time = cur_time
# 遍历_ckpoint_filelist中的文件
for ck_file in self._ckpoint_filelist:
# 获取文件的修改时间
modify_time = os.path.getmtime(ck_file)
# 如果当前时间减去文件的修改时间小于60*minutes则将文件添加到del_list中
if cur_time - modify_time < 60 * minutes:
del_list.append(ck_file)
# 如果文件的修改时间小于oldest_time则更新oldest_time和oldest_file
if modify_time < oldest_time:
oldest_time = modify_time
oldest_file = ck_file
# 遍历del_list中的文件
for mv_file in del_list:
# 如果文件是最旧的文件,则跳过
if mv_file == oldest_file:
continue
self.remove_ckpoint_file(mv_file)
# 调用remove_ckpoint_file方法删除文件
self.remove_ckpoint_file(mv_file)

@ -256,36 +256,53 @@ class DatasetHelper:
"""
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
# 检查dataset_sink_mode是否为布尔值
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
# 检查sink_size是否为整数
Validator.check_is_int(sink_size)
# 如果sink_size小于-1或者等于0抛出异常
if sink_size < -1 or sink_size == 0:
raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size))
# 如果sink_size等于-1则将其设置为dataset的dataset_size
if sink_size == -1:
sink_size = dataset.get_dataset_size()
# 如果dataset_sink_mode为True则根据不同的设备类型选择不同的迭代器
if dataset_sink_mode:
# 如果启用了GE则使用GE的迭代器
if context.get_context("enable_ge"):
iterclass = _DatasetIterGE
else:
# 如果当前模式为GRAPH_MODE则根据角色选择不同的迭代器
if context.get_context("mode") == context.GRAPH_MODE:
# 如果当前角色为调度器或者参数服务器,则使用参数服务器的迭代器
if _is_role_sched() or _is_role_pserver():
iterclass = _DatasetIterPSServer
# 如果当前角色为工作节点并且是参数服务器模式,则使用参数服务器工作节点的迭代器
elif _is_role_worker() and _is_ps_mode():
iterclass = _DatasetIterPSWork
# 如果当前设备类型为Ascend或者GPU则使用多线程循环的迭代器
elif (context.get_context("device_target") == "Ascend") or \
(context.get_context("device_target") == "GPU"):
iterclass = _DatasetIterMSLoopSink
# 如果当前设备类型为CPU则抛出异常因为CPU不支持数据集下沉模式
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device "
"target is CPU, please set dataset sink mode to False.")
# 如果当前模式不是GRAPH_MODE则使用PyNative的迭代器
else:
iterclass = _DatasetIterPyNative
# 创建迭代器
self.iter = iterclass(dataset, sink_size, epoch_num)
# 如果dataset_sink_mode为False则使用普通的迭代器
else:
# 如果不是分布式训练则使用_DatasetIterNormal类
iterclass = _DatasetIterNormal
# 初始化迭代器
self.iter = iterclass(dataset, epoch_num=epoch_num)
def __iter__(self):
# 返回self.iter的迭代器
return self.iter.__iter__()
# A temp solution for loop sink. Delete later
@ -301,6 +318,7 @@ class DatasetHelper:
>>>
>>> types, shapes = dataset_helper.types_shapes()
"""
# 从当前配置的dataset中获取类型和形状
return self.iter.types_shapes()
def sink_size(self):
@ -316,18 +334,22 @@ class DatasetHelper:
>>> # if sink_size==-1, then will return the full size of source dataset.
>>> sink_size = dataset_helper.sink_size()
"""
# 返回迭代器的接收缓冲区大小
return self.iter.get_sink_size()
def stop_send(self):
"""Stop send data about data sink."""
# 停止发送关于数据接收器的数据
self.iter.stop_send()
def release(self):
"""Free up resources about data sink."""
# 释放数据接收器的资源
self.iter.release()
def continue_send(self):
"""Continue to send data to device at the beginning of epoch."""
# 在每个epoch的开始处继续向设备发送数据
self.iter.continue_send()
def _reset(self, step):
@ -339,6 +361,7 @@ class DatasetHelper:
In sink mode, it returns the types and shapes of the current data.
Generally, it works in dynamic shape scenarios.
"""
# 返回迭代器的数据信息
return self.iter.get_data_info()
def dynamic_min_max_shapes(self):
@ -355,6 +378,7 @@ class DatasetHelper:
>>>
>>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes()
"""
# 返回self.iter的dynamic_min_max_shapes方法
return self.iter.dynamic_min_max_shapes()
@ -362,20 +386,27 @@ class _DatasetIter:
"""Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num):
# 初始化函数传入数据集、sink大小和epoch数量
self.dataset = dataset
self.sink_size = sink_size
self.sink_count = self.get_sink_count(dataset)
# 如果数据集没有__transfer_dataset__属性
if not hasattr(dataset, '__transfer_dataset__'):
# 如果数据集有__loop_size__属性
if hasattr(dataset, '__loop_size__'):
# PS mode does not support loop sink and need get the real sink size.
# 如果不是worker角色或者不是ps模式则设置sink_size为dataset的循环大小
if not (_is_role_worker() and _is_ps_mode()):
self.sink_size = dataset.__loop_size__
# 如果sink_size为1sink_count为1dataset的大小不为1并且设备目标为Ascend则创建数据信息队列
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1
and context.get_context("device_target") == "Ascend")
# 执行数据图并将sink_size和create_data_info_queue作为参数传入
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
create_data_info_queue=create_data_info_queue)
# 如果dataset没有__no_send__属性则发送数据
if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num)
else:
@ -384,33 +415,48 @@ class _DatasetIter:
_cell_graph_executor.set_queue_name(dataset.__transfer_dataset__.queue_name)
_send_data_no_flag(dataset, epoch_num)
# 获取dataset的stop_send方法
self.stop_send = dataset.__transfer_dataset__.stop_send
# 获取dataset的release方法
self.release = dataset.__transfer_dataset__.release
# 获取dataset的continue_send方法
self.continue_send = dataset.__transfer_dataset__.continue_send
# 获取dataset的get_data_info方法
self.get_data_info = dataset.__transfer_dataset__.get_data_info
# 获取dataset的dynamic_min_max_shapes属性
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
# 获取dataset的数据类型和数据形状
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
# 如果dataset的__transfer_dataset__属性中有_reset方法则获取该_reset方法
if hasattr(dataset.__transfer_dataset__, "_reset"):
self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212
def __iter__(self):
# 初始化索引为0
self.index = 0
# 返回self
return self
# 迭代器的下一项
def __next__(self):
# 如果索引大于等于sink_count抛出StopIteration异常
if self.index >= self.sink_count:
raise StopIteration()
# 索引加1
self.index += 1
# 返回op()的返回值
return self.op()
def types_shapes(self):
"""
Return the types and shapes of the dataset. The type and shape of each data in the dataset
should be consistent.
返回数据集的类型和形状数据集中每个数据的类型和形状应该是一致的
"""
return self.dataset_types, self.dataset_shapes
def get_sink_count(self, dataset):
"""
获取数据集的sink次数
:param dataset: 数据集对象
:return: sink次数
"""
sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__
@ -421,7 +467,10 @@ class _DatasetIter:
return sink_count
def get_sink_size(self):
"""get sink_size to device"""
"""
获取设备的sink大小
:return: sink大小
"""
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__

@ -23,47 +23,59 @@ from setuptools import setup, find_packages
from setuptools.command.egg_info import egg_info
from setuptools.command.build_py import build_py
# 获取环境变量
backend_policy = os.getenv('BACKEND_POLICY')
device_target = os.getenv('BACKEND_TARGET')
commit_id = os.getenv('COMMIT_ID').replace("\n", "")
package_name = os.getenv('MS_PACKAGE_NAME').replace("\n", "")
build_path = os.getenv('BUILD_PATH')
# 获取当前文件路径
pwd = os.path.dirname(os.path.realpath(__file__))
# 获取包目录路径
pkg_dir = os.path.join(build_path, 'package')
def _read_file(filename):
"""读取文件内容"""
with open(os.path.join(pwd, filename), encoding='UTF-8') as f:
return f.read()
# 读取版本号
version = _read_file('version.txt').replace("\n", "")
# 读取README.md文件内容
readme = _read_file('README.md')
def _write_version(file):
"""写入版本号"""
file.write("__version__ = '{}'\n".format(version))
def _write_config(file):
"""写入后端策略"""
file.write("__backend__ = '{}'\n".format(backend_policy))
def _write_commit_file(file):
"""写入commit_id"""
file.write("__commit_id__ = '{}'\n".format(commit_id))
def _write_package_name(file):
"""写入包名"""
file.write("__package_name__ = '{}'\n".format(package_name))
def _write_device_target(file):
"""写入设备目标"""
file.write("__device_target__ = '{}'\n".format(device_target))
def build_dependencies():
"""generate python file"""
# 生成version.py文件
version_file = os.path.join(pkg_dir, 'mindspore', 'version.py')
with open(version_file, 'w') as f:
_write_version(f)
@ -72,6 +84,7 @@ def build_dependencies():
with open(version_file, 'w') as f:
_write_version(f)
# 生成default_config.py文件
config_file = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(config_file, 'w') as f:
_write_config(f)
@ -80,6 +93,7 @@ def build_dependencies():
with open(config_file, 'w') as f:
_write_config(f)
# 向default_config.py文件中追加device_target
target = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(target, 'a') as f:
_write_device_target(f)
@ -88,6 +102,7 @@ def build_dependencies():
with open(target, 'a') as f:
_write_device_target(f)
# 向default_config.py文件中追加package_name
package_info = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(package_info, 'a') as f:
_write_package_name(f)
@ -96,6 +111,7 @@ def build_dependencies():
with open(package_info, 'a') as f:
_write_package_name(f)
# 生成.commit_id文件
commit_file = os.path.join(pkg_dir, 'mindspore', '.commit_id')
with open(commit_file, 'w') as f:
_write_commit_file(f)
@ -145,16 +161,24 @@ def update_permissions(path):
Args:
path (str): Target directory path.
"""
# 判断操作系统是否为Windows
if platform.system() == "Windows":
return
# 遍历目标目录下的所有文件和文件夹
for dirpath, dirnames, filenames in os.walk(path):
# 遍历文件夹
for dirname in dirnames:
# 获取文件夹的完整路径
dir_fullpath = os.path.join(dirpath, dirname)
# 更新文件夹的权限
os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE |
stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP)
# 遍历文件
for filename in filenames:
# 获取文件的完整路径
file_fullpath = os.path.join(dirpath, filename)
# 更新文件的权限
os.chmod(file_fullpath, stat.S_IREAD)
@ -163,7 +187,9 @@ class EggInfo(egg_info):
def run(self):
super().run()
# 获取egg-info目录的路径
egg_info_dir = os.path.join(pkg_dir, 'mindspore.egg-info')
# 更新egg-info目录的权限
update_permissions(egg_info_dir)
@ -172,41 +198,64 @@ class BuildPy(build_py):
def run(self):
super().run()
# 获取build目录下的lib/mindspore目录的路径
mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore')
# 更新lib/mindspore目录的权限
update_permissions(mindspore_dir)
# 获取build目录下的lib/mindspore/_akg目录的路径
mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore', '_akg')
# 更新lib/mindspore/_akg目录的权限
update_permissions(mindspore_dir)
# 设置包的名称
setup(
name=package_name,
# 设置包的版本
version=version,
# 设置包的作者
author='The MindSpore Authors',
# 设置包的作者邮箱
author_email='contact@mindspore.cn',
# 设置包的网址
url='https://www.mindspore.cn',
# 设置包的下载网址
download_url='https://github.com/mindspore-ai/mindspore/tags',
# 设置包的源代码网址
project_urls={
'Sources': 'https://github.com/mindspore-ai/mindspore',
# 设置包的问题追踪网址
'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues',
},
# 设置包的描述
description='MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.',
# 读取readme文件作为包的详细描述
long_description=readme,
# 设置详细描述的格式
long_description_content_type="text/markdown",
# 查找包中的所有模块
packages=find_packages(),
# 设置包的数据
package_data=package_data,
# 包含包中的所有数据
include_package_data=True,
# 设置自定义的命令类
cmdclass={
'egg_info': EggInfo,
'build_py': BuildPy,
},
# 设置包的入口点
entry_points={
'console_scripts': [
'cache_admin=mindspore.dataset.engine.cache_admin:main',
],
},
# 设置包的Python版本要求
python_requires='>=3.7',
# 设置包的依赖
install_requires=required_package,
# 设置包的分类器
classifiers=[
'Development Status :: 4 - Beta',
'Environment :: Console',
@ -223,6 +272,8 @@ setup(
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
],
# 设置包的许可证
license='Apache 2.0',
# 设置包的关键词
keywords='mindspore machine learning',
)

@ -17,5 +17,7 @@ import mindspore.context as context
def setup_module(module):
# 禁用pylint对未使用参数的警告
# pylint: disable=unused-argument
# 设置上下文模式为图模式
context.set_context(mode=context.GRAPH_MODE)

@ -26,25 +26,28 @@ from .utils import keyword
def mindspore_test(verification_pipeline):
"""
Run verification pipeline.
运行验证流水线
Args:
verification_pipeline (list): Pipeline designed to do verification.
verification_pipeline (list): 设计的验证流水线
Returns:
"""
def decorate(get_verification_set):
# 获取验证集
verification_set = get_verification_set()
facade_components = []
data_components = []
builder_components = []
executor_components = []
verifier_components = []
fi_policy_components = []
er_policy_components = []
# 初始化组件列表
facade_components = [] # 外观组件列表
data_components = [] # 数据组件列表
builder_components = [] # 构建组件列表
executor_components = [] # 执行组件列表
verifier_components = [] # 验证组件列表
fi_policy_components = [] # FI策略组件列表
er_policy_components = [] # ER策略组件列表
for component in verification_pipeline:
# 判断组件类型并添加到对应列表
if issubclass(component, IFacadeComponent):
facade_components.append(component)
elif issubclass(component, IDataComponent):
@ -62,68 +65,90 @@ def mindspore_test(verification_pipeline):
else:
raise Exception(f'{component} is not an instance of {IComponent}')
# 依次处理外观组件
for component in facade_components:
fc = component(verification_set)
verification_set = fc()
# 初始化输入列表
inputs = []
# 依次处理数据组件
for component in data_components:
dc = component(verification_set)
item = dc()
inputs.extend(item)
# 如果输入列表为空,记录警告
if not inputs:
logging.warning("Inputs set is empty.")
# 初始化函数列表
functions = []
# 依次处理构建组件
for component in builder_components:
bc = component(verification_set)
f = bc()
functions.extend(f)
# 如果函数列表为空,记录警告
if not functions:
logging.warning("Function set is empty.")
# 初始化函数输入对列表
fis = []
# 依次处理FI策略组件
for component in fi_policy_components:
fipc = component(verification_set, functions, inputs)
result = fipc()
fis.extend(result)
# 如果函数输入对列表为空,记录警告
if not fis:
logging.warning("Function inputs pair set is empty.")
# 定义测试用例函数
def test_case(args):
# 提取系统待测和输入参数
sut, inputs = args
# 初始化结果列表
results = []
# 依次处理执行组件
for component in executor_components:
ec = component(verification_set, sut, inputs)
result = ec()
results.append(result)
# 如果结果列表为空,记录警告
if not results:
logging.warning("Result set is empty.")
# 初始化期望实际结果对列表
expect_actuals = []
# 依次处理ER策略组件
for component in er_policy_components:
erpc = component(verification_set, verification_set['expect'], results)
result = erpc()
expect_actuals.extend(result)
# 如果期望实际结果对列表为空,记录警告
if not expect_actuals:
logging.warning("Expect Result pair set is empty.")
# 依次处理验证组件
for ea in expect_actuals:
for component in verifier_components:
vc = component(verification_set, *ea)
vc()
# 定义测试用例名称生成函数
def get_tc_name(f, inputs):
# 拼接测试用例ID和组名
tc_id = f[keyword.id] + '-' + inputs[keyword.id]
group = f[keyword.group] + '-' + inputs[keyword.group]
return 'Group_' + group + '-' + 'Id_' + tc_id
# 如果存在函数输入对,则生成测试用例
if fis:
m = pytest.mark.parametrize('args', fis, ids=lambda fi: get_tc_name(*fi))(test_case)
m.__orig__ = get_verification_set

Loading…
Cancel
Save