_extends\parallel_compile\akg_compiler

branch-yixin
yixin 7 months ago
parent cd3f01ab90
commit 0012f23abf

@ -16,4 +16,8 @@
Extension functions. Extension functions.
Python functions that will be called in the c++ parts of MindSpore. Python functions that will be called in the c++ parts of MindSpore.
扩展函数
这些Python函数将在MindSpore的C++部分中被调用
""" """

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""akg process""" """akg process"""
import os import os
import json import json
import subprocess import subprocess
@ -24,10 +26,10 @@ from mindspore._extends.parallel_compile.akg_compiler.get_file_path import get_a
def _compile_akg_task_default(json_strs, attrs): def _compile_akg_task_default(json_strs, attrs):
""" """
compile func called in single process 编译函数在单个进程中调用
Parameters: 参数
json_strs: list. List contains multiple kernel infos, suitable for json compile api. json_strs列表包含多个内核信息的列表适用于json编译API
""" """
sys.path.insert(0, get_akg_path()) sys.path.insert(0, get_akg_path())
@ -37,15 +39,15 @@ def _compile_akg_task_default(json_strs, attrs):
for json_str in json_strs: for json_str in json_strs:
res = func(json_str, attrs) res = func(json_str, attrs)
if not res: if not res:
raise ValueError("Compile error, args: {}! build attrs: {}".format(json_str, attrs)) raise ValueError("编译错误,参数:{}!构建属性:{}".format(json_str, attrs))
def _compile_akg_task_ascend(json_strs, attrs): def _compile_akg_task_ascend(json_strs, attrs):
""" """
compile func called in single process 编译函数在单个进程中调用
Parameters: 参数
json_strs: list. List contains multiple kernel infos, suitable for json compile api. json_strs列表包含多个内核信息的列表适用于json编译API
""" """
if attrs is None: if attrs is None:
attrs = "{}" attrs = "{}"
@ -56,35 +58,33 @@ def _compile_akg_task_ascend(json_strs, attrs):
if compile_result.returncode: if compile_result.returncode:
json_dict = json.loads(json_str) json_dict = json.loads(json_str)
if not json_dict.get("composite"): if not json_dict.get("composite"):
raise ValueError("Compile error, json str: {}! build attrs: {}".format(json_str, attrs)) raise ValueError("编译错误json字符串{}!构建属性:{}".format(json_str, attrs))
logger.debug("Will try to split, json str: {}! build attrs: {}".format(json_str, attrs)) logger.debug("将尝试拆分json字符串{}!构建属性:{}".format(json_str, attrs))
def create_akg_parallel_process(process_num, wait_time, platform): def create_akg_parallel_process(process_num, wait_time, platform):
""" """
create AkgParallelCompiler object 创建AkgParallelCompiler对象
Returns: 返回
AkgParallelCompiler AkgParallelCompiler
""" """
return AkgProcess(process_num, wait_time, platform) return AkgProcess(process_num, wait_time, platform)
class AkgProcess: class AkgProcess:
"""akg kernel parallel process""" """akg内核并行进程"""
def __init__(self, process_num, wait_time, platform): def __init__(self, process_num, wait_time, platform):
""" """
Args: 参数
process_num: int. processes number process_numint进程数量
wait_time: int. max time the function blocked wait_timeint函数阻塞的最大时间
""" """
if not isinstance(process_num, int): if not isinstance(process_num, int):
raise ValueError("AKG kernel compiling process number must be of type int, but got {} with type {}" raise ValueError("AKG内核编译进程数量必须是int类型但得到的是{},类型为{}".format(process_num, type(wait_time)))
.format(process_num, type(wait_time)))
if not isinstance(wait_time, int): if not isinstance(wait_time, int):
raise ValueError("AKG kernel compiling wait time must be of type int, but got {} with type {}" raise ValueError("AKG内核编译等待时间必须是int类型但得到的是{},类型为{}".format(wait_time, type(wait_time)))
.format(wait_time, type(wait_time)))
if process_num == 0: if process_num == 0:
process_num = 1 process_num = 1
max_proc_num = 16 max_proc_num = 16
@ -96,13 +96,12 @@ class AkgProcess:
def compile(self, attrs=None): def compile(self, attrs=None):
""" """
compile kernel by multi processes 多进程编译内核
Return: 返回
True for all compile success, False for some failed. 所有编译成功返回True部分失败返回False
""" """
if self.argc == 0: if self.argc == 0:
raise ValueError("In AKG kernel compiling, the number of kernel json that need to be compiled can " raise ValueError("在AKG内核编译中需要编译的内核json数量不能为零。")
"not be zero.")
args = list((arg, attrs) for arg in self.args) args = list((arg, attrs) for arg in self.args)
if self.platform == "ASCEND": if self.platform == "ASCEND":
with Pool(processes=self.process_num) as pool: with Pool(processes=self.process_num) as pool:
@ -116,12 +115,11 @@ class AkgProcess:
def accept_json(self, json_str): def accept_json(self, json_str):
""" """
accept json data before compile 在编译前接受内核的json数据
Args: 参数
json_str: str. kernel info. json_strstr内核信息
""" """
if not isinstance(json_str, str): if not isinstance(json_str, str):
raise ValueError("In AKG kernel compiling, the kernel json must be of type str, but got {} with type {}" raise ValueError("在AKG内核编译中内核json必须是str类型但得到的是{},类型为{}".format(json_str, type(json_str)))
.format(json, type(json)))
self.args[self.argc % self.process_num].append(json_str) self.args[self.argc % self.process_num].append(json_str)
self.argc += 1 self.argc += 1

@ -28,16 +28,23 @@ def run_compiler(op_json, attrs=None):
None None
""" """
from get_file_path import get_akg_path from get_file_path import get_akg_path
# 将akg路径添加到sys.path中
sys.path.insert(0, get_akg_path()) sys.path.insert(0, get_akg_path())
# 导入akg模块
p = __import__("akg", globals(), locals(), ['ms'], 0) p = __import__("akg", globals(), locals(), ['ms'], 0)
# 获取akg.ms.compilewithjson函数
func = getattr(p.ms, "compilewithjson") func = getattr(p.ms, "compilewithjson")
# 调用akg.ms.compilewithjson函数进行编译
res = func(op_json, attrs) res = func(op_json, attrs)
# 如果编译失败,抛出异常
if not res: if not res:
raise ValueError("Compile error") raise ValueError("Compile error")
if __name__ == "__main__": if __name__ == "__main__":
# 如果命令行参数大于2则调用run_compiler函数传入op_json和attrs
if len(sys.argv) > 2: if len(sys.argv) > 2:
run_compiler(sys.argv[1], sys.argv[2]) run_compiler(sys.argv[1], sys.argv[2])
# 否则只传入op_json
else: else:
run_compiler(sys.argv[1]) run_compiler(sys.argv[1])

@ -19,18 +19,27 @@ import os
def get_akg_path(): def get_akg_path():
"""get akg directory base path""" """get akg directory base path"""
# 提示信息如果找不到mindspore模块请检查1MindSpore是否成功编译。2MindSpore是否成功安装使用pip install安装或设置环境变量PYTHONPATH为${mindspore_build_dir}/package
hint = "Please check: 1) whether MindSpore is compiled successfully. " \ hint = "Please check: 1) whether MindSpore is compiled successfully. " \
"2) Whether MindSpore is installed successfully with pip install or " \ "2) Whether MindSpore is installed successfully with pip install or " \
"the path ${mindspore_build_dir}/package is set in env PYTHONPATH." "the path ${mindspore_build_dir}/package is set in env PYTHONPATH."
# 查找mindspore模块
search_res = importlib.util.find_spec("mindspore") search_res = importlib.util.find_spec("mindspore")
if search_res is None: if search_res is None:
# 如果找不到mindspore模块抛出异常
raise RuntimeError("Cannot find mindspore module! {}".format(hint)) raise RuntimeError("Cannot find mindspore module! {}".format(hint))
# 获取mindspore模块的路径
res_path = search_res.origin res_path = search_res.origin
# 在路径中查找__init__.py文件
find_pos = res_path.find("__init__.py") find_pos = res_path.find("__init__.py")
if find_pos == -1: if find_pos == -1:
# 如果找不到__init__.py文件抛出异常
raise RuntimeError("Find module mindspore origin file failed! {}".format(hint)) raise RuntimeError("Find module mindspore origin file failed! {}".format(hint))
# 获取akg路径
akg_path = "{}_akg".format(res_path[:find_pos]) akg_path = "{}_akg".format(res_path[:find_pos])
# 如果akg路径不存在抛出异常
if not os.path.isdir(akg_path): if not os.path.isdir(akg_path):
raise RuntimeError("Cannot find akg from mindspore module! {}".format(hint)) raise RuntimeError("Cannot find akg from mindspore module! {}".format(hint))
# 返回akg路径
return akg_path return akg_path

Loading…
Cancel
Save