|
|
|
@ -21,24 +21,48 @@ from . import model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def estimate_ops(json_str):
|
|
|
|
|
"""
|
|
|
|
|
估计操作数。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
json_str (str): 包含图描述的json字符串。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
tuple: 包含估计结果的元组,包括块分配、增益、融合类型和类型信息的元组。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
JSONDecodeError: 如果输入的json字符串无法解码,将引发此异常。
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Call cost model to estimate ops."""
|
|
|
|
|
try:
|
|
|
|
|
# 将json字符串转换为json对象
|
|
|
|
|
json_obj = json.loads(json_str)
|
|
|
|
|
# 获取json对象中的graph_desc
|
|
|
|
|
graph_descs = json_obj["graph_desc"]
|
|
|
|
|
# 初始化graphs和target
|
|
|
|
|
graphs = []
|
|
|
|
|
target = None
|
|
|
|
|
# 遍历graph_descs
|
|
|
|
|
for gd in graph_descs:
|
|
|
|
|
# 如果target为空,则将gd['process']赋值给target
|
|
|
|
|
if target is None:
|
|
|
|
|
target = gd['process']
|
|
|
|
|
# 如果target不为空,且gd['process']与target不同,则输出错误信息
|
|
|
|
|
elif target != gd['process']:
|
|
|
|
|
logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process']))
|
|
|
|
|
return None
|
|
|
|
|
# 将model.load_composite(gd).graph添加到graphs中
|
|
|
|
|
graphs.append(model.load_composite(gd).graph)
|
|
|
|
|
# 调用model.parallel_estimate函数,传入graphs和target,获取estimation
|
|
|
|
|
estimation = model.parallel_estimate(graphs, target)
|
|
|
|
|
# 将estimation的block_assign、gain、fusion_type和type_info赋值给res
|
|
|
|
|
res = (estimation.block_assign, estimation.gain,
|
|
|
|
|
estimation.fusion_type, estimation.type_info)
|
|
|
|
|
# 返回res
|
|
|
|
|
return res
|
|
|
|
|
except jd.JSONDecodeError:
|
|
|
|
|
# 如果出现JSONDecodeError,则输出错误信息
|
|
|
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
return None
|
|
|
|
|
finally:
|
|
|
|
@ -46,14 +70,33 @@ def estimate_ops(json_str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def estimate_calculation_amount(json_str):
|
|
|
|
|
"""
|
|
|
|
|
估计操作计算量的函数。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
json_str (str): 包含操作描述的JSON字符串。
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int: 计算量的估计值,如果解析JSON字符串失败,则返回-1。
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
无
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
"""Call cost model to estimate calculation amount of op."""
|
|
|
|
|
try:
|
|
|
|
|
# 将json字符串转换为json对象
|
|
|
|
|
graph_desc = json.loads(json_str)
|
|
|
|
|
# 获取json对象中的process
|
|
|
|
|
target = graph_desc['process']
|
|
|
|
|
# 调用model.load_composite函数,传入graph_desc,获取comp
|
|
|
|
|
comp = model.load_composite(graph_desc)
|
|
|
|
|
# 调用model.parallel_estimate函数,传入comp.graph和target,获取estimation
|
|
|
|
|
estimation = model.parallel_estimate([comp.graph], target)
|
|
|
|
|
# 返回estimation的bottleneck
|
|
|
|
|
return estimation.bottleneck
|
|
|
|
|
except jd.JSONDecodeError:
|
|
|
|
|
# 如果出现JSONDecodeError,则输出错误信息
|
|
|
|
|
logger.error(traceback.format_exc())
|
|
|
|
|
return -1
|
|
|
|
|
finally:
|
|
|
|
|