(add comments and docstrings to resources.py for clarity)

branch-yixin
yixin 7 months ago
parent eff398104d
commit c7222ba353

@ -34,7 +34,7 @@ trope_ns = CellNamespace('mindspore._extends.parse.trope')
NO_IMPLEMENT = None # not implemented NO_IMPLEMENT = None # not implemented
SYMBOL_UNDEFINE = 0xFF # Undefined var and function SYMBOL_UNDEFINE = 0xFF # Undefined var and function
# Some space set aside for readability of code # 一些空间设置以提高代码可读性
parse_object_map = { parse_object_map = {
# ast grammar # ast grammar
ast.Add: (trope_ns, 'add'), ast.Add: (trope_ns, 'add'),
@ -93,8 +93,8 @@ ops_symbol_map = {
SYMBOL_UNDEFINE: '', SYMBOL_UNDEFINE: '',
} }
# Escape an object to another object, eg: system function(len,xxx) # 将一个对象转为另一个对象,例如:系统函数(len,xxx)
# Some space set aside for readability of code # 一些空间设置以提高代码可读性
convert_object_map = { convert_object_map = {
T.add: multitype_ops.add, T.add: multitype_ops.add,
T.sub: multitype_ops.sub, T.sub: multitype_ops.sub,
@ -162,5 +162,6 @@ convert_object_map = {
CSRTensor: F.make_csr_tensor CSRTensor: F.make_csr_tensor
} }
# 如果不启用安全性,则将 T.print 映射到 F.print_
if not security.enable_security(): if not security.enable_security():
convert_object_map[T.print] = F.print_ convert_object_map[T.print] = F.print_

@ -50,55 +50,45 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
def MakeTuple(*elts): # pragma: no cover def MakeTuple(*elts): # pragma: no cover
"""Tuple builder.""" """Tuple builder.""" # 创建元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_dict(key, value): # pragma: no cover def make_dict(key, value): # pragma: no cover
"""Dict builder.""" """Dict builder.""" # 创建字典的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_list(*elts): # pragma: no cover def make_list(*elts): # pragma: no cover
"""List builder.""" """List builder.""" # 创建列表的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_slice(*elts): # pragma: no cover def make_slice(*elts): # pragma: no cover
"""Slice builder.""" """Slice builder.""" # 创建切片的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def make_range(*elts): # pragma: no cover def make_range(*elts): # pragma: no cover
"""Range tuple builder.""" """Range tuple builder.""" # 创建范围元组的构造函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def switch(cond, tb, fb): # pragma: no cover def switch(cond, tb, fb): # pragma: no cover
"""Switch statement, returns one of the two values.""" """Switch statement, returns one of the two values.""" # 返回两个值中的一个的开关语句
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def hasnext(it): # pragma: no cover def hasnext(it): # pragma: no cover
"""Hasnext function.""" """Hasnext function.""" # 判断是否有下一个元素的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def to_array(x): def to_array(x):
"""The to_array function.""" """The to_array function.""" # 将输入转换为数组的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def not_contains(x): # pragma: no cover def not_contains(x): # pragma: no cover
"""Not in function.""" """Not in function.""" # 判断元素是否不在集合中的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def while_cond(x): # pragma: no cover def while_cond(x): # pragma: no cover
"""Not in function.""" """Not in function.""" # 判断条件是否成立的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')
def bool_(x): # pragma: no cover def bool_(x): # pragma: no cover
"""judge true function.""" """judge true function.""" # 判断一个值是否为真值的函数
raise RuntimeError('This operation is not meant to be called directly.') raise RuntimeError('This operation is not meant to be called directly.')

@ -23,6 +23,13 @@ class Messager:
'''Messager''' '''Messager'''
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
"""
初始化 Messager
Args:
fdin: 输入文件描述符
fdout: 输出文件描述符
"""
self.fdin = fdin self.fdin = fdin
self.fdout = fdout self.fdout = fdout
self.fin = os.fdopen(fdin, "r") self.fin = os.fdopen(fdin, "r")
@ -30,12 +37,15 @@ class Messager:
self.message = '' self.message = ''
def __del__(self): def __del__(self):
"""
删除 Messager 实例时关闭文件描述符
"""
os.close(self.fdin) os.close(self.fdin)
os.close(self.fdout) os.close(self.fdout)
def get_message(self): def get_message(self):
""" """
Get message from remote 从远程获取消息
Returns: Returns:
message message
@ -61,10 +71,10 @@ class Messager:
def send_res(self, res, keep_format=True): def send_res(self, res, keep_format=True):
""" """
Send result to remote 发送结果到远程
Args: Args:
keep_format: True or False keep_format: True False
""" """
logger.debug(f"[OUT] {str(res)}") logger.debug(f"[OUT] {str(res)}")
if keep_format: if keep_format:
@ -85,10 +95,10 @@ class Messager:
def send_ack(self, success=True): def send_ack(self, success=True):
""" """
Send ack to remote 发送确认消息到远程
Args: Args:
success: True or False success: True False
""" """
if success: if success:
self.send_res('ACK') self.send_res('ACK')
@ -97,29 +107,30 @@ class Messager:
def loop(self): def loop(self):
""" """
Messaging loop 消息循环
""" """
while True: while True:
self.handle() self.handle()
def run(self): def run(self):
"""运行消息循环"""
self.loop() self.loop()
def handle(self): def handle(self):
""" """
A interface communicates with remote. 与远程通信的接口
Note: Note:
All subclasses should override this interface. 所有子类应该重写此接口
""" """
raise NotImplementedError raise NotImplementedError
def exit(self): def exit(self):
""" """
A interface handles the procedure before exit. 处理退出之前的程序
Note: Note:
All subclasses should override this interface. 所有子类应该重写此接口
""" """
raise NotImplementedError raise NotImplementedError
@ -128,23 +139,29 @@ class AkgBuilder():
"""Akg building wrapper""" """Akg building wrapper"""
def __init__(self, platform): def __init__(self, platform):
"""
初始化 AkgBuilder
Args:
platform: 平台标识
"""
self.platform = platform self.platform = platform
self.attrs = None self.attrs = None
def create(self, process_num, waitime): def create(self, process_num, waitime):
""" Create akg processor""" """ 创建 akg 处理器"""
self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform) self.akg_processor = create_akg_parallel_process(process_num, waitime, self.platform)
def accept_json(self, json): def accept_json(self, json):
""" Accept json""" """ 接受 json 数据"""
return self.akg_processor.accept_json(json) return self.akg_processor.accept_json(json)
def compile(self): def compile(self):
"""Compile""" """编译"""
return self.akg_processor.compile(self.attrs) return self.akg_processor.compile(self.attrs)
def handle(self, messager, arg): def handle(self, messager, arg):
"""Handle message about akg""" """处理关于 akg 的消息"""
if arg == 'AKG/START': if arg == 'AKG/START':
messager.send_ack() messager.send_ack()
process_num_str = messager.get_message() process_num_str = messager.get_message()
@ -175,4 +192,5 @@ class AkgBuilder():
def get_logger(): def get_logger():
"""获取日志记录器"""
return logger return logger

@ -20,19 +20,24 @@ from mindspore._extends.remote.kernel_build_server import Messager, get_logger,
class AkgMessager(Messager): class AkgMessager(Messager):
''' '''
Default Messager for akg kernels. 默认的 akg 内核消息处理器
It works as a server, communicating with c++ client. 它作为一个服务器 C++ 客户端进行通信
''' '''
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
"""
初始化 AkgMessager 实例
:param fdin: 输入文件描述符
:param fdout: 输出文件描述符
"""
super().__init__(fdin, fdout) super().__init__(fdin, fdout)
get_logger().info("[TRACE] Akg Messager init...") get_logger().info("[TRACE] Akg Messager init...")
self.akg_builder = AkgBuilder("default") self.akg_builder = AkgBuilder("default")
def handle(self): def handle(self):
""" """
Communicate with remote client. 与远程客户端进行通信
Reference protocol between them at PR#4063 它们之间的参考协议见 PR#4063。
""" """
arg = self.get_message() arg = self.get_message()
if "AKG" in arg: if "AKG" in arg:
@ -42,11 +47,18 @@ class AkgMessager(Messager):
self.exit() self.exit()
def exit(self): def exit(self):
"""
退出 AkgMessager
"""
get_logger().info("[TRACE] Akg Messager Exit...") get_logger().info("[TRACE] Akg Messager Exit...")
exit() exit()
if __name__ == '__main__': if __name__ == '__main__':
"""
程序入口
检查命令行参数并初始化 AkgMessager 实例
"""
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
if len(sys.argv) != 3: if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv)) raise Exception('Incorrect argv: {}'.format(sys.argv))

@ -26,13 +26,14 @@ class AscendMessager(Messager):
Ascend Messager Ascend Messager
It works as a server, communicating with c++ client. It works as a server, communicating with c++ client.
""" """
# 初始化方法
def __init__(self, fdin, fdout): def __init__(self, fdin, fdout):
super().__init__(fdin, fdout) super().__init__(fdin, fdout)
get_logger().info("[TRACE] Ascend Messager init...") get_logger().info("[TRACE] Ascend Messager init...")
self.tbe_builder = TbeJobManager() self.tbe_builder = TbeJobManager()
self.akg_builder = AkgBuilder("ASCEND") self.akg_builder = AkgBuilder("ASCEND")
# 处理与远程客户端的通信
def handle(self): def handle(self):
""" """
Communicate with remote client. Communicate with remote client.
@ -60,6 +61,7 @@ class AscendMessager(Messager):
self.send_ack(False) self.send_ack(False)
self.exit() self.exit()
# 退出方法
def exit(self): def exit(self):
self.tbe_builder.reset() self.tbe_builder.reset()
get_logger().info("[TRACE] Ascend Messager Exit...") get_logger().info("[TRACE] Ascend Messager Exit...")

Loading…
Cancel
Save