_extends\parallel_compile\akg_compiler

branch-yixin
yixin 7 months ago
parent cd3f01ab90
commit 0012f23abf

@ -16,4 +16,8 @@
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
扩展函数
这些Python函数将在MindSpore的C++部分中被调用
"""

@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""akg process"""
import os
import json
import subprocess
@ -24,10 +26,10 @@ from mindspore._extends.parallel_compile.akg_compiler.get_file_path import get_a
def _compile_akg_task_default(json_strs, attrs):
"""
compile func called in single process
编译函数在单个进程中调用
Parameters:
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
参数
json_strs列表包含多个内核信息的列表适用于json编译API
"""
sys.path.insert(0, get_akg_path())
@ -37,15 +39,15 @@ def _compile_akg_task_default(json_strs, attrs):
for json_str in json_strs:
res = func(json_str, attrs)
if not res:
raise ValueError("Compile error, args: {}! build attrs: {}".format(json_str, attrs))
raise ValueError("编译错误,参数:{}!构建属性:{}".format(json_str, attrs))
def _compile_akg_task_ascend(json_strs, attrs):
"""
compile func called in single process
编译函数在单个进程中调用
Parameters:
json_strs: list. List contains multiple kernel infos, suitable for json compile api.
参数
json_strs列表包含多个内核信息的列表适用于json编译API
"""
if attrs is None:
attrs = "{}"
@ -56,35 +58,33 @@ def _compile_akg_task_ascend(json_strs, attrs):
if compile_result.returncode:
json_dict = json.loads(json_str)
if not json_dict.get("composite"):
raise ValueError("Compile error, json str: {}! build attrs: {}".format(json_str, attrs))
logger.debug("Will try to split, json str: {}! build attrs: {}".format(json_str, attrs))
raise ValueError("编译错误json字符串{}!构建属性:{}".format(json_str, attrs))
logger.debug("将尝试拆分json字符串{}!构建属性:{}".format(json_str, attrs))
def create_akg_parallel_process(process_num, wait_time, platform):
"""
create AkgParallelCompiler object
创建AkgParallelCompiler对象
Returns:
返回
AkgParallelCompiler
"""
return AkgProcess(process_num, wait_time, platform)
class AkgProcess:
"""akg kernel parallel process"""
"""akg内核并行进程"""
def __init__(self, process_num, wait_time, platform):
"""
Args:
process_num: int. processes number
wait_time: int. max time the function blocked
参数
process_numint进程数量
wait_timeint函数阻塞的最大时间
"""
if not isinstance(process_num, int):
raise ValueError("AKG kernel compiling process number must be of type int, but got {} with type {}"
.format(process_num, type(wait_time)))
raise ValueError("AKG内核编译进程数量必须是int类型但得到的是{},类型为{}".format(process_num, type(wait_time)))
if not isinstance(wait_time, int):
raise ValueError("AKG kernel compiling wait time must be of type int, but got {} with type {}"
.format(wait_time, type(wait_time)))
raise ValueError("AKG内核编译等待时间必须是int类型但得到的是{},类型为{}".format(wait_time, type(wait_time)))
if process_num == 0:
process_num = 1
max_proc_num = 16
@ -96,13 +96,12 @@ class AkgProcess:
def compile(self, attrs=None):
"""
compile kernel by multi processes
Return:
True for all compile success, False for some failed.
多进程编译内核
返回
所有编译成功返回True部分失败返回False
"""
if self.argc == 0:
raise ValueError("In AKG kernel compiling, the number of kernel json that need to be compiled can "
"not be zero.")
raise ValueError("在AKG内核编译中需要编译的内核json数量不能为零。")
args = list((arg, attrs) for arg in self.args)
if self.platform == "ASCEND":
with Pool(processes=self.process_num) as pool:
@ -116,12 +115,11 @@ class AkgProcess:
def accept_json(self, json_str):
"""
accept json data before compile
Args:
json_str: str. kernel info.
在编译前接受内核的json数据
参数
json_strstr内核信息
"""
if not isinstance(json_str, str):
raise ValueError("In AKG kernel compiling, the kernel json must be of type str, but got {} with type {}"
.format(json, type(json)))
raise ValueError("在AKG内核编译中内核json必须是str类型但得到的是{},类型为{}".format(json_str, type(json_str)))
self.args[self.argc % self.process_num].append(json_str)
self.argc += 1

@ -28,16 +28,23 @@ def run_compiler(op_json, attrs=None):
None
"""
from get_file_path import get_akg_path
# 将akg路径添加到sys.path中
sys.path.insert(0, get_akg_path())
# 导入akg模块
p = __import__("akg", globals(), locals(), ['ms'], 0)
# 获取akg.ms.compilewithjson函数
func = getattr(p.ms, "compilewithjson")
# 调用akg.ms.compilewithjson函数进行编译
res = func(op_json, attrs)
# 如果编译失败,抛出异常
if not res:
raise ValueError("Compile error")
if __name__ == "__main__":
# 如果命令行参数大于2则调用run_compiler函数传入op_json和attrs
if len(sys.argv) > 2:
run_compiler(sys.argv[1], sys.argv[2])
# 否则只传入op_json
else:
run_compiler(sys.argv[1])

@ -19,18 +19,27 @@ import os
def get_akg_path():
"""get akg directory base path"""
# 提示信息如果找不到mindspore模块请检查1MindSpore是否成功编译。2MindSpore是否成功安装使用pip install安装或设置环境变量PYTHONPATH为${mindspore_build_dir}/package
hint = "Please check: 1) whether MindSpore is compiled successfully. " \
"2) Whether MindSpore is installed successfully with pip install or " \
"the path ${mindspore_build_dir}/package is set in env PYTHONPATH."
# 查找mindspore模块
search_res = importlib.util.find_spec("mindspore")
if search_res is None:
# 如果找不到mindspore模块抛出异常
raise RuntimeError("Cannot find mindspore module! {}".format(hint))
# 获取mindspore模块的路径
res_path = search_res.origin
# 在路径中查找__init__.py文件
find_pos = res_path.find("__init__.py")
if find_pos == -1:
# 如果找不到__init__.py文件抛出异常
raise RuntimeError("Find module mindspore origin file failed! {}".format(hint))
# 获取akg路径
akg_path = "{}_akg".format(res_path[:find_pos])
# 如果akg路径不存在抛出异常
if not os.path.isdir(akg_path):
raise RuntimeError("Cannot find akg from mindspore module! {}".format(hint))
# 返回akg路径
return akg_path

Loading…
Cancel
Save