|
|
|
@ -45,14 +45,15 @@ class OneOf(OneOf_):
|
|
|
|
|
TypeError: raise type error for invalid inputs.
|
|
|
|
|
"""
|
|
|
|
|
self.patterns = patterns
|
|
|
|
|
# 检查 patterns 是否是 Pattern 类的实例
|
|
|
|
|
if isinstance(patterns, Pattern):
|
|
|
|
|
OneOf_.__init__(self, [patterns])
|
|
|
|
|
# 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例
|
|
|
|
|
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
|
|
|
|
|
OneOf_.__init__(self, patterns)
|
|
|
|
|
# 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Prim(Prim_):
|
|
|
|
|
r"""
|
|
|
|
|
Express a pattern of certain primitive type(s).
|
|
|
|
@ -76,25 +77,33 @@ class Prim(Prim_):
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: raise type error for invalid argument.
|
|
|
|
|
"""
|
|
|
|
|
# 检查name是否为字符串类型,如果不是则抛出TypeError
|
|
|
|
|
if name is not None and not isinstance(name, str):
|
|
|
|
|
raise TypeError(f"Expect string, got : {name}")
|
|
|
|
|
self.name = name
|
|
|
|
|
# 如果types是字符串类型,则将其按'|'分割成列表
|
|
|
|
|
if isinstance(types, str):
|
|
|
|
|
if self.name is None:
|
|
|
|
|
self.name = types
|
|
|
|
|
self.types = types.split('|')
|
|
|
|
|
# 如果types是Primitive类型,则直接将其放入列表中
|
|
|
|
|
elif isinstance(types, Primitive):
|
|
|
|
|
if self.name is None:
|
|
|
|
|
self.name = types.name
|
|
|
|
|
self.types = [types]
|
|
|
|
|
# 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型
|
|
|
|
|
elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
|
|
|
|
|
# 如果 self.name 为 None,则初始化为空字符串并拼接所有 Primitive 的 name
|
|
|
|
|
if self.name is None:
|
|
|
|
|
self.name = ""
|
|
|
|
|
for prim in types:
|
|
|
|
|
self.name += prim.name
|
|
|
|
|
# 设置 self.types 为传入的 types
|
|
|
|
|
self.types = types
|
|
|
|
|
# 如果 types 不符合预期类型,抛出 TypeError
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
|
|
|
|
|
# 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name
|
|
|
|
|
Prim_.__init__(self, self.types, self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -115,16 +124,22 @@ class Call(Call_):
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: raise type error for invalid argument.
|
|
|
|
|
"""
|
|
|
|
|
# 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError
|
|
|
|
|
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
|
|
|
|
|
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
|
|
|
|
|
self.prim_pattern = prim_pattern
|
|
|
|
|
# 初始化 inputs 列表
|
|
|
|
|
self.inputs = []
|
|
|
|
|
# 如果 inputs 为 None,则不做任何操作
|
|
|
|
|
if inputs is None:
|
|
|
|
|
pass
|
|
|
|
|
# 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs
|
|
|
|
|
elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
|
|
|
|
|
self.inputs = inputs
|
|
|
|
|
# 如果 inputs 不符合上述条件,则抛出 TypeError
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
|
|
|
|
|
# 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs
|
|
|
|
|
Call_.__init__(self, self.prim_pattern, self.inputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -145,6 +160,7 @@ class NoneOf(NoneOf_):
|
|
|
|
|
TypeError: raise type error for invalid argument.
|
|
|
|
|
"""
|
|
|
|
|
self.patterns = patterns
|
|
|
|
|
# 根据 patterns 的类型初始化 NoneOf_ 类
|
|
|
|
|
if patterns is None:
|
|
|
|
|
NoneOf_.__init__(self, ())
|
|
|
|
|
elif isinstance(patterns, Pattern):
|
|
|
|
@ -154,7 +170,6 @@ class NoneOf(NoneOf_):
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NewTensor(NewTensor_):
|
|
|
|
|
r"""
|
|
|
|
|
New Tensor to be used in the target.
|
|
|
|
@ -167,13 +182,16 @@ class NewTensor(NewTensor_):
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: raise type error for invalid argument.
|
|
|
|
|
"""
|
|
|
|
|
# 初始化输入张量
|
|
|
|
|
self.input_tensor = input_tensor
|
|
|
|
|
# 检查输入是否为 Tensor 类型
|
|
|
|
|
if isinstance(input_tensor, Tensor):
|
|
|
|
|
# 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法
|
|
|
|
|
NewTensor_.__init__(self, input_tensor)
|
|
|
|
|
else:
|
|
|
|
|
# 如果不是 Tensor 类型,则抛出 TypeError 异常
|
|
|
|
|
raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NewParameter(NewParameter_):
|
|
|
|
|
r"""
|
|
|
|
|
New Parameter to be used in the target.
|
|
|
|
@ -193,11 +211,14 @@ class NewParameter(NewParameter_):
|
|
|
|
|
self.default_tensor = default_tensor
|
|
|
|
|
self.requires_grad = requires_grad
|
|
|
|
|
self.layerwise_parallel = layerwise_parallel
|
|
|
|
|
# 检查参数类型是否符合预期
|
|
|
|
|
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
|
|
|
|
|
isinstance(layerwise_parallel, bool):
|
|
|
|
|
# 初始化 NewParameter_ 类
|
|
|
|
|
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
|
|
|
|
|
self.layerwise_parallel)
|
|
|
|
|
else:
|
|
|
|
|
# 如果参数类型不符合预期,抛出 TypeError
|
|
|
|
|
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
|
|
|
|
|
layerwise_parallel(bool), got : {para_name}, {default_tensor}, \
|
|
|
|
|
{requires_grad}, {layerwise_parallel}")
|
|
|
|
|
{requires_grad}, {layerwise_parallel}")
|