add comments for _extends\graph_kernel\expanders\_utils.py

branch-yixin
yixin 7 months ago
parent 500f5d3348
commit 1531f33582

@ -27,6 +27,19 @@ class Expander:
__metaclass__ = ABCMeta
def __init__(self, expand_info):
"""
初始化方法
Args:
expand_info (dict): 包含模型信息的字典包括模型名称输入描述输出描述属性处理函数等
Attributes:
name (str): 模型名称
inputs (list): 输入描述列表
outputs (list): 输出描述列表
attrs (dict): 模型属性字典
processor (callable): 处理函数
"""
self.name = expand_info["name"]
self.inputs = expand_info["input_desc"]
self.outputs = expand_info["output_desc"]
@ -34,6 +47,19 @@ class Expander:
self.processor = expand_info["process"]
def run(self):
"""
将操作扩展为图
Args:
Returns:
返回扩展后的图对象
Raises:
GraphKernelUnsupportedException: 如果检查失败则引发此异常
"""
"""
Expand the operator to a graph.
@ -58,9 +84,31 @@ class Expander:
return graph
def _check(self):
"""
检查输入
Args:
Returns:
Raises:
ValueError: 如果输入不符合要求则引发此异常
"""
"""Check inputs"""
def _check_output_same(self, outputs):
"""
检查输出是否与预期一致
Args:
outputs (list): 实际输出值的列表
Raises:
GKException: 如果实际输出值与预期不一致则抛出异常
"""
for index, value in enumerate(self.outputs):
if list(outputs[index].shape) != list(value['shape']):
raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
@ -74,6 +122,18 @@ class Expander:
@abstractmethod
def _expand(self, graph_builder):
"""
Expand 操作符此函数应在子类中重写
Args:
graph_builder (GraphBuilder): 图构建器对象
Raises:
Exception: 如果子类未重写此方法则抛出异常提示 "_expand() is not implemented in {}".
Returns:
None
"""
"""Expand operator, this function should be overridden in subclass"""
raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
@ -82,10 +142,34 @@ class ExpanderInfoValidator:
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
def __init__(self):
"""
初始化方法
Args:
Returns:
"""
"""Init"""
@staticmethod
def _add_check_function(kls, func):
"""
向类 Expander 中的 `_check` 函数添加新的检查函数 `func`
Args:
kls (type): 需要被修改的类对象应继承自 Expander
func (callable): 要添加的新检查函数该函数接受一个对象作为参数
Returns:
None
Raises:
AttributeError: 如果 kls 类中不存在 `_check` 方法
"""
"""
Rewrite the function `_check` in class Expander
to append the new `func` after the original checks.
@ -93,6 +177,21 @@ class ExpanderInfoValidator:
old_check = getattr(kls, "_check")
def new_check(obj):
"""
执行新的检查函数
Args:
obj (Any): 需要检查的对象
Returns:
None
Raises:
None
这个函数首先调用旧版本的检查函数 `old_check` 对传入的对象 `obj` 进行检查
然后调用自定义的函数 `func` 对该对象进行处理
"""
old_check(obj)
func(obj)
@ -103,6 +202,34 @@ class ExpanderInfoValidator:
"""
Add new supported format for the operator
Args:
*input_format: A variable number of arguments representing the new supported formats.
Returns:
A wrapper function that adds the specified formats to the operator's supported formats list.
Raises:
GKException: Raised if the length of the registered format list does not match the length of the input formats,
or if the input formats do not match any registered format.
Exception: Raised if the wrapped class is not a subclass of Expander.
Description:
This function adds a list `__supported_formats` to the expander,
which contains the whitelist of formats supported by the operator.
It also rewrites the `_check` function to check the formats.
Example:
python
@add_format("text", "image")
class MyOperator(Expander):
pass
```
In this example, `MyOperator` will support the "text" and "image" formats.
"""
"""
Add new supported format for the operator
this function will add a list `__supported_formats` into the expander,
saving the whitelist of formats that this op supports.
it also rewrites the `_check` function to check the formats.
@ -110,6 +237,19 @@ class ExpanderInfoValidator:
format_list_name = "__supported_formats"
def _check_format(obj):
"""
检查对象的输入格式是否与已注册的格式匹配
Args:
obj (object): 需要检查的对象
Raises:
GKException: 如果输入格式与已注册的格式不匹配则引发异常
Returns:
None
"""
inp_formats = [inp['format'] for inp in obj.inputs]
for formats in getattr(obj, format_list_name):
if len(formats) != len(inp_formats):
@ -120,6 +260,18 @@ class ExpanderInfoValidator:
raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
def wrapper(cls):
"""
为给定的类添加包装功能
Args:
cls: 需要被包装的类必须继承自 Expander
Returns:
返回包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
if not hasattr(cls, format_list_name):
@ -132,11 +284,49 @@ class ExpanderInfoValidator:
@staticmethod
def check_all_formats_same(kls):
"""
检查所有格式是否相同
Args:
kls: 待检查的类
Returns:
返回传入的类 kls并在类上注册一个检查函数用于验证该类所有输入格式是否一致
Raises:
Exception: 如果传入的类 kls 不是 Expander 的子类则抛出异常
GKException: 如果 kls 类中的输入格式不一致则抛出异常并显示不匹配格式的信息
"""
"""Check that all formats are the same"""
# Ensure no args case can return a class
def _check(*args):
"""
检查操作输入格式是否一致的装饰器
Args:
*args: 可变参数装饰器可以接收任意数量的参数
Returns:
wrapper: 返回一个装饰器函数用于包装类
Raises:
GKException: 如果所有输入的格式不一致抛出GKException异常
Exception: 如果被装饰的类不是Expander的子类抛出异常
"""
def _check_format(obj):
"""
检查输入格式是否一致
Args:
obj (Any): 包含输入信息的对象
Raises:
GKException: 如果所有输入格式不一致则抛出异常并包含不匹配格式的具体信息
"""
inp_formats = [inp['format'] for inp in obj.inputs]
if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
return
@ -144,6 +334,19 @@ class ExpanderInfoValidator:
','.join(inp_formats), obj.name))
def wrapper(cls):
"""
将给定类包装为 Expander 的子类并进行格式检查
Args:
cls (class): 需要包装的类
Returns:
class: 包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_format)
@ -155,14 +358,53 @@ class ExpanderInfoValidator:
@staticmethod
def check_attrs(*args):
"""
检查属性是否存在
Args:
*args: 一个或多个属性名用于检查对象是否具有这些属性
Returns:
一个装饰器函数该装饰器函数用于验证类是否具有指定的属性
Raises:
GKException: 如果对象不具有指定的属性则抛出该异常
Exception: 如果被装饰的类不是 Expander 的子类则抛出该异常
"""
"""Check the attrs exist"""
def _check_attr(obj):
"""
检查对象是否具有指定的属性
Args:
obj (object): 要检查的对象
Raises:
GKException: 如果对象不具有指定的属性则抛出异常
Returns:
None
"""
for a in args:
if a not in obj.attrs:
raise GKException("attr '{}' does not exist.".format(a))
def wrapper(cls):
"""
对类进行包装确保该类是 Expander 的子类并添加属性检查功能
Args:
cls (class): 需要包装的类
Returns:
class: 包装后的类
Raises:
Exception: 如果 cls 不是 Expander 的子类则抛出异常
"""
if not issubclass(cls, Expander):
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
ExpanderInfoValidator._add_check_function(cls, _check_attr)
@ -172,6 +414,21 @@ class ExpanderInfoValidator:
def to_frac_z_axis(ori_shape, ori_axis):
"""
判断是否为分形NZ格式
Args:
----
ori_shape: list or tuple
输入的原始形状
ori_axis: list or tuple
操作的原始形状的轴
Returns:
-------
output: list
分形Nz形状的轴
"""
"""
judge the format is fractal NZ
Parameters
@ -208,6 +465,16 @@ def to_frac_z_axis(ori_shape, ori_axis):
def infer_shape_from_fractalnz(fractal):
"""
从fractalnz形状推断原始形状
Args:
fractal (list): fractalnz形状一个包含形状的列表
Returns:
list: 推断出的原始形状
"""
"get original shape from fractalnz shape"
shape = []
dims = len(fractal)
@ -222,6 +489,17 @@ def infer_shape_from_fractalnz(fractal):
def get_reduced_ori_shape(shape, axis):
"""
获取基于原始形状的降维后的形状
Args:
shape (List[int]): 原始形状是一个整数列表
axis (List[int]): 需要降维的轴索引列表
Returns:
List[int]: 降维后的形状是一个整数列表
"""
"get shape after reduced which is based on original shape"
reduced_ori_shape = []
for i, value in enumerate(shape):
@ -233,6 +511,25 @@ def get_reduced_ori_shape(shape, axis):
def get_reduce_axis_shape(shape, data_format, axis):
"""
根据给定的输入形状数据格式和轴获取在指定格式下的归约轴和原始的归约形状
Args:
-----
shape: list or tuple
输入的形状
data_format: str
输入的数据格式
axis: None, int, list or tuple
在原始形状下的归约轴
Returns:
--------
reduce_axis: list
在指定数据格式下的归约轴
ori_reduced_shape: list
原始的归约形状
"""
"""
Get the reduce axis under format `data_format` and original reduced shape.
Parameters

Loading…
Cancel
Save