Compare commits

...

31 Commits

Author SHA1 Message Date
donghaoqian 8ff195e3f3 泛读报告格式调整后再次提交
2 months ago
donghaoqian 40500495da 提交修改后的泛读报告
2 months ago
donghaoqian 1319480deb 删除无关文件
2 months ago
donghaoqian 79516a3321 Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2
2 months ago
donghaoqian eeea480af9 更新泛读报告
2 months ago
pupb4l8rm ab1f7b5097 Merge pull request 'ruiqin' (#6) from Branch_ruiqin into main
2 months ago
zhang 759ad9d09c communication_ruiqin2
2 months ago
liuwenhao bdf240c986 Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-lwh
2 months ago
pupb4l8rm b12adc5753 Merge pull request 'ruiqin' (#5) from Branch_ruiqin into main
2 months ago
liuwenhao ae1d86521d Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-lwh
2 months ago
zhang 542726fc1a run_check,ruiqin
2 months ago
zhang ddf9ac7477 graph_utils,ruiqin
2 months ago
zhang 61450586a7 communication,ruiqin
2 months ago
pstluih63 6d63bbbd8c Merge pull request '提交mindspore泛读报告' (#4) from branch-donghaoqian into main
2 months ago
donghaoqian b47832ea8c 提交mindspore泛读报告
2 months ago
pbp2cnvu4 93f0af2f73 Merge pull request 'merge zwt' (#3) from branch_zwt into main
2 months ago
zouwentao 3668e69ed5 注释
2 months ago
zouwentao a672592d99 comment
2 months ago
donghaoqian 74aff73449 Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-donghaoqian
2 months ago
donghaoqian db2c8b9de0 董昊千代码注读
2 months ago
pptw92c8a c8a0ad4d29 Merge pull request 'xiangguo' (#2) from branch-xg into main
2 months ago
donghaoqian 69ae5ccbe1 董昊千注读/nn/loss/loss.py
2 months ago
xiangguo 7fa1f5ab19 branch-xg
2 months ago
xiangguo ea602b4c9f branch-xg
2 months ago
liuwenhao 5880e5be32 Merge branch 'main' of https://bdgit.educoder.net/pstluih63/mindspore_group_2 into branch-lwh
2 months ago
pc8nkf3eg 0c93bc5640 Merge pull request '_extend文件夹泛读注释' (#1) from branch-yixin into main
2 months ago
xiangguo 6201dc959e add comment
2 months ago
xiangguo bc8eeb6dd1 add comment setup.py
2 months ago
zhang 4e49598ac9 ruiqn
2 months ago
donghaoqian ade7c451d5 创建branch-donghaoqian
2 months ago
liuwenhao 65ca9afacc 注释
2 months ago

@ -1,2 +0,0 @@
# mindspore_group_2

@ -20,8 +20,11 @@ function(find_submodule_lib module name path)
) )
endfunction() endfunction()
# protobuf
function(ge_protobuf_generate c_var h_var) function(ge_protobuf_generate c_var h_var)
# common_protobuf_generateprotobuf
common_protobuf_generate(${CMAKE_BINARY_DIR}/proto/ge/proto ${c_var} ${h_var} ${ARGN}) common_protobuf_generate(${CMAKE_BINARY_DIR}/proto/ge/proto ${c_var} ${h_var} ${ARGN})
# chc_varh_var
set(${c_var} ${${c_var}} PARENT_SCOPE) set(${c_var} ${${c_var}} PARENT_SCOPE)
set(${h_var} ${${h_var}} PARENT_SCOPE) set(${h_var} ${${h_var}} PARENT_SCOPE)
endfunction() endfunction()

@ -19,23 +19,30 @@ from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib from ._hccl_management import load_lib as hccl_load_lib
from .._c_expression import get_rank_id, get_rank_size from .._c_expression import get_rank_id, get_rank_size
# 检查HCCL是否可用
_HCCL_AVAILABLE = False _HCCL_AVAILABLE = False
# 检查HCCL测试是否可用
_HCCL_TEST_AVAILABLE = False _HCCL_TEST_AVAILABLE = False
# 检查NCCL是否可用
_NCCL_AVAILABLE = False _NCCL_AVAILABLE = False
# 检查MPI是否可用
_MPI_AVAILABLE = False _MPI_AVAILABLE = False
try: try:
# 尝试导入mindspore._ms_mpi如果成功则NCCL可用
import mindspore._ms_mpi as mpi import mindspore._ms_mpi as mpi
_NCCL_AVAILABLE = True _NCCL_AVAILABLE = True
except ImportError: except ImportError:
# 如果导入失败则NCCL不可用
_NCCL_AVAILABLE = False _NCCL_AVAILABLE = False
# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True否则捕获 RuntimeError 并设置为 False
try: try:
hccl_load_lib() hccl_load_lib()
_HCCL_AVAILABLE = True _HCCL_AVAILABLE = True
except RuntimeError: except RuntimeError:
_HCCL_AVAILABLE = False _HCCL_AVAILABLE = False
# 如果 HCCL 可用,则导入 _hccl_management 并尝试导入 mindspore._ascend_mpi如果成功则设置 _MPI_AVAILABLE 为 True否则捕获 ImportError 并设置为 False
if _HCCL_AVAILABLE: if _HCCL_AVAILABLE:
from . import _hccl_management as hccl from . import _hccl_management as hccl
try: try:
@ -43,6 +50,7 @@ if _HCCL_AVAILABLE:
_MPI_AVAILABLE = True _MPI_AVAILABLE = True
except ImportError: except ImportError:
_MPI_AVAILABLE = False _MPI_AVAILABLE = False
# 如果 HCCL 不可用,则尝试导入 hccl_test.manage.api如果成功则设置 _HCCL_AVAILABLE 和 _HCCL_TEST_AVAILABLE 为 True否则捕获 ImportError 并设置 _HCCL_AVAILABLE 为 False
else: else:
try: try:
import hccl_test.manage.api as hccl import hccl_test.manage.api as hccl
@ -51,11 +59,10 @@ else:
except ImportError: except ImportError:
_HCCL_AVAILABLE = False _HCCL_AVAILABLE = False
# 定义 HCCL 和 NCCL 的通信组名称常量
HCCL_WORLD_COMM_GROUP = "hccl_world_group" HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group" NCCL_WORLD_COMM_GROUP = "nccl_world_group"
class Backend: class Backend:
""" """
Class for available backends. Class for available backends.
@ -82,13 +89,17 @@ class Backend:
def __new__(cls, name): def __new__(cls, name):
"""Create instance object of Backend.""" """Create instance object of Backend."""
# 检查传入的name是否为字符串类型
if not isinstance(name, str): if not isinstance(name, str):
raise TypeError("For 'Backend', the class variable 'name' must be a string, " raise TypeError("For 'Backend', the class variable 'name' must be a string, "
"but got the type : {}".format(type(name))) "but got the type : {}".format(type(name)))
# 获取对应name的大写形式的Backend类属性值如果不存在则返回Backend.UNDEFINED
value = getattr(Backend, name.upper(), Backend.UNDEFINED) value = getattr(Backend, name.upper(), Backend.UNDEFINED)
# 如果获取到的值是Backend.UNDEFINED说明传入的name不被支持
if value == Backend.UNDEFINED: if value == Backend.UNDEFINED:
raise ValueError("For 'Backend', the class variable 'name' {} is not supported, " raise ValueError("For 'Backend', the class variable 'name' {} is not supported, "
"please use hccl or nccl.".format(name)) "please use hccl or nccl.".format(name))
# 返回获取到的Backend类属性值
return value return value
DEFAULT_BACKEND = Backend("hccl") DEFAULT_BACKEND = Backend("hccl")
@ -103,36 +114,35 @@ class GlobalComm:
""" """
BACKEND = DEFAULT_BACKEND BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False INITED = False # 标记全局通信是否已初始化
CHECK_ENVS = True CHECK_ENVS = True # 标记是否需要检查通信环境变量
class _ExistingGroup: class _ExistingGroup:
""" """
The communication groups which exist in the progress. 用于表示在程序运行过程中存在的通信组
""" """
ITEMS = {} ITEMS = {} # 存储通信组的字典,键为通信组的标识符,值为通信组对象
def is_hccl_available(): def is_hccl_available():
""" """
Check HCCL api is available. 检查HCCL API是否可用
Returns: Returns:
Boolean. Return whether HCCL is available or not. Boolean: 返回HCCL是否可用
""" """
return _HCCL_AVAILABLE return _HCCL_AVAILABLE # 返回一个布尔值指示HCCL是否可用
def is_mpi_available(): def is_mpi_available():
""" """
Check HCCL & MPI api is available. 检查HCCL和MPI API是否可用
Returns: Returns:
Boolean. Return whether HCCL & MPI is available or not. Boolean: 返回HCCL和MPI是否同时可用
""" """
return _MPI_AVAILABLE return _MPI_AVAILABLE # 返回一个布尔值指示HCCL和MPI是否同时可用
def is_nccl_available(): def is_nccl_available():
""" """
@ -158,19 +168,25 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error. Wrapper. If not available, raise Error.
""" """
def wrapper(*args, **kargs): def wrapper(*args, **kargs):
# 如果当前角色是参数服务器或者调度器,直接调用被装饰的函数
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs) return func(*args, **kargs)
# 检查分布式通信是否已经初始化,未初始化则抛出异常
if not GlobalComm.INITED: if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited") raise RuntimeError("Distributed Communication has not been inited")
# 获取参数组默认为None
group = None group = None
# 检查关键字参数中是否包含"group",并进行类型检查
if "group" in kargs.keys(): if "group" in kargs.keys():
group = kargs.get("group") group = kargs.get("group")
if group is not None and not isinstance(group, str): if group is not None and not isinstance(group, str):
raise TypeError("The parameter 'group' should be str or None, " raise TypeError("The parameter 'group' should be str or None, "
"but got the type : {}".format(type(group))) "but got the type : {}".format(type(group)))
# 获取后端默认为None
if "backend" in kargs.keys(): if "backend" in kargs.keys():
backend = kargs.get("backend") backend = kargs.get("backend")
# 检查后端是否可用,不可用则抛出异常
if backend is Backend.HCCL and not is_hccl_available(): if backend is Backend.HCCL and not is_hccl_available():
raise RuntimeError("Distributed Communication doesn't have HCCL built in") raise RuntimeError("Distributed Communication doesn't have HCCL built in")
if backend is Backend.HCCL_MPI and not is_mpi_available(): if backend is Backend.HCCL_MPI and not is_mpi_available():
@ -178,15 +194,16 @@ def check_parameter_available(func):
if backend is Backend.NCCL and not is_nccl_available(): if backend is Backend.NCCL and not is_nccl_available():
raise RuntimeError("Distributed Communication doesn't have NCCL built in") raise RuntimeError("Distributed Communication doesn't have NCCL built in")
# 如果未指定group根据backend设置默认的group
if group is None: if group is None:
if backend is Backend.HCCL or Backend.HCCL_MPI: if backend is Backend.HCCL or backend is Backend.HCCL_MPI:
group = HCCL_WORLD_COMM_GROUP group = HCCL_WORLD_COMM_GROUP
elif backend is Backend.NCCL: elif backend is Backend.NCCL:
group = NCCL_WORLD_COMM_GROUP group = NCCL_WORLD_COMM_GROUP
# 调用被装饰的函数
return func(*args, **kargs) return func(*args, **kargs)
return wrapper return wrapper
@check_parameter_available @check_parameter_available
def _get_rank_helper(group, backend): def _get_rank_helper(group, backend):
""" """
@ -202,10 +219,13 @@ def _get_rank_helper(group, backend):
Returns: Returns:
Integer. The local rank id of the calling process. Integer. The local rank id of the calling process.
""" """
# 辅助函数,用于根据不同的后端和组获取 rank_id
# 获取当前角色的rank_id如果是参数服务器或调度器角色rank_id设为0并返回
rank_id = None rank_id = None
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
rank_id = 0 rank_id = 0
return rank_id return rank_id
# 根据不同的后端获取rank_id
if backend == Backend.HCCL_MPI: if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group) rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL: elif backend == Backend.HCCL:
@ -216,6 +236,7 @@ def _get_rank_helper(group, backend):
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
rank_id = get_rank_id(group) rank_id = get_rank_id(group)
else: else:
# 如果后端不被支持抛出ValueError异常
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, " raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi, hccl or nccl.".format(backend)) "please use hccl_mpi, hccl or nccl.".format(backend))
return rank_id return rank_id
@ -236,22 +257,30 @@ def _get_local_rank_helper(group, backend):
Returns: Returns:
Integer. The local rank id of the calling process. Integer. The local rank id of the calling process.
""" """
# 获取当前进程的rank id根据不同的后端和组进行处理
rank_id = None rank_id = None
# 根据不同的后端选择获取rank_id的方法
if backend == Backend.HCCL_MPI: if backend == Backend.HCCL_MPI:
# 使用HCCL MPI后端时通过mpi.get_rank_id获取rank_id
rank_id = mpi.get_rank_id(group) rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL: elif backend == Backend.HCCL:
# 使用HCCL后端时根据group的不同选择获取rank_id的方法
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
# 如果group是HCCL_WORLD_COMM_GROUP则使用hccl.get_local_rank_id获取rank_id
rank_id = hccl.get_local_rank_id() rank_id = hccl.get_local_rank_id()
else: else:
# 对于其他group同样使用hccl.get_local_rank_id获取rank_id
rank_id = hccl.get_local_rank_id(group) rank_id = hccl.get_local_rank_id(group)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
# 如果使用NCCL后端当前不支持get_local_rank_id方法抛出异常
raise RuntimeError("Nccl doesn't support get_local_rank_id now.") raise RuntimeError("Nccl doesn't support get_local_rank_id now.")
else: else:
# 如果backend既不是HCCL_MPI也不是HCCL抛出异常表示不支持的backend
raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, " raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi or hccl.".format(backend)) "please use hccl_mpi or hccl.".format(backend))
# 返回获取到的rank_id
return rank_id return rank_id
@check_parameter_available @check_parameter_available
def _get_size_helper(group, backend): def _get_size_helper(group, backend):
""" """
@ -268,9 +297,12 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group. Integer. The rank size of specified group.
""" """
size = None size = None
# 如果当前角色是参数服务器或调度器则将size设为1并返回
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
size = 1 size = 1
return size return size
# 根据不同的后端设置size的值
# 根据不同的后端获取组的大小
if backend == Backend.HCCL_MPI: if backend == Backend.HCCL_MPI:
size = mpi.get_rank_size(group) size = mpi.get_rank_size(group)
elif backend == Backend.HCCL: elif backend == Backend.HCCL:
@ -302,19 +334,23 @@ def _get_local_size_helper(group, backend):
Integer. The local rank size where the calling process is being within specified group. Integer. The local rank size where the calling process is being within specified group.
""" """
size = None size = None
# 根据不同的后端获取本地进程组的大小
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 如果组是全局通信组,则获取全局通信组的本地排名大小
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_local_rank_size() size = hccl.get_local_rank_size()
# 否则获取指定组的本地排名大小
else: else:
size = hccl.get_local_rank_size(group) size = hccl.get_local_rank_size(group)
# NCCL后端不支持获取本地排名大小抛出异常
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_local_rank_size now.") raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
# 对于不支持的后端,抛出异常
else: else:
raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, " raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, "
"please use hccl.".format(backend)) "please use hccl.".format(backend))
return size return size
@check_parameter_available @check_parameter_available
def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend): def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
""" """
@ -333,21 +369,26 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
Integer. A rank id in world communication group. Integer. A rank id in world communication group.
""" """
world_rank_id = None world_rank_id = None
# 检查 group_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(group_rank_id, int): if not isinstance(group_rank_id, int):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be" raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id))) " type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
# 根据不同的后端选择不同的逻辑处理方式
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 如果在 GPU 上使用 HCCL但 group 参数为 HCCL_WORLD_COMM_GROUP则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' " raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.") "should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl.get_world_rank_from_group_rank 方法获取 world_rank_id
world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id) world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
# 如果使用 NCCL则抛出 RuntimeError 表示不支持该操作
raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.") raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
else: else:
# 如果 backend 参数不支持,则抛出 ValueError 表示不支持该后端
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend)) raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取的 world_rank_id
return world_rank_id return world_rank_id
@check_parameter_available @check_parameter_available
def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend): def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
""" """
@ -366,21 +407,27 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
Integer. A rank id in user communication group. Integer. A rank id in user communication group.
""" """
group_rank_id = None group_rank_id = None
# 检查 world_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(world_rank_id, int): if not isinstance(world_rank_id, int):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, " raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
"but got 'world_rank_id' type : {}.".format(type(world_rank_id))) "but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
# 根据不同的后端处理获取 group_rank_id 的逻辑
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 检查 GPU 后端的 group 参数是否正确,如果不正确则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' " raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.") "should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl 模块的函数获取 group_rank_id
group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group) group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
# NCCL 后端不支持此操作,抛出 RuntimeError
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.") raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
# 如果后端不是 HCCL 或 NCCL则抛出 ValueError 表示不支持的后端
else: else:
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend)) raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取到的 group_rank_id
return group_rank_id return group_rank_id
@check_parameter_available @check_parameter_available
def _create_group_helper(group, rank_ids, backend): def _create_group_helper(group, rank_ids, backend):
""" """
@ -395,34 +442,46 @@ def _create_group_helper(group, rank_ids, backend):
TypeError: If rank_ids is not a list. TypeError: If rank_ids is not a list.
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
""" """
# 检查组是否已存在
if group in _ExistingGroup.ITEMS.keys(): if group in _ExistingGroup.ITEMS.keys():
# 如果组已存在且提供的rank_ids与存储的不一致抛出异常
if rank_ids != _ExistingGroup.ITEMS[group]: if rank_ids != _ExistingGroup.ITEMS[group]:
raise ValueError("The group {} has been created, the rank_list is {}, " raise ValueError("The group {} has been created, the rank_list is {}, "
"but current rank_list for the group is {}". "but current rank_list for the group is {}".
format(group, _ExistingGroup.ITEMS[group], rank_ids)) format(group, _ExistingGroup.ITEMS[group], rank_ids))
# 记录警告信息,提示组已存在
logger.warning("%r group has existed.", group) logger.warning("%r group has existed.", group)
return return
# 根据不同的后端创建组
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 检查rank_ids是否为列表类型
if not isinstance(rank_ids, list): if not isinstance(rank_ids, list):
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids))) "but got 'rank_ids' type : {}.".format(type(rank_ids)))
# 检查rank_ids的长度是否大于1
rank_size = len(rank_ids) rank_size = len(rank_ids)
if rank_size < 1: if rank_size < 1:
raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, " raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, "
"but got 'rank_ids' size : {}.".format(len(rank_ids))) "but got 'rank_ids' size : {}.".format(len(rank_ids)))
# 检查rank_ids中是否有重复的元素
if len(rank_ids) - len(list(set(rank_ids))) > 0: if len(rank_ids) - len(list(set(rank_ids))) > 0:
raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
# 使用HCCL创建组
hccl.create_group(group, rank_size, rank_ids) hccl.create_group(group, rank_size, rank_ids)
elif backend == Backend.HCCL_MPI: elif backend == Backend.HCCL_MPI:
# 使用HCCL_MPI创建组
mpi.create_group(group, rank_ids) mpi.create_group(group, rank_ids)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
# NCCL暂不支持创建组抛出异常
raise RuntimeError("Nccl doesn't support create_group now.") raise RuntimeError("Nccl doesn't support create_group now.")
else: else:
# 如果后端不支持,抛出异常
raise ValueError("The context configuration parameter 'backend' {} is not supported, " raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend)) "please use hccl.".format(backend))
_ExistingGroup.ITEMS[group] = rank_ids
# 将新创建的组及其rank_ids添加到_existingGroup中
_ExistingGroup.ITEMS[group] = rank_ids
@check_parameter_available @check_parameter_available
def _destroy_group_helper(group, backend): def _destroy_group_helper(group, backend):
""" """
@ -435,12 +494,17 @@ def _destroy_group_helper(group, backend):
Raises: Raises:
ValueError: If group is "hccl_world_group" or backend is invalid. ValueError: If group is "hccl_world_group" or backend is invalid.
""" """
# 根据后端类型销毁通信组
if backend == Backend.HCCL: if backend == Backend.HCCL:
# 检查是否为 HCCL 的全局通信组
if group == HCCL_WORLD_COMM_GROUP: if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("The hccl_world_group does not support destruction.") raise ValueError("The hccl_world_group does not support destruction.")
# 销毁指定的 HCCL 通信组
hccl.destroy_group(group) hccl.destroy_group(group)
elif backend == Backend.NCCL: elif backend == Backend.NCCL:
# 当前 NCCL 后端不支持销毁通信组
raise RuntimeError("Nccl doesn't support destroy_group now.") raise RuntimeError("Nccl doesn't support destroy_group now.")
else: else:
# 抛出错误,表示不支持的后端类型
raise ValueError("The context configuration parameter 'backend' {} is not supported, " raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend)) "please use hccl.".format(backend))

@ -24,7 +24,7 @@ MAX_RANK_NUM = 4096
HCCL_LIB = 'libhccl_plugin.so' HCCL_LIB = 'libhccl_plugin.so'
HCCL_LIB_CTYPES = "" HCCL_LIB_CTYPES = ""
# 检查集体通信组的名称是否合法
def check_group(group): def check_group(group):
""" """
A function that check if a collection communication group is legal. A function that check if a collection communication group is legal.
@ -41,23 +41,31 @@ def check_group(group):
raise TypeError("The type of communication group name must be type of string, " raise TypeError("The type of communication group name must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 检查集体通信中的排名编号是否合法
def check_rank_num(rank_num): def check_rank_num(rank_num):
""" """
A function that check if a collection communication rank number is legal.If not raise error. 检查通信集合中的排名编号是否合法如果不合法则抛出错误
参数:
rank_num: 需要检查的排名编号预期为整数类型
Returns: Returns:
None None: 该函数不返回任何值但可能会抛出异常
""" """
if isinstance(rank_num, (int)): # 检查 rank_num 是否为整数类型
if isinstance(rank_num, int):
# 检查 rank_num 是否在合法范围内大于0且小于等于 MAX_RANK_NUM
if rank_num > MAX_RANK_NUM or rank_num <= 0: if rank_num > MAX_RANK_NUM or rank_num <= 0:
raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and" # 如果 rank_num 不在合法范围内,抛出 ValueError 异常,并提供详细的错误信息
"less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num)) raise ValueError("对于 'create_group' 函数,参数 'rank_ids' 的大小必须大于0且"
"小于等于 {},但得到的 'rank_ids' 的大小为: {}".format(MAX_RANK_NUM, rank_num))
else: else:
raise TypeError("The argument 'rank_num' must be type of int, " # 如果 rank_num 不是整数类型,抛出 TypeError 异常,并提供详细的错误信息
"but got 'rank_num' type : {}.".format(type(rank_num))) raise TypeError("参数 'rank_num' 必须为整数类型,"
"但得到的 'rank_num' 类型为: {}".format(type(rank_num)))
#检查集体通信中的排名标识(rank id)是否合法
def check_rank_id(rank_id): def check_rank_id(rank_id):
""" """
A function that check if a collection communication rank id is legal.If not raise error. A function that check if a collection communication rank id is legal.If not raise error.
@ -65,30 +73,38 @@ def check_rank_id(rank_id):
Returns: Returns:
None None
""" """
# 检查rank_id是否为整数类型
if isinstance(rank_id, (int)): if isinstance(rank_id, (int)):
# 检查rank_id是否在有效范围内
if rank_id >= MAX_RANK_NUM or rank_id < 0: if rank_id >= MAX_RANK_NUM or rank_id < 0:
raise ValueError("The rand id in the communication group must be greater or equal 0 and " raise ValueError("The rand id in the communication group must be greater or equal 0 and "
"less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id)) "less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id))
else: else:
# 如果rank_id不是整数类型抛出类型错误
raise TypeError("The rand id in the communication group must be must be type of int, " raise TypeError("The rand id in the communication group must be must be type of int, "
"but got type value : {}.".format(type(rank_id))) "but got type value : {}.".format(type(rank_id)))
# 加载 HCCLHuawei Cloud Communication Library
def load_lib(): def load_lib():
"""load hccl lib""" """load hccl lib"""
try: try:
# 获取当前文件所在的目录
base_dir = os.path.dirname(os.path.realpath(__file__)) base_dir = os.path.dirname(os.path.realpath(__file__))
# 构建库文件的路径
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB) lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
# 加载库文件
hccl_lib = ctypes.CDLL(lib_path) hccl_lib = ctypes.CDLL(lib_path)
except Exception: except Exception:
# 如果加载失败则抛出运行时错误
raise RuntimeError('Get hccl lib error.') raise RuntimeError('Get hccl lib error.')
# 将加载的库文件设置为全局变量
global HCCL_LIB_CTYPES global HCCL_LIB_CTYPES
HCCL_LIB_CTYPES = hccl_lib HCCL_LIB_CTYPES = hccl_lib
def c_str(string): def c_str(string):
"""Convert a python string to C string.""" """Convert a python string to C string."""
# 将字符串转换为C风格字符串
if not isinstance(string, str): if not isinstance(string, str):
string = string.decode('ascii') string = string.decode('ascii')
return ctypes.c_char_p(string.encode('utf-8')) return ctypes.c_char_p(string.encode('utf-8'))
@ -96,9 +112,9 @@ def c_str(string):
def c_array(ctype, values): def c_array(ctype, values):
"""Create ctypes array from a python array.""" """Create ctypes array from a python array."""
# 从Python数组创建ctypes数组
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
#用于创建包含指定数量和ID的HCCL通信组但不能创建世界组。
def create_group(group, rank_num, rank_ids): def create_group(group, rank_num, rank_ids):
""" """
Create group. Create group.
@ -112,28 +128,38 @@ def create_group(group, rank_num, rank_ids):
Returns: Returns:
None None
""" """
# 检查组的有效性
check_group(group) check_group(group)
# 检查排名数量的有效性
check_rank_num(rank_num) check_rank_num(rank_num)
# 检查rank_ids是否为列表类型
if isinstance(rank_ids, (list)): if isinstance(rank_ids, (list)):
# 确保rank_num与rank_ids的长度一致
if rank_num != len(rank_ids): if rank_num != len(rank_ids):
raise ValueError("The argument 'rank_num' number should be equal to the length " raise ValueError("The argument 'rank_num' number should be equal to the length "
"of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}." "of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}."
.format(rank_num, rank_ids)) .format(rank_num, rank_ids))
# 检查rank_ids中的每个元素是否为非负整数
for rank_id in rank_ids: for rank_id in rank_ids:
if not isinstance(rank_id, (int)) or rank_id < 0: if not isinstance(rank_id, (int)) or rank_id < 0:
raise ValueError("The elements of argument 'rank_ids' must be " raise ValueError("The elements of argument 'rank_ids' must be "
"unsigned integer, but got the type : {}".format(type(rank_id))) "unsigned integer, but got the type : {}".format(type(rank_id)))
# 将rank_ids转换为C类型的数组
c_array_rank_ids = c_array(ctypes.c_uint, rank_ids) c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
# 将rank_num转换为C类型的无符号整数
c_rank_num = ctypes.c_uint(rank_num) c_rank_num = ctypes.c_uint(rank_num)
# 将group转换为C类型的字符串
c_group = c_str(group) c_group = c_str(group)
# 调用HCCL库创建组
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
# 检查组创建是否成功
if ret != 0: if ret != 0:
raise RuntimeError('Create group error, the error code is {}.'.format(ret)) raise RuntimeError('Create group error, the error code is {}.'.format(ret))
else: else:
# 如果rank_ids不是列表类型抛出类型错误
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids))) "but got 'rank_ids' type : {}.".format(type(rank_ids)))
#用于销毁用户创建的HCCL组
def destroy_group(group): def destroy_group(group):
""" """
A function that destroy the group which created by user. A function that destroy the group which created by user.
@ -144,11 +170,16 @@ def destroy_group(group):
Returns: Returns:
None None
""" """
check_group(group) # 检查传入的组是否有效
c_group = c_str(group) check_group(group)
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) # 将组名转换为C风格的字符串
if ret != 0: c_group = c_str(group)
raise RuntimeError('Destroy group error.') # 调用HCCL库中的函数销毁指定的组
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
# 如果返回值不为0说明销毁组时发生了错误抛出异常
if ret != 0:
raise RuntimeError('Destroy group error.')
def get_rank_size(group="hccl_world_group"): def get_rank_size(group="hccl_world_group"):
@ -162,16 +193,23 @@ def get_rank_size(group="hccl_world_group"):
An integer scalar with the num of ranks. An integer scalar with the num of ranks.
""" """
# 根据上下文的模式判断是否为PYNATIVE_MODE模式若是则直接返回HCCL的rank size
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_size() return get_hccl_rank_size()
# 检查给定的组是否有效
check_group(group) check_group(group)
# 将组名转换为C字符串格式
c_group = c_str(group) c_group = c_str(group)
# 定义一个C类型的无符号整数用于存储rank size
c_rank_size = ctypes.c_uint() c_rank_size = ctypes.c_uint()
# 调用HCCL库的HcomGetRankSize函数获取组内的rank size
ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size)) ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
# 如果返回值不为0表示获取rank size失败抛出运行时错误
if ret != 0: if ret != 0:
raise RuntimeError('Get rank size error.') raise RuntimeError('Get rank size error.')
# 返回获取到的rank size值
return c_rank_size.value return c_rank_size.value
@ -186,17 +224,22 @@ def get_rank_id(group="hccl_world_group"):
if context.get_context("mode") == context.PYNATIVE_MODE: if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_id() return get_hccl_rank_id()
# 检查组的有效性
check_group(group) check_group(group)
# 将组转换为 C 字符串格式
c_group = c_str(group) c_group = c_str(group)
# 定义一个用于存储 rank id 的 ctypes 无符号整数
c_rank_id = ctypes.c_uint() c_rank_id = ctypes.c_uint()
# 调用 HCCL 库获取当前进程的 rank id
ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id)) ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
# 如果返回值不为 0表示获取 rank id 出错,抛出 RuntimeError 异常
if ret != 0: if ret != 0:
raise RuntimeError('Get rank id error.') raise RuntimeError('Get rank id error.')
# 返回获取到的 rank id 值
return c_rank_id.value return c_rank_id.value
def get_local_rank_size(group="hccl_world_group"): def get_local_rank_size(group="hccl_world_group"):
""" """
A function that returns the number of local ranks within the given collection communication group. A function that returns the number of local ranks within the given collection communication group.
@ -207,19 +250,25 @@ def get_local_rank_size(group="hccl_world_group"):
Returns: Returns:
An integer scalar with the num of local ranks. An integer scalar with the num of local ranks.
""" """
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE: if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, " raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, "
"'get_local_rank_size' only support GRAPH_MODE") "'get_local_rank_size' only support GRAPH_MODE")
# 验证传入的组是否有效
check_group(group) check_group(group)
# 将组名称转换为C字符串格式
c_group = c_str(group) c_group = c_str(group)
# 定义一个ctypes的无符号整数变量用于存储本地排名大小
c_local_rank_size = ctypes.c_uint() c_local_rank_size = ctypes.c_uint()
# 调用HCCL库中的HcomGetLocalRankSize函数获取本地排名大小
ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size)) ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
# 如果返回值不为0说明获取本地排名大小时出错抛出异常
if ret != 0: if ret != 0:
raise RuntimeError('Get local rank size error.') raise RuntimeError('Get local rank size error.')
# 返回获取到的本地排名大小
return c_local_rank_size.value return c_local_rank_size.value
def get_local_rank_id(group="hccl_world_group"): def get_local_rank_id(group="hccl_world_group"):
""" """
Get local rank id. Get local rank id.
@ -230,16 +279,23 @@ def get_local_rank_id(group="hccl_world_group"):
An integer scalar with the local rank id of the calling process. An integer scalar with the local rank id of the calling process.
""" """
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE: if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, " raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, "
"'get_local_rank_id' only support GRAPH_MODE") "'get_local_rank_id' only support GRAPH_MODE")
# 验证群组的有效性
check_group(group) check_group(group)
# 将群组名称转换为C字符串格式
c_group = c_str(group) c_group = c_str(group)
# 定义一个无符号整型的C类型变量来存储本地排名ID
c_local_rank_id = ctypes.c_uint() c_local_rank_id = ctypes.c_uint()
# 调用HCCL库的HcomGetLocalRankId函数获取本地排名ID
ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id)) ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
# 如果返回值不为0表示获取本地排名ID失败抛出异常
if ret != 0: if ret != 0:
raise RuntimeError('Get local rank id error.') raise RuntimeError('Get local rank id error.')
# 返回获取到的本地排名ID值
return c_local_rank_id.value return c_local_rank_id.value
@ -256,18 +312,25 @@ def get_world_rank_from_group_rank(group, group_rank_id):
if context.get_context("mode") is context.PYNATIVE_MODE: if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, " raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, "
"'get_world_rank_from_group_rank' only support GRAPH_MODE") "'get_world_rank_from_group_rank' only support GRAPH_MODE")
# 检查组名是否有效
check_group(group) check_group(group)
# 检查组内rank ID是否有效
check_rank_id(group_rank_id) check_rank_id(group_rank_id)
# 将组名转换为C字符串格式
c_group = c_str(group) c_group = c_str(group)
# 将组内rank ID转换为C的无符号整数类型
c_group_rank_id = ctypes.c_uint(group_rank_id) c_group_rank_id = ctypes.c_uint(group_rank_id)
# 声明一个用于存储世界rank ID的C的无符号整数类型变量
c_world_rank_id = ctypes.c_uint() c_world_rank_id = ctypes.c_uint()
# 调用HCCL库中的HcomGetWorldRankFromGroupRank函数根据组名和组内rank ID获取对应的世界rank ID
ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id)) ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
# 如果返回值不为0说明函数调用出错抛出RuntimeError异常
if ret != 0: if ret != 0:
raise RuntimeError('Get world rank from group rank error.') raise RuntimeError('根据组内rank ID获取世界rank ID时出错。')
# 返回获取到的世界rank ID的值
return c_world_rank_id.value return c_world_rank_id.value
def get_group_rank_from_world_rank(world_rank_id, group): def get_group_rank_from_world_rank(world_rank_id, group):
""" """
Get group rank from world rank. Get group rank from world rank.
@ -281,13 +344,21 @@ def get_group_rank_from_world_rank(world_rank_id, group):
if context.get_context("mode") is context.PYNATIVE_MODE: if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, " raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, "
"'get_group_rank_from_world_rank' only support GRAPH_MODE") "'get_group_rank_from_world_rank' only support GRAPH_MODE")
# 检查组的有效性
check_group(group) check_group(group)
# 检查世界排名ID的有效性
check_rank_id(world_rank_id) check_rank_id(world_rank_id)
# 将组转换为C字符串
c_group = c_str(group) c_group = c_str(group)
# 将世界排名ID转换为C无符号整数
c_world_rank_id = ctypes.c_uint(world_rank_id) c_world_rank_id = ctypes.c_uint(world_rank_id)
# 创建一个用于存储组排名ID的C无符号整数
c_group_rank_id = ctypes.c_uint() c_group_rank_id = ctypes.c_uint()
# 从世界排名ID获取组排名ID
ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id)) ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
# 如果返回值不为0抛出运行时错误
if ret != 0: if ret != 0:
raise RuntimeError('Get group rank from world rank error.') raise RuntimeError('Get group rank from world rank error.')
# 返回获取到的组排名ID的值
return c_group_rank_id.value return c_group_rank_id.value

@ -20,40 +20,51 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
_get_local_rank_helper, _get_local_size_helper, GlobalComm _get_local_rank_helper, _get_local_size_helper, GlobalComm
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
# 导入mindspore上下文模块
# 导入mindspore的进程服务器和调度器角色判断函数
# 导入通信辅助模块包括后端类型、获取和设置rank及size的辅助函数、创建和销毁组的辅助函数、以及世界通信组和本地通信组的相关标识
# 导入C语言表达式模块包含HCCL和NCCL的初始化和终结化函数以及GPU集体通信的初始化函数
__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"get_local_rank_size", "get_world_rank_from_group_rank", "get_local_rank_size", "get_world_rank_from_group_rank",
"get_group_rank_from_world_rank", "create_group", "destroy_group", "get_group_rank_from_world_rank", "create_group", "destroy_group",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
# 默认的世界通信组
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
def _get_group(group): def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" """根据输入的`group`参数返回相应的通信组。
如果`group`等于`DEFAULT_WORLD_COMM_GROUP`则返回全局通信组`GlobalComm.WORLD_COMM_GROUP`
否则返回传入的`group`参数
"""
if group == DEFAULT_WORLD_COMM_GROUP: if group == DEFAULT_WORLD_COMM_GROUP:
return GlobalComm.WORLD_COMM_GROUP return GlobalComm.WORLD_COMM_GROUP # 返回默认的世界通信组
return group return group # 返回传入的通信组参数
def _check_task_sink_envs(): def _check_task_sink_envs():
""" """检查任务接收器task_sink相关的环境变量是否已导出。
Check whether task_sink environment variables have been exported or not.
return True if task_sink environment variables have been exported, False otherwise. 该函数通过检查环境变量`GRAPH_OP_RUN`来判断任务接收器的环境变量是否已导出
返回值
- 如果环境变量`GRAPH_OP_RUN`已导出且其值可以转换为整数1则返回False表示环境变量未正确设置为启用状态
- 如果环境变量`GRAPH_OP_RUN`未导出或其值不能转换为整数1则返回True表示环境变量未导出或设置有误
""" """
import os import os
task_sink = os.getenv("GRAPH_OP_RUN") task_sink = os.getenv("GRAPH_OP_RUN") # 获取名为"GRAPH_OP_RUN"的环境变量
if task_sink: if task_sink:
try: try:
if int(task_sink) == 1: if int(task_sink) == 1: # 尝试将环境变量的值转换为整数并检查是否等于1
return False return False # 如果等于1返回False表示环境变量已导出但设置为禁用状态非预期情况
except ValueError: except ValueError:
return True return True # 如果转换为整数失败返回True表示环境变量设置有误
finally: finally:
pass pass # finally块中的代码在这里是空操作通常用于清理操作
return True return True # 如果环境变量未导出返回True
def _check_parallel_envs(): def _check_parallel_envs():
@ -63,25 +74,31 @@ def _check_parallel_envs():
Raises: Raises:
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values. RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
""" """
# 检查是否需要进行环境验证,如果不进行则直接返回
if not GlobalComm.CHECK_ENVS: if not GlobalComm.CHECK_ENVS:
return return
import os import os
# 获取环境变量RANK_ID的值
rank_id_str = os.getenv("RANK_ID") rank_id_str = os.getenv("RANK_ID")
# 如果RANK_ID未设置抛出运行时错误
if not rank_id_str: if not rank_id_str:
raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.") raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.")
try: try:
# 尝试将RANK_ID转换为整数
int(rank_id_str) int(rank_id_str)
except ValueError: except ValueError:
# 如果转换失败,打印错误信息
print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str))) print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str)))
finally: finally:
# 无论是否发生异常,此块为空操作
pass pass
# 获取环境变量MINDSPORE_HCCL_CONFIG_PATH和RANK_TABLE_FILE的值
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH") rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE") rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
# 如果两个环境变量都未设置,抛出运行时错误
if not rank_table_file_str and not rank_table_file_str_old: if not rank_table_file_str and not rank_table_file_str_old:
raise RuntimeError("Get hccl rank_table_file failed, " raise RuntimeError("Get hccl rank_table_file failed, "
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.") "please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
def init(backend_name=None): def init(backend_name=None):
""" """
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service. Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
@ -109,15 +126,20 @@ def init(backend_name=None):
>>> from mindspore.communication import init >>> from mindspore.communication import init
>>> init() >>> init()
""" """
# 检查当前角色是否为参数服务器或调度器,如果是则直接返回
if _is_role_pserver() or _is_role_sched(): if _is_role_pserver() or _is_role_sched():
return return
# 检查任务接收环境变量,获取设备目标和模式
task_sink = _check_task_sink_envs() task_sink = _check_task_sink_envs()
device_target = context.get_context("device_target") device_target = context.get_context("device_target")
mode = context.get_context("mode") mode = context.get_context("mode")
mpi_init = False mpi_init = False
# 如果没有任务接收且模式为图模式设置mpi_init为True
if not task_sink and mode == context.GRAPH_MODE: if not task_sink and mode == context.GRAPH_MODE:
mpi_init = True mpi_init = True
# 根据设备目标选择后端名称,如果不支持则抛出异常
# 根据设备目标设置默认的后端名称
if backend_name is None: if backend_name is None:
if device_target == "Ascend": if device_target == "Ascend":
backend_name = "hccl" backend_name = "hccl"
@ -126,28 +148,34 @@ def init(backend_name=None):
else: else:
raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in " raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in "
"parallel initialization, please use Ascend or GPU.".format(device_target)) "parallel initialization, please use Ascend or GPU.".format(device_target))
# 检查后端名称是否为字符串,如果不是则抛出异常
if not isinstance(backend_name, str): if not isinstance(backend_name, str):
raise TypeError("For 'init', the argument 'backend_name' must be a string, " raise TypeError("For 'init', the argument 'backend_name' must be a string, "
"but got the type : {}".format(type(backend_name))) "but got the type : {}".format(type(backend_name)))
# 根据后端名称初始化通信环境
if backend_name == "hccl": if backend_name == "hccl":
# 如果设备目标不是Ascend抛出异常
if device_target != "Ascend": if device_target != "Ascend":
raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, " raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, "
"but got {}".format(device_target)) "but got {}".format(device_target))
# 如果不需要MPI初始化检查并行环境并设置后端名称
if not mpi_init: if not mpi_init:
_check_parallel_envs() _check_parallel_envs()
GlobalComm.BACKEND = Backend("hccl") GlobalComm.BACKEND = Backend("hccl")
else: else:
GlobalComm.BACKEND = Backend("hccl_mpi") GlobalComm.BACKEND = Backend("hccl_mpi")
# 初始化HCCL并设置全局通信组和初始化状态
init_hccl() init_hccl()
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True GlobalComm.INITED = True
elif backend_name == "nccl": elif backend_name == "nccl":
# 初始化GPU集体通信并设置后端名称、全局通信组和初始化状态
init_gpu_collective() init_gpu_collective()
GlobalComm.BACKEND = Backend("nccl") GlobalComm.BACKEND = Backend("nccl")
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True GlobalComm.INITED = True
else: else:
# 如果后端名称不支持,抛出异常
raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, " raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, "
"but got the 'backend_name' : hccl.") "but got the 'backend_name' : hccl.")
@ -165,10 +193,11 @@ def release():
Examples: Examples:
>>> from mindspore.communication import init, release >>> from mindspore.communication import init, release
>>> init() >>> init() # 初始化分布式通信环境
>>> release() >>> release() # 释放分布式通信资源
""" """
finalize_hccl() finalize_hccl()# 结束 HCCL 的使用,释放相关资源
def get_rank(group=GlobalComm.WORLD_COMM_GROUP): def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
@ -197,12 +226,18 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
>>> print(rank_id) >>> print(rank_id)
>>> # the result is the rank_id in world_group >>> # the result is the rank_id in world_group
""" """
# 检查传入的group参数是否为字符串类型
if not isinstance(group, str): if not isinstance(group, str):
# 如果group参数不是字符串类型则抛出TypeError异常
raise TypeError("For 'get_rank', the argument 'group' must be type of string, " raise TypeError("For 'get_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用_get_rank_helper函数获取指定group的rank ID
# _get_group函数用于解析group参数
# GlobalComm.BACKEND变量存储了当前使用的通信后端
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
""" """
Gets local rank ID for current device in specified collective communication group. Gets local rank ID for current device in specified collective communication group.
@ -234,9 +269,11 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank)) >>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank))
local_rank is: 1, world_rank is 9 local_rank is: 1, world_rank is 9
""" """
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, " raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _get_local_rank_helper 来获取本地排名,传入的参数为解析后的组和全局通信后端
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -270,9 +307,11 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("group_size is: ", group_size) >>> print("group_size is: ", group_size)
group_size is: 8 group_size is: 8
""" """
# 检查传入的参数 'group' 是否为字符串类型
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, " raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 返回指定组的大小,使用辅助函数 _get_size_helper 和全局通信后端
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -306,9 +345,11 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("local_rank_size is: ", local_rank_size) >>> print("local_rank_size is: ", local_rank_size)
local_rank_size is: 8 local_rank_size is: 8
""" """
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, " raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数获取本地组的大小,使用 _get_group 函数获取组,并使用 GlobalComm.BACKEND 作为后端
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -347,12 +388,12 @@ def get_world_rank_from_group_rank(group, group_rank_id):
>>> print("world_rank_id is: ", world_rank_id) >>> print("world_rank_id is: ", world_rank_id)
world_rank_id is: 4 world_rank_id is: 4
""" """
# 检查传入的 group 参数是否为字符串类型
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, " raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 返回根据组和组排名获取的世界排名的帮助函数的结果
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND) return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
def get_group_rank_from_world_rank(world_rank_id, group): def get_group_rank_from_world_rank(world_rank_id, group):
""" """
Get the rank ID in the specified user communication group corresponding to Get the rank ID in the specified user communication group corresponding to
@ -389,12 +430,13 @@ def get_group_rank_from_world_rank(world_rank_id, group):
>>> print("group_rank_id is: ", group_rank_id) >>> print("group_rank_id is: ", group_rank_id)
group_rank_id is: 1 group_rank_id is: 1
""" """
# 检查输入参数 'group' 是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, " raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _get_group_rank_from_world_rank_helper 来获取组的排名,并返回结果
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND) return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
#创建一个用户集体通信组通过传入通信组名称和设备ID列表来实现。
def create_group(group, rank_ids): def create_group(group, rank_ids):
""" """
Create a user collective communication group. Create a user collective communication group.
@ -427,9 +469,11 @@ def create_group(group, rank_ids):
>>> create_group(group, rank_ids) >>> create_group(group, rank_ids)
>>> allreduce = ops.AllReduce(group) >>> allreduce = ops.AllReduce(group)
""" """
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'create_group', the argument 'group' must be type of string, " raise TypeError("For 'create_group', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _create_group_helper 来创建组,使用指定的 rank_ids 和后端
_create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND) _create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND)
@ -450,7 +494,9 @@ def destroy_group(group):
ValueError: If group is "hccl_world_group" or backend is invalid. ValueError: If group is "hccl_world_group" or backend is invalid.
RuntimeError: If HCCL is not available or MindSpore is GPU version. RuntimeError: If HCCL is not available or MindSpore is GPU version.
""" """
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str): if not isinstance(group, str):
raise TypeError("For 'destroy_group', the argument 'group' must be type of string, " raise TypeError("For 'destroy_group', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group))) "but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _destroy_group_helper 来销毁指定的组,并使用全局通信后端
_destroy_group_helper(group, backend=GlobalComm.BACKEND) _destroy_group_helper(group, backend=GlobalComm.BACKEND)

@ -452,11 +452,14 @@ class _GeneratorWorkerMp(multiprocessing.Process):
""" """
def __init__(self, dataset, eof, max_rowsize, queue_size, ppid): def __init__(self, dataset, eof, max_rowsize, queue_size, ppid):
# 初始化一个多进程队列,用于存储索引
self.idx_queue = multiprocessing.Queue(queue_size) self.idx_queue = multiprocessing.Queue(queue_size)
# 如果启用了共享内存,则初始化一个共享队列,否则初始化一个多进程队列
if get_enable_shared_mem(): if get_enable_shared_mem():
self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize) self.res_queue = _SharedQueue(queue_size, max_rowsize=max_rowsize)
else: else:
self.res_queue = multiprocessing.Queue(queue_size) self.res_queue = multiprocessing.Queue(queue_size)
# 设置队列的_joincancelled属性为True表示在进程退出时队列不会阻塞
self.idx_queue._joincancelled = True # pylint: disable=W0212 self.idx_queue._joincancelled = True # pylint: disable=W0212
self.res_queue._joincancelled = True # pylint: disable=W0212 self.res_queue._joincancelled = True # pylint: disable=W0212
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid)) super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True, ppid))
@ -465,6 +468,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
""" """
Put function for worker index queue. Never block. Raise queue.Full on failure. Put function for worker index queue. Never block. Raise queue.Full on failure.
""" """
# 将item放入idx_queue队列中不阻塞如果失败则抛出queue.Full异常
self.idx_queue.put_nowait(item) self.idx_queue.put_nowait(item)
def get(self): def get(self):
@ -476,12 +480,19 @@ class _GeneratorWorkerMp(multiprocessing.Process):
return self.res_queue.get(timeout=30) return self.res_queue.get(timeout=30)
def queue_empty(self): def queue_empty(self):
# 检查idx_queue是否为空
if not self.idx_queue.empty(): if not self.idx_queue.empty():
# 如果不为空,记录警告日志
logger.warning("idx_queue is not empty.") logger.warning("idx_queue is not empty.")
# 返回False
return False return False
# 检查res_queue是否为空
if not self.res_queue.empty(): if not self.res_queue.empty():
# 如果不为空,记录警告日志
logger.warning("res_queue is not empty.") logger.warning("res_queue is not empty.")
# 返回False
return False return False
# 如果两个队列都为空返回True
return True return True
def __del__(self): def __del__(self):
@ -632,14 +643,17 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
python_multiprocessing=True, max_rowsize=6): python_multiprocessing=True, max_rowsize=6):
# 调用父类的初始化方法
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id) shuffle=shuffle, num_shards=num_shards, shard_id=shard_id)
# 如果source是zip类型则将其转换为列表
if isinstance(source, builtins.zip): if isinstance(source, builtins.zip):
# Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array. # Although zip is iteratable, it does not have the feature of repeated iteration, so pass it to the array.
self.source = [item for item in source] self.source = [item for item in source]
else: else:
self.source = source self.source = source
self.prepared_source = None # source to be sent to C++ self.prepared_source = None # source to be sent to C++
# 如果self.operator_mixed属性为True则将num_parallel_workers设置为1
if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True: if hasattr(self, 'operator_mixed') and getattr(self, 'operator_mixed') is True:
self.num_parallel_workers = 1 self.num_parallel_workers = 1
logger.warning( logger.warning(
@ -650,56 +664,78 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
self.python_multiprocessing = python_multiprocessing self.python_multiprocessing = python_multiprocessing
# 将column_names转换为列表
self.column_names = to_list(column_names) self.column_names = to_list(column_names)
# 如果column_types不为空则将其转换为detypelist类型
if column_types is not None: if column_types is not None:
self.column_types = mstypelist_to_detypelist(column_types) self.column_types = mstypelist_to_detypelist(column_types)
else: else:
self.column_types = [] self.column_types = []
self.schema = schema self.schema = schema
# 如果schema不为空则将其转换为Schema类型
if schema is not None: if schema is not None:
# 如果schema不为空则将其赋值给self.schema
self.schema = schema self.schema = schema
# 如果schema不是Schema类型则将其转换为Schema类型
if not isinstance(schema, Schema): if not isinstance(schema, Schema):
self.schema = Schema(schema) self.schema = Schema(schema)
# Move get dataset_size by len from parse to here, because self.source will # Move get dataset_size by len from parse to here, because self.source will
# lose attribution of '__len__' after deepcopy. # lose attribution of '__len__' after deepcopy.
self.source_len = -1 # unknown self.source_len = -1 # unknown
# 如果self.source有__len__属性则获取self.source的长度
if hasattr(self.source, "__len__"): if hasattr(self.source, "__len__"):
self.source_len = len(self.source) self.source_len = len(self.source)
# 设置最大行大小
self.max_rowsize = max_rowsize self.max_rowsize = max_rowsize
# 设置采样函数为None
self.sample_fn = None self.sample_fn = None
def __deepcopy__(self, memodict): def __deepcopy__(self, memodict):
# 深度复制当前对象并传入一个字典memodict用于存储已经复制的对象
if id(self) in memodict: if id(self) in memodict:
# 如果当前对象的id已经在memodict中则直接返回该对象
return memodict[id(self)] return memodict[id(self)]
# 否则调用__safe_deepcopy__方法进行深度复制并传入memodict和exclude参数
new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__")) new_op = self.__safe_deepcopy__(memodict, exclude=("source", "__transfer_dataset__"))
sample_fn = None sample_fn = None
# 如果新对象的sampler属性不为空并且self.source对象具有__getitem__方法
if new_op.sampler is not None and hasattr(self.source, "__getitem__"): if new_op.sampler is not None and hasattr(self.source, "__getitem__"):
# The reason why there is a try catch here is because when the new op is being constructed with shared # The reason why there is a try catch here is because when the new op is being constructed with shared
# memory enabled, there will be an exception thrown if there is not enough shared memory available # memory enabled, there will be an exception thrown if there is not enough shared memory available
# 如果self.source_len为-1则抛出RuntimeError异常因为尝试构造一个随机访问的数据集需要__len__方法
if self.source_len == -1: if self.source_len == -1:
raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!") raise RuntimeError("Attempt to construct a random access dataset, '__len__' method is required!")
try: try:
# 如果新对象的num_parallel_workers大于1则调用__validate_memory_usage方法进行内存使用验证
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
self.__validate_memory_usage() self.__validate_memory_usage()
# 创建一个SamplerFn对象用于并行采样
sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing, sample_fn = SamplerFn(self.source, new_op.num_parallel_workers, self.python_multiprocessing,
self.max_rowsize) self.max_rowsize)
# 将新对象的prepared_source属性设置为_cpp_sampler_fn_mp函数用于并行采样
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn)) new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn_mp(sample_ids, sample_fn))
else: else:
# 否则将新对象的prepared_source属性设置为_cpp_sampler_fn函数用于单线程采样
new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source)) new_op.prepared_source = (lambda sample_ids: _cpp_sampler_fn(sample_ids, self.source))
# 将新对象的sample_fn属性设置为sample_fn
new_op.sample_fn = sample_fn new_op.sample_fn = sample_fn
except RuntimeError as e: except RuntimeError as e:
# 如果抛出RuntimeError异常则抛出Exception异常并传入异常信息
raise Exception(str(e)) raise Exception(str(e))
else: else:
try: try:
# 否则将新对象的sampler属性设置为Nonesample_fn属性设置为sample_fn
new_op.sampler = None new_op.sampler = None
new_op.sample_fn = sample_fn new_op.sample_fn = sample_fn
# 将新对象的source_len属性设置为min(new_op.source_len, new_op.num_samples)如果new_op.num_samples不为0否则设置为new_op.source_len
new_op.source_len = min(new_op.source_len, new_op.source_len = min(new_op.source_len,
new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len new_op.num_samples) if new_op.num_samples != 0 else new_op.source_len
# 遍历self.source对象
iter(self.source) iter(self.source)
except TypeError: except TypeError:
# Use generator function if input callable # Use generator function if input callable
@ -711,19 +747,26 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
return new_op return new_op
# 判断是否被洗牌
def is_shuffled(self): def is_shuffled(self):
return self.sampler.is_shuffled() return self.sampler.is_shuffled()
# 判断是否被分片
def is_sharded(self): def is_sharded(self):
return self.sampler.is_sharded() return self.sampler.is_sharded()
# 解析
def parse(self, children=None): def parse(self, children=None):
# 如果schema为空则返回GeneratorNode对象
if self.schema is None: if self.schema is None:
return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len, return cde.GeneratorNode(self.prepared_source, self.column_names, self.column_types, self.source_len,
self.sampler, self.num_parallel_workers) self.sampler, self.num_parallel_workers)
# 获取schema
schema = self.schema schema = self.schema
# 如果schema是Schema类型则获取cpp_schema
if isinstance(schema, Schema): if isinstance(schema, Schema):
schema = self.schema.cpp_schema schema = self.schema.cpp_schema
# 返回GeneratorNode对象
return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler, return cde.GeneratorNode(self.prepared_source, schema, self.source_len, self.sampler,
self.num_parallel_workers) self.num_parallel_workers)
@ -735,24 +778,37 @@ class GeneratorDataset(MappableDataset, UnionBaseDataset):
# if use num_parallel_workers is to large when python_multiprocessing=True which would cause # if use num_parallel_workers is to large when python_multiprocessing=True which would cause
# OOM error get the num_shards # OOM error get the num_shards
valid_num_shards = 1 valid_num_shards = 1
# 判断self.sampler是否为samplers.DistributedSampler类型
if isinstance(self.sampler, samplers.DistributedSampler): if isinstance(self.sampler, samplers.DistributedSampler):
# 如果是则将self.sampler的num_shards赋值给valid_num_shards
valid_num_shards = self.sampler.num_shards valid_num_shards = self.sampler.num_shards
# 否则判断self.num_shards是否为None
elif self.num_shards is not None: elif self.num_shards is not None:
# 如果不是则将self.num_shards赋值给valid_num_shards
valid_num_shards = self.num_shards valid_num_shards = self.num_shards
# get process memory usage # get process memory usage
# 获取当前进程
process = psutil.Process(os.getpid()) process = psutil.Process(os.getpid())
# 获取当前进程的内存信息
process_memory = process.memory_info().rss process_memory = process.memory_info().rss
# 获取系统内存的空闲量
sys_memory_free = psutil.virtual_memory().free sys_memory_free = psutil.virtual_memory().free
# 计算可能使用的总内存量
total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards total_memory_maybe_used = process_memory * self.num_parallel_workers * valid_num_shards
# 如果总内存可能使用的内存量除以系统可用内存大于0.85
if total_memory_maybe_used / sys_memory_free > 0.85: if total_memory_maybe_used / sys_memory_free > 0.85:
# 计算有效的worker数量即系统可用内存乘以0.85除以有效的shards数量再除以每个进程的内存
valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory) valid_num_worker = math.floor(sys_memory_free * 0.85 / valid_num_shards / process_memory)
# 如果有效的worker数量小于等于0则将其设置为1
valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker valid_num_worker = 1 if valid_num_worker <= 0 else valid_num_worker
# 构造警告信息提示用户num_parallel_workers设置过大可能会导致内存占用过高或OOM建议将其减小到valid_num_worker或更小
info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \ info = "GeneratorDataset's num_parallel_workers: {} is too large which may cause a lot of memory " \
"occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \ "occupation (>85%) or out of memory(OOM) during multiprocessing. Therefore, it is recommended " \
"to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers, "to reduce num_parallel_workers to {} or smaller.".format(self.num_parallel_workers,
valid_num_worker) valid_num_worker)
# 打印警告信息
logger.warning(info) logger.warning(info)
@ -764,37 +820,55 @@ class _NumpySlicesDataset:
def __init__(self, data, column_list=None): def __init__(self, data, column_list=None):
self.column_list = None self.column_list = None
# Convert dict data into tuple # Convert dict data into tuple
# 判断data是否为字典类型
if isinstance(data, dict): if isinstance(data, dict):
# 如果是字典类型则调用process_dict方法处理
data = self.process_dict(data) data = self.process_dict(data)
# 判断data是否为元组类型
if isinstance(data, tuple): if isinstance(data, tuple):
# 如果是元组类型则将self.data初始化为空元组
self.data = () self.data = ()
# 获取data的长度
data_len = len(data) data_len = len(data)
# 遍历data中的每个元素
for i in range(data_len): for i in range(data_len):
# 将data中的每个元素转换为numpy数组并添加到self.data中
self.data = self.data + (np.array(data[i]),) self.data = self.data + (np.array(data[i]),)
else: else:
# 如果data不是元组类型则将data转换为numpy数组并添加到self.data中
self.data = (np.array(data),) self.data = (np.array(data),)
# check whether the data length in each column is equal # check whether the data length in each column is equal
# 获取每个data_item的长度
data_len = [len(data_item) for data_item in self.data] data_len = [len(data_item) for data_item in self.data]
# 如果每个data_item的长度不相等则抛出ValueError异常
if data_len[1:] != data_len[:-1]: if data_len[1:] != data_len[:-1]:
raise ValueError("Data length in each column is not equal.") raise ValueError("Data length in each column is not equal.")
# Init column_name # Init column_name
# 如果column_list不为空则将self.column_list赋值为column_list
if column_list is not None: if column_list is not None:
self.column_list = column_list self.column_list = column_list
# 如果self.column_list为空则将self.column_list赋值为空列表
elif self.column_list is None: elif self.column_list is None:
self.column_list = [] self.column_list = []
# 获取data的列数
column_num = len(self.data) column_num = len(self.data)
# 遍历列数,将"column_" + str(i)添加到self.column_list中
for i in range(column_num): for i in range(column_num):
self.column_list.append("column_" + str(i)) self.column_list.append("column_" + str(i))
def __getitem__(self, index): def __getitem__(self, index):
# 获取指定索引的数据行
data_row = [d[index, ...] for d in self.data] data_row = [d[index, ...] for d in self.data]
# 将数据行转换为元组
data_res = tuple(data_row) data_res = tuple(data_row)
# 返回数据行
return data_res return data_res
def __len__(self): def __len__(self):
# 返回data的第一个元素的长度
return len(self.data[0]) return len(self.data[0])
def process_dict(self, input_data): def process_dict(self, input_data):
@ -802,24 +876,29 @@ class _NumpySlicesDataset:
Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first. Convert the dict like data into tuple format, when input is a tuple of dicts then compose it into a dict first.
""" """
# Convert pandas like dict(has "values" column) into General dict # Convert pandas like dict(has "values" column) into General dict
# 将pandas样式的字典有"values"列)转换为通用字典
data_keys = list(input_data.keys()) data_keys = list(input_data.keys())
# 获取字典的第一个键对应的值
data_col = input_data[data_keys[0]] data_col = input_data[data_keys[0]]
# 如果值有values属性则将其转换为通用字典
if hasattr(data_col, "values"): if hasattr(data_col, "values"):
new_dict = {} new_dict = {}
for key in data_keys: for key in data_keys:
# 将字典中的键对应的值转换为列表
item1 = input_data.pop(key) item1 = input_data.pop(key)
new_dict[key] = item1.values new_dict[key] = item1.values
# 将转换后的字典赋值给input_data
input_data = new_dict input_data = new_dict
# Convert the data in dict into tuple # Convert the data in dict into tuple
data = () data = () # 初始化一个空元组
keys = list(input_data.keys()) keys = list(input_data.keys()) # 将输入数据的键转换为列表
self.column_list = keys self.column_list = keys # 将键列表赋值给实例变量column_list
for key in keys: for key in keys: # 遍历键列表
value = input_data[key] value = input_data[key] # 获取键对应的值
data = data + (list(value),) data = data + (list(value),) # 将值转换为列表,并添加到元组中
return data return data # 返回元组
class NumpySlicesDataset(GeneratorDataset): class NumpySlicesDataset(GeneratorDataset):
@ -909,7 +988,9 @@ class NumpySlicesDataset(GeneratorDataset):
@check_numpyslicesdataset @check_numpyslicesdataset
def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, def __init__(self, data, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None,
num_shards=None, shard_id=None): num_shards=None, shard_id=None):
# 创建一个_NumpySlicesDataset对象传入data和column_names参数
dataset = _NumpySlicesDataset(data, column_names) dataset = _NumpySlicesDataset(data, column_names)
# 调用父类的__init__方法传入dataset、column_names、num_samples、num_parallel_workers、shuffle、sampler、num_shards和shard_id参数
super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples, super().__init__(dataset, column_names=dataset.column_list, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler,
num_shards=num_shards, shard_id=shard_id) num_shards=num_shards, shard_id=shard_id)

@ -45,14 +45,15 @@ class OneOf(OneOf_):
TypeError: raise type error for invalid inputs. TypeError: raise type error for invalid inputs.
""" """
self.patterns = patterns self.patterns = patterns
# 检查 patterns 是否是 Pattern 类的实例
if isinstance(patterns, Pattern): if isinstance(patterns, Pattern):
OneOf_.__init__(self, [patterns]) OneOf_.__init__(self, [patterns])
# 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
OneOf_.__init__(self, patterns) OneOf_.__init__(self, patterns)
# 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常
else: else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class Prim(Prim_): class Prim(Prim_):
r""" r"""
Express a pattern of certain primitive type(s). Express a pattern of certain primitive type(s).
@ -76,25 +77,33 @@ class Prim(Prim_):
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
# 检查name是否为字符串类型如果不是则抛出TypeError
if name is not None and not isinstance(name, str): if name is not None and not isinstance(name, str):
raise TypeError(f"Expect string, got : {name}") raise TypeError(f"Expect string, got : {name}")
self.name = name self.name = name
# 如果types是字符串类型则将其按'|'分割成列表
if isinstance(types, str): if isinstance(types, str):
if self.name is None: if self.name is None:
self.name = types self.name = types
self.types = types.split('|') self.types = types.split('|')
# 如果types是Primitive类型则直接将其放入列表中
elif isinstance(types, Primitive): elif isinstance(types, Primitive):
if self.name is None: if self.name is None:
self.name = types.name self.name = types.name
self.types = [types] self.types = [types]
# 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型
elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
# 如果 self.name 为 None则初始化为空字符串并拼接所有 Primitive 的 name
if self.name is None: if self.name is None:
self.name = "" self.name = ""
for prim in types: for prim in types:
self.name += prim.name self.name += prim.name
# 设置 self.types 为传入的 types
self.types = types self.types = types
# 如果 types 不符合预期类型,抛出 TypeError
else: else:
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
# 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name
Prim_.__init__(self, self.types, self.name) Prim_.__init__(self, self.types, self.name)
@ -115,16 +124,22 @@ class Call(Call_):
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
# 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError
if not isinstance(prim_pattern, (Pattern, str, Primitive)): if not isinstance(prim_pattern, (Pattern, str, Primitive)):
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
self.prim_pattern = prim_pattern self.prim_pattern = prim_pattern
# 初始化 inputs 列表
self.inputs = [] self.inputs = []
# 如果 inputs 为 None则不做任何操作
if inputs is None: if inputs is None:
pass pass
# 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs
elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
self.inputs = inputs self.inputs = inputs
# 如果 inputs 不符合上述条件,则抛出 TypeError
else: else:
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
# 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs
Call_.__init__(self, self.prim_pattern, self.inputs) Call_.__init__(self, self.prim_pattern, self.inputs)
@ -145,6 +160,7 @@ class NoneOf(NoneOf_):
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
self.patterns = patterns self.patterns = patterns
# 根据 patterns 的类型初始化 NoneOf_ 类
if patterns is None: if patterns is None:
NoneOf_.__init__(self, ()) NoneOf_.__init__(self, ())
elif isinstance(patterns, Pattern): elif isinstance(patterns, Pattern):
@ -154,7 +170,6 @@ class NoneOf(NoneOf_):
else: else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
class NewTensor(NewTensor_): class NewTensor(NewTensor_):
r""" r"""
New Tensor to be used in the target. New Tensor to be used in the target.
@ -167,13 +182,16 @@ class NewTensor(NewTensor_):
Raises: Raises:
TypeError: raise type error for invalid argument. TypeError: raise type error for invalid argument.
""" """
# 初始化输入张量
self.input_tensor = input_tensor self.input_tensor = input_tensor
# 检查输入是否为 Tensor 类型
if isinstance(input_tensor, Tensor): if isinstance(input_tensor, Tensor):
# 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法
NewTensor_.__init__(self, input_tensor) NewTensor_.__init__(self, input_tensor)
else: else:
# 如果不是 Tensor 类型,则抛出 TypeError 异常
raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}") raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}")
class NewParameter(NewParameter_): class NewParameter(NewParameter_):
r""" r"""
New Parameter to be used in the target. New Parameter to be used in the target.
@ -193,11 +211,14 @@ class NewParameter(NewParameter_):
self.default_tensor = default_tensor self.default_tensor = default_tensor
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel self.layerwise_parallel = layerwise_parallel
# 检查参数类型是否符合预期
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool): isinstance(layerwise_parallel, bool):
# 初始化 NewParameter_ 类
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel) self.layerwise_parallel)
else: else:
# 如果参数类型不符合预期,抛出 TypeError
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) got : {para_name}, {default_tensor}, \ layerwise_parallel(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}") {requires_grad}, {layerwise_parallel}")

@ -39,6 +39,7 @@ class PyPassManager(PyPassManager_):
TypeError: If argument has invalid type. TypeError: If argument has invalid type.
""" """
def __init__(self, requires_grad=True, run_only_once=False): def __init__(self, requires_grad=True, run_only_once=False):
# 初始化方法,接收两个布尔参数,设置实例的属性并调用父类的初始化方法
if not isinstance(requires_grad, bool): if not isinstance(requires_grad, bool):
raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}") raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
if not isinstance(run_only_once, bool): if not isinstance(run_only_once, bool):
@ -48,17 +49,20 @@ class PyPassManager(PyPassManager_):
PyPassManager_.__init__(self) PyPassManager_.__init__(self)
def register(self, py_pass): def register(self, py_pass):
# 注册一个Python pass检查其是否为函数类型并获取其模式和目标
if not isfunction(py_pass): if not isfunction(py_pass):
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass() pattern, target = py_pass()
pass_name = py_pass.__name__ pass_name = py_pass.__name__
# 检查模式和目标是否为Pattern类型
if not isinstance(pattern, Pattern): if not isinstance(pattern, Pattern):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern): if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
# 调用父类的register方法注册pass及其相关信息
super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_) super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
def unregister(self, py_pass): def unregister(self, py_pass):
# 从注册表中移除指定的Python传递对象可以是字符串形式的名称或函数对象
if isinstance(py_pass, str): if isinstance(py_pass, str):
super().unregister(py_pass) super().unregister(py_pass)
return return
@ -68,25 +72,28 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass): def __call__(self, py_pass):
# 将Python传递对象注册到注册表中并返回该对象
self.register(py_pass) self.register(py_pass)
return py_pass return py_pass
def gen_new_parameter(self, pattern): def gen_new_parameter(self, pattern):
# 根据给定的模式生成新的参数模式必须是NewParameter类型
if not isinstance(pattern, NewParameter): if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
super().gen_new_parameter(pattern) super().gen_new_parameter(pattern)
def set_renorm(self, should_renorm): def set_renorm(self, should_renorm):
# 设置是否进行重归一化操作,参数必须是布尔值
if not isinstance(should_renorm, bool): if not isinstance(should_renorm, bool):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm) super().set_renorm(should_renorm)
def set_reopt(self, do_reopt): def set_reopt(self, do_reopt):
# 设置是否进行重新优化操作,参数必须是布尔值
if not isinstance(do_reopt, bool): if not isinstance(do_reopt, bool):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt) super().set_reopt(do_reopt)
def register_pass(requires_grad=True, run_only_once=False): def register_pass(requires_grad=True, run_only_once=False):
""" """
Register python pass to specified pipeline phase which would be used in compilation. Register python pass to specified pipeline phase which would be used in compilation.
@ -165,12 +172,13 @@ def cancel_new_parameter(pattern):
>>> # some compilations >>> # some compilations
>>> cancel_new_parameter(abc) >>> cancel_new_parameter(abc)
""" """
# 检查传入的pattern是否为NewParameter的实例
if not isinstance(pattern, NewParameter): if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
# 创建一个PyPassManager对象
ppm = PyPassManager() ppm = PyPassManager()
# 从PyPassManager中注销指定名称的参数
ppm.unregister(pattern.para_name) ppm.unregister(pattern.para_name)
def set_renorm(should_renorm): def set_renorm(should_renorm):
""" """
Set whether or not to do renormalization after modified graph in python pass(es). Set whether or not to do renormalization after modified graph in python pass(es).

@ -1,47 +1,117 @@
# Copyright 2020-2022 Huawei Technologies Co., Ltd # Copyright 2020-2022 Huawei Technologies Co., Ltd
# # 代码版权声明说明此代码由华为技术有限公司在2020-2022年间开发
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# 说明此代码使用Apache License 2.0版本的许可证
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# 说明除非遵守许可证,否则不得使用此文件
# You may obtain a copy of the License at # You may obtain a copy of the License at
# # 提供许可证的获取地址
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# # 许可证的具体地址
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# 除非适用法律要求或书面同意
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# 许可证在“现状”基础上进行分发
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 不附带任何形式的明示或暗示的担保或条件
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# 请参阅许可证了解特定的权限和
# limitations under the License. # limitations under the License.
# 限制条件
# ============================================================================ # ============================================================================
# 标准分割线,通常用于分隔许可证部分与代码部分
"""cell""" """cell"""
# 文档字符串模块的名称为cell
import gc import gc
# 导入垃圾回收模块,用于管理内存
import inspect import inspect
# 导入inspect模块用于获取活对象的信息
import os import os
# 导入os模块用于与操作系统进行交互
import time import time
# 导入time模块用于处理时间相关操作
from collections import OrderedDict from collections import OrderedDict
# 从collections模块导入OrderedDict类用于创建有序字典
from types import FunctionType, MethodType from types import FunctionType, MethodType
# 从types模块导入FunctionType和MethodType类用于类型检查
import numpy import numpy
# 导入numpy模块用于科学计算
from mindspore._checkparam import args_type_check from mindspore._checkparam import args_type_check
# 从mindspore._checkparam模块导入args_type_check函数用于检查函数参数的类型
from mindspore import log as logger from mindspore import log as logger
# 从mindspore模块导入log模块并命名为logger用于日志记录
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
# 从mindspore.common.parameter模块导入PARAMETER_NAME_DEFAULT常量用于默认参数名称
from mindspore.common.hook_handle import HookHandle from mindspore.common.hook_handle import HookHandle
# 从mindspore.common.hook_handle模块导入HookHandle类用于管理钩子处理
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
# 从mindspore.context模块导入ParallelMode类用于并行模式配置
from mindspore.ops.composite import Shard from mindspore.ops.composite import Shard
# 从mindspore.ops.composite模块导入Shard类用于分片操作
from .. import context from .. import context
# 导入相对路径的context模块用于上下文配置
from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType from .._c_expression import init_pipeline, update_func_graph_hyper_params, Cell_, FuncGraph, MixedPrecisionType
# 从相对路径的_c_expression模块导入多个函数和类用于初始化管道、更新函数图超参数、Cell的基础类、函数图类、混合精度类型
from .._checkparam import Validator from .._checkparam import Validator
# 从相对路径的_checkparam模块导入Validator类用于参数验证
from ..common import dtype as mstype from ..common import dtype as mstype
# 从相对路径的common模块导入dtype模块并重命名为mstype用于数据类型定义
from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache from ..common.api import _cell_graph_executor, _pynative_executor, _check_all_tensor, cells_compile_cache
# 从相对路径的common.api模块导入多个函数和类用于单元图执行器、原生模式执行器、检查所有张量、编译缓存
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
# 从相对路径的common.parameter模块导入Parameter类和ParameterTuple类用于参数和参数元组
from ..common.variable import Variable from ..common.variable import Variable
# 从相对路径的common.variable模块导入Variable类用于变量表示
from ..common.tensor import Tensor, CSRTensor, COOTensor from ..common.tensor import Tensor, CSRTensor, COOTensor
# 从相对路径的common.tensor模块导入Tensor类、CSRTensor类和COOTensor类用于张量表示
from ..ops.operations import Cast from ..ops.operations import Cast
# 从相对路径的ops.operations模块导入Cast类用于类型转换操作
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
# 从相对路径的ops.primitive模块导入Primitive类用于基础操作
from ..ops.operations import _inner_ops as inner from ..ops.operations import _inner_ops as inner
from ..parallel._tensor import _load_tensor_by_layout # 从相对路径的ops.operations模块导入_inner_ops并重命名为inner用于内部操作
from ..parallel._tensor import _load_tensor_by_layout
# 从相对路径的parallel._tensor模块导入_load_tensor_by_layout函数用于按布局加载张量
class Cell(Cell_): class Cell(Cell_):
# 定义Cell类继承自Cell_类这是MindSpore中神经网络的基本构建单元
""" """
The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this The basic building block of neural networks in MindSpore. The model or neural network layer should inherit this
base class. base class.
@ -81,6 +151,8 @@ class Cell(Cell_):
... # the parameter's name will be 'net.weight'. ... # the parameter's name will be 'net.weight'.
[Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)] [Parameter (name=weight, shape=(240, 120, 4, 4), dtype=Float32, requires_grad=True)]
""" """
# 类文档字符串解释Cell类的作用、继承关系、参数、支持平台及示例
class _CellGuard: class _CellGuard:
"""Detecting whether the cell is a top-level cell with the 'with statement'.""" """Detecting whether the cell is a top-level cell with the 'with statement'."""

@ -22,61 +22,68 @@ from ...common.api import ms_function
class _FirstGrad(Cell): class _FirstGrad(Cell):
# 计算第一个梯度的类
def __init__(self, fn): def __init__(self, fn):
super(_FirstGrad, self).__init__() super(_FirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True) self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
self.fn = fn self.fn = fn
def construct(self, u, first_grad_input): def construct(self, u, first_grad_input):
# 构造方法,用于计算梯度
return self.first_grad_op(self.fn)(*first_grad_input, u) return self.first_grad_op(self.fn)(*first_grad_input, u)
class _JvpFirstGrad(Cell): class _JvpFirstGrad(Cell):
# 计算Jacobian-Vector-Product的第一个梯度的类
def __init__(self): def __init__(self):
super(_JvpFirstGrad, self).__init__() super(_JvpFirstGrad, self).__init__()
self.first_grad_op = C.GradOperation(sens_param=True, get_all=True) self.first_grad_op = C.GradOperation(sens_param=True, get_all=True)
def construct(self, u, fn, first_grad_input): def construct(self, u, fn, first_grad_input):
# 构造方法用于计算JVP的第一个梯度
return self.first_grad_op(fn)(*first_grad_input, u) return self.first_grad_op(fn)(*first_grad_input, u)
class _FirstGradSingleValue(Cell): class _FirstGradSingleValue(Cell):
# 计算单值梯度的类
def __init__(self, fn): def __init__(self, fn):
super(_FirstGradSingleValue, self).__init__() super(_FirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True) self.first_grad_single_value_op = C.GradOperation(sens_param=True)
self.fn = fn self.fn = fn
def construct(self, u, first_grad_single_value_input): def construct(self, u, first_grad_single_value_input):
# 构造方法,用于计算单值梯度
return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u) return self.first_grad_single_value_op(self.fn)(*first_grad_single_value_input, u)
class _JvpFirstGradSingleValue(Cell): class _JvpFirstGradSingleValue(Cell):
# 计算Jacobian-Vector-Product的单值梯度的类
def __init__(self): def __init__(self):
super(_JvpFirstGradSingleValue, self).__init__() super(_JvpFirstGradSingleValue, self).__init__()
self.first_grad_single_value_op = C.GradOperation(sens_param=True) self.first_grad_single_value_op = C.GradOperation(sens_param=True)
def construct(self, u, fn, first_grad_single_value_input): def construct(self, u, fn, first_grad_single_value_input):
# 构造方法用于计算JVP的单值梯度
return self.first_grad_single_value_op(fn)(*first_grad_single_value_input, u) return self.first_grad_single_value_op(fn)(*first_grad_single_value_input, u)
class Jvp(Cell): class Jvp(Cell):
""" """
Compute the jacobian-vector-product of the given fn. Jvp is equivalent to forward mode autodiff. 计算给定fn的雅可比向量积Jvp等同于前向模式自动微分
Args: Args:
fn (Cell): The fn that takes Tensor inputs and returns a tuple of Tensors or a Tensor. fn (Cell): 接受Tensor输入并返回Tensor元组或Tensor的fn
Inputs: Inputs:
- **inputs** (Tensors) - The inputs to `fn`. - **inputs** (Tensors) - `fn`的输入
- **v** (Tensors or Tuple of Tensors) - The vector for which the Jacobian vector product is computed. - **v** (Tensors Tensor元组) - 用于计算雅可比向量积的向量
Must have the same size as the input of `fn`. 必须与`fn`的输入大小相同
Outputs: Outputs:
A tuple with 2 Tensors or Tuple of Tensors: 包含2个Tensors或Tensor元组的元组
- **net_output** (Tensors or Tuple of Tensors) - The output of `fn(inputs)`. - **net_output** (Tensors Tensor元组) - `fn(inputs)`的输出
- **jvp** (Tensors or Tuple of Tensors) - The result of the jacobian vector product. - **jvp** (Tensors Tensor元组) - 雅可比向量积的结果
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -113,6 +120,7 @@ class Jvp(Cell):
@ms_function @ms_function
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算JVP
jvp_input = args[0:-1] jvp_input = args[0:-1]
v = args[-1] v = args[-1]
output = self.fn(*jvp_input) output = self.fn(*jvp_input)
@ -135,8 +143,8 @@ class Jvp(Cell):
class _JvpInner(Cell): class _JvpInner(Cell):
""" """
Compute the jacobian-vector-product of the given network. Jvp is equivalent to forward mode autodiff. 计算给定网络的雅可比向量积Jvp等同于前向模式自动微分
This class implements the inner process of function jvp. 该类实现了JVP的内部过程
""" """
def __init__(self): def __init__(self):
super(_JvpInner, self).__init__() super(_JvpInner, self).__init__()
@ -152,6 +160,7 @@ class _JvpInner(Cell):
self.tuple_len = Primitive("tuple_len") self.tuple_len = Primitive("tuple_len")
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算内部JVP
fn = args[0] fn = args[0]
v = args[1] v = args[1]
jvp_input = args[2:] jvp_input = args[2:]
@ -175,22 +184,21 @@ class _JvpInner(Cell):
class Vjp(Cell): class Vjp(Cell):
""" """
Computes the dot product between a vector `v` and the Jacobian of the given fn at the point 计算给定向量`v`与给定fn在输入点处的雅可比的点积
given by the inputs.
Args: Args:
fn (Cell): The fn that takes Tensor inputs and returns a tuple of Tensors or a Tensor. fn (Cell): 接受Tensor输入并返回Tensor元组或Tensor的fn
Inputs: Inputs:
- **inputs** (Tensors) - The inputs to `fn`. Must be a tuple or a list. - **inputs** (Tensors) - `fn`的输入必须是元组或列表
- **v** (Tensors or Tuple of Tensors) - The vector for which the vector Jacobian product is computed. - **v** (Tensors Tensor元组) - 用于计算向量雅可比积的向量
Must have the same size as the output of `fn`. 必须与`fn`的输出大小相同
Outputs: Outputs:
A tuple with 2 Tensors or Tuple of Tensors: 包含2个Tensors或Tensor元组的元组
- **net_output** (Tensors or Tuple of Tensors) - The output of `fn(inputs)`. - **net_output** (Tensors Tensor元组) - `fn(inputs)`的输出
- **vjp** (Tensors or Tuple of Tensors) - The result of the dot product. - **vjp** (Tensors Tensor元组) - 点积的结果
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -226,6 +234,7 @@ class Vjp(Cell):
@ms_function @ms_function
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算VJP
front_input = args[0:-1] front_input = args[0:-1]
output = self.fn(*front_input) output = self.fn(*front_input)
if self.tuple_len(front_input) == 1: if self.tuple_len(front_input) == 1:
@ -237,8 +246,8 @@ class Vjp(Cell):
class _VjpInner(Cell): class _VjpInner(Cell):
""" """
Computes the dot product between a vector `v` and the Jacobian of the given network at the point 计算给定向量`v`与给定网络在输入点处的雅可比的点积
given by the inputs. This class implements the inner process of function vjp. 该类实现了VJP的内部过程
""" """
def __init__(self): def __init__(self):
@ -248,6 +257,7 @@ class _VjpInner(Cell):
self.tuple_len = Primitive("tuple_len") self.tuple_len = Primitive("tuple_len")
def construct(self, *args): def construct(self, *args):
# 构造方法用于计算内部VJP
fn = args[0] fn = args[0]
front_input = args[1:-1] front_input = args[1:-1]
input_with_v = args[1:] input_with_v = args[1:]

@ -48,24 +48,23 @@ class LossBase(Cell):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize Loss.""" """Initialize Loss.""" # 初始化LossBase类接收一个参数reduction默认值为'mean'
super(LossBase, self).__init__() super(LossBase, self).__init__() # 调用父类Cell的初始化方法
if reduction not in ('mean', 'sum', 'none'): if reduction not in ('mean', 'sum', 'none'): # 检查reduction参数是否为'mean', 'sum', 'none'中的一个
raise ValueError(f"For '{self.cls_name}', the 'reduction' should be in ['mean', 'sum', 'none'], " raise ValueError(f"For '{self.cls_name}', the 'reduction' should be in ['mean', 'sum', 'none'], " # 如果参数不在允许的范围内抛出ValueError
f"but got {reduction}.") f"but got {reduction}.")
self.average = True # 设置average属性为True默认进行平均
self.average = True self.reduce = True # 设置reduce属性为True默认进行降维
self.reduce = True if reduction == 'sum': # 如果reduction参数为'sum'
if reduction == 'sum': self.average = False # 设置average属性为False不进行平均
self.average = False if reduction == 'none': # 如果reduction参数为'none'
if reduction == 'none': self.reduce = False # 设置reduce属性为False不进行降维
self.reduce = False
self.reduce_mean = P.ReduceMean() # 定义reduce_mean操作用于计算平均损失
self.reduce_mean = P.ReduceMean() self.reduce_sum = P.ReduceSum() # 定义reduce_sum操作用于计算总损失
self.reduce_sum = P.ReduceSum() self.mul = P.Mul() # 定义mul操作用于权重乘法
self.mul = P.Mul() self.cast = P.Cast() # 定义cast操作用于数据类型转换
self.cast = P.Cast()
def get_axis(self, x): def get_axis(self, x):
""" """
@ -98,10 +97,10 @@ class LossBase(Cell):
>>> print(output) >>> print(output)
(0, 1) (0, 1)
""" """
shape = F.shape(x) shape = F.shape(x) # 获取输入张量x的形状
length = F.tuple_len(shape) length = F.tuple_len(shape) # 获取形状的长度(即维度数量)
perm = F.make_range(0, length) perm = F.make_range(0, length) # 生成一个从0到length-1的元组表示所有轴
return perm return perm # 返回这个元组
def get_loss(self, x, weights=1.0): def get_loss(self, x, weights=1.0):
""" """
@ -141,20 +140,19 @@ class LossBase(Cell):
>>> print(output) >>> print(output)
0.11111111 0.11111111
""" """
input_dtype = x.dtype input_dtype = x.dtype # 获取输入张量x的数据类型
x = self.cast(x, mstype.float32) x = self.cast(x, mstype.float32) # 将输入张量x的数据类型转换为float32
weights = self.cast(weights, mstype.float32) weights = self.cast(weights, mstype.float32) # 将权重weights的数据类型转换为float32
x = self.mul(weights, x) x = self.mul(weights, x) # 将权重weights与输入张量x相乘
if self.reduce and self.average: if self.reduce and self.average: # 如果需要降维且进行平均
x = self.reduce_mean(x, self.get_axis(x)) x = self.reduce_mean(x, self.get_axis(x)) # 计算平均损失
if self.reduce and not self.average: if self.reduce and not self.average: # 如果需要降维但不进行平均
x = self.reduce_sum(x, self.get_axis(x)) x = self.reduce_sum(x, self.get_axis(x)) # 计算总损失
x = self.cast(x, input_dtype) x = self.cast(x, input_dtype) # 将损失x的数据类型转换回输入张量x的原始数据类型
return x return x # 返回计算得到的损失
def construct(self, logits, labels): def construct(self, logits, labels):
raise NotImplementedError raise NotImplementedError # 这是一个抽象方法,需要在子类中实现
class _Loss(LossBase): class _Loss(LossBase):
""" """
@ -162,23 +160,21 @@ class _Loss(LossBase):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize _Loss.""" """Initialize _Loss.""" # 初始化_Loss类接收一个参数reduction默认值为'mean'
log.warning("'_Loss' is deprecated from version 1.3 and " log.warning("'_Loss' is deprecated from version 1.3 and " # 输出警告信息提示_Loss类已过时
"will be removed in a future version, use 'LossBase' instead.") "will be removed in a future version, use 'LossBase' instead.")
super(_Loss, self).__init__(reduction) super(_Loss, self).__init__(reduction) # 调用父类LossBase的初始化方法
def construct(self, logits, labels): def construct(self, logits, labels):
raise NotImplementedError raise NotImplementedError # 这是一个抽象方法,需要在子类中实现
@constexpr @constexpr
def _check_is_tensor(param_name, input_data, cls_name): def _check_is_tensor(param_name, input_data, cls_name):
"""Internal function, used to check whether the input data is Tensor.""" """Internal function, used to check whether the input data is Tensor.""" # 定义一个内部函数用于检查输入数据是否为Tensor
if input_data is not None and not isinstance(F.typeof(input_data), mstype.tensor_type): if input_data is not None and not isinstance(F.typeof(input_data), mstype.tensor_type): # 如果输入数据不为None且类型不是Tensor
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', " raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', " # 抛出TypeError
f"but got '{F.typeof(input_data)}'") f"but got '{F.typeof(input_data)}'")
class L1Loss(LossBase): class L1Loss(LossBase):
r""" r"""
L1Loss is used to calculate the mean absolute error between the predicted value and the target value. L1Loss is used to calculate the mean absolute error between the predicted value and the target value.
@ -238,16 +234,15 @@ class L1Loss(LossBase):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize L1Loss.""" """Initialize L1Loss.""" # 初始化L1Loss类接收一个参数reduction默认值为'mean'
super(L1Loss, self).__init__(reduction) super(L1Loss, self).__init__(reduction) # 调用父类LossBase的初始化方法
self.abs = P.Abs() self.abs = P.Abs() # 定义abs操作用于计算绝对值
def construct(self, logits, labels): def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) # 检查labels是否为Tensor
x = self.abs(logits - labels) x = self.abs(logits - labels) # 计算logits与labels的差的绝对值
return self.get_loss(x) return self.get_loss(x) # 使用self.get_loss方法计算加权损失并返回
class MSELoss(LossBase): class MSELoss(LossBase):
r""" r"""
@ -308,10 +303,10 @@ class MSELoss(LossBase):
""" """
def construct(self, logits, labels): def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) # 检查labels是否为Tensor
x = F.square(logits - labels) x = F.square(logits - labels) # 计算logits与labels的差的平方
return self.get_loss(x) return self.get_loss(x) # 使用self.get_loss方法计算加权损失并返回
class RMSELoss(LossBase): class RMSELoss(LossBase):
@ -356,15 +351,14 @@ class RMSELoss(LossBase):
""" """
def __init__(self): def __init__(self):
"""Initialize RMSELoss.""" """Initialize RMSELoss.""" # 初始化RMSELoss类
super(RMSELoss, self).__init__() super(RMSELoss, self).__init__() # 调用父类LossBase的初始化方法
self.MSELoss = MSELoss() self.MSELoss = MSELoss() # 初始化MSELoss对象用于计算均方误差
def construct(self, logits, label): def construct(self, logits, label):
rmse_loss = F.sqrt(self.MSELoss(logits, label)) rmse_loss = F.sqrt(self.MSELoss(logits, label)) # 计算均方误差损失然后取平方根得到RMSE损失
return rmse_loss
return rmse_loss # 返回计算得到的RMSE损失
class MAELoss(LossBase): class MAELoss(LossBase):
r""" r"""
@ -426,16 +420,15 @@ class MAELoss(LossBase):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize MAELoss.""" """Initialize MAELoss.""" # 初始化MAELoss类接收一个参数reduction默认值为'mean'
super(MAELoss, self).__init__(reduction) super(MAELoss, self).__init__(reduction) # 调用父类LossBase的初始化方法
self.abs = P.Abs() self.abs = P.Abs() # 定义abs操作用于计算绝对值
def construct(self, logits, label): def construct(self, logits, label):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', label, self.cls_name) _check_is_tensor('labels', label, self.cls_name) # 检查labels是否为Tensor
x = self.abs(logits - label) x = self.abs(logits - label) # 计算logits与labels的差的绝对值
return self.get_loss(x) return self.get_loss(x) # 使用self.get_loss方法计算加权损失并返回
class SmoothL1Loss(LossBase): class SmoothL1Loss(LossBase):
r""" r"""
@ -491,16 +484,15 @@ class SmoothL1Loss(LossBase):
""" """
def __init__(self, beta=1.0): def __init__(self, beta=1.0):
"""Initialize SmoothL1Loss.""" """Initialize SmoothL1Loss.""" # 初始化SmoothL1Loss类接收一个参数beta默认值为1.0
super(SmoothL1Loss, self).__init__() super(SmoothL1Loss, self).__init__() # 调用父类LossBase的初始化方法
self.beta = beta self.beta = beta # 设置beta属性表示平滑阈值
self.smooth_l1_loss = P.SmoothL1Loss(self.beta) self.smooth_l1_loss = P.SmoothL1Loss(self.beta) # 定义smooth_l1_loss操作用于计算平滑L1损失
def construct(self, logits, labels): def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) # 检查labels是否为Tensor
return self.smooth_l1_loss(logits, labels) return self.smooth_l1_loss(logits, labels) # 使用self.smooth_l1_loss计算平滑L1损失并返回
class SoftMarginLoss(LossBase): class SoftMarginLoss(LossBase):
r""" r"""
@ -545,12 +537,11 @@ class SoftMarginLoss(LossBase):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
super(SoftMarginLoss, self).__init__() super(SoftMarginLoss, self).__init__() # 调用父类LossBase的初始化方法
self.soft_margin_loss = P.SoftMarginLoss(reduction) self.soft_margin_loss = P.SoftMarginLoss(reduction) # 定义soft_margin_loss操作用于计算SoftMargin损失
def construct(self, logits, labels): def construct(self, logits, labels):
return self.soft_margin_loss(logits, labels) return self.soft_margin_loss(logits, labels) # 使用self.soft_margin_loss计算SoftMargin损失并返回
class SoftmaxCrossEntropyWithLogits(LossBase): class SoftmaxCrossEntropyWithLogits(LossBase):
r""" r"""
@ -619,27 +610,28 @@ class SoftmaxCrossEntropyWithLogits(LossBase):
def __init__(self, def __init__(self,
sparse=False, sparse=False,
reduction='none'): reduction='none'):
"""Initialize SoftmaxCrossEntropyWithLogits.""" """Initialize SoftmaxCrossEntropyWithLogits.""" # 初始化SoftmaxCrossEntropyWithLogits类接收sparse和reduction两个参数
super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) super(SoftmaxCrossEntropyWithLogits, self).__init__(reduction) # 调用父类LossBase的初始化方法传入reduction参数
self.sparse = validator.check_bool(sparse, "sparse", self.cls_name) self.sparse = validator.check_bool(sparse, "sparse", self.cls_name) # 检查sparse参数是否为布尔值如果是则赋值给self.sparse
self.reduction = reduction self.reduction = reduction # 设置reduction属性表示减少类型
self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() self.softmax_cross_entropy = P.SoftmaxCrossEntropyWithLogits() # 定义softmax_cross_entropy操作用于计算Softmax交叉熵损失
self.one_hot = P.OneHot() self.one_hot = P.OneHot() # 定义one_hot操作用于将标签转换为OneHot编码
self.on_value = Tensor(1.0, mstype.float32) self.on_value = Tensor(1.0, mstype.float32) # 定义on_value属性表示OneHot编码中正类的值
self.off_value = Tensor(0., mstype.float32) self.off_value = Tensor(0., mstype.float32) # 定义off_value属性表示OneHot编码中负类的值
self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] self.is_cpugpu = context.get_context('device_target') in ["CPU", "GPU"] # 定义is_cpugpu属性表示是否在CPU或GPU上运行
self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits() self.sparse_softmax_cross_entropy = P.SparseSoftmaxCrossEntropyWithLogits() # 定义sparse_softmax_cross_entropy操作用于计算稀疏Softmax交叉熵损失
def construct(self, logits, labels): def construct(self, logits, labels):
_check_is_tensor('logits', logits, self.cls_name) _check_is_tensor('logits', logits, self.cls_name) # 检查logits是否为Tensor
_check_is_tensor('labels', labels, self.cls_name) _check_is_tensor('labels', labels, self.cls_name) # 检查labels是否为Tensor
if self.sparse: if self.sparse: # 如果使用稀疏标签格式
if self.reduction == 'mean': if self.reduction == 'mean': # 如果reduction为'mean'
x = self.sparse_softmax_cross_entropy(logits, labels) x = self.sparse_softmax_cross_entropy(logits, labels) # 使用稀疏Softmax交叉熵损失计算损失
return x return x # 返回计算得到的损失
labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value) labels = self.one_hot(labels, F.shape(logits)[-1], self.on_value, self.off_value) # 将labels转换为OneHot编码
x = self.softmax_cross_entropy(logits, labels)[0] x = self.softmax_cross_entropy(logits, labels)[0] # 计算Softmax交叉熵损失取第一个返回值
return self.get_loss(x) return self.get_loss(x) # 使用self.get_loss方法计算加权损失并返回
@constexpr @constexpr

@ -85,58 +85,84 @@ def array(obj, dtype=None, copy=True, ndmin=0):
>>> print(np.array([1,2,3])) >>> print(np.array([1,2,3]))
[1 2 3] [1 2 3]
""" """
if dtype is not None: if dtype is not None: # 如果用户指定了数据类型则检查并转换为mindspore的数据类型
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
res = asarray(obj, dtype) res = asarray(obj, dtype) # 将输入对象转换为tensor
if ndmin > res.ndim: if ndmin > res.ndim: # 如果用户指定的最小维度大于转换后的tensor维度则在tensor的前面添加维度
if res.size == 0: if res.size == 0: # 如果tensor为空抛出异常
_raise_value_error("Empty tensor cannot be expanded beyond the current dimension.") _raise_value_error("Empty tensor cannot be expanded beyond the current dimension.")
res = _expand(res, ndmin) res = _expand(res, ndmin) # 扩展tensor的维度
if copy and isinstance(obj, Tensor): if copy and isinstance(obj, Tensor): # 如果copy为True且输入对象已经是tensor则创建其副本
res = copy_(res) res = copy_(res)
elif dtype is not None and dtype != res.dtype: elif dtype is not None and dtype != res.dtype: # 如果用户指定了数据类型且与转换后的tensor数据类型不同则转换数据类型
res = res.astype(dtype) res = res.astype(dtype)
return res return res # 返回最终生成的tensor
@constexpr @constexpr
def asarray_const(a, dtype=None): def asarray_const(a, dtype=None):
# 标记此函数为constexpr意味着它是一个编译时常量函数
"""Converts the input to tensor. Note here `a` cannot be tensor itself.""" """Converts the input to tensor. Note here `a` cannot be tensor itself."""
# 文档字符串解释函数作用将输入转换为张量注意这里的a不能是张量本身
_check_input_for_asarray(a) _check_input_for_asarray(a)
# 检查输入a是否符合asarray函数的输入要求
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if isinstance(a, (float, int, bool)) and dtype is None: if isinstance(a, (float, int, bool)) and dtype is None:
# 如果a是float、int或bool类型并且dtype未指定
dtype = _get_dtype_from_scalar(a) dtype = _get_dtype_from_scalar(a)
# 从标量a中获取数据类型并赋值给dtype
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# 如果a是list或tuple类型
# Convert all tuple/nested tuples to lists # Convert all tuple/nested tuples to lists
a = _deep_list(a) a = _deep_list(a)
# 将所有tuple及其嵌套的tuple转换为list
# Convert all tensor sub-elements to numpy arrays # Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a) a = _deep_tensor_to_nparray(a)
# 将所有tensor子元素转换为numpy数组
a = onp.asarray(a) a = onp.asarray(a)
# 使用numpy的asarray函数将a转换为numpy数组
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果转换后的numpy数组的数据类型是object
raise ValueError('Input array must have the same size across all dimensions.') raise ValueError('Input array must have the same size across all dimensions.')
# 抛出ValueError表示输入数组在所有维度上必须具有相同的大小
# If dtype is not specified, we keep consistent with numpy decision # If dtype is not specified, we keep consistent with numpy decision
# only exceptions are: we use int/float32 # only exceptions are: we use int/float32
if dtype is None: if dtype is None:
# 如果dtype未指定
dtype = mstype.pytype_to_dtype(a.dtype) dtype = mstype.pytype_to_dtype(a.dtype)
# 将numpy数组的数据类型转换为mindspore的dtype
if dtype == mstype.float64: if dtype == mstype.float64:
# 如果dtype是float64
dtype = mstype.float32 dtype = mstype.float32
# 将dtype改为float32
elif dtype == mstype.int64: elif dtype == mstype.int64:
# 如果dtype是int64
dtype = mstype.int32 dtype = mstype.int32
# 将dtype改为int32
if isinstance(a, onp.ndarray) and dtype is None: if isinstance(a, onp.ndarray) and dtype is None:
# 如果a是numpy数组并且dtype未指定
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果numpy数组的数据类型是object
raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") raise TypeError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
# 抛出TypeError表示输入数据包含不支持的元素
dtype = mstype.pytype_to_dtype(a.dtype) dtype = mstype.pytype_to_dtype(a.dtype)
# 将numpy数组的数据类型转换为mindspore的dtype
a = Tensor.from_numpy(a) a = Tensor.from_numpy(a)
# 将numpy数组转换为mindspore的Tensor
return Tensor(a, dtype=dtype) return Tensor(a, dtype=dtype)
# 返回一个具有指定dtype的Tensor
def asarray(a, dtype=None): def asarray(a, dtype=None):
@ -168,29 +194,46 @@ def asarray(a, dtype=None):
[1 2 3] [1 2 3]
""" """
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if isinstance(a, Tensor): if isinstance(a, Tensor):
# 如果a是Tensor类型
if dtype is None or dtype == a.dtype: if dtype is None or dtype == a.dtype:
# 如果dtype未指定或指定的数据类型与a的数据类型相同
return a return a
# 直接返回a
return a.astype(dtype) return a.astype(dtype)
# 如果指定的数据类型与a的数据类型不同将a的数据类型转换为指定的dtype并返回
return asarray_const(a, dtype) return asarray_const(a, dtype)
# 如果a不是Tensor类型调用asarray_const函数将其转换为Tensor并返回
@constexpr @constexpr
def asfarray_const(a, dtype=mstype.float32): def asfarray_const(a, dtype=mstype.float32):
"""Converts the input to tensor. Note here `a` cannot be tensor itself.""" """Converts the input to tensor. Note here `a` cannot be tensor itself."""
# 文档字符串解释函数作用将输入转换为张量注意这里的a不能是张量本身
_check_input_for_asarray(a) _check_input_for_asarray(a)
# 检查输入a是否符合asarray函数的输入要求
if isinstance(a, (list, tuple)): if isinstance(a, (list, tuple)):
# 如果a是list或tuple类型
# Convert all tuple/nested tuples to lists # Convert all tuple/nested tuples to lists
a = _deep_list(a) a = _deep_list(a)
# 将所有tuple及其嵌套的tuple转换为list
# Convert all tensor sub-elements to numpy arrays # Convert all tensor sub-elements to numpy arrays
a = _deep_tensor_to_nparray(a) a = _deep_tensor_to_nparray(a)
# 将所有tensor子元素转换为numpy数组
a = onp.asarray(a) a = onp.asarray(a)
# 使用numpy的asarray函数将a转换为numpy数组
if a.dtype is onp.dtype('object'): if a.dtype is onp.dtype('object'):
# 如果转换后的numpy数组的数据类型是object
raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.") raise ValueError(f"For Tensor conversion, the input_data is {a} that contains unsupported element.")
# 抛出ValueError表示输入数组在所有维度上必须具有相同的大小
a = Tensor.from_numpy(a) a = Tensor.from_numpy(a)
# 将numpy数组转换为mindspore的Tensor
return Tensor(a, dtype) return Tensor(a, dtype)
# 返回一个具有指定dtype的Tensor
def asfarray(a, dtype=mstype.float32): def asfarray(a, dtype=mstype.float32):
@ -206,7 +249,6 @@ def asfarray(a, dtype=mstype.float32):
be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type be in format of np.int32, or \'int32\'. If dtype is :class:`None`, the data type
of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`. of the new tensor will be inferred from `a`. Default is :class:`mindspore.float32`.
Returns: Returns:
Tensor, generated tensor with the specified float dtype. Tensor, generated tensor with the specified float dtype.
@ -223,16 +265,24 @@ def asfarray(a, dtype=mstype.float32):
[1. 2. 3.] [1. 2. 3.]
""" """
if dtype is None: if dtype is None:
# 如果dtype未指定
return asarray(a) return asarray(a)
# 调用asarray函数将a转换为Tensor并返回
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if dtype not in (mstype.float16, mstype.float32, mstype.float64): if dtype not in (mstype.float16, mstype.float32, mstype.float64):
# 如果dtype不是float16、float32或float64
dtype = mstype.float32 dtype = mstype.float32
# 将dtype改为float32
if isinstance(a, Tensor): if isinstance(a, Tensor):
# 如果a是Tensor类型
return a.astype(dtype) return a.astype(dtype)
# 将a的数据类型转换为指定的dtype并返回
return asfarray_const(a, dtype) return asfarray_const(a, dtype)
# 如果a不是Tensor类型调用asfarray_const函数将其转换为Tensor并返回
def copy_(a): def copy_(a):
@ -261,7 +311,9 @@ def copy_(a):
[1. 1.]] [1. 1.]]
""" """
a = asarray(a) a = asarray(a)
# 使用asarray函数将a转换为Tensor
return a.copy() return a.copy()
# 返回a的副本
def ones(shape, dtype=mstype.float32): def ones(shape, dtype=mstype.float32):
@ -290,11 +342,17 @@ def ones(shape, dtype=mstype.float32):
[1. 1.]] [1. 1.]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if _is_shape_empty(shape): if _is_shape_empty(shape):
# 如果shape表示的形状是空的
return full(shape, 1.0, dtype) return full(shape, 1.0, dtype)
# 使用full函数创建一个指定形状、数据类型并用1.0填充的Tensor
output = F.fill(dtype, shape, 1) output = F.fill(dtype, shape, 1)
# 使用F.fill函数创建一个指定形状、数据类型并用1填充的Tensor
return output return output
# 返回创建的Tensor
def zeros(shape, dtype=mstype.float32): def zeros(shape, dtype=mstype.float32):
@ -323,11 +381,17 @@ def zeros(shape, dtype=mstype.float32):
[0. 0.]] [0. 0.]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
if _is_shape_empty(shape): if _is_shape_empty(shape):
# 如果shape表示的形状是空的
return full(shape, 0.0, dtype) return full(shape, 0.0, dtype)
# 使用full函数创建一个指定形状、数据类型并用0.0填充的Tensor
output = F.fill(dtype, shape, 0) output = F.fill(dtype, shape, 0)
# 使用F.fill函数创建一个指定形状、数据类型并用0填充的Tensor
return output return output
# 返回创建的Tensor
def full(shape, fill_value, dtype=None): def full(shape, fill_value, dtype=None):
@ -360,24 +424,42 @@ def full(shape, fill_value, dtype=None):
[True True]] [True True]]
""" """
shape = _check_shape(shape) shape = _check_shape(shape)
# 检查并确认shape是一个有效的形状
if not isinstance(fill_value, ARRAY_TYPES): if not isinstance(fill_value, ARRAY_TYPES):
# 如果fill_value不是int、float、bool、list、tuple、Tensor类型
_raise_type_error("fill value should be int, float, bool, list, tuple, Tensor, but got", fill_value) _raise_type_error("fill value should be int, float, bool, list, tuple, Tensor, but got", fill_value)
# 抛出TypeError表示fill_value类型不支持
if dtype is not None: if dtype is not None:
# 如果dtype不为None
dtype = _check_dtype(dtype) dtype = _check_dtype(dtype)
# 检查并确认dtype是一个有效的数据类型
else: else:
# 如果dtype为None
if isinstance(fill_value, (int, float, bool)): if isinstance(fill_value, (int, float, bool)):
# 如果fill_value是int、float或bool类型
dtype = _get_dtype_from_scalar(fill_value) dtype = _get_dtype_from_scalar(fill_value)
# 从标量fill_value中获取数据类型并赋值给dtype
if isinstance(fill_value, Tensor): if isinstance(fill_value, Tensor):
# 如果fill_value是Tensor类型
dtype = fill_value.dtype dtype = fill_value.dtype
# 从Tensor fill_value中获取数据类型并赋值给dtype
if not _is_shape_empty(shape): if not _is_shape_empty(shape):
# 如果shape表示的形状不是空的
if isinstance(fill_value, (int, float, bool)): if isinstance(fill_value, (int, float, bool)):
# 如果fill_value是int、float或bool类型
return F.fill(dtype, shape, fill_value) return F.fill(dtype, shape, fill_value)
# 使用F.fill函数创建一个指定形状、数据类型并用fill_value填充的Tensor
if isinstance(fill_value, (list, tuple)): if isinstance(fill_value, (list, tuple)):
# 如果fill_value是list或tuple类型
fill_value = asarray_const(fill_value) fill_value = asarray_const(fill_value)
# 使用asarray_const函数将fill_value转换为Tensor
return broadcast_to(fill_value, shape) return broadcast_to(fill_value, shape)
# 使用broadcast_to函数将fill_value广播到指定的shape并返回结果
# if shape contains zero, use c.Tensor() # if shape contains zero, use c.Tensor()
return _convert_64_to_32(empty_compile(dtype, shape)) return _convert_64_to_32(empty_compile(dtype, shape))
# 如果shape包含零使用empty_compile函数创建一个空的Tensor并使用_convert_64_to_32函数将数据类型从float64转换为float32
@constexpr @constexpr

@ -21,36 +21,44 @@ from .. import signature as sig
class UpdateCache(PrimitiveWithCheck): class UpdateCache(PrimitiveWithCheck):
""" """
Update the value fo input_x, similar to ScatterNdUpdate. 更新 input_x 的值类似于 ScatterNdUpdate
The difference is that UpdateCache will not update when indices < 0 or indices >= max_num. 不同之处在于UpdateCache indices < 0 indices >= max_num 时不会更新
Inputs: Inputs:
- **input_x** (Parameter) - Parameter which is going to be updated. - **input_x** (Parameter) - 将要更新的参数
- **indices** (Tensor) - Update indices of input_x. - **indices** (Tensor) - input_x 的更新索引
- **updates** (Tensor) - The update values. - **updates** (Tensor) - 更新值
Outputs: Outputs:
- **out** (Tensor) - Returns a [1] Tensor, which is not useful. - **out** (Tensor) - 返回一个 [1] 的张量这个张量没有用处
""" """
# 定义函数签名,指定输入参数的类型和读写权限
__mindspore_signature__ = ( __mindspore_signature__ = (
# 定义输入参数input_x类型为T读写权限为写
sig.make_sig('input_x', sig.sig_rw.RW_WRITE, sig.make_sig('input_x', sig.sig_rw.RW_WRITE,
dtype=sig.sig_dtype.T), dtype=sig.sig_dtype.T),
# 定义输入参数indices类型为T1
sig.make_sig('indices', dtype=sig.sig_dtype.T1), sig.make_sig('indices', dtype=sig.sig_dtype.T1),
# 定义输入参数updates类型为T
sig.make_sig('updates', dtype=sig.sig_dtype.T), sig.make_sig('updates', dtype=sig.sig_dtype.T),
# 定义输入参数max_num类型为T1
sig.make_sig('max_num', dtype=sig.sig_dtype.T1) sig.make_sig('max_num', dtype=sig.sig_dtype.T1)
) )
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init UpdateCache""" """初始化 UpdateCache"""
# 初始化输入和输出名称
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
outputs=['out']) outputs=['out'])
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
# 检查输入形状
return [1] return [1]
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
# 检查输入数据类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"indices", indices_dtype, mstype.int_type, self.name) "indices", indices_dtype, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype
@ -58,20 +66,18 @@ class UpdateCache(PrimitiveWithCheck):
class SubAndFilter(PrimitiveWithCheck): class SubAndFilter(PrimitiveWithCheck):
""" """
Dynamic kernel, sub an offset and 动态内核减去一个偏移量并返回在范围 [0, max_num) 内的元素
return the elements which in range [0, max_num).
Inputs: Inputs:
- **input_x** (Tensor) - Input tensor. - **input_x** (Tensor) - 输入张量
- **max_num** (Int) - The max value of element that after sub `offset`. - **max_num** (Int) - 减去 `offset` 后元素的最大值
- **offset** (int) - Specifies the offset value of this `input_x`. - **offset** (int) - 指定此 `input_x` 的偏移值
Outputs: Outputs:
tuple(Tensor), tuple of 2 tensors, filter_res and filter_idx. tuple(Tensor), 2 个张量组成的元组filter_res filter_idx
- **filter_res** (Tensor) - The result that `input_x` minus `offset`, - **filter_res** (Tensor) - `input_x` 减去 `offset` 的结果
and return which in the range [0, max_num). 并返回在范围 [0, max_num) 内的值
- **filter_idx** (Tensor) - A tensor containing indices of elements in the input - **filter_idx** (Tensor) - 一个张量包含与输出张量对应的输入元素的索引
coressponding to the output tensor.
Supported Platforms: Supported Platforms:
`CPU` `CPU`
@ -87,15 +93,18 @@ class SubAndFilter(PrimitiveWithCheck):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init SubAndFilter""" """初始化 SubAndFilter"""
# 初始化输入和输出名称
self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'], self.init_prim_io_names(inputs=['input_x', 'max_num', 'offset'],
outputs=['sub_res', 'sub_idx']) outputs=['sub_res', 'sub_idx'])
def check_shape(self, input_x_shape, max_num_shape, offset_shape): def check_shape(self, input_x_shape, max_num_shape, offset_shape):
# 检查输入形状
return ((-1,), (-1,)) return ((-1,), (-1,))
def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype): def check_dtype(self, input_x_dtype, max_num_dtype, offset_dtype):
# 检查输入数据类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"input_x", input_x_dtype, mstype.int_type, self.name) "input_x", input_x_dtype, mstype.int_type, self.name)
return input_x_dtype return input_x_dtype
@ -103,15 +112,15 @@ class SubAndFilter(PrimitiveWithCheck):
class MapUniform(PrimitiveWithCheck): class MapUniform(PrimitiveWithCheck):
""" """
Map a tensor by using fomula : value = key % `group_num` * `per_group_size` + key // `group_num`. 通过公式映射一个张量value = key % `group_num` * `per_group_size` + key // `group_num`
Inputs: Inputs:
- **input** (Tensor) - Input Tensor. - **input** (Tensor) - 输入张量
- **per_group_size** (int) - The size of each group. - **per_group_size** (int) - 每个组的大小
- **group_num** (int) - The number of group. - **group_num** (int) - 组的数量
Outputs: Outputs:
Tensor, has the same dtype and shape as the `input`. Tensor具有与 `input` 相同的 dtype 和形状
Supported Platforms: Supported Platforms:
`CPU` `CPU`
@ -128,11 +137,12 @@ class MapUniform(PrimitiveWithCheck):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init MapUniform""" """初始化 MapUniform"""
self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'], self.init_prim_io_names(inputs=['input', 'per_group_size', 'group_num'],
outputs=['output']) outputs=['output'])
def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype): def check_dtype(self, input_dtype, per_group_size_dtype, group_num_dtype):
"""检查输入数据类型"""
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"input", input_dtype, mstype.int_type, self.name) "input", input_dtype, mstype.int_type, self.name)
validator.check_value_type( validator.check_value_type(
@ -143,15 +153,15 @@ class MapUniform(PrimitiveWithCheck):
class CacheSwapTable(PrimitiveWithCheck): class CacheSwapTable(PrimitiveWithCheck):
""" """
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. 删除一个哈希映射条目并插入一个新键到哈希映射中返回删除条目的键和值
Inputs: Inputs:
- **cache_table** (Parameter) - The cache table which is on device. - **cache_table** (Parameter) - 在设备上的缓存表
- **swap_cache_idx** (Tensor) - The index of table which need to swap. -1 is skipped. - **swap_cache_idx** (Tensor) - 需要交换的表索引-1 被跳过
- **miss_value** (int) - The values which arg going to swap into cache table. - **miss_value** (int) - 将要交换到缓存表的值
Outputs: Outputs:
- **old_value** (Tensor) - The values which are swapped out. - **old_value** (Tensor) - 被交换出去的值
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('cache_table', sig.sig_rw.RW_WRITE, sig.make_sig('cache_table', sig.sig_rw.RW_WRITE,
@ -162,28 +172,32 @@ class CacheSwapTable(PrimitiveWithCheck):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init CacheSwapTable""" """初始化 CacheSwapTable"""
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
outputs=['old_value']) outputs=['old_value'])
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
# 检查cache_table_shape的长度是否为2如果不是则抛出ValueError异常
if len(cache_table_shape) != 2: if len(cache_table_shape) != 2:
raise ValueError( raise ValueError(
"cache table shape must be 2, but got %d" % len(cache_table_shape)) "cache table shape must be 2, but got %d" % len(cache_table_shape))
# 返回miss_value_shape
return miss_value_shape return miss_value_shape
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
# 检查swap_cache_idx_dtype是否为mstype.int_type如果不是则抛出ValueError异常
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
# 返回miss_value_dtype
return miss_value_dtype return miss_value_dtype
class MapCacheIdx(PrimitiveWithCheck): class MapCacheIdx(PrimitiveWithCheck):
""" """
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. MapCacheIdx SearchCacheIdxCacheSwapHashmap UpdateCache 合并在一起
When input an indices tensor, it will output the cache indices which search in hashmap. 当输入一个索引张量时它将输出在哈希映射中搜索的缓存索引
""" """
__mindspore_signature__ = ( __mindspore_signature__ = (
sig.make_sig('hashmap', sig.sig_rw.RW_WRITE, sig.make_sig('hashmap', sig.sig_rw.RW_WRITE,
@ -196,52 +210,65 @@ class MapCacheIdx(PrimitiveWithCheck):
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""init MapCacheIdx""" """初始化 MapCacheIdx"""
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
def __check__(self, hashmap, indices, step, emb_max_num, offset): def __check__(self, hashmap, indices, step, emb_max_num, offset):
# 获取hashmap的形状
hashmap_shape = hashmap['shape'] hashmap_shape = hashmap['shape']
# 如果hashmap的维度不是2则抛出异常
if len(hashmap_shape) != 2: if len(hashmap_shape) != 2:
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
"but got %d." % len(hashmap_shape)) "but got %d." % len(hashmap_shape))
# 设置输出的形状
out_shape = (indices['shape'], -1, -1, -1) out_shape = (indices['shape'], -1, -1, -1)
# 获取hashmap和indices的数据类型
hashmap_dtype = hashmap['dtype'] hashmap_dtype = hashmap['dtype']
indices_dtype = indices['dtype'] indices_dtype = indices['dtype']
# 将数据类型存入字典
args = {"hashmap": hashmap_dtype, "indices": indices_dtype} args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
# 检查数据类型是否相同且有效
validator.check_tensors_dtypes_same_and_valid( validator.check_tensors_dtypes_same_and_valid(
args, mstype.int_type, self.name) args, mstype.int_type, self.name)
# 设置输出的数据类型
out_dtype = (hashmap_dtype, hashmap_dtype, out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype) hashmap_dtype, hashmap_dtype)
# 设置输出的字典
out = {'shape': out_shape, out = {'shape': out_shape,
'dtype': out_dtype, 'dtype': out_dtype,
'value': None} 'value': None}
# 如果indices中有max_shape则设置输出的max_shape
if 'max_shape' in indices: if 'max_shape' in indices:
out['max_shape'] = (indices['max_shape'], indices['max_shape'], out['max_shape'] = (indices['max_shape'], indices['max_shape'],
indices['max_shape'], indices['max_shape']) indices['max_shape'], indices['max_shape'])
# 否则设置输出的max_shape为indices的形状
else: else:
out['max_shape'] = (indices['shape'], indices['shape'], out['max_shape'] = (indices['shape'], indices['shape'],
indices['shape'], indices['shape']) indices['shape'], indices['shape'])
# 如果indices中有min_shape则设置输出的min_shape
if 'min_shape' in indices: if 'min_shape' in indices:
out['min_shape'] = (indices['min_shape'], 0, 0, 0) out['min_shape'] = (indices['min_shape'], 0, 0, 0)
# 否则设置输出的min_shape为(0, 0, 0, 0)
else: else:
out['min_shape'] = (0, 0, 0, 0) out['min_shape'] = (0, 0, 0, 0)
# 返回输出的字典
return out return out
class DynamicAssign(PrimitiveWithCheck): class DynamicAssign(PrimitiveWithCheck):
""" """
Assigns `Parameter` with a value, the `value` can have a dynamic shape. `Parameter` 与值分配`value` 可以具有动态形状
Inputs: Inputs:
- **variable** (Parameter) - The `Parameter`. - **variable** (Parameter) - `Parameter`
- **value** (Tensor) - The value to be assigned. - **value** (Tensor) - 要分配的值
Outputs: Outputs:
Tensor, has the same type as original `variable`. Tensor具有与原始 `variable` 相同的类型
Supported Platforms: Supported Platforms:
`CPU` `CPU`
@ -256,31 +283,32 @@ class DynamicAssign(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output'])
def check_dtype(self, variable, value): def check_dtype(self, variable, value):
# 检查变量是否为mstype.type_refkey
if variable != mstype.type_refkey: if variable != mstype.type_refkey:
# 检查变量是否为mstype.number_type类型
validator.check_tensor_dtype_valid( validator.check_tensor_dtype_valid(
"variable", variable, mstype.number_type, self.name) "variable", variable, mstype.number_type, self.name)
# 检查value是否为mstype.number_type类型
validator.check_scalar_or_tensor_types_same( validator.check_scalar_or_tensor_types_same(
{"value": value}, mstype.number_type, self.name) {"value": value}, mstype.number_type, self.name)
class PadAndShift(PrimitiveWithCheck): class PadAndShift(PrimitiveWithCheck):
""" """
Pad a tensor with -1, and shift with a length. -1 填充张量并按长度进行移位
Inputs: Inputs:
- **input_x** (Tensor) - The input Tensor, which will be copied - **input_x** (Tensor) - 输入张量将被复制到 `output`
to `output`. - **cum_sum_arr** (Tensor) - cum_sum_arr 的最后一个值是输出张量的填充长度
- **cum_sum_arr** (Tensor) - The last value of cum_sum_arr is cum_sum_arr[shift_idx] 是开始移位cum_sum_arr[shift_idx+1] 是结束
the pad length of output tensor, cum_sum_arr[shift_idx] is - **shift_idx** (Int) - cum_sum_arr 的索引
the start to shift, and cum_sum_arr[shift_idx+1] is the end. 如果使用 PythonPadAndShift
- **shift_idx** (Int) - The idx of cum_sum_arr.
if use python, PadAndShift is:
output = [-1] * cum_sum_arr[-1] output = [-1] * cum_sum_arr[-1]
start = cum_sum_arr[shift_idx] start = cum_sum_arr[shift_idx]
end = cum_sum_arr[shift_idx + 1] end = cum_sum_arr[shift_idx + 1]
output[start:end] = input_x[:(end-start)] output[start:end] = input_x[:(end-start)]
Outputs: Outputs:
Tensor, has the same type as original `variable`. Tensor具有与原始 `variable` 相同的类型
Supported Platforms: Supported Platforms:
`CPU` `CPU`
@ -296,11 +324,14 @@ class PadAndShift(PrimitiveWithCheck):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
# 初始化输入输出名称
self.init_prim_io_names( self.init_prim_io_names(
inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output']) inputs=['input_x', 'cum_sum_arr', 'shift_idx'], outputs=['output'])
def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape): def check_shape(self, input_x_shape, cum_sum_arr_shape, shift_idx_shape):
# 检查输入形状
return input_x_shape return input_x_shape
def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype): def check_dtype(self, input_x_dtype, cum_sum_arr_dtype, shift_idx_dtype):
# 检查输入数据类型
return input_x_dtype return input_x_dtype

@ -55,6 +55,7 @@ class TensorArray(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"): def __init__(self, dtype, element_shape, dynamic_size=True, size=0, name="TA"):
"""初始化TensorArray类设置参数和属性."""
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
validator.check_int(size, 0, Rel.GE, "size", self.name) validator.check_int(size, 0, Rel.GE, "size", self.name)
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
@ -65,9 +66,11 @@ class TensorArray(PrimitiveWithInfer):
self.add_prim_attr('name', name) self.add_prim_attr('name', name)
def infer_shape(self): def infer_shape(self):
"""推断输出形状."""
return () return ()
def infer_dtype(self): def infer_dtype(self):
"""推断输出数据类型."""
return mstype.int64 return mstype.int64
@ -99,12 +102,15 @@ class TensorArrayWrite(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayWrite类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape, index_shape, value_shape): def infer_shape(self, handle_shape, index_shape, value_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type, index_type, value_type): def infer_dtype(self, handle_type, index_type, value_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
validator.check_type_name("index", index_type, (int, ms.int64), self.name) validator.check_type_name("index", index_type, (int, ms.int64), self.name)
validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("value", value_type, mstype.number_type + (mstype.bool_,), self.name)
@ -146,6 +152,7 @@ class TensorArrayRead(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape): def __init__(self, dtype, element_shape):
"""初始化TensorArrayRead类设置参数和属性."""
validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name) validator.check_type_name("dtype", dtype, mstype.number_type + (mstype.bool_,), self.name)
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
self.add_prim_attr('element_shape', element_shape) self.add_prim_attr('element_shape', element_shape)
@ -154,9 +161,11 @@ class TensorArrayRead(PrimitiveWithInfer):
self.shape = element_shape self.shape = element_shape
def infer_shape(self, handle_shape, index_shape): def infer_shape(self, handle_shape, index_shape):
"""推断输出形状."""
return self.shape return self.shape
def infer_dtype(self, handle_type, index_type): def infer_dtype(self, handle_type, index_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
validator.check_type_name("index", index_type, (int, ms.int64), self.name) validator.check_type_name("index", index_type, (int, ms.int64), self.name)
return self.dtype return self.dtype
@ -188,12 +197,15 @@ class TensorArrayClose(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayClose类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
@ -224,12 +236,15 @@ class TensorArrayClear(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArrayClear类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
@ -269,7 +284,7 @@ class TensorArrayStack(Primitive):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape, dynamic_size, size): def __init__(self, dtype, element_shape, dynamic_size, size):
"""Initialize TensorArrayStack""" """初始化TensorArrayStack类设置参数和属性."""
self.init_prim_io_names(inputs=[''], outputs=['output']) self.init_prim_io_names(inputs=[''], outputs=['output'])
self.add_prim_attr('dtype', dtype) self.add_prim_attr('dtype', dtype)
self.add_prim_attr('element_shape', element_shape) self.add_prim_attr('element_shape', element_shape)
@ -304,12 +319,15 @@ class TensorArraySize(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self): def __init__(self):
"""初始化TensorArraySize类."""
self.add_prim_attr('side_effect_mem', True) self.add_prim_attr('side_effect_mem', True)
def infer_shape(self, handle_shape): def infer_shape(self, handle_shape):
"""推断输出形状."""
return () return ()
def infer_dtype(self, handle_type): def infer_dtype(self, handle_type):
"""推断输出数据类型."""
validator.check_type_name("handle", handle_type, (ms.int64), self.name) validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64 return mstype.int64
@ -344,17 +362,20 @@ class TensorArrayGather(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, dtype, element_shape): def __init__(self, dtype, element_shape):
"""初始化TensorArrayGather类设置参数和属性."""
self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value']) self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value'])
self.add_prim_attr("side_effect_mem", True) self.add_prim_attr("side_effect_mem", True)
self.dtype = dtype self.dtype = dtype
self.element_shape = element_shape self.element_shape = element_shape
def infer_shape(self, handle, indices): def infer_shape(self, handle, indices):
"""推断输出形状."""
if len(indices) != 1: if len(indices) != 1:
return ValueError("indices dimension should be equal to 1") return ValueError("indices dimension should be equal to 1")
return [indices[0]] + list(self.element_shape) return [indices[0]] + list(self.element_shape)
def infer_dtype(self, handle, indices): def infer_dtype(self, handle, indices):
"""推断输出数据类型."""
validator.check_type_name("handle", handle, (ms.int64), self.name) validator.check_type_name("handle", handle, (ms.int64), self.name)
validator.check_type_name("indices", indices, (ms.int32), self.name) validator.check_type_name("indices", indices, (ms.int32), self.name)
return self.dtype return self.dtype

@ -30,10 +30,12 @@ class AllGatherCell(Cell):
def __init__(self, group): def __init__(self, group):
super(AllGatherCell, self).__init__(auto_prefix=False) super(AllGatherCell, self).__init__(auto_prefix=False)
# 创建AllGather操作对象
self.allgather = AllGather(group) self.allgather = AllGather(group)
@ms_function() @ms_function()
def construct(self, x): def construct(self, x):
# 执行AllGather操作
x = self.allgather(x) x = self.allgather(x)
return x return x
@ -50,10 +52,12 @@ class SaveOptShardCkptCell(Cell):
""" """
def __init__(self, group): def __init__(self, group):
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False) super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
# 创建AllGather操作对象
self.allgather1 = AllGather(group) self.allgather1 = AllGather(group)
self.allgather2 = AllGather() self.allgather2 = AllGather()
def construct(self, x): def construct(self, x):
# 执行AllGather操作
x = self.allgather1(x) x = self.allgather1(x)
x = self.allgather2(x) x = self.allgather2(x)
@ -64,11 +68,14 @@ def get_allgather_cell(group, need_merge_twice=False):
"""Get AllGatherCell object.""" """Get AllGatherCell object."""
global _allgather_cell global _allgather_cell
if need_merge_twice: if need_merge_twice:
# 如果需要两次合并则创建SaveOptShardCkptCell对象
_allgather_cell = SaveOptShardCkptCell(group) _allgather_cell = SaveOptShardCkptCell(group)
else: else:
if group: if group:
# 如果有指定的设备组则创建AllGatherCell对象
_allgather_cell = AllGatherCell(group) _allgather_cell = AllGatherCell(group)
else: else:
# 否则创建AllGatherCell对象使用全局通信组
_allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP) _allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
return _allgather_cell return _allgather_cell
@ -77,4 +84,5 @@ def destroy_allgather_cell():
"""Destroy AllGatherCell object.""" """Destroy AllGatherCell object."""
global _allgather_cell global _allgather_cell
if _allgather_cell: if _allgather_cell:
# 销毁AllGatherCell对象
_allgather_cell = None _allgather_cell = None

@ -28,13 +28,16 @@ def parse_args():
Examples: Examples:
>>> parse_args() >>> parse_args()
""" """
# 创建一个ArgumentParser对象用于解析命令行参数描述信息为"MindSpore dependency packages version checker."
parser = ArgumentParser(description="MindSpore dependency packages version checker.") parser = ArgumentParser(description="MindSpore dependency packages version checker.")
# 添加一个命令行参数--mindspore_version类型为字符串帮助信息为"MindSpore version."
parser.add_argument("--mindspore_version", type=str, help="MindSpore version.") parser.add_argument("--mindspore_version", type=str, help="MindSpore version.")
# 添加一个命令行参数--supported_version类型为字符串可以多次指定帮助信息为"Supported environment version."
parser.add_argument("--supported_version", type=str, action='append', help="Supported environment version.") parser.add_argument("--supported_version", type=str, action='append', help="Supported environment version.")
# 解析命令行参数并返回结果
args = parser.parse_args() args = parser.parse_args()
return args return args
def check_deps_version(mindspore_version, supported_version): def check_deps_version(mindspore_version, supported_version):
""" """
check te/hccl/topi version check te/hccl/topi version
@ -46,6 +49,7 @@ def check_deps_version(mindspore_version, supported_version):
Returns: Returns:
void void
""" """
# 尝试导入并检查 hccl、te 和 topi 包的版本是否与支持的版本匹配
try: try:
from hccl import sys_version as hccl_version from hccl import sys_version as hccl_version
v = '.'.join(hccl_version.__sys_version__.split('.')[0:2]) v = '.'.join(hccl_version.__sys_version__.split('.')[0:2])
@ -63,18 +67,20 @@ def check_deps_version(mindspore_version, supported_version):
print(f"MindSpore version {mindspore_version} and \"topi\" wheel package version {v} does not " print(f"MindSpore version {mindspore_version} and \"topi\" wheel package version {v} does not "
"match, reference to the match info on: https://www.mindspore.cn/install") "match, reference to the match info on: https://www.mindspore.cn/install")
# 捕获导入错误并打印相应的检查失败信息
except ImportError as e: except ImportError as e:
print("CheckFailed: ", e.args) print("CheckFailed: ", e.args)
print("MindSpore relies on the 3 whl packages of \"te\", \"topi\" and \"hccl\" in the \"fwkacllib\" " print("MindSpore relies on the 3 whl packages of \"te\", \"topi\" and \"hccl\" in the \"fwkacllib\" "
"folder of the Ascend AI software package (Ascend Data Center Solution), please check whether they are " "folder of the Ascend AI software package (Ascend Data Center Solution), please check whether they are "
"installed correctly or not, reference to the match info on: https://www.mindspore.cn/install") "installed correctly or not, reference to the match info on: https://www.mindspore.cn/install")
def main(): def main():
# 解析命令行参数
args = parse_args() args = parse_args()
# 检查 mindspore 的版本是否在支持的版本范围内
check_deps_version(args.mindspore_version, args.supported_version) check_deps_version(args.mindspore_version, args.supported_version)
if __name__ == "__main__": if __name__ == "__main__":
sys.path = sys.path[1:] # avoid the impact of relative path env, only affect this process # 避免相对路径环境的影响,仅影响当前进程
sys.path = sys.path[1:]
main() main()

@ -31,47 +31,58 @@ class EnvChecker(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def check_env(self, e): def check_env(self, e):
pass """检查环境是否符合要求"""
@abstractmethod @abstractmethod
def set_env(self): def set_env(self):
pass """设置环境"""
@abstractmethod @abstractmethod
def check_version(self): def check_version(self):
pass """检查版本是否符合要求"""
class GPUEnvChecker(EnvChecker): class GPUEnvChecker(EnvChecker):
"""GPU environment check.""" """GPU environment check."""
def __init__(self): def __init__(self):
# 初始化版本列表
self.version = ["10.1", "11.1"] self.version = ["10.1", "11.1"]
# 初始化库键到库名的映射字典
self.lib_key_to_lib_name = {'libcu': 'libcuda.so'} self.lib_key_to_lib_name = {'libcu': 'libcuda.so'}
# env # env
# 获取系统环境变量 PATH 的值
self.path = os.getenv("PATH") self.path = os.getenv("PATH")
# 获取系统环境变量 LD_LIBRARY_PATH 的值
self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") self.ld_lib_path = os.getenv("LD_LIBRARY_PATH")
# check # check
# 初始化版本号为 "0"
self.v = "0" self.v = "0"
# 获取 CUDA 库的路径
self.cuda_lib_path = self._get_lib_path("libcu") self.cuda_lib_path = self._get_lib_path("libcu")
# 获取 CUDA 可执行文件的路径
self.cuda_bin_path = self._get_bin_path("cuda") self.cuda_bin_path = self._get_bin_path("cuda")
# 获取 cuDNN 库的路径
self.cudnn_lib_path = self._get_lib_path("libcudnn") self.cudnn_lib_path = self._get_lib_path("libcudnn")
def check_env(self, e): def check_env(self, e):
# 抛出传入的异常 e
raise e raise e
def set_env(self): def set_env(self):
# 设置环境变量,当前实现为空
return return
def _get_bin_path(self, bin_name): def _get_bin_path(self, bin_name):
"""Get bin path by bin name.""" """Get bin path by bin name."""
# 如果二进制名称为 "cuda",则调用获取 CUDA 二进制路径的方法
if bin_name == "cuda": if bin_name == "cuda":
return self._get_cuda_bin_path() return self._get_cuda_bin_path()
# 否则返回空列表
return [] return []
def _get_cuda_bin_path(self): def _get_cuda_bin_path(self):
"""Get cuda bin path by lib path.""" # Get cuda bin path by lib path.
path_list = [] path_list = []
for path in self.cuda_lib_path: for path in self.cuda_lib_path:
path = os.path.abspath(path.strip()+"/bin/") path = os.path.abspath(path.strip()+"/bin/")
@ -81,56 +92,87 @@ class GPUEnvChecker(EnvChecker):
def _get_nvcc_version(self, is_set_env): def _get_nvcc_version(self, is_set_env):
"""Get cuda version by nvcc command.""" """Get cuda version by nvcc command."""
# 运行 nvcc 命令获取 CUDA 版本信息
nvcc_result = subprocess.run(["nvcc", "--version | grep release"], nvcc_result = subprocess.run(["nvcc", "--version | grep release"],
timeout=3, text=True, capture_output=True, check=False) timeout=3, text=True, capture_output=True, check=False)
# 如果命令返回非零值,表示命令执行失败
if nvcc_result.returncode: if nvcc_result.returncode:
# 如果尚未设置环境变量
if not is_set_env: if not is_set_env:
# 遍历预设的 CUDA 二进制路径
for path in self.cuda_bin_path: for path in self.cuda_bin_path:
# 检查路径中是否存在 nvcc 文件
if Path(path + "/nvcc").is_file(): if Path(path + "/nvcc").is_file():
# 将路径添加到环境变量 PATH 中
os.environ['PATH'] = path + ":" + os.environ['PATH'] os.environ['PATH'] = path + ":" + os.environ['PATH']
# 递归调用以重新尝试获取版本信息
return self._get_nvcc_version(True) return self._get_nvcc_version(True)
# 如果命令执行失败且未找到 nvcc 文件,返回空字符串
return "" return ""
# 获取命令输出结果
result = nvcc_result.stdout result = nvcc_result.stdout
# 遍历输出结果的每一行
for line in result.split('\n'): for line in result.split('\n'):
if line: if line:
# 提取并返回 CUDA 版本号
return line.strip().split("release")[1].split(",")[0].strip() return line.strip().split("release")[1].split(",")[0].strip()
# 如果未找到版本信息,返回空字符串
return "" return ""
def _get_cudnn_version(self): def _get_cudnn_version(self):
"""Get cudnn version by libcudnn.so.""" """Get cudnn version by libcudnn.so."""
# 初始化cudnn版本列表为空
cudnn_version = [] cudnn_version = []
# 遍历cudnn库路径
for path in self.cudnn_lib_path: for path in self.cudnn_lib_path:
# 查找路径下所有的libcudnn.so文件
real_path = glob.glob(path + "/lib*/libcudnn.so.*.*") real_path = glob.glob(path + "/lib*/libcudnn.so.*.*")
# 如果没有找到对应的文件,继续下一个路径
if real_path == []: if real_path == []:
continue continue
# 使用ls命令获取文件信息
ls_cudnn = subprocess.run(["ls", real_path[0]], timeout=10, text=True, ls_cudnn = subprocess.run(["ls", real_path[0]], timeout=10, text=True,
capture_output=True, check=False) capture_output=True, check=False)
# 如果ls命令执行成功解析输出以获取版本号
if ls_cudnn.returncode == 0: if ls_cudnn.returncode == 0:
cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.') cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.')
# 如果版本号只有两个部分,添加一个'.0'作为第三部分
if len(cudnn_version) == 2: if len(cudnn_version) == 2:
cudnn_version.append('0') cudnn_version.append('0')
# 找到版本号后跳出循环
break break
# 将版本号列表转换为字符串
version_str = ''.join([n for n in cudnn_version]) version_str = ''.join([n for n in cudnn_version])
# 返回版本号的前三位
return version_str[0:3] return version_str[0:3]
def _get_cudart_version(self): def _get_cudart_version(self):
"""Get cuda runtime version by libcudart.so.""" """Get cuda runtime version by libcudart.so."""
# 遍历可能的 CUDA 库路径
for path in self.cuda_lib_path: for path in self.cuda_lib_path:
# 查找路径下所有可能的 libcudart.so 文件
real_path = glob.glob(path + "/lib*/libcudart.so.*.*.*") real_path = glob.glob(path + "/lib*/libcudart.so.*.*.*")
# 如果没有找到任何文件,则跳过当前路径
if real_path == []: if real_path == []:
continue continue
# 获取文件名信息以确定 CUDA 版本
ls_cudart = subprocess.run(["ls", real_path[0]], timeout=10, text=True, ls_cudart = subprocess.run(["ls", real_path[0]], timeout=10, text=True,
capture_output=True, check=False) capture_output=True, check=False)
# 如果命令成功执行,则解析输出以提取版本号
if ls_cudart.returncode == 0: if ls_cudart.returncode == 0:
self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip() self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip()
# 找到版本号后跳出循环
break break
# 返回找到的 CUDA 版本号
return self.v return self.v
def check_version(self): def check_version(self):
"""Check cuda version.""" """Check cuda version."""
version_match = False version_match = False
# 调用私有方法检查版本是否匹配并根据结果设置version_match标志
if self._check_version(): if self._check_version():
version_match = True version_match = True
# 如果版本不匹配根据CUDA版本号输出不同的警告信息
if not version_match: if not version_match:
if self.v == "0": if self.v == "0":
logger.warning("Can not found cuda libs, please confirm that the correct " logger.warning("Can not found cuda libs, please confirm that the correct "
@ -140,17 +182,20 @@ class GPUEnvChecker(EnvChecker):
logger.warning(f"MindSpore version {__version__} and cuda version {self.v} does not match, " logger.warning(f"MindSpore version {__version__} and cuda version {self.v} does not match, "
"please refer to the installation guide for version matching " "please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install") "information: https://www.mindspore.cn/install")
# 获取nvcc版本号并检查是否与MindSpore支持的版本匹配
nvcc_version = self._get_nvcc_version(False) nvcc_version = self._get_nvcc_version(False)
if nvcc_version and (nvcc_version not in self.version): if nvcc_version and (nvcc_version not in self.version):
logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} " logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} "
"does not match, please refer to the installation guide for version matching " "does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install") "information: https://www.mindspore.cn/install")
# 获取cudnn版本号并检查是否符合最低要求
cudnn_version = self._get_cudnn_version() cudnn_version = self._get_cudnn_version()
if cudnn_version and int(cudnn_version) < 760: if cudnn_version and int(cudnn_version) < 760:
logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} " logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching " "does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install. The recommended version is " "information: https://www.mindspore.cn/install. The recommended version is "
"CUDA10.1 with cuDNN7.6.x and CUDA11.1 with cuDNN8.0.x") "CUDA10.1 with cuDNN7.6.x and CUDA11.1 with cuDNN8.0.x")
# 检查cudnn版本号与CUDA版本号的兼容性对于CUDA 11.0以上版本cudnn版本需要至少为8.0
if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10: if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10:
logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} " logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching " "does not match, please refer to the installation guide for version matching "
@ -159,45 +204,58 @@ class GPUEnvChecker(EnvChecker):
def _check_version(self): def _check_version(self):
"""Check cuda version""" """Check cuda version"""
# 获取 CUDA 运行时版本
v = self._get_cudart_version() v = self._get_cudart_version()
# 解析版本字符串为版本对象
v = version.parse(v) v = version.parse(v)
# 构造版本号字符串,格式为 "主版本.次版本"
v_str = str(v.major) + "." + str(v.minor) v_str = str(v.major) + "." + str(v.minor)
# 检查构造的版本号字符串是否在预定义的版本列表中
if v_str not in self.version: if v_str not in self.version:
return False return False
# 版本号匹配,返回 True
return True return True
def _get_lib_path(self, lib_name): def _get_lib_path(self, lib_name):
"""Get gpu lib path by ldd command.""" """通过ldd命令获取gpu库路径。"""
path_list = [] path_list = [] # 初始化一个空列表用于存储路径
current_path = os.path.split(os.path.realpath(__file__))[0] current_path = os.path.split(os.path.realpath(__file__))[0] # 获取当前文件的绝对路径并分割以获取目录部分
mindspore_path = os.path.join(current_path, "../") mindspore_path = os.path.join(current_path, "../") # 构建mindspore路径通常是当前文件的上一级目录
try: try:
# 使用glob模块查找mindspore_path目录下所有以_c_expression.so开头的文件路径
real_path = glob.glob(mindspore_path + "/_c_expression*.so*") real_path = glob.glob(mindspore_path + "/_c_expression*.so*")
if real_path == []: if real_path == []: # 如果没有找到任何文件
logger.error(f"{self.lib_key_to_lib_name[lib_name]} (need by mindspore-gpu) is not found, please " # 记录错误日志提示用户确认_c_expression.so文件是否存在以及是否安装了正确的cuda版本
f"confirm that _c_expression.so is in directory:{mindspore_path} and the correct cuda " logger.error(f"{self.lib_key_to_lib_name[lib_name]} (mindspore-gpu所需的库) 未找到,请确认 "
"version has been installed, you can refer to the installation " f"_c_expression.so是否位于目录:{mindspore_path}并且已安装正确的cuda版本"
"guidelines: https://www.mindspore.cn/install") "您可以参考安装指南https://www.mindspore.cn/install")
return path_list return path_list # 返回空路径列表
# 使用subprocess.Popen执行ldd命令以获取依赖库的信息
ldd_r = subprocess.Popen(['ldd', real_path[0]], stdout=subprocess.PIPE) ldd_r = subprocess.Popen(['ldd', real_path[0]], stdout=subprocess.PIPE)
# 使用subprocess.Popen的stdin参数从ldd_r.stdout接收输出并执行grep命令以过滤出包含指定库名的信息
ldd_result = subprocess.Popen(['grep', lib_name], stdin=ldd_r.stdout, stdout=subprocess.PIPE) ldd_result = subprocess.Popen(['grep', lib_name], stdin=ldd_r.stdout, stdout=subprocess.PIPE)
# 获取grep命令的输出结果并解码为字符串
result = ldd_result.communicate()[0].decode() result = ldd_result.communicate()[0].decode()
for i in result.split('\n'): for i in result.split('\n'): # 按行分割结果字符串
# 使用partition方法从每一行中提取出库文件的路径
path = i.partition("=>")[2] path = i.partition("=>")[2]
if path.lower().find("not found") > 0: if path.lower().find("not found") > 0: # 如果路径中包含"not found"
logger.warning(f"Cuda {self.version} version(need by mindspore-gpu) is not found, please confirm " # 记录警告日志提示用户确认cuda路径是否已添加到环境变量LD_LIBRARY_PATH中
"that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the " logger.warning(f"Cuda {self.version}版本(由mindspore-gpu要求的) 未找到请确认cuda路径已设置到环境变量LD_LIBRARY_PATH中"
"installation guidelines: https://www.mindspore.cn/install") "您可以参考安装指南https://www.mindspore.cn/install")
continue continue # 继续下一次循环
# 从路径中去除库名部分
path = path.partition(lib_name)[0] path = path.partition(lib_name)[0]
if path: if path: # 如果路径非空
# 将路径的绝对路径并去除末尾斜杠后添加到path_list中
path_list.append(os.path.abspath(path.strip() + "../")) path_list.append(os.path.abspath(path.strip() + "../"))
# 返回path_list中唯一的路径
return np.unique(path_list) return np.unique(path_list)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired: # 捕获subprocess.TimeoutExpired异常
logger.warning("Failed to check cuda version due to the ldd command timeout, please confirm that " # 记录警告日志提示用户确认cuda版本是否正确安装因为ldd命令超时
"the correct cuda version has been installed, you can refer to the " logger.warning("由于ldd命令超时无法检查cuda版本请确认已安装正确的cuda版本"
"installation guidelines: https://www.mindspore.cn/install") "您可以参考安装指南:https://www.mindspore.cn/install")
return path_list return path_list # 返回空路径列表
def _read_version(self, file_path): def _read_version(self, file_path):
"""Get gpu version info in version.txt.""" """Get gpu version info in version.txt."""
@ -211,70 +269,80 @@ class GPUEnvChecker(EnvChecker):
class AscendEnvChecker(EnvChecker): class AscendEnvChecker(EnvChecker):
"""ascend environment check""" """Ascend 环境检查类"""
def __init__(self): def __init__(self):
# 初始化 Ascend 环境检查器的版本列表
self.version = ["1.81"] self.version = ["1.81"]
# 定义不同路径下的 version.info 文件位置
atlas_nnae_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info" atlas_nnae_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info"
atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/version.info" atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/version.info"
hisi_fwk_version = "/usr/local/Ascend/latest/fwkacllib/version.info" hisi_fwk_version = "/usr/local/Ascend/latest/fwkacllib/version.info"
# 检查 Atlas NNAE 环境是否存在
if os.path.exists(atlas_nnae_version): if os.path.exists(atlas_nnae_version):
# atlas default path # 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = atlas_nnae_version self.fwk_version = atlas_nnae_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/nnae/latest/opp" self.op_path = "/usr/local/Ascend/nnae/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/nnae/latest" self.aicpu_path = "/usr/local/Ascend/nnae/latest" # AI CPU 路径
# 检查 Atlas Toolkit 环境是否存在
elif os.path.exists(atlas_toolkit_version): elif os.path.exists(atlas_toolkit_version):
# atlas default path # 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = atlas_toolkit_version self.fwk_version = atlas_toolkit_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" # AI CPU 路径
# 检查 Hisi 环境是否存在
elif os.path.exists(hisi_fwk_version): elif os.path.exists(hisi_fwk_version):
# hisi default path # 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = hisi_fwk_version self.fwk_version = hisi_fwk_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/latest/opp" self.op_path = "/usr/local/Ascend/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/latest" self.aicpu_path = "/usr/local/Ascend/latest" # AI CPU 路径
else:
# custom or unknown environment
self.fwk_path = ""
self.op_impl_path = ""
self.tbe_path = ""
self.cce_path = ""
self.fwk_version = ""
self.op_path = ""
self.aicpu_path = ""
# env else:
# 如果以上环境都不存在,设置为空路径
self.fwk_path = "" # Framework 路径
self.op_impl_path = "" # Operator 实现路径
self.tbe_path = "" # TBE 库路径
self.cce_path = "" # CCE 编译器路径
self.fwk_version = "" # Framework 版本文件路径
self.op_path = "" # Operator 路径
self.aicpu_path = "" # AI CPU 路径
# 初始化环境变量
self.path = os.getenv("PATH") self.path = os.getenv("PATH")
self.python_path = os.getenv("PYTHONPATH") self.python_path = os.getenv("PYTHONPATH")
self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") self.ld_lib_path = os.getenv("LD_LIBRARY_PATH")
self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH") self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH")
self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH") self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH")
# check content # 设置需要检查的路径内容
self.path_check = "/fwkacllib/ccec_compiler/bin" self.path_check = "/fwkacllib/ccec_compiler/bin"
self.python_path_check = "opp/op_impl/built-in/ai_core/tbe" self.python_path_check = "opp/op_impl/built-in/ai_core/tbe"
self.ld_lib_path_check_fwk = "/fwkacllib/lib64" self.ld_lib_path_check_fwk = "/fwkacllib/lib64"
self.ld_lib_path_check_addons = "/add-ons" self.ld_lib_path_check_addons = "/add-ons"
self.ascend_opp_path_check = "/op" self.ascend_opp_path_check = "/op"
self.v = "" self.v = ""
def check_env(self, e): def check_env(self, e):
self._check_env() self._check_env()
raise e raise e
def check_version(self): def check_version(self):
# 检查指定路径的版本文件是否存在,如果不存在则跳过版本检查
if not Path(self.fwk_version).is_file(): if not Path(self.fwk_version).is_file():
logger.warning("Using custom Ascend AI software package (Ascend Data Center Solution) path, package " logger.warning("Using custom Ascend AI software package (Ascend Data Center Solution) path, package "
"version checking is skipped, please make sure Ascend AI software package (Ascend Data " "version checking is skipped, please make sure Ascend AI software package (Ascend Data "
@ -282,36 +350,43 @@ class AscendEnvChecker(EnvChecker):
"https://www.mindspore.cn/install") "https://www.mindspore.cn/install")
return return
# 读取版本文件中的版本信息
v = self._read_version(self.fwk_version) v = self._read_version(self.fwk_version)
# 如果读取的版本不在支持的版本列表中,则记录警告信息
if v not in self.version: if v not in self.version:
v_list = str([x for x in self.version]) v_list = str([x for x in self.version])
logger.warning(f"MindSpore version {__version__} and Ascend AI software package (Ascend Data Center " logger.warning(f"MindSpore version {__version__} and Ascend AI software package (Ascend Data Center "
f"Solution)version {v} does not match, the version of software package expect one of " f"Solution)version {v} does not match, the version of software package expect one of "
f"{v_list}, please reference to the match info on: https://www.mindspore.cn/install") f"{v_list}, please reference to the match info on: https://www.mindspore.cn/install")
def check_deps_version(self): def check_deps_version(self):
""" """
te, topi, hccl wheel package version check te, topi, hccl wheel package version check
in order to update the change of 'LD_LIBRARY_PATH' env, run a sub process in order to update the change of 'LD_LIBRARY_PATH' env, run a sub process
""" """
# 构建输入参数列表包含mindspore版本和受支持的版本列表
input_args = ["--mindspore_version=" + __version__] input_args = ["--mindspore_version=" + __version__]
for v in self.version: for v in self.version:
input_args.append("--supported_version=" + v) input_args.append("--supported_version=" + v)
# 获取依赖版本检查脚本的路径
deps_version_checker = os.path.join(os.path.split(os.path.realpath(__file__))[0], deps_version_checker = os.path.join(os.path.split(os.path.realpath(__file__))[0],
"_check_deps_version.py") "_check_deps_version.py")
# 构建调用命令包括python解释器路径、脚本路径和输入参数
call_cmd = [sys.executable, deps_version_checker] + input_args call_cmd = [sys.executable, deps_version_checker] + input_args
try: try:
# 运行子进程进行版本检查设置超时时间为3秒并捕获输出
process = subprocess.run(call_cmd, timeout=3, text=True, capture_output=True, check=False) process = subprocess.run(call_cmd, timeout=3, text=True, capture_output=True, check=False)
# 如果子进程的输出不为空,则记录警告信息并进行倒计时提醒
if process.stdout.strip() != "": if process.stdout.strip() != "":
logger.warning(process.stdout.strip()) logger.warning(process.stdout.strip())
warning_countdown = 3 warning_countdown = 3
for i in range(warning_countdown, 0, -1): for i in range(warning_countdown, 0, -1):
logger.warning(f"Please pay attention to the above warning, countdown: {i}") logger.warning(f"Please pay attention to the above warning, countdown: {i}")
time.sleep(1) time.sleep(1)
# 如果版本检查超时,则记录信息并跳过
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.info("Package te, topi, hccl version check timed out, skip.") logger.info("Package te, topi, hccl version check timed out, skip.")
def set_env(self): def set_env(self):
# 设置Ascend环境变量
if not self.tbe_path: if not self.tbe_path:
self._check_env() self._check_env()
return return
@ -330,24 +405,26 @@ class AscendEnvChecker(EnvChecker):
f"No such directory: {self.tbe_path}, Please check if Ascend AI software package (Ascend Data " f"No such directory: {self.tbe_path}, Please check if Ascend AI software package (Ascend Data "
"Center Solution) is installed correctly.") "Center Solution) is installed correctly.")
# check te version after set te env # 检查te版本
self.check_deps_version() self.check_deps_version()
# 设置op实现路径环境变量
if Path(self.op_impl_path).is_dir(): if Path(self.op_impl_path).is_dir():
# python path for sub process # python路径用于子进程
if os.getenv('PYTHONPATH'): if os.getenv('PYTHONPATH'):
os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH'] os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH']
else: else:
os.environ['PYTHONPATH'] = self.op_impl_path os.environ['PYTHONPATH'] = self.op_impl_path
# sys path for this process # sys路径用于当前进程
sys.path.append(self.op_impl_path) sys.path.append(self.op_impl_path)
os.environ['TBE_IMPL_PATH'] = self.op_impl_path os.environ['TBE_IMPL_PATH'] = self.op_impl_path
else: else:
raise EnvironmentError( raise EnvironmentError(
f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data " f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data Center "
"Center Solution) is installed correctly.") "Solution) is installed correctly.")
# 设置CCE路径环境变量
if Path(self.cce_path).is_dir(): if Path(self.cce_path).is_dir():
os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH'] os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH']
else: else:
@ -355,6 +432,7 @@ class AscendEnvChecker(EnvChecker):
f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center " f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.") "Solution) is installed correctly.")
# 设置OP路径环境变量
if self.op_path is None: if self.op_path is None:
pass pass
elif Path(self.op_path).is_dir(): elif Path(self.op_path).is_dir():
@ -364,6 +442,7 @@ class AscendEnvChecker(EnvChecker):
f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center " f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.") "Solution) is installed correctly.")
# 设置AICPU路径环境变量
if self.aicpu_path is None: if self.aicpu_path is None:
pass pass
elif Path(self.aicpu_path).is_dir(): elif Path(self.aicpu_path).is_dir():
@ -375,41 +454,51 @@ class AscendEnvChecker(EnvChecker):
def _check_env(self): def _check_env(self):
"""ascend dependence path check""" """ascend dependence path check"""
# 检查是否设置正确的PATH环境变量
if self.path is None or self.path_check not in self.path: if self.path is None or self.path_check not in self.path:
logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env " logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env "
"PATH, you can reference to the installation guidelines https://www.mindspore.cn/install") "PATH, you can reference to the installation guidelines https://www.mindspore.cn/install")
# 检查是否设置正确的PYTHONPATH环境变量
if self.python_path is None or self.python_path_check not in self.python_path: if self.python_path is None or self.python_path_check not in self.python_path:
logger.warning( logger.warning(
"Can not find tbe op implement(need by mindspore-ascend), please check if you have set env " "Can not find tbe op implement(need by mindspore-ascend), please check if you have set env "
"PYTHONPATH, you can reference to the installation guidelines " "PYTHONPATH, you can reference to the installation guidelines "
"https://www.mindspore.cn/install") "https://www.mindspore.cn/install")
# 检查是否设置正确的LD_LIBRARY_PATH环境变量
if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and
self.ld_lib_path_check_addons in self.ld_lib_path): self.ld_lib_path_check_addons in self.ld_lib_path):
logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env " logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env "
"LD_LIBRARY_PATH, you can reference to the installation guidelines " "LD_LIBRARY_PATH, you can reference to the installation guidelines "
"https://www.mindspore.cn/install") "https://www.mindspore.cn/install")
# 检查是否设置正确的ASCEND_OPP_PATH环境变量
if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path: if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path:
logger.warning( logger.warning(
"Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, " "Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, "
"you can reference to the installation guidelines https://www.mindspore.cn/install") "you can reference to the installation guidelines https://www.mindspore.cn/install")
def _read_version(self, file_path): def _read_version(self, file_path):
"""get ascend version info""" """get ascend version info"""
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
all_info = f.readlines() all_info = f.readlines()
# 遍历文件中的每一行
for line in all_info: for line in all_info:
# 检查行是否以 "Version=" 开头
if line.startswith("Version="): if line.startswith("Version="):
# 去除行末的换行符并按 "=" 分割, 获取版本号
full_version = line.strip().split("=")[1] full_version = line.strip().split("=")[1]
# 提取主版本号和次版本号, 并用 "." 连接
self.v = '.'.join(full_version.split('.')[0:2]) self.v = '.'.join(full_version.split('.')[0:2])
# 返回版本号
return self.v return self.v
# 如果未找到版本信息, 返回 None 或默认值
return self.v return self.v
def check_version_and_env_config(): def check_version_and_env_config():
"""check version and env config""" """检查版本和环境配置"""
# 检查包名以确定使用哪种环境检查器
if __package_name__.lower() == "mindspore-ascend": if __package_name__.lower() == "mindspore-ascend":
env_checker = AscendEnvChecker() env_checker = AscendEnvChecker()
# Note: pre-load libgomp.so to solve error like "cannot allocate memory in statis TLS block" # Note: pre-load libgomp.so to solve error like "cannot allocate memory in statis TLS block"
@ -425,19 +514,21 @@ def check_version_and_env_config():
else: else:
logger.info(f"Package version {__package_name__} does not need to check any environment variable, skipping.") logger.info(f"Package version {__package_name__} does not need to check any environment variable, skipping.")
return return
# 检查是否关闭版本检查,如果已关闭则直接返回
if os.getenv("MS_DEV_CLOSE_VERSION_CHECK") == "ON": if os.getenv("MS_DEV_CLOSE_VERSION_CHECK") == "ON":
return return
# 设置环境变量以关闭版本检查
os.environ["MS_DEV_CLOSE_VERSION_CHECK"] = "ON" os.environ["MS_DEV_CLOSE_VERSION_CHECK"] = "ON"
try: try:
# check version of ascend site or cuda # 检查 ascend site 或 cuda 的版本
env_checker.check_version() env_checker.check_version()
from .. import _c_expression # pylint: disable=unused-import from .. import _c_expression # pylint: disable=unused-import
# 设置环境
env_checker.set_env() env_checker.set_env()
except ImportError as e: except ImportError as e:
# 处理导入错误,检查环境
env_checker.check_env(e) env_checker.check_env(e)
def _set_pb_env(): def _set_pb_env():
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" """Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
@ -450,6 +541,8 @@ def _set_pb_env():
"during save or load checkpoint file.") "during save or load checkpoint file.")
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# 检查版本和环境配置
check_version_and_env_config() check_version_and_env_config()
# 设置协议缓冲区的环境变量, 防止内存溢出
_set_pb_env() _set_pb_env()

@ -34,15 +34,18 @@ def _check_mul():
finally: finally:
pass pass
# 打印MindSpore版本信息
print(f"MindSpore version: ", ms.__version__) print(f"MindSpore version: ", ms.__version__)
# 创建两个MindSpore张量分别包含数组[1.0, 2.0, 3.0]和[4.0, 5.0, 6.0]
input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32) input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32) input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32)
# 创建一个乘法操作对象
mul = ms.ops.Mul() mul = ms.ops.Mul()
# 执行乘法操作
mul(input_x, input_y) mul(input_x, input_y)
# 打印乘法计算结果正确MindSpore安装成功的信息
print(f"The result of multiplication calculation is correct, MindSpore has been installed successfully!") print(f"The result of multiplication calculation is correct, MindSpore has been installed successfully!")
def run_check(): def run_check():
""" """
Provide a convenient API to check if the installation is successful or failed. Provide a convenient API to check if the installation is successful or failed.
@ -55,10 +58,13 @@ def run_check():
The result of multiplication calculation is correct, MindSpore has been installed successfully! The result of multiplication calculation is correct, MindSpore has been installed successfully!
""" """
try: try:
# 尝试执行检查乘法操作的函数
_check_mul() _check_mul()
# pylint: disable=broad-except # pylint: disable=broad-except
# 捕获所有异常并打印错误信息
except Exception as e: except Exception as e:
print("MindSpore running check failed.") print("MindSpore running check failed.")
print(str(e)) print(str(e))
finally: finally:
pass # 无论是否发生异常,都会执行此部分代码
pass # 执行乘法检查的函数,并处理可能的异常情况。如果检查失败,打印错误信息。

@ -22,185 +22,171 @@ from ..common import dtype as mstype
from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack from .utils_const import _type_convert, _raise_value_error, _callable_const, _super_check, pack
from ..ops.composite import GradOperation from ..ops.composite import GradOperation
grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) grad = GradOperation(get_all=False, get_by_list=False, sens_param=False) # 定义一个求梯度操作,设置为只求第一个参数的梯度,不通过列表获取,且不使用敏感参数
_eps_net = ops.Eps() _eps_net = ops.Eps() # 定义一个计算数值精度的操作
def _convert_64_to_32(tensor): # 定义一个函数将输入的tensor从float64或int64类型转换为float32或int32类型
def _convert_64_to_32(tensor):
"""Convert Tensor with float64/int64 types to float32/int32.""" """Convert Tensor with float64/int64 types to float32/int32."""
if tensor.dtype == mstype.float64: if tensor.dtype == mstype.float64: # 如果tensor的数据类型是float64
return tensor.astype("float32") return tensor.astype("float32") # 将其转换为float32类型
if tensor.dtype == mstype.int64: if tensor.dtype == mstype.int64: # 如果tensor的数据类型是int64
return tensor.astype("int32") return tensor.astype("int32") # 将其转换为int32类型
return tensor return tensor # 如果不是以上两种类型则直接返回原tensor
def _to_tensor(*args, dtype=None): def _to_tensor(*args, dtype=None): # 定义一个函数将输入的参数转换为tensor
"""Returns each input as Tensor""" """Returns each input as Tensor"""
res = () res = () # 初始化一个空元组用于存储结果
for arg in args: for arg in args: # 遍历每一个输入参数
if isinstance(arg, (int, float, bool, list, tuple)): if isinstance(arg, (int, float, bool, list, tuple)): # 如果参数是整数、浮点数、布尔值、列表或元组
arg = _type_convert(Tensor, arg) arg = _type_convert(Tensor, arg) # 将其转换为Tensor类型
if dtype is None: if dtype is None: # 如果没有指定dtype
arg = _convert_64_to_32(arg) arg = _convert_64_to_32(arg) # 调用_convert_64_to_32函数进行类型转换
else: else: # 如果指定了dtype
arg = arg.astype(dtype) arg = arg.astype(dtype) # 将tensor转换为指定的dtype
elif not isinstance(arg, Tensor): elif not isinstance(arg, Tensor): # 如果参数不是Tensor类型
_raise_value_error("Expect input to be array like.") _raise_value_error("Expect input to be array like.") # 抛出错误,提示输入应为数组形式
res += (arg,) res += (arg,) # 将转换后的tensor添加到结果元组中
if len(res) == 1: if len(res) == 1: # 如果结果元组中只有一个元素
return res[0] return res[0] # 直接返回该元素
return res return res # 否则返回整个元组
def _to_scalar(arr): # 定义一个函数将输入的Tensor或ndarray转换为标量值
def _to_scalar(arr):
"""Convert a scalar Tensor or ndarray to a scalar.""" """Convert a scalar Tensor or ndarray to a scalar."""
if isinstance(arr, (int, float, bool)): if isinstance(arr, (int, float, bool)): # 如果输入参数是整数、浮点数或布尔值
return arr return arr # 直接返回该参数
if isinstance(arr, Tensor): if isinstance(arr, Tensor): # 如果输入参数是Tensor类型
if arr.shape: if arr.shape: # 如果tensor的形状不是空的即不是标量
return arr return arr # 返回整个tensor
return arr.asnumpy().item() return arr.asnumpy().item() # 如果是标量将其转换为numpy数组并返回标量值
raise ValueError("{} are not supported.".format(type(arr))) raise ValueError("{} are not supported.".format(type(arr))) # 如果输入参数不是以上两种类型,抛出错误,提示不支持该类型
def _eps(x): # 定义一个函数计算输入tensor的数值精度
def _eps(x): return _eps_net(x[(0,) * x.ndim]) # 使用_ops.Eps操作计算数值精度x[(0,) * x.ndim]确保输入的是一个标量
return _eps_net(x[(0,) * x.ndim])
def _safe_normalize(x, threshold=None): # 定义一个函数对输入的tensor进行归一化如果归一化结果非常小则设置为零
def _safe_normalize(x, threshold=None):
"""Normalize method that cast very small results to zero.""" """Normalize method that cast very small results to zero."""
x_sum2 = F.reduce_sum(F.pows(x, 2.0)) x_sum2 = F.reduce_sum(F.pows(x, 2.0)) # 计算tensor元素平方的和
norm = F.pows(x_sum2, 1. / 2.0) norm = F.pows(x_sum2, 1. / 2.0) # 计算上述和的平方根得到norm
if threshold is None: if threshold is None: # 如果没有指定threshold
if x.dtype in (mstype.float32, mstype.float64): if x.dtype in (mstype.float32, mstype.float64): # 如果tensor的dtype是float32或float64
# pick the first element of x to get the eps # pick the first element of x to get the eps # 获取eps来作为threshold
threshold = _eps(x) threshold = _eps(x)
else: else: # 如果tensor的dtype不是float32或float64
threshold = 0 threshold = 0 # 设置threshold为0
use_norm = greater(norm, threshold) use_norm = greater(norm, threshold) # 比较norm和threshold得到一个布尔mask
x_norm = x / norm x_norm = x / norm # 使用norm对tensor进行归一化
normalized_x = where(use_norm, x_norm, zeros_like(x)) normalized_x = where(use_norm, x_norm, zeros_like(x)) # 如果norm大于threshold则使用归一化后的tensor否则使用零
norm = where(use_norm, norm, zeros_like(norm)) norm = where(use_norm, norm, zeros_like(norm)) # 如果norm大于threshold则保留norm否则使用零
return normalized_x, norm return normalized_x, norm # 返回归一化后的tensor及其对应的norm
def sparse_dot(a, b): # 定义一个函数计算稀疏矩阵CSRTensor与向量generic Tensor的点积
def sparse_dot(a, b):
"""Returns the dot product of CSRTensor and generic Tensor(vector).""" """Returns the dot product of CSRTensor and generic Tensor(vector)."""
b_aligned = F.reshape(b, (b.shape[0], -1)) b_aligned = F.reshape(b, (b.shape[0], -1)) # 将向量b重塑为(b.shape[0], -1)的形状,使其可以与稀疏矩阵相乘
res = F.csr_mv(a, b_aligned) res = F.csr_mv(a, b_aligned) # 使用csr_mv操作计算稀疏矩阵a与向量b_aligned的点积
res = F.reshape(res, a.shape[:-1] + b.shape[1:]) res = F.reshape(res, a.shape[:-1] + b.shape[1:]) # 将计算结果重新塑形为a.shape[:-1] + b.shape[1:]的形状
return res return res # 返回结果
def _normalize_matvec(f): def _normalize_matvec(f): # 定义一个函数,对输入的矩阵或向量进行归一化处理
"""Normalize an argument for computing matrix-vector products.""" """Normalize an argument for computing matrix-vector products."""
if isinstance(f, Tensor): if isinstance(f, Tensor): # 如果输入参数是Tensor类型
return F.partial(dot, f) return F.partial(dot, f) # 返回一个带有矩阵参数f的dot函数的部分应用
if isinstance(f, CSRTensor):
return F.partial(sparse_dot, f)
return f
if isinstance(f, CSRTensor): # 如果输入参数是CSRTensor类型
return F.partial(sparse_dot, f) # 返回一个带有稀疏矩阵参数f的sparse_dot函数的部分应用
def _norm(x, ord_=None): return f # 如果输入参数不是上述两种类型,则直接返回原参数
if ord_ == mnp.inf:
res = mnp.max(mnp.abs(x))
else:
res = mnp.sqrt(mnp.sum(x ** 2))
return res
def _norm(x, ord_=None): # 定义一个函数计算输入tensor的范数
if ord_ == mnp.inf: # 如果ord_为无穷大实际为最大值
res = mnp.max(mnp.abs(x)) # 返回tensor绝对值的最大值
else: # 如果ord_不是无穷大
res = mnp.sqrt(mnp.sum(x ** 2)) # 返回tensor元素平方和的平方根即L2范数
return res # 返回结果
def _nd_transpose(a): def _nd_transpose(a): # 定义一个函数对输入的tensor进行转置最后一个维度与倒数第二个维度互换
dims = a.ndim dims = a.ndim # 获取tensor的维度数
if dims < 2: if dims < 2: # 如果tensor的维度小于2
_raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.") _raise_value_error("to do _nd_transpose for input a's ndim is not greater or equal to 2d, which is invalid.") # 抛出错误提示输入tensor的维度应大于等于2
axes = ops.make_range(0, dims) axes = ops.make_range(0, dims) # 生成一个从0到tensor维度数的序列
axes = axes[:-2] + (axes[-1],) + (axes[-2],) axes = axes[:-2] + (axes[-1],) + (axes[-2],) # 将序列中的倒数第二个和最后一个元素互换位置
return ops.transpose(a, axes) return ops.transpose(a, axes) # 使用transpose操作对tensor进行转置
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): # 定义一个函数,用于检查输入参数的值是否符合预期
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) # 调用_super_check函数进行检查
def _value_check(func_name, arg1, arg2, arg_name='', attr_name='', op="in", fmt="attr", msg=None): def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): # 定义一个函数,用于检查输入参数的类型是否符合预期
return _super_check(pack(arg1, arg2), (func_name, arg_name, attr_name), op, fmt, msg, True) return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False) # 调用_super_check函数进行检查
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'): # 定义一个函数用于检查输入参数的mstype是否符合预期
def _type_check(func_name, arg1, arg2, arg_name='', op="isinstance", fmt="type", msg=None): return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype", # 调用_super_check函数进行检查
return _super_check(pack(arg1, arg2), (func_name, arg_name), op, fmt, msg, False)
def _mstype_check(func_name, arg, arg_mstype, arg_name='a'):
return _super_check((F.typeof(arg), arg_mstype), pack(arg, arg_mstype, func_name, arg_name), "isinstance", "mstype",
None, False) None, False)
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): # 定义一个函数,用于检查输入参数的数据类型是否符合预期
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", # 调用_super_check函数进行检查
None, False)
def _dtype_check(func_name, arg, arg_dtype, arg_name='a'): def _square_check(func_name, arg, arg_name='a'): # 定义一个函数,用于检查输入参数是否为方阵
return _super_check((F.dtype(arg), arg_dtype), (func_name, arg_name, "data type"), "in", "attr", None, False) arg_shape = arg.shape # 获取输入参数的形状
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) # 检查输入参数的维度是否为2
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) # 检查输入参数的形状是否为方阵
def _square_check(func_name, arg, arg_name='a'): return arg # 返回检查后的参数
arg_shape = arg.shape
_super_check((len(arg_shape), 2), (func_name, arg_name, 'dimension'), '==', 'attr', None, True) def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): # 定义一个函数,用于在求解线性方程组时检查输入参数
_super_check(arg_shape, (func_name, arg_name), '==', 'square', None, True) arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) # 获取第一个参数的形状和数据类型
return arg arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) # 获取第二个参数的形状和数据类型
_square_check(func_name, arg1, arg1_name) # 检查第一个参数是否为方阵
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) # 检查第二个参数的维度是否为1或2
def _solve_check(func_name, arg1, arg2, arg1_name='a', arg2_name='b', sparse=False): _super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True) # 检查第一个参数和第二个参数的形状是否可以用于求解线性方程组
arg1_shape, arg1_dtype = arg1.shape, F.dtype(arg1) _super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False) # 检查第一个参数和第二个参数的数据类型是否匹配
arg2_shape, arg2_dtype = arg2.shape, F.dtype(arg2) return arg1, arg2 # 返回检查后的两个参数
_square_check(func_name, arg1, arg1_name)
_super_check((len(arg2_shape), (1, 2)), (func_name, arg2_name, 'dimension'), 'in', 'attr', None, True) def _sparse_check(func_name, a, m, b, x0): # 定义一个函数用于在稀疏求解器如cg, bicgstab和gmres中检查输入参数
_super_check((arg1_shape, arg2_shape), (func_name, arg1_name, arg2_name, sparse), 'solve', 'solve', None, True)
_super_check((arg1_dtype, arg2_dtype), (func_name, arg1_name, arg2_name, 'data type'), '==', 'match', None, False)
return arg1, arg2
def _sparse_check(func_name, a, m, b, x0):
"""Used for cg, bicgstab and gmres method.""" """Used for cg, bicgstab and gmres method."""
def _check_right(arg, arg_name): def _check_right(arg, arg_name): # 定义一个内部函数用于检查右侧参数b或x0
if arg is None: if arg is None: # 如果参数为None
return mnp.zeros_like(b) # x0 same as b return mnp.zeros_like(b) # x0 same as b # 返回与b形状相同元素为零的tensor
# Type # Type
_mstype_check(func_name, arg, mstype.tensor_type, arg_name) _mstype_check(func_name, arg, mstype.tensor_type, arg_name) # 检查参数的mstype是否为tensor_type
# DType # DType
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
# Shape # Shape
if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1): if (arg.ndim != 1 and arg.ndim != 2) or (arg.ndim == 2 and arg.shape[1] != 1): # 检查参数的形状是否为(N,)或(N, 1)
_raise_value_error("For: '", func_name, "', the shape of '", arg_name, _raise_value_error("For: '", func_name, "', the shape of '", arg_name, # 如果不满足条件,抛出错误
"' should be like (N,) or (N, 1), bug got ", arg.shape, ".") "' should be like (N,) or (N, 1), bug got ", arg.shape, ".")
return arg return arg # 返回检查后的参数
b = _check_right(b, 'b') b = _check_right(b, 'b') # 检查参数b
x0 = _check_right(x0, 'x0') x0 = _check_right(x0, 'x0') # 检查参数x0
def _check_left(arg, arg_name): def _check_left(arg, arg_name): # 定义一个内部函数用于检查左侧参数a或m
if arg is None: if arg is None: # 如果参数为None
return lambda x: x # identity function return lambda x: x # identity function # 返回一个恒等函数
# Type # Type
_mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name) _mstype_check(func_name, arg, [mstype.function_type, mstype.tensor_type, mstype.csr_tensor_type], arg_name) # 检查参数的mstype是否为function_type, tensor_type或csr_tensor_type
if _callable_const(F.typeof(arg)): if _callable_const(F.typeof(arg)): # 如果参数是一个可调用的常量(即函数)
return arg return arg # 返回该参数
# DType # DType
if isinstance(arg, CSRTensor): if isinstance(arg, CSRTensor): # 如果参数是CSRTensor类型
_dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) _dtype_check(func_name, arg.indptr, [mstype.int32], arg_name) # 检查CSRTensor的indptr数据类型是否为int32
_dtype_check(func_name, arg.indices, [mstype.int32], arg_name) _dtype_check(func_name, arg.indices, [mstype.int32], arg_name) # 检查CSRTensor的indices数据类型是否为int32
_dtype_check(func_name, arg.values, [mstype.float32], arg_name) _dtype_check(func_name, arg.values, [mstype.float32], arg_name) # 检查CSRTensor的values数据类型是否为float32
else: else: # 如果参数不是CSRTensor类型
_dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) _dtype_check(func_name, arg, [mstype.int32, mstype.int64, mstype.float32, mstype.float64], arg_name) # 检查参数的数据类型是否在指定的类型列表中
# Shape # Shape
_solve_check(func_name, arg, b, arg_name, 'b', True) _solve_check(func_name, arg, b, arg_name, 'b', True) # 检查参数a和b的形状是否可以用于求解线性方程组
_solve_check(func_name, arg, x0, arg_name, 'x0', True) _solve_check(func_name, arg, x0, arg_name, 'x0', True) # 检查参数a和x0的形状是否可以用于求解线性方程组
if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64): if isinstance(arg, Tensor) and F.dtype(arg) in (mstype.int32, mstype.int64): # 如果参数是Tensor类型且数据类型为int32或int64
arg = F.cast(arg, mstype.float64) arg = F.cast(arg, mstype.float64) # 将其转换为float64类型
return arg return arg # 返回检查后的参数
a = _check_left(a, 'A') a = _check_left(a, 'A') # 检查参数a
m = _check_left(m, 'M') m = _check_left(m, 'M') # 检查参数m
b = b.flatten() b = b.flatten() # 将参数b展平为一维的tensor
x0 = x0.flatten() x0 = x0.flatten() # 将参数x0展平为一维的tensor
if F.dtype(b) in (mstype.int32, mstype.int64): if F.dtype(b) in (mstype.int32, mstype.int64): # 如果参数b的数据类型为int32或int64
b = F.cast(b, mstype.float64) b = F.cast(b, mstype.float64) # 将其转换为float64类型
x0 = F.cast(x0, mstype.float64) x0 = F.cast(x0, mstype.float64) # 将其转换为float64类型
return a, m, b, x0 return a, m, b, x0 # 返回检查并转换后的参数

@ -366,31 +366,46 @@ class ModelCheckpoint(Callback):
""" """
def __init__(self, prefix='CKP', directory=None, config=None): def __init__(self, prefix='CKP', directory=None, config=None):
# 初始化函数,设置前缀、目录、配置等参数
super(ModelCheckpoint, self).__init__() super(ModelCheckpoint, self).__init__()
# 调用父类的初始化函数
self._latest_ckpt_file_name = "" self._latest_ckpt_file_name = ""
# 初始化最新检查点文件名为空字符串
self._init_time = time.time() self._init_time = time.time()
# 初始化初始化时间为当前时间
self._last_time = time.time() self._last_time = time.time()
# 初始化最后时间时间为当前时间
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
# 初始化最后保存时间为当前时间
self._last_triggered_step = 0 self._last_triggered_step = 0
# 初始化最后触发的步数为0
# 检查前缀是否为字符串且不包含'/'
if not isinstance(prefix, str) or prefix.find('/') >= 0: if not isinstance(prefix, str) or prefix.find('/') >= 0:
raise ValueError("For 'ModelCheckpoint', the argument 'prefix' " raise ValueError("For 'ModelCheckpoint', the argument 'prefix' "
"for checkpoint file name is invalid, it must be " "for checkpoint file name is invalid, it must be "
"string and does not contain '/', but got {}.".format(prefix)) "string and does not contain '/', but got {}.".format(prefix))
self._prefix = prefix self._prefix = prefix
# 设置前缀
self._exception_prefix = prefix self._exception_prefix = prefix
# 设置异常前缀
# 如果目录不为空,则创建目录
if directory is not None: if directory is not None:
self._directory = _make_directory(directory) self._directory = _make_directory(directory)
else: else:
self._directory = _cur_dir self._directory = _cur_dir
# 否则,使用当前目录
# 如果启用了恢复上下文,则设置检查点路径
if _get_recovery_context("enable_recovery"): if _get_recovery_context("enable_recovery"):
_set_recovery_context(ckpt_path=self._directory) _set_recovery_context(ckpt_path=self._directory)
# 如果config为None则使用默认的CheckpointConfig
if config is None: if config is None:
self._config = CheckpointConfig() self._config = CheckpointConfig()
else: else:
# 如果config不是CheckpointConfig类型则抛出TypeError异常
if not isinstance(config, CheckpointConfig): if not isinstance(config, CheckpointConfig):
raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be " raise TypeError("For 'ModelCheckpoint', the type of argument 'config' should be "
"'CheckpointConfig', " "'CheckpointConfig', "
@ -398,11 +413,17 @@ class ModelCheckpoint(Callback):
self._config = config self._config = config
# get existing checkpoint files # get existing checkpoint files
# 创建CheckpointManager对象
self._manager = CheckpointManager() self._manager = CheckpointManager()
# 如果存在相同名称的文件,则更改文件名
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
# 获取配置中的append_dict参数如果没有则设置为空字典
self._append_dict = self._config.append_dict or {} self._append_dict = self._config.append_dict or {}
# 获取append_dict中的epoch_num参数如果没有则设置为0
self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0 self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0
# 获取append_dict中的step_num参数如果没有则设置为0
self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0 self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0
# 标记是否已经保存了图
self._graph_saved = False self._graph_saved = False
self._need_flush_from_cache = True self._need_flush_from_cache = True
@ -413,6 +434,7 @@ class ModelCheckpoint(Callback):
Args: Args:
run_context (RunContext): Context of the train running. run_context (RunContext): Context of the train running.
""" """
# If the role is PServer, add the role name and rank to the prefix
if _is_role_pserver(): if _is_role_pserver():
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
cb_params = run_context.original_args() cb_params = run_context.original_args()
@ -423,18 +445,23 @@ class ModelCheckpoint(Callback):
self._last_triggered_step = cb_params.last_save_ckpt_step self._last_triggered_step = cb_params.last_save_ckpt_step
cb_params.last_save_ckpt_step = None cb_params.last_save_ckpt_step = None
# Create the directory if it doesn't exist
_make_directory(self._directory) _make_directory(self._directory)
# save graph (only once) # save graph (only once)
if not self._graph_saved: if not self._graph_saved:
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
# If the graph file already exists and the mode is GRAPH_MODE, remove it
if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE: if os.path.isfile(graph_file_name) and context.get_context("mode") == context.GRAPH_MODE:
os.remove(graph_file_name) os.remove(graph_file_name)
# Save the graph
_save_graph(cb_params.train_network, graph_file_name) _save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True self._graph_saved = True
# Wait for any asynchronous checkpoint saving threads to finish
thread_list = threading.enumerate() thread_list = threading.enumerate()
for thread in thread_list: for thread in thread_list:
if thread.getName() == "asyn_save_ckpt": if thread.getName() == "asyn_save_ckpt":
thread.join() thread.join()
# Save the checkpoint
self._save_ckpt(cb_params) self._save_ckpt(cb_params)
def end(self, run_context): def end(self, run_context):
@ -444,44 +471,63 @@ class ModelCheckpoint(Callback):
Args: Args:
run_context (RunContext): Context of the train running. run_context (RunContext): Context of the train running.
""" """
# 获取训练的参数
cb_params = run_context.original_args() cb_params = run_context.original_args()
# 设置保存最后一个checkpoint的标志为True
_to_save_last_ckpt = True _to_save_last_ckpt = True
# 保存最后一个checkpoint
self._save_ckpt(cb_params, _to_save_last_ckpt) self._save_ckpt(cb_params, _to_save_last_ckpt)
# 获取当前线程列表
thread_list = threading.enumerate() thread_list = threading.enumerate()
# 遍历线程列表
for thread in thread_list: for thread in thread_list:
# 如果线程名为"asyn_save_ckpt",则等待该线程结束
if thread.getName() == "asyn_save_ckpt": if thread.getName() == "asyn_save_ckpt":
thread.join() thread.join()
# 销毁所有gather cell
destroy_allgather_cell() destroy_allgather_cell()
def _check_save_ckpt(self, cb_params, force_to_save): def _check_save_ckpt(self, cb_params, force_to_save):
"""Check whether save checkpoint files or not.""" """Check whether save checkpoint files or not."""
# 如果配置了保存检查点步数且步数大于0
if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0: if self._config.save_checkpoint_steps and self._config.save_checkpoint_steps > 0:
# 如果当前步数大于等于上次触发保存检查点步数加上保存检查点步数,或者强制保存检查点
if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \ if cb_params.cur_step_num >= self._last_triggered_step + self._config.save_checkpoint_steps \
or force_to_save is True: or force_to_save is True:
return True return True
# 如果配置了保存检查点秒数且秒数大于0
elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0: elif self._config.save_checkpoint_seconds and self._config.save_checkpoint_seconds > 0:
# 获取当前时间
self._cur_time = time.time() self._cur_time = time.time()
# 如果当前时间减去上次时间大于保存检查点秒数,或者强制保存检查点
if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save: if (self._cur_time - self._last_time) > self._config.save_checkpoint_seconds or force_to_save:
# 更新上次时间
self._last_time = self._cur_time self._last_time = self._cur_time
return True return True
# 返回False
return False return False
def _save_ckpt(self, cb_params, force_to_save=False): def _save_ckpt(self, cb_params, force_to_save=False):
"""Save checkpoint files.""" """Save checkpoint files."""
# 如果当前步骤数等于最后触发的步骤数,则返回
if cb_params.cur_step_num == self._last_triggered_step: if cb_params.cur_step_num == self._last_triggered_step:
return return
# if param is cache enable, flush data from cache to host before save_ckpt # if param is cache enable, flush data from cache to host before save_ckpt
# 如果需要从缓存中刷新数据则调用_flush_from_cache方法
if self._need_flush_from_cache: if self._need_flush_from_cache:
self._flush_from_cache(cb_params) self._flush_from_cache(cb_params)
# 检查是否需要保存检查点如果force_to_save为True则强制保存
save_ckpt = self._check_save_ckpt(cb_params, force_to_save) save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
# 计算当前步数在epoch中的位置
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
# 如果需要保存检查点,则创建当前检查点的文件名
if save_ckpt: if save_ckpt:
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \ cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
+ str(step_num_in_epoch) + ".ckpt" + str(step_num_in_epoch) + ".ckpt"
@ -489,43 +535,68 @@ class ModelCheckpoint(Callback):
self._manager.update_ckpoint_filelist(self._directory, self._prefix) self._manager.update_ckpoint_filelist(self._directory, self._prefix)
# keep checkpoint files number equal max number. # keep checkpoint files number equal max number.
if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num:
# 如果keep_checkpoint_max配置存在且大于0且小于等于当前checkpoint文件数量则删除最旧的checkpoint文件
self._manager.remove_oldest_ckpoint_file() self._manager.remove_oldest_ckpoint_file()
elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0:
# 如果keep_checkpoint_per_n_minutes配置存在且大于0则记录当前时间
self._cur_time_for_keep = time.time() self._cur_time_for_keep = time.time()
# 如果当前时间与上次记录的时间之差小于keep_checkpoint_per_n_minutes配置的分钟数乘以60则保留每个分钟的一个checkpoint文件
if (self._cur_time_for_keep - self._last_time_for_keep) \ if (self._cur_time_for_keep - self._last_time_for_keep) \
< self._config.keep_checkpoint_per_n_minutes * 60: < self._config.keep_checkpoint_per_n_minutes * 60:
self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes, self._manager.keep_one_ckpoint_per_minutes(self._config.keep_checkpoint_per_n_minutes,
self._cur_time_for_keep) self._cur_time_for_keep)
# generate the new checkpoint file and rename it. # generate the new checkpoint file and rename it.
# 定义全局变量_save_dir并将其赋值为self._directory
global _save_dir global _save_dir
_save_dir = self._directory _save_dir = self._directory
# 获取当前checkpoint文件的路径
cur_file = os.path.join(self._directory, cur_ckpoint_file) cur_file = os.path.join(self._directory, cur_ckpoint_file)
# 记录当前时间
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
# 记录当前触发步数
self._last_triggered_step = cb_params.cur_step_num self._last_triggered_step = cb_params.cur_step_num
# 如果启用了GEGraph Execution
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
# 设置当前网络
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
# 执行checkpoint图
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()
# 如果_append_dict中包含"epoch_num"
if "epoch_num" in self._append_dict: if "epoch_num" in self._append_dict:
# 将_append_epoch_num加上当前epoch数赋值给"epoch_num"
self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
# 如果_append_dict中包含"step_num"
if "step_num" in self._append_dict: if "step_num" in self._append_dict:
# 将_append_step_num加上当前step数赋值给"step_num"
self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
# 获取保存的网络如果self._config.saved_network不为None则使用self._config.saved_network否则使用cb_params.train_network
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
# 保存checkpoint
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
self._append_dict, self._config.enc_key, self._config.enc_mode) self._append_dict, self._config.enc_key, self._config.enc_mode)
# 记录最新的checkpoint文件名
self._latest_ckpt_file_name = cur_file self._latest_ckpt_file_name = cur_file
def _flush_from_cache(self, cb_params): def _flush_from_cache(self, cb_params):
"""Flush cache data to host if tensor is cache enable.""" """Flush cache data to host if tensor is cache enable."""
# 初始化has_cache_params为False
has_cache_params = False has_cache_params = False
# 获取训练网络中的参数
params = cb_params.train_network.get_parameters() params = cb_params.train_network.get_parameters()
# 遍历参数
for param in params: for param in params:
# 如果参数的cache_enable为True
if param.cache_enable: if param.cache_enable:
# 设置has_cache_params为True
has_cache_params = True has_cache_params = True
# 将参数的Tensor数据从缓存中刷新到主机
Tensor(param).flush_from_cache() Tensor(param).flush_from_cache()
# 如果没有参数的cache_enable为True
if not has_cache_params: if not has_cache_params:
# 设置_need_flush_from_cache为False
self._need_flush_from_cache = False self._need_flush_from_cache = False
@property @property
@ -535,63 +606,88 @@ class ModelCheckpoint(Callback):
class CheckpointManager: class CheckpointManager:
"""Manage checkpoint files according to train_config of checkpoint.""" """管理检查点文件,根据训练配置进行管理。"""
def __init__(self): def __init__(self):
"""初始化检查点管理器,创建空的检查点文件列表。"""
self._ckpoint_filelist = [] self._ckpoint_filelist = []
@property @property
def ckpoint_filelist(self): def ckpoint_filelist(self):
"""Get all the related checkpoint files managed here.""" """获取当前管理的所有检查点文件列表。"""
return self._ckpoint_filelist return self._ckpoint_filelist
@property @property
def ckpoint_num(self): def ckpoint_num(self):
"""Get the number of the related checkpoint files managed here.""" """获取当前管理的检查点文件数量。"""
return len(self._ckpoint_filelist) return len(self._ckpoint_filelist)
def update_ckpoint_filelist(self, directory, prefix): def update_ckpoint_filelist(self, directory, prefix):
"""Update the checkpoint file list.""" """更新检查点文件列表,根据目录和前缀筛选符合条件的检查点文件。"""
# 初始化一个空列表用于存储ckpt文件
self._ckpoint_filelist = [] self._ckpoint_filelist = []
# 获取指定目录下的所有文件
files = os.listdir(directory) files = os.listdir(directory)
# 遍历所有文件
for filename in files: for filename in files:
# 判断文件是否以指定前缀开头,并且以.ckpt结尾
if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"): if os.path.splitext(filename)[-1] == ".ckpt" and filename.startswith(prefix + "-"):
# 获取文件名中间部分
mid_name = filename[len(prefix):-5] mid_name = filename[len(prefix):-5]
# 判断中间部分是否包含字母
flag = not (True in [char.isalpha() for char in mid_name]) flag = not (True in [char.isalpha() for char in mid_name])
# 如果不包含字母,则将文件路径添加到列表中
if flag: if flag:
self._ckpoint_filelist.append(os.path.join(directory, filename)) self._ckpoint_filelist.append(os.path.join(directory, filename))
def remove_ckpoint_file(self, file_name): def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory.""" """从检查点管理器中移除指定的检查点文件,并从目录中删除该文件。"""
try: try:
# 修改文件权限为可写
os.chmod(file_name, stat.S_IWRITE) os.chmod(file_name, stat.S_IWRITE)
# 删除文件
os.remove(file_name) os.remove(file_name)
# 从ckpoint文件列表中移除该文件
self._ckpoint_filelist.remove(file_name) self._ckpoint_filelist.remove(file_name)
except OSError: except OSError:
# 捕获OSError异常并记录警告日志
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name) logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError: except ValueError:
# 捕获ValueError异常并记录警告日志
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name) logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def remove_oldest_ckpoint_file(self): def remove_oldest_ckpoint_file(self):
"""Remove the oldest checkpoint file from this checkpoint manager and also from the directory.""" """移除检查点管理器中最早的检查点文件,并从目录中删除该文件。"""
# 获取所有checkpoint文件并按修改时间排序
ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime) ckpoint_files = sorted(self._ckpoint_filelist, key=os.path.getmtime)
# 删除最早修改的checkpoint文件
self.remove_ckpoint_file(ckpoint_files[0]) self.remove_ckpoint_file(ckpoint_files[0])
def keep_one_ckpoint_per_minutes(self, minutes, cur_time): def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" """保留每分钟生成的最新检查点文件,移除在指定时间范围内生成的其他文件。"""
# 定义一个空列表,用于存储需要删除的文件
del_list = [] del_list = []
# 定义一个空字符串,用于存储最旧的文件名
oldest_file = '' oldest_file = ''
# 定义一个变量,用于存储当前时间
oldest_time = cur_time oldest_time = cur_time
# 遍历_ckpoint_filelist中的文件
for ck_file in self._ckpoint_filelist: for ck_file in self._ckpoint_filelist:
# 获取文件的修改时间
modify_time = os.path.getmtime(ck_file) modify_time = os.path.getmtime(ck_file)
# 如果当前时间减去文件的修改时间小于60*minutes则将文件添加到del_list中
if cur_time - modify_time < 60 * minutes: if cur_time - modify_time < 60 * minutes:
del_list.append(ck_file) del_list.append(ck_file)
# 如果文件的修改时间小于oldest_time则更新oldest_time和oldest_file
if modify_time < oldest_time: if modify_time < oldest_time:
oldest_time = modify_time oldest_time = modify_time
oldest_file = ck_file oldest_file = ck_file
# 遍历del_list中的文件
for mv_file in del_list: for mv_file in del_list:
# 如果文件是最旧的文件,则跳过
if mv_file == oldest_file: if mv_file == oldest_file:
continue continue
# 调用remove_ckpoint_file方法删除文件
self.remove_ckpoint_file(mv_file) self.remove_ckpoint_file(mv_file)

@ -256,36 +256,53 @@ class DatasetHelper:
""" """
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1): def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
# 检查dataset_sink_mode是否为布尔值
dataset_sink_mode = Validator.check_bool(dataset_sink_mode) dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
# 检查sink_size是否为整数
Validator.check_is_int(sink_size) Validator.check_is_int(sink_size)
# 如果sink_size小于-1或者等于0抛出异常
if sink_size < -1 or sink_size == 0: if sink_size < -1 or sink_size == 0:
raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size)) raise ValueError("The 'sink_size' must be -1 or positive, but got sink_size {}.".format(sink_size))
# 如果sink_size等于-1则将其设置为dataset的dataset_size
if sink_size == -1: if sink_size == -1:
sink_size = dataset.get_dataset_size() sink_size = dataset.get_dataset_size()
# 如果dataset_sink_mode为True则根据不同的设备类型选择不同的迭代器
if dataset_sink_mode: if dataset_sink_mode:
# 如果启用了GE则使用GE的迭代器
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
iterclass = _DatasetIterGE iterclass = _DatasetIterGE
else: else:
# 如果当前模式为GRAPH_MODE则根据角色选择不同的迭代器
if context.get_context("mode") == context.GRAPH_MODE: if context.get_context("mode") == context.GRAPH_MODE:
# 如果当前角色为调度器或者参数服务器,则使用参数服务器的迭代器
if _is_role_sched() or _is_role_pserver(): if _is_role_sched() or _is_role_pserver():
iterclass = _DatasetIterPSServer iterclass = _DatasetIterPSServer
# 如果当前角色为工作节点并且是参数服务器模式,则使用参数服务器工作节点的迭代器
elif _is_role_worker() and _is_ps_mode(): elif _is_role_worker() and _is_ps_mode():
iterclass = _DatasetIterPSWork iterclass = _DatasetIterPSWork
# 如果当前设备类型为Ascend或者GPU则使用多线程循环的迭代器
elif (context.get_context("device_target") == "Ascend") or \ elif (context.get_context("device_target") == "Ascend") or \
(context.get_context("device_target") == "GPU"): (context.get_context("device_target") == "GPU"):
iterclass = _DatasetIterMSLoopSink iterclass = _DatasetIterMSLoopSink
# 如果当前设备类型为CPU则抛出异常因为CPU不支持数据集下沉模式
elif context.get_context("device_target") == "CPU": elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device " raise RuntimeError("Currently dataset sink mode is not supported when the device "
"target is CPU, please set dataset sink mode to False.") "target is CPU, please set dataset sink mode to False.")
# 如果当前模式不是GRAPH_MODE则使用PyNative的迭代器
else: else:
iterclass = _DatasetIterPyNative iterclass = _DatasetIterPyNative
# 创建迭代器
self.iter = iterclass(dataset, sink_size, epoch_num) self.iter = iterclass(dataset, sink_size, epoch_num)
# 如果dataset_sink_mode为False则使用普通的迭代器
else: else:
# 如果不是分布式训练则使用_DatasetIterNormal类
iterclass = _DatasetIterNormal iterclass = _DatasetIterNormal
# 初始化迭代器
self.iter = iterclass(dataset, epoch_num=epoch_num) self.iter = iterclass(dataset, epoch_num=epoch_num)
def __iter__(self): def __iter__(self):
# 返回self.iter的迭代器
return self.iter.__iter__() return self.iter.__iter__()
# A temp solution for loop sink. Delete later # A temp solution for loop sink. Delete later
@ -301,6 +318,7 @@ class DatasetHelper:
>>> >>>
>>> types, shapes = dataset_helper.types_shapes() >>> types, shapes = dataset_helper.types_shapes()
""" """
# 从当前配置的dataset中获取类型和形状
return self.iter.types_shapes() return self.iter.types_shapes()
def sink_size(self): def sink_size(self):
@ -316,18 +334,22 @@ class DatasetHelper:
>>> # if sink_size==-1, then will return the full size of source dataset. >>> # if sink_size==-1, then will return the full size of source dataset.
>>> sink_size = dataset_helper.sink_size() >>> sink_size = dataset_helper.sink_size()
""" """
# 返回迭代器的接收缓冲区大小
return self.iter.get_sink_size() return self.iter.get_sink_size()
def stop_send(self): def stop_send(self):
"""Stop send data about data sink.""" """Stop send data about data sink."""
# 停止发送关于数据接收器的数据
self.iter.stop_send() self.iter.stop_send()
def release(self): def release(self):
"""Free up resources about data sink.""" """Free up resources about data sink."""
# 释放数据接收器的资源
self.iter.release() self.iter.release()
def continue_send(self): def continue_send(self):
"""Continue to send data to device at the beginning of epoch.""" """Continue to send data to device at the beginning of epoch."""
# 在每个epoch的开始处继续向设备发送数据
self.iter.continue_send() self.iter.continue_send()
def _reset(self, step): def _reset(self, step):
@ -339,6 +361,7 @@ class DatasetHelper:
In sink mode, it returns the types and shapes of the current data. In sink mode, it returns the types and shapes of the current data.
Generally, it works in dynamic shape scenarios. Generally, it works in dynamic shape scenarios.
""" """
# 返回迭代器的数据信息
return self.iter.get_data_info() return self.iter.get_data_info()
def dynamic_min_max_shapes(self): def dynamic_min_max_shapes(self):
@ -355,6 +378,7 @@ class DatasetHelper:
>>> >>>
>>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes() >>> min_shapes, max_shapes = dataset_helper.dynamic_min_max_shapes()
""" """
# 返回self.iter的dynamic_min_max_shapes方法
return self.iter.dynamic_min_max_shapes() return self.iter.dynamic_min_max_shapes()
@ -362,20 +386,27 @@ class _DatasetIter:
"""Base iter for dataset helper""" """Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num): def __init__(self, dataset, sink_size, epoch_num):
# 初始化函数传入数据集、sink大小和epoch数量
self.dataset = dataset self.dataset = dataset
self.sink_size = sink_size self.sink_size = sink_size
self.sink_count = self.get_sink_count(dataset) self.sink_count = self.get_sink_count(dataset)
# 如果数据集没有__transfer_dataset__属性
if not hasattr(dataset, '__transfer_dataset__'): if not hasattr(dataset, '__transfer_dataset__'):
# 如果数据集有__loop_size__属性
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
# PS mode does not support loop sink and need get the real sink size. # PS mode does not support loop sink and need get the real sink size.
# 如果不是worker角色或者不是ps模式则设置sink_size为dataset的循环大小
if not (_is_role_worker() and _is_ps_mode()): if not (_is_role_worker() and _is_ps_mode()):
self.sink_size = dataset.__loop_size__ self.sink_size = dataset.__loop_size__
# 如果sink_size为1sink_count为1dataset的大小不为1并且设备目标为Ascend则创建数据信息队列
create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1 create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1
and context.get_context("device_target") == "Ascend") and context.get_context("device_target") == "Ascend")
# 执行数据图并将sink_size和create_data_info_queue作为参数传入
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
create_data_info_queue=create_data_info_queue) create_data_info_queue=create_data_info_queue)
# 如果dataset没有__no_send__属性则发送数据
if not hasattr(dataset, '__no_send__'): if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num) _send_data(dataset, epoch_num)
else: else:
@ -384,33 +415,48 @@ class _DatasetIter:
_cell_graph_executor.set_queue_name(dataset.__transfer_dataset__.queue_name) _cell_graph_executor.set_queue_name(dataset.__transfer_dataset__.queue_name)
_send_data_no_flag(dataset, epoch_num) _send_data_no_flag(dataset, epoch_num)
# 获取dataset的stop_send方法
self.stop_send = dataset.__transfer_dataset__.stop_send self.stop_send = dataset.__transfer_dataset__.stop_send
# 获取dataset的release方法
self.release = dataset.__transfer_dataset__.release self.release = dataset.__transfer_dataset__.release
# 获取dataset的continue_send方法
self.continue_send = dataset.__transfer_dataset__.continue_send self.continue_send = dataset.__transfer_dataset__.continue_send
# 获取dataset的get_data_info方法
self.get_data_info = dataset.__transfer_dataset__.get_data_info self.get_data_info = dataset.__transfer_dataset__.get_data_info
# 获取dataset的dynamic_min_max_shapes属性
self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes self.dynamic_min_max_shapes = dataset.dynamic_min_max_shapes
# 获取dataset的数据类型和数据形状
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset) self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
# 如果dataset的__transfer_dataset__属性中有_reset方法则获取该_reset方法
if hasattr(dataset.__transfer_dataset__, "_reset"): if hasattr(dataset.__transfer_dataset__, "_reset"):
self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212 self._reset = dataset.__transfer_dataset__._reset # pylint: disable=W0212
def __iter__(self): def __iter__(self):
# 初始化索引为0
self.index = 0 self.index = 0
# 返回self
return self return self
# 迭代器的下一项
def __next__(self): def __next__(self):
# 如果索引大于等于sink_count抛出StopIteration异常
if self.index >= self.sink_count: if self.index >= self.sink_count:
raise StopIteration() raise StopIteration()
# 索引加1
self.index += 1 self.index += 1
# 返回op()的返回值
return self.op() return self.op()
def types_shapes(self): def types_shapes(self):
""" """
Return the types and shapes of the dataset. The type and shape of each data in the dataset 返回数据集的类型和形状数据集中每个数据的类型和形状应该是一致的
should be consistent.
""" """
return self.dataset_types, self.dataset_shapes return self.dataset_types, self.dataset_shapes
def get_sink_count(self, dataset): def get_sink_count(self, dataset):
"""
获取数据集的sink次数
:param dataset: 数据集对象
:return: sink次数
"""
sink_count = 1 sink_count = 1
if hasattr(dataset, '__loop_size__'): if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ loop_size = dataset.__loop_size__
@ -421,7 +467,10 @@ class _DatasetIter:
return sink_count return sink_count
def get_sink_size(self): def get_sink_size(self):
"""get sink_size to device""" """
获取设备的sink大小
:return: sink大小
"""
sink_size = 1 sink_size = 1
if hasattr(self.dataset, '__loop_size__'): if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__ sink_size = self.dataset.__loop_size__

@ -23,47 +23,59 @@ from setuptools import setup, find_packages
from setuptools.command.egg_info import egg_info from setuptools.command.egg_info import egg_info
from setuptools.command.build_py import build_py from setuptools.command.build_py import build_py
# 获取环境变量
backend_policy = os.getenv('BACKEND_POLICY') backend_policy = os.getenv('BACKEND_POLICY')
device_target = os.getenv('BACKEND_TARGET') device_target = os.getenv('BACKEND_TARGET')
commit_id = os.getenv('COMMIT_ID').replace("\n", "") commit_id = os.getenv('COMMIT_ID').replace("\n", "")
package_name = os.getenv('MS_PACKAGE_NAME').replace("\n", "") package_name = os.getenv('MS_PACKAGE_NAME').replace("\n", "")
build_path = os.getenv('BUILD_PATH') build_path = os.getenv('BUILD_PATH')
# 获取当前文件路径
pwd = os.path.dirname(os.path.realpath(__file__)) pwd = os.path.dirname(os.path.realpath(__file__))
# 获取包目录路径
pkg_dir = os.path.join(build_path, 'package') pkg_dir = os.path.join(build_path, 'package')
def _read_file(filename): def _read_file(filename):
"""读取文件内容"""
with open(os.path.join(pwd, filename), encoding='UTF-8') as f: with open(os.path.join(pwd, filename), encoding='UTF-8') as f:
return f.read() return f.read()
# 读取版本号
version = _read_file('version.txt').replace("\n", "") version = _read_file('version.txt').replace("\n", "")
# 读取README.md文件内容
readme = _read_file('README.md') readme = _read_file('README.md')
def _write_version(file): def _write_version(file):
"""写入版本号"""
file.write("__version__ = '{}'\n".format(version)) file.write("__version__ = '{}'\n".format(version))
def _write_config(file): def _write_config(file):
"""写入后端策略"""
file.write("__backend__ = '{}'\n".format(backend_policy)) file.write("__backend__ = '{}'\n".format(backend_policy))
def _write_commit_file(file): def _write_commit_file(file):
"""写入commit_id"""
file.write("__commit_id__ = '{}'\n".format(commit_id)) file.write("__commit_id__ = '{}'\n".format(commit_id))
def _write_package_name(file): def _write_package_name(file):
"""写入包名"""
file.write("__package_name__ = '{}'\n".format(package_name)) file.write("__package_name__ = '{}'\n".format(package_name))
def _write_device_target(file): def _write_device_target(file):
"""写入设备目标"""
file.write("__device_target__ = '{}'\n".format(device_target)) file.write("__device_target__ = '{}'\n".format(device_target))
def build_dependencies(): def build_dependencies():
"""generate python file""" """generate python file"""
# 生成version.py文件
version_file = os.path.join(pkg_dir, 'mindspore', 'version.py') version_file = os.path.join(pkg_dir, 'mindspore', 'version.py')
with open(version_file, 'w') as f: with open(version_file, 'w') as f:
_write_version(f) _write_version(f)
@ -72,6 +84,7 @@ def build_dependencies():
with open(version_file, 'w') as f: with open(version_file, 'w') as f:
_write_version(f) _write_version(f)
# 生成default_config.py文件
config_file = os.path.join(pkg_dir, 'mindspore', 'default_config.py') config_file = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(config_file, 'w') as f: with open(config_file, 'w') as f:
_write_config(f) _write_config(f)
@ -80,6 +93,7 @@ def build_dependencies():
with open(config_file, 'w') as f: with open(config_file, 'w') as f:
_write_config(f) _write_config(f)
# 向default_config.py文件中追加device_target
target = os.path.join(pkg_dir, 'mindspore', 'default_config.py') target = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(target, 'a') as f: with open(target, 'a') as f:
_write_device_target(f) _write_device_target(f)
@ -88,6 +102,7 @@ def build_dependencies():
with open(target, 'a') as f: with open(target, 'a') as f:
_write_device_target(f) _write_device_target(f)
# 向default_config.py文件中追加package_name
package_info = os.path.join(pkg_dir, 'mindspore', 'default_config.py') package_info = os.path.join(pkg_dir, 'mindspore', 'default_config.py')
with open(package_info, 'a') as f: with open(package_info, 'a') as f:
_write_package_name(f) _write_package_name(f)
@ -96,6 +111,7 @@ def build_dependencies():
with open(package_info, 'a') as f: with open(package_info, 'a') as f:
_write_package_name(f) _write_package_name(f)
# 生成.commit_id文件
commit_file = os.path.join(pkg_dir, 'mindspore', '.commit_id') commit_file = os.path.join(pkg_dir, 'mindspore', '.commit_id')
with open(commit_file, 'w') as f: with open(commit_file, 'w') as f:
_write_commit_file(f) _write_commit_file(f)
@ -145,16 +161,24 @@ def update_permissions(path):
Args: Args:
path (str): Target directory path. path (str): Target directory path.
""" """
# 判断操作系统是否为Windows
if platform.system() == "Windows": if platform.system() == "Windows":
return return
# 遍历目标目录下的所有文件和文件夹
for dirpath, dirnames, filenames in os.walk(path): for dirpath, dirnames, filenames in os.walk(path):
# 遍历文件夹
for dirname in dirnames: for dirname in dirnames:
# 获取文件夹的完整路径
dir_fullpath = os.path.join(dirpath, dirname) dir_fullpath = os.path.join(dirpath, dirname)
# 更新文件夹的权限
os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE | os.chmod(dir_fullpath, stat.S_IREAD | stat.S_IWRITE |
stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP) stat.S_IEXEC | stat.S_IRGRP | stat.S_IXGRP)
# 遍历文件
for filename in filenames: for filename in filenames:
# 获取文件的完整路径
file_fullpath = os.path.join(dirpath, filename) file_fullpath = os.path.join(dirpath, filename)
# 更新文件的权限
os.chmod(file_fullpath, stat.S_IREAD) os.chmod(file_fullpath, stat.S_IREAD)
@ -163,7 +187,9 @@ class EggInfo(egg_info):
def run(self): def run(self):
super().run() super().run()
# 获取egg-info目录的路径
egg_info_dir = os.path.join(pkg_dir, 'mindspore.egg-info') egg_info_dir = os.path.join(pkg_dir, 'mindspore.egg-info')
# 更新egg-info目录的权限
update_permissions(egg_info_dir) update_permissions(egg_info_dir)
@ -172,41 +198,64 @@ class BuildPy(build_py):
def run(self): def run(self):
super().run() super().run()
# 获取build目录下的lib/mindspore目录的路径
mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore') mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore')
# 更新lib/mindspore目录的权限
update_permissions(mindspore_dir) update_permissions(mindspore_dir)
# 获取build目录下的lib/mindspore/_akg目录的路径
mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore', '_akg') mindspore_dir = os.path.join(pkg_dir, 'build', 'lib', 'mindspore', '_akg')
# 更新lib/mindspore/_akg目录的权限
update_permissions(mindspore_dir) update_permissions(mindspore_dir)
# 设置包的名称
setup( setup(
name=package_name, name=package_name,
# 设置包的版本
version=version, version=version,
# 设置包的作者
author='The MindSpore Authors', author='The MindSpore Authors',
# 设置包的作者邮箱
author_email='contact@mindspore.cn', author_email='contact@mindspore.cn',
# 设置包的网址
url='https://www.mindspore.cn', url='https://www.mindspore.cn',
# 设置包的下载网址
download_url='https://github.com/mindspore-ai/mindspore/tags', download_url='https://github.com/mindspore-ai/mindspore/tags',
# 设置包的源代码网址
project_urls={ project_urls={
'Sources': 'https://github.com/mindspore-ai/mindspore', 'Sources': 'https://github.com/mindspore-ai/mindspore',
# 设置包的问题追踪网址
'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues', 'Issue Tracker': 'https://github.com/mindspore-ai/mindspore/issues',
}, },
# 设置包的描述
description='MindSpore is a new open source deep learning training/inference ' description='MindSpore is a new open source deep learning training/inference '
'framework that could be used for mobile, edge and cloud scenarios.', 'framework that could be used for mobile, edge and cloud scenarios.',
# 读取readme文件作为包的详细描述
long_description=readme, long_description=readme,
# 设置详细描述的格式
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
# 查找包中的所有模块
packages=find_packages(), packages=find_packages(),
# 设置包的数据
package_data=package_data, package_data=package_data,
# 包含包中的所有数据
include_package_data=True, include_package_data=True,
# 设置自定义的命令类
cmdclass={ cmdclass={
'egg_info': EggInfo, 'egg_info': EggInfo,
'build_py': BuildPy, 'build_py': BuildPy,
}, },
# 设置包的入口点
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'cache_admin=mindspore.dataset.engine.cache_admin:main', 'cache_admin=mindspore.dataset.engine.cache_admin:main',
], ],
}, },
# 设置包的Python版本要求
python_requires='>=3.7', python_requires='>=3.7',
# 设置包的依赖
install_requires=required_package, install_requires=required_package,
# 设置包的分类器
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
'Environment :: Console', 'Environment :: Console',
@ -223,6 +272,8 @@ setup(
'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules', 'Topic :: Software Development :: Libraries :: Python Modules',
], ],
# 设置包的许可证
license='Apache 2.0', license='Apache 2.0',
# 设置包的关键词
keywords='mindspore machine learning', keywords='mindspore machine learning',
) )

@ -1,3 +1,6 @@
这段代码是一个Python类的实现名为`MindData`它是一个用于模拟MindSpore框架中数据集处理的桩Stub下面是对这段代码的逐行注释
```python
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -12,77 +15,94 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
'''Remove after MindData merge to MindSpore ''' '''Remove after MindData merge to MindSpore '''
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor
class MindData: class MindData:
""" Stub for MindData """ """ Stub for MindData """
# 构造函数初始化MindData类的实例
def __init__(self, size=1, batch_size=None, repeat_count=1, def __init__(self, size=1, batch_size=None, repeat_count=1,
np_types=None, output_shapes=None, input_indexs=()): np_types=None, output_shapes=None, input_indexs=()):
self._size = size self._size = size # 数据集的大小
self._batch_size = batch_size self._batch_size = batch_size # 批处理大小
self._repeat_count = repeat_count self._repeat_count = repeat_count # 重复次数
self._np_types = np_types self._np_types = np_types # NumPy数据类型
self._output_shapes = output_shapes self._output_shapes = output_shapes # 输出形状
self._input_indexs = input_indexs self._input_indexs = input_indexs # 输入索引
self._iter_num = 0 self._iter_num = 0 # 迭代次数计数器
self.dynamic_setting = [False, None] self.dynamic_setting = [False, None] # 动态设置标志和值
# 获取数据集大小
def get_dataset_size(self): def get_dataset_size(self):
return self._size return self._size
# 获取重复次数
def get_repeat_count(self): def get_repeat_count(self):
return self._repeat_count return self._repeat_count
# 获取批处理大小
def get_batch_size(self): def get_batch_size(self):
return self._batch_size return self._batch_size
# 获取输出数据类型
def output_types(self): def output_types(self):
return self._np_types return self._np_types
# 获取输出形状
def output_shapes(self): def output_shapes(self):
return self._output_shapes return self._output_shapes
# 输入索引属性
@property @property
def input_indexs(self): def input_indexs(self):
return self._input_indexs return self._input_indexs
# 设备队列设置
def device_que(self, send_epoch_end=True, create_data_info_queue=False): def device_que(self, send_epoch_end=True, create_data_info_queue=False):
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736' self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
self.send_epoch_end = send_epoch_end self.send_epoch_end = send_epoch_end
return self return self
# 创建元组迭代器
def create_tuple_iterator(self, num_epochs=-1, do_copy=True): def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self.__iter__() return self.__iter__()
# 发送数据
def send(self, num_epochs=-1): def send(self, num_epochs=-1):
pass pass
# 停止发送数据
def stop_send(self): def stop_send(self):
pass pass
# 释放资源
def release(self): def release(self):
pass pass
# 继续发送数据
def continue_send(self): def continue_send(self):
pass pass
# 获取数据信息
def get_data_info(self): def get_data_info(self):
pass pass
# 动态最小最大形状
def dynamic_min_max_shapes(self): def dynamic_min_max_shapes(self):
pass pass
# 获取长度
def __len__(self): def __len__(self):
return self._size return self._size
# 迭代器
def __iter__(self): def __iter__(self):
return self return self
# 获取下一个元素
def __next__(self): def __next__(self):
if self._size < self._iter_num: if self._size < self._iter_num:
raise StopIteration raise StopIteration
@ -90,11 +110,13 @@ class MindData:
next_value = [] next_value = []
for shape, typ in zip(self._output_shapes, self._np_types): for shape, typ in zip(self._output_shapes, self._np_types):
next_value.append(Tensor(np.ndarray(shape, typ))) next_value.append(Tensor(np.ndarray(shape, typ)))
return tuple(next_value) return tuple(next_value)
# 下一个元素
def next(self): def next(self):
return self.__next__() return self.__next__()
# 重置迭代器
def reset(self): def reset(self):
self._iter_num = 0 self._iter_num = 0

Loading…
Cancel
Save