branch-lwh
liuwenhao 2 months ago
commit bdf240c986

@ -0,0 +1,38 @@
import os
import stat
import fcntl
import platform
from logging.handlers import RotatingFileHandler
class _MultiCompatibleRotatingFileHandler(RotatingFileHandler):
"""Inherit RotatingFileHandler for multiprocess compatibility.
这个类继承自`RotatingFileHandler`,是为了在多进程环境下安全地使用日志回滚功能。
在多进程环境下,多个进程可能会同时尝试写入或回滚日志文件,这可能会导致文件损坏或数据丢失。
通过在这个类中对相关方法进行重写,确保了日志文件在多进程环境下的正确处理。
"""
def doRollover(self):
"""Override doRollover for multiprocess compatibility
and setting permission of Log file
这个方法重写了`RotatingFileHandler`中的`doRollover`方法,增加了多进程兼容性,
并设置了日志文件的权限。
1. 使用`fcntl`模块获得独占锁,确保在回滚日志文件时不会有其他进程进行写操作。
2. 设置日志文件的权限,以确保日志文件的安全性。
3. 调用父类的`doRollover`方法执行实际的日志回滚操作。
4. 回滚后,修改日志文件的权限,使其可读可写。
"""
# Attain an exclusive lock with blocking mode by `fcntl` module.
with open(self.baseFilename, 'a') as file_pointer:
# 如果操作系统不是Windows使用`fcntl`模块对文件加锁
if platform.system() != "Windows":
fcntl.lockf(file_pointer.fileno(), fcntl.LOCK_EX)
# 设置日志文件权限为只读,增加安全性
os.chmod(self.baseFilename, stat.S_IREAD)
# 调用父类的`doRollover`方法执行日志回滚操作
super().doRollover()
# 修改日志文件的权限为可读可写,以便后续的日志写入操作
os.chmod(self.baseFilename, stat.S_IREAD | stat.S_IWRITE)

@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib
from .._c_expression import get_rank_id, get_rank_size
_HCCL_AVAILABLE = False
_HCCL_TEST_AVAILABLE = False
_NCCL_AVAILABLE = False
@ -28,14 +28,15 @@ try:
_NCCL_AVAILABLE = True
except ImportError:
_NCCL_AVAILABLE = False
# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True否则捕获 RuntimeError 并设置为 False
try:
hccl_load_lib()
_HCCL_AVAILABLE = True
except RuntimeError:
_HCCL_AVAILABLE = False
# 如果 HCCL 可用,则导入 _hccl_management 并尝试导入 mindspore._ascend_mpi如果成功则设置 _MPI_AVAILABLE 为 True否则捕获 ImportError 并设置为 False
if _HCCL_AVAILABLE:
from . import _hccl_management as hccl
try:
@ -43,6 +44,7 @@ if _HCCL_AVAILABLE:
_MPI_AVAILABLE = True
except ImportError:
_MPI_AVAILABLE = False
# 如果 HCCL 不可用,则尝试导入 hccl_test.manage.api如果成功则设置 _HCCL_AVAILABLE 和 _HCCL_TEST_AVAILABLE 为 True否则捕获 ImportError 并设置 _HCCL_AVAILABLE 为 False
else:
try:
import hccl_test.manage.api as hccl
@ -50,12 +52,11 @@ else:
_HCCL_TEST_AVAILABLE = True
except ImportError:
_HCCL_AVAILABLE = False
# 定义 HCCL 和 NCCL 的通信组名称常量
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
class Backend:
"""
Class for available backends.
@ -82,13 +83,17 @@ class Backend:
def __new__(cls, name):
"""Create instance object of Backend."""
# 检查传入的name是否为字符串类型
if not isinstance(name, str):
raise TypeError("For 'Backend', the class variable 'name' must be a string, "
"but got the type : {}".format(type(name)))
# 获取对应name的大写形式的Backend类属性值如果不存在则返回Backend.UNDEFINED
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
# 如果获取到的值是Backend.UNDEFINED说明传入的name不被支持
if value == Backend.UNDEFINED:
raise ValueError("For 'Backend', the class variable 'name' {} is not supported, "
"please use hccl or nccl.".format(name))
# 返回获取到的Backend类属性值
return value
DEFAULT_BACKEND = Backend("hccl")
@ -97,7 +102,7 @@ DEFAULT_BACKEND = Backend("hccl")
class GlobalComm:
"""
World communication information. The GlobalComm is a global class. The members contain:
- BACKEND: The communication library used, using HCCL/NCCL.
- WORLD_COMM_GROUP: Global communication domain.
"""
@ -105,35 +110,34 @@ class GlobalComm:
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
CHECK_ENVS = True
class _ExistingGroup:
"""
The communication groups which exist in the progress.
"""
ITEMS = {}
def is_hccl_available():
"""
Check HCCL api is available.
Returns:
Boolean. Return whether HCCL is available or not.
"""
return _HCCL_AVAILABLE
def is_mpi_available():
"""
Check HCCL & MPI api is available.
Returns:
Boolean. Return whether HCCL & MPI is available or not.
"""
return _MPI_AVAILABLE
def is_nccl_available():
"""
Check NCCL api is available.
@ -158,19 +162,25 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error.
"""
def wrapper(*args, **kargs):
# 如果当前角色是参数服务器或者调度器,直接调用被装饰的函数
if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs)
# 检查分布式通信是否已经初始化,未初始化则抛出异常
if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited")
# 获取参数组默认为None
group = None
# 检查关键字参数中是否包含"group",并进行类型检查
if "group" in kargs.keys():
group = kargs.get("group")
if group is not None and not isinstance(group, str):
raise TypeError("The parameter 'group' should be str or None, "
"but got the type : {}".format(type(group)))
# 获取后端默认为None
if "backend" in kargs.keys():
backend = kargs.get("backend")
# 检查后端是否可用,不可用则抛出异常
if backend is Backend.HCCL and not is_hccl_available():
raise RuntimeError("Distributed Communication doesn't have HCCL built in")
if backend is Backend.HCCL_MPI and not is_mpi_available():
@ -178,15 +188,16 @@ def check_parameter_available(func):
if backend is Backend.NCCL and not is_nccl_available():
raise RuntimeError("Distributed Communication doesn't have NCCL built in")
# 如果未指定group根据backend设置默认的group
if group is None:
if backend is Backend.HCCL or Backend.HCCL_MPI:
if backend is Backend.HCCL or backend is Backend.HCCL_MPI:
group = HCCL_WORLD_COMM_GROUP
elif backend is Backend.NCCL:
group = NCCL_WORLD_COMM_GROUP
# 调用被装饰的函数
return func(*args, **kargs)
return wrapper
@check_parameter_available
def _get_rank_helper(group, backend):
"""
@ -202,10 +213,13 @@ def _get_rank_helper(group, backend):
Returns:
Integer. The local rank id of the calling process.
"""
# 辅助函数,用于根据不同的后端和组获取 rank_id
# 获取当前角色的rank_id如果是参数服务器或调度器角色rank_id设为0并返回
rank_id = None
if _is_role_pserver() or _is_role_sched():
rank_id = 0
return rank_id
# 根据不同的后端获取rank_id
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
@ -216,6 +230,7 @@ def _get_rank_helper(group, backend):
elif backend == Backend.NCCL:
rank_id = get_rank_id(group)
else:
# 如果后端不被支持抛出ValueError异常
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi, hccl or nccl.".format(backend))
return rank_id
@ -236,6 +251,7 @@ def _get_local_rank_helper(group, backend):
Returns:
Integer. The local rank id of the calling process.
"""
# 获取当前进程的rank id根据不同的后端和组进行处理
rank_id = None
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
@ -251,7 +267,6 @@ def _get_local_rank_helper(group, backend):
"please use hccl_mpi or hccl.".format(backend))
return rank_id
@check_parameter_available
def _get_size_helper(group, backend):
"""
@ -268,9 +283,11 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group.
"""
size = None
# 如果当前角色是参数服务器或调度器则将size设为1并返回
if _is_role_pserver() or _is_role_sched():
size = 1
return size
# 根据不同的后端设置size的值
if backend == Backend.HCCL_MPI:
size = mpi.get_rank_size(group)
elif backend == Backend.HCCL:
@ -302,19 +319,23 @@ def _get_local_size_helper(group, backend):
Integer. The local rank size where the calling process is being within specified group.
"""
size = None
# 根据不同的后端获取本地进程组的大小
if backend == Backend.HCCL:
# 如果组是全局通信组,则获取全局通信组的本地排名大小
if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_local_rank_size()
# 否则获取指定组的本地排名大小
else:
size = hccl.get_local_rank_size(group)
# NCCL后端不支持获取本地排名大小抛出异常
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
# 对于不支持的后端,抛出异常
else:
raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, "
"please use hccl.".format(backend))
return size
@check_parameter_available
def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
"""
@ -333,21 +354,26 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
Integer. A rank id in world communication group.
"""
world_rank_id = None
# 检查 group_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(group_rank_id, int):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
# 根据不同的后端选择不同的逻辑处理方式
if backend == Backend.HCCL:
# 如果在 GPU 上使用 HCCL但 group 参数为 HCCL_WORLD_COMM_GROUP则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl.get_world_rank_from_group_rank 方法获取 world_rank_id
world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
elif backend == Backend.NCCL:
# 如果使用 NCCL则抛出 RuntimeError 表示不支持该操作
raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
else:
# 如果 backend 参数不支持,则抛出 ValueError 表示不支持该后端
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取的 world_rank_id
return world_rank_id
@check_parameter_available
def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
"""
@ -366,21 +392,27 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
Integer. A rank id in user communication group.
"""
group_rank_id = None
# 检查 world_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(world_rank_id, int):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
"but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
# 根据不同的后端处理获取 group_rank_id 的逻辑
if backend == Backend.HCCL:
# 检查 GPU 后端的 group 参数是否正确,如果不正确则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl 模块的函数获取 group_rank_id
group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
# NCCL 后端不支持此操作,抛出 RuntimeError
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
# 如果后端不是 HCCL 或 NCCL则抛出 ValueError 表示不支持的后端
else:
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取到的 group_rank_id
return group_rank_id
@check_parameter_available
def _create_group_helper(group, rank_ids, backend):
"""
@ -395,34 +427,46 @@ def _create_group_helper(group, rank_ids, backend):
TypeError: If rank_ids is not a list.
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
"""
# 检查组是否已存在
if group in _ExistingGroup.ITEMS.keys():
# 如果组已存在且提供的rank_ids与存储的不一致抛出异常
if rank_ids != _ExistingGroup.ITEMS[group]:
raise ValueError("The group {} has been created, the rank_list is {}, "
"but current rank_list for the group is {}".
format(group, _ExistingGroup.ITEMS[group], rank_ids))
# 记录警告信息,提示组已存在
logger.warning("%r group has existed.", group)
return
# 根据不同的后端创建组
if backend == Backend.HCCL:
# 检查rank_ids是否为列表类型
if not isinstance(rank_ids, list):
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
# 检查rank_ids的长度是否大于1
rank_size = len(rank_ids)
if rank_size < 1:
raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, "
"but got 'rank_ids' size : {}.".format(len(rank_ids)))
# 检查rank_ids中是否有重复的元素
if len(rank_ids) - len(list(set(rank_ids))) > 0:
raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
# 使用HCCL创建组
hccl.create_group(group, rank_size, rank_ids)
elif backend == Backend.HCCL_MPI:
# 使用HCCL_MPI创建组
mpi.create_group(group, rank_ids)
elif backend == Backend.NCCL:
# NCCL暂不支持创建组抛出异常
raise RuntimeError("Nccl doesn't support create_group now.")
else:
# 如果后端不支持,抛出异常
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend))
_ExistingGroup.ITEMS[group] = rank_ids
# 将新创建的组及其rank_ids添加到_existingGroup中
_ExistingGroup.ITEMS[group] = rank_ids
@check_parameter_available
def _destroy_group_helper(group, backend):
"""
@ -435,12 +479,17 @@ def _destroy_group_helper(group, backend):
Raises:
ValueError: If group is "hccl_world_group" or backend is invalid.
"""
# 根据后端类型销毁通信组
if backend == Backend.HCCL:
# 检查是否为 HCCL 的全局通信组
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("The hccl_world_group does not support destruction.")
# 销毁指定的 HCCL 通信组
hccl.destroy_group(group)
elif backend == Backend.NCCL:
# 当前 NCCL 后端不支持销毁通信组
raise RuntimeError("Nccl doesn't support destroy_group now.")
else:
# 抛出错误,表示不支持的后端类型
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend))
"please use hccl.".format(backend))

@ -24,7 +24,7 @@ MAX_RANK_NUM = 4096
HCCL_LIB = 'libhccl_plugin.so'
HCCL_LIB_CTYPES = ""
# 检查集体通信组的名称是否合法
def check_group(group):
"""
A function that check if a collection communication group is legal.
@ -41,7 +41,7 @@ def check_group(group):
raise TypeError("The type of communication group name must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 检查集体通信中的排名编号是否合法
def check_rank_num(rank_num):
"""
A function that check if a collection communication rank number is legal.If not raise error.
@ -57,7 +57,7 @@ def check_rank_num(rank_num):
raise TypeError("The argument 'rank_num' must be type of int, "
"but got 'rank_num' type : {}.".format(type(rank_num)))
#检查集体通信中的排名标识(rank id)是否合法
def check_rank_id(rank_id):
"""
A function that check if a collection communication rank id is legal.If not raise error.
@ -65,15 +65,18 @@ def check_rank_id(rank_id):
Returns:
None
"""
# 检查rank_id是否为整数类型
if isinstance(rank_id, (int)):
# 检查rank_id是否在有效范围内
if rank_id >= MAX_RANK_NUM or rank_id < 0:
raise ValueError("The rand id in the communication group must be greater or equal 0 and "
"less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id))
else:
# 如果rank_id不是整数类型抛出类型错误
raise TypeError("The rand id in the communication group must be must be type of int, "
"but got type value : {}.".format(type(rank_id)))
# 加载 HCCLHuawei Cloud Communication Library
def load_lib():
"""load hccl lib"""
try:
@ -82,11 +85,10 @@ def load_lib():
hccl_lib = ctypes.CDLL(lib_path)
except Exception:
raise RuntimeError('Get hccl lib error.')
global HCCL_LIB_CTYPES
HCCL_LIB_CTYPES = hccl_lib
def c_str(string):
"""Convert a python string to C string."""
if not isinstance(string, str):
@ -98,7 +100,7 @@ def c_array(ctype, values):
"""Create ctypes array from a python array."""
return (ctype * len(values))(*values)
#用于创建包含指定数量和ID的HCCL通信组但不能创建世界组。
def create_group(group, rank_num, rank_ids):
"""
Create group.
@ -112,28 +114,38 @@ def create_group(group, rank_num, rank_ids):
Returns:
None
"""
# 检查组的有效性
check_group(group)
# 检查排名数量的有效性
check_rank_num(rank_num)
# 检查rank_ids是否为列表类型
if isinstance(rank_ids, (list)):
# 确保rank_num与rank_ids的长度一致
if rank_num != len(rank_ids):
raise ValueError("The argument 'rank_num' number should be equal to the length "
"of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}."
.format(rank_num, rank_ids))
# 检查rank_ids中的每个元素是否为非负整数
for rank_id in rank_ids:
if not isinstance(rank_id, (int)) or rank_id < 0:
raise ValueError("The elements of argument 'rank_ids' must be "
"unsigned integer, but got the type : {}".format(type(rank_id)))
# 将rank_ids转换为C类型的数组
c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
# 将rank_num转换为C类型的无符号整数
c_rank_num = ctypes.c_uint(rank_num)
# 将group转换为C类型的字符串
c_group = c_str(group)
# 调用HCCL库创建组
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
# 检查组创建是否成功
if ret != 0:
raise RuntimeError('Create group error, the error code is {}.'.format(ret))
else:
# 如果rank_ids不是列表类型抛出类型错误
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
#用于销毁用户创建的HCCL组
def destroy_group(group):
"""
A function that destroy the group which created by user.
@ -162,16 +174,23 @@ def get_rank_size(group="hccl_world_group"):
An integer scalar with the num of ranks.
"""
# 根据上下文的模式判断是否为PYNATIVE_MODE模式若是则直接返回HCCL的rank size
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_size()
# 检查给定的组是否有效
check_group(group)
# 将组名转换为C字符串格式
c_group = c_str(group)
# 定义一个C类型的无符号整数用于存储rank size
c_rank_size = ctypes.c_uint()
# 调用HCCL库的HcomGetRankSize函数获取组内的rank size
ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
# 如果返回值不为0表示获取rank size失败抛出运行时错误
if ret != 0:
raise RuntimeError('Get rank size error.')
# 返回获取到的rank size值
return c_rank_size.value
@ -186,17 +205,22 @@ def get_rank_id(group="hccl_world_group"):
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_id()
# 检查组的有效性
check_group(group)
# 将组转换为 C 字符串格式
c_group = c_str(group)
# 定义一个用于存储 rank id 的 ctypes 无符号整数
c_rank_id = ctypes.c_uint()
# 调用 HCCL 库获取当前进程的 rank id
ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
# 如果返回值不为 0表示获取 rank id 出错,抛出 RuntimeError 异常
if ret != 0:
raise RuntimeError('Get rank id error.')
# 返回获取到的 rank id 值
return c_rank_id.value
def get_local_rank_size(group="hccl_world_group"):
"""
A function that returns the number of local ranks within the given collection communication group.
@ -207,19 +231,25 @@ def get_local_rank_size(group="hccl_world_group"):
Returns:
An integer scalar with the num of local ranks.
"""
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, "
"'get_local_rank_size' only support GRAPH_MODE")
# 验证传入的组是否有效
check_group(group)
# 将组名称转换为C字符串格式
c_group = c_str(group)
# 定义一个ctypes的无符号整数变量用于存储本地排名大小
c_local_rank_size = ctypes.c_uint()
# 调用HCCL库中的HcomGetLocalRankSize函数获取本地排名大小
ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
# 如果返回值不为0说明获取本地排名大小时出错抛出异常
if ret != 0:
raise RuntimeError('Get local rank size error.')
# 返回获取到的本地排名大小
return c_local_rank_size.value
def get_local_rank_id(group="hccl_world_group"):
"""
Get local rank id.
@ -230,16 +260,23 @@ def get_local_rank_id(group="hccl_world_group"):
An integer scalar with the local rank id of the calling process.
"""
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, "
"'get_local_rank_id' only support GRAPH_MODE")
# 验证群组的有效性
check_group(group)
# 将群组名称转换为C字符串格式
c_group = c_str(group)
# 定义一个无符号整型的C类型变量来存储本地排名ID
c_local_rank_id = ctypes.c_uint()
# 调用HCCL库的HcomGetLocalRankId函数获取本地排名ID
ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
# 如果返回值不为0表示获取本地排名ID失败抛出异常
if ret != 0:
raise RuntimeError('Get local rank id error.')
# 返回获取到的本地排名ID值
return c_local_rank_id.value
@ -256,18 +293,25 @@ def get_world_rank_from_group_rank(group, group_rank_id):
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, "
"'get_world_rank_from_group_rank' only support GRAPH_MODE")
# 检查组名是否有效
check_group(group)
# 检查组内rank ID是否有效
check_rank_id(group_rank_id)
# 将组名转换为C字符串格式
c_group = c_str(group)
# 将组内rank ID转换为C的无符号整数类型
c_group_rank_id = ctypes.c_uint(group_rank_id)
# 声明一个用于存储世界rank ID的C的无符号整数类型变量
c_world_rank_id = ctypes.c_uint()
# 调用HCCL库中的HcomGetWorldRankFromGroupRank函数根据组名和组内rank ID获取对应的世界rank ID
ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
# 如果返回值不为0说明函数调用出错抛出RuntimeError异常
if ret != 0:
raise RuntimeError('Get world rank from group rank error.')
raise RuntimeError('根据组内rank ID获取世界rank ID时出错。')
# 返回获取到的世界rank ID的值
return c_world_rank_id.value
def get_group_rank_from_world_rank(world_rank_id, group):
"""
Get group rank from world rank.
@ -281,13 +325,21 @@ def get_group_rank_from_world_rank(world_rank_id, group):
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, "
"'get_group_rank_from_world_rank' only support GRAPH_MODE")
# 检查组的有效性
check_group(group)
# 检查世界排名ID的有效性
check_rank_id(world_rank_id)
# 将组转换为C字符串
c_group = c_str(group)
# 将世界排名ID转换为C无符号整数
c_world_rank_id = ctypes.c_uint(world_rank_id)
# 创建一个用于存储组排名ID的C无符号整数
c_group_rank_id = ctypes.c_uint()
# 从世界排名ID获取组排名ID
ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
# 如果返回值不为0抛出运行时错误
if ret != 0:
raise RuntimeError('Get group rank from world rank error.')
return c_group_rank_id.value
# 返回获取到的组排名ID的值
return c_group_rank_id.value

@ -63,25 +63,31 @@ def _check_parallel_envs():
Raises:
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
"""
# 检查是否需要进行环境验证,如果不进行则直接返回
if not GlobalComm.CHECK_ENVS:
return
import os
# 获取环境变量RANK_ID的值
rank_id_str = os.getenv("RANK_ID")
# 如果RANK_ID未设置抛出运行时错误
if not rank_id_str:
raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.")
try:
# 尝试将RANK_ID转换为整数
int(rank_id_str)
except ValueError:
# 如果转换失败,打印错误信息
print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str)))
finally:
# 无论是否发生异常,此块为空操作
pass
# 获取环境变量MINDSPORE_HCCL_CONFIG_PATH和RANK_TABLE_FILE的值
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
# 如果两个环境变量都未设置,抛出运行时错误
if not rank_table_file_str and not rank_table_file_str_old:
raise RuntimeError("Get hccl rank_table_file failed, "
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
def init(backend_name=None):
"""
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
@ -109,15 +115,19 @@ def init(backend_name=None):
>>> from mindspore.communication import init
>>> init()
"""
# 检查当前角色是否为参数服务器或调度器,如果是则直接返回
if _is_role_pserver() or _is_role_sched():
return
# 检查任务接收环境变量,获取设备目标和模式
task_sink = _check_task_sink_envs()
device_target = context.get_context("device_target")
mode = context.get_context("mode")
mpi_init = False
# 如果没有任务接收且模式为图模式设置mpi_init为True
if not task_sink and mode == context.GRAPH_MODE:
mpi_init = True
# 根据设备目标选择后端名称,如果不支持则抛出异常
if backend_name is None:
if device_target == "Ascend":
backend_name = "hccl"
@ -126,28 +136,35 @@ def init(backend_name=None):
else:
raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in "
"parallel initialization, please use Ascend or GPU.".format(device_target))
# 检查后端名称是否为字符串,如果不是则抛出异常
if not isinstance(backend_name, str):
raise TypeError("For 'init', the argument 'backend_name' must be a string, "
"but got the type : {}".format(type(backend_name)))
# 根据后端名称初始化通信环境
if backend_name == "hccl":
# 如果设备目标不是Ascend抛出异常
if device_target != "Ascend":
raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, "
"but got {}".format(device_target))
# 如果不需要MPI初始化检查并行环境并设置后端名称
if not mpi_init:
_check_parallel_envs()
GlobalComm.BACKEND = Backend("hccl")
else:
GlobalComm.BACKEND = Backend("hccl_mpi")
# 初始化HCCL并设置全局通信组和初始化状态
init_hccl()
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
elif backend_name == "nccl":
# 初始化GPU集体通信并设置后端名称、全局通信组和初始化状态
init_gpu_collective()
GlobalComm.BACKEND = Backend("nccl")
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
else:
# 如果后端名称不支持,抛出异常
raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, "
"but got the 'backend_name' : hccl.")

@ -45,14 +45,15 @@ class OneOf(OneOf_):
TypeError: raise type error for invalid inputs.
"""
self.patterns = patterns
# 检查 patterns 是否是 Pattern 类的实例
if isinstance(patterns, Pattern):
OneOf_.__init__(self, [patterns])
# 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
OneOf_.__init__(self, patterns)
# 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常
else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class Prim(Prim_):
r"""
Express a pattern of certain primitive type(s).
@ -76,25 +77,33 @@ class Prim(Prim_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 检查name是否为字符串类型如果不是则抛出TypeError
if name is not None and not isinstance(name, str):
raise TypeError(f"Expect string, got : {name}")
self.name = name
# 如果types是字符串类型则将其按'|'分割成列表
if isinstance(types, str):
if self.name is None:
self.name = types
self.types = types.split('|')
# 如果types是Primitive类型则直接将其放入列表中
elif isinstance(types, Primitive):
if self.name is None:
self.name = types.name
self.types = [types]
# 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型
elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
# 如果 self.name 为 None则初始化为空字符串并拼接所有 Primitive 的 name
if self.name is None:
self.name = ""
for prim in types:
self.name += prim.name
# 设置 self.types 为传入的 types
self.types = types
# 如果 types 不符合预期类型,抛出 TypeError
else:
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
# 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name
Prim_.__init__(self, self.types, self.name)
@ -115,16 +124,22 @@ class Call(Call_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
self.prim_pattern = prim_pattern
# 初始化 inputs 列表
self.inputs = []
# 如果 inputs 为 None则不做任何操作
if inputs is None:
pass
# 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs
elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
self.inputs = inputs
# 如果 inputs 不符合上述条件,则抛出 TypeError
else:
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
# 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs
Call_.__init__(self, self.prim_pattern, self.inputs)
@ -145,6 +160,7 @@ class NoneOf(NoneOf_):
TypeError: raise type error for invalid argument.
"""
self.patterns = patterns
# 根据 patterns 的类型初始化 NoneOf_ 类
if patterns is None:
NoneOf_.__init__(self, ())
elif isinstance(patterns, Pattern):
@ -154,7 +170,6 @@ class NoneOf(NoneOf_):
else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
class NewTensor(NewTensor_):
r"""
New Tensor to be used in the target.
@ -167,13 +182,16 @@ class NewTensor(NewTensor_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 初始化输入张量
self.input_tensor = input_tensor
# 检查输入是否为 Tensor 类型
if isinstance(input_tensor, Tensor):
# 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法
NewTensor_.__init__(self, input_tensor)
else:
# 如果不是 Tensor 类型,则抛出 TypeError 异常
raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}")
class NewParameter(NewParameter_):
r"""
New Parameter to be used in the target.
@ -193,11 +211,14 @@ class NewParameter(NewParameter_):
self.default_tensor = default_tensor
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
# 检查参数类型是否符合预期
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool):
# 初始化 NewParameter_ 类
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel)
else:
# 如果参数类型不符合预期,抛出 TypeError
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}")
{requires_grad}, {layerwise_parallel}")

@ -39,6 +39,7 @@ class PyPassManager(PyPassManager_):
TypeError: If argument has invalid type.
"""
def __init__(self, requires_grad=True, run_only_once=False):
# 初始化方法,接收两个布尔参数,设置实例的属性并调用父类的初始化方法
if not isinstance(requires_grad, bool):
raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
if not isinstance(run_only_once, bool):
@ -48,17 +49,20 @@ class PyPassManager(PyPassManager_):
PyPassManager_.__init__(self)
def register(self, py_pass):
# 注册一个Python pass检查其是否为函数类型并获取其模式和目标
if not isfunction(py_pass):
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass()
pass_name = py_pass.__name__
# 检查模式和目标是否为Pattern类型
if not isinstance(pattern, Pattern):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
# 调用父类的register方法注册pass及其相关信息
super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
def unregister(self, py_pass):
# 从注册表中移除指定的Python传递对象可以是字符串形式的名称或函数对象
if isinstance(py_pass, str):
super().unregister(py_pass)
return
@ -66,27 +70,30 @@ class PyPassManager(PyPassManager_):
super().unregister(py_pass.__name__)
return
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass):
# 将Python传递对象注册到注册表中并返回该对象
self.register(py_pass)
return py_pass
def gen_new_parameter(self, pattern):
# 根据给定的模式生成新的参数模式必须是NewParameter类型
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
super().gen_new_parameter(pattern)
def set_renorm(self, should_renorm):
# 设置是否进行重归一化操作,参数必须是布尔值
if not isinstance(should_renorm, bool):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm)
def set_reopt(self, do_reopt):
# 设置是否进行重新优化操作,参数必须是布尔值
if not isinstance(do_reopt, bool):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt)
def register_pass(requires_grad=True, run_only_once=False):
"""
Register python pass to specified pipeline phase which would be used in compilation.
@ -165,12 +172,13 @@ def cancel_new_parameter(pattern):
>>> # some compilations
>>> cancel_new_parameter(abc)
"""
# 检查传入的pattern是否为NewParameter的实例
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
# 创建一个PyPassManager对象
ppm = PyPassManager()
# 从PyPassManager中注销指定名称的参数
ppm.unregister(pattern.para_name)
def set_renorm(should_renorm):
"""
Set whether or not to do renormalization after modified graph in python pass(es).

@ -19,22 +19,25 @@ import sys
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
# 创建一个ArgumentParser对象用于解析命令行参数描述信息为"MindSpore dependency packages version checker."
parser = ArgumentParser(description="MindSpore dependency packages version checker.")
# 添加一个命令行参数--mindspore_version类型为字符串帮助信息为"MindSpore version."
parser.add_argument("--mindspore_version", type=str, help="MindSpore version.")
# 添加一个命令行参数--supported_version类型为字符串可以多次指定帮助信息为"Supported environment version."
parser.add_argument("--supported_version", type=str, action='append', help="Supported environment version.")
# 解析命令行参数并返回结果
args = parser.parse_args()
return args
def check_deps_version(mindspore_version, supported_version):
"""
check te/hccl/topi version
@ -46,6 +49,7 @@ def check_deps_version(mindspore_version, supported_version):
Returns:
void
"""
# 尝试导入并检查 hccl、te 和 topi 包的版本是否与支持的版本匹配
try:
from hccl import sys_version as hccl_version
v = '.'.join(hccl_version.__sys_version__.split('.')[0:2])
@ -63,18 +67,20 @@ def check_deps_version(mindspore_version, supported_version):
print(f"MindSpore version {mindspore_version} and \"topi\" wheel package version {v} does not "
"match, reference to the match info on: https://www.mindspore.cn/install")
# 捕获导入错误并打印相应的检查失败信息
except ImportError as e:
print("CheckFailed: ", e.args)
print("MindSpore relies on the 3 whl packages of \"te\", \"topi\" and \"hccl\" in the \"fwkacllib\" "
"folder of the Ascend AI software package (Ascend Data Center Solution), please check whether they are "
"installed correctly or not, reference to the match info on: https://www.mindspore.cn/install")
def main():
# 解析命令行参数
args = parse_args()
# 检查 mindspore 的版本是否在支持的版本范围内
check_deps_version(args.mindspore_version, args.supported_version)
if __name__ == "__main__":
sys.path = sys.path[1:] # avoid the impact of relative path env, only affect this process
# 避免相对路径环境的影响,仅影响当前进程
sys.path = sys.path[1:]
main()

@ -31,47 +31,58 @@ class EnvChecker(metaclass=ABCMeta):
@abstractmethod
def check_env(self, e):
pass
"""检查环境是否符合要求"""
@abstractmethod
def set_env(self):
pass
"""设置环境"""
@abstractmethod
def check_version(self):
pass
"""检查版本是否符合要求"""
class GPUEnvChecker(EnvChecker):
"""GPU environment check."""
def __init__(self):
# 初始化版本列表
self.version = ["10.1", "11.1"]
# 初始化库键到库名的映射字典
self.lib_key_to_lib_name = {'libcu': 'libcuda.so'}
# env
# 获取系统环境变量 PATH 的值
self.path = os.getenv("PATH")
# 获取系统环境变量 LD_LIBRARY_PATH 的值
self.ld_lib_path = os.getenv("LD_LIBRARY_PATH")
# check
# 初始化版本号为 "0"
self.v = "0"
# 获取 CUDA 库的路径
self.cuda_lib_path = self._get_lib_path("libcu")
# 获取 CUDA 可执行文件的路径
self.cuda_bin_path = self._get_bin_path("cuda")
# 获取 cuDNN 库的路径
self.cudnn_lib_path = self._get_lib_path("libcudnn")
def check_env(self, e):
# 抛出传入的异常 e
raise e
def set_env(self):
# 设置环境变量,当前实现为空
return
def _get_bin_path(self, bin_name):
"""Get bin path by bin name."""
# 如果二进制名称为 "cuda",则调用获取 CUDA 二进制路径的方法
if bin_name == "cuda":
return self._get_cuda_bin_path()
# 否则返回空列表
return []
def _get_cuda_bin_path(self):
"""Get cuda bin path by lib path."""
# Get cuda bin path by lib path.
path_list = []
for path in self.cuda_lib_path:
path = os.path.abspath(path.strip()+"/bin/")
@ -81,56 +92,87 @@ class GPUEnvChecker(EnvChecker):
def _get_nvcc_version(self, is_set_env):
"""Get cuda version by nvcc command."""
# 运行 nvcc 命令获取 CUDA 版本信息
nvcc_result = subprocess.run(["nvcc", "--version | grep release"],
timeout=3, text=True, capture_output=True, check=False)
# 如果命令返回非零值,表示命令执行失败
if nvcc_result.returncode:
# 如果尚未设置环境变量
if not is_set_env:
# 遍历预设的 CUDA 二进制路径
for path in self.cuda_bin_path:
# 检查路径中是否存在 nvcc 文件
if Path(path + "/nvcc").is_file():
# 将路径添加到环境变量 PATH 中
os.environ['PATH'] = path + ":" + os.environ['PATH']
# 递归调用以重新尝试获取版本信息
return self._get_nvcc_version(True)
# 如果命令执行失败且未找到 nvcc 文件,返回空字符串
return ""
# 获取命令输出结果
result = nvcc_result.stdout
# 遍历输出结果的每一行
for line in result.split('\n'):
if line:
# 提取并返回 CUDA 版本号
return line.strip().split("release")[1].split(",")[0].strip()
# 如果未找到版本信息,返回空字符串
return ""
def _get_cudnn_version(self):
"""Get cudnn version by libcudnn.so."""
# 初始化cudnn版本列表为空
cudnn_version = []
# 遍历cudnn库路径
for path in self.cudnn_lib_path:
# 查找路径下所有的libcudnn.so文件
real_path = glob.glob(path + "/lib*/libcudnn.so.*.*")
# 如果没有找到对应的文件,继续下一个路径
if real_path == []:
continue
# 使用ls命令获取文件信息
ls_cudnn = subprocess.run(["ls", real_path[0]], timeout=10, text=True,
capture_output=True, check=False)
# 如果ls命令执行成功解析输出以获取版本号
if ls_cudnn.returncode == 0:
cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.')
# 如果版本号只有两个部分,添加一个'.0'作为第三部分
if len(cudnn_version) == 2:
cudnn_version.append('0')
# 找到版本号后跳出循环
break
# 将版本号列表转换为字符串
version_str = ''.join([n for n in cudnn_version])
# 返回版本号的前三位
return version_str[0:3]
def _get_cudart_version(self):
"""Get cuda runtime version by libcudart.so."""
# 遍历可能的 CUDA 库路径
for path in self.cuda_lib_path:
# 查找路径下所有可能的 libcudart.so 文件
real_path = glob.glob(path + "/lib*/libcudart.so.*.*.*")
# 如果没有找到任何文件,则跳过当前路径
if real_path == []:
continue
# 获取文件名信息以确定 CUDA 版本
ls_cudart = subprocess.run(["ls", real_path[0]], timeout=10, text=True,
capture_output=True, check=False)
# 如果命令成功执行,则解析输出以提取版本号
if ls_cudart.returncode == 0:
self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip()
# 找到版本号后跳出循环
break
# 返回找到的 CUDA 版本号
return self.v
def check_version(self):
"""Check cuda version."""
version_match = False
# 调用私有方法检查版本是否匹配并根据结果设置version_match标志
if self._check_version():
version_match = True
# 如果版本不匹配根据CUDA版本号输出不同的警告信息
if not version_match:
if self.v == "0":
logger.warning("Can not found cuda libs, please confirm that the correct "
@ -140,17 +182,20 @@ class GPUEnvChecker(EnvChecker):
logger.warning(f"MindSpore version {__version__} and cuda version {self.v} does not match, "
"please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install")
# 获取nvcc版本号并检查是否与MindSpore支持的版本匹配
nvcc_version = self._get_nvcc_version(False)
if nvcc_version and (nvcc_version not in self.version):
logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} "
"does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install")
# 获取cudnn版本号并检查是否符合最低要求
cudnn_version = self._get_cudnn_version()
if cudnn_version and int(cudnn_version) < 760:
logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install. The recommended version is "
"CUDA10.1 with cuDNN7.6.x and CUDA11.1 with cuDNN8.0.x")
# 检查cudnn版本号与CUDA版本号的兼容性对于CUDA 11.0以上版本cudnn版本需要至少为8.0
if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10:
logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching "
@ -159,45 +204,58 @@ class GPUEnvChecker(EnvChecker):
def _check_version(self):
"""Check cuda version"""
# 获取 CUDA 运行时版本
v = self._get_cudart_version()
# 解析版本字符串为版本对象
v = version.parse(v)
# 构造版本号字符串,格式为 "主版本.次版本"
v_str = str(v.major) + "." + str(v.minor)
# 检查构造的版本号字符串是否在预定义的版本列表中
if v_str not in self.version:
return False
# 版本号匹配,返回 True
return True
def _get_lib_path(self, lib_name):
"""Get gpu lib path by ldd command."""
path_list = []
current_path = os.path.split(os.path.realpath(__file__))[0]
mindspore_path = os.path.join(current_path, "../")
"""通过ldd命令获取gpu库路径。"""
path_list = [] # 初始化一个空列表用于存储路径
current_path = os.path.split(os.path.realpath(__file__))[0] # 获取当前文件的绝对路径并分割以获取目录部分
mindspore_path = os.path.join(current_path, "../") # 构建mindspore路径通常是当前文件的上一级目录
try:
# 使用glob模块查找mindspore_path目录下所有以_c_expression.so开头的文件路径
real_path = glob.glob(mindspore_path + "/_c_expression*.so*")
if real_path == []:
logger.error(f"{self.lib_key_to_lib_name[lib_name]} (need by mindspore-gpu) is not found, please "
f"confirm that _c_expression.so is in directory:{mindspore_path} and the correct cuda "
"version has been installed, you can refer to the installation "
"guidelines: https://www.mindspore.cn/install")
return path_list
if real_path == []: # 如果没有找到任何文件
# 记录错误日志提示用户确认_c_expression.so文件是否存在以及是否安装了正确的cuda版本
logger.error(f"{self.lib_key_to_lib_name[lib_name]} (mindspore-gpu所需的库) 未找到,请确认 "
f"_c_expression.so是否位于目录:{mindspore_path}并且已安装正确的cuda版本"
"您可以参考安装指南https://www.mindspore.cn/install")
return path_list # 返回空路径列表
# 使用subprocess.Popen执行ldd命令以获取依赖库的信息
ldd_r = subprocess.Popen(['ldd', real_path[0]], stdout=subprocess.PIPE)
# 使用subprocess.Popen的stdin参数从ldd_r.stdout接收输出并执行grep命令以过滤出包含指定库名的信息
ldd_result = subprocess.Popen(['grep', lib_name], stdin=ldd_r.stdout, stdout=subprocess.PIPE)
# 获取grep命令的输出结果并解码为字符串
result = ldd_result.communicate()[0].decode()
for i in result.split('\n'):
for i in result.split('\n'): # 按行分割结果字符串
# 使用partition方法从每一行中提取出库文件的路径
path = i.partition("=>")[2]
if path.lower().find("not found") > 0:
logger.warning(f"Cuda {self.version} version(need by mindspore-gpu) is not found, please confirm "
"that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the "
"installation guidelines: https://www.mindspore.cn/install")
continue
if path.lower().find("not found") > 0: # 如果路径中包含"not found"
# 记录警告日志提示用户确认cuda路径是否已添加到环境变量LD_LIBRARY_PATH中
logger.warning(f"Cuda {self.version}版本(由mindspore-gpu要求的) 未找到请确认cuda路径已设置到环境变量LD_LIBRARY_PATH中"
"您可以参考安装指南https://www.mindspore.cn/install")
continue # 继续下一次循环
# 从路径中去除库名部分
path = path.partition(lib_name)[0]
if path:
if path: # 如果路径非空
# 将路径的绝对路径并去除末尾斜杠后添加到path_list中
path_list.append(os.path.abspath(path.strip() + "../"))
# 返回path_list中唯一的路径
return np.unique(path_list)
except subprocess.TimeoutExpired:
logger.warning("Failed to check cuda version due to the ldd command timeout, please confirm that "
"the correct cuda version has been installed, you can refer to the "
"installation guidelines: https://www.mindspore.cn/install")
return path_list
except subprocess.TimeoutExpired: # 捕获subprocess.TimeoutExpired异常
# 记录警告日志提示用户确认cuda版本是否正确安装因为ldd命令超时
logger.warning("由于ldd命令超时无法检查cuda版本请确认已安装正确的cuda版本"
"您可以参考安装指南:https://www.mindspore.cn/install")
return path_list # 返回空路径列表
def _read_version(self, file_path):
"""Get gpu version info in version.txt."""
@ -211,70 +269,80 @@ class GPUEnvChecker(EnvChecker):
class AscendEnvChecker(EnvChecker):
"""ascend environment check"""
"""Ascend 环境检查类"""
def __init__(self):
# 初始化 Ascend 环境检查器的版本列表
self.version = ["1.81"]
# 定义不同路径下的 version.info 文件位置
atlas_nnae_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info"
atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/version.info"
hisi_fwk_version = "/usr/local/Ascend/latest/fwkacllib/version.info"
# 检查 Atlas NNAE 环境是否存在
if os.path.exists(atlas_nnae_version):
# atlas default path
self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib"
self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe"
self.tbe_path = self.fwk_path + "/lib64"
self.cce_path = self.fwk_path + "/ccec_compiler/bin"
self.fwk_version = atlas_nnae_version
self.op_path = "/usr/local/Ascend/nnae/latest/opp"
self.aicpu_path = "/usr/local/Ascend/nnae/latest"
# 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = atlas_nnae_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/nnae/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/nnae/latest" # AI CPU 路径
# 检查 Atlas Toolkit 环境是否存在
elif os.path.exists(atlas_toolkit_version):
# atlas default path
self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib"
self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe"
self.tbe_path = self.fwk_path + "/lib64"
self.cce_path = self.fwk_path + "/ccec_compiler/bin"
self.fwk_version = atlas_toolkit_version
self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp"
self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest"
# 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = atlas_toolkit_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" # AI CPU 路径
# 检查 Hisi 环境是否存在
elif os.path.exists(hisi_fwk_version):
# hisi default path
self.fwk_path = "/usr/local/Ascend/latest/fwkacllib"
self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe"
self.tbe_path = self.fwk_path + "/lib64"
self.cce_path = self.fwk_path + "/ccec_compiler/bin"
self.fwk_version = hisi_fwk_version
self.op_path = "/usr/local/Ascend/latest/opp"
self.aicpu_path = "/usr/local/Ascend/latest"
# 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = hisi_fwk_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/latest" # AI CPU 路径
else:
# custom or unknown environment
self.fwk_path = ""
self.op_impl_path = ""
self.tbe_path = ""
self.cce_path = ""
self.fwk_version = ""
self.op_path = ""
self.aicpu_path = ""
# env
# 如果以上环境都不存在,设置为空路径
self.fwk_path = "" # Framework 路径
self.op_impl_path = "" # Operator 实现路径
self.tbe_path = "" # TBE 库路径
self.cce_path = "" # CCE 编译器路径
self.fwk_version = "" # Framework 版本文件路径
self.op_path = "" # Operator 路径
self.aicpu_path = "" # AI CPU 路径
# 初始化环境变量
self.path = os.getenv("PATH")
self.python_path = os.getenv("PYTHONPATH")
self.ld_lib_path = os.getenv("LD_LIBRARY_PATH")
self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH")
self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH")
# check content
# 设置需要检查的路径内容
self.path_check = "/fwkacllib/ccec_compiler/bin"
self.python_path_check = "opp/op_impl/built-in/ai_core/tbe"
self.ld_lib_path_check_fwk = "/fwkacllib/lib64"
self.ld_lib_path_check_addons = "/add-ons"
self.ascend_opp_path_check = "/op"
self.v = ""
def check_env(self, e):
self._check_env()
raise e
def check_version(self):
# 检查指定路径的版本文件是否存在,如果不存在则跳过版本检查
if not Path(self.fwk_version).is_file():
logger.warning("Using custom Ascend AI software package (Ascend Data Center Solution) path, package "
"version checking is skipped, please make sure Ascend AI software package (Ascend Data "
@ -282,40 +350,47 @@ class AscendEnvChecker(EnvChecker):
"https://www.mindspore.cn/install")
return
# 读取版本文件中的版本信息
v = self._read_version(self.fwk_version)
# 如果读取的版本不在支持的版本列表中,则记录警告信息
if v not in self.version:
v_list = str([x for x in self.version])
logger.warning(f"MindSpore version {__version__} and Ascend AI software package (Ascend Data Center "
f"Solution)version {v} does not match, the version of software package expect one of "
f"{v_list}, please reference to the match info on: https://www.mindspore.cn/install")
def check_deps_version(self):
"""
te, topi, hccl wheel package version check
in order to update the change of 'LD_LIBRARY_PATH' env, run a sub process
"""
# 构建输入参数列表包含mindspore版本和受支持的版本列表
input_args = ["--mindspore_version=" + __version__]
for v in self.version:
input_args.append("--supported_version=" + v)
# 获取依赖版本检查脚本的路径
deps_version_checker = os.path.join(os.path.split(os.path.realpath(__file__))[0],
"_check_deps_version.py")
# 构建调用命令包括python解释器路径、脚本路径和输入参数
call_cmd = [sys.executable, deps_version_checker] + input_args
try:
# 运行子进程进行版本检查设置超时时间为3秒并捕获输出
process = subprocess.run(call_cmd, timeout=3, text=True, capture_output=True, check=False)
# 如果子进程的输出不为空,则记录警告信息并进行倒计时提醒
if process.stdout.strip() != "":
logger.warning(process.stdout.strip())
warning_countdown = 3
for i in range(warning_countdown, 0, -1):
logger.warning(f"Please pay attention to the above warning, countdown: {i}")
time.sleep(1)
# 如果版本检查超时,则记录信息并跳过
except subprocess.TimeoutExpired:
logger.info("Package te, topi, hccl version check timed out, skip.")
def set_env(self):
# 设置Ascend环境变量
if not self.tbe_path:
self._check_env()
return
try:
import te # pylint: disable=unused-import
# pylint: disable=broad-except
@ -329,32 +404,35 @@ class AscendEnvChecker(EnvChecker):
raise EnvironmentError(
f"No such directory: {self.tbe_path}, Please check if Ascend AI software package (Ascend Data "
"Center Solution) is installed correctly.")
# check te version after set te env
# 检查te版本
self.check_deps_version()
# 设置op实现路径环境变量
if Path(self.op_impl_path).is_dir():
# python path for sub process
# python路径用于子进程
if os.getenv('PYTHONPATH'):
os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH']
else:
os.environ['PYTHONPATH'] = self.op_impl_path
# sys path for this process
# sys路径用于当前进程
sys.path.append(self.op_impl_path)
os.environ['TBE_IMPL_PATH'] = self.op_impl_path
else:
raise EnvironmentError(
f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data "
"Center Solution) is installed correctly.")
f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.")
# 设置CCE路径环境变量
if Path(self.cce_path).is_dir():
os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH']
else:
raise EnvironmentError(
f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.")
# 设置OP路径环境变量
if self.op_path is None:
pass
elif Path(self.op_path).is_dir():
@ -363,7 +441,8 @@ class AscendEnvChecker(EnvChecker):
raise EnvironmentError(
f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.")
# 设置AICPU路径环境变量
if self.aicpu_path is None:
pass
elif Path(self.aicpu_path).is_dir():
@ -372,44 +451,54 @@ class AscendEnvChecker(EnvChecker):
raise EnvironmentError(
f"No such directory: {self.aicpu_path}, Please check if Ascend AI software package (Ascend Data Center"
" Solution) is installed correctly.")
def _check_env(self):
"""ascend dependence path check"""
# 检查是否设置正确的PATH环境变量
if self.path is None or self.path_check not in self.path:
logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env "
"PATH, you can reference to the installation guidelines https://www.mindspore.cn/install")
# 检查是否设置正确的PYTHONPATH环境变量
if self.python_path is None or self.python_path_check not in self.python_path:
logger.warning(
"Can not find tbe op implement(need by mindspore-ascend), please check if you have set env "
"PYTHONPATH, you can reference to the installation guidelines "
"https://www.mindspore.cn/install")
# 检查是否设置正确的LD_LIBRARY_PATH环境变量
if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and
self.ld_lib_path_check_addons in self.ld_lib_path):
logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env "
"LD_LIBRARY_PATH, you can reference to the installation guidelines "
"https://www.mindspore.cn/install")
# 检查是否设置正确的ASCEND_OPP_PATH环境变量
if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path:
logger.warning(
"Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, "
"you can reference to the installation guidelines https://www.mindspore.cn/install")
def _read_version(self, file_path):
"""get ascend version info"""
with open(file_path, 'r') as f:
all_info = f.readlines()
# 遍历文件中的每一行
for line in all_info:
# 检查行是否以 "Version=" 开头
if line.startswith("Version="):
# 去除行末的换行符并按 "=" 分割, 获取版本号
full_version = line.strip().split("=")[1]
# 提取主版本号和次版本号, 并用 "." 连接
self.v = '.'.join(full_version.split('.')[0:2])
# 返回版本号
return self.v
# 如果未找到版本信息, 返回 None 或默认值
return self.v
def check_version_and_env_config():
"""check version and env config"""
"""检查版本和环境配置"""
# 检查包名以确定使用哪种环境检查器
if __package_name__.lower() == "mindspore-ascend":
env_checker = AscendEnvChecker()
# Note: pre-load libgomp.so to solve error like "cannot allocate memory in statis TLS block"
@ -425,19 +514,21 @@ def check_version_and_env_config():
else:
logger.info(f"Package version {__package_name__} does not need to check any environment variable, skipping.")
return
# 检查是否关闭版本检查,如果已关闭则直接返回
if os.getenv("MS_DEV_CLOSE_VERSION_CHECK") == "ON":
return
# 设置环境变量以关闭版本检查
os.environ["MS_DEV_CLOSE_VERSION_CHECK"] = "ON"
try:
# check version of ascend site or cuda
# 检查 ascend site 或 cuda 的版本
env_checker.check_version()
from .. import _c_expression # pylint: disable=unused-import
# 设置环境
env_checker.set_env()
except ImportError as e:
# 处理导入错误,检查环境
env_checker.check_env(e)
def _set_pb_env():
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
@ -449,7 +540,9 @@ def _set_pb_env():
logger.info("Setting the env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python` to prevent memory overflow "
"during save or load checkpoint file.")
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# 检查版本和环境配置
check_version_and_env_config()
# 设置协议缓冲区的环境变量, 防止内存溢出
_set_pb_env()

@ -26,23 +26,26 @@ def _check_mul():
"""
from importlib import import_module
import numpy as np
try:
ms = import_module("mindspore")
except ModuleNotFoundError:
ms = None
finally:
pass
# 打印MindSpore版本信息
print(f"MindSpore version: ", ms.__version__)
# 创建两个MindSpore张量分别包含数组[1.0, 2.0, 3.0]和[4.0, 5.0, 6.0]
input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32)
# 创建一个乘法操作对象
mul = ms.ops.Mul()
# 执行乘法操作
mul(input_x, input_y)
# 打印乘法计算结果正确MindSpore安装成功的信息
print(f"The result of multiplication calculation is correct, MindSpore has been installed successfully!")
def run_check():
"""
Provide a convenient API to check if the installation is successful or failed.
@ -55,10 +58,13 @@ def run_check():
The result of multiplication calculation is correct, MindSpore has been installed successfully!
"""
try:
# 尝试执行检查乘法操作的函数
_check_mul()
# pylint: disable=broad-except
# 捕获所有异常并打印错误信息
except Exception as e:
print("MindSpore running check failed.")
print(str(e))
finally:
pass
# 无论是否发生异常,都会执行此部分代码
pass # 执行乘法检查的函数,并处理可能的异常情况。如果检查失败,打印错误信息。
Loading…
Cancel
Save