(add comments and docstrings to resources.py for clarity)

branch-yixin
yixin 7 months ago
parent eff398104d
commit c7222ba353

@ -17,7 +17,7 @@
"""Resources for ast tree parse."""
import ast
import math
from mindspore import RowTensor, SparseTensor, COOTensor, CSRTensor
from mindspore.ops import functional as F, composite as C
from mindspore.ops.composite import multitype_ops
@ -25,16 +25,16 @@ from mindspore._c_expression import security
from . import standard_method as M
from . import trope as T
from .namespace import CellNamespace
# namespace define
functional_ns = CellNamespace('mindspore.ops.functional')
composite_ns = CellNamespace('mindspore.ops.composite')
trope_ns = CellNamespace('mindspore._extends.parse.trope')
NO_IMPLEMENT = None # not implemented
SYMBOL_UNDEFINE = 0xFF # Undefined var and function
# Some space set aside for readability of code
# 一些空间设置以提高代码可读性
parse_object_map = {
# ast grammar
ast.Add: (trope_ns, 'add'),
@ -64,17 +64,17 @@ parse_object_map = {
ast.IsNot: (trope_ns, 'is_not'),
ast.In: (trope_ns, 'contains'),
ast.NotIn: (trope_ns, 'not_contains'),
# operation symbol type
'getitem': (composite_ns, 'getitem'),
'ms_iter': (composite_ns, 'ms_iter'),
'ms_next': (composite_ns, 'ms_next'),
'hasnext': (composite_ns, 'hasnext'),
# undefined type
SYMBOL_UNDEFINE: (None, 'undefine'),
}
# Operation symbols corresponding to ast grammar
ops_symbol_map = {
# ast grammar
@ -88,13 +88,13 @@ ops_symbol_map = {
ast.LShift: '<<',
ast.RShift: '>>',
ast.BitXor: '^',
# undefined type
SYMBOL_UNDEFINE: '',
}
# Escape an object to another object, eg: system function(len,xxx)
# Some space set aside for readability of code
# 将一个对象转为另一个对象,例如:系统函数(len,xxx)
# 一些空间设置以提高代码可读性
convert_object_map = {
T.add: multitype_ops.add,
T.sub: multitype_ops.sub,
@ -124,7 +124,7 @@ convert_object_map = {
T.is_not: F.is_not,
T.contains: multitype_ops.in_,
T.not_contains: multitype_ops.not_in_,
# system function
T.len: M.ms_len,
T.bool_: M.bool_,
@ -134,7 +134,7 @@ convert_object_map = {
T.zip: C.zip_operation,
T.enumerate: M.enumerate_,
T.isinstance: M.isinstance_,
# custom define operation
T.iter: M.ms_iter,
T.next: M.ms_next,
@ -145,7 +145,7 @@ convert_object_map = {
T.make_slice: F.make_slice,
T.range: F.make_range,
T.while_cond: M.while_cond,
# lib function
math.floor: NO_IMPLEMENT,
math.trunc: NO_IMPLEMENT,
@ -154,13 +154,14 @@ convert_object_map = {
math.sin: NO_IMPLEMENT,
math.cos: NO_IMPLEMENT,
math.tan: NO_IMPLEMENT,
# user defined
RowTensor: F.make_row_tensor,
SparseTensor: F.make_sparse_tensor,
COOTensor: F.make_coo_tensor,
CSRTensor: F.make_csr_tensor
}
# 如果不启用安全性,则将 T.print 映射到 F.print_
if not security.enable_security():
convert_object_map[T.print] = F.print_
convert_object_map[T.print] = F.print_

@ -50,55 +50,45 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
def MakeTuple(*elts): # pragma: no cover
"""Tuple builder."""
"""Tuple builder.""" # 创建元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def make_dict(key, value): # pragma: no cover
"""Dict builder."""
"""Dict builder.""" # 创建字典的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def make_list(*elts): # pragma: no cover
"""List builder."""
"""List builder.""" # 创建列表的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def make_slice(*elts): # pragma: no cover
"""Slice builder."""
"""Slice builder.""" # 创建切片的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def make_range(*elts): # pragma: no cover
"""Range tuple builder."""
"""Range tuple builder.""" # 创建范围元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.')
def switch(cond, tb, fb): # pragma: no cover
"""Switch statement, returns one of the two values."""
"""Switch statement, returns one of the two values.""" # 返回两个值中的一个的开关语句
raise RuntimeError('This operation is not meant to be called directly.')
def hasnext(it): # pragma: no cover
"""Hasnext function."""
"""Hasnext function.""" # 判断是否有下一个元素的函数
raise RuntimeError('This operation is not meant to be called directly.')
def to_array(x):
"""The to_array function."""
"""The to_array function.""" # 将输入转换为数组的函数
raise RuntimeError('This operation is not meant to be called directly.')
def not_contains(x): # pragma: no cover
"""Not in function."""
"""Not in function.""" # 判断元素是否不在集合中的函数
raise RuntimeError('This operation is not meant to be called directly.')
def while_cond(x): # pragma: no cover
"""Not in function."""
"""Not in function.""" # 判断条件是否成立的函数
raise RuntimeError('This operation is not meant to be called directly.')
def bool_(x): # pragma: no cover
"""judge true function."""
"""judge true function.""" # 判断一个值是否为真值的函数
raise RuntimeError('This operation is not meant to be called directly.')

@ -16,27 +16,37 @@
import os
from mindspore import log as logger
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
class Messager:
'''Messager'''
def __init__(self, fdin, fdout):
"""
初始化 Messager
Args:
fdin: 输入文件描述符
fdout: 输出文件描述符
"""
self.fdin = fdin
self.fdout = fdout
self.fin = os.fdopen(fdin, "r")
self.fout = os.fdopen(fdout, "w")
self.message = ''
def __del__(self):
"""
删除 Messager 实例时关闭文件描述符
"""
os.close(self.fdin)
os.close(self.fdout)
def get_message(self):
"""
Get message from remote
从远程获取消息
Returns:
message
"""
@ -58,13 +68,13 @@ class Messager:
self.send_ack()
self.exit()
return self.message
def send_res(self, res, keep_format=True):
"""
Send result to remote
发送结果到远程
Args:
keep_format: True or False
keep_format: True False
"""
logger.debug(f"[OUT] {str(res)}")
if keep_format:
@ -72,7 +82,7 @@ class Messager:
else:
res_str = str(res).replace('\n', '').replace('\r', '').replace(' ', '')
tag = '[~]' # The same as client kTAG
# Not write by print(tag + res_str, flush=True) any more
try:
self.fout.write(tag + res_str + "\n")
@ -82,69 +92,76 @@ class Messager:
self.exit()
finally:
pass
def send_ack(self, success=True):
"""
Send ack to remote
发送确认消息到远程
Args:
success: True or False
success: True False
"""
if success:
self.send_res('ACK')
else:
self.send_res('ERR')
def loop(self):
"""
Messaging loop
消息循环
"""
while True:
self.handle()
def run(self):
"""运行消息循环"""
self.loop()
def handle(self):
"""
A interface communicates with remote.
与远程通信的接口
Note:
All subclasses should override this interface.
所有子类应该重写此接口
"""
raise NotImplementedError
def exit(self):
"""
A interface handles the procedure before exit.
处理退出之前的程序
Note:
All subclasses should override this interface.
所有子类应该重写此接口
"""
raise NotImplementedError
class AkgBuilder():
"""Akg building wrapper"""
def __init__(self, platform):
"""
初始化 AkgBuilder
Args:
platform: 平台标识
"""
self.platform = platform
self.attrs = None
def create(self, process_num, waitime):
""" Create akg processor"""
""" 创建 akg 处理器"""
self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform)
def accept_json(self, json):
""" Accept json"""
""" 接受 json 数据"""
return self.akg_processor.accept_json(json)
def compile(self):
"""Compile"""
"""编译"""
return self.akg_processor.compile(self.attrs)
def handle(self, messager, arg):
"""Handle message about akg"""
"""处理关于 akg 的消息"""
if arg == 'AKG/START':
messager.send_ack()
process_num_str = messager.get_message()
@ -172,7 +189,8 @@ class AkgBuilder():
break
else:
raise RuntimeError("Unknown message type: %s" % arg)
def get_logger():
return logger
"""获取日志记录器"""
return logger

@ -20,19 +20,24 @@ from mindspore._extends.remote.kernel_build_server import Messager, get_logger,
class AkgMessager(Messager):
'''
Default Messager for akg kernels.
It works as a server, communicating with c++ client.
默认的 akg 内核消息处理器
它作为一个服务器 C++ 客户端进行通信
'''
def __init__(self, fdin, fdout):
"""
初始化 AkgMessager 实例
:param fdin: 输入文件描述符
:param fdout: 输出文件描述符
"""
super().__init__(fdin, fdout)
get_logger().info("[TRACE] Akg Messager init...")
self.akg_builder = AkgBuilder("default")
def handle(self):
"""
Communicate with remote client.
Reference protocol between them at PR#4063
与远程客户端进行通信
它们之间的参考协议见 PR#4063。
"""
arg = self.get_message()
if "AKG" in arg:
@ -42,11 +47,18 @@ class AkgMessager(Messager):
self.exit()
def exit(self):
"""
退出 AkgMessager
"""
get_logger().info("[TRACE] Akg Messager Exit...")
exit()
if __name__ == '__main__':
"""
程序入口
检查命令行参数并初始化 AkgMessager 实例
"""
warnings.simplefilter("ignore")
if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv))

@ -16,23 +16,24 @@
import sys
import warnings
import json
from mindspore._extends.parallel_compile.tbe_compiler.tbe_job_manager import TbeJobManager
from mindspore._extends.remote.kernel_build_server import Messager, get_logger, AkgBuilder
class AscendMessager(Messager):
"""
Ascend Messager
It works as a server, communicating with c++ client.
"""
# 初始化方法
def __init__(self, fdin, fdout):
super().__init__(fdin, fdout)
get_logger().info("[TRACE] Ascend Messager init...")
self.tbe_builder = TbeJobManager()
self.akg_builder = AkgBuilder("ASCEND")
# 处理与远程客户端的通信
def handle(self):
"""
Communicate with remote client.
@ -51,7 +52,7 @@ class AscendMessager(Messager):
self.exit()
finally:
pass
if "job_type" in job_json:
res = self.tbe_builder.job_handler(arg)
self.send_res(res)
@ -59,17 +60,18 @@ class AscendMessager(Messager):
get_logger().error("[TRACE] Request is not a TBE Job message: {}".format(arg))
self.send_ack(False)
self.exit()
# 退出方法
def exit(self):
self.tbe_builder.reset()
get_logger().info("[TRACE] Ascend Messager Exit...")
exit()
if __name__ == '__main__':
warnings.simplefilter("ignore")
if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv))
get_logger().debug(f"[TRACE] argv: {str(sys.argv)}")
messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2]))
messager.run()
messager.run()
Loading…
Cancel
Save