diff --git a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/parallel_estimate.py b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/parallel_estimate.py index c3b307cf..a96421c0 100644 --- a/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/parallel_estimate.py +++ b/src/mindspore2022/mindspore/python/mindspore/_extends/graph_kernel/parallel_estimate.py @@ -21,24 +21,48 @@ from . import model def estimate_ops(json_str): + """ + 估计操作数。 + + Args: + json_str (str): 包含图描述的json字符串。 + + Returns: + tuple: 包含估计结果的元组,包括块分配、增益、融合类型和类型信息的元组。 + + Raises: + JSONDecodeError: 如果输入的json字符串无法解码,将引发此异常。 + + """ """Call cost model to estimate ops.""" try: + # 将json字符串转换为json对象 json_obj = json.loads(json_str) + # 获取json对象中的graph_desc graph_descs = json_obj["graph_desc"] + # 初始化graphs和target graphs = [] target = None + # 遍历graph_descs for gd in graph_descs: + # 如果target为空,则将gd['process']赋值给target if target is None: target = gd['process'] + # 如果target不为空,且gd['process']与target不同,则输出错误信息 elif target != gd['process']: logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process'])) return None + # 将model.load_composite(gd).graph添加到graphs中 graphs.append(model.load_composite(gd).graph) + # 调用model.parallel_estimate函数,传入graphs和target,获取estimation estimation = model.parallel_estimate(graphs, target) + # 将estimation的block_assign、gain、fusion_type和type_info赋值给res res = (estimation.block_assign, estimation.gain, estimation.fusion_type, estimation.type_info) + # 返回res return res except jd.JSONDecodeError: + # 如果出现JSONDecodeError,则输出错误信息 logger.error(traceback.format_exc()) return None finally: @@ -46,14 +70,33 @@ def estimate_ops(json_str): def estimate_calculation_amount(json_str): + """ + 估计操作计算量的函数。 + + Args: + json_str (str): 包含操作描述的JSON字符串。 + + Returns: + int: 计算量的估计值,如果解析JSON字符串失败,则返回-1。 + + Raises: + 无 + + """ """Call cost model to estimate calculation amount of op.""" try: + # 将json字符串转换为json对象 graph_desc = json.loads(json_str) + # 获取json对象中的process target = graph_desc['process'] + # 调用model.load_composite函数,传入graph_desc,获取comp comp = model.load_composite(graph_desc) + # 调用model.parallel_estimate函数,传入comp.graph和target,获取estimation estimation = model.parallel_estimate([comp.graph], target) + # 返回estimation的bottleneck return estimation.bottleneck except jd.JSONDecodeError: + # 如果出现JSONDecodeError,则输出错误信息 logger.error(traceback.format_exc()) return -1 finally: