add comments for _extends/graph_kernel/parrellel_estimate.py

branch-yixin
yixin 7 months ago
parent bfa789008e
commit 5437cea5c8

@ -21,24 +21,48 @@ from . import model
def estimate_ops(json_str):
"""
估计操作数
Args:
json_str (str): 包含图描述的json字符串
Returns:
tuple: 包含估计结果的元组包括块分配增益融合类型和类型信息的元组
Raises:
JSONDecodeError: 如果输入的json字符串无法解码将引发此异常
"""
"""Call cost model to estimate ops."""
try:
# 将json字符串转换为json对象
json_obj = json.loads(json_str)
# 获取json对象中的graph_desc
graph_descs = json_obj["graph_desc"]
# 初始化graphs和target
graphs = []
target = None
# 遍历graph_descs
for gd in graph_descs:
# 如果target为空则将gd['process']赋值给target
if target is None:
target = gd['process']
# 如果target不为空且gd['process']与target不同则输出错误信息
elif target != gd['process']:
logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process']))
return None
# 将model.load_composite(gd).graph添加到graphs中
graphs.append(model.load_composite(gd).graph)
# 调用model.parallel_estimate函数传入graphs和target获取estimation
estimation = model.parallel_estimate(graphs, target)
# 将estimation的block_assign、gain、fusion_type和type_info赋值给res
res = (estimation.block_assign, estimation.gain,
estimation.fusion_type, estimation.type_info)
# 返回res
return res
except jd.JSONDecodeError:
# 如果出现JSONDecodeError则输出错误信息
logger.error(traceback.format_exc())
return None
finally:
@ -46,14 +70,33 @@ def estimate_ops(json_str):
def estimate_calculation_amount(json_str):
"""
估计操作计算量的函数
Args:
json_str (str): 包含操作描述的JSON字符串
Returns:
int: 计算量的估计值如果解析JSON字符串失败则返回-1
Raises:
"""
"""Call cost model to estimate calculation amount of op."""
try:
# 将json字符串转换为json对象
graph_desc = json.loads(json_str)
# 获取json对象中的process
target = graph_desc['process']
# 调用model.load_composite函数传入graph_desc获取comp
comp = model.load_composite(graph_desc)
# 调用model.parallel_estimate函数传入comp.graph和target获取estimation
estimation = model.parallel_estimate([comp.graph], target)
# 返回estimation的bottleneck
return estimation.bottleneck
except jd.JSONDecodeError:
# 如果出现JSONDecodeError则输出错误信息
logger.error(traceback.format_exc())
return -1
finally:

Loading…
Cancel
Save