diff --git a/qin.txt b/qin.txt new file mode 100644 index 00000000..b87bfcea --- /dev/null +++ b/qin.txt @@ -0,0 +1,38 @@ +import os +import stat +import fcntl +import platform +from logging.handlers import RotatingFileHandler + +class _MultiCompatibleRotatingFileHandler(RotatingFileHandler): + """Inherit RotatingFileHandler for multiprocess compatibility. + + 这个类继承自`RotatingFileHandler`,是为了在多进程环境下安全地使用日志回滚功能。 + 在多进程环境下,多个进程可能会同时尝试写入或回滚日志文件,这可能会导致文件损坏或数据丢失。 + 通过在这个类中对相关方法进行重写,确保了日志文件在多进程环境下的正确处理。 + """ + + def doRollover(self): + """Override doRollover for multiprocess compatibility + and setting permission of Log file + + 这个方法重写了`RotatingFileHandler`中的`doRollover`方法,增加了多进程兼容性, + 并设置了日志文件的权限。 + + 1. 使用`fcntl`模块获得独占锁,确保在回滚日志文件时不会有其他进程进行写操作。 + 2. 设置日志文件的权限,以确保日志文件的安全性。 + 3. 调用父类的`doRollover`方法执行实际的日志回滚操作。 + 4. 回滚后,修改日志文件的权限,使其可读可写。 + """ + + # Attain an exclusive lock with blocking mode by `fcntl` module. + with open(self.baseFilename, 'a') as file_pointer: + # 如果操作系统不是Windows,使用`fcntl`模块对文件加锁 + if platform.system() != "Windows": + fcntl.lockf(file_pointer.fileno(), fcntl.LOCK_EX) + # 设置日志文件权限为只读,增加安全性 + os.chmod(self.baseFilename, stat.S_IREAD) + # 调用父类的`doRollover`方法执行日志回滚操作 + super().doRollover() + # 修改日志文件的权限为可读可写,以便后续的日志写入操作 + os.chmod(self.baseFilename, stat.S_IREAD | stat.S_IWRITE) diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py b/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py index 05b3c51c..9af56409 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py @@ -18,7 +18,7 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from mindspore import log as logger from ._hccl_management import load_lib as hccl_load_lib from .._c_expression import get_rank_id, get_rank_size - + _HCCL_AVAILABLE = False _HCCL_TEST_AVAILABLE = False _NCCL_AVAILABLE = False @@ -28,14 +28,15 @@ try: _NCCL_AVAILABLE = True except ImportError: _NCCL_AVAILABLE = False - - + +# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True,否则捕获 RuntimeError 并设置为 False try: hccl_load_lib() _HCCL_AVAILABLE = True except RuntimeError: _HCCL_AVAILABLE = False - + +# 如果 HCCL 可用,则导入 _hccl_management 并尝试导入 mindspore._ascend_mpi,如果成功则设置 _MPI_AVAILABLE 为 True,否则捕获 ImportError 并设置为 False if _HCCL_AVAILABLE: from . import _hccl_management as hccl try: @@ -43,6 +44,7 @@ if _HCCL_AVAILABLE: _MPI_AVAILABLE = True except ImportError: _MPI_AVAILABLE = False +# 如果 HCCL 不可用,则尝试导入 hccl_test.manage.api,如果成功则设置 _HCCL_AVAILABLE 和 _HCCL_TEST_AVAILABLE 为 True,否则捕获 ImportError 并设置 _HCCL_AVAILABLE 为 False else: try: import hccl_test.manage.api as hccl @@ -50,12 +52,11 @@ else: _HCCL_TEST_AVAILABLE = True except ImportError: _HCCL_AVAILABLE = False - - + +# 定义 HCCL 和 NCCL 的通信组名称常量 HCCL_WORLD_COMM_GROUP = "hccl_world_group" NCCL_WORLD_COMM_GROUP = "nccl_world_group" - class Backend: """ Class for available backends. @@ -82,13 +83,17 @@ class Backend: def __new__(cls, name): """Create instance object of Backend.""" + # 检查传入的name是否为字符串类型 if not isinstance(name, str): raise TypeError("For 'Backend', the class variable 'name' must be a string, " "but got the type : {}".format(type(name))) + # 获取对应name的大写形式的Backend类属性值,如果不存在则返回Backend.UNDEFINED value = getattr(Backend, name.upper(), Backend.UNDEFINED) + # 如果获取到的值是Backend.UNDEFINED,说明传入的name不被支持 if value == Backend.UNDEFINED: raise ValueError("For 'Backend', the class variable 'name' {} is not supported, " "please use hccl or nccl.".format(name)) + # 返回获取到的Backend类属性值 return value DEFAULT_BACKEND = Backend("hccl") @@ -97,7 +102,7 @@ DEFAULT_BACKEND = Backend("hccl") class GlobalComm: """ World communication information. The GlobalComm is a global class. The members contain: - + - BACKEND: The communication library used, using HCCL/NCCL. - WORLD_COMM_GROUP: Global communication domain. """ @@ -105,35 +110,34 @@ class GlobalComm: WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP INITED = False CHECK_ENVS = True - - + + class _ExistingGroup: """ The communication groups which exist in the progress. """ ITEMS = {} - - + + def is_hccl_available(): """ Check HCCL api is available. - + Returns: Boolean. Return whether HCCL is available or not. """ return _HCCL_AVAILABLE - - + + def is_mpi_available(): """ Check HCCL & MPI api is available. - + Returns: Boolean. Return whether HCCL & MPI is available or not. """ return _MPI_AVAILABLE - def is_nccl_available(): """ Check NCCL api is available. @@ -158,19 +162,25 @@ def check_parameter_available(func): Wrapper. If not available, raise Error. """ def wrapper(*args, **kargs): + # 如果当前角色是参数服务器或者调度器,直接调用被装饰的函数 if _is_role_pserver() or _is_role_sched(): return func(*args, **kargs) + # 检查分布式通信是否已经初始化,未初始化则抛出异常 if not GlobalComm.INITED: raise RuntimeError("Distributed Communication has not been inited") + # 获取参数组,默认为None group = None + # 检查关键字参数中是否包含"group",并进行类型检查 if "group" in kargs.keys(): group = kargs.get("group") if group is not None and not isinstance(group, str): raise TypeError("The parameter 'group' should be str or None, " "but got the type : {}".format(type(group))) + # 获取后端,默认为None if "backend" in kargs.keys(): backend = kargs.get("backend") + # 检查后端是否可用,不可用则抛出异常 if backend is Backend.HCCL and not is_hccl_available(): raise RuntimeError("Distributed Communication doesn't have HCCL built in") if backend is Backend.HCCL_MPI and not is_mpi_available(): @@ -178,15 +188,16 @@ def check_parameter_available(func): if backend is Backend.NCCL and not is_nccl_available(): raise RuntimeError("Distributed Communication doesn't have NCCL built in") + # 如果未指定group,根据backend设置默认的group if group is None: - if backend is Backend.HCCL or Backend.HCCL_MPI: + if backend is Backend.HCCL or backend is Backend.HCCL_MPI: group = HCCL_WORLD_COMM_GROUP elif backend is Backend.NCCL: group = NCCL_WORLD_COMM_GROUP + # 调用被装饰的函数 return func(*args, **kargs) return wrapper - @check_parameter_available def _get_rank_helper(group, backend): """ @@ -202,10 +213,13 @@ def _get_rank_helper(group, backend): Returns: Integer. The local rank id of the calling process. """ + # 辅助函数,用于根据不同的后端和组获取 rank_id + # 获取当前角色的rank_id,如果是参数服务器或调度器角色,rank_id设为0并返回 rank_id = None if _is_role_pserver() or _is_role_sched(): rank_id = 0 return rank_id + # 根据不同的后端获取rank_id if backend == Backend.HCCL_MPI: rank_id = mpi.get_rank_id(group) elif backend == Backend.HCCL: @@ -216,6 +230,7 @@ def _get_rank_helper(group, backend): elif backend == Backend.NCCL: rank_id = get_rank_id(group) else: + # 如果后端不被支持,抛出ValueError异常 raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, " "please use hccl_mpi, hccl or nccl.".format(backend)) return rank_id @@ -236,6 +251,7 @@ def _get_local_rank_helper(group, backend): Returns: Integer. The local rank id of the calling process. """ + # 获取当前进程的rank id,根据不同的后端和组进行处理 rank_id = None if backend == Backend.HCCL_MPI: rank_id = mpi.get_rank_id(group) @@ -251,7 +267,6 @@ def _get_local_rank_helper(group, backend): "please use hccl_mpi or hccl.".format(backend)) return rank_id - @check_parameter_available def _get_size_helper(group, backend): """ @@ -268,9 +283,11 @@ def _get_size_helper(group, backend): Integer. The rank size of specified group. """ size = None + # 如果当前角色是参数服务器或调度器,则将size设为1并返回 if _is_role_pserver() or _is_role_sched(): size = 1 return size + # 根据不同的后端设置size的值 if backend == Backend.HCCL_MPI: size = mpi.get_rank_size(group) elif backend == Backend.HCCL: @@ -302,19 +319,23 @@ def _get_local_size_helper(group, backend): Integer. The local rank size where the calling process is being within specified group. """ size = None + # 根据不同的后端获取本地进程组的大小 if backend == Backend.HCCL: + # 如果组是全局通信组,则获取全局通信组的本地排名大小 if group == HCCL_WORLD_COMM_GROUP: size = hccl.get_local_rank_size() + # 否则获取指定组的本地排名大小 else: size = hccl.get_local_rank_size(group) + # NCCL后端不支持获取本地排名大小,抛出异常 elif backend == Backend.NCCL: raise RuntimeError("Nccl doesn't support get_local_rank_size now.") + # 对于不支持的后端,抛出异常 else: raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, " "please use hccl.".format(backend)) return size - @check_parameter_available def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend): """ @@ -333,21 +354,26 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend): Integer. A rank id in world communication group. """ world_rank_id = None + # 检查 group_rank_id 是否为整数类型,如果不是则抛出 TypeError if not isinstance(group_rank_id, int): raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be" " type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id))) + # 根据不同的后端选择不同的逻辑处理方式 if backend == Backend.HCCL: + # 如果在 GPU 上使用 HCCL,但 group 参数为 HCCL_WORLD_COMM_GROUP,则抛出 ValueError if group == HCCL_WORLD_COMM_GROUP: raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' " "should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.") + # 调用 hccl.get_world_rank_from_group_rank 方法获取 world_rank_id world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id) elif backend == Backend.NCCL: + # 如果使用 NCCL,则抛出 RuntimeError 表示不支持该操作 raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.") else: + # 如果 backend 参数不支持,则抛出 ValueError 表示不支持该后端 raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend)) + # 返回获取的 world_rank_id return world_rank_id - - @check_parameter_available def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend): """ @@ -366,21 +392,27 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend): Integer. A rank id in user communication group. """ group_rank_id = None + # 检查 world_rank_id 是否为整数类型,如果不是则抛出 TypeError if not isinstance(world_rank_id, int): raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, " "but got 'world_rank_id' type : {}.".format(type(world_rank_id))) + # 根据不同的后端处理获取 group_rank_id 的逻辑 if backend == Backend.HCCL: + # 检查 GPU 后端的 group 参数是否正确,如果不正确则抛出 ValueError if group == HCCL_WORLD_COMM_GROUP: raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' " "should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.") + # 调用 hccl 模块的函数获取 group_rank_id group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group) + # NCCL 后端不支持此操作,抛出 RuntimeError elif backend == Backend.NCCL: raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.") + # 如果后端不是 HCCL 或 NCCL,则抛出 ValueError 表示不支持的后端 else: raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend)) + # 返回获取到的 group_rank_id return group_rank_id - @check_parameter_available def _create_group_helper(group, rank_ids, backend): """ @@ -395,34 +427,46 @@ def _create_group_helper(group, rank_ids, backend): TypeError: If rank_ids is not a list. ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. """ + # 检查组是否已存在 if group in _ExistingGroup.ITEMS.keys(): + # 如果组已存在且提供的rank_ids与存储的不一致,抛出异常 if rank_ids != _ExistingGroup.ITEMS[group]: raise ValueError("The group {} has been created, the rank_list is {}, " "but current rank_list for the group is {}". format(group, _ExistingGroup.ITEMS[group], rank_ids)) + # 记录警告信息,提示组已存在 logger.warning("%r group has existed.", group) return + + # 根据不同的后端创建组 if backend == Backend.HCCL: + # 检查rank_ids是否为列表类型 if not isinstance(rank_ids, list): raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " "but got 'rank_ids' type : {}.".format(type(rank_ids))) + # 检查rank_ids的长度是否大于1 rank_size = len(rank_ids) if rank_size < 1: raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, " "but got 'rank_ids' size : {}.".format(len(rank_ids))) + # 检查rank_ids中是否有重复的元素 if len(rank_ids) - len(list(set(rank_ids))) > 0: raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) + # 使用HCCL创建组 hccl.create_group(group, rank_size, rank_ids) elif backend == Backend.HCCL_MPI: + # 使用HCCL_MPI创建组 mpi.create_group(group, rank_ids) elif backend == Backend.NCCL: + # NCCL暂不支持创建组,抛出异常 raise RuntimeError("Nccl doesn't support create_group now.") else: + # 如果后端不支持,抛出异常 raise ValueError("The context configuration parameter 'backend' {} is not supported, " "please use hccl.".format(backend)) - _ExistingGroup.ITEMS[group] = rank_ids - + # 将新创建的组及其rank_ids添加到_existingGroup中 + _ExistingGroup.ITEMS[group] = rank_ids @check_parameter_available def _destroy_group_helper(group, backend): """ @@ -435,12 +479,17 @@ def _destroy_group_helper(group, backend): Raises: ValueError: If group is "hccl_world_group" or backend is invalid. """ + # 根据后端类型销毁通信组 if backend == Backend.HCCL: + # 检查是否为 HCCL 的全局通信组 if group == HCCL_WORLD_COMM_GROUP: raise ValueError("The hccl_world_group does not support destruction.") + # 销毁指定的 HCCL 通信组 hccl.destroy_group(group) elif backend == Backend.NCCL: + # 当前 NCCL 后端不支持销毁通信组 raise RuntimeError("Nccl doesn't support destroy_group now.") else: + # 抛出错误,表示不支持的后端类型 raise ValueError("The context configuration parameter 'backend' {} is not supported, " - "please use hccl.".format(backend)) + "please use hccl.".format(backend)) \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py b/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py index 92dbd698..8307a839 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py @@ -24,7 +24,7 @@ MAX_RANK_NUM = 4096 HCCL_LIB = 'libhccl_plugin.so' HCCL_LIB_CTYPES = "" - +# 检查集体通信组的名称是否合法 def check_group(group): """ A function that check if a collection communication group is legal. @@ -41,7 +41,7 @@ def check_group(group): raise TypeError("The type of communication group name must be type of string, " "but got 'group' type : {}.".format(type(group))) - +# 检查集体通信中的排名编号是否合法 def check_rank_num(rank_num): """ A function that check if a collection communication rank number is legal.If not raise error. @@ -57,7 +57,7 @@ def check_rank_num(rank_num): raise TypeError("The argument 'rank_num' must be type of int, " "but got 'rank_num' type : {}.".format(type(rank_num))) - +#检查集体通信中的排名标识(rank id)是否合法 def check_rank_id(rank_id): """ A function that check if a collection communication rank id is legal.If not raise error. @@ -65,15 +65,18 @@ def check_rank_id(rank_id): Returns: None """ + # 检查rank_id是否为整数类型 if isinstance(rank_id, (int)): + # 检查rank_id是否在有效范围内 if rank_id >= MAX_RANK_NUM or rank_id < 0: raise ValueError("The rand id in the communication group must be greater or equal 0 and " "less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id)) else: + # 如果rank_id不是整数类型,抛出类型错误 raise TypeError("The rand id in the communication group must be must be type of int, " "but got type value : {}.".format(type(rank_id))) - +# 加载 HCCL(Huawei Cloud Communication Library)库 def load_lib(): """load hccl lib""" try: @@ -82,11 +85,10 @@ def load_lib(): hccl_lib = ctypes.CDLL(lib_path) except Exception: raise RuntimeError('Get hccl lib error.') - + global HCCL_LIB_CTYPES HCCL_LIB_CTYPES = hccl_lib - def c_str(string): """Convert a python string to C string.""" if not isinstance(string, str): @@ -98,7 +100,7 @@ def c_array(ctype, values): """Create ctypes array from a python array.""" return (ctype * len(values))(*values) - +#用于创建包含指定数量和ID的HCCL通信组,但不能创建世界组。 def create_group(group, rank_num, rank_ids): """ Create group. @@ -112,28 +114,38 @@ def create_group(group, rank_num, rank_ids): Returns: None """ + # 检查组的有效性 check_group(group) + # 检查排名数量的有效性 check_rank_num(rank_num) + # 检查rank_ids是否为列表类型 if isinstance(rank_ids, (list)): + # 确保rank_num与rank_ids的长度一致 if rank_num != len(rank_ids): raise ValueError("The argument 'rank_num' number should be equal to the length " "of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}." .format(rank_num, rank_ids)) + # 检查rank_ids中的每个元素是否为非负整数 for rank_id in rank_ids: if not isinstance(rank_id, (int)) or rank_id < 0: raise ValueError("The elements of argument 'rank_ids' must be " "unsigned integer, but got the type : {}".format(type(rank_id))) + # 将rank_ids转换为C类型的数组 c_array_rank_ids = c_array(ctypes.c_uint, rank_ids) + # 将rank_num转换为C类型的无符号整数 c_rank_num = ctypes.c_uint(rank_num) + # 将group转换为C类型的字符串 c_group = c_str(group) + # 调用HCCL库创建组 ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) + # 检查组创建是否成功 if ret != 0: raise RuntimeError('Create group error, the error code is {}.'.format(ret)) else: + # 如果rank_ids不是列表类型,抛出类型错误 raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " "but got 'rank_ids' type : {}.".format(type(rank_ids))) - - +#用于销毁用户创建的HCCL组 def destroy_group(group): """ A function that destroy the group which created by user. @@ -162,16 +174,23 @@ def get_rank_size(group="hccl_world_group"): An integer scalar with the num of ranks. """ + # 根据上下文的模式判断是否为PYNATIVE_MODE模式,若是,则直接返回HCCL的rank size if context.get_context("mode") == context.PYNATIVE_MODE: return get_hccl_rank_size() - + + # 检查给定的组是否有效 check_group(group) + # 将组名转换为C字符串格式 c_group = c_str(group) + # 定义一个C类型的无符号整数用于存储rank size c_rank_size = ctypes.c_uint() + # 调用HCCL库的HcomGetRankSize函数获取组内的rank size ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size)) + # 如果返回值不为0,表示获取rank size失败,抛出运行时错误 if ret != 0: raise RuntimeError('Get rank size error.') - + + # 返回获取到的rank size值 return c_rank_size.value @@ -186,17 +205,22 @@ def get_rank_id(group="hccl_world_group"): if context.get_context("mode") == context.PYNATIVE_MODE: return get_hccl_rank_id() + # 检查组的有效性 check_group(group) + # 将组转换为 C 字符串格式 c_group = c_str(group) + # 定义一个用于存储 rank id 的 ctypes 无符号整数 c_rank_id = ctypes.c_uint() + # 调用 HCCL 库获取当前进程的 rank id ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id)) + # 如果返回值不为 0,表示获取 rank id 出错,抛出 RuntimeError 异常 if ret != 0: raise RuntimeError('Get rank id error.') + # 返回获取到的 rank id 值 return c_rank_id.value - def get_local_rank_size(group="hccl_world_group"): """ A function that returns the number of local ranks within the given collection communication group. @@ -207,19 +231,25 @@ def get_local_rank_size(group="hccl_world_group"): Returns: An integer scalar with the num of local ranks. """ + # 检查当前运行模式是否为PYNATIVE_MODE,如果是则抛出异常 if context.get_context("mode") is context.PYNATIVE_MODE: raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, " "'get_local_rank_size' only support GRAPH_MODE") + # 验证传入的组是否有效 check_group(group) + # 将组名称转换为C字符串格式 c_group = c_str(group) + # 定义一个ctypes的无符号整数变量,用于存储本地排名大小 c_local_rank_size = ctypes.c_uint() + # 调用HCCL库中的HcomGetLocalRankSize函数获取本地排名大小 ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size)) + # 如果返回值不为0,说明获取本地排名大小时出错,抛出异常 if ret != 0: raise RuntimeError('Get local rank size error.') + # 返回获取到的本地排名大小 return c_local_rank_size.value - def get_local_rank_id(group="hccl_world_group"): """ Get local rank id. @@ -230,16 +260,23 @@ def get_local_rank_id(group="hccl_world_group"): An integer scalar with the local rank id of the calling process. """ + # 检查当前运行模式是否为PYNATIVE_MODE,如果是则抛出异常 if context.get_context("mode") is context.PYNATIVE_MODE: raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, " "'get_local_rank_id' only support GRAPH_MODE") + # 验证群组的有效性 check_group(group) + # 将群组名称转换为C字符串格式 c_group = c_str(group) + # 定义一个无符号整型的C类型变量来存储本地排名ID c_local_rank_id = ctypes.c_uint() + # 调用HCCL库的HcomGetLocalRankId函数获取本地排名ID ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id)) + # 如果返回值不为0,表示获取本地排名ID失败,抛出异常 if ret != 0: raise RuntimeError('Get local rank id error.') + # 返回获取到的本地排名ID值 return c_local_rank_id.value @@ -256,18 +293,25 @@ def get_world_rank_from_group_rank(group, group_rank_id): if context.get_context("mode") is context.PYNATIVE_MODE: raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, " "'get_world_rank_from_group_rank' only support GRAPH_MODE") + # 检查组名是否有效 check_group(group) + # 检查组内rank ID是否有效 check_rank_id(group_rank_id) + # 将组名转换为C字符串格式 c_group = c_str(group) + # 将组内rank ID转换为C的无符号整数类型 c_group_rank_id = ctypes.c_uint(group_rank_id) + # 声明一个用于存储世界rank ID的C的无符号整数类型变量 c_world_rank_id = ctypes.c_uint() + # 调用HCCL库中的HcomGetWorldRankFromGroupRank函数,根据组名和组内rank ID获取对应的世界rank ID ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id)) + # 如果返回值不为0,说明函数调用出错,抛出RuntimeError异常 if ret != 0: - raise RuntimeError('Get world rank from group rank error.') + raise RuntimeError('根据组内rank ID获取世界rank ID时出错。') + # 返回获取到的世界rank ID的值 return c_world_rank_id.value - def get_group_rank_from_world_rank(world_rank_id, group): """ Get group rank from world rank. @@ -281,13 +325,21 @@ def get_group_rank_from_world_rank(world_rank_id, group): if context.get_context("mode") is context.PYNATIVE_MODE: raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, " "'get_group_rank_from_world_rank' only support GRAPH_MODE") + # 检查组的有效性 check_group(group) + # 检查世界排名ID的有效性 check_rank_id(world_rank_id) + # 将组转换为C字符串 c_group = c_str(group) + # 将世界排名ID转换为C无符号整数 c_world_rank_id = ctypes.c_uint(world_rank_id) + # 创建一个用于存储组排名ID的C无符号整数 c_group_rank_id = ctypes.c_uint() + # 从世界排名ID获取组排名ID ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id)) + # 如果返回值不为0,抛出运行时错误 if ret != 0: raise RuntimeError('Get group rank from world rank error.') - return c_group_rank_id.value + # 返回获取到的组排名ID的值 + return c_group_rank_id.value \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/management.py b/src/mindspore2022/mindspore/python/mindspore/communication/management.py index a64276af..a00dee56 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/management.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/management.py @@ -63,25 +63,31 @@ def _check_parallel_envs(): Raises: RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values. """ + # 检查是否需要进行环境验证,如果不进行则直接返回 if not GlobalComm.CHECK_ENVS: return import os + # 获取环境变量RANK_ID的值 rank_id_str = os.getenv("RANK_ID") + # 如果RANK_ID未设置,抛出运行时错误 if not rank_id_str: raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.") try: + # 尝试将RANK_ID转换为整数 int(rank_id_str) except ValueError: + # 如果转换失败,打印错误信息 print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str))) finally: + # 无论是否发生异常,此块为空操作 pass + # 获取环境变量MINDSPORE_HCCL_CONFIG_PATH和RANK_TABLE_FILE的值 rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH") rank_table_file_str_old = os.getenv("RANK_TABLE_FILE") + # 如果两个环境变量都未设置,抛出运行时错误 if not rank_table_file_str and not rank_table_file_str_old: raise RuntimeError("Get hccl rank_table_file failed, " "please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.") - - def init(backend_name=None): """ Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service. @@ -109,15 +115,19 @@ def init(backend_name=None): >>> from mindspore.communication import init >>> init() """ + # 检查当前角色是否为参数服务器或调度器,如果是则直接返回 if _is_role_pserver() or _is_role_sched(): return + # 检查任务接收环境变量,获取设备目标和模式 task_sink = _check_task_sink_envs() device_target = context.get_context("device_target") mode = context.get_context("mode") mpi_init = False + # 如果没有任务接收且模式为图模式,设置mpi_init为True if not task_sink and mode == context.GRAPH_MODE: mpi_init = True + # 根据设备目标选择后端名称,如果不支持则抛出异常 if backend_name is None: if device_target == "Ascend": backend_name = "hccl" @@ -126,28 +136,35 @@ def init(backend_name=None): else: raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in " "parallel initialization, please use Ascend or GPU.".format(device_target)) + # 检查后端名称是否为字符串,如果不是则抛出异常 if not isinstance(backend_name, str): raise TypeError("For 'init', the argument 'backend_name' must be a string, " "but got the type : {}".format(type(backend_name))) + # 根据后端名称初始化通信环境 if backend_name == "hccl": + # 如果设备目标不是Ascend,抛出异常 if device_target != "Ascend": raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, " "but got {}".format(device_target)) + # 如果不需要MPI初始化,检查并行环境并设置后端名称 if not mpi_init: _check_parallel_envs() GlobalComm.BACKEND = Backend("hccl") else: GlobalComm.BACKEND = Backend("hccl_mpi") + # 初始化HCCL并设置全局通信组和初始化状态 init_hccl() GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP GlobalComm.INITED = True elif backend_name == "nccl": + # 初始化GPU集体通信并设置后端名称、全局通信组和初始化状态 init_gpu_collective() GlobalComm.BACKEND = Backend("nccl") GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP GlobalComm.INITED = True else: + # 如果后端名称不支持,抛出异常 raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, " "but got the 'backend_name' : hccl.") diff --git a/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py b/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py index 32fcb19d..537c82cd 100644 --- a/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py +++ b/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py @@ -45,14 +45,15 @@ class OneOf(OneOf_): TypeError: raise type error for invalid inputs. """ self.patterns = patterns + # 检查 patterns 是否是 Pattern 类的实例 if isinstance(patterns, Pattern): OneOf_.__init__(self, [patterns]) + # 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例 elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): OneOf_.__init__(self, patterns) + # 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常 else: raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") - - class Prim(Prim_): r""" Express a pattern of certain primitive type(s). @@ -76,25 +77,33 @@ class Prim(Prim_): Raises: TypeError: raise type error for invalid argument. """ + # 检查name是否为字符串类型,如果不是则抛出TypeError if name is not None and not isinstance(name, str): raise TypeError(f"Expect string, got : {name}") self.name = name + # 如果types是字符串类型,则将其按'|'分割成列表 if isinstance(types, str): if self.name is None: self.name = types self.types = types.split('|') + # 如果types是Primitive类型,则直接将其放入列表中 elif isinstance(types, Primitive): if self.name is None: self.name = types.name self.types = [types] + # 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型 elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): + # 如果 self.name 为 None,则初始化为空字符串并拼接所有 Primitive 的 name if self.name is None: self.name = "" for prim in types: self.name += prim.name + # 设置 self.types 为传入的 types self.types = types + # 如果 types 不符合预期类型,抛出 TypeError else: raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") + # 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name Prim_.__init__(self, self.types, self.name) @@ -115,16 +124,22 @@ class Call(Call_): Raises: TypeError: raise type error for invalid argument. """ + # 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError if not isinstance(prim_pattern, (Pattern, str, Primitive)): raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") self.prim_pattern = prim_pattern + # 初始化 inputs 列表 self.inputs = [] + # 如果 inputs 为 None,则不做任何操作 if inputs is None: pass + # 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): self.inputs = inputs + # 如果 inputs 不符合上述条件,则抛出 TypeError else: raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") + # 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs Call_.__init__(self, self.prim_pattern, self.inputs) @@ -145,6 +160,7 @@ class NoneOf(NoneOf_): TypeError: raise type error for invalid argument. """ self.patterns = patterns + # 根据 patterns 的类型初始化 NoneOf_ 类 if patterns is None: NoneOf_.__init__(self, ()) elif isinstance(patterns, Pattern): @@ -154,7 +170,6 @@ class NoneOf(NoneOf_): else: raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") - class NewTensor(NewTensor_): r""" New Tensor to be used in the target. @@ -167,13 +182,16 @@ class NewTensor(NewTensor_): Raises: TypeError: raise type error for invalid argument. """ + # 初始化输入张量 self.input_tensor = input_tensor + # 检查输入是否为 Tensor 类型 if isinstance(input_tensor, Tensor): + # 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法 NewTensor_.__init__(self, input_tensor) else: + # 如果不是 Tensor 类型,则抛出 TypeError 异常 raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") - class NewParameter(NewParameter_): r""" New Parameter to be used in the target. @@ -193,11 +211,14 @@ class NewParameter(NewParameter_): self.default_tensor = default_tensor self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel + # 检查参数类型是否符合预期 if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ isinstance(layerwise_parallel, bool): + # 初始化 NewParameter_ 类 NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, self.layerwise_parallel) else: + # 如果参数类型不符合预期,抛出 TypeError raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ layerwise_parallel(bool), got : {para_name}, {default_tensor}, \ - {requires_grad}, {layerwise_parallel}") + {requires_grad}, {layerwise_parallel}") \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py b/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py index 445c36a9..2266248a 100644 --- a/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py +++ b/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py @@ -39,6 +39,7 @@ class PyPassManager(PyPassManager_): TypeError: If argument has invalid type. """ def __init__(self, requires_grad=True, run_only_once=False): + # 初始化方法,接收两个布尔参数,设置实例的属性并调用父类的初始化方法 if not isinstance(requires_grad, bool): raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}") if not isinstance(run_only_once, bool): @@ -48,17 +49,20 @@ class PyPassManager(PyPassManager_): PyPassManager_.__init__(self) def register(self, py_pass): + # 注册一个Python pass,检查其是否为函数类型,并获取其模式和目标 if not isfunction(py_pass): raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") pattern, target = py_pass() pass_name = py_pass.__name__ + # 检查模式和目标是否为Pattern类型 if not isinstance(pattern, Pattern): raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") if not isinstance(target, Pattern): raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") + # 调用父类的register方法,注册pass及其相关信息 super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_) - def unregister(self, py_pass): + # 从注册表中移除指定的Python传递对象,可以是字符串形式的名称或函数对象 if isinstance(py_pass, str): super().unregister(py_pass) return @@ -66,27 +70,30 @@ class PyPassManager(PyPassManager_): super().unregister(py_pass.__name__) return raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") - + def __call__(self, py_pass): + # 将Python传递对象注册到注册表中,并返回该对象 self.register(py_pass) return py_pass - + def gen_new_parameter(self, pattern): + # 根据给定的模式生成新的参数,模式必须是NewParameter类型 if not isinstance(pattern, NewParameter): raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") super().gen_new_parameter(pattern) - + def set_renorm(self, should_renorm): + # 设置是否进行重归一化操作,参数必须是布尔值 if not isinstance(should_renorm, bool): raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") super().set_renorm(should_renorm) - + def set_reopt(self, do_reopt): + # 设置是否进行重新优化操作,参数必须是布尔值 if not isinstance(do_reopt, bool): raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") super().set_reopt(do_reopt) - def register_pass(requires_grad=True, run_only_once=False): """ Register python pass to specified pipeline phase which would be used in compilation. @@ -165,12 +172,13 @@ def cancel_new_parameter(pattern): >>> # some compilations >>> cancel_new_parameter(abc) """ + # 检查传入的pattern是否为NewParameter的实例 if not isinstance(pattern, NewParameter): raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") + # 创建一个PyPassManager对象 ppm = PyPassManager() + # 从PyPassManager中注销指定名称的参数 ppm.unregister(pattern.para_name) - - def set_renorm(should_renorm): """ Set whether or not to do renormalization after modified graph in python pass(es). diff --git a/src/mindspore2022/mindspore/python/mindspore/run_check/_check_deps_version.py b/src/mindspore2022/mindspore/python/mindspore/run_check/_check_deps_version.py index 4e4426b2..65984f7d 100644 --- a/src/mindspore2022/mindspore/python/mindspore/run_check/_check_deps_version.py +++ b/src/mindspore2022/mindspore/python/mindspore/run_check/_check_deps_version.py @@ -19,22 +19,25 @@ import sys def parse_args(): """ parse args . - + Args: - + Returns: args. - + Examples: >>> parse_args() """ + # 创建一个ArgumentParser对象,用于解析命令行参数,描述信息为"MindSpore dependency packages version checker." parser = ArgumentParser(description="MindSpore dependency packages version checker.") + # 添加一个命令行参数--mindspore_version,类型为字符串,帮助信息为"MindSpore version." parser.add_argument("--mindspore_version", type=str, help="MindSpore version.") + # 添加一个命令行参数--supported_version,类型为字符串,可以多次指定,帮助信息为"Supported environment version." parser.add_argument("--supported_version", type=str, action='append', help="Supported environment version.") + # 解析命令行参数并返回结果 args = parser.parse_args() return args - def check_deps_version(mindspore_version, supported_version): """ check te/hccl/topi version @@ -46,6 +49,7 @@ def check_deps_version(mindspore_version, supported_version): Returns: void """ + # 尝试导入并检查 hccl、te 和 topi 包的版本是否与支持的版本匹配 try: from hccl import sys_version as hccl_version v = '.'.join(hccl_version.__sys_version__.split('.')[0:2]) @@ -63,18 +67,20 @@ def check_deps_version(mindspore_version, supported_version): print(f"MindSpore version {mindspore_version} and \"topi\" wheel package version {v} does not " "match, reference to the match info on: https://www.mindspore.cn/install") + # 捕获导入错误并打印相应的检查失败信息 except ImportError as e: print("CheckFailed: ", e.args) print("MindSpore relies on the 3 whl packages of \"te\", \"topi\" and \"hccl\" in the \"fwkacllib\" " "folder of the Ascend AI software package (Ascend Data Center Solution), please check whether they are " "installed correctly or not, reference to the match info on: https://www.mindspore.cn/install") - def main(): + # 解析命令行参数 args = parse_args() + # 检查 mindspore 的版本是否在支持的版本范围内 check_deps_version(args.mindspore_version, args.supported_version) - - + if __name__ == "__main__": - sys.path = sys.path[1:] # avoid the impact of relative path env, only affect this process + # 避免相对路径环境的影响,仅影响当前进程 + sys.path = sys.path[1:] main() diff --git a/src/mindspore2022/mindspore/python/mindspore/run_check/_check_version.py b/src/mindspore2022/mindspore/python/mindspore/run_check/_check_version.py index 5eb6f281..61ea0a8b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/run_check/_check_version.py +++ b/src/mindspore2022/mindspore/python/mindspore/run_check/_check_version.py @@ -31,47 +31,58 @@ class EnvChecker(metaclass=ABCMeta): @abstractmethod def check_env(self, e): - pass + """检查环境是否符合要求""" @abstractmethod def set_env(self): - pass + """设置环境""" @abstractmethod def check_version(self): - pass + """检查版本是否符合要求""" class GPUEnvChecker(EnvChecker): """GPU environment check.""" def __init__(self): + # 初始化版本列表 self.version = ["10.1", "11.1"] + # 初始化库键到库名的映射字典 self.lib_key_to_lib_name = {'libcu': 'libcuda.so'} # env + # 获取系统环境变量 PATH 的值 self.path = os.getenv("PATH") + # 获取系统环境变量 LD_LIBRARY_PATH 的值 self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") # check + # 初始化版本号为 "0" self.v = "0" + # 获取 CUDA 库的路径 self.cuda_lib_path = self._get_lib_path("libcu") + # 获取 CUDA 可执行文件的路径 self.cuda_bin_path = self._get_bin_path("cuda") + # 获取 cuDNN 库的路径 self.cudnn_lib_path = self._get_lib_path("libcudnn") - def check_env(self, e): + # 抛出传入的异常 e raise e - + def set_env(self): + # 设置环境变量,当前实现为空 return - + def _get_bin_path(self, bin_name): """Get bin path by bin name.""" + # 如果二进制名称为 "cuda",则调用获取 CUDA 二进制路径的方法 if bin_name == "cuda": return self._get_cuda_bin_path() + # 否则返回空列表 return [] def _get_cuda_bin_path(self): - """Get cuda bin path by lib path.""" + # Get cuda bin path by lib path. path_list = [] for path in self.cuda_lib_path: path = os.path.abspath(path.strip()+"/bin/") @@ -81,56 +92,87 @@ class GPUEnvChecker(EnvChecker): def _get_nvcc_version(self, is_set_env): """Get cuda version by nvcc command.""" + # 运行 nvcc 命令获取 CUDA 版本信息 nvcc_result = subprocess.run(["nvcc", "--version | grep release"], timeout=3, text=True, capture_output=True, check=False) + # 如果命令返回非零值,表示命令执行失败 if nvcc_result.returncode: + # 如果尚未设置环境变量 if not is_set_env: + # 遍历预设的 CUDA 二进制路径 for path in self.cuda_bin_path: + # 检查路径中是否存在 nvcc 文件 if Path(path + "/nvcc").is_file(): + # 将路径添加到环境变量 PATH 中 os.environ['PATH'] = path + ":" + os.environ['PATH'] + # 递归调用以重新尝试获取版本信息 return self._get_nvcc_version(True) + # 如果命令执行失败且未找到 nvcc 文件,返回空字符串 return "" + # 获取命令输出结果 result = nvcc_result.stdout + # 遍历输出结果的每一行 for line in result.split('\n'): if line: + # 提取并返回 CUDA 版本号 return line.strip().split("release")[1].split(",")[0].strip() + # 如果未找到版本信息,返回空字符串 return "" def _get_cudnn_version(self): """Get cudnn version by libcudnn.so.""" + # 初始化cudnn版本列表为空 cudnn_version = [] + # 遍历cudnn库路径 for path in self.cudnn_lib_path: + # 查找路径下所有的libcudnn.so文件 real_path = glob.glob(path + "/lib*/libcudnn.so.*.*") + # 如果没有找到对应的文件,继续下一个路径 if real_path == []: continue + # 使用ls命令获取文件信息 ls_cudnn = subprocess.run(["ls", real_path[0]], timeout=10, text=True, capture_output=True, check=False) + # 如果ls命令执行成功,解析输出以获取版本号 if ls_cudnn.returncode == 0: cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.') + # 如果版本号只有两个部分,添加一个'.0'作为第三部分 if len(cudnn_version) == 2: cudnn_version.append('0') + # 找到版本号后跳出循环 break + # 将版本号列表转换为字符串 version_str = ''.join([n for n in cudnn_version]) + # 返回版本号的前三位 return version_str[0:3] def _get_cudart_version(self): """Get cuda runtime version by libcudart.so.""" + # 遍历可能的 CUDA 库路径 for path in self.cuda_lib_path: + # 查找路径下所有可能的 libcudart.so 文件 real_path = glob.glob(path + "/lib*/libcudart.so.*.*.*") + # 如果没有找到任何文件,则跳过当前路径 if real_path == []: continue + # 获取文件名信息以确定 CUDA 版本 ls_cudart = subprocess.run(["ls", real_path[0]], timeout=10, text=True, capture_output=True, check=False) + # 如果命令成功执行,则解析输出以提取版本号 if ls_cudart.returncode == 0: self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip() + # 找到版本号后跳出循环 break + # 返回找到的 CUDA 版本号 return self.v def check_version(self): """Check cuda version.""" version_match = False + # 调用私有方法检查版本是否匹配,并根据结果设置version_match标志 if self._check_version(): version_match = True + # 如果版本不匹配,根据CUDA版本号输出不同的警告信息 if not version_match: if self.v == "0": logger.warning("Can not found cuda libs, please confirm that the correct " @@ -140,17 +182,20 @@ class GPUEnvChecker(EnvChecker): logger.warning(f"MindSpore version {__version__} and cuda version {self.v} does not match, " "please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install") + # 获取nvcc版本号,并检查是否与MindSpore支持的版本匹配 nvcc_version = self._get_nvcc_version(False) if nvcc_version and (nvcc_version not in self.version): logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} " "does not match, please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install") + # 获取cudnn版本号,并检查是否符合最低要求 cudnn_version = self._get_cudnn_version() if cudnn_version and int(cudnn_version) < 760: logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} " "does not match, please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install. The recommended version is " "CUDA10.1 with cuDNN7.6.x and CUDA11.1 with cuDNN8.0.x") + # 检查cudnn版本号与CUDA版本号的兼容性,对于CUDA 11.0以上版本,cudnn版本需要至少为8.0 if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10: logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} " "does not match, please refer to the installation guide for version matching " @@ -159,45 +204,58 @@ class GPUEnvChecker(EnvChecker): def _check_version(self): """Check cuda version""" + # 获取 CUDA 运行时版本 v = self._get_cudart_version() + # 解析版本字符串为版本对象 v = version.parse(v) + # 构造版本号字符串,格式为 "主版本.次版本" v_str = str(v.major) + "." + str(v.minor) + # 检查构造的版本号字符串是否在预定义的版本列表中 if v_str not in self.version: return False + # 版本号匹配,返回 True return True def _get_lib_path(self, lib_name): - """Get gpu lib path by ldd command.""" - path_list = [] - current_path = os.path.split(os.path.realpath(__file__))[0] - mindspore_path = os.path.join(current_path, "../") + """通过ldd命令获取gpu库路径。""" + path_list = [] # 初始化一个空列表用于存储路径 + current_path = os.path.split(os.path.realpath(__file__))[0] # 获取当前文件的绝对路径并分割以获取目录部分 + mindspore_path = os.path.join(current_path, "../") # 构建mindspore路径,通常是当前文件的上一级目录 try: + # 使用glob模块查找mindspore_path目录下所有以_c_expression.so开头的文件路径 real_path = glob.glob(mindspore_path + "/_c_expression*.so*") - if real_path == []: - logger.error(f"{self.lib_key_to_lib_name[lib_name]} (need by mindspore-gpu) is not found, please " - f"confirm that _c_expression.so is in directory:{mindspore_path} and the correct cuda " - "version has been installed, you can refer to the installation " - "guidelines: https://www.mindspore.cn/install") - return path_list + if real_path == []: # 如果没有找到任何文件 + # 记录错误日志,提示用户确认_c_expression.so文件是否存在以及是否安装了正确的cuda版本 + logger.error(f"{self.lib_key_to_lib_name[lib_name]} (mindspore-gpu所需的库) 未找到,请确认 " + f"_c_expression.so是否位于目录:{mindspore_path}中,并且已安装正确的cuda版本," + "您可以参考安装指南:https://www.mindspore.cn/install") + return path_list # 返回空路径列表 + # 使用subprocess.Popen执行ldd命令以获取依赖库的信息 ldd_r = subprocess.Popen(['ldd', real_path[0]], stdout=subprocess.PIPE) + # 使用subprocess.Popen的stdin参数从ldd_r.stdout接收输出,并执行grep命令以过滤出包含指定库名的信息 ldd_result = subprocess.Popen(['grep', lib_name], stdin=ldd_r.stdout, stdout=subprocess.PIPE) + # 获取grep命令的输出结果,并解码为字符串 result = ldd_result.communicate()[0].decode() - for i in result.split('\n'): + for i in result.split('\n'): # 按行分割结果字符串 + # 使用partition方法从每一行中提取出库文件的路径 path = i.partition("=>")[2] - if path.lower().find("not found") > 0: - logger.warning(f"Cuda {self.version} version(need by mindspore-gpu) is not found, please confirm " - "that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the " - "installation guidelines: https://www.mindspore.cn/install") - continue + if path.lower().find("not found") > 0: # 如果路径中包含"not found" + # 记录警告日志,提示用户确认cuda路径是否已添加到环境变量LD_LIBRARY_PATH中 + logger.warning(f"Cuda {self.version}版本(由mindspore-gpu要求的) 未找到,请确认cuda路径已设置到环境变量LD_LIBRARY_PATH中," + "您可以参考安装指南:https://www.mindspore.cn/install") + continue # 继续下一次循环 + # 从路径中去除库名部分 path = path.partition(lib_name)[0] - if path: + if path: # 如果路径非空 + # 将路径的绝对路径并去除末尾斜杠后添加到path_list中 path_list.append(os.path.abspath(path.strip() + "../")) + # 返回path_list中唯一的路径 return np.unique(path_list) - except subprocess.TimeoutExpired: - logger.warning("Failed to check cuda version due to the ldd command timeout, please confirm that " - "the correct cuda version has been installed, you can refer to the " - "installation guidelines: https://www.mindspore.cn/install") - return path_list + except subprocess.TimeoutExpired: # 捕获subprocess.TimeoutExpired异常 + # 记录警告日志,提示用户确认cuda版本是否正确安装,因为ldd命令超时 + logger.warning("由于ldd命令超时,无法检查cuda版本,请确认已安装正确的cuda版本," + "您可以参考安装指南:https://www.mindspore.cn/install") + return path_list # 返回空路径列表 def _read_version(self, file_path): """Get gpu version info in version.txt.""" @@ -211,70 +269,80 @@ class GPUEnvChecker(EnvChecker): class AscendEnvChecker(EnvChecker): - """ascend environment check""" + """Ascend 环境检查类""" def __init__(self): + # 初始化 Ascend 环境检查器的版本列表 self.version = ["1.81"] + + # 定义不同路径下的 version.info 文件位置 atlas_nnae_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info" atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/version.info" hisi_fwk_version = "/usr/local/Ascend/latest/fwkacllib/version.info" + + # 检查 Atlas NNAE 环境是否存在 if os.path.exists(atlas_nnae_version): - # atlas default path - self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" - self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" - self.tbe_path = self.fwk_path + "/lib64" - self.cce_path = self.fwk_path + "/ccec_compiler/bin" - self.fwk_version = atlas_nnae_version - self.op_path = "/usr/local/Ascend/nnae/latest/opp" - self.aicpu_path = "/usr/local/Ascend/nnae/latest" + # 如果存在,设置默认路径 + self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" # Framework 路径 + self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径 + self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径 + self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径 + self.fwk_version = atlas_nnae_version # Framework 版本文件路径 + self.op_path = "/usr/local/Ascend/nnae/latest/opp" # Operator 路径 + self.aicpu_path = "/usr/local/Ascend/nnae/latest" # AI CPU 路径 + + # 检查 Atlas Toolkit 环境是否存在 elif os.path.exists(atlas_toolkit_version): - # atlas default path - self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" - self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" - self.tbe_path = self.fwk_path + "/lib64" - self.cce_path = self.fwk_path + "/ccec_compiler/bin" - self.fwk_version = atlas_toolkit_version - self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" - self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" + # 如果存在,设置默认路径 + self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" # Framework 路径 + self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径 + self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径 + self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径 + self.fwk_version = atlas_toolkit_version # Framework 版本文件路径 + self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" # Operator 路径 + self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" # AI CPU 路径 + + # 检查 Hisi 环境是否存在 elif os.path.exists(hisi_fwk_version): - # hisi default path - self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" - self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" - self.tbe_path = self.fwk_path + "/lib64" - self.cce_path = self.fwk_path + "/ccec_compiler/bin" - self.fwk_version = hisi_fwk_version - self.op_path = "/usr/local/Ascend/latest/opp" - self.aicpu_path = "/usr/local/Ascend/latest" + # 如果存在,设置默认路径 + self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" # Framework 路径 + self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径 + self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径 + self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径 + self.fwk_version = hisi_fwk_version # Framework 版本文件路径 + self.op_path = "/usr/local/Ascend/latest/opp" # Operator 路径 + self.aicpu_path = "/usr/local/Ascend/latest" # AI CPU 路径 + else: - # custom or unknown environment - self.fwk_path = "" - self.op_impl_path = "" - self.tbe_path = "" - self.cce_path = "" - self.fwk_version = "" - self.op_path = "" - self.aicpu_path = "" - - # env + # 如果以上环境都不存在,设置为空路径 + self.fwk_path = "" # Framework 路径 + self.op_impl_path = "" # Operator 实现路径 + self.tbe_path = "" # TBE 库路径 + self.cce_path = "" # CCE 编译器路径 + self.fwk_version = "" # Framework 版本文件路径 + self.op_path = "" # Operator 路径 + self.aicpu_path = "" # AI CPU 路径 + + # 初始化环境变量 self.path = os.getenv("PATH") self.python_path = os.getenv("PYTHONPATH") self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH") self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH") - # check content + # 设置需要检查的路径内容 self.path_check = "/fwkacllib/ccec_compiler/bin" self.python_path_check = "opp/op_impl/built-in/ai_core/tbe" self.ld_lib_path_check_fwk = "/fwkacllib/lib64" self.ld_lib_path_check_addons = "/add-ons" self.ascend_opp_path_check = "/op" self.v = "" - def check_env(self, e): self._check_env() raise e def check_version(self): + # 检查指定路径的版本文件是否存在,如果不存在则跳过版本检查 if not Path(self.fwk_version).is_file(): logger.warning("Using custom Ascend AI software package (Ascend Data Center Solution) path, package " "version checking is skipped, please make sure Ascend AI software package (Ascend Data " @@ -282,40 +350,47 @@ class AscendEnvChecker(EnvChecker): "https://www.mindspore.cn/install") return + # 读取版本文件中的版本信息 v = self._read_version(self.fwk_version) + # 如果读取的版本不在支持的版本列表中,则记录警告信息 if v not in self.version: v_list = str([x for x in self.version]) logger.warning(f"MindSpore version {__version__} and Ascend AI software package (Ascend Data Center " f"Solution)version {v} does not match, the version of software package expect one of " f"{v_list}, please reference to the match info on: https://www.mindspore.cn/install") - def check_deps_version(self): """ te, topi, hccl wheel package version check in order to update the change of 'LD_LIBRARY_PATH' env, run a sub process """ + # 构建输入参数列表,包含mindspore版本和受支持的版本列表 input_args = ["--mindspore_version=" + __version__] for v in self.version: input_args.append("--supported_version=" + v) + # 获取依赖版本检查脚本的路径 deps_version_checker = os.path.join(os.path.split(os.path.realpath(__file__))[0], "_check_deps_version.py") + # 构建调用命令,包括python解释器路径、脚本路径和输入参数 call_cmd = [sys.executable, deps_version_checker] + input_args try: + # 运行子进程进行版本检查,设置超时时间为3秒,并捕获输出 process = subprocess.run(call_cmd, timeout=3, text=True, capture_output=True, check=False) + # 如果子进程的输出不为空,则记录警告信息并进行倒计时提醒 if process.stdout.strip() != "": logger.warning(process.stdout.strip()) warning_countdown = 3 for i in range(warning_countdown, 0, -1): logger.warning(f"Please pay attention to the above warning, countdown: {i}") time.sleep(1) + # 如果版本检查超时,则记录信息并跳过 except subprocess.TimeoutExpired: logger.info("Package te, topi, hccl version check timed out, skip.") - def set_env(self): + # 设置Ascend环境变量 if not self.tbe_path: self._check_env() return - + try: import te # pylint: disable=unused-import # pylint: disable=broad-except @@ -329,32 +404,35 @@ class AscendEnvChecker(EnvChecker): raise EnvironmentError( f"No such directory: {self.tbe_path}, Please check if Ascend AI software package (Ascend Data " "Center Solution) is installed correctly.") - - # check te version after set te env + + # 检查te版本 self.check_deps_version() - + + # 设置op实现路径环境变量 if Path(self.op_impl_path).is_dir(): - # python path for sub process + # python路径用于子进程 if os.getenv('PYTHONPATH'): os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH'] else: os.environ['PYTHONPATH'] = self.op_impl_path - # sys path for this process + # sys路径用于当前进程 sys.path.append(self.op_impl_path) - + os.environ['TBE_IMPL_PATH'] = self.op_impl_path else: raise EnvironmentError( - f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data " - "Center Solution) is installed correctly.") - + f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data Center " + "Solution) is installed correctly.") + + # 设置CCE路径环境变量 if Path(self.cce_path).is_dir(): os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH'] else: raise EnvironmentError( f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center " "Solution) is installed correctly.") - + + # 设置OP路径环境变量 if self.op_path is None: pass elif Path(self.op_path).is_dir(): @@ -363,7 +441,8 @@ class AscendEnvChecker(EnvChecker): raise EnvironmentError( f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center " "Solution) is installed correctly.") - + + # 设置AICPU路径环境变量 if self.aicpu_path is None: pass elif Path(self.aicpu_path).is_dir(): @@ -372,44 +451,54 @@ class AscendEnvChecker(EnvChecker): raise EnvironmentError( f"No such directory: {self.aicpu_path}, Please check if Ascend AI software package (Ascend Data Center" " Solution) is installed correctly.") - + def _check_env(self): """ascend dependence path check""" + # 检查是否设置正确的PATH环境变量 if self.path is None or self.path_check not in self.path: logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env " "PATH, you can reference to the installation guidelines https://www.mindspore.cn/install") - + + # 检查是否设置正确的PYTHONPATH环境变量 if self.python_path is None or self.python_path_check not in self.python_path: logger.warning( "Can not find tbe op implement(need by mindspore-ascend), please check if you have set env " "PYTHONPATH, you can reference to the installation guidelines " "https://www.mindspore.cn/install") - + + # 检查是否设置正确的LD_LIBRARY_PATH环境变量 if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and self.ld_lib_path_check_addons in self.ld_lib_path): logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env " "LD_LIBRARY_PATH, you can reference to the installation guidelines " "https://www.mindspore.cn/install") - + + # 检查是否设置正确的ASCEND_OPP_PATH环境变量 if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path: logger.warning( "Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, " "you can reference to the installation guidelines https://www.mindspore.cn/install") - def _read_version(self, file_path): """get ascend version info""" with open(file_path, 'r') as f: all_info = f.readlines() + # 遍历文件中的每一行 for line in all_info: + # 检查行是否以 "Version=" 开头 if line.startswith("Version="): + # 去除行末的换行符并按 "=" 分割, 获取版本号 full_version = line.strip().split("=")[1] + # 提取主版本号和次版本号, 并用 "." 连接 self.v = '.'.join(full_version.split('.')[0:2]) + # 返回版本号 return self.v + # 如果未找到版本信息, 返回 None 或默认值 return self.v def check_version_and_env_config(): - """check version and env config""" + """检查版本和环境配置""" + # 检查包名以确定使用哪种环境检查器 if __package_name__.lower() == "mindspore-ascend": env_checker = AscendEnvChecker() # Note: pre-load libgomp.so to solve error like "cannot allocate memory in statis TLS block" @@ -425,19 +514,21 @@ def check_version_and_env_config(): else: logger.info(f"Package version {__package_name__} does not need to check any environment variable, skipping.") return + # 检查是否关闭版本检查,如果已关闭则直接返回 if os.getenv("MS_DEV_CLOSE_VERSION_CHECK") == "ON": return + # 设置环境变量以关闭版本检查 os.environ["MS_DEV_CLOSE_VERSION_CHECK"] = "ON" - try: - # check version of ascend site or cuda + # 检查 ascend site 或 cuda 的版本 env_checker.check_version() from .. import _c_expression # pylint: disable=unused-import + # 设置环境 env_checker.set_env() except ImportError as e: + # 处理导入错误,检查环境 env_checker.check_env(e) - def _set_pb_env(): """Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": @@ -449,7 +540,9 @@ def _set_pb_env(): logger.info("Setting the env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python` to prevent memory overflow " "during save or load checkpoint file.") os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" - - + +# 检查版本和环境配置 check_version_and_env_config() + +# 设置协议缓冲区的环境变量, 防止内存溢出 _set_pb_env() diff --git a/src/mindspore2022/mindspore/python/mindspore/run_check/run_check.py b/src/mindspore2022/mindspore/python/mindspore/run_check/run_check.py index 225cd92a..4838e654 100644 --- a/src/mindspore2022/mindspore/python/mindspore/run_check/run_check.py +++ b/src/mindspore2022/mindspore/python/mindspore/run_check/run_check.py @@ -26,23 +26,26 @@ def _check_mul(): """ from importlib import import_module import numpy as np - + try: ms = import_module("mindspore") except ModuleNotFoundError: ms = None finally: pass - + + # 打印MindSpore版本信息 print(f"MindSpore version: ", ms.__version__) - + + # 创建两个MindSpore张量,分别包含数组[1.0, 2.0, 3.0]和[4.0, 5.0, 6.0] input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32) input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32) + # 创建一个乘法操作对象 mul = ms.ops.Mul() + # 执行乘法操作 mul(input_x, input_y) + # 打印乘法计算结果正确,MindSpore安装成功的信息 print(f"The result of multiplication calculation is correct, MindSpore has been installed successfully!") - - def run_check(): """ Provide a convenient API to check if the installation is successful or failed. @@ -55,10 +58,13 @@ def run_check(): The result of multiplication calculation is correct, MindSpore has been installed successfully! """ try: + # 尝试执行检查乘法操作的函数 _check_mul() # pylint: disable=broad-except + # 捕获所有异常并打印错误信息 except Exception as e: print("MindSpore running check failed.") print(str(e)) finally: - pass + # 无论是否发生异常,都会执行此部分代码 + pass # 执行乘法检查的函数,并处理可能的异常情况。如果检查失败,打印错误信息。 \ No newline at end of file