(add comments and docstrings to resources.py for clarity)

branch-yixin
yixin 7 months ago
parent eff398104d
commit c7222ba353

@ -17,7 +17,7 @@
"""Resources for ast tree parse.""" """Resources for ast tree parse."""
import ast import ast
import math import math
from mindspore import RowTensor, SparseTensor, COOTensor, CSRTensor from mindspore import RowTensor, SparseTensor, COOTensor, CSRTensor
from mindspore.ops import functional as F, composite as C from mindspore.ops import functional as F, composite as C
from mindspore.ops.composite import multitype_ops from mindspore.ops.composite import multitype_ops
@ -25,16 +25,16 @@ from mindspore._c_expression import security
from . import standard_method as M from . import standard_method as M
from . import trope as T from . import trope as T
from .namespace import CellNamespace from .namespace import CellNamespace
# namespace define # namespace define
functional_ns = CellNamespace('mindspore.ops.functional') functional_ns = CellNamespace('mindspore.ops.functional')
composite_ns = CellNamespace('mindspore.ops.composite') composite_ns = CellNamespace('mindspore.ops.composite')
trope_ns = CellNamespace('mindspore._extends.parse.trope') trope_ns = CellNamespace('mindspore._extends.parse.trope')
NO_IMPLEMENT = None # not implemented NO_IMPLEMENT = None # not implemented
SYMBOL_UNDEFINE = 0xFF # Undefined var and function SYMBOL_UNDEFINE = 0xFF # Undefined var and function
# Some space set aside for readability of code # 一些空间设置以提高代码可读性
parse_object_map = { parse_object_map = {
# ast grammar # ast grammar
ast.Add: (trope_ns, 'add'), ast.Add: (trope_ns, 'add'),
@ -64,17 +64,17 @@ parse_object_map = {
ast.IsNot: (trope_ns, 'is_not'), ast.IsNot: (trope_ns, 'is_not'),
ast.In: (trope_ns, 'contains'), ast.In: (trope_ns, 'contains'),
ast.NotIn: (trope_ns, 'not_contains'), ast.NotIn: (trope_ns, 'not_contains'),
# operation symbol type # operation symbol type
'getitem': (composite_ns, 'getitem'), 'getitem': (composite_ns, 'getitem'),
'ms_iter': (composite_ns, 'ms_iter'), 'ms_iter': (composite_ns, 'ms_iter'),
'ms_next': (composite_ns, 'ms_next'), 'ms_next': (composite_ns, 'ms_next'),
'hasnext': (composite_ns, 'hasnext'), 'hasnext': (composite_ns, 'hasnext'),
# undefined type # undefined type
SYMBOL_UNDEFINE: (None, 'undefine'), SYMBOL_UNDEFINE: (None, 'undefine'),
} }
# Operation symbols corresponding to ast grammar # Operation symbols corresponding to ast grammar
ops_symbol_map = { ops_symbol_map = {
# ast grammar # ast grammar
@ -88,13 +88,13 @@ ops_symbol_map = {
ast.LShift: '<<', ast.LShift: '<<',
ast.RShift: '>>', ast.RShift: '>>',
ast.BitXor: '^', ast.BitXor: '^',
# undefined type # undefined type
SYMBOL_UNDEFINE: '', SYMBOL_UNDEFINE: '',
} }
# Escape an object to another object, eg: system function(len,xxx) # 将一个对象转为另一个对象,例如:系统函数(len,xxx)
# Some space set aside for readability of code # 一些空间设置以提高代码可读性
convert_object_map = { convert_object_map = {
T.add: multitype_ops.add, T.add: multitype_ops.add,
T.sub: multitype_ops.sub, T.sub: multitype_ops.sub,
@ -124,7 +124,7 @@ convert_object_map = {
T.is_not: F.is_not, T.is_not: F.is_not,
T.contains: multitype_ops.in_, T.contains: multitype_ops.in_,
T.not_contains: multitype_ops.not_in_, T.not_contains: multitype_ops.not_in_,
# system function # system function
T.len: M.ms_len, T.len: M.ms_len,
T.bool_: M.bool_, T.bool_: M.bool_,
@ -134,7 +134,7 @@ convert_object_map = {
T.zip: C.zip_operation, T.zip: C.zip_operation,
T.enumerate: M.enumerate_, T.enumerate: M.enumerate_,
T.isinstance: M.isinstance_, T.isinstance: M.isinstance_,
# custom define operation # custom define operation
T.iter: M.ms_iter, T.iter: M.ms_iter,
T.next: M.ms_next, T.next: M.ms_next,
@ -145,7 +145,7 @@ convert_object_map = {
T.make_slice: F.make_slice, T.make_slice: F.make_slice,
T.range: F.make_range, T.range: F.make_range,
T.while_cond: M.while_cond, T.while_cond: M.while_cond,
# lib function # lib function
math.floor: NO_IMPLEMENT, math.floor: NO_IMPLEMENT,
math.trunc: NO_IMPLEMENT, math.trunc: NO_IMPLEMENT,
@ -154,13 +154,14 @@ convert_object_map = {
math.sin: NO_IMPLEMENT, math.sin: NO_IMPLEMENT,
math.cos: NO_IMPLEMENT, math.cos: NO_IMPLEMENT,
math.tan: NO_IMPLEMENT, math.tan: NO_IMPLEMENT,
# user defined # user defined
RowTensor: F.make_row_tensor, RowTensor: F.make_row_tensor,
SparseTensor: F.make_sparse_tensor, SparseTensor: F.make_sparse_tensor,
COOTensor: F.make_coo_tensor, COOTensor: F.make_coo_tensor,
CSRTensor: F.make_csr_tensor CSRTensor: F.make_csr_tensor
} }
# 如果不启用安全性,则将 T.print 映射到 F.print_
if not security.enable_security(): if not security.enable_security():
convert_object_map[T.print] = F.print_ convert_object_map[T.print] = F.print_

@ -50,55 +50,45 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
def MakeTuple(*elts): # pragma: no cover def MakeTuple(*elts): # pragma: no cover
"""Tuple builder.""" """Tuple builder.""" # 创建元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_dict(key, value): # pragma: no cover def make_dict(key, value): # pragma: no cover
"""Dict builder.""" """Dict builder.""" # 创建字典的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_list(*elts): # pragma: no cover def make_list(*elts): # pragma: no cover
"""List builder.""" """List builder.""" # 创建列表的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_slice(*elts): # pragma: no cover def make_slice(*elts): # pragma: no cover
"""Slice builder.""" """Slice builder.""" # 创建切片的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_range(*elts): # pragma: no cover def make_range(*elts): # pragma: no cover
"""Range tuple builder.""" """Range tuple builder.""" # 创建范围元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def switch(cond, tb, fb): # pragma: no cover def switch(cond, tb, fb): # pragma: no cover
"""Switch statement, returns one of the two values.""" """Switch statement, returns one of the two values.""" # 返回两个值中的一个的开关语句
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def hasnext(it): # pragma: no cover def hasnext(it): # pragma: no cover
"""Hasnext function.""" """Hasnext function.""" # 判断是否有下一个元素的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def to_array(x): def to_array(x):
"""The to_array function.""" """The to_array function.""" # 将输入转换为数组的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def not_contains(x): # pragma: no cover def not_contains(x): # pragma: no cover
"""Not in function.""" """Not in function.""" # 判断元素是否不在集合中的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def while_cond(x): # pragma: no cover def while_cond(x): # pragma: no cover
"""Not in function.""" """Not in function.""" # 判断条件是否成立的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def bool_(x): # pragma: no cover def bool_(x): # pragma: no cover
"""judge true function.""" """judge true function.""" # 判断一个值是否为真值的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')

@ -16,27 +16,37 @@
import os import os
from mindspore import log as logger from mindspore import log as logger
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
class Messager: class Messager:
'''Messager''' '''Messager'''
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
"""
初始化 Messager
Args:
fdin: 输入文件描述符
fdout: 输出文件描述符
"""
self.fdin = fdin self.fdin = fdin
self.fdout = fdout self.fdout = fdout
self.fin = os.fdopen(fdin, "r") self.fin = os.fdopen(fdin, "r")
self.fout = os.fdopen(fdout, "w") self.fout = os.fdopen(fdout, "w")
self.message = '' self.message = ''
def __del__(self): def __del__(self):
"""
删除 Messager 实例时关闭文件描述符
"""
os.close(self.fdin) os.close(self.fdin)
os.close(self.fdout) os.close(self.fdout)
def get_message(self): def get_message(self):
""" """
Get message from remote 从远程获取消息
Returns: Returns:
message message
""" """
@ -58,13 +68,13 @@ class Messager:
self.send_ack() self.send_ack()
self.exit() self.exit()
return self.message return self.message
def send_res(self, res, keep_format=True): def send_res(self, res, keep_format=True):
""" """
Send result to remote 发送结果到远程
Args: Args:
keep_format: True or False keep_format: True False
""" """
logger.debug(f"[OUT] {str(res)}") logger.debug(f"[OUT] {str(res)}")
if keep_format: if keep_format:
@ -72,7 +82,7 @@ class Messager:
else: else:
res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '') res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '')
tag = '[~]' # The same as client kTAG tag = '[~]' # The same as client kTAG
# Not write by print(tag + res_str, flush=True) any more # Not write by print(tag + res_str, flush=True) any more
try: try:
self.fout.write(tag + res_str + "\n") self.fout.write(tag + res_str + "\n")
@ -82,69 +92,76 @@ class Messager:
self.exit() self.exit()
finally: finally:
pass pass
def send_ack(self, success=True): def send_ack(self, success=True):
""" """
Send ack to remote 发送确认消息到远程
Args: Args:
success: True or False success: True False
""" """
if success: if success:
self.send_res('ACK') self.send_res('ACK')
else: else:
self.send_res('ERR') self.send_res('ERR')
def loop(self): def loop(self):
""" """
Messaging loop 消息循环
""" """
while True: while True:
self.handle() self.handle()
def run(self): def run(self):
"""运行消息循环"""
self.loop() self.loop()
def handle(self): def handle(self):
""" """
A interface communicates with remote. 与远程通信的接口
Note: Note:
All subclasses should override this interface. 所有子类应该重写此接口
""" """
raise NotImplementedError raise NotImplementedError
def exit(self): def exit(self):
""" """
A interface handles the procedure before exit. 处理退出之前的程序
Note: Note:
All subclasses should override this interface. 所有子类应该重写此接口
""" """
raise NotImplementedError raise NotImplementedError
class AkgBuilder(): class AkgBuilder():
"""Akg building wrapper""" """Akg building wrapper"""
def __init__(self, platform): def __init__(self, platform):
"""
初始化 AkgBuilder
Args:
platform: 平台标识
"""
self.platform = platform self.platform = platform
self.attrs = None self.attrs = None
def create(self, process_num, waitime): def create(self, process_num, waitime):
""" Create akg processor""" """ 创建 akg 处理器"""
self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform) self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform)
def accept_json(self, json): def accept_json(self, json):
""" Accept json""" """ 接受 json 数据"""
return self.akg_processor.accept_json(json) return self.akg_processor.accept_json(json)
def compile(self): def compile(self):
"""Compile""" """编译"""
return self.akg_processor.compile(self.attrs) return self.akg_processor.compile(self.attrs)
def handle(self, messager, arg): def handle(self, messager, arg):
"""Handle message about akg""" """处理关于 akg 的消息"""
if arg == 'AKG/START': if arg == 'AKG/START':
messager.send_ack() messager.send_ack()
process_num_str = messager.get_message() process_num_str = messager.get_message()
@ -172,7 +189,8 @@ class AkgBuilder():
break break
else: else:
raise RuntimeError("Unknown message type: %s" % arg) raise RuntimeError("Unknown message type: %s" % arg)
def get_logger(): def get_logger():
return logger """获取日志记录器"""
return logger

@ -20,19 +20,24 @@ from mindspore._extends.remote.kernel_build_server import Messager, get_logger,
class AkgMessager(Messager): class AkgMessager(Messager):
''' '''
Default Messager for akg kernels. 默认的 akg 内核消息处理器
It works as a server, communicating with c++ client. 它作为一个服务器 C++ 客户端进行通信
''' '''
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
"""
初始化 AkgMessager 实例
:param fdin: 输入文件描述符
:param fdout: 输出文件描述符
"""
super().__init__(fdin, fdout) super().__init__(fdin, fdout)
get_logger().info("[TRACE] Akg Messager init...") get_logger().info("[TRACE] Akg Messager init...")
self.akg_builder = AkgBuilder("default") self.akg_builder = AkgBuilder("default")
def handle(self): def handle(self):
""" """
Communicate with remote client. 与远程客户端进行通信
Reference protocol between them at PR#4063 它们之间的参考协议见 PR#4063。
""" """
arg = self.get_message() arg = self.get_message()
if "AKG" in arg: if "AKG" in arg:
@ -42,11 +47,18 @@ class AkgMessager(Messager):
self.exit() self.exit()
def exit(self): def exit(self):
"""
退出 AkgMessager
"""
get_logger().info("[TRACE] Akg Messager Exit...") get_logger().info("[TRACE] Akg Messager Exit...")
exit() exit()
if __name__ == '__main__': if __name__ == '__main__':
"""
程序入口
检查命令行参数并初始化 AkgMessager 实例
"""
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
if len(sys.argv) != 3: if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv)) raise Exception('Incorrect argv: {}'.format(sys.argv))

@ -16,23 +16,24 @@
import sys import sys
import warnings import warnings
import json import json
from mindspore._extends.parallel_compile.tbe_compiler.tbe_job_manager import TbeJobManager from mindspore._extends.parallel_compile.tbe_compiler.tbe_job_manager import TbeJobManager
from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder
class AscendMessager(Messager): class AscendMessager(Messager):
""" """
Ascend Messager Ascend Messager
It works as a server, communicating with c++ client. It works as a server, communicating with c++ client.
""" """
# 初始化方法
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
super().__init__(fdin, fdout) super().__init__(fdin, fdout)
get_logger().info("[TRACE] Ascend Messager init...") get_logger().info("[TRACE] Ascend Messager init...")
self.tbe_builder = TbeJobManager() self.tbe_builder = TbeJobManager()
self.akg_builder = AkgBuilder("ASCEND") self.akg_builder = AkgBuilder("ASCEND")
# 处理与远程客户端的通信
def handle(self): def handle(self):
""" """
Communicate with remote client. Communicate with remote client.
@ -51,7 +52,7 @@ class AscendMessager(Messager):
self.exit() self.exit()
finally: finally:
pass pass
if "job_type" in job_json: if "job_type" in job_json:
res = self.tbe_builder.job_handler(arg) res = self.tbe_builder.job_handler(arg)
self.send_res(res) self.send_res(res)
@ -59,17 +60,18 @@ class AscendMessager(Messager):
get_logger().error("[TRACE] Request is not a TBE Job message: {}".format(arg)) get_logger().error("[TRACE] Request is not a TBE Job message: {}".format(arg))
self.send_ack(False) self.send_ack(False)
self.exit() self.exit()
# 退出方法
def exit(self): def exit(self):
self.tbe_builder.reset() self.tbe_builder.reset()
get_logger().info("[TRACE] Ascend Messager Exit...") get_logger().info("[TRACE] Ascend Messager Exit...")
exit() exit()
if __name__ == '__main__': if __name__ == '__main__':
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
if len(sys.argv) != 3: if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv)) raise Exception('Incorrect argv: {}'.format(sys.argv))
get_logger().debug(f"[TRACE] argv: {str(sys.argv)}") get_logger().debug(f"[TRACE] argv: {str(sys.argv)}")
messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2])) messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2]))
messager.run() messager.run()
Loading…
Cancel
Save