|
|
|
@ -22,8 +22,32 @@ from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedEx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_expander(expand_info):
|
|
|
|
|
"""
|
|
|
|
|
根据操作符名称创建一个扩展器。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
expand_info (dict): 包含操作符名称及其他相关信息的字典。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Any: 调用指定操作符名称的扩展器后返回的结果。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
GraphKernelUnsupportedException: 如果指定的操作符名称在扩展器模块中不存在,则抛出此异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Create an expander according to op name"""
|
|
|
|
|
def call_func(func, arg):
|
|
|
|
|
"""
|
|
|
|
|
调用给定的函数并返回其结果。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
func (callable): 要调用的函数。
|
|
|
|
|
arg: 要传递给函数的参数。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
调用给定函数后的返回值。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return func(arg)
|
|
|
|
|
op_name = str(expand_info['name'])
|
|
|
|
|
if not hasattr(expanders, op_name):
|
|
|
|
@ -33,6 +57,21 @@ def create_expander(expand_info):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_expand_info(kernel_info):
|
|
|
|
|
"""
|
|
|
|
|
将json格式的kernel信息转换为更友好的格式。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
kernel_info (dict): 包含kernel信息的字典。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
dict: 转换后的kernel信息字典,包含以下键:
|
|
|
|
|
- name (str): kernel的名称。
|
|
|
|
|
- input_desc (list): 输入描述列表。
|
|
|
|
|
- output_desc (list): 输出描述列表。
|
|
|
|
|
- attr (dict): 属性字典,键为属性名,值为属性值。
|
|
|
|
|
- process (str): 处理过程的描述。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Convert the json into a more friendly format"""
|
|
|
|
|
input_desc = []
|
|
|
|
|
if 'input_desc' in kernel_info and kernel_info['input_desc']:
|
|
|
|
@ -53,6 +92,20 @@ def extract_expand_info(kernel_info):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_op_expander(json_str: str):
|
|
|
|
|
"""
|
|
|
|
|
通过json信息获取操作扩展器。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
json_str (str): 包含操作扩展器信息的json字符串。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
str: 返回扩展后的操作图的json描述。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
jd.JSONDecodeError: 如果输入的json字符串解码失败。
|
|
|
|
|
GraphKernelUnsupportedException: 如果操作图不支持的操作类型。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""get op expander by json info"""
|
|
|
|
|
try:
|
|
|
|
|
kernel_info = json.loads(json_str)
|
|
|
|
|