|
|
@ -12,7 +12,9 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
"""akg process"""
|
|
|
|
"""akg process"""
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import subprocess
|
|
|
|
import subprocess
|
|
|
@ -24,10 +26,10 @@ from mindspore._extends.parallel_compile.akg_compiler.get_file_path import get_a
|
|
|
|
|
|
|
|
|
|
|
|
def _compile_akg_task_default(json_strs, attrs):
|
|
|
|
def _compile_akg_task_default(json_strs, attrs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
compile func called in single process
|
|
|
|
编译函数,在单个进程中调用
|
|
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
参数:
|
|
|
|
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
|
|
|
|
json_strs:列表。包含多个内核信息的列表,适用于json编译API。
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.insert(0, get_akg_path())
|
|
|
|
sys.path.insert(0, get_akg_path())
|
|
|
@ -37,15 +39,15 @@ def _compile_akg_task_default(json_strs, attrs):
|
|
|
|
for json_str in json_strs:
|
|
|
|
for json_str in json_strs:
|
|
|
|
res = func(json_str, attrs)
|
|
|
|
res = func(json_str, attrs)
|
|
|
|
if not res:
|
|
|
|
if not res:
|
|
|
|
raise ValueError("Compile error, args: {}! build attrs: {}".format(json_str, attrs))
|
|
|
|
raise ValueError("编译错误,参数:{}!构建属性:{}".format(json_str, attrs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _compile_akg_task_ascend(json_strs, attrs):
|
|
|
|
def _compile_akg_task_ascend(json_strs, attrs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
compile func called in single process
|
|
|
|
编译函数,在单个进程中调用
|
|
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
参数:
|
|
|
|
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
|
|
|
|
json_strs:列表。包含多个内核信息的列表,适用于json编译API。
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if attrs is None:
|
|
|
|
if attrs is None:
|
|
|
|
attrs = "{}"
|
|
|
|
attrs = "{}"
|
|
|
@ -56,35 +58,33 @@ def _compile_akg_task_ascend(json_strs, attrs):
|
|
|
|
if compile_result.returncode:
|
|
|
|
if compile_result.returncode:
|
|
|
|
json_dict = json.loads(json_str)
|
|
|
|
json_dict = json.loads(json_str)
|
|
|
|
if not json_dict.get("composite"):
|
|
|
|
if not json_dict.get("composite"):
|
|
|
|
raise ValueError("Compile error, json str: {}! build attrs: {}".format(json_str, attrs))
|
|
|
|
raise ValueError("编译错误,json字符串:{}!构建属性:{}".format(json_str, attrs))
|
|
|
|
logger.debug("Will try to split, json str: {}! build attrs: {}".format(json_str, attrs))
|
|
|
|
logger.debug("将尝试拆分,json字符串:{}!构建属性:{}".format(json_str, attrs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_akg_parallel_process(process_num, wait_time, platform):
|
|
|
|
def create_akg_parallel_process(process_num, wait_time, platform):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
create AkgParallelCompiler object
|
|
|
|
创建AkgParallelCompiler对象
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
返回:
|
|
|
|
AkgParallelCompiler
|
|
|
|
AkgParallelCompiler
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return AkgProcess(process_num, wait_time, platform)
|
|
|
|
return AkgProcess(process_num, wait_time, platform)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AkgProcess:
|
|
|
|
class AkgProcess:
|
|
|
|
"""akg kernel parallel process"""
|
|
|
|
"""akg内核并行进程"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, process_num, wait_time, platform):
|
|
|
|
def __init__(self, process_num, wait_time, platform):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
参数:
|
|
|
|
process_num: int. processes number
|
|
|
|
process_num:int。进程数量
|
|
|
|
wait_time: int. max time the function blocked
|
|
|
|
wait_time:int。函数阻塞的最大时间
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if not isinstance(process_num, int):
|
|
|
|
if not isinstance(process_num, int):
|
|
|
|
raise ValueError("AKG kernel compiling process number must be of type int, but got {} with type {}"
|
|
|
|
raise ValueError("AKG内核编译进程数量必须是int类型,但得到的是{},类型为{}".format(process_num, type(wait_time)))
|
|
|
|
.format(process_num, type(wait_time)))
|
|
|
|
|
|
|
|
if not isinstance(wait_time, int):
|
|
|
|
if not isinstance(wait_time, int):
|
|
|
|
raise ValueError("AKG kernel compiling wait time must be of type int, but got {} with type {}"
|
|
|
|
raise ValueError("AKG内核编译等待时间必须是int类型,但得到的是{},类型为{}".format(wait_time, type(wait_time)))
|
|
|
|
.format(wait_time, type(wait_time)))
|
|
|
|
|
|
|
|
if process_num == 0:
|
|
|
|
if process_num == 0:
|
|
|
|
process_num = 1
|
|
|
|
process_num = 1
|
|
|
|
max_proc_num = 16
|
|
|
|
max_proc_num = 16
|
|
|
@ -96,13 +96,12 @@ class AkgProcess:
|
|
|
|
|
|
|
|
|
|
|
|
def compile(self, attrs=None):
|
|
|
|
def compile(self, attrs=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
compile kernel by multi processes
|
|
|
|
多进程编译内核
|
|
|
|
Return:
|
|
|
|
返回:
|
|
|
|
True for all compile success, False for some failed.
|
|
|
|
所有编译成功返回True,部分失败返回False。
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if self.argc == 0:
|
|
|
|
if self.argc == 0:
|
|
|
|
raise ValueError("In AKG kernel compiling, the number of kernel json that need to be compiled can "
|
|
|
|
raise ValueError("在AKG内核编译中,需要编译的内核json数量不能为零。")
|
|
|
|
"not be zero.")
|
|
|
|
|
|
|
|
args = list((arg, attrs) for arg in self.args)
|
|
|
|
args = list((arg, attrs) for arg in self.args)
|
|
|
|
if self.platform == "ASCEND":
|
|
|
|
if self.platform == "ASCEND":
|
|
|
|
with Pool(processes=self.process_num) as pool:
|
|
|
|
with Pool(processes=self.process_num) as pool:
|
|
|
@ -116,12 +115,11 @@ class AkgProcess:
|
|
|
|
|
|
|
|
|
|
|
|
def accept_json(self, json_str):
|
|
|
|
def accept_json(self, json_str):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
accept json data before compile
|
|
|
|
在编译前接受内核的json数据
|
|
|
|
Args:
|
|
|
|
参数:
|
|
|
|
json_str: str. kernel info.
|
|
|
|
json_str:str。内核信息。
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if not isinstance(json_str, str):
|
|
|
|
if not isinstance(json_str, str):
|
|
|
|
raise ValueError("In AKG kernel compiling, the kernel json must be of type str, but got {} with type {}"
|
|
|
|
raise ValueError("在AKG内核编译中,内核json必须是str类型,但得到的是{},类型为{}".format(json_str, type(json_str)))
|
|
|
|
.format(json, type(json)))
|
|
|
|
|
|
|
|
self.args[self.argc % self.process_num].append(json_str)
|
|
|
|
self.args[self.argc % self.process_num].append(json_str)
|
|
|
|
self.argc += 1
|
|
|
|
self.argc += 1
|
|
|
|