diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py b/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py index 9af56409..28802f9f 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py @@ -18,15 +18,21 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from mindspore import log as logger from ._hccl_management import load_lib as hccl_load_lib from .._c_expression import get_rank_id, get_rank_size - + +# 检查HCCL是否可用 _HCCL_AVAILABLE = False +# 检查HCCL测试是否可用 _HCCL_TEST_AVAILABLE = False +# 检查NCCL是否可用 _NCCL_AVAILABLE = False +# 检查MPI是否可用 _MPI_AVAILABLE = False try: + # 尝试导入mindspore._ms_mpi,如果成功则NCCL可用 import mindspore._ms_mpi as mpi _NCCL_AVAILABLE = True except ImportError: + # 如果导入失败,则NCCL不可用 _NCCL_AVAILABLE = False # 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True,否则捕获 RuntimeError 并设置为 False @@ -102,41 +108,41 @@ DEFAULT_BACKEND = Backend("hccl") class GlobalComm: """ World communication information. The GlobalComm is a global class. The members contain: - + - BACKEND: The communication library used, using HCCL/NCCL. - WORLD_COMM_GROUP: Global communication domain. """ BACKEND = DEFAULT_BACKEND WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP - INITED = False - CHECK_ENVS = True + INITED = False # 标记全局通信是否已初始化 + CHECK_ENVS = True # 标记是否需要检查通信环境变量 class _ExistingGroup: """ - The communication groups which exist in the progress. + 用于表示在程序运行过程中存在的通信组。 """ - ITEMS = {} - - + ITEMS = {} # 存储通信组的字典,键为通信组的标识符,值为通信组对象 + + def is_hccl_available(): """ - Check HCCL api is available. - + 检查HCCL API是否可用。 + Returns: - Boolean. Return whether HCCL is available or not. + Boolean: 返回HCCL是否可用。 """ - return _HCCL_AVAILABLE - - + return _HCCL_AVAILABLE # 返回一个布尔值,指示HCCL是否可用 + + def is_mpi_available(): """ - Check HCCL & MPI api is available. - + 检查HCCL和MPI API是否可用。 + Returns: - Boolean. Return whether HCCL & MPI is available or not. + Boolean: 返回HCCL和MPI是否同时可用。 """ - return _MPI_AVAILABLE + return _MPI_AVAILABLE # 返回一个布尔值,指示HCCL和MPI是否同时可用 def is_nccl_available(): """ @@ -253,18 +259,26 @@ def _get_local_rank_helper(group, backend): """ # 获取当前进程的rank id,根据不同的后端和组进行处理 rank_id = None + # 根据不同的后端选择获取rank_id的方法 if backend == Backend.HCCL_MPI: + # 使用HCCL MPI后端时,通过mpi.get_rank_id获取rank_id rank_id = mpi.get_rank_id(group) elif backend == Backend.HCCL: + # 使用HCCL后端时,根据group的不同选择获取rank_id的方法 if group == HCCL_WORLD_COMM_GROUP: + # 如果group是HCCL_WORLD_COMM_GROUP,则使用hccl.get_local_rank_id获取rank_id rank_id = hccl.get_local_rank_id() else: + # 对于其他group,同样使用hccl.get_local_rank_id获取rank_id rank_id = hccl.get_local_rank_id(group) elif backend == Backend.NCCL: + # 如果使用NCCL后端,当前不支持get_local_rank_id方法,抛出异常 raise RuntimeError("Nccl doesn't support get_local_rank_id now.") else: + # 如果backend既不是HCCL_MPI也不是HCCL,抛出异常表示不支持的backend raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, " "please use hccl_mpi or hccl.".format(backend)) + # 返回获取到的rank_id return rank_id @check_parameter_available @@ -288,6 +302,7 @@ def _get_size_helper(group, backend): size = 1 return size # 根据不同的后端设置size的值 + # 根据不同的后端获取组的大小 if backend == Backend.HCCL_MPI: size = mpi.get_rank_size(group) elif backend == Backend.HCCL: diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py b/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py index 8307a839..1841b222 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py @@ -44,18 +44,26 @@ def check_group(group): # 检查集体通信中的排名编号是否合法 def check_rank_num(rank_num): """ - A function that check if a collection communication rank number is legal.If not raise error. + 检查通信集合中的排名编号是否合法。如果不合法则抛出错误。 + + 参数: + rank_num: 需要检查的排名编号,预期为整数类型。 Returns: - None + None: 该函数不返回任何值,但可能会抛出异常。 """ - if isinstance(rank_num, (int)): + # 检查 rank_num 是否为整数类型 + if isinstance(rank_num, int): + # 检查 rank_num 是否在合法范围内(大于0且小于等于 MAX_RANK_NUM) if rank_num > MAX_RANK_NUM or rank_num <= 0: - raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and" - "less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num)) + # 如果 rank_num 不在合法范围内,抛出 ValueError 异常,并提供详细的错误信息 + raise ValueError("对于 'create_group' 函数,参数 'rank_ids' 的大小必须大于0且" + "小于等于 {},但得到的 'rank_ids' 的大小为: {}。".format(MAX_RANK_NUM, rank_num)) else: - raise TypeError("The argument 'rank_num' must be type of int, " - "but got 'rank_num' type : {}.".format(type(rank_num))) + # 如果 rank_num 不是整数类型,抛出 TypeError 异常,并提供详细的错误信息 + raise TypeError("参数 'rank_num' 必须为整数类型," + "但得到的 'rank_num' 类型为: {}。".format(type(rank_num))) + #检查集体通信中的排名标识(rank id)是否合法 def check_rank_id(rank_id): @@ -80,26 +88,32 @@ def check_rank_id(rank_id): def load_lib(): """load hccl lib""" try: + # 获取当前文件所在的目录 base_dir = os.path.dirname(os.path.realpath(__file__)) + # 构建库文件的路径 lib_path = os.path.join(base_dir, "../lib", HCCL_LIB) + # 加载库文件 hccl_lib = ctypes.CDLL(lib_path) except Exception: + # 如果加载失败则抛出运行时错误 raise RuntimeError('Get hccl lib error.') - + + # 将加载的库文件设置为全局变量 global HCCL_LIB_CTYPES HCCL_LIB_CTYPES = hccl_lib - + def c_str(string): """Convert a python string to C string.""" + # 将字符串转换为C风格字符串 if not isinstance(string, str): string = string.decode('ascii') return ctypes.c_char_p(string.encode('utf-8')) - - + + def c_array(ctype, values): """Create ctypes array from a python array.""" + # 从Python数组创建ctypes数组 return (ctype * len(values))(*values) - #用于创建包含指定数量和ID的HCCL通信组,但不能创建世界组。 def create_group(group, rank_num, rank_ids): """ @@ -156,11 +170,16 @@ def destroy_group(group): Returns: None """ - check_group(group) - c_group = c_str(group) - ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) - if ret != 0: - raise RuntimeError('Destroy group error.') + # 检查传入的组是否有效 +check_group(group) +# 将组名转换为C风格的字符串 +c_group = c_str(group) +# 调用HCCL库中的函数销毁指定的组 +ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) +# 如果返回值不为0,说明销毁组时发生了错误,抛出异常 +if ret != 0: + raise RuntimeError('Destroy group error.') + def get_rank_size(group="hccl_world_group"): diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/management.py b/src/mindspore2022/mindspore/python/mindspore/communication/management.py index a00dee56..cba0dc9b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/management.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/management.py @@ -20,40 +20,51 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ _get_local_rank_helper, _get_local_size_helper, GlobalComm from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective - +# 导入mindspore上下文模块 +# 导入mindspore的进程服务器和调度器角色判断函数 +# 导入通信辅助模块,包括后端类型、获取和设置rank及size的辅助函数、创建和销毁组的辅助函数、以及世界通信组和本地通信组的相关标识 +# 导入C语言表达式模块,包含HCCL和NCCL的初始化和终结化函数,以及GPU集体通信的初始化函数 __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", "get_local_rank_size", "get_world_rank_from_group_rank", "get_group_rank_from_world_rank", "create_group", "destroy_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] +# 默认的世界通信组 DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP def _get_group(group): - """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" + """根据输入的`group`参数返回相应的通信组。 + + 如果`group`等于`DEFAULT_WORLD_COMM_GROUP`,则返回全局通信组`GlobalComm.WORLD_COMM_GROUP`; + 否则,返回传入的`group`参数。 + """ if group == DEFAULT_WORLD_COMM_GROUP: - return GlobalComm.WORLD_COMM_GROUP - return group + return GlobalComm.WORLD_COMM_GROUP # 返回默认的世界通信组 + return group # 返回传入的通信组参数 def _check_task_sink_envs(): - """ - Check whether task_sink environment variables have been exported or not. - - return True if task_sink environment variables have been exported, False otherwise. + """检查任务接收器(task_sink)相关的环境变量是否已导出。 + + 该函数通过检查环境变量`GRAPH_OP_RUN`来判断任务接收器的环境变量是否已导出。 + + 返回值: + - 如果环境变量`GRAPH_OP_RUN`已导出且其值可以转换为整数1,则返回False,表示环境变量未正确设置为启用状态。 + - 如果环境变量`GRAPH_OP_RUN`未导出或其值不能转换为整数1,则返回True,表示环境变量未导出或设置有误。 """ import os - task_sink = os.getenv("GRAPH_OP_RUN") + task_sink = os.getenv("GRAPH_OP_RUN") # 获取名为"GRAPH_OP_RUN"的环境变量 if task_sink: try: - if int(task_sink) == 1: - return False + if int(task_sink) == 1: # 尝试将环境变量的值转换为整数并检查是否等于1 + return False # 如果等于1,返回False,表示环境变量已导出但设置为禁用状态(非预期情况) except ValueError: - return True + return True # 如果转换为整数失败,返回True,表示环境变量设置有误 finally: - pass - return True + pass # finally块中的代码在这里是空操作,通常用于清理操作 + return True # 如果环境变量未导出,返回True def _check_parallel_envs(): @@ -128,6 +139,7 @@ def init(backend_name=None): mpi_init = True # 根据设备目标选择后端名称,如果不支持则抛出异常 + # 根据设备目标设置默认的后端名称 if backend_name is None: if device_target == "Ascend": backend_name = "hccl" @@ -140,7 +152,6 @@ def init(backend_name=None): if not isinstance(backend_name, str): raise TypeError("For 'init', the argument 'backend_name' must be a string, " "but got the type : {}".format(type(backend_name))) - # 根据后端名称初始化通信环境 if backend_name == "hccl": # 如果设备目标不是Ascend,抛出异常 @@ -182,10 +193,11 @@ def release(): Examples: >>> from mindspore.communication import init, release - >>> init() - >>> release() + >>> init() # 初始化分布式通信环境 + >>> release() # 释放分布式通信资源 + """ - finalize_hccl() + finalize_hccl()# 结束 HCCL 的使用,释放相关资源 def get_rank(group=GlobalComm.WORLD_COMM_GROUP): @@ -214,12 +226,18 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP): >>> print(rank_id) >>> # the result is the rank_id in world_group """ + # 检查传入的group参数是否为字符串类型 if not isinstance(group, str): + # 如果group参数不是字符串类型,则抛出TypeError异常 raise TypeError("For 'get_rank', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用_get_rank_helper函数,获取指定group的rank ID + # _get_group函数用于解析group参数 + # GlobalComm.BACKEND变量存储了当前使用的通信后端 return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) + def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): """ Gets local rank ID for current device in specified collective communication group. @@ -251,9 +269,11 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): >>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank)) local_rank is: 1, world_rank is 9 """ + # 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用辅助函数 _get_local_rank_helper 来获取本地排名,传入的参数为解析后的组和全局通信后端 return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) @@ -287,9 +307,11 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): >>> print("group_size is: ", group_size) group_size is: 8 """ + # 检查传入的参数 'group' 是否为字符串类型 if not isinstance(group, str): raise TypeError("For 'get_group_size', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 返回指定组的大小,使用辅助函数 _get_size_helper 和全局通信后端 return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) @@ -323,9 +345,11 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP): >>> print("local_rank_size is: ", local_rank_size) local_rank_size is: 8 """ + # 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用辅助函数获取本地组的大小,使用 _get_group 函数获取组,并使用 GlobalComm.BACKEND 作为后端 return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) @@ -364,12 +388,12 @@ def get_world_rank_from_group_rank(group, group_rank_id): >>> print("world_rank_id is: ", world_rank_id) world_rank_id is: 4 """ + # 检查传入的 group 参数是否为字符串类型 if not isinstance(group, str): raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 返回根据组和组排名获取的世界排名的帮助函数的结果 return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND) - - def get_group_rank_from_world_rank(world_rank_id, group): """ Get the rank ID in the specified user communication group corresponding to @@ -406,12 +430,13 @@ def get_group_rank_from_world_rank(world_rank_id, group): >>> print("group_rank_id is: ", group_rank_id) group_rank_id is: 1 """ + # 检查输入参数 'group' 是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用辅助函数 _get_group_rank_from_world_rank_helper 来获取组的排名,并返回结果 return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND) - - +#创建一个用户集体通信组,通过传入通信组名称和设备ID列表来实现。 def create_group(group, rank_ids): """ Create a user collective communication group. @@ -444,9 +469,11 @@ def create_group(group, rank_ids): >>> create_group(group, rank_ids) >>> allreduce = ops.AllReduce(group) """ + # 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'create_group', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用辅助函数 _create_group_helper 来创建组,使用指定的 rank_ids 和后端 _create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND) @@ -467,7 +494,9 @@ def destroy_group(group): ValueError: If group is "hccl_world_group" or backend is invalid. RuntimeError: If HCCL is not available or MindSpore is GPU version. """ + # 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'destroy_group', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) - _destroy_group_helper(group, backend=GlobalComm.BACKEND) + # 调用辅助函数 _destroy_group_helper 来销毁指定的组,并使用全局通信后端 + _destroy_group_helper(group, backend=GlobalComm.BACKEND) \ No newline at end of file