|
|
|
@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from ._hccl_management import load_lib as hccl_load_lib
|
|
|
|
|
from .._c_expression import get_rank_id, get_rank_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_HCCL_AVAILABLE = False
|
|
|
|
|
_HCCL_TEST_AVAILABLE = False
|
|
|
|
|
_NCCL_AVAILABLE = False
|
|
|
|
@ -28,14 +28,15 @@ try:
|
|
|
|
|
_NCCL_AVAILABLE = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
_NCCL_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True,否则捕获 RuntimeError 并设置为 False
|
|
|
|
|
try:
|
|
|
|
|
hccl_load_lib()
|
|
|
|
|
_HCCL_AVAILABLE = True
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
_HCCL_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 如果 HCCL 可用,则导入 _hccl_management 并尝试导入 mindspore._ascend_mpi,如果成功则设置 _MPI_AVAILABLE 为 True,否则捕获 ImportError 并设置为 False
|
|
|
|
|
if _HCCL_AVAILABLE:
|
|
|
|
|
from . import _hccl_management as hccl
|
|
|
|
|
try:
|
|
|
|
@ -43,6 +44,7 @@ if _HCCL_AVAILABLE:
|
|
|
|
|
_MPI_AVAILABLE = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
_MPI_AVAILABLE = False
|
|
|
|
|
# 如果 HCCL 不可用,则尝试导入 hccl_test.manage.api,如果成功则设置 _HCCL_AVAILABLE 和 _HCCL_TEST_AVAILABLE 为 True,否则捕获 ImportError 并设置 _HCCL_AVAILABLE 为 False
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
import hccl_test.manage.api as hccl
|
|
|
|
@ -50,12 +52,11 @@ else:
|
|
|
|
|
_HCCL_TEST_AVAILABLE = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
_HCCL_AVAILABLE = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 定义 HCCL 和 NCCL 的通信组名称常量
|
|
|
|
|
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
|
|
|
|
|
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Backend:
|
|
|
|
|
"""
|
|
|
|
|
Class for available backends.
|
|
|
|
@ -82,13 +83,17 @@ class Backend:
|
|
|
|
|
|
|
|
|
|
def __new__(cls, name):
|
|
|
|
|
"""Create instance object of Backend."""
|
|
|
|
|
# 检查传入的name是否为字符串类型
|
|
|
|
|
if not isinstance(name, str):
|
|
|
|
|
raise TypeError("For 'Backend', the class variable 'name' must be a string, "
|
|
|
|
|
"but got the type : {}".format(type(name)))
|
|
|
|
|
# 获取对应name的大写形式的Backend类属性值,如果不存在则返回Backend.UNDEFINED
|
|
|
|
|
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
|
|
|
|
|
# 如果获取到的值是Backend.UNDEFINED,说明传入的name不被支持
|
|
|
|
|
if value == Backend.UNDEFINED:
|
|
|
|
|
raise ValueError("For 'Backend', the class variable 'name' {} is not supported, "
|
|
|
|
|
"please use hccl or nccl.".format(name))
|
|
|
|
|
# 返回获取到的Backend类属性值
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
DEFAULT_BACKEND = Backend("hccl")
|
|
|
|
@ -97,7 +102,7 @@ DEFAULT_BACKEND = Backend("hccl")
|
|
|
|
|
class GlobalComm:
|
|
|
|
|
"""
|
|
|
|
|
World communication information. The GlobalComm is a global class. The members contain:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- BACKEND: The communication library used, using HCCL/NCCL.
|
|
|
|
|
- WORLD_COMM_GROUP: Global communication domain.
|
|
|
|
|
"""
|
|
|
|
@ -105,35 +110,34 @@ class GlobalComm:
|
|
|
|
|
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
|
|
|
INITED = False
|
|
|
|
|
CHECK_ENVS = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ExistingGroup:
|
|
|
|
|
"""
|
|
|
|
|
The communication groups which exist in the progress.
|
|
|
|
|
"""
|
|
|
|
|
ITEMS = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_hccl_available():
|
|
|
|
|
"""
|
|
|
|
|
Check HCCL api is available.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Boolean. Return whether HCCL is available or not.
|
|
|
|
|
"""
|
|
|
|
|
return _HCCL_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_mpi_available():
|
|
|
|
|
"""
|
|
|
|
|
Check HCCL & MPI api is available.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Boolean. Return whether HCCL & MPI is available or not.
|
|
|
|
|
"""
|
|
|
|
|
return _MPI_AVAILABLE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_nccl_available():
|
|
|
|
|
"""
|
|
|
|
|
Check NCCL api is available.
|
|
|
|
@ -158,19 +162,25 @@ def check_parameter_available(func):
|
|
|
|
|
Wrapper. If not available, raise Error.
|
|
|
|
|
"""
|
|
|
|
|
def wrapper(*args, **kargs):
|
|
|
|
|
# 如果当前角色是参数服务器或者调度器,直接调用被装饰的函数
|
|
|
|
|
if _is_role_pserver() or _is_role_sched():
|
|
|
|
|
return func(*args, **kargs)
|
|
|
|
|
# 检查分布式通信是否已经初始化,未初始化则抛出异常
|
|
|
|
|
if not GlobalComm.INITED:
|
|
|
|
|
raise RuntimeError("Distributed Communication has not been inited")
|
|
|
|
|
# 获取参数组,默认为None
|
|
|
|
|
group = None
|
|
|
|
|
# 检查关键字参数中是否包含"group",并进行类型检查
|
|
|
|
|
if "group" in kargs.keys():
|
|
|
|
|
group = kargs.get("group")
|
|
|
|
|
if group is not None and not isinstance(group, str):
|
|
|
|
|
raise TypeError("The parameter 'group' should be str or None, "
|
|
|
|
|
"but got the type : {}".format(type(group)))
|
|
|
|
|
|
|
|
|
|
# 获取后端,默认为None
|
|
|
|
|
if "backend" in kargs.keys():
|
|
|
|
|
backend = kargs.get("backend")
|
|
|
|
|
# 检查后端是否可用,不可用则抛出异常
|
|
|
|
|
if backend is Backend.HCCL and not is_hccl_available():
|
|
|
|
|
raise RuntimeError("Distributed Communication doesn't have HCCL built in")
|
|
|
|
|
if backend is Backend.HCCL_MPI and not is_mpi_available():
|
|
|
|
@ -178,15 +188,16 @@ def check_parameter_available(func):
|
|
|
|
|
if backend is Backend.NCCL and not is_nccl_available():
|
|
|
|
|
raise RuntimeError("Distributed Communication doesn't have NCCL built in")
|
|
|
|
|
|
|
|
|
|
# 如果未指定group,根据backend设置默认的group
|
|
|
|
|
if group is None:
|
|
|
|
|
if backend is Backend.HCCL or Backend.HCCL_MPI:
|
|
|
|
|
if backend is Backend.HCCL or backend is Backend.HCCL_MPI:
|
|
|
|
|
group = HCCL_WORLD_COMM_GROUP
|
|
|
|
|
elif backend is Backend.NCCL:
|
|
|
|
|
group = NCCL_WORLD_COMM_GROUP
|
|
|
|
|
# 调用被装饰的函数
|
|
|
|
|
return func(*args, **kargs)
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@check_parameter_available
|
|
|
|
|
def _get_rank_helper(group, backend):
|
|
|
|
|
"""
|
|
|
|
@ -202,10 +213,13 @@ def _get_rank_helper(group, backend):
|
|
|
|
|
Returns:
|
|
|
|
|
Integer. The local rank id of the calling process.
|
|
|
|
|
"""
|
|
|
|
|
# 辅助函数,用于根据不同的后端和组获取 rank_id
|
|
|
|
|
# 获取当前角色的rank_id,如果是参数服务器或调度器角色,rank_id设为0并返回
|
|
|
|
|
rank_id = None
|
|
|
|
|
if _is_role_pserver() or _is_role_sched():
|
|
|
|
|
rank_id = 0
|
|
|
|
|
return rank_id
|
|
|
|
|
# 根据不同的后端获取rank_id
|
|
|
|
|
if backend == Backend.HCCL_MPI:
|
|
|
|
|
rank_id = mpi.get_rank_id(group)
|
|
|
|
|
elif backend == Backend.HCCL:
|
|
|
|
@ -216,6 +230,7 @@ def _get_rank_helper(group, backend):
|
|
|
|
|
elif backend == Backend.NCCL:
|
|
|
|
|
rank_id = get_rank_id(group)
|
|
|
|
|
else:
|
|
|
|
|
# 如果后端不被支持,抛出ValueError异常
|
|
|
|
|
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
|
|
|
|
|
"please use hccl_mpi, hccl or nccl.".format(backend))
|
|
|
|
|
return rank_id
|
|
|
|
@ -236,6 +251,7 @@ def _get_local_rank_helper(group, backend):
|
|
|
|
|
Returns:
|
|
|
|
|
Integer. The local rank id of the calling process.
|
|
|
|
|
"""
|
|
|
|
|
# 获取当前进程的rank id,根据不同的后端和组进行处理
|
|
|
|
|
rank_id = None
|
|
|
|
|
if backend == Backend.HCCL_MPI:
|
|
|
|
|
rank_id = mpi.get_rank_id(group)
|
|
|
|
@ -251,7 +267,6 @@ def _get_local_rank_helper(group, backend):
|
|
|
|
|
"please use hccl_mpi or hccl.".format(backend))
|
|
|
|
|
return rank_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@check_parameter_available
|
|
|
|
|
def _get_size_helper(group, backend):
|
|
|
|
|
"""
|
|
|
|
@ -268,9 +283,11 @@ def _get_size_helper(group, backend):
|
|
|
|
|
Integer. The rank size of specified group.
|
|
|
|
|
"""
|
|
|
|
|
size = None
|
|
|
|
|
# 如果当前角色是参数服务器或调度器,则将size设为1并返回
|
|
|
|
|
if _is_role_pserver() or _is_role_sched():
|
|
|
|
|
size = 1
|
|
|
|
|
return size
|
|
|
|
|
# 根据不同的后端设置size的值
|
|
|
|
|
if backend == Backend.HCCL_MPI:
|
|
|
|
|
size = mpi.get_rank_size(group)
|
|
|
|
|
elif backend == Backend.HCCL:
|
|
|
|
@ -302,19 +319,23 @@ def _get_local_size_helper(group, backend):
|
|
|
|
|
Integer. The local rank size where the calling process is being within specified group.
|
|
|
|
|
"""
|
|
|
|
|
size = None
|
|
|
|
|
# 根据不同的后端获取本地进程组的大小
|
|
|
|
|
if backend == Backend.HCCL:
|
|
|
|
|
# 如果组是全局通信组,则获取全局通信组的本地排名大小
|
|
|
|
|
if group == HCCL_WORLD_COMM_GROUP:
|
|
|
|
|
size = hccl.get_local_rank_size()
|
|
|
|
|
# 否则获取指定组的本地排名大小
|
|
|
|
|
else:
|
|
|
|
|
size = hccl.get_local_rank_size(group)
|
|
|
|
|
# NCCL后端不支持获取本地排名大小,抛出异常
|
|
|
|
|
elif backend == Backend.NCCL:
|
|
|
|
|
raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
|
|
|
|
|
# 对于不支持的后端,抛出异常
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, "
|
|
|
|
|
"please use hccl.".format(backend))
|
|
|
|
|
return size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@check_parameter_available
|
|
|
|
|
def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
|
|
|
|
|
"""
|
|
|
|
@ -333,21 +354,26 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
|
|
|
|
|
Integer. A rank id in world communication group.
|
|
|
|
|
"""
|
|
|
|
|
world_rank_id = None
|
|
|
|
|
# 检查 group_rank_id 是否为整数类型,如果不是则抛出 TypeError
|
|
|
|
|
if not isinstance(group_rank_id, int):
|
|
|
|
|
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
|
|
|
|
|
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
|
|
|
|
|
# 根据不同的后端选择不同的逻辑处理方式
|
|
|
|
|
if backend == Backend.HCCL:
|
|
|
|
|
# 如果在 GPU 上使用 HCCL,但 group 参数为 HCCL_WORLD_COMM_GROUP,则抛出 ValueError
|
|
|
|
|
if group == HCCL_WORLD_COMM_GROUP:
|
|
|
|
|
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
|
|
|
|
|
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
|
|
|
|
|
# 调用 hccl.get_world_rank_from_group_rank 方法获取 world_rank_id
|
|
|
|
|
world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
|
|
|
|
|
elif backend == Backend.NCCL:
|
|
|
|
|
# 如果使用 NCCL,则抛出 RuntimeError 表示不支持该操作
|
|
|
|
|
raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
|
|
|
|
|
else:
|
|
|
|
|
# 如果 backend 参数不支持,则抛出 ValueError 表示不支持该后端
|
|
|
|
|
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
|
|
|
|
|
# 返回获取的 world_rank_id
|
|
|
|
|
return world_rank_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@check_parameter_available
|
|
|
|
|
def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
|
|
|
|
|
"""
|
|
|
|
@ -366,21 +392,27 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
|
|
|
|
|
Integer. A rank id in user communication group.
|
|
|
|
|
"""
|
|
|
|
|
group_rank_id = None
|
|
|
|
|
# 检查 world_rank_id 是否为整数类型,如果不是则抛出 TypeError
|
|
|
|
|
if not isinstance(world_rank_id, int):
|
|
|
|
|
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
|
|
|
|
|
"but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
|
|
|
|
|
# 根据不同的后端处理获取 group_rank_id 的逻辑
|
|
|
|
|
if backend == Backend.HCCL:
|
|
|
|
|
# 检查 GPU 后端的 group 参数是否正确,如果不正确则抛出 ValueError
|
|
|
|
|
if group == HCCL_WORLD_COMM_GROUP:
|
|
|
|
|
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
|
|
|
|
|
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
|
|
|
|
|
# 调用 hccl 模块的函数获取 group_rank_id
|
|
|
|
|
group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
|
|
|
|
|
# NCCL 后端不支持此操作,抛出 RuntimeError
|
|
|
|
|
elif backend == Backend.NCCL:
|
|
|
|
|
raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
|
|
|
|
|
# 如果后端不是 HCCL 或 NCCL,则抛出 ValueError 表示不支持的后端
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
|
|
|
|
|
# 返回获取到的 group_rank_id
|
|
|
|
|
return group_rank_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@check_parameter_available
|
|
|
|
|
def _create_group_helper(group, rank_ids, backend):
|
|
|
|
|
"""
|
|
|
|
@ -395,34 +427,46 @@ def _create_group_helper(group, rank_ids, backend):
|
|
|
|
|
TypeError: If rank_ids is not a list.
|
|
|
|
|
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
|
|
|
|
|
"""
|
|
|
|
|
# 检查组是否已存在
|
|
|
|
|
if group in _ExistingGroup.ITEMS.keys():
|
|
|
|
|
# 如果组已存在且提供的rank_ids与存储的不一致,抛出异常
|
|
|
|
|
if rank_ids != _ExistingGroup.ITEMS[group]:
|
|
|
|
|
raise ValueError("The group {} has been created, the rank_list is {}, "
|
|
|
|
|
"but current rank_list for the group is {}".
|
|
|
|
|
format(group, _ExistingGroup.ITEMS[group], rank_ids))
|
|
|
|
|
# 记录警告信息,提示组已存在
|
|
|
|
|
logger.warning("%r group has existed.", group)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 根据不同的后端创建组
|
|
|
|
|
if backend == Backend.HCCL:
|
|
|
|
|
# 检查rank_ids是否为列表类型
|
|
|
|
|
if not isinstance(rank_ids, list):
|
|
|
|
|
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
|
|
|
|
|
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
|
|
|
|
|
# 检查rank_ids的长度是否大于1
|
|
|
|
|
rank_size = len(rank_ids)
|
|
|
|
|
if rank_size < 1:
|
|
|
|
|
raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, "
|
|
|
|
|
"but got 'rank_ids' size : {}.".format(len(rank_ids)))
|
|
|
|
|
# 检查rank_ids中是否有重复的元素
|
|
|
|
|
if len(rank_ids) - len(list(set(rank_ids))) > 0:
|
|
|
|
|
raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
|
|
|
|
|
# 使用HCCL创建组
|
|
|
|
|
hccl.create_group(group, rank_size, rank_ids)
|
|
|
|
|
elif backend == Backend.HCCL_MPI:
|
|
|
|
|
# 使用HCCL_MPI创建组
|
|
|
|
|
mpi.create_group(group, rank_ids)
|
|
|
|
|
elif backend == Backend.NCCL:
|
|
|
|
|
# NCCL暂不支持创建组,抛出异常
|
|
|
|
|
raise RuntimeError("Nccl doesn't support create_group now.")
|
|
|
|
|
else:
|
|
|
|
|
# 如果后端不支持,抛出异常
|
|
|
|
|
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
|
|
|
|
|
"please use hccl.".format(backend))
|
|
|
|
|
_ExistingGroup.ITEMS[group] = rank_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 将新创建的组及其rank_ids添加到_existingGroup中
|
|
|
|
|
_ExistingGroup.ITEMS[group] = rank_ids
|
|
|
|
|
@check_parameter_available
|
|
|
|
|
def _destroy_group_helper(group, backend):
|
|
|
|
|
"""
|
|
|
|
@ -435,12 +479,17 @@ def _destroy_group_helper(group, backend):
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If group is "hccl_world_group" or backend is invalid.
|
|
|
|
|
"""
|
|
|
|
|
# 根据后端类型销毁通信组
|
|
|
|
|
if backend == Backend.HCCL:
|
|
|
|
|
# 检查是否为 HCCL 的全局通信组
|
|
|
|
|
if group == HCCL_WORLD_COMM_GROUP:
|
|
|
|
|
raise ValueError("The hccl_world_group does not support destruction.")
|
|
|
|
|
# 销毁指定的 HCCL 通信组
|
|
|
|
|
hccl.destroy_group(group)
|
|
|
|
|
elif backend == Backend.NCCL:
|
|
|
|
|
# 当前 NCCL 后端不支持销毁通信组
|
|
|
|
|
raise RuntimeError("Nccl doesn't support destroy_group now.")
|
|
|
|
|
else:
|
|
|
|
|
# 抛出错误,表示不支持的后端类型
|
|
|
|
|
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
|
|
|
|
|
"please use hccl.".format(backend))
|
|
|
|
|
"please use hccl.".format(backend))
|