graph_utils,ruiqin

pull/5/head
zhang 7 months ago
parent 61450586a7
commit ddf9ac7477

@ -45,14 +45,15 @@ class OneOf(OneOf_):
TypeError: raise type error for invalid inputs.
"""
self.patterns = patterns
# 检查 patterns 是否是 Pattern 类的实例
if isinstance(patterns, Pattern):
OneOf_.__init__(self, [patterns])
# 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
OneOf_.__init__(self, patterns)
# 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常
else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class Prim(Prim_):
r"""
Express a pattern of certain primitive type(s).
@ -76,25 +77,33 @@ class Prim(Prim_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 检查name是否为字符串类型如果不是则抛出TypeError
if name is not None and not isinstance(name, str):
raise TypeError(f"Expect string, got : {name}")
self.name = name
# 如果types是字符串类型则将其按'|'分割成列表
if isinstance(types, str):
if self.name is None:
self.name = types
self.types = types.split('|')
# 如果types是Primitive类型则直接将其放入列表中
elif isinstance(types, Primitive):
if self.name is None:
self.name = types.name
self.types = [types]
# 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型
elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
# 如果 self.name 为 None则初始化为空字符串并拼接所有 Primitive 的 name
if self.name is None:
self.name = ""
for prim in types:
self.name += prim.name
# 设置 self.types 为传入的 types
self.types = types
# 如果 types 不符合预期类型,抛出 TypeError
else:
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
# 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name
Prim_.__init__(self, self.types, self.name)
@ -115,16 +124,22 @@ class Call(Call_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
self.prim_pattern = prim_pattern
# 初始化 inputs 列表
self.inputs = []
# 如果 inputs 为 None则不做任何操作
if inputs is None:
pass
# 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs
elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
self.inputs = inputs
# 如果 inputs 不符合上述条件,则抛出 TypeError
else:
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
# 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs
Call_.__init__(self, self.prim_pattern, self.inputs)
@ -145,6 +160,7 @@ class NoneOf(NoneOf_):
TypeError: raise type error for invalid argument.
"""
self.patterns = patterns
# 根据 patterns 的类型初始化 NoneOf_ 类
if patterns is None:
NoneOf_.__init__(self, ())
elif isinstance(patterns, Pattern):
@ -154,7 +170,6 @@ class NoneOf(NoneOf_):
else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
class NewTensor(NewTensor_):
r"""
New Tensor to be used in the target.
@ -167,13 +182,16 @@ class NewTensor(NewTensor_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 初始化输入张量
self.input_tensor = input_tensor
# 检查输入是否为 Tensor 类型
if isinstance(input_tensor, Tensor):
# 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法
NewTensor_.__init__(self, input_tensor)
else:
# 如果不是 Tensor 类型,则抛出 TypeError 异常
raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}")
class NewParameter(NewParameter_):
r"""
New Parameter to be used in the target.
@ -193,11 +211,14 @@ class NewParameter(NewParameter_):
self.default_tensor = default_tensor
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
# 检查参数类型是否符合预期
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool):
# 初始化 NewParameter_ 类
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel)
else:
# 如果参数类型不符合预期,抛出 TypeError
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}")
{requires_grad}, {layerwise_parallel}")

@ -39,6 +39,7 @@ class PyPassManager(PyPassManager_):
TypeError: If argument has invalid type.
"""
def __init__(self, requires_grad=True, run_only_once=False):
# 初始化方法,接收两个布尔参数,设置实例的属性并调用父类的初始化方法
if not isinstance(requires_grad, bool):
raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
if not isinstance(run_only_once, bool):
@ -48,17 +49,20 @@ class PyPassManager(PyPassManager_):
PyPassManager_.__init__(self)
def register(self, py_pass):
# 注册一个Python pass检查其是否为函数类型并获取其模式和目标
if not isfunction(py_pass):
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass()
pass_name = py_pass.__name__
# 检查模式和目标是否为Pattern类型
if not isinstance(pattern, Pattern):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
# 调用父类的register方法注册pass及其相关信息
super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
def unregister(self, py_pass):
# 从注册表中移除指定的Python传递对象可以是字符串形式的名称或函数对象
if isinstance(py_pass, str):
super().unregister(py_pass)
return
@ -66,27 +70,30 @@ class PyPassManager(PyPassManager_):
super().unregister(py_pass.__name__)
return
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass):
# 将Python传递对象注册到注册表中并返回该对象
self.register(py_pass)
return py_pass
def gen_new_parameter(self, pattern):
# 根据给定的模式生成新的参数模式必须是NewParameter类型
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
super().gen_new_parameter(pattern)
def set_renorm(self, should_renorm):
# 设置是否进行重归一化操作,参数必须是布尔值
if not isinstance(should_renorm, bool):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm)
def set_reopt(self, do_reopt):
# 设置是否进行重新优化操作,参数必须是布尔值
if not isinstance(do_reopt, bool):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt)
def register_pass(requires_grad=True, run_only_once=False):
"""
Register python pass to specified pipeline phase which would be used in compilation.
@ -165,12 +172,13 @@ def cancel_new_parameter(pattern):
>>> # some compilations
>>> cancel_new_parameter(abc)
"""
# 检查传入的pattern是否为NewParameter的实例
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
# 创建一个PyPassManager对象
ppm = PyPassManager()
# 从PyPassManager中注销指定名称的参数
ppm.unregister(pattern.para_name)
def set_renorm(should_renorm):
"""
Set whether or not to do renormalization after modified graph in python pass(es).

Loading…
Cancel
Save