communication_ruiqin2

Branch_ruiqin
zhang 2 months ago
parent 542726fc1a
commit 759ad9d09c

@ -18,15 +18,21 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib
from .._c_expression import get_rank_id, get_rank_size
# 检查HCCL是否可用
_HCCL_AVAILABLE = False
# 检查HCCL测试是否可用
_HCCL_TEST_AVAILABLE = False
# 检查NCCL是否可用
_NCCL_AVAILABLE = False
# 检查MPI是否可用
_MPI_AVAILABLE = False
try:
# 尝试导入mindspore._ms_mpi如果成功则NCCL可用
import mindspore._ms_mpi as mpi
_NCCL_AVAILABLE = True
except ImportError:
# 如果导入失败则NCCL不可用
_NCCL_AVAILABLE = False
# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True否则捕获 RuntimeError 并设置为 False
@ -102,41 +108,41 @@ DEFAULT_BACKEND = Backend("hccl")
class GlobalComm:
"""
World communication information. The GlobalComm is a global class. The members contain:
- BACKEND: The communication library used, using HCCL/NCCL.
- WORLD_COMM_GROUP: Global communication domain.
"""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
CHECK_ENVS = True
INITED = False # 标记全局通信是否已初始化
CHECK_ENVS = True # 标记是否需要检查通信环境变量
class _ExistingGroup:
"""
The communication groups which exist in the progress.
用于表示在程序运行过程中存在的通信组
"""
ITEMS = {}
ITEMS = {} # 存储通信组的字典,键为通信组的标识符,值为通信组对象
def is_hccl_available():
"""
Check HCCL api is available.
检查HCCL API是否可用
Returns:
Boolean. Return whether HCCL is available or not.
Boolean: 返回HCCL是否可用
"""
return _HCCL_AVAILABLE
return _HCCL_AVAILABLE # 返回一个布尔值指示HCCL是否可用
def is_mpi_available():
"""
Check HCCL & MPI api is available.
检查HCCL和MPI API是否可用
Returns:
Boolean. Return whether HCCL & MPI is available or not.
Boolean: 返回HCCL和MPI是否同时可用
"""
return _MPI_AVAILABLE
return _MPI_AVAILABLE # 返回一个布尔值指示HCCL和MPI是否同时可用
def is_nccl_available():
"""
@ -253,18 +259,26 @@ def _get_local_rank_helper(group, backend):
"""
# 获取当前进程的rank id根据不同的后端和组进行处理
rank_id = None
# 根据不同的后端选择获取rank_id的方法
if backend == Backend.HCCL_MPI:
# 使用HCCL MPI后端时通过mpi.get_rank_id获取rank_id
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
# 使用HCCL后端时根据group的不同选择获取rank_id的方法
if group == HCCL_WORLD_COMM_GROUP:
# 如果group是HCCL_WORLD_COMM_GROUP则使用hccl.get_local_rank_id获取rank_id
rank_id = hccl.get_local_rank_id()
else:
# 对于其他group同样使用hccl.get_local_rank_id获取rank_id
rank_id = hccl.get_local_rank_id(group)
elif backend == Backend.NCCL:
# 如果使用NCCL后端当前不支持get_local_rank_id方法抛出异常
raise RuntimeError("Nccl doesn't support get_local_rank_id now.")
else:
# 如果backend既不是HCCL_MPI也不是HCCL抛出异常表示不支持的backend
raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi or hccl.".format(backend))
# 返回获取到的rank_id
return rank_id
@check_parameter_available
@ -288,6 +302,7 @@ def _get_size_helper(group, backend):
size = 1
return size
# 根据不同的后端设置size的值
# 根据不同的后端获取组的大小
if backend == Backend.HCCL_MPI:
size = mpi.get_rank_size(group)
elif backend == Backend.HCCL:

@ -44,18 +44,26 @@ def check_group(group):
# 检查集体通信中的排名编号是否合法
def check_rank_num(rank_num):
"""
A function that check if a collection communication rank number is legal.If not raise error.
检查通信集合中的排名编号是否合法如果不合法则抛出错误
参数:
rank_num: 需要检查的排名编号预期为整数类型
Returns:
None
None: 该函数不返回任何值但可能会抛出异常
"""
if isinstance(rank_num, (int)):
# 检查 rank_num 是否为整数类型
if isinstance(rank_num, int):
# 检查 rank_num 是否在合法范围内大于0且小于等于 MAX_RANK_NUM
if rank_num > MAX_RANK_NUM or rank_num <= 0:
raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and"
"less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num))
# 如果 rank_num 不在合法范围内,抛出 ValueError 异常,并提供详细的错误信息
raise ValueError("对于 'create_group' 函数,参数 'rank_ids' 的大小必须大于0且"
"小于等于 {},但得到的 'rank_ids' 的大小为: {}".format(MAX_RANK_NUM, rank_num))
else:
raise TypeError("The argument 'rank_num' must be type of int, "
"but got 'rank_num' type : {}.".format(type(rank_num)))
# 如果 rank_num 不是整数类型,抛出 TypeError 异常,并提供详细的错误信息
raise TypeError("参数 'rank_num' 必须为整数类型,"
"但得到的 'rank_num' 类型为: {}".format(type(rank_num)))
#检查集体通信中的排名标识(rank id)是否合法
def check_rank_id(rank_id):
@ -80,26 +88,32 @@ def check_rank_id(rank_id):
def load_lib():
"""load hccl lib"""
try:
# 获取当前文件所在的目录
base_dir = os.path.dirname(os.path.realpath(__file__))
# 构建库文件的路径
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
# 加载库文件
hccl_lib = ctypes.CDLL(lib_path)
except Exception:
# 如果加载失败则抛出运行时错误
raise RuntimeError('Get hccl lib error.')
# 将加载的库文件设置为全局变量
global HCCL_LIB_CTYPES
HCCL_LIB_CTYPES = hccl_lib
def c_str(string):
"""Convert a python string to C string."""
# 将字符串转换为C风格字符串
if not isinstance(string, str):
string = string.decode('ascii')
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array."""
# 从Python数组创建ctypes数组
return (ctype * len(values))(*values)
#用于创建包含指定数量和ID的HCCL通信组但不能创建世界组。
def create_group(group, rank_num, rank_ids):
"""
@ -156,11 +170,16 @@ def destroy_group(group):
Returns:
None
"""
check_group(group)
c_group = c_str(group)
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
if ret != 0:
raise RuntimeError('Destroy group error.')
# 检查传入的组是否有效
check_group(group)
# 将组名转换为C风格的字符串
c_group = c_str(group)
# 调用HCCL库中的函数销毁指定的组
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
# 如果返回值不为0说明销毁组时发生了错误抛出异常
if ret != 0:
raise RuntimeError('Destroy group error.')
def get_rank_size(group="hccl_world_group"):

@ -20,40 +20,51 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
_get_local_rank_helper, _get_local_size_helper, GlobalComm
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
# 导入mindspore上下文模块
# 导入mindspore的进程服务器和调度器角色判断函数
# 导入通信辅助模块包括后端类型、获取和设置rank及size的辅助函数、创建和销毁组的辅助函数、以及世界通信组和本地通信组的相关标识
# 导入C语言表达式模块包含HCCL和NCCL的初始化和终结化函数以及GPU集体通信的初始化函数
__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"get_local_rank_size", "get_world_rank_from_group_rank",
"get_group_rank_from_world_rank", "create_group", "destroy_group",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
# 默认的世界通信组
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
"""根据输入的`group`参数返回相应的通信组。
如果`group`等于`DEFAULT_WORLD_COMM_GROUP`则返回全局通信组`GlobalComm.WORLD_COMM_GROUP`
否则返回传入的`group`参数
"""
if group == DEFAULT_WORLD_COMM_GROUP:
return GlobalComm.WORLD_COMM_GROUP
return group
return GlobalComm.WORLD_COMM_GROUP # 返回默认的世界通信组
return group # 返回传入的通信组参数
def _check_task_sink_envs():
"""
Check whether task_sink environment variables have been exported or not.
return True if task_sink environment variables have been exported, False otherwise.
"""检查任务接收器task_sink相关的环境变量是否已导出。
该函数通过检查环境变量`GRAPH_OP_RUN`来判断任务接收器的环境变量是否已导出
返回值
- 如果环境变量`GRAPH_OP_RUN`已导出且其值可以转换为整数1则返回False表示环境变量未正确设置为启用状态
- 如果环境变量`GRAPH_OP_RUN`未导出或其值不能转换为整数1则返回True表示环境变量未导出或设置有误
"""
import os
task_sink = os.getenv("GRAPH_OP_RUN")
task_sink = os.getenv("GRAPH_OP_RUN") # 获取名为"GRAPH_OP_RUN"的环境变量
if task_sink:
try:
if int(task_sink) == 1:
return False
if int(task_sink) == 1: # 尝试将环境变量的值转换为整数并检查是否等于1
return False # 如果等于1返回False表示环境变量已导出但设置为禁用状态非预期情况
except ValueError:
return True
return True # 如果转换为整数失败返回True表示环境变量设置有误
finally:
pass
return True
pass # finally块中的代码在这里是空操作通常用于清理操作
return True # 如果环境变量未导出返回True
def _check_parallel_envs():
@ -128,6 +139,7 @@ def init(backend_name=None):
mpi_init = True
# 根据设备目标选择后端名称,如果不支持则抛出异常
# 根据设备目标设置默认的后端名称
if backend_name is None:
if device_target == "Ascend":
backend_name = "hccl"
@ -140,7 +152,6 @@ def init(backend_name=None):
if not isinstance(backend_name, str):
raise TypeError("For 'init', the argument 'backend_name' must be a string, "
"but got the type : {}".format(type(backend_name)))
# 根据后端名称初始化通信环境
if backend_name == "hccl":
# 如果设备目标不是Ascend抛出异常
@ -182,10 +193,11 @@ def release():
Examples:
>>> from mindspore.communication import init, release
>>> init()
>>> release()
>>> init() # 初始化分布式通信环境
>>> release() # 释放分布式通信资源
"""
finalize_hccl()
finalize_hccl()# 结束 HCCL 的使用,释放相关资源
def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
@ -214,12 +226,18 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
>>> print(rank_id)
>>> # the result is the rank_id in world_group
"""
# 检查传入的group参数是否为字符串类型
if not isinstance(group, str):
# 如果group参数不是字符串类型则抛出TypeError异常
raise TypeError("For 'get_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用_get_rank_helper函数获取指定group的rank ID
# _get_group函数用于解析group参数
# GlobalComm.BACKEND变量存储了当前使用的通信后端
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
"""
Gets local rank ID for current device in specified collective communication group.
@ -251,9 +269,11 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank))
local_rank is: 1, world_rank is 9
"""
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _get_local_rank_helper 来获取本地排名,传入的参数为解析后的组和全局通信后端
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -287,9 +307,11 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("group_size is: ", group_size)
group_size is: 8
"""
# 检查传入的参数 'group' 是否为字符串类型
if not isinstance(group, str):
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 返回指定组的大小,使用辅助函数 _get_size_helper 和全局通信后端
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -323,9 +345,11 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("local_rank_size is: ", local_rank_size)
local_rank_size is: 8
"""
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用辅助函数获取本地组的大小,使用 _get_group 函数获取组,并使用 GlobalComm.BACKEND 作为后端
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -364,12 +388,12 @@ def get_world_rank_from_group_rank(group, group_rank_id):
>>> print("world_rank_id is: ", world_rank_id)
world_rank_id is: 4
"""
# 检查传入的 group 参数是否为字符串类型
if not isinstance(group, str):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 返回根据组和组排名获取的世界排名的帮助函数的结果
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
def get_group_rank_from_world_rank(world_rank_id, group):
"""
Get the rank ID in the specified user communication group corresponding to
@ -406,12 +430,13 @@ def get_group_rank_from_world_rank(world_rank_id, group):
>>> print("group_rank_id is: ", group_rank_id)
group_rank_id is: 1
"""
# 检查输入参数 'group' 是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _get_group_rank_from_world_rank_helper 来获取组的排名,并返回结果
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
#创建一个用户集体通信组通过传入通信组名称和设备ID列表来实现。
def create_group(group, rank_ids):
"""
Create a user collective communication group.
@ -444,9 +469,11 @@ def create_group(group, rank_ids):
>>> create_group(group, rank_ids)
>>> allreduce = ops.AllReduce(group)
"""
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'create_group', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _create_group_helper 来创建组,使用指定的 rank_ids 和后端
_create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND)
@ -467,7 +494,9 @@ def destroy_group(group):
ValueError: If group is "hccl_world_group" or backend is invalid.
RuntimeError: If HCCL is not available or MindSpore is GPU version.
"""
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'destroy_group', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
# 调用辅助函数 _destroy_group_helper 来销毁指定的组,并使用全局通信后端
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
Loading…
Cancel
Save