communication,ruiqin

pull/5/head
zhang 2 months ago
parent 4e49598ac9
commit 61450586a7

@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib
from .._c_expression import get_rank_id, get_rank_size
_HCCL_AVAILABLE = False
_HCCL_TEST_AVAILABLE = False
_NCCL_AVAILABLE = False
@ -28,14 +28,15 @@ try:
_NCCL_AVAILABLE = True
except ImportError:
_NCCL_AVAILABLE = False
# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True否则捕获 RuntimeError 并设置为 False
try:
hccl_load_lib()
_HCCL_AVAILABLE = True
except RuntimeError:
_HCCL_AVAILABLE = False
# 如果 HCCL 可用,则导入 _hccl_management 并尝试导入 mindspore._ascend_mpi如果成功则设置 _MPI_AVAILABLE 为 True否则捕获 ImportError 并设置为 False
if _HCCL_AVAILABLE:
from . import _hccl_management as hccl
try:
@ -43,6 +44,7 @@ if _HCCL_AVAILABLE:
_MPI_AVAILABLE = True
except ImportError:
_MPI_AVAILABLE = False
# 如果 HCCL 不可用,则尝试导入 hccl_test.manage.api如果成功则设置 _HCCL_AVAILABLE 和 _HCCL_TEST_AVAILABLE 为 True否则捕获 ImportError 并设置 _HCCL_AVAILABLE 为 False
else:
try:
import hccl_test.manage.api as hccl
@ -50,12 +52,11 @@ else:
_HCCL_TEST_AVAILABLE = True
except ImportError:
_HCCL_AVAILABLE = False
# 定义 HCCL 和 NCCL 的通信组名称常量
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
class Backend:
"""
Class for available backends.
@ -82,13 +83,17 @@ class Backend:
def __new__(cls, name):
"""Create instance object of Backend."""
# 检查传入的name是否为字符串类型
if not isinstance(name, str):
raise TypeError("For 'Backend', the class variable 'name' must be a string, "
"but got the type : {}".format(type(name)))
# 获取对应name的大写形式的Backend类属性值如果不存在则返回Backend.UNDEFINED
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
# 如果获取到的值是Backend.UNDEFINED说明传入的name不被支持
if value == Backend.UNDEFINED:
raise ValueError("For 'Backend', the class variable 'name' {} is not supported, "
"please use hccl or nccl.".format(name))
# 返回获取到的Backend类属性值
return value
DEFAULT_BACKEND = Backend("hccl")
@ -97,7 +102,7 @@ DEFAULT_BACKEND = Backend("hccl")
class GlobalComm:
"""
World communication information. The GlobalComm is a global class. The members contain:
- BACKEND: The communication library used, using HCCL/NCCL.
- WORLD_COMM_GROUP: Global communication domain.
"""
@ -105,35 +110,34 @@ class GlobalComm:
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
CHECK_ENVS = True
class _ExistingGroup:
"""
The communication groups which exist in the progress.
"""
ITEMS = {}
def is_hccl_available():
"""
Check HCCL api is available.
Returns:
Boolean. Return whether HCCL is available or not.
"""
return _HCCL_AVAILABLE
def is_mpi_available():
"""
Check HCCL & MPI api is available.
Returns:
Boolean. Return whether HCCL & MPI is available or not.
"""
return _MPI_AVAILABLE
def is_nccl_available():
"""
Check NCCL api is available.
@ -158,19 +162,25 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error.
"""
def wrapper(*args, **kargs):
# 如果当前角色是参数服务器或者调度器,直接调用被装饰的函数
if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs)
# 检查分布式通信是否已经初始化,未初始化则抛出异常
if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited")
# 获取参数组默认为None
group = None
# 检查关键字参数中是否包含"group",并进行类型检查
if "group" in kargs.keys():
group = kargs.get("group")
if group is not None and not isinstance(group, str):
raise TypeError("The parameter 'group' should be str or None, "
"but got the type : {}".format(type(group)))
# 获取后端默认为None
if "backend" in kargs.keys():
backend = kargs.get("backend")
# 检查后端是否可用,不可用则抛出异常
if backend is Backend.HCCL and not is_hccl_available():
raise RuntimeError("Distributed Communication doesn't have HCCL built in")
if backend is Backend.HCCL_MPI and not is_mpi_available():
@ -178,15 +188,16 @@ def check_parameter_available(func):
if backend is Backend.NCCL and not is_nccl_available():
raise RuntimeError("Distributed Communication doesn't have NCCL built in")
# 如果未指定group根据backend设置默认的group
if group is None:
if backend is Backend.HCCL or Backend.HCCL_MPI:
if backend is Backend.HCCL or backend is Backend.HCCL_MPI:
group = HCCL_WORLD_COMM_GROUP
elif backend is Backend.NCCL:
group = NCCL_WORLD_COMM_GROUP
# 调用被装饰的函数
return func(*args, **kargs)
return wrapper
@check_parameter_available
def _get_rank_helper(group, backend):
"""
@ -202,10 +213,13 @@ def _get_rank_helper(group, backend):
Returns:
Integer. The local rank id of the calling process.
"""
# 辅助函数,用于根据不同的后端和组获取 rank_id
# 获取当前角色的rank_id如果是参数服务器或调度器角色rank_id设为0并返回
rank_id = None
if _is_role_pserver() or _is_role_sched():
rank_id = 0
return rank_id
# 根据不同的后端获取rank_id
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
@ -216,6 +230,7 @@ def _get_rank_helper(group, backend):
elif backend == Backend.NCCL:
rank_id = get_rank_id(group)
else:
# 如果后端不被支持抛出ValueError异常
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi, hccl or nccl.".format(backend))
return rank_id
@ -236,6 +251,7 @@ def _get_local_rank_helper(group, backend):
Returns:
Integer. The local rank id of the calling process.
"""
# 获取当前进程的rank id根据不同的后端和组进行处理
rank_id = None
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
@ -251,7 +267,6 @@ def _get_local_rank_helper(group, backend):
"please use hccl_mpi or hccl.".format(backend))
return rank_id
@check_parameter_available
def _get_size_helper(group, backend):
"""
@ -268,9 +283,11 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group.
"""
size = None
# 如果当前角色是参数服务器或调度器则将size设为1并返回
if _is_role_pserver() or _is_role_sched():
size = 1
return size
# 根据不同的后端设置size的值
if backend == Backend.HCCL_MPI:
size = mpi.get_rank_size(group)
elif backend == Backend.HCCL:
@ -302,19 +319,23 @@ def _get_local_size_helper(group, backend):
Integer. The local rank size where the calling process is being within specified group.
"""
size = None
# 根据不同的后端获取本地进程组的大小
if backend == Backend.HCCL:
# 如果组是全局通信组,则获取全局通信组的本地排名大小
if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_local_rank_size()
# 否则获取指定组的本地排名大小
else:
size = hccl.get_local_rank_size(group)
# NCCL后端不支持获取本地排名大小抛出异常
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
# 对于不支持的后端,抛出异常
else:
raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, "
"please use hccl.".format(backend))
return size
@check_parameter_available
def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
"""
@ -333,21 +354,26 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
Integer. A rank id in world communication group.
"""
world_rank_id = None
# 检查 group_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(group_rank_id, int):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
# 根据不同的后端选择不同的逻辑处理方式
if backend == Backend.HCCL:
# 如果在 GPU 上使用 HCCL但 group 参数为 HCCL_WORLD_COMM_GROUP则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl.get_world_rank_from_group_rank 方法获取 world_rank_id
world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
elif backend == Backend.NCCL:
# 如果使用 NCCL则抛出 RuntimeError 表示不支持该操作
raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
else:
# 如果 backend 参数不支持,则抛出 ValueError 表示不支持该后端
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取的 world_rank_id
return world_rank_id
@check_parameter_available
def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
"""
@ -366,21 +392,27 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
Integer. A rank id in user communication group.
"""
group_rank_id = None
# 检查 world_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(world_rank_id, int):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
"but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
# 根据不同的后端处理获取 group_rank_id 的逻辑
if backend == Backend.HCCL:
# 检查 GPU 后端的 group 参数是否正确,如果不正确则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl 模块的函数获取 group_rank_id
group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
# NCCL 后端不支持此操作,抛出 RuntimeError
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
# 如果后端不是 HCCL 或 NCCL则抛出 ValueError 表示不支持的后端
else:
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取到的 group_rank_id
return group_rank_id
@check_parameter_available
def _create_group_helper(group, rank_ids, backend):
"""
@ -395,34 +427,46 @@ def _create_group_helper(group, rank_ids, backend):
TypeError: If rank_ids is not a list.
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
"""
# 检查组是否已存在
if group in _ExistingGroup.ITEMS.keys():
# 如果组已存在且提供的rank_ids与存储的不一致抛出异常
if rank_ids != _ExistingGroup.ITEMS[group]:
raise ValueError("The group {} has been created, the rank_list is {}, "
"but current rank_list for the group is {}".
format(group, _ExistingGroup.ITEMS[group], rank_ids))
# 记录警告信息,提示组已存在
logger.warning("%r group has existed.", group)
return
# 根据不同的后端创建组
if backend == Backend.HCCL:
# 检查rank_ids是否为列表类型
if not isinstance(rank_ids, list):
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
# 检查rank_ids的长度是否大于1
rank_size = len(rank_ids)
if rank_size < 1:
raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, "
"but got 'rank_ids' size : {}.".format(len(rank_ids)))
# 检查rank_ids中是否有重复的元素
if len(rank_ids) - len(list(set(rank_ids))) > 0:
raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
# 使用HCCL创建组
hccl.create_group(group, rank_size, rank_ids)
elif backend == Backend.HCCL_MPI:
# 使用HCCL_MPI创建组
mpi.create_group(group, rank_ids)
elif backend == Backend.NCCL:
# NCCL暂不支持创建组抛出异常
raise RuntimeError("Nccl doesn't support create_group now.")
else:
# 如果后端不支持,抛出异常
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend))
_ExistingGroup.ITEMS[group] = rank_ids
# 将新创建的组及其rank_ids添加到_existingGroup中
_ExistingGroup.ITEMS[group] = rank_ids
@check_parameter_available
def _destroy_group_helper(group, backend):
"""
@ -435,12 +479,17 @@ def _destroy_group_helper(group, backend):
Raises:
ValueError: If group is "hccl_world_group" or backend is invalid.
"""
# 根据后端类型销毁通信组
if backend == Backend.HCCL:
# 检查是否为 HCCL 的全局通信组
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("The hccl_world_group does not support destruction.")
# 销毁指定的 HCCL 通信组
hccl.destroy_group(group)
elif backend == Backend.NCCL:
# 当前 NCCL 后端不支持销毁通信组
raise RuntimeError("Nccl doesn't support destroy_group now.")
else:
# 抛出错误,表示不支持的后端类型
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend))
"please use hccl.".format(backend))

@ -24,7 +24,7 @@ MAX_RANK_NUM = 4096
HCCL_LIB = 'libhccl_plugin.so'
HCCL_LIB_CTYPES = ""
# 检查集体通信组的名称是否合法
def check_group(group):
"""
A function that check if a collection communication group is legal.
@ -41,7 +41,7 @@ def check_group(group):
raise TypeError("The type of communication group name must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 检查集体通信中的排名编号是否合法
def check_rank_num(rank_num):
"""
A function that check if a collection communication rank number is legal.If not raise error.
@ -57,7 +57,7 @@ def check_rank_num(rank_num):
raise TypeError("The argument 'rank_num' must be type of int, "
"but got 'rank_num' type : {}.".format(type(rank_num)))
#检查集体通信中的排名标识(rank id)是否合法
def check_rank_id(rank_id):
"""
A function that check if a collection communication rank id is legal.If not raise error.
@ -65,15 +65,18 @@ def check_rank_id(rank_id):
Returns:
None
"""
# 检查rank_id是否为整数类型
if isinstance(rank_id, (int)):
# 检查rank_id是否在有效范围内
if rank_id >= MAX_RANK_NUM or rank_id < 0:
raise ValueError("The rand id in the communication group must be greater or equal 0 and "
"less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id))
else:
# 如果rank_id不是整数类型抛出类型错误
raise TypeError("The rand id in the communication group must be must be type of int, "
"but got type value : {}.".format(type(rank_id)))
# 加载 HCCLHuawei Cloud Communication Library
def load_lib():
"""load hccl lib"""
try:
@ -82,11 +85,10 @@ def load_lib():
hccl_lib = ctypes.CDLL(lib_path)
except Exception:
raise RuntimeError('Get hccl lib error.')
global HCCL_LIB_CTYPES
HCCL_LIB_CTYPES = hccl_lib
def c_str(string):
"""Convert a python string to C string."""
if not isinstance(string, str):
@ -98,7 +100,7 @@ def c_array(ctype, values):
"""Create ctypes array from a python array."""
return (ctype * len(values))(*values)
#用于创建包含指定数量和ID的HCCL通信组但不能创建世界组。
def create_group(group, rank_num, rank_ids):
"""
Create group.
@ -112,28 +114,38 @@ def create_group(group, rank_num, rank_ids):
Returns:
None
"""
# 检查组的有效性
check_group(group)
# 检查排名数量的有效性
check_rank_num(rank_num)
# 检查rank_ids是否为列表类型
if isinstance(rank_ids, (list)):
# 确保rank_num与rank_ids的长度一致
if rank_num != len(rank_ids):
raise ValueError("The argument 'rank_num' number should be equal to the length "
"of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}."
.format(rank_num, rank_ids))
# 检查rank_ids中的每个元素是否为非负整数
for rank_id in rank_ids:
if not isinstance(rank_id, (int)) or rank_id < 0:
raise ValueError("The elements of argument 'rank_ids' must be "
"unsigned integer, but got the type : {}".format(type(rank_id)))
# 将rank_ids转换为C类型的数组
c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
# 将rank_num转换为C类型的无符号整数
c_rank_num = ctypes.c_uint(rank_num)
# 将group转换为C类型的字符串
c_group = c_str(group)
# 调用HCCL库创建组
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
# 检查组创建是否成功
if ret != 0:
raise RuntimeError('Create group error, the error code is {}.'.format(ret))
else:
# 如果rank_ids不是列表类型抛出类型错误
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
#用于销毁用户创建的HCCL组
def destroy_group(group):
"""
A function that destroy the group which created by user.
@ -162,16 +174,23 @@ def get_rank_size(group="hccl_world_group"):
An integer scalar with the num of ranks.
"""
# 根据上下文的模式判断是否为PYNATIVE_MODE模式若是则直接返回HCCL的rank size
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_size()
# 检查给定的组是否有效
check_group(group)
# 将组名转换为C字符串格式
c_group = c_str(group)
# 定义一个C类型的无符号整数用于存储rank size
c_rank_size = ctypes.c_uint()
# 调用HCCL库的HcomGetRankSize函数获取组内的rank size
ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
# 如果返回值不为0表示获取rank size失败抛出运行时错误
if ret != 0:
raise RuntimeError('Get rank size error.')
# 返回获取到的rank size值
return c_rank_size.value
@ -186,17 +205,22 @@ def get_rank_id(group="hccl_world_group"):
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_id()
# 检查组的有效性
check_group(group)
# 将组转换为 C 字符串格式
c_group = c_str(group)
# 定义一个用于存储 rank id 的 ctypes 无符号整数
c_rank_id = ctypes.c_uint()
# 调用 HCCL 库获取当前进程的 rank id
ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
# 如果返回值不为 0表示获取 rank id 出错,抛出 RuntimeError 异常
if ret != 0:
raise RuntimeError('Get rank id error.')
# 返回获取到的 rank id 值
return c_rank_id.value
def get_local_rank_size(group="hccl_world_group"):
"""
A function that returns the number of local ranks within the given collection communication group.
@ -207,19 +231,25 @@ def get_local_rank_size(group="hccl_world_group"):
Returns:
An integer scalar with the num of local ranks.
"""
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, "
"'get_local_rank_size' only support GRAPH_MODE")
# 验证传入的组是否有效
check_group(group)
# 将组名称转换为C字符串格式
c_group = c_str(group)
# 定义一个ctypes的无符号整数变量用于存储本地排名大小
c_local_rank_size = ctypes.c_uint()
# 调用HCCL库中的HcomGetLocalRankSize函数获取本地排名大小
ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
# 如果返回值不为0说明获取本地排名大小时出错抛出异常
if ret != 0:
raise RuntimeError('Get local rank size error.')
# 返回获取到的本地排名大小
return c_local_rank_size.value
def get_local_rank_id(group="hccl_world_group"):
"""
Get local rank id.
@ -230,16 +260,23 @@ def get_local_rank_id(group="hccl_world_group"):
An integer scalar with the local rank id of the calling process.
"""
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, "
"'get_local_rank_id' only support GRAPH_MODE")
# 验证群组的有效性
check_group(group)
# 将群组名称转换为C字符串格式
c_group = c_str(group)
# 定义一个无符号整型的C类型变量来存储本地排名ID
c_local_rank_id = ctypes.c_uint()
# 调用HCCL库的HcomGetLocalRankId函数获取本地排名ID
ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
# 如果返回值不为0表示获取本地排名ID失败抛出异常
if ret != 0:
raise RuntimeError('Get local rank id error.')
# 返回获取到的本地排名ID值
return c_local_rank_id.value
@ -256,18 +293,25 @@ def get_world_rank_from_group_rank(group, group_rank_id):
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, "
"'get_world_rank_from_group_rank' only support GRAPH_MODE")
# 检查组名是否有效
check_group(group)
# 检查组内rank ID是否有效
check_rank_id(group_rank_id)
# 将组名转换为C字符串格式
c_group = c_str(group)
# 将组内rank ID转换为C的无符号整数类型
c_group_rank_id = ctypes.c_uint(group_rank_id)
# 声明一个用于存储世界rank ID的C的无符号整数类型变量
c_world_rank_id = ctypes.c_uint()
# 调用HCCL库中的HcomGetWorldRankFromGroupRank函数根据组名和组内rank ID获取对应的世界rank ID
ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
# 如果返回值不为0说明函数调用出错抛出RuntimeError异常
if ret != 0:
raise RuntimeError('Get world rank from group rank error.')
raise RuntimeError('根据组内rank ID获取世界rank ID时出错。')
# 返回获取到的世界rank ID的值
return c_world_rank_id.value
def get_group_rank_from_world_rank(world_rank_id, group):
"""
Get group rank from world rank.
@ -281,13 +325,21 @@ def get_group_rank_from_world_rank(world_rank_id, group):
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, "
"'get_group_rank_from_world_rank' only support GRAPH_MODE")
# 检查组的有效性
check_group(group)
# 检查世界排名ID的有效性
check_rank_id(world_rank_id)
# 将组转换为C字符串
c_group = c_str(group)
# 将世界排名ID转换为C无符号整数
c_world_rank_id = ctypes.c_uint(world_rank_id)
# 创建一个用于存储组排名ID的C无符号整数
c_group_rank_id = ctypes.c_uint()
# 从世界排名ID获取组排名ID
ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
# 如果返回值不为0抛出运行时错误
if ret != 0:
raise RuntimeError('Get group rank from world rank error.')
return c_group_rank_id.value
# 返回获取到的组排名ID的值
return c_group_rank_id.value

@ -63,25 +63,31 @@ def _check_parallel_envs():
Raises:
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
"""
# 检查是否需要进行环境验证,如果不进行则直接返回
if not GlobalComm.CHECK_ENVS:
return
import os
# 获取环境变量RANK_ID的值
rank_id_str = os.getenv("RANK_ID")
# 如果RANK_ID未设置抛出运行时错误
if not rank_id_str:
raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.")
try:
# 尝试将RANK_ID转换为整数
int(rank_id_str)
except ValueError:
# 如果转换失败,打印错误信息
print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str)))
finally:
# 无论是否发生异常,此块为空操作
pass
# 获取环境变量MINDSPORE_HCCL_CONFIG_PATH和RANK_TABLE_FILE的值
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
# 如果两个环境变量都未设置,抛出运行时错误
if not rank_table_file_str and not rank_table_file_str_old:
raise RuntimeError("Get hccl rank_table_file failed, "
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
def init(backend_name=None):
"""
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
@ -109,15 +115,19 @@ def init(backend_name=None):
>>> from mindspore.communication import init
>>> init()
"""
# 检查当前角色是否为参数服务器或调度器,如果是则直接返回
if _is_role_pserver() or _is_role_sched():
return
# 检查任务接收环境变量,获取设备目标和模式
task_sink = _check_task_sink_envs()
device_target = context.get_context("device_target")
mode = context.get_context("mode")
mpi_init = False
# 如果没有任务接收且模式为图模式设置mpi_init为True
if not task_sink and mode == context.GRAPH_MODE:
mpi_init = True
# 根据设备目标选择后端名称,如果不支持则抛出异常
if backend_name is None:
if device_target == "Ascend":
backend_name = "hccl"
@ -126,28 +136,35 @@ def init(backend_name=None):
else:
raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in "
"parallel initialization, please use Ascend or GPU.".format(device_target))
# 检查后端名称是否为字符串,如果不是则抛出异常
if not isinstance(backend_name, str):
raise TypeError("For 'init', the argument 'backend_name' must be a string, "
"but got the type : {}".format(type(backend_name)))
# 根据后端名称初始化通信环境
if backend_name == "hccl":
# 如果设备目标不是Ascend抛出异常
if device_target != "Ascend":
raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, "
"but got {}".format(device_target))
# 如果不需要MPI初始化检查并行环境并设置后端名称
if not mpi_init:
_check_parallel_envs()
GlobalComm.BACKEND = Backend("hccl")
else:
GlobalComm.BACKEND = Backend("hccl_mpi")
# 初始化HCCL并设置全局通信组和初始化状态
init_hccl()
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
elif backend_name == "nccl":
# 初始化GPU集体通信并设置后端名称、全局通信组和初始化状态
init_gpu_collective()
GlobalComm.BACKEND = Backend("nccl")
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
else:
# 如果后端名称不支持,抛出异常
raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, "
"but got the 'backend_name' : hccl.")

Loading…
Cancel
Save