From ddf9ac7477acf910b550183657ec4c9f98e968c7 Mon Sep 17 00:00:00 2001 From: zhang <3237520467@qq.com> Date: Mon, 30 Dec 2024 20:04:19 +0800 Subject: [PATCH] graph_utils,ruiqin --- .../mindspore/graph_utils/graph_pattern.py | 31 ++++++++++++++++--- .../python_pass/python_pass_register.py | 24 +++++++++----- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py b/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py index 32fcb19d..537c82cd 100644 --- a/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py +++ b/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py @@ -45,14 +45,15 @@ class OneOf(OneOf_): TypeError: raise type error for invalid inputs. """ self.patterns = patterns + # 检查 patterns 是否是 Pattern 类的实例 if isinstance(patterns, Pattern): OneOf_.__init__(self, [patterns]) + # 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例 elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): OneOf_.__init__(self, patterns) + # 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常 else: raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") - - class Prim(Prim_): r""" Express a pattern of certain primitive type(s). @@ -76,25 +77,33 @@ class Prim(Prim_): Raises: TypeError: raise type error for invalid argument. """ + # 检查name是否为字符串类型,如果不是则抛出TypeError if name is not None and not isinstance(name, str): raise TypeError(f"Expect string, got : {name}") self.name = name + # 如果types是字符串类型,则将其按'|'分割成列表 if isinstance(types, str): if self.name is None: self.name = types self.types = types.split('|') + # 如果types是Primitive类型,则直接将其放入列表中 elif isinstance(types, Primitive): if self.name is None: self.name = types.name self.types = [types] + # 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型 elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): + # 如果 self.name 为 None,则初始化为空字符串并拼接所有 Primitive 的 name if self.name is None: self.name = "" for prim in types: self.name += prim.name + # 设置 self.types 为传入的 types self.types = types + # 如果 types 不符合预期类型,抛出 TypeError else: raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") + # 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name Prim_.__init__(self, self.types, self.name) @@ -115,16 +124,22 @@ class Call(Call_): Raises: TypeError: raise type error for invalid argument. """ + # 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError if not isinstance(prim_pattern, (Pattern, str, Primitive)): raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") self.prim_pattern = prim_pattern + # 初始化 inputs 列表 self.inputs = [] + # 如果 inputs 为 None,则不做任何操作 if inputs is None: pass + # 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): self.inputs = inputs + # 如果 inputs 不符合上述条件,则抛出 TypeError else: raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") + # 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs Call_.__init__(self, self.prim_pattern, self.inputs) @@ -145,6 +160,7 @@ class NoneOf(NoneOf_): TypeError: raise type error for invalid argument. """ self.patterns = patterns + # 根据 patterns 的类型初始化 NoneOf_ 类 if patterns is None: NoneOf_.__init__(self, ()) elif isinstance(patterns, Pattern): @@ -154,7 +170,6 @@ class NoneOf(NoneOf_): else: raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") - class NewTensor(NewTensor_): r""" New Tensor to be used in the target. @@ -167,13 +182,16 @@ class NewTensor(NewTensor_): Raises: TypeError: raise type error for invalid argument. """ + # 初始化输入张量 self.input_tensor = input_tensor + # 检查输入是否为 Tensor 类型 if isinstance(input_tensor, Tensor): + # 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法 NewTensor_.__init__(self, input_tensor) else: + # 如果不是 Tensor 类型,则抛出 TypeError 异常 raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") - class NewParameter(NewParameter_): r""" New Parameter to be used in the target. @@ -193,11 +211,14 @@ class NewParameter(NewParameter_): self.default_tensor = default_tensor self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel + # 检查参数类型是否符合预期 if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ isinstance(layerwise_parallel, bool): + # 初始化 NewParameter_ 类 NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, self.layerwise_parallel) else: + # 如果参数类型不符合预期,抛出 TypeError raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ layerwise_parallel(bool), got : {para_name}, {default_tensor}, \ - {requires_grad}, {layerwise_parallel}") + {requires_grad}, {layerwise_parallel}") \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py b/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py index 445c36a9..2266248a 100644 --- a/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py +++ b/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py @@ -39,6 +39,7 @@ class PyPassManager(PyPassManager_): TypeError: If argument has invalid type. """ def __init__(self, requires_grad=True, run_only_once=False): + # 初始化方法,接收两个布尔参数,设置实例的属性并调用父类的初始化方法 if not isinstance(requires_grad, bool): raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}") if not isinstance(run_only_once, bool): @@ -48,17 +49,20 @@ class PyPassManager(PyPassManager_): PyPassManager_.__init__(self) def register(self, py_pass): + # 注册一个Python pass,检查其是否为函数类型,并获取其模式和目标 if not isfunction(py_pass): raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") pattern, target = py_pass() pass_name = py_pass.__name__ + # 检查模式和目标是否为Pattern类型 if not isinstance(pattern, Pattern): raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") if not isinstance(target, Pattern): raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") + # 调用父类的register方法,注册pass及其相关信息 super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_) - def unregister(self, py_pass): + # 从注册表中移除指定的Python传递对象,可以是字符串形式的名称或函数对象 if isinstance(py_pass, str): super().unregister(py_pass) return @@ -66,27 +70,30 @@ class PyPassManager(PyPassManager_): super().unregister(py_pass.__name__) return raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") - + def __call__(self, py_pass): + # 将Python传递对象注册到注册表中,并返回该对象 self.register(py_pass) return py_pass - + def gen_new_parameter(self, pattern): + # 根据给定的模式生成新的参数,模式必须是NewParameter类型 if not isinstance(pattern, NewParameter): raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") super().gen_new_parameter(pattern) - + def set_renorm(self, should_renorm): + # 设置是否进行重归一化操作,参数必须是布尔值 if not isinstance(should_renorm, bool): raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") super().set_renorm(should_renorm) - + def set_reopt(self, do_reopt): + # 设置是否进行重新优化操作,参数必须是布尔值 if not isinstance(do_reopt, bool): raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") super().set_reopt(do_reopt) - def register_pass(requires_grad=True, run_only_once=False): """ Register python pass to specified pipeline phase which would be used in compilation. @@ -165,12 +172,13 @@ def cancel_new_parameter(pattern): >>> # some compilations >>> cancel_new_parameter(abc) """ + # 检查传入的pattern是否为NewParameter的实例 if not isinstance(pattern, NewParameter): raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") + # 创建一个PyPassManager对象 ppm = PyPassManager() + # 从PyPassManager中注销指定名称的参数 ppm.unregister(pattern.para_name) - - def set_renorm(should_renorm): """ Set whether or not to do renormalization after modified graph in python pass(es).