donghaoqian 7 months ago
commit 79516a3321

@ -0,0 +1,38 @@
import os
import stat
import fcntl
import platform
from logging.handlers import RotatingFileHandler
class _MultiCompatibleRotatingFileHandler(RotatingFileHandler):
"""Inherit RotatingFileHandler for multiprocess compatibility.
这个类继承自`RotatingFileHandler`,是为了在多进程环境下安全地使用日志回滚功能。
在多进程环境下,多个进程可能会同时尝试写入或回滚日志文件,这可能会导致文件损坏或数据丢失。
通过在这个类中对相关方法进行重写,确保了日志文件在多进程环境下的正确处理。
"""
def doRollover(self):
"""Override doRollover for multiprocess compatibility
and setting permission of Log file
这个方法重写了`RotatingFileHandler`中的`doRollover`方法,增加了多进程兼容性,
并设置了日志文件的权限。
1. 使用`fcntl`模块获得独占锁,确保在回滚日志文件时不会有其他进程进行写操作。
2. 设置日志文件的权限,以确保日志文件的安全性。
3. 调用父类的`doRollover`方法执行实际的日志回滚操作。
4. 回滚后,修改日志文件的权限,使其可读可写。
"""
# Attain an exclusive lock with blocking mode by `fcntl` module.
with open(self.baseFilename, 'a') as file_pointer:
# 如果操作系统不是Windows使用`fcntl`模块对文件加锁
if platform.system() != "Windows":
fcntl.lockf(file_pointer.fileno(), fcntl.LOCK_EX)
# 设置日志文件权限为只读,增加安全性
os.chmod(self.baseFilename, stat.S_IREAD)
# 调用父类的`doRollover`方法执行日志回滚操作
super().doRollover()
# 修改日志文件的权限为可读可写,以便后续的日志写入操作
os.chmod(self.baseFilename, stat.S_IREAD | stat.S_IWRITE)

@ -18,24 +18,31 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched
from mindspore import log as logger
from ._hccl_management import load_lib as hccl_load_lib
from .._c_expression import get_rank_id, get_rank_size
# 检查HCCL是否可用
_HCCL_AVAILABLE = False
# 检查HCCL测试是否可用
_HCCL_TEST_AVAILABLE = False
# 检查NCCL是否可用
_NCCL_AVAILABLE = False
# 检查MPI是否可用
_MPI_AVAILABLE = False
try:
# 尝试导入mindspore._ms_mpi如果成功则NCCL可用
import mindspore._ms_mpi as mpi
_NCCL_AVAILABLE = True
except ImportError:
# 如果导入失败则NCCL不可用
_NCCL_AVAILABLE = False
# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True否则捕获 RuntimeError 并设置为 False
try:
hccl_load_lib()
_HCCL_AVAILABLE = True
except RuntimeError:
_HCCL_AVAILABLE = False
# 如果 HCCL 可用,则导入 _hccl_management 并尝试导入 mindspore._ascend_mpi如果成功则设置 _MPI_AVAILABLE 为 True否则捕获 ImportError 并设置为 False
if _HCCL_AVAILABLE:
from . import _hccl_management as hccl
try:
@ -43,6 +50,7 @@ if _HCCL_AVAILABLE:
_MPI_AVAILABLE = True
except ImportError:
_MPI_AVAILABLE = False
# 如果 HCCL 不可用,则尝试导入 hccl_test.manage.api如果成功则设置 _HCCL_AVAILABLE 和 _HCCL_TEST_AVAILABLE 为 True否则捕获 ImportError 并设置 _HCCL_AVAILABLE 为 False
else:
try:
import hccl_test.manage.api as hccl
@ -50,12 +58,11 @@ else:
_HCCL_TEST_AVAILABLE = True
except ImportError:
_HCCL_AVAILABLE = False
# 定义 HCCL 和 NCCL 的通信组名称常量
HCCL_WORLD_COMM_GROUP = "hccl_world_group"
NCCL_WORLD_COMM_GROUP = "nccl_world_group"
class Backend:
"""
Class for available backends.
@ -82,13 +89,17 @@ class Backend:
def __new__(cls, name):
"""Create instance object of Backend."""
# 检查传入的name是否为字符串类型
if not isinstance(name, str):
raise TypeError("For 'Backend', the class variable 'name' must be a string, "
"but got the type : {}".format(type(name)))
# 获取对应name的大写形式的Backend类属性值如果不存在则返回Backend.UNDEFINED
value = getattr(Backend, name.upper(), Backend.UNDEFINED)
# 如果获取到的值是Backend.UNDEFINED说明传入的name不被支持
if value == Backend.UNDEFINED:
raise ValueError("For 'Backend', the class variable 'name' {} is not supported, "
"please use hccl or nccl.".format(name))
# 返回获取到的Backend类属性值
return value
DEFAULT_BACKEND = Backend("hccl")
@ -97,42 +108,41 @@ DEFAULT_BACKEND = Backend("hccl")
class GlobalComm:
"""
World communication information. The GlobalComm is a global class. The members contain:
- BACKEND: The communication library used, using HCCL/NCCL.
- WORLD_COMM_GROUP: Global communication domain.
"""
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
CHECK_ENVS = True
INITED = False # 标记全局通信是否已初始化
CHECK_ENVS = True # 标记是否需要检查通信环境变量
class _ExistingGroup:
"""
The communication groups which exist in the progress.
用于表示在程序运行过程中存在的通信组
"""
ITEMS = {}
ITEMS = {} # 存储通信组的字典,键为通信组的标识符,值为通信组对象
def is_hccl_available():
"""
Check HCCL api is available.
检查HCCL API是否可用
Returns:
Boolean. Return whether HCCL is available or not.
Boolean: 返回HCCL是否可用
"""
return _HCCL_AVAILABLE
return _HCCL_AVAILABLE # 返回一个布尔值指示HCCL是否可用
def is_mpi_available():
"""
Check HCCL & MPI api is available.
检查HCCL和MPI API是否可用
Returns:
Boolean. Return whether HCCL & MPI is available or not.
Boolean: 返回HCCL和MPI是否同时可用
"""
return _MPI_AVAILABLE
return _MPI_AVAILABLE # 返回一个布尔值指示HCCL和MPI是否同时可用
def is_nccl_available():
"""
@ -158,19 +168,25 @@ def check_parameter_available(func):
Wrapper. If not available, raise Error.
"""
def wrapper(*args, **kargs):
# 如果当前角色是参数服务器或者调度器,直接调用被装饰的函数
if _is_role_pserver() or _is_role_sched():
return func(*args, **kargs)
# 检查分布式通信是否已经初始化,未初始化则抛出异常
if not GlobalComm.INITED:
raise RuntimeError("Distributed Communication has not been inited")
# 获取参数组默认为None
group = None
# 检查关键字参数中是否包含"group",并进行类型检查
if "group" in kargs.keys():
group = kargs.get("group")
if group is not None and not isinstance(group, str):
raise TypeError("The parameter 'group' should be str or None, "
"but got the type : {}".format(type(group)))
# 获取后端默认为None
if "backend" in kargs.keys():
backend = kargs.get("backend")
# 检查后端是否可用,不可用则抛出异常
if backend is Backend.HCCL and not is_hccl_available():
raise RuntimeError("Distributed Communication doesn't have HCCL built in")
if backend is Backend.HCCL_MPI and not is_mpi_available():
@ -178,15 +194,16 @@ def check_parameter_available(func):
if backend is Backend.NCCL and not is_nccl_available():
raise RuntimeError("Distributed Communication doesn't have NCCL built in")
# 如果未指定group根据backend设置默认的group
if group is None:
if backend is Backend.HCCL or Backend.HCCL_MPI:
if backend is Backend.HCCL or backend is Backend.HCCL_MPI:
group = HCCL_WORLD_COMM_GROUP
elif backend is Backend.NCCL:
group = NCCL_WORLD_COMM_GROUP
# 调用被装饰的函数
return func(*args, **kargs)
return wrapper
@check_parameter_available
def _get_rank_helper(group, backend):
"""
@ -202,10 +219,13 @@ def _get_rank_helper(group, backend):
Returns:
Integer. The local rank id of the calling process.
"""
# 辅助函数,用于根据不同的后端和组获取 rank_id
# 获取当前角色的rank_id如果是参数服务器或调度器角色rank_id设为0并返回
rank_id = None
if _is_role_pserver() or _is_role_sched():
rank_id = 0
return rank_id
# 根据不同的后端获取rank_id
if backend == Backend.HCCL_MPI:
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
@ -216,6 +236,7 @@ def _get_rank_helper(group, backend):
elif backend == Backend.NCCL:
rank_id = get_rank_id(group)
else:
# 如果后端不被支持抛出ValueError异常
raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi, hccl or nccl.".format(backend))
return rank_id
@ -236,22 +257,30 @@ def _get_local_rank_helper(group, backend):
Returns:
Integer. The local rank id of the calling process.
"""
# 获取当前进程的rank id根据不同的后端和组进行处理
rank_id = None
# 根据不同的后端选择获取rank_id的方法
if backend == Backend.HCCL_MPI:
# 使用HCCL MPI后端时通过mpi.get_rank_id获取rank_id
rank_id = mpi.get_rank_id(group)
elif backend == Backend.HCCL:
# 使用HCCL后端时根据group的不同选择获取rank_id的方法
if group == HCCL_WORLD_COMM_GROUP:
# 如果group是HCCL_WORLD_COMM_GROUP则使用hccl.get_local_rank_id获取rank_id
rank_id = hccl.get_local_rank_id()
else:
# 对于其他group同样使用hccl.get_local_rank_id获取rank_id
rank_id = hccl.get_local_rank_id(group)
elif backend == Backend.NCCL:
# 如果使用NCCL后端当前不支持get_local_rank_id方法抛出异常
raise RuntimeError("Nccl doesn't support get_local_rank_id now.")
else:
# 如果backend既不是HCCL_MPI也不是HCCL抛出异常表示不支持的backend
raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, "
"please use hccl_mpi or hccl.".format(backend))
# 返回获取到的rank_id
return rank_id
@check_parameter_available
def _get_size_helper(group, backend):
"""
@ -268,9 +297,12 @@ def _get_size_helper(group, backend):
Integer. The rank size of specified group.
"""
size = None
# 如果当前角色是参数服务器或调度器则将size设为1并返回
if _is_role_pserver() or _is_role_sched():
size = 1
return size
# 根据不同的后端设置size的值
# 根据不同的后端获取组的大小
if backend == Backend.HCCL_MPI:
size = mpi.get_rank_size(group)
elif backend == Backend.HCCL:
@ -302,19 +334,23 @@ def _get_local_size_helper(group, backend):
Integer. The local rank size where the calling process is being within specified group.
"""
size = None
# 根据不同的后端获取本地进程组的大小
if backend == Backend.HCCL:
# 如果组是全局通信组,则获取全局通信组的本地排名大小
if group == HCCL_WORLD_COMM_GROUP:
size = hccl.get_local_rank_size()
# 否则获取指定组的本地排名大小
else:
size = hccl.get_local_rank_size(group)
# NCCL后端不支持获取本地排名大小抛出异常
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
# 对于不支持的后端,抛出异常
else:
raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, "
"please use hccl.".format(backend))
return size
@check_parameter_available
def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
"""
@ -333,21 +369,26 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
Integer. A rank id in world communication group.
"""
world_rank_id = None
# 检查 group_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(group_rank_id, int):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be"
" type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id)))
# 根据不同的后端选择不同的逻辑处理方式
if backend == Backend.HCCL:
# 如果在 GPU 上使用 HCCL但 group 参数为 HCCL_WORLD_COMM_GROUP则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl.get_world_rank_from_group_rank 方法获取 world_rank_id
world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
elif backend == Backend.NCCL:
# 如果使用 NCCL则抛出 RuntimeError 表示不支持该操作
raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
else:
# 如果 backend 参数不支持,则抛出 ValueError 表示不支持该后端
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取的 world_rank_id
return world_rank_id
@check_parameter_available
def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
"""
@ -366,21 +407,27 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
Integer. A rank id in user communication group.
"""
group_rank_id = None
# 检查 world_rank_id 是否为整数类型,如果不是则抛出 TypeError
if not isinstance(world_rank_id, int):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, "
"but got 'world_rank_id' type : {}.".format(type(world_rank_id)))
# 根据不同的后端处理获取 group_rank_id 的逻辑
if backend == Backend.HCCL:
# 检查 GPU 后端的 group 参数是否正确,如果不正确则抛出 ValueError
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' "
"should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.")
# 调用 hccl 模块的函数获取 group_rank_id
group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
# NCCL 后端不支持此操作,抛出 RuntimeError
elif backend == Backend.NCCL:
raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
# 如果后端不是 HCCL 或 NCCL则抛出 ValueError 表示不支持的后端
else:
raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend))
# 返回获取到的 group_rank_id
return group_rank_id
@check_parameter_available
def _create_group_helper(group, rank_ids, backend):
"""
@ -395,34 +442,46 @@ def _create_group_helper(group, rank_ids, backend):
TypeError: If rank_ids is not a list.
ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
"""
# 检查组是否已存在
if group in _ExistingGroup.ITEMS.keys():
# 如果组已存在且提供的rank_ids与存储的不一致抛出异常
if rank_ids != _ExistingGroup.ITEMS[group]:
raise ValueError("The group {} has been created, the rank_list is {}, "
"but current rank_list for the group is {}".
format(group, _ExistingGroup.ITEMS[group], rank_ids))
# 记录警告信息,提示组已存在
logger.warning("%r group has existed.", group)
return
# 根据不同的后端创建组
if backend == Backend.HCCL:
# 检查rank_ids是否为列表类型
if not isinstance(rank_ids, list):
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
# 检查rank_ids的长度是否大于1
rank_size = len(rank_ids)
if rank_size < 1:
raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, "
"but got 'rank_ids' size : {}.".format(len(rank_ids)))
# 检查rank_ids中是否有重复的元素
if len(rank_ids) - len(list(set(rank_ids))) > 0:
raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
# 使用HCCL创建组
hccl.create_group(group, rank_size, rank_ids)
elif backend == Backend.HCCL_MPI:
# 使用HCCL_MPI创建组
mpi.create_group(group, rank_ids)
elif backend == Backend.NCCL:
# NCCL暂不支持创建组抛出异常
raise RuntimeError("Nccl doesn't support create_group now.")
else:
# 如果后端不支持,抛出异常
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend))
_ExistingGroup.ITEMS[group] = rank_ids
# 将新创建的组及其rank_ids添加到_existingGroup中
_ExistingGroup.ITEMS[group] = rank_ids
@check_parameter_available
def _destroy_group_helper(group, backend):
"""
@ -435,12 +494,17 @@ def _destroy_group_helper(group, backend):
Raises:
ValueError: If group is "hccl_world_group" or backend is invalid.
"""
# 根据后端类型销毁通信组
if backend == Backend.HCCL:
# 检查是否为 HCCL 的全局通信组
if group == HCCL_WORLD_COMM_GROUP:
raise ValueError("The hccl_world_group does not support destruction.")
# 销毁指定的 HCCL 通信组
hccl.destroy_group(group)
elif backend == Backend.NCCL:
# 当前 NCCL 后端不支持销毁通信组
raise RuntimeError("Nccl doesn't support destroy_group now.")
else:
# 抛出错误,表示不支持的后端类型
raise ValueError("The context configuration parameter 'backend' {} is not supported, "
"please use hccl.".format(backend))
"please use hccl.".format(backend))

@ -24,7 +24,7 @@ MAX_RANK_NUM = 4096
HCCL_LIB = 'libhccl_plugin.so'
HCCL_LIB_CTYPES = ""
# 检查集体通信组的名称是否合法
def check_group(group):
"""
A function that check if a collection communication group is legal.
@ -41,23 +41,31 @@ def check_group(group):
raise TypeError("The type of communication group name must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 检查集体通信中的排名编号是否合法
def check_rank_num(rank_num):
"""
A function that check if a collection communication rank number is legal.If not raise error.
检查通信集合中的排名编号是否合法如果不合法则抛出错误
参数:
rank_num: 需要检查的排名编号预期为整数类型
Returns:
None
None: 该函数不返回任何值但可能会抛出异常
"""
if isinstance(rank_num, (int)):
# 检查 rank_num 是否为整数类型
if isinstance(rank_num, int):
# 检查 rank_num 是否在合法范围内大于0且小于等于 MAX_RANK_NUM
if rank_num > MAX_RANK_NUM or rank_num <= 0:
raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and"
"less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num))
# 如果 rank_num 不在合法范围内,抛出 ValueError 异常,并提供详细的错误信息
raise ValueError("对于 'create_group' 函数,参数 'rank_ids' 的大小必须大于0且"
"小于等于 {},但得到的 'rank_ids' 的大小为: {}".format(MAX_RANK_NUM, rank_num))
else:
raise TypeError("The argument 'rank_num' must be type of int, "
"but got 'rank_num' type : {}.".format(type(rank_num)))
# 如果 rank_num 不是整数类型,抛出 TypeError 异常,并提供详细的错误信息
raise TypeError("参数 'rank_num' 必须为整数类型,"
"但得到的 'rank_num' 类型为: {}".format(type(rank_num)))
#检查集体通信中的排名标识(rank id)是否合法
def check_rank_id(rank_id):
"""
A function that check if a collection communication rank id is legal.If not raise error.
@ -65,40 +73,48 @@ def check_rank_id(rank_id):
Returns:
None
"""
# 检查rank_id是否为整数类型
if isinstance(rank_id, (int)):
# 检查rank_id是否在有效范围内
if rank_id >= MAX_RANK_NUM or rank_id < 0:
raise ValueError("The rand id in the communication group must be greater or equal 0 and "
"less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id))
else:
# 如果rank_id不是整数类型抛出类型错误
raise TypeError("The rand id in the communication group must be must be type of int, "
"but got type value : {}.".format(type(rank_id)))
# 加载 HCCLHuawei Cloud Communication Library
def load_lib():
"""load hccl lib"""
try:
# 获取当前文件所在的目录
base_dir = os.path.dirname(os.path.realpath(__file__))
# 构建库文件的路径
lib_path = os.path.join(base_dir, "../lib", HCCL_LIB)
# 加载库文件
hccl_lib = ctypes.CDLL(lib_path)
except Exception:
# 如果加载失败则抛出运行时错误
raise RuntimeError('Get hccl lib error.')
# 将加载的库文件设置为全局变量
global HCCL_LIB_CTYPES
HCCL_LIB_CTYPES = hccl_lib
def c_str(string):
"""Convert a python string to C string."""
# 将字符串转换为C风格字符串
if not isinstance(string, str):
string = string.decode('ascii')
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array."""
# 从Python数组创建ctypes数组
return (ctype * len(values))(*values)
#用于创建包含指定数量和ID的HCCL通信组但不能创建世界组。
def create_group(group, rank_num, rank_ids):
"""
Create group.
@ -112,28 +128,38 @@ def create_group(group, rank_num, rank_ids):
Returns:
None
"""
# 检查组的有效性
check_group(group)
# 检查排名数量的有效性
check_rank_num(rank_num)
# 检查rank_ids是否为列表类型
if isinstance(rank_ids, (list)):
# 确保rank_num与rank_ids的长度一致
if rank_num != len(rank_ids):
raise ValueError("The argument 'rank_num' number should be equal to the length "
"of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}."
.format(rank_num, rank_ids))
# 检查rank_ids中的每个元素是否为非负整数
for rank_id in rank_ids:
if not isinstance(rank_id, (int)) or rank_id < 0:
raise ValueError("The elements of argument 'rank_ids' must be "
"unsigned integer, but got the type : {}".format(type(rank_id)))
# 将rank_ids转换为C类型的数组
c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
# 将rank_num转换为C类型的无符号整数
c_rank_num = ctypes.c_uint(rank_num)
# 将group转换为C类型的字符串
c_group = c_str(group)
# 调用HCCL库创建组
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
# 检查组创建是否成功
if ret != 0:
raise RuntimeError('Create group error, the error code is {}.'.format(ret))
else:
# 如果rank_ids不是列表类型抛出类型错误
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
#用于销毁用户创建的HCCL组
def destroy_group(group):
"""
A function that destroy the group which created by user.
@ -144,11 +170,16 @@ def destroy_group(group):
Returns:
None
"""
check_group(group)
c_group = c_str(group)
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
if ret != 0:
raise RuntimeError('Destroy group error.')
# 检查传入的组是否有效
check_group(group)
# 将组名转换为C风格的字符串
c_group = c_str(group)
# 调用HCCL库中的函数销毁指定的组
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
# 如果返回值不为0说明销毁组时发生了错误抛出异常
if ret != 0:
raise RuntimeError('Destroy group error.')
def get_rank_size(group="hccl_world_group"):
@ -162,16 +193,23 @@ def get_rank_size(group="hccl_world_group"):
An integer scalar with the num of ranks.
"""
# 根据上下文的模式判断是否为PYNATIVE_MODE模式若是则直接返回HCCL的rank size
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_size()
# 检查给定的组是否有效
check_group(group)
# 将组名转换为C字符串格式
c_group = c_str(group)
# 定义一个C类型的无符号整数用于存储rank size
c_rank_size = ctypes.c_uint()
# 调用HCCL库的HcomGetRankSize函数获取组内的rank size
ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
# 如果返回值不为0表示获取rank size失败抛出运行时错误
if ret != 0:
raise RuntimeError('Get rank size error.')
# 返回获取到的rank size值
return c_rank_size.value
@ -186,17 +224,22 @@ def get_rank_id(group="hccl_world_group"):
if context.get_context("mode") == context.PYNATIVE_MODE:
return get_hccl_rank_id()
# 检查组的有效性
check_group(group)
# 将组转换为 C 字符串格式
c_group = c_str(group)
# 定义一个用于存储 rank id 的 ctypes 无符号整数
c_rank_id = ctypes.c_uint()
# 调用 HCCL 库获取当前进程的 rank id
ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
# 如果返回值不为 0表示获取 rank id 出错,抛出 RuntimeError 异常
if ret != 0:
raise RuntimeError('Get rank id error.')
# 返回获取到的 rank id 值
return c_rank_id.value
def get_local_rank_size(group="hccl_world_group"):
"""
A function that returns the number of local ranks within the given collection communication group.
@ -207,19 +250,25 @@ def get_local_rank_size(group="hccl_world_group"):
Returns:
An integer scalar with the num of local ranks.
"""
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, "
"'get_local_rank_size' only support GRAPH_MODE")
# 验证传入的组是否有效
check_group(group)
# 将组名称转换为C字符串格式
c_group = c_str(group)
# 定义一个ctypes的无符号整数变量用于存储本地排名大小
c_local_rank_size = ctypes.c_uint()
# 调用HCCL库中的HcomGetLocalRankSize函数获取本地排名大小
ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
# 如果返回值不为0说明获取本地排名大小时出错抛出异常
if ret != 0:
raise RuntimeError('Get local rank size error.')
# 返回获取到的本地排名大小
return c_local_rank_size.value
def get_local_rank_id(group="hccl_world_group"):
"""
Get local rank id.
@ -230,16 +279,23 @@ def get_local_rank_id(group="hccl_world_group"):
An integer scalar with the local rank id of the calling process.
"""
# 检查当前运行模式是否为PYNATIVE_MODE如果是则抛出异常
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, "
"'get_local_rank_id' only support GRAPH_MODE")
# 验证群组的有效性
check_group(group)
# 将群组名称转换为C字符串格式
c_group = c_str(group)
# 定义一个无符号整型的C类型变量来存储本地排名ID
c_local_rank_id = ctypes.c_uint()
# 调用HCCL库的HcomGetLocalRankId函数获取本地排名ID
ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
# 如果返回值不为0表示获取本地排名ID失败抛出异常
if ret != 0:
raise RuntimeError('Get local rank id error.')
# 返回获取到的本地排名ID值
return c_local_rank_id.value
@ -256,18 +312,25 @@ def get_world_rank_from_group_rank(group, group_rank_id):
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, "
"'get_world_rank_from_group_rank' only support GRAPH_MODE")
# 检查组名是否有效
check_group(group)
# 检查组内rank ID是否有效
check_rank_id(group_rank_id)
# 将组名转换为C字符串格式
c_group = c_str(group)
# 将组内rank ID转换为C的无符号整数类型
c_group_rank_id = ctypes.c_uint(group_rank_id)
# 声明一个用于存储世界rank ID的C的无符号整数类型变量
c_world_rank_id = ctypes.c_uint()
# 调用HCCL库中的HcomGetWorldRankFromGroupRank函数根据组名和组内rank ID获取对应的世界rank ID
ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
# 如果返回值不为0说明函数调用出错抛出RuntimeError异常
if ret != 0:
raise RuntimeError('Get world rank from group rank error.')
raise RuntimeError('根据组内rank ID获取世界rank ID时出错。')
# 返回获取到的世界rank ID的值
return c_world_rank_id.value
def get_group_rank_from_world_rank(world_rank_id, group):
"""
Get group rank from world rank.
@ -281,13 +344,21 @@ def get_group_rank_from_world_rank(world_rank_id, group):
if context.get_context("mode") is context.PYNATIVE_MODE:
raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, "
"'get_group_rank_from_world_rank' only support GRAPH_MODE")
# 检查组的有效性
check_group(group)
# 检查世界排名ID的有效性
check_rank_id(world_rank_id)
# 将组转换为C字符串
c_group = c_str(group)
# 将世界排名ID转换为C无符号整数
c_world_rank_id = ctypes.c_uint(world_rank_id)
# 创建一个用于存储组排名ID的C无符号整数
c_group_rank_id = ctypes.c_uint()
# 从世界排名ID获取组排名ID
ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
# 如果返回值不为0抛出运行时错误
if ret != 0:
raise RuntimeError('Get group rank from world rank error.')
return c_group_rank_id.value
# 返回获取到的组排名ID的值
return c_group_rank_id.value

@ -20,40 +20,51 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \
_create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \
_get_local_rank_helper, _get_local_size_helper, GlobalComm
from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective
# 导入mindspore上下文模块
# 导入mindspore的进程服务器和调度器角色判断函数
# 导入通信辅助模块包括后端类型、获取和设置rank及size的辅助函数、创建和销毁组的辅助函数、以及世界通信组和本地通信组的相关标识
# 导入C语言表达式模块包含HCCL和NCCL的初始化和终结化函数以及GPU集体通信的初始化函数
__all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size",
"get_local_rank_size", "get_world_rank_from_group_rank",
"get_group_rank_from_world_rank", "create_group", "destroy_group",
"HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"]
# 默认的世界通信组
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
def _get_group(group):
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
"""根据输入的`group`参数返回相应的通信组。
如果`group`等于`DEFAULT_WORLD_COMM_GROUP`则返回全局通信组`GlobalComm.WORLD_COMM_GROUP`
否则返回传入的`group`参数
"""
if group == DEFAULT_WORLD_COMM_GROUP:
return GlobalComm.WORLD_COMM_GROUP
return group
return GlobalComm.WORLD_COMM_GROUP # 返回默认的世界通信组
return group # 返回传入的通信组参数
def _check_task_sink_envs():
"""
Check whether task_sink environment variables have been exported or not.
return True if task_sink environment variables have been exported, False otherwise.
"""检查任务接收器task_sink相关的环境变量是否已导出。
该函数通过检查环境变量`GRAPH_OP_RUN`来判断任务接收器的环境变量是否已导出
返回值
- 如果环境变量`GRAPH_OP_RUN`已导出且其值可以转换为整数1则返回False表示环境变量未正确设置为启用状态
- 如果环境变量`GRAPH_OP_RUN`未导出或其值不能转换为整数1则返回True表示环境变量未导出或设置有误
"""
import os
task_sink = os.getenv("GRAPH_OP_RUN")
task_sink = os.getenv("GRAPH_OP_RUN") # 获取名为"GRAPH_OP_RUN"的环境变量
if task_sink:
try:
if int(task_sink) == 1:
return False
if int(task_sink) == 1: # 尝试将环境变量的值转换为整数并检查是否等于1
return False # 如果等于1返回False表示环境变量已导出但设置为禁用状态非预期情况
except ValueError:
return True
return True # 如果转换为整数失败返回True表示环境变量设置有误
finally:
pass
return True
pass # finally块中的代码在这里是空操作通常用于清理操作
return True # 如果环境变量未导出返回True
def _check_parallel_envs():
@ -63,25 +74,31 @@ def _check_parallel_envs():
Raises:
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
"""
# 检查是否需要进行环境验证,如果不进行则直接返回
if not GlobalComm.CHECK_ENVS:
return
import os
# 获取环境变量RANK_ID的值
rank_id_str = os.getenv("RANK_ID")
# 如果RANK_ID未设置抛出运行时错误
if not rank_id_str:
raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.")
try:
# 尝试将RANK_ID转换为整数
int(rank_id_str)
except ValueError:
# 如果转换失败,打印错误信息
print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str)))
finally:
# 无论是否发生异常,此块为空操作
pass
# 获取环境变量MINDSPORE_HCCL_CONFIG_PATH和RANK_TABLE_FILE的值
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
# 如果两个环境变量都未设置,抛出运行时错误
if not rank_table_file_str and not rank_table_file_str_old:
raise RuntimeError("Get hccl rank_table_file failed, "
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")
def init(backend_name=None):
"""
Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service.
@ -109,15 +126,20 @@ def init(backend_name=None):
>>> from mindspore.communication import init
>>> init()
"""
# 检查当前角色是否为参数服务器或调度器,如果是则直接返回
if _is_role_pserver() or _is_role_sched():
return
# 检查任务接收环境变量,获取设备目标和模式
task_sink = _check_task_sink_envs()
device_target = context.get_context("device_target")
mode = context.get_context("mode")
mpi_init = False
# 如果没有任务接收且模式为图模式设置mpi_init为True
if not task_sink and mode == context.GRAPH_MODE:
mpi_init = True
# 根据设备目标选择后端名称,如果不支持则抛出异常
# 根据设备目标设置默认的后端名称
if backend_name is None:
if device_target == "Ascend":
backend_name = "hccl"
@ -126,28 +148,34 @@ def init(backend_name=None):
else:
raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in "
"parallel initialization, please use Ascend or GPU.".format(device_target))
# 检查后端名称是否为字符串,如果不是则抛出异常
if not isinstance(backend_name, str):
raise TypeError("For 'init', the argument 'backend_name' must be a string, "
"but got the type : {}".format(type(backend_name)))
# 根据后端名称初始化通信环境
if backend_name == "hccl":
# 如果设备目标不是Ascend抛出异常
if device_target != "Ascend":
raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, "
"but got {}".format(device_target))
# 如果不需要MPI初始化检查并行环境并设置后端名称
if not mpi_init:
_check_parallel_envs()
GlobalComm.BACKEND = Backend("hccl")
else:
GlobalComm.BACKEND = Backend("hccl_mpi")
# 初始化HCCL并设置全局通信组和初始化状态
init_hccl()
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
elif backend_name == "nccl":
# 初始化GPU集体通信并设置后端名称、全局通信组和初始化状态
init_gpu_collective()
GlobalComm.BACKEND = Backend("nccl")
GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP
GlobalComm.INITED = True
else:
# 如果后端名称不支持,抛出异常
raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, "
"but got the 'backend_name' : hccl.")
@ -165,10 +193,11 @@ def release():
Examples:
>>> from mindspore.communication import init, release
>>> init()
>>> release()
>>> init() # 初始化分布式通信环境
>>> release() # 释放分布式通信资源
"""
finalize_hccl()
finalize_hccl()# 结束 HCCL 的使用,释放相关资源
def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
@ -197,12 +226,18 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP):
>>> print(rank_id)
>>> # the result is the rank_id in world_group
"""
# 检查传入的group参数是否为字符串类型
if not isinstance(group, str):
# 如果group参数不是字符串类型则抛出TypeError异常
raise TypeError("For 'get_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用_get_rank_helper函数获取指定group的rank ID
# _get_group函数用于解析group参数
# GlobalComm.BACKEND变量存储了当前使用的通信后端
return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
"""
Gets local rank ID for current device in specified collective communication group.
@ -234,9 +269,11 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank))
local_rank is: 1, world_rank is 9
"""
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _get_local_rank_helper 来获取本地排名,传入的参数为解析后的组和全局通信后端
return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -270,9 +307,11 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("group_size is: ", group_size)
group_size is: 8
"""
# 检查传入的参数 'group' 是否为字符串类型
if not isinstance(group, str):
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 返回指定组的大小,使用辅助函数 _get_size_helper 和全局通信后端
return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -306,9 +345,11 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP):
>>> print("local_rank_size is: ", local_rank_size)
local_rank_size is: 8
"""
# 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用辅助函数获取本地组的大小,使用 _get_group 函数获取组,并使用 GlobalComm.BACKEND 作为后端
return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND)
@ -347,12 +388,12 @@ def get_world_rank_from_group_rank(group, group_rank_id):
>>> print("world_rank_id is: ", world_rank_id)
world_rank_id is: 4
"""
# 检查传入的 group 参数是否为字符串类型
if not isinstance(group, str):
raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 返回根据组和组排名获取的世界排名的帮助函数的结果
return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND)
def get_group_rank_from_world_rank(world_rank_id, group):
"""
Get the rank ID in the specified user communication group corresponding to
@ -389,12 +430,13 @@ def get_group_rank_from_world_rank(world_rank_id, group):
>>> print("group_rank_id is: ", group_rank_id)
group_rank_id is: 1
"""
# 检查输入参数 'group' 是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _get_group_rank_from_world_rank_helper 来获取组的排名,并返回结果
return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND)
#创建一个用户集体通信组通过传入通信组名称和设备ID列表来实现。
def create_group(group, rank_ids):
"""
Create a user collective communication group.
@ -427,9 +469,11 @@ def create_group(group, rank_ids):
>>> create_group(group, rank_ids)
>>> allreduce = ops.AllReduce(group)
"""
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'create_group', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
# 调用辅助函数 _create_group_helper 来创建组,使用指定的 rank_ids 和后端
_create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND)
@ -450,7 +494,9 @@ def destroy_group(group):
ValueError: If group is "hccl_world_group" or backend is invalid.
RuntimeError: If HCCL is not available or MindSpore is GPU version.
"""
# 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError
if not isinstance(group, str):
raise TypeError("For 'destroy_group', the argument 'group' must be type of string, "
"but got 'group' type : {}.".format(type(group)))
_destroy_group_helper(group, backend=GlobalComm.BACKEND)
# 调用辅助函数 _destroy_group_helper 来销毁指定的组,并使用全局通信后端
_destroy_group_helper(group, backend=GlobalComm.BACKEND)

@ -45,14 +45,15 @@ class OneOf(OneOf_):
TypeError: raise type error for invalid inputs.
"""
self.patterns = patterns
# 检查 patterns 是否是 Pattern 类的实例
if isinstance(patterns, Pattern):
OneOf_.__init__(self, [patterns])
# 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例
elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns):
OneOf_.__init__(self, patterns)
# 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常
else:
raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}")
class Prim(Prim_):
r"""
Express a pattern of certain primitive type(s).
@ -76,25 +77,33 @@ class Prim(Prim_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 检查name是否为字符串类型如果不是则抛出TypeError
if name is not None and not isinstance(name, str):
raise TypeError(f"Expect string, got : {name}")
self.name = name
# 如果types是字符串类型则将其按'|'分割成列表
if isinstance(types, str):
if self.name is None:
self.name = types
self.types = types.split('|')
# 如果types是Primitive类型则直接将其放入列表中
elif isinstance(types, Primitive):
if self.name is None:
self.name = types.name
self.types = [types]
# 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型
elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types):
# 如果 self.name 为 None则初始化为空字符串并拼接所有 Primitive 的 name
if self.name is None:
self.name = ""
for prim in types:
self.name += prim.name
# 设置 self.types 为传入的 types
self.types = types
# 如果 types 不符合预期类型,抛出 TypeError
else:
raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}")
# 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name
Prim_.__init__(self, self.types, self.name)
@ -115,16 +124,22 @@ class Call(Call_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
self.prim_pattern = prim_pattern
# 初始化 inputs 列表
self.inputs = []
# 如果 inputs 为 None则不做任何操作
if inputs is None:
pass
# 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs
elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs):
self.inputs = inputs
# 如果 inputs 不符合上述条件,则抛出 TypeError
else:
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
# 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs
Call_.__init__(self, self.prim_pattern, self.inputs)
@ -145,6 +160,7 @@ class NoneOf(NoneOf_):
TypeError: raise type error for invalid argument.
"""
self.patterns = patterns
# 根据 patterns 的类型初始化 NoneOf_ 类
if patterns is None:
NoneOf_.__init__(self, ())
elif isinstance(patterns, Pattern):
@ -154,7 +170,6 @@ class NoneOf(NoneOf_):
else:
raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}")
class NewTensor(NewTensor_):
r"""
New Tensor to be used in the target.
@ -167,13 +182,16 @@ class NewTensor(NewTensor_):
Raises:
TypeError: raise type error for invalid argument.
"""
# 初始化输入张量
self.input_tensor = input_tensor
# 检查输入是否为 Tensor 类型
if isinstance(input_tensor, Tensor):
# 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法
NewTensor_.__init__(self, input_tensor)
else:
# 如果不是 Tensor 类型,则抛出 TypeError 异常
raise TypeError(f"Expect input_tensor to be a Tensor got : {input_tensor}")
class NewParameter(NewParameter_):
r"""
New Parameter to be used in the target.
@ -193,11 +211,14 @@ class NewParameter(NewParameter_):
self.default_tensor = default_tensor
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
# 检查参数类型是否符合预期
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool):
# 初始化 NewParameter_ 类
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel)
else:
# 如果参数类型不符合预期,抛出 TypeError
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}")
{requires_grad}, {layerwise_parallel}")

@ -39,6 +39,7 @@ class PyPassManager(PyPassManager_):
TypeError: If argument has invalid type.
"""
def __init__(self, requires_grad=True, run_only_once=False):
# 初始化方法,接收两个布尔参数,设置实例的属性并调用父类的初始化方法
if not isinstance(requires_grad, bool):
raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
if not isinstance(run_only_once, bool):
@ -48,17 +49,20 @@ class PyPassManager(PyPassManager_):
PyPassManager_.__init__(self)
def register(self, py_pass):
# 注册一个Python pass检查其是否为函数类型并获取其模式和目标
if not isfunction(py_pass):
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass()
pass_name = py_pass.__name__
# 检查模式和目标是否为Pattern类型
if not isinstance(pattern, Pattern):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
# 调用父类的register方法注册pass及其相关信息
super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
def unregister(self, py_pass):
# 从注册表中移除指定的Python传递对象可以是字符串形式的名称或函数对象
if isinstance(py_pass, str):
super().unregister(py_pass)
return
@ -66,27 +70,30 @@ class PyPassManager(PyPassManager_):
super().unregister(py_pass.__name__)
return
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass):
# 将Python传递对象注册到注册表中并返回该对象
self.register(py_pass)
return py_pass
def gen_new_parameter(self, pattern):
# 根据给定的模式生成新的参数模式必须是NewParameter类型
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
super().gen_new_parameter(pattern)
def set_renorm(self, should_renorm):
# 设置是否进行重归一化操作,参数必须是布尔值
if not isinstance(should_renorm, bool):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm)
def set_reopt(self, do_reopt):
# 设置是否进行重新优化操作,参数必须是布尔值
if not isinstance(do_reopt, bool):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt)
def register_pass(requires_grad=True, run_only_once=False):
"""
Register python pass to specified pipeline phase which would be used in compilation.
@ -165,12 +172,13 @@ def cancel_new_parameter(pattern):
>>> # some compilations
>>> cancel_new_parameter(abc)
"""
# 检查传入的pattern是否为NewParameter的实例
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
# 创建一个PyPassManager对象
ppm = PyPassManager()
# 从PyPassManager中注销指定名称的参数
ppm.unregister(pattern.para_name)
def set_renorm(should_renorm):
"""
Set whether or not to do renormalization after modified graph in python pass(es).

@ -19,22 +19,25 @@ import sys
def parse_args():
"""
parse args .
Args:
Returns:
args.
Examples:
>>> parse_args()
"""
# 创建一个ArgumentParser对象用于解析命令行参数描述信息为"MindSpore dependency packages version checker."
parser = ArgumentParser(description="MindSpore dependency packages version checker.")
# 添加一个命令行参数--mindspore_version类型为字符串帮助信息为"MindSpore version."
parser.add_argument("--mindspore_version", type=str, help="MindSpore version.")
# 添加一个命令行参数--supported_version类型为字符串可以多次指定帮助信息为"Supported environment version."
parser.add_argument("--supported_version", type=str, action='append', help="Supported environment version.")
# 解析命令行参数并返回结果
args = parser.parse_args()
return args
def check_deps_version(mindspore_version, supported_version):
"""
check te/hccl/topi version
@ -46,6 +49,7 @@ def check_deps_version(mindspore_version, supported_version):
Returns:
void
"""
# 尝试导入并检查 hccl、te 和 topi 包的版本是否与支持的版本匹配
try:
from hccl import sys_version as hccl_version
v = '.'.join(hccl_version.__sys_version__.split('.')[0:2])
@ -63,18 +67,20 @@ def check_deps_version(mindspore_version, supported_version):
print(f"MindSpore version {mindspore_version} and \"topi\" wheel package version {v} does not "
"match, reference to the match info on: https://www.mindspore.cn/install")
# 捕获导入错误并打印相应的检查失败信息
except ImportError as e:
print("CheckFailed: ", e.args)
print("MindSpore relies on the 3 whl packages of \"te\", \"topi\" and \"hccl\" in the \"fwkacllib\" "
"folder of the Ascend AI software package (Ascend Data Center Solution), please check whether they are "
"installed correctly or not, reference to the match info on: https://www.mindspore.cn/install")
def main():
# 解析命令行参数
args = parse_args()
# 检查 mindspore 的版本是否在支持的版本范围内
check_deps_version(args.mindspore_version, args.supported_version)
if __name__ == "__main__":
sys.path = sys.path[1:] # avoid the impact of relative path env, only affect this process
# 避免相对路径环境的影响,仅影响当前进程
sys.path = sys.path[1:]
main()

@ -31,47 +31,58 @@ class EnvChecker(metaclass=ABCMeta):
@abstractmethod
def check_env(self, e):
pass
"""检查环境是否符合要求"""
@abstractmethod
def set_env(self):
pass
"""设置环境"""
@abstractmethod
def check_version(self):
pass
"""检查版本是否符合要求"""
class GPUEnvChecker(EnvChecker):
"""GPU environment check."""
def __init__(self):
# 初始化版本列表
self.version = ["10.1", "11.1"]
# 初始化库键到库名的映射字典
self.lib_key_to_lib_name = {'libcu': 'libcuda.so'}
# env
# 获取系统环境变量 PATH 的值
self.path = os.getenv("PATH")
# 获取系统环境变量 LD_LIBRARY_PATH 的值
self.ld_lib_path = os.getenv("LD_LIBRARY_PATH")
# check
# 初始化版本号为 "0"
self.v = "0"
# 获取 CUDA 库的路径
self.cuda_lib_path = self._get_lib_path("libcu")
# 获取 CUDA 可执行文件的路径
self.cuda_bin_path = self._get_bin_path("cuda")
# 获取 cuDNN 库的路径
self.cudnn_lib_path = self._get_lib_path("libcudnn")
def check_env(self, e):
# 抛出传入的异常 e
raise e
def set_env(self):
# 设置环境变量,当前实现为空
return
def _get_bin_path(self, bin_name):
"""Get bin path by bin name."""
# 如果二进制名称为 "cuda",则调用获取 CUDA 二进制路径的方法
if bin_name == "cuda":
return self._get_cuda_bin_path()
# 否则返回空列表
return []
def _get_cuda_bin_path(self):
"""Get cuda bin path by lib path."""
# Get cuda bin path by lib path.
path_list = []
for path in self.cuda_lib_path:
path = os.path.abspath(path.strip()+"/bin/")
@ -81,56 +92,87 @@ class GPUEnvChecker(EnvChecker):
def _get_nvcc_version(self, is_set_env):
"""Get cuda version by nvcc command."""
# 运行 nvcc 命令获取 CUDA 版本信息
nvcc_result = subprocess.run(["nvcc", "--version | grep release"],
timeout=3, text=True, capture_output=True, check=False)
# 如果命令返回非零值,表示命令执行失败
if nvcc_result.returncode:
# 如果尚未设置环境变量
if not is_set_env:
# 遍历预设的 CUDA 二进制路径
for path in self.cuda_bin_path:
# 检查路径中是否存在 nvcc 文件
if Path(path + "/nvcc").is_file():
# 将路径添加到环境变量 PATH 中
os.environ['PATH'] = path + ":" + os.environ['PATH']
# 递归调用以重新尝试获取版本信息
return self._get_nvcc_version(True)
# 如果命令执行失败且未找到 nvcc 文件,返回空字符串
return ""
# 获取命令输出结果
result = nvcc_result.stdout
# 遍历输出结果的每一行
for line in result.split('\n'):
if line:
# 提取并返回 CUDA 版本号
return line.strip().split("release")[1].split(",")[0].strip()
# 如果未找到版本信息,返回空字符串
return ""
def _get_cudnn_version(self):
"""Get cudnn version by libcudnn.so."""
# 初始化cudnn版本列表为空
cudnn_version = []
# 遍历cudnn库路径
for path in self.cudnn_lib_path:
# 查找路径下所有的libcudnn.so文件
real_path = glob.glob(path + "/lib*/libcudnn.so.*.*")
# 如果没有找到对应的文件,继续下一个路径
if real_path == []:
continue
# 使用ls命令获取文件信息
ls_cudnn = subprocess.run(["ls", real_path[0]], timeout=10, text=True,
capture_output=True, check=False)
# 如果ls命令执行成功解析输出以获取版本号
if ls_cudnn.returncode == 0:
cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.')
# 如果版本号只有两个部分,添加一个'.0'作为第三部分
if len(cudnn_version) == 2:
cudnn_version.append('0')
# 找到版本号后跳出循环
break
# 将版本号列表转换为字符串
version_str = ''.join([n for n in cudnn_version])
# 返回版本号的前三位
return version_str[0:3]
def _get_cudart_version(self):
"""Get cuda runtime version by libcudart.so."""
# 遍历可能的 CUDA 库路径
for path in self.cuda_lib_path:
# 查找路径下所有可能的 libcudart.so 文件
real_path = glob.glob(path + "/lib*/libcudart.so.*.*.*")
# 如果没有找到任何文件,则跳过当前路径
if real_path == []:
continue
# 获取文件名信息以确定 CUDA 版本
ls_cudart = subprocess.run(["ls", real_path[0]], timeout=10, text=True,
capture_output=True, check=False)
# 如果命令成功执行,则解析输出以提取版本号
if ls_cudart.returncode == 0:
self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip()
# 找到版本号后跳出循环
break
# 返回找到的 CUDA 版本号
return self.v
def check_version(self):
"""Check cuda version."""
version_match = False
# 调用私有方法检查版本是否匹配并根据结果设置version_match标志
if self._check_version():
version_match = True
# 如果版本不匹配根据CUDA版本号输出不同的警告信息
if not version_match:
if self.v == "0":
logger.warning("Can not found cuda libs, please confirm that the correct "
@ -140,17 +182,20 @@ class GPUEnvChecker(EnvChecker):
logger.warning(f"MindSpore version {__version__} and cuda version {self.v} does not match, "
"please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install")
# 获取nvcc版本号并检查是否与MindSpore支持的版本匹配
nvcc_version = self._get_nvcc_version(False)
if nvcc_version and (nvcc_version not in self.version):
logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} "
"does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install")
# 获取cudnn版本号并检查是否符合最低要求
cudnn_version = self._get_cudnn_version()
if cudnn_version and int(cudnn_version) < 760:
logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching "
"information: https://www.mindspore.cn/install. The recommended version is "
"CUDA10.1 with cuDNN7.6.x and CUDA11.1 with cuDNN8.0.x")
# 检查cudnn版本号与CUDA版本号的兼容性对于CUDA 11.0以上版本cudnn版本需要至少为8.0
if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10:
logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} "
"does not match, please refer to the installation guide for version matching "
@ -159,45 +204,58 @@ class GPUEnvChecker(EnvChecker):
def _check_version(self):
"""Check cuda version"""
# 获取 CUDA 运行时版本
v = self._get_cudart_version()
# 解析版本字符串为版本对象
v = version.parse(v)
# 构造版本号字符串,格式为 "主版本.次版本"
v_str = str(v.major) + "." + str(v.minor)
# 检查构造的版本号字符串是否在预定义的版本列表中
if v_str not in self.version:
return False
# 版本号匹配,返回 True
return True
def _get_lib_path(self, lib_name):
"""Get gpu lib path by ldd command."""
path_list = []
current_path = os.path.split(os.path.realpath(__file__))[0]
mindspore_path = os.path.join(current_path, "../")
"""通过ldd命令获取gpu库路径。"""
path_list = [] # 初始化一个空列表用于存储路径
current_path = os.path.split(os.path.realpath(__file__))[0] # 获取当前文件的绝对路径并分割以获取目录部分
mindspore_path = os.path.join(current_path, "../") # 构建mindspore路径通常是当前文件的上一级目录
try:
# 使用glob模块查找mindspore_path目录下所有以_c_expression.so开头的文件路径
real_path = glob.glob(mindspore_path + "/_c_expression*.so*")
if real_path == []:
logger.error(f"{self.lib_key_to_lib_name[lib_name]} (need by mindspore-gpu) is not found, please "
f"confirm that _c_expression.so is in directory:{mindspore_path} and the correct cuda "
"version has been installed, you can refer to the installation "
"guidelines: https://www.mindspore.cn/install")
return path_list
if real_path == []: # 如果没有找到任何文件
# 记录错误日志提示用户确认_c_expression.so文件是否存在以及是否安装了正确的cuda版本
logger.error(f"{self.lib_key_to_lib_name[lib_name]} (mindspore-gpu所需的库) 未找到,请确认 "
f"_c_expression.so是否位于目录:{mindspore_path}并且已安装正确的cuda版本"
"您可以参考安装指南https://www.mindspore.cn/install")
return path_list # 返回空路径列表
# 使用subprocess.Popen执行ldd命令以获取依赖库的信息
ldd_r = subprocess.Popen(['ldd', real_path[0]], stdout=subprocess.PIPE)
# 使用subprocess.Popen的stdin参数从ldd_r.stdout接收输出并执行grep命令以过滤出包含指定库名的信息
ldd_result = subprocess.Popen(['grep', lib_name], stdin=ldd_r.stdout, stdout=subprocess.PIPE)
# 获取grep命令的输出结果并解码为字符串
result = ldd_result.communicate()[0].decode()
for i in result.split('\n'):
for i in result.split('\n'): # 按行分割结果字符串
# 使用partition方法从每一行中提取出库文件的路径
path = i.partition("=>")[2]
if path.lower().find("not found") > 0:
logger.warning(f"Cuda {self.version} version(need by mindspore-gpu) is not found, please confirm "
"that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the "
"installation guidelines: https://www.mindspore.cn/install")
continue
if path.lower().find("not found") > 0: # 如果路径中包含"not found"
# 记录警告日志提示用户确认cuda路径是否已添加到环境变量LD_LIBRARY_PATH中
logger.warning(f"Cuda {self.version}版本(由mindspore-gpu要求的) 未找到请确认cuda路径已设置到环境变量LD_LIBRARY_PATH中"
"您可以参考安装指南https://www.mindspore.cn/install")
continue # 继续下一次循环
# 从路径中去除库名部分
path = path.partition(lib_name)[0]
if path:
if path: # 如果路径非空
# 将路径的绝对路径并去除末尾斜杠后添加到path_list中
path_list.append(os.path.abspath(path.strip() + "../"))
# 返回path_list中唯一的路径
return np.unique(path_list)
except subprocess.TimeoutExpired:
logger.warning("Failed to check cuda version due to the ldd command timeout, please confirm that "
"the correct cuda version has been installed, you can refer to the "
"installation guidelines: https://www.mindspore.cn/install")
return path_list
except subprocess.TimeoutExpired: # 捕获subprocess.TimeoutExpired异常
# 记录警告日志提示用户确认cuda版本是否正确安装因为ldd命令超时
logger.warning("由于ldd命令超时无法检查cuda版本请确认已安装正确的cuda版本"
"您可以参考安装指南:https://www.mindspore.cn/install")
return path_list # 返回空路径列表
def _read_version(self, file_path):
"""Get gpu version info in version.txt."""
@ -211,70 +269,80 @@ class GPUEnvChecker(EnvChecker):
class AscendEnvChecker(EnvChecker):
"""ascend environment check"""
"""Ascend 环境检查类"""
def __init__(self):
# 初始化 Ascend 环境检查器的版本列表
self.version = ["1.81"]
# 定义不同路径下的 version.info 文件位置
atlas_nnae_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info"
atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/version.info"
hisi_fwk_version = "/usr/local/Ascend/latest/fwkacllib/version.info"
# 检查 Atlas NNAE 环境是否存在
if os.path.exists(atlas_nnae_version):
# atlas default path
self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib"
self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe"
self.tbe_path = self.fwk_path + "/lib64"
self.cce_path = self.fwk_path + "/ccec_compiler/bin"
self.fwk_version = atlas_nnae_version
self.op_path = "/usr/local/Ascend/nnae/latest/opp"
self.aicpu_path = "/usr/local/Ascend/nnae/latest"
# 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = atlas_nnae_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/nnae/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/nnae/latest" # AI CPU 路径
# 检查 Atlas Toolkit 环境是否存在
elif os.path.exists(atlas_toolkit_version):
# atlas default path
self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib"
self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe"
self.tbe_path = self.fwk_path + "/lib64"
self.cce_path = self.fwk_path + "/ccec_compiler/bin"
self.fwk_version = atlas_toolkit_version
self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp"
self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest"
# 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = atlas_toolkit_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" # AI CPU 路径
# 检查 Hisi 环境是否存在
elif os.path.exists(hisi_fwk_version):
# hisi default path
self.fwk_path = "/usr/local/Ascend/latest/fwkacllib"
self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe"
self.tbe_path = self.fwk_path + "/lib64"
self.cce_path = self.fwk_path + "/ccec_compiler/bin"
self.fwk_version = hisi_fwk_version
self.op_path = "/usr/local/Ascend/latest/opp"
self.aicpu_path = "/usr/local/Ascend/latest"
# 如果存在,设置默认路径
self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" # Framework 路径
self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径
self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径
self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径
self.fwk_version = hisi_fwk_version # Framework 版本文件路径
self.op_path = "/usr/local/Ascend/latest/opp" # Operator 路径
self.aicpu_path = "/usr/local/Ascend/latest" # AI CPU 路径
else:
# custom or unknown environment
self.fwk_path = ""
self.op_impl_path = ""
self.tbe_path = ""
self.cce_path = ""
self.fwk_version = ""
self.op_path = ""
self.aicpu_path = ""
# env
# 如果以上环境都不存在,设置为空路径
self.fwk_path = "" # Framework 路径
self.op_impl_path = "" # Operator 实现路径
self.tbe_path = "" # TBE 库路径
self.cce_path = "" # CCE 编译器路径
self.fwk_version = "" # Framework 版本文件路径
self.op_path = "" # Operator 路径
self.aicpu_path = "" # AI CPU 路径
# 初始化环境变量
self.path = os.getenv("PATH")
self.python_path = os.getenv("PYTHONPATH")
self.ld_lib_path = os.getenv("LD_LIBRARY_PATH")
self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH")
self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH")
# check content
# 设置需要检查的路径内容
self.path_check = "/fwkacllib/ccec_compiler/bin"
self.python_path_check = "opp/op_impl/built-in/ai_core/tbe"
self.ld_lib_path_check_fwk = "/fwkacllib/lib64"
self.ld_lib_path_check_addons = "/add-ons"
self.ascend_opp_path_check = "/op"
self.v = ""
def check_env(self, e):
self._check_env()
raise e
def check_version(self):
# 检查指定路径的版本文件是否存在,如果不存在则跳过版本检查
if not Path(self.fwk_version).is_file():
logger.warning("Using custom Ascend AI software package (Ascend Data Center Solution) path, package "
"version checking is skipped, please make sure Ascend AI software package (Ascend Data "
@ -282,40 +350,47 @@ class AscendEnvChecker(EnvChecker):
"https://www.mindspore.cn/install")
return
# 读取版本文件中的版本信息
v = self._read_version(self.fwk_version)
# 如果读取的版本不在支持的版本列表中,则记录警告信息
if v not in self.version:
v_list = str([x for x in self.version])
logger.warning(f"MindSpore version {__version__} and Ascend AI software package (Ascend Data Center "
f"Solution)version {v} does not match, the version of software package expect one of "
f"{v_list}, please reference to the match info on: https://www.mindspore.cn/install")
def check_deps_version(self):
"""
te, topi, hccl wheel package version check
in order to update the change of 'LD_LIBRARY_PATH' env, run a sub process
"""
# 构建输入参数列表包含mindspore版本和受支持的版本列表
input_args = ["--mindspore_version=" + __version__]
for v in self.version:
input_args.append("--supported_version=" + v)
# 获取依赖版本检查脚本的路径
deps_version_checker = os.path.join(os.path.split(os.path.realpath(__file__))[0],
"_check_deps_version.py")
# 构建调用命令包括python解释器路径、脚本路径和输入参数
call_cmd = [sys.executable, deps_version_checker] + input_args
try:
# 运行子进程进行版本检查设置超时时间为3秒并捕获输出
process = subprocess.run(call_cmd, timeout=3, text=True, capture_output=True, check=False)
# 如果子进程的输出不为空,则记录警告信息并进行倒计时提醒
if process.stdout.strip() != "":
logger.warning(process.stdout.strip())
warning_countdown = 3
for i in range(warning_countdown, 0, -1):
logger.warning(f"Please pay attention to the above warning, countdown: {i}")
time.sleep(1)
# 如果版本检查超时,则记录信息并跳过
except subprocess.TimeoutExpired:
logger.info("Package te, topi, hccl version check timed out, skip.")
def set_env(self):
# 设置Ascend环境变量
if not self.tbe_path:
self._check_env()
return
try:
import te # pylint: disable=unused-import
# pylint: disable=broad-except
@ -329,32 +404,35 @@ class AscendEnvChecker(EnvChecker):
raise EnvironmentError(
f"No such directory: {self.tbe_path}, Please check if Ascend AI software package (Ascend Data "
"Center Solution) is installed correctly.")
# check te version after set te env
# 检查te版本
self.check_deps_version()
# 设置op实现路径环境变量
if Path(self.op_impl_path).is_dir():
# python path for sub process
# python路径用于子进程
if os.getenv('PYTHONPATH'):
os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH']
else:
os.environ['PYTHONPATH'] = self.op_impl_path
# sys path for this process
# sys路径用于当前进程
sys.path.append(self.op_impl_path)
os.environ['TBE_IMPL_PATH'] = self.op_impl_path
else:
raise EnvironmentError(
f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data "
"Center Solution) is installed correctly.")
f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.")
# 设置CCE路径环境变量
if Path(self.cce_path).is_dir():
os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH']
else:
raise EnvironmentError(
f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.")
# 设置OP路径环境变量
if self.op_path is None:
pass
elif Path(self.op_path).is_dir():
@ -363,7 +441,8 @@ class AscendEnvChecker(EnvChecker):
raise EnvironmentError(
f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center "
"Solution) is installed correctly.")
# 设置AICPU路径环境变量
if self.aicpu_path is None:
pass
elif Path(self.aicpu_path).is_dir():
@ -372,44 +451,54 @@ class AscendEnvChecker(EnvChecker):
raise EnvironmentError(
f"No such directory: {self.aicpu_path}, Please check if Ascend AI software package (Ascend Data Center"
" Solution) is installed correctly.")
def _check_env(self):
"""ascend dependence path check"""
# 检查是否设置正确的PATH环境变量
if self.path is None or self.path_check not in self.path:
logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env "
"PATH, you can reference to the installation guidelines https://www.mindspore.cn/install")
# 检查是否设置正确的PYTHONPATH环境变量
if self.python_path is None or self.python_path_check not in self.python_path:
logger.warning(
"Can not find tbe op implement(need by mindspore-ascend), please check if you have set env "
"PYTHONPATH, you can reference to the installation guidelines "
"https://www.mindspore.cn/install")
# 检查是否设置正确的LD_LIBRARY_PATH环境变量
if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and
self.ld_lib_path_check_addons in self.ld_lib_path):
logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env "
"LD_LIBRARY_PATH, you can reference to the installation guidelines "
"https://www.mindspore.cn/install")
# 检查是否设置正确的ASCEND_OPP_PATH环境变量
if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path:
logger.warning(
"Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, "
"you can reference to the installation guidelines https://www.mindspore.cn/install")
def _read_version(self, file_path):
"""get ascend version info"""
with open(file_path, 'r') as f:
all_info = f.readlines()
# 遍历文件中的每一行
for line in all_info:
# 检查行是否以 "Version=" 开头
if line.startswith("Version="):
# 去除行末的换行符并按 "=" 分割, 获取版本号
full_version = line.strip().split("=")[1]
# 提取主版本号和次版本号, 并用 "." 连接
self.v = '.'.join(full_version.split('.')[0:2])
# 返回版本号
return self.v
# 如果未找到版本信息, 返回 None 或默认值
return self.v
def check_version_and_env_config():
"""check version and env config"""
"""检查版本和环境配置"""
# 检查包名以确定使用哪种环境检查器
if __package_name__.lower() == "mindspore-ascend":
env_checker = AscendEnvChecker()
# Note: pre-load libgomp.so to solve error like "cannot allocate memory in statis TLS block"
@ -425,19 +514,21 @@ def check_version_and_env_config():
else:
logger.info(f"Package version {__package_name__} does not need to check any environment variable, skipping.")
return
# 检查是否关闭版本检查,如果已关闭则直接返回
if os.getenv("MS_DEV_CLOSE_VERSION_CHECK") == "ON":
return
# 设置环境变量以关闭版本检查
os.environ["MS_DEV_CLOSE_VERSION_CHECK"] = "ON"
try:
# check version of ascend site or cuda
# 检查 ascend site 或 cuda 的版本
env_checker.check_version()
from .. import _c_expression # pylint: disable=unused-import
# 设置环境
env_checker.set_env()
except ImportError as e:
# 处理导入错误,检查环境
env_checker.check_env(e)
def _set_pb_env():
"""Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow."""
if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp":
@ -449,7 +540,9 @@ def _set_pb_env():
logger.info("Setting the env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python` to prevent memory overflow "
"during save or load checkpoint file.")
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# 检查版本和环境配置
check_version_and_env_config()
# 设置协议缓冲区的环境变量, 防止内存溢出
_set_pb_env()

@ -26,23 +26,26 @@ def _check_mul():
"""
from importlib import import_module
import numpy as np
try:
ms = import_module("mindspore")
except ModuleNotFoundError:
ms = None
finally:
pass
# 打印MindSpore版本信息
print(f"MindSpore version: ", ms.__version__)
# 创建两个MindSpore张量分别包含数组[1.0, 2.0, 3.0]和[4.0, 5.0, 6.0]
input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32)
input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32)
# 创建一个乘法操作对象
mul = ms.ops.Mul()
# 执行乘法操作
mul(input_x, input_y)
# 打印乘法计算结果正确MindSpore安装成功的信息
print(f"The result of multiplication calculation is correct, MindSpore has been installed successfully!")
def run_check():
"""
Provide a convenient API to check if the installation is successful or failed.
@ -55,10 +58,13 @@ def run_check():
The result of multiplication calculation is correct, MindSpore has been installed successfully!
"""
try:
# 尝试执行检查乘法操作的函数
_check_mul()
# pylint: disable=broad-except
# 捕获所有异常并打印错误信息
except Exception as e:
print("MindSpore running check failed.")
print(str(e))
finally:
pass
# 无论是否发生异常,都会执行此部分代码
pass # 执行乘法检查的函数,并处理可能的异常情况。如果检查失败,打印错误信息。

@ -1,3 +1,6 @@
这段代码是一个Python类的实现名为`MindData`它是一个用于模拟MindSpore框架中数据集处理的桩Stub下面是对这段代码的逐行注释
```python
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -12,77 +15,94 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''Remove after MindData merge to MindSpore '''
import numpy as np
from mindspore import Tensor
class MindData:
""" Stub for MindData """
# 构造函数初始化MindData类的实例
def __init__(self, size=1, batch_size=None, repeat_count=1,
np_types=None, output_shapes=None, input_indexs=()):
self._size = size
self._batch_size = batch_size
self._repeat_count = repeat_count
self._np_types = np_types
self._output_shapes = output_shapes
self._input_indexs = input_indexs
self._iter_num = 0
self.dynamic_setting = [False, None]
self._size = size # 数据集的大小
self._batch_size = batch_size # 批处理大小
self._repeat_count = repeat_count # 重复次数
self._np_types = np_types # NumPy数据类型
self._output_shapes = output_shapes # 输出形状
self._input_indexs = input_indexs # 输入索引
self._iter_num = 0 # 迭代次数计数器
self.dynamic_setting = [False, None] # 动态设置标志和值
# 获取数据集大小
def get_dataset_size(self):
return self._size
# 获取重复次数
def get_repeat_count(self):
return self._repeat_count
# 获取批处理大小
def get_batch_size(self):
return self._batch_size
# 获取输出数据类型
def output_types(self):
return self._np_types
# 获取输出形状
def output_shapes(self):
return self._output_shapes
# 输入索引属性
@property
def input_indexs(self):
return self._input_indexs
# 设备队列设置
def device_que(self, send_epoch_end=True, create_data_info_queue=False):
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
self.send_epoch_end = send_epoch_end
return self
# 创建元组迭代器
def create_tuple_iterator(self, num_epochs=-1, do_copy=True):
return self.__iter__()
# 发送数据
def send(self, num_epochs=-1):
pass
# 停止发送数据
def stop_send(self):
pass
# 释放资源
def release(self):
pass
# 继续发送数据
def continue_send(self):
pass
# 获取数据信息
def get_data_info(self):
pass
# 动态最小最大形状
def dynamic_min_max_shapes(self):
pass
# 获取长度
def __len__(self):
return self._size
# 迭代器
def __iter__(self):
return self
# 获取下一个元素
def __next__(self):
if self._size < self._iter_num:
raise StopIteration
@ -90,11 +110,13 @@ class MindData:
next_value = []
for shape, typ in zip(self._output_shapes, self._np_types):
next_value.append(Tensor(np.ndarray(shape, typ)))
return tuple(next_value)
# 下一个元素
def next(self):
return self.__next__()
# 重置迭代器
def reset(self):
self._iter_num = 0

Loading…
Cancel
Save