|
|
|
@ -27,6 +27,19 @@ class Expander:
|
|
|
|
|
__metaclass__ = ABCMeta
|
|
|
|
|
|
|
|
|
|
def __init__(self, expand_info):
|
|
|
|
|
"""
|
|
|
|
|
初始化方法。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
expand_info (dict): 包含模型信息的字典,包括模型名称、输入描述、输出描述、属性、处理函数等。
|
|
|
|
|
|
|
|
|
|
Attributes:
|
|
|
|
|
name (str): 模型名称。
|
|
|
|
|
inputs (list): 输入描述列表。
|
|
|
|
|
outputs (list): 输出描述列表。
|
|
|
|
|
attrs (dict): 模型属性字典。
|
|
|
|
|
processor (callable): 处理函数。
|
|
|
|
|
"""
|
|
|
|
|
self.name = expand_info["name"]
|
|
|
|
|
self.inputs = expand_info["input_desc"]
|
|
|
|
|
self.outputs = expand_info["output_desc"]
|
|
|
|
@ -34,6 +47,19 @@ class Expander:
|
|
|
|
|
self.processor = expand_info["process"]
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
"""
|
|
|
|
|
将操作扩展为图。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
返回扩展后的图对象。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GraphKernelUnsupportedException: 如果检查失败,则引发此异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
Expand the operator to a graph.
|
|
|
|
|
|
|
|
|
@ -58,9 +84,31 @@ class Expander:
|
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
def _check(self):
|
|
|
|
|
"""
|
|
|
|
|
检查输入。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: 如果输入不符合要求,则引发此异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Check inputs"""
|
|
|
|
|
|
|
|
|
|
def _check_output_same(self, outputs):
|
|
|
|
|
"""
|
|
|
|
|
检查输出是否与预期一致。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
outputs (list): 实际输出值的列表。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果实际输出值与预期不一致,则抛出异常。
|
|
|
|
|
"""
|
|
|
|
|
for index, value in enumerate(self.outputs):
|
|
|
|
|
if list(outputs[index].shape) != list(value['shape']):
|
|
|
|
|
raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
|
|
|
|
@ -74,6 +122,18 @@ class Expander:
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def _expand(self, graph_builder):
|
|
|
|
|
"""
|
|
|
|
|
Expand 操作符,此函数应在子类中重写。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
graph_builder (GraphBuilder): 图构建器对象。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
Exception: 如果子类未重写此方法,则抛出异常,提示 "_expand() is not implemented in {}".
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
"""Expand operator, this function should be overridden in subclass"""
|
|
|
|
|
raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
|
|
|
|
|
|
|
|
|
@ -82,10 +142,34 @@ class ExpanderInfoValidator:
|
|
|
|
|
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""
|
|
|
|
|
初始化方法。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Init"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _add_check_function(kls, func):
|
|
|
|
|
"""
|
|
|
|
|
向类 Expander 中的 `_check` 函数添加新的检查函数 `func`。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
kls (type): 需要被修改的类对象,应继承自 Expander 类。
|
|
|
|
|
func (callable): 要添加的新检查函数,该函数接受一个对象作为参数。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
AttributeError: 如果 kls 类中不存在 `_check` 方法。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
Rewrite the function `_check` in class Expander
|
|
|
|
|
to append the new `func` after the original checks.
|
|
|
|
@ -93,6 +177,21 @@ class ExpanderInfoValidator:
|
|
|
|
|
old_check = getattr(kls, "_check")
|
|
|
|
|
|
|
|
|
|
def new_check(obj):
|
|
|
|
|
"""
|
|
|
|
|
执行新的检查函数。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
obj (Any): 需要检查的对象。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
这个函数首先调用旧版本的检查函数 `old_check` 对传入的对象 `obj` 进行检查,
|
|
|
|
|
然后调用自定义的函数 `func` 对该对象进行处理。
|
|
|
|
|
"""
|
|
|
|
|
old_check(obj)
|
|
|
|
|
func(obj)
|
|
|
|
|
|
|
|
|
@ -103,6 +202,34 @@ class ExpanderInfoValidator:
|
|
|
|
|
"""
|
|
|
|
|
Add new supported format for the operator
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
*input_format: A variable number of arguments representing the new supported formats.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A wrapper function that adds the specified formats to the operator's supported formats list.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: Raised if the length of the registered format list does not match the length of the input formats,
|
|
|
|
|
or if the input formats do not match any registered format.
|
|
|
|
|
Exception: Raised if the wrapped class is not a subclass of Expander.
|
|
|
|
|
|
|
|
|
|
Description:
|
|
|
|
|
This function adds a list `__supported_formats` to the expander,
|
|
|
|
|
which contains the whitelist of formats supported by the operator.
|
|
|
|
|
It also rewrites the `_check` function to check the formats.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
python
|
|
|
|
|
@add_format("text", "image")
|
|
|
|
|
class MyOperator(Expander):
|
|
|
|
|
pass
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
In this example, `MyOperator` will support the "text" and "image" formats.
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
Add new supported format for the operator
|
|
|
|
|
|
|
|
|
|
this function will add a list `__supported_formats` into the expander,
|
|
|
|
|
saving the whitelist of formats that this op supports.
|
|
|
|
|
it also rewrites the `_check` function to check the formats.
|
|
|
|
@ -110,6 +237,19 @@ class ExpanderInfoValidator:
|
|
|
|
|
format_list_name = "__supported_formats"
|
|
|
|
|
|
|
|
|
|
def _check_format(obj):
|
|
|
|
|
"""
|
|
|
|
|
检查对象的输入格式是否与已注册的格式匹配。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
obj (object): 需要检查的对象。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果输入格式与已注册的格式不匹配,则引发异常。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
inp_formats = [inp['format'] for inp in obj.inputs]
|
|
|
|
|
for formats in getattr(obj, format_list_name):
|
|
|
|
|
if len(formats) != len(inp_formats):
|
|
|
|
@ -120,6 +260,18 @@ class ExpanderInfoValidator:
|
|
|
|
|
raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
|
|
|
|
|
|
|
|
|
|
def wrapper(cls):
|
|
|
|
|
"""
|
|
|
|
|
为给定的类添加包装功能。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cls: 需要被包装的类,必须继承自 Expander 类。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
返回包装后的类。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
Exception: 如果 cls 不是 Expander 的子类,则抛出异常。
|
|
|
|
|
"""
|
|
|
|
|
if not issubclass(cls, Expander):
|
|
|
|
|
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
|
|
|
if not hasattr(cls, format_list_name):
|
|
|
|
@ -132,11 +284,49 @@ class ExpanderInfoValidator:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_all_formats_same(kls):
|
|
|
|
|
"""
|
|
|
|
|
检查所有格式是否相同。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
kls: 待检查的类
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
返回传入的类 kls,并在类上注册一个检查函数,用于验证该类所有输入格式是否一致。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
Exception: 如果传入的类 kls 不是 Expander 的子类,则抛出异常。
|
|
|
|
|
GKException: 如果 kls 类中的输入格式不一致,则抛出异常,并显示不匹配格式的信息。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Check that all formats are the same"""
|
|
|
|
|
|
|
|
|
|
# Ensure no args case can return a class
|
|
|
|
|
def _check(*args):
|
|
|
|
|
"""
|
|
|
|
|
检查操作输入格式是否一致的装饰器。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
*args: 可变参数,装饰器可以接收任意数量的参数。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
wrapper: 返回一个装饰器函数,用于包装类。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果所有输入的格式不一致,抛出GKException异常。
|
|
|
|
|
Exception: 如果被装饰的类不是Expander的子类,抛出异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
def _check_format(obj):
|
|
|
|
|
"""
|
|
|
|
|
检查输入格式是否一致。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
obj (Any): 包含输入信息的对象。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果所有输入格式不一致,则抛出异常,并包含不匹配格式的具体信息。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
inp_formats = [inp['format'] for inp in obj.inputs]
|
|
|
|
|
if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
|
|
|
|
|
return
|
|
|
|
@ -144,6 +334,19 @@ class ExpanderInfoValidator:
|
|
|
|
|
','.join(inp_formats), obj.name))
|
|
|
|
|
|
|
|
|
|
def wrapper(cls):
|
|
|
|
|
"""
|
|
|
|
|
将给定类包装为 Expander 的子类,并进行格式检查。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cls (class): 需要包装的类。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
class: 包装后的类。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
Exception: 如果 cls 不是 Expander 的子类,则抛出异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if not issubclass(cls, Expander):
|
|
|
|
|
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
|
|
|
ExpanderInfoValidator._add_check_function(cls, _check_format)
|
|
|
|
@ -155,14 +358,53 @@ class ExpanderInfoValidator:
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def check_attrs(*args):
|
|
|
|
|
"""
|
|
|
|
|
检查属性是否存在。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
*args: 一个或多个属性名,用于检查对象是否具有这些属性。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
一个装饰器函数,该装饰器函数用于验证类是否具有指定的属性。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果对象不具有指定的属性,则抛出该异常。
|
|
|
|
|
Exception: 如果被装饰的类不是 Expander 的子类,则抛出该异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Check the attrs exist"""
|
|
|
|
|
|
|
|
|
|
def _check_attr(obj):
|
|
|
|
|
"""
|
|
|
|
|
检查对象是否具有指定的属性。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
obj (object): 要检查的对象。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GKException: 如果对象不具有指定的属性,则抛出异常。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
"""
|
|
|
|
|
for a in args:
|
|
|
|
|
if a not in obj.attrs:
|
|
|
|
|
raise GKException("attr '{}' does not exist.".format(a))
|
|
|
|
|
|
|
|
|
|
def wrapper(cls):
|
|
|
|
|
"""
|
|
|
|
|
对类进行包装,确保该类是 Expander 的子类,并添加属性检查功能。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cls (class): 需要包装的类。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
class: 包装后的类。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
Exception: 如果 cls 不是 Expander 的子类,则抛出异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if not issubclass(cls, Expander):
|
|
|
|
|
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
|
|
|
ExpanderInfoValidator._add_check_function(cls, _check_attr)
|
|
|
|
@ -172,6 +414,21 @@ class ExpanderInfoValidator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_frac_z_axis(ori_shape, ori_axis):
|
|
|
|
|
"""
|
|
|
|
|
判断是否为分形NZ格式
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
----
|
|
|
|
|
ori_shape: list or tuple
|
|
|
|
|
输入的原始形状
|
|
|
|
|
ori_axis: list or tuple
|
|
|
|
|
操作的原始形状的轴
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
-------
|
|
|
|
|
output: list
|
|
|
|
|
分形Nz形状的轴
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
judge the format is fractal NZ
|
|
|
|
|
Parameters
|
|
|
|
@ -208,6 +465,16 @@ def to_frac_z_axis(ori_shape, ori_axis):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape_from_fractalnz(fractal):
|
|
|
|
|
"""
|
|
|
|
|
从fractalnz形状推断原始形状
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
fractal (list): fractalnz形状,一个包含形状的列表
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
list: 推断出的原始形状
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"get original shape from fractalnz shape"
|
|
|
|
|
shape = []
|
|
|
|
|
dims = len(fractal)
|
|
|
|
@ -222,6 +489,17 @@ def infer_shape_from_fractalnz(fractal):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_reduced_ori_shape(shape, axis):
|
|
|
|
|
"""
|
|
|
|
|
获取基于原始形状的降维后的形状。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
shape (List[int]): 原始形状,是一个整数列表。
|
|
|
|
|
axis (List[int]): 需要降维的轴索引列表。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List[int]: 降维后的形状,是一个整数列表。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"get shape after reduced which is based on original shape"
|
|
|
|
|
reduced_ori_shape = []
|
|
|
|
|
for i, value in enumerate(shape):
|
|
|
|
@ -233,6 +511,25 @@ def get_reduced_ori_shape(shape, axis):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_reduce_axis_shape(shape, data_format, axis):
|
|
|
|
|
"""
|
|
|
|
|
根据给定的输入形状、数据格式和轴,获取在指定格式下的归约轴和原始的归约形状。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
-----
|
|
|
|
|
shape: list or tuple
|
|
|
|
|
输入的形状。
|
|
|
|
|
data_format: str
|
|
|
|
|
输入的数据格式。
|
|
|
|
|
axis: None, int, list or tuple
|
|
|
|
|
在原始形状下的归约轴。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
--------
|
|
|
|
|
reduce_axis: list
|
|
|
|
|
在指定数据格式下的归约轴。
|
|
|
|
|
ori_reduced_shape: list
|
|
|
|
|
原始的归约形状。
|
|
|
|
|
"""
|
|
|
|
|
"""
|
|
|
|
|
Get the reduce axis under format `data_format` and original reduced shape.
|
|
|
|
|
Parameters
|
|
|
|
|