diff --git a/src/mindspore2022/cmake/dependency_graphengine.cmake b/src/mindspore2022/cmake/dependency_graphengine.cmake index af63a609..89eece28 100644 --- a/src/mindspore2022/cmake/dependency_graphengine.cmake +++ b/src/mindspore2022/cmake/dependency_graphengine.cmake @@ -20,8 +20,11 @@ function(find_submodule_lib module name path) ) endfunction() +# 定义一个函数,用于生成protobuf文件 function(ge_protobuf_generate c_var h_var) + # 调用common_protobuf_generate函数,生成protobuf文件 common_protobuf_generate(${CMAKE_BINARY_DIR}/proto/ge/proto ${c_var} ${h_var} ${ARGN}) + # 将生成的c文件和h文件赋值给c_var和h_var set(${c_var} ${${c_var}} PARENT_SCOPE) set(${h_var} ${${h_var}} PARENT_SCOPE) endfunction() diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/builtin_operations.py b/src/mindspore2022/mindspore/python/mindspore/_extends/builtin_operations.py index 1b0e4e9a..d07dbf71 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/builtin_operations.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/builtin_operations.py @@ -22,31 +22,100 @@ from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype def ScalarAdd(x, y): + """ + 实现标量加法运算。 + + Args: + x (float): 第一个加数。 + y (float): 第二个加数。 + + Returns: + float: x 和 y 的和。 + + """ """Implement `scalar_add`.""" return x + y def ScalarMul(x, y): + """ + 标量乘法函数。 + + Args: + x (float): 第一个标量。 + y (float): 第二个标量。 + + Returns: + float: 两个标量的乘积。 + + """ """Implement `scalar_mul`.""" return x * y def ScalarMod(x, y): + """ + 对两个数进行模运算。 + + Args: + x (int or float): 被模数。 + y (int or float): 模数。 + + Returns: + int or float: x 对 y 取模的结果。 + + """ """Implement `scalar_mul`.""" return x % y def ScalarSub(x, y): + """ + 实现标量减法运算。 + + Args: + x (float): 第一个标量值。 + y (float): 第二个标量值。 + + Returns: + float: x 和 y 的差值。 + + """ """Implement `scalar_sub`.""" return x - y def ScalarUsub(x): + """ + 对给定的标量 x 进行取反操作。 + + Args: + x (float or int): 需要取反的标量。 + + Returns: + float or int: 取反后的标量。 + + """ """Implement `scalar_usub`.""" return -x def TupleGetItem(x, index): + """ + 从给定对象中获取指定索引处的元素。 + + Args: + x (Union[Tensor, dict]): 输入对象,可以是Tensor类型或字典类型。 + index (int): 要获取的元素的索引。 + + Returns: + Union[Tensor, Any]: 如果输入是Tensor类型,则返回Tensor类型的元素; + 如果输入是字典类型,则返回字典中对应索引的值; + 否则,返回输入对象中对应索引的元素。 + + Raises: + IndexError: 如果索引超出范围。 + """ """Implement `tuple_getitem`.""" if isinstance(x, Tensor): x = x.asnumpy() @@ -64,36 +133,111 @@ def TupleGetItem(x, index): def scalar_gt(x, y): + """ + 判断两个标量值x和y的大小关系。 + + Args: + x (float or int): 第一个标量值。 + y (float or int): 第二个标量值。 + + Returns: + bool: 如果x大于y,则返回True;否则返回False。 + + """ """Implement `scalar_gt`.""" return x > y def scalar_ne(x, y): + """ + 比较两个标量值是否不相等。 + + Args: + x (float): 第一个标量值。 + y (float): 第二个标量值。 + + Returns: + bool: 如果 x 不等于 y,则返回 True;否则返回 False。 + + """ """Implement `scalar_ne`.""" return x != y def scalar_eq(x, y): + """ + 判断两个标量值是否相等。 + + Args: + x (Any): 第一个标量值。 + y (Any): 第二个标量值。 + + Returns: + bool: 如果 x 和 y 相等,返回 True;否则返回 False。 + + """ """Implement `scalar_eq`.""" return x == y def scalar_le(x, y): + """ + 判断标量 x 是否小于等于标量 y。 + + Args: + x (float): 第一个标量值。 + y (float): 第二个标量值。 + + Returns: + bool: 如果 x 小于等于 y,则返回 True;否则返回 False。 + + """ """Implement `scalar_le`.""" return x <= y def scalar_lt(x, y): + """ + 判断两个标量值的大小关系。 + + Args: + x (float): 第一个标量值。 + y (float): 第二个标量值。 + + Returns: + bool: 如果 x 小于 y,则返回 True;否则返回 False。 + + """ """Implement `scalar_lt`.""" return x < y def identity(x): + """ + 返回输入参数本身。 + + Args: + x: 任何类型的输入参数。 + + Returns: + 返回输入参数本身。 + + """ """Implement `identity`.""" return x def zeros_like_tensor(x): + """ + 根据给定的张量x创建一个形状相同但所有元素为零的新张量。 + + Args: + x (Tensor): 输入的张量,用于确定新张量的形状。 + + Returns: + Tensor: 一个与输入张量x形状相同但所有元素为零的新张量。 + + """ """Implement `zeros_like_tensor`.""" x = x.asnumpy() value = Tensor(np.zeros(x.shape).astype(np.float32)) @@ -101,61 +245,201 @@ def zeros_like_tensor(x): def Switch(c, x, y): + """ + 实现 `switch` 功能。 + + Args: + c (bool): 条件值,如果为 True,则返回 x,否则返回 y。 + x (Any): 条件为 True 时返回的值。 + y (Any): 条件为 False 时返回的值。 + + Returns: + Any: 根据条件 c 返回 x 或 y。 + + """ """Implement `switch`.""" return x if c else y def list_getitem(data, item): + """ + 从列表中获取指定索引处的元素。 + + Args: + data (list): 待查询的列表。 + item (int): 要获取的元素的索引。 + + Returns: + 返回列表中索引为item的元素。 + + Raises: + IndexError: 如果索引超出列表范围。 + """ """Implement `list_getitem`.""" return data[item] def bool_not(x): + """ + 对输入值取反。 + + Args: + x (bool): 要取反的布尔值。 + + Returns: + bool: x 的逻辑非值。 + + """ """Implement `bool_not`.""" return not x def bool_and(x, y): + """ + 对两个布尔值进行逻辑与操作。 + + Args: + x (bool): 第一个布尔值。 + y (bool): 第二个布尔值。 + + Returns: + bool: 返回两个布尔值进行逻辑与操作后的结果。 + + """ """Implement `bool_and`.""" return x and y def bool_or(x, y): + """ + 实现布尔或运算。 + + Args: + x (bool): 第一个布尔值。 + y (bool): 第二个布尔值。 + + Returns: + bool: 如果 x 或 y 为 True,则返回 True,否则返回 False。 + + """ """Implement `bool_or`.""" return x or y def make_list(*xs): + """ + 将不定数量的参数转换为一个列表。 + + Args: + *xs: 不定数量的参数,可以是任意类型。 + + Returns: + list: 包含所有传入参数的列表。 + + Examples: + >>> make_list(1, 2, 3) + [1, 2, 3] + >>> make_list('a', 'b', 'c') + ['a', 'b', 'c'] + >>> make_list(1, 'a', [1, 2, 3]) + [1, 'a', [1, 2, 3]] + """ """Implement `make_list`.""" return list(xs) def list_len(x): + """ + 计算列表的长度。 + + Args: + x (list): 需要计算长度的列表。 + + Returns: + int: 列表的长度。 + + """ """Implement `list_len`.""" return len(x) def Depend(value, expr): + """ + 依赖函数,根据给定的表达式返回相应的值。 + + Args: + value (Any): 要返回的值。 + expr (Any): 表达式,该参数在当前实现中被忽略。 + + Returns: + Any: 返回与输入相同的值。 + + """ """Implement `Depend`.""" return value def UpdateState(monad, *exprs): + """ + 更新状态。 + + Args: + monad (Monad): 一个符合 Monad 类型的对象。 + *exprs (Any): 需要更新的表达式,可以为任意类型。 + + Returns: + Monad: 更新后的 Monad 对象。 + + """ """Implement `UpdateState`.""" return monad def Load(value, u=None): + """ + 加载指定的值。 + + Args: + value (Any): 要加载的值。 + u (Optional[Any], optional): 可选参数,默认为None。当前版本未使用,保留以便未来扩展。 + + Returns: + Any: 返回加载的值。 + + """ """Implement `Load`.""" return value # only used in PyNative mode def make_ref(key, value, ref): + """ + 创建一个引用对象。 + + Args: + key (str): 键名,用于标识引用的对象。 + value (Any): 引用对象的值。 + ref (Any): 引用对象,可以为任意类型。 + + Returns: + Any: 返回引用的值。 + + """ return value def scalar_cast(x, t): + """ + 将标量值x转换为指定的NumPy数据类型t。 + + Args: + x (float, int): 要转换的标量值。 + t (np.dtype): 目标NumPy数据类型。 + + Returns: + Any: 转换后的标量值,类型为t。 + + """ """Implement scalar_cast.""" np_type = dtype_to_nptype(t) value = np_type(x) @@ -164,16 +448,46 @@ def scalar_cast(x, t): def typeof(x): + """ + 实现 typeof 函数。 + + Args: + x (Any): 要获取类型的对象。 + + Returns: + str: 返回传入对象的Python类型名称。 + + """ """Implement typeof.""" return get_py_obj_dtype(x) def tuple_to_array(x): + """ + 将元组转换为数组。 + + Args: + x (tuple): 待转换的元组。 + + Returns: + Tensor: 转换后的数组。 + + """ """Implement `tuple_to_array`.""" return Tensor(np.array(x)) def stop_gradient(x): + """ + 停止梯度传播。 + + Args: + x (Tensor): 需要停止梯度传播的张量。 + + Returns: + Tensor: 停止梯度传播的张量。 + + """ """Implement `stop_gradient`.""" return x @@ -182,6 +496,20 @@ hyper_map = C.HyperMap() def mixed_precision_cast(dst_type, x): + """ + 实现混合精度转换函数。 + + Args: + dst_type (mstype.Type): 目标数据类型。 + x (Union[Tensor, list, tuple]): 需要进行类型转换的数据,可以是单个Tensor,也可以是一个包含Tensor的列表或元组。 + + Returns: + Union[Tensor, list, tuple]: 转换后的数据,类型与输入一致。 + + Raises: + TypeError: 如果输入数据类型不支持,将引发TypeError异常。 + + """ """Implement `mixed_precision_cast`.""" def cast_inner(data): diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/__init__.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/__init__.py index 495b8fb0..d59153e3 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/__init__.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/__init__.py @@ -13,6 +13,9 @@ # limitations under the License. # ============================================================================ """init""" +# 从splitter模块中导入split_with_json函数 from .splitter import split_with_json +# 从expander模块中导入get_op_expander函数 from .expander import get_op_expander +# 从parallel_estimate模块中导入estimate_calculation_amount和estimate_ops函数 from .parallel_estimate import estimate_calculation_amount, estimate_ops diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expander.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expander.py index fb104da9..0264233f 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expander.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expander.py @@ -22,8 +22,32 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx def create_expander(expand_info): + """ + 根据操作符名称创建一个扩展器。 + + Args: + expand_info (dict): 包含操作符名称及其他相关信息的字典。 + + Returns: + Any: 调用指定操作符名称的扩展器后返回的结果。 + + Raises: + GraphKernelUnsupportedException: 如果指定的操作符名称在扩展器模块中不存在,则抛出此异常。 + + """ """Create an expander according to op name""" def call_func(func, arg): + """ + 调用给定的函数并返回其结果。 + + Args: + func (callable): 要调用的函数。 + arg: 要传递给函数的参数。 + + Returns: + 调用给定函数后的返回值。 + + """ return func(arg) op_name = str(expand_info['name']) if not hasattr(expanders, op_name): @@ -33,6 +57,21 @@ def create_expander(expand_info): def extract_expand_info(kernel_info): + """ + 将json格式的kernel信息转换为更友好的格式。 + + Args: + kernel_info (dict): 包含kernel信息的字典。 + + Returns: + dict: 转换后的kernel信息字典,包含以下键: + - name (str): kernel的名称。 + - input_desc (list): 输入描述列表。 + - output_desc (list): 输出描述列表。 + - attr (dict): 属性字典,键为属性名,值为属性值。 + - process (str): 处理过程的描述。 + + """ """Convert the json into a more friendly format""" input_desc = [] if 'input_desc' in kernel_info and kernel_info['input_desc']: @@ -53,6 +92,20 @@ def extract_expand_info(kernel_info): def get_op_expander(json_str: str): + """ + 通过json信息获取操作扩展器。 + + Args: + json_str (str): 包含操作扩展器信息的json字符串。 + + Returns: + str: 返回扩展后的操作图的json描述。 + + Raises: + jd.JSONDecodeError: 如果输入的json字符串解码失败。 + GraphKernelUnsupportedException: 如果操作图不支持的操作类型。 + + """ """get op expander by json info""" try: kernel_info = json.loads(json_str) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/__init__.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/__init__.py index 5fe87dd5..b47ae225 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """expanders init. Deprecated, please add the new operators in the c++ file""" +"""扩展器初始化。已弃用,请在新运算符中添加C++文件""" from .addn import AddN diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/_utils.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/_utils.py index 1bdb205d..9fb8a071 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/_utils.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/_utils.py @@ -27,6 +27,19 @@ class Expander: __metaclass__ = ABCMeta def __init__(self, expand_info): + """ + 初始化方法。 + + Args: + expand_info (dict): 包含模型信息的字典,包括模型名称、输入描述、输出描述、属性、处理函数等。 + + Attributes: + name (str): 模型名称。 + inputs (list): 输入描述列表。 + outputs (list): 输出描述列表。 + attrs (dict): 模型属性字典。 + processor (callable): 处理函数。 + """ self.name = expand_info["name"] self.inputs = expand_info["input_desc"] self.outputs = expand_info["output_desc"] @@ -34,6 +47,19 @@ class Expander: self.processor = expand_info["process"] def run(self): + """ + 将操作扩展为图。 + + Args: + 无 + + Returns: + 返回扩展后的图对象。 + + Raises: + GraphKernelUnsupportedException: 如果检查失败,则引发此异常。 + + """ """ Expand the operator to a graph. @@ -58,9 +84,31 @@ class Expander: return graph def _check(self): + """ + 检查输入。 + + Args: + 无 + + Returns: + 无 + + Raises: + ValueError: 如果输入不符合要求,则引发此异常。 + + """ """Check inputs""" def _check_output_same(self, outputs): + """ + 检查输出是否与预期一致。 + + Args: + outputs (list): 实际输出值的列表。 + + Raises: + GKException: 如果实际输出值与预期不一致,则抛出异常。 + """ for index, value in enumerate(self.outputs): if list(outputs[index].shape) != list(value['shape']): raise GKException("{} 's output shape {} is wrong. Expected:{}".format( @@ -74,6 +122,18 @@ class Expander: @abstractmethod def _expand(self, graph_builder): + """ + Expand 操作符,此函数应在子类中重写。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Raises: + Exception: 如果子类未重写此方法,则抛出异常,提示 "_expand() is not implemented in {}". + + Returns: + None + """ """Expand operator, this function should be overridden in subclass""" raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__)) @@ -82,10 +142,34 @@ class ExpanderInfoValidator: """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" def __init__(self): + """ + 初始化方法。 + + Args: + 无 + + Returns: + 无 + + """ """Init""" @staticmethod def _add_check_function(kls, func): + """ + 向类 Expander 中的 `_check` 函数添加新的检查函数 `func`。 + + Args: + kls (type): 需要被修改的类对象,应继承自 Expander 类。 + func (callable): 要添加的新检查函数,该函数接受一个对象作为参数。 + + Returns: + None + + Raises: + AttributeError: 如果 kls 类中不存在 `_check` 方法。 + + """ """ Rewrite the function `_check` in class Expander to append the new `func` after the original checks. @@ -93,6 +177,21 @@ class ExpanderInfoValidator: old_check = getattr(kls, "_check") def new_check(obj): + """ + 执行新的检查函数。 + + Args: + obj (Any): 需要检查的对象。 + + Returns: + None + + Raises: + None + + 这个函数首先调用旧版本的检查函数 `old_check` 对传入的对象 `obj` 进行检查, + 然后调用自定义的函数 `func` 对该对象进行处理。 + """ old_check(obj) func(obj) @@ -103,6 +202,34 @@ class ExpanderInfoValidator: """ Add new supported format for the operator + Args: + *input_format: A variable number of arguments representing the new supported formats. + + Returns: + A wrapper function that adds the specified formats to the operator's supported formats list. + + Raises: + GKException: Raised if the length of the registered format list does not match the length of the input formats, + or if the input formats do not match any registered format. + Exception: Raised if the wrapped class is not a subclass of Expander. + + Description: + This function adds a list `__supported_formats` to the expander, + which contains the whitelist of formats supported by the operator. + It also rewrites the `_check` function to check the formats. + + Example: + python + @add_format("text", "image") + class MyOperator(Expander): + pass + ``` + + In this example, `MyOperator` will support the "text" and "image" formats. + """ + """ + Add new supported format for the operator + this function will add a list `__supported_formats` into the expander, saving the whitelist of formats that this op supports. it also rewrites the `_check` function to check the formats. @@ -110,6 +237,19 @@ class ExpanderInfoValidator: format_list_name = "__supported_formats" def _check_format(obj): + """ + 检查对象的输入格式是否与已注册的格式匹配。 + + Args: + obj (object): 需要检查的对象。 + + Raises: + GKException: 如果输入格式与已注册的格式不匹配,则引发异常。 + + Returns: + None + + """ inp_formats = [inp['format'] for inp in obj.inputs] for formats in getattr(obj, format_list_name): if len(formats) != len(inp_formats): @@ -120,6 +260,18 @@ class ExpanderInfoValidator: raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name)) def wrapper(cls): + """ + 为给定的类添加包装功能。 + + Args: + cls: 需要被包装的类,必须继承自 Expander 类。 + + Returns: + 返回包装后的类。 + + Raises: + Exception: 如果 cls 不是 Expander 的子类,则抛出异常。 + """ if not issubclass(cls, Expander): raise Exception("{} should be subclass of Expander.".format(cls.__name__)) if not hasattr(cls, format_list_name): @@ -132,11 +284,49 @@ class ExpanderInfoValidator: @staticmethod def check_all_formats_same(kls): + """ + 检查所有格式是否相同。 + + Args: + kls: 待检查的类 + + Returns: + 返回传入的类 kls,并在类上注册一个检查函数,用于验证该类所有输入格式是否一致。 + + Raises: + Exception: 如果传入的类 kls 不是 Expander 的子类,则抛出异常。 + GKException: 如果 kls 类中的输入格式不一致,则抛出异常,并显示不匹配格式的信息。 + + """ """Check that all formats are the same""" # Ensure no args case can return a class def _check(*args): + """ + 检查操作输入格式是否一致的装饰器。 + + Args: + *args: 可变参数,装饰器可以接收任意数量的参数。 + + Returns: + wrapper: 返回一个装饰器函数,用于包装类。 + + Raises: + GKException: 如果所有输入的格式不一致,抛出GKException异常。 + Exception: 如果被装饰的类不是Expander的子类,抛出异常。 + + """ def _check_format(obj): + """ + 检查输入格式是否一致。 + + Args: + obj (Any): 包含输入信息的对象。 + + Raises: + GKException: 如果所有输入格式不一致,则抛出异常,并包含不匹配格式的具体信息。 + + """ inp_formats = [inp['format'] for inp in obj.inputs] if all((fmt == inp_formats[0] for fmt in inp_formats[1:])): return @@ -144,6 +334,19 @@ class ExpanderInfoValidator: ','.join(inp_formats), obj.name)) def wrapper(cls): + """ + 将给定类包装为 Expander 的子类,并进行格式检查。 + + Args: + cls (class): 需要包装的类。 + + Returns: + class: 包装后的类。 + + Raises: + Exception: 如果 cls 不是 Expander 的子类,则抛出异常。 + + """ if not issubclass(cls, Expander): raise Exception("{} should be subclass of Expander.".format(cls.__name__)) ExpanderInfoValidator._add_check_function(cls, _check_format) @@ -155,14 +358,53 @@ class ExpanderInfoValidator: @staticmethod def check_attrs(*args): + """ + 检查属性是否存在。 + + Args: + *args: 一个或多个属性名,用于检查对象是否具有这些属性。 + + Returns: + 一个装饰器函数,该装饰器函数用于验证类是否具有指定的属性。 + + Raises: + GKException: 如果对象不具有指定的属性,则抛出该异常。 + Exception: 如果被装饰的类不是 Expander 的子类,则抛出该异常。 + + """ """Check the attrs exist""" def _check_attr(obj): + """ + 检查对象是否具有指定的属性。 + + Args: + obj (object): 要检查的对象。 + + Raises: + GKException: 如果对象不具有指定的属性,则抛出异常。 + + Returns: + None + """ for a in args: if a not in obj.attrs: raise GKException("attr '{}' does not exist.".format(a)) def wrapper(cls): + """ + 对类进行包装,确保该类是 Expander 的子类,并添加属性检查功能。 + + Args: + cls (class): 需要包装的类。 + + Returns: + class: 包装后的类。 + + Raises: + Exception: 如果 cls 不是 Expander 的子类,则抛出异常。 + + """ if not issubclass(cls, Expander): raise Exception("{} should be subclass of Expander.".format(cls.__name__)) ExpanderInfoValidator._add_check_function(cls, _check_attr) @@ -172,6 +414,21 @@ class ExpanderInfoValidator: def to_frac_z_axis(ori_shape, ori_axis): + """ + 判断是否为分形NZ格式 + + Args: + ---- + ori_shape: list or tuple + 输入的原始形状 + ori_axis: list or tuple + 操作的原始形状的轴 + + Returns: + ------- + output: list + 分形Nz形状的轴 + """ """ judge the format is fractal NZ Parameters @@ -208,6 +465,16 @@ def to_frac_z_axis(ori_shape, ori_axis): def infer_shape_from_fractalnz(fractal): + """ + 从fractalnz形状推断原始形状 + + Args: + fractal (list): fractalnz形状,一个包含形状的列表 + + Returns: + list: 推断出的原始形状 + + """ "get original shape from fractalnz shape" shape = [] dims = len(fractal) @@ -222,6 +489,17 @@ def infer_shape_from_fractalnz(fractal): def get_reduced_ori_shape(shape, axis): + """ + 获取基于原始形状的降维后的形状。 + + Args: + shape (List[int]): 原始形状,是一个整数列表。 + axis (List[int]): 需要降维的轴索引列表。 + + Returns: + List[int]: 降维后的形状,是一个整数列表。 + + """ "get shape after reduced which is based on original shape" reduced_ori_shape = [] for i, value in enumerate(shape): @@ -233,6 +511,25 @@ def get_reduced_ori_shape(shape, axis): def get_reduce_axis_shape(shape, data_format, axis): + """ + 根据给定的输入形状、数据格式和轴,获取在指定格式下的归约轴和原始的归约形状。 + + Args: + ----- + shape: list or tuple + 输入的形状。 + data_format: str + 输入的数据格式。 + axis: None, int, list or tuple + 在原始形状下的归约轴。 + + Returns: + -------- + reduce_axis: list + 在指定数据格式下的归约轴。 + ori_reduced_shape: list + 原始的归约形状。 + """ """ Get the reduce axis under format `data_format` and original reduced shape. Parameters diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/addn.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/addn.py index 8dd9049c..111d637f 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/addn.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/addn.py @@ -13,20 +13,47 @@ # limitations under the License. # =========================================================================== """generate json desc for addn""" +# 导入GraphKernelUnsupportedException异常类 from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException +# 导入Expander和ExpanderInfoValidator类 from ._utils import Expander, ExpanderInfoValidator as VLD +# 使用VLD.check_all_formats_same装饰器,确保所有输入格式相同 @VLD.check_all_formats_same class AddN(Expander): """Expand AddN to multiple Adds""" + # 检查输入数量是否大于1 def _check(self): + """ + 检查输入的数量是否满足要求。 + + Args: + 无 + + Returns: + 无 + + Raises: + GKException: 如果输入的数量小于2,则抛出GKException异常。 + """ if len(self.inputs) < 2: raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}" .format(len(self.inputs))) + # 将AddN展开为多个Add操作 def _expand(self, graph_builder): + """ + 对输入张量进行逐元素加法运算。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成图节点。 + + Returns: + Tensor: 逐元素加法运算后的结果张量。 + + """ result = self.inputs[0] for inp in self.inputs[1:]: result = graph_builder.emit('Add', [result, inp]) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm.py index 799dc3f5..bd29f9a3 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm.py @@ -36,15 +36,19 @@ class BatchNorm(Expander): input_x_ori_type = input_x.dtype input_x_new_type = input_x.dtype + # 如果输入数据的类型为float16,而scale、offset、mean、variance的类型为float32,则将输入数据类型转换为float32 if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \ input_mean.dtype == "float32" and input_variance.dtype == "float32": input_x_new_type = "float32" + # 如果输入数据类型与原始类型不同,则进行类型转换 if input_x_new_type != input_x_ori_type: input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type}) + # 如果是训练模式 if self.attrs['is_training']: self.inputs[0] = input_x res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder) + # 如果输入数据类型与原始类型不同,则将输出数据类型转换为原始类型 if input_x_new_type != input_x_ori_type: res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type}) return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec @@ -70,21 +74,42 @@ class BatchNorm(Expander): return res_y, var_add, var_add, var_add, var_add def _bn_train(self, graph_builder): + """ + 在训练模式下扩展BatchNorm。 + + Args: + graph_builder (GraphBuilder): 图构建器实例。 + + Returns: + tuple: 包含以下内容的元组: + - res_y (Tensor): 归一化后的输出。 + - mean_res (Tensor): 更新后的移动均值。 + - variance_res (Tensor): 更新后的移动方差。 + - mean_muls (Tensor): 输入数据的均值。 + - y_sqrt_rec (Tensor): 1 / sqrt(方差 + epsilon),用于反向传播。 + + """ """expand BatchNorm for training mode""" + # 获取输入数据 input_x = self.inputs[0] input_scale = self.inputs[1] input_offset = self.inputs[2] input_mean = self.inputs[3] input_variance = self.inputs[4] + # 获取epsilon值 epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon']) + # 获取reduce轴 reduce_axis = () + # 获取输入数据的形状 shape_x = input_x.shape + # 根据输入数据的格式,设置reduce轴和num值 if input_x.data_format == DF.NHWC: reduce_axis = (0, 1, 2) num = shape_x[0] * shape_x[1] * shape_x[2] else: reduce_axis = (0, 2, 3) num = shape_x[0] * shape_x[2] * shape_x[3] + # 计算num的倒数 num_rec = 1.0 / num num_rec_v = graph_builder.value(input_scale.dtype, num_rec) @@ -112,41 +137,67 @@ class BatchNorm(Expander): y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt]) # compute res_y + # 计算输入x和mean_muls_expand的差值 tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand]) + # 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对y_sqrt_rec进行reshape操作 if input_x.data_format in (DF.DEFAULT, DF.NCHW): y_sqrt_rec_expand = graph_builder.emit( 'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])}) + # 否则,y_sqrt_rec保持不变 else: y_sqrt_rec_expand = y_sqrt_rec + # 计算tmp_sub和y_sqrt_rec_expand的乘积 y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand]) + # 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对input_scale进行reshape操作 if input_x.data_format in (DF.DEFAULT, DF.NCHW): input_scale_expand = graph_builder.emit( 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])}) + # 否则,input_scale保持不变 else: input_scale_expand = input_scale + # 计算input_scale_expand和y_norm的乘积 res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm]) + # 如果输入x的数据格式为DF.DEFAULT或DF.NCHW,则对input_offset进行reshape操作 if input_x.data_format in (DF.DEFAULT, DF.NCHW): input_offset_expand = graph_builder.emit( 'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])}) + # 否则,input_offset保持不变 else: input_offset_expand = input_offset + # 计算res_y_mul和input_offset_expand的和 res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand]) # compute mean_res + # 计算动量减去1的值 momentum_sub = scalar_one - self.attrs['momentum'] + # 将动量减去1的值转换为输入数据的类型 momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub) + # 计算新的移动平均值的临时值 + # 计算新的running_mean_tmp new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean]) + # 计算momentum_v momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum']) + # 计算current_mean_tmp current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls]) + # 计算updated_moving_mean updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp]) + # 将updated_moving_mean赋值给input_mean mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean]) # variance_res is calculated by sample variance, and need to multiply by num / (num - 1) + # 计算方差 var_num = float(num) / (num - 1) + # 将方差转换为输入数据的类型 var_num_v = graph_builder.value(input_scale.dtype, var_num) + # 计算方差乘积 var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul]) + # 计算新的移动方差 new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance]) + # 计算当前移动方差 current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update]) + # 更新移动方差 updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp]) + # 将更新后的移动方差赋值给输入方差 variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance]) + # 返回结果 return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py index eeb94ca1..2393ba90 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py @@ -11,46 +11,59 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# =========================================================================== -"""generate json desc for BatchNormGrad""" +# ======================================================================== +# === +# 版权声明 +# 根据Apache License 2.0授权 +# 除非遵守许可,否则不得使用此文件 + +""" +为BatchNormGrad生成json描述,BatchNormGrad是用于计算Batch Normalization层梯度的类。 +""" + +# 导入必要的模块和类 from mindspore._extends.graph_kernel.model.model import DataFormat as DF from ._utils import Expander, ExpanderInfoValidator as VLD from .expand_dims import ExpandDims - +# 定义BatchNormGrad类,继承自Expander @VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.check_attrs('is_training', 'epsilon') class BatchNormGrad(Expander): - """BatchNormGrad expander""" + """BatchNormGrad扩展器,用于计算Batch Normalization层的梯度""" + # 定义扩展方法,该方法将被调用来执行BatchNormGrad的计算 def _expand(self, graph_builder): - # get op info - input_dy = self.inputs[0] - input_x = self.inputs[1] - input_scale = self.inputs[2] - input_save_mean = self.inputs[3] - input_save_inv_variance = self.inputs[4] + # 获取操作信息,包括梯度、输入数据、尺度、保存的均值和倒数方差 + input_dy = self.inputs[0] # 输入数据的梯度 + input_x = self.inputs[1] # 输入数据 + input_scale = self.inputs[2] # 输入数据的尺度 + input_save_mean = self.inputs[3] # 保存的均值 + input_save_inv_variance = self.inputs[4] # 保存的倒数方差 + # 根据输入数据的格式计算reduce_axis,用于后续的ReduceSum操作 reduce_axis = () shape_x = input_x.shape - if input_x.data_format == DF.NHWC: - reduce_axis = (0, 1, 2) - num = shape_x[0] * shape_x[1] * shape_x[2] - else: - reduce_axis = (0, 2, 3) - num = shape_x[0] * shape_x[2] * shape_x[3] - ori_type = input_x.dtype + if input_x.data_format == DF.NHWC: # 如果数据格式为NHWC + reduce_axis = (0, 1, 2) # 指定ReduceSum的轴 + num = shape_x[0] * shape_x[1] * shape_x[2] # 计算元素总数 + else: # 否则,假设数据格式为NCHW + reduce_axis = (0, 2, 3) # 指定ReduceSum的轴 + num = shape_x[0] * shape_x[2] * shape_x[3] # 计算元素总数 + ori_type = input_x.dtype # 原始数据类型 + + # 如果原始数据类型为float16,则转换为float32进行计算,以避免精度损失 if ori_type == 'float16': input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) if input_dy.dtype == 'float16': input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) - num_rec = -1.0 / num - num_rec_v = graph_builder.value(input_scale.dtype, num_rec) - dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + num_rec = -1.0 / num # 计算倒数 + num_rec_v = graph_builder.value(input_scale.dtype, num_rec) # 创建倒数的值 + dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dbeta,即beta的梯度 - # in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass + # 根据是否在训练中,计算inv_variance(倒数方差) if self.attrs['is_training']: inv_variance = input_save_inv_variance else: @@ -61,7 +74,7 @@ class BatchNormGrad(Expander): scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one) inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps]) - # compute dgamma + # 计算dgamma(gamma的梯度) if input_x.data_format in (DF.DEFAULT, DF.NCHW): input_save_mean = graph_builder.emit( 'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])}) @@ -69,13 +82,13 @@ class BatchNormGrad(Expander): 'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])}) input_scale = graph_builder.emit( 'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])}) - x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) - x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) - dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) + x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) # 计算x减去均值 + x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) # 计算x除以倒数方差 + dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) # 计算dgamma参数 dgamma = graph_builder.emit( - 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) # 计算dgamma - # compute dx + # 计算dx(x的梯度) if self.attrs['is_training']: tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta]) if input_x.data_format in (DF.DEFAULT, DF.NCHW): @@ -95,11 +108,12 @@ class BatchNormGrad(Expander): y_scale = graph_builder.emit('Mul', [input_scale, input_dy]) dx = graph_builder.emit('Mul', [inv_variance, y_scale]) if ori_type == 'float16': - dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) + dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) # 如果原始数据类型为float16,则转换回float16 - # set output tensors' data_format + # 设置输出张量的数据格式 dx.data_format = self.outputs[0]['format'] dgamma.data_format = self.outputs[1]['format'] dbeta.data_format = self.outputs[2]['format'] - return dx, dgamma, dbeta + # 返回计算结果 + return dx, dgamma, dbeta \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py index 161f33c0..8e8394dd 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py @@ -13,37 +13,64 @@ # limitations under the License. # =========================================================================== """generate json desc for bias_add""" +# 导入MindSpore的DataFormat类 from mindspore._extends.graph_kernel.model.model import DataFormat as DF +# 导入Expander和ExpanderInfoValidator类 from ._utils import Expander, ExpanderInfoValidator as VLD +# 为BiasAddGrad类添加DF.DEFAULT、DF.NHWC、DF.NCHW、DF.FRAC_NZ格式的验证 @VLD.add_format(DF.DEFAULT) @VLD.add_format(DF.NHWC) @VLD.add_format(DF.NCHW) @VLD.add_format(DF.FRAC_NZ) +# 定义BiasAddGrad类,继承自Expander类 class BiasAddGrad(Expander): """BiasAddGrad expander""" def _expand(self, graph_builder): + """ + 内部方法,用于扩展输入张量的维度。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成图操作。 + + Returns: + Tensor: 扩展维度后的张量。 + + """ + # 获取输入张量 x = self.inputs[0] + # 定义reduce_axis,用于指定求和的维度 reduce_axis = () + # 如果输入张量的数据格式为NHWC,则reduce_axis为(0, 1, 2) if x.data_format == DF.NHWC: reduce_axis = (0, 1, 2) + # 如果输入张量的数据格式为NCHW,则reduce_axis为(0, 2, 3) elif x.data_format == DF.NCHW: reduce_axis = (0, 2, 3) + # 如果输入张量的数据格式为FRAC_NZ,则reduce_axis为(-2, -3) elif x.data_format == DF.FRAC_NZ: reduce_axis = (-2, -3) + # 如果输入张量的数据格式为DefaultFormat,则根据shape的长度确定reduce_axis else: # DefaultFormat shape's length should be from 2 to 4 + # 如果x的维度为2,则reduce_axis为(0,) if len(x.shape) == 2: reduce_axis = (0,) + # 如果x的维度为3,则reduce_axis为(0, 1) elif len(x.shape) == 3: reduce_axis = (0, 1) + # 否则,reduce_axis为(0, 2, 3) else: reduce_axis = (0, 2, 3) + # 发射ReduceSum操作,计算x的reduce_sum,reduce_axis为reduce_axis,keep_dims为False result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + # 如果x的数据格式为DF.FRAC_NZ,则将result的shape改为x.shape[:-4] + [x.shape[-1] * x.shape[-4]] if x.data_format == DF.FRAC_NZ: out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]] + # 发射Reshape操作,将result的shape改为out_shape result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape}) + # 返回result return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py index e6c345f4..81fe4cbb 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py @@ -21,13 +21,28 @@ class ClipByNormNoDivSum(Expander): """ClipByNormNoDivSum expander""" def _expand(self, graph_builder): + """ + 对输入的张量进行计算,返回计算结果。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。 + + Returns: + Tensor: 计算结果张量。 + + """ input_x0, input_x1, input_x2, input_x3 = self.inputs # cal result + # 计算大于结果 greater_res = graph_builder.emit('Greater', [input_x0, input_x1]) + # 根据大于结果选择input_x0或input_x2 select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2]) + # 计算select_res0的平方根 sqrt_res = graph_builder.emit('Sqrt', [select_res0]) + # 根据大于结果选择sqrt_res或input_x0 select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0]) + # 计算select_res1和input_x3的最大值 result = graph_builder.emit('Maximum', [select_res1, input_x3]) return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/__init__.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/__init__.py index 742db667..8e10678e 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/__init__.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/__init__.py @@ -14,8 +14,13 @@ # ============================================================================ """complex expanders init""" +# 从当前目录下的abs模块中导入CAbs类 from .abs import CAbs +# 从当前目录下的add模块中导入CAdd类 from .add import CAdd +# 从当前目录下的div模块中导入CDiv类 from .div import CDiv +# 从当前目录下的mul模块中导入CMul类 from .mul import CMul +# 从当前目录下的sub模块中导入CSub类 from .sub import CSub diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/abs.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/abs.py index 44f99afd..0f16730e 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/abs.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/abs.py @@ -20,11 +20,23 @@ class CAbs(Expander): """CAbs expander""" def _expand(self, graph_builder): + # 获取输入的第一个元素 input_x = self.inputs[0] + + # 发射指令CReal,将输入x的实部提取出来 x_real = graph_builder.emit('CReal', [input_x]) + # 发射指令CImag,将输入x的虚部提取出来 x_imag = graph_builder.emit('CImag', [input_x]) + + # 发射指令Mul,计算x的实部的平方 squre_x_real = graph_builder.emit('Mul', [x_real, x_real]) + # 发射指令Mul,计算x的虚部的平方 squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag]) + + # 发射指令Add,计算实部和虚部的平方和 squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag]) + + # 发射指令Sqrt,计算平方和的平方根 result = graph_builder.emit('Sqrt', [squre_sum]) + return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/add.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/add.py index efb04e1a..4f64d7da 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/add.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/add.py @@ -22,12 +22,35 @@ class CAdd(Expander): """CAdd expander""" def _expand(self, graph_builder): + """ + 将两个复数相加。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Node: 相加后的复数结果。 + + """ + # 获取输入参数 input_x, input_y = self.inputs + + # 将输入参数x转换为实数部分 x_real = graph_builder.emit('CReal', [input_x]) + # 将输入参数y转换为实数部分 y_real = graph_builder.emit('CReal', [input_y]) + + # 将输入参数x转换为虚数部分 x_imag = graph_builder.emit('CImag', [input_x]) + # 将输入参数y转换为虚数部分 y_imag = graph_builder.emit('CImag', [input_y]) + + # 将x和y的实数部分相加 result_real = graph_builder.emit('Add', [x_real, y_real]) + # 将x和y的虚数部分相加 result_imag = graph_builder.emit('Add', [x_imag, y_imag]) + + # 将相加后的实数部分和虚数部分组合为复数 result = graph_builder.emit('Complex', [result_real, result_imag]) + return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/div.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/div.py index 0f681aa5..dc05064c 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/div.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/div.py @@ -22,22 +22,43 @@ class CDiv(Expander): """CDiv expander""" def _expand(self, graph_builder): + """ + CDiv Implementation + + Args: + graph_builder: 图构建器对象,用于构建计算图。 + + Returns: + 返回复数除法结果。 + + 实现复数除法(CDiv)操作。 + 获取两个输入的复数,分别计算它们的实部和虚部。 + 然后计算分母和分子的实部和虚部,并进行除法运算, + 最后将得到的商的实部和虚部合并为复数结果返回。 + """ """CDiv Implementation""" + # 获取输入的两个复数 input_x, input_y = self.inputs - x_real = graph_builder.emit('CReal', [input_x]) - y_real = graph_builder.emit('CReal', [input_y]) - x_imag = graph_builder.emit('CImag', [input_x]) - y_imag = graph_builder.emit('CImag', [input_y]) - squre_y_real = graph_builder.emit('Mul', [y_real, y_real]) - squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag]) - final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag]) - x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) - x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) - x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) - x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) - final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag]) - final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag]) - result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator]) - result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator]) - result = graph_builder.emit('Complex', [result_real, result_imag]) + # 获取输入复数的实部 + x_real = graph_builder.emit('CReal', [input_x]) # 发射 CReal 操作获取 input_x 的实部 + y_real = graph_builder.emit('CReal', [input_y]) # 发射 CReal 操作获取 input_y 的实部 + # 获取输入复数的虚部 + x_imag = graph_builder.emit('CImag', [input_x]) # 发射 CImag 操作获取 input_x 的虚部 + y_imag = graph_builder.emit('CImag', [input_y]) # 发射 CImag 操作获取 input_y 的虚部 + # 计算分母 + squre_y_real = graph_builder.emit('Mul', [y_real, y_real]) # 计算 y_real 的平方 + squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag]) # 计算 y_imag 的平方 + final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag]) # 计算分母 + # 计算分子 + x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) # 计算 x_real 和 y_real 的乘积 + x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) # 计算 x_imag 和 y_imag 的乘积 + x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) # 计算 x_real 和 y_imag 的乘积 + x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) # 计算 x_imag 和 y_real 的乘积 + final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag]) # 计算分子的实部 + final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag]) # 计算分子的虚部 + # 计算商 + result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator]) # 计算商的实部 + result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator]) # 计算商的虚部 + # 将商合并为复数结果 + result = graph_builder.emit('Complex', [result_real, result_imag]) # 将实部和虚部合并为复数结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/mul.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/mul.py index a964ae96..61c18eac 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/mul.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/mul.py @@ -22,17 +22,45 @@ class CMul(Expander): """CMul expander""" def _expand(self, graph_builder): + """ + 计算两个复数的乘积。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于在图中生成操作节点。 + + Returns: + Result: 计算得到的复数乘积结果。 + + """ """CMul Implementation""" + # 获取输入的两个复数 input_x, input_y = self.inputs - x_real = graph_builder.emit('CReal', [input_x]) - y_real = graph_builder.emit('CReal', [input_y]) - x_imag = graph_builder.emit('CImag', [input_x]) - y_imag = graph_builder.emit('CImag', [input_y]) - x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) - x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) - x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) - x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) - result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag]) - result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real]) - result = graph_builder.emit('Complex', [result_real, result_imag]) + + # 获取输入复数的实部 + x_real = graph_builder.emit('CReal', [input_x]) # 发射指令获取input_x的实部 + y_real = graph_builder.emit('CReal', [input_y]) # 发射指令获取input_y的实部 + + # 获取输入复数的虚部 + x_imag = graph_builder.emit('CImag', [input_x]) # 发射指令获取input_x的虚部 + y_imag = graph_builder.emit('CImag', [input_y]) # 发射指令获取input_y的虚部 + + # 计算实部与实部的乘积 + x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real]) # 发射指令计算x_real与y_real的乘积 + + # 计算虚部与虚部的乘积 + x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag]) # 发射指令计算x_imag与y_imag的乘积 + + # 计算实部与虚部的乘积 + x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag]) # 发射指令计算x_real与y_imag的乘积 + x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real]) # 发射指令计算x_imag与y_real的乘积 + + # 计算复数的实部 + result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag]) # 发射指令计算实部结果 + + # 计算复数的虚部 + result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real]) # 发射指令计算虚部结果 + + # 构造复数结果 + result = graph_builder.emit('Complex', [result_real, result_imag]) # 发射指令构造复数结果 + return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/sub.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/sub.py index f3715769..8b846f68 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/sub.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/complex/sub.py @@ -22,12 +22,38 @@ class CSub(Expander): """CSub expander""" def _expand(self, graph_builder): + """ + 计算两个复数的差。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。 + + Returns: + Tensor: 计算得到的复数差的结果。 + + """ + # 获取输入 input_x, input_y = self.inputs + + # 提取输入x的实部 x_real = graph_builder.emit('CReal', [input_x]) + + # 提取输入y的实部 y_real = graph_builder.emit('CReal', [input_y]) + + # 提取输入x的虚部 x_imag = graph_builder.emit('CImag', [input_x]) + + # 提取输入y的虚部 y_imag = graph_builder.emit('CImag', [input_y]) + + # 计算实部之差 result_real = graph_builder.emit('Sub', [x_real, y_real]) + + # 计算虚部之差 result_imag = graph_builder.emit('Sub', [x_imag, y_imag]) + + # 将实部和虚部组合成复数结果 result = graph_builder.emit('Complex', [result_real, result_imag]) + return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/conv2d.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/conv2d.py index 36b3fb72..40441dd5 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/conv2d.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/conv2d.py @@ -18,6 +18,7 @@ from mindspore._extends.graph_kernel.model.model import DataFormat as DF from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException from ._utils import Expander, ExpanderInfoValidator as VLD +# 定义常量 M_ALIGN = 32 N_ALIGN = 32 K_ALIGN = 16 @@ -29,6 +30,7 @@ C_CHANNEL_ALIGN = 16 OUT_NHW_ALIGN = 128 +# 添加格式验证 @VLD.add_format(DF.DEFAULT, DF.DEFAULT) @VLD.add_format(DF.NHWC, DF.NHWC) @VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation') @@ -47,6 +49,16 @@ class Conv2D(Expander): """ def __init__(self, expand_info): + """ + 类的构造函数 + + Args: + expand_info (dict): 扩展信息字典,包含一些扩展的配置参数。 + + Returns: + None + + """ super().__init__(expand_info) self.dst_type = self.outputs[0]['data_type'] self.dst_format = self.outputs[0]['format'] @@ -59,6 +71,19 @@ class Conv2D(Expander): self.k = 0 def _optimize_to_matmul(self): + """ + 检查是否可以将Conv2D优化为MatMul。 + + Args: + 无 + + Returns: + bool: 如果可以将Conv2D优化为MatMul,则返回True;否则返回False。 + + """ + """ + Check if the Conv2D can be optimized to MatMul. + """ stride = self.attrs['stride'] dilation = self.attrs['dilation'] _, h, w, _ = self.inputs[1]['shape'] @@ -68,6 +93,18 @@ class Conv2D(Expander): return False def _common_check(self): + """ + 对输入和属性的通用检查 + + Args: + 无 + + Returns: + 无 + + Raises: + GKException: 如果输入数据类型不是 float16,或者输入格式不是 NHWC,或者属性 groups 和 group 不是 1,或者属性 dilation 不是 [1, 1, 1, 1],抛出异常 + """ """common check for inputs and attrs""" type_0 = self.inputs[0]['data_type'] type_1 = self.inputs[1]['data_type'] @@ -91,26 +128,52 @@ class Conv2D(Expander): .format(dilation)) def _check(self): + """ + 检查卷积2D操作的参数和输入是否合法。 + + Args: + 无 + + Raises: + GKException: 当输入参数或输入维度不满足要求时抛出异常。 + + Returns: + 无 + """ + # 调用_common_check()方法 self._common_check() + # 获取pad_list pad_list = self.attrs['pad_list'] + # 检查pad_list的维度是否为4 check_nd(pad_list, 4) + # 调用conv_had_pad()方法,判断是否有pad self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode']) + # 获取输入的shape shape_0 = self.inputs[0]['shape'] shape_1 = self.inputs[1]['shape'] + # 获取stride stride = self.attrs['stride'] + # 检查shape_0的维度是否为4 check_nd(shape_0, 4) + # 检查shape_1的维度是否为4 check_nd(shape_1, 4) + # 检查stride的维度是否为4 check_nd(stride, 4) + # 获取shape_0的各个维度 n0, h0, w0, c0 = shape_0 + # 获取shape_1的各个维度 n1, h1, w1, c1 = shape_1 + # 检查n0是否为N0_CHANNEL_ALIGN的倍数 if (n0 % N0_CHANNEL_ALIGN) != 0: raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}" .format(N0_CHANNEL_ALIGN, n0)) + # 检查n1是否为N1_CHANNEL_ALIGN的倍数 if (n1 % N1_CHANNEL_ALIGN) != 0: raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}" .format(N1_CHANNEL_ALIGN, n1)) + # 检查c0和c1是否相等,并且是否为C_CHANNEL_ALIGN的倍数 if c0 != c1 or (c0 % C_CHANNEL_ALIGN) != 0: raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got " "{} and {}".format(C_CHANNEL_ALIGN, c0, c1)) @@ -130,68 +193,106 @@ class Conv2D(Expander): # check if can optimize to matmul self.m, self.n, self.k = n0 * h0 * w0, n1, c1 + # 调用_optimize_to_matmul()方法,判断是否可以优化为matmul self.can_optimize_to_matmul = self._optimize_to_matmul() # requirements if self.can_optimize_to_matmul: + # 如果可以优化为matmul,检查k是否大于K_LIMIT if self.k > K_LIMIT: raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got " "{}".format(K_LIMIT, self.k)) + # 如果可以优化为matmul,检查m*n*k的总大小是否大于MNK_LIMIT if self.m * self.n * self.k >= MNK_LIMIT: raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than " "{}, but got {}".format(MNK_LIMIT, self.m * self.n * self.k)) else: + # 如果不能优化为matmul,计算输出的大小 out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1 + # 检查n0*out_h*out_w是否为OUT_NHW_ALIGN的倍数 if ((n0 * out_h * out_w) % OUT_NHW_ALIGN) != 0: raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}" .format(n0, out_h, out_w, OUT_NHW_ALIGN)) + # 检查stride是否为[1, 1, 2, 2] if stride != [1, 1, 2, 2]: raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}" .format(stride)) + # 保存pad后的shape self.shape_0_pad = [n0, h0, w0, c0] self.shape_1_pad = [n1, h1, w1, c1] - def _expand(self, graph_builder): +def _expand(self, graph_builder): + """ + 对输入进行扩展处理。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展处理后的结果。 + + """ + # 获取输入0 input_0 = self.inputs[0] + # 获取输入1 input_1 = self.inputs[1] + # 获取输入0的形状 n0, _, _, c0 = input_0.shape + # 获取输入1的形状 n1, _, _, c1 = input_1.shape + # 获取输入0的填充形状 n0_p, h0_p, w0_p, c0_p = self.shape_0_pad + # 获取输入1的填充形状 n1_p, _, _, c1_p = self.shape_1_pad pad_value = 0 # input0 pad + # 初始化输入0的填充前后的值 input_0_pad_before = [0, 0, 0, 0] input_0_pad_after = [0, 0, 0, 0] + # 如果有填充,则获取填充列表 if self.has_pad: pad_list = self.attrs['pad_list'] + # 设置输入0的填充前后的值 input_0_pad_before = [0, pad_list[0], pad_list[2], 0] input_0_pad_after = [0, pad_list[1], pad_list[3], 0] + # 设置输入0的填充后的值 input_0_pad_after[0] = n0_p - n0 input_0_pad_after[3] = c0_p - c0 + # 如果输入0的填充前后的值不为默认值,则进行填充操作 if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]: + # 发射填充操作 input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before, 'tail': input_0_pad_after, 'pad_val': pad_value}) # input1 pad + # 计算input_1的pad值 input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1] + # 如果input_1的pad值不为0,则进行pad操作 if input_1_pad_after != [0, 0, 0, 0]: input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0], 'tail': input_1_pad_after, 'pad_val': pad_value}) + # 如果可以优化为matmul操作,则进行matmul操作 if self.can_optimize_to_matmul: + # 将input_0和input_1进行reshape操作 a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]}) b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]}) + # 进行matmul操作 c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False, 'transpose_b': True, 'dst_type': self.dst_type}) + # 将结果进行reshape操作 result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p], 'format': self.dst_format}) + # 否则进行Conv2D操作 else: + # 设置Conv2D操作的属性 attrs = self.attrs attrs['pad_list'] = [0, 0, 0, 0] attrs['dst_type'] = self.dst_type + # 进行Conv2D操作 result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs) # unpad unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]] diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/dropout_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/dropout_grad.py index ac7a011f..6293e4de 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/dropout_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/dropout_grad.py @@ -13,18 +13,36 @@ # limitations under the License. # =========================================================================== """generate json desc for DropoutGrad""" +# 导入Expander和ExpanderInfoValidator类 from ._utils import Expander, ExpanderInfoValidator as VLD +# 定义DropoutGrad类,继承自Expander类 @VLD.check_all_formats_same @VLD.check_attrs('keep_prob') class DropoutGrad(Expander): """DropoutGrad expander""" def _expand(self, graph_builder): + """ + 对输入数据进行扩展操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展后的输入数据。 + + """ + # 获取输入数据和掩码 input_dy, input_mask = self.inputs + # 获取保持概率 keep_prob = self.attrs['keep_prob'] + # 计算保持概率的倒数 r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob) + # 计算输入数据和保持概率的乘积 result = graph_builder.emit('Mul', [input_dy, r_keep_prob]) + # 计算乘积和掩码的乘积 result = graph_builder.emit('Mul', [result, input_mask]) + # 返回结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/equal_count.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/equal_count.py index 53bfbb0e..940d1c33 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/equal_count.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/equal_count.py @@ -17,34 +17,84 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx from ._utils import Expander, ExpanderInfoValidator as VLD +# @VLD.check_all_formats_same:检查所有格式的相同性 @VLD.check_all_formats_same class EqualCount(Expander): """EqualCount expander""" def __init__(self, expand_info): + """ + 初始化方法。 + + Args: + expand_info (dict): 扩展信息字典。 + + Returns: + None + + """ + # 调用父类的初始化方法 super().__init__(expand_info) + # 获取输入x的形状 self.shape_x = self.inputs[0]['shape'] + # 获取输入y的形状 self.shape_y = self.inputs[1]['shape'] + # 获取输入x的数据类型 self.dtype_x = self.inputs[0]['data_type'] + # 获取输入y的数据类型 self.dtype_y = self.inputs[1]['data_type'] def _check(self): + """ + 检查输入的两个张量是否具有相同的形状和数据类型。 + + Args: + 无 + + Returns: + 无 + + Raises: + GKException: 如果两个张量的形状不同,则引发异常,异常信息中包含两个张量的形状。 + GKException: 如果两个张量的数据类型不同,则引发异常,异常信息中包含两个张量的数据类型。 + """ + # 判断输入的形状是否相同 if self.shape_x != self.shape_y: + # 如果不相同,抛出异常 raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}" .format(self.shape_x, self.shape_y)) + # 判断输入的数据类型是否相同 if self.dtype_x != self.dtype_y: + # 如果不相同,抛出异常 raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}" .format(self.dtype_x, self.dtype_y)) def _expand(self, graph_builder): + """ + 扩展输入维度的方法。 + + Args: + graph_builder: 图构建器对象,用于生成计算图。 + + Returns: + 扩展后的张量。 + + """ + # 获取输入张量 input_x = self.inputs[0] input_y = self.inputs[1] + # 比较输入张量是否相等 eql_val = graph_builder.emit('Equal', [input_x, input_y]) + # 将比较结果转换为float32类型 cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'}) + # 获取输入张量的维度 axis = list(range(len(input_x.shape))) + # 对比较结果进行求和 result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': True}) + # 如果求和结果的数据类型与输入张量的数据类型不同,则将求和结果转换为输入张量的数据类型 if result.dtype != input_x.dtype: result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype}) + # 返回求和结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py index 7e97c455..198120d6 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/erfc.py @@ -16,20 +16,44 @@ from ._utils import Expander +# 定义一个Erfc类,继承自Expander类 class Erfc(Expander): """Erfc expander""" def _expand(self, graph_builder): + """ + 对输入数据进行扩展处理。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展处理后的结果。 + + """ + # 获取输入数据 input_x = self.inputs[0] + # 初始化结果 result = None + # 如果输入数据的类型是float16 if input_x.dtype == "float16": + # 创建一个float32类型的常量1 const_one = graph_builder.value("float32", 1) + # 将输入数据转换为float32类型 input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"}) + # 计算输入数据的erf值 erf_result = graph_builder.emit('Erf', [input_x]) + # 计算结果 result = graph_builder.emit('Sub', [const_one, erf_result]) + # 将结果转换为float16类型 result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"}) + # 返回结果 return result + # 创建一个与输入数据类型相同的常量1 const_one = graph_builder.value(input_x.dtype, 1) + # 计算输入数据的erf值 erf_result = graph_builder.emit('Erf', [input_x]) + # 计算结果 result = graph_builder.emit('Sub', [const_one, erf_result]) + # 返回结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/expand_dims.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/expand_dims.py index 7403f119..ff2b46e5 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/expand_dims.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/expand_dims.py @@ -21,6 +21,16 @@ class ExpandDims(Expander): """ExpandDims expander""" def _expand(self, graph_builder): + """ + 对输入数据进行维度扩展。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展后的数据。 + + """ input_x = self.inputs[0] shape = self.infer_shape(input_x.shape, self.attrs['axis']) result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape}) @@ -29,8 +39,35 @@ class ExpandDims(Expander): @staticmethod def infer_shape(shape, axis): + """ + 根据给定的轴位置推断新的形状。 + + Args: + shape (list or tuple): 原始形状,表示一个多维数组的尺寸。 + axis (int, list or tuple): 指定要插入新维度的轴位置。如果为整数,表示在指定位置插入一个维度;如果为列表或元组,则按顺序在指定位置插入多个维度。 + + Returns: + list: 插入新维度后的新形状。 + + Raises: + ValueError: 如果axis的值或类型不符合要求时抛出。 + + """ """infer shape for expand_dims""" def insert_axis(shape, axis): + """ + 在指定轴上插入一个新的维度。 + + Args: + shape (list): 原始数组的形状,类型为列表。 + axis (int): 要插入新维度的轴的位置。 + + Returns: + list: 插入新维度后的数组形状。 + + Raises: + ValueError: 如果axis的类型不是int,或者axis的值不在合法范围内,将抛出异常。 + """ if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1: raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, " "{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis))) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam.py index c424f0aa..fc781df9 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam.py @@ -21,24 +21,51 @@ class FusedAdam(Expander): """FusedAdam expander""" def _expand(self, graph_builder): + """ + 使用图构建器对模型参数进行更新。 + + Args: + graph_builder (GraphBuilder): 图构建器实例,用于生成计算图。 + + Returns: + Tensor: 更新后的参数结果。 + + """ + # 获取输入参数 beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs + # 计算beta_1乘以m beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) + # 计算one_sub_beta_1乘以gradient one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) + # 计算next_m next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) + # 计算beta_2乘以v beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) + # 计算gradient的平方 grad_square = graph_builder.emit('Mul', [gradient, gradient]) + # 计算one_sub_beta_2乘以gradient的平方 one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) + # 计算next_v next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) + # 计算next_v的平方根 sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) + # 计算sqrt_next_v加上eps sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps]) + # 计算更新值 update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) + # 计算更新值乘以lr update_with_lr = graph_builder.emit('Mul', [lr, update]) + # 计算next_para next_para = graph_builder.emit('Sub', [param, update_with_lr]) + # 将next_para赋值给param param_result = graph_builder.emit( 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) + # 将next_m赋值给m param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True}) + # 将next_v赋值给v param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True}) + # 返回param_result return param_result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py index 53598068..47dcbce1 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py @@ -21,27 +21,54 @@ class FusedAdamWeightDecay(Expander): """FusedAdamWeightDecay expander""" def _expand(self, graph_builder): + """ + 对输入参数进行梯度下降更新。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于在图中添加节点。 + + Returns: + ParaResult: 更新后的参数结果节点。 + + """ beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs # compute result + # 计算beta_1和m的乘积 beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m]) + # 计算one_sub_beta_1和gradient的乘积 one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient]) + # 计算next_m next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad]) + # 计算beta_2和v的乘积 beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v]) + # 计算gradient的平方 grad_square = graph_builder.emit('Mul', [gradient, gradient]) + # 计算one_sub_beta_2和grad_square的乘积 one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square]) + # 计算next_v next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square]) + # 计算sqrt_next_v sqrt_next_v = graph_builder.emit('Sqrt', [next_v]) + # 计算sqrt_next_v和eps的和 sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps]) + # 计算update update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps]) + # 计算param_with_weight_decay param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param]) + # 计算update和param_with_weight_decay的和 update = graph_builder.emit('Add', [update, param_with_weight_decay]) + # 计算update_with_lr update_with_lr = graph_builder.emit('Mul', [lr, update]) + # 计算next_para next_para = graph_builder.emit('Sub', [param, update_with_lr]) + # 将next_para赋值给param para_result = graph_builder.emit( 'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True}) + # 将next_m赋值给m para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True}) + # 将next_v赋值给v para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True}) return para_result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_mul_add.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_mul_add.py index 86f3a4d1..3cad5a0a 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/fused_mul_add.py @@ -20,9 +20,23 @@ class FusedMulAdd(Expander): """FusedMulAdd expander""" def _expand(self, graph_builder): + """ + 执行扩展操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 执行加法操作后的结果。 + + """ + # 获取输入 input_x, input_y, input_z = self.inputs + # 发射乘法操作 mul_res = graph_builder.emit('Mul', [input_x, input_y]) + # 发射加法操作 result = graph_builder.emit('Add', [mul_res, input_z]) + # 返回结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gather.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gather.py index fb690816..f31bcb76 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gather.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gather.py @@ -22,22 +22,47 @@ class Gather(Expander): """Expand Gather""" def _expand(self, graph_builder): + """ + 对输入张量进行扩展操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展后的张量。 + + """ + # 获取输入和索引 inputs, indices = self.inputs + # 获取轴 axis = self.attrs['axis'] + # 如果轴小于0,则将其转换为正数 if axis < 0: axis += len(inputs.shape) + # 如果索引的维度为1,则直接进行Gather操作 if len(indices.shape) == 1: result = graph_builder.emit('Gather', [inputs, indices], attrs={'axis': axis}) + # 否则,对索引进行Reshape操作,然后进行Gather操作,最后再进行Reshape操作 else: + # 获取原始索引的形状 ori_indices_shape = indices.shape + # 计算索引的形状的乘积 indices_shape_one_dim = 1 for dim in ori_indices_shape: indices_shape_one_dim *= dim + # 构造新的索引形状 new_indices_shape = [indices_shape_one_dim] + # 对索引进行Reshape操作 reshape_indices = graph_builder.emit('Reshape', [indices], attrs={'shape': new_indices_shape}) + # 对输入和Reshape后的索引进行Gather操作 tmp_result = graph_builder.emit('Gather', [inputs, reshape_indices], attrs={'axis': axis}) + # 获取输出的形状 output_shape = list(inputs.shape) + # 将索引的形状插入到输出的形状中 output_shape[axis:axis] = ori_indices_shape + # 删除输出的形状中多余的维度 del output_shape[axis + len(ori_indices_shape)] + # 对Gather操作的结果进行Reshape操作 result = graph_builder.emit('Reshape', [tmp_result], attrs={'shape': output_shape}) + # 返回结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu.py index 24fe81bc..24aeab68 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu.py @@ -22,6 +22,16 @@ class GeLU(Expander): CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) def _expand(self, graph_builder): + """ + 计算输入张量的GELU激活函数值。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。 + + Returns: + Tensor: 输入张量的GELU激活函数值。 + + """ # cal formula are: # gelu of x is 0.5 * x * (1.0 + tanh(y)) # y is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)' @@ -29,20 +39,33 @@ class GeLU(Expander): input_x = self.inputs[0] # cal y + # 计算 input_x 的平方 mul_0 = graph_builder.emit('Mul', [input_x, input_x]) + # 计算 input_x 的立方 pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) + # 创建一个 CSVALUE 常量 const_csvalue = graph_builder.value(pow_0.dtype, self.CSVALUE) + # 计算 pow_0 和 CSVALUE 的乘积 mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) + # 计算 input_x 和 mul_1 的和 tanh_res = graph_builder.emit('Add', [input_x, mul_1]) + # 创建一个 CSVALUE_SQRT_TWO_DIV_PI 常量 const_csvalue_sqrt_two_div_pi = graph_builder.value(tanh_res.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) + # 计算 tanh_res 和 CSVALUE_SQRT_TWO_DIV_PI 的乘积 y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) # cal gelu(x) + # 计算 y 的 tanh 值 tanh_y = graph_builder.emit('Tanh', [y]) + # 创建一个 1 常量 const_one = graph_builder.value(tanh_y.dtype, 1) + # 创建一个 0.5 常量 const_half = graph_builder.value(tanh_y.dtype, 0.5) + # 计算 tanh_y 和 1 的和 tanh_y_add_one = graph_builder.emit('Add', [tanh_y, const_one]) + # 计算 input_x 和 tanh_y_add_one 的乘积 mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) + # 计算 const_half 和 mul_x 的乘积 result = graph_builder.emit('Mul', [const_half, mul_x]) return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu_grad.py index 09c15055..1ddc9816 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gelu_grad.py @@ -19,11 +19,29 @@ from ._utils import Expander, ExpanderInfoValidator as VLD @VLD.check_all_formats_same class GeLUGrad(Expander): """GeLUGrad expander""" - CSVALUE = 0.044715 + + # CSVALUE = 0.044715 + CSVALUE = 0.044715 # CSVALUE的值为0.044715 CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) CSVALUE_TRI = 0.134141 # CSVALUE * 3 def _expand(self, graph_builder): + """ + 计算GELU函数的梯度。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于生成计算图。 + + Returns: + Tensor: GELU函数的梯度。 + + 计算公式如下: + GELU的梯度dy和x是dy * y' + y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right + tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) + mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) + + """ # cal formula are: # gelu_grad of dy and x is dy * y' # y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right @@ -33,21 +51,33 @@ class GeLUGrad(Expander): input_dy, input_x, _ = self.inputs # create some const var + # 创建一个常量,值为self.CSVALUE,数据类型为input_dy.dtype const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE) + # 创建一个常量,值为self.CSVALUE_SQRT_TWO_DIV_PI,数据类型为input_dy.dtype const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI) + # 创建一个常量,值为self.CSVALUE_TRI,数据类型为input_dy.dtype const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI) + # 创建一个常量,值为1,数据类型为input_dy.dtype const_one = graph_builder.value(input_dy.dtype, 1) + # 创建一个常量,值为0.5,数据类型为input_dy.dtype const_half = graph_builder.value(input_dy.dtype, 0.5) # cal mul_right + # 计算input_x的平方 mul_double = graph_builder.emit('Mul', [input_x, input_x]) + # 将const_csvalue_tri与mul_double相乘 mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double]) + # 将const_one与mul_double_mul_tri相加 mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri]) + # 将const_csvalue_sqrt_two_div_pi与mul_add_one相乘 mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one]) # cal tanh_para + # 计算input_x和mul_double的乘积 mul_triple = graph_builder.emit('Mul', [input_x, mul_double]) + # 计算const_csvalue和mul_triple的乘积 mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple]) + # 计算input_x和mul_triple_mul_csvalue的和 mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue]) tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x]) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gkdropout.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gkdropout.py index 8101c6c7..74870291 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gkdropout.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/gkdropout.py @@ -22,12 +22,27 @@ class GkDropout(Expander): """GkDropout expander""" def _expand(self, graph_builder): + """ + 对输入数据进行dropout操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + tuple: 包含两个元素,第一个是执行dropout操作后的结果,第二个是生成的掩码。 + + """ + # 获取输入数据和掩码 input_x, input_mask = self.inputs + # 获取保持概率 keep_prob = self.attrs['keep_prob'] + # 计算保持概率的倒数 r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob) + # 计算保持概率 keep_prob = graph_builder.value(input_x.dtype, keep_prob) + # 如果掩码的数据类型与输入数据类型不同,则进行类型转换 if input_mask.dtype != input_x.dtype: input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype}) mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/identity.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/identity.py index fe500660..d0abf4c3 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/identity.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/identity.py @@ -20,6 +20,16 @@ class Identity(Expander): """Identity expander""" def _expand(self, graph_builder): - input_x = self.inputs[0] - result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape}) - return result + """ + 对输入数据进行重塑操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于构建计算图。 + + Returns: + Tensor: 重塑后的输入数据。 + + """ + input_x = self.inputs[0] # 获取输入数据 + result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape}) # 使用图构建器对象构建计算图,对输入数据进行重塑操作 + return result # 返回重塑后的输入数据 diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/layernorm.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/layernorm.py index eaa44140..791445c1 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/layernorm.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/layernorm.py @@ -25,67 +25,107 @@ class LayerNorm(Expander): """LayerNorm expander""" def _expand(self, graph_builder): + """ + 对输入进行扩展处理,包括批量归一化操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于构建计算图。 + + Returns: + tuple: 包含三个元素的元组,分别是处理后的输入、均值和方差。 + - res (Tensor): 处理后的输入张量。 + - mean (Tensor): 输入的均值张量。 + - variance (Tensor): 输入的方差张量。 + + """ + # 获取输入数据 input_x, input_gamma, input_beta = self.inputs + # 获取处理器类型 processor = self.processor + # 获取归一化开始轴 begin_norm_axis = self.attrs['begin_norm_axis'] + # 获取epsilon值 epsilon = self.attrs['epsilon'] + # 获取输入数据的原始数据类型 ori_dtype = input_x.dtype + # 如果处理器类型为aicore且输入数据类型为float16,则将输入数据类型转换为float32 if processor == 'aicore' and ori_dtype == 'float16': input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'}) input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'}) + # 获取输入数据的原始形状 ori_shape_x = input_x.shape + # 如果输入数据的格式为FRAC_NZ,则根据FRAC_NZ格式获取输入数据的形状 if input_x.data_format == DF.FRAC_NZ: ori_shape_x = infer_shape_from_fractalnz(input_x.shape) # Calculate the scaling ratio of the average + # 如果begin_norm_axis小于0,则将其加上ori_shape_x的长度 if begin_norm_axis < 0: begin_norm_axis += len(ori_shape_x) + # 定义reduce_axis,用于存储需要归一化的维度 reduce_axis = () + # 遍历ori_shape_x,如果维度大于begin_norm_axis或者等于begin_norm_axis,则将其加入reduce_axis for i, _ in enumerate(ori_shape_x): if i > begin_norm_axis or i == begin_norm_axis: reduce_axis = reduce_axis + (i,) + # 计算reduce_elts,即需要归一化的维度上的元素个数 reduce_elts = 1.0 for i in reduce_axis: reduce_elts *= ori_shape_x[i] # after reduced + # 获取归一化后的ori_shape_x ori_reduced_shape_x = get_reduced_ori_shape(ori_shape_x, reduce_axis) + # 定义axis,用于存储归一化的维度 axis = reduce_axis + # 如果input_x的数据格式为DF.FRAC_NZ,则将axis转换为frac_z轴 if input_x.data_format == DF.FRAC_NZ: axis = to_frac_z_axis(ori_shape_x, reduce_axis) + # 计算mean_cof_v,即归一化系数 mean_cof_v = graph_builder.value(input_x.dtype, 1.0 / reduce_elts) # Calculate mean + # 计算输入张量的均值 mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) + # 将均值乘以系数 mean = graph_builder.emit('Mul', [mean_red, mean_cof_v]) + # 如果输入张量的数据格式为DF.FRAC_NZ,则对均值进行重整 if input_x.data_format == DF.FRAC_NZ: mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x}) # Calculate variance - variance_sub = graph_builder.emit('Sub', [input_x, mean]) - variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub]) - variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': axis, 'keep_dims': True}) - variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) + # 计算方差 + variance_sub = graph_builder.emit('Sub', [input_x, mean]) # 计算输入与均值的差值 + variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub]) # 计算差值的平方 + variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': axis, 'keep_dims': True}) # 对差值的平方求和 + variance = graph_builder.emit('Mul', [variance_red, mean_cof_v]) # 计算方差 + # 如果输入数据的格式为DF.FRAC_NZ,则对方差进行reshape操作 if input_x.data_format == DF.FRAC_NZ: variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x}) # Calculate normalize + # 计算输入x与均值之间的差值 normalize_sub = graph_builder.emit('Sub', [input_x, mean]) + # 创建一个epsilon值,用于防止除零错误 epsilon_v = graph_builder.value(input_x.dtype, epsilon) + # 计算方差加上epsilon的值 normalize_add = graph_builder.emit('Add', [variance, epsilon_v]) normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add]) normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt]) # Calculate scale and translate + # 计算归一化后的乘积 scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma]) + # 计算最终结果 res = graph_builder.emit('Add', [scale_mul, input_beta]) + # 如果处理器为aicore且原始数据类型为float16,则将结果、均值和方差转换为float16 if processor == 'aicore' and ori_dtype == 'float16': res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'}) mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'}) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py index 2ae7078b..4facbef1 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/layernorm_grad.py @@ -23,13 +23,33 @@ class LayerNormGrad(Expander): """LayerNormGrad expander""" def _expand(self, graph_builder): + """ + 对输入进行扩展操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + tuple: 包含dx, dg, db的元组。 + dx (Tensor): 梯度相对于输入x的导数。 + dg (Tensor): 梯度相对于gamma的导数。 + db (Tensor): 梯度相对于beta的导数。 + + """ + # 获取输入参数 x, dy, variance, mean, gamma = self.inputs + # 获取处理器类型 processor = self.processor + # 获取归一化轴的起始位置 begin_norm_axis = self.attrs['begin_norm_axis'] + # 获取参数轴的起始位置 begin_params_axis = self.attrs['begin_params_axis'] + # 获取epsilon值,默认为1e-12 epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12 + # 获取输入数据的原始数据类型 ori_dtype = x.dtype + # 如果处理器类型为aicore且数据类型为float16,则将输入数据转换为float32 if processor == 'aicore' and ori_dtype == 'float16': x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'}) dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'}) @@ -37,77 +57,121 @@ class LayerNormGrad(Expander): mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'}) gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'}) + # 如果归一化轴的起始位置小于0,则将其转换为正数 if begin_norm_axis < 0: begin_norm_axis += len(x.shape) + # 如果参数轴的起始位置小于0,则将其转换为正数 if begin_params_axis < 0: begin_params_axis += len(x.shape) + # 获取归一化轴和参数轴的范围 norm_axis = tuple(range(begin_norm_axis, len(x.shape))) param_axis = tuple(range(0, begin_params_axis)) + # 计算归一化轴的维度乘积 reduce_size = 1.0 for i in norm_axis: reduce_size *= x.shape[i] # set some constant val. + # 计算epsilon的值 eps = graph_builder.value(x.dtype, epsilon) + # 计算-0.5的值 const_neg_half = graph_builder.value(x.dtype, -0.5) + # 计算-2.0的值 const_neg_two = graph_builder.value(x.dtype, -2.0) + # 计算2.0的值 const_two = graph_builder.value(x.dtype, 2.0) + # 计算-1.0的值 const_neg_one = graph_builder.value(x.dtype, -1.0) + # 计算mean_cof的值 mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size)) # cal dg db + # 计算方差和eps的和 var_eps = graph_builder.emit('Add', [variance, eps]) + # 计算方差和eps的和的对数 var_eps_log = graph_builder.emit('Log', [var_eps]) + # 计算方差和eps的对数乘以-0.5 var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half]) + # 计算方差和eps的对数乘以-0.5的指数 rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul]) + # 计算x和mean的差 + # 计算输入x减去均值 x_sub_mean = graph_builder.emit('Sub', [x, mean]) + # 计算x减去均值乘以rsqrt_var_eps x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean]) + # 计算dy乘以x减去均值乘以rsqrt_var_eps dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps]) + # 计算dg,对dg_mul进行求和,reduce_axis为param_axis,keep_dims为False dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False}) + # 计算db,对dy进行求和,reduce_axis为param_axis,keep_dims为False db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False}) # pd_var + # 计算tmp_var_eps tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps]) + # 计算r_tmp_var_eps r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps]) + # 计算dy_mul_gamma dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma]) + # 计算tmp_mul tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean]) + # 计算padvar_mul1 padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + # 计算padvar_mul3 padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps]) + # 计算pd_var pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half]) # pd_mean + # 计算pdmean1_sum,使用ReduceSum函数,输入为dy_mul_gamma,归约轴为norm_axis,保持维度为True pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + # 计算neg_rsqrt_var_eps,使用Mul函数,输入为rsqrt_var_eps和const_neg_one neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one]) + # 计算pd_mean_1,使用Mul函数,输入为neg_rsqrt_var_eps和pdmean1_sum pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum]) + # 计算pdmean2_mul1,使用Mul函数,输入为const_neg_two和x_sub_mean pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean]) + # 计算pdmean2_sum,使用ReduceSum函数,输入为pdmean2_mul1,归约轴为norm_axis,保持维度为True pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1], attrs={'reduce_axis': norm_axis, 'keep_dims': True}) + # 计算pdmean2_mul3,使用Mul函数,输入为pdmean2_sum和mean_cof pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof]) + # 计算pd_mean_2,使用Mul函数,输入为pdmean2_mul3和pd_var pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var]) + # 计算pd_mean,使用Add函数,输入为pd_mean_1和pd_mean_2 pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2]) # cal dx + # 计算pd_x_1 pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps]) + # 计算pdx2_mul pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean]) + # 计算pdx2_mul_two pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two]) + # 计算pd_x_2 pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof]) + # 计算pd_x_3 pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof]) + # 计算dx_tmp dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2]) + # 计算dx dx = graph_builder.emit('Add', [dx_tmp, pd_x_3]) + # 如果处理器为aicore且原始数据类型为float16,则将dx、dg、db转换为float16 if processor == 'aicore' and ori_dtype == 'float16': dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'}) db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'}) + # 返回dx、dg、db return dx, dg, db diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/logsoftmax.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/logsoftmax.py index 27aa8035..691f4780 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/logsoftmax.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/logsoftmax.py @@ -23,24 +23,49 @@ class LogSoftmax(Expander): """LogSoftmax expander""" def _expand(self, graph_builder): + """ + 对输入数据进行Softmax归一化。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于构建计算图。 + + Returns: + Tensor: Softmax归一化后的结果。 + + """ + # 获取输入数据 input_x = self.inputs[0] + # 获取轴参数 axis = self.attrs['axis'] + # 获取处理器类型 processor = self.processor + # 如果轴参数是整数,则将其转换为元组 if isinstance(axis, int): axis = (axis,) + # 获取输入数据的原始数据类型 ori_dtype = input_x.dtype + # 如果原始数据类型不是float16且处理器类型是aicore,则将输入数据转换为float16 if ori_dtype != "float16" and processor == "aicore": input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'}) + # 对转换后的数据进行ReduceMax操作 max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True}) + # 将ReduceMax操作的结果转换回原始数据类型 max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype}) else: + # 对输入数据进行ReduceMax操作 max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True}) + # 计算输入数据与ReduceMax操作结果的差值 data_sub = graph_builder.emit('Sub', [input_x, max_x]) + # 计算差值的指数 data_exp = graph_builder.emit('Exp', [data_sub]) + # 对指数结果进行ReduceSum操作 data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) + # 计算ReduceSum结果的log log_expsum = graph_builder.emit('Log', [data_expsum]) + # 计算差值与log的差值 result = graph_builder.emit('Sub', [data_sub, log_expsum]) + # 返回结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py index 9a075d7c..7af59ed0 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py @@ -23,14 +23,32 @@ class LogSoftmaxGrad(Expander): """LogSoftmaxGrad expander""" def _expand(self, graph_builder): + """ + 对输入进行扩展操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展操作的结果。 + + """ + # 获取输入的logits和dy input_logits, input_dy = self.inputs + # 获取axis参数 axis = self.attrs['axis'] + # 如果axis是整数,则将其转换为元组 if isinstance(axis, int): axis = (axis,) + # 计算softmax softmax = graph_builder.emit('Exp', [input_logits]) + # 计算dy的sum dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True}) + # 计算softmax和dy_sum的乘积 mul_result = graph_builder.emit('Mul', [softmax, dy_sum]) + # 计算input_dy和mul_result的差 result = graph_builder.emit('Sub', [input_dy, mul_result]) + # 返回结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/matmul.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/matmul.py index 8af69e0b..1d888f58 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/matmul.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/matmul.py @@ -25,48 +25,139 @@ class MatMul(Expander): """ def __init__(self, expand_info): + """ + 初始化MatMul类实例。 + + Args: + expand_info (dict): 扩展信息字典,包含操作所需的额外信息。 + + Attributes: + transpose_a (bool): 是否对矩阵A进行转置。 + transpose_b (bool): 是否对矩阵B进行转置。 + left_format (str): 矩阵A的数据格式。 + right_format (str): 矩阵B的数据格式。 + shape_a (tuple): 矩阵A的形状。 + shape_b (tuple): 矩阵B的形状。 + + """ + # 调用父类的初始化方法 super(MatMul, self).__init__(expand_info) + # 获取transpose_a属性 self.transpose_a = self.attrs['transpose_a'] + # 获取transpose_b属性 self.transpose_b = self.attrs['transpose_b'] + # 获取left_format属性 self.left_format = self.attrs['left_format'] + # 获取right_format属性 self.right_format = self.attrs['right_format'] + # 获取输入A的shape self.shape_a = self.inputs[0]['shape'] + # 获取输入B的shape self.shape_b = self.inputs[1]['shape'] def _optimize_to_mul(self): + """ + 检查是否可以用乘法(mul)替换矩阵乘法(matmul) + + Args: + 无 + + Returns: + bool: 如果可以用乘法替换矩阵乘法,返回True;否则返回False。 + + """ """check if matmul can be replace by mul""" + # 如果处理器不是'aicore',或者左格式或右格式不是默认格式,则返回False if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT: return False + # 如果transpose_a为True,则k_a为shape_a的倒数第二个维度,否则为shape_a的倒数第一个维度 k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1] + # 如果transpose_b为True,则k_b为shape_b的倒数第一个维度,否则为shape_b的倒数第二个维度 k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2] + # 如果k_a或k_b不等于1,则返回False if k_a != 1 or k_b != 1: return False + # 否则返回True return True def _check(self): + """ + 检查输入个数是否满足矩阵乘法的要求。 + + Args: + 无 + + Returns: + 无 + + Raises: + GKException: 如果输入的个数小于2,则抛出GKException异常,提示信息为 "For 'MatMul', inputs number should bigger than 1, but got {}.",其中{}为输入的个数。 + + """ + # 获取输入的个数 input_num = len(self.inputs) + # 如果输入的个数小于2,抛出异常 if input_num < 2: raise GKException("For 'MatMul', inputs number should bigger than 1, but got {}.".format(input_num)) def _expand(self, graph_builder): + """ + 将MatMul或BatchMatMul操作替换为Mul操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Node: Mul操作的结果节点。 + + Raises: + GKException: 如果不需要将MatMul/BatchMatMul替换为Mul操作,则引发异常。 + + """ + # 定义一个函数,用于转置shape def transpose(shape): + """ + 将给定的shape进行转置操作。 + + Args: + shape (tuple): 输入的shape,为一个元组,表示多维数组的形状。 + + Returns: + list: 转置后的shape,以列表形式返回。 + + """ + # 将shape转换为列表 trans_shape = list(shape) + # 将shape的倒数第二个元素和倒数第一个元素交换位置 trans_shape[-2] = shape[-1] trans_shape[-1] = shape[-2] + # 返回转置后的shape return trans_shape + # 如果不需要优化为乘法,则抛出异常 if not self._optimize_to_mul(): raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul") # Matmul is replaced by Mul([b m k], [b k n]) when k==1 + # 获取输入a input_a = self.inputs[0] + # 获取输入b input_b = self.inputs[1] + # 如果transpose_a为True,则对输入a进行转置 if self.transpose_a: + # 获取输入a的转置形状 shape_a_trans = transpose(self.shape_a) + # 对输入a进行转置 input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans}) + # 如果transpose_b为True,则对输入b进行转置 if self.transpose_b: + # 获取输入b的转置形状 shape_b_trans = transpose(self.shape_b) + # 对输入b进行转置 input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans}) + # 对输入a和输入b进行乘法运算 result = graph_builder.emit('Mul', [input_a, input_b]) + # 如果dst_type在attrs中,并且输入a的数据类型与dst_type不同,则对结果进行类型转换 if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']: + # 对结果进行类型转换 result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']}) return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/maximum_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/maximum_grad.py index d29989cb..0e62e7b7 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/maximum_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/maximum_grad.py @@ -23,35 +23,76 @@ class MaximumGrad(Expander): """MaximumGrad expander""" def _check(self): + """ + 检查MaximumGrad的属性是否符合要求。 + + Args: + 无 + + Returns: + 返回父类的检查结果。 + + Raises: + GKException: 当 'grad_x' 和 'grad_y' 的值都为 False 时抛出异常。 + + """ + # 如果attr 'grad_x'和'grad_y'的值都为False,则抛出异常 if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): raise GKException("For 'MaximumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and " "{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y'))) + # 调用父类的方法 return super()._check() def _expand(self, graph_builder): + """ + 根据输入计算梯度,并返回两个梯度结果。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于构建计算图。 + + Returns: + tuple: 包含两个梯度结果的元组,第一个元素为对输入x的梯度,第二个元素为对输入y的梯度。 + + """ + # 获取输入的x、y和dout input_x, input_y, input_dout = self.inputs + # 比较x和y的大小,返回一个布尔值 ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y]) + # 将布尔值转换为与x相同的类型 ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) + # 计算dx,即x的梯度 dx = graph_builder.emit('Mul', [ge_result, input_dout]) + # 计算dy,即y的梯度 dy = graph_builder.emit('Sub', [input_dout, dx]) + # 获取dx和dy的reduce轴 reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape) reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape) + # 如果dx有reduce轴 if reduce_axis_x: + # 对dx进行求和 dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False}) + # 如果dx_reduce的形状与input_x的形状不同,则进行reshape if dx_reduce.shape != input_x.shape: dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape}) + # 否则,dx_out等于dx_reduce else: dx_out = dx_reduce + # 否则,dx_out等于dx else: dx_out = dx + # 如果dy有reduce轴 if reduce_axis_y: + # 对dy进行求和 dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False}) + # 如果dy_reduce的形状与input_y的形状不同,则进行reshape if dy_reduce.shape != input_y.shape: dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape}) + # 否则,dy_out等于dy_reduce else: dy_out = dy_reduce + # 否则,dy_out等于dy else: dy_out = dy diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/minimum_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/minimum_grad.py index 8772fb0e..0a98d1df 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/minimum_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/minimum_grad.py @@ -22,59 +22,117 @@ class MinimumGrad(Expander): """MinimumGrad expander""" def _check(self): + """ + 检查MinimumGrad类的属性是否满足要求。 + + Args: + 无 + + Returns: + bool: 如果属性符合要求,则返回True,否则抛出GKException异常。 + + Raises: + GKException: 如果MinimumGrad类的属性'grad_x'和'grad_y'均为False,则抛出此异常。 + + """ + # 如果attr 'grad_x'和'grad_y'的值都为False,则抛出异常 if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True): raise GKException("For 'MinimumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and " "{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y'))) + # 调用父类的方法 return super(MinimumGrad, self)._check() def _expand(self, graph_builder): + """ + 计算两个输入的梯度。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于在图中执行操作。 + + Returns: + tuple: 包含两个梯度结果的元组。 + + """ + # 输入参数 input_x, input_y, input_dout = self.inputs - le_result = graph_builder.emit('LessEqual', [input_x, input_y]) - le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype}) - dx = graph_builder.emit('Mul', [le_result, input_dout]) - dy = graph_builder.emit('Sub', [input_dout, dx]) + # 执行 LessEqual 操作 + le_result = graph_builder.emit('LessEqual', [input_x, input_y]) # 执行 LessEqual 操作 + # 将结果转换为与 input_x 相同的数据类型 + le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype}) # 将结果转换为与 input_x 相同的数据类型 + # 执行 Mul 操作,将 le_result 和 input_dout 相乘 + dx = graph_builder.emit('Mul', [le_result, input_dout]) # 执行 Mul 操作,将 le_result 和 input_dout 相乘 + # 执行 Sub 操作,用 input_dout 减去 dx + dy = graph_builder.emit('Sub', [input_dout, dx]) # 执行 Sub 操作,用 input_dout 减去 dx + # 对于 minimumgrad 操作,输出形状应与输入形状相同, + # 但某些元素级操作可能会广播输入形状, + # 导致输出形状不等于原始输入形状,因此需要减少输出来使它们相等 # for minimumgrad op, output_shape should be equal to input_shape, # but some elementwise operating may broadcast input_shape # then output_shape not equal to original input_shape, so need to reduce output to let them equal - reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape) - reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape) + reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape) # 获取 x 的减少轴 + reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape) # 获取 y 的减少轴 if reduce_axis_x: - dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False}) + # 如果存在减少轴,执行 ReduceSum 操作 + dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False}) # 执行 ReduceSum 操作 if dx_reduce.shape != input_x.shape: - dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape}) + # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作 + dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape}) # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作 else: - dx_out = dx_reduce + dx_out = dx_reduce # 否则直接使用减少后的结果 else: - dx_out = dx + dx_out = dx # 如果没有减少轴,则直接使用 dx if reduce_axis_y: - dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False}) + # 如果存在减少轴,执行 ReduceSum 操作 + dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False}) # 执行 ReduceSum 操作 if dy_reduce.shape != input_y.shape: - dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape}) + # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作 + dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape}) # 如果减少后的形状不等于原始输入形状,则进行 Reshape 操作 else: - dy_out = dy_reduce + dy_out = dy_reduce # 否则直接使用减少后的结果 else: - dy_out = dy + dy_out = dy # 如果没有减少轴,则直接使用 dy - # output two results, regardless of grad_x and grad_y + # 输出两个结果, return dx_out, dy_out @staticmethod def get_reduce_axis(original_shape, broadcast_shape): + """ + 计算最终输出形状的归约轴。 + + Args: + original_shape (tuple of int): 原始形状,一个包含整数的元组。 + broadcast_shape (tuple of int): 广播形状,一个包含整数的元组。 + + Returns: + list of int: 归约轴列表,表示在最终输出形状中需要归约的轴索引。 + + Raises: + ValueError: 如果original_shape的长度大于broadcast_shape的长度,或者original_shape和broadcast_shape无法广播。 + + """ """compute reduce axis for final output_shape""" + # 如果original_shape的长度大于broadcast_shape的长度 if len(original_shape) > len(broadcast_shape): raise ValueError("For 'MinimumGrad', the length of original_shape should be less than or equal to the " "length of broadcast_shape, but got {} and {}".format(original_shape, broadcast_shape)) + # 创建一个tmp_shape列表,长度为broadcast_shape的长度,前面填充1,后面填充original_shape的元素 tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape reduce_axis = [] + # 遍历tmp_shape中的每个元素 for idx, _ in enumerate(tmp_shape): + # 如果tmp_shape中的元素与broadcast_shape中的对应元素不相等 if tmp_shape[idx] != broadcast_shape[idx]: + # 如果tmp_shape中的元素为1 if tmp_shape[idx] == 1: + # 将当前索引添加到reduce_axis列表中 reduce_axis.append(idx) else: + # 抛出异常,表示original_shape和broadcast_shape无法广播 raise ValueError("For 'MinimumGrad', original_shape {} and broadcast_shape {} can not broadcast." .format(original_shape, broadcast_shape)) return reduce_axis diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/oneslike.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/oneslike.py index e4a4a4d0..7b91cb2a 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/oneslike.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/oneslike.py @@ -20,7 +20,24 @@ class OnesLike(Expander): """OnesLike expander""" def _expand(self, graph_builder): + """ + 将输入张量扩展至指定形状。 + + Args: + graph_builder: 图构建器对象,用于构建图结构。 + + Returns: + 扩展后的张量。 + + """ + # 获取输入张量 input_x = self.inputs[0] + + # 创建一个值为1的常量,数据类型与输入张量相同 const_one = graph_builder.value(input_x.dtype, 1) + + # 使用BroadcastTo操作将常量扩展至输入张量的形状 result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape}) + + # 返回扩展后的张量 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/reduce_mean.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/reduce_mean.py index abb2df57..2258135a 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/reduce_mean.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/reduce_mean.py @@ -22,22 +22,41 @@ from ._utils import Expander, ExpanderInfoValidator as VLD class ReduceMean(Expander): """ReduceMean expander""" - def _expand(self, graph_builder): +def _expand(self, graph_builder): + """ + 对输入张量进行扩展操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展操作后的张量。 + + """ + # 获取输入张量 x = self.inputs[0] + # 获取扩展操作的轴 axis = self.attrs['axis'] + # 获取是否保持维度 keep_dims = self.attrs['keep_dims'] + # 如果轴不是元组或列表,则将其转换为元组 if not isinstance(axis, (tuple, list)): axis = (axis,) + # 如果轴为空,则将其设置为张量的所有维度 elif not axis: axis = list(range(len(x.shape))) + # 计算缩减的大小 reduce_size = 1.0 for idx in axis: reduce_size *= x.shape[idx] + # 创建一个与输入张量相同数据类型的值,值为缩减的大小 reduce_size_value = graph_builder.value(x.dtype, reduce_size) + # 沿指定轴对输入张量进行求和操作 sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims}) + # 将求和结果除以缩减的大小,得到扩展后的张量 result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value]) return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/relu_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/relu_grad.py index d2e7f740..1fe53f22 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/relu_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/relu_grad.py @@ -21,12 +21,28 @@ class ReluGrad(Expander): """ReLU expander""" def _expand(self, graph_builder): + """ + 在指定的图构建器中扩展当前节点。 + + Args: + graph_builder (GraphBuilder): 图构建器实例,用于在图中生成新的节点。 + + Returns: + Tensor: 返回计算后的结果张量。 + + """ + # 获取输入张量 input_x = self.inputs[0] input_y = self.inputs[1] + # 生成一个与input_y相同数据类型的0值张量 + # 生成一个与input_y相同数据类型的0值张量 const_zero = graph_builder.value(input_y.dtype, 0) + # 判断input_y是否大于0,生成布尔张量 ge_result = graph_builder.emit('Greater', [input_y, const_zero]) + # 将布尔张量转换为与input_x相同数据类型的张量 ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) + # 将转换后的张量与input_x相乘 result = graph_builder.emit('Mul', [ge_result, input_x]) return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py index c2c9d20e..c8ed536d 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py @@ -21,21 +21,46 @@ class SigmoidCrossEntropyWithLogits(Expander): """SigmoidCrossEntropyWithLogits expander""" def _expand(self, graph_builder): + """ + 计算sigmoid交叉熵损失。 + + Args: + graph_builder: 图构建器对象,用于构建计算图。 + + Returns: + 计算得到的sigmoid交叉熵损失值。 + + """ logits, labels = self.inputs + # 计算 logits 和 labels 的 sigmoid_cross_entropy_with_logits # Calculate sigmoid_cross_entropy_with_logits(logits, labels) - # formula of sigmoid_cross_entropy_with_logits is: - # -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits))) - # To ensure stability and avoid overflow, the formula equal to : - # max(logits, 0) - logits * labels + log(1 + exp(-abs(logits))) + # sigmoid_cross_entropy_with_logits 的公式为: + # -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits))) + # 为了确保稳定性并避免溢出,该公式等价于: + # max(logits, 0) - logits * labels + log(1 + exp(-abs(logits))) + + # 创建一个值为 1.0 的常量 const_one = graph_builder.value(logits.dtype, 1.0) + # 创建一个值为 0.0 的常量 const_zero = graph_builder.value(logits.dtype, 0.0) + + # 计算 logits 和 0 的最大值 max_logits = graph_builder.emit('Maximum', [logits, const_zero]) + # 计算 logits 和 labels 的乘积 logits_mul_labels = graph_builder.emit('Mul', [logits, labels]) + # 计算 logits 的绝对值 abs_logits = graph_builder.emit('Abs', [logits]) + # 计算 logits 的负值 neg_abs_logits = graph_builder.emit('Neg', [abs_logits]) + # 计算 exp(-abs(logits)) exp_neg_abs_logits = graph_builder.emit('Exp', [neg_abs_logits]) + # 计算 1 + exp(-abs(logits)) one_add_exp_neg_abs_logits = graph_builder.emit('Add', [const_one, exp_neg_abs_logits]) + # 计算 log(1 + exp(-abs(logits))) log_one_add_exp_neg_abs_logits = graph_builder.emit('Log', [one_add_exp_neg_abs_logits]) + # 计算 max(logits, 0) - logits * labels res_tmp = graph_builder.emit('Sub', [max_logits, logits_mul_labels]) + # 计算最终结果 res = graph_builder.emit('Add', [res_tmp, log_one_add_exp_neg_abs_logits]) + return res diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py index 07f9dba3..cf49d4cc 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py @@ -21,15 +21,29 @@ class SigmoidCrossEntropyWithLogitsGrad(Expander): """SigmoidCrossEntropyWithLogitsGrad expander""" def _expand(self, graph_builder): + """ + 计算sigmoid交叉熵损失的梯度。 + + Args: + graph_builder: 图构建器对象,用于构建计算图。 + + Returns: + 计算得到的梯度值。 + + """ logits, label, dout = self.inputs - # Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout) - # formula of sigmoid_cross_entropy_with_logits_grad is : + # 计算sigmoid_cross_entropy_with_logits_grad(logits, label, dout) + # sigmoid_cross_entropy_with_logits_grad的公式为: # (sigmoid(logits) - label) * dout + # 计算sigmoid(logits) + # Calculate sigmoid(logits) const_one = graph_builder.value(logits.dtype, 1.0) - neg_x = graph_builder.emit('Neg', [logits]) - exp_neg_x = graph_builder.emit('Exp', [neg_x]) - add_exp = graph_builder.emit('Add', [const_one, exp_neg_x]) - sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp]) + neg_x = graph_builder.emit('Neg', [logits]) # 计算-logits + exp_neg_x = graph_builder.emit('Exp', [neg_x]) # 计算e^(-logits) + add_exp = graph_builder.emit('Add', [const_one, exp_neg_x]) # 计算1 + e^(-logits) + sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp]) # 计算1 / (1 + e^(-logits)),即sigmoid(logits) + # 计算(sigmoid(logits) - label) sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label]) - res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout]) + # 计算最终结果 + res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout]) # 计算(sigmoid(logits) - label) * dout return res diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py index 56d9413a..72564754 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py @@ -16,16 +16,31 @@ from ._utils import Expander, ExpanderInfoValidator as VLD -@VLD.check_all_formats_same +@VLD.check_all_formats_same # 定义一个SigmoidGrad类,继承自Expander类 class SigmoidGrad(Expander): """SigmoidGrad expander""" def _expand(self, graph_builder): + """ + 计算 sigmoid 函数的梯度。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于在图中添加操作。 + + Returns: + Tensor: 计算得到的 sigmoid 梯度。 + + """ input_y, dy = self.inputs + # 计算 sigmoid_grad(y, dy) + # sigmoid_grad 的公式是: (1 - y) * y * dy # Calculate sigmoid_grad(y, dy) # formula of sigmoid_grad is : (1 - y) * y * dy const_one = graph_builder.value(input_y.dtype, 1.0) + # 1 - y one_mins_y = graph_builder.emit('Sub', [const_one, input_y]) + # y * dy y_mul_dy = graph_builder.emit('Mul', [input_y, dy]) + # (1 - y) * (y * dy) res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy]) return res diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/slice.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/slice.py index 5bc265ea..fb847dd2 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/slice.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/slice.py @@ -21,15 +21,41 @@ class Slice(Expander): """Slice expander""" def _expand(self, graph_builder): + """ + 在图中扩展输入张量。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 扩展后的输出张量。 + + """ + # 获取输入张量 input_x = self.inputs[0] + + # 获取开始索引 begin = self.attrs['begin'] + # 获取切片大小 size = self.attrs['size'] + + # 初始化结束索引列表 end = [] + # 初始化步长列表 strides = [] + + # 遍历每个维度,计算结束索引和步长 for i, begin_idx in enumerate(begin): + # 步长设置为1 strides.append(1) + # 计算结束索引 end.append(begin_idx + size[i]) + + # 创建一个新的张量作为输出 output = graph_builder.tensor(size, input_x.dtype, input_x.data_format) + + # 执行StridedSlice操作,对输入张量进行切片 graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides}) + # 返回输出张量 return output diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax.py index 991d5e6a..9bad459e 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax.py @@ -25,45 +25,75 @@ class Softmax(Expander): """Softmax expander""" def _expand(self, graph_builder): + """ + 计算Softmax函数值。 + + Args: + graph_builder: 图构建器对象。 + + Returns: + Softmax函数的计算结果。 + + """ + # 获取输入数据 input_x = self.inputs[0] + # 获取处理器 processor = self.processor + # 获取轴信息 axis = self.attrs['axis'] + # 获取输入数据的原始形状 ori_shape = input_x.shape + # 如果输入数据格式为FRAC_NZ,则推断其形状 if input_x.data_format == DF.FRAC_NZ: ori_shape = infer_shape_from_fractalnz(input_x.shape) + # 遍历轴信息,处理负数轴索引 for i, _ in enumerate(list(axis)): if axis[i] < 0: axis[i] += len(ori_shape) + # 获取减少维度后的原始形状 ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis) + # 获取减少的轴 reduce_axis = axis + # 如果输入数据格式为FRAC_NZ,则转换轴 if input_x.data_format == DF.FRAC_NZ: reduce_axis = to_frac_z_axis(ori_shape, axis) + # 获取输入数据的原始数据类型 ori_dtype = input_x.dtype + # 如果原始数据类型不是float16且处理器为aicore,则进行类型转换 if ori_dtype != "float16" and processor == "aicore": input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'}) max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype}) else: + # 计算最大值 max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) + # 如果原始数据类型为float16且处理器为aicore,则进行类型转换 if ori_dtype == "float16" and processor == "aicore": max_x = graph_builder.emit('Cast', [max_x], attrs={'dst_type': "float32"}) input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"}) + # 如果输入数据格式为FRAC_NZ,则重新调整最大值的形状 if input_x.data_format == DF.FRAC_NZ: max_x = graph_builder.emit('Reshape', [max_x], attrs={'shape': ori_reduced_shape}) + # 计算输入数据减去最大值的差值 data_sub = graph_builder.emit('Sub', [input_x, max_x]) + # 计算差值的指数 data_exp = graph_builder.emit('Exp', [data_sub]) + # 计算指数的和 data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': reduce_axis, 'keep_dims': True}) + # 如果输入数据格式为FRAC_NZ,则重新调整指数和的形状 if input_x.data_format == DF.FRAC_NZ: data_expsum = graph_builder.emit('Reshape', [data_expsum], attrs={'shape': ori_reduced_shape}) + # 计算Softmax值 result = graph_builder.emit('RealDiv', [data_exp, data_expsum]) + # 如果原始数据类型为float16且处理器为aicore,则进行类型转换 if ori_dtype == "float16" and processor == "aicore": result = graph_builder.emit('Cast', [result], attrs={'dst_type': ori_dtype}) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py index e28e74f4..12cae367 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py @@ -22,21 +22,45 @@ class SoftmaxCrossEntropyWithLogits(Expander): """SoftmaxCrossEntropyWithLogits expander""" def _expand(self, graph_builder): + """ + 计算损失值和 logits 的梯度。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tuple[Tensor, Tensor]: 损失值和 logits 的梯度。 + + """ logits, label = self.inputs + # 计算 softmax_cross_entropy_with_logits(logits, label) + # softmax_cross_entropy_with_logits 的公式是: -reduce_sum(label * log(softmax(logits))) # Calculate softmax_cross_entropy_with_logits(logits, label) # formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits))) axis = (-1,) max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True}) + # 计算 logits 的最大值 data_sub = graph_builder.emit('Sub', [logits, max_x]) + # logits 减去最大值 data_exp = graph_builder.emit('Exp', [data_sub]) + # 对上一步结果进行指数运算 data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True}) + # 对指数运算结果求和 data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum]) + # 计算 softmax const_eps = graph_builder.value(logits.dtype, 0.000001) + # 定义一个极小的常数,用于防止除以零的错误 data_softmax_safety = graph_builder.emit("Maximum", [data_softmax, const_eps]) + # 确保 softmax 的值不为零 softmax_log = graph_builder.emit('Log', [data_softmax_safety]) + # 对 softmax 结果取对数 label_mul_log = graph_builder.emit('Mul', [label, softmax_log]) + # 将 label 与 softmax 的对数相乘 tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={ 'reduce_axis': axis, 'keep_dims': False}) + # 对上一步结果进行求和 loss = graph_builder.emit('Neg', [tmp_res]) + # 计算损失值,即上一步结果的负值 dlogits = graph_builder.emit('Sub', [data_softmax, label]) + # 计算 logits 的梯度 return loss, dlogits diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py index d4ccff60..d3c23a4e 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py @@ -13,29 +13,48 @@ # limitations under the License. # =========================================================================== """generate json desc for SoftmaxGradExt""" -from mindspore._extends.graph_kernel.model.model import DataFormat as DF -from ._utils import Expander, ExpanderInfoValidator as VLD -from ._utils import get_reduce_axis_shape +from mindspore._extends.graph_kernel.model.model import DataFormat as DF # 导入DataFormat类 +from ._utils import Expander, ExpanderInfoValidator as VLD # 导入Expander和ExpanderInfoValidator类 +from ._utils import get_reduce_axis_shape # 导入get_reduce_axis_shape函数 -@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT) +@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT) # 使用ExpanderInfoValidator类添加FRAC_NZ格式 @VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) @VLD.check_attrs('axis') class SoftmaxGradExt(Expander): """SoftmaxGradExt expander""" def _expand(self, graph_builder): + """ + 对输入数据进行扩展处理。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 处理后的数据。 + + """ + # 获取输入参数 x, y, z = self.inputs + # 获取指定的轴 axis = self.attrs['axis'] + # 获取需要减少的轴和原始减少的形状 reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis) + # 将x和y相乘 data_mul = graph_builder.emit('Mul', [x, y]) + # 对乘积进行求和,并保留维度 data_sum = graph_builder.emit('ReduceSum', [data_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': True, 'reduce_output_fuse': True}) + # 如果x的数据格式为FRAC_NZ,则对求和结果进行重塑 if x.data_format == DF.FRAC_NZ: data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape}) + # 从x中减去求和结果 data_sub = graph_builder.emit('Sub', [x, data_sum]) + # 将减法结果与y相乘 data_mul2 = graph_builder.emit('Mul', [data_sub, y]) + # 将结果与z相乘得到最终结果 result = graph_builder.emit('Mul', [data_mul2, z]) return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py index 9f072e7a..8d1bb0bd 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/sqrt_grad.py @@ -21,9 +21,24 @@ class SqrtGrad(Expander): """SqrtGrad expander""" def _expand(self, graph_builder): + """ + 计算并返回给定输入 x 的平方根的梯度。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于构建计算图。 + + Returns: + Tensor: 返回给定输入 x 的平方根的梯度。 + + """ + # 获取输入 x 和梯度 dout # formula of sqrt_grad is dout / (2 * x) x, dout = self.inputs + # 创建一个常数值 2 const_two = graph_builder.value(x.dtype, 2) + # 计算 2 * x dividend = graph_builder.emit('Mul', [x, const_two]) + # 计算梯度:dout / (2 * x) result = graph_builder.emit('RealDiv', [dout, dividend]) + # 返回计算结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/square_sum_all.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/square_sum_all.py index a06c1a39..f38953d8 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/square_sum_all.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/square_sum_all.py @@ -21,24 +21,57 @@ class SquareSumAll(Expander): """SquareSumAll expander""" def _check(self): + """ + 检查输入是否合法。 + + Args: + 无。 + + Returns: + 无。 + + Raises: + GKException: 如果输入的数量不等于2,则抛出GKException异常。 + + """ """check inputs""" + # 获取输入的数量 input_num = len(self.inputs) if input_num != 2: + # 如果输入的数量不等于2,则抛出异常 raise GKException("For 'SquareSumAll', the inputs number should be 2, but got {}.".format(input_num)) def _expand(self, graph_builder): + """ + 对输入的两个变量进行扩展操作。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于在图构建过程中发射操作。 + + Returns: + tuple: 包含两个元素的元组,每个元素为扩展操作的结果。 + + """ """do expand""" + # 获取输入的两个变量 x0 = self.inputs[0] x1 = self.inputs[1] + # 获取x0的形状 ori_shape = x0.shape + # 初始化一个空列表,用于存储维度索引 axis = [] + # 遍历ori_shape,将每个维度的索引添加到axis列表中 for i, _ in enumerate(ori_shape): axis.append(i) + # 对x0进行平方运算 square_res0 = graph_builder.emit('Mul', [x0, x0]) + # 对x1进行平方运算 square_res1 = graph_builder.emit('Mul', [x1, x1]) + # 对square_res0进行求和运算,求和的维度为axis,并保持维度不变 result0 = graph_builder.emit('ReduceSum', [square_res0], attrs={'reduce_axis': axis, 'keep_dims': False}) + # 对square_res1进行求和运算,求和的维度为axis,并保持维度不变 result1 = graph_builder.emit('ReduceSum', [square_res1], attrs={'reduce_axis': axis, 'keep_dims': False}) return result0, result1 diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/square_sum_v1.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/square_sum_v1.py index 8774dcbe..67d6f28b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/square_sum_v1.py @@ -25,13 +25,30 @@ class SquareSumV1(Expander): """Square expander""" def _expand(self, graph_builder): + """ + 计算输入张量的平方并沿指定轴进行求和。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Returns: + Tensor: 计算得到的张量。 + + """ + # 获取输入的第一个元素 x = self.inputs[0] + # 获取属性中的axis值 axis = self.attrs['axis'] + # 获取需要reduce的axis和原始的reduced shape reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis) + # 计算x的平方 square_res = graph_builder.emit('Mul', [x, x]) + # 对平方结果进行ReduceSum操作 result = graph_builder.emit('ReduceSum', [square_res], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + # 如果数据格式为FRAC_NZ,则对结果进行Reshape操作 if x.data_format == DF.FRAC_NZ: result = graph_builder.emit('Reshape', [result], attrs={'shape': ori_reduced_shape}) + # 返回最终结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/squared_difference.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/squared_difference.py index 316b000e..ca0dc687 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/squared_difference.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/squared_difference.py @@ -21,10 +21,24 @@ class SquaredDifference(Expander): """SquaredDifference expander""" def _expand(self, graph_builder): + """ + 根据输入的两个输入值计算并返回它们的平方差的计算结果。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于在图中生成节点和边。 + + Returns: + Node: 计算结果节点。 + + """ + # 获取输入的第一个值 input_x = self.inputs[0] + # 获取输入的第二个值 input_y = self.inputs[1] + # 使用图构建器计算输入值的差值 sub_val = graph_builder.emit('Sub', [input_x, input_y]) + # 使用图构建器计算差值的平方 result = graph_builder.emit('Mul', [sub_val, sub_val]) return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/squeeze.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/squeeze.py index a5e631b5..66178c0b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/squeeze.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/squeeze.py @@ -21,27 +21,67 @@ class Squeeze(Expander): """Squeeze expander""" def _expand(self, graph_builder): + """ + 扩展输入的维度。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于构建图结构。 + + Returns: + Tensor: 扩展维度后的输入。 + + """ + # 获取输入的第一个元素 input_x = self.inputs[0] + + # 根据输入的shape和axis属性推断输出shape out_shape = self.infer_shape(input_x.shape, self.attrs['axis']) + + # 使用graph_builder发射Reshape操作,并设置shape属性为out_shape result = graph_builder.emit('Reshape', [input_x], attrs={'shape': out_shape}) + # 返回结果 return result @staticmethod def infer_shape(shape, axis): + """ + 根据指定的axis推断squeeze后的shape。 + + Args: + shape (list, tuple): 原始数据的shape。 + axis (int, list, tuple): 需要被squeeze的维度。如果为int,则只squeeze该维度; + 如果为list或tuple,则squeeze列表或元组中的每个维度。如果为空,则squeeze所有维度为1的维度。 + + Returns: + list: squeeze后的shape。 + + Raises: + ValueError: 如果输入的axis类型不符合要求,抛出异常。 + + """ """infer shape for squeeze""" def squeeze_axis(shape, axis): + # 如果axis为空,移除shape中所有值为1的维度 if not axis: out_shape = list(d for d in shape if d != 1) else: + # 获取shape的维度数量 ndim = len(shape) + # 移除shape中指定的axis维度 out_shape = list(shape[i] for i in range(ndim) if not (i in axis or (i - ndim) in axis)) + # 如果out_shape为空,则将其设置为[1] if not out_shape: out_shape = [1] return out_shape + + # 如果shape是列表或元组类型 if isinstance(shape, (list, tuple)): + # 如果axis是整数类型,则将其转换为列表 if isinstance(axis, int): axis = [axis] + # 如果axis是列表或元组类型,则调用squeeze_axis函数处理 if isinstance(axis, (list, tuple)): return squeeze_axis(shape, axis) + # 如果输入不符合要求,则抛出异常 raise ValueError("Invalid axis for Squeeze.") diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/tanh_grad.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/tanh_grad.py index a43c7962..60681a54 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/tanh_grad.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/tanh_grad.py @@ -21,11 +21,29 @@ class TanhGrad(Expander): """TanhGrad expander""" def _expand(self, graph_builder): + """ + 计算1减去输入值的平方后,再与输入的导数相乘。 + + Args: + graph_builder (GraphBuilder): 图构建器对象,用于构建计算图。 + + Returns: + Tensor: 计算结果,类型为Tensor。 + + """ + # 获取输入值 input_y, input_dy = self.inputs + # 创建一个值为1的常量,数据类型与input_y相同 const_one = graph_builder.value(input_y.dtype, 1) + + # 计算input_y的平方 double_y = graph_builder.emit('Mul', [input_y, input_y]) + + # 计算1减去input_y的平方 one_sub_double_y = graph_builder.emit('Sub', [const_one, double_y]) + + # 计算input_dy与1减去input_y的平方的乘积 result = graph_builder.emit('Mul', [input_dy, one_sub_double_y]) return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/tile.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/tile.py index 918b7486..9b262d45 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/tile.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/tile.py @@ -25,30 +25,48 @@ class Tile(Expander): def _get_output_shape(self): """Get output shape""" + # 获取输入形状的列表 shape = list(self.inputs[0].shape) + # 获取属性"multiples"的列表 multiples = list(self.attrs["multiples"]) + # 计算"multiples"和输入形状的长度差 diff_len = len(multiples) - len(shape) + # 如果长度差小于0,抛出异常 if diff_len < 0: raise GKException("For 'Tile', dimensions of attr 'multiples' should be greater than or equal to " "dimensions of input shape, but got {} and {}".format(multiples, shape)) + # 如果长度差大于0,则扩展输入形状的列表 if diff_len > 0: for _ in range(diff_len): shape.insert(0, 1) + # 初始化输出形状的列表 output_shape = [] + # 遍历输入形状和multiples的元组 for sh, mul in list(zip(shape, multiples)): + # 如果输入形状和multiples的值都不为1,则抛出异常 if sh != 1 and mul != 1: raise GKException("For 'Tile', input shape{} and attr 'multiples'{} can not broadcast." .format(self.inputs[0].shape, multiples)) + # 计算维度 dim = sh * mul + # 将计算得到的维度添加到输出形状的列表中 output_shape.append(dim) + # 返回输出形状的列表 return output_shape def _expand(self, graph_builder): + # 获取输入的第一个元素 input_x = self.inputs[0] + + # 获取输出形状 output_shape = self._get_output_shape() + # 使用graph_builder的emit方法生成BroadcastTo操作 + # 参数为[input_x]和输出形状 result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) + + # 返回结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/__init__.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/__init__.py index 8125a8b1..b9f7be45 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/__init__.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/__init__.py @@ -14,6 +14,9 @@ # =========================================================================== """GraphKernel cost model init""" +# 导入split模块 from .graph_split import split +# 导入GraphBuilder和load_composite模块 from .model_builder import GraphBuilder, load_composite +# 导入parallel_estimate模块 from .graph_parallel import parallel_estimate diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_parallel.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_parallel.py index 7e69da3e..f910b8fe 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_parallel.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_parallel.py @@ -20,128 +20,299 @@ class ParalGain: """Paral Gain""" def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info): + """ + 类的构造函数。 + + Args: + fusion_type (str): 融合类型。 + bottleneck (int): 瓶颈层的大小。 + gain (float): 增益值。 + block_assign (list): 块分配列表。 + type_info (dict): 类型信息字典。 + + Returns: + None + """ + # 初始化融合类型 self.fusion_type = fusion_type + # 初始化瓶颈层 self.bottleneck = bottleneck + # 初始化增益 self.gain = gain + # 初始化块分配 self.block_assign = block_assign + # 初始化类型信息 self.type_info = type_info class ScheduleAnalyzer: """schedule analyzer""" + # 定义一个常量,表示wrap的大小 WRAP_SIZE = 32 + # 定义一个常量,表示最大SM数量 MAX_SM = 80 # Volta + # 定义一个常量,表示最大线程数量 MAX_NUM_THREADS = 1024 + # 定义一个常量,表示最大block数量 MAX_BLOCK = 256 + # 定义一个常量,表示流水线操作的阈值 PIPELINE_OP_THREADHOLD = 5 def __init__(self, graph): + """ + 初始化图处理类。 + + Args: + graph (Graph): 图对象,用于存储图的结构和参数。 + + Attributes: + graph (Graph): 图对象,存储图的结构和参数。 + block_num (int): 块的数量,初始值为0。 + block_weight (float): 块的权重,初始值为0。 + ops (List[Operation]): 图的操作列表。 + dom_op (List[Operation]): 输出的每个操作对应的操作列表。 + + """ + # 将传入的图对象赋值给实例变量graph self.graph = graph + # 初始化block数量为0 self.block_num = 0 + # 初始化block权重为0 self.block_weight = 0 + # 通过图对象的deduce_parameters方法获取参数,并赋值给outputs变量 _, outputs = graph.deduce_parameters() + # 将图对象的操作列表赋值给实例变量ops self.ops = graph.ops + # 将outputs中的每个输出对应的操作收集到一个列表中,并赋值给实例变量dom_op self.dom_op = list(out.op for out in outputs) @staticmethod def prod(shape): + """ + 计算形状乘积。 + + Args: + shape (list): 一个包含整数的列表,表示形状。 + + Returns: + int: 形状乘积的结果。 + + """ """Compute shape product""" + # 初始化结果变量为shape的第一个元素 res = shape[0] + # 遍历shape列表,从第二个元素开始 for i in range(1, len(shape)): + # 将当前结果与shape的下一个元素相乘 res = res * shape[i] + # 返回计算后的结果 return res def _cal_weight(self, ops): + """ + 计算给定操作列表的总权重。 + + Args: + ops (list): 包含多个操作对象的列表。 + + Returns: + int: 所有操作的权重总和。 + + """ weight = 0 for op in ops: + # 遍历每个操作 weight += self.prod(op.output.shape) * \ - PrimLib.dtype_bytes(op.output.dtype) + PrimLib.dtype_bytes(op.output.dtype) # 计算op的输出数据类型的字节数 return weight def injective_analyze(self): + """ + 分析单射情况。 + + Args: + 无 + + Returns: + 无 + + """ """analyze injective case""" + # 计算常量大小 const_size = max((self.prod(op.output.shape) for op in self.dom_op)) + # 调整常量大小,确保是MAX_NUM_THREADS的倍数 const_size = (const_size + self.MAX_NUM_THREADS - 1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS + # 计算总权重 total_weight = self._cal_weight(self.ops) + # 计算总块数 total_block = (const_size + self.MAX_NUM_THREADS - 1) // self.MAX_NUM_THREADS + # 判断是否需要分割块 need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS if need_block_split: + # 如果需要分割块,设置块数为MAX_BLOCK self.block_num = self.MAX_BLOCK + # 计算波数 waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK + # 计算块权重 self.block_weight = total_weight // total_block * waves else: + # 如果不需要分割块,设置块数为总块数 self.block_num = total_block + # 计算块权重 self.block_weight = total_weight // self.block_num def reduce_analyze(self): + """ + 分析reduce操作。 + + Args: + 无 + + Returns: + 无 + + Raises: + RuntimeError: 如果并行融合不支持多个reduce操作,或者没有找到reduce操作。 + + """ """analyze reduce case""" + # 定义线程数 thread_x, thread_y = 32, 32 reduce_op = None + for op in self.ops: + # 判断操作类型是否为reduce if PrimLib.iter_type(op) == PrimLib.REDUCE: + # 如果已经存在reduce操作,则抛出异常 if reduce_op: raise RuntimeError("Parallel fusion does not support multiple reduce op now.") reduce_op = op + + # 如果没有找到reduce操作,则抛出异常 if not reduce_op: raise RuntimeError("Parallel fusion does not find a reduce op.") + + # 获取reduce操作的输入形状 shape = reduce_op.inputs[0].shape + # 获取reduce操作的reduce轴 reduce_axis = reduce_op.attrs['reduce_axis'] + # 计算总空间 total_space = self.prod(shape) + # 计算reduce空间 red_space = shape[reduce_axis[0]] for i in range(1, len(reduce_axis)): red_space *= shape[reduce_axis[i]] + + # 获取数据类型大小 dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype) + # 计算权重 weight = self._cal_weight(self.ops) # reduce + injective + # 计算block_x block_x = (total_space // red_space + thread_y - 1) // thread_y + # 计算block_w block_w = (weight + block_x - 1) // block_x + # 计算waves waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK + # 设置block_num self.block_num = min(self.MAX_BLOCK, block_x) + + # 定义all_reduce all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write + # 计算block_weight self.block_weight = (block_w + all_reduce * dtype_size * thread_x * thread_y) * waves def default_analyze(self): + """ + 默认分析函数 + + Args: + 无 + + Returns: + 无 + + Raises: + 无 + + """ """analyze default case""" + # 定义一个内部函数,用于计算默认空间 def _cal_default_space(op): + # 计算op的输出空间 space = self.prod(op.output.shape) + # 遍历op的所有输入 for t in op.inputs: + # 计算输入的空间 size = self.prod(t.shape) + # 如果输入空间大于当前空间,则更新空间 if size > space: space = size + # 返回计算出的空间 return space + + # 计算所有操作中的最大空间 space = max((_cal_default_space(op) for op in self.dom_op)) - # each sm least 4 wrap + # 每个sm至少包含4个wrap + # 计算所需的block数量 block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4) + # 将block数量限制在最大block数量之内 self.block_num = min(self.MAX_BLOCK, block) + # 计算每个block的权重 self.block_weight = self._cal_weight(self.ops) // self.block_num def analyze(self): """analyze ops""" def _ops_type(ops, dom_op): + """ + 判断操作列表中是否包含reduce操作。 + + Args: + ops (list): 操作列表。 + dom_op (list): 操作列表。 + + Returns: + bool: 如果操作列表中包含reduce操作,则返回True;否则返回False。 + """ + # 检查ops列表中是否有reduce操作 have_reduce = any( + # 如果op的类型是PrimLib.REDUCE,则返回True (PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops)) if have_reduce: + # 如果有reduce操作,返回True return True + # 否则返回dom_op[0]的类型 return PrimLib.iter_type(dom_op[0]) + # 调用_ops_type函数,获取dom_op的类型 dom_type = _ops_type(self.ops, self.dom_op) + # 如果dom_type是PrimLib.ELEMWISE或PrimLib.BROADCAST类型 if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST): + # 调用injective_analyze方法 self.injective_analyze() + # 如果dom_type是PrimLib.REDUCE类型 elif dom_type == PrimLib.REDUCE: + # 调用reduce_analyze方法 self.reduce_analyze() + # 如果dom_type是其他类型 else: + # 调用default_analyze方法 self.default_analyze() def suitable_to_pipeline(self): """judge whether is suitable to be pipeline optimized""" + # 判断是否适合进行流水线优化 + + # Reduce操作不适合 # Reduce is not suitable def _contain_reduce(ops): for op in ops: + # Reduce操作可能导致分片效果差 # Reduce may make the tiling bad. if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE: return True @@ -149,6 +320,7 @@ class ScheduleAnalyzer: suitable = True if _contain_reduce(self.ops): + # 如果包含Reduce操作,则不适合进行流水线优化 suitable = False return suitable @@ -166,13 +338,16 @@ class ScheduleAnalyzer: classes (list[list[int]]): The list of clusters. Each cluster is a list of indices. """ def _cal_mean(classes): + # 计算每个聚类的均值 class_datas = list(list(data[cid] for cid in cls) for cls in classes) return list(sum(cls) / len(cls) if cls else float('inf') for cls in class_datas) def _cal_distance(a, b): + # 计算两个元素之间的距离 return abs(a - b) def _check_different(old_classes, new_classes): + # 检查新旧聚类是否不同 for o, n in zip(old_classes, new_classes): if o != n: return True @@ -201,31 +376,39 @@ class ScheduleAnalyzer: min_idx = i if min_dis > cur_dis else min_idx min_dis = cur_dis if min_dis > cur_dis else min_dis new_classes[min_idx].append(idx) + # 检查聚类是否发生变化 changed = _check_different(classes, new_classes) + # 更新聚类 classes = new_classes return classes @staticmethod def pipeline_fusion_analyze(blocks, op_sizes, exclude_id): """analyze whether the segments can be pipeline optimized""" - # op size first, block second. + # op size first, block second。 + # 操作大小在前,块在后 def _simple_factor(block, op_size): return block + 5 * op_size def _take_second(elem): return elem[1] + # 计算每个块的简单因子 simple_indicators = list(_simple_factor(b, s) for b, s in zip(blocks, op_sizes)) # 2 classes, one heavy, the other light + # 两类,一类重,一类轻 classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id) if not classes: return [] + # 计算每类的均值 means = list(sum([simple_indicators[idx] for idx in cls]) / len(cls) if cls else float('inf') for cls in classes) # The target two clusters should be a heavy one and a light one. + # 目标两类应该是一类重的和一类轻的 # The light one maybe suitable to run with pipeline optimized. + # 轻的一类可能适合进行流水线优化 classes_infos = list([cls, m] for cls, m in zip(classes, means)) classes_infos.sort(key=_take_second) pipeline_target = None @@ -234,6 +417,7 @@ class ScheduleAnalyzer: pipeline_target = ci break pipeline_gids, pipeline_mean = pipeline_target + # 如果轻的一类的均值大于某个阈值,则返回空列表 if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks), ScheduleAnalyzer.PIPELINE_OP_THREADHOLD): return [] @@ -241,6 +425,7 @@ class ScheduleAnalyzer: pipeline_blocks = [] pipeline_weight = len(pipeline_gids) # Try to make two paralleled at least. + # 至少尝试两个并行 if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2: if len(pipeline_gids[:pipeline_weight // 2]) > 1: pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2]) @@ -252,49 +437,114 @@ class ScheduleAnalyzer: @staticmethod def fusion_consult(blocks, op_sizes, exclude_gid): + """ + 获取并行融合的建议。 + + Args: + blocks (list): 包含多个计算块的列表。 + op_sizes (list): 每个操作的尺寸列表。 + exclude_gid (int): 需要排除的组ID。 + + Returns: + tuple: 包含融合类型和类型信息的元组。 + + Raises: + 无 + + """ """get a recommendation for parallel fusion""" + # 默认是块融合 # Default is block fusion fusion_type = "block_fusion" type_info = None + # 禁用管道优化 activate_pipeline_optimization = False # Disable pipeline optimization for now. + # 如果启用管道优化 if activate_pipeline_optimization: + # 对块、操作大小和排除组ID进行管道融合分析 pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze( blocks, op_sizes, exclude_gid) + # 如果存在管道信息 if pipeline_info: + # 融合类型为块管道融合 fusion_type = "block_pipeline_fusion" + # 设置类型信息为管道信息 type_info = pipeline_info return fusion_type, type_info def block_parallel_estimate(graphs): + """ + 估计块并行增益。 + + Args: + graphs (list): 图集合,每个元素是一个图对象。 + + Returns: + ParalGain: 包含块并行增益信息的ParalGain对象。 + + """ """estimate block parallel gain""" + # 初始化变量 sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], [] + + # 遍历图集合 for gid, g in enumerate(graphs): + # 创建ScheduleAnalyzer对象 s = ScheduleAnalyzer(g) + # 分析图 s.analyze() + # 累加块的数量 sum_block += s.block_num + # 更新最大权重 if s.block_weight > max_weight: max_weight = s.block_weight + # 累加权重 sum_weight += s.block_weight + # 添加块的数量到blocks列表 blocks.append(s.block_num) + # 添加操作数量到op_sizes列表 op_sizes.append(len(s.ops)) + # 如果不适合流水线处理,将gid添加到exclude_gid列表 if not s.suitable_to_pipeline(): exclude_gid.append(gid) + + # 如果块的数量大于ScheduleAnalyzer.MAX_SM * 32,返回"none" if sum_block > ScheduleAnalyzer.MAX_SM * 32: return ParalGain("none", sum_weight, 0, list(0 for _ in graphs), None) + # 获取融合类型和类型信息 fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid)) + # 返回ParalGain对象 return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info) def parallel_estimate(graphs, target): + """ + 并行估计函数。 + + Args: + graphs (list): 图结构列表。 + target (str): 目标类型,例如"aicore"。 + + Returns: + ParalGain: 并行增益对象。 + + """ """Estimate parallel gain""" + # 如果目标是"aicore" if target == "aicore": + # 融合类型为"block_fusion" fusion_type = "block_fusion" + # 类型信息为空 type_info = None + # 假设估计值为1000 fake_estimate = 1000 + # 生成一个与graphs长度相同的列表,每个元素都是1 fake_blocks = list(1 for g in graphs) + # 返回ParalGain对象 return ParalGain(fusion_type, fake_estimate, fake_estimate, fake_blocks, type_info) + # 调用block_parallel_estimate函数进行并行估计 return block_parallel_estimate(graphs) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_split.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_split.py index 4f6c9778..3dea775d 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_split.py @@ -21,33 +21,98 @@ from .model import DataFormat as DF def tensor_size(tensor): + """ + 获取张量的总大小。 + + Args: + tensor (torch.Tensor): 输入的张量。 + + Returns: + int: 张量的总大小。 + + """ """get tensor size""" + # 初始化大小为1 size = 1 + # 遍历张量的形状 for i in tensor.shape: + # 将当前维度的大小乘到总大小上 size *= i + # 返回总大小 return size def reduce_nums(ops): + """ + 统计以'Reduce'开头的操作数量。 + + Args: + ops (List[Operation]): 操作列表,其中每个操作都是一个Operation对象。 + + Returns: + int: 以'Reduce'开头的操作数量。 + + """ """get reduce nums""" count = 0 + # 遍历操作列表 for op in ops: + # 判断操作是否以'Reduce'开头 if op.prim.startswith('Reduce'): + # 如果是,计数器加一 count += 1 + # 返回计数结果 return count -def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size): - """check if can stitch""" +def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size): # 如果a的操作数大于等于2,则返回False + """ + 检查是否可以进行拼接操作。 - def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size): - """does a and b have same stitch axis""" + Args: + dom (Operator): DOM操作符,表示拼接操作的上下文。 + a (Operator): 待拼接的操作符A。 + r (Operator): 待拼接的操作符R。 + stitch_axis_size (int): 拼接轴的大小。 + stitch_buffer_size (int): 缓冲区大小,用于判断是否满足拼接条件。 - def _stitch_axis(shape, stitch_axis_size): - """get stitch axis""" + Returns: + bool: 如果可以进行拼接操作,返回True;否则返回False。 + + """ + """check if can stitch""" + # 获取dom的操作输出 + def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size): # 获取a的操作输入 + """ + 判断给定的张量列表是否具有相同的拼接轴。 + + Args: + stitch_tensors (list of Tensor): 待拼接的张量列表。 + final_outs (list of Tensor): 最终输出的张量列表。 + stitch_axis_size (int): 拼接轴的大小限制。 + + Returns: + bool: 如果所有张量在拼接轴上的大小相同,则返回True;否则返回False。 + + """ + """判断a和b是否具有相同的拼接轴""" # 获取a的操作输出 + # 获取a的操作最终输出 + def _stitch_axis(shape, stitch_axis_size): # 获取dom的输出中在a的输入中的张量 + """ + 获取拼接轴。 + + Args: + shape (list): 形状列表,表示张量的维度。 + stitch_axis_size (int): 拼接轴的大小。 + + Returns: + list: 拼接轴的列表。 + + """ + """获取拼接轴""" # 如果stitch_tensors和a_final_outs的stitch轴不同,则返回False stitchaxis = [] - size = 1 - for i in shape: + size = 1 # 如果stitch_tensors中的任何一个张量的size大于等于stitch_buffer_size,则返回True + for i in shape: # 否则返回False size = size * i stitchaxis.append(i) if size >= stitch_axis_size: @@ -56,25 +121,28 @@ def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size): x = [] x.extend(stitch_tensors) - x.extend(final_outs) + x.extend(final_outs) # 如果dom的模式不等于PrimLib.RESHAPE,则返回False stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size) - for item in x: - i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size) - if not i_stitch_axis or i_stitch_axis != stitch_axis_0: + for item in x: # 初始化最小面积和前向融合为None和False + i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size) # 遍历输出关系 + if not i_stitch_axis or i_stitch_axis != stitch_axis_0: # 如果关系模式小于等于广播,且无环,且最小面积未定义或关系模式小于最小面积模式,则更新最小面积 return False return True - - if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): + # 遍历输入关系 + if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): # 如果关系模式小于等于广播,且无环,且输入操作只有一个,且不是输出,且最小面积未定义或关系模式小于最小面积模式,则更新最小面积和前向融合 + """判断a和r的操作模式,以及a是否有环""" if reduce_nums(a.ops) >= 2: return False - dom_outs = set(op.output for op in dom.ops) + dom_outs = set(op.output for op in dom.ops) # 如果最小面积存在,则返回最小面积和前向融合,否则返回空列表 a_ins = set(op_input for op in a.ops for op_input in op.inputs) a_outs = set(op.output for op in a.ops) a_final_outs = list(tensor for tensor in a_outs if tensor not in a_ins) stitch_tensors = list(tensor for tensor in dom_outs if tensor in a_ins) + """获取需要拼接的张量""" # 如果模式不是elemwise或broadcast,或者输入关系不为1,则返回空列表 if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size): return False - return any((tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors)) + return any((tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors)) # 如果关系模式大于广播,或者输出关系不为1,或者关系不是elemwise,或者输出形状不等于输入形状,则返回空列表 + """判断是否存在需要拼接的张量的大小大于等于缓冲区大小""" return False @@ -83,52 +151,109 @@ class CommonPattern: @staticmethod def reshape(dom): + """ + 对reshape dom进行融合策略的函数。 + + Args: + dom (PrimDom): 需要进行reshape操作的dom对象。 + + Returns: + tuple: 包含两个元素的元组。 + - List[PrimDom]: 包含最小面积的dom对象的列表。 + - bool: 表示是否进行前向融合的布尔值。 + + """ """fuse strategy for reshape dom""" + # 判断dom的模式是否为PrimLib.RESHAPE if dom.pattern != PrimLib.RESHAPE: return [] + + # 初始化最小面积和是否前向融合的标志 min_area, forward_fuse = None, False + + # 遍历dom的输出关系 for a, _ in dom.out_relations.items(): + # 如果a的模式小于等于PrimLib.BROADCAST,并且dom与a之间没有环,且a的模式小于当前最小面积的模式 if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ (min_area is None or a.pattern < min_area.pattern): + # 更新最小面积为a min_area = a + + # 遍历dom的输入关系 for a, _ in dom.in_relations.items(): + # 如果a的模式小于等于PrimLib.BROADCAST,且a与dom之间没有环,且dom的第一个操作的第一个输入的to_ops长度为1,且a不是输出 if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ (min_area is None or a.pattern < min_area.pattern): + # 更新最小面积为a,并设置前向融合标志为True min_area, forward_fuse = a, True + + # 如果最小面积存在,则返回最小面积列表和前向融合标志,否则返回空列表 return ([min_area], forward_fuse) if min_area else [] @staticmethod def elemwise_depth(dom): + """ + 在深度上对elemwise dom进行融合策略。 + + Args: + dom (object): 待融合的dom对象。 + + Returns: + tuple: 包含两个元素的元组,第一个元素为dom的输入操作列表,第二个元素为布尔值,表示是否成功进行融合。 + + """ """fuse strategy in depth for elemwise dom""" + # 如果dom的模式不是Elemwise或Broadcast,或者dom的输入关系数量不为1,则返回空列表 if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: return [] + + # 获取dom的唯一输入关系及其关联的操作 a, r = list(dom.in_relations.items())[0] + + # 如果a的模式大于Broadcast,或者a的输出关系数量不为1,或者r的模式不是Elemwise,或者a和dom的输出形状不一致,则返回空列表 if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ a.dom_op().output.shape != dom.dom_op().output.shape: return [] + + # 返回包含a的列表和True return [a], True @staticmethod def elemwise_width(dom): """fuse strategy in width for elemwise dom""" + # 如果dom的模式不是PrimLib.ELEMWISE或PrimLib.BROADCAST,则返回空列表 if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): return [] + fused = [] + # 遍历dom的输入关系 for a, r in dom.in_relations.items(): + # 如果a的模式小于等于PrimLib.BROADCAST,关系r为PrimLib.ELEMWISE,且a是无环的, + # 同时a的dom操作输出形状与dom的dom操作输出形状相同 if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \ a.dom_op().output.shape == dom.dom_op().output.shape: + # 将满足条件的a添加到fused列表中 fused.append(a) + + # 返回fused列表和一个布尔值True return fused, True @staticmethod def assign(dom): """fuse strategy for assign dom""" + # 判断dom的操作数个数是否等于1,并且操作是否为Assign if len(dom.ops) != 1 or dom.dom_op().prim != "Assign": + # 如果不满足条件,返回空列表 return [] + + # 初始化一个空列表,用于存储融合后的结果 fused = [] + # 遍历dom的输入关系 for a, _ in dom.in_relations.items(): + # 将输入关系中的a添加到fused列表中 fused.append(a) + # 返回融合后的结果和表示成功融合的布尔值True return fused, True @@ -139,37 +264,56 @@ class GraphSplitByPattern: """Reachable table""" def __init__(self, size): + # 初始化一个空的列表,用于存储地图状态 self.map = [] + # 初始化一个集合,包含所有活细胞的位置 self.alive = set(range(size)) + # 遍历从0到size-1的每一个数字 for i in range(0, size): + # 在map列表中添加一个新的列表,包含size个False值 self.map.append([False] * size) + # 将对角线位置设置为True,表示这些位置是活细胞 self.map[i][i] = True def reachable(self, x, y): """reachable from x to y""" + # 返回地图中位置 (x, y) 的可达性 return self.map[x][y] def sync(self, x, y): - """sync from y to x""" + """ + 从 y 同步到 x + """ + # 遍历 alive 列表 for i in self.alive: + # 调用 _link 方法,将 y 中的元素链接到 x 中 self._link(self.map[y][i], x, i) def _link(self, cond, f, t): """link from `f` to `t`""" + # 如果条件为真 if cond: + # 在映射字典中将 f 到 t 的路径标记为 True self.map[f][t] = True def fuse(self, x, y): """fuse y to x""" for i in self.alive: + # i 是 y 的后继节点,将 x 的前驱节点连接到 i # i is the succeeding node of y, links the x's previous nodes to i if self.map[y][i] and not self.map[x][i]: + # 遍历 x 的前驱节点 for pre in self.alive: + # 将 x 的前驱节点连接到 i self._link(self.map[pre][x], pre, i) + # i 是 y 的前驱节点,将 i 连接到 x 的后继节点 # i is the previous node of y, link i to x's succeeding nodes if self.map[i][y] and not self.map[i][x]: + # 遍历 x 的后继节点 for suc in self.alive: + # 将 i 连接到 x 的后继节点 self._link(self.map[x][suc], i, suc) + # 从 self.alive 中移除 y self.alive.remove(y) class Area: @@ -181,92 +325,145 @@ class GraphSplitByPattern: """StitchInfo""" def __init__(self): + """ + 初始化函数。 + """ + # 初始化存储拼接操作的集合 self.stitch_ops = set() + # 初始化存储原子拼接操作的集合 self.stitch_atomic_ops = set() def has_stitch_op(self): """check stitch_op exists""" + # 检查是否存在 stitch_ops return self.stitch_ops or self.stitch_atomic_ops def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None): + # 初始化模式类型 self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN + # 初始化操作列表 self.ops = [] if init_op is None else [init_op] + # 初始化输入关系字典 self.in_relations = dict() # {area1: relation1, area2: relation2, ...} + # 初始化输出关系字典 self.out_relations = dict() # {area1: relation1, area2: relation2, ...} + # 初始化模式 self.mode = None + # 初始化缝合信息 self.stitch_info = self.StitchInfo() + # 初始化重计算操作列表 self.recompute_ops = [] if recompute_ops is None else recompute_ops + # 初始化原始操作映射 self.ori_op_map = {} + # 初始化是否重计算 self.is_recompute = False + # 初始化是否为输出 self.is_output = is_output + # 初始化输出排除集合 self.output_excluded = set() + # 如果模式是减少类型,则执行以下逻辑 if self.pattern == PrimLib.REDUCE: + # 定义用于收集减少排除的函数 def _gather_reduce_exclude(): + # 初始化递归栈 recursion_stack = [init_op] + # 循环遍历递归栈 while recursion_stack: + # 弹出栈顶操作 op = recursion_stack.pop() + # 遍历操作的输出到操作 for to in op.output.to_ops: + # 获取输出到操作的索引 idx = to.inputs.index(op.output) + # 如果关系大于元素级关系,则将该操作添加到输出排除集合 if self.get_relation(to, idx) > PrimLib.ELEMWISE: self.output_excluded.add(to) else: + # 否则,将该操作添加到递归栈中 recursion_stack.append(to) + # 调用收集减少排除的函数 _gather_reduce_exclude() + # 初始化唯一ID self.unique_id = unique_id + # 初始化可达表 self.reach_tab = reach_tab def __str__(self): + # 将self.ops中的每个op的output.name拼接成一个由'-'连接的字符串 return '<' + '-'.join((op.output.name for op in self.ops)) + '>' def __repr__(self): + # 返回对象的字符串表示 return str(self) @staticmethod def get_relation(op, i): """Get op relation""" + # 初始化关系为未知 relation = PrimLib.UNKNOWN + # 获取输入的关系 _, elem_relation = PrimLib.input_relation(op, i) + # 遍历每个关系 for r in elem_relation: + # 如果关系为空,则将关系更新为广播 if r is None: + # 更新关系为最大关系,若当前关系为UNKNOWN,则更新为BROADCAST relation = max(relation, PrimLib.BROADCAST) + # 如果当前关系大于已有关系,则更新关系 elif r > relation: relation = r return relation def link_input(self, area_map): """Link inputs""" + # 遍历self.ops[0].inputs中的每个元素 for i, t in enumerate(self.ops[0].inputs): + # 如果当前元素t的op不为空 if t.op is not None: + # 从area_map中获取当前元素t的op对应的area + # 并调用self.get_relation获取self.ops[0]和i的relation area, relation = area_map[t.op], self.get_relation(self.ops[0], i) + # 将area和relation保存到self.in_relations中 self.in_relations[area] = relation def link_output(self): """Link outputs""" + # 遍历输入区域与关系的字典 for input_area, r in self.in_relations.items(): + # 将当前对象添加到输入区域的输出关系中 input_area.out_relations[self] = r + # 遍历输出关系 for out, _ in self.out_relations.items(): + # 同步当前对象与输出对象的唯一标识符 self.reach_tab.sync(self.unique_id, out.unique_id) def update_stitch_info(self, stitch_info): """Update stitch info""" + # 如果stitch_info中存在stitch_ops if stitch_info.stitch_ops: + # 更新self.stitch_info中的stitch_ops self.stitch_info.stitch_ops.update(stitch_info.stitch_ops) + # 如果stitch_info中存在stitch_atomic_ops if stitch_info.stitch_atomic_ops: + # 更新self.stitch_info中的stitch_atomic_ops self.stitch_info.stitch_atomic_ops.update(stitch_info.stitch_atomic_ops) def fuse(self, area): - """Fuse `area` to `self`""" + """将`area`融合到`self`中""" def _update_relation(relations, a, r): + # 更新关系 relations[a] = max(r, relations[a]) if a in relations else r def _update_pattern(): + # 更新模式 if area.pattern > self.pattern: self.pattern = area.pattern if area in self.in_relations and self.in_relations.get(area) > self.pattern: self.pattern = self.in_relations.get(area) def _fuse_relation(self_relations, new_relations): + # 融合关系 for a, r in new_relations.items(): if a != self: _update_relation(self_relations, a, r) @@ -274,65 +471,94 @@ class GraphSplitByPattern: self_relations.pop(area) def _redirect_relation(rels): - """Replace `area` with `self` in relations""" + """在关系中用`self`替换`area`""" if area in rels: r = rels.pop(area) _update_relation(rels, self, r) + # 如果`area`需要重新计算 if area.is_recompute: self.cp_ops(area) + # 如果`self`的模式大于或等于`area`的模式 if self.pattern >= area.pattern: self.ops.extend(area.ops) else: self.ops = area.ops + self.ops + # 更新模式 _update_pattern() + # 融合输入关系 _fuse_relation(self.in_relations, area.in_relations) + # 融合输出关系 _fuse_relation(self.out_relations, area.out_relations) + # 更新输入关系的重定向 for a, _ in area.in_relations.items(): _redirect_relation(a.out_relations) + # 更新输出关系的重定向 for a, _ in area.out_relations.items(): _redirect_relation(a.in_relations) + # 如果`self`的模式大于PrimLib.RESHAPE if self.pattern > PrimLib.RESHAPE: self.mode = self.MODE_COMPOSITE + # 如果`area`是输出而`self`不是 if area.is_output and not self.is_output: self.is_output = True + # 如果`area`有排除的输出 if area.output_excluded: self.output_excluded.update(area.output_excluded) + # 更新拼接信息 self.update_stitch_info(area.stitch_info) + # 如果`area`不需要重新计算 if not area.is_recompute: self.reach_tab.fuse(self.unique_id, area.unique_id) + # 融合重新计算的操作 self.recompute_ops.extend(area.recompute_ops) def check_acyclic(self, to): """Check circle. It returns false if circle exists""" + # 遍历所有出边关系 for out, _ in self.out_relations.items(): + # 如果当前出边的节点不是目标节点,并且从当前出边节点到目标节点可达 if out != to and self.reach_tab.reachable(out.unique_id, to.unique_id): + # 存在环,返回False return False + # 不存在环,返回True return True def dom_op(self): - """Get dom op""" + """ + 获取dom操作 + """ + # 返回操作列表中的第一个操作 return self.ops[0] def reduce_out_exclude(self, area): """Check whether op is reduce_out_exclude """ + # 如果self.output_excluded为真 if self.output_excluded: + # 遍历self.output_excluded中的每个操作 for op in self.output_excluded: + # 如果操作在area.ops中存在 if op in area.ops: + # 返回True return True + # 如果没有找到符合条件的操作,返回False return False def cp_ops(self, area): """copy recompute_ops in area to ops, self is area's user""" tail_tensor = area.recompute_ops[-1].output + # 复制张量,所有复制的张量都是Tensor.PARA_NONE # copy tensors, all copied are Tensor.PARA_NONE tensor_map = {} if area.recompute_ops[0].inputs: + # 如果第一个操作的输入不为空,则将输入张量映射为自身 tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0] for op in area.recompute_ops: orig_tensor = op.output cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format) tensor_map[orig_tensor] = cp_tensor + + # 复制操作 # copy ops cp_ops = [] for op in area.recompute_ops: @@ -341,6 +567,8 @@ class GraphSplitByPattern: cp_op.all_inputs = cp_op.inputs cp_ops.append(cp_op) area.ori_op_map[cp_op] = op + + # 连接复制的操作 # connect copied ops for op in self.ops: if tail_tensor in op.inputs: @@ -348,6 +576,8 @@ class GraphSplitByPattern: op.inputs.append(tensor_map.get(tail_tensor)) tail_tensor.to_ops.remove(op) tensor_map.get(tail_tensor).to_ops.append(op) + + # 将复制的操作填充到self.recompute_area中 # fill cp_ops in self.recompute_area cp_dom_op = None for cp, ori in area.ori_op_map.items(): @@ -358,60 +588,94 @@ class GraphSplitByPattern: area.ops.extend((op for op in cp_ops if op != cp_dom_op)) def __init__(self, graph, flags): + # 初始化图对象 self.graph = graph + # 初始化空区域列表 self.areas = [] + # 初始化标志对象 self.flags = flags + # 初始化是否启用重新计算融合的标志 self.enable_recompute = self.flags.get("enable_recompute_fusion", False) + # 初始化是否启用缝合融合的标志 self.enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) + # 初始化是否启用水平融合的标志 self.enable_horizontal_fusion = self.flags.get("enable_horizontal_fusion", False) + # 初始化可达表 self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops)) + # 初始化区域映射字典 self.area_map = {} + # 获取图的输出参数 _, outputs = graph.deduce_parameters() idx = 0 + # 遍历图中的所有操作 for op in graph.ops: + # 判断操作是否是输出操作 is_output = op.output in outputs + # 创建一个区域对象 a = self.Area(op, is_output, idx, self.reach_tab) idx += 1 + # 设置默认模式 self.set_default_mode(a) + # 将区域对象添加到区域列表中 self.areas.append(a) + # 设置区域映射 self.set_area_map([op], a) + # 遍历所有区域,设置输入链接 for a in self.areas: a.link_input(self.area_map) + # 从后往前遍历区域,设置输出链接 for i in range(len(self.areas) - 1, -1, -1): self.areas[i].link_output() + # 如果启用了重新计算融合 if self.enable_recompute: + # 创建一个用于重新计算的区域对象 self.recom_area = self.Area(None, False, idx, self.reach_tab) + # 设置重新计算标志 self.recom_area.is_recompute = True + # 初始化重新计算区域的前驱、用户和支配区域 self.recom_pre = None self.recom_user = None self.recom_dom = None + # 初始化重新计算区域的支配用户 self.dom_user_r = PrimLib.UNKNOWN + # 初始化重新计算结果标志 self.recom_res = False + # 初始化原始操作映射 self.orig_op_map = {} def set_area_map(self, ops, area): """update area_map after op fused to area""" + # 遍历操作列表 for op in ops: + # 将操作映射到指定的区域 self.area_map[op] = area def set_default_mode(self, area): """Set default mode""" + # 设置区域模式为默认模式 area.mode = self.get_default_mode(area.ops[0]) @staticmethod def limit_area_size(dominant, fuse_areas, limit_size=200): """Remove some areas if the size is too large""" + # 计算每个区域的操作数大小 area_sizes = map(lambda area: len(area.ops), fuse_areas) + # 计算主要区域的操作数大小 dom_size = len(dominant.ops) + # 如果总操作数大小不超过限制大小,则返回原区域列表 if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size: return fuse_areas + # 按操作数大小优先融合较小的区域 # fuse the smaller area in priority fuse_areas.sort(key=lambda area: len(area.ops)) new_fuse_areas = [] for area in fuse_areas: + # 如果加上当前区域后超过限制大小,则跳出循环 if dom_size + len(area.ops) > limit_size: break + # 累加当前区域的操作数到总操作数中 dom_size += len(area.ops) + # 将当前区域添加到新的区域列表中 new_fuse_areas.append(area) return new_fuse_areas @@ -419,174 +683,292 @@ class GraphSplitByPattern: """Fuse areas""" def _fuse_area(): + # 遍历所有区域 for dominant in self.areas: + # 使用选择器函数选择区域 result = selector(dominant) + # 如果选择器没有返回结果或结果为空,则跳过当前循环 if not result or not result[0]: continue + # 解包结果 fuse_areas, is_forward = result + # 限制融合区域的大小 fuse_areas = self.limit_area_size(dominant, fuse_areas) + # 如果没有融合区域,则跳过当前循环 if not fuse_areas: continue + # 判断融合方向 if is_forward: + # 如果是正向融合 for area in fuse_areas: + # 将当前区域与融合区域融合 dominant.fuse(area) + # 更新区域映射 self.set_area_map(area.ops, dominant) + # 从区域列表中移除融合区域 self.areas.remove(area) else: + # 如果是反向融合 forward_area = dominant for area in fuse_areas: + # 将当前区域与融合区域融合 area.fuse(forward_area) + # 更新区域映射 self.set_area_map(forward_area.ops, area) + # 从区域列表中移除当前区域 self.areas.remove(forward_area) + # 更新当前区域为下一个融合区域 forward_area = area + # 返回True表示有变化 return True + # 如果没有进行任何融合操作,则返回False return False changed, do_again = False, True while do_again: + # 执行融合操作 do_again = _fuse_area() + # 更新是否有变化的标志 changed = changed or do_again return changed def hfuse(self, selector): """Fuse horizontal areas with same input tensor""" + # 定义一个内部函数,用于执行融合操作 def _do_fuse(areas): + # 遍历所有区域,除了最后一个 for i in range(len(areas) - 1): dom = areas[i] + # 遍历剩余的区域 for a in areas[i + 1:]: + # 如果两个区域无环且满足selector函数,且融合后的区域大小不超过限制 if dom.check_acyclic(a) and a.check_acyclic(dom) and \ selector(dom, a) and self.limit_area_size(dom, [a], 64): + # 融合区域 dom.fuse(a) + # 更新区域映射 self.set_area_map(a.ops, dom) + # 从区域列表中移除已融合的区域 self.areas.remove(a) + # 返回True表示有变化 return True + # 如果没有发生融合,返回False return False + # 定义一个内部函数,用于更新区域列表 def _update_areas(areas, from_op): + # 遍历from_op的所有输出操作 for op in from_op.to_ops: + # 获取操作对应的区域 a = self.area_map.get(op) + # 如果区域存在且不在当前区域列表中,则添加到列表中 if a in self.areas and a not in areas: areas.append(a) + # 初始化变化标志 changed = False + # 不断循环,直到没有变化为止 while True: for dom in self.areas: + # 如果当前区域有多个输出关系,则尝试进行融合 if len(dom.out_relations) > 1 and _do_fuse(list(dom.out_relations.keys())): changed = True break + # 如果没有发生任何变化,则跳出循环 else: break + + # 获取输入参数 inputs, _ = self.graph.deduce_parameters() + # 不断循环,直到没有变化为止 while True: for t in inputs: + # 初始化区域列表 areas = [] + # 更新区域列表 _update_areas(areas, t) + # 如果区域列表中有多个区域,则尝试进行融合 if len(areas) > 1 and _do_fuse(areas): changed = True break + # 如果没有发生任何变化,则跳出循环 else: break + + # 返回是否有变化 return changed def fuse_recom(self, selector): """Fuse recompute area to its user""" + # 遍历主导区域数组 for dominant in [self.recom_area, self.recom_user]: + # 使用选择器函数处理主导区域 result = selector(dominant) + # 如果选择器函数返回结果且第一个元素为真 if result and result[0]: + # 解包结果,获取需要融合的区域和额外信息 fuse_areas, _ = result + # 对需要融合的区域进行大小限制 fuse_areas = self.limit_area_size(dominant, fuse_areas) + # 如果没有需要融合的区域,则跳过当前循环 if not fuse_areas: continue + # 如果需要融合的第一个区域是主导区域之一 if fuse_areas[0] in [self.recom_area, self.recom_user]: + # 将recom_area融合到recom_user self.recom_user.fuse(self.recom_area) + # 设置区域映射 self.set_area_map(self.recom_area.ops, self.recom_user) + # 设置融合成功的标志 self.recom_res = True + # 返回成功标志 return True + # 如果没有成功融合,则返回失败标志 return False def index_op(self): """index op by order, the copied op share id with original op, for topo-sort""" + # 创建一个空字典,用于存储操作及其对应的索引 ids = {} + + # 遍历图中的操作,并为每个操作分配一个索引 for i, op in enumerate(self.graph.ops): + # 将操作作为键,索引作为值存储到ids字典中 ids[op] = i + + # 如果存在orig_op_map属性 if hasattr(self, 'orig_op_map'): + # 遍历orig_op_map中的键值对 for k, v in self.orig_op_map.items(): + # 如果v在ids中存在,则将k的索引设置为v的索引 ids[k] = ids.get(v) + + # 返回ids字典 return ids def to_subgraphs(self): """Transform op groups to subgraphs""" + # 获取操作索引 ids = self.index_op() + # 初始化子图列表 subgraphs = [] + # 初始化图模式列表 graphmodes = [] + # 遍历区域列表 for i, area in enumerate(self.areas): + # 根据操作索引对区域中的操作进行排序 area.ops.sort(key=ids.get) + # 创建子图并添加到子图列表中 subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops)) + # 根据区域模式设置图模式并添加到图模式列表中 graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") return subgraphs, graphmodes def pattern_fuse(self, fuse_func=None): """fuse Areas by pattern repeatedly""" + # 删除fuse_func参数 del fuse_func + # 抛出异常,提示pattern_fuse()函数在当前类中未实现 raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__)) def split(self): """Split graph by pattern""" + # 融合模式 self.pattern_fuse() + # 重新计算融合结果 self.recompute_fuse() + # 将图拆分为子图和模式 subgraphs, graphmodes = self.to_subgraphs() + # 返回子图和模式 return subgraphs, graphmodes def set_recompute(self, dom_area, ops, user_area): """set the recompute area and connect with other areas""" + # 将操作添加到recompute区域 self.recom_area.recompute_ops.extend(ops) # recom_area: set dom_op and correct ops length + # 设置dom_op并修正ops长度 patterns = list(PrimLib.iter_type(op) for op in ops) self.recom_area.pattern = max(patterns) for i, pat in enumerate(patterns): if pat == self.recom_area.pattern: + # 修正ops列表,使其长度与patterns相同 self.recom_area.ops = [ops[i]] * len(ops) break + # disconnect dom_area and user_area + # 断开dom_area和user_area的连接 self.dom_user_r = dom_area.out_relations[user_area] dom_area.out_relations.pop(user_area) user_area.in_relations.pop(dom_area) + # connect recom_area and user_area + # 连接recom_area和user_area user_area.in_relations[self.recom_area] = self.dom_user_r self.recom_area.out_relations[user_area] = self.dom_user_r + # connect recom_pre and recom_area + # 连接recom_pre和recom_area self.recom_pre = self.area_map.get(ops[0].inputs[0].op) if ops[0].inputs and ops[0].inputs[0].op else None if self.recom_pre is not None: self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre] self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre] + # set related areas + # 设置相关区域 self.recom_user = user_area self.recom_dom = dom_area self.recom_res = False def clear_recompute(self): """disconnect recom_area from other areas, and clear recom_area""" - self.recom_area.out_relations.clear() - self.recom_area.in_relations.clear() + # 断开recom_area与其他区域的连接 + self.recom_area.out_relations.clear() # 清除recom_area的输出关系 + self.recom_area.in_relations.clear() # 清除recom_area的输入关系 + + # 如果没有recom_res if not self.recom_res: + # 从recom_user的输入关系中移除recom_area self.recom_user.in_relations.pop(self.recom_area) + # 在recom_user的输入关系中添加recom_dom,并设置关系为dom_user_r self.recom_user.in_relations[self.recom_dom] = self.dom_user_r + # 在recom_dom的输出关系中设置recom_user的关系为dom_user_r self.recom_dom.out_relations[self.recom_user] = self.dom_user_r + # 如果存在recom_pre if self.recom_pre: + # 从recom_pre的输出关系中移除recom_area self.recom_pre.out_relations.pop(self.recom_area) + + # 清除recom_area的操作 self.recom_area.ops.clear() + # 清除recom_area的重计算操作 self.recom_area.recompute_ops.clear() + # 将recom_area的原始操作映射更新到orig_op_map中 self.orig_op_map.update(self.recom_area.ori_op_map) + # 清除recom_area的原始操作映射 self.recom_area.ori_op_map.clear() def to_subgraph(self, dom): """Transform area to subgraphs""" + # 获取索引操作符 ids = self.index_op() + + # 初始化一个空列表用于存储操作 dom_ops = list() + + # 将dom的ops属性扩展到dom_ops列表中 dom_ops.extend(dom.ops) + + # 根据ids的get方法对dom_ops进行排序 dom_ops.sort(key=ids.get) + + # 初始化一个空列表用于存储子图 subgraph = [] + + # 使用dom_ops和指定的图名创建一个Graph对象,并将其赋值给subgraph subgraph = Graph('{}_area'.format(self.graph.name), dom_ops) + + # 返回创建好的子图 return subgraph def find_cheap_regions(self, dom): @@ -595,44 +977,57 @@ class GraphSplitByPattern: def _grow_region(region_ops, op, weight, inputs): """include op to region_ops if region grow""" # region successfully ends at inputs + # 如果op没有输入,则将其添加到region_ops中,并返回False表示区域不再增长 if not op.inputs: region_ops.append(op) return False, op, weight, True + # 如果op的输入是inputs中的一个,且只有一个输入,且op的类型小于等于BROADCAST,则将其添加到region_ops中 if op.inputs[0] in inputs and len(op.inputs) == 1 and \ PrimLib.iter_type(op) <= PrimLib.BROADCAST: region_ops.append(op) return False, op, weight, True # region fails to grow + # 如果weight大于20,或op的输入个数大于1,或op的类型大于BROADCAST,则区域增长失败 max_weight = 20 if weight > max_weight or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST: return False, op, weight, False # region grows successfully + # 如果区域成功增长,则增加weight,并将op添加到region_ops中 weight = weight + 1 region_ops.append(op) return True, op.inputs[0].op, weight, False def _find_cheap_regions(dom): + # 从dom区域中提取子图 sub = self.to_subgraph(dom) + # 推导子图的输入和输出参数 inputs, outputs = sub.deduce_parameters() + # 如果没有输入,则返回空列表 if not inputs: return list() cheap_regions = [] for output in outputs: - # tensor should have user other than user_area to be fused + # tensor should have user other than user_area to be fused + # 如果输出的操作数少于2,则跳过 if len(output.to_ops) < 2: continue + # 初始化区域操作列表和区域增长标志 region_ops = [] grow = True candidate_op = output.op weight = 1 result = False + # 当区域增长时,不断尝试扩展区域 while grow: grow, candidate_op, weight, result = _grow_region(region_ops, candidate_op, weight, inputs) + # 如果区域成功扩展,则反转操作列表,并检查是否满足区域大小条件 if result: region_ops.reverse() # tensor size should equal or becomes larger(cast up, broadcast) + # 如果区域的第一个操作的输入张量大小大于最后一个操作的输出张量大小,则跳过 if region_ops[0].inputs and region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size(): continue + # 将满足条件的区域添加到cheap_regions列表中 cheap_regions.append(region_ops) return cheap_regions @@ -642,7 +1037,9 @@ class GraphSplitByPattern: """select the user area has only one edge to dom area""" def _get_edge_num(dom_area, user_area): - """get edge num between two areas""" + """ + 获取两个区域之间的边数 + """ dom_graph = self.to_subgraph(dom_area) _, dom_outputs = dom_graph.deduce_parameters() user_graph = self.to_subgraph(user_area) @@ -650,13 +1047,18 @@ class GraphSplitByPattern: return len(list(t for t in dom_outputs if t in user_inputs)) def _select_user_area(tail_tensor): + """ + 选择只有一个边到dom区域的user区域 + """ user_areas = [] for user_op in tail_tensor.to_ops: user_area = self.area_map.get(user_op) if user_area.pattern == PrimLib.RESHAPE: + # 如果user区域是reshape操作,则跳过 continue edge_num = _get_edge_num(self.area_map.get(tail_tensor.op), user_area) if edge_num == 1 and not user_area in user_areas: + # 如果edge数为1且user区域不在user_areas中,则添加到user_areas user_areas.append(user_area) return user_areas @@ -669,14 +1071,20 @@ class GraphSplitByPattern: """split the unfusing pattern by add recompute area""" def recompute_cheap_region(dom): + # 对每个廉价区域进行处理 for cheap_region in cheap_regions: + # 获取用户区域 user_areas = self.select_user_area(cheap_region[-1].output) if not user_areas: continue for user_area in user_areas: + # 设置重新计算区域 self.set_recompute(dom, cheap_region, user_area) + # 融合模式 self.pattern_fuse(self.fuse_recom) + # 清除重新计算区域 self.clear_recompute() + # 如果重新计算结果有效 if self.recom_res: return True return False @@ -687,13 +1095,17 @@ class GraphSplitByPattern: for dom in orig_areas: if dom not in self.areas or not dom.out_relations: continue + # 找到廉价区域 cheap_regions = self.find_cheap_regions(dom) + # 对当前区域进行廉价区域重新计算 if recompute_cheap_region(dom): recompute_suc = True return recompute_suc + # 如果启用了重新计算 if self.enable_recompute: while do_recompute_fuse(): + # 融合模式 self.pattern_fuse() @@ -705,10 +1117,14 @@ class GraphSplitGpu(GraphSplitByPattern): def get_default_mode(self, op): """Get default mode in GPU""" + # 判断操作是否为矩阵乘法 if op.prim == "MatMul": + # 如果是矩阵乘法,并且输入数据类型为float16且属性Akg存在,则返回MODE_COMPOSITE模式,否则返回MODE_BASIC模式 return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" and op.attrs['Akg'] else \ self.Area.MODE_BASIC + # 获取操作的迭代类型 pattern = PrimLib.iter_type(op) + # 根据迭代类型返回相应的模式 return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE def pattern_fuse(self, fuse_func=None): @@ -720,128 +1136,213 @@ class GraphSplitGpu(GraphSplitByPattern): return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST def _broadcast_depth(dom): + # 如果dom的模式不是Elemwise或Broadcast,或者输出关系数量不为1,或者dom是输出,或者操作数超过BROADCAST_FUSE_DEPTH,则返回空列表 if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH: return [] + + # 获取dom的输出关系中的第一个元素及其对应的关系 a, r = list(dom.out_relations.items())[0] + + # 如果满足广播排除条件,或者a的输入关系数量不为1,则返回空列表 if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: return [] + + # 返回包含a的列表和False return [a], False def _broadcast_width(dom): + # 如果dom的模式不是PrimLib.ELEMWISE或PrimLib.BROADCAST,或者dom是输出节点,或者dom的操作数深度超过BROADCAST_FUSE_DEPTH,则返回空列表 if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH: return [] - fused = [] + fused = [] # 初始化一个空列表,用于存储符合条件的节点 + + # 遍历dom的输出关系 for a, r in dom.out_relations.items(): + # 如果满足以下任一条件,则返回空列表 if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \ (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): return [] + # 否则,将节点a添加到fused列表中 fused.append(a) + + # 返回fused列表和一个False值 return fused, False def _reduce_pat_exclude(_, a, r): + # 如果操作数a的操作数量大于设定的减少融合深度,则返回True if len(a.ops) > self.REDUCE_FUSE_DEPTH: return True + # 如果a的模式大于基本库中的元素级操作,或者r大于基本库中的减少操作,或者r等于基本库中的广播操作,则返回True return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST def _reduce_depth(dom): + # 检查dom的pattern是否为REDUCE,且dom的in_relations长度是否为1 if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: + # 如果不是,则返回空列表 return [] + + # 获取dom的in_relations的第一个元素及其关系 a, r = list(dom.in_relations.items())[0] + + # 检查特定条件,如果满足,则返回空列表 if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ _is_atomic_add_available(dom): # to evade the precision problem. + # 为避免精度问题 return [] + + # 检查是否满足特定条件或a的out_relations长度不为1,如果满足,则返回空列表 if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: return [] + + # 返回包含a的列表和布尔值True return [a], True def _reduce_width(dom): + # 判断dom的模式是否为REDUCE if dom.pattern != PrimLib.REDUCE: return [] + fused = [] for a, r in dom.in_relations.items(): + # 判断dom的第一个操作的输入数据类型是否为float16,且a为输出,a的操作数大于等于10,且满足_is_atomic_add_available条件 if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ _is_atomic_add_available(dom): + # 跳过,以避免精度问题 # to evade the precision problem. continue + + # 判断是否满足_reduce_pat_exclude条件,且a在dom中是无环的 if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom): + # 将满足条件的a添加到fused列表中 fused.append(a) + return fused, True def _is_atomic_add_available(dom): + # 检查是否存在Reduce操作 if any(("Reduce" in x.prim for x in dom.ops[1:])): return False + + # 获取第一个操作 op = dom.ops[0] + + # 检查是否有reduce_axis属性 if "reduce_axis" in op.attrs: reduce_axis = op.attrs["reduce_axis"] + # 检查是否有axis属性 elif "axis" in op.attrs: reduce_axis = [op.attrs["axis"]] else: + # 如果以上属性都不存在,抛出异常 raise Exception("For '{}', can not find the attr 'reduce_axis' or 'axis'".format(op.prim)) + + # 检查reduce_axis中是否包含输入张量的倒数第二个维度 if len(op.inputs[0].shape) - 1 in reduce_axis: + # 计算reduce操作的大小 reduce_size = prod_reduce(lambda x, y: x * y, (op.inputs[0].shape[i] for i in reduce_axis)) + # 判断reduce操作的大小是否大于等于1024 return reduce_size >= 1024 + # 如果没有包含倒数第二个维度,返回True return True def _reduce_output(dom): + # 判断dom的模式是否为REDUCE if dom.pattern != PrimLib.REDUCE: return [] + # 判断reduce操作的次数是否大于1 if reduce_nums(dom.ops) > 1: return [] + # 判断是否为原子加法可用 if _is_atomic_add_available(dom): return [] + # 判断是否为全归约 is_all_reduce = tensor_size(dom.ops[0].output) == 1 + # 排除大尺寸的全归约 # excluded large size all reduce if is_all_reduce and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: return [] fused = [] for a, r in dom.out_relations.items(): + # 判断a的模式和r的模式是否都小于等于BROADCAST,且dom的a节点无环,且a不在reduce_out_exclude列表中 if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ dom.check_acyclic(a) and not dom.reduce_out_exclude(a): fused.append(a) return fused, False def _reduce_stitch(dom): + # 如果dom的模式不是PrimLib.REDUCE,则返回空列表 if dom.pattern != PrimLib.REDUCE: return [] + + # 如果dom的第一个操作的输出张量大小为1,则返回空列表 if tensor_size(dom.ops[0].output) == 1: return [] + + # 如果dom的第一个操作的第一个输入张量大小小于1024 * 12,则返回空列表 if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12: return [] fused = [] + # 遍历dom的输出关系 for a, r in dom.out_relations.items(): + # 如果不满足拼接条件,则跳过当前循环 if not may_stitch(dom, a, r, 1024 * 8, 1024 * 1024): continue + + # 如果a的模式是PrimLib.REDUCE if a.pattern == PrimLib.REDUCE: + # 如果a和dom的reduce轴相同 if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']: + # 将dom的输出名称添加到stitch_ops集合中 dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) + # 将a添加到fused列表中 fused.append(a) + # 如果a的模式是PrimLib.BROADCAST elif a.pattern == PrimLib.BROADCAST: + # 将dom的输出名称添加到stitch_ops集合中 dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) + # 将a添加到fused列表中 fused.append(a) + + # 返回fused列表和False return fused, False def _transpose(dom): + # 如果dom的操作数不为1或者第一个操作不是"Transpose",则返回空列表 if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": return [] fused = [] + + # 遍历dom的输入关系 for a, _ in dom.in_relations.items(): + # 如果输入的操作模式小于等于PrimLib.BROADCAST, + # 并且输入的操作是无环的,并且输入的操作数小于等于TRANSPOSE_FUSE_DEPTH,则将其添加到fused列表中 if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH: fused.append(a) + # 返回fused列表和一个表示成功的布尔值 return fused, True def _strided_slice(dom): + # 判断操作是否为 StridedSlice if dom.dom_op().prim != "StridedSlice": return [] + + # 初始化 fused 列表 fused = [] + + # 遍历 dom 的输入关系 for a, _ in dom.in_relations.items(): + # 判断输入节点是否满足条件 if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ len(a.out_relations) == 1 and not a.is_output: + # 如果满足条件,则将其添加到 fused 列表中 fused.append(a) + + # 返回 fused 列表和布尔值 True return fused, True def _gather_output(dom, reduce_fusion=False): @@ -860,37 +1361,52 @@ class GraphSplitGpu(GraphSplitByPattern): Returns: Boolean. Whether this operator should be excluded. """ + # 获取reduce操作符的reduce轴 axis = op.attrs["reduce_axis"] + # 如果reduce轴是整数,则将其转换为列表 if isinstance(axis, int): axis = [axis] + # 获取输入数据的形状长度 in_shape_len = len(op.inputs[0].shape) + # 对每个reduce轴进行转换,如果是负数则加上输入形状的长度 for i, dim in enumerate(axis): + # 如果dim是负数,则将其转换为正数 axis[i] = in_shape_len + dim if dim < 0 else dim + + # 初始化一个空列表,用于存储经过过滤的reduce轴 fix_axis = [] + # 遍历每个reduce轴 for ax in axis: + # 如果当前轴的长度为1,则跳过 if op.inputs[0].shape[ax] == 1: continue + # 否则,将当前轴添加到fix_axis列表中 fix_axis.append(ax) + + # 返回fix_axis和axis_list的交集是否非空 return bool(set(fix_axis) & set(axis_list)) def _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis): - consisten_shape = start_op.output.shape - visited = [] - op_queue = [start_op] + consisten_shape = start_op.output.shape # 获取起始操作的输出形状 + visited = [] # 用于记录已经访问过的操作 + op_queue = [start_op] # 初始化操作队列,包含起始操作 def _early_stop(cur_op): + # 提前停止函数 if cur_op in end_ops: - # If reduce the gather axis, stop early for not fusion. + # 如果当前操作在结束操作中 + # 如果减少聚合轴,则不融合,提前停止 if cur_op.prim == "ReduceSum" and _reduce_exclude(cur_op, gather_axis): return True else: + # 如果当前操作不在起始操作中或形状不一致 if (cur_op.prim in start_prims and cur_op != start_op) or \ consisten_shape != cur_op.output.shape: return True return False while op_queue: - tmp_queue = [] + tmp_queue = [] # 临时队列,用于存储下一层的操作 for op in op_queue: if op in visited or not op in total_ops: continue @@ -899,9 +1415,9 @@ class GraphSplitGpu(GraphSplitByPattern): if op in end_ops: continue for to_op in op.output.to_ops: - tmp_queue.append(to_op) - visited.append(op) - op_queue = tmp_queue + tmp_queue.append(to_op) # 将当前操作的输出操作添加到临时队列中 + visited.append(op) # 将当前操作标记为已访问 + op_queue = tmp_queue # 更新操作队列为临时队列 return True def _shape_consistent(start_prims, end_prims, source, target): @@ -911,30 +1427,41 @@ class GraphSplitGpu(GraphSplitByPattern): When fusing ReduceSum, first check if TensorScatterAdd and/or UnsortedSegmentSum has already been fused, if so, stop ReduceSum fusion. """ + # 获取source和target的所有操作 total_ops = source.ops + target.ops + # 获取所有操作的prim集合 op_prims_set = {op.prim for op in total_ops} + # 如果需要融合ReduceSum,并且TensorScatterAdd和/或UnsortedSegmentSum已经被融合,则不融合ReduceSum if reduce_fusion and (len({"TensorScatterAdd", "UnsortedSegmentSum"} & op_prims_set) >= 1): return False + + # 获取source的起始操作 start_ops = [] for op in source.ops: if op.prim in start_prims: start_ops.append(op) + + # 获取total_ops中的结束操作 end_ops = [] for op in total_ops: if op.prim in end_prims and not any((to_op in total_ops for to_op in op.output.to_ops)): end_ops.append(op) + # 遍历start_ops中的每一个操作 for start_op in start_ops: + # 获取操作的gather_axis属性 gather_axis = start_op.attrs.get("axis", None) if gather_axis is None: - # For GatherNd + # 对于GatherNd gather_axis = list(range(len(start_op.inputs[1].shape))) elif isinstance(gather_axis, int): gather_axis = [gather_axis] + # 调用_bfs_visit函数检查形状一致性 is_consistent = _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis) if not is_consistent: return False + return True appected_areas = {"ReduceSum"} if reduce_fusion else {"TensorScatterAdd", "UnsortedSegmentSum"} @@ -947,11 +1474,13 @@ class GraphSplitGpu(GraphSplitByPattern): def _broadcast_tot(dom): """Fuse rule for TensorScatterAdd and UnsortedSegmentSum.""" def _same_input(op1, op2): + # 判断两个操作的输入是否有交集 return bool(set(op1.inputs) & set(op2.inputs)) if len(dom.ops) != 1: return [] + # 只融合 TensorScatterAdd 的第一个输入和 UnsortedSegmentSum 的前两个输入 # Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`. fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)} arg_idx = fuse_arg.get(dom.dom_op().prim, -1) @@ -960,9 +1489,11 @@ class GraphSplitGpu(GraphSplitByPattern): fuse_tensor = dom.dom_op().inputs[arg_idx] for a, _ in dom.in_relations.items(): + # 规则1:类型相同且有至少一个相同输入 # Rule 1: Same type with at lease one same input. if a.dom_op().prim == dom.dom_op().prim and _same_input(dom.dom_op(), a.dom_op()): return [a], True + # 规则2:在指定位置输入中融合操作(reshape/elementwise/broadcast) # Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs. if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ any((op.output in fuse_tensor for op in a.ops)): @@ -971,74 +1502,121 @@ class GraphSplitGpu(GraphSplitByPattern): def _broadcast_onehot(dom, fwd=True): """Fuse rule for OneHot.""" + # 判断当前操作的原始操作是否为OneHot if dom.dom_op().prim != "OneHot": return None fused = [] + # 根据fwd参数决定是遍历输入关系还是输出关系 neighbours = dom.in_relations.items() if fwd else dom.out_relations.items() for a, _ in neighbours: + # 判断当前关系是否满足广播模式 if a.pattern <= PrimLib.BROADCAST: + # 判断当前关系是否满足无环、单一输出且不是输出节点(如果fwd为True) + # 或者判断当前关系的来源节点是否满足无环(如果fwd为False) if (fwd and a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output) or \ (not fwd and dom.check_acyclic(a)): + # 将满足条件的关系添加到fused列表中 fused.append(a) return fused, fwd def _h_broadcast(dom, a): + # 判断dom的模式是否大于PrimLib.BROADCAST if dom.pattern > PrimLib.BROADCAST: return [] + # 判断a的模式是否小于等于PrimLib.BROADCAST,并且dom的第一个操作的输出形状与a的第一个操作的输出形状是否相同 return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape def _h_reduce(dom, a): + # 如果dom的模式不是PrimLib.REDUCE或者dom的拼接信息包含拼接操作,则返回空列表 if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops: return [] + + # 获取dom的操作 dom_op = dom.ops[0] + # 如果dom的操作不是reduce操作或者dom是原子加法可用的情况,则返回空列表 if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom): return [] + + # 获取a的操作 op = a.ops[0] + # 返回a的模式是PrimLib.REDUCE,a的拼接信息不包含拼接操作,a的操作是reduce操作, + # dom的输入和a的输入形状相同,且dom和a的reduce轴相同 return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis") def _h_opaque(dom, a): + # 检查dom的第一个操作是否是StridedSlice if dom.ops[0].prim not in {"StridedSlice"}: + # 如果不是,则返回空列表 return [] + # 返回以下条件的逻辑与结果: + # a的第一个操作是否与dom的第一个操作相同 + # dom的第一个操作的输出形状是否与a的第一个操作的输出形状相同 + # dom的第一个操作的第一个输入的形状是否与a的第一个操作的第一个输入的形状相同 return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape def _fuse_loop(): changed = True + # 当changed为True时,进入循环 while changed: + # 尝试融合reshape模式 changed = self.fuse(CommonPattern.reshape) + # 尝试融合assign模式,如果失败则保留原来的changed值 changed = self.fuse(CommonPattern.assign) or changed + # 尝试融合elemwise_depth模式,如果失败则保留原来的changed值 changed = self.fuse(CommonPattern.elemwise_depth) or changed + # 尝试融合elemwise_width模式,如果失败则保留原来的changed值 changed = self.fuse(CommonPattern.elemwise_width) or changed + # 尝试融合_reduce_depth模式,如果失败则保留原来的changed值 changed = self.fuse(_reduce_depth) or changed + # 尝试融合_reduce_width模式,如果失败则保留原来的changed值 changed = self.fuse(_reduce_width) or changed + # 尝试融合_broadcast_depth模式,如果失败则保留原来的changed值 changed = self.fuse(_broadcast_depth) or changed + # 尝试融合_broadcast_width模式,如果失败则保留原来的changed值 changed = self.fuse(_broadcast_width) or changed + # 尝试融合_strided_slice模式,如果失败则保留原来的changed值 changed = self.fuse(_strided_slice) or changed + # 尝试融合正向的_broadcast_onehot模式,如果失败则保留原来的changed值 changed = self.fuse(partial(_broadcast_onehot, fwd=True)) or changed + # 尝试融合反向的_broadcast_onehot模式,如果失败则保留原来的changed值 changed = self.fuse(partial(_broadcast_onehot, fwd=False)) or changed + # 尝试融合_broadcast_tot模式,如果失败则保留原来的changed值 changed = self.fuse(_broadcast_tot) or changed + # 尝试融合不启用reduce_fusion的_gather_output模式,如果失败则保留原来的changed值 changed = self.fuse(partial(_gather_output, reduce_fusion=False)) or changed + # 尝试融合启用reduce_fusion的_gather_output模式,如果失败则保留原来的changed值 changed = self.fuse(partial(_gather_output, reduce_fusion=True)) or changed + # 尝试融合_reduce_output模式,如果失败则保留原来的changed值 changed = self.fuse(_reduce_output) or changed + # 如果启用了stitch_fusion,则尝试融合_reduce_stitch模式,如果失败则保留原来的changed值 if self.enable_stitch_fusion: changed = self.fuse(_reduce_stitch) or changed + # 融合_transpose模式 self.fuse(_transpose) + # 如果启用了horizontal_fusion,则进行以下融合操作 if self.enable_horizontal_fusion: + # 融合_h_broadcast模式 self.hfuse(_h_broadcast) + # 融合_h_reduce模式 self.hfuse(_h_reduce) + # 融合_h_opaque模式 self.hfuse(_h_opaque) def _fuse_once(fuse_func): + # 如果满足任一条件,则直接返回 if fuse_func(CommonPattern.reshape) or fuse_func(CommonPattern.elemwise_depth) or \ fuse_func(CommonPattern.elemwise_width) or fuse_func(_reduce_depth) or \ fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or fuse_func(_broadcast_width): return + # 如果满足任一条件,则直接返回 if fuse_func(_reduce_output) or (self.enable_stitch_fusion and fuse_func(_reduce_stitch)): return + # 调用 fuse_func 对 _transpose 进行处理 fuse_func(_transpose) return @@ -1057,6 +1635,7 @@ class GraphSplitAscend(GraphSplitByPattern): """Get default mode for Ascend""" def _dtype_same(tensors): + # 判断所有张量的数据类型是否相同 dtype = tensors[0].dtype for tensor_ in tensors: if tensor_.dtype != dtype: @@ -1064,25 +1643,31 @@ class GraphSplitAscend(GraphSplitByPattern): return True if op.prim == "MatMul": + # 如果是矩阵乘法操作 if op.inputs[0].dtype == "float16" and not _dtype_same(op.inputs): + # 如果输入张量中存在不同数据类型的张量,则返回复合模式 return self.Area.MODE_COMPOSITE if op.prim in ("Tile", "BroadcastTo", "ExpandDims"): + # 如果是平铺、广播到、扩展维度操作 return self.Area.MODE_COMPOSITE return self.Area.MODE_BASIC def pattern_fuse(self, fuse_func=None): """fuse Areas by pattern""" + # 判断某个运算是否可能在多核上运行 def _likely_multicore(dom): op = dom.dom_op() iter_size = tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0]) return iter_size > 1024 + # 判断某个运算是否应该被排除在广播融合之外 def _broadcast_pat_exclude(dom, a, r): if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH): return True return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST + # 获取广播融合深度 def _broadcast_depth(dom): if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1: return [] @@ -1091,6 +1676,7 @@ class GraphSplitAscend(GraphSplitByPattern): return [] return [a], False + # 获取广播融合宽度 def _broadcast_width(dom): if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): return [] @@ -1102,6 +1688,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + # 判断某个运算是否应该被排除在归约融合之外 def _reduce_pat_exclude(dom, a, r): if len(a.ops) > self.REDUCE_FUSE_DEPTH: return True @@ -1110,6 +1697,7 @@ class GraphSplitAscend(GraphSplitByPattern): return True return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE + # 获取归约融合深度 def _reduce_depth(dom): if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: return [] @@ -1118,6 +1706,7 @@ class GraphSplitAscend(GraphSplitByPattern): return [] return [a], True + # 获取归约融合宽度 def _reduce_width(dom): if dom.pattern != PrimLib.REDUCE: return [] @@ -1127,6 +1716,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, True + # 获取矩阵乘法融合深度 def _matmul_depth(dom): if dom.dom_op().prim != "MatMul" and dom.dom_op().prim != "BatchMatMul": return [] @@ -1139,6 +1729,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + # 获取归约输出融合 def _reduce_output(dom): if dom.pattern != PrimLib.REDUCE: return [] @@ -1152,6 +1743,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + # 获取归约拼接融合 def _reduce_stitch(dom): if dom.pattern != PrimLib.REDUCE: return [] @@ -1173,10 +1765,11 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + # 判断转置数据操作是否支持某个运算的融合 def _transdata_pattern_support(dom, a): transdata_op = dom.dom_op() - # Currently, if transdata has the pad, it is not used to fuse + # 如果转置数据操作有填充,则不用于融合 def _has_pad(): res = False input_shape = transdata_op.inputs[0].shape @@ -1197,7 +1790,7 @@ class GraphSplitAscend(GraphSplitByPattern): if a.dom_op().prim == "MatMul" and len(dom.ops) == 1: return True - # reshape/elewise/broadcast + transdata + # 重塑/元素级操作/广播 + 转置数据 if a.pattern <= PrimLib.BROADCAST and len(dom.ops) == 1: op_attrs = dom.dom_op().attrs if 'src_format' not in op_attrs.keys() \ @@ -1207,12 +1800,13 @@ class GraphSplitAscend(GraphSplitByPattern): src_format, dst_format = op_attrs['src_format'], op_attrs['dst_format'] if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW): return True - # For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported + # 对于Default/NCHW到FRAC_NZ的转换,目前仅支持Cast+Transdata if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ \ and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output: return True return False + # 获取转置数据融合 def _transdata(dom): if dom.dom_op().prim != "TransData": return [] @@ -1222,6 +1816,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, True + # 循环进行融合操作 def _fuse_loop(): changed = True while changed: @@ -1239,6 +1834,7 @@ class GraphSplitAscend(GraphSplitByPattern): changed = self.fuse(_reduce_stitch) or changed self.fuse(_transdata) + # 执行一次融合操作 def _fuse_once(fuse_func): if fuse_func(CommonPattern.reshape) or fuse_func(CommonPattern.elemwise_depth) or \ fuse_func(CommonPattern.elemwise_width) or fuse_func(_reduce_depth) or \ @@ -1254,9 +1850,14 @@ class GraphSplitAscend(GraphSplitByPattern): def split(graph, target, flags): """Split graph""" + # 初始化结果变量 result = None + # 如果目标设备是"cuda" if target == "cuda": + # 使用GraphSplitGpu类对图进行分割 result = GraphSplitGpu(graph, flags).split() else: + # 否则,使用GraphSplitAscend类对图进行分割 result = GraphSplitAscend(graph, flags).split() + # 返回分割结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model.py index 23701f97..a84d421b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model.py @@ -24,39 +24,61 @@ class Utils: @staticmethod def get_attr_type(attr): """Get attr type""" + # 判断attr是否为bool类型 if isinstance(attr, bool): return 'bool' + # 判断attr是否为str类型 if isinstance(attr, str): return 'str' + # 判断attr是否为int类型 if isinstance(attr, int): return 'int' + # 判断attr是否为float类型 if isinstance(attr, float): return 'float' + # 判断attr是否为list或tuple类型 if isinstance(attr, (list, tuple)): + # 判断attr是否为空 if not attr: raise ValueError("attr is invalid: the length of attr is 0") + # 判断attr的第一个元素是否为int类型 if isinstance(attr[0], int): return 'listInt' + # 判断attr的第一个元素是否为str类型 if isinstance(attr[0], str): return 'listStr' + # 如果attr的类型不在支持的列表中,则抛出异常 raise ValueError("attr {} type {} is not in supported list ['bool', 'str', 'int', 'float', 'int' list, " "'str' list]".format(attr, type(attr))) class DataFormat: """DataFormat""" + # 默认格式 DEFAULT = "DefaultFormat" + # NC1KHKWHWC0格式 NC1KHKWHWC0 = "NC1KHKWHWC0" + # ND格式 ND = "ND" + # NCHW格式 NCHW = "NCHW" + # NHWC格式 NHWC = "NHWC" + # HWCN格式 HWCN = "HWCN" + # NC1HWC0格式 NC1HWC0 = "NC1HWC0" + # FRAC_Z格式 FRAC_Z = "FracZ" + # FRAC_NZ格式 FRAC_NZ = "FRACTAL_NZ" + # C1HWNCOC0格式 C1HWNCOC0 = "C1HWNCoC0" + # NC1HWC0_C04格式 NC1HWC0_C04 = "NC1HWC0_C04" + # FRACTAL_Z_C04格式 FRACTAL_Z_C04 = "FRACTAL_Z_C04" + # NDHWC格式 NDHWC = "NDHWC" def __init__(self): @@ -65,29 +87,47 @@ class DataFormat: class DataType: """Data Type""" + # 浮点型 FLOAT = "float" + # 半精度浮点型 FLOAT16 = "float16" + # 单精度浮点型 FLOAT32 = "float32" + # 双精度浮点型 FLOAT64 = "float64" + # 整型 INT = "int" + # 8位整型 INT8 = "int8" + # 16位整型 INT16 = "int16" + # 32位整型 INT32 = "int32" + # 64位整型 INT64 = "int64" + # 无符号整型 UINT = "uint" + # 8位无符号整型 UINT8 = "uint8" + # 16位无符号整型 UINT16 = "uint16" + # 32位无符号整型 UINT32 = "uint32" + # 64位无符号整型 UINT64 = "uint64" + # 布尔型 BOOL = "bool" + # 初始化函数 def __init__(self): + # 无需执行任何操作 pass class PrimLib: """Prim lib""" + # 定义PrimLib类中的常量 UNKNOWN = 0 RESHAPE = 1 ELEMWISE = 2 @@ -102,53 +142,73 @@ class PrimLib: """Prim""" def __init__(self, iter_type, calibrate=1, relation_func=None): + # 初始化Prim类,设置iter_type、calibrate和relation_func属性 self.iter_type = iter_type self.calibrate = calibrate self.relation_func = relation_func if relation_func is None: + # 如果relation_func为None,则设置默认的relation_func self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x) def default_reshape_relation(self, op, input_idx): """Process reshape relation""" + # 处理reshape关系 axis_relation, elem_relation = self.unknown_relation(op, input_idx) + # 将elem_relation设置为PrimLib.RESHAPE elem_relation = [PrimLib.RESHAPE] * len(elem_relation) return axis_relation, elem_relation def default_elemwise_broadcast_relation(self, op, input_idx): """Process elemwise and broadcast relation""" + # 处理elemwise和broadcast关系 out_shape = op.output.shape in_shape = op.inputs[input_idx].shape + # 如果输出形状的长度小于输入形状的长度,则抛出异常 if len(out_shape) < len(in_shape): raise ValueError("For '{}', the input/output size is abnormal, as the length of output shape{} is less " "than the length of input shape{}".format(op.prim, out_shape, in_shape)) axis_relation, elem_relation = [], [] + # 计算输出形状和输入形状的长度差 delta = len(out_shape) - len(in_shape) if delta > 0: + # 如果输出形状的长度大于输入形状的长度,则在axis_relation和elem_relation中添加None for i in range(0, delta): axis_relation.append(None) elem_relation.append(None) + # 遍历输入形状的每个元素 for i, _ in enumerate(in_shape): + # 在axis_relation中添加当前元素的索引 axis_relation.append(i) + # 如果输出形状的对应元素等于输入形状的对应元素,则elem_relation添加PrimLib.ELEMWISE,否则添加PrimLib.BROADCAST elem_relation.append( PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST) return axis_relation, elem_relation def default_reduce_relation(self, op, input_idx): """Process reduce relation""" + # 处理reduce关系 axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx) + # 遍历reduce_axis中的每个元素 for i in op.attrs['reduce_axis']: + # 将elem_relation中对应元素的值设置为PrimLib.REDUCE elem_relation[i] = PrimLib.REDUCE return axis_relation, elem_relation def unknown_relation(self, op, input_idx): """Process unknown relation""" + # 获取输出和输入的形状 out_shape = op.output.shape in_shape = op.inputs[input_idx].shape + # 获取所有可能的轴关系 all_relation = list(range(len(in_shape))) + # 初始化轴关系列表 axis_relation = [all_relation for i in range(0, len(out_shape))] + # 初始化元素关系列表 elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))] + # 返回轴关系和元素关系 return axis_relation, elem_relation + # 默认的关系函数列表 default_relation_func = [ unknown_relation, default_reshape_relation, @@ -158,6 +218,7 @@ class PrimLib: unknown_relation, ] + # 定义基本操作 primtives = { 'Add': Prim(ELEMWISE), 'Abs': Prim(ELEMWISE), @@ -239,7 +300,9 @@ class PrimLib: @classmethod def get_prim(cls, op): """Get op primtive""" + # 从cls.primtives中获取op.prim对应的prim prim = cls.primtives.get(op.prim, None) + # 如果prim为None,则打印警告信息,并返回cls.default_primtive if prim is None: print('[WARN] primtive is not registered: ' + op.prim) prim = cls.default_primtive @@ -248,50 +311,65 @@ class PrimLib: @classmethod def input_relation(cls, op, input_idx): """Get op's input_relation according to input_idx""" + # 调用cls.get_prim(op)获取op对应的prim,然后调用prim的relation_func方法获取op的input_relation return cls.get_prim(op).relation_func(op, input_idx) @classmethod def iter_type(cls, op): """Get op's iter type""" + # 调用cls.get_prim(op)获取op对应的prim,然后返回prim的iter_type return cls.get_prim(op).iter_type @classmethod def is_reduce(cls, op): """Check whether op's iter type is reduce""" + # 调用cls.get_prim(op)获取op对应的prim,然后判断prim的iter_type是否为cls.REDUCE return cls.get_prim(op).iter_type == cls.REDUCE @classmethod def calibrate_iter_size(cls, op, iter_size): """Get calibrate_iter_size""" + # 调用cls.get_prim(op)获取op对应的prim,然后返回prim的calibrate乘以iter_size return cls.get_prim(op).calibrate * iter_size @classmethod def dtype_bytes(cls, dtype): """Get dtype bytes""" + # 初始化bits和unit为1 bits, unit = 1, 1 + # 从dtype的最后一个字符开始向前遍历 for i in range(len(dtype) - 1, 0, -1): + # 如果当前字符是数字,则将bits加上当前字符对应的数字乘以unit,并将unit乘以10 if dtype[i].isdecimal(): bits += int(dtype[i]) * unit unit *= 10 + # 如果当前字符不是数字,则跳出循环 else: break + # 返回bits除以8的结果 return bits // 8 @classmethod def inplace_reuse(cls, op, input_idx, start_axis=0): """Check whether op is inplace reuse""" + # 如果op.output.dtype的字节数大于op.inputs[input_idx].dtype的字节数,则返回False if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype): return False + # 调用cls.get_prim(op)获取op对应的prim,然后调用prim的relation_func方法获取op的input_relation _, elem_relation = cls.get_prim(op).relation_func(op, input_idx) + # 从start_axis开始遍历elem_relation for i in range(start_axis, len(elem_relation)): + # 如果elem_relation中的元素不等于cls.ELEMWISE,则返回False if elem_relation[i] != cls.ELEMWISE: return False + # 如果以上条件都不满足,则返回True return True class Tensor: """Tensor""" + # 参数类型常量 PARA_NONE = 0 PARA_INPUT = 1 PARA_OUTPUT = 2 @@ -303,6 +381,7 @@ class Tensor: self.members = [leader] def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0): + # 初始化Tensor对象 self.name = name self.shape = shape self.dtype = dtype @@ -313,13 +392,16 @@ class Tensor: self.buddy = None def __str__(self): + # 返回Tensor对象的字符串表示 return self.name + str(list(self.shape)) def __repr__(self): + # 返回Tensor对象的字符串表示 return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) def get_size(self): """Get size""" + # 获取Tensor对象的大小 size = PrimLib.dtype_bytes(self.dtype) for i in self.shape: size *= i @@ -327,6 +409,7 @@ class Tensor: def add_buddy(self, tensor): """Add buddy""" + # 添加buddy if self.buddy is None: self.buddy = self.Buddy(self) self.buddy.members.append(tensor) @@ -337,6 +420,7 @@ class Value: """Value""" def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT): + # 初始化Value对象 self.name = name self.shape = [1] self.dtype = dtype @@ -344,14 +428,17 @@ class Value: self.data_format = data_format def __str__(self): + # 返回Value对象的字符串表示 return self.name + str(list(self.shape)) def __repr__(self): + # 返回Value对象的字符串表示 return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) @staticmethod def get_size(): """Get size""" + # 获取Value对象的大小 return 1 @@ -359,23 +446,31 @@ class Operator: """Operator""" def __init__(self, primtive, inputs, output, attrs): + # 初始化Operator对象 self.prim = primtive self.inputs = inputs self.output = output self.attrs = attrs + # 将当前Operator对象添加到每个输入的to_ops列表中 for t in inputs: t.to_ops.append(self) + # 如果输出的op属性为None,则将当前Operator对象赋值给输出的op属性 if output.op is None: output.op = self + # 初始化all_inputs列表,用于存储Tensor输入和Value输入 self.all_inputs = [] # include Tensor inputs and Value inputs. def __str__(self): + # 将self.all_inputs中的元素转换为字符串,并用逗号连接起来 args = ', '.join((str(t) for t in self.all_inputs)) + # 构造表达式字符串 expr = "%s = %s.%s(%s) id:%s" % ( str(self.output), self.prim, self.output.dtype, args, id(self)) + # 如果self.attrs不为空,则返回表达式字符串和self.attrs的字符串连接 return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs)) def __repr__(self): + # 返回self的字符串表示 return str(self) @@ -383,6 +478,7 @@ class Graph: """Graph""" def __init__(self, name, ops, stitch_info=None, recompute_ops=None): + # 初始化Graph对象 self.name = name self.ops = ops # in topo order, can not use set self.inputs = [] @@ -393,10 +489,12 @@ class Graph: def set_processor(self, processor): """Set processor""" + # 设置处理器 self.processor = processor def add(self, ops): """Add ops""" + # 添加操作 if isinstance(ops, Operator): self.ops.append(ops) else: @@ -404,101 +502,148 @@ class Graph: def extract_subgraph(self, graph_name, tensor_names, difference=False): """Extract subgraph from this graph""" + # 从当前图中提取子图 graph = Graph(graph_name, []) outputs = set(tensor_names) if difference: + # 如果difference为True,则提取不在outputs中的操作 for op in self.ops: if op.output.name not in outputs: graph.add(op) else: + # 如果difference为False,则提取在outputs中的操作 for op in self.ops: if op.output.name in outputs: graph.add(op) outputs.remove(op.output.name) + # 如果outputs中还有元素,则抛出异常 for name in outputs: raise ValueError("Invalid input tensor : {}, can not find it in graph".format(name)) return graph def deduce_parameters(self): """Deduce parameters""" + # 初始化输入和输出列表 inputs, outputs = [], [] + # 遍历所有操作 for op in self.ops: + # 遍历操作的所有输入 for t in op.inputs: + # 如果输入不在输入列表中,且输入的操作不在操作列表中,则将输入添加到输入列表中 if t not in inputs and t.op not in self.ops: inputs.append(t) + # 如果操作输出已经在输出列表中,则跳过 if op.output in outputs: continue + # 如果操作输出是输出参数类型,或者操作输出没有后续操作,则将操作输出添加到输出列表中 if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops: outputs.append(op.output) continue + # 如果操作输出的后续操作不在操作列表中,则将操作输出添加到输出列表中 if any((succ not in self.ops for succ in op.output.to_ops)): outputs.append(op.output) + # 如果有指定的输入,则将指定的输入赋值给输入列表 if self.inputs: inputs = self.inputs + # 如果有指定的输出,则将指定的输出赋值给输出列表 if self.outputs: outputs = self.outputs + # 返回输入和输出列表 return inputs, outputs def __str__(self): + # 调用deduce_parameters方法获取输入和输出列表 inputs, outputs = self.deduce_parameters() + # 将输入列表转换为字符串 para_str = ', '.join((repr(t) for t in inputs)) + # 将输出列表转换为字符串 out_str = ', '.join((repr(t) for t in outputs)) + # 初始化行列表 lines = [] + # 添加操作名称、输入和输出到行列表中 lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str)) + # 如果有拼接信息,则添加拼接操作和拼接原子操作到行列表中 if self.stitch_info: if self.stitch_info.stitch_ops: lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops)) if self.stitch_info.stitch_atomic_ops: lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops)) + # 遍历所有操作,将操作添加到行列表中 for op in self.ops: lines.append(' ' + str(op)) + # 添加结束符号到行列表中 lines.append('}') + # 将行列表转换为字符串并返回 return '\n'.join(lines) def __repr__(self): + # 返回对象的字符串表示 return str(self) def dump(self): """Dump Graph to json""" + # 将Graph转换为json格式 attr_name = {'reduce_axis': 'axis'} + # 获取Graph的输入和输出参数 inputs, outputs = self.deduce_parameters() input_desc, output_desc, op_desc = [], [], [] + # 遍历输入参数 for t in inputs: + # 将输入参数转换为字典格式 input_desc.append([{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}]) + # 遍历输出参数 for t in outputs: + # 将输出参数转换为字典格式 output_desc.append({'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}) + # 遍历Graph中的操作 for op in self.ops: attrs, in_desc = [], [] + # 遍历操作中的属性 for a in op.attrs: + # 获取属性名 name = attr_name.get(a, a) + # 将属性转换为字典格式 attrs.append( {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])}) + # 遍历操作中的输入 for t in op.all_inputs: + # 如果输入是Tensor类型 if isinstance(t, Tensor): + # 将输入转换为字典格式 in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}]) else: + # 将输入转换为字典格式 in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}]) + # 将操作输出转换为字典格式 out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape, 'tensor_name': op.output.name, 'format': op.output.data_format}] + # 将操作转换为字典格式 op_desc.append({'attr': attrs, 'impl_path': '', 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc}) + # 将Graph转换为字典格式 graph_desc = {'composite': True, 'composite_graph': '', 'id': 0, 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc, 'platform': 'AKG', 'process': self.processor} + # 如果Graph中有stitch信息 if self.stitch_info and self.stitch_info.stitch_ops: + # 将stitch信息转换为字典格式 buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)} + # 如果有stitch_atomic_ops if self.stitch_info.stitch_atomic_ops: + # 将stitch_atomic_ops转换为字典格式 buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops) + # 将stitch信息添加到Graph字典中 graph_desc['buffer_stitch'] = buffer_stitch + # 返回Graph字典 return graph_desc @@ -506,13 +651,16 @@ class GraphVisitor: """Graph visitor""" def __init__(self, forward=True): + # 初始化forward参数,默认为True self.forward = forward def visit_graph(self, graph): """Visit graph""" + # 如果forward为True,则按照顺序遍历graph中的ops if self.forward: for op in graph.ops: self.visit(op) + # 如果forward为False,则按照逆序遍历graph中的ops else: for i in range(len(graph.ops)-1, -1, -1): self.visit(graph.ops[i]) @@ -528,12 +676,18 @@ class AlignShape(GraphVisitor): def visit(op): """Visit op node""" prim = PrimLib.get_prim(op) + # 如果op的迭代类型是ELEMWISE、BROADCAST或REDUCE,则需要进行形状对齐 if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE): + # 获取op的输出维度 out_dim = len(op.output.shape) + # 初始化对齐维度为输出维度 align_dim = out_dim + # 遍历op的输入 for t in op.inputs: + # 如果输入的维度大于对齐维度,则更新对齐维度 if len(t.shape) > align_dim: align_dim = len(t.shape) + # 如果对齐维度大于输出维度,则对op的输出形状进行对齐 if align_dim > out_dim: op.output.shape = [1] * (align_dim - out_dim) + op.output.shape diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model_builder.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model_builder.py index 7c9414ce..fc55cc42 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model_builder.py @@ -25,90 +25,140 @@ class GraphBuilder: """Graph wrapper""" def __init__(self, name): + """ + 初始化类实例。 + + Args: + name (str): 图的名称。 + + Attributes: + self.graph (Graph): 图的实例,使用传入的名称初始化。 + """ self.graph = Graph(name, []) def set_input(self, *para): """set input to graph inputs""" + # 遍历传入的参数 for t in para: + # 设置参数类型为输入参数 t.para_type = Tensor.PARA_INPUT + # 将参数添加到图的输入列表中 self.graph.inputs.append(t) def set_output(self, *para): """set output to graph inputs""" + # 遍历传入的参数 for t in para: + # 设置参数类型为输出参数 t.para_type = Tensor.PARA_OUTPUT + # 将参数添加到图的输出列表中 self.graph.outputs.append(t) def __init__(self): + # 初始化图列表 self.graphs = [] + # 当前图设置为None self.current = None + # 初始化名称ID self.name_id = 0 def _alloc_tensor_name(self): + # 获取当前名称ID tid = self.name_id + # 名称ID加1 self.name_id += 1 + # 格式化字符串,生成张量名称 return "t%d" % (tid) def graph_scope(self, name): """The graph scope to be processed""" + # 定义GraphScope类 class GraphScope: """Graph Scope""" def __init__(self, gb): + # 初始化GraphScope对象,接收一个GraphBuilder对象 self.gb = gb def __enter__(self): + # 当使用with语句进入GraphScope上下文时调用 return self.gb.current def __exit__(self, ptype, value, trace): + # 当离开GraphScope上下文时调用 self.gb.graphs.append(self.gb.current.graph) self.gb.current = None + # 检查self.current是否不为None if self.current is not None: raise ValueError("self.current is not None!") + # 创建GraphWrapper对象并赋值给self.current self.current = self.GraphWrapper(name) + # 返回GraphScope对象 return GraphScope(self) def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE): - """Create a new Tensor""" + """创建一个新的张量""" + # 如果名称为空或None,则分配一个新的张量名称 if name in (None, ''): + # 分配一个新的张量名称 name = self._alloc_tensor_name() + # 如果shape为空,则默认设置为[1] if not shape: shape = [1] + # 返回创建好的张量对象 return Tensor(name, shape, dtype, data_format, para_type=para_type) def value(self, dtype, value, name=None): """Create a new Value""" + # 如果name为None或空字符串 if name in (None, ''): + # 分配一个新的tensor名称 name = self._alloc_tensor_name() + # 创建一个新的Value对象 v = Value(name, dtype, value) + # 返回创建的Value对象 return v def op(self, prim, output, inputs, attrs=None): """Insert an operator into graph""" + # 如果 attrs 为 None,则将其设置为空字典 if attrs is None: attrs = {} + # 如果 inputs 是 Tensor 类型,则将其转换为列表 if isinstance(inputs, Tensor): inputs = [inputs] + # 过滤出 inputs 中 Tensor 类型的元素 tensor_inputs = [t for t in inputs if isinstance(t, Tensor)] + # 创建一个 Operator 对象 node = Operator(prim, tensor_inputs, output, attrs) + # 将所有输入保存到 node 的 all_inputs 属性中 node.all_inputs = inputs + # 将 node 添加到当前图的节点列表中 self.current.graph.add(node) def emit(self, prim, inputs, name=None, attrs=None): """Emit a new operation""" + # 如果attrs为None,则初始化为空字典 if attrs is None: attrs = {} + # 如果inputs是Tensor或Value的实例,则将其转换为列表 if isinstance(inputs, (Tensor, Value)): inputs = [inputs] + # 过滤出inputs中的Tensor和Value实例 tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))] + # 调用op_infer.infer函数进行形状、数据类型和格式的推断 out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs) + # 创建一个新的Tensor实例作为输出 output = self.tensor(out_shape, out_dtype, out_format, name) + # 执行操作,并将结果存储在output中 self.op(prim, output, inputs, attrs) + # 返回操作的结果 return output def get(self): """Get graphs""" + # 返回self.graphs return self.graphs @@ -116,16 +166,21 @@ class CompositeGraph: """Composite Graph""" def __init__(self): + # 初始化图对象,默认为None self.graph = None + # 初始化描述信息,默认为None self.desc = None - self.tensors = {} # name : Tensor + # 初始化张量字典,默认为空字典 + self.tensors = {} def refine(self): """Refine Graph""" + # 对图进行形状对齐操作 AlignShape().visit_graph(self.graph) def load(self, desc): """Load Graph from json""" + # 定义一个内部函数,用于处理操作属性 def _attr_of(op): if not op['attr']: return dict() @@ -134,21 +189,29 @@ class CompositeGraph: if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin', 'Argmax', 'Argmin'): attr['reduce_axis'] = a['value'] else: + # 将属性添加到字典中 attr[a['name']] = a['value'] return attr + # 创建GraphBuilder对象 builder = GraphBuilder() + # 在描述的操作范围内构建图 with builder.graph_scope(desc['op']): + # 遍历输入描述并构建输入张量 for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []: name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[ 0]['shape'], in_desc[0]['data_type'], in_desc[0]['format'] + # 将输入张量添加到tensors字典中 self.tensors[name] = builder.tensor( shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT) + # 遍历输出描述并构建输出张量 for out_desc in desc['output_desc']: name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[ 'shape'], out_desc['data_type'], out_desc['format'] + # 将输出张量添加到tensors字典中 self.tensors[name] = builder.tensor( shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) + # 遍历操作描述并构建操作 for op in desc['op_desc']: inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d] out_desc = op['output_desc'] @@ -164,52 +227,77 @@ class CompositeGraph: if not output: output = builder.tensor(shape, dtype, data_format, name=name) self.tensors[name] = output + # 构建操作并添加到图中 builder.op(op['name'], output, inputs, attrs=_attr_of(op)) + # 获取构建好的图 self.graph = builder.get()[0] self.desc = desc def add_stitch_info(self, subgraph, desc): """add stitch info to desc""" + # 如果subgraph包含stitch信息且stitch_ops不为空 if subgraph.stitch_info and subgraph.stitch_info.stitch_ops: + # 创建一个字典,用于存储stitch操作信息 buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)} + # 如果subgraph包含stitch_atomic_ops信息 if subgraph.stitch_info.stitch_atomic_ops: + # 将stitch_atomic_ops信息添加到buffer_stitch字典中 buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops) + # 将buffer_stitch信息添加到desc字典中 desc['buffer_stitch'] = buffer_stitch return desc def add_recompute_ops(self, subgraph, desc): """add recompute ops to desc""" + # 如果subgraph中包含需要重新计算的操作 if subgraph.recompute_ops: + # 将需要重新计算的操作的输出名称添加到desc中 desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops] return desc def _pre_dump(self, outputs): """restore name to before load""" + # 创建一个空字典用于存储inplace赋值操作 inplace_assign = {} # y_name, output_name inplace_assign_z = None + # 遍历self.desc['op_desc']中的操作 for op in self.desc['op_desc']: + # 如果操作名称为'InplaceAssign' if op['name'] == 'InplaceAssign': + # 将inplace赋值操作的输入tensor名作为键,输出tensor名作为值,存入inplace_assign字典 inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name'] + # 如果inplace_assign字典不为空 if inplace_assign: + # 遍历outputs中的tensor for t in outputs: + # 如果当前tensor的名称不在inplace_assign字典中 if t.name not in inplace_assign: + # 将当前tensor赋值给inplace_assign_z inplace_assign_z = t + # 返回inplace_assign和inplace_assign_z return inplace_assign, inplace_assign_z def dump(self, subgraph): """Dump Graph to json""" desc = {} + # 获取输入和输出参数 inputs, outputs = subgraph.deduce_parameters() + # 获取图中的所有操作 graph_ops = set(subgraph.ops) + # 预处理输出参数 inplace_assign, inplace_assign_z = self._pre_dump(outputs) def dump_output(t): + # 如果输出参数是原地赋值操作的结果 if t.name in inplace_assign: z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name] + # 返回包含数据类型、形状和张量名称的字典 return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)} + # 返回包含数据类型、形状和张量名称的字典 return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name} def dump_op_desc(d): + # 如果操作是原地赋值操作 if d['name'] == 'InplaceAssign': y = d['input_desc'][1][0]['tensor_name'] if self.tensors[y].op in graph_ops: @@ -222,33 +310,50 @@ class CompositeGraph: z_desc['tensor_name'] = z.name out_desc['shape'] = z.shape out_desc['data_type'] = z.dtype + # 返回处理后的原地赋值操作描述 return inplace_desc + # 获取操作对应的张量 op = self.tensors[d['output_desc'][0]['tensor_name']].op + # 如果操作在图操作集或重新计算操作集中 if op in graph_ops or op in subgraph.recompute_ops: + # 返回操作描述 return d + # 返回None return None for key in self.desc.keys(): if key == 'input_desc': + # 处理输入描述 desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs] elif key == 'output_desc': + # 处理输出描述 desc[key] = list(map(dump_output, outputs)) elif key == 'op_desc': + # 处理操作描述 op_desc = map(dump_op_desc, self.desc[key]) desc[key] = [d for d in op_desc if d is not None] elif key == 'op': + # 处理操作名称 desc[key] = subgraph.name else: + # 处理其他描述 desc[key] = self.desc[key] + # 添加缝合信息 desc = self.add_stitch_info(subgraph, desc) + # 添加重新计算操作信息 desc = self.add_recompute_ops(subgraph, desc) + # 返回最终描述 return desc def load_composite(desc): """Load composite kernel""" + # 创建一个CompositeGraph对象 composite = CompositeGraph() + # 加载描述信息 composite.load(desc) + # 对加载的CompositeGraph进行细化 composite.refine() + # 返回处理后的CompositeGraph对象 return composite diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/op_infer.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/op_infer.py index d30ee032..3ad25bf4 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/op_infer.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/op_infer.py @@ -25,24 +25,32 @@ def infer(op_name, inputs, attrs): """infer shape dtype and format""" def _create_opinfer(): + # 获取当前模块 self_module = sys.modules.get(__name__, None) + # 如果当前模块为空,则抛出异常 if self_module is None: raise GKException("OpInfo does not support op {}".format(op_name)) + # 如果当前模块有op_name属性,则获取该属性 if hasattr(self_module, op_name): op_cls = getattr(self_module, op_name) return op_cls(op_name, inputs, attrs) # common infer + # 定义一个字典,将PrimLib中的iter_type映射到对应的类名 class_name_map = { PrimLib.ELEMWISE: "_Elemwise", PrimLib.REDUCE: "_Reduce", } + # 获取op_name对应的iter_type cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None) + # 如果没有对应的iter_type,则抛出异常 if not cls_name: raise GKException("OpInfo does not support op {}".format(op_name)) + # 获取对应的类 op_cls = getattr(self_module, cls_name) return op_cls(op_name, inputs, attrs) + # 返回infer方法 return _create_opinfer().infer() @@ -55,49 +63,65 @@ class OpInfer: """ def __init__(self, name, inputs, attrs): + # 初始化函数,传入参数name、inputs、attrs self.name = name self.inputs = inputs self.attrs = attrs def infer(self): """Infer shape, type and format by op inputs""" + # 根据op的输入推断shape、type和format self._check() return self._infer_shape(), self._infer_type(), self._infer_format() def _infer_shape(self): + # 根据op的输入推断shape return self.inputs[0].shape def _infer_type(self): + # 根据op的输入推断type return self.inputs[0].dtype def _infer_format(self): + # 根据op的输入推断format return self.inputs[0].data_format def _check(self): + # 检查shape、type和format self._check_shape() self._check_type() self._check_format() def _check_shape(self): + # 检查shape pass def _check_type(self): """check all dtypes are same""" + # 获取第一个输入的dtype dtype = self.inputs[0].dtype + # 遍历剩下的输入 for i, t in enumerate(self.inputs[1:]): + # 如果当前输入的dtype与第一个输入的dtype不同,则抛出异常 if t.dtype != dtype: raise GKException( "Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype)) def _check_format(self): """check formats are compatible. only DefaultFormat is compatible with others""" + # 获取第一个输入的data_format result = self.inputs[0].data_format + # 初始化i为0 i = 0 + # 遍历剩下的输入 for j, t in enumerate(self.inputs[1:]): + # 如果当前输入的data_format与第一个输入的data_format不同,则进行判断 if t.data_format != result: + # 如果第一个输入的data_format和当前输入的data_format都不是DefaultFormat,则抛出异常 if DF.DEFAULT not in (result, t.data_format): raise GKException("Incompatible format between input {}({}) and {}({})".format( i, result, j + 1, t.data_format)) + # 如果第一个输入的data_format是DefaultFormat,则将result设置为当前输入的data_format,并将i设置为j+1 if result == DF.DEFAULT: result = t.data_format i = j + 1 @@ -109,17 +133,26 @@ class _Elemwise(OpInfer): @staticmethod def broadcast_shape(shapes): """deduce broadcast shape using same rules as numpy""" + # 计算所有shape的最大维度 dim_size = max(len(shape) for shape in shapes) + # 将所有shape扩展到最大维度,不足的部分用1填充 align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes] + # 初始化输出shape为全1 out_shape = [1] * dim_size + # 遍历每个维度 for i in range(dim_size): + # 遍历每个shape for align_shape in align_shapes: + # 如果当前维度为1,则跳过 if align_shape[i] == 1: continue + # 如果输出shape当前维度为1,则将输出shape当前维度设置为当前shape当前维度的值 if out_shape[i] == 1: out_shape[i] = align_shape[i] + # 如果输出shape当前维度和当前shape当前维度不相等,则抛出异常 elif out_shape[i] != align_shape[i]: raise GKException("Input shapes {} can not broadcast.".format(shapes)) + # 返回输出shape return out_shape @staticmethod @@ -174,9 +207,12 @@ class _Elemwise(OpInfer): .format(inputs_format)) def _infer_format(self): + # 遍历输入张量 for tensor in self.inputs: + # 如果张量的数据格式不是默认格式,则返回该数据格式 if tensor.data_format != DF.DEFAULT: return tensor.data_format + # 如果所有输入张量的数据格式都是默认格式,则返回默认格式 return DF.DEFAULT @@ -184,6 +220,7 @@ class _Reduce(OpInfer): """Common infer for reduction operators""" def _check(self): + # 调用父类的方法 super(_Reduce, self)._check() # check reduce axis in the range [-len, len) shape_len = len(self.inputs[0].shape) @@ -195,21 +232,29 @@ class _Reduce(OpInfer): "Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis)) def _infer_shape(self): + # 深度拷贝输入的形状 shape = copy.deepcopy(self.inputs[0].shape) + # 获取reduce_axis属性 axis = self.attrs['reduce_axis'] + # 如果axis是整数,则将其转换为列表 if isinstance(axis, int): axis = [axis] + # 如果axis中的元素小于0,则将其转换为非负数 if any(i < 0 for i in axis): # change the axis to non-negative number. + # 将axis中的负数转换为正数 axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis)) + # 将axis排序 self.attrs['reduce_axis'] = sorted(axis) + # 如果keep_dims为True,则将axis中的维度设置为1 if self.attrs['keep_dims']: for i in axis: shape[i] = 1 return shape + # 如果keep_dims为False,则将axis中的维度从shape中移除 real_shape = [] for i, s in enumerate(shape): if i not in axis: @@ -223,10 +268,14 @@ class _Reduce(OpInfer): class _Reshape(OpInfer): """Common infer for reshape operators, should not be instantiated""" + # 定义一个函数,用于推断形状 def _infer_shape(self): + # 抛出一个异常,提示子类需要实现这个函数 raise GKException("_infer_shape should be implemented by subclass") + # 定义一个函数,用于推断格式 def _infer_format(self): + # 如果attrs中不存在"format"这个属性,则返回DF.DEFAULT,否则返回attrs中"format"的值 return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"] @@ -234,14 +283,20 @@ class Reshape(_Reshape): """Reshape op infer""" def _check_shape(self): + # 获取输入形状 input_shape = self.inputs[0].shape + # 获取输出形状 output_shape = self.attrs["shape"] + # 计算输入形状的乘积 size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape) + # 计算输出形状的乘积 size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape) + # 如果输入形状的乘积不等于输出形状的乘积,则抛出异常 if size_before_reshape != size_after_reshape: raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape)) def _infer_shape(self): + # 返回输出形状 return self.attrs["shape"] @@ -249,6 +304,7 @@ class Cast(_Elemwise): """Cast op infer""" def _infer_type(self): + # 返回dst_type属性 return self.attrs["dst_type"] @@ -256,29 +312,38 @@ class InplaceAssign(_Elemwise): """InplaceAssign op infer""" def _infer_shape(self): + # 返回第3个输入的shape属性 return self.inputs[2].shape def _infer_type(self): + # 返回第3个输入的dtype属性 return self.inputs[2].dtype def _infer_format(self): + # 返回第3个输入的data_format属性 return self.inputs[2].data_format class BroadcastTo(OpInfer): """BroadcastTo op infer""" + # 定义一个函数,用于推断形状 def _infer_shape(self): + # 返回self.attrs字典中的"shape"键对应的值 return self.attrs["shape"] + # 定义一个函数,用于推断格式 def _infer_format(self): + # 返回self.inputs列表中第一个元素的data_format属性 return self.inputs[0].data_format class _CompareOp(_Elemwise): """Compare operators""" + # 定义一个函数,用于推断类型 def _infer_type(self): + # 返回类型为bool return "bool" @@ -286,11 +351,14 @@ class CImag(OpInfer): """CImag op infer""" def _check_type(self): + # 检查输入数据的类型是否为complex64 if self.inputs[0].dtype != "complex64": + # 如果不是,则抛出异常 raise GKException( "For 'CImag', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype)) def _infer_type(self): + # 返回数据类型为float32 return "float32" @@ -298,11 +366,14 @@ class CReal(OpInfer): """CReal op infer""" def _check_type(self): + # 检查输入数据的类型是否为complex64 if self.inputs[0].dtype != "complex64": + # 如果不是,则抛出异常 raise GKException( "For 'CReal', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype)) def _infer_type(self): + # 返回数据类型为float32 return "float32" @@ -310,14 +381,17 @@ class Complex(OpInfer): """Complex op infer""" def _check_type(self): + # 检查输入数据的类型是否为float32 if self.inputs[0].dtype != "float32": raise GKException( "For 'Complex', input[0] should be of type float32, but got {}".format(self.inputs[0].dtype)) + # 检查输入数据的类型是否一致 if self.inputs[0].dtype != self.inputs[1].dtype: raise GKException("For 'Complex', inputs data type mismatch ({} vs {})" .format(self.inputs[0].dtype, self.inputs[1].dtype)) def _infer_type(self): + # 返回复数类型 return "complex64" @@ -345,40 +419,53 @@ class Select(_Elemwise): """Select op infer""" def _check_type(self): + # 检查输入数据的类型 if self.inputs[0].dtype != "bool": + # 如果输入数据的类型不是bool,则抛出异常 raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype)) if self.inputs[1].dtype != self.inputs[2].dtype: + # 如果输入数据的类型不一致,则抛出异常 raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})" .format(self.inputs[1].dtype, self.inputs[2].dtype)) def _infer_type(self): + # 推断输入数据的类型 return self.inputs[1].dtype def check_format_any(formats, checked_format): """Check whether input format in formats list""" + # 检查输入格式是否在formats列表中 if not isinstance(formats, (list, tuple)): + # 如果formats不是list或tuple类型,则抛出异常 raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats))) if checked_format not in formats: + # 如果checked_format不在formats列表中,则抛出异常 raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats)) def check_nd(data, nd): """Check whether data are nd format""" + # 检查数据是否为nd格式 if not isinstance(data, (list, tuple)) or len(data) != nd: + # 如果数据不是list或tuple类型,或者数据的维度不等于nd,则抛出异常 raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data)) def conv_had_pad(pad_list, pad_mode): """Check whether conv need to add pad""" + # 检查pad_list是否为4D list或tuple,如果不是则抛出异常 if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4: raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list)) + # 如果pad_list的前两个元素不相等或后两个元素不相等,则返回True if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]: return True + # 如果pad_mode不是"VALID"或"valid",则遍历pad_list,如果有元素不为0,则返回True if pad_mode not in ["VALID", "valid"]: for _, pad in enumerate(pad_list): if pad != 0: return True + # 否则返回False return False @@ -386,38 +473,50 @@ class Conv2D(OpInfer): """Conv2D infer""" def _infer_type(self): + # 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值,否则返回输入的第一个元素的dtype if isinstance(self.attrs, dict) and "dst_type" in self.attrs: return self.attrs["dst_type"] return self.inputs[0].dtype def _infer_shape(self): + # 将输入的第一个和第二个元素的shape转换为list shape_0 = list(self.inputs[0].shape) shape_1 = list(self.inputs[1].shape) + # 检查shape_0和shape_1的维度是否为4 check_nd(shape_0, 4) check_nd(shape_1, 4) + # 检查输入的data_format是否为NHWC formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]] check_format_any(formats, DF.NHWC) + # 获取输入的n、h、w和out_channel n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0] + # 获取pad_list和pad_mode pad_list = self.attrs["pad_list"] pad_mode = self.attrs["pad_mode"] + # 获取kernel_size、stride和dilation kernel_size = self.attrs["kernel_size"] stride = self.attrs["stride"] dilation = self.attrs["dilation"] + # 检查pad_list、kernel_size、stride和dilation的维度是否为4、2、4和4 check_nd(pad_list, 4) check_nd(kernel_size, 2) check_nd(stride, 4) check_nd(dilation, 4) + # 调用conv_had_pad函数,判断是否需要pad has_pad = conv_had_pad(pad_list, pad_mode) + # 如果不需要pad,则将pad_list设置为[0, 0, 0, 0] if not has_pad: pad_list = [0, 0, 0, 0] + # 计算输出的h和w k_h = (kernel_size[0] - 1) * dilation[-2] + 1 k_w = (kernel_size[1] - 1) * dilation[-1] + 1 out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1 out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1 + # 返回输出的shape return [n, out_h, out_w, out_channel] @@ -425,23 +524,31 @@ class MatMul(OpInfer): """MatMul infer""" def _infer_type(self): + # 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值,否则返回输入的第一个元素的dtype if isinstance(self.attrs, dict) and "dst_type" in self.attrs: return self.attrs["dst_type"] return self.inputs[0].dtype def _infer_shape(self): + # 将输入的第一个和第二个元素的shape转换为list shape_0 = list(self.inputs[0].shape) shape_1 = list(self.inputs[1].shape) + # 检查shape_0和shape_1的维度是否为2 if len(shape_0) != 2 or len(shape_1) != 2: raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}" .format(shape_0, shape_1)) + # 获取transpose_a和transpose_b transpose_a = self.attrs["transpose_a"] transpose_b = self.attrs["transpose_b"] + # 根据transpose_a和transpose_b获取m、k1、k2和n m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1]) k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1]) + # 如果k1和k2不相等,则抛出异常 if k1 != k2: raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2)) + # 计算输出的shape output_shape = [m, n] + # 返回输出的shape return output_shape @@ -449,14 +556,20 @@ class PadAkg(OpInfer): """PadAkg infer""" def _infer_shape(self): + # 将输入的第一个元素的shape转换为list shape = list(self.inputs[0].shape) + # 获取输入的维度 n = len(shape) + # 获取pad_before和pad_after pad_before = list(self.attrs["head"]) pad_after = list(self.attrs["tail"]) + # 检查pad_before和pad_after的维度是否与输入的维度相等 if len(pad_before) != n or len(pad_after) != n: raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d" .format(n, len(pad_before), len(pad_after))) + # 计算输出的shape out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)] + # 返回输出的shape return out_shape @@ -464,13 +577,19 @@ class UnPadAkg(OpInfer): """UnPadAkg infer""" def _infer_shape(self): + # 将输入的第一个元素的shape转换为list shape = list(self.inputs[0].shape) + # 获取输入的维度 n = len(shape) + # 获取unpad_after unpad_after = list(self.attrs["tail"]) + # 检查unpad_after的维度是否与输入的维度相等 if len(unpad_after) != n: raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d" .format(n, len(unpad_after))) + # 计算输出的shape out_shape = [shape[i] - unpad_after[i] for i in range(n)] + # 返回输出的shape return out_shape @@ -478,23 +597,32 @@ class Gather(OpInfer): """Gather infer""" def _infer_shape(self): + # 获取输入的第一个和第二个元素的shape input_shape = self.inputs[0].shape indices_shape = self.inputs[1].shape + # 获取axis axis = self.attrs['axis'] + # 将输出的shape设置为输入的第一个元素的shape output_shape = input_shape + # 计算indices_shape的维度 indices_shape_one_dim = 1 for dim in indices_shape: indices_shape_one_dim *= dim + # 将输出的shape的axis维度设置为indices_shape的维度 output_shape[axis] = indices_shape_one_dim + # 返回输出的shape return output_shape def _infer_type(self): + # 返回输入的第一个元素的dtype return self.inputs[0].dtype def _infer_format(self): + # 返回输入的第一个元素的data_format return self.inputs[0].data_format def _check_type(self): + # 检查输入的第二个元素的dtype是否为int32,如果不是则抛出异常 if self.inputs[1].dtype != "int32": raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}" .format(self.inputs[1].dtype)) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/parallel_estimate.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/parallel_estimate.py index c3b307cf..a96421c0 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/parallel_estimate.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/parallel_estimate.py @@ -21,24 +21,48 @@ from . import model def estimate_ops(json_str): + """ + 估计操作数。 + + Args: + json_str (str): 包含图描述的json字符串。 + + Returns: + tuple: 包含估计结果的元组,包括块分配、增益、融合类型和类型信息的元组。 + + Raises: + JSONDecodeError: 如果输入的json字符串无法解码,将引发此异常。 + + """ """Call cost model to estimate ops.""" try: + # 将json字符串转换为json对象 json_obj = json.loads(json_str) + # 获取json对象中的graph_desc graph_descs = json_obj["graph_desc"] + # 初始化graphs和target graphs = [] target = None + # 遍历graph_descs for gd in graph_descs: + # 如果target为空,则将gd['process']赋值给target if target is None: target = gd['process'] + # 如果target不为空,且gd['process']与target不同,则输出错误信息 elif target != gd['process']: logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process'])) return None + # 将model.load_composite(gd).graph添加到graphs中 graphs.append(model.load_composite(gd).graph) + # 调用model.parallel_estimate函数,传入graphs和target,获取estimation estimation = model.parallel_estimate(graphs, target) + # 将estimation的block_assign、gain、fusion_type和type_info赋值给res res = (estimation.block_assign, estimation.gain, estimation.fusion_type, estimation.type_info) + # 返回res return res except jd.JSONDecodeError: + # 如果出现JSONDecodeError,则输出错误信息 logger.error(traceback.format_exc()) return None finally: @@ -46,14 +70,33 @@ def estimate_ops(json_str): def estimate_calculation_amount(json_str): + """ + 估计操作计算量的函数。 + + Args: + json_str (str): 包含操作描述的JSON字符串。 + + Returns: + int: 计算量的估计值,如果解析JSON字符串失败,则返回-1。 + + Raises: + 无 + + """ """Call cost model to estimate calculation amount of op.""" try: + # 将json字符串转换为json对象 graph_desc = json.loads(json_str) + # 获取json对象中的process target = graph_desc['process'] + # 调用model.load_composite函数,传入graph_desc,获取comp comp = model.load_composite(graph_desc) + # 调用model.parallel_estimate函数,传入comp.graph和target,获取estimation estimation = model.parallel_estimate([comp.graph], target) + # 返回estimation的bottleneck return estimation.bottleneck except jd.JSONDecodeError: + # 如果出现JSONDecodeError,则输出错误信息 logger.error(traceback.format_exc()) return -1 finally: diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/splitter.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/splitter.py index e3da3ec7..258c36a1 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/splitter.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/splitter.py @@ -24,6 +24,20 @@ from . import utils def split_with_json(json_str, flags_str): + """ + 根据JSON字符串分割GraphKernel + + Args: + json_str (str): 包含GraphKernel描述的JSON字符串。 + flags_str (str): 包含分割标志的JSON字符串。 + + Returns: + str: 包含分割结果的JSON字符串。 + + Raises: + jd.JSONDecodeError: 如果json_str或flags_str无法被解析为JSON格式,将引发此异常。 + + """ """Call cost model to split GraphKernel""" try: graph_desc = json.loads(json_str) @@ -45,6 +59,21 @@ def split_with_json(json_str, flags_str): def _reset_graphmode_for_inplaceassign(graph_list, graph_mode): + """ + 重置具有 InplaceAssign 操作符的图模式。 + + Args: + graph_list (list): 包含图的列表,每个图都是一个包含操作描述的字典。 + graph_mode (list): 图模式列表,每个元素表示对应图的模式。 + + Returns: + None + + Notes: + 具有 InplaceAssign 操作符的操作应始终为复合操作。 + 对于包含 InplaceAssign 操作符的图,将其模式设置为 'composite'。 + + """ """Operator with InplaceAssign should always be composite op""" for i, g in enumerate(graph_list): if any((op['name'] == 'InplaceAssign' for op in g['op_desc'])): @@ -52,6 +81,20 @@ def _reset_graphmode_for_inplaceassign(graph_list, graph_mode): def _dump_split_info(flags, graph_json, graph_desc, subgraphs, graph_mode): + """ + 将分割信息以文本形式输出 + + Args: + flags (dict): 包含配置信息的字典 + graph_json (str): 图结构的JSON字符串 + graph_desc (object): 图描述对象 + subgraphs (list): 子图列表 + graph_mode (list): 图模式列表 + + Returns: + None + + """ """Dump split info as text""" if not flags.get("dump_as_text", False): return diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/utils.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/utils.py index 7d4cc7ae..1102a283 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/utils.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/utils.py @@ -19,6 +19,19 @@ GRAPH_KERNEL_DUMP_PATH = "graph_kernel_dump" def create_dir(pathname): + """ + 尝试创建目录 + + Args: + pathname (str): 要创建的目录的路径。 + + Returns: + None + + Raises: + 不显式抛出异常。 + + """ """Try to create directory""" if os.path.exists(pathname): return diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/__init__.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/__init__.py index c336f0da..d6db398f 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/__init__.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/__init__.py @@ -16,4 +16,8 @@ Extension functions. Python functions that will be called in the c++ parts of MindSpore. +扩展函数。 + +这些Python函数将在MindSpore的C++部分中被调用。 + """ diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py index 75d71b21..d10bdd7b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ + """akg process""" + import os import json import subprocess @@ -24,10 +26,10 @@ from mindspore._extends.parallel_compile.akg_compiler.get_file_path import get_a def _compile_akg_task_default(json_strs, attrs): """ - compile func called in single process + 编译函数,在单个进程中调用 - Parameters: - json_strs: list. List contains multiple kernel infos, suitable for json compile api. + 参数: + json_strs:列表。包含多个内核信息的列表,适用于json编译API。 """ sys.path.insert(0, get_akg_path()) @@ -37,15 +39,15 @@ def _compile_akg_task_default(json_strs, attrs): for json_str in json_strs: res = func(json_str, attrs) if not res: - raise ValueError("Compile error, args: {}! build attrs: {}".format(json_str, attrs)) + raise ValueError("编译错误,参数:{}!构建属性:{}".format(json_str, attrs)) def _compile_akg_task_ascend(json_strs, attrs): """ - compile func called in single process + 编译函数,在单个进程中调用 - Parameters: - json_strs: list. List contains multiple kernel infos, suitable for json compile api. + 参数: + json_strs:列表。包含多个内核信息的列表,适用于json编译API。 """ if attrs is None: attrs = "{}" @@ -56,35 +58,33 @@ def _compile_akg_task_ascend(json_strs, attrs): if compile_result.returncode: json_dict = json.loads(json_str) if not json_dict.get("composite"): - raise ValueError("Compile error, json str: {}! build attrs: {}".format(json_str, attrs)) - logger.debug("Will try to split, json str: {}! build attrs: {}".format(json_str, attrs)) + raise ValueError("编译错误,json字符串:{}!构建属性:{}".format(json_str, attrs)) + logger.debug("将尝试拆分,json字符串:{}!构建属性:{}".format(json_str, attrs)) def create_akg_parallel_process(process_num, wait_time, platform): """ - create AkgParallelCompiler object + 创建AkgParallelCompiler对象 - Returns: + 返回: AkgParallelCompiler """ return AkgProcess(process_num, wait_time, platform) class AkgProcess: - """akg kernel parallel process""" + """akg内核并行进程""" def __init__(self, process_num, wait_time, platform): """ - Args: - process_num: int. processes number - wait_time: int. max time the function blocked + 参数: + process_num:int。进程数量 + wait_time:int。函数阻塞的最大时间 """ if not isinstance(process_num, int): - raise ValueError("AKG kernel compiling process number must be of type int, but got {} with type {}" - .format(process_num, type(wait_time))) + raise ValueError("AKG内核编译进程数量必须是int类型,但得到的是{},类型为{}".format(process_num, type(wait_time))) if not isinstance(wait_time, int): - raise ValueError("AKG kernel compiling wait time must be of type int, but got {} with type {}" - .format(wait_time, type(wait_time))) + raise ValueError("AKG内核编译等待时间必须是int类型,但得到的是{},类型为{}".format(wait_time, type(wait_time))) if process_num == 0: process_num = 1 max_proc_num = 16 @@ -96,13 +96,12 @@ class AkgProcess: def compile(self, attrs=None): """ - compile kernel by multi processes - Return: - True for all compile success, False for some failed. + 多进程编译内核 + 返回: + 所有编译成功返回True,部分失败返回False。 """ if self.argc == 0: - raise ValueError("In AKG kernel compiling, the number of kernel json that need to be compiled can " - "not be zero.") + raise ValueError("在AKG内核编译中,需要编译的内核json数量不能为零。") args = list((arg, attrs) for arg in self.args) if self.platform == "ASCEND": with Pool(processes=self.process_num) as pool: @@ -116,12 +115,11 @@ class AkgProcess: def accept_json(self, json_str): """ - accept json data before compile - Args: - json_str: str. kernel info. + 在编译前接受内核的json数据 + 参数: + json_str:str。内核信息。 """ if not isinstance(json_str, str): - raise ValueError("In AKG kernel compiling, the kernel json must be of type str, but got {} with type {}" - .format(json, type(json))) + raise ValueError("在AKG内核编译中,内核json必须是str类型,但得到的是{},类型为{}".format(json_str, type(json_str))) self.args[self.argc % self.process_num].append(json_str) self.argc += 1 diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/compiler.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/compiler.py index 55bdded4..053ef9a0 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/compiler.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/compiler.py @@ -28,16 +28,23 @@ def run_compiler(op_json, attrs=None): None """ from get_file_path import get_akg_path + # 将akg路径添加到sys.path中 sys.path.insert(0, get_akg_path()) + # 导入akg模块 p = __import__("akg", globals(), locals(), ['ms'], 0) + # 获取akg.ms.compilewithjson函数 func = getattr(p.ms, "compilewithjson") + # 调用akg.ms.compilewithjson函数进行编译 res = func(op_json, attrs) + # 如果编译失败,抛出异常 if not res: raise ValueError("Compile error") if __name__ == "__main__": + # 如果命令行参数大于2,则调用run_compiler函数,传入op_json和attrs if len(sys.argv) > 2: run_compiler(sys.argv[1], sys.argv[2]) + # 否则,只传入op_json else: run_compiler(sys.argv[1]) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py index 452b0d15..d04136c6 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py @@ -19,18 +19,27 @@ import os def get_akg_path(): """get akg directory base path""" + # 提示信息,如果找不到mindspore模块,请检查1)MindSpore是否成功编译。2)MindSpore是否成功安装,使用pip install安装或设置环境变量PYTHONPATH为${mindspore_build_dir}/package hint = "Please check: 1) whether MindSpore is compiled successfully. " \ "2) Whether MindSpore is installed successfully with pip install or " \ "the path ${mindspore_build_dir}/package is set in env PYTHONPATH." + # 查找mindspore模块 search_res = importlib.util.find_spec("mindspore") if search_res is None: + # 如果找不到mindspore模块,抛出异常 raise RuntimeError("Cannot find mindspore module! {}".format(hint)) + # 获取mindspore模块的路径 res_path = search_res.origin + # 在路径中查找__init__.py文件 find_pos = res_path.find("__init__.py") if find_pos == -1: + # 如果找不到__init__.py文件,抛出异常 raise RuntimeError("Find module mindspore origin file failed! {}".format(hint)) + # 获取akg路径 akg_path = "{}_akg".format(res_path[:find_pos]) + # 如果akg路径不存在,抛出异常 if not os.path.isdir(akg_path): raise RuntimeError("Cannot find akg from mindspore module! {}".format(hint)) + # 返回akg路径 return akg_path diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py index 926aab76..b55a2f1b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """tbe adapter to adapt te/topi/auto-tune python api """ +# 导入必要的库和模块 import json import os import shutil @@ -20,33 +21,62 @@ import sys import traceback from datetime import datetime +# 导入TBE相关的库和模块 from tbe.common.rl_bank.bank_manager import set_current_op_name from tbe.common.repository_manager.interface import cann_kb_unload, cann_kb_load from tbe.common.rl_bank.bank_cfg import LocalLock from te.platform.cce_conf import te_set_version from te.platform.cce_policy import set_L1_info -from te_fusion.compile_task_manager import dispatch_prebuild_task, dispatch_single_op_compile_task, import_py_module, \ - dispatch_fusion_op_compile_task, dispatch_autotune_task, sync_op_tune_params -from te_fusion.compile_task_manager import sync_syspath -from te_fusion.fusion_manager import call_op_func, clear_fusion_params, check_op_impl_mode, \ - save_op_params, build_single_op_from_c, op_params_to_json +from te_fusion.compile_task_manager import ( + dispatch_prebuild_task, + dispatch_single_op_compile_task, + import_py_module, + dispatch_fusion_op_compile_task, + dispatch_autotune_task, + sync_op_tune_params, + sync_syspath +) +from te_fusion.fusion_manager import ( + call_op_func, + clear_fusion_params, + check_op_impl_mode, + save_op_params, + build_single_op_from_c, + op_params_to_json +) from te_fusion.fusion_util import dump_fusion_json -from te_fusion.parallel_compilation import init_multi_process_env, start_ga_multi_process, deinit_multi_process_env, \ +from te_fusion.parallel_compilation import ( + init_multi_process_env, + start_ga_multi_process, + deinit_multi_process_env, get_finished_compilation_task - -from .tbe_helper import get_soc_info, assemble_op_args, get_compute_op_list, get_options_info, get_fuzz_build_info, \ - adjust_custom_op_info, pack_op_args, get_module_name, get_real_op_debug_level +) +from .tbe_helper import ( + get_soc_info, + assemble_op_args, + get_compute_op_list, + get_options_info, + get_fuzz_build_info, + adjust_custom_op_info, + pack_op_args, + get_module_name, + get_real_op_debug_level +) from .tbe_job import TbeJob, JobStatus -PLATFORM_FLAG = ["Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403"] - +# 定义支持的平台标志 +PLATFORM_FLAG = [ + "Ascend310", "Ascend910", "Hi3796CV300ES", "Ascend710", "Ascend610", "Hi3796CV300CS", "SD3403" +] +# 定义Tune初始化函数 def _tune_init(job: TbeJob): """ - Tune Initialize - :param job: - :return: + Tune初始化 + :param job: TbeJob对象,包含任务信息 + :return: 初始化是否成功 """ + # 提取Soc信息和Tune信息 auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"] offline_tune = job.content["SocInfo"]["offlineTune"] op_bank_update = job.content["SocInfo"]["op_bank_update"] @@ -54,11 +84,14 @@ def _tune_init(job: TbeJob): tune_bank_path = job.content["TuneInfo"]["tune_bank_path"] need_ga = bool("GA" in auto_tiling_mode) need_rl = bool("RL" in auto_tiling_mode) + + # 设置环境变量 if offline_tune: os.environ["ENABLE_TUNE_DUMP"] = "TRUE" if op_bank_update: sync_op_tune_params("tbe.common.tiling.tiling_api", "reset_repository", False, "") + # 初始化Tune环境 if need_ga or need_rl or offline_tune: res = __init_tune_env(job, need_ga) if not res: @@ -66,6 +99,7 @@ def _tune_init(job: TbeJob): else: return True + # 设置Tune路径 if tune_dump_path: os.environ["TUNE_DUMP_PATH"] = str(tune_dump_path) if tune_bank_path: @@ -73,12 +107,12 @@ def _tune_init(job: TbeJob): res = _creating_custom_path(job) return res - +# 定义CANN知识库加载函数 def _cann_kb_load(job: TbeJob): """ - database load - :param job: - :return: + 加载CANN知识库 + :param job: TbeJob对象,包含任务信息 + :return: 加载是否成功 """ soc_version = job.soc_version core_num = job.core_num @@ -87,12 +121,12 @@ def _cann_kb_load(job: TbeJob): res = cann_kb_load(soc_version, core_num, op_bank_path, kb_type) return res - +# 定义CANN知识库卸载函数 def _cann_kb_unload(job: TbeJob): """ - database unload - :param job: - :return: + 卸载CANN知识库 + :param job: TbeJob对象,包含任务信息 + :return: 卸载是否成功 """ if job is None: return 0 @@ -102,12 +136,12 @@ def _cann_kb_unload(job: TbeJob): res = cann_kb_unload(soc_version, core_num, kb_type) return res - +# 定义移除缓存文件函数 def _remove_cache(job: TbeJob): """ - :param job: remove cache file:[*.json, *.o, *.info, *.cce] when "op_debug_level" is "0" - op_debug_level: representation the env MS_COMPILER_OP_LEVEL - :return: + 移除缓存文件 + :param job: TbeJob对象,包含任务信息 + :return: 无 """ op_debug_level = job.content["SocInfo"]["op_debug_level"] op_debug_dir = job.content["SocInfo"]["op_debug_dir"] @@ -118,24 +152,30 @@ def _remove_cache(job: TbeJob): real_path = os.path.join(root_path, "kernel_meta/") shutil.rmtree(real_path) - +# 定义创建目录函数 def __directory_creation(path, concat_path): """ - Create directory + 创建目录 + :param path: 基础路径 + :param concat_path: 需要连接的路径 + :return: 创建后的完整路径 """ path = os.path.join(path, concat_path) if not os.path.isdir(path): os.makedirs(path, 0o750) return path - +# 定义初始化Tune环境函数 def __init_tune_env(job, need_ga): """ - Initialize tune env + 初始化Tune环境 + :param job: TbeJob对象,包含任务信息 + :param need_ga: 是否需要GA + :return: 初始化是否成功 """ try: import auto_tune.auto_tune_main as at_atm - from schedule_search.rl_online_tune import rl_tune_init # pylint: disable=unused-import + from schedule_search.rl_online_tune import rl_tune_init if need_ga: res = at_atm.ga_tune_init() if not res: @@ -157,10 +197,13 @@ def __init_tune_env(job, need_ga): finally: pass - +# 定义创建默认自定义路径函数 def __creating_default_custom_path(auto_tiling_mode, base_custom_path): """ - Create default custom path + 创建默认自定义路径 + :param auto_tiling_mode: 自动平铺模式 + :param base_custom_path: 基础自定义路径 + :return: 无 """ base_custom_path = __directory_creation(base_custom_path, "data") tune_flag = [] @@ -179,27 +222,40 @@ def __creating_default_custom_path(auto_tiling_mode, base_custom_path): def _creating_custom_path(job): """ - Create custom path + 创建自定义路径,用于存储和检索自定义算子的调优参数。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 自定义路径创建是否成功。 """ + # 获取自动平铺模式 auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"] + # 如果模式中包含"NO_TUNE",则不需要创建自定义路径 if "NO_TUNE" in auto_tiling_mode: return True + # 获取调优参数的基础路径 base_custom_path = job.content["TuneInfo"]["tune_bank_path"] tune_bank_flag = True + # 如果基础路径不存在,则尝试从auto_tune模块获取 if not base_custom_path: import auto_tune base_custom_path = os.path.dirname(os.path.realpath(auto_tune.__file__)) base_custom_path = os.path.realpath(os.path.join(base_custom_path, "../../../")) tune_bank_flag = False + # 检查基础路径是否存在 if not os.path.isdir(base_custom_path): job.error("Check whether the tuning path [{}] exists.".format(base_custom_path)) return False + # 检查基础路径的权限 if not os.access(base_custom_path, os.R_OK | os.W_OK | os.X_OK): job.error("Check whether the permission on the tuning path [{}] is correct.".format(base_custom_path)) return False + # 如果不需要创建调优参数库,则直接返回成功 if not tune_bank_flag: return __creating_default_custom_path(auto_tiling_mode, base_custom_path) return True @@ -207,22 +263,34 @@ def _creating_custom_path(job): def _parallel_compilation_init(initialize: TbeJob): """ - Tbe parallel compilation initialize - :param initialize: - :return: + 初始化TBE并行编译环境。 + + Args: + initialize (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 并行编译环境初始化是否成功。 """ + # 设置并行编译器的环境变量 os.environ["TE_PARALLEL_COMPILER"] = str(initialize.content["process_num"]) + # 获取SoC信息 soc_info = get_soc_info(initialize.content) + # 获取实际的调试级别 real_debug_level = get_real_op_debug_level(initialize.content) + # 获取自动平铺模式 auto_tiling_mode = initialize.content["SocInfo"]["autoTilingMode"] + # 获取是否需要离线调优 offline_tune = initialize.content["SocInfo"]["offlineTune"] + # 生成进程ID和时间戳的组合字符串 pid_ts = "{}_pid{}".format(datetime.now().strftime('%Y%m%d_%H%M%S%f')[:-3], os.getpid()) + # 初始化多进程环境 ret = init_multi_process_env(False, soc_info, auto_tiling_mode, real_debug_level, None, 1, pid_ts) if ret is None: initialize.error("Init multiprocess env failed") return False initialize.info("Init multiprocess env success with {} process".format(ret[0])) + # 如果需要RL或离线调优,则初始化RL环境 if "RL" in auto_tiling_mode or offline_tune: res_queue = ret[1] live_checker = ret[2] @@ -234,6 +302,7 @@ def _parallel_compilation_init(initialize: TbeJob): initialize.error("RL env init failed!") return False initialize.info("RL Tune init success.") + # 如果需要GA,则启动GA多进程 if "GA" in auto_tiling_mode: start_ga_multi_process(auto_tiling_mode) initialize.info("GA Tune init success.") @@ -242,31 +311,44 @@ def _parallel_compilation_init(initialize: TbeJob): def tbe_initialize(job: TbeJob): """ - Tbe Initialize - :param job: - :return: + 初始化TBE环境。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: TBE环境初始化是否成功。 """ + # 设置上下文模型编译环境变量 os.environ["CONTEXT_MODELCOMPILING"] = "TRUE" + # 获取SoC信息 soc_info = get_soc_info(job.content) + # 设置版本 res = te_set_version(*soc_info) if not res: job.error("Set version failed") + # 初始化调优环境 res = _tune_init(job) if not res: job.error("Tune init failed") + # 创建锁文件 lock_file = os.path.join(job.content["SocInfo"]["op_debug_dir"], "kernel_meta", "file.lock") local_lock = LocalLock(lock_file) try: + # 加锁 local_lock.lock() + # 加载CANN知识库 res = _cann_kb_load(job) if res == 1: job.error("Cann kb load failed") + # 初始化并行编译 res = _parallel_compilation_init(job) if not res: job.error("Parallel compilation failed") except RuntimeError: job.error("Initialize failed with RuntimeError") finally: + # 解锁 local_lock.unlock() job.result = "Success" return res @@ -274,9 +356,13 @@ def tbe_initialize(job: TbeJob): def get_auto_tune_support_op_list(job: TbeJob): """ - Get GA tune supported op list - :param job: - :return: + 获取支持自动调优的算子列表。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + list: 支持自动调优的算子列表。 """ from auto_tune_main import enable_auto_tune_support auto_tune_op_list = enable_auto_tune_support() @@ -286,10 +372,14 @@ def get_auto_tune_support_op_list(job: TbeJob): def _normalize_module_name(module_name, py_module_path): """ - Normalize module name - :param module_name: - :param py_module_path: - :return: + 规范化模块名称。 + + Args: + module_name (str): 模块名称。 + py_module_path (str): Python模块路径。 + + Returns: + None """ if py_module_path not in sys.path: sys.path.insert(0, py_module_path) @@ -298,9 +388,13 @@ def _normalize_module_name(module_name, py_module_path): def check_support(job: TbeJob): """ - Check support - :param job: - :return: + 检查算子是否受支持。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 算子是否受支持。 """ op_compute_info_list = get_compute_op_list(job.content) if len(op_compute_info_list) != 1: @@ -341,21 +435,37 @@ def check_support(job: TbeJob): def select_op_format(job: TbeJob): """ Select op format - :param job: - :return: + 根据计算操作信息选择操作的格式。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 操作格式选择是否成功。 """ + # 获取计算操作列表 compute_op_info_list = get_compute_op_list(job.content) + # 检查计算操作数量是否为1 if len(compute_op_info_list) != 1: job.error("Invalid op compute num ({}) in check_support".format(len(compute_op_info_list))) return False + # 获取第一个计算操作信息 compute_op_info = compute_op_info_list[0] + # 调整自定义操作信息 adjust_custom_op_info(compute_op_info) + # 组装操作参数 inputs, outputs, attrs = assemble_op_args(compute_op_info) + # 获取操作模块名称 op_module_name = get_module_name(compute_op_info) + # 获取Python模块路径 py_module_path = compute_op_info["py_module_path"] + # 规范化模块名称 _normalize_module_name(op_module_name, py_module_path) + # 设置操作选择格式的函数名称 op_func_name = "op_select_format" + # 调用操作函数选择格式 res = call_op_func((inputs, outputs, attrs), op_module_name, op_func_name) + # 设置操作格式选择结果 job.result = str(res) return True @@ -363,15 +473,25 @@ def select_op_format(job: TbeJob): def parallel_pre_compile_op(job: TbeJob): """ Parallel pre compile op - :param job: - :return: + 并行预编译操作。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 预编译操作是否成功。 """ + # 获取计算操作列表 compute_op_info_list = get_compute_op_list(job.content) + # 检查计算操作数量是否为1 if len(compute_op_info_list) != 1: job.error("Invalid op compute num ({}) in pre compile op".format(len(compute_op_info_list))) return False + # 获取第一个计算操作信息 compute_op_info = compute_op_info_list[0] + # 调整自定义操作信息 adjust_custom_op_info(compute_op_info) + # 预构建计算操作信息 _pre_build_compute_op_info(compute_op_info, job) return True @@ -379,35 +499,60 @@ def parallel_pre_compile_op(job: TbeJob): def _pre_build_compute_op_info(compute_op, job): """ Prebuild by compute op info - :param compute_op: - :param job: - :return: + 根据计算操作信息预构建操作。 + + Args: + compute_op (dict): 计算操作信息。 + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + None """ + # 获取L1缓存大小 l1_size = job.content["l1_size"] + # 如果L1缓存大小不为-1,则设置L1缓存信息 if l1_size != -1: set_L1_info("op_L1_space", -1) + # 组装操作参数 inputs, outputs, attrs = assemble_op_args(compute_op, is_single_op_build=True) + # 获取操作模块名称 op_module_name = get_module_name(compute_op) + # 获取Python模块路径 py_module_path = compute_op["py_module_path"] + # 获取操作函数名称 op_func_name = compute_op["func_name"] + # 获取操作类型 op_type = compute_op["type"] + # 获取操作名称 op_name = compute_op["op_name"] + # 保存操作参数 save_op_params(op_name, "prebuild", (outputs, attrs)) - l1_size = job.content["l1_size"] + # 设置L1缓存信息 set_L1_info("op_L1_space", l1_size) + # 规范化模块名称 _normalize_module_name(op_module_name, py_module_path) + # 获取未知形状信息 unknown_shape = compute_op["unknown_shape"] + # 获取int64模式信息 int64_mode = compute_op["int64mode"] + # 检查操作实现模式 res = check_op_impl_mode(op_module_name, op_func_name) + # 获取操作实现模式 op_impl_mode = job.content["SocInfo"]["op_impl_mode"] + # 获取操作实现模式列表 op_impl_mode_list = job.content["SocInfo"]["op_impl_mode_list"] + # 获取完整操作名称 op_full_name = job.content["full_name"] + # 如果操作不支持实现模式,则发出警告 if not res: if op_impl_mode_list: job.warning("The op {} do NOT support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode)) else: + # 否则,记录操作支持实现模式的信息 job.info("OpType {} support op_impl_mode, current op_impl_mode:{}".format(op_type, op_impl_mode)) + # 获取选项信息 options = get_options_info(job.content) + # 分派预构建任务 dispatch_prebuild_task(job.source_id, job.id, l1_size, op_module_name, op_full_name, op_type, op_func_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode, unknown_shape, @@ -416,13 +561,22 @@ def _pre_build_compute_op_info(compute_op, job): def get_prebuild_output(op_name): """ - get prebuild output - :param op_name: + Get prebuild output + 获取预构建输出。 + + Args: + op_name (str): 操作名称。 + + Returns: + dict: 预构建输出。 """ + # 将操作参数转换为JSON字符串 params_str = op_params_to_json(op_name) try: + # 尝试解析JSON字符串 res = json.loads(params_str) except ValueError: + # 如果解析失败,则返回空字典 res = {} finally: pass @@ -432,9 +586,15 @@ def get_prebuild_output(op_name): def do_fuzz_build_tbe_op(job: TbeJob): """ Fuzzy build op - :param job: - :return: + 模糊构建操作。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 模糊构建操作是否成功。 """ + # 设置操作结果为"NOT_CHANGED" job.result = "NOT_CHANGED" return True @@ -442,9 +602,15 @@ def do_fuzz_build_tbe_op(job: TbeJob): def _dump_fusion_op_info_to_json_file(job: TbeJob): """ Dump fusion op info to json file - :param job: - :return: + 将融合操作信息转储到JSON文件。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + None """ + # 如果系统参数调试路径不为空,则转储融合操作信息 if not job.sys_para_debug_path or job.sys_para_debug_path == "\0": return dump_fusion_json(json.dumps(job.content), job.sys_para_debug_path) @@ -453,30 +619,55 @@ def _dump_fusion_op_info_to_json_file(job: TbeJob): def build_single_pre_op(job: TbeJob): """ Build single op - :param job: - :return: + 构建单个操作的预处理过程。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 构建过程是否成功。 """ + # 执行构建前的处理工作 before_build_process(job) + # 获取计算操作列表 compute_op_info_list = get_compute_op_list(job.content) + # 确保只有一个计算操作 if len(compute_op_info_list) != 1: job.error("Invalid op compute num ({}) in build single op".format(len(compute_op_info_list))) return False + # 获取单个计算操作信息 compute_op_info = compute_op_info_list[0] + # 调整自定义操作信息 adjust_custom_op_info(compute_op_info) + # 组装操作的输入、输出和属性 inputs, outputs, attrs = assemble_op_args(compute_op_info, is_single_op_build=True) + # 获取操作类型 op_type = compute_op_info["type"] + # 获取L1缓存大小 l1_size = job.content["l1_size"] + # 获取操作模块名称 op_module_name = get_module_name(compute_op_info) + # 获取操作内核名称 op_kernel_name = compute_op_info["op_name"] + # 获取Python模块路径 py_module_path = compute_op_info["py_module_path"] + # 获取完整操作名称 op_name = job.content["full_name"] + # 获取操作函数名称 op_func_name = compute_op_info["func_name"] + # 规范化模块名称 _normalize_module_name(op_module_name, py_module_path) + # 获取未知形状信息 unknown_shape = compute_op_info["unknown_shape"] + # 获取int64模式信息 int64_mode = compute_op_info["int64mode"] + # 获取操作模式 op_pattern = compute_op_info["pattern"] + # 获取选项信息 options = get_options_info(job.content) + # 获取模糊构建信息 fuzz_build_info = get_fuzz_build_info(job.content) + # 分派单个操作编译任务 dispatch_single_op_compile_task(job.source_id, job.id, l1_size, op_module_name, op_name, op_type, op_func_name, op_kernel_name, unknown_shape, (inputs, outputs, attrs, options), int64_mode, None, None, unknown_shape, op_pattern, @@ -487,13 +678,22 @@ def build_single_pre_op(job: TbeJob): def before_build_process(job: TbeJob): """ Processing before build - :param job: - :return: + 在构建前进行处理。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + None """ + # 获取L1缓存大小并设置 l1_size = job.content["l1_size"] set_L1_info("op_L1_space", l1_size) + # 将融合操作信息转储到JSON文件 _dump_fusion_op_info_to_json_file(job) + # 获取是否需要离线调优 offline_tune = job.sys_offline_tune + # 如果需要离线调优,则将融合操作信息转储到JSON文件 if offline_tune: dump_fusion_json(json.dumps(job.content), job.sys_tune_dump_path) @@ -501,20 +701,29 @@ def before_build_process(job: TbeJob): def sync_fusion_env(fusion_need_sync, module_list): """ Sync fusion env - :param fusion_need_sync: - :param module_list: - :return: + 同步融合环境。 + + Args: + fusion_need_sync (int): 是否需要同步融合环境。 + module_list (dict): 模块列表。 + + Returns: + bool: 同步是否成功。 """ + # 如果不需要同步,则直接返回成功 if fusion_need_sync == 0: return True + # 准备使用的模块列表 module_using = [] for key, value in module_list.items(): if value > 0: module_using.append(str(key)) module_list[key] = 0 + # 将使用的模块列表转换为字符串 module_str = ",".join(module_using) + # 导入使用的模块 import_py_module(module_str) return True @@ -522,13 +731,23 @@ def sync_fusion_env(fusion_need_sync, module_list): def parallel_compile_fusion_op(job: TbeJob): """ Compile fusion op in parallel compiler - :param job: - :return: + 在并行编译器中编译融合操作。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 编译过程是否成功。 """ + # 获取L1缓存大小 l1_size = job.content["l1_size"] + # 获取选项信息 options = get_options_info(job.content) + # 获取融合操作内核名称 op_kernel_name = job.content["fusion_op_name"] + # 获取完整操作名称 op_name = job.content["full_name"] + # 分派融合操作编译任务 dispatch_fusion_op_compile_task(job.source_id, job.id, l1_size, json.dumps(job.content), op_kernel_name, None, None, options, None, job.pass_list, op_name) return True @@ -537,112 +756,185 @@ def parallel_compile_fusion_op(job: TbeJob): def ga_tune(job: TbeJob): """ GA tune - :param job: - :return: + 使用遗传算法进行调优。 + + Args: + job (TbeJob): 包含任务信息的TbeJob对象。 + + Returns: + bool: 调优过程是否成功。 """ + # 获取L1缓存大小 l1_size = job.content["l1_size"] + # 获取融合操作内核名称 op_kernel_name = job.content["fusion_op_name"] + # 获取完整操作名称 op_name = job.content["full_name"] + # 分派自动调优任务 dispatch_autotune_task(job.source_id, job.id, l1_size, json.dumps(job.content), {}, op_kernel_name, op_name) + # 设置任务状态为运行中 job.status = JobStatus.JOB_RUNNING return True def rl_tune_single_op(job: TbeJob): """ - RL tune single op - :param job: - :return: + Perform RL (Reinforcement Learning) tuning for a single operation. + + This function is responsible for tuning a single operation using RL techniques. + It retrieves the operation's information, performs the tuning, and handles any exceptions that may occur during the process. + + Args: + job (TbeJob): An object containing job information, including the operation to be tuned. + + Returns: + bool: True if the RL tuning is successful, False otherwise. """ + # Retrieve the list of compute operations from the job content compute_op_info_list = get_compute_op_list(job.content) + # Check if there is exactly one compute operation if len(compute_op_info_list) != 1: job.error("Invalid op compute num ({}) in rl tune single op".format(len(compute_op_info_list))) return False + # Get the first (and only) compute operation info compute_op_info = compute_op_info_list[0] + # Assemble the operation's input, output, and attributes inputs, outputs, attrs = assemble_op_args(compute_op_info) + # Get the operation type op_type = compute_op_info["type"] + # Get the L1 size from the job content l1_size = job.content["l1_size"] + # Get the operation module name op_module_name = get_module_name(compute_op_info) + # Get the operation kernel name op_kernel_name = compute_op_info["op_name"] + # Get the full name of the operation full_name = compute_op_info["name"] + # Get the Python module path py_module_path = compute_op_info["py_module_path"] + # Get the operation function name op_func_name = compute_op_info["func_name"] + # Normalize the module name _normalize_module_name(op_module_name, py_module_path) + # Set the current operation name set_current_op_name(op_kernel_name) + # Get the unknown shape information unknown_shape = compute_op_info["unknown_shape"] + # Get the int64 mode information int64_mode = compute_op_info["int64mode"] + # Get the operation pattern op_pattern = compute_op_info["pattern"] + # Get the fuzz build information fuzz_build_info = get_fuzz_build_info(job.content) + # Get the auto tiling mode auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"] + # Get the device ID device_id = job.content["SocInfo"]["deviceId"] + # Get the options information options = get_options_info(job.content) try: + # Build the single operation from C code build_single_op_from_c(op_module_name, op_func_name, op_type, "build", unknown_shape, (inputs, outputs, attrs), int64_mode, unknown_shape, options, op_pattern, auto_tiling_mode, device_id, json.dumps(fuzz_build_info)) - # pylint: disable=broad-except except Exception: + # If an exception occurs, log the error and return False job.error( "Single op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string)) exc_type, exc_value, _ = sys.exc_info() job.error( "exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc())) return False - finally: - pass + # Prepare the tuning operation module name tune_op_module_name = op_module_name + "@" + py_module_path + # Get the base kernel path base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o" + # Dispatch the single tune task from schedule_search.rl_online_tune import dispatch_single_tune_task pack_args = pack_op_args(inputs, outputs, attrs) res = dispatch_single_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, full_name, tune_op_module_name, op_func_name, op_type, pack_args) + # Process the RL tune result return _process_rl_tune_result(job, op_type, res) def rl_tune_fusion_op(job: TbeJob): """ - rl tune fusion op - :param job: - :return: + Perform RL tuning for a fusion operation. + + This function is responsible for tuning a fusion operation using RL techniques. + It compiles the operation using multiprocessing and handles any exceptions that may occur during the process. + + Args: + job (TbeJob): An object containing job information, including the fusion operation to be tuned. + + Returns: + bool: True if the RL tuning is successful, False otherwise. """ + # Get the fusion operation kernel name op_kernel_name = job.content["fusion_op_name"] + # Set the current operation name set_current_op_name(op_kernel_name) try: + # Compile the operation using multiprocessing from schedule_search.rl_online_tune import compile_op_by_mp compile_op_by_mp(json.dumps(job.content)) # pylint: disable=broad-except except Exception: + # If an exception occurs, log the error and return False job.error( "Fusion op {} build failed, no need to do rl tune, json string:{}".format(op_kernel_name, job.json_string)) exc_type, exc_value, _ = sys.exc_info() job.error( "exc_type:{}, exc_value:{}, exc_traceback:{}".format(exc_type, exc_value, traceback.format_exc())) return False - finally: - pass + # Get the L1 size l1_size = job.content["l1_size"] + # Get the base kernel path base_kernel = job.content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_kernel_name + ".o" + # Get the list of compute operations compute_op_list = get_compute_op_list(job.content) + # Prepare the operation module names string op_module_names_str = "" op_type_set = set() for op in compute_op_list: op_module_names_str = ','.join([op_module_names_str, get_module_name(op)]) op_type_set.add(op["type"]) + # Remove the leading comma from the operation module names string op_module_names_str = op_module_names_str[1:] + # Join the operation types with double underscore op_type = "__".join(list(op_type_set)) + # Dispatch the fusion tune task from schedule_search.rl_online_tune import dispatch_fusion_tune_task res = dispatch_fusion_tune_task(job.source_id, job.id, l1_size, base_kernel, op_kernel_name, op_module_names_str, json.dumps(job.content)) + # Process the RL tune result return _process_rl_tune_result(job, op_type, res) def _process_rl_tune_result(job, op_type, res): + """ + Process the result of RL tuning. + + If the tuning result is False, it checks if the operation type is in the black list or if the job is set to offline tune. + If the tuning result is True, it sets the job status to running. + + Args: + job (TbeJob): An object containing job information. + op_type (str): The type of the operation. + res (bool): The result of RL tuning. + + Returns: + bool: The processed result of RL tuning. + """ if not res: + # Check if the operation type is in the black list or if the job is set to offline tune from schedule_search.tune_util import filter_black_op_type res = bool(job.sys_offline_tune or os.getenv("REPEAT_TUNE", "False").lower() != "true" or filter_black_op_type( op_type)) else: + # Set the job status to running job.status = JobStatus.JOB_RUNNING res = True return res @@ -650,8 +942,13 @@ def _process_rl_tune_result(job, op_type, res): def get_finish_tasks(source_id): """ - Get finish task from parallel compilation framework - :return task info list + Get the list of finished tasks from the parallel compilation framework. + + Args: + source_id (int): The source ID of the tasks. + + Returns: + list: A list of finished task information. """ return get_finished_compilation_task(source_id) @@ -664,14 +961,21 @@ def tbe_finalize(auto_tiling_mode, offline_tune, job: TbeJob): :param job: TbeJob :return: None """ + # 释放多进程环境 deinit_multi_process_env() + # 如果自动切分模式为RL或者离线调优,则释放RL调优 if "RL" in auto_tiling_mode or offline_tune: from schedule_search.rl_online_tune import rl_tune_deinit rl_tune_deinit() + # 卸载Cann kb res = _cann_kb_unload(job) + # 如果卸载失败,则返回False if res == 1: job.error("Cann kb unload failed") return False + # 清除融合参数 clear_fusion_params() + # 删除缓存 _remove_cache(job) + # 返回True return True diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py index 6b48a037..7af39764 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py @@ -26,6 +26,7 @@ class BuildType(Enum): ACCURATELY = "accurately" +# 获取JobType枚举类中的所有值 job_type_list = [job_type.value for _, job_type in JobType.__members__.items()] @@ -35,14 +36,19 @@ def check_job_json(job_info): :param job_info:tne compilation job json :return: raise value error if wrong """ + # 检查job_info中是否包含source_id if 'source_id' not in job_info: raise ValueError("Json string Errors, key:source_id not found.") + # 检查job_info中是否包含job_id if 'job_id' not in job_info: raise ValueError("Json string Errors, key:job_id not found.") + # 检查job_info中是否包含job_type if 'job_type' not in job_info or not job_info['job_type']: raise ValueError("Json string Errors, key:job_type not found.") + # 检查job_info中job_type是否在job_type_list中 if job_info['job_type'] not in job_type_list: raise ValueError("Invalid job type: {}.".format(job_info['job_type'])) + # 检查job_info中是否包含job_content if 'job_content' not in job_info: raise ValueError("Json string Errors, key:job_content not found.") @@ -52,6 +58,7 @@ def reset_op_debug_level_in_soc_info(level): :param level: op_debug_level, if level is 3 or 4, replace it with 0 :return: op_debug_level """ + # 如果level为3或4,则将其替换为0 if level in ("3", "4"): level = "0" return level @@ -62,6 +69,7 @@ def get_real_op_debug_level(initialize_job_info): :param initialize_job_info: initialize_job_info :return: origin op_debug_level for init_multi_process_env """ + # 返回initialize_job_info中op_debug_level的值 return initialize_job_info["SocInfo"]["op_debug_level"] @@ -72,21 +80,35 @@ def get_soc_info(initialize_job_info): :return: soc info """ soc_param = dict() + # 获取soc_info中的op_impl_mode soc_param["op_impl_mode"] = initialize_job_info["SocInfo"]["op_impl_mode"] + # 获取soc_info中的op_debug_level,并调用reset_op_debug_level_in_soc_info函数进行处理 soc_param["op_debug_level"] = reset_op_debug_level_in_soc_info(initialize_job_info["SocInfo"]["op_debug_level"]) + # 获取soc_info中的op_impl_mode_list soc_param["op_impl_mode_list"] = initialize_job_info["SocInfo"]["op_impl_mode_list"] + # 获取soc_info中的op_debug_dir soc_param["op_debug_dir"] = initialize_job_info["SocInfo"]["op_debug_dir"] + # 获取soc_info中的vector_fp_ceiling soc_param["vector_fp_ceiling"] = initialize_job_info["SocInfo"]["vector_fp_ceiling"] + # 获取soc_info中的mdl_bank_path soc_param['mdl_bank_path'] = initialize_job_info["SocInfo"]["mdl_bank_path"] + # 获取soc_info中的op_bank_path soc_param['op_bank_path'] = initialize_job_info["SocInfo"]["op_bank_path"] soc_info = list() + # 获取soc_info中的socVersion soc_info.append(initialize_job_info["SocInfo"]["socVersion"]) + # 获取soc_info中的coreType soc_info.append(initialize_job_info["SocInfo"]["coreType"]) + # 获取soc_info中的coreNum soc_info.append(initialize_job_info["SocInfo"]["coreNum"]) + # 获取soc_info中的l1Fusion soc_info.append(initialize_job_info["SocInfo"]["l1Fusion"]) + # 获取soc_info中的l2Mode soc_info.append(initialize_job_info["SocInfo"]["l2Mode"]) + # 获取soc_info中的l2Fusion soc_info.append(initialize_job_info["SocInfo"]["l2Fusion"]) + # 将soc_param添加到soc_info中 soc_info.append(soc_param) return soc_info @@ -98,16 +120,22 @@ def check_arg_info(io_info): :param io_info:A dict, to be checked. :return: Exception: If specific keyword is not found. """ + # 检查io_info中是否包含shape if 'shape' not in io_info: raise ValueError("Json string Errors, key:shape not found.") + # 检查io_info中是否包含ori_shape if 'ori_shape' not in io_info: raise ValueError("Json string Errors, key:ori_shape not found.") + # 检查io_info中是否包含format if 'format' not in io_info or not io_info['format']: raise ValueError("Json string Errors, key:format not found.") + # 检查io_info中是否包含ori_format if 'ori_format' not in io_info or not io_info['ori_format']: raise ValueError("Json string Errors, key:ori_format not found.") + # 检查io_info中是否包含dtype if 'dtype' not in io_info or not io_info['dtype']: raise ValueError("Json string Errors, key:dtype not found.") + # 检查io_info中是否包含param_type if 'param_type' not in io_info or not io_info['param_type']: raise ValueError("Json string Errors, key:param_type not found.") @@ -119,18 +147,28 @@ def get_input_output_args(io_info): :return:input/output args """ args = [] + # 如果io_info为空,则返回空列表 if io_info is None: return args + # 遍历io_info中的每个元素 for item in io_info: + # 如果元素是字典类型 if isinstance(item, dict): + # 调用get_single_io_arg函数获取单个输入/输出参数 arg = get_single_io_arg(item) args.append(arg) elif isinstance(item, list): + # 如果元素是列表类型 dyn_arg = [] + # 创建一个空列表dyn_arg for info in item: + # 遍历列表中的每个元素 arg = get_single_io_arg(info) + # 调用get_single_io_arg函数获取单个输入/输出参数 dyn_arg.append(arg) + # 将参数添加到dyn_arg列表中 args.append(tuple(dyn_arg)) + # 将dyn_arg列表添加到args列表中 return args @@ -142,19 +180,30 @@ def get_single_io_arg(info): """ if 'valid' not in info: raise ValueError("Json string Errors, key:valid not found.") + # 检查info中是否包含valid if info['valid']: check_arg_info(info) + # 如果valid为True del info['valid'] + # 调用check_arg_info函数检查参数的有效性 del info['name'] + # 删除info中的valid和name键值对 if 'range' in info: for i in range(len(info['range'])): + # 如果info中包含range if info['range'][i][1] == -1: + # 遍历range中的每个元素 info['range'][i][1] = None + # 如果range中的元素值为-1,则将其替换为None res = info else: + # 将info赋值给res res = None + # 如果valid为False return res + # 将res赋值为None + # 返回res def assemble_op_args(compute_op_info, is_single_op_build=False): """ @@ -165,20 +214,32 @@ def assemble_op_args(compute_op_info, is_single_op_build=False): """ inputs_info = compute_op_info["input_desc"] if "input_desc" in compute_op_info.keys() else None outputs_info = compute_op_info["output_desc"] if "output_desc" in compute_op_info.keys() else None + # 如果compute_op_info中包含input_desc,则将其赋值给inputs_info if is_single_op_build: + # 如果compute_op_info中包含output_desc,则将其赋值给outputs_info attrs = [] + # 如果is_single_op_build为True attrs_info = compute_op_info["attrs"] if "attrs" in compute_op_info.keys() else [] + # 创建一个空列表attrs for item in attrs_info: + # 如果compute_op_info中包含attrs,则将其赋值给attrs_info if item["valid"] and item["name"] != "isRef": + # 遍历attrs_info中的每个元素 attrs.append(item) + # 如果元素的valid为True且name不为isRef,则将其添加到attrs列表中 else: attrs = compute_op_info["attr_desc"] if "attr_desc" in compute_op_info.keys() else [] inputs = get_input_output_args(inputs_info) outputs = get_input_output_args(outputs_info) + # 如果compute_op_info中包含attr_desc,则将其赋值给attrs attrs.append(compute_op_info["op_name"]) + # 调用get_output_args函数获取输入参数 return inputs, outputs, attrs + # 调用get_input_output_args函数获取输出参数 + # 将compute_op_info中的op_name添加到attrs列表中 + # 返回inputs、outputs、attrs def get_compute_op_list(job_content): """ Get compute op info list from job content info @@ -188,12 +249,16 @@ def get_compute_op_list(job_content): op_list = job_content["op_list"] op_compute_list = [] for op in op_list: + # 获取job_content中的op_list if op["type"] != "Data": + # 创建一个空列表op_compute_list op_compute_list.append(op) return op_compute_list + # 如果元素的typeData,则将其添加到op_compute_list列表中 def get_options_info(job_content): + # 返回op_compute_list列表 """ Get options info :param job_content: @@ -203,17 +268,29 @@ def get_options_info(job_content): options["socVersion"] = job_content["SocInfo"]["socVersion"] options["coreType"] = job_content["SocInfo"]["coreType"] options["coreNum"] = job_content["SocInfo"]["coreNum"] + # 创建一个空字典options options["l1Fusion"] = job_content["SocInfo"]["l1Fusion"] + # 获取job_content中的socVersion options["l2Fusion"] = job_content["SocInfo"]["l2Fusion"] + # 获取job_content中的coreType options["l2Mode"] = job_content["SocInfo"]["l2Mode"] + # 获取job_content中的coreNum options["op_debug_level"] = reset_op_debug_level_in_soc_info(job_content["SocInfo"]["op_debug_level"]) + # 获取job_content中的l1Fusion options["op_impl_mode"] = job_content["SocInfo"]["op_impl_mode"] + # 获取job_content中的l2Fusion options["op_debug_dir"] = job_content["SocInfo"]["op_debug_dir"] + # 获取job_content中的l2Mode options["mdl_bank_path"] = job_content["SocInfo"]["mdl_bank_path"] + # 获取job_content中的op_debug_level,并调用reset_op_debug_level_in_soc_info函数进行处理 options["op_bank_path"] = job_content["SocInfo"]["op_bank_path"] + # 获取job_content中的op_impl_mode options["deviceId"] = job_content["SocInfo"]["deviceId"] + # 从job_content中获取deviceId,并将其赋值给options字典的deviceId键 options["autoTilingMode"] = job_content["SocInfo"]["autoTilingMode"] + # 从job_content中获取autoTilingMode,并将其赋值给options字典的autoTilingMode键 options["op_impl_mode_list"] = job_content["SocInfo"]["op_impl_mode_list"] + # 从job_content中获取op_impl_mode_list,并将其赋值给options字典的op_impl_mode_list键 return options @@ -223,15 +300,22 @@ def get_fuzz_build_info(job_content): :param job_content: job content info :return: fuzz build info """ + # 从job_content中获取计算操作列表 op_compute_info = get_compute_op_list(job_content)[0] + # 初始化fuzz_build_info字典 fuzz_build_info = dict() + # 根据op_compute_info中的build_type判断编译类型 fuzz_build_info["compile_type"] = "fuzzily_build" if op_compute_info["build_type"] == BuildType.FUZZILY.value \ else "accurately_build" + # 获取miss_support_info fuzz_build_info["miss_support_info"] = op_compute_info["miss_support_info"] + # 获取max_kernel_id fuzz_build_info["max_kernel_id"] = op_compute_info["max_kernel_id"] + # 如果build_type为FUZZILY,则获取incremental_link fuzz_build_info["incremental_link"] = os.path.realpath( job_content["SocInfo"]["op_debug_dir"] + "/kernel_meta/" + op_compute_info["name"] + ".json") if \ op_compute_info["build_type"] == BuildType.FUZZILY.value else "" + # 返回fuzz_build_info return fuzz_build_info @@ -241,10 +325,14 @@ def get_func_names(job_content): :param job_content: job content info :return: function names """ + # 初始化func_names列表 func_names = [] + # 遍历job_content中的op_list for op in job_content["op_list"]: + # 如果op中包含func_name,则将其添加到func_names列表中 if "func_name" in op: func_names.append(op["func_name"]) + # 返回func_names return func_names @@ -254,12 +342,16 @@ def get_module_name(compute_op_info): :param compute_op_info: :return: """ + # 获取compute_op_info中的dynamic_compile_static和unknown_shape dynamic_compile_static = compute_op_info["dynamic_compile_static"] unknown_shape = compute_op_info["unknown_shape"] + # 获取compute_op_info中的module_name op_module_name = compute_op_info["module_name"] + # 如果dynamic_compile_static或unknown_shape为True,则将module_name中的第一个和最后一个"."之间的字符串替换为".dynamic." if dynamic_compile_static or unknown_shape: d = ".dynamic." op_module_name = d.join((op_module_name.split(".")[0], op_module_name.split(".")[-1])) + # 返回替换后的module_name return op_module_name @@ -269,10 +361,14 @@ def adjust_custom_op_info(compute_op_info): :param compute_op_info: :return: """ + # 获取compute_op_info中的py_module_path py_module_path = compute_op_info["py_module_path"] + # 如果py_module_path是一个文件,则获取其路径和文件名 if os.path.isfile(py_module_path): py_module_path, file_name = os.path.split(py_module_path) + # 获取文件名中的模块名 module_name, _ = os.path.splitext(file_name) + # 将py_module_path和module_name更新到compute_op_info中 compute_op_info["py_module_path"] = py_module_path compute_op_info["module_name"] = module_name @@ -281,5 +377,6 @@ def pack_op_args(inputs, outputs, attrs): """ flatten inputs outputs attrs """ + # 将inputs、outputs、attrs展开为一个列表 op_args = (inputs, outputs, attrs) return [item for arg in op_args for item in arg] diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py index 6de4c424..ac66faef 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py @@ -20,14 +20,23 @@ from enum import Enum class JobType(Enum): """ Job Type """ + # 初始化任务 INITIALIZE_JOB = 'Initialize' + # 结束任务 FINALIZE_JOB = 'Finalize' + # 检查支持任务 CHECK_JOB = 'CheckSupport' + # 选择格式任务 SELECT_JOB = 'SelectFormat' + # 预编译任务 PRECOMPILE_JOB = 'PreCompile' + # 编译任务 COMPILE_JOB = 'Compile' + # 融合编译任务 FUSION_COMPILE_JOB = 'FusionOpCompile' + # 调优任务 TUNE_JOB = 'Tune' + # 查询任务 QUERY_JOB = 'Query' @@ -51,9 +60,13 @@ class JobStatus(Enum): class LogMessage: """ Log message """ + # 初始化函数,用于创建一个对象 def __init__(self, index, level, info): + # 将传入的index参数赋值给对象的index属性 self.index = index + # 将传入的level参数赋值给对象的level属性 self.level = level + # 将传入的info参数赋值给对象的info属性 self.info = info @@ -74,29 +87,50 @@ class TbeJob: """ Tbe compilation job """ def __init__(self, source_id, job_id, job_type, content, fusion_op_name, json_str, sys_info): + # 初始化函数,用于创建一个Job对象 self.source_id = source_id + # 源ID self.id = job_id + # 任务ID self.type = JobType(job_type) + # 任务类型 self.status = JobStatus.JOB_INITIAL + # 任务状态 self.content = content + # 任务内容 self.fusion_op_name = fusion_op_name + # 融合操作名称 self.result = "" + # 任务结果 self.process_info = [] + # 任务处理信息 self.json_string = json_str + # JSON字符串 self._sys_logger = sys_info["logger"] + # 系统日志 self.sys_offline_tune = sys_info["offline_tune"] + # 离线调优 self.sys_tune_dump_path = sys_info["tune_dump_path"] + # 调优转储路径 self.sys_para_debug_path = sys_info["para_debug_path"] + # 参数调试路径 # license info self.rl_tune_switch = sys_info["rl_tune_switch"] + # 强化学习调优开关 self.rl_tune_list = sys_info["rl_tune_list"] + # 强化学习调优列表 self.op_tune_switch = sys_info["op_tune_switch"] + # 操作调优开关 self.op_tune_list = sys_info["op_tune_list"] + # 操作调优列表 self.pass_list = sys_info["pass_list"] + # 通过列表 # soc info self.soc_version = sys_info["socVersion"] + # SoC版本 self.core_num = sys_info["coreNum"] + # 核心数量 self.op_bank_path = sys_info["op_bank_path"] def debug(self, msg, *args, **kwargs): @@ -106,9 +140,13 @@ class TbeJob: :param args: :return: """ + # 获取处理后的消息 processed_msg = _get_message(msg, args) + # 创建日志消息对象 message = LogMessage(len(self.process_info), LogLevel.DEBUG, processed_msg) + # 将日志消息对象添加到process_info列表中 self.process_info.append(message) + # 使用系统日志记录器记录日志 self._sys_logger.debug(msg, *args, **kwargs) def info(self, msg, *args, **kwargs): @@ -118,9 +156,13 @@ class TbeJob: :param args: :return: """ + # 获取处理后的消息 processed_msg = _get_message(msg, args) + # 创建日志消息对象 message = LogMessage(len(self.process_info), LogLevel.INFO, processed_msg) + # 将日志消息对象添加到process_info列表中 self.process_info.append(message) + # 使用系统日志记录器记录日志 self._sys_logger.info(msg, *args, **kwargs) def warning(self, msg, *args, **kwargs): @@ -130,9 +172,13 @@ class TbeJob: :param args: :return: """ + # 获取处理后的消息 processed_msg = _get_message(msg, args) + # 创建日志消息对象 message = LogMessage(len(self.process_info), LogLevel.WARNING, processed_msg) + # 将日志消息对象添加到process_info列表中 self.process_info.append(message) + # 使用系统日志记录器记录警告信息 self._sys_logger.warning(msg, *args, **kwargs) def error(self, msg, *args, **kwargs): @@ -142,9 +188,13 @@ class TbeJob: :param args: :return: """ + # 获取处理后的消息 processed_msg = _get_message(msg, args) + # 创建一个LogMessage对象,包含消息的长度、日志级别和消息内容 message = LogMessage(len(self.process_info), LogLevel.ERROR, processed_msg) + # 将LogMessage对象添加到process_info列表中 self.process_info.append(message) + # 使用_sys_logger记录错误日志,msg为原始消息,args和kwargs为参数 self._sys_logger.error(msg, *args, **kwargs) def error_manager(self, msg, *args, **kwargs): @@ -154,30 +204,50 @@ class TbeJob: :param args: :return: """ + # 如果msg为空,则输出警告信息并返回 if not msg: self.warning("Get empty error manager message, op_name: {}".format(self.fusion_op_name)) return + # 初始化异常信息为None exception_info = None + # 获取融合操作名称 op_name = self.fusion_op_name + # 如果msg是Exception类型 if isinstance(msg, Exception): + # 遍历msg的参数 for arg in msg.args: + # 如果参数是字典类型且包含"errCode"键 if isinstance(arg, dict) and "errCode" in arg: + # 将异常信息赋值给exception_info exception_info = arg break + # 如果没有找到异常信息 if not exception_info: + # 输出错误信息 self.error("Exception message:{}".format(msg)) return + # 如果msg不是Exception类型 else: + # 将msg的第一个元素赋值给异常信息 exception_info = msg[0] + # 如果msg的长度大于等于2 if len(msg) >= 2: + # 将msg的第二个元素赋值给融合操作名称 op_name = msg[1] + # 如果异常信息不是字典类型或为空 if not isinstance(exception_info, dict) or not exception_info: + # 输出警告信息 self.warning("Get illegal error manager message, op_name: {}".format(self.fusion_op_name)) return + # 将异常信息中的op_name字段赋值为融合操作名称 exception_info["op_name"] = op_name + # 将异常信息转换为JSON格式 processed_msg = json.dumps(exception_info) + # 创建LogMessage对象 message = LogMessage(len(self.process_info), LogLevel.ERROR_MANAGER, processed_msg) + # 将LogMessage对象添加到process_info列表中 self.process_info.append(message) + # 输出异常信息 self._sys_logger.exception(msg, *args, **kwargs) def get_result(self): @@ -186,15 +256,26 @@ class TbeJob: :return: job process result string """ result = dict() + # 获取任务状态 result["status"] = self.status.value + # 获取任务源ID result["source_id"] = self.source_id + # 获取任务ID result["job_id"] = self.id + # 获取任务类型 result["job_type"] = self.type.value + # 获取融合操作名称 result["fusion_op_name"] = self.fusion_op_name + # 获取任务结果 result["result"] = self.result process_info = [] + # 遍历任务处理信息 for info in self.process_info: + # 构造消息字典 msg = {"index": info.index, "level": info.level.value, "message": info.info} + # 将消息字典添加到处理信息列表中 process_info.append(msg) + # 将处理信息列表添加到结果字典中 result["process_info"] = process_info + # 将结果字典转换为JSON字符串并返回 return json.dumps(result) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py index be60df96..44246954 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py @@ -29,6 +29,7 @@ class TbeJobManager: """ TBE compiler job manager """ def __init__(self): + # 定义一个字典,用于存储不同类型的任务及其对应的处理函数 self.job_handlers = { JobType.INITIALIZE_JOB: self.initialize_handler, JobType.FINALIZE_JOB: self.finalize_handler, @@ -41,24 +42,43 @@ class TbeJobManager: JobType.QUERY_JOB: self.query_handler } + # 定义一个字典,用于存储所有任务 self._all_jobs = {} + # 定义一个字典,用于存储已完成任务 self._finished_jobs = {} + # 定义一个字典,用于存储正在运行的任务 self._running_jobs = {} + # 定义一个字典,用于存储原始完成任务 self._raw_finish_jobs = {} + # 定义一个布尔值,用于判断TBE是否初始化 self.tbe_initialize = False + # 定义一个变量,用于存储初始化缓存 self.init_cache = None + # 定义一个字符串,用于存储参数调试路径 self.para_debug_path = "" + # 定义一个字符串,用于存储自动调优模式 self.auto_tiling_mode = "" + # 定义一个布尔值,用于判断是否离线调优 self.offline_tune = False + # 定义一个列表,用于存储调优操作 self.tune_op_list = [] + # 定义一个字符串,用于存储调优输出路径 self.tune_dump_path = "" + # 定义一个字符串,用于存储调优库路径 self.tune_bank_path = "" + # 定义一个列表,用于存储自动调优操作 self.auto_tune_op_list = [] + # 定义一个字典,用于存储预编译操作 self.pre_build_ops = {} + # 定义一个整数,用于存储融合编译需要同步的次数 self.fusion_need_sync = 0 + # 定义一个字典,用于存储导入的模块 self.imported_module = {} + # 定义一个字符串,用于存储SoC版本 self.soc_version = "" + # 定义一个整数,用于存储核心数量 self.core_num = 0 + # 定义一个字符串,用于存储操作库路径 self.op_bank_path = "" # license info self.rl_tune_switch = "" @@ -68,6 +88,7 @@ class TbeJobManager: self.pass_list = "" def __del__(self): + # 删除对象时调用reset方法 self.reset() def reset(self): @@ -75,22 +96,38 @@ class TbeJobManager: Reset the job manager :return: None """ + # 重置所有任务 self._all_jobs = {} + # 重置已完成任务 self._finished_jobs = {} + # 重置正在运行的任务 self._running_jobs = {} + # 重置原始已完成任务 self._raw_finish_jobs = {} + # 重置调试路径 self.para_debug_path = "" + # 重置自动切分模式 self.auto_tiling_mode = "" + # 重置离线调优 self.offline_tune = False + # 重置调优操作列表 self.tune_op_list = [] + # 重置调优导出路径 self.tune_dump_path = "" + # 重置调优银行路径 self.tune_bank_path = "" + # 重置自动调优操作列表 self.auto_tune_op_list = [] + # 重置预构建操作 self.pre_build_ops = [] + # 重置融合需要同步 self.fusion_need_sync = 0 + # 重置导入模块 self.imported_module = {} + # 如果tbe_initialize为True,则调用tbe_finalize方法 if self.tbe_initialize: tbe_finalize(self.auto_tiling_mode, self.offline_tune, self.init_cache) + # 重置tbe_initialize self.tbe_initialize = False self.init_cache = None self.soc_version = "" @@ -105,11 +142,17 @@ class TbeJobManager: """ job = None try: + # 将job_str转换为json格式 job_json = json.loads(job_str) + # 检查job_json的合法性 check_job_json(job_json) + # 获取job_id job_id = job_json["job_id"] + # 获取source_id source_id = job_json["source_id"] + # 获取job_type job_type = job_json["job_type"] + # 获取系统信息 sys_info = self._get_job_sys_info() fusion_op_name = "NA" if "fusion_op_name" not in job_json["job_content"] else job_json["job_content"][ "fusion_op_name"] @@ -140,173 +183,260 @@ class TbeJobManager: def initialize_handler(self, job: TbeJob): """ Initialize job handler """ + # 初始化系统信息 self._init_sys_info(job) + # 调用tbe_initialize函数初始化job res = tbe_initialize(job) + # 如果初始化失败,记录错误信息,并将job状态设置为JOB_FAILED if not res: job.error("Process Initialize Job failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 如果auto_tiling_mode中包含"GA",则获取自动调优支持的操作列表 if "GA" in self.auto_tiling_mode: self.auto_tune_op_list = get_auto_tune_support_op_list(job) + # 设置tbe_initialize为True self.tbe_initialize = True + # 将job保存到init_cache中 self.init_cache = job + # 将job状态设置为JOB_SUCCESS return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) def finalize_handler(self, job: TbeJob): """ Finalize job handler """ + # 如果tbe_initialize为False,则直接将job状态设置为JOB_SUCCESS if not self.tbe_initialize: return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) + # 调用tbe_finalize函数,传入auto_tiling_mode和offline_tune参数 res = tbe_finalize(self.auto_tiling_mode, self.offline_tune, job) + # 如果finalize失败,记录错误信息,并将job状态设置为JOB_FAILED if not res: job.error("Process Finalize Job failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 将job状态设置为JOB_SUCCESS return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) def check_support_handler(self, job: TbeJob): """ Check Support job handler """ + # 调用check_support函数,检查job是否支持 res = check_support(job) + # 如果不支持,记录错误信息,并将job状态设置为JOB_FAILED if not res: job.error("Process CheckSupport Job failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 更新导入的操作模块 self._update_imported_op_module(job) + # 将job状态设置为JOB_SUCCESS return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) def select_format_handler(self, job: TbeJob): """ Select Format job handler """ + # 调用select_op_format函数,选择操作格式 res = select_op_format(job) + # 如果选择失败,记录错误信息,并将job状态设置为JOB_FAILED if not res: job.error("Process SelectFormat Job failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 将job状态设置为JOB_SUCCESS return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) def pre_compile_handler(self, job: TbeJob): """ Pre Compile job handler """ + # 调用parallel_pre_compile_op函数,对job进行预处理 res = parallel_pre_compile_op(job) + # 如果预处理失败,则记录错误信息,并将job状态设置为JOB_FAILED if not res: job.error("Process PreCompile Job failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 将job添加到pre_build_ops字典中,以fusion_op_name为键 self.pre_build_ops[job.content["fusion_op_name"]] = job + # 将job状态设置为JOB_RUNNING return self.add_to_running_jobs(job) def compile_handler(self, job: TbeJob): """ Compile job handler """ + # 获取job中的compute_op_list compute_op_list = get_compute_op_list(job.content) + # 如果compute_op_list只有一个元素,则调用single_op_compile函数进行编译 if len(compute_op_list) == 1: # pylint: disable=no-else-return return self.single_op_compile(job) else: + # 调用before_build_process函数,对job进行预处理 before_build_process(job) + # 如果需要同步fusion,则调用sync_fusion_env函数进行同步 if self.fusion_need_sync: sync_fusion_env(self.fusion_need_sync, self.imported_module) self.fusion_need_sync = 0 + # 调用parallel_compile_fusion_op函数,对job进行编译 res = parallel_compile_fusion_op(job) + # 如果编译失败,则记录错误信息,并将job状态设置为JOB_FAILED if not res: job.error("Parallel_compile_fusion_op Job failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 将job状态设置为JOB_RUNNING return self.add_to_running_jobs(job) def single_op_compile(self, job: TbeJob): """Single operator compile""" + # 调用do_fuzz_build_tbe_op函数,对job进行编译 res = do_fuzz_build_tbe_op(job) + # 如果编译失败,则记录错误信息,并将job状态设置为JOB_FAILED if not res: job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 如果job.result为"NOT_CHANGED",则调用before_build_process函数进行预处理,并调用build_single_pre_op函数进行编译 if job.result == "NOT_CHANGED": job.result = "" before_build_process(job) res = build_single_pre_op(job) + # 如果编译失败,则记录错误信息,并将job状态设置为JOB_FAILED if not res: job.error("Process build single pre op failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 将job状态设置为JOB_RUNNING return self.add_to_running_jobs(job) + # 如果job.result为"SUCCESS",则将job状态设置为JOB_SUCCESS if job.result == "SUCCESS": return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) + # 如果编译失败,则记录错误信息,并将job状态设置为JOB_FAILED job.error("Process do fuzz build tbe op failed, job json string:{}".format(job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) def tune_handler(self, job: TbeJob): """ Tune job handler """ before_build_process(job) + # 选择调优模式 tune_mode = self._select_tune_mode(job) + # 如果调优模式为不调优,则直接调用编译处理函数 if tune_mode == TuneMode.NO_TUNE: return self.compile_handler(job) + # 获取计算操作列表 compute_op_list = get_compute_op_list(job.content) + # 如果计算操作列表只有一个,则调用单操作调优函数 if len(compute_op_list) == 1: return self.single_op_tune(job) + # 否则调用融合操作调优函数 return self.fusion_op_tune(job) def single_op_tune(self, job: TbeJob): """Single operator tune""" + # 选择调优模式 tune_mode = self._select_tune_mode(job) + # 如果调优模式为强化学习调优 if tune_mode == TuneMode.RL_TUNE: + # 调用强化学习单操作调优函数 res = rl_tune_single_op(job) + # 如果调优失败,则记录错误信息,并将任务状态设置为失败 if not res: job.error( "Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 否则,如果需要同步融合环境,则调用同步融合环境函数 else: if self.fusion_need_sync: sync_fusion_env(self.fusion_need_sync, self.imported_module) self.fusion_need_sync = 0 + # 调用遗传算法调优函数 res = ga_tune(job) + # 如果调优失败,则记录错误信息,并调用编译处理函数 if not res: job.error("ga tune Job failed, job json string:{}".format(job.json_string)) return self.compile_handler(job) + # 如果任务状态为运行中 if job.status == JobStatus.JOB_RUNNING: + # 如果调优模式为强化学习调优,则更新导入的操作模块 if tune_mode == TuneMode.RL_TUNE: self._update_imported_op_module(job) + # 将任务添加到运行中任务列表 return self.add_to_running_jobs(job) + # 否则将任务添加到已完成任务列表,并设置任务状态为成功 return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) def fusion_op_tune(self, job: TbeJob): """Fusion operator tune""" + # 选择调优模式 tune_mode = self._select_tune_mode(job) + # 如果需要同步融合环境,则调用同步融合环境函数 if self.fusion_need_sync: sync_fusion_env(self.fusion_need_sync, self.imported_module) self.fusion_need_sync = 0 + # 如果调优模式为强化学习调优,则调用强化学习融合操作调优函数 if tune_mode == TuneMode.RL_TUNE: res = rl_tune_fusion_op(job) + # 否则调用遗传算法调优函数 else: res = ga_tune(job) + # 如果调优失败,则记录错误信息,并将任务状态设置为失败 if not res: job.error( "Tune Job failed, tune type {}, job json string:{}".format(tune_mode, job.json_string)) return self.add_to_finished_jobs(job, JobStatus.JOB_FAILED) + # 如果任务状态为运行中,则将任务添加到运行中任务列表 if job.status == JobStatus.JOB_RUNNING: return self.add_to_running_jobs(job) + # 否则将任务添加到已完成任务列表,并设置任务状态为成功 return self.add_to_finished_jobs(job, JobStatus.JOB_SUCCESS) def query_handler(self, query_job: TbeJob): """ Query job handler """ + # 获取查询任务的source_id和job_id target_source_id = query_job.content["source_id"] target_job_id = query_job.content["job_id"] + # 根据source_id和job_id获取已完成的任务 target_job = get_job(self._finished_jobs, target_source_id, target_job_id) + # 如果找到了已完成的任务 if target_job: + # 记录警告信息 query_job.warning("Query a finished job: {}".format(query_job.content)) + # 将查询任务的结果设置为已完成任务的结果 query_job.result = target_job.get_result() + # 将查询任务添加到已完成任务列表中,并返回成功状态 return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) + # 根据source_id和job_id获取未完成的任务 target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id) + # 如果未找到未完成的任务 if not target_job: + # 更新未完成的任务列表 self.update_raw_finished_jobs(query_job) + # 再次根据source_id和job_id获取未完成的任务 target_job = get_job(self._raw_finish_jobs, target_source_id, target_job_id) + # 如果找到了未完成的任务 if target_job: + # 记录调试信息 query_job.debug("Found job in raw finished jobs, source_id:{}, job_id:{}".format(target_source_id, target_job_id)) + # 将查询任务的结果设置为未完成任务的结果 query_job.result = target_job.get_result() + # 从未完成任务列表中删除该任务 del_job(self._raw_finish_jobs, target_job.source_id, target_job.id) + # 将未完成任务添加到已完成任务列表中,并返回成功状态 self.add_to_finished_jobs(target_job, target_job.status) return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) + # 根据source_id和job_id获取正在运行的任务 target_job = get_job(self._running_jobs, target_source_id, target_job_id) + # 如果找到了正在运行的任务 if target_job: + # 将查询任务的结果设置为正在运行任务的结果 query_job.result = target_job.get_result() + # 将查询任务添加到已完成任务列表中,并返回成功状态 return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) + # 根据source_id和job_id获取所有任务 target_job = get_job(self._all_jobs, target_source_id, target_job_id) + # 如果找到了所有任务 if target_job: + # 记录调试信息 query_job.debug("Found job in all jobs, source_id:{}, job_id:{}".format(target_source_id, target_job_id)) + # 记录调试信息 target_job.debug("Be Queried") + # 将查询任务的结果设置为所有任务的结果 query_job.result = target_job.get_result() + # 将查询任务添加到已完成任务列表中,并返回成功状态 return self.add_to_finished_jobs(query_job, JobStatus.JOB_SUCCESS) + # 如果没有找到任何任务,记录错误信息 query_job.error("Can't find job in finished/raw_finished/running jobs, source_id: {}".format(target_source_id)) + # 将查询任务的结果设置为空 query_job.result = "" + # 将查询任务添加到已完成任务列表中,并返回失败状态 return self.add_to_finished_jobs(query_job, JobStatus.JOB_FAILED) def _get_job_sys_info(self): @@ -314,10 +444,15 @@ class TbeJobManager: Get job manager system info :return: system info """ + # 创建一个字典,用于存储系统信息 sys_info = dict() + # 将DummyLogger添加到系统信息中 sys_info["logger"] = DummyLogger + # 将para_debug_path添加到系统信息中 sys_info["para_debug_path"] = self.para_debug_path + # 将tune_dump_path添加到系统信息中 sys_info["tune_dump_path"] = self.tune_dump_path + # 将offline_tune添加到系统信息中 sys_info["offline_tune"] = self.offline_tune # license info sys_info["rl_tune_switch"] = self.rl_tune_switch @@ -362,12 +497,17 @@ class TbeJobManager: :param job: :return: """ + # 获取计算操作列表 compute_op_info = get_compute_op_list(job.content)[0] + # 获取操作模块名称 op_module_name = compute_op_info["module_name"] + # 如果操作模块名称在已导入模块中,则增加引用次数 if op_module_name in self.imported_module.keys(): self.imported_module[op_module_name] = self.imported_module[op_module_name] + 1 + # 否则,将操作模块名称添加到已导入模块中,并设置引用次数为1 else: self.imported_module[op_module_name] = 1 + # 增加融合需要同步的次数 self.fusion_need_sync = self.fusion_need_sync + 1 def _select_tune_mode(self, job): @@ -376,18 +516,25 @@ class TbeJobManager: :param job: tbe tune job :return: NO_TUNE RL_TUNE or GA_TUNE """ + # 获取job的SocInfo中的autoTilingMode和offlineTune auto_tiling_mode = job.content["SocInfo"]["autoTilingMode"] offline_tune = job.content["SocInfo"]["offlineTune"] + # 获取job的full_name full_name = job.content["full_name"] + # 获取job的func_names func_names = get_func_names(job.content) + # 如果self.tune_op_list不为空且full_name不在self.tune_op_list中,则返回TuneMode.NO_TUNE if self.tune_op_list and full_name not in self.tune_op_list: return TuneMode.NO_TUNE + # 如果offline_tune为True,则返回TuneMode.RL_TUNE if offline_tune: return TuneMode.RL_TUNE + # 如果auto_tiling_mode中包含TuneMode.GA_TUNE.value,则遍历func_names,如果func_name.lower()在self.auto_tune_op_list中,则返回TuneMode.GA_TUNE if TuneMode.GA_TUNE.value in auto_tiling_mode: for func_name in func_names: if func_name.lower() in self.auto_tune_op_list: return TuneMode.GA_TUNE + # 如果auto_tiling_mode中包含TuneMode.RL_TUNE.value,则返回TuneMode.RL_TUNE if TuneMode.RL_TUNE.value in auto_tiling_mode: return TuneMode.RL_TUNE return TuneMode.NO_TUNE @@ -398,15 +545,22 @@ class TbeJobManager: :param query_job: query job :return: Node """ + # 获取已完成任务 new_finished_jobs = get_finish_tasks(query_job.source_id) + # 遍历已完成任务 for new_job in new_finished_jobs: + # 获取任务ID source_id = new_job["graph_id"] job_id = new_job["task_id"] + # 获取任务 target_job = get_job(self._running_jobs, source_id, job_id) + # 如果任务不存在,则报错 if not target_job: query_job.error("Can't get job, source id:{}, job id:{}".format(source_id, job_id)) continue + # 设置任务结果 target_job.result = new_job["op_res"] if "op_res" in new_job else new_job["result"] + # 如果任务类型为预编译任务,则进行预编译 if target_job.type == JobType.PRECOMPILE_JOB: op_name = target_job.content["fusion_op_name"] op_params = get_prebuild_output(op_name) @@ -415,13 +569,17 @@ class TbeJobManager: pre_compile_result["op_params"] = op_params pre_compile_result["core_type"] = new_job["core_type"] if "core_type" in new_job else "" target_job.result = json.dumps(pre_compile_result) + # 输出任务结果 target_job.info("Query result:{}".format(new_job["result"])) + # 如果任务状态码为0,则任务成功 if new_job["status_code"] == 0: target_job.status = JobStatus.JOB_SUCCESS target_job.info("Query info_msg:{}".format(new_job["info_msg"])) + # 否则任务失败 else: target_job.status = JobStatus.JOB_FAILED target_job.error("Query info_msg:{}".format(new_job["info_msg"])) + # 输出错误信息 if "err_args" in new_job: target_job.error("Query err_args:{}".format(new_job["err_args"])) if "except_msg" in new_job: @@ -429,7 +587,9 @@ class TbeJobManager: if "except_tuple_msg" in new_job: target_job.error_manager(new_job["except_tuple_msg"]) target_job.error("\nOriginal compile json: \n {}\n".format(target_job.json_string)) + # 将任务添加到已完成任务列表 post_job(self._raw_finish_jobs, target_job) + # 从运行中任务列表中删除任务 del_job(self._running_jobs, target_job.source_id, target_job.id) def add_to_finished_jobs(self, job, status): @@ -456,8 +616,11 @@ class TbeJobManager: class TuneMode(Enum): """Class of tune mode: NO_TUNE, GA, RL""" + # 不调优模式 NO_TUNE = "NO_TUNE" + # 遗传算法调优模式 GA_TUNE = "GA" + # 强化学习调优模式 RL_TUNE = "RL" @@ -469,18 +632,22 @@ class DummyLogger: @staticmethod def debug(msg, *args, **kwargs): + """Debug级别日志""" pass @staticmethod def info(msg, *args, **kwargs): + """Info级别日志""" pass @staticmethod def warning(msg, *args, **kwargs): + """Warning级别日志""" pass @staticmethod def error(msg, *args, **kwargs): + """Error级别日志""" pass @staticmethod @@ -497,10 +664,13 @@ def get_job(jobs, source_id, job_id): :return: job instance if found in job list None if not found in job list """ + # 如果source_id不在jobs的键中,返回None if source_id not in jobs.keys(): return None + # 如果job_id不在jobs[source_id]的键中,返回None if job_id not in jobs[source_id].keys(): return None + # 返回jobs[source_id][job_id] return jobs[source_id][job_id] @@ -526,9 +696,15 @@ def del_job(jobs, source_id, job_id): :param job_id: target job's job_id :return: bool True or False """ + # 判断source_id是否在jobs字典中 if source_id not in jobs.keys(): + # 如果不在,返回False return False + # 判断job_id是否在jobs[source_id]字典中 if job_id not in jobs[source_id].keys(): + # 如果不在,返回False return False + # 删除jobs[source_id]字典中的job_id键值对 del jobs[source_id][job_id] + # 返回True return True diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/__init__.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/__init__.py index 99b0a40e..f78c7e9c 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/__init__.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/__init__.py @@ -26,6 +26,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, get_ms_class_attr) +# 导入parser模块中的所有函数和类 __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type', 'get_args_default_values', 'get_ast_namespace_symbol', 'get_operation_namespace_symbol', diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/namespace.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/namespace.py index df904e0a..2cd79562 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/namespace.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/namespace.py @@ -16,131 +16,136 @@ # ============================================================================ """Define the namespace of parse.""" -import builtins - -from mindspore import log as logger +import builtins # 导入内置模块builtins +from mindspore import log as logger # 从mindspore库导入log模块并重命名为logger class Namespace: """ - Base class of namespace for resolve variables. + 基类,用于解析变量命名空间。 Args: - name (str): The namespace's name. - dicts (dict): A list of dict containing the namespace's variable. + name (str): 命名空间的名称。 + dicts (dict): 包含命名空间变量的字典列表。 """ def __init__(self, name, *dicts): - self.name = name - self.dicts = dicts + self.name = name # 初始化命名空间名称 + self.dicts = dicts # 初始化包含变量的字典列表 def __contains__(self, name): + # 检查命名空间中是否包含指定名称的变量 for d in self.dicts: if name in d: return True return False def __getitem__(self, name): + # 获取命名空间中指定名称的变量 for d in self.dicts: if name in d: return d[name] - raise NameError(name) + raise NameError(name) # 如果未找到,抛出NameError def __repr__(self): + # 返回命名空间的字符串表示 return f'Namespace:{self.name}' class CellNamespace(Namespace): """ - Namespace for Cell object. + Cell对象的命名空间。 Args: - name (str): Valid module name, it can be imported. + name (str): 可导入的有效模块名称。 """ def __init__(self, name): - mod_dict = vars(__import__(name, fromlist=['_'])) - builtins_dict = vars(builtins) - super().__init__(name, mod_dict, builtins_dict) + mod_dict = vars(__import__(name, fromlist=['_'])) # 导入模块并获取其变量字典 + builtins_dict = vars(builtins) # 获取内置模块的变量字典 + super().__init__(name, mod_dict, builtins_dict) # 调用父类初始化 def __getstate__(self): + # 获取对象的状态,用于序列化 return (self.name,) def __setstate__(self, state): + # 设置对象的状态,用于反序列化 name, = state - mod_dict = vars(__import__(name, fromlist=['_'])) - builtins_dict = vars(builtins) - super().__init__(name, mod_dict, builtins_dict) - + mod_dict = vars(__import__(name, fromlist=['_'])) # 重新导入模块 + builtins_dict = vars(builtins) # 重新获取内置模块字典 + super().__init__(name, mod_dict, builtins_dict) # 重新初始化父类 class ClosureNamespace(Namespace): """ - Namespace for function closure. + 函数闭包的命名空间。 Args: - fn (Function): A python function. + fn (Function): 一个Python函数。 """ def __init__(self, fn): - name = f'{fn.__module__}..<{fn.__name__}>' - names = fn.__code__.co_freevars - cells = fn.__closure__ - ns = dict(zip(names, cells or ())) - super().__init__(name, ns) + name = f'{fn.__module__}..<{fn.__name__}>' # 构造命名空间名称 + names = fn.__code__.co_freevars # 获取函数的自由变量名称 + cells = fn.__closure__ # 获取函数的闭包 + ns = dict(zip(names, cells or ())) # 构造命名空间字典 + super().__init__(name, ns) # 调用父类初始化 def __getitem__(self, name): + # 获取命名空间中指定名称的变量 d, = self.dicts try: - return d[name].cell_contents + return d[name].cell_contents # 返回闭包内容 except ValueError: - raise UnboundLocalError(name) - + raise UnboundLocalError(name) # 如果未找到,抛出UnboundLocalError class ClassMemberNamespace(Namespace): """ - Namespace of a class's closure. + 类闭包的命名空间。 Args: - obj (Object): A python class object. + obj (Object): 一个Python类对象。 """ def __init__(self, obj): - self.__class_member_namespace__ = True - label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' - super().__init__(label, obj) + self.__class_member_namespace__ = True # 标记为类成员命名空间 + label = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' # 构造命名空间标签 + super().__init__(label, obj) # 调用父类初始化 def __getitem__(self, name): + # 获取命名空间中指定名称的变量 d, = self.dicts if name == "self": - return d + return d # 如果名称是self,返回对象本身 if name == "namespace": - return self + return self # 如果名称是namespace,返回命名空间对象 try: if hasattr(d, name): - return getattr(d, name) - return d.__dict__[name] + return getattr(d, name) # 如果对象有该属性,返回属性值 + return d.__dict__[name] # 否则从对象字典中获取 except ValueError: - raise UnboundLocalError(name) + raise UnboundLocalError(name) # 如果未找到,抛出UnboundLocalError except KeyError: logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") - raise AttributeError(name) - + raise AttributeError(name) # 如果未找到属性,记录日志并抛出AttributeError class ClassAttrNamespace(Namespace): """ - Namespace of a class. + 类的命名空间。 Args: - obj (Object): A python class object. + obj (Object): 一个Python类对象。 """ def __init__(self, obj): - name = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' - super().__init__(name, obj) + name = f'{obj.__module__}..<{obj.__class__.__name__}::{id(obj)}>' # 构造命名空间名称 + super().__init__(name, obj) # 调用父类初始化 def __getattr__(self, name): + # 获取命名空间中指定名称的属性 d, = self.dicts try: if hasattr(d, name): - return getattr(d, name) - return d.__dict__[name] + return getattr(d, name) # 如果对象有该属性,返回属性值 + return d.__dict__[name] # 否则从对象字典中获取 except ValueError: - raise UnboundLocalError(name) + raise UnboundLocalError(name) # 如果未找到,抛出UnboundLocalError except KeyError: logger.info(f"'{d.__class__.__name__ }' object has no attribute or method: '{name}', so will return None.") - raise AttributeError(name) + raise AttributeError(name) # 如果未找到属性,记录日志并抛出AttributeError + diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/resources.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/resources.py index b381ec3e..5f35e5d9 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/resources.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/resources.py @@ -17,7 +17,7 @@ """Resources for ast tree parse.""" import ast import math - + from mindspore import RowTensor, SparseTensor, COOTensor, CSRTensor from mindspore.ops import functional as F, composite as C from mindspore.ops.composite import multitype_ops @@ -25,16 +25,16 @@ from mindspore._c_expression import security from . import standard_method as M from . import trope as T from .namespace import CellNamespace - + # namespace define functional_ns = CellNamespace('mindspore.ops.functional') composite_ns = CellNamespace('mindspore.ops.composite') trope_ns = CellNamespace('mindspore._extends.parse.trope') - + NO_IMPLEMENT = None # not implemented SYMBOL_UNDEFINE = 0xFF # Undefined var and function - -# Some space set aside for readability of code + +# 一些空间设置以提高代码可读性 parse_object_map = { # ast grammar ast.Add: (trope_ns, 'add'), @@ -64,17 +64,17 @@ parse_object_map = { ast.IsNot: (trope_ns, 'is_not'), ast.In: (trope_ns, 'contains'), ast.NotIn: (trope_ns, 'not_contains'), - + # operation symbol type 'getitem': (composite_ns, 'getitem'), 'ms_iter': (composite_ns, 'ms_iter'), 'ms_next': (composite_ns, 'ms_next'), 'hasnext': (composite_ns, 'hasnext'), - + # undefined type SYMBOL_UNDEFINE: (None, 'undefine'), } - + # Operation symbols corresponding to ast grammar ops_symbol_map = { # ast grammar @@ -88,13 +88,13 @@ ops_symbol_map = { ast.LShift: '<<', ast.RShift: '>>', ast.BitXor: '^', - + # undefined type SYMBOL_UNDEFINE: '', } - -# Escape an object to another object, eg: system function(len,xxx) -# Some space set aside for readability of code + +# 将一个对象转为另一个对象,例如:系统函数(len,xxx) +# 一些空间设置以提高代码可读性 convert_object_map = { T.add: multitype_ops.add, T.sub: multitype_ops.sub, @@ -124,7 +124,7 @@ convert_object_map = { T.is_not: F.is_not, T.contains: multitype_ops.in_, T.not_contains: multitype_ops.not_in_, - + # system function T.len: M.ms_len, T.bool_: M.bool_, @@ -134,7 +134,7 @@ convert_object_map = { T.zip: C.zip_operation, T.enumerate: M.enumerate_, T.isinstance: M.isinstance_, - + # custom define operation T.iter: M.ms_iter, T.next: M.ms_next, @@ -145,7 +145,7 @@ convert_object_map = { T.make_slice: F.make_slice, T.range: F.make_range, T.while_cond: M.while_cond, - + # lib function math.floor: NO_IMPLEMENT, math.trunc: NO_IMPLEMENT, @@ -154,13 +154,14 @@ convert_object_map = { math.sin: NO_IMPLEMENT, math.cos: NO_IMPLEMENT, math.tan: NO_IMPLEMENT, - + # user defined RowTensor: F.make_row_tensor, SparseTensor: F.make_sparse_tensor, COOTensor: F.make_coo_tensor, CSRTensor: F.make_csr_tensor } - + +# 如果不启用安全性,则将 T.print 映射到 F.print_ if not security.enable_security(): - convert_object_map[T.print] = F.print_ + convert_object_map[T.print] = F.print_ \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/standard_method.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/standard_method.py index fb6cbbb8..34731c01 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -17,10 +17,10 @@ """standard_method""" from dataclasses import dataclass - + from mindspore import Tensor, Parameter, CSRTensor, COOTensor from mindspore import dtype as mstype - + from ..._checkparam import Validator as validator from ...ops import functional as F from ...ops import operations as P @@ -32,10 +32,10 @@ from ...ops.composite.multitype_ops import _constexpr_utils as const_utils from ...ops.composite.multitype_ops import _compile_utils as compile_utils from ...ops.operations._inner_ops import Format from ...ops.primitive import constexpr - - + + __all__ = ['MultitypeFuncGraph', 'env_get', 'hyper_add', 'zeros_like', 'ones_like'] - + shape_ = P.Shape() dtype_ = P.DType() abs_ = P.Abs() @@ -46,30 +46,30 @@ _format = Format() _reduce_sum_default = P.ReduceSum() _reduce_sum_keepdims = P.ReduceSum(True) _mean_keepdims = P.ReduceMean(True) - + itemsize_map = {mstype.bool_: 1, mstype.int8: 1, mstype.uint8: 1, mstype.float16: 2, mstype.int16: 2, mstype.uint16: 2, mstype.float32: 4, mstype.int32: 4, mstype.uint32: 4, mstype.float64: 8, mstype.int64: 8, mstype.uint64: 8} - + nan_tensor = Tensor(float('nan'), dtype=mstype.float32) - - + + def mean(x, axis=(), keep_dims=False): """ Reduces a dimension of a tensor by averaging all elements in the dimension. - + Args: axis (Union[None, int, tuple(int), list(int)]): Dimensions of reduction, - when axis is None or empty tuple, reduce all dimensions. Default: (). + when axis is None or empty tuple, reduce all dimensions. Default: (). keep_dims (bool): Whether to keep the reduced dimensions. Default: False. - + Returns: Tensor, has the same data type as input tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -82,36 +82,36 @@ def mean(x, axis=(), keep_dims=False): axis = () reduce_mean = P.ReduceMean(keep_dims) return reduce_mean(x, axis) - - + + def all_(x, axis=(), keep_dims=False): """ Check all array elements along a given axis evaluate to True. - + Args: x (Tensor): A Tensor to be reduced. axis (Union[None, int, tuple(int)): Dimensions of reduction. keep_dims (bool): Whether to keep the reduced dimensions. - + Returns: Tensor, has the same data type as x. """ - + if axis is None: axis = () reduce_all = P.ReduceAll(keep_dims) return reduce_all(x, axis) - - + + def any_(x, axis=(), keep_dims=False): """ Check any array element along a given axis evaluate to True. - + Args: x (Tensor): A Tensor to be reduced. axis (Union[None, int, tuple(int)): Dimensions of reduction. keep_dims (bool): Whether to keep the reduced dimensions. - + Returns: Tensor, has the same data type as x. """ @@ -119,59 +119,59 @@ def any_(x, axis=(), keep_dims=False): axis = () reduce_any = P.ReduceAny(keep_dims) return reduce_any(x, axis) - - + + def size_(x): """ Return the number of elements in tensor `x`. - + Note: To strictly follow Numpy's behaviour, return 1 for tensor scalar. - + Args: x (Tensor): Input tensor. - + Returns: size(int). """ if not shape_(x): return size_op_(x) + 1 return size_op_(x) - - + + def itemsize_(x): """ Return length of one tensor element in bytes. - + Args: x (Tensor): Input tensor. - + Returns: itemsize(int). """ return get_itemsize(x.dtype) - - + + def nbytes_(x): """ Return total number of bytes taken by the tensor. - + Args: x (Tensor): Input tensor. - + Returns: nbytes(int). """ return itemsize_(x) * F.shape_mul(shape_(x)) - - + + def strides_(x): """ Return the tuple of bytes to step in each dimension when traversing a tensor. - + Args: x (Tensor): Input tensor. - + Returns: strides (tuple[int]). """ @@ -184,12 +184,12 @@ def strides_(x): stride *= tensor_shape[j] strides += (stride,) return strides - - + + def astype(x, dtype, copy=True): # pylint: disable=redefined-outer-name """ Return a copy of the tensor, casted to a specified type. - + Args: dtype (Union[:class:`mindspore.dtype`, str]): Designated tensor dtype, can be in format of :class:`mindspore.dtype.float32` or `float32`. @@ -197,16 +197,16 @@ def astype(x, dtype, copy=True): # pylint: disable=redefined-outer-name copy (bool, optional): By default, astype always returns a newly allocated tensor. If this is set to false, the input tensor is returned instead of a copy if possible. Default: True. - + Returns: Tensor, with the designated dtype. - + Raises: TypeError: If `dtype` has types not specified above, or values cannot be understood. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -219,35 +219,35 @@ def astype(x, dtype, copy=True): # pylint: disable=redefined-outer-name if not copy and dtype == x.dtype: return x return F.cast(x, dtype) - - + + def transpose(x, *axis): r""" Return a view of the tensor with axes transposed. - + For a 1-D tensor this has no effect, as a transposed vector is simply the same vector. For a 2-D tensor, this is a standard matrix transpose. For a n-D tensor, if axes are given, their order indicates how the axes are permuted. If axes are not provided and tensor.shape = (i[0], i[1],...i[n-2], i[n-1]), then tensor.transpose().shape = (i[n-1], i[n-2], ... i[1], i[0]). - + Args: axes(Union[None, tuple(int), list(int), int], optional): If axes is None or blank, tensor.transpose() will reverse the order of the axes. If axes is tuple(int) or list(int), tensor.transpose() will transpose the tensor to the new axes order. If axes is int, this form is simply intended as a convenience alternative to the tuple/list form. - + Returns: Tensor, has the same dimension as input tensor, with axes suitably permuted. - + Raises: TypeError: If input arguments have types not specified above. ValueError: If the number of `axes` is not euqal to a.ndim. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -259,32 +259,32 @@ def transpose(x, *axis): ndim = F.rank(x) perm = check_transpose_axis_const(axis, ndim) return F.transpose(x, perm) - - + + # `tensor.T` is used as a property in graph mode T_ = transpose - - + + def reshape(x, *shape): """ Give a new shape to a tensor without changing its data. - + Args: shape(Union[int, tuple(int), list(int)]): The new shape should be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions. - + Returns: Tensor, with new specified shape. - + Raises: TypeError: If new_shape is not integer, list or tuple, or `x` is not tensor. ValueError: If new_shape is not compatible with the original shape. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> from mindspore import Tensor >>> from mindspore import dtype as mstype @@ -297,18 +297,18 @@ def reshape(x, *shape): """ new_shape = check_reshape_shp_const(shape) return F.reshape(x, new_shape) - - + + def ravel(x): """ Return a contiguous flattened tensor. - + Returns: Tensor, a 1-D tensor, containing the same elements of the input. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -318,27 +318,27 @@ def ravel(x): (24,) """ return reshape(x, (-1,)) - - + + def flatten(x, order='C'): r""" Return a copy of the tensor collapsed into one dimension. - + Args: order (str, optional): Can choose between 'C' and 'F'. 'C' means to flatten in row-major (C-style) order. 'F' means to flatten in column-major (Fortran-style) order. Only 'C' and 'F' are supported. Default: 'C'. - + Returns: Tensor, has the same data type as input. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Raises: TypeError: If `order` is not string type. ValueError: If `order` is string type, but not 'C' or 'F'. - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -350,30 +350,30 @@ def flatten(x, order='C'): order = check_flatten_order_const(order) if order == 'C': return F.reshape(x, (-1,)) - + perm = F.make_range(0, F.rank(x)) new_order = F.tuple_reversed(perm) return F.reshape(F.transpose(x, new_order), (-1,)) - - + + def swapaxes(x, axis1, axis2): """ Interchange two axes of a tensor. - + Args: axis1 (int): First axis. axis2 (int): Second axis. - + Returns: Transposed tensor, has the same data type as the input. - + Raises: TypeError: If `axis1` or `axis2` is not integer. ValueError: If `axis1` or `axis2` is not in the range of :math:`[-ndim, ndim-1]`. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -383,12 +383,12 @@ def swapaxes(x, axis1, axis2): (4,3,2) """ axis1, axis2 = check_swapaxes_axis_const((axis1, axis2), x.ndim) - + if axis1 == axis2: return x if axis1 > axis2: axis1, axis2 = axis2, axis1 - + perm = F.make_range(0, x.ndim) new_perm = None if axis2 + 1 < x.ndim: @@ -397,27 +397,27 @@ def swapaxes(x, axis1, axis2): else: new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \ perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] - + return F.transpose(x, new_perm) - - + + def squeeze(x, axis=None): """ Remove single-dimensional entries from the shape of a tensor. - + Args: axis (Union[None, int, list(int), tuple(int)], optional): Default is None. - + Returns: Tensor, with all or a subset of the dimensions of length 1 removed. - + Raises: TypeError: If input arguments have types not specified above. - ValueError: If specified axis has shape entry :math:`> 1`. - + ValueError: If specified axis has shape entry :math:`> 1`. + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -432,27 +432,27 @@ def squeeze(x, axis=None): # yield squeezed shape based on the axes new_shape = prepare_shape_for_squeeze_const(shape, axis) return F.reshape(x, new_shape) - - + + def argmax(x, axis=None): """ Returns the indices of the maximum values along an axis. - + Args: axis (int, optional): By default, the index is into the flattened array, otherwise along the specified axis. Defaults to None. - + Returns: Tensor, array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. - + Raises: ValueError: if axis is out of range. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -468,28 +468,28 @@ def argmax(x, axis=None): else: axis = check_axis_in_range_const(axis, F.rank(x)) return P.Argmax(axis)(x) - - + + def argmin(x, axis=None): """ Returns the indices of the minimum values along an axis. - + Args: a (Union[int, float, bool, list, tuple, Tensor]): Input array. axis (int, optional): By default, the index is into the flattened array, otherwise along the specified axis. Defaults to None. - + Returns: Tensor, array of indices into the array. It has the same shape as a.shape with the dimension along axis removed. - + Raises: ValueError: if axis is out of range. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -506,16 +506,16 @@ def argmin(x, axis=None): axis = check_axis_in_range_const(axis, F.rank(x)) # P.Argmin is currently not supported return P.Argmax(axis)(F.neg_tensor(x)) - - + + def cumsum(x, axis=None, dtype=None): """ Returns the cumulative sum of the elements along a given axis. - + Note: If ``x.dtype`` is :class:`int8`, :class:`int16` or :class:`bool`, the result `dtype` will be elevated to :class:`int32`, :class:`int64` is not supported. - + Args: x (Tensor): Input tensor. axis (int, optional): Axis along which the cumulative sum is computed. The @@ -523,13 +523,13 @@ def cumsum(x, axis=None, dtype=None): dtype (:class:`mindspore.dtype`, optional): If not specified, stay the same as original, tensor, unless it has an integer dtype with a precision less than :class:`float32`. In that case, :class:`float32` is used. - + Returns: Tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -552,24 +552,24 @@ def cumsum(x, axis=None, dtype=None): if dtype is not None and original_dtype != dtype: return cumsum_(x, axis).astype(dtype, copy=False) return cumsum_(x, axis) - - + + def copy(x): """ Returns a copy of the tensor. - + Note: The current implementation does not support `order` argument. - + Args: x (Tensor): Input tensor. - + Returns: Copied tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -590,12 +590,12 @@ def copy(x): x = x / 1.0 x = x.astype(origin_dtype) return x - - + + def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin """ Returns the maximum of a tensor or maximum along an axis. - + Args: x (Tensor): Input Tensor. axis (None or int or tuple of ints, optional): defaults to None. Axis or @@ -613,17 +613,17 @@ def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disab A boolean array which is broadcasted to match the dimensions of array, and selects elements to include in the reduction. If non-default value is passed, initial must also be provided. - + Returns: Tensor or scalar, maximum of input tensor. If `axis` is None, the result is a scalar value. If `axis` is given, the result is an array of dimension ``a.ndim - 1``. - + Raises: TypeError: if the input is not a tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -635,12 +635,12 @@ def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disab """ return compile_utils.reduce_(x, P.ReduceMax(keepdims), cmp_fn=F.maximum, axis=axis, keepdims=keepdims, initial=initial, where=where) - - + + def min(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin """ Returns the minimum of a tensor or minimum along an axis. - + Args: a (Tensor): Input data. axis (None or int or tuple of ints, optional): defaults to None. Axis or @@ -658,17 +658,17 @@ def min(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disab A boolean array which is broadcasted to match the dimensions of array, and selects elements to include in the reduction. If non-default value is passed, initial must also be provided. - + Returns: Tensor or scalar, minimum of `a`. If axis is None, the result is a scalar value. If `axis` is given, the result is an array of dimension ``a.ndim - 1``. - + Raises: TypeError: if the input is not a tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -680,26 +680,26 @@ def min(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disab """ return compile_utils.reduce_(x, P.ReduceMin(keepdims), cmp_fn=F.minimum, axis=axis, keepdims=keepdims, initial=initial, where=where) - - + + def resize(x, *new_shape): """ Changes shape and size of array in-place. - + Note: Instead of changing the size of the input array and returns nothing as in numpy, this method returns a new Tensor with the input size. Numpy argument `refcheck` is not supported. - + Args: new_shape (Union[ints, tuple of ints]): Shape of resized array. - + Returns: Tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> from mindspore import numpy as np >>> x = np.array([[0, 1], [2, 3]]) @@ -723,12 +723,12 @@ def resize(x, *new_shape): else: res = flattened[:new_size] return res.reshape(new_shape) - - + + def diagonal(x, offset=0, axis1=0, axis2=1): """ Returns specified diagonals. - + Args: offset (int, optional): Offset of the diagonal from the main diagonal. Can be positive or negative. Defaults to main diagonal. @@ -738,16 +738,16 @@ def diagonal(x, offset=0, axis1=0, axis2=1): axis2 (int, optional): Axis to be used as the second axis of the 2-D sub-arrays from which the diagonals should be taken. Defaults to second axis. - + Returns: Tensor, if `a` is 2-D, then `a` 1-D array containing the diagonal. - + Raises: ValueError: if the input tensor has less than two dimensions. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import mindspore.numpy as np >>> a = np.arange(4).reshape(2,2) @@ -762,7 +762,7 @@ def diagonal(x, offset=0, axis1=0, axis2=1): if ndim < 2: const_utils.raise_value_error('diagonal requires an array of at least two dimensions') dtype = x.dtype - + axes = check_axis_valid((axis1, axis2), ndim) perm = () for i in range(ndim): @@ -770,10 +770,10 @@ def diagonal(x, offset=0, axis1=0, axis2=1): perm += (i,) perm += axes x = x.transpose(perm) - + shape = x.shape n, m = shape[-2:] - + e = F.eye(n, m, dtype) if offset >= m or offset <= -n: e = F.fill(dtype, (n, m), 0) @@ -788,10 +788,10 @@ def diagonal(x, offset=0, axis1=0, axis2=1): e_lower = e[0:n+offset:1, ...] e = P.Concat(0)((e_upper, e_lower)).astype(dtype) e = P.BroadcastTo(shape)(e) - + prod = F.tensor_mul(x, e) res = F.reduce_sum(prod.astype(mstype.float32), -1) - + begin = () for i in range(ndim-2): begin += (0,) @@ -805,12 +805,12 @@ def diagonal(x, offset=0, axis1=0, axis2=1): size += (last_dim_end,) res = F.tensor_slice(res, begin, size) return res.astype(dtype) - - + + def trace(x, offset=0, axis1=0, axis2=1, dtype=None): """ Returns the sum along diagonals of the array. - + Args: offset (int, optional): Offset of the diagonal from the main diagonal. Can be positive or negative. Defaults to main diagonal. @@ -822,16 +822,16 @@ def trace(x, offset=0, axis1=0, axis2=1, dtype=None): second axis. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. - + Returns: Tensor, sum_along_diagonals. - + Raises: ValueError: if the input tensor has less than two dimensions. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import mindspore.numpy as np >>> x = np.eye(3) @@ -846,36 +846,36 @@ def trace(x, offset=0, axis1=0, axis2=1, dtype=None): return F.fill(dtype, shape[:-1], 0) res = F.reduce_sum(d.astype(mstype.float32), -1) return res.astype(dtype) - - + + def take(x, indices, axis=None, mode='clip'): """ Takes elements from an array along an axis. - + Args: a (Tensor): Source array with shape `(Ni…, M, Nk…)`. indices (Tensor): The indices with shape `(Nj...)` of the values to extract. axis (int, optional): The axis over which to select values. By default, the flattened input array is used. Defaults to None. mode ('raise', 'wrap', 'clip', optional): Defaults to "clip". - + - edge: Pads with the edge values of `arr`. - raise: Raises an error; - wrap: Wraps around; - clip: Clips to the range. 'clip' mode means that all indices that are too large are replaced by the index that addresses the last element along that axis. Note that this disables indexing with negative numbers. - + Returns: Tensor, the indexed result. - + Raises: ValueError: if axis is out of range. TypeError: if the input is not a Tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import mindspore.numpy as np >>> a = np.array([4, 3, 5, 7, 6, 8]) @@ -893,12 +893,12 @@ def take(x, indices, axis=None, mode='clip'): a = x ndim = a.ndim axis = check_axis_in_range_const(axis, ndim) - + shape_a = a.shape shape_indices = indices.shape size_indices = indices.size indices = compile_utils.check_indices(shape_a[axis], indices, mode) - + # reshapes indices to shape (Ni..., Nj..., Nk) shape_ni = tuple_slice(shape_a, None, axis) shape_nk = tuple_slice(shape_a, axis + 1, None) @@ -907,15 +907,15 @@ def take(x, indices, axis=None, mode='clip'): indices = indices.reshape(shape_indices) shape_indices = shape_ni + (indices.size,) + shape_nk indices = P.BroadcastTo(shape_indices)(indices) - + res = F.gather_d(a, axis, indices) return res.reshape(shape_out) - - + + def choose(x, choices, mode='clip'): """ Construct an array from an index array and a list of arrays to choose from. - + Args: choices (sequence of arrays): Choice arrays. `a` and all of the `choices` must be broadcastable to the same shape. If `choices` is itself an array, then @@ -923,24 +923,24 @@ def choose(x, choices, mode='clip'): is taken as defining the "sequence". mode ('raise', 'wrap', 'clip', optional): Specifies how indices outside ``[0, n-1]`` will be treated: - + 'raise' – raise an error (default); - + 'wrap' – wrap around; - + 'clip' – clip to the range. 'clip' mode means that all indices that are too large are replaced by the index that addresses the last element along that axis. Note that this disables indexing with negative numbers. - + Returns: Tensor, the merged result. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Raises: ValueError: if ``len(condlist) != len(choicelist)``. - + Examples: >>> import mindspore.numpy as np >>> choices = [[0, 1, 2, 3], [10, 11, 12, 13], [20, 21, 22, 23], [30, 31, 32, 33]] @@ -965,7 +965,7 @@ def choose(x, choices, mode='clip'): for choice in choicelist: tmp.append(P.BroadcastTo(shape_choice)(choice)) choices = F.stack(tmp) - + if x.ndim == 0 or choices.ndim == 0: const_utils.raise_value_error('input cannot be scalars') a = P.BroadcastTo(shape_choice)(x) @@ -974,7 +974,7 @@ def choose(x, choices, mode='clip'): a = a.astype(mstype.int32) choices = choices.astype(mstype.int32) a = compile_utils.check_indices(choices.shape[0], a, mode, allow_negative_index=False) - + grids = [] ndim = len(a.shape) for i in range(ndim): @@ -985,12 +985,12 @@ def choose(x, choices, mode='clip'): grid = P.Stack(-1)(grids) indices = P.Concat(-1)((a.reshape(a.shape + (1,)), grid)) return F.gather_nd(choices, indices).astype(dtype) - - + + def searchsorted(x, v, side='left', sorter=None): """ Finds indices where elements should be inserted to maintain order. - + Args: v (Union[int, float, bool, list, tuple, Tensor]): Values to insert into `a`. side ('left', 'right', optional): If 'left', the index of the first suitable @@ -999,16 +999,16 @@ def searchsorted(x, v, side='left', sorter=None): sorter (Union[int, float, bool, list, tuple, Tensor]): 1-D optional array of integer indices that sort array `a` into ascending order. They are typically the result of argsort. - + Returns: Tensor, array of insertion points with the same shape as `v`. - + Raises: ValueError: if argument for `side` or `sorter` is invalid. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> from mindspore import numpy as np >>> x = np.array([1,2,3,4,5]) @@ -1030,7 +1030,7 @@ def searchsorted(x, v, side='left', sorter=None): less_op = F.tensor_le if side == 'left' else F.tensor_lt i = F.fill(mstype.int32, shape, 0) j = F.fill(mstype.int32, shape, a.size) - + sort_range = F.make_range(get_log2_size(F.shape_mul(a.shape) + 1)) for _ in sort_range: mid = (i - F.neg_tensor(j))//2 @@ -1038,29 +1038,29 @@ def searchsorted(x, v, side='left', sorter=None): i = F.select(mask, i, mid) j = F.select(mask, mid, j) return j - - + + def fill(x, value): """ Fills the array with a scalar value. - + Note: Unlike Numpy, tensor.fill() will always returns a new tensor, instead of filling the original tensor. - + Args: value (Union[None, int, float, bool]): All elements of a will be assigned this value. - + Returns: Tensor, with the original dtype and shape as input tensor. - + Raises: TypeError: If input arguments have types not specified above. ValueError: If `shape` has entries < 0. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import numpy as np >>> from mindspore import Tensor @@ -1077,30 +1077,30 @@ def fill(x, value): if not isinstance(value, (int, float, bool)): const_utils.raise_type_error("input value must be a scalar.") return F.fill(x.dtype, x.shape, value) - - + + def ptp(x, axis=None, keepdims=False): """ The name of the function comes from the acronym for "peak to peak". - + Note: Numpy arguments `dtype` and `out` are not supported. - + Args: x (Tensor): Input tensor. axis (Union[None, int, tuple(int)]): Axis or axes along which the range is computed. The default is to compute the variance of the flattened array. Default: None. keepdims (bool): Default is False. - + Returns: Tensor. - + Raises: TypeError: if the input is not a tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> from mindspore import Tensor >>> x = Tensor([[4.0, 9.0, 2.0, 10.0], [6.0, 9.0, 7.0, 12.0]]).astype("float32") @@ -1116,21 +1116,21 @@ def ptp(x, axis=None, keepdims=False): else: check_axis_type(axis, True, True, False) axis = check_axis_valid(axis, x.ndim) - + return x.max(axis, keepdims) - x.min(axis, keepdims) - - + + def clip(x, xmin, xmax, dtype=None): """ Clips (limits) the values in an array. - + Given an interval, values outside the interval are clipped to the interval edges. For example, if an interval of :math:`[0, 1]` is specified, values smaller than 0 become 0, and values larger than 1 become 1. - + Note: Currently, clip with `nan` is not supported. - + Args: x (Tensor): Tensor containing elements to clip. xmin (Tensor, scalar, None): Minimum value. If None, clipping is not performed @@ -1141,14 +1141,14 @@ def clip(x, xmin, xmax, dtype=None): to match their shapes. dtype (:class:`mindspore.dtype`, optional): defaults to None. Overrides the dtype of the output Tensor. - + Returns: Tensor, a tensor with the elements of `x`, but where values < `xmin` are replaced with `xmin`, and those > `xmax` with `xmax`. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> from mindspore import Tensor >>> x = Tensor([1, 2, 3, -4, 0, 3, 2, 0]).astype("float32") @@ -1176,20 +1176,20 @@ def clip(x, xmin, xmax, dtype=None): if dtype is not None and dtype != x.dtype: return x.astype(dtype) return x - - + + def var(x, axis=None, ddof=0, keepdims=False): """ Compute the variance along the specified axis. The variance is the average of the squared deviations from the mean, i.e., :math:`var = mean(abs(x - x.mean())**2)`. - + Return the variance, which is computed for the flattened array by default, otherwise over the specified axis. - + Note: Numpy arguments `dtype`, `out` and `where` are not supported. - + Args: x (Tensor): A Tensor to be calculated. axis (Union[None, int, tuple(int)]): Axis or axes along which the variance is computed. @@ -1197,13 +1197,13 @@ def var(x, axis=None, ddof=0, keepdims=False): ddof (int): Means Delta Degrees of Freedom. Default: 0. The divisor used in calculations is :math:`N - ddof`, where :math:`N` represents the number of elements. keepdims (bool): Default: `False`. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Returns: Standard deviation tensor. - + Examples: >>> import mindspore.numpy as np >>> input_x = np.array([1., 2., 3., 4.]) @@ -1214,7 +1214,7 @@ def var(x, axis=None, ddof=0, keepdims=False): return nan_tensor.astype(x.dtype) if not isinstance(ddof, int) or not isinstance(keepdims, int): const_utils.raise_type_error("integer argument expected") - + if axis is None: axis = () else: @@ -1226,43 +1226,43 @@ def var(x, axis=None, ddof=0, keepdims=False): x_sum = _reduce_sum_keepdims(x_pow, axis) else: x_sum = _reduce_sum_default(x_pow, axis) - + if axis == (): axis = F.make_range(x.ndim) nums = 1 for ax in axis: nums *= x.shape[ax] return F.tensor_div(x_sum, nums - ddof) - - + + def std(x, axis=None, ddof=0, keepdims=False): """ Compute the standard deviation along the specified axis. The standard deviation is the square root of the average of the squared deviations from the mean, i.e., :math:`std = sqrt(mean(abs(x - x.mean())**2))`. - + Return the standard deviation, which is computed for the flattened array by default, otherwise over the specified axis. - + Note: Numpy arguments `dtype`, `out` and `where` are not supported. - + Args: x (Tensor): A Tensor to be calculated. axis (Union[None, int, tuple(int)]): Axis or axes along which the standard deviation is computed. Default: `None`. - + If `None`, compute the standard deviation of the flattened array. ddof (int): Means Delta Degrees of Freedom. The divisor used in calculations is :math:`N - ddof`, where :math:`N` represents the number of elements. Default: 0. keepdims: Default: `False`. - + Returns: Standard deviation tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import mindspore.numpy as np >>> input_x = np.array([1., 2., 3., 4.]) @@ -1271,16 +1271,16 @@ def std(x, axis=None, ddof=0, keepdims=False): """ x_var = var(x, axis, ddof, keepdims) return F.tensor_pow(x_var, 0.5) - - + + def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disable=redefined-builtin """ Return sum of array elements over a given axis. - + Note: Numpy arguments `out`, `where`, `casting`, `order`, `subok`, `signature`, and `extobj` are not supported. - + Args: x (Union[int, float, bool, list, tuple, Tensor]): Elements to sum. axis (Union[None, int, tuple(int)]): Axis or axes along which a sum is performed. Default: None. @@ -1296,19 +1296,19 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab sub-classes of ndarray, however any non-default value will be. If the sub-class method does not implement keepdims any exceptions will be raised. initial (scalar): Starting value for the sum. - + Returns: Tensor. A tensor with the same shape as input, with the specified axis removed. If input tensor is a 0-d array, or if axis is None, a scalar is returned. - + Raises: TypeError: If input is not array_like or `axis` is not int or tuple of ints or `keepdims` is not integer or `initial` is not scalar. ValueError: If any axis is out of range or duplicate axes exist. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import mindspore.numpy as np >>> input_x = np.array([-1, 0, 1]).astype('int32') @@ -1328,7 +1328,7 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab axis = () else: axis = check_and_canonicalize_axes(axis, x.ndim) - + if not check_type_support(input_x.dtype, 'GPU', (mstype.float64, mstype.float32, mstype.float16)): input_x = input_x.astype(mstype.float32) if 0 in x.shape: @@ -1340,29 +1340,29 @@ def sum(x, axis=None, dtype=None, keepdims=False, initial=None): # pylint: disab if initial is not None: res += initial return res.astype(dtype) - - + + def repeat(x, repeats, axis=None): """ Repeat elements of an array. - + Args: x (Tensor): Input tensor. repeats (Union[int, tuple, list]): The number of repetitions for each element. `repeats` is broadcasted to fit the shape of the given axis. axis (int, optional): The axis along which to repeat values. By default, use the flattened input tensor, and return a flat output tensor. - + Returns: Tensor, has the same shape as input tensor except along the given axis. - + Raises: ValueError: if axis is out of range. TypeError: if input is not a Tensor. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> import mindspore.numpy as np >>> x = np.array(3) @@ -1391,7 +1391,7 @@ def repeat(x, repeats, axis=None): const_utils.raise_type_error('axes should be integers') check_axis_in_range_const(axis, x.ndim) axis = axis + x.ndim if axis < 0 else axis - + if len(repeats) == 1: repeats = repeats[0] if repeats == 0: @@ -1406,83 +1406,83 @@ def repeat(x, repeats, axis=None): if rep != 0: repeated_subs.append(repeat_elements(sub, rep, axis)) return P.Concat(axis)(repeated_subs) - - + + def getitem(data, index): """Implementation of `getitem`.""" return data.__getitem__(index) - - + + def setitem(data, index, value): """Implementation of `setitem`.""" return data.__setitem__(index, value) - - + + def item(data, *args): """Implementation of `item`.""" return compile_utils.tensor_item(data, *args) - - + + def itemset(data, *args): """Implementation of `itemset`.""" return compile_utils.tensor_itemset(data, *args) - - + + def ms_iter(xs): """Implementation of `iter`.""" return xs.__ms_iter__() - - + + def ms_next(it): """Implementation of `next`.""" return it.__ms_next__() - - + + def hasnext(it): """Implementation of `hasnext`.""" return it.__ms_hasnext__() - - + + def ms_len(data): """Implementation of `len`.""" return data.__len__() - - + + def floor(x): """Implementation of `floor`.""" return x.__floor__() - - + + def trunc(x): """Implementation of `trunc`.""" return x.__trunc__() - - + + def uadd(x): """Implementation of `uadd`.""" return x.__pos__() - - + + def usub(x): """Implementation of `usub`.""" return x.__neg__() - - + + def scalar_truediv(x, y): """Implementation of `scalar_truediv`.""" return x.__truediv__(y) - - + + def scalar_floordiv(x, y): """Implementation of `scalar_floordiv`.""" return x.__floordiv__(y) - - + + def bool_(x): """Implementation of `bool`.""" return x.__bool__() - - + + def enumerate_(x, start=0): """Enumerate list or tuple or tensor.""" x_type = F.typeof(x) @@ -1496,22 +1496,22 @@ def enumerate_(x, start=0): else: ret = zip(range(start, start + len(x)), x) return ret - - + + def expand_tensor_as(x, y): """Expand tensor""" broadcast_to = P.BroadcastTo(shape_(y)) return broadcast_to(x) - - + + def expand_dims(x, axis): """ Insert a dimension of shape 1 at the specified axis of Tensor """ check_is_int(axis, 'axis') return P.ExpandDims()(x, axis) - - + + def masked_fill(x, mask, value): """ Fills elements of self tensor with value where mask is True. @@ -1523,34 +1523,34 @@ def masked_fill(x, mask, value): mask = P.BroadcastTo(mask_shape)(mask) check_value_type('value', value, [int, float], "Tensor") return C.array_ops.masked_fill(x, mask, value) - - + + def narrow(x, axis, start, length): """ Returns a narrowed tensor from input tensor. The dimension axis is input from start to start + length. """ return F.narrow(x, axis, start, length) - - + + def view(x, *shape): """Reshape tensor, if shape is -1, reshape tensor into one dimension""" shape = check_view_shape(shape) return F.reshape(x, shape) - - + + @constexpr def check_is_tuple(x): """check whether x is tuple.""" return isinstance(x, mstype.Tuple) - - + + @constexpr def check_is_func(x): """check whether x is function.""" return isinstance(x, mstype.function_type) - - + + def isinstance_(x, base_type): """Determine whether x is an instance of base.""" x_type = F.typeof(x) @@ -1565,8 +1565,8 @@ def isinstance_(x, base_type): if check_is_func(F.typeof(base_type)) and base_type.__is_csr_func__(): cmp_type = mstype.csr_tensor_type return check_type_same(x_type, cmp_type) - - + + def while_cond(x): """For while condition, if the condition is a tensor, the loop will not be unrolled""" if F.issubclass_(F.typeof(x), F.typeof(mstype.tensor)): @@ -1574,8 +1574,8 @@ def while_cond(x): if is_cond: return F.cast(x, mstype.bool_) return x - - + + def coo_to_csr(x): """convert coo to csr.""" row_indices = x.indices[:, 0] @@ -1587,33 +1587,33 @@ def coo_to_csr(x): values = x.values[sort_idx] indptr = F.coo2csr(row_indices, x.shape[0]) return CSRTensor(indptr, col_indices, values, x.shape) - - + + def coo_to_dense(x): """convert coo to dense.""" zeros_tensor = F.zeros(x.shape, x.values.dtype) return F.tensor_scatter_update(zeros_tensor, x.indices, x.values) - - + + def csr_to_coo(x): """convert csr to coo.""" row_indices = F.csr2coo(x.indptr, x.values.shape[0]) coo_indices = P.Stack(1)((row_indices, x.indices)) return COOTensor(coo_indices, x.values, x.shape) - - + + def csr_to_dense(x): """convert csr to dense.""" coo_tensor = x.to_coo() return coo_tensor.to_dense() - - + + @constexpr def empty_tensor(dtype): """Return empty tensor""" return Tensor([], dtype) - - + + @constexpr def check_type_same(x_type, base_type): """Check x_type is same as base_type.""" @@ -1630,10 +1630,10 @@ def check_type_same(x_type, base_type): slice: mstype.Slice, } sparse_mstype_set = (mstype.csr_tensor_type,) - + has_int = False has_tensor = False - + def to_target_type(origin_type): try: if isinstance(origin_type, type): @@ -1642,7 +1642,7 @@ def check_type_same(x_type, base_type): ret_type = pytype_to_mstype[origin_type] elif origin_type in sparse_mstype_set: ret_type = origin_type - + if ret_type == mstype.Int: nonlocal has_int has_int = True @@ -1661,30 +1661,30 @@ def check_type_same(x_type, base_type): if (isinstance(x_type, mstype.Bool) and has_int) or (isinstance(x_type, mstype.ref_type) and has_tensor): return True return isinstance(x_type, target_type) - - + + @constexpr def get_itemsize(x_type): """get itemsize from tensor's dtype.""" return itemsize_map[x_type] - - + + @constexpr def check_is_tensor(x): """check whether x is tensor.""" if isinstance(x, mstype.tensor_type): return True return False - - + + @constexpr def check_is_tuple_or_list_or_tensor(x, op_name, arg_name): """check whether x is list or tuple or tensor.""" if isinstance(x, (mstype.List, mstype.Tuple, mstype.tensor_type)): return True raise TypeError(f"For '{op_name}', the '{arg_name}' should be tuple or list or tensor, but got {x}.") - - + + @constexpr def check_is_const_int(x, op_name, arg_name): """check whether x is const int.""" @@ -1693,16 +1693,16 @@ def check_is_const_int(x, op_name, arg_name): if not isinstance(x, int): raise TypeError(f"For '{op_name}', the '{arg_name}' should be a const int number, but got {x}.") return True - - + + @constexpr def check_is_tensor_bool_cond(shp): """check if tensor is a bool condition""" if shp in ((), (1,)): return True raise ValueError(f"Only tensor which shape is () or (1,) can be converted to bool, but got tensor shape is {shp}") - - + + @constexpr def const_tensor_to_bool(x): """convert bool tensor to bool condition""" @@ -1715,8 +1715,8 @@ def const_tensor_to_bool(x): return bool(x[0]) raise ValueError( f"Only tensor which shape is () or (1,) can be converted to bool, but got tensor shape is {x.shape}") - - + + @constexpr def check_view_shape(x): """Check view function input shape""" @@ -1727,8 +1727,8 @@ def check_view_shape(x): raise ValueError(f"Only one tuple is needed, but got {x}") x = x[0] return x - - + + # convert normal param_check functions to constexpr functions check_astype_dtype_const = constexpr(validator.check_astype_dtype) check_transpose_axis_const = constexpr(validator.check_transpose_axis) @@ -1751,154 +1751,154 @@ check_type_support = constexpr(validator.check_type_support) check_is_int = constexpr(validator.check_is_int) check_type_name = constexpr(validator.check_type_name) check_value_type = constexpr(validator.check_value_type) - - + + def tensor_bool(x): """tensor as condition, if is constant, return immediate bool value""" is_cond = check_is_tensor_bool_cond(F.shape(x)) if is_cond and F.isconstant(x): return const_tensor_to_bool(x) return F.cast(x, mstype.bool_) - - + + def and_(x, y): """Implementation of `and` (`&`).""" return x.__and__(y) - - + + def or_(x, y): """Implementation of `or` (`|`).""" return x.__or__(y) - - + + def matmul(x, y): """Implementation of `matmul` (`@`).""" return x.__matmul__(y) - - + + def float_bool(x): """Implementation of `float_bool`.""" return x != 0.0 - - + + def int_bool(x): """Implementation of `int_bool`.""" return x != 0 - - + + def str_bool(x): """Implementation of `str_bool`.""" if x == "": return False return True - - + + def list_bool(x): """Implementation of `tuple_bool`.""" return len(x) != 0 - - + + def tuple_bool(x): """Implementation of `tuple_bool`.""" return len(x) != 0 - - + + def dict_bool(x): """Implementation of `dict_bool`.""" return len(x) != 0 - - + + def none_bool(x): """Implementation of `none_bool`.""" return False - - + + def func_bool(x): """Implementation of `func_bool`.""" return True - - + + def float_floordiv(x, y): """Implementation of `float_floordiv`.""" return floor(x / y) - - + + ############# # Iteration # ############# - - + + @dataclass(frozen=True) class SequenceIterator: """ SequenceIterator is a util dataclass for iterating sequence object. - + Iterator to use for sequences like List, Array. """ - + idx: int seq: list - + @core(ignore_values=True) def __ms_hasnext__(self): """Whether the index is past the length of the sequence.""" return self.idx < ms_len(self.seq) - + @core(ignore_values=True) def __ms_next__(self): """Return the next element and a new iterator.""" return self.seq[self.idx], SequenceIterator(self.idx + 1, self.seq) - - + + def list_iter(xs): """Iterator for List.""" return SequenceIterator(0, xs) - - + + def array_iter(xs): """Iterator for Array.""" return SequenceIterator(0, xs) - - + + def tuple_next(xs): """Next tuple.""" return xs[0], tail(xs) - - + + def tuple_hasnext(xs): """Whether the tuple is empty or not.""" return len(xs) > 0 - - + + def list_next(xs): """Next list.""" return xs[0], tail(xs) - - + + def list_hasnext(xs): """Whether the list is empty or not.""" return len(xs) > 0 - - + + # pylint: disable=redefined-outer-name def list_append(self_, item): return _append(self_, item) - - + + def list_insert(self_, index, obj): """Insert into list""" return _insert(self_, index, obj) - + ################# # Array methods # ################# - - + + def to_array(x): """Implementation of `to_array`.""" return x.__ms_to_array__() - - + + def filter_(fun, iter_): """Support the use of built-in function filter.""" result = [] @@ -1906,71 +1906,71 @@ def filter_(fun, iter_): if fun(elem): result.append(elem) return result - + ################## # Sparse methods # ################## - - + + def csr_astype(x, dtype): """Implementation of `astype` for CSRTensor.""" data = x.values.astype(dtype) return F.make_csr_tensor(x.indptr, x.indices, data, x.shape) - - + + def csr_sum(x, axis): """Implementation of `sum` for CSRTensor.""" return F.csr_reduce_sum(x, axis) - - + + def csr_abs(x): """Implementation of `abs` for CSRTensor.""" data = F.absolute(x.values) return F.make_csr_tensor(x.indptr, x.indices, data, x.shape) - - + + def csr_mv(x, dense_vector): """Implementation of `abs` for CSRTensor.""" check_value_type('dense_vector', dense_vector, (Tensor,), 'CSRTensor.mv') return F.csr_mv(x, dense_vector) - - + + def csr_to_tuple(x): """Implementation of `to_tuple` for CSRTensor.""" res = (x.indptr, x.indices, x.values, x.shape) return res - - + + def coo_astype(x, dtype): """Implementation of `astype` for COOTensor.""" data = x.values.astype(dtype) return F.make_coo_tensor(x.indices, data, x.shape) - - + + def coo_to_tuple(x): """Implementation of `to_tuple` for COOTensor.""" return x.indices, x.values, x.shape - - + + def coo_abs(x): """Implementation of `abs` for COOTensor.""" data = F.absolute(x.values) return F.make_coo_tensor(x.indices, data, x.shape) - + ################ # Sparse Attrs # ################ - - + + def sparse_size_(x): """ Return the size of SparseTensor.values. That is the number of non-zero values in SparseTensor. """ return size_(x.values) - - + + def sparse_ndim_(x): """ Return the ndim of SparseTensor, according to its dense shape. """ - return F.tuple_len(x.shape) + return F.tuple_len(x.shape) \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/trope.py b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/trope.py index 84ec9562..f5d06128 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/parse/trope.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/parse/trope.py @@ -50,55 +50,45 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', def MakeTuple(*elts): # pragma: no cover - """Tuple builder.""" + """Tuple builder.""" # 创建元组的构造函数 raise RuntimeError('This operation is not meant to be called directly.') - def make_dict(key, value): # pragma: no cover - """Dict builder.""" + """Dict builder.""" # 创建字典的构造函数 raise RuntimeError('This operation is not meant to be called directly.') - def make_list(*elts): # pragma: no cover - """List builder.""" + """List builder.""" # 创建列表的构造函数 raise RuntimeError('This operation is not meant to be called directly.') - def make_slice(*elts): # pragma: no cover - """Slice builder.""" + """Slice builder.""" # 创建切片的构造函数 raise RuntimeError('This operation is not meant to be called directly.') - def make_range(*elts): # pragma: no cover - """Range tuple builder.""" + """Range tuple builder.""" # 创建范围元组的构造函数 raise RuntimeError('This operation is not meant to be called directly.') - def switch(cond, tb, fb): # pragma: no cover - """Switch statement, returns one of the two values.""" + """Switch statement, returns one of the two values.""" # 返回两个值中的一个的开关语句 raise RuntimeError('This operation is not meant to be called directly.') - def hasnext(it): # pragma: no cover - """Hasnext function.""" + """Hasnext function.""" # 判断是否有下一个元素的函数 raise RuntimeError('This operation is not meant to be called directly.') - def to_array(x): - """The to_array function.""" + """The to_array function.""" # 将输入转换为数组的函数 raise RuntimeError('This operation is not meant to be called directly.') - def not_contains(x): # pragma: no cover - """Not in function.""" + """Not in function.""" # 判断元素是否不在集合中的函数 raise RuntimeError('This operation is not meant to be called directly.') - def while_cond(x): # pragma: no cover - """Not in function.""" + """Not in function.""" # 判断条件是否成立的函数 raise RuntimeError('This operation is not meant to be called directly.') - def bool_(x): # pragma: no cover - """judge true function.""" + """judge true function.""" # 判断一个值是否为真值的函数 raise RuntimeError('This operation is not meant to be called directly.') diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server.py b/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server.py index 72f589f3..02678768 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server.py @@ -16,27 +16,37 @@ import os from mindspore import log as logger from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process - - + + class Messager: - + '''Messager''' - + def __init__(self, fdin, fdout): + """ + 初始化 Messager 类 + + Args: + fdin: 输入文件描述符 + fdout: 输出文件描述符 + """ self.fdin = fdin self.fdout = fdout self.fin = os.fdopen(fdin, "r") self.fout = os.fdopen(fdout, "w") self.message = '' - + def __del__(self): + """ + 删除 Messager 实例时关闭文件描述符 + """ os.close(self.fdin) os.close(self.fdout) - + def get_message(self): """ - Get message from remote - + 从远程获取消息 + Returns: message """ @@ -58,13 +68,13 @@ class Messager: self.send_ack() self.exit() return self.message - + def send_res(self, res, keep_format=True): """ - Send result to remote - + 发送结果到远程 + Args: - keep_format: True or False + keep_format: True 或 False """ logger.debug(f"[OUT] {str(res)}") if keep_format: @@ -72,7 +82,7 @@ class Messager: else: res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '') tag = '[~]' # The same as client kTAG - + # Not write by print(tag + res_str, flush=True) any more try: self.fout.write(tag + res_str + "\n") @@ -82,69 +92,76 @@ class Messager: self.exit() finally: pass - + def send_ack(self, success=True): """ - Send ack to remote - + 发送确认消息到远程 + Args: - success: True or False + success: True 或 False """ if success: self.send_res('ACK') else: self.send_res('ERR') - + def loop(self): """ - Messaging loop + 消息循环 """ while True: self.handle() - + def run(self): + """运行消息循环""" self.loop() - + def handle(self): """ - A interface communicates with remote. - + 与远程通信的接口。 + Note: - All subclasses should override this interface. + 所有子类应该重写此接口。 """ raise NotImplementedError - + def exit(self): """ - A interface handles the procedure before exit. - + 处理退出之前的程序。 + Note: - All subclasses should override this interface. + 所有子类应该重写此接口。 """ raise NotImplementedError - - + + class AkgBuilder(): """Akg building wrapper""" - + def __init__(self, platform): + """ + 初始化 AkgBuilder 类 + + Args: + platform: 平台标识 + """ self.platform = platform self.attrs = None - + def create(self, process_num, waitime): - """ Create akg processor""" + """ 创建 akg 处理器""" self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform) - + def accept_json(self, json): - """ Accept json""" + """ 接受 json 数据""" return self.akg_processor.accept_json(json) - + def compile(self): - """Compile""" + """编译""" return self.akg_processor.compile(self.attrs) - + def handle(self, messager, arg): - """Handle message about akg""" + """处理关于 akg 的消息""" if arg == 'AKG/START': messager.send_ack() process_num_str = messager.get_message() @@ -172,7 +189,8 @@ class AkgBuilder(): break else: raise RuntimeError("Unknown message type: %s" % arg) - - + + def get_logger(): - return logger + """获取日志记录器""" + return logger \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server_akg.py b/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server_akg.py index bd1ee1fd..a81c1100 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server_akg.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server_akg.py @@ -20,19 +20,24 @@ from mindspore._extends.remote.kernel_build_server import Messager, get_logger, class AkgMessager(Messager): ''' - Default Messager for akg kernels. - It works as a server, communicating with c++ client. + 默认的 akg 内核消息处理器。 + 它作为一个服务器,与 C++ 客户端进行通信。 ''' def __init__(self, fdin, fdout): + """ + 初始化 AkgMessager 实例。 + :param fdin: 输入文件描述符 + :param fdout: 输出文件描述符 + """ super().__init__(fdin, fdout) get_logger().info("[TRACE] Akg Messager init...") self.akg_builder = AkgBuilder("default") def handle(self): """ - Communicate with remote client. - Reference protocol between them at PR#4063 + 与远程客户端进行通信。 + 它们之间的参考协议见 PR#4063。 """ arg = self.get_message() if "AKG" in arg: @@ -42,11 +47,18 @@ class AkgMessager(Messager): self.exit() def exit(self): + """ + 退出 AkgMessager。 + """ get_logger().info("[TRACE] Akg Messager Exit...") exit() if __name__ == '__main__': + """ + 程序入口。 + 检查命令行参数并初始化 AkgMessager 实例。 + """ warnings.simplefilter("ignore") if len(sys.argv) != 3: raise Exception('Incorrect argv: {}'.format(sys.argv)) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server_ascend.py b/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server_ascend.py index dc276dca..65469320 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server_ascend.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/remote/kernel_build_server_ascend.py @@ -16,23 +16,24 @@ import sys import warnings import json - + from mindspore._extends.parallel_compile.tbe_compiler.tbe_job_manager import TbeJobManager from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder - - + + class AscendMessager(Messager): """ Ascend Messager It works as a server, communicating with c++ client. """ - + # 初始化方法 def __init__(self, fdin, fdout): super().__init__(fdin, fdout) get_logger().info("[TRACE] Ascend Messager init...") self.tbe_builder = TbeJobManager() self.akg_builder = AkgBuilder("ASCEND") - + + # 处理与远程客户端的通信 def handle(self): """ Communicate with remote client. @@ -51,7 +52,7 @@ class AscendMessager(Messager): self.exit() finally: pass - + if "job_type" in job_json: res = self.tbe_builder.job_handler(arg) self.send_res(res) @@ -59,17 +60,18 @@ class AscendMessager(Messager): get_logger().error("[TRACE] Request is not a TBE Job message: {}".format(arg)) self.send_ack(False) self.exit() - + + # 退出方法 def exit(self): self.tbe_builder.reset() get_logger().info("[TRACE] Ascend Messager Exit...") exit() - - + + if __name__ == '__main__': warnings.simplefilter("ignore") if len(sys.argv) != 3: raise Exception('Incorrect argv: {}'.format(sys.argv)) get_logger().debug(f"[TRACE] argv: {str(sys.argv)}") messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2])) - messager.run() + messager.run() \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/utils.py b/src/mindspore2022/mindspore/python/mindspore/_extends/utils.py index 18f47470..ec8fdb3b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/utils.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/utils.py @@ -22,6 +22,21 @@ def cell_attr_register(fn=None, attrs=None): """ Cell init attributes register. + Args: + fn (function, optional): The __init__ function of the cell. Defaults to None. + attrs (list(string) | string, optional): A list of attributes to register. + Can be a list of strings or a single string. Defaults to None. + + Returns: + function: The original function wrapped with attribute registration. + + 该函数用于注册cell类的初始化属性。 + 通过装饰器模式,将cell类的__init__函数的参数保存为operator的属性。 + 如果未提供fn参数,则返回装饰器函数wrap_cell,否则返回包装后的__init__函数。 + """ + """ + Cell init attributes register. + Registering the decorator of the built-in operator cell __init__ function will add save all the parameters of __init__ as operator attributes. @@ -34,8 +49,38 @@ def cell_attr_register(fn=None, attrs=None): """ def wrap_cell(fn): + """ + 装饰器函数,用于记录类的初始化参数。 + + Args: + fn (function): 需要被装饰的函数。 + + Returns: + function: 返回一个新的函数,该函数在调用时会记录传递给fn函数的参数。 + + """ @wraps(fn) def deco(self, *args, **kwargs): + """ + 这是一个装饰器函数,用于记录类的初始化参数。 + + Args: + self: 类实例对象。 + *args: 传递给被装饰函数的可变位置参数。 + **kwargs: 传递给被装饰函数的可变关键字参数。 + attrs: 可选参数,指定要记录的属性。可以是字符串或字符串列表。 + + Returns: + None + + Raises: + ValueError: 如果attrs不是字符串或字符串列表,或者attrs中的元素不是字符串时抛出。 + + 该函数的主要作用是在类实例初始化时,记录传递给__init__方法的参数。 + 如果attrs为None,则记录所有传递给__init__方法的参数(不包括self)。 + 如果attrs为字符串或字符串列表,则只记录指定的属性。 + 记录的参数将被保存为实例的cell_init_args属性,格式为"类名+参数列表"。 + """ arguments = [] if attrs is None: bound_args = inspect.signature(fn).bind(self, *args, **kwargs) diff --git a/src/mindspore2022/mindspore/python/mindspore/boost/__init__.py b/src/mindspore2022/mindspore/python/mindspore/boost/__init__.py index 255c5d7d..de6cbe31 100644 --- a/src/mindspore2022/mindspore/python/mindspore/boost/__init__.py +++ b/src/mindspore2022/mindspore/python/mindspore/boost/__init__.py @@ -19,16 +19,25 @@ accumulation and so on. Note: This feature is a beta feature, and we are still improving its functionality. """ +# 从当前包的boost模块导入AutoBoost类 from .boost import AutoBoost +# 从当前包的base模块导入OptimizerProcess和ParameterProcess类 from .base import OptimizerProcess, ParameterProcess +# 从当前包的boost_cell_wrapper模块导入BoostTrainOneStepCell和BoostTrainOneStepWithLossScaleCell类 from .boost_cell_wrapper import BoostTrainOneStepCell, BoostTrainOneStepWithLossScaleCell +# 从当前包的less_batch_normalization模块导入LessBN类 from .less_batch_normalization import LessBN +# 从当前包的grad_freeze模块导入GradientFreeze, FreezeOpt和freeze_cell类或函数 from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell +# 从当前包的grad_accumulation模块导入GradientAccumulation类 from .grad_accumulation import GradientAccumulation +# 从当前包的adasum模块导入AdaSum类 from .adasum import AdaSum +# 从当前包的dim_reduce模块导入DimReduce类 from .dim_reduce import DimReduce +# 定义一个列表,包含所有要公开的模块成员 __all__ = ['AutoBoost', 'OptimizerProcess', 'ParameterProcess', 'BoostTrainOneStepCell', 'BoostTrainOneStepWithLossScaleCell', diff --git a/src/mindspore2022/mindspore/python/mindspore/boost/adasum.py b/src/mindspore2022/mindspore/python/mindspore/boost/adasum.py index 136bdd7e..f3768938 100644 --- a/src/mindspore2022/mindspore/python/mindspore/boost/adasum.py +++ b/src/mindspore2022/mindspore/python/mindspore/boost/adasum.py @@ -22,38 +22,41 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P from mindspore.ops.operations._inner_ops import Send, Receive - - + + __all__ = ["AdaSum"] - - + + MAX_NUM_HASH = 2 ** 31 - - + + _update_parameters = C.MultitypeFuncGraph("update_parameters") - - + + @_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor") def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter): + """更新参数的函数,在广播后应用delta_weight来更新参数.""" shape = F.shape(delta_weight) update_delta_weight = P.Reshape()(update_delta_weight, shape) new_parameter = old_parameter - update_delta_weight return P.Assign()(parameter, new_parameter) - - + + def _send_before_receive(send_part, send, recv): + """在接收之前发送数据的辅助函数.""" send_ok = send(send_part) return recv(send_ok) - - + + def _receive_before_send(send_part, send, recv): + """在发送之前接收数据的辅助函数.""" receive_ok = recv(send_part) send_part = F.depend(send_part, receive_ok) return F.depend(receive_ok, send(send_part)) - - + + def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num): - """send result and receive result.""" + """发送结果并接收结果的辅助函数.""" if parameter_divisibility: recv_part = P.Squeeze()(recv_part) local_part = F.depend(local_part, recv_part) @@ -76,14 +79,14 @@ def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisi res = allreduce(local_part) res /= allreduce_node_num return res - - + + _adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward") - - + + @_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor") def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w): - """adasum optimizer process.""" + """adaSum优化器的前向过程处理函数.""" if parameter_divisibility: delta_w = P.Squeeze()(delta_w) ori_len = F.shape(delta_w)[0] @@ -93,7 +96,7 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al else: left_part = delta_w right_part = delta_w - + if left_send: if parameter_divisibility: recv_part = _send_before_receive(left_part, send, recv) @@ -108,26 +111,26 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al recv_part = left_part update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility, allreduce_node_num) - + return update_delta_w - - + + _adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback") - - + + @_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function") def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv): - """adasum optimizer rollback process.""" + """adaSum优化器的回滚处理函数.""" if parameter_divisibility: if left_send: recv_part = _send_before_receive(delta_w, send, recv) else: recv_part = _receive_before_send(delta_w, send, recv) - + recv_part = P.Squeeze()(recv_part) recv_part = P.Reshape()(recv_part, (-1,)) delta_w = P.Reshape()(delta_w, (-1,)) - + if left_send: res = P.Concat()((recv_part, delta_w)) else: @@ -135,28 +138,28 @@ def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, sen else: res = delta_w return res - - + + class AdaSum(Cell): r""" - The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data - parallel training of Deep Learning models. - + 自适应加法(AdaSum)是一种新算法,用于改善深度学习模型的分布式数据并行训练。 + Args: - rank (int): Rank number. - device_number (int): Device number. - group_number (int): Group number. - parameter_tuple (Tuple(Parameter)): Tuple of parameters. - + rank (int): 排名编号。 + device_number (int): 设备数量。 + group_number (int): 组数量。 + parameter_tuple (Tuple(Parameter)): 参数元组。 + Inputs: - - **delta_weights** (Tuple(Tensor)) - Tuple of gradients. - - **parameters** (Tuple(Parameter)) - Tuple of current parameters. - - **old_parameters** (Tuple(Parameter)) - Tuple of last parameters. - + - **delta_weights** (Tuple(Tensor)) - 梯度的元组。 + - **parameters** (Tuple(Parameter)) - 当前参数的元组。 + - **old_parameters** (Tuple(Parameter)) - 上一参数的元组。 + Outputs: - - **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process. + - **adasum_parameters** (Tuple(Tensor)) - 经过adasum处理后的参数元组。 """ def __init__(self, rank, device_number, group_number, parameter_tuple): + """AdaSum类的初始化函数.""" super(AdaSum, self).__init__() self.rank = rank self.device_number = device_number @@ -164,9 +167,9 @@ class AdaSum(Cell): self.parameter_tuple = parameter_tuple self._generate_communication_op() self.hyper_map = C.HyperMap() - + def _generate_communication_op(self): - """generate communication op.""" + """生成通信操作的私有方法.""" self.calc_times = int(math.log(self.group_number, 2)) self.send_node = [] self.send_list_forward = [] @@ -179,7 +182,7 @@ class AdaSum(Cell): self.allreduce_node_num_list = [] last_delta_weights = [] group_start_rank = (self.rank // self.device_number) * self.device_number - + for step in range(self.calc_times): current_group = self.device_number * (2 ** step) sr_target = self.rank @@ -189,7 +192,7 @@ class AdaSum(Cell): else: dest_target = sr_target - current_group self.send_node.append(False) - + neighbor_ids = [] group_name_last = 0 for index in range(2 ** (step + 1)): @@ -201,7 +204,7 @@ class AdaSum(Cell): group_name_last += neighbor_id group_name = "adasum_" + str(step) + "_" + str(group_name_last) create_group(group_name, neighbor_ids) - + send_left = [] send_right = [] recv_left = [] @@ -234,7 +237,7 @@ class AdaSum(Cell): send_right.append(send) recv_right.append(recv) weights_index += 1 - + if self.send_node and self.send_node[-1]: self.send_list_forward.append(send_left) self.send_list_rollback.append(send_right) @@ -247,27 +250,27 @@ class AdaSum(Cell): self.recv_list_forward.append(recv_left) self.recv_list_rollback.append(recv_right) last_delta_weights = left_delta_weights - + server_all_reduce = P.AllReduce("sum", group_name) server_all_reduce.add_prim_attr("fusion", fusion_id + 2) self.allreduce_list.append(server_all_reduce) - + for param_divisibility in delta_weights_divisibility: if param_divisibility: allreduce_node_num += (0,) else: allreduce_node_num += (2 ** (step + 1),) self.allreduce_node_num_list.append(allreduce_node_num) - + broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)] broadcast_group_name = "broadcast_group_" + str(group_start_rank) create_group(broadcast_group_name, broadcast_group) for b_rank in range(len(broadcast_group)): self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name)) self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name) - + def _get_delta_weights_info(self, last_delta_weights): - """get delta weights info.""" + """获取delta权重信息的私有方法.""" half_delta_weights = [] if last_delta_weights: half_delta_weights = last_delta_weights @@ -292,14 +295,16 @@ class AdaSum(Cell): right_delta_weights.append((right_shape, dtype)) delta_weights_divisibility += (divisibility_flag,) return left_delta_weights, right_delta_weights, delta_weights_divisibility - + def _hash(self, step, target, weights_index): + """计算哈希值的私有方法.""" target = "tag" + str(step) + str(target) + str(weights_index) target_hash = hashlib.sha1(target.encode()).hexdigest() hash_res = int(int(target_hash, 16) % MAX_NUM_HASH) return hash_res - + def construct(self, delta_weights, parameters, old_parameters): + """构建方法,用于执行adaSum优化过程.""" forward_weights = [delta_weights] for i in range(self.calc_times): process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]), @@ -314,4 +319,4 @@ class AdaSum(Cell): forward_weights[j] = process_weights adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0], parameters, old_parameters) - return adasum_parameters + return adasum_parameters \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py b/src/mindspore2022/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py index a5cc35e5..900150b3 100644 --- a/src/mindspore2022/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py +++ b/src/mindspore2022/mindspore/python/mindspore/dataset/engine/datasets_user_defined.py @@ -452,11 +452,14 @@ class _GeneratorWorkerMp(multiprocessing.Process): """ def __init__(self, dataset, eof, max_rowsize, queue_size, ppid): + # 初始化一个多进程队列,用于存储索引 self.idx_queue = multiprocessing.Queue(queue_size) + # 如果启用了共享内存,则初始化一个共享队列,否则初始化一个多进程队列 if get_enable_shared_mem(): self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize) else: self.res_queue = multiprocessing.Queue(queue_size) + # 设置队列的_joincancelled属性为True,表示在进程退出时,队列不会阻塞 self.idx_queue._joincancelled = True # pylint: disable=W0212 self.res_queue._joincancelled = True # pylint: disable=W0212 super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid)) @@ -465,6 +468,7 @@ class _GeneratorWorkerMp(multiprocessing.Process): """ Put function for worker index queue. Never block. Raise queue.Full on failure. """ + # 将item放入idx_queue队列中,不阻塞,如果失败则抛出queue.Full异常 self.idx_queue.put_nowait(item) def get(self): @@ -476,12 +480,19 @@ class _GeneratorWorkerMp(multiprocessing.Process): return self.res_queue.get(timeout=30) def queue_empty(self): + # 检查idx_queue是否为空 if not self.idx_queue.empty(): + # 如果不为空,记录警告日志 logger.warning("idx_queue is not empty.") + # 返回False return False + # 检查res_queue是否为空 if not self.res_queue.empty(): + # 如果不为空,记录警告日志 logger.warning("res_queue is not empty.") + # 返回False return False + # 如果两个队列都为空,返回True return True def __del__(self): @@ -632,14 +643,17 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset): def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6): + # 调用父类的初始化方法 super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) + # 如果source是zip类型,则将其转换为列表 if isinstance(source, builtins.zip): # Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array. self.source = [item for item in source] else: self.source = source self.prepared_source = None # source to be sent to C++ + # 如果self.operator_mixed属性为True,则将num_parallel_workers设置为1 if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True: self.num_parallel_workers = 1 logger.warning( @@ -650,56 +664,78 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset): self.python_multiprocessing = python_multiprocessing + # 将column_names转换为列表 self.column_names = to_list(column_names) + # 如果column_types不为空,则将其转换为detypelist类型 if column_types is not None: self.column_types = mstypelist_to_detypelist(column_types) else: self.column_types = [] self.schema = schema + # 如果schema不为空,则将其转换为Schema类型 if schema is not None: + # 如果schema不为空,则将其赋值给self.schema self.schema = schema + # 如果schema不是Schema类型,则将其转换为Schema类型 if not isinstance(schema, Schema): self.schema = Schema(schema) # Move get dataset_size by len from parse to here, because self.source will # lose attribution of '__len__' after deepcopy. self.source_len = -1 # unknown + # 如果self.source有__len__属性,则获取self.source的长度 if hasattr(self.source, "__len__"): self.source_len = len(self.source) + # 设置最大行大小 self.max_rowsize = max_rowsize + # 设置采样函数为None self.sample_fn = None def __deepcopy__(self, memodict): + # 深度复制当前对象,并传入一个字典memodict,用于存储已经复制的对象 if id(self) in memodict: + # 如果当前对象的id已经在memodict中,则直接返回该对象 return memodict[id(self)] + # 否则,调用__safe_deepcopy__方法进行深度复制,并传入memodict和exclude参数 new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__")) sample_fn = None + # 如果新对象的sampler属性不为空,并且self.source对象具有__getitem__方法 if new_op.sampler is not None and hasattr(self.source, "__getitem__"): # The reason why there is a try catch here is because when the new op is being constructed with shared # memory enabled, there will be an exception thrown if there is not enough shared memory available + # 如果self.source_len为-1,则抛出RuntimeError异常,因为尝试构造一个随机访问的数据集,需要__len__方法 if self.source_len == -1: raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!") try: + # 如果新对象的num_parallel_workers大于1,则调用__validate_memory_usage方法进行内存使用验证 if new_op.num_parallel_workers > 1: self.__validate_memory_usage() + # 创建一个SamplerFn对象,用于并行采样 sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing, self.max_rowsize) + # 将新对象的prepared_source属性设置为_cpp_sampler_fn_mp函数,用于并行采样 new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) else: + # 否则,将新对象的prepared_source属性设置为_cpp_sampler_fn函数,用于单线程采样 new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) + # 将新对象的sample_fn属性设置为sample_fn new_op.sample_fn = sample_fn except RuntimeError as e: + # 如果抛出RuntimeError异常,则抛出Exception异常,并传入异常信息 raise Exception(str(e)) else: try: + # 否则,将新对象的sampler属性设置为None,sample_fn属性设置为sample_fn new_op.sampler = None new_op.sample_fn = sample_fn + # 将新对象的source_len属性设置为min(new_op.source_len, new_op.num_samples),如果new_op.num_samples不为0,否则设置为new_op.source_len new_op.source_len = min(new_op.source_len, new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len + # 遍历self.source对象 iter(self.source) except TypeError: # Use generator function if input callable @@ -711,19 +747,26 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset): return new_op + # 判断是否被洗牌 def is_shuffled(self): return self.sampler.is_shuffled() + # 判断是否被分片 def is_sharded(self): return self.sampler.is_sharded() + # 解析 def parse(self, children=None): + # 如果schema为空,则返回GeneratorNode对象 if self.schema is None: return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len, self.sampler, self.num_parallel_workers) + # 获取schema schema = self.schema + # 如果schema是Schema类型,则获取cpp_schema if isinstance(schema, Schema): schema = self.schema.cpp_schema + # 返回GeneratorNode对象 return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler, self.num_parallel_workers) @@ -735,24 +778,37 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset): # if use num_parallel_workers is to large when python_multiprocessing=True which would cause # OOM error get the num_shards valid_num_shards = 1 + # 判断self.sampler是否为samplers.DistributedSampler类型 if isinstance(self.sampler, samplers.DistributedSampler): + # 如果是,则将self.sampler的num_shards赋值给valid_num_shards valid_num_shards = self.sampler.num_shards + # 否则,判断self.num_shards是否为None elif self.num_shards is not None: + # 如果不是,则将self.num_shards赋值给valid_num_shards valid_num_shards = self.num_shards # get process memory usage + # 获取当前进程 process = psutil.Process(os.getpid()) + # 获取当前进程的内存信息 process_memory = process.memory_info().rss + # 获取系统内存的空闲量 sys_memory_free = psutil.virtual_memory().free + # 计算可能使用的总内存量 total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards + # 如果总内存可能使用的内存量除以系统可用内存大于0.85 if total_memory_maybe_used / sys_memory_free > 0.85: + # 计算有效的worker数量,即系统可用内存乘以0.85除以有效的shards数量再除以每个进程的内存 valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory) + # 如果有效的worker数量小于等于0,则将其设置为1 valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker + # 构造警告信息,提示用户num_parallel_workers设置过大,可能会导致内存占用过高或OOM,建议将其减小到valid_num_worker或更小 info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \ "occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \ "to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers, valid_num_worker) + # 打印警告信息 logger.warning(info) @@ -764,37 +820,55 @@ class _NumpySlicesDataset: def __init__(self, data, column_list=None): self.column_list = None # Convert dict data into tuple + # 判断data是否为字典类型 if isinstance(data, dict): + # 如果是字典类型,则调用process_dict方法处理 data = self.process_dict(data) + # 判断data是否为元组类型 if isinstance(data, tuple): + # 如果是元组类型,则将self.data初始化为空元组 self.data = () + # 获取data的长度 data_len = len(data) + # 遍历data中的每个元素 for i in range(data_len): + # 将data中的每个元素转换为numpy数组,并添加到self.data中 self.data = self.data + (np.array(data[i]),) else: + # 如果data不是元组类型,则将data转换为numpy数组,并添加到self.data中 self.data = (np.array(data),) # check whether the data length in each column is equal + # 获取每个data_item的长度 data_len = [len(data_item) for data_item in self.data] + # 如果每个data_item的长度不相等,则抛出ValueError异常 if data_len[1:] != data_len[:-1]: raise ValueError("Data length in each column is not equal.") # Init column_name + # 如果column_list不为空,则将self.column_list赋值为column_list if column_list is not None: self.column_list = column_list + # 如果self.column_list为空,则将self.column_list赋值为空列表 elif self.column_list is None: self.column_list = [] + # 获取data的列数 column_num = len(self.data) + # 遍历列数,将"column_" + str(i)添加到self.column_list中 for i in range(column_num): self.column_list.append("column_" + str(i)) def __getitem__(self, index): + # 获取指定索引的数据行 data_row = [d[index, ...] for d in self.data] + # 将数据行转换为元组 data_res = tuple(data_row) + # 返回数据行 return data_res def __len__(self): + # 返回data的第一个元素的长度 return len(self.data[0]) def process_dict(self, input_data): @@ -802,24 +876,29 @@ class _NumpySlicesDataset: Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. """ # Convert pandas like dict(has "values" column) into General dict + # 将pandas样式的字典(有"values"列)转换为通用字典 data_keys = list(input_data.keys()) + # 获取字典的第一个键对应的值 data_col = input_data[data_keys[0]] + # 如果值有values属性,则将其转换为通用字典 if hasattr(data_col, "values"): new_dict = {} for key in data_keys: + # 将字典中的键对应的值转换为列表 item1 = input_data.pop(key) new_dict[key] = item1.values + # 将转换后的字典赋值给input_data input_data = new_dict # Convert the data in dict into tuple - data = () - keys = list(input_data.keys()) - self.column_list = keys - for key in keys: - value = input_data[key] - data = data + (list(value),) + data = () # 初始化一个空元组 + keys = list(input_data.keys()) # 将输入数据的键转换为列表 + self.column_list = keys # 将键列表赋值给实例变量column_list + for key in keys: # 遍历键列表 + value = input_data[key] # 获取键对应的值 + data = data + (list(value),) # 将值转换为列表,并添加到元组中 - return data + return data # 返回元组 class NumpySlicesDataset(GeneratorDataset): @@ -909,7 +988,9 @@ class NumpySlicesDataset(GeneratorDataset): @check_numpyslicesdataset def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): + # 创建一个_NumpySlicesDataset对象,传入data和column_names参数 dataset = _NumpySlicesDataset(data, column_names) + # 调用父类的__init__方法,传入dataset、column_names、num_samples、num_parallel_workers、shuffle、sampler、num_shards和shard_id参数 super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, num_shards=num_shards, shard_id=shard_id) diff --git a/src/mindspore2022/mindspore/python/mindspore/ops/operations/_embedding_cache_ops.py b/src/mindspore2022/mindspore/python/mindspore/ops/operations/_embedding_cache_ops.py index a6d1f45d..a043b90c 100644 --- a/src/mindspore2022/mindspore/python/mindspore/ops/operations/_embedding_cache_ops.py +++ b/src/mindspore2022/mindspore/python/mindspore/ops/operations/_embedding_cache_ops.py @@ -17,65 +17,71 @@ from ..._checkparam import Validator as validator from ...common import dtype as mstype from ..primitive import prim_attr_register, PrimitiveWithCheck from .. import signature as sig - - + + class UpdateCache(PrimitiveWithCheck): """ - Update the value fo input_x, similar to ScatterNdUpdate. - The difference is that UpdateCache will not update when indices < 0 or indices >= max_num. - + 更新 input_x 的值,类似于 ScatterNdUpdate。 + 不同之处在于,UpdateCache 当 indices < 0 或 indices >= max_num 时不会更新。 + Inputs: - - **input_x** (Parameter) - Parameter which is going to be updated. - - **indices** (Tensor) - Update indices of input_x. - - **updates** (Tensor) - The update values. - + - **input_x** (Parameter) - 将要更新的参数。 + - **indices** (Tensor) - input_x 的更新索引。 + - **updates** (Tensor) - 更新值。 + Outputs: - - **out** (Tensor) - Returns a [1] Tensor, which is not useful. + - **out** (Tensor) - 返回一个 [1] 的张量,这个张量没有用处。 """ + # 定义函数签名,指定输入参数的类型和读写权限 __mindspore_signature__ = ( + # 定义输入参数input_x,类型为T,读写权限为写 sig.make_sig('input_x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + # 定义输入参数indices,类型为T1 sig.make_sig('indices', dtype=sig.sig_dtype.T1), + # 定义输入参数updates,类型为T sig.make_sig('updates', dtype=sig.sig_dtype.T), + # 定义输入参数max_num,类型为T1 sig.make_sig('max_num', dtype=sig.sig_dtype.T1) ) - + @prim_attr_register def __init__(self): - """init UpdateCache""" - + """初始化 UpdateCache""" + + # 初始化输入和输出名称 self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], outputs=['out']) - + def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): + # 检查输入形状 return [1] - + def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): + # 检查输入数据类型 validator.check_tensor_dtype_valid( "indices", indices_dtype, mstype.int_type, self.name) return input_x_dtype - - + + class SubAndFilter(PrimitiveWithCheck): """ - Dynamic kernel, sub an offset and - return the elements which in range [0, max_num). - + 动态内核,减去一个偏移量并返回在范围 [0, max_num) 内的元素。 + Inputs: - - **input_x** (Tensor) - Input tensor. - - **max_num** (Int) - The max value of element that after sub `offset`. - - **offset** (int) - Specifies the offset value of this `input_x`. - + - **input_x** (Tensor) - 输入张量。 + - **max_num** (Int) - 减去 `offset` 后元素的最大值。 + - **offset** (int) - 指定此 `input_x` 的偏移值。 + Outputs: - tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx. - - **filter_res** (Tensor) - The result that `input_x` minus `offset`, - and return which in the range [0, max_num). - - **filter_idx** (Tensor) - A tensor containing indices of elements in the input - coressponding to the output tensor. - + tuple(Tensor), 由 2 个张量组成的元组,filter_res 和 filter_idx。 + - **filter_res** (Tensor) - `input_x` 减去 `offset` 的结果, + 并返回在范围 [0, max_num) 内的值。 + - **filter_idx** (Tensor) - 一个张量,包含与输出张量对应的输入元素的索引。 + Supported Platforms: `CPU` - + Examples: >>> x = Tensor(np.array([1, 3, 5, 8, 9, 16]), mindspore.int32) >>> max_num = 10 @@ -87,35 +93,38 @@ class SubAndFilter(PrimitiveWithCheck): """ @prim_attr_register def __init__(self): - """init SubAndFilter""" - + """初始化 SubAndFilter""" + + # 初始化输入和输出名称 self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'], outputs=['sub_res', 'sub_idx']) - + def check_shape(self, input_x_shape, max_num_shape, offset_shape): + # 检查输入形状 return ((-1,), (-1,)) - + def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype): + # 检查输入数据类型 validator.check_tensor_dtype_valid( "input_x", input_x_dtype, mstype.int_type, self.name) return input_x_dtype - - + + class MapUniform(PrimitiveWithCheck): """ - Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`. - + 通过公式映射一个张量:value = key % `group_num` * `per_group_size` + key // `group_num`。 + Inputs: - - **input** (Tensor) - Input Tensor. - - **per_group_size** (int) - The size of each group. - - **group_num** (int) - The number of group. - + - **input** (Tensor) - 输入张量。 + - **per_group_size** (int) - 每个组的大小。 + - **group_num** (int) - 组的数量。 + Outputs: - Tensor, has the same dtype and shape as the `input`. - + Tensor,具有与 `input` 相同的 dtype 和形状。 + Supported Platforms: `CPU` - + Examples: >>> input_x = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7])) >>> per_group_size = 4 @@ -125,33 +134,34 @@ class MapUniform(PrimitiveWithCheck): >>> print(output) [0, 4, 1, 5, 2, 6, 3, 7] """ - + @prim_attr_register def __init__(self): - """init MapUniform""" + """初始化 MapUniform""" self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'], outputs=['output']) - + def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype): + """检查输入数据类型""" validator.check_tensor_dtype_valid( "input", input_dtype, mstype.int_type, self.name) validator.check_value_type( 'per_group_size', per_group_size_dtype, [mstype.Int], self.name) validator.check_value_type( 'group_num', group_num_dtype, [mstype.Int], self.name) - - + + class CacheSwapTable(PrimitiveWithCheck): """ - Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. - + 删除一个哈希映射条目,并插入一个新键到哈希映射中,返回删除条目的键和值。 + Inputs: - - **cache_table** (Parameter) - The cache table which is on device. - - **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped. - - **miss_value** (int) - The values which arg going to swap into cache table. - + - **cache_table** (Parameter) - 在设备上的缓存表。 + - **swap_cache_idx** (Tensor) - 需要交换的表索引,-1 被跳过。 + - **miss_value** (int) - 将要交换到缓存表的值。 + Outputs: - - **old_value** (Tensor) - The values which are swapped out. + - **old_value** (Tensor) - 被交换出去的值。 """ __mindspore_signature__ = ( sig.make_sig('cache_table', sig.sig_rw.RW_WRITE, @@ -159,31 +169,35 @@ class CacheSwapTable(PrimitiveWithCheck): sig.make_sig('swap_cache_idx', dtype=sig.sig_dtype.T1), sig.make_sig('miss_value', dtype=sig.sig_dtype.T) ) - + @prim_attr_register def __init__(self): - """init CacheSwapTable""" - + """初始化 CacheSwapTable""" + self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], outputs=['old_value']) - + def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): + # 检查cache_table_shape的长度是否为2,如果不是,则抛出ValueError异常 if len(cache_table_shape) != 2: raise ValueError( "cache table shape must be 2, but got %d" % len(cache_table_shape)) - + + # 返回miss_value_shape return miss_value_shape - + def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): + # 检查swap_cache_idx_dtype是否为mstype.int_type,如果不是,则抛出ValueError异常 validator.check_tensor_dtype_valid( "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) + # 返回miss_value_dtype return miss_value_dtype - - + + class MapCacheIdx(PrimitiveWithCheck): """ - MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. - When input an indices tensor, it will output the cache indices which search in hashmap. + MapCacheIdx 将 SearchCacheIdx、CacheSwapHashmap 和 UpdateCache 合并在一起。 + 当输入一个索引张量时,它将输出在哈希映射中搜索的缓存索引。 """ __mindspore_signature__ = ( sig.make_sig('hashmap', sig.sig_rw.RW_WRITE, @@ -193,56 +207,69 @@ class MapCacheIdx(PrimitiveWithCheck): sig.make_sig('emb_max_num', dtype=sig.sig_dtype.T), sig.make_sig('cache_max_num', dtype=sig.sig_dtype.T) ) - + @prim_attr_register def __init__(self): - """init MapCacheIdx""" - + """初始化 MapCacheIdx""" + self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) - + def __check__(self, hashmap, indices, step, emb_max_num, offset): + # 获取hashmap的形状 hashmap_shape = hashmap['shape'] + # 如果hashmap的维度不是2,则抛出异常 if len(hashmap_shape) != 2: raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " "but got %d." % len(hashmap_shape)) + # 设置输出的形状 out_shape = (indices['shape'], -1, -1, -1) - + + # 获取hashmap和indices的数据类型 hashmap_dtype = hashmap['dtype'] indices_dtype = indices['dtype'] + # 将数据类型存入字典 args = {"hashmap": hashmap_dtype, "indices": indices_dtype} + # 检查数据类型是否相同且有效 validator.check_tensors_dtypes_same_and_valid( args, mstype.int_type, self.name) + # 设置输出的数据类型 out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype, hashmap_dtype) - + + # 设置输出的字典 out = {'shape': out_shape, 'dtype': out_dtype, 'value': None} + # 如果indices中有max_shape,则设置输出的max_shape if 'max_shape' in indices: out['max_shape'] = (indices['max_shape'], indices['max_shape'], indices['max_shape'], indices['max_shape']) + # 否则,设置输出的max_shape为indices的形状 else: out['max_shape'] = (indices['shape'], indices['shape'], indices['shape'], indices['shape']) + # 如果indices中有min_shape,则设置输出的min_shape if 'min_shape' in indices: out['min_shape'] = (indices['min_shape'], 0, 0, 0) + # 否则,设置输出的min_shape为(0, 0, 0, 0) else: out['min_shape'] = (0, 0, 0, 0) + # 返回输出的字典 return out - - + + class DynamicAssign(PrimitiveWithCheck): """ - Assigns `Parameter` with a value, the `value` can have a dynamic shape. - + 将 `Parameter` 与值分配,`value` 可以具有动态形状。 + Inputs: - - **variable** (Parameter) - The `Parameter`. - - **value** (Tensor) - The value to be assigned. - + - **variable** (Parameter) - `Parameter`。 + - **value** (Tensor) - 要分配的值。 + Outputs: - Tensor, has the same type as original `variable`. - + Tensor,具有与原始 `variable` 相同的类型。 + Supported Platforms: `CPU` """ @@ -250,41 +277,42 @@ class DynamicAssign(PrimitiveWithCheck): sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('value', dtype=sig.sig_dtype.T) ) - + @prim_attr_register def __init__(self): self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) - + def check_dtype(self, variable, value): + # 检查变量是否为mstype.type_refkey if variable != mstype.type_refkey: + # 检查变量是否为mstype.number_type类型 validator.check_tensor_dtype_valid( "variable", variable, mstype.number_type, self.name) + # 检查value是否为mstype.number_type类型 validator.check_scalar_or_tensor_types_same( {"value": value}, mstype.number_type, self.name) - - + + class PadAndShift(PrimitiveWithCheck): """ - Pad a tensor with -1, and shift with a length. - + 用 -1 填充张量,并按长度进行移位。 + Inputs: - - **input_x** (Tensor) - The input Tensor, which will be copied - to `output`. - - **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is - the pad length of output tensor, cum_sum_arr[shift_idx] is - the start to shift, and cum_sum_arr[shift_idx+1] is the end. - - **shift_idx** (Int) - The idx of cum_sum_arr. - if use python, PadAndShift is: + - **input_x** (Tensor) - 输入张量,将被复制到 `output`。 + - **cum_sum_arr** (Tensor) - cum_sum_arr 的最后一个值是输出张量的填充长度, + cum_sum_arr[shift_idx] 是开始移位,cum_sum_arr[shift_idx+1] 是结束。 + - **shift_idx** (Int) - cum_sum_arr 的索引。 + 如果使用 Python,PadAndShift 为: output = [-1] * cum_sum_arr[-1] start = cum_sum_arr[shift_idx] end = cum_sum_arr[shift_idx + 1] output[start:end] = input_x[:(end-start)] Outputs: - Tensor, has the same type as original `variable`. - + Tensor,具有与原始 `variable` 相同的类型。 + Supported Platforms: `CPU` - + Examples: >>> input_x = Tensor(np.array([9, 13, -1, -1, -1, -1, -1, -1]), mstype.int32) >>> cum_sum_arr = Tensor(np.array([0, 3, 5]), mstype.int32) @@ -296,11 +324,14 @@ class PadAndShift(PrimitiveWithCheck): """ @prim_attr_register def __init__(self): + # 初始化输入输出名称 self.init_prim_io_names( inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output']) - + def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape): + # 检查输入形状 return input_x_shape - + def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype): - return input_x_dtype + # 检查输入数据类型 + return input_x_dtype \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/ops/operations/_tensor_array.py b/src/mindspore2022/mindspore/python/mindspore/ops/operations/_tensor_array.py index 989e547e..013c35b9 100644 --- a/src/mindspore2022/mindspore/python/mindspore/ops/operations/_tensor_array.py +++ b/src/mindspore2022/mindspore/python/mindspore/ops/operations/_tensor_array.py @@ -12,39 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - + """Operators for TensorArray.""" - + import mindspore as ms from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import prim_attr_register, PrimitiveWithInfer, Primitive - - + + class TensorArray(PrimitiveWithInfer): r""" TensorArrayCreate used to create a TensorArray and return an unique handle. - + .. warning:: This is an experimental prototype that is subject to change and/or deletion. - + Args: dtype (mindspore.dtype): the data type in the TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray. dynamic_size (bool): If true the TensorArray can increase the size. Default: True. size (int): The size of the TensorArray if dynamic_size = False. name (string): the name of this TensorArray. Default: "TA". - + Inputs: None. - + Outputs: - **output** (Tensor[mindspore.int64]) - an unique handle binded to the TensorArray. - + Supported Platforms: ``GPU`` ``CPU`` - + Examples: >>> import mindspore >>> import mindspore.ops as ops @@ -55,6 +55,7 @@ class TensorArray(PrimitiveWithInfer): """ @prim_attr_register def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"): + """初始化TensorArray类,设置参数和属性.""" validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name) validator.check_int(size, 0, Rel.GE, "size", self.name) self.add_prim_attr('dtype', dtype) @@ -63,32 +64,34 @@ class TensorArray(PrimitiveWithInfer): self.add_prim_attr('size', size) self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('name', name) - + def infer_shape(self): + """推断输出形状.""" return () - + def infer_dtype(self): + """推断输出数据类型.""" return mstype.int64 - - + + class TensorArrayWrite(PrimitiveWithInfer): r""" TensorArrayWrite used to write tensor into a created TensorArray. - + .. warning:: This is an experimental prototype that is subject to change and/or deletion. - + Inputs: - **index** (Tensor[int64]) - The position to write. - **value** (Tensor) - The value to add into the TensorArray. - **handle** (Tensor[int64]) - The handle pointed to the TensorArray. - + Outputs: None. - + Supported Platforms: ``GPU`` ``CPU`` - + Examples: >>> import mindspore >>> import mindspore.ops as ops @@ -99,39 +102,42 @@ class TensorArrayWrite(PrimitiveWithInfer): """ @prim_attr_register def __init__(self): + """初始化TensorArrayWrite类.""" self.add_prim_attr('side_effect_mem', True) - + def infer_shape(self, handle_shape, index_shape, value_shape): + """推断输出形状.""" return () - + def infer_dtype(self, handle_type, index_type, value_type): + """推断输出数据类型.""" validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("index", index_type, (int, ms.int64), self.name) validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name) return mstype.int64 - - + + class TensorArrayRead(PrimitiveWithInfer): r""" TensorArrayRead used to read tensor from a created TensorArray by the given index. - + .. warning:: This is an experimental prototype that is subject to change and/or deletion. - + Args: dtype (mindspore.dtype): the data type in the TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray. - + Inputs: - **index** (Tensor[int64]) - The position to read. - **handle** (mindspore.int64) - The handle pointed to the TensorArray. - + Outputs: - **output** (Tensor) - the value in position index. - + Supported Platforms: ``GPU`` ``CPU`` - + Examples: >>> import mindspore >>> import mindspore.ops as ops @@ -146,38 +152,41 @@ class TensorArrayRead(PrimitiveWithInfer): """ @prim_attr_register def __init__(self, dtype, element_shape): + """初始化TensorArrayRead类,设置参数和属性.""" validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name) self.add_prim_attr('dtype', dtype) self.add_prim_attr('element_shape', element_shape) self.add_prim_attr('side_effect_mem', True) self.dtype = dtype self.shape = element_shape - + def infer_shape(self, handle_shape, index_shape): + """推断输出形状.""" return self.shape - + def infer_dtype(self, handle_type, index_type): + """推断输出数据类型.""" validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("index", index_type, (int, ms.int64), self.name) return self.dtype - - + + class TensorArrayClose(PrimitiveWithInfer): r""" TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted. - + .. warning:: This is an experimental prototype that is subject to change and/or deletion. - + Inputs: - **handle** (mindspore.int64) - The handle pointed to the TensorArray. - + Outputs: None. - + Supported Platforms: ``GPU`` ``CPU`` - + Examples: >>> import mindspore >>> import mindspore.ops as ops @@ -188,32 +197,35 @@ class TensorArrayClose(PrimitiveWithInfer): """ @prim_attr_register def __init__(self): + """初始化TensorArrayClose类.""" self.add_prim_attr('side_effect_mem', True) - + def infer_shape(self, handle_shape): + """推断输出形状.""" return () - + def infer_dtype(self, handle_type): + """推断输出数据类型.""" validator.check_type_name("handle", handle_type, (ms.int64), self.name) return mstype.int64 - - + + class TensorArrayClear(PrimitiveWithInfer): r""" TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable. - + .. warning:: This is an experimental prototype that is subject to change and/or deletion. - + Inputs: - **handle** (mindspore.int64) - The handle pointed to the TensorArray. - + Outputs: None. - + Supported Platforms: ``GPU`` ``CPU`` - + Examples: >>> import mindspore >>> import mindspore.ops as ops @@ -224,36 +236,39 @@ class TensorArrayClear(PrimitiveWithInfer): """ @prim_attr_register def __init__(self): + """初始化TensorArrayClear类.""" self.add_prim_attr('side_effect_mem', True) - + def infer_shape(self, handle_shape): + """推断输出形状.""" return () - + def infer_dtype(self, handle_type): + """推断输出数据类型.""" validator.check_type_name("handle", handle_type, (ms.int64), self.name) return mstype.int64 - - + + class TensorArrayStack(Primitive): r""" TensorArrayStack used to stack the tensors in a created TensorArray into one tensor. - + .. warning:: This is an experimental prototype that is subject to change and/or deletion. - + Args: dtype (mindspore.dtype): the data type in the TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray. - + Inputs: - **handle** (mindspore.int64) - The handle pointed to the TensorArray. - + Outputs: - **output** (Tensor) - the stacked value from the TensorArray. - + Supported Platforms: ``GPU`` ``CPU`` - + Examples: >>> import mindspore >>> import mindspore.ops as ops @@ -269,31 +284,31 @@ class TensorArrayStack(Primitive): """ @prim_attr_register def __init__(self, dtype, element_shape, dynamic_size, size): - """Initialize TensorArrayStack""" + """初始化TensorArrayStack类,设置参数和属性.""" self.init_prim_io_names(inputs=[''], outputs=['output']) self.add_prim_attr('dtype', dtype) self.add_prim_attr('element_shape', element_shape) self.add_prim_attr('is_dynamic_shape', dynamic_size) self.add_prim_attr('size', size) self.add_prim_attr('side_effect_mem', True) - - + + class TensorArraySize(PrimitiveWithInfer): r""" TensorArraySize used to get the logical size of the created TensorArray. - + .. warning:: This is an experimental prototype that is subject to change and/or deletion. - + Inputs: - **handle** (mindspore.int64) - The handle pointed to the TensorArray. - + Outputs: - **output** (Tensor[mindspore.int64]) - the logical size of the TensorArray. - + Supported Platforms: ``GPU`` ``CPU`` - + Examples: >>> import mindspore >>> import mindspore.ops as ops @@ -304,34 +319,37 @@ class TensorArraySize(PrimitiveWithInfer): """ @prim_attr_register def __init__(self): + """初始化TensorArraySize类.""" self.add_prim_attr('side_effect_mem', True) - + def infer_shape(self, handle_shape): + """推断输出形状.""" return () - + def infer_dtype(self, handle_type): + """推断输出数据类型.""" validator.check_type_name("handle", handle_type, (ms.int64), self.name) return mstype.int64 - - + + class TensorArrayGather(PrimitiveWithInfer): r""" TensorArrayGather used to gather specified elements from the created TensorArray. - + .. warning:: This is an experimental prototype that is subject to change and/or deletion. - + Args: dtype (mindspore.dtype): the data type in the TensorArray. element_shape (tuple[int]): the shape of each tensor in a TensorArray. - + Inputs: - **handle** (mindspore.int64) - The handle pointed to the TensorArray. - **indices** (mindspore.int32) - The locations of the gathered elements. - + Outputs: - **output** (Tensor) - The gathered value from the TensorArray. - + Examples: >>> import mindspore >>> import mindspore.ops as ops @@ -344,17 +362,20 @@ class TensorArrayGather(PrimitiveWithInfer): """ @prim_attr_register def __init__(self, dtype, element_shape): + """初始化TensorArrayGather类,设置参数和属性.""" self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value']) self.add_prim_attr("side_effect_mem", True) self.dtype = dtype self.element_shape = element_shape - + def infer_shape(self, handle, indices): + """推断输出形状.""" if len(indices) != 1: return ValueError("indices dimension should be equal to 1") return [indices[0]] + list(self.element_shape) - + def infer_dtype(self, handle, indices): + """推断输出数据类型.""" validator.check_type_name("handle", handle, (ms.int64), self.name) validator.check_type_name("indices", indices, (ms.int32), self.name) - return self.dtype + return self.dtype \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/ops/operations/nn_ops.py b/src/mindspore2022/mindspore/python/mindspore/ops/operations/nn_ops.py index 4db53e09..c0c643d4 100644 --- a/src/mindspore2022/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/src/mindspore2022/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -290,48 +290,34 @@ class AdaptiveAvgPool2D(PrimitiveWithInfer): class Softmax(Primitive): - r""" - Softmax operation. - - Applies the Softmax operation to the input tensor on the specified axis. - Suppose a slice in the given axis :math:`x`, then for each element :math:`x_i`, - the Softmax function is shown as follows: - - .. math:: - \text{output}(x_i) = \frac{exp(x_i)}{\sum_{j = 0}^{N-1}\exp(x_j)}, + """ + Softmax操作。 - where :math:`N` is the length of the tensor. + 将Softmax操作应用于输入张量的指定轴。 + 假设在给定轴上的切片为x,则对于每个元素x_i,Softmax函数如下所示: Args: - axis (Union[int, tuple]): The axis to perform the Softmax operation. Default: -1. + axis (Union[int, tuple]): 执行Softmax操作的轴。默认为-1。 Inputs: - - **logits** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of - additional dimensions, with float16 or float32 data type. + - **logits** (Tensor): 形状为(N, *)的张量,其中*表示任意数量的额外维度,数据类型为float16或float32。 Outputs: - Tensor, with the same type and shape as the logits. + Tensor,与logits具有相同的类型和形状。 Raises: - TypeError: If `axis` is neither an int nor a tuple. - TypeError: If dtype of `logits` is neither float16 nor float32. - ValueError: If `axis` is a tuple whose length is less than 1. - ValueError: If `axis` is a tuple whose elements are not all in range [-len(logits.shape), len(logits.shape)). + TypeError: 如果`axis`既不是int也不是tuple。 + TypeError: 如果`logits`的数据类型既不是float16也不是float32。 + ValueError: 如果`axis`是一个长度小于1的元组。 + ValueError: 如果`axis`是一个元组,但其元素不在[-len(logits.shape), len(logits.shape))范围内。 Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> logits = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32) - >>> softmax = ops.Softmax() - >>> output = softmax(logits) - >>> print(output) - [0.01165623 0.03168492 0.08612854 0.23412167 0.6364086 ] """ @prim_attr_register def __init__(self, axis=-1): - """Initialize Softmax.""" + """初始化Softmax。""" self.init_prim_io_names(inputs=['x'], outputs=['output']) validator.check_value_type("axis", axis, [int, tuple], self.name) if isinstance(axis, int): @@ -341,125 +327,86 @@ class Softmax(Primitive): class LogSoftmax(Primitive): - r""" - Log Softmax activation function. - - Applies the Log Softmax function to the input tensor on the specified axis. - Supposes a slice in the given axis, :math:`x` for each element :math:`x_i`, - the Log Softmax function is shown as follows: - - .. math:: - \text{output}(x_i) = \log \left(\frac{\exp(x_i)} {\sum_{j = 0}^{N-1}\exp(x_j)}\right), + """ + Log Softmax激活函数。 - where :math:`N` is the length of the Tensor. + 将Log Softmax函数应用于输入张量的指定轴。 + 假设在给定轴上的切片为x,则对于每个元素x_i,Log Softmax函数如下所示: Args: - axis (int): The axis to perform the Log softmax operation. Default: -1. + axis (int): 执行Log Softmax操作的轴。默认为-1。 Inputs: - - **logits** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of - additional dimensions, with float16 or float32 data type. + - **logits** (Tensor): 形状为(N, *)的张量,其中*表示任意数量的额外维度,数据类型为float16或float32。 Outputs: - Tensor, with the same type and shape as the logits. + Tensor,与logits具有相同的类型和形状。 Raises: - TypeError: If `axis` is not an int. - TypeError: If dtype of `logits` is neither float16 nor float32. - ValueError: If `axis` is not in range [-len(logits.shape), len(logits.shape)). + TypeError: 如果`axis`不是int类型。 + TypeError: 如果`logits`的数据类型既不是float16也不是float32。 + ValueError: 如果`axis`不在[-len(logits.shape), len(logits.shape))范围内。 Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> logits = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32) - >>> log_softmax = ops.LogSoftmax() - >>> output = log_softmax(logits) - >>> print(output) - [-4.4519143 -3.4519143 -2.4519143 -1.4519144 -0.4519144] """ @prim_attr_register def __init__(self, axis=-1): - """Initialize LogSoftmax.""" + """初始化LogSoftmax。""" validator.check_value_type("axis", axis, [int], self.name) class Softplus(Primitive): - r""" - Softplus activation function. - - Softplus is a smooth approximation to the ReLU function. - It can be used to constrain the output of a machine to always be positive. - The function is shown as follows: - - .. math:: + """ + Softplus激活函数。 - \text{output} = \log(1 + \exp(\text{x})), + Softplus是ReLU函数的平滑近似。它可以用于约束机器的输出始终为正。 + 函数如下所示: Inputs: - - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of - additional dimensions, with float16 or float32 data type. + - **input_x** (Tensor): 形状为(N, *)的张量,其中*表示任意数量的额外维度,数据类型为float16或float32。 Outputs: - Tensor, with the same type and shape as the `input_x`. + Tensor,与`input_x`具有相同的类型和形状。 Raises: - TypeError: If `input_x` is not a Tensor. - TypeError: If the dtype of `input_x` is neither float16 nor float32. + TypeError: 如果`input_x`不是张量。 + TypeError: 如果`input_x`的数据类型既不是float16也不是float32。 Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> input_x = Tensor(np.array([1, 2, 3, 4, 5]), mindspore.float32) - >>> softplus = ops.Softplus() - >>> output = softplus(input_x) - >>> print(output) - [1.3132615 2.126928 3.0485873 4.01815 5.0067153] """ @prim_attr_register def __init__(self): - """Initialize Softplus""" + """初始化Softplus""" self.init_prim_io_names(inputs=['x'], outputs=['output']) class Softsign(Primitive): - r""" - Softsign activation function. - - The function is shown as follows: - - .. math:: + """ + Softsign激活函数。 - \text{SoftSign}(x) = \frac{x}{ 1 + |x|} + 函数如下所示: Inputs: - - **input_x** (Tensor) - Tensor of shape :math:`(N, *)`, where :math:`*` means, any number of - additional dimensions, with float16 or float32 data type. + - **input_x** (Tensor): 形状为(N, *)的张量,其中*表示任意数量的额外维度,数据类型为float16或float32。 Outputs: - Tensor, with the same type and shape as the `input_x`. + Tensor,与`input_x`具有相同的类型和形状。 Raises: - TypeError: If `input_x` is not a Tensor. - TypeError: If dtype of `input_x` is neither float16 nor float32. + TypeError: 如果`input_x`不是张量。 + TypeError: 如果`input_x`的数据类型既不是float16也不是float32。 Supported Platforms: ``Ascend`` - - Examples: - >>> input_x = Tensor(np.array([0, -1, 2, 30, -30]), mindspore.float32) - >>> softsign = ops.Softsign() - >>> output = softsign(input_x) - >>> print(output) - [ 0. -0.5 0.6666667 0.9677419 -0.9677419] """ @prim_attr_register def __init__(self): - """Initialize Softsign""" + """初始化Softsign""" self.init_prim_io_names(inputs=['x'], outputs=['output']) @@ -1089,229 +1036,33 @@ class BNTrainingUpdate(Primitive): class BatchNorm(PrimitiveWithInfer): r""" Batch Normalization for input data and updated parameters. - - Batch Normalization is widely used in convolutional neural networks. This operation - applies Batch Normalization over inputs to avoid internal covariate shift as described - in the paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal - Covariate Shift `_. It rescales and recenters the - features using a mini-batch of data and the learned parameters can be described - in the following formula, - - .. math:: - - y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta - - where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon, :math:`mean` is the mean of x, - :math:`variance` is the variance of x. - - .. warning:: - - If the operation is used for inference, and outputs "reserve_space_1" and "reserve_space_2" are available, - then "reserve_space_1" has the same value as "mean" and "reserve_space_2" has the same value as "variance". - - For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. - - Args: - is_training (bool): If `is_training` is True, `mean` and `variance` are computed during training. - If `is_training` is False, they're loaded from checkpoint during inference. Default: False. - epsilon (float): A small value added for numerical stability. Default: 1e-5. - momentum (float): The hyper parameter to compute moving average for running_mean and running_var - (e.g. :math:`new\_running\_mean = (1 - momentum) * running\_mean + momentum * current\_mean`). - Momentum value must be [0, 1]. Default: 0.1. - data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. - Default: "NCHW". - - Inputs: - If `is_training` is False, inputs are Tensors. - - - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. - - **scale** (Tensor) - Tensor of shape :math:`(C,)`, with float16 or float32 data type. - - **bias** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. - - **mean** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. - - **variance** (Tensor) - Tensor of shape :math:`(C,)`, has the same data type with `scale`. - - If `is_training` is True, `scale`, `bias`, `mean` and `variance` are Parameters. - - - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. - - **scale** (Parameter) - Parameter of shape :math:`(C,)`, with float16 or float32 data type. - - **bias** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`. - - **mean** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`. - - **variance** (Parameter) - Parameter of shape :math:`(C,)`, has the same data type with `scale`. - - Outputs: - Tuple of 5 Tensors, the normalized inputs and the updated parameters. - - - **output_x** (Tensor) - The same type and shape as the input_x. The shape is :math:`(N, C)`. - - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. - - **batch_variance** (Tensor) - Tensor of shape :math:`(C,)`. - - **reserve_space_1** (Tensor) - Tensor of shape :math:`(C,)`. - - **reserve_space_2** (Tensor) - Tensor of shape :math:`(C,)`. - - Raises: - TypeError: If `is_training` is not a bool. - TypeError: If dtype of `epsilon` or `momentum` is not float. - TypeError: If `data_format` is not a str. - TypeError: If `input_x`, `scale`, `bias`, `mean` or `variance` is not a Tensor. - TypeError: If dtype of `input_x`, `scale` is neither float16 nor float32. - - Supported Platforms: - ``Ascend`` ``CPU`` ``GPU`` - - Examples: - >>> input_x = Tensor(np.ones([2, 2]), mindspore.float32) - >>> scale = Tensor(np.ones([2]), mindspore.float32) - >>> bias = Tensor(np.ones([2]), mindspore.float32) - >>> mean = Tensor(np.ones([2]), mindspore.float32) - >>> variance = Tensor(np.ones([2]), mindspore.float32) - >>> batch_norm = ops.BatchNorm() - >>> output = batch_norm(input_x, scale, bias, mean, variance) - >>> print(output[0]) - [[1. 1.] - [1. 1.]] + ... + """ - - __mindspore_signature__ = ( - sig.make_sig('input_x', dtype=sig.sig_dtype.T1), - sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2), - sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2), - sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3), - sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3) - ) - + # 初始化BatchNorm类 @prim_attr_register def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"): """Initialize BatchNorm.""" - if is_training is False: - self.set_signatures(tuple()) - validator.check_value_type('is_training', is_training, (bool,), self.name) - validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) - validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) - if context.get_context("device_target") != "GPU" and self.format == "NHWC": - raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, " - f"but got the 'data_format' is {self.format} and " - f"the platform is {context.get_context('device_target')}.") - self.add_prim_attr('data_format', self.format) - self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], - outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) - + # 检查参数类型和范围,并设置属性 + ... + + # 推断输入形状 def infer_shape(self, input_x, scale, bias, mean, variance): - input_x_channel = input_x[-1] if self.format == "NHWC" else input_x[1] - validator.check_equal_int(len(scale), 1, "scale rank", self.name) - validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) - validator.check("scale shape[0]", scale[0], "input_x channel", input_x_channel, Rel.EQ, self.name) - if not self.is_training: - validator.check_equal_int(len(mean), 1, "mean rank", self.name) - validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) - validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) - return input_x, scale, scale, scale, scale - + ... + + # 推断输入数据类型 def infer_dtype(self, input_x, scale, bias, mean, variance): - validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) - args = {"scale": scale, "bias": bias, "mean": mean, "variance": variance} - validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) - return input_x, mstype.float32, mstype.float32, mstype.float32, mstype.float32 - - + ... + + class Conv2D(Primitive): r""" 2D convolution layer. - - Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`, - where :math:`N` is batch size, :math:`C` is channel number, :math:`H` is height, :math:`W` is width, :math:`X_i` is - the :math:`i^{th}` input value and :math:`b_i` indicates the deviation value of the :math:`i^{th}` input value. - For each batch of shape :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as: - - .. math:: - - out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j, - - where :math:`ccor` is the cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges - from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th - filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice - of kernel and it has shape :math:`(\text{kernel_size[0]}, \text{kernel_size[1]})`, - where :math:`\text{kernel_size[0]}` and :math:`\text{kernel_size[1]}` are the height and width of the - convolution kernel. The full kernel has shape - :math:`(C_{out}, C_{in} / \text{group}, \text{kernel_size[0]}, \text{kernel_size[1]})`, - where group is the group number to split the input in the channel dimension. - - If the 'pad_mode' is set to be "valid", the output height and width will be - :math:`\left \lfloor{1 + \frac{H_{in} + \text{padding[0]} + \text{padding[1]} - \text{kernel_size[0]} - - (\text{kernel_size[0]} - 1) \times (\text{dilation[0]} - 1) }{\text{stride[0]}}} \right \rfloor` and - :math:`\left \lfloor{1 + \frac{W_{in} + \text{padding[2]} + \text{padding[3]} - \text{kernel_size[1]} - - (\text{kernel_size[1]} - 1) \times (\text{dilation[1]} - 1) }{\text{stride[1]}}} \right \rfloor` respectively. - Where :math:`dilation` is Spacing between kernel elements, :math:`stride` is The step length of each step, - :math:`padding` is zero-padding added to both sides of the input. - - - The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition - `_. More detailed introduction can be found here: - http://cs231n.github.io/convolutional-networks/. - - Args: - out_channel (int): The number of output channel :math:`C_{out}`. - kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height - and width of the 2D convolution window. Single int means the value is for both the height and the width of - the kernel. A tuple of 2 ints means the first value is for the height and the other is for the - width of the kernel. - mode (int): Modes for different convolutions. 0 Math convolution, 1 cross-correlation convolution , - 2 deconvolution, 3 depthwise convolution. Default: 1. - pad_mode (str): Specifies padding mode. The optional values are - "same", "valid" and "pad". Default: "valid". - - - same: Adopts the way of completion. The height and width of the output will be equal to - the input `x` divided by stride. The padding will be evenly calculated in top and bottom, - left and right possiblily. - Otherwise, the last extra padding will be calculated from the bottom and the right side. - If this mode is set, `pad` must be 0. - - - valid: Adopts the way of discarding. The possible largest height and width of output will be returned - without padding. Extra pixels will be discarded. If this mode is set, `pad` must be 0. - - - pad: Implicit paddings on both sides of the input `x`. The number of `pad` will be padded to the input - Tensor borders. `pad` must be greater than or equal to 0. - pad (Union(int, tuple[int])): Implicit paddings on both sides of the input `x`. If `pad` is one integer, - the paddings of top, bottom, left and right are the same, equal to pad. If `pad` is a tuple - with four integers, the paddings of top, bottom, left and right will be equal to pad[0], - pad[1], pad[2], and pad[3] accordingly. Default: 0. - stride (Union(int, tuple[int])): The distance of kernel moving, an int number that represents - the height and width of movement are both strides, or a tuple of two int numbers that - represent height and width of movement respectively. Default: 1. - dilation (Union(int, tuple[int])): The data type is int or a tuple of 2 integers. Specifies the dilation rate - to use for dilated convolution. If set to be :math:`k > 1`, there will - be :math:`k - 1` pixels skipped for each sampling location. Its value must - be greater than or equal to 1 and bounded by the height and width of the - input `x`. Default: 1. - group (int): Splits input into groups. Default: 1. - data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: "NCHW". - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - **weight** (Tensor) - Set size of kernel is :math:`(\text{kernel_size[0]}, \text{kernel_size[1]})`, - then the shape is :math:`(C_{out}, C_{in}, \text{kernel_size[0]}, \text{kernel_size[1]})`. - - Outputs: - Tensor, the value that applied 2D convolution. The shape is :math:`(N, C_{out}, H_{out}, W_{out})`. - - Raises: - TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int nor a tuple. - TypeError: If `out_channel` or `group` is not an int. - ValueError: If `kernel_size`, `stride` or `dilation` is less than 1. - ValueError: If `pad_mode` is not one of 'same', 'valid' or 'pad'. - ValueError: If `pad` is a tuple whose length is not equal to 4. - ValueError: If `pad_mode` it not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0). - ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> x = Tensor(np.ones([10, 32, 32, 32]), mindspore.float32) - >>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32) - >>> conv2d = ops.Conv2D(out_channel=32, kernel_size=3) - >>> output = conv2d(x, weight) - >>> print(output.shape) - (10, 32, 30, 30) + ... + """ - + + # 初始化Conv2D类 @prim_attr_register def __init__(self, out_channel, @@ -1324,49 +1075,18 @@ class Conv2D(Primitive): group=1, data_format="NCHW"): """Initialize Conv2D""" - self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) - self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) - self.add_prim_attr('stride', self.stride) - self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) - self.add_prim_attr('dilation', self.dilation) - validator.check_value_type('pad', pad, (int, tuple), self.name) - validator.check_value_type('pad_mode', pad_mode, [str], self.name) - if isinstance(pad, int): - pad = (pad,) * 4 - else: - validator.check_equal_int(len(pad), 4, 'pad size', self.name) - self.pad_mode = validator.check_string(pad_mode, ['valid', 'same', 'pad'], 'pad_mode', self.name) - - if pad_mode != 'pad' and pad != (0, 0, 0, 0): - raise ValueError(f"For '{self.name}', the 'pad' must be zero when 'pad_mode' is not 'pad', " - f"but got 'pad': {self.pad} and 'pad_mode': {self.pad_mode}.") - self.add_prim_attr("pad", pad) - self.padding = pad - if self.pad_mode == 'pad': - for item in pad: - validator.check_non_negative_int(item, 'pad item', self.name) - - self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) - if context.get_context("device_target") != "GPU" and self.format == "NHWC": - raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, " - f"but got the 'data_format' is {self.format} " - f"and platform is {context.get_context('device_target')}.") - self.add_prim_attr('data_format', self.format) - self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) - self.group = validator.check_positive_int(group, 'group', self.name) - self.add_prim_attr('groups', self.group) - - + # 设置输入输出名称及参数检查 + ... + + class DepthwiseConv2dNative(PrimitiveWithInfer): r""" DepthwiseConv2dNative will be deprecated in the future. Please use :class:`mindspore.nn.Conv2d` instead. - - Supported Platforms: - Deprecated + ... + """ - + + # 初始化DepthwiseConv2dNative类 @prim_attr_register def __init__(self, channel_multiplier, @@ -1378,312 +1098,76 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): dilation=1, group=1): """Initialize DepthwiseConv2dNative""" - logger.warning("WARN_DEPRECATED: The usage of DepthwiseConv2dNative is deprecated." - " Please use nn.Conv2D.") - self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) - self.stride = _check_positive_int_or_tuple('stride', stride, self.name) - if self.stride[0] != self.stride[1]: - raise ValueError("The height and width of 'stride' should be equal," - f"but got height:{self.stride[0]}, width:{self.stride[1]}") - self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - - self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name) - if self.dilation[0] != self.dilation[1]: - raise ValueError("The height and width of 'dilation' should be equal," - f"but got height:{self.dilation[0]}, width:{self.dilation[1]}") - self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1])) - validator.check_value_type('pad', pad, (int, tuple), self.name) - validator.check_value_type('pad_mode', pad_mode, [str], self.name) - if isinstance(pad, int): - pad = (pad,) * 4 - else: - validator.check_equal_int(len(pad), 4, 'pad size', self.name) - self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) - if pad_mode != 'pad' and pad != (0, 0, 0, 0): - raise ValueError(f"For '{self.name}', the 'pad' must be zero or (0, 0, 0, 0) when 'pad_mode' " - f"is not \"pad\", but got 'pad' is {self.pad} and 'pad_mode' is {pad_mode}.") - self.add_prim_attr("pad", pad) - self.padding = pad - if self.pad_mode == 'pad': - for item in pad: - validator.check_non_negative_int(item, 'pad item', self.name) - self.mode = validator.check_equal_int(mode, 3, "mode", self.name) - self.add_prim_attr('data_format', "NCHW") - self.channel_multiplier = validator.check_positive_int(channel_multiplier, "channel_multiplier", self.name) - self.group = validator.check_positive_int(group, "group", self.name) - self.add_prim_attr('offset_a', 0) - + # 警告并设置输入输出名称及属性检查 + ... + + # 推断输入形状 def infer_shape(self, x_shape, w_shape, b_shape=None): - validator.check_equal_int(len(w_shape), 4, "weight rank", self.name) - validator.check_equal_int(len(x_shape), 4, "x rank", self.name) - validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) - validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) - - kernel_size_n, _, kernel_size_h, kernel_size_w = w_shape - _, _, stride_h, stride_w = self.stride - _, _, dilation_h, dilation_w = self.dilation - if kernel_size_n != 1: - raise ValueError(f"For '{self.name}', the batch of 'weight' should be 1, but got {kernel_size_n}") - if self.pad_mode == "valid": - h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h) - w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w) - pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 - elif self.pad_mode == "same": - h_out = math.ceil(x_shape[2] / stride_h) - w_out = math.ceil(x_shape[3] / stride_w) - - pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]) - pad_top = math.floor(pad_needed_h / 2) - pad_bottom = pad_needed_h - pad_top - - pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]) - pad_left = math.floor(pad_needed_w / 2) - pad_right = pad_needed_w - pad_left - elif self.pad_mode == 'pad': - pad_top, pad_bottom, pad_left, pad_right = self.padding - - h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \ - / stride_h - w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \ - / stride_w - h_out = math.floor(h_out) - w_out = math.floor(w_out) - - self.pad_list = (pad_top, pad_bottom, pad_left, pad_right) - self.add_prim_attr('pad_list', self.pad_list) - - out_channel = self.channel_multiplier * x_shape[1] - out_shape = [x_shape[0], out_channel, h_out, w_out] - return out_shape - + ... + + # 推断输入数据类型 def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): - args = {'x': x_dtype, 'w': w_dtype} - validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) - if x_dtype.element_type() == mstype.int8: - return mstype.tensor_type(mstype.int32) - return x_dtype - - + ... + + class _Pool(PrimitiveWithInfer): r""" Performs max/avg pooling operation. - - Args: - kernel_size (Union[int, tuple[int]]): The size of the kernel, that must be a tuple - of two `int` for height and width. Default: 1. - strides (Union[int, tuple[int]]): The stride of the window, that must be - a tuple of two `int` for height and width. Default: 1. - pad_mode (str): The optional value for pad mode, is "same" or "valid". - Default: "valid". - data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. - Default: "NCHW". + ... + """ - + + # 初始化_Pool类 @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): """Initialize _Pool.""" - self.init_prim_io_names(inputs=['x'], outputs=['output']) - validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) - validator.check_value_type('strides', strides, [int, tuple], self.name) - validator.check_value_type('pad_mode', pad_mode, [str], self.name) - self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name) - self.add_prim_attr("pad_mode", self.pad_mode) - self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax") - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) - if context.get_context("device_target") != "GPU" and self.format == "NHWC": - raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, " - f"but got the 'data_format' is {self.format} and " - f"the platform is {context.get_context('device_target')}.") - if not self.is_maxpoolwithargmax: - self.add_prim_attr('data_format', self.format) - - self.kernel_size = _check_positive_int_or_tuple( - "kernel_size", kernel_size, self.name, allow_four=False, ret_four=True) - if self.is_maxpoolwithargmax: - self.kernel_size = (1, self.kernel_size[-2], self.kernel_size[-1], 1) - self.add_prim_attr("kernel_size", self.kernel_size) - - self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True) - if self.is_maxpoolwithargmax: - self.strides = (1, self.strides[-2], self.strides[-1], 1) - self.add_prim_attr("strides", self.strides) - + # 检查参数类型并设置属性 + ... + + # 推断输入形状 def infer_shape(self, x_shape): - x_shape_norm = x_shape if self.format == "NCHW" else [x_shape[0], x_shape[3], x_shape[1], x_shape[2]] - validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name) - batch, channel, input_h, input_w = x_shape_norm - if self.is_maxpoolwithargmax: - _, kernel_h, kernel_w, _ = self.kernel_size - _, stride_h, stride_w, _ = self.strides - else: - _, _, kernel_h, kernel_w = self.kernel_size - _, _, stride_h, stride_w = self.strides - - if self.pad_mode == "VALID": - out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h) - out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w) - elif self.pad_mode == "SAME": - out_h = math.ceil(input_h / stride_h) - out_w = math.ceil(input_w / stride_w) - out_shape = [batch, channel, out_h, out_w] if self.format == "NCHW" else [batch, out_h, out_w, channel] - - for shape_value in out_shape: - if shape_value <= 0: - raise ValueError(f"For '{self.name}', the each element of the output shape must be larger than 0, " - f"but got output shape: {out_shape}. The input shape: {x_shape}, " - f"kernel size: {self.kernel_size}, strides: {self.strides}." - f"Please check the official api documents for " - f"more information about the output.") - return out_shape - + ... + + # 推断输入数据类型 def infer_dtype(self, x_dtype): - validator.check_subclass("input", x_dtype, mstype.tensor, self.name) - return x_dtype - - + ... + + class MaxPool(_Pool): r""" Max pooling operation. - - Applies a 2D max pooling over an input Tensor which can be regarded as a composition of 2D planes. - - Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs - regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size - :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows. - - .. math:: - \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1} - \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) - - Args: - kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value, - is an int number that represents height and width of the kernel, or a tuple - of two int numbers that represent height and width respectively. Default: 1. - strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents - not only the height of movement but also the width of movement, or a tuple of two int numbers that - represent height and width of movement respectively. Default: 1. - pad_mode (str): The optional value of pad mode is "same" or "valid". - Default: "valid". - - - same: Adopts the way of completion. The height and width of the output will be the same as - the input. The total number of padding will be calculated in horizontal and vertical - directions and evenly distributed to top, bottom, left and right if possible. - Otherwise, the last extra padding will be done from the bottom and the right side. - - - valid: Adopts the way of discarding. The possible largest height and width of output - will be returned without padding. Extra pixels will be discarded. - data_format (str) : The optional value for data format, is 'NHWC' or 'NCHW'. - Default: 'NCHW'. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - - Outputs: - Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. - - Raises: - TypeError: If `kernel_size` or `strides` is neither int nor tuple. - ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive. - ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'. - ValueError: If `kernel_size` or `strides` is less than 1. - ValueError: If length of shape of `input` is not equal to 4. - - Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` - - Examples: - >>> x = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) - >>> maxpool_op = ops.MaxPool(pad_mode="VALID", kernel_size=2, strides=1) - >>> output = maxpool_op(x) - >>> print(output) - [[[[ 5. 6. 7.] - [ 9. 10. 11.]] - [[17. 18. 19.] - [21. 22. 23.]] - [[29. 30. 31.] - [33. 34. 35.]]]] + ... + """ - + + # 初始化MaxPool类 @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): """Initialize MaxPool.""" + # 调用父类构造函数 super(MaxPool, self).__init__(kernel_size, strides, pad_mode, data_format) - - + + class MaxPoolWithArgmax(_Pool): r""" Performs max pooling on the input Tensor and returns both max values and indices. - - Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool outputs - regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size - :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows. - - .. math:: - \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1} - \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) - - Args: - kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value and argmax - value, is an int number that represents height and width of the kernel, or a tuple of - two int numbers that represent height and width respectively. Default: 1. - strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents - not only the height of movement but also the width of movement, or a tuple of two int numbers that - represent height and width of movement respectively. Default: 1. - pad_mode (str): The optional value for pad mode, is "same" or "valid". - Default: "valid". - - - same: Adopts the way of completion. The height and width of the output will be the same as - the input. The total number of padding will be calculated in horizontal and vertical - directions and evenly distributed to top, bottom, left and right if possible. - Otherwise, the last extra padding will be done from the bottom and the right side. - - - valid: Adopts the way of discarding. The possible largest height and width of output - will be returned without padding. Extra pixels will be discarded. - data_format (str) : The optional value for data format, is 'NHWC' or 'NCHW'. - Default: 'NCHW'. - - Inputs: - - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - Data type must be float16 or float32. - - Outputs: - Tuple of 2 Tensors, representing the maxpool result and where the max values are generated. - - - **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. - It has the same data type as `x`. - - **mask** (Tensor) - Max values' index represented by the mask. Data type is int32. - - Raises: - TypeError: If the data type of `x` is neither float16 nor float32. - TypeError: If `kernel_size` or `strides` is neither an int nor a tuple. - TypeError: If `x` is not a Tensor. - - Supported Platforms: - ``Ascend`` ``GPU`` - - Examples: - >>> x = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) - >>> maxpool_arg_op = ops.MaxPoolWithArgmax(pad_mode="VALID", kernel_size=2, strides=1) - >>> output_tensor, argmax = maxpool_arg_op(x) - >>> print(output_tensor) - [[[[ 5. 6. 7.] - [ 9. 10. 11.]] - [[17. 18. 19.] - [21. 22. 23.]] - [[29. 30. 31.] - [33. 34. 35.]]]] + ... + """ - + + # 初始化MaxPoolWithArgmax类 @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): """Initialize MaxPoolWithArgmax.""" + # 调用父类构造函数 super(MaxPoolWithArgmax, self).__init__(kernel_size, strides, pad_mode, data_format) - + + # 推断输入形状 def infer_shape(self, x_shape): out_shape = _Pool.infer_shape(self, x_shape) return out_shape, out_shape - + + # 推断输入数据类型 def infer_dtype(self, x_dtype): validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name) argmax_dtype = mstype.int32 @@ -1693,18 +1177,18 @@ class MaxPoolWithArgmax(_Pool): class MaxPool3D(PrimitiveWithInfer): r""" 3D max pooling operation. - + Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes. - + Typically the input is of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})`, MaxPool outputs regional maximum in the :math:`(D_{in}, H_{in}, W_{in})`-dimension. Given kernel size :math:`ks = (d_{ker}, h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1, s_2)`, the operation is as follows. - + .. math:: \text{output}(N_i, C_j, d, h, w) = \max_{l=0, \ldots, d_{ker}-1} \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1} \text{input}(N_i, C_j, s_0 \times d + l, s_1 \times h + m, s_2 \times w + n) - + Args: kernel_size (Union[int, tuple[int]]): The size of kernel used to take the maximum value, is an int number that represents depth, height and width of the kernel, or a tuple @@ -1714,18 +1198,18 @@ class MaxPool3D(PrimitiveWithInfer): represent depth, height and width of movement respectively. Default: 1. pad_mode (str): The optional value of pad mode is "same", "valid" or "pad". Default: "valid". - + - same: Adopts the way of completion. The height and width of the output will be the same as the input. The total number of padding will be calculated in horizontal and vertical directions and evenly distributed to top, bottom, left and right if possible. Otherwise, the last extra padding will be done from the bottom and the right side. - + - valid: Adopts the way of discarding. The possible largest height and width of output will be returned without padding. Extra pixels will be discarded. - + - pad: Implicit paddings on both sides of the input in depth, height and width. The number of "pad" will be padded to the input Tensor borders. "pad_list" must be greater than or equal to 0. - + pad_list (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six integers, the padding of head, tail, top, bottom, left and right equals to pad[0], pad[1], pad[2], @@ -1734,14 +1218,14 @@ class MaxPool3D(PrimitiveWithInfer): Only effective in "pad" mode. When "pad_mode" is "pad" and "ceil_mode" is "None", "ceil_mode" will be set as "False". Default: None. data_format (str) : The optional value for data format. Currently only support 'NCDHW'. Default: 'NCDHW'. - + Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C, D_{in}, H_{in}, W_{in})`. Data type must be float16 or float32. - + Outputs: Tensor, with shape :math:`(N, C, D_{out}, H_{out}, W_{out})`. Has the data type of `x`. - + Raises: TypeError: If `kernel_size` or `strides` is neither an int nor a tuple. TypeError: If `pad_mode` or `data_format` is not a string. @@ -1750,10 +1234,10 @@ class MaxPool3D(PrimitiveWithInfer): ValueError: If `pad_mode` is 'same' or 'valid', 'ceil_mode' is not None. ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3. ValueError: If `data_format` is not 'NCDHW'. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> x = Tensor(np.arange(1 * 2 * 2 * 2 * 3).reshape((1, 2, 2, 2, 3)), mindspore.float32) >>> max_pool3d = ops.MaxPool3D(kernel_size=2, strides=1, pad_mode="valid") @@ -1762,63 +1246,93 @@ class MaxPool3D(PrimitiveWithInfer): [[[[[10. 11.]]] [[[22. 23.]]]]] """ - + @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", pad_list=0, ceil_mode=None, data_format="NCDHW"): """Initialize MaxPool3D.""" + # 初始化MaxPool3D self.init_prim_io_names(inputs=['x'], outputs=['output']) + # 检查kernel_size的类型是否为int或tuple validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) + # 检查strides的类型是否为int或tuple validator.check_value_type('strides', strides, [int, tuple], self.name) + # 检查pad_mode的类型是否为str validator.check_value_type('pad_mode', pad_mode, [str], self.name) + # 检查pad_mode的值是否为VALID、SAME或PAD self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME', 'PAD'], 'pad_mode', self.name) + # 如果pad_mode为PAD,则将其设置为CALCULATED if pad_mode.upper() == "PAD": self.pad_mode = "CALCULATED" + # 将pad_mode添加到prim_attr中 self.add_prim_attr("pad_mode", self.pad_mode) + # 检查data_format的值是否为NCDHW self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name) + # 检查kernel_size的值是否为3维的int或tuple self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, ret_five=True) + # 将kernel_size添加到prim_attr中 self.add_prim_attr("kernel_size", self.kernel_size) + # 检查strides的值是否为3维的int或tuple self.strides = _check_3d_int_or_tuple("strides", strides, self.name, ret_five=True) + # 将strides添加到prim_attr中 self.add_prim_attr("strides", self.strides) + # 如果ceil_mode为None,则将其设置为False if ceil_mode is None: self.ceil_mode = False else: + # 检查ceil_mode的类型是否为bool self.ceil_mode = validator.check_value_type('ceil_mode', ceil_mode, [bool], self.name) + # 如果pad_mode不是CALCULATED,则抛出异常 if self.pad_mode != "CALCULATED": raise ValueError("When the 'pad_mode' is 'same' or 'valid', the 'ceil_mode' only supports 'None'.") + # 将ceil_mode添加到prim_attr中 self.add_prim_attr("ceil_mode", int(self.ceil_mode)) - + + # 检查pad_list的类型是否为int或tuple validator.check_value_type('pad_list', pad_list, (int, tuple), self.name) self.pad_list = pad_list + # 如果pad_list为int,则将其转换为tuple if isinstance(self.pad_list, int): self.pad_list = (self.pad_list,) * 6 + # 如果pad_list为3维,则将其转换为6维 if len(self.pad_list) == 3: self.pad_list = (pad_list[0], pad_list[0], pad_list[1], pad_list[1], pad_list[2], pad_list[2]) + # 如果pad_list不为3维或6维,则抛出异常 if len(self.pad_list) != 3 and len(self.pad_list) != 6: raise ValueError(f"For '{self.name}', attr 'pad_list' should be an positive int number or a tuple of " f"three or six positive int numbers, but got {len(self.pad_list)} numbers.") + # 如果pad_mode不是CALCULATED,且pad_list不为(0, 0, 0, 0, 0, 0),则抛出异常 if self.pad_mode != 'CALCULATED' and self.pad_list != (0, 0, 0, 0, 0, 0): raise ValueError(f"For '{self.name}', the 'pad_list' must be zero or (0, 0, 0, 0, 0, 0) when 'pad_mode' " f"is not \"pad\", but got 'pad_list' is {pad_list} and 'pad_mode' is {pad_mode}.") + # 如果pad_mode为CALCULATED,则检查pad_list中的每个元素是否为非负整数 if self.pad_mode == 'CALCULATED': for item in self.pad_list: validator.check_non_negative_int(item, 'pad_list item', self.name) self.add_prim_attr("pad_list", self.pad_list) - + def infer_shape(self, x_shape): + # 检查输入张量的维度是否为5 validator.check_equal_int(len(x_shape), 5, "x rank", self.name) + # 将输入张量的维度赋值给变量 batch, channel, input_d, input_h, input_w = x_shape + # 将输入张量的维度添加到prim_attr中 self.add_prim_attr("x_shape", x_shape) + # 将卷积核的维度赋值给变量 _, _, kernel_d, kernel_h, kernel_w = self.kernel_size + # 将步长的维度赋值给变量 _, _, stride_d, stride_h, stride_w = self.strides - + + # 如果pad_mode为VALID,则计算输出张量的维度 if self.pad_mode == "VALID": out_d = math.ceil((input_d - (kernel_d - 1)) / stride_d) out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h) out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w) + # 如果pad_mode为SAME,则计算输出张量的维度 elif self.pad_mode == "SAME": out_d = math.ceil(input_d / stride_d) out_h = math.ceil(input_h / stride_h) out_w = math.ceil(input_w / stride_w) + # 如果pad_mode为其他,则计算输出张量的维度 else: out_d = ((input_d + self.pad_list[0] + self.pad_list[1] - (kernel_d - 1) - 1) / stride_d) + 1 @@ -1826,44 +1340,51 @@ class MaxPool3D(PrimitiveWithInfer): (kernel_h - 1) - 1) / stride_h) + 1 out_w = ((input_w + self.pad_list[4] + self.pad_list[5] - (kernel_w - 1) - 1) / stride_w) + 1 + # 如果ceil_mode为True,则向上取整 if self.ceil_mode: out_d = math.ceil(out_d) out_h = math.ceil(out_h) out_w = math.ceil(out_w) + # 否则向下取整 else: out_d = math.floor(out_d) out_h = math.floor(out_h) out_w = math.floor(out_w) + # 将输出张量的维度赋值给变量 out_shape = [batch, channel, out_d, out_h, out_w] - + + # 检查输出张量的维度是否合法 _check_shape('output', out_shape, self.name) + # 返回输出张量的维度 return out_shape - + def infer_dtype(self, x_dtype): + # 检查输入张量的数据类型是否合法 validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name) + # 返回输入张量的数据类型 return x_dtype - - + + class AvgPool(_Pool): r""" Average pooling operation. - + Applies a 2D average pooling over an input Tensor which can be regarded as a composition of 2D input planes. Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool outputs regional average in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows. - + .. math:: \text{output}(N_i, C_j, h, w) = \frac{1}{h_{ker} * w_{ker}} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1} \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) - + .. warning:: - Global pooling is supported. - For Ascend, the height of "kernel_size" and the weight of "kernel_size" are positive integers within the range [1, 255]. ksize_h * ksize_w < 256. - For Ascend, due to instruction restrictions, the values of "strides_h" and "strides_w" are positive integers within the range [1, 63]. - + Args: kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value, is an int number that represents height and width of the kernel, or a tuple @@ -1873,33 +1394,33 @@ class AvgPool(_Pool): represent height and width of movement respectively. Default: 1. pad_mode (str): The optional value for pad mode, is "same" or "valid". Default: "valid". - + - same: Adopts the way of completion. The height and width of the output will be the same as the input. The total number of padding will be calculated in horizontal and vertical directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the last extra padding will be done from the bottom and the right side. - + - valid: Adopts the way of discarding. The possible largest height and width of output will be returned without padding. Extra pixels will be discarded. data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW'. Default: 'NCHW'. - + Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - + Outputs: Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. - + Raises: TypeError: If `kernel_size` or `strides` is neither int nor tuple. ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive. ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'. ValueError: If `kernel_size` or `strides` is less than 1. ValueError: If length of shape of `x` is not equal to 4. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> class Net(nn.Cell): ... def __init__(self): @@ -1921,18 +1442,18 @@ class AvgPool(_Pool): [[26.5 27.5 28.5] [30.5 31.5 32.5]]]] """ - + @prim_attr_register def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"): """Initialize AvgPool.""" super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format) - - + + class Conv2DBackpropInput(Primitive): r""" The Conv2DBackpropInput interface is deprecated, please refer to :class:`mindspore.ops.Conv2DTranspose` if you want to do unsampling. - + Supported Platforms: Deprecated """ @@ -1941,7 +1462,7 @@ class Conv2DBackpropInput(Primitive): sig.make_sig('filter', dtype=sig.sig_dtype.T1), sig.make_sig('input_sizes', dtype=sig.sig_dtype.T2) ) - + @prim_attr_register def __init__(self, out_channel, @@ -1955,54 +1476,80 @@ class Conv2DBackpropInput(Primitive): group=1, data_format="NCHW"): """Initialize Conv2DBackpropInput""" + # 初始化Conv2DBackpropInput self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) + # 初始化输入输出名称 self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) + # 检查out_channel是否为正整数 self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + # 检查kernel_size是否为正整数或元组 self.add_prim_attr('kernel_size', self.kernel_size) + # 添加kernel_size属性 self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + # 检查data_format是否为NCHW或NHWC if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, " f"but got the 'data_format' is {self.format} and " f"the platform is {context.get_context('device_target')}.") + # 如果不是GPU平台,且data_format为NHWC,则抛出异常 self.add_prim_attr('data_format', self.format) + # 添加data_format属性 self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) + # 检查stride是否为正整数或元组 self.stride = _update_attr_by_format(self.stride, self.format) + # 根据data_format更新stride属性 self.add_prim_attr('stride', self.stride) + # 添加stride属性 self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) + # 检查dilation是否为正整数或元组 self.dilation = _update_attr_by_format(self.dilation, self.format) + # 根据data_format更新dilation属性 self.add_prim_attr('dilation', self.dilation) + # 添加dilation属性 validator.check_value_type('pad', pad, (int, tuple), self.name) + # 检查pad是否为整数或元组 validator.check_value_type('pad_mode', pad_mode, [str], self.name) + # 检查pad_mode是否为字符串 if isinstance(pad, int): pad = (pad,) * 4 else: validator.check_equal_int(len(pad), 4, 'pad size', self.name) + # 如果pad为整数,则将其转换为元组 self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name) + # 检查pad_mode是否为valid、same或pad if pad_mode != 'pad' and pad != (0, 0, 0, 0): raise ValueError(f"For '{self.name}', the 'pad' must be zero or (0, 0, 0, 0) when 'pad_mode' " f"is not \"pad\", but got 'pad' is {self.pad} and 'pad_mode' is {pad_mode}.") + # 如果pad_mode不是pad,且pad不是(0, 0, 0, 0),则抛出异常 self.add_prim_attr("pad", pad) + # 添加pad属性 self.padding = pad + # 将pad赋值给padding if self.pad_mode == 'pad': for item in pad: validator.check_non_negative_int(item, 'pad item', self.name) - + + # 如果pad_mode为pad,则检查pad中的每个元素是否为非负整数 pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) + # 添加pad_mode属性 self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) + # 检查mode是否为1 self.group = validator.check_positive_int(group, 'group', self.name) + # 检查group是否为正整数 self.add_prim_attr('groups', self.group) + # 添加group属性 if pad_list: for x in pad_list: validator.check_non_negative_int(x, 'element of pad_list', self.name) self.pad_list = pad_list - - + + class Conv2DTranspose(Conv2DBackpropInput): """ Compute a 2D transposed convolution, which is also known as a deconvolution (although it is not an actual deconvolution). - + Args: out_channel (int): The dimensionality of the output space. kernel_size (Union[int, tuple[int]]): The size of the convolution window. @@ -2017,9 +1564,9 @@ class Conv2DTranspose(Conv2DBackpropInput): dilation (Union[int. tuple[int]]): Specifies the dilation rate to be used for the dilated convolution. Default: 1. group (int): Splits input into groups. Default: 1. - data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW',\ + data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW', default is 'NCHW'. - + Inputs: - **dout** (Tensor) - the gradients with respect to the output of the convolution. The shape conforms to the default data_format :math:`(N, C_{out}, H_{out}, W_{out})`. @@ -2027,10 +1574,10 @@ class Conv2DTranspose(Conv2DBackpropInput): :math:`(C_{out}, C_{in}, K_1, K_2)`. - **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format :math:`(N, C_{in}, H_{in}, W_{in})`. - + Outputs: Tensor, the gradients with respect to the input of convolution. It has the same shape as the input. - + Raises: TypeError: If `kernel_size`, `stride`, `pad` or `dilation` is neither an int nor a tuple. TypeError: If `out_channel` or `group` is not an int. @@ -2039,10 +1586,10 @@ class Conv2DTranspose(Conv2DBackpropInput): ValueError: If `padding` is a tuple whose length is not equal to 4. ValueError: If `pad_mode` it not equal to 'pad' and `pad` is not equal to (0, 0, 0, 0). ValueError: If `data_format` is neither 'NCHW' nor 'NHWC'. - + Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` - + Examples: >>> dout = Tensor(np.ones([10, 32, 30, 30]), mindspore.float32) >>> weight = Tensor(np.ones([32, 32, 3, 3]), mindspore.float32) @@ -2052,13 +1599,13 @@ class Conv2DTranspose(Conv2DBackpropInput): >>> print(output.shape) (10, 32, 32, 32) """ - + @prim_attr_register def __init__(self, out_channel, kernel_size, pad_mode="valid", pad=0, pad_list=None, mode=1, stride=1, dilation=1, group=1, data_format="NCHW"): """Initialize Conv2DTranspose.""" super(Conv2DTranspose, self).__init__(out_channel, kernel_size, pad_mode, pad, - pad_list, mode, stride, dilation, group, data_format) + pad_list, mode, stride, dilation, group, data_format)t) class BiasAdd(Primitive): diff --git a/src/mindspore2022/mindspore/python/mindspore/train/callback/_checkpoint.py b/src/mindspore2022/mindspore/python/mindspore/train/callback/_checkpoint.py index 87661110..1158c783 100644 --- a/src/mindspore2022/mindspore/python/mindspore/train/callback/_checkpoint.py +++ b/src/mindspore2022/mindspore/python/mindspore/train/callback/_checkpoint.py @@ -366,31 +366,46 @@ class ModelCheckpoint(Callback): """ def __init__(self, prefix='CKP', directory=None, config=None): + # 初始化函数,设置前缀、目录、配置等参数 super(ModelCheckpoint, self).__init__() + # 调用父类的初始化函数 self._latest_ckpt_file_name = "" + # 初始化最新检查点文件名为空字符串 self._init_time = time.time() + # 初始化初始化时间为当前时间 self._last_time = time.time() + # 初始化最后时间时间为当前时间 self._last_time_for_keep = time.time() + # 初始化最后保存时间为当前时间 self._last_triggered_step = 0 + # 初始化最后触发的步数为0 + # 检查前缀是否为字符串且不包含'/' if not isinstance(prefix, str) or prefix.find('/') >= 0: raise ValueError("For 'ModelCheckpoint', the argument 'prefix' " "for checkpoint file name is invalid, it must be " "string and does not contain '/', but got {}.".format(prefix)) self._prefix = prefix + # 设置前缀 self._exception_prefix = prefix + # 设置异常前缀 + # 如果目录不为空,则创建目录 if directory is not None: self._directory = _make_directory(directory) else: self._directory = _cur_dir + # 否则,使用当前目录 + # 如果启用了恢复上下文,则设置检查点路径 if _get_recovery_context("enable_recovery"): _set_recovery_context(ckpt_path=self._directory) + # 如果config为None,则使用默认的CheckpointConfig if config is None: self._config = CheckpointConfig() else: + # 如果config不是CheckpointConfig类型,则抛出TypeError异常 if not isinstance(config, CheckpointConfig): raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be " "'CheckpointConfig', " @@ -398,11 +413,17 @@ class ModelCheckpoint(Callback): self._config = config # get existing checkpoint files + # 创建CheckpointManager对象 self._manager = CheckpointManager() + # 如果存在相同名称的文件,则更改文件名 self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) + # 获取配置中的append_dict参数,如果没有则设置为空字典 self._append_dict = self._config.append_dict or {} + # 获取append_dict中的epoch_num参数,如果没有则设置为0 self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0 + # 获取append_dict中的step_num参数,如果没有则设置为0 self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0 + # 标记是否已经保存了图 self._graph_saved = False self._need_flush_from_cache = True @@ -413,6 +434,7 @@ class ModelCheckpoint(Callback): Args: run_context (RunContext): Context of the train running. """ + # If the role is PServer, add the role name and rank to the prefix if _is_role_pserver(): self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix cb_params = run_context.original_args() @@ -423,18 +445,23 @@ class ModelCheckpoint(Callback): self._last_triggered_step = cb_params.last_save_ckpt_step cb_params.last_save_ckpt_step = None + # Create the directory if it doesn't exist _make_directory(self._directory) # save graph (only once) if not self._graph_saved: graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') + # If the graph file already exists and the mode is GRAPH_MODE, remove it if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE: os.remove(graph_file_name) + # Save the graph _save_graph(cb_params.train_network, graph_file_name) self._graph_saved = True + # Wait for any asynchronous checkpoint saving threads to finish thread_list = threading.enumerate() for thread in thread_list: if thread.getName() == "asyn_save_ckpt": thread.join() + # Save the checkpoint self._save_ckpt(cb_params) def end(self, run_context): @@ -444,44 +471,63 @@ class ModelCheckpoint(Callback): Args: run_context (RunContext): Context of the train running. """ + # 获取训练的参数 cb_params = run_context.original_args() + # 设置保存最后一个checkpoint的标志为True _to_save_last_ckpt = True + # 保存最后一个checkpoint self._save_ckpt(cb_params, _to_save_last_ckpt) + # 获取当前线程列表 thread_list = threading.enumerate() + # 遍历线程列表 for thread in thread_list: + # 如果线程名为"asyn_save_ckpt",则等待该线程结束 if thread.getName() == "asyn_save_ckpt": thread.join() + # 销毁所有gather cell destroy_allgather_cell() def _check_save_ckpt(self, cb_params, force_to_save): """Check whether save checkpoint files or not.""" + # 如果配置了保存检查点步数且步数大于0 if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: + # 如果当前步数大于等于上次触发保存检查点步数加上保存检查点步数,或者强制保存检查点 if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ or force_to_save is True: return True + # 如果配置了保存检查点秒数且秒数大于0 elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: + # 获取当前时间 self._cur_time = time.time() + # 如果当前时间减去上次时间大于保存检查点秒数,或者强制保存检查点 if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save: + # 更新上次时间 self._last_time = self._cur_time return True + # 返回False return False def _save_ckpt(self, cb_params, force_to_save=False): """Save checkpoint files.""" + # 如果当前步骤数等于最后触发的步骤数,则返回 if cb_params.cur_step_num == self._last_triggered_step: return # if param is cache enable, flush data from cache to host before save_ckpt + # 如果需要从缓存中刷新数据,则调用_flush_from_cache方法 if self._need_flush_from_cache: self._flush_from_cache(cb_params) + # 检查是否需要保存检查点,如果force_to_save为True,则强制保存 save_ckpt = self._check_save_ckpt(cb_params, force_to_save) + # 计算当前步数在epoch中的位置 step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) + # 如果需要保存检查点,则创建当前检查点的文件名 if save_ckpt: cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ + str(step_num_in_epoch) + ".ckpt" @@ -489,43 +535,68 @@ class ModelCheckpoint(Callback): self._manager.update_ckpoint_filelist(self._directory, self._prefix) # keep checkpoint files number equal max number. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: + # 如果keep_checkpoint_max配置存在且大于0且小于等于当前checkpoint文件数量,则删除最旧的checkpoint文件 self._manager.remove_oldest_ckpoint_file() elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: + # 如果keep_checkpoint_per_n_minutes配置存在且大于0,则记录当前时间 self._cur_time_for_keep = time.time() + # 如果当前时间与上次记录的时间之差小于keep_checkpoint_per_n_minutes配置的分钟数乘以60,则保留每个分钟的一个checkpoint文件 if (self._cur_time_for_keep - self._last_time_for_keep) \ < self._config.keep_checkpoint_per_n_minutes * 60: self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, self._cur_time_for_keep) # generate the new checkpoint file and rename it. + # 定义全局变量_save_dir,并将其赋值为self._directory global _save_dir _save_dir = self._directory + # 获取当前checkpoint文件的路径 cur_file = os.path.join(self._directory, cur_ckpoint_file) + # 记录当前时间 self._last_time_for_keep = time.time() + # 记录当前触发步数 self._last_triggered_step = cb_params.cur_step_num + # 如果启用了GE(Graph Execution) if context.get_context("enable_ge"): + # 设置当前网络 set_cur_net(cb_params.train_network) + # 执行checkpoint图 cb_params.train_network.exec_checkpoint_graph() + # 如果_append_dict中包含"epoch_num" if "epoch_num" in self._append_dict: + # 将_append_epoch_num加上当前epoch数赋值给"epoch_num" self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num + # 如果_append_dict中包含"step_num" if "step_num" in self._append_dict: + # 将_append_step_num加上当前step数赋值给"step_num" self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num + # 获取保存的网络,如果self._config.saved_network不为None,则使用self._config.saved_network,否则使用cb_params.train_network network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network + # 保存checkpoint save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, self._append_dict, self._config.enc_key, self._config.enc_mode) + # 记录最新的checkpoint文件名 self._latest_ckpt_file_name = cur_file def _flush_from_cache(self, cb_params): """Flush cache data to host if tensor is cache enable.""" + # 初始化has_cache_params为False has_cache_params = False + # 获取训练网络中的参数 params = cb_params.train_network.get_parameters() + # 遍历参数 for param in params: + # 如果参数的cache_enable为True if param.cache_enable: + # 设置has_cache_params为True has_cache_params = True + # 将参数的Tensor数据从缓存中刷新到主机 Tensor(param).flush_from_cache() + # 如果没有参数的cache_enable为True if not has_cache_params: + # 设置_need_flush_from_cache为False self._need_flush_from_cache = False @property @@ -535,63 +606,88 @@ class ModelCheckpoint(Callback): class CheckpointManager: - """Manage checkpoint files according to train_config of checkpoint.""" + """管理检查点文件,根据训练配置进行管理。""" def __init__(self): + """初始化检查点管理器,创建空的检查点文件列表。""" self._ckpoint_filelist = [] @property def ckpoint_filelist(self): - """Get all the related checkpoint files managed here.""" + """获取当前管理的所有检查点文件列表。""" return self._ckpoint_filelist @property def ckpoint_num(self): - """Get the number of the related checkpoint files managed here.""" + """获取当前管理的检查点文件数量。""" return len(self._ckpoint_filelist) def update_ckpoint_filelist(self, directory, prefix): - """Update the checkpoint file list.""" + """更新检查点文件列表,根据目录和前缀筛选符合条件的检查点文件。""" + # 初始化一个空列表,用于存储ckpt文件 self._ckpoint_filelist = [] + # 获取指定目录下的所有文件 files = os.listdir(directory) + # 遍历所有文件 for filename in files: + # 判断文件是否以指定前缀开头,并且以.ckpt结尾 if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"): + # 获取文件名中间部分 mid_name = filename[len(prefix):-5] + # 判断中间部分是否包含字母 flag = not (True in [char.isalpha() for char in mid_name]) + # 如果不包含字母,则将文件路径添加到列表中 if flag: self._ckpoint_filelist.append(os.path.join(directory, filename)) def remove_ckpoint_file(self, file_name): - """Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" + """从检查点管理器中移除指定的检查点文件,并从目录中删除该文件。""" try: + # 修改文件权限为可写 os.chmod(file_name, stat.S_IWRITE) + # 删除文件 os.remove(file_name) + # 从ckpoint文件列表中移除该文件 self._ckpoint_filelist.remove(file_name) except OSError: + # 捕获OSError异常,并记录警告日志 logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) except ValueError: + # 捕获ValueError异常,并记录警告日志 logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) def remove_oldest_ckpoint_file(self): - """Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" + """移除检查点管理器中最早的检查点文件,并从目录中删除该文件。""" + # 获取所有checkpoint文件,并按修改时间排序 ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) + # 删除最早修改的checkpoint文件 self.remove_ckpoint_file(ckpoint_files[0]) def keep_one_ckpoint_per_minutes(self, minutes, cur_time): - """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" + """保留每分钟生成的最新检查点文件,移除在指定时间范围内生成的其他文件。""" + # 定义一个空列表,用于存储需要删除的文件 del_list = [] + # 定义一个空字符串,用于存储最旧的文件名 oldest_file = '' + # 定义一个变量,用于存储当前时间 oldest_time = cur_time + # 遍历_ckpoint_filelist中的文件 for ck_file in self._ckpoint_filelist: + # 获取文件的修改时间 modify_time = os.path.getmtime(ck_file) + # 如果当前时间减去文件的修改时间小于60*minutes,则将文件添加到del_list中 if cur_time - modify_time < 60 * minutes: del_list.append(ck_file) + # 如果文件的修改时间小于oldest_time,则更新oldest_time和oldest_file if modify_time < oldest_time: oldest_time = modify_time oldest_file = ck_file + # 遍历del_list中的文件 for mv_file in del_list: + # 如果文件是最旧的文件,则跳过 if mv_file == oldest_file: continue - self.remove_ckpoint_file(mv_file) + # 调用remove_ckpoint_file方法删除文件 + self.remove_ckpoint_file(mv_file) \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/train/dataset_helper.py b/src/mindspore2022/mindspore/python/mindspore/train/dataset_helper.py index 7dc22fe6..c026669c 100644 --- a/src/mindspore2022/mindspore/python/mindspore/train/dataset_helper.py +++ b/src/mindspore2022/mindspore/python/mindspore/train/dataset_helper.py @@ -256,36 +256,53 @@ class DatasetHelper: """ def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1): + # 检查dataset_sink_mode是否为布尔值 dataset_sink_mode = Validator.check_bool(dataset_sink_mode) + # 检查sink_size是否为整数 Validator.check_is_int(sink_size) + # 如果sink_size小于-1或者等于0,抛出异常 if sink_size < -1 or sink_size == 0: raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size)) + # 如果sink_size等于-1,则将其设置为dataset的dataset_size if sink_size == -1: sink_size = dataset.get_dataset_size() + # 如果dataset_sink_mode为True,则根据不同的设备类型选择不同的迭代器 if dataset_sink_mode: + # 如果启用了GE,则使用GE的迭代器 if context.get_context("enable_ge"): iterclass = _DatasetIterGE else: + # 如果当前模式为GRAPH_MODE,则根据角色选择不同的迭代器 if context.get_context("mode") == context.GRAPH_MODE: + # 如果当前角色为调度器或者参数服务器,则使用参数服务器的迭代器 if _is_role_sched() or _is_role_pserver(): iterclass = _DatasetIterPSServer + # 如果当前角色为工作节点并且是参数服务器模式,则使用参数服务器工作节点的迭代器 elif _is_role_worker() and _is_ps_mode(): iterclass = _DatasetIterPSWork + # 如果当前设备类型为Ascend或者GPU,则使用多线程循环的迭代器 elif (context.get_context("device_target") == "Ascend") or \ (context.get_context("device_target") == "GPU"): iterclass = _DatasetIterMSLoopSink + # 如果当前设备类型为CPU,则抛出异常,因为CPU不支持数据集下沉模式 elif context.get_context("device_target") == "CPU": raise RuntimeError("Currently dataset sink mode is not supported when the device " "target is CPU, please set dataset sink mode to False.") + # 如果当前模式不是GRAPH_MODE,则使用PyNative的迭代器 else: iterclass = _DatasetIterPyNative + # 创建迭代器 self.iter = iterclass(dataset, sink_size, epoch_num) + # 如果dataset_sink_mode为False,则使用普通的迭代器 else: + # 如果不是分布式训练,则使用_DatasetIterNormal类 iterclass = _DatasetIterNormal + # 初始化迭代器 self.iter = iterclass(dataset, epoch_num=epoch_num) def __iter__(self): + # 返回self.iter的迭代器 return self.iter.__iter__() # A temp solution for loop sink. Delete later @@ -301,6 +318,7 @@ class DatasetHelper: >>> >>> types, shapes = dataset_helper.types_shapes() """ + # 从当前配置的dataset中获取类型和形状 return self.iter.types_shapes() def sink_size(self): @@ -316,18 +334,22 @@ class DatasetHelper: >>> # if sink_size==-1, then will return the full size of source dataset. >>> sink_size = dataset_helper.sink_size() """ + # 返回迭代器的接收缓冲区大小 return self.iter.get_sink_size() def stop_send(self): """Stop send data about data sink.""" + # 停止发送关于数据接收器的数据 self.iter.stop_send() def release(self): """Free up resources about data sink.""" + # 释放数据接收器的资源 self.iter.release() def continue_send(self): """Continue to send data to device at the beginning of epoch.""" + # 在每个epoch的开始处继续向设备发送数据 self.iter.continue_send() def _reset(self, step): @@ -339,6 +361,7 @@ class DatasetHelper: In sink mode, it returns the types and shapes of the current data. Generally, it works in dynamic shape scenarios. """ + # 返回迭代器的数据信息 return self.iter.get_data_info() def dynamic_min_max_shapes(self): @@ -355,6 +378,7 @@ class DatasetHelper: >>> >>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes() """ + # 返回self.iter的dynamic_min_max_shapes方法 return self.iter.dynamic_min_max_shapes() @@ -362,20 +386,27 @@ class _DatasetIter: """Base iter for dataset helper""" def __init__(self, dataset, sink_size, epoch_num): + # 初始化函数,传入数据集、sink大小和epoch数量 self.dataset = dataset self.sink_size = sink_size self.sink_count = self.get_sink_count(dataset) + # 如果数据集没有__transfer_dataset__属性 if not hasattr(dataset, '__transfer_dataset__'): + # 如果数据集有__loop_size__属性 if hasattr(dataset, '__loop_size__'): # PS mode does not support loop sink and need get the real sink size. + # 如果不是worker角色或者不是ps模式,则设置sink_size为dataset的循环大小 if not (_is_role_worker() and _is_ps_mode()): self.sink_size = dataset.__loop_size__ + # 如果sink_size为1,sink_count为1,dataset的大小不为1,并且设备目标为Ascend,则创建数据信息队列 create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1 and context.get_context("device_target") == "Ascend") + # 执行数据图,并将sink_size和create_data_info_queue作为参数传入 dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, create_data_info_queue=create_data_info_queue) + # 如果dataset没有__no_send__属性,则发送数据 if not hasattr(dataset, '__no_send__'): _send_data(dataset, epoch_num) else: @@ -384,33 +415,48 @@ class _DatasetIter: _cell_graph_executor.set_queue_name(dataset.__transfer_dataset__.queue_name) _send_data_no_flag(dataset, epoch_num) + # 获取dataset的stop_send方法 self.stop_send = dataset.__transfer_dataset__.stop_send + # 获取dataset的release方法 self.release = dataset.__transfer_dataset__.release + # 获取dataset的continue_send方法 self.continue_send = dataset.__transfer_dataset__.continue_send + # 获取dataset的get_data_info方法 self.get_data_info = dataset.__transfer_dataset__.get_data_info + # 获取dataset的dynamic_min_max_shapes属性 self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes + # 获取dataset的数据类型和数据形状 self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) + # 如果dataset的__transfer_dataset__属性中有_reset方法,则获取该_reset方法 if hasattr(dataset.__transfer_dataset__, "_reset"): self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212 def __iter__(self): + # 初始化索引为0 self.index = 0 + # 返回self return self - + # 迭代器的下一项 def __next__(self): + # 如果索引大于等于sink_count,抛出StopIteration异常 if self.index >= self.sink_count: raise StopIteration() + # 索引加1 self.index += 1 + # 返回op()的返回值 return self.op() def types_shapes(self): """ - Return the types and shapes of the dataset. The type and shape of each data in the dataset - should be consistent. + 返回数据集的类型和形状。数据集中每个数据的类型和形状应该是一致的。 """ return self.dataset_types, self.dataset_shapes - def get_sink_count(self, dataset): + """ + 获取数据集的sink次数 + :param dataset: 数据集对象 + :return: sink次数 + """ sink_count = 1 if hasattr(dataset, '__loop_size__'): loop_size = dataset.__loop_size__ @@ -421,7 +467,10 @@ class _DatasetIter: return sink_count def get_sink_size(self): - """get sink_size to device""" + """ + 获取设备的sink大小 + :return: sink大小 + """ sink_size = 1 if hasattr(self.dataset, '__loop_size__'): sink_size = self.dataset.__loop_size__ diff --git a/src/mindspore2022/setup.py b/src/mindspore2022/setup.py index 63d69b00..26b61f63 100644 --- a/src/mindspore2022/setup.py +++ b/src/mindspore2022/setup.py @@ -23,47 +23,59 @@ from setuptools import setup, find_packages from setuptools.command.egg_info import egg_info from setuptools.command.build_py import build_py +# 获取环境变量 backend_policy = os.getenv('BACKEND_POLICY') device_target = os.getenv('BACKEND_TARGET') commit_id = os.getenv('COMMIT_ID').replace("\n", "") package_name = os.getenv('MS_PACKAGE_NAME').replace("\n", "") build_path = os.getenv('BUILD_PATH') +# 获取当前文件路径 pwd = os.path.dirname(os.path.realpath(__file__)) +# 获取包目录路径 pkg_dir = os.path.join(build_path, 'package') def _read_file(filename): + """读取文件内容""" with open(os.path.join(pwd, filename), encoding='UTF-8') as f: return f.read() +# 读取版本号 version = _read_file('version.txt').replace("\n", "") +# 读取README.md文件内容 readme = _read_file('README.md') def _write_version(file): + """写入版本号""" file.write("__version__ = '{}'\n".format(version)) def _write_config(file): + """写入后端策略""" file.write("__backend__ = '{}'\n".format(backend_policy)) def _write_commit_file(file): + """写入commit_id""" file.write("__commit_id__ = '{}'\n".format(commit_id)) def _write_package_name(file): + """写入包名""" file.write("__package_name__ = '{}'\n".format(package_name)) def _write_device_target(file): + """写入设备目标""" file.write("__device_target__ = '{}'\n".format(device_target)) def build_dependencies(): """generate python file""" + # 生成version.py文件 version_file = os.path.join(pkg_dir, 'mindspore', 'version.py') with open(version_file, 'w') as f: _write_version(f) @@ -72,6 +84,7 @@ def build_dependencies(): with open(version_file, 'w') as f: _write_version(f) + # 生成default_config.py文件 config_file = os.path.join(pkg_dir, 'mindspore', 'default_config.py') with open(config_file, 'w') as f: _write_config(f) @@ -80,6 +93,7 @@ def build_dependencies(): with open(config_file, 'w') as f: _write_config(f) + # 向default_config.py文件中追加device_target target = os.path.join(pkg_dir, 'mindspore', 'default_config.py') with open(target, 'a') as f: _write_device_target(f) @@ -88,6 +102,7 @@ def build_dependencies(): with open(target, 'a') as f: _write_device_target(f) + # 向default_config.py文件中追加package_name package_info = os.path.join(pkg_dir, 'mindspore', 'default_config.py') with open(package_info, 'a') as f: _write_package_name(f) @@ -96,6 +111,7 @@ def build_dependencies(): with open(package_info, 'a') as f: _write_package_name(f) + # 生成.commit_id文件 commit_file = os.path.join(pkg_dir, 'mindspore', '.commit_id') with open(commit_file, 'w') as f: _write_commit_file(f) @@ -145,16 +161,24 @@ def update_permissions(path): Args: path (str): Target directory path. """ + # 判断操作系统是否为Windows if platform.system() == "Windows": return + # 遍历目标目录下的所有文件和文件夹 for dirpath, dirnames, filenames in os.walk(path): + # 遍历文件夹 for dirname in dirnames: + # 获取文件夹的完整路径 dir_fullpath = os.path.join(dirpath, dirname) + # 更新文件夹的权限 os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP) + # 遍历文件 for filename in filenames: + # 获取文件的完整路径 file_fullpath = os.path.join(dirpath, filename) + # 更新文件的权限 os.chmod(file_fullpath, stat.S_IREAD) @@ -163,7 +187,9 @@ class EggInfo(egg_info): def run(self): super().run() + # 获取egg-info目录的路径 egg_info_dir = os.path.join(pkg_dir, 'mindspore.egg-info') + # 更新egg-info目录的权限 update_permissions(egg_info_dir) @@ -172,41 +198,64 @@ class BuildPy(build_py): def run(self): super().run() + # 获取build目录下的lib/mindspore目录的路径 mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore') + # 更新lib/mindspore目录的权限 update_permissions(mindspore_dir) + # 获取build目录下的lib/mindspore/_akg目录的路径 mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore', '_akg') + # 更新lib/mindspore/_akg目录的权限 update_permissions(mindspore_dir) +# 设置包的名称 setup( name=package_name, + # 设置包的版本 version=version, + # 设置包的作者 author='The MindSpore Authors', + # 设置包的作者邮箱 author_email='contact@mindspore.cn', + # 设置包的网址 url='https://www.mindspore.cn', + # 设置包的下载网址 download_url='https://github.com/mindspore-ai/mindspore/tags', + # 设置包的源代码网址 project_urls={ 'Sources': 'https://github.com/mindspore-ai/mindspore', + # 设置包的问题追踪网址 'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues', }, + # 设置包的描述 description='MindSpore is a new open source deep learning training/inference ' 'framework that could be used for mobile, edge and cloud scenarios.', + # 读取readme文件作为包的详细描述 long_description=readme, + # 设置详细描述的格式 long_description_content_type="text/markdown", + # 查找包中的所有模块 packages=find_packages(), + # 设置包的数据 package_data=package_data, + # 包含包中的所有数据 include_package_data=True, + # 设置自定义的命令类 cmdclass={ 'egg_info': EggInfo, 'build_py': BuildPy, }, + # 设置包的入口点 entry_points={ 'console_scripts': [ 'cache_admin=mindspore.dataset.engine.cache_admin:main', ], }, + # 设置包的Python版本要求 python_requires='>=3.7', + # 设置包的依赖 install_requires=required_package, + # 设置包的分类器 classifiers=[ 'Development Status :: 4 - Beta', 'Environment :: Console', @@ -223,6 +272,8 @@ setup( 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', ], + # 设置包的许可证 license='Apache 2.0', + # 设置包的关键词 keywords='mindspore machine learning', ) diff --git a/src/mindspore2022/tests/mindspore_test_framework/__init__.py b/src/mindspore2022/tests/mindspore_test_framework/__init__.py index 3a970347..5ce38111 100644 --- a/src/mindspore2022/tests/mindspore_test_framework/__init__.py +++ b/src/mindspore2022/tests/mindspore_test_framework/__init__.py @@ -17,5 +17,7 @@ import mindspore.context as context def setup_module(module): + # 禁用pylint对未使用参数的警告 # pylint: disable=unused-argument + # 设置上下文模式为图模式 context.set_context(mode=context.GRAPH_MODE) diff --git a/src/mindspore2022/tests/mindspore_test_framework/mindspore_test.py b/src/mindspore2022/tests/mindspore_test_framework/mindspore_test.py index 5ec5d8b6..33017044 100644 --- a/src/mindspore2022/tests/mindspore_test_framework/mindspore_test.py +++ b/src/mindspore2022/tests/mindspore_test_framework/mindspore_test.py @@ -26,25 +26,28 @@ from .utils import keyword def mindspore_test(verification_pipeline): """ - Run verification pipeline. + 运行验证流水线。 Args: - verification_pipeline (list): Pipeline designed to do verification. + verification_pipeline (list): 设计的验证流水线。 Returns: """ def decorate(get_verification_set): + # 获取验证集 verification_set = get_verification_set() - facade_components = [] - data_components = [] - builder_components = [] - executor_components = [] - verifier_components = [] - fi_policy_components = [] - er_policy_components = [] + # 初始化组件列表 + facade_components = [] # 外观组件列表 + data_components = [] # 数据组件列表 + builder_components = [] # 构建组件列表 + executor_components = [] # 执行组件列表 + verifier_components = [] # 验证组件列表 + fi_policy_components = [] # FI策略组件列表 + er_policy_components = [] # ER策略组件列表 for component in verification_pipeline: + # 判断组件类型并添加到对应列表 if issubclass(component, IFacadeComponent): facade_components.append(component) elif issubclass(component, IDataComponent): @@ -62,68 +65,90 @@ def mindspore_test(verification_pipeline): else: raise Exception(f'{component} is not an instance of {IComponent}') + # 依次处理外观组件 for component in facade_components: fc = component(verification_set) verification_set = fc() + # 初始化输入列表 inputs = [] + # 依次处理数据组件 for component in data_components: dc = component(verification_set) item = dc() inputs.extend(item) + # 如果输入列表为空,记录警告 if not inputs: logging.warning("Inputs set is empty.") + # 初始化函数列表 functions = [] + # 依次处理构建组件 for component in builder_components: bc = component(verification_set) f = bc() functions.extend(f) + # 如果函数列表为空,记录警告 if not functions: logging.warning("Function set is empty.") + # 初始化函数输入对列表 fis = [] + # 依次处理FI策略组件 for component in fi_policy_components: fipc = component(verification_set, functions, inputs) result = fipc() fis.extend(result) + # 如果函数输入对列表为空,记录警告 if not fis: logging.warning("Function inputs pair set is empty.") + # 定义测试用例函数 def test_case(args): + # 提取系统待测和输入参数 sut, inputs = args + # 初始化结果列表 results = [] + # 依次处理执行组件 for component in executor_components: ec = component(verification_set, sut, inputs) result = ec() results.append(result) + # 如果结果列表为空,记录警告 if not results: logging.warning("Result set is empty.") + # 初始化期望实际结果对列表 expect_actuals = [] + # 依次处理ER策略组件 for component in er_policy_components: erpc = component(verification_set, verification_set['expect'], results) result = erpc() expect_actuals.extend(result) + # 如果期望实际结果对列表为空,记录警告 if not expect_actuals: logging.warning("Expect Result pair set is empty.") + # 依次处理验证组件 for ea in expect_actuals: for component in verifier_components: vc = component(verification_set, *ea) vc() + # 定义测试用例名称生成函数 def get_tc_name(f, inputs): + # 拼接测试用例ID和组名 tc_id = f[keyword.id] + '-' + inputs[keyword.id] group = f[keyword.group] + '-' + inputs[keyword.group] return 'Group_' + group + '-' + 'Id_' + tc_id + # 如果存在函数输入对,则生成测试用例 if fis: m = pytest.mark.parametrize('args', fis, ids=lambda fi: get_tc_name(*fi))(test_case) m.__orig__ = get_verification_set