|
|
|
@ -20,40 +20,51 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
|
|
|
|
|
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
|
|
|
|
|
_get_local_rank_helper, _get_local_size_helper, GlobalComm
|
|
|
|
|
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
|
|
|
|
|
|
|
|
|
|
# 导入mindspore上下文模块
|
|
|
|
|
# 导入mindspore的进程服务器和调度器角色判断函数
|
|
|
|
|
# 导入通信辅助模块,包括后端类型、获取和设置rank及size的辅助函数、创建和销毁组的辅助函数、以及世界通信组和本地通信组的相关标识
|
|
|
|
|
# 导入C语言表达式模块,包含HCCL和NCCL的初始化和终结化函数,以及GPU集体通信的初始化函数
|
|
|
|
|
|
|
|
|
|
__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
|
|
|
|
|
"get_local_rank_size", "get_world_rank_from_group_rank",
|
|
|
|
|
"get_group_rank_from_world_rank", "create_group", "destroy_group",
|
|
|
|
|
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
|
|
|
|
|
|
|
|
|
|
# 默认的世界通信组
|
|
|
|
|
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_group(group):
|
|
|
|
|
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
|
|
|
|
"""根据输入的`group`参数返回相应的通信组。
|
|
|
|
|
|
|
|
|
|
如果`group`等于`DEFAULT_WORLD_COMM_GROUP`,则返回全局通信组`GlobalComm.WORLD_COMM_GROUP`;
|
|
|
|
|
否则,返回传入的`group`参数。
|
|
|
|
|
"""
|
|
|
|
|
if group == DEFAULT_WORLD_COMM_GROUP:
|
|
|
|
|
return GlobalComm.WORLD_COMM_GROUP
|
|
|
|
|
return group
|
|
|
|
|
return GlobalComm.WORLD_COMM_GROUP # 返回默认的世界通信组
|
|
|
|
|
return group # 返回传入的通信组参数
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_task_sink_envs():
|
|
|
|
|
"""
|
|
|
|
|
Check whether task_sink environment variables have been exported or not.
|
|
|
|
|
|
|
|
|
|
return True if task_sink environment variables have been exported, False otherwise.
|
|
|
|
|
"""检查任务接收器(task_sink)相关的环境变量是否已导出。
|
|
|
|
|
|
|
|
|
|
该函数通过检查环境变量`GRAPH_OP_RUN`来判断任务接收器的环境变量是否已导出。
|
|
|
|
|
|
|
|
|
|
返回值:
|
|
|
|
|
- 如果环境变量`GRAPH_OP_RUN`已导出且其值可以转换为整数1,则返回False,表示环境变量未正确设置为启用状态。
|
|
|
|
|
- 如果环境变量`GRAPH_OP_RUN`未导出或其值不能转换为整数1,则返回True,表示环境变量未导出或设置有误。
|
|
|
|
|
"""
|
|
|
|
|
import os
|
|
|
|
|
task_sink = os.getenv("GRAPH_OP_RUN")
|
|
|
|
|
task_sink = os.getenv("GRAPH_OP_RUN") # 获取名为"GRAPH_OP_RUN"的环境变量
|
|
|
|
|
if task_sink:
|
|
|
|
|
try:
|
|
|
|
|
if int(task_sink) == 1:
|
|
|
|
|
return False
|
|
|
|
|
if int(task_sink) == 1: # 尝试将环境变量的值转换为整数并检查是否等于1
|
|
|
|
|
return False # 如果等于1,返回False,表示环境变量已导出但设置为禁用状态(非预期情况)
|
|
|
|
|
except ValueError:
|
|
|
|
|
return True
|
|
|
|
|
return True # 如果转换为整数失败,返回True,表示环境变量设置有误
|
|
|
|
|
finally:
|
|
|
|
|
pass
|
|
|
|
|
return True
|
|
|
|
|
pass # finally块中的代码在这里是空操作,通常用于清理操作
|
|
|
|
|
return True # 如果环境变量未导出,返回True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_parallel_envs():
|
|
|
|
@ -128,6 +139,7 @@ def init(backend_name=None):
|
|
|
|
|
mpi_init = True
|
|
|
|
|
|
|
|
|
|
# 根据设备目标选择后端名称,如果不支持则抛出异常
|
|
|
|
|
# 根据设备目标设置默认的后端名称
|
|
|
|
|
if backend_name is None:
|
|
|
|
|
if device_target == "Ascend":
|
|
|
|
|
backend_name = "hccl"
|
|
|
|
@ -140,7 +152,6 @@ def init(backend_name=None):
|
|
|
|
|
if not isinstance(backend_name, str):
|
|
|
|
|
raise TypeError("For 'init', the argument 'backend_name' must be a string, "
|
|
|
|
|
"but got the type : {}".format(type(backend_name)))
|
|
|
|
|
|
|
|
|
|
# 根据后端名称初始化通信环境
|
|
|
|
|
if backend_name == "hccl":
|
|
|
|
|
# 如果设备目标不是Ascend,抛出异常
|
|
|
|
@ -182,10 +193,11 @@ def release():
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> from mindspore.communication import init, release
|
|
|
|
|
>>> init()
|
|
|
|
|
>>> release()
|
|
|
|
|
>>> init() # 初始化分布式通信环境
|
|
|
|
|
>>> release() # 释放分布式通信资源
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
finalize_hccl()
|
|
|
|
|
finalize_hccl()# 结束 HCCL 的使用,释放相关资源
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
@ -214,12 +226,18 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
>>> print(rank_id)
|
|
|
|
|
>>> # the result is the rank_id in world_group
|
|
|
|
|
"""
|
|
|
|
|
# 检查传入的group参数是否为字符串类型
|
|
|
|
|
if not isinstance(group, str):
|
|
|
|
|
# 如果group参数不是字符串类型,则抛出TypeError异常
|
|
|
|
|
raise TypeError("For 'get_rank', the argument 'group' must be type of string, "
|
|
|
|
|
"but got 'group' type : {}.".format(type(group)))
|
|
|
|
|
# 调用_get_rank_helper函数,获取指定group的rank ID
|
|
|
|
|
# _get_group函数用于解析group参数
|
|
|
|
|
# GlobalComm.BACKEND变量存储了当前使用的通信后端
|
|
|
|
|
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
"""
|
|
|
|
|
Gets local rank ID for current device in specified collective communication group.
|
|
|
|
@ -251,9 +269,11 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
>>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank))
|
|
|
|
|
local_rank is: 1, world_rank is 9
|
|
|
|
|
"""
|
|
|
|
|
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
|
|
|
|
|
if not isinstance(group, str):
|
|
|
|
|
raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, "
|
|
|
|
|
"but got 'group' type : {}.".format(type(group)))
|
|
|
|
|
# 调用辅助函数 _get_local_rank_helper 来获取本地排名,传入的参数为解析后的组和全局通信后端
|
|
|
|
|
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -287,9 +307,11 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
>>> print("group_size is: ", group_size)
|
|
|
|
|
group_size is: 8
|
|
|
|
|
"""
|
|
|
|
|
# 检查传入的参数 'group' 是否为字符串类型
|
|
|
|
|
if not isinstance(group, str):
|
|
|
|
|
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
|
|
|
|
|
"but got 'group' type : {}.".format(type(group)))
|
|
|
|
|
# 返回指定组的大小,使用辅助函数 _get_size_helper 和全局通信后端
|
|
|
|
|
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -323,9 +345,11 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
|
|
|
>>> print("local_rank_size is: ", local_rank_size)
|
|
|
|
|
local_rank_size is: 8
|
|
|
|
|
"""
|
|
|
|
|
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
|
|
|
|
|
if not isinstance(group, str):
|
|
|
|
|
raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, "
|
|
|
|
|
"but got 'group' type : {}.".format(type(group)))
|
|
|
|
|
# 调用辅助函数获取本地组的大小,使用 _get_group 函数获取组,并使用 GlobalComm.BACKEND 作为后端
|
|
|
|
|
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -364,12 +388,12 @@ def get_world_rank_from_group_rank(group, group_rank_id):
|
|
|
|
|
>>> print("world_rank_id is: ", world_rank_id)
|
|
|
|
|
world_rank_id is: 4
|
|
|
|
|
"""
|
|
|
|
|
# 检查传入的 group 参数是否为字符串类型
|
|
|
|
|
if not isinstance(group, str):
|
|
|
|
|
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, "
|
|
|
|
|
"but got 'group' type : {}.".format(type(group)))
|
|
|
|
|
# 返回根据组和组排名获取的世界排名的帮助函数的结果
|
|
|
|
|
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_group_rank_from_world_rank(world_rank_id, group):
|
|
|
|
|
"""
|
|
|
|
|
Get the rank ID in the specified user communication group corresponding to
|
|
|
|
@ -406,12 +430,13 @@ def get_group_rank_from_world_rank(world_rank_id, group):
|
|
|
|
|
>>> print("group_rank_id is: ", group_rank_id)
|
|
|
|
|
group_rank_id is: 1
|
|
|
|
|
"""
|
|
|
|
|
# 检查输入参数 'group' 是否为字符串类型,如果不是则抛出 TypeError
|
|
|
|
|
if not isinstance(group, str):
|
|
|
|
|
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, "
|
|
|
|
|
"but got 'group' type : {}.".format(type(group)))
|
|
|
|
|
# 调用辅助函数 _get_group_rank_from_world_rank_helper 来获取组的排名,并返回结果
|
|
|
|
|
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#创建一个用户集体通信组,通过传入通信组名称和设备ID列表来实现。
|
|
|
|
|
def create_group(group, rank_ids):
|
|
|
|
|
"""
|
|
|
|
|
Create a user collective communication group.
|
|
|
|
@ -444,9 +469,11 @@ def create_group(group, rank_ids):
|
|
|
|
|
>>> create_group(group, rank_ids)
|
|
|
|
|
>>> allreduce = ops.AllReduce(group)
|
|
|
|
|
"""
|
|
|
|
|
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
|
|
|
|
|
if not isinstance(group, str):
|
|
|
|
|
raise TypeError("For 'create_group', the argument 'group' must be type of string, "
|
|
|
|
|
"but got 'group' type : {}.".format(type(group)))
|
|
|
|
|
# 调用辅助函数 _create_group_helper 来创建组,使用指定的 rank_ids 和后端
|
|
|
|
|
_create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -467,7 +494,9 @@ def destroy_group(group):
|
|
|
|
|
ValueError: If group is "hccl_world_group" or backend is invalid.
|
|
|
|
|
RuntimeError: If HCCL is not available or MindSpore is GPU version.
|
|
|
|
|
"""
|
|
|
|
|
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
|
|
|
|
|
if not isinstance(group, str):
|
|
|
|
|
raise TypeError("For 'destroy_group', the argument 'group' must be type of string, "
|
|
|
|
|
"but got 'group' type : {}.".format(type(group)))
|
|
|
|
|
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
|
|
|
|
|
# 调用辅助函数 _destroy_group_helper 来销毁指定的组,并使用全局通信后端
|
|
|
|
|
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
|