From ffdf6162c789956e5a8cf33b386fb055d8bb7482 Mon Sep 17 00:00:00 2001 From: yixin <2050485123@qq.com> Date: Wed, 25 Dec 2024 16:41:30 +0800 Subject: [PATCH] _extends\graph_kernel\model --- .../_extends/graph_kernel/model/__init__.py | 3 + .../graph_kernel/model/graph_parallel.py | 61 ++ .../graph_kernel/model/graph_split.py | 675 +++++++++++++++++- .../_extends/graph_kernel/model/model.py | 154 ++++ .../graph_kernel/model/model_builder.py | 109 ++- .../_extends/graph_kernel/model/op_infer.py | 128 ++++ 6 files changed, 1091 insertions(+), 39 deletions(-) diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/__init__.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/__init__.py index 8125a8b1..b9f7be45 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/__init__.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/__init__.py @@ -14,6 +14,9 @@ # =========================================================================== """GraphKernel cost model init""" +# 导入split模块 from .graph_split import split +# 导入GraphBuilder和load_composite模块 from .model_builder import GraphBuilder, load_composite +# 导入parallel_estimate模块 from .graph_parallel import parallel_estimate diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_parallel.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_parallel.py index 7e69da3e..fdf90238 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_parallel.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_parallel.py @@ -20,42 +20,103 @@ class ParalGain: """Paral Gain""" def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info): + """ + 类的构造函数。 + + Args: + fusion_type (str): 融合类型。 + bottleneck (int): 瓶颈层的大小。 + gain (float): 增益值。 + block_assign (list): 块分配列表。 + type_info (dict): 类型信息字典。 + + Returns: + None + """ + # 初始化融合类型 self.fusion_type = fusion_type + # 初始化瓶颈层 self.bottleneck = bottleneck + # 初始化增益 self.gain = gain + # 初始化块分配 self.block_assign = block_assign + # 初始化类型信息 self.type_info = type_info class ScheduleAnalyzer: """schedule analyzer""" + # 定义一个常量,表示wrap的大小 WRAP_SIZE = 32 + # 定义一个常量,表示最大SM数量 MAX_SM = 80 # Volta + # 定义一个常量,表示最大线程数量 MAX_NUM_THREADS = 1024 + # 定义一个常量,表示最大block数量 MAX_BLOCK = 256 + # 定义一个常量,表示流水线操作的阈值 PIPELINE_OP_THREADHOLD = 5 def __init__(self, graph): + """ + 初始化图处理类。 + + Args: + graph (Graph): 图对象,用于存储图的结构和参数。 + + Attributes: + graph (Graph): 图对象,存储图的结构和参数。 + block_num (int): 块的数量,初始值为0。 + block_weight (float): 块的权重,初始值为0。 + ops (List[Operation]): 图的操作列表。 + dom_op (List[Operation]): 输出的每个操作对应的操作列表。 + + """ + # 将传入的图对象赋值给实例变量graph self.graph = graph + # 初始化block数量为0 self.block_num = 0 + # 初始化block权重为0 self.block_weight = 0 + # 通过图对象的deduce_parameters方法获取参数,并赋值给outputs变量 _, outputs = graph.deduce_parameters() + # 将图对象的操作列表赋值给实例变量ops self.ops = graph.ops + # 将outputs中的每个输出对应的操作收集到一个列表中,并赋值给实例变量dom_op self.dom_op = list(out.op for out in outputs) @staticmethod def prod(shape): + """ + 计算形状乘积。 + + Args: + shape (list): 一个包含整数的列表,表示形状。 + + Returns: + int: 形状乘积的结果。 + + """ """Compute shape product""" + # 初始化结果变量为shape的第一个元素 res = shape[0] + # 遍历shape列表,从第二个元素开始 for i in range(1, len(shape)): + # 将当前结果与shape的下一个元素相乘 res = res * shape[i] + # 返回计算后的结果 return res def _cal_weight(self, ops): + # 初始化权重为0 weight = 0 for op in ops: + # 计算当前操作的权重 weight += self.prod(op.output.shape) * \ + # 根据输出数据类型计算字节数 PrimLib.dtype_bytes(op.output.dtype) + # 返回计算得到的权重 return weight def injective_analyze(self): diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_split.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_split.py index 4f6c9778..3dea775d 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/graph_split.py @@ -21,33 +21,98 @@ from .model import DataFormat as DF def tensor_size(tensor): + """ + 获取张量的总大小。 + + Args: + tensor (torch.Tensor): 输入的张量。 + + Returns: + int: 张量的总大小。 + + """ """get tensor size""" + # 初始化大小为1 size = 1 + # 遍历张量的形状 for i in tensor.shape: + # 将当前维度的大小乘到总大小上 size *= i + # 返回总大小 return size def reduce_nums(ops): + """ + 统计以'Reduce'开头的操作数量。 + + Args: + ops (List[Operation]): 操作列表,其中每个操作都是一个Operation对象。 + + Returns: + int: 以'Reduce'开头的操作数量。 + + """ """get reduce nums""" count = 0 + # 遍历操作列表 for op in ops: + # 判断操作是否以'Reduce'开头 if op.prim.startswith('Reduce'): + # 如果是,计数器加一 count += 1 + # 返回计数结果 return count -def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size): - """check if can stitch""" +def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size): # 如果a的操作数大于等于2,则返回False + """ + 检查是否可以进行拼接操作。 - def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size): - """does a and b have same stitch axis""" + Args: + dom (Operator): DOM操作符,表示拼接操作的上下文。 + a (Operator): 待拼接的操作符A。 + r (Operator): 待拼接的操作符R。 + stitch_axis_size (int): 拼接轴的大小。 + stitch_buffer_size (int): 缓冲区大小,用于判断是否满足拼接条件。 - def _stitch_axis(shape, stitch_axis_size): - """get stitch axis""" + Returns: + bool: 如果可以进行拼接操作,返回True;否则返回False。 + + """ + """check if can stitch""" + # 获取dom的操作输出 + def _same_stitch_axis(stitch_tensors, final_outs, stitch_axis_size): # 获取a的操作输入 + """ + 判断给定的张量列表是否具有相同的拼接轴。 + + Args: + stitch_tensors (list of Tensor): 待拼接的张量列表。 + final_outs (list of Tensor): 最终输出的张量列表。 + stitch_axis_size (int): 拼接轴的大小限制。 + + Returns: + bool: 如果所有张量在拼接轴上的大小相同,则返回True;否则返回False。 + + """ + """判断a和b是否具有相同的拼接轴""" # 获取a的操作输出 + # 获取a的操作最终输出 + def _stitch_axis(shape, stitch_axis_size): # 获取dom的输出中在a的输入中的张量 + """ + 获取拼接轴。 + + Args: + shape (list): 形状列表,表示张量的维度。 + stitch_axis_size (int): 拼接轴的大小。 + + Returns: + list: 拼接轴的列表。 + + """ + """获取拼接轴""" # 如果stitch_tensors和a_final_outs的stitch轴不同,则返回False stitchaxis = [] - size = 1 - for i in shape: + size = 1 # 如果stitch_tensors中的任何一个张量的size大于等于stitch_buffer_size,则返回True + for i in shape: # 否则返回False size = size * i stitchaxis.append(i) if size >= stitch_axis_size: @@ -56,25 +121,28 @@ def may_stitch(dom, a, r, stitch_axis_size, stitch_buffer_size): x = [] x.extend(stitch_tensors) - x.extend(final_outs) + x.extend(final_outs) # 如果dom的模式不等于PrimLib.RESHAPE,则返回False stitch_axis_0 = _stitch_axis(x[0].shape, stitch_axis_size) - for item in x: - i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size) - if not i_stitch_axis or i_stitch_axis != stitch_axis_0: + for item in x: # 初始化最小面积和前向融合为None和False + i_stitch_axis = _stitch_axis(item.shape, stitch_axis_size) # 遍历输出关系 + if not i_stitch_axis or i_stitch_axis != stitch_axis_0: # 如果关系模式小于等于广播,且无环,且最小面积未定义或关系模式小于最小面积模式,则更新最小面积 return False return True - - if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): + # 遍历输入关系 + if a.pattern <= PrimLib.REDUCE and r <= PrimLib.BROADCAST and dom.check_acyclic(a): # 如果关系模式小于等于广播,且无环,且输入操作只有一个,且不是输出,且最小面积未定义或关系模式小于最小面积模式,则更新最小面积和前向融合 + """判断a和r的操作模式,以及a是否有环""" if reduce_nums(a.ops) >= 2: return False - dom_outs = set(op.output for op in dom.ops) + dom_outs = set(op.output for op in dom.ops) # 如果最小面积存在,则返回最小面积和前向融合,否则返回空列表 a_ins = set(op_input for op in a.ops for op_input in op.inputs) a_outs = set(op.output for op in a.ops) a_final_outs = list(tensor for tensor in a_outs if tensor not in a_ins) stitch_tensors = list(tensor for tensor in dom_outs if tensor in a_ins) + """获取需要拼接的张量""" # 如果模式不是elemwise或broadcast,或者输入关系不为1,则返回空列表 if not _same_stitch_axis(stitch_tensors, a_final_outs, stitch_axis_size): return False - return any((tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors)) + return any((tensor_size(tensor) >= stitch_buffer_size for tensor in stitch_tensors)) # 如果关系模式大于广播,或者输出关系不为1,或者关系不是elemwise,或者输出形状不等于输入形状,则返回空列表 + """判断是否存在需要拼接的张量的大小大于等于缓冲区大小""" return False @@ -83,52 +151,109 @@ class CommonPattern: @staticmethod def reshape(dom): + """ + 对reshape dom进行融合策略的函数。 + + Args: + dom (PrimDom): 需要进行reshape操作的dom对象。 + + Returns: + tuple: 包含两个元素的元组。 + - List[PrimDom]: 包含最小面积的dom对象的列表。 + - bool: 表示是否进行前向融合的布尔值。 + + """ """fuse strategy for reshape dom""" + # 判断dom的模式是否为PrimLib.RESHAPE if dom.pattern != PrimLib.RESHAPE: return [] + + # 初始化最小面积和是否前向融合的标志 min_area, forward_fuse = None, False + + # 遍历dom的输出关系 for a, _ in dom.out_relations.items(): + # 如果a的模式小于等于PrimLib.BROADCAST,并且dom与a之间没有环,且a的模式小于当前最小面积的模式 if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ (min_area is None or a.pattern < min_area.pattern): + # 更新最小面积为a min_area = a + + # 遍历dom的输入关系 for a, _ in dom.in_relations.items(): + # 如果a的模式小于等于PrimLib.BROADCAST,且a与dom之间没有环,且dom的第一个操作的第一个输入的to_ops长度为1,且a不是输出 if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ (min_area is None or a.pattern < min_area.pattern): + # 更新最小面积为a,并设置前向融合标志为True min_area, forward_fuse = a, True + + # 如果最小面积存在,则返回最小面积列表和前向融合标志,否则返回空列表 return ([min_area], forward_fuse) if min_area else [] @staticmethod def elemwise_depth(dom): + """ + 在深度上对elemwise dom进行融合策略。 + + Args: + dom (object): 待融合的dom对象。 + + Returns: + tuple: 包含两个元素的元组,第一个元素为dom的输入操作列表,第二个元素为布尔值,表示是否成功进行融合。 + + """ """fuse strategy in depth for elemwise dom""" + # 如果dom的模式不是Elemwise或Broadcast,或者dom的输入关系数量不为1,则返回空列表 if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: return [] + + # 获取dom的唯一输入关系及其关联的操作 a, r = list(dom.in_relations.items())[0] + + # 如果a的模式大于Broadcast,或者a的输出关系数量不为1,或者r的模式不是Elemwise,或者a和dom的输出形状不一致,则返回空列表 if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ a.dom_op().output.shape != dom.dom_op().output.shape: return [] + + # 返回包含a的列表和True return [a], True @staticmethod def elemwise_width(dom): """fuse strategy in width for elemwise dom""" + # 如果dom的模式不是PrimLib.ELEMWISE或PrimLib.BROADCAST,则返回空列表 if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): return [] + fused = [] + # 遍历dom的输入关系 for a, r in dom.in_relations.items(): + # 如果a的模式小于等于PrimLib.BROADCAST,关系r为PrimLib.ELEMWISE,且a是无环的, + # 同时a的dom操作输出形状与dom的dom操作输出形状相同 if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_acyclic(dom) and \ a.dom_op().output.shape == dom.dom_op().output.shape: + # 将满足条件的a添加到fused列表中 fused.append(a) + + # 返回fused列表和一个布尔值True return fused, True @staticmethod def assign(dom): """fuse strategy for assign dom""" + # 判断dom的操作数个数是否等于1,并且操作是否为Assign if len(dom.ops) != 1 or dom.dom_op().prim != "Assign": + # 如果不满足条件,返回空列表 return [] + + # 初始化一个空列表,用于存储融合后的结果 fused = [] + # 遍历dom的输入关系 for a, _ in dom.in_relations.items(): + # 将输入关系中的a添加到fused列表中 fused.append(a) + # 返回融合后的结果和表示成功融合的布尔值True return fused, True @@ -139,37 +264,56 @@ class GraphSplitByPattern: """Reachable table""" def __init__(self, size): + # 初始化一个空的列表,用于存储地图状态 self.map = [] + # 初始化一个集合,包含所有活细胞的位置 self.alive = set(range(size)) + # 遍历从0到size-1的每一个数字 for i in range(0, size): + # 在map列表中添加一个新的列表,包含size个False值 self.map.append([False] * size) + # 将对角线位置设置为True,表示这些位置是活细胞 self.map[i][i] = True def reachable(self, x, y): """reachable from x to y""" + # 返回地图中位置 (x, y) 的可达性 return self.map[x][y] def sync(self, x, y): - """sync from y to x""" + """ + 从 y 同步到 x + """ + # 遍历 alive 列表 for i in self.alive: + # 调用 _link 方法,将 y 中的元素链接到 x 中 self._link(self.map[y][i], x, i) def _link(self, cond, f, t): """link from `f` to `t`""" + # 如果条件为真 if cond: + # 在映射字典中将 f 到 t 的路径标记为 True self.map[f][t] = True def fuse(self, x, y): """fuse y to x""" for i in self.alive: + # i 是 y 的后继节点,将 x 的前驱节点连接到 i # i is the succeeding node of y, links the x's previous nodes to i if self.map[y][i] and not self.map[x][i]: + # 遍历 x 的前驱节点 for pre in self.alive: + # 将 x 的前驱节点连接到 i self._link(self.map[pre][x], pre, i) + # i 是 y 的前驱节点,将 i 连接到 x 的后继节点 # i is the previous node of y, link i to x's succeeding nodes if self.map[i][y] and not self.map[i][x]: + # 遍历 x 的后继节点 for suc in self.alive: + # 将 i 连接到 x 的后继节点 self._link(self.map[x][suc], i, suc) + # 从 self.alive 中移除 y self.alive.remove(y) class Area: @@ -181,92 +325,145 @@ class GraphSplitByPattern: """StitchInfo""" def __init__(self): + """ + 初始化函数。 + """ + # 初始化存储拼接操作的集合 self.stitch_ops = set() + # 初始化存储原子拼接操作的集合 self.stitch_atomic_ops = set() def has_stitch_op(self): """check stitch_op exists""" + # 检查是否存在 stitch_ops return self.stitch_ops or self.stitch_atomic_ops def __init__(self, init_op, is_output, unique_id, reach_tab, recompute_ops=None): + # 初始化模式类型 self.pattern = PrimLib.iter_type(init_op) if init_op is not None else PrimLib.UNKNOWN + # 初始化操作列表 self.ops = [] if init_op is None else [init_op] + # 初始化输入关系字典 self.in_relations = dict() # {area1: relation1, area2: relation2, ...} + # 初始化输出关系字典 self.out_relations = dict() # {area1: relation1, area2: relation2, ...} + # 初始化模式 self.mode = None + # 初始化缝合信息 self.stitch_info = self.StitchInfo() + # 初始化重计算操作列表 self.recompute_ops = [] if recompute_ops is None else recompute_ops + # 初始化原始操作映射 self.ori_op_map = {} + # 初始化是否重计算 self.is_recompute = False + # 初始化是否为输出 self.is_output = is_output + # 初始化输出排除集合 self.output_excluded = set() + # 如果模式是减少类型,则执行以下逻辑 if self.pattern == PrimLib.REDUCE: + # 定义用于收集减少排除的函数 def _gather_reduce_exclude(): + # 初始化递归栈 recursion_stack = [init_op] + # 循环遍历递归栈 while recursion_stack: + # 弹出栈顶操作 op = recursion_stack.pop() + # 遍历操作的输出到操作 for to in op.output.to_ops: + # 获取输出到操作的索引 idx = to.inputs.index(op.output) + # 如果关系大于元素级关系,则将该操作添加到输出排除集合 if self.get_relation(to, idx) > PrimLib.ELEMWISE: self.output_excluded.add(to) else: + # 否则,将该操作添加到递归栈中 recursion_stack.append(to) + # 调用收集减少排除的函数 _gather_reduce_exclude() + # 初始化唯一ID self.unique_id = unique_id + # 初始化可达表 self.reach_tab = reach_tab def __str__(self): + # 将self.ops中的每个op的output.name拼接成一个由'-'连接的字符串 return '<' + '-'.join((op.output.name for op in self.ops)) + '>' def __repr__(self): + # 返回对象的字符串表示 return str(self) @staticmethod def get_relation(op, i): """Get op relation""" + # 初始化关系为未知 relation = PrimLib.UNKNOWN + # 获取输入的关系 _, elem_relation = PrimLib.input_relation(op, i) + # 遍历每个关系 for r in elem_relation: + # 如果关系为空,则将关系更新为广播 if r is None: + # 更新关系为最大关系,若当前关系为UNKNOWN,则更新为BROADCAST relation = max(relation, PrimLib.BROADCAST) + # 如果当前关系大于已有关系,则更新关系 elif r > relation: relation = r return relation def link_input(self, area_map): """Link inputs""" + # 遍历self.ops[0].inputs中的每个元素 for i, t in enumerate(self.ops[0].inputs): + # 如果当前元素t的op不为空 if t.op is not None: + # 从area_map中获取当前元素t的op对应的area + # 并调用self.get_relation获取self.ops[0]和i的relation area, relation = area_map[t.op], self.get_relation(self.ops[0], i) + # 将area和relation保存到self.in_relations中 self.in_relations[area] = relation def link_output(self): """Link outputs""" + # 遍历输入区域与关系的字典 for input_area, r in self.in_relations.items(): + # 将当前对象添加到输入区域的输出关系中 input_area.out_relations[self] = r + # 遍历输出关系 for out, _ in self.out_relations.items(): + # 同步当前对象与输出对象的唯一标识符 self.reach_tab.sync(self.unique_id, out.unique_id) def update_stitch_info(self, stitch_info): """Update stitch info""" + # 如果stitch_info中存在stitch_ops if stitch_info.stitch_ops: + # 更新self.stitch_info中的stitch_ops self.stitch_info.stitch_ops.update(stitch_info.stitch_ops) + # 如果stitch_info中存在stitch_atomic_ops if stitch_info.stitch_atomic_ops: + # 更新self.stitch_info中的stitch_atomic_ops self.stitch_info.stitch_atomic_ops.update(stitch_info.stitch_atomic_ops) def fuse(self, area): - """Fuse `area` to `self`""" + """将`area`融合到`self`中""" def _update_relation(relations, a, r): + # 更新关系 relations[a] = max(r, relations[a]) if a in relations else r def _update_pattern(): + # 更新模式 if area.pattern > self.pattern: self.pattern = area.pattern if area in self.in_relations and self.in_relations.get(area) > self.pattern: self.pattern = self.in_relations.get(area) def _fuse_relation(self_relations, new_relations): + # 融合关系 for a, r in new_relations.items(): if a != self: _update_relation(self_relations, a, r) @@ -274,65 +471,94 @@ class GraphSplitByPattern: self_relations.pop(area) def _redirect_relation(rels): - """Replace `area` with `self` in relations""" + """在关系中用`self`替换`area`""" if area in rels: r = rels.pop(area) _update_relation(rels, self, r) + # 如果`area`需要重新计算 if area.is_recompute: self.cp_ops(area) + # 如果`self`的模式大于或等于`area`的模式 if self.pattern >= area.pattern: self.ops.extend(area.ops) else: self.ops = area.ops + self.ops + # 更新模式 _update_pattern() + # 融合输入关系 _fuse_relation(self.in_relations, area.in_relations) + # 融合输出关系 _fuse_relation(self.out_relations, area.out_relations) + # 更新输入关系的重定向 for a, _ in area.in_relations.items(): _redirect_relation(a.out_relations) + # 更新输出关系的重定向 for a, _ in area.out_relations.items(): _redirect_relation(a.in_relations) + # 如果`self`的模式大于PrimLib.RESHAPE if self.pattern > PrimLib.RESHAPE: self.mode = self.MODE_COMPOSITE + # 如果`area`是输出而`self`不是 if area.is_output and not self.is_output: self.is_output = True + # 如果`area`有排除的输出 if area.output_excluded: self.output_excluded.update(area.output_excluded) + # 更新拼接信息 self.update_stitch_info(area.stitch_info) + # 如果`area`不需要重新计算 if not area.is_recompute: self.reach_tab.fuse(self.unique_id, area.unique_id) + # 融合重新计算的操作 self.recompute_ops.extend(area.recompute_ops) def check_acyclic(self, to): """Check circle. It returns false if circle exists""" + # 遍历所有出边关系 for out, _ in self.out_relations.items(): + # 如果当前出边的节点不是目标节点,并且从当前出边节点到目标节点可达 if out != to and self.reach_tab.reachable(out.unique_id, to.unique_id): + # 存在环,返回False return False + # 不存在环,返回True return True def dom_op(self): - """Get dom op""" + """ + 获取dom操作 + """ + # 返回操作列表中的第一个操作 return self.ops[0] def reduce_out_exclude(self, area): """Check whether op is reduce_out_exclude """ + # 如果self.output_excluded为真 if self.output_excluded: + # 遍历self.output_excluded中的每个操作 for op in self.output_excluded: + # 如果操作在area.ops中存在 if op in area.ops: + # 返回True return True + # 如果没有找到符合条件的操作,返回False return False def cp_ops(self, area): """copy recompute_ops in area to ops, self is area's user""" tail_tensor = area.recompute_ops[-1].output + # 复制张量,所有复制的张量都是Tensor.PARA_NONE # copy tensors, all copied are Tensor.PARA_NONE tensor_map = {} if area.recompute_ops[0].inputs: + # 如果第一个操作的输入不为空,则将输入张量映射为自身 tensor_map[area.recompute_ops[0].inputs[0]] = area.recompute_ops[0].inputs[0] for op in area.recompute_ops: orig_tensor = op.output cp_tensor = Tensor(orig_tensor.name, orig_tensor.shape, orig_tensor.dtype, orig_tensor.data_format) tensor_map[orig_tensor] = cp_tensor + + # 复制操作 # copy ops cp_ops = [] for op in area.recompute_ops: @@ -341,6 +567,8 @@ class GraphSplitByPattern: cp_op.all_inputs = cp_op.inputs cp_ops.append(cp_op) area.ori_op_map[cp_op] = op + + # 连接复制的操作 # connect copied ops for op in self.ops: if tail_tensor in op.inputs: @@ -348,6 +576,8 @@ class GraphSplitByPattern: op.inputs.append(tensor_map.get(tail_tensor)) tail_tensor.to_ops.remove(op) tensor_map.get(tail_tensor).to_ops.append(op) + + # 将复制的操作填充到self.recompute_area中 # fill cp_ops in self.recompute_area cp_dom_op = None for cp, ori in area.ori_op_map.items(): @@ -358,60 +588,94 @@ class GraphSplitByPattern: area.ops.extend((op for op in cp_ops if op != cp_dom_op)) def __init__(self, graph, flags): + # 初始化图对象 self.graph = graph + # 初始化空区域列表 self.areas = [] + # 初始化标志对象 self.flags = flags + # 初始化是否启用重新计算融合的标志 self.enable_recompute = self.flags.get("enable_recompute_fusion", False) + # 初始化是否启用缝合融合的标志 self.enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) + # 初始化是否启用水平融合的标志 self.enable_horizontal_fusion = self.flags.get("enable_horizontal_fusion", False) + # 初始化可达表 self.reach_tab = self.ReachTable(len(graph.ops) + 1 if self.enable_recompute else len(graph.ops)) + # 初始化区域映射字典 self.area_map = {} + # 获取图的输出参数 _, outputs = graph.deduce_parameters() idx = 0 + # 遍历图中的所有操作 for op in graph.ops: + # 判断操作是否是输出操作 is_output = op.output in outputs + # 创建一个区域对象 a = self.Area(op, is_output, idx, self.reach_tab) idx += 1 + # 设置默认模式 self.set_default_mode(a) + # 将区域对象添加到区域列表中 self.areas.append(a) + # 设置区域映射 self.set_area_map([op], a) + # 遍历所有区域,设置输入链接 for a in self.areas: a.link_input(self.area_map) + # 从后往前遍历区域,设置输出链接 for i in range(len(self.areas) - 1, -1, -1): self.areas[i].link_output() + # 如果启用了重新计算融合 if self.enable_recompute: + # 创建一个用于重新计算的区域对象 self.recom_area = self.Area(None, False, idx, self.reach_tab) + # 设置重新计算标志 self.recom_area.is_recompute = True + # 初始化重新计算区域的前驱、用户和支配区域 self.recom_pre = None self.recom_user = None self.recom_dom = None + # 初始化重新计算区域的支配用户 self.dom_user_r = PrimLib.UNKNOWN + # 初始化重新计算结果标志 self.recom_res = False + # 初始化原始操作映射 self.orig_op_map = {} def set_area_map(self, ops, area): """update area_map after op fused to area""" + # 遍历操作列表 for op in ops: + # 将操作映射到指定的区域 self.area_map[op] = area def set_default_mode(self, area): """Set default mode""" + # 设置区域模式为默认模式 area.mode = self.get_default_mode(area.ops[0]) @staticmethod def limit_area_size(dominant, fuse_areas, limit_size=200): """Remove some areas if the size is too large""" + # 计算每个区域的操作数大小 area_sizes = map(lambda area: len(area.ops), fuse_areas) + # 计算主要区域的操作数大小 dom_size = len(dominant.ops) + # 如果总操作数大小不超过限制大小,则返回原区域列表 if dom_size + prod_reduce(lambda x, y: x + y, area_sizes) <= limit_size: return fuse_areas + # 按操作数大小优先融合较小的区域 # fuse the smaller area in priority fuse_areas.sort(key=lambda area: len(area.ops)) new_fuse_areas = [] for area in fuse_areas: + # 如果加上当前区域后超过限制大小,则跳出循环 if dom_size + len(area.ops) > limit_size: break + # 累加当前区域的操作数到总操作数中 dom_size += len(area.ops) + # 将当前区域添加到新的区域列表中 new_fuse_areas.append(area) return new_fuse_areas @@ -419,174 +683,292 @@ class GraphSplitByPattern: """Fuse areas""" def _fuse_area(): + # 遍历所有区域 for dominant in self.areas: + # 使用选择器函数选择区域 result = selector(dominant) + # 如果选择器没有返回结果或结果为空,则跳过当前循环 if not result or not result[0]: continue + # 解包结果 fuse_areas, is_forward = result + # 限制融合区域的大小 fuse_areas = self.limit_area_size(dominant, fuse_areas) + # 如果没有融合区域,则跳过当前循环 if not fuse_areas: continue + # 判断融合方向 if is_forward: + # 如果是正向融合 for area in fuse_areas: + # 将当前区域与融合区域融合 dominant.fuse(area) + # 更新区域映射 self.set_area_map(area.ops, dominant) + # 从区域列表中移除融合区域 self.areas.remove(area) else: + # 如果是反向融合 forward_area = dominant for area in fuse_areas: + # 将当前区域与融合区域融合 area.fuse(forward_area) + # 更新区域映射 self.set_area_map(forward_area.ops, area) + # 从区域列表中移除当前区域 self.areas.remove(forward_area) + # 更新当前区域为下一个融合区域 forward_area = area + # 返回True表示有变化 return True + # 如果没有进行任何融合操作,则返回False return False changed, do_again = False, True while do_again: + # 执行融合操作 do_again = _fuse_area() + # 更新是否有变化的标志 changed = changed or do_again return changed def hfuse(self, selector): """Fuse horizontal areas with same input tensor""" + # 定义一个内部函数,用于执行融合操作 def _do_fuse(areas): + # 遍历所有区域,除了最后一个 for i in range(len(areas) - 1): dom = areas[i] + # 遍历剩余的区域 for a in areas[i + 1:]: + # 如果两个区域无环且满足selector函数,且融合后的区域大小不超过限制 if dom.check_acyclic(a) and a.check_acyclic(dom) and \ selector(dom, a) and self.limit_area_size(dom, [a], 64): + # 融合区域 dom.fuse(a) + # 更新区域映射 self.set_area_map(a.ops, dom) + # 从区域列表中移除已融合的区域 self.areas.remove(a) + # 返回True表示有变化 return True + # 如果没有发生融合,返回False return False + # 定义一个内部函数,用于更新区域列表 def _update_areas(areas, from_op): + # 遍历from_op的所有输出操作 for op in from_op.to_ops: + # 获取操作对应的区域 a = self.area_map.get(op) + # 如果区域存在且不在当前区域列表中,则添加到列表中 if a in self.areas and a not in areas: areas.append(a) + # 初始化变化标志 changed = False + # 不断循环,直到没有变化为止 while True: for dom in self.areas: + # 如果当前区域有多个输出关系,则尝试进行融合 if len(dom.out_relations) > 1 and _do_fuse(list(dom.out_relations.keys())): changed = True break + # 如果没有发生任何变化,则跳出循环 else: break + + # 获取输入参数 inputs, _ = self.graph.deduce_parameters() + # 不断循环,直到没有变化为止 while True: for t in inputs: + # 初始化区域列表 areas = [] + # 更新区域列表 _update_areas(areas, t) + # 如果区域列表中有多个区域,则尝试进行融合 if len(areas) > 1 and _do_fuse(areas): changed = True break + # 如果没有发生任何变化,则跳出循环 else: break + + # 返回是否有变化 return changed def fuse_recom(self, selector): """Fuse recompute area to its user""" + # 遍历主导区域数组 for dominant in [self.recom_area, self.recom_user]: + # 使用选择器函数处理主导区域 result = selector(dominant) + # 如果选择器函数返回结果且第一个元素为真 if result and result[0]: + # 解包结果,获取需要融合的区域和额外信息 fuse_areas, _ = result + # 对需要融合的区域进行大小限制 fuse_areas = self.limit_area_size(dominant, fuse_areas) + # 如果没有需要融合的区域,则跳过当前循环 if not fuse_areas: continue + # 如果需要融合的第一个区域是主导区域之一 if fuse_areas[0] in [self.recom_area, self.recom_user]: + # 将recom_area融合到recom_user self.recom_user.fuse(self.recom_area) + # 设置区域映射 self.set_area_map(self.recom_area.ops, self.recom_user) + # 设置融合成功的标志 self.recom_res = True + # 返回成功标志 return True + # 如果没有成功融合,则返回失败标志 return False def index_op(self): """index op by order, the copied op share id with original op, for topo-sort""" + # 创建一个空字典,用于存储操作及其对应的索引 ids = {} + + # 遍历图中的操作,并为每个操作分配一个索引 for i, op in enumerate(self.graph.ops): + # 将操作作为键,索引作为值存储到ids字典中 ids[op] = i + + # 如果存在orig_op_map属性 if hasattr(self, 'orig_op_map'): + # 遍历orig_op_map中的键值对 for k, v in self.orig_op_map.items(): + # 如果v在ids中存在,则将k的索引设置为v的索引 ids[k] = ids.get(v) + + # 返回ids字典 return ids def to_subgraphs(self): """Transform op groups to subgraphs""" + # 获取操作索引 ids = self.index_op() + # 初始化子图列表 subgraphs = [] + # 初始化图模式列表 graphmodes = [] + # 遍历区域列表 for i, area in enumerate(self.areas): + # 根据操作索引对区域中的操作进行排序 area.ops.sort(key=ids.get) + # 创建子图并添加到子图列表中 subgraphs.append(Graph('{}_{}'.format(self.graph.name, i), area.ops, area.stitch_info, area.recompute_ops)) + # 根据区域模式设置图模式并添加到图模式列表中 graphmodes.append("basic" if area.mode == self.Area.MODE_BASIC else "composite") return subgraphs, graphmodes def pattern_fuse(self, fuse_func=None): """fuse Areas by pattern repeatedly""" + # 删除fuse_func参数 del fuse_func + # 抛出异常,提示pattern_fuse()函数在当前类中未实现 raise Exception("pattern_fuse() is not implemented in {}".format(self.__class__.__name__)) def split(self): """Split graph by pattern""" + # 融合模式 self.pattern_fuse() + # 重新计算融合结果 self.recompute_fuse() + # 将图拆分为子图和模式 subgraphs, graphmodes = self.to_subgraphs() + # 返回子图和模式 return subgraphs, graphmodes def set_recompute(self, dom_area, ops, user_area): """set the recompute area and connect with other areas""" + # 将操作添加到recompute区域 self.recom_area.recompute_ops.extend(ops) # recom_area: set dom_op and correct ops length + # 设置dom_op并修正ops长度 patterns = list(PrimLib.iter_type(op) for op in ops) self.recom_area.pattern = max(patterns) for i, pat in enumerate(patterns): if pat == self.recom_area.pattern: + # 修正ops列表,使其长度与patterns相同 self.recom_area.ops = [ops[i]] * len(ops) break + # disconnect dom_area and user_area + # 断开dom_area和user_area的连接 self.dom_user_r = dom_area.out_relations[user_area] dom_area.out_relations.pop(user_area) user_area.in_relations.pop(dom_area) + # connect recom_area and user_area + # 连接recom_area和user_area user_area.in_relations[self.recom_area] = self.dom_user_r self.recom_area.out_relations[user_area] = self.dom_user_r + # connect recom_pre and recom_area + # 连接recom_pre和recom_area self.recom_pre = self.area_map.get(ops[0].inputs[0].op) if ops[0].inputs and ops[0].inputs[0].op else None if self.recom_pre is not None: self.recom_area.in_relations[self.recom_pre] = dom_area.in_relations[self.recom_pre] self.recom_pre.out_relations[self.recom_area] = dom_area.in_relations[self.recom_pre] + # set related areas + # 设置相关区域 self.recom_user = user_area self.recom_dom = dom_area self.recom_res = False def clear_recompute(self): """disconnect recom_area from other areas, and clear recom_area""" - self.recom_area.out_relations.clear() - self.recom_area.in_relations.clear() + # 断开recom_area与其他区域的连接 + self.recom_area.out_relations.clear() # 清除recom_area的输出关系 + self.recom_area.in_relations.clear() # 清除recom_area的输入关系 + + # 如果没有recom_res if not self.recom_res: + # 从recom_user的输入关系中移除recom_area self.recom_user.in_relations.pop(self.recom_area) + # 在recom_user的输入关系中添加recom_dom,并设置关系为dom_user_r self.recom_user.in_relations[self.recom_dom] = self.dom_user_r + # 在recom_dom的输出关系中设置recom_user的关系为dom_user_r self.recom_dom.out_relations[self.recom_user] = self.dom_user_r + # 如果存在recom_pre if self.recom_pre: + # 从recom_pre的输出关系中移除recom_area self.recom_pre.out_relations.pop(self.recom_area) + + # 清除recom_area的操作 self.recom_area.ops.clear() + # 清除recom_area的重计算操作 self.recom_area.recompute_ops.clear() + # 将recom_area的原始操作映射更新到orig_op_map中 self.orig_op_map.update(self.recom_area.ori_op_map) + # 清除recom_area的原始操作映射 self.recom_area.ori_op_map.clear() def to_subgraph(self, dom): """Transform area to subgraphs""" + # 获取索引操作符 ids = self.index_op() + + # 初始化一个空列表用于存储操作 dom_ops = list() + + # 将dom的ops属性扩展到dom_ops列表中 dom_ops.extend(dom.ops) + + # 根据ids的get方法对dom_ops进行排序 dom_ops.sort(key=ids.get) + + # 初始化一个空列表用于存储子图 subgraph = [] + + # 使用dom_ops和指定的图名创建一个Graph对象,并将其赋值给subgraph subgraph = Graph('{}_area'.format(self.graph.name), dom_ops) + + # 返回创建好的子图 return subgraph def find_cheap_regions(self, dom): @@ -595,44 +977,57 @@ class GraphSplitByPattern: def _grow_region(region_ops, op, weight, inputs): """include op to region_ops if region grow""" # region successfully ends at inputs + # 如果op没有输入,则将其添加到region_ops中,并返回False表示区域不再增长 if not op.inputs: region_ops.append(op) return False, op, weight, True + # 如果op的输入是inputs中的一个,且只有一个输入,且op的类型小于等于BROADCAST,则将其添加到region_ops中 if op.inputs[0] in inputs and len(op.inputs) == 1 and \ PrimLib.iter_type(op) <= PrimLib.BROADCAST: region_ops.append(op) return False, op, weight, True # region fails to grow + # 如果weight大于20,或op的输入个数大于1,或op的类型大于BROADCAST,则区域增长失败 max_weight = 20 if weight > max_weight or len(op.inputs) > 1 or PrimLib.iter_type(op) > PrimLib.BROADCAST: return False, op, weight, False # region grows successfully + # 如果区域成功增长,则增加weight,并将op添加到region_ops中 weight = weight + 1 region_ops.append(op) return True, op.inputs[0].op, weight, False def _find_cheap_regions(dom): + # 从dom区域中提取子图 sub = self.to_subgraph(dom) + # 推导子图的输入和输出参数 inputs, outputs = sub.deduce_parameters() + # 如果没有输入,则返回空列表 if not inputs: return list() cheap_regions = [] for output in outputs: - # tensor should have user other than user_area to be fused + # tensor should have user other than user_area to be fused + # 如果输出的操作数少于2,则跳过 if len(output.to_ops) < 2: continue + # 初始化区域操作列表和区域增长标志 region_ops = [] grow = True candidate_op = output.op weight = 1 result = False + # 当区域增长时,不断尝试扩展区域 while grow: grow, candidate_op, weight, result = _grow_region(region_ops, candidate_op, weight, inputs) + # 如果区域成功扩展,则反转操作列表,并检查是否满足区域大小条件 if result: region_ops.reverse() # tensor size should equal or becomes larger(cast up, broadcast) + # 如果区域的第一个操作的输入张量大小大于最后一个操作的输出张量大小,则跳过 if region_ops[0].inputs and region_ops[0].inputs[0].get_size() > region_ops[-1].output.get_size(): continue + # 将满足条件的区域添加到cheap_regions列表中 cheap_regions.append(region_ops) return cheap_regions @@ -642,7 +1037,9 @@ class GraphSplitByPattern: """select the user area has only one edge to dom area""" def _get_edge_num(dom_area, user_area): - """get edge num between two areas""" + """ + 获取两个区域之间的边数 + """ dom_graph = self.to_subgraph(dom_area) _, dom_outputs = dom_graph.deduce_parameters() user_graph = self.to_subgraph(user_area) @@ -650,13 +1047,18 @@ class GraphSplitByPattern: return len(list(t for t in dom_outputs if t in user_inputs)) def _select_user_area(tail_tensor): + """ + 选择只有一个边到dom区域的user区域 + """ user_areas = [] for user_op in tail_tensor.to_ops: user_area = self.area_map.get(user_op) if user_area.pattern == PrimLib.RESHAPE: + # 如果user区域是reshape操作,则跳过 continue edge_num = _get_edge_num(self.area_map.get(tail_tensor.op), user_area) if edge_num == 1 and not user_area in user_areas: + # 如果edge数为1且user区域不在user_areas中,则添加到user_areas user_areas.append(user_area) return user_areas @@ -669,14 +1071,20 @@ class GraphSplitByPattern: """split the unfusing pattern by add recompute area""" def recompute_cheap_region(dom): + # 对每个廉价区域进行处理 for cheap_region in cheap_regions: + # 获取用户区域 user_areas = self.select_user_area(cheap_region[-1].output) if not user_areas: continue for user_area in user_areas: + # 设置重新计算区域 self.set_recompute(dom, cheap_region, user_area) + # 融合模式 self.pattern_fuse(self.fuse_recom) + # 清除重新计算区域 self.clear_recompute() + # 如果重新计算结果有效 if self.recom_res: return True return False @@ -687,13 +1095,17 @@ class GraphSplitByPattern: for dom in orig_areas: if dom not in self.areas or not dom.out_relations: continue + # 找到廉价区域 cheap_regions = self.find_cheap_regions(dom) + # 对当前区域进行廉价区域重新计算 if recompute_cheap_region(dom): recompute_suc = True return recompute_suc + # 如果启用了重新计算 if self.enable_recompute: while do_recompute_fuse(): + # 融合模式 self.pattern_fuse() @@ -705,10 +1117,14 @@ class GraphSplitGpu(GraphSplitByPattern): def get_default_mode(self, op): """Get default mode in GPU""" + # 判断操作是否为矩阵乘法 if op.prim == "MatMul": + # 如果是矩阵乘法,并且输入数据类型为float16且属性Akg存在,则返回MODE_COMPOSITE模式,否则返回MODE_BASIC模式 return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" and op.attrs['Akg'] else \ self.Area.MODE_BASIC + # 获取操作的迭代类型 pattern = PrimLib.iter_type(op) + # 根据迭代类型返回相应的模式 return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE def pattern_fuse(self, fuse_func=None): @@ -720,128 +1136,213 @@ class GraphSplitGpu(GraphSplitByPattern): return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST def _broadcast_depth(dom): + # 如果dom的模式不是Elemwise或Broadcast,或者输出关系数量不为1,或者dom是输出,或者操作数超过BROADCAST_FUSE_DEPTH,则返回空列表 if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH: return [] + + # 获取dom的输出关系中的第一个元素及其对应的关系 a, r = list(dom.out_relations.items())[0] + + # 如果满足广播排除条件,或者a的输入关系数量不为1,则返回空列表 if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: return [] + + # 返回包含a的列表和False return [a], False def _broadcast_width(dom): + # 如果dom的模式不是PrimLib.ELEMWISE或PrimLib.BROADCAST,或者dom是输出节点,或者dom的操作数深度超过BROADCAST_FUSE_DEPTH,则返回空列表 if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH: return [] - fused = [] + fused = [] # 初始化一个空列表,用于存储符合条件的节点 + + # 遍历dom的输出关系 for a, r in dom.out_relations.items(): + # 如果满足以下任一条件,则返回空列表 if _broadcast_pat_exclude(dom, a, r) or not dom.check_acyclic(a) or \ (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): return [] + # 否则,将节点a添加到fused列表中 fused.append(a) + + # 返回fused列表和一个False值 return fused, False def _reduce_pat_exclude(_, a, r): + # 如果操作数a的操作数量大于设定的减少融合深度,则返回True if len(a.ops) > self.REDUCE_FUSE_DEPTH: return True + # 如果a的模式大于基本库中的元素级操作,或者r大于基本库中的减少操作,或者r等于基本库中的广播操作,则返回True return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST def _reduce_depth(dom): + # 检查dom的pattern是否为REDUCE,且dom的in_relations长度是否为1 if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: + # 如果不是,则返回空列表 return [] + + # 获取dom的in_relations的第一个元素及其关系 a, r = list(dom.in_relations.items())[0] + + # 检查特定条件,如果满足,则返回空列表 if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ _is_atomic_add_available(dom): # to evade the precision problem. + # 为避免精度问题 return [] + + # 检查是否满足特定条件或a的out_relations长度不为1,如果满足,则返回空列表 if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: return [] + + # 返回包含a的列表和布尔值True return [a], True def _reduce_width(dom): + # 判断dom的模式是否为REDUCE if dom.pattern != PrimLib.REDUCE: return [] + fused = [] for a, r in dom.in_relations.items(): + # 判断dom的第一个操作的输入数据类型是否为float16,且a为输出,a的操作数大于等于10,且满足_is_atomic_add_available条件 if dom.ops[0].inputs[0].dtype == "float16" and a.is_output and len(a.ops) >= 10 and \ _is_atomic_add_available(dom): + # 跳过,以避免精度问题 # to evade the precision problem. continue + + # 判断是否满足_reduce_pat_exclude条件,且a在dom中是无环的 if not _reduce_pat_exclude(dom, a, r) and a.check_acyclic(dom): + # 将满足条件的a添加到fused列表中 fused.append(a) + return fused, True def _is_atomic_add_available(dom): + # 检查是否存在Reduce操作 if any(("Reduce" in x.prim for x in dom.ops[1:])): return False + + # 获取第一个操作 op = dom.ops[0] + + # 检查是否有reduce_axis属性 if "reduce_axis" in op.attrs: reduce_axis = op.attrs["reduce_axis"] + # 检查是否有axis属性 elif "axis" in op.attrs: reduce_axis = [op.attrs["axis"]] else: + # 如果以上属性都不存在,抛出异常 raise Exception("For '{}', can not find the attr 'reduce_axis' or 'axis'".format(op.prim)) + + # 检查reduce_axis中是否包含输入张量的倒数第二个维度 if len(op.inputs[0].shape) - 1 in reduce_axis: + # 计算reduce操作的大小 reduce_size = prod_reduce(lambda x, y: x * y, (op.inputs[0].shape[i] for i in reduce_axis)) + # 判断reduce操作的大小是否大于等于1024 return reduce_size >= 1024 + # 如果没有包含倒数第二个维度,返回True return True def _reduce_output(dom): + # 判断dom的模式是否为REDUCE if dom.pattern != PrimLib.REDUCE: return [] + # 判断reduce操作的次数是否大于1 if reduce_nums(dom.ops) > 1: return [] + # 判断是否为原子加法可用 if _is_atomic_add_available(dom): return [] + # 判断是否为全归约 is_all_reduce = tensor_size(dom.ops[0].output) == 1 + # 排除大尺寸的全归约 # excluded large size all reduce if is_all_reduce and tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: return [] fused = [] for a, r in dom.out_relations.items(): + # 判断a的模式和r的模式是否都小于等于BROADCAST,且dom的a节点无环,且a不在reduce_out_exclude列表中 if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ dom.check_acyclic(a) and not dom.reduce_out_exclude(a): fused.append(a) return fused, False def _reduce_stitch(dom): + # 如果dom的模式不是PrimLib.REDUCE,则返回空列表 if dom.pattern != PrimLib.REDUCE: return [] + + # 如果dom的第一个操作的输出张量大小为1,则返回空列表 if tensor_size(dom.ops[0].output) == 1: return [] + + # 如果dom的第一个操作的第一个输入张量大小小于1024 * 12,则返回空列表 if tensor_size(dom.ops[0].inputs[0]) < 1024 * 12: return [] fused = [] + # 遍历dom的输出关系 for a, r in dom.out_relations.items(): + # 如果不满足拼接条件,则跳过当前循环 if not may_stitch(dom, a, r, 1024 * 8, 1024 * 1024): continue + + # 如果a的模式是PrimLib.REDUCE if a.pattern == PrimLib.REDUCE: + # 如果a和dom的reduce轴相同 if a.ops[0].attrs['reduce_axis'] == dom.ops[0].attrs['reduce_axis']: + # 将dom的输出名称添加到stitch_ops集合中 dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) + # 将a添加到fused列表中 fused.append(a) + # 如果a的模式是PrimLib.BROADCAST elif a.pattern == PrimLib.BROADCAST: + # 将dom的输出名称添加到stitch_ops集合中 dom.stitch_info.stitch_ops.add(dom.ops[0].output.name) + # 将a添加到fused列表中 fused.append(a) + + # 返回fused列表和False return fused, False def _transpose(dom): + # 如果dom的操作数不为1或者第一个操作不是"Transpose",则返回空列表 if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": return [] fused = [] + + # 遍历dom的输入关系 for a, _ in dom.in_relations.items(): + # 如果输入的操作模式小于等于PrimLib.BROADCAST, + # 并且输入的操作是无环的,并且输入的操作数小于等于TRANSPOSE_FUSE_DEPTH,则将其添加到fused列表中 if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and len(a.ops) <= self.TRANSPOSE_FUSE_DEPTH: fused.append(a) + # 返回fused列表和一个表示成功的布尔值 return fused, True def _strided_slice(dom): + # 判断操作是否为 StridedSlice if dom.dom_op().prim != "StridedSlice": return [] + + # 初始化 fused 列表 fused = [] + + # 遍历 dom 的输入关系 for a, _ in dom.in_relations.items(): + # 判断输入节点是否满足条件 if a.pattern <= PrimLib.BROADCAST and a.check_acyclic(dom) and \ len(a.out_relations) == 1 and not a.is_output: + # 如果满足条件,则将其添加到 fused 列表中 fused.append(a) + + # 返回 fused 列表和布尔值 True return fused, True def _gather_output(dom, reduce_fusion=False): @@ -860,37 +1361,52 @@ class GraphSplitGpu(GraphSplitByPattern): Returns: Boolean. Whether this operator should be excluded. """ + # 获取reduce操作符的reduce轴 axis = op.attrs["reduce_axis"] + # 如果reduce轴是整数,则将其转换为列表 if isinstance(axis, int): axis = [axis] + # 获取输入数据的形状长度 in_shape_len = len(op.inputs[0].shape) + # 对每个reduce轴进行转换,如果是负数则加上输入形状的长度 for i, dim in enumerate(axis): + # 如果dim是负数,则将其转换为正数 axis[i] = in_shape_len + dim if dim < 0 else dim + + # 初始化一个空列表,用于存储经过过滤的reduce轴 fix_axis = [] + # 遍历每个reduce轴 for ax in axis: + # 如果当前轴的长度为1,则跳过 if op.inputs[0].shape[ax] == 1: continue + # 否则,将当前轴添加到fix_axis列表中 fix_axis.append(ax) + + # 返回fix_axis和axis_list的交集是否非空 return bool(set(fix_axis) & set(axis_list)) def _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis): - consisten_shape = start_op.output.shape - visited = [] - op_queue = [start_op] + consisten_shape = start_op.output.shape # 获取起始操作的输出形状 + visited = [] # 用于记录已经访问过的操作 + op_queue = [start_op] # 初始化操作队列,包含起始操作 def _early_stop(cur_op): + # 提前停止函数 if cur_op in end_ops: - # If reduce the gather axis, stop early for not fusion. + # 如果当前操作在结束操作中 + # 如果减少聚合轴,则不融合,提前停止 if cur_op.prim == "ReduceSum" and _reduce_exclude(cur_op, gather_axis): return True else: + # 如果当前操作不在起始操作中或形状不一致 if (cur_op.prim in start_prims and cur_op != start_op) or \ consisten_shape != cur_op.output.shape: return True return False while op_queue: - tmp_queue = [] + tmp_queue = [] # 临时队列,用于存储下一层的操作 for op in op_queue: if op in visited or not op in total_ops: continue @@ -899,9 +1415,9 @@ class GraphSplitGpu(GraphSplitByPattern): if op in end_ops: continue for to_op in op.output.to_ops: - tmp_queue.append(to_op) - visited.append(op) - op_queue = tmp_queue + tmp_queue.append(to_op) # 将当前操作的输出操作添加到临时队列中 + visited.append(op) # 将当前操作标记为已访问 + op_queue = tmp_queue # 更新操作队列为临时队列 return True def _shape_consistent(start_prims, end_prims, source, target): @@ -911,30 +1427,41 @@ class GraphSplitGpu(GraphSplitByPattern): When fusing ReduceSum, first check if TensorScatterAdd and/or UnsortedSegmentSum has already been fused, if so, stop ReduceSum fusion. """ + # 获取source和target的所有操作 total_ops = source.ops + target.ops + # 获取所有操作的prim集合 op_prims_set = {op.prim for op in total_ops} + # 如果需要融合ReduceSum,并且TensorScatterAdd和/或UnsortedSegmentSum已经被融合,则不融合ReduceSum if reduce_fusion and (len({"TensorScatterAdd", "UnsortedSegmentSum"} & op_prims_set) >= 1): return False + + # 获取source的起始操作 start_ops = [] for op in source.ops: if op.prim in start_prims: start_ops.append(op) + + # 获取total_ops中的结束操作 end_ops = [] for op in total_ops: if op.prim in end_prims and not any((to_op in total_ops for to_op in op.output.to_ops)): end_ops.append(op) + # 遍历start_ops中的每一个操作 for start_op in start_ops: + # 获取操作的gather_axis属性 gather_axis = start_op.attrs.get("axis", None) if gather_axis is None: - # For GatherNd + # 对于GatherNd gather_axis = list(range(len(start_op.inputs[1].shape))) elif isinstance(gather_axis, int): gather_axis = [gather_axis] + # 调用_bfs_visit函数检查形状一致性 is_consistent = _bfs_visit(start_op, start_prims, total_ops, end_ops, gather_axis) if not is_consistent: return False + return True appected_areas = {"ReduceSum"} if reduce_fusion else {"TensorScatterAdd", "UnsortedSegmentSum"} @@ -947,11 +1474,13 @@ class GraphSplitGpu(GraphSplitByPattern): def _broadcast_tot(dom): """Fuse rule for TensorScatterAdd and UnsortedSegmentSum.""" def _same_input(op1, op2): + # 判断两个操作的输入是否有交集 return bool(set(op1.inputs) & set(op2.inputs)) if len(dom.ops) != 1: return [] + # 只融合 TensorScatterAdd 的第一个输入和 UnsortedSegmentSum 的前两个输入 # Only fuse the first input for `TensorScatterAdd`` and the first and second input for `UnsortedSegmentSum`. fuse_arg = {"TensorScatterAdd": slice(1, None), "UnsortedSegmentSum": slice(0, 2)} arg_idx = fuse_arg.get(dom.dom_op().prim, -1) @@ -960,9 +1489,11 @@ class GraphSplitGpu(GraphSplitByPattern): fuse_tensor = dom.dom_op().inputs[arg_idx] for a, _ in dom.in_relations.items(): + # 规则1:类型相同且有至少一个相同输入 # Rule 1: Same type with at lease one same input. if a.dom_op().prim == dom.dom_op().prim and _same_input(dom.dom_op(), a.dom_op()): return [a], True + # 规则2:在指定位置输入中融合操作(reshape/elementwise/broadcast) # Rule 2: Fuse op(reshape/elementwise/broadcast) in specified position inputs. if a.pattern <= PrimLib.BROADCAST and dom.check_acyclic(a) and \ any((op.output in fuse_tensor for op in a.ops)): @@ -971,74 +1502,121 @@ class GraphSplitGpu(GraphSplitByPattern): def _broadcast_onehot(dom, fwd=True): """Fuse rule for OneHot.""" + # 判断当前操作的原始操作是否为OneHot if dom.dom_op().prim != "OneHot": return None fused = [] + # 根据fwd参数决定是遍历输入关系还是输出关系 neighbours = dom.in_relations.items() if fwd else dom.out_relations.items() for a, _ in neighbours: + # 判断当前关系是否满足广播模式 if a.pattern <= PrimLib.BROADCAST: + # 判断当前关系是否满足无环、单一输出且不是输出节点(如果fwd为True) + # 或者判断当前关系的来源节点是否满足无环(如果fwd为False) if (fwd and a.check_acyclic(dom) and len(a.out_relations) == 1 and not a.is_output) or \ (not fwd and dom.check_acyclic(a)): + # 将满足条件的关系添加到fused列表中 fused.append(a) return fused, fwd def _h_broadcast(dom, a): + # 判断dom的模式是否大于PrimLib.BROADCAST if dom.pattern > PrimLib.BROADCAST: return [] + # 判断a的模式是否小于等于PrimLib.BROADCAST,并且dom的第一个操作的输出形状与a的第一个操作的输出形状是否相同 return a.pattern <= PrimLib.BROADCAST and dom.ops[0].output.shape == a.ops[0].output.shape def _h_reduce(dom, a): + # 如果dom的模式不是PrimLib.REDUCE或者dom的拼接信息包含拼接操作,则返回空列表 if dom.pattern != PrimLib.REDUCE or dom.stitch_info.stitch_ops: return [] + + # 获取dom的操作 dom_op = dom.ops[0] + # 如果dom的操作不是reduce操作或者dom是原子加法可用的情况,则返回空列表 if not PrimLib.is_reduce(dom_op) or _is_atomic_add_available(dom): return [] + + # 获取a的操作 op = a.ops[0] + # 返回a的模式是PrimLib.REDUCE,a的拼接信息不包含拼接操作,a的操作是reduce操作, + # dom的输入和a的输入形状相同,且dom和a的reduce轴相同 return a.pattern == PrimLib.REDUCE and not a.stitch_info.stitch_ops and \ PrimLib.is_reduce(op) and dom_op.inputs[0].shape == op.inputs[0].shape and \ dom_op.attrs.get("reduce_axis") == op.attrs.get("reduce_axis") def _h_opaque(dom, a): + # 检查dom的第一个操作是否是StridedSlice if dom.ops[0].prim not in {"StridedSlice"}: + # 如果不是,则返回空列表 return [] + # 返回以下条件的逻辑与结果: + # a的第一个操作是否与dom的第一个操作相同 + # dom的第一个操作的输出形状是否与a的第一个操作的输出形状相同 + # dom的第一个操作的第一个输入的形状是否与a的第一个操作的第一个输入的形状相同 return a.ops[0].prim == dom.ops[0].prim and dom.ops[0].output.shape == a.ops[0].output.shape and \ dom.ops[0].inputs[0].shape == a.ops[0].inputs[0].shape def _fuse_loop(): changed = True + # 当changed为True时,进入循环 while changed: + # 尝试融合reshape模式 changed = self.fuse(CommonPattern.reshape) + # 尝试融合assign模式,如果失败则保留原来的changed值 changed = self.fuse(CommonPattern.assign) or changed + # 尝试融合elemwise_depth模式,如果失败则保留原来的changed值 changed = self.fuse(CommonPattern.elemwise_depth) or changed + # 尝试融合elemwise_width模式,如果失败则保留原来的changed值 changed = self.fuse(CommonPattern.elemwise_width) or changed + # 尝试融合_reduce_depth模式,如果失败则保留原来的changed值 changed = self.fuse(_reduce_depth) or changed + # 尝试融合_reduce_width模式,如果失败则保留原来的changed值 changed = self.fuse(_reduce_width) or changed + # 尝试融合_broadcast_depth模式,如果失败则保留原来的changed值 changed = self.fuse(_broadcast_depth) or changed + # 尝试融合_broadcast_width模式,如果失败则保留原来的changed值 changed = self.fuse(_broadcast_width) or changed + # 尝试融合_strided_slice模式,如果失败则保留原来的changed值 changed = self.fuse(_strided_slice) or changed + # 尝试融合正向的_broadcast_onehot模式,如果失败则保留原来的changed值 changed = self.fuse(partial(_broadcast_onehot, fwd=True)) or changed + # 尝试融合反向的_broadcast_onehot模式,如果失败则保留原来的changed值 changed = self.fuse(partial(_broadcast_onehot, fwd=False)) or changed + # 尝试融合_broadcast_tot模式,如果失败则保留原来的changed值 changed = self.fuse(_broadcast_tot) or changed + # 尝试融合不启用reduce_fusion的_gather_output模式,如果失败则保留原来的changed值 changed = self.fuse(partial(_gather_output, reduce_fusion=False)) or changed + # 尝试融合启用reduce_fusion的_gather_output模式,如果失败则保留原来的changed值 changed = self.fuse(partial(_gather_output, reduce_fusion=True)) or changed + # 尝试融合_reduce_output模式,如果失败则保留原来的changed值 changed = self.fuse(_reduce_output) or changed + # 如果启用了stitch_fusion,则尝试融合_reduce_stitch模式,如果失败则保留原来的changed值 if self.enable_stitch_fusion: changed = self.fuse(_reduce_stitch) or changed + # 融合_transpose模式 self.fuse(_transpose) + # 如果启用了horizontal_fusion,则进行以下融合操作 if self.enable_horizontal_fusion: + # 融合_h_broadcast模式 self.hfuse(_h_broadcast) + # 融合_h_reduce模式 self.hfuse(_h_reduce) + # 融合_h_opaque模式 self.hfuse(_h_opaque) def _fuse_once(fuse_func): + # 如果满足任一条件,则直接返回 if fuse_func(CommonPattern.reshape) or fuse_func(CommonPattern.elemwise_depth) or \ fuse_func(CommonPattern.elemwise_width) or fuse_func(_reduce_depth) or \ fuse_func(_reduce_width) or fuse_func(_broadcast_depth) or fuse_func(_broadcast_width): return + # 如果满足任一条件,则直接返回 if fuse_func(_reduce_output) or (self.enable_stitch_fusion and fuse_func(_reduce_stitch)): return + # 调用 fuse_func 对 _transpose 进行处理 fuse_func(_transpose) return @@ -1057,6 +1635,7 @@ class GraphSplitAscend(GraphSplitByPattern): """Get default mode for Ascend""" def _dtype_same(tensors): + # 判断所有张量的数据类型是否相同 dtype = tensors[0].dtype for tensor_ in tensors: if tensor_.dtype != dtype: @@ -1064,25 +1643,31 @@ class GraphSplitAscend(GraphSplitByPattern): return True if op.prim == "MatMul": + # 如果是矩阵乘法操作 if op.inputs[0].dtype == "float16" and not _dtype_same(op.inputs): + # 如果输入张量中存在不同数据类型的张量,则返回复合模式 return self.Area.MODE_COMPOSITE if op.prim in ("Tile", "BroadcastTo", "ExpandDims"): + # 如果是平铺、广播到、扩展维度操作 return self.Area.MODE_COMPOSITE return self.Area.MODE_BASIC def pattern_fuse(self, fuse_func=None): """fuse Areas by pattern""" + # 判断某个运算是否可能在多核上运行 def _likely_multicore(dom): op = dom.dom_op() iter_size = tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0]) return iter_size > 1024 + # 判断某个运算是否应该被排除在广播融合之外 def _broadcast_pat_exclude(dom, a, r): if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH): return True return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST + # 获取广播融合深度 def _broadcast_depth(dom): if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1: return [] @@ -1091,6 +1676,7 @@ class GraphSplitAscend(GraphSplitByPattern): return [] return [a], False + # 获取广播融合宽度 def _broadcast_width(dom): if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): return [] @@ -1102,6 +1688,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + # 判断某个运算是否应该被排除在归约融合之外 def _reduce_pat_exclude(dom, a, r): if len(a.ops) > self.REDUCE_FUSE_DEPTH: return True @@ -1110,6 +1697,7 @@ class GraphSplitAscend(GraphSplitByPattern): return True return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE + # 获取归约融合深度 def _reduce_depth(dom): if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: return [] @@ -1118,6 +1706,7 @@ class GraphSplitAscend(GraphSplitByPattern): return [] return [a], True + # 获取归约融合宽度 def _reduce_width(dom): if dom.pattern != PrimLib.REDUCE: return [] @@ -1127,6 +1716,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, True + # 获取矩阵乘法融合深度 def _matmul_depth(dom): if dom.dom_op().prim != "MatMul" and dom.dom_op().prim != "BatchMatMul": return [] @@ -1139,6 +1729,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + # 获取归约输出融合 def _reduce_output(dom): if dom.pattern != PrimLib.REDUCE: return [] @@ -1152,6 +1743,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + # 获取归约拼接融合 def _reduce_stitch(dom): if dom.pattern != PrimLib.REDUCE: return [] @@ -1173,10 +1765,11 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + # 判断转置数据操作是否支持某个运算的融合 def _transdata_pattern_support(dom, a): transdata_op = dom.dom_op() - # Currently, if transdata has the pad, it is not used to fuse + # 如果转置数据操作有填充,则不用于融合 def _has_pad(): res = False input_shape = transdata_op.inputs[0].shape @@ -1197,7 +1790,7 @@ class GraphSplitAscend(GraphSplitByPattern): if a.dom_op().prim == "MatMul" and len(dom.ops) == 1: return True - # reshape/elewise/broadcast + transdata + # 重塑/元素级操作/广播 + 转置数据 if a.pattern <= PrimLib.BROADCAST and len(dom.ops) == 1: op_attrs = dom.dom_op().attrs if 'src_format' not in op_attrs.keys() \ @@ -1207,12 +1800,13 @@ class GraphSplitAscend(GraphSplitByPattern): src_format, dst_format = op_attrs['src_format'], op_attrs['dst_format'] if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW): return True - # For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported + # 对于Default/NCHW到FRAC_NZ的转换,目前仅支持Cast+Transdata if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ \ and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output: return True return False + # 获取转置数据融合 def _transdata(dom): if dom.dom_op().prim != "TransData": return [] @@ -1222,6 +1816,7 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, True + # 循环进行融合操作 def _fuse_loop(): changed = True while changed: @@ -1239,6 +1834,7 @@ class GraphSplitAscend(GraphSplitByPattern): changed = self.fuse(_reduce_stitch) or changed self.fuse(_transdata) + # 执行一次融合操作 def _fuse_once(fuse_func): if fuse_func(CommonPattern.reshape) or fuse_func(CommonPattern.elemwise_depth) or \ fuse_func(CommonPattern.elemwise_width) or fuse_func(_reduce_depth) or \ @@ -1254,9 +1850,14 @@ class GraphSplitAscend(GraphSplitByPattern): def split(graph, target, flags): """Split graph""" + # 初始化结果变量 result = None + # 如果目标设备是"cuda" if target == "cuda": + # 使用GraphSplitGpu类对图进行分割 result = GraphSplitGpu(graph, flags).split() else: + # 否则,使用GraphSplitAscend类对图进行分割 result = GraphSplitAscend(graph, flags).split() + # 返回分割结果 return result diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model.py index 23701f97..a84d421b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model.py @@ -24,39 +24,61 @@ class Utils: @staticmethod def get_attr_type(attr): """Get attr type""" + # 判断attr是否为bool类型 if isinstance(attr, bool): return 'bool' + # 判断attr是否为str类型 if isinstance(attr, str): return 'str' + # 判断attr是否为int类型 if isinstance(attr, int): return 'int' + # 判断attr是否为float类型 if isinstance(attr, float): return 'float' + # 判断attr是否为list或tuple类型 if isinstance(attr, (list, tuple)): + # 判断attr是否为空 if not attr: raise ValueError("attr is invalid: the length of attr is 0") + # 判断attr的第一个元素是否为int类型 if isinstance(attr[0], int): return 'listInt' + # 判断attr的第一个元素是否为str类型 if isinstance(attr[0], str): return 'listStr' + # 如果attr的类型不在支持的列表中,则抛出异常 raise ValueError("attr {} type {} is not in supported list ['bool', 'str', 'int', 'float', 'int' list, " "'str' list]".format(attr, type(attr))) class DataFormat: """DataFormat""" + # 默认格式 DEFAULT = "DefaultFormat" + # NC1KHKWHWC0格式 NC1KHKWHWC0 = "NC1KHKWHWC0" + # ND格式 ND = "ND" + # NCHW格式 NCHW = "NCHW" + # NHWC格式 NHWC = "NHWC" + # HWCN格式 HWCN = "HWCN" + # NC1HWC0格式 NC1HWC0 = "NC1HWC0" + # FRAC_Z格式 FRAC_Z = "FracZ" + # FRAC_NZ格式 FRAC_NZ = "FRACTAL_NZ" + # C1HWNCOC0格式 C1HWNCOC0 = "C1HWNCoC0" + # NC1HWC0_C04格式 NC1HWC0_C04 = "NC1HWC0_C04" + # FRACTAL_Z_C04格式 FRACTAL_Z_C04 = "FRACTAL_Z_C04" + # NDHWC格式 NDHWC = "NDHWC" def __init__(self): @@ -65,29 +87,47 @@ class DataFormat: class DataType: """Data Type""" + # 浮点型 FLOAT = "float" + # 半精度浮点型 FLOAT16 = "float16" + # 单精度浮点型 FLOAT32 = "float32" + # 双精度浮点型 FLOAT64 = "float64" + # 整型 INT = "int" + # 8位整型 INT8 = "int8" + # 16位整型 INT16 = "int16" + # 32位整型 INT32 = "int32" + # 64位整型 INT64 = "int64" + # 无符号整型 UINT = "uint" + # 8位无符号整型 UINT8 = "uint8" + # 16位无符号整型 UINT16 = "uint16" + # 32位无符号整型 UINT32 = "uint32" + # 64位无符号整型 UINT64 = "uint64" + # 布尔型 BOOL = "bool" + # 初始化函数 def __init__(self): + # 无需执行任何操作 pass class PrimLib: """Prim lib""" + # 定义PrimLib类中的常量 UNKNOWN = 0 RESHAPE = 1 ELEMWISE = 2 @@ -102,53 +142,73 @@ class PrimLib: """Prim""" def __init__(self, iter_type, calibrate=1, relation_func=None): + # 初始化Prim类,设置iter_type、calibrate和relation_func属性 self.iter_type = iter_type self.calibrate = calibrate self.relation_func = relation_func if relation_func is None: + # 如果relation_func为None,则设置默认的relation_func self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x) def default_reshape_relation(self, op, input_idx): """Process reshape relation""" + # 处理reshape关系 axis_relation, elem_relation = self.unknown_relation(op, input_idx) + # 将elem_relation设置为PrimLib.RESHAPE elem_relation = [PrimLib.RESHAPE] * len(elem_relation) return axis_relation, elem_relation def default_elemwise_broadcast_relation(self, op, input_idx): """Process elemwise and broadcast relation""" + # 处理elemwise和broadcast关系 out_shape = op.output.shape in_shape = op.inputs[input_idx].shape + # 如果输出形状的长度小于输入形状的长度,则抛出异常 if len(out_shape) < len(in_shape): raise ValueError("For '{}', the input/output size is abnormal, as the length of output shape{} is less " "than the length of input shape{}".format(op.prim, out_shape, in_shape)) axis_relation, elem_relation = [], [] + # 计算输出形状和输入形状的长度差 delta = len(out_shape) - len(in_shape) if delta > 0: + # 如果输出形状的长度大于输入形状的长度,则在axis_relation和elem_relation中添加None for i in range(0, delta): axis_relation.append(None) elem_relation.append(None) + # 遍历输入形状的每个元素 for i, _ in enumerate(in_shape): + # 在axis_relation中添加当前元素的索引 axis_relation.append(i) + # 如果输出形状的对应元素等于输入形状的对应元素,则elem_relation添加PrimLib.ELEMWISE,否则添加PrimLib.BROADCAST elem_relation.append( PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST) return axis_relation, elem_relation def default_reduce_relation(self, op, input_idx): """Process reduce relation""" + # 处理reduce关系 axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx) + # 遍历reduce_axis中的每个元素 for i in op.attrs['reduce_axis']: + # 将elem_relation中对应元素的值设置为PrimLib.REDUCE elem_relation[i] = PrimLib.REDUCE return axis_relation, elem_relation def unknown_relation(self, op, input_idx): """Process unknown relation""" + # 获取输出和输入的形状 out_shape = op.output.shape in_shape = op.inputs[input_idx].shape + # 获取所有可能的轴关系 all_relation = list(range(len(in_shape))) + # 初始化轴关系列表 axis_relation = [all_relation for i in range(0, len(out_shape))] + # 初始化元素关系列表 elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))] + # 返回轴关系和元素关系 return axis_relation, elem_relation + # 默认的关系函数列表 default_relation_func = [ unknown_relation, default_reshape_relation, @@ -158,6 +218,7 @@ class PrimLib: unknown_relation, ] + # 定义基本操作 primtives = { 'Add': Prim(ELEMWISE), 'Abs': Prim(ELEMWISE), @@ -239,7 +300,9 @@ class PrimLib: @classmethod def get_prim(cls, op): """Get op primtive""" + # 从cls.primtives中获取op.prim对应的prim prim = cls.primtives.get(op.prim, None) + # 如果prim为None,则打印警告信息,并返回cls.default_primtive if prim is None: print('[WARN] primtive is not registered: ' + op.prim) prim = cls.default_primtive @@ -248,50 +311,65 @@ class PrimLib: @classmethod def input_relation(cls, op, input_idx): """Get op's input_relation according to input_idx""" + # 调用cls.get_prim(op)获取op对应的prim,然后调用prim的relation_func方法获取op的input_relation return cls.get_prim(op).relation_func(op, input_idx) @classmethod def iter_type(cls, op): """Get op's iter type""" + # 调用cls.get_prim(op)获取op对应的prim,然后返回prim的iter_type return cls.get_prim(op).iter_type @classmethod def is_reduce(cls, op): """Check whether op's iter type is reduce""" + # 调用cls.get_prim(op)获取op对应的prim,然后判断prim的iter_type是否为cls.REDUCE return cls.get_prim(op).iter_type == cls.REDUCE @classmethod def calibrate_iter_size(cls, op, iter_size): """Get calibrate_iter_size""" + # 调用cls.get_prim(op)获取op对应的prim,然后返回prim的calibrate乘以iter_size return cls.get_prim(op).calibrate * iter_size @classmethod def dtype_bytes(cls, dtype): """Get dtype bytes""" + # 初始化bits和unit为1 bits, unit = 1, 1 + # 从dtype的最后一个字符开始向前遍历 for i in range(len(dtype) - 1, 0, -1): + # 如果当前字符是数字,则将bits加上当前字符对应的数字乘以unit,并将unit乘以10 if dtype[i].isdecimal(): bits += int(dtype[i]) * unit unit *= 10 + # 如果当前字符不是数字,则跳出循环 else: break + # 返回bits除以8的结果 return bits // 8 @classmethod def inplace_reuse(cls, op, input_idx, start_axis=0): """Check whether op is inplace reuse""" + # 如果op.output.dtype的字节数大于op.inputs[input_idx].dtype的字节数,则返回False if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype): return False + # 调用cls.get_prim(op)获取op对应的prim,然后调用prim的relation_func方法获取op的input_relation _, elem_relation = cls.get_prim(op).relation_func(op, input_idx) + # 从start_axis开始遍历elem_relation for i in range(start_axis, len(elem_relation)): + # 如果elem_relation中的元素不等于cls.ELEMWISE,则返回False if elem_relation[i] != cls.ELEMWISE: return False + # 如果以上条件都不满足,则返回True return True class Tensor: """Tensor""" + # 参数类型常量 PARA_NONE = 0 PARA_INPUT = 1 PARA_OUTPUT = 2 @@ -303,6 +381,7 @@ class Tensor: self.members = [leader] def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0): + # 初始化Tensor对象 self.name = name self.shape = shape self.dtype = dtype @@ -313,13 +392,16 @@ class Tensor: self.buddy = None def __str__(self): + # 返回Tensor对象的字符串表示 return self.name + str(list(self.shape)) def __repr__(self): + # 返回Tensor对象的字符串表示 return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) def get_size(self): """Get size""" + # 获取Tensor对象的大小 size = PrimLib.dtype_bytes(self.dtype) for i in self.shape: size *= i @@ -327,6 +409,7 @@ class Tensor: def add_buddy(self, tensor): """Add buddy""" + # 添加buddy if self.buddy is None: self.buddy = self.Buddy(self) self.buddy.members.append(tensor) @@ -337,6 +420,7 @@ class Value: """Value""" def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT): + # 初始化Value对象 self.name = name self.shape = [1] self.dtype = dtype @@ -344,14 +428,17 @@ class Value: self.data_format = data_format def __str__(self): + # 返回Value对象的字符串表示 return self.name + str(list(self.shape)) def __repr__(self): + # 返回Value对象的字符串表示 return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) @staticmethod def get_size(): """Get size""" + # 获取Value对象的大小 return 1 @@ -359,23 +446,31 @@ class Operator: """Operator""" def __init__(self, primtive, inputs, output, attrs): + # 初始化Operator对象 self.prim = primtive self.inputs = inputs self.output = output self.attrs = attrs + # 将当前Operator对象添加到每个输入的to_ops列表中 for t in inputs: t.to_ops.append(self) + # 如果输出的op属性为None,则将当前Operator对象赋值给输出的op属性 if output.op is None: output.op = self + # 初始化all_inputs列表,用于存储Tensor输入和Value输入 self.all_inputs = [] # include Tensor inputs and Value inputs. def __str__(self): + # 将self.all_inputs中的元素转换为字符串,并用逗号连接起来 args = ', '.join((str(t) for t in self.all_inputs)) + # 构造表达式字符串 expr = "%s = %s.%s(%s) id:%s" % ( str(self.output), self.prim, self.output.dtype, args, id(self)) + # 如果self.attrs不为空,则返回表达式字符串和self.attrs的字符串连接 return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs)) def __repr__(self): + # 返回self的字符串表示 return str(self) @@ -383,6 +478,7 @@ class Graph: """Graph""" def __init__(self, name, ops, stitch_info=None, recompute_ops=None): + # 初始化Graph对象 self.name = name self.ops = ops # in topo order, can not use set self.inputs = [] @@ -393,10 +489,12 @@ class Graph: def set_processor(self, processor): """Set processor""" + # 设置处理器 self.processor = processor def add(self, ops): """Add ops""" + # 添加操作 if isinstance(ops, Operator): self.ops.append(ops) else: @@ -404,101 +502,148 @@ class Graph: def extract_subgraph(self, graph_name, tensor_names, difference=False): """Extract subgraph from this graph""" + # 从当前图中提取子图 graph = Graph(graph_name, []) outputs = set(tensor_names) if difference: + # 如果difference为True,则提取不在outputs中的操作 for op in self.ops: if op.output.name not in outputs: graph.add(op) else: + # 如果difference为False,则提取在outputs中的操作 for op in self.ops: if op.output.name in outputs: graph.add(op) outputs.remove(op.output.name) + # 如果outputs中还有元素,则抛出异常 for name in outputs: raise ValueError("Invalid input tensor : {}, can not find it in graph".format(name)) return graph def deduce_parameters(self): """Deduce parameters""" + # 初始化输入和输出列表 inputs, outputs = [], [] + # 遍历所有操作 for op in self.ops: + # 遍历操作的所有输入 for t in op.inputs: + # 如果输入不在输入列表中,且输入的操作不在操作列表中,则将输入添加到输入列表中 if t not in inputs and t.op not in self.ops: inputs.append(t) + # 如果操作输出已经在输出列表中,则跳过 if op.output in outputs: continue + # 如果操作输出是输出参数类型,或者操作输出没有后续操作,则将操作输出添加到输出列表中 if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops: outputs.append(op.output) continue + # 如果操作输出的后续操作不在操作列表中,则将操作输出添加到输出列表中 if any((succ not in self.ops for succ in op.output.to_ops)): outputs.append(op.output) + # 如果有指定的输入,则将指定的输入赋值给输入列表 if self.inputs: inputs = self.inputs + # 如果有指定的输出,则将指定的输出赋值给输出列表 if self.outputs: outputs = self.outputs + # 返回输入和输出列表 return inputs, outputs def __str__(self): + # 调用deduce_parameters方法获取输入和输出列表 inputs, outputs = self.deduce_parameters() + # 将输入列表转换为字符串 para_str = ', '.join((repr(t) for t in inputs)) + # 将输出列表转换为字符串 out_str = ', '.join((repr(t) for t in outputs)) + # 初始化行列表 lines = [] + # 添加操作名称、输入和输出到行列表中 lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str)) + # 如果有拼接信息,则添加拼接操作和拼接原子操作到行列表中 if self.stitch_info: if self.stitch_info.stitch_ops: lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops)) if self.stitch_info.stitch_atomic_ops: lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops)) + # 遍历所有操作,将操作添加到行列表中 for op in self.ops: lines.append(' ' + str(op)) + # 添加结束符号到行列表中 lines.append('}') + # 将行列表转换为字符串并返回 return '\n'.join(lines) def __repr__(self): + # 返回对象的字符串表示 return str(self) def dump(self): """Dump Graph to json""" + # 将Graph转换为json格式 attr_name = {'reduce_axis': 'axis'} + # 获取Graph的输入和输出参数 inputs, outputs = self.deduce_parameters() input_desc, output_desc, op_desc = [], [], [] + # 遍历输入参数 for t in inputs: + # 将输入参数转换为字典格式 input_desc.append([{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}]) + # 遍历输出参数 for t in outputs: + # 将输出参数转换为字典格式 output_desc.append({'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}) + # 遍历Graph中的操作 for op in self.ops: attrs, in_desc = [], [] + # 遍历操作中的属性 for a in op.attrs: + # 获取属性名 name = attr_name.get(a, a) + # 将属性转换为字典格式 attrs.append( {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])}) + # 遍历操作中的输入 for t in op.all_inputs: + # 如果输入是Tensor类型 if isinstance(t, Tensor): + # 将输入转换为字典格式 in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}]) else: + # 将输入转换为字典格式 in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape, 'tensor_name': t.name, 'format': t.data_format}]) + # 将操作输出转换为字典格式 out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape, 'tensor_name': op.output.name, 'format': op.output.data_format}] + # 将操作转换为字典格式 op_desc.append({'attr': attrs, 'impl_path': '', 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc}) + # 将Graph转换为字典格式 graph_desc = {'composite': True, 'composite_graph': '', 'id': 0, 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc, 'platform': 'AKG', 'process': self.processor} + # 如果Graph中有stitch信息 if self.stitch_info and self.stitch_info.stitch_ops: + # 将stitch信息转换为字典格式 buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)} + # 如果有stitch_atomic_ops if self.stitch_info.stitch_atomic_ops: + # 将stitch_atomic_ops转换为字典格式 buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops) + # 将stitch信息添加到Graph字典中 graph_desc['buffer_stitch'] = buffer_stitch + # 返回Graph字典 return graph_desc @@ -506,13 +651,16 @@ class GraphVisitor: """Graph visitor""" def __init__(self, forward=True): + # 初始化forward参数,默认为True self.forward = forward def visit_graph(self, graph): """Visit graph""" + # 如果forward为True,则按照顺序遍历graph中的ops if self.forward: for op in graph.ops: self.visit(op) + # 如果forward为False,则按照逆序遍历graph中的ops else: for i in range(len(graph.ops)-1, -1, -1): self.visit(graph.ops[i]) @@ -528,12 +676,18 @@ class AlignShape(GraphVisitor): def visit(op): """Visit op node""" prim = PrimLib.get_prim(op) + # 如果op的迭代类型是ELEMWISE、BROADCAST或REDUCE,则需要进行形状对齐 if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE): + # 获取op的输出维度 out_dim = len(op.output.shape) + # 初始化对齐维度为输出维度 align_dim = out_dim + # 遍历op的输入 for t in op.inputs: + # 如果输入的维度大于对齐维度,则更新对齐维度 if len(t.shape) > align_dim: align_dim = len(t.shape) + # 如果对齐维度大于输出维度,则对op的输出形状进行对齐 if align_dim > out_dim: op.output.shape = [1] * (align_dim - out_dim) + op.output.shape diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model_builder.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model_builder.py index 7c9414ce..fc55cc42 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/model_builder.py @@ -25,90 +25,140 @@ class GraphBuilder: """Graph wrapper""" def __init__(self, name): + """ + 初始化类实例。 + + Args: + name (str): 图的名称。 + + Attributes: + self.graph (Graph): 图的实例,使用传入的名称初始化。 + """ self.graph = Graph(name, []) def set_input(self, *para): """set input to graph inputs""" + # 遍历传入的参数 for t in para: + # 设置参数类型为输入参数 t.para_type = Tensor.PARA_INPUT + # 将参数添加到图的输入列表中 self.graph.inputs.append(t) def set_output(self, *para): """set output to graph inputs""" + # 遍历传入的参数 for t in para: + # 设置参数类型为输出参数 t.para_type = Tensor.PARA_OUTPUT + # 将参数添加到图的输出列表中 self.graph.outputs.append(t) def __init__(self): + # 初始化图列表 self.graphs = [] + # 当前图设置为None self.current = None + # 初始化名称ID self.name_id = 0 def _alloc_tensor_name(self): + # 获取当前名称ID tid = self.name_id + # 名称ID加1 self.name_id += 1 + # 格式化字符串,生成张量名称 return "t%d" % (tid) def graph_scope(self, name): """The graph scope to be processed""" + # 定义GraphScope类 class GraphScope: """Graph Scope""" def __init__(self, gb): + # 初始化GraphScope对象,接收一个GraphBuilder对象 self.gb = gb def __enter__(self): + # 当使用with语句进入GraphScope上下文时调用 return self.gb.current def __exit__(self, ptype, value, trace): + # 当离开GraphScope上下文时调用 self.gb.graphs.append(self.gb.current.graph) self.gb.current = None + # 检查self.current是否不为None if self.current is not None: raise ValueError("self.current is not None!") + # 创建GraphWrapper对象并赋值给self.current self.current = self.GraphWrapper(name) + # 返回GraphScope对象 return GraphScope(self) def tensor(self, shape, dtype, data_format="DefaultFormat", name=None, para_type=Tensor.PARA_NONE): - """Create a new Tensor""" + """创建一个新的张量""" + # 如果名称为空或None,则分配一个新的张量名称 if name in (None, ''): + # 分配一个新的张量名称 name = self._alloc_tensor_name() + # 如果shape为空,则默认设置为[1] if not shape: shape = [1] + # 返回创建好的张量对象 return Tensor(name, shape, dtype, data_format, para_type=para_type) def value(self, dtype, value, name=None): """Create a new Value""" + # 如果name为None或空字符串 if name in (None, ''): + # 分配一个新的tensor名称 name = self._alloc_tensor_name() + # 创建一个新的Value对象 v = Value(name, dtype, value) + # 返回创建的Value对象 return v def op(self, prim, output, inputs, attrs=None): """Insert an operator into graph""" + # 如果 attrs 为 None,则将其设置为空字典 if attrs is None: attrs = {} + # 如果 inputs 是 Tensor 类型,则将其转换为列表 if isinstance(inputs, Tensor): inputs = [inputs] + # 过滤出 inputs 中 Tensor 类型的元素 tensor_inputs = [t for t in inputs if isinstance(t, Tensor)] + # 创建一个 Operator 对象 node = Operator(prim, tensor_inputs, output, attrs) + # 将所有输入保存到 node 的 all_inputs 属性中 node.all_inputs = inputs + # 将 node 添加到当前图的节点列表中 self.current.graph.add(node) def emit(self, prim, inputs, name=None, attrs=None): """Emit a new operation""" + # 如果attrs为None,则初始化为空字典 if attrs is None: attrs = {} + # 如果inputs是Tensor或Value的实例,则将其转换为列表 if isinstance(inputs, (Tensor, Value)): inputs = [inputs] + # 过滤出inputs中的Tensor和Value实例 tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))] + # 调用op_infer.infer函数进行形状、数据类型和格式的推断 out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs) + # 创建一个新的Tensor实例作为输出 output = self.tensor(out_shape, out_dtype, out_format, name) + # 执行操作,并将结果存储在output中 self.op(prim, output, inputs, attrs) + # 返回操作的结果 return output def get(self): """Get graphs""" + # 返回self.graphs return self.graphs @@ -116,16 +166,21 @@ class CompositeGraph: """Composite Graph""" def __init__(self): + # 初始化图对象,默认为None self.graph = None + # 初始化描述信息,默认为None self.desc = None - self.tensors = {} # name : Tensor + # 初始化张量字典,默认为空字典 + self.tensors = {} def refine(self): """Refine Graph""" + # 对图进行形状对齐操作 AlignShape().visit_graph(self.graph) def load(self, desc): """Load Graph from json""" + # 定义一个内部函数,用于处理操作属性 def _attr_of(op): if not op['attr']: return dict() @@ -134,21 +189,29 @@ class CompositeGraph: if a['name'] == 'axis' and op['name'] in ('ReduceSum', 'ReduceMax', 'ReduceMin', 'Argmax', 'Argmin'): attr['reduce_axis'] = a['value'] else: + # 将属性添加到字典中 attr[a['name']] = a['value'] return attr + # 创建GraphBuilder对象 builder = GraphBuilder() + # 在描述的操作范围内构建图 with builder.graph_scope(desc['op']): + # 遍历输入描述并构建输入张量 for in_desc in desc['input_desc'] if desc['input_desc'] is not None else []: name, shape, dtype, data_format = in_desc[0]['tensor_name'], in_desc[ 0]['shape'], in_desc[0]['data_type'], in_desc[0]['format'] + # 将输入张量添加到tensors字典中 self.tensors[name] = builder.tensor( shape, dtype, data_format, name=name, para_type=Tensor.PARA_INPUT) + # 遍历输出描述并构建输出张量 for out_desc in desc['output_desc']: name, shape, dtype, data_format = out_desc['tensor_name'], out_desc[ 'shape'], out_desc['data_type'], out_desc['format'] + # 将输出张量添加到tensors字典中 self.tensors[name] = builder.tensor( shape, dtype, data_format, name=name, para_type=Tensor.PARA_OUTPUT) + # 遍历操作描述并构建操作 for op in desc['op_desc']: inputs = [self.tensors[d['tensor_name']] for x in op['input_desc'] for d in x if 'value' not in d] out_desc = op['output_desc'] @@ -164,52 +227,77 @@ class CompositeGraph: if not output: output = builder.tensor(shape, dtype, data_format, name=name) self.tensors[name] = output + # 构建操作并添加到图中 builder.op(op['name'], output, inputs, attrs=_attr_of(op)) + # 获取构建好的图 self.graph = builder.get()[0] self.desc = desc def add_stitch_info(self, subgraph, desc): """add stitch info to desc""" + # 如果subgraph包含stitch信息且stitch_ops不为空 if subgraph.stitch_info and subgraph.stitch_info.stitch_ops: + # 创建一个字典,用于存储stitch操作信息 buffer_stitch = {'stitch_op': list(subgraph.stitch_info.stitch_ops)} + # 如果subgraph包含stitch_atomic_ops信息 if subgraph.stitch_info.stitch_atomic_ops: + # 将stitch_atomic_ops信息添加到buffer_stitch字典中 buffer_stitch['stitch_atomic_op'] = list(subgraph.stitch_info.stitch_atomic_ops) + # 将buffer_stitch信息添加到desc字典中 desc['buffer_stitch'] = buffer_stitch return desc def add_recompute_ops(self, subgraph, desc): """add recompute ops to desc""" + # 如果subgraph中包含需要重新计算的操作 if subgraph.recompute_ops: + # 将需要重新计算的操作的输出名称添加到desc中 desc['recompute_ops'] = [op.output.name for op in subgraph.recompute_ops] return desc def _pre_dump(self, outputs): """restore name to before load""" + # 创建一个空字典用于存储inplace赋值操作 inplace_assign = {} # y_name, output_name inplace_assign_z = None + # 遍历self.desc['op_desc']中的操作 for op in self.desc['op_desc']: + # 如果操作名称为'InplaceAssign' if op['name'] == 'InplaceAssign': + # 将inplace赋值操作的输入tensor名作为键,输出tensor名作为值,存入inplace_assign字典 inplace_assign[op['input_desc'][1][0]['tensor_name']] = op['output_desc'][0]['tensor_name'] + # 如果inplace_assign字典不为空 if inplace_assign: + # 遍历outputs中的tensor for t in outputs: + # 如果当前tensor的名称不在inplace_assign字典中 if t.name not in inplace_assign: + # 将当前tensor赋值给inplace_assign_z inplace_assign_z = t + # 返回inplace_assign和inplace_assign_z return inplace_assign, inplace_assign_z def dump(self, subgraph): """Dump Graph to json""" desc = {} + # 获取输入和输出参数 inputs, outputs = subgraph.deduce_parameters() + # 获取图中的所有操作 graph_ops = set(subgraph.ops) + # 预处理输出参数 inplace_assign, inplace_assign_z = self._pre_dump(outputs) def dump_output(t): + # 如果输出参数是原地赋值操作的结果 if t.name in inplace_assign: z = inplace_assign_z if inplace_assign_z is not None else self.tensors[t.name] + # 返回包含数据类型、形状和张量名称的字典 return {'data_type': z.dtype, 'shape': z.shape, 'tensor_name': inplace_assign.get(t.name)} + # 返回包含数据类型、形状和张量名称的字典 return {'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name} def dump_op_desc(d): + # 如果操作是原地赋值操作 if d['name'] == 'InplaceAssign': y = d['input_desc'][1][0]['tensor_name'] if self.tensors[y].op in graph_ops: @@ -222,33 +310,50 @@ class CompositeGraph: z_desc['tensor_name'] = z.name out_desc['shape'] = z.shape out_desc['data_type'] = z.dtype + # 返回处理后的原地赋值操作描述 return inplace_desc + # 获取操作对应的张量 op = self.tensors[d['output_desc'][0]['tensor_name']].op + # 如果操作在图操作集或重新计算操作集中 if op in graph_ops or op in subgraph.recompute_ops: + # 返回操作描述 return d + # 返回None return None for key in self.desc.keys(): if key == 'input_desc': + # 处理输入描述 desc[key] = [[{'data_type': t.dtype, 'shape': t.shape, 'tensor_name': t.name}] for t in inputs] elif key == 'output_desc': + # 处理输出描述 desc[key] = list(map(dump_output, outputs)) elif key == 'op_desc': + # 处理操作描述 op_desc = map(dump_op_desc, self.desc[key]) desc[key] = [d for d in op_desc if d is not None] elif key == 'op': + # 处理操作名称 desc[key] = subgraph.name else: + # 处理其他描述 desc[key] = self.desc[key] + # 添加缝合信息 desc = self.add_stitch_info(subgraph, desc) + # 添加重新计算操作信息 desc = self.add_recompute_ops(subgraph, desc) + # 返回最终描述 return desc def load_composite(desc): """Load composite kernel""" + # 创建一个CompositeGraph对象 composite = CompositeGraph() + # 加载描述信息 composite.load(desc) + # 对加载的CompositeGraph进行细化 composite.refine() + # 返回处理后的CompositeGraph对象 return composite diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/op_infer.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/op_infer.py index d30ee032..3ad25bf4 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/op_infer.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/model/op_infer.py @@ -25,24 +25,32 @@ def infer(op_name, inputs, attrs): """infer shape dtype and format""" def _create_opinfer(): + # 获取当前模块 self_module = sys.modules.get(__name__, None) + # 如果当前模块为空,则抛出异常 if self_module is None: raise GKException("OpInfo does not support op {}".format(op_name)) + # 如果当前模块有op_name属性,则获取该属性 if hasattr(self_module, op_name): op_cls = getattr(self_module, op_name) return op_cls(op_name, inputs, attrs) # common infer + # 定义一个字典,将PrimLib中的iter_type映射到对应的类名 class_name_map = { PrimLib.ELEMWISE: "_Elemwise", PrimLib.REDUCE: "_Reduce", } + # 获取op_name对应的iter_type cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None) + # 如果没有对应的iter_type,则抛出异常 if not cls_name: raise GKException("OpInfo does not support op {}".format(op_name)) + # 获取对应的类 op_cls = getattr(self_module, cls_name) return op_cls(op_name, inputs, attrs) + # 返回infer方法 return _create_opinfer().infer() @@ -55,49 +63,65 @@ class OpInfer: """ def __init__(self, name, inputs, attrs): + # 初始化函数,传入参数name、inputs、attrs self.name = name self.inputs = inputs self.attrs = attrs def infer(self): """Infer shape, type and format by op inputs""" + # 根据op的输入推断shape、type和format self._check() return self._infer_shape(), self._infer_type(), self._infer_format() def _infer_shape(self): + # 根据op的输入推断shape return self.inputs[0].shape def _infer_type(self): + # 根据op的输入推断type return self.inputs[0].dtype def _infer_format(self): + # 根据op的输入推断format return self.inputs[0].data_format def _check(self): + # 检查shape、type和format self._check_shape() self._check_type() self._check_format() def _check_shape(self): + # 检查shape pass def _check_type(self): """check all dtypes are same""" + # 获取第一个输入的dtype dtype = self.inputs[0].dtype + # 遍历剩下的输入 for i, t in enumerate(self.inputs[1:]): + # 如果当前输入的dtype与第一个输入的dtype不同,则抛出异常 if t.dtype != dtype: raise GKException( "Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype)) def _check_format(self): """check formats are compatible. only DefaultFormat is compatible with others""" + # 获取第一个输入的data_format result = self.inputs[0].data_format + # 初始化i为0 i = 0 + # 遍历剩下的输入 for j, t in enumerate(self.inputs[1:]): + # 如果当前输入的data_format与第一个输入的data_format不同,则进行判断 if t.data_format != result: + # 如果第一个输入的data_format和当前输入的data_format都不是DefaultFormat,则抛出异常 if DF.DEFAULT not in (result, t.data_format): raise GKException("Incompatible format between input {}({}) and {}({})".format( i, result, j + 1, t.data_format)) + # 如果第一个输入的data_format是DefaultFormat,则将result设置为当前输入的data_format,并将i设置为j+1 if result == DF.DEFAULT: result = t.data_format i = j + 1 @@ -109,17 +133,26 @@ class _Elemwise(OpInfer): @staticmethod def broadcast_shape(shapes): """deduce broadcast shape using same rules as numpy""" + # 计算所有shape的最大维度 dim_size = max(len(shape) for shape in shapes) + # 将所有shape扩展到最大维度,不足的部分用1填充 align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes] + # 初始化输出shape为全1 out_shape = [1] * dim_size + # 遍历每个维度 for i in range(dim_size): + # 遍历每个shape for align_shape in align_shapes: + # 如果当前维度为1,则跳过 if align_shape[i] == 1: continue + # 如果输出shape当前维度为1,则将输出shape当前维度设置为当前shape当前维度的值 if out_shape[i] == 1: out_shape[i] = align_shape[i] + # 如果输出shape当前维度和当前shape当前维度不相等,则抛出异常 elif out_shape[i] != align_shape[i]: raise GKException("Input shapes {} can not broadcast.".format(shapes)) + # 返回输出shape return out_shape @staticmethod @@ -174,9 +207,12 @@ class _Elemwise(OpInfer): .format(inputs_format)) def _infer_format(self): + # 遍历输入张量 for tensor in self.inputs: + # 如果张量的数据格式不是默认格式,则返回该数据格式 if tensor.data_format != DF.DEFAULT: return tensor.data_format + # 如果所有输入张量的数据格式都是默认格式,则返回默认格式 return DF.DEFAULT @@ -184,6 +220,7 @@ class _Reduce(OpInfer): """Common infer for reduction operators""" def _check(self): + # 调用父类的方法 super(_Reduce, self)._check() # check reduce axis in the range [-len, len) shape_len = len(self.inputs[0].shape) @@ -195,21 +232,29 @@ class _Reduce(OpInfer): "Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis)) def _infer_shape(self): + # 深度拷贝输入的形状 shape = copy.deepcopy(self.inputs[0].shape) + # 获取reduce_axis属性 axis = self.attrs['reduce_axis'] + # 如果axis是整数,则将其转换为列表 if isinstance(axis, int): axis = [axis] + # 如果axis中的元素小于0,则将其转换为非负数 if any(i < 0 for i in axis): # change the axis to non-negative number. + # 将axis中的负数转换为正数 axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis)) + # 将axis排序 self.attrs['reduce_axis'] = sorted(axis) + # 如果keep_dims为True,则将axis中的维度设置为1 if self.attrs['keep_dims']: for i in axis: shape[i] = 1 return shape + # 如果keep_dims为False,则将axis中的维度从shape中移除 real_shape = [] for i, s in enumerate(shape): if i not in axis: @@ -223,10 +268,14 @@ class _Reduce(OpInfer): class _Reshape(OpInfer): """Common infer for reshape operators, should not be instantiated""" + # 定义一个函数,用于推断形状 def _infer_shape(self): + # 抛出一个异常,提示子类需要实现这个函数 raise GKException("_infer_shape should be implemented by subclass") + # 定义一个函数,用于推断格式 def _infer_format(self): + # 如果attrs中不存在"format"这个属性,则返回DF.DEFAULT,否则返回attrs中"format"的值 return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"] @@ -234,14 +283,20 @@ class Reshape(_Reshape): """Reshape op infer""" def _check_shape(self): + # 获取输入形状 input_shape = self.inputs[0].shape + # 获取输出形状 output_shape = self.attrs["shape"] + # 计算输入形状的乘积 size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape) + # 计算输出形状的乘积 size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape) + # 如果输入形状的乘积不等于输出形状的乘积,则抛出异常 if size_before_reshape != size_after_reshape: raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape)) def _infer_shape(self): + # 返回输出形状 return self.attrs["shape"] @@ -249,6 +304,7 @@ class Cast(_Elemwise): """Cast op infer""" def _infer_type(self): + # 返回dst_type属性 return self.attrs["dst_type"] @@ -256,29 +312,38 @@ class InplaceAssign(_Elemwise): """InplaceAssign op infer""" def _infer_shape(self): + # 返回第3个输入的shape属性 return self.inputs[2].shape def _infer_type(self): + # 返回第3个输入的dtype属性 return self.inputs[2].dtype def _infer_format(self): + # 返回第3个输入的data_format属性 return self.inputs[2].data_format class BroadcastTo(OpInfer): """BroadcastTo op infer""" + # 定义一个函数,用于推断形状 def _infer_shape(self): + # 返回self.attrs字典中的"shape"键对应的值 return self.attrs["shape"] + # 定义一个函数,用于推断格式 def _infer_format(self): + # 返回self.inputs列表中第一个元素的data_format属性 return self.inputs[0].data_format class _CompareOp(_Elemwise): """Compare operators""" + # 定义一个函数,用于推断类型 def _infer_type(self): + # 返回类型为bool return "bool" @@ -286,11 +351,14 @@ class CImag(OpInfer): """CImag op infer""" def _check_type(self): + # 检查输入数据的类型是否为complex64 if self.inputs[0].dtype != "complex64": + # 如果不是,则抛出异常 raise GKException( "For 'CImag', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype)) def _infer_type(self): + # 返回数据类型为float32 return "float32" @@ -298,11 +366,14 @@ class CReal(OpInfer): """CReal op infer""" def _check_type(self): + # 检查输入数据的类型是否为complex64 if self.inputs[0].dtype != "complex64": + # 如果不是,则抛出异常 raise GKException( "For 'CReal', input[0] should be of type complex64, but got {}".format(self.inputs[0].dtype)) def _infer_type(self): + # 返回数据类型为float32 return "float32" @@ -310,14 +381,17 @@ class Complex(OpInfer): """Complex op infer""" def _check_type(self): + # 检查输入数据的类型是否为float32 if self.inputs[0].dtype != "float32": raise GKException( "For 'Complex', input[0] should be of type float32, but got {}".format(self.inputs[0].dtype)) + # 检查输入数据的类型是否一致 if self.inputs[0].dtype != self.inputs[1].dtype: raise GKException("For 'Complex', inputs data type mismatch ({} vs {})" .format(self.inputs[0].dtype, self.inputs[1].dtype)) def _infer_type(self): + # 返回复数类型 return "complex64" @@ -345,40 +419,53 @@ class Select(_Elemwise): """Select op infer""" def _check_type(self): + # 检查输入数据的类型 if self.inputs[0].dtype != "bool": + # 如果输入数据的类型不是bool,则抛出异常 raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype)) if self.inputs[1].dtype != self.inputs[2].dtype: + # 如果输入数据的类型不一致,则抛出异常 raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})" .format(self.inputs[1].dtype, self.inputs[2].dtype)) def _infer_type(self): + # 推断输入数据的类型 return self.inputs[1].dtype def check_format_any(formats, checked_format): """Check whether input format in formats list""" + # 检查输入格式是否在formats列表中 if not isinstance(formats, (list, tuple)): + # 如果formats不是list或tuple类型,则抛出异常 raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats))) if checked_format not in formats: + # 如果checked_format不在formats列表中,则抛出异常 raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats)) def check_nd(data, nd): """Check whether data are nd format""" + # 检查数据是否为nd格式 if not isinstance(data, (list, tuple)) or len(data) != nd: + # 如果数据不是list或tuple类型,或者数据的维度不等于nd,则抛出异常 raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data)) def conv_had_pad(pad_list, pad_mode): """Check whether conv need to add pad""" + # 检查pad_list是否为4D list或tuple,如果不是则抛出异常 if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4: raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list)) + # 如果pad_list的前两个元素不相等或后两个元素不相等,则返回True if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]: return True + # 如果pad_mode不是"VALID"或"valid",则遍历pad_list,如果有元素不为0,则返回True if pad_mode not in ["VALID", "valid"]: for _, pad in enumerate(pad_list): if pad != 0: return True + # 否则返回False return False @@ -386,38 +473,50 @@ class Conv2D(OpInfer): """Conv2D infer""" def _infer_type(self): + # 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值,否则返回输入的第一个元素的dtype if isinstance(self.attrs, dict) and "dst_type" in self.attrs: return self.attrs["dst_type"] return self.inputs[0].dtype def _infer_shape(self): + # 将输入的第一个和第二个元素的shape转换为list shape_0 = list(self.inputs[0].shape) shape_1 = list(self.inputs[1].shape) + # 检查shape_0和shape_1的维度是否为4 check_nd(shape_0, 4) check_nd(shape_1, 4) + # 检查输入的data_format是否为NHWC formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]] check_format_any(formats, DF.NHWC) + # 获取输入的n、h、w和out_channel n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0] + # 获取pad_list和pad_mode pad_list = self.attrs["pad_list"] pad_mode = self.attrs["pad_mode"] + # 获取kernel_size、stride和dilation kernel_size = self.attrs["kernel_size"] stride = self.attrs["stride"] dilation = self.attrs["dilation"] + # 检查pad_list、kernel_size、stride和dilation的维度是否为4、2、4和4 check_nd(pad_list, 4) check_nd(kernel_size, 2) check_nd(stride, 4) check_nd(dilation, 4) + # 调用conv_had_pad函数,判断是否需要pad has_pad = conv_had_pad(pad_list, pad_mode) + # 如果不需要pad,则将pad_list设置为[0, 0, 0, 0] if not has_pad: pad_list = [0, 0, 0, 0] + # 计算输出的h和w k_h = (kernel_size[0] - 1) * dilation[-2] + 1 k_w = (kernel_size[1] - 1) * dilation[-1] + 1 out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1 out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1 + # 返回输出的shape return [n, out_h, out_w, out_channel] @@ -425,23 +524,31 @@ class MatMul(OpInfer): """MatMul infer""" def _infer_type(self): + # 如果attrs是dict类型且包含"dst_type"键,则返回"dst_type"的值,否则返回输入的第一个元素的dtype if isinstance(self.attrs, dict) and "dst_type" in self.attrs: return self.attrs["dst_type"] return self.inputs[0].dtype def _infer_shape(self): + # 将输入的第一个和第二个元素的shape转换为list shape_0 = list(self.inputs[0].shape) shape_1 = list(self.inputs[1].shape) + # 检查shape_0和shape_1的维度是否为2 if len(shape_0) != 2 or len(shape_1) != 2: raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}" .format(shape_0, shape_1)) + # 获取transpose_a和transpose_b transpose_a = self.attrs["transpose_a"] transpose_b = self.attrs["transpose_b"] + # 根据transpose_a和transpose_b获取m、k1、k2和n m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1]) k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1]) + # 如果k1和k2不相等,则抛出异常 if k1 != k2: raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2)) + # 计算输出的shape output_shape = [m, n] + # 返回输出的shape return output_shape @@ -449,14 +556,20 @@ class PadAkg(OpInfer): """PadAkg infer""" def _infer_shape(self): + # 将输入的第一个元素的shape转换为list shape = list(self.inputs[0].shape) + # 获取输入的维度 n = len(shape) + # 获取pad_before和pad_after pad_before = list(self.attrs["head"]) pad_after = list(self.attrs["tail"]) + # 检查pad_before和pad_after的维度是否与输入的维度相等 if len(pad_before) != n or len(pad_after) != n: raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d" .format(n, len(pad_before), len(pad_after))) + # 计算输出的shape out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)] + # 返回输出的shape return out_shape @@ -464,13 +577,19 @@ class UnPadAkg(OpInfer): """UnPadAkg infer""" def _infer_shape(self): + # 将输入的第一个元素的shape转换为list shape = list(self.inputs[0].shape) + # 获取输入的维度 n = len(shape) + # 获取unpad_after unpad_after = list(self.attrs["tail"]) + # 检查unpad_after的维度是否与输入的维度相等 if len(unpad_after) != n: raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d" .format(n, len(unpad_after))) + # 计算输出的shape out_shape = [shape[i] - unpad_after[i] for i in range(n)] + # 返回输出的shape return out_shape @@ -478,23 +597,32 @@ class Gather(OpInfer): """Gather infer""" def _infer_shape(self): + # 获取输入的第一个和第二个元素的shape input_shape = self.inputs[0].shape indices_shape = self.inputs[1].shape + # 获取axis axis = self.attrs['axis'] + # 将输出的shape设置为输入的第一个元素的shape output_shape = input_shape + # 计算indices_shape的维度 indices_shape_one_dim = 1 for dim in indices_shape: indices_shape_one_dim *= dim + # 将输出的shape的axis维度设置为indices_shape的维度 output_shape[axis] = indices_shape_one_dim + # 返回输出的shape return output_shape def _infer_type(self): + # 返回输入的第一个元素的dtype return self.inputs[0].dtype def _infer_format(self): + # 返回输入的第一个元素的data_format return self.inputs[0].data_format def _check_type(self): + # 检查输入的第二个元素的dtype是否为int32,如果不是则抛出异常 if self.inputs[1].dtype != "int32": raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}" .format(self.inputs[1].dtype))