From 1531f3358222ffcb5276d6c436b8cc84d49bea19 Mon Sep 17 00:00:00 2001 From: yixin <2050485123@qq.com> Date: Wed, 25 Dec 2024 09:27:11 +0800 Subject: [PATCH] add comments for _extends\graph_kernel\expanders\_utils.py --- .../_extends/graph_kernel/expanders/_utils.py | 297 ++++++++++++++++++ 1 file changed, 297 insertions(+) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/_utils.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/_utils.py index 1bdb205d..9fb8a071 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/_utils.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/expanders/_utils.py @@ -27,6 +27,19 @@ class Expander: __metaclass__ = ABCMeta def __init__(self, expand_info): + """ + 初始化方法。 + + Args: + expand_info (dict): 包含模型信息的字典,包括模型名称、输入描述、输出描述、属性、处理函数等。 + + Attributes: + name (str): 模型名称。 + inputs (list): 输入描述列表。 + outputs (list): 输出描述列表。 + attrs (dict): 模型属性字典。 + processor (callable): 处理函数。 + """ self.name = expand_info["name"] self.inputs = expand_info["input_desc"] self.outputs = expand_info["output_desc"] @@ -34,6 +47,19 @@ class Expander: self.processor = expand_info["process"] def run(self): + """ + 将操作扩展为图。 + + Args: + 无 + + Returns: + 返回扩展后的图对象。 + + Raises: + GraphKernelUnsupportedException: 如果检查失败,则引发此异常。 + + """ """ Expand the operator to a graph. @@ -58,9 +84,31 @@ class Expander: return graph def _check(self): + """ + 检查输入。 + + Args: + 无 + + Returns: + 无 + + Raises: + ValueError: 如果输入不符合要求,则引发此异常。 + + """ """Check inputs""" def _check_output_same(self, outputs): + """ + 检查输出是否与预期一致。 + + Args: + outputs (list): 实际输出值的列表。 + + Raises: + GKException: 如果实际输出值与预期不一致,则抛出异常。 + """ for index, value in enumerate(self.outputs): if list(outputs[index].shape) != list(value['shape']): raise GKException("{} 's output shape {} is wrong. Expected:{}".format( @@ -74,6 +122,18 @@ class Expander: @abstractmethod def _expand(self, graph_builder): + """ + Expand 操作符,此函数应在子类中重写。 + + Args: + graph_builder (GraphBuilder): 图构建器对象。 + + Raises: + Exception: 如果子类未重写此方法,则抛出异常,提示 "_expand() is not implemented in {}". + + Returns: + None + """ """Expand operator, this function should be overridden in subclass""" raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__)) @@ -82,10 +142,34 @@ class ExpanderInfoValidator: """ExpanderInfoValidator is the utility class which defines the validator decorator for expanders""" def __init__(self): + """ + 初始化方法。 + + Args: + 无 + + Returns: + 无 + + """ """Init""" @staticmethod def _add_check_function(kls, func): + """ + 向类 Expander 中的 `_check` 函数添加新的检查函数 `func`。 + + Args: + kls (type): 需要被修改的类对象,应继承自 Expander 类。 + func (callable): 要添加的新检查函数,该函数接受一个对象作为参数。 + + Returns: + None + + Raises: + AttributeError: 如果 kls 类中不存在 `_check` 方法。 + + """ """ Rewrite the function `_check` in class Expander to append the new `func` after the original checks. @@ -93,6 +177,21 @@ class ExpanderInfoValidator: old_check = getattr(kls, "_check") def new_check(obj): + """ + 执行新的检查函数。 + + Args: + obj (Any): 需要检查的对象。 + + Returns: + None + + Raises: + None + + 这个函数首先调用旧版本的检查函数 `old_check` 对传入的对象 `obj` 进行检查, + 然后调用自定义的函数 `func` 对该对象进行处理。 + """ old_check(obj) func(obj) @@ -103,6 +202,34 @@ class ExpanderInfoValidator: """ Add new supported format for the operator + Args: + *input_format: A variable number of arguments representing the new supported formats. + + Returns: + A wrapper function that adds the specified formats to the operator's supported formats list. + + Raises: + GKException: Raised if the length of the registered format list does not match the length of the input formats, + or if the input formats do not match any registered format. + Exception: Raised if the wrapped class is not a subclass of Expander. + + Description: + This function adds a list `__supported_formats` to the expander, + which contains the whitelist of formats supported by the operator. + It also rewrites the `_check` function to check the formats. + + Example: + python + @add_format("text", "image") + class MyOperator(Expander): + pass + ``` + + In this example, `MyOperator` will support the "text" and "image" formats. + """ + """ + Add new supported format for the operator + this function will add a list `__supported_formats` into the expander, saving the whitelist of formats that this op supports. it also rewrites the `_check` function to check the formats. @@ -110,6 +237,19 @@ class ExpanderInfoValidator: format_list_name = "__supported_formats" def _check_format(obj): + """ + 检查对象的输入格式是否与已注册的格式匹配。 + + Args: + obj (object): 需要检查的对象。 + + Raises: + GKException: 如果输入格式与已注册的格式不匹配,则引发异常。 + + Returns: + None + + """ inp_formats = [inp['format'] for inp in obj.inputs] for formats in getattr(obj, format_list_name): if len(formats) != len(inp_formats): @@ -120,6 +260,18 @@ class ExpanderInfoValidator: raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name)) def wrapper(cls): + """ + 为给定的类添加包装功能。 + + Args: + cls: 需要被包装的类,必须继承自 Expander 类。 + + Returns: + 返回包装后的类。 + + Raises: + Exception: 如果 cls 不是 Expander 的子类,则抛出异常。 + """ if not issubclass(cls, Expander): raise Exception("{} should be subclass of Expander.".format(cls.__name__)) if not hasattr(cls, format_list_name): @@ -132,11 +284,49 @@ class ExpanderInfoValidator: @staticmethod def check_all_formats_same(kls): + """ + 检查所有格式是否相同。 + + Args: + kls: 待检查的类 + + Returns: + 返回传入的类 kls,并在类上注册一个检查函数,用于验证该类所有输入格式是否一致。 + + Raises: + Exception: 如果传入的类 kls 不是 Expander 的子类,则抛出异常。 + GKException: 如果 kls 类中的输入格式不一致,则抛出异常,并显示不匹配格式的信息。 + + """ """Check that all formats are the same""" # Ensure no args case can return a class def _check(*args): + """ + 检查操作输入格式是否一致的装饰器。 + + Args: + *args: 可变参数,装饰器可以接收任意数量的参数。 + + Returns: + wrapper: 返回一个装饰器函数,用于包装类。 + + Raises: + GKException: 如果所有输入的格式不一致,抛出GKException异常。 + Exception: 如果被装饰的类不是Expander的子类,抛出异常。 + + """ def _check_format(obj): + """ + 检查输入格式是否一致。 + + Args: + obj (Any): 包含输入信息的对象。 + + Raises: + GKException: 如果所有输入格式不一致,则抛出异常,并包含不匹配格式的具体信息。 + + """ inp_formats = [inp['format'] for inp in obj.inputs] if all((fmt == inp_formats[0] for fmt in inp_formats[1:])): return @@ -144,6 +334,19 @@ class ExpanderInfoValidator: ','.join(inp_formats), obj.name)) def wrapper(cls): + """ + 将给定类包装为 Expander 的子类,并进行格式检查。 + + Args: + cls (class): 需要包装的类。 + + Returns: + class: 包装后的类。 + + Raises: + Exception: 如果 cls 不是 Expander 的子类,则抛出异常。 + + """ if not issubclass(cls, Expander): raise Exception("{} should be subclass of Expander.".format(cls.__name__)) ExpanderInfoValidator._add_check_function(cls, _check_format) @@ -155,14 +358,53 @@ class ExpanderInfoValidator: @staticmethod def check_attrs(*args): + """ + 检查属性是否存在。 + + Args: + *args: 一个或多个属性名,用于检查对象是否具有这些属性。 + + Returns: + 一个装饰器函数,该装饰器函数用于验证类是否具有指定的属性。 + + Raises: + GKException: 如果对象不具有指定的属性,则抛出该异常。 + Exception: 如果被装饰的类不是 Expander 的子类,则抛出该异常。 + + """ """Check the attrs exist""" def _check_attr(obj): + """ + 检查对象是否具有指定的属性。 + + Args: + obj (object): 要检查的对象。 + + Raises: + GKException: 如果对象不具有指定的属性,则抛出异常。 + + Returns: + None + """ for a in args: if a not in obj.attrs: raise GKException("attr '{}' does not exist.".format(a)) def wrapper(cls): + """ + 对类进行包装,确保该类是 Expander 的子类,并添加属性检查功能。 + + Args: + cls (class): 需要包装的类。 + + Returns: + class: 包装后的类。 + + Raises: + Exception: 如果 cls 不是 Expander 的子类,则抛出异常。 + + """ if not issubclass(cls, Expander): raise Exception("{} should be subclass of Expander.".format(cls.__name__)) ExpanderInfoValidator._add_check_function(cls, _check_attr) @@ -172,6 +414,21 @@ class ExpanderInfoValidator: def to_frac_z_axis(ori_shape, ori_axis): + """ + 判断是否为分形NZ格式 + + Args: + ---- + ori_shape: list or tuple + 输入的原始形状 + ori_axis: list or tuple + 操作的原始形状的轴 + + Returns: + ------- + output: list + 分形Nz形状的轴 + """ """ judge the format is fractal NZ Parameters @@ -208,6 +465,16 @@ def to_frac_z_axis(ori_shape, ori_axis): def infer_shape_from_fractalnz(fractal): + """ + 从fractalnz形状推断原始形状 + + Args: + fractal (list): fractalnz形状,一个包含形状的列表 + + Returns: + list: 推断出的原始形状 + + """ "get original shape from fractalnz shape" shape = [] dims = len(fractal) @@ -222,6 +489,17 @@ def infer_shape_from_fractalnz(fractal): def get_reduced_ori_shape(shape, axis): + """ + 获取基于原始形状的降维后的形状。 + + Args: + shape (List[int]): 原始形状,是一个整数列表。 + axis (List[int]): 需要降维的轴索引列表。 + + Returns: + List[int]: 降维后的形状,是一个整数列表。 + + """ "get shape after reduced which is based on original shape" reduced_ori_shape = [] for i, value in enumerate(shape): @@ -233,6 +511,25 @@ def get_reduced_ori_shape(shape, axis): def get_reduce_axis_shape(shape, data_format, axis): + """ + 根据给定的输入形状、数据格式和轴,获取在指定格式下的归约轴和原始的归约形状。 + + Args: + ----- + shape: list or tuple + 输入的形状。 + data_format: str + 输入的数据格式。 + axis: None, int, list or tuple + 在原始形状下的归约轴。 + + Returns: + -------- + reduce_axis: list + 在指定数据格式下的归约轴。 + ori_reduced_shape: list + 原始的归约形状。 + """ """ Get the reduce axis under format `data_format` and original reduced shape. Parameters