diff --git a/qin.txt b/qin.txt new file mode 100644 index 00000000..b87bfcea --- /dev/null +++ b/qin.txt @@ -0,0 +1,38 @@ +import os +import stat +import fcntl +import platform +from logging.handlers import RotatingFileHandler + +class _MultiCompatibleRotatingFileHandler(RotatingFileHandler): + """Inherit RotatingFileHandler for multiprocess compatibility. + + 这个类继承自`RotatingFileHandler`,是为了在多进程环境下安全地使用日志回滚功能。 + 在多进程环境下,多个进程可能会同时尝试写入或回滚日志文件,这可能会导致文件损坏或数据丢失。 + 通过在这个类中对相关方法进行重写,确保了日志文件在多进程环境下的正确处理。 + """ + + def doRollover(self): + """Override doRollover for multiprocess compatibility + and setting permission of Log file + + 这个方法重写了`RotatingFileHandler`中的`doRollover`方法,增加了多进程兼容性, + 并设置了日志文件的权限。 + + 1. 使用`fcntl`模块获得独占锁,确保在回滚日志文件时不会有其他进程进行写操作。 + 2. 设置日志文件的权限,以确保日志文件的安全性。 + 3. 调用父类的`doRollover`方法执行实际的日志回滚操作。 + 4. 回滚后,修改日志文件的权限,使其可读可写。 + """ + + # Attain an exclusive lock with blocking mode by `fcntl` module. + with open(self.baseFilename, 'a') as file_pointer: + # 如果操作系统不是Windows,使用`fcntl`模块对文件加锁 + if platform.system() != "Windows": + fcntl.lockf(file_pointer.fileno(), fcntl.LOCK_EX) + # 设置日志文件权限为只读,增加安全性 + os.chmod(self.baseFilename, stat.S_IREAD) + # 调用父类的`doRollover`方法执行日志回滚操作 + super().doRollover() + # 修改日志文件的权限为可读可写,以便后续的日志写入操作 + os.chmod(self.baseFilename, stat.S_IREAD | stat.S_IWRITE) diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py b/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py index 05b3c51c..28802f9f 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/_comm_helper.py @@ -18,24 +18,31 @@ from mindspore.parallel._ps_context import _is_role_pserver, _is_role_sched from mindspore import log as logger from ._hccl_management import load_lib as hccl_load_lib from .._c_expression import get_rank_id, get_rank_size - + +# 检查HCCL是否可用 _HCCL_AVAILABLE = False +# 检查HCCL测试是否可用 _HCCL_TEST_AVAILABLE = False +# 检查NCCL是否可用 _NCCL_AVAILABLE = False +# 检查MPI是否可用 _MPI_AVAILABLE = False try: + # 尝试导入mindspore._ms_mpi,如果成功则NCCL可用 import mindspore._ms_mpi as mpi _NCCL_AVAILABLE = True except ImportError: + # 如果导入失败,则NCCL不可用 _NCCL_AVAILABLE = False - - + +# 尝试加载 HCCL 库,如果成功则设置 _HCCL_AVAILABLE 为 True,否则捕获 RuntimeError 并设置为 False try: hccl_load_lib() _HCCL_AVAILABLE = True except RuntimeError: _HCCL_AVAILABLE = False - + +# 如果 HCCL 可用,则导入 _hccl_management 并尝试导入 mindspore._ascend_mpi,如果成功则设置 _MPI_AVAILABLE 为 True,否则捕获 ImportError 并设置为 False if _HCCL_AVAILABLE: from . import _hccl_management as hccl try: @@ -43,6 +50,7 @@ if _HCCL_AVAILABLE: _MPI_AVAILABLE = True except ImportError: _MPI_AVAILABLE = False +# 如果 HCCL 不可用,则尝试导入 hccl_test.manage.api,如果成功则设置 _HCCL_AVAILABLE 和 _HCCL_TEST_AVAILABLE 为 True,否则捕获 ImportError 并设置 _HCCL_AVAILABLE 为 False else: try: import hccl_test.manage.api as hccl @@ -50,12 +58,11 @@ else: _HCCL_TEST_AVAILABLE = True except ImportError: _HCCL_AVAILABLE = False - - + +# 定义 HCCL 和 NCCL 的通信组名称常量 HCCL_WORLD_COMM_GROUP = "hccl_world_group" NCCL_WORLD_COMM_GROUP = "nccl_world_group" - class Backend: """ Class for available backends. @@ -82,13 +89,17 @@ class Backend: def __new__(cls, name): """Create instance object of Backend.""" + # 检查传入的name是否为字符串类型 if not isinstance(name, str): raise TypeError("For 'Backend', the class variable 'name' must be a string, " "but got the type : {}".format(type(name))) + # 获取对应name的大写形式的Backend类属性值,如果不存在则返回Backend.UNDEFINED value = getattr(Backend, name.upper(), Backend.UNDEFINED) + # 如果获取到的值是Backend.UNDEFINED,说明传入的name不被支持 if value == Backend.UNDEFINED: raise ValueError("For 'Backend', the class variable 'name' {} is not supported, " "please use hccl or nccl.".format(name)) + # 返回获取到的Backend类属性值 return value DEFAULT_BACKEND = Backend("hccl") @@ -97,42 +108,41 @@ DEFAULT_BACKEND = Backend("hccl") class GlobalComm: """ World communication information. The GlobalComm is a global class. The members contain: - + - BACKEND: The communication library used, using HCCL/NCCL. - WORLD_COMM_GROUP: Global communication domain. """ BACKEND = DEFAULT_BACKEND WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP - INITED = False - CHECK_ENVS = True - - + INITED = False # 标记全局通信是否已初始化 + CHECK_ENVS = True # 标记是否需要检查通信环境变量 + + class _ExistingGroup: """ - The communication groups which exist in the progress. + 用于表示在程序运行过程中存在的通信组。 """ - ITEMS = {} + ITEMS = {} # 存储通信组的字典,键为通信组的标识符,值为通信组对象 def is_hccl_available(): """ - Check HCCL api is available. + 检查HCCL API是否可用。 Returns: - Boolean. Return whether HCCL is available or not. + Boolean: 返回HCCL是否可用。 """ - return _HCCL_AVAILABLE + return _HCCL_AVAILABLE # 返回一个布尔值,指示HCCL是否可用 def is_mpi_available(): """ - Check HCCL & MPI api is available. + 检查HCCL和MPI API是否可用。 Returns: - Boolean. Return whether HCCL & MPI is available or not. + Boolean: 返回HCCL和MPI是否同时可用。 """ - return _MPI_AVAILABLE - + return _MPI_AVAILABLE # 返回一个布尔值,指示HCCL和MPI是否同时可用 def is_nccl_available(): """ @@ -158,19 +168,25 @@ def check_parameter_available(func): Wrapper. If not available, raise Error. """ def wrapper(*args, **kargs): + # 如果当前角色是参数服务器或者调度器,直接调用被装饰的函数 if _is_role_pserver() or _is_role_sched(): return func(*args, **kargs) + # 检查分布式通信是否已经初始化,未初始化则抛出异常 if not GlobalComm.INITED: raise RuntimeError("Distributed Communication has not been inited") + # 获取参数组,默认为None group = None + # 检查关键字参数中是否包含"group",并进行类型检查 if "group" in kargs.keys(): group = kargs.get("group") if group is not None and not isinstance(group, str): raise TypeError("The parameter 'group' should be str or None, " "but got the type : {}".format(type(group))) + # 获取后端,默认为None if "backend" in kargs.keys(): backend = kargs.get("backend") + # 检查后端是否可用,不可用则抛出异常 if backend is Backend.HCCL and not is_hccl_available(): raise RuntimeError("Distributed Communication doesn't have HCCL built in") if backend is Backend.HCCL_MPI and not is_mpi_available(): @@ -178,15 +194,16 @@ def check_parameter_available(func): if backend is Backend.NCCL and not is_nccl_available(): raise RuntimeError("Distributed Communication doesn't have NCCL built in") + # 如果未指定group,根据backend设置默认的group if group is None: - if backend is Backend.HCCL or Backend.HCCL_MPI: + if backend is Backend.HCCL or backend is Backend.HCCL_MPI: group = HCCL_WORLD_COMM_GROUP elif backend is Backend.NCCL: group = NCCL_WORLD_COMM_GROUP + # 调用被装饰的函数 return func(*args, **kargs) return wrapper - @check_parameter_available def _get_rank_helper(group, backend): """ @@ -202,10 +219,13 @@ def _get_rank_helper(group, backend): Returns: Integer. The local rank id of the calling process. """ + # 辅助函数,用于根据不同的后端和组获取 rank_id + # 获取当前角色的rank_id,如果是参数服务器或调度器角色,rank_id设为0并返回 rank_id = None if _is_role_pserver() or _is_role_sched(): rank_id = 0 return rank_id + # 根据不同的后端获取rank_id if backend == Backend.HCCL_MPI: rank_id = mpi.get_rank_id(group) elif backend == Backend.HCCL: @@ -216,6 +236,7 @@ def _get_rank_helper(group, backend): elif backend == Backend.NCCL: rank_id = get_rank_id(group) else: + # 如果后端不被支持,抛出ValueError异常 raise ValueError("For '_get_rank_helper', the argument 'backend' {} is not supported, " "please use hccl_mpi, hccl or nccl.".format(backend)) return rank_id @@ -236,22 +257,30 @@ def _get_local_rank_helper(group, backend): Returns: Integer. The local rank id of the calling process. """ + # 获取当前进程的rank id,根据不同的后端和组进行处理 rank_id = None + # 根据不同的后端选择获取rank_id的方法 if backend == Backend.HCCL_MPI: + # 使用HCCL MPI后端时,通过mpi.get_rank_id获取rank_id rank_id = mpi.get_rank_id(group) elif backend == Backend.HCCL: + # 使用HCCL后端时,根据group的不同选择获取rank_id的方法 if group == HCCL_WORLD_COMM_GROUP: + # 如果group是HCCL_WORLD_COMM_GROUP,则使用hccl.get_local_rank_id获取rank_id rank_id = hccl.get_local_rank_id() else: + # 对于其他group,同样使用hccl.get_local_rank_id获取rank_id rank_id = hccl.get_local_rank_id(group) elif backend == Backend.NCCL: + # 如果使用NCCL后端,当前不支持get_local_rank_id方法,抛出异常 raise RuntimeError("Nccl doesn't support get_local_rank_id now.") else: + # 如果backend既不是HCCL_MPI也不是HCCL,抛出异常表示不支持的backend raise ValueError("For '_get_local_rank_helper', the argument 'backend' {} is not supported, " "please use hccl_mpi or hccl.".format(backend)) + # 返回获取到的rank_id return rank_id - @check_parameter_available def _get_size_helper(group, backend): """ @@ -268,9 +297,12 @@ def _get_size_helper(group, backend): Integer. The rank size of specified group. """ size = None + # 如果当前角色是参数服务器或调度器,则将size设为1并返回 if _is_role_pserver() or _is_role_sched(): size = 1 return size + # 根据不同的后端设置size的值 + # 根据不同的后端获取组的大小 if backend == Backend.HCCL_MPI: size = mpi.get_rank_size(group) elif backend == Backend.HCCL: @@ -302,19 +334,23 @@ def _get_local_size_helper(group, backend): Integer. The local rank size where the calling process is being within specified group. """ size = None + # 根据不同的后端获取本地进程组的大小 if backend == Backend.HCCL: + # 如果组是全局通信组,则获取全局通信组的本地排名大小 if group == HCCL_WORLD_COMM_GROUP: size = hccl.get_local_rank_size() + # 否则获取指定组的本地排名大小 else: size = hccl.get_local_rank_size(group) + # NCCL后端不支持获取本地排名大小,抛出异常 elif backend == Backend.NCCL: raise RuntimeError("Nccl doesn't support get_local_rank_size now.") + # 对于不支持的后端,抛出异常 else: raise ValueError("For '_get_local_size_helper', the argument 'backend' {} is not supported, " "please use hccl.".format(backend)) return size - @check_parameter_available def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend): """ @@ -333,21 +369,26 @@ def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend): Integer. A rank id in world communication group. """ world_rank_id = None + # 检查 group_rank_id 是否为整数类型,如果不是则抛出 TypeError if not isinstance(group_rank_id, int): raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group_rank_id' must be" " type of int, but got 'group_rank_id' type : {}.".format(type(group_rank_id))) + # 根据不同的后端选择不同的逻辑处理方式 if backend == Backend.HCCL: + # 如果在 GPU 上使用 HCCL,但 group 参数为 HCCL_WORLD_COMM_GROUP,则抛出 ValueError if group == HCCL_WORLD_COMM_GROUP: raise ValueError("For 'get_world_rank_from_group_rank' on GPU, the argument 'group' " "should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.") + # 调用 hccl.get_world_rank_from_group_rank 方法获取 world_rank_id world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id) elif backend == Backend.NCCL: + # 如果使用 NCCL,则抛出 RuntimeError 表示不支持该操作 raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.") else: + # 如果 backend 参数不支持,则抛出 ValueError 表示不支持该后端 raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend)) + # 返回获取的 world_rank_id return world_rank_id - - @check_parameter_available def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend): """ @@ -366,21 +407,27 @@ def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend): Integer. A rank id in user communication group. """ group_rank_id = None + # 检查 world_rank_id 是否为整数类型,如果不是则抛出 TypeError if not isinstance(world_rank_id, int): raise TypeError("For 'get_group_rank_from_world_rank', the argument 'world_rank_id' must be type of int, " "but got 'world_rank_id' type : {}.".format(type(world_rank_id))) + # 根据不同的后端处理获取 group_rank_id 的逻辑 if backend == Backend.HCCL: + # 检查 GPU 后端的 group 参数是否正确,如果不正确则抛出 ValueError if group == HCCL_WORLD_COMM_GROUP: raise ValueError("For 'get_group_rank_from_world_rank' on GPU, the argument 'group' " "should be 'NCCL_WORLD_COMM_GROUP', but got 'HCCL_WORLD_COMM_GROUP'.") + # 调用 hccl 模块的函数获取 group_rank_id group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group) + # NCCL 后端不支持此操作,抛出 RuntimeError elif backend == Backend.NCCL: raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.") + # 如果后端不是 HCCL 或 NCCL,则抛出 ValueError 表示不支持的后端 else: raise ValueError("The argument 'backend' {} is not supported, please use hccl.".format(backend)) + # 返回获取到的 group_rank_id return group_rank_id - @check_parameter_available def _create_group_helper(group, rank_ids, backend): """ @@ -395,34 +442,46 @@ def _create_group_helper(group, rank_ids, backend): TypeError: If rank_ids is not a list. ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid. """ + # 检查组是否已存在 if group in _ExistingGroup.ITEMS.keys(): + # 如果组已存在且提供的rank_ids与存储的不一致,抛出异常 if rank_ids != _ExistingGroup.ITEMS[group]: raise ValueError("The group {} has been created, the rank_list is {}, " "but current rank_list for the group is {}". format(group, _ExistingGroup.ITEMS[group], rank_ids)) + # 记录警告信息,提示组已存在 logger.warning("%r group has existed.", group) return + + # 根据不同的后端创建组 if backend == Backend.HCCL: + # 检查rank_ids是否为列表类型 if not isinstance(rank_ids, list): raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " "but got 'rank_ids' type : {}.".format(type(rank_ids))) + # 检查rank_ids的长度是否大于1 rank_size = len(rank_ids) if rank_size < 1: raise ValueError("For 'create_group', the argument 'rank_ids' size should be greater than 1, " "but got 'rank_ids' size : {}.".format(len(rank_ids))) + # 检查rank_ids中是否有重复的元素 if len(rank_ids) - len(list(set(rank_ids))) > 0: raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) + # 使用HCCL创建组 hccl.create_group(group, rank_size, rank_ids) elif backend == Backend.HCCL_MPI: + # 使用HCCL_MPI创建组 mpi.create_group(group, rank_ids) elif backend == Backend.NCCL: + # NCCL暂不支持创建组,抛出异常 raise RuntimeError("Nccl doesn't support create_group now.") else: + # 如果后端不支持,抛出异常 raise ValueError("The context configuration parameter 'backend' {} is not supported, " "please use hccl.".format(backend)) - _ExistingGroup.ITEMS[group] = rank_ids - + # 将新创建的组及其rank_ids添加到_existingGroup中 + _ExistingGroup.ITEMS[group] = rank_ids @check_parameter_available def _destroy_group_helper(group, backend): """ @@ -435,12 +494,17 @@ def _destroy_group_helper(group, backend): Raises: ValueError: If group is "hccl_world_group" or backend is invalid. """ + # 根据后端类型销毁通信组 if backend == Backend.HCCL: + # 检查是否为 HCCL 的全局通信组 if group == HCCL_WORLD_COMM_GROUP: raise ValueError("The hccl_world_group does not support destruction.") + # 销毁指定的 HCCL 通信组 hccl.destroy_group(group) elif backend == Backend.NCCL: + # 当前 NCCL 后端不支持销毁通信组 raise RuntimeError("Nccl doesn't support destroy_group now.") else: + # 抛出错误,表示不支持的后端类型 raise ValueError("The context configuration parameter 'backend' {} is not supported, " - "please use hccl.".format(backend)) + "please use hccl.".format(backend)) \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py b/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py index 92dbd698..1841b222 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/_hccl_management.py @@ -24,7 +24,7 @@ MAX_RANK_NUM = 4096 HCCL_LIB = 'libhccl_plugin.so' HCCL_LIB_CTYPES = "" - +# 检查集体通信组的名称是否合法 def check_group(group): """ A function that check if a collection communication group is legal. @@ -41,23 +41,31 @@ def check_group(group): raise TypeError("The type of communication group name must be type of string, " "but got 'group' type : {}.".format(type(group))) - +# 检查集体通信中的排名编号是否合法 def check_rank_num(rank_num): """ - A function that check if a collection communication rank number is legal.If not raise error. + 检查通信集合中的排名编号是否合法。如果不合法则抛出错误。 + + 参数: + rank_num: 需要检查的排名编号,预期为整数类型。 Returns: - None + None: 该函数不返回任何值,但可能会抛出异常。 """ - if isinstance(rank_num, (int)): + # 检查 rank_num 是否为整数类型 + if isinstance(rank_num, int): + # 检查 rank_num 是否在合法范围内(大于0且小于等于 MAX_RANK_NUM) if rank_num > MAX_RANK_NUM or rank_num <= 0: - raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and" - "less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num)) + # 如果 rank_num 不在合法范围内,抛出 ValueError 异常,并提供详细的错误信息 + raise ValueError("对于 'create_group' 函数,参数 'rank_ids' 的大小必须大于0且" + "小于等于 {},但得到的 'rank_ids' 的大小为: {}。".format(MAX_RANK_NUM, rank_num)) else: - raise TypeError("The argument 'rank_num' must be type of int, " - "but got 'rank_num' type : {}.".format(type(rank_num))) + # 如果 rank_num 不是整数类型,抛出 TypeError 异常,并提供详细的错误信息 + raise TypeError("参数 'rank_num' 必须为整数类型," + "但得到的 'rank_num' 类型为: {}。".format(type(rank_num))) +#检查集体通信中的排名标识(rank id)是否合法 def check_rank_id(rank_id): """ A function that check if a collection communication rank id is legal.If not raise error. @@ -65,40 +73,48 @@ def check_rank_id(rank_id): Returns: None """ + # 检查rank_id是否为整数类型 if isinstance(rank_id, (int)): + # 检查rank_id是否在有效范围内 if rank_id >= MAX_RANK_NUM or rank_id < 0: raise ValueError("The rand id in the communication group must be greater or equal 0 and " "less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id)) else: + # 如果rank_id不是整数类型,抛出类型错误 raise TypeError("The rand id in the communication group must be must be type of int, " "but got type value : {}.".format(type(rank_id))) - +# 加载 HCCL(Huawei Cloud Communication Library)库 def load_lib(): """load hccl lib""" try: + # 获取当前文件所在的目录 base_dir = os.path.dirname(os.path.realpath(__file__)) + # 构建库文件的路径 lib_path = os.path.join(base_dir, "../lib", HCCL_LIB) + # 加载库文件 hccl_lib = ctypes.CDLL(lib_path) except Exception: + # 如果加载失败则抛出运行时错误 raise RuntimeError('Get hccl lib error.') - + + # 将加载的库文件设置为全局变量 global HCCL_LIB_CTYPES HCCL_LIB_CTYPES = hccl_lib - - + def c_str(string): """Convert a python string to C string.""" + # 将字符串转换为C风格字符串 if not isinstance(string, str): string = string.decode('ascii') return ctypes.c_char_p(string.encode('utf-8')) - - + + def c_array(ctype, values): """Create ctypes array from a python array.""" + # 从Python数组创建ctypes数组 return (ctype * len(values))(*values) - - +#用于创建包含指定数量和ID的HCCL通信组,但不能创建世界组。 def create_group(group, rank_num, rank_ids): """ Create group. @@ -112,28 +128,38 @@ def create_group(group, rank_num, rank_ids): Returns: None """ + # 检查组的有效性 check_group(group) + # 检查排名数量的有效性 check_rank_num(rank_num) + # 检查rank_ids是否为列表类型 if isinstance(rank_ids, (list)): + # 确保rank_num与rank_ids的长度一致 if rank_num != len(rank_ids): raise ValueError("The argument 'rank_num' number should be equal to the length " "of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}." .format(rank_num, rank_ids)) + # 检查rank_ids中的每个元素是否为非负整数 for rank_id in rank_ids: if not isinstance(rank_id, (int)) or rank_id < 0: raise ValueError("The elements of argument 'rank_ids' must be " "unsigned integer, but got the type : {}".format(type(rank_id))) + # 将rank_ids转换为C类型的数组 c_array_rank_ids = c_array(ctypes.c_uint, rank_ids) + # 将rank_num转换为C类型的无符号整数 c_rank_num = ctypes.c_uint(rank_num) + # 将group转换为C类型的字符串 c_group = c_str(group) + # 调用HCCL库创建组 ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) + # 检查组创建是否成功 if ret != 0: raise RuntimeError('Create group error, the error code is {}.'.format(ret)) else: + # 如果rank_ids不是列表类型,抛出类型错误 raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, " "but got 'rank_ids' type : {}.".format(type(rank_ids))) - - +#用于销毁用户创建的HCCL组 def destroy_group(group): """ A function that destroy the group which created by user. @@ -144,11 +170,16 @@ def destroy_group(group): Returns: None """ - check_group(group) - c_group = c_str(group) - ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) - if ret != 0: - raise RuntimeError('Destroy group error.') + # 检查传入的组是否有效 +check_group(group) +# 将组名转换为C风格的字符串 +c_group = c_str(group) +# 调用HCCL库中的函数销毁指定的组 +ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group) +# 如果返回值不为0,说明销毁组时发生了错误,抛出异常 +if ret != 0: + raise RuntimeError('Destroy group error.') + def get_rank_size(group="hccl_world_group"): @@ -162,16 +193,23 @@ def get_rank_size(group="hccl_world_group"): An integer scalar with the num of ranks. """ + # 根据上下文的模式判断是否为PYNATIVE_MODE模式,若是,则直接返回HCCL的rank size if context.get_context("mode") == context.PYNATIVE_MODE: return get_hccl_rank_size() - + + # 检查给定的组是否有效 check_group(group) + # 将组名转换为C字符串格式 c_group = c_str(group) + # 定义一个C类型的无符号整数用于存储rank size c_rank_size = ctypes.c_uint() + # 调用HCCL库的HcomGetRankSize函数获取组内的rank size ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size)) + # 如果返回值不为0,表示获取rank size失败,抛出运行时错误 if ret != 0: raise RuntimeError('Get rank size error.') - + + # 返回获取到的rank size值 return c_rank_size.value @@ -186,17 +224,22 @@ def get_rank_id(group="hccl_world_group"): if context.get_context("mode") == context.PYNATIVE_MODE: return get_hccl_rank_id() + # 检查组的有效性 check_group(group) + # 将组转换为 C 字符串格式 c_group = c_str(group) + # 定义一个用于存储 rank id 的 ctypes 无符号整数 c_rank_id = ctypes.c_uint() + # 调用 HCCL 库获取当前进程的 rank id ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id)) + # 如果返回值不为 0,表示获取 rank id 出错,抛出 RuntimeError 异常 if ret != 0: raise RuntimeError('Get rank id error.') + # 返回获取到的 rank id 值 return c_rank_id.value - def get_local_rank_size(group="hccl_world_group"): """ A function that returns the number of local ranks within the given collection communication group. @@ -207,19 +250,25 @@ def get_local_rank_size(group="hccl_world_group"): Returns: An integer scalar with the num of local ranks. """ + # 检查当前运行模式是否为PYNATIVE_MODE,如果是则抛出异常 if context.get_context("mode") is context.PYNATIVE_MODE: raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, " "'get_local_rank_size' only support GRAPH_MODE") + # 验证传入的组是否有效 check_group(group) + # 将组名称转换为C字符串格式 c_group = c_str(group) + # 定义一个ctypes的无符号整数变量,用于存储本地排名大小 c_local_rank_size = ctypes.c_uint() + # 调用HCCL库中的HcomGetLocalRankSize函数获取本地排名大小 ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size)) + # 如果返回值不为0,说明获取本地排名大小时出错,抛出异常 if ret != 0: raise RuntimeError('Get local rank size error.') + # 返回获取到的本地排名大小 return c_local_rank_size.value - def get_local_rank_id(group="hccl_world_group"): """ Get local rank id. @@ -230,16 +279,23 @@ def get_local_rank_id(group="hccl_world_group"): An integer scalar with the local rank id of the calling process. """ + # 检查当前运行模式是否为PYNATIVE_MODE,如果是则抛出异常 if context.get_context("mode") is context.PYNATIVE_MODE: raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, " "'get_local_rank_id' only support GRAPH_MODE") + # 验证群组的有效性 check_group(group) + # 将群组名称转换为C字符串格式 c_group = c_str(group) + # 定义一个无符号整型的C类型变量来存储本地排名ID c_local_rank_id = ctypes.c_uint() + # 调用HCCL库的HcomGetLocalRankId函数获取本地排名ID ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id)) + # 如果返回值不为0,表示获取本地排名ID失败,抛出异常 if ret != 0: raise RuntimeError('Get local rank id error.') + # 返回获取到的本地排名ID值 return c_local_rank_id.value @@ -256,18 +312,25 @@ def get_world_rank_from_group_rank(group, group_rank_id): if context.get_context("mode") is context.PYNATIVE_MODE: raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, " "'get_world_rank_from_group_rank' only support GRAPH_MODE") + # 检查组名是否有效 check_group(group) + # 检查组内rank ID是否有效 check_rank_id(group_rank_id) + # 将组名转换为C字符串格式 c_group = c_str(group) + # 将组内rank ID转换为C的无符号整数类型 c_group_rank_id = ctypes.c_uint(group_rank_id) + # 声明一个用于存储世界rank ID的C的无符号整数类型变量 c_world_rank_id = ctypes.c_uint() + # 调用HCCL库中的HcomGetWorldRankFromGroupRank函数,根据组名和组内rank ID获取对应的世界rank ID ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id)) + # 如果返回值不为0,说明函数调用出错,抛出RuntimeError异常 if ret != 0: - raise RuntimeError('Get world rank from group rank error.') + raise RuntimeError('根据组内rank ID获取世界rank ID时出错。') + # 返回获取到的世界rank ID的值 return c_world_rank_id.value - def get_group_rank_from_world_rank(world_rank_id, group): """ Get group rank from world rank. @@ -281,13 +344,21 @@ def get_group_rank_from_world_rank(world_rank_id, group): if context.get_context("mode") is context.PYNATIVE_MODE: raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, " "'get_group_rank_from_world_rank' only support GRAPH_MODE") + # 检查组的有效性 check_group(group) + # 检查世界排名ID的有效性 check_rank_id(world_rank_id) + # 将组转换为C字符串 c_group = c_str(group) + # 将世界排名ID转换为C无符号整数 c_world_rank_id = ctypes.c_uint(world_rank_id) + # 创建一个用于存储组排名ID的C无符号整数 c_group_rank_id = ctypes.c_uint() + # 从世界排名ID获取组排名ID ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id)) + # 如果返回值不为0,抛出运行时错误 if ret != 0: raise RuntimeError('Get group rank from world rank error.') - return c_group_rank_id.value + # 返回获取到的组排名ID的值 + return c_group_rank_id.value \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/communication/management.py b/src/mindspore2022/mindspore/python/mindspore/communication/management.py index a64276af..cba0dc9b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/communication/management.py +++ b/src/mindspore2022/mindspore/python/mindspore/communication/management.py @@ -20,40 +20,51 @@ from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ _create_group_helper, _destroy_group_helper, HCCL_WORLD_COMM_GROUP, NCCL_WORLD_COMM_GROUP, \ _get_local_rank_helper, _get_local_size_helper, GlobalComm from .._c_expression import init_hccl, finalize_hccl, init_gpu_collective - +# 导入mindspore上下文模块 +# 导入mindspore的进程服务器和调度器角色判断函数 +# 导入通信辅助模块,包括后端类型、获取和设置rank及size的辅助函数、创建和销毁组的辅助函数、以及世界通信组和本地通信组的相关标识 +# 导入C语言表达式模块,包含HCCL和NCCL的初始化和终结化函数,以及GPU集体通信的初始化函数 __all__ = ["init", "release", "get_rank", "get_local_rank", "get_group_size", "get_local_rank_size", "get_world_rank_from_group_rank", "get_group_rank_from_world_rank", "create_group", "destroy_group", "HCCL_WORLD_COMM_GROUP", "NCCL_WORLD_COMM_GROUP"] +# 默认的世界通信组 DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP def _get_group(group): - """Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`.""" + """根据输入的`group`参数返回相应的通信组。 + + 如果`group`等于`DEFAULT_WORLD_COMM_GROUP`,则返回全局通信组`GlobalComm.WORLD_COMM_GROUP`; + 否则,返回传入的`group`参数。 + """ if group == DEFAULT_WORLD_COMM_GROUP: - return GlobalComm.WORLD_COMM_GROUP - return group + return GlobalComm.WORLD_COMM_GROUP # 返回默认的世界通信组 + return group # 返回传入的通信组参数 def _check_task_sink_envs(): - """ - Check whether task_sink environment variables have been exported or not. - - return True if task_sink environment variables have been exported, False otherwise. + """检查任务接收器(task_sink)相关的环境变量是否已导出。 + + 该函数通过检查环境变量`GRAPH_OP_RUN`来判断任务接收器的环境变量是否已导出。 + + 返回值: + - 如果环境变量`GRAPH_OP_RUN`已导出且其值可以转换为整数1,则返回False,表示环境变量未正确设置为启用状态。 + - 如果环境变量`GRAPH_OP_RUN`未导出或其值不能转换为整数1,则返回True,表示环境变量未导出或设置有误。 """ import os - task_sink = os.getenv("GRAPH_OP_RUN") + task_sink = os.getenv("GRAPH_OP_RUN") # 获取名为"GRAPH_OP_RUN"的环境变量 if task_sink: try: - if int(task_sink) == 1: - return False + if int(task_sink) == 1: # 尝试将环境变量的值转换为整数并检查是否等于1 + return False # 如果等于1,返回False,表示环境变量已导出但设置为禁用状态(非预期情况) except ValueError: - return True + return True # 如果转换为整数失败,返回True,表示环境变量设置有误 finally: - pass - return True + pass # finally块中的代码在这里是空操作,通常用于清理操作 + return True # 如果环境变量未导出,返回True def _check_parallel_envs(): @@ -63,25 +74,31 @@ def _check_parallel_envs(): Raises: RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values. """ + # 检查是否需要进行环境验证,如果不进行则直接返回 if not GlobalComm.CHECK_ENVS: return import os + # 获取环境变量RANK_ID的值 rank_id_str = os.getenv("RANK_ID") + # 如果RANK_ID未设置,抛出运行时错误 if not rank_id_str: raise RuntimeError("Environment variables RANK_ID has not been exported, please export variables 'RANK_ID'.") try: + # 尝试将RANK_ID转换为整数 int(rank_id_str) except ValueError: + # 如果转换失败,打印错误信息 print("Environment variables 'RANK_ID' should be number, but got the type : {}".format(type(rank_id_str))) finally: + # 无论是否发生异常,此块为空操作 pass + # 获取环境变量MINDSPORE_HCCL_CONFIG_PATH和RANK_TABLE_FILE的值 rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH") rank_table_file_str_old = os.getenv("RANK_TABLE_FILE") + # 如果两个环境变量都未设置,抛出运行时错误 if not rank_table_file_str and not rank_table_file_str_old: raise RuntimeError("Get hccl rank_table_file failed, " "please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.") - - def init(backend_name=None): """ Initialize distributed backend, e.g. HCCL/NCCL, it is required before using the communication service. @@ -109,15 +126,20 @@ def init(backend_name=None): >>> from mindspore.communication import init >>> init() """ + # 检查当前角色是否为参数服务器或调度器,如果是则直接返回 if _is_role_pserver() or _is_role_sched(): return + # 检查任务接收环境变量,获取设备目标和模式 task_sink = _check_task_sink_envs() device_target = context.get_context("device_target") mode = context.get_context("mode") mpi_init = False + # 如果没有任务接收且模式为图模式,设置mpi_init为True if not task_sink and mode == context.GRAPH_MODE: mpi_init = True + # 根据设备目标选择后端名称,如果不支持则抛出异常 + # 根据设备目标设置默认的后端名称 if backend_name is None: if device_target == "Ascend": backend_name = "hccl" @@ -126,28 +148,34 @@ def init(backend_name=None): else: raise RuntimeError("For 'set_context', the argument 'device_target' {} is not supported in " "parallel initialization, please use Ascend or GPU.".format(device_target)) + # 检查后端名称是否为字符串,如果不是则抛出异常 if not isinstance(backend_name, str): raise TypeError("For 'init', the argument 'backend_name' must be a string, " "but got the type : {}".format(type(backend_name))) - + # 根据后端名称初始化通信环境 if backend_name == "hccl": + # 如果设备目标不是Ascend,抛出异常 if device_target != "Ascend": raise RuntimeError("For 'init', the argument 'backend_name' should be 'Ascend' to init hccl, " "but got {}".format(device_target)) + # 如果不需要MPI初始化,检查并行环境并设置后端名称 if not mpi_init: _check_parallel_envs() GlobalComm.BACKEND = Backend("hccl") else: GlobalComm.BACKEND = Backend("hccl_mpi") + # 初始化HCCL并设置全局通信组和初始化状态 init_hccl() GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP GlobalComm.INITED = True elif backend_name == "nccl": + # 初始化GPU集体通信并设置后端名称、全局通信组和初始化状态 init_gpu_collective() GlobalComm.BACKEND = Backend("nccl") GlobalComm.WORLD_COMM_GROUP = NCCL_WORLD_COMM_GROUP GlobalComm.INITED = True else: + # 如果后端名称不支持,抛出异常 raise RuntimeError("For 'init', the argument 'backend_name' must be nccl while 'device_target' is GPU, " "but got the 'backend_name' : hccl.") @@ -165,10 +193,11 @@ def release(): Examples: >>> from mindspore.communication import init, release - >>> init() - >>> release() + >>> init() # 初始化分布式通信环境 + >>> release() # 释放分布式通信资源 + """ - finalize_hccl() + finalize_hccl()# 结束 HCCL 的使用,释放相关资源 def get_rank(group=GlobalComm.WORLD_COMM_GROUP): @@ -197,12 +226,18 @@ def get_rank(group=GlobalComm.WORLD_COMM_GROUP): >>> print(rank_id) >>> # the result is the rank_id in world_group """ + # 检查传入的group参数是否为字符串类型 if not isinstance(group, str): + # 如果group参数不是字符串类型,则抛出TypeError异常 raise TypeError("For 'get_rank', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用_get_rank_helper函数,获取指定group的rank ID + # _get_group函数用于解析group参数 + # GlobalComm.BACKEND变量存储了当前使用的通信后端 return _get_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) + def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): """ Gets local rank ID for current device in specified collective communication group. @@ -234,9 +269,11 @@ def get_local_rank(group=GlobalComm.WORLD_COMM_GROUP): >>> print("local_rank is: {}, world_rank is {}".format(local_rank, world_rank)) local_rank is: 1, world_rank is 9 """ + # 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'get_local_rank', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用辅助函数 _get_local_rank_helper 来获取本地排名,传入的参数为解析后的组和全局通信后端 return _get_local_rank_helper(group=_get_group(group), backend=GlobalComm.BACKEND) @@ -270,9 +307,11 @@ def get_group_size(group=GlobalComm.WORLD_COMM_GROUP): >>> print("group_size is: ", group_size) group_size is: 8 """ + # 检查传入的参数 'group' 是否为字符串类型 if not isinstance(group, str): raise TypeError("For 'get_group_size', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 返回指定组的大小,使用辅助函数 _get_size_helper 和全局通信后端 return _get_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) @@ -306,9 +345,11 @@ def get_local_rank_size(group=GlobalComm.WORLD_COMM_GROUP): >>> print("local_rank_size is: ", local_rank_size) local_rank_size is: 8 """ + # 检查传入的 'group' 参数是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'get_local_rank_size', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用辅助函数获取本地组的大小,使用 _get_group 函数获取组,并使用 GlobalComm.BACKEND 作为后端 return _get_local_size_helper(group=_get_group(group), backend=GlobalComm.BACKEND) @@ -347,12 +388,12 @@ def get_world_rank_from_group_rank(group, group_rank_id): >>> print("world_rank_id is: ", world_rank_id) world_rank_id is: 4 """ + # 检查传入的 group 参数是否为字符串类型 if not isinstance(group, str): raise TypeError("For 'get_world_rank_from_group_rank', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 返回根据组和组排名获取的世界排名的帮助函数的结果 return _get_world_rank_from_group_rank_helper(group=group, group_rank_id=group_rank_id, backend=GlobalComm.BACKEND) - - def get_group_rank_from_world_rank(world_rank_id, group): """ Get the rank ID in the specified user communication group corresponding to @@ -389,12 +430,13 @@ def get_group_rank_from_world_rank(world_rank_id, group): >>> print("group_rank_id is: ", group_rank_id) group_rank_id is: 1 """ + # 检查输入参数 'group' 是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'get_group_rank_from_world_rank', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用辅助函数 _get_group_rank_from_world_rank_helper 来获取组的排名,并返回结果 return _get_group_rank_from_world_rank_helper(world_rank_id=world_rank_id, group=group, backend=GlobalComm.BACKEND) - - +#创建一个用户集体通信组,通过传入通信组名称和设备ID列表来实现。 def create_group(group, rank_ids): """ Create a user collective communication group. @@ -427,9 +469,11 @@ def create_group(group, rank_ids): >>> create_group(group, rank_ids) >>> allreduce = ops.AllReduce(group) """ + # 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'create_group', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) + # 调用辅助函数 _create_group_helper 来创建组,使用指定的 rank_ids 和后端 _create_group_helper(group, rank_ids, backend=GlobalComm.BACKEND) @@ -450,7 +494,9 @@ def destroy_group(group): ValueError: If group is "hccl_world_group" or backend is invalid. RuntimeError: If HCCL is not available or MindSpore is GPU version. """ + # 检查传入的 group 参数是否为字符串类型,如果不是则抛出 TypeError if not isinstance(group, str): raise TypeError("For 'destroy_group', the argument 'group' must be type of string, " "but got 'group' type : {}.".format(type(group))) - _destroy_group_helper(group, backend=GlobalComm.BACKEND) + # 调用辅助函数 _destroy_group_helper 来销毁指定的组,并使用全局通信后端 + _destroy_group_helper(group, backend=GlobalComm.BACKEND) \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py b/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py index 32fcb19d..537c82cd 100644 --- a/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py +++ b/src/mindspore2022/mindspore/python/mindspore/graph_utils/graph_pattern.py @@ -45,14 +45,15 @@ class OneOf(OneOf_): TypeError: raise type error for invalid inputs. """ self.patterns = patterns + # 检查 patterns 是否是 Pattern 类的实例 if isinstance(patterns, Pattern): OneOf_.__init__(self, [patterns]) + # 检查 patterns 是否是 tuple 或 list 类型,并且其中所有元素都是 Pattern 类的实例 elif isinstance(patterns, (tuple, list)) and all(isinstance(pattern, Pattern) for pattern in patterns): OneOf_.__init__(self, patterns) + # 如果 patterns 不符合上述两种情况,则抛出 TypeError 异常 else: raise TypeError(f"Expect patterns to be a list of Patterns/Pattern, got : {patterns}") - - class Prim(Prim_): r""" Express a pattern of certain primitive type(s). @@ -76,25 +77,33 @@ class Prim(Prim_): Raises: TypeError: raise type error for invalid argument. """ + # 检查name是否为字符串类型,如果不是则抛出TypeError if name is not None and not isinstance(name, str): raise TypeError(f"Expect string, got : {name}") self.name = name + # 如果types是字符串类型,则将其按'|'分割成列表 if isinstance(types, str): if self.name is None: self.name = types self.types = types.split('|') + # 如果types是Primitive类型,则直接将其放入列表中 elif isinstance(types, Primitive): if self.name is None: self.name = types.name self.types = [types] + # 如果 types 是元组或列表,并且其中所有元素都是 Primitive 类型 elif isinstance(types, (tuple, list)) and all(isinstance(tp, Primitive) for tp in types): + # 如果 self.name 为 None,则初始化为空字符串并拼接所有 Primitive 的 name if self.name is None: self.name = "" for prim in types: self.name += prim.name + # 设置 self.types 为传入的 types self.types = types + # 如果 types 不符合预期类型,抛出 TypeError else: raise TypeError(f"Expecting a primitive type string or a list of Primitives, got : {types}") + # 调用基类 Prim_ 的初始化方法,传入 self.types 和 self.name Prim_.__init__(self, self.types, self.name) @@ -115,16 +124,22 @@ class Call(Call_): Raises: TypeError: raise type error for invalid argument. """ + # 检查 prim_pattern 是否为 Pattern, str 或 Primitive 类型,如果不是则抛出 TypeError if not isinstance(prim_pattern, (Pattern, str, Primitive)): raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}") self.prim_pattern = prim_pattern + # 初始化 inputs 列表 self.inputs = [] + # 如果 inputs 为 None,则不做任何操作 if inputs is None: pass + # 如果 inputs 是 tuple 或 list 并且其中所有元素都是 Pattern 类型,则赋值给 self.inputs elif isinstance(inputs, (tuple, list)) and all(isinstance(input, Pattern) for input in inputs): self.inputs = inputs + # 如果 inputs 不符合上述条件,则抛出 TypeError else: raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}") + # 调用父类 Call_ 的初始化方法,传入 self.prim_pattern 和 self.inputs Call_.__init__(self, self.prim_pattern, self.inputs) @@ -145,6 +160,7 @@ class NoneOf(NoneOf_): TypeError: raise type error for invalid argument. """ self.patterns = patterns + # 根据 patterns 的类型初始化 NoneOf_ 类 if patterns is None: NoneOf_.__init__(self, ()) elif isinstance(patterns, Pattern): @@ -154,7 +170,6 @@ class NoneOf(NoneOf_): else: raise TypeError(f"Expect list of Patterns/Pattern, got : {patterns}") - class NewTensor(NewTensor_): r""" New Tensor to be used in the target. @@ -167,13 +182,16 @@ class NewTensor(NewTensor_): Raises: TypeError: raise type error for invalid argument. """ + # 初始化输入张量 self.input_tensor = input_tensor + # 检查输入是否为 Tensor 类型 if isinstance(input_tensor, Tensor): + # 如果是 Tensor 类型,则调用 NewTensor_ 的初始化方法 NewTensor_.__init__(self, input_tensor) else: + # 如果不是 Tensor 类型,则抛出 TypeError 异常 raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}") - class NewParameter(NewParameter_): r""" New Parameter to be used in the target. @@ -193,11 +211,14 @@ class NewParameter(NewParameter_): self.default_tensor = default_tensor self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel + # 检查参数类型是否符合预期 if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\ isinstance(layerwise_parallel, bool): + # 初始化 NewParameter_ 类 NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad, self.layerwise_parallel) else: + # 如果参数类型不符合预期,抛出 TypeError raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \ layerwise_parallel(bool), got : {para_name}, {default_tensor}, \ - {requires_grad}, {layerwise_parallel}") + {requires_grad}, {layerwise_parallel}") \ No newline at end of file diff --git a/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py b/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py index 445c36a9..2266248a 100644 --- a/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py +++ b/src/mindspore2022/mindspore/python/mindspore/graph_utils/python_pass/python_pass_register.py @@ -39,6 +39,7 @@ class PyPassManager(PyPassManager_): TypeError: If argument has invalid type. """ def __init__(self, requires_grad=True, run_only_once=False): + # 初始化方法,接收两个布尔参数,设置实例的属性并调用父类的初始化方法 if not isinstance(requires_grad, bool): raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}") if not isinstance(run_only_once, bool): @@ -48,17 +49,20 @@ class PyPassManager(PyPassManager_): PyPassManager_.__init__(self) def register(self, py_pass): + # 注册一个Python pass,检查其是否为函数类型,并获取其模式和目标 if not isfunction(py_pass): raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}") pattern, target = py_pass() pass_name = py_pass.__name__ + # 检查模式和目标是否为Pattern类型 if not isinstance(pattern, Pattern): raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}") if not isinstance(target, Pattern): raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}") + # 调用父类的register方法,注册pass及其相关信息 super().register(pass_name, pattern, target, self.requires_grad, self.run_only_once_) - def unregister(self, py_pass): + # 从注册表中移除指定的Python传递对象,可以是字符串形式的名称或函数对象 if isinstance(py_pass, str): super().unregister(py_pass) return @@ -66,27 +70,30 @@ class PyPassManager(PyPassManager_): super().unregister(py_pass.__name__) return raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}") - + def __call__(self, py_pass): + # 将Python传递对象注册到注册表中,并返回该对象 self.register(py_pass) return py_pass - + def gen_new_parameter(self, pattern): + # 根据给定的模式生成新的参数,模式必须是NewParameter类型 if not isinstance(pattern, NewParameter): raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") super().gen_new_parameter(pattern) - + def set_renorm(self, should_renorm): + # 设置是否进行重归一化操作,参数必须是布尔值 if not isinstance(should_renorm, bool): raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}") super().set_renorm(should_renorm) - + def set_reopt(self, do_reopt): + # 设置是否进行重新优化操作,参数必须是布尔值 if not isinstance(do_reopt, bool): raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}") super().set_reopt(do_reopt) - def register_pass(requires_grad=True, run_only_once=False): """ Register python pass to specified pipeline phase which would be used in compilation. @@ -165,12 +172,13 @@ def cancel_new_parameter(pattern): >>> # some compilations >>> cancel_new_parameter(abc) """ + # 检查传入的pattern是否为NewParameter的实例 if not isinstance(pattern, NewParameter): raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}") + # 创建一个PyPassManager对象 ppm = PyPassManager() + # 从PyPassManager中注销指定名称的参数 ppm.unregister(pattern.para_name) - - def set_renorm(should_renorm): """ Set whether or not to do renormalization after modified graph in python pass(es). diff --git a/src/mindspore2022/mindspore/python/mindspore/run_check/_check_deps_version.py b/src/mindspore2022/mindspore/python/mindspore/run_check/_check_deps_version.py index 4e4426b2..65984f7d 100644 --- a/src/mindspore2022/mindspore/python/mindspore/run_check/_check_deps_version.py +++ b/src/mindspore2022/mindspore/python/mindspore/run_check/_check_deps_version.py @@ -19,22 +19,25 @@ import sys def parse_args(): """ parse args . - + Args: - + Returns: args. - + Examples: >>> parse_args() """ + # 创建一个ArgumentParser对象,用于解析命令行参数,描述信息为"MindSpore dependency packages version checker." parser = ArgumentParser(description="MindSpore dependency packages version checker.") + # 添加一个命令行参数--mindspore_version,类型为字符串,帮助信息为"MindSpore version." parser.add_argument("--mindspore_version", type=str, help="MindSpore version.") + # 添加一个命令行参数--supported_version,类型为字符串,可以多次指定,帮助信息为"Supported environment version." parser.add_argument("--supported_version", type=str, action='append', help="Supported environment version.") + # 解析命令行参数并返回结果 args = parser.parse_args() return args - def check_deps_version(mindspore_version, supported_version): """ check te/hccl/topi version @@ -46,6 +49,7 @@ def check_deps_version(mindspore_version, supported_version): Returns: void """ + # 尝试导入并检查 hccl、te 和 topi 包的版本是否与支持的版本匹配 try: from hccl import sys_version as hccl_version v = '.'.join(hccl_version.__sys_version__.split('.')[0:2]) @@ -63,18 +67,20 @@ def check_deps_version(mindspore_version, supported_version): print(f"MindSpore version {mindspore_version} and \"topi\" wheel package version {v} does not " "match, reference to the match info on: https://www.mindspore.cn/install") + # 捕获导入错误并打印相应的检查失败信息 except ImportError as e: print("CheckFailed: ", e.args) print("MindSpore relies on the 3 whl packages of \"te\", \"topi\" and \"hccl\" in the \"fwkacllib\" " "folder of the Ascend AI software package (Ascend Data Center Solution), please check whether they are " "installed correctly or not, reference to the match info on: https://www.mindspore.cn/install") - def main(): + # 解析命令行参数 args = parse_args() + # 检查 mindspore 的版本是否在支持的版本范围内 check_deps_version(args.mindspore_version, args.supported_version) - - + if __name__ == "__main__": - sys.path = sys.path[1:] # avoid the impact of relative path env, only affect this process + # 避免相对路径环境的影响,仅影响当前进程 + sys.path = sys.path[1:] main() diff --git a/src/mindspore2022/mindspore/python/mindspore/run_check/_check_version.py b/src/mindspore2022/mindspore/python/mindspore/run_check/_check_version.py index 5eb6f281..61ea0a8b 100644 --- a/src/mindspore2022/mindspore/python/mindspore/run_check/_check_version.py +++ b/src/mindspore2022/mindspore/python/mindspore/run_check/_check_version.py @@ -31,47 +31,58 @@ class EnvChecker(metaclass=ABCMeta): @abstractmethod def check_env(self, e): - pass + """检查环境是否符合要求""" @abstractmethod def set_env(self): - pass + """设置环境""" @abstractmethod def check_version(self): - pass + """检查版本是否符合要求""" class GPUEnvChecker(EnvChecker): """GPU environment check.""" def __init__(self): + # 初始化版本列表 self.version = ["10.1", "11.1"] + # 初始化库键到库名的映射字典 self.lib_key_to_lib_name = {'libcu': 'libcuda.so'} # env + # 获取系统环境变量 PATH 的值 self.path = os.getenv("PATH") + # 获取系统环境变量 LD_LIBRARY_PATH 的值 self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") # check + # 初始化版本号为 "0" self.v = "0" + # 获取 CUDA 库的路径 self.cuda_lib_path = self._get_lib_path("libcu") + # 获取 CUDA 可执行文件的路径 self.cuda_bin_path = self._get_bin_path("cuda") + # 获取 cuDNN 库的路径 self.cudnn_lib_path = self._get_lib_path("libcudnn") - def check_env(self, e): + # 抛出传入的异常 e raise e - + def set_env(self): + # 设置环境变量,当前实现为空 return - + def _get_bin_path(self, bin_name): """Get bin path by bin name.""" + # 如果二进制名称为 "cuda",则调用获取 CUDA 二进制路径的方法 if bin_name == "cuda": return self._get_cuda_bin_path() + # 否则返回空列表 return [] def _get_cuda_bin_path(self): - """Get cuda bin path by lib path.""" + # Get cuda bin path by lib path. path_list = [] for path in self.cuda_lib_path: path = os.path.abspath(path.strip()+"/bin/") @@ -81,56 +92,87 @@ class GPUEnvChecker(EnvChecker): def _get_nvcc_version(self, is_set_env): """Get cuda version by nvcc command.""" + # 运行 nvcc 命令获取 CUDA 版本信息 nvcc_result = subprocess.run(["nvcc", "--version | grep release"], timeout=3, text=True, capture_output=True, check=False) + # 如果命令返回非零值,表示命令执行失败 if nvcc_result.returncode: + # 如果尚未设置环境变量 if not is_set_env: + # 遍历预设的 CUDA 二进制路径 for path in self.cuda_bin_path: + # 检查路径中是否存在 nvcc 文件 if Path(path + "/nvcc").is_file(): + # 将路径添加到环境变量 PATH 中 os.environ['PATH'] = path + ":" + os.environ['PATH'] + # 递归调用以重新尝试获取版本信息 return self._get_nvcc_version(True) + # 如果命令执行失败且未找到 nvcc 文件,返回空字符串 return "" + # 获取命令输出结果 result = nvcc_result.stdout + # 遍历输出结果的每一行 for line in result.split('\n'): if line: + # 提取并返回 CUDA 版本号 return line.strip().split("release")[1].split(",")[0].strip() + # 如果未找到版本信息,返回空字符串 return "" def _get_cudnn_version(self): """Get cudnn version by libcudnn.so.""" + # 初始化cudnn版本列表为空 cudnn_version = [] + # 遍历cudnn库路径 for path in self.cudnn_lib_path: + # 查找路径下所有的libcudnn.so文件 real_path = glob.glob(path + "/lib*/libcudnn.so.*.*") + # 如果没有找到对应的文件,继续下一个路径 if real_path == []: continue + # 使用ls命令获取文件信息 ls_cudnn = subprocess.run(["ls", real_path[0]], timeout=10, text=True, capture_output=True, check=False) + # 如果ls命令执行成功,解析输出以获取版本号 if ls_cudnn.returncode == 0: cudnn_version = ls_cudnn.stdout.split('/')[-1].strip('libcudnn.so.').strip().split('.') + # 如果版本号只有两个部分,添加一个'.0'作为第三部分 if len(cudnn_version) == 2: cudnn_version.append('0') + # 找到版本号后跳出循环 break + # 将版本号列表转换为字符串 version_str = ''.join([n for n in cudnn_version]) + # 返回版本号的前三位 return version_str[0:3] def _get_cudart_version(self): """Get cuda runtime version by libcudart.so.""" + # 遍历可能的 CUDA 库路径 for path in self.cuda_lib_path: + # 查找路径下所有可能的 libcudart.so 文件 real_path = glob.glob(path + "/lib*/libcudart.so.*.*.*") + # 如果没有找到任何文件,则跳过当前路径 if real_path == []: continue + # 获取文件名信息以确定 CUDA 版本 ls_cudart = subprocess.run(["ls", real_path[0]], timeout=10, text=True, capture_output=True, check=False) + # 如果命令成功执行,则解析输出以提取版本号 if ls_cudart.returncode == 0: self.v = ls_cudart.stdout.split('/')[-1].strip('libcudart.so.').strip() + # 找到版本号后跳出循环 break + # 返回找到的 CUDA 版本号 return self.v def check_version(self): """Check cuda version.""" version_match = False + # 调用私有方法检查版本是否匹配,并根据结果设置version_match标志 if self._check_version(): version_match = True + # 如果版本不匹配,根据CUDA版本号输出不同的警告信息 if not version_match: if self.v == "0": logger.warning("Can not found cuda libs, please confirm that the correct " @@ -140,17 +182,20 @@ class GPUEnvChecker(EnvChecker): logger.warning(f"MindSpore version {__version__} and cuda version {self.v} does not match, " "please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install") + # 获取nvcc版本号,并检查是否与MindSpore支持的版本匹配 nvcc_version = self._get_nvcc_version(False) if nvcc_version and (nvcc_version not in self.version): logger.warning(f"MindSpore version {__version__} and nvcc(cuda bin) version {nvcc_version} " "does not match, please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install") + # 获取cudnn版本号,并检查是否符合最低要求 cudnn_version = self._get_cudnn_version() if cudnn_version and int(cudnn_version) < 760: logger.warning(f"MindSpore version {__version__} and cudDNN version {cudnn_version} " "does not match, please refer to the installation guide for version matching " "information: https://www.mindspore.cn/install. The recommended version is " "CUDA10.1 with cuDNN7.6.x and CUDA11.1 with cuDNN8.0.x") + # 检查cudnn版本号与CUDA版本号的兼容性,对于CUDA 11.0以上版本,cudnn版本需要至少为8.0 if cudnn_version and int(cudnn_version) < 800 and int(str(self.v).split('.')[0]) > 10: logger.warning(f"CUDA version {self.v} and cuDNN version {cudnn_version} " "does not match, please refer to the installation guide for version matching " @@ -159,45 +204,58 @@ class GPUEnvChecker(EnvChecker): def _check_version(self): """Check cuda version""" + # 获取 CUDA 运行时版本 v = self._get_cudart_version() + # 解析版本字符串为版本对象 v = version.parse(v) + # 构造版本号字符串,格式为 "主版本.次版本" v_str = str(v.major) + "." + str(v.minor) + # 检查构造的版本号字符串是否在预定义的版本列表中 if v_str not in self.version: return False + # 版本号匹配,返回 True return True def _get_lib_path(self, lib_name): - """Get gpu lib path by ldd command.""" - path_list = [] - current_path = os.path.split(os.path.realpath(__file__))[0] - mindspore_path = os.path.join(current_path, "../") + """通过ldd命令获取gpu库路径。""" + path_list = [] # 初始化一个空列表用于存储路径 + current_path = os.path.split(os.path.realpath(__file__))[0] # 获取当前文件的绝对路径并分割以获取目录部分 + mindspore_path = os.path.join(current_path, "../") # 构建mindspore路径,通常是当前文件的上一级目录 try: + # 使用glob模块查找mindspore_path目录下所有以_c_expression.so开头的文件路径 real_path = glob.glob(mindspore_path + "/_c_expression*.so*") - if real_path == []: - logger.error(f"{self.lib_key_to_lib_name[lib_name]} (need by mindspore-gpu) is not found, please " - f"confirm that _c_expression.so is in directory:{mindspore_path} and the correct cuda " - "version has been installed, you can refer to the installation " - "guidelines: https://www.mindspore.cn/install") - return path_list + if real_path == []: # 如果没有找到任何文件 + # 记录错误日志,提示用户确认_c_expression.so文件是否存在以及是否安装了正确的cuda版本 + logger.error(f"{self.lib_key_to_lib_name[lib_name]} (mindspore-gpu所需的库) 未找到,请确认 " + f"_c_expression.so是否位于目录:{mindspore_path}中,并且已安装正确的cuda版本," + "您可以参考安装指南:https://www.mindspore.cn/install") + return path_list # 返回空路径列表 + # 使用subprocess.Popen执行ldd命令以获取依赖库的信息 ldd_r = subprocess.Popen(['ldd', real_path[0]], stdout=subprocess.PIPE) + # 使用subprocess.Popen的stdin参数从ldd_r.stdout接收输出,并执行grep命令以过滤出包含指定库名的信息 ldd_result = subprocess.Popen(['grep', lib_name], stdin=ldd_r.stdout, stdout=subprocess.PIPE) + # 获取grep命令的输出结果,并解码为字符串 result = ldd_result.communicate()[0].decode() - for i in result.split('\n'): + for i in result.split('\n'): # 按行分割结果字符串 + # 使用partition方法从每一行中提取出库文件的路径 path = i.partition("=>")[2] - if path.lower().find("not found") > 0: - logger.warning(f"Cuda {self.version} version(need by mindspore-gpu) is not found, please confirm " - "that the path of cuda is set to the env LD_LIBRARY_PATH, please refer to the " - "installation guidelines: https://www.mindspore.cn/install") - continue + if path.lower().find("not found") > 0: # 如果路径中包含"not found" + # 记录警告日志,提示用户确认cuda路径是否已添加到环境变量LD_LIBRARY_PATH中 + logger.warning(f"Cuda {self.version}版本(由mindspore-gpu要求的) 未找到,请确认cuda路径已设置到环境变量LD_LIBRARY_PATH中," + "您可以参考安装指南:https://www.mindspore.cn/install") + continue # 继续下一次循环 + # 从路径中去除库名部分 path = path.partition(lib_name)[0] - if path: + if path: # 如果路径非空 + # 将路径的绝对路径并去除末尾斜杠后添加到path_list中 path_list.append(os.path.abspath(path.strip() + "../")) + # 返回path_list中唯一的路径 return np.unique(path_list) - except subprocess.TimeoutExpired: - logger.warning("Failed to check cuda version due to the ldd command timeout, please confirm that " - "the correct cuda version has been installed, you can refer to the " - "installation guidelines: https://www.mindspore.cn/install") - return path_list + except subprocess.TimeoutExpired: # 捕获subprocess.TimeoutExpired异常 + # 记录警告日志,提示用户确认cuda版本是否正确安装,因为ldd命令超时 + logger.warning("由于ldd命令超时,无法检查cuda版本,请确认已安装正确的cuda版本," + "您可以参考安装指南:https://www.mindspore.cn/install") + return path_list # 返回空路径列表 def _read_version(self, file_path): """Get gpu version info in version.txt.""" @@ -211,70 +269,80 @@ class GPUEnvChecker(EnvChecker): class AscendEnvChecker(EnvChecker): - """ascend environment check""" + """Ascend 环境检查类""" def __init__(self): + # 初始化 Ascend 环境检查器的版本列表 self.version = ["1.81"] + + # 定义不同路径下的 version.info 文件位置 atlas_nnae_version = "/usr/local/Ascend/nnae/latest/fwkacllib/version.info" atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/version.info" hisi_fwk_version = "/usr/local/Ascend/latest/fwkacllib/version.info" + + # 检查 Atlas NNAE 环境是否存在 if os.path.exists(atlas_nnae_version): - # atlas default path - self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" - self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" - self.tbe_path = self.fwk_path + "/lib64" - self.cce_path = self.fwk_path + "/ccec_compiler/bin" - self.fwk_version = atlas_nnae_version - self.op_path = "/usr/local/Ascend/nnae/latest/opp" - self.aicpu_path = "/usr/local/Ascend/nnae/latest" + # 如果存在,设置默认路径 + self.fwk_path = "/usr/local/Ascend/nnae/latest/fwkacllib" # Framework 路径 + self.op_impl_path = "/usr/local/Ascend/nnae/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径 + self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径 + self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径 + self.fwk_version = atlas_nnae_version # Framework 版本文件路径 + self.op_path = "/usr/local/Ascend/nnae/latest/opp" # Operator 路径 + self.aicpu_path = "/usr/local/Ascend/nnae/latest" # AI CPU 路径 + + # 检查 Atlas Toolkit 环境是否存在 elif os.path.exists(atlas_toolkit_version): - # atlas default path - self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" - self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" - self.tbe_path = self.fwk_path + "/lib64" - self.cce_path = self.fwk_path + "/ccec_compiler/bin" - self.fwk_version = atlas_toolkit_version - self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" - self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" + # 如果存在,设置默认路径 + self.fwk_path = "/usr/local/Ascend/ascend-toolkit/latest/fwkacllib" # Framework 路径 + self.op_impl_path = "/usr/local/Ascend/ascend-toolkit/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径 + self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径 + self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径 + self.fwk_version = atlas_toolkit_version # Framework 版本文件路径 + self.op_path = "/usr/local/Ascend/ascend-toolkit/latest/opp" # Operator 路径 + self.aicpu_path = "/usr/local/Ascend/ascend-toolkit/latest" # AI CPU 路径 + + # 检查 Hisi 环境是否存在 elif os.path.exists(hisi_fwk_version): - # hisi default path - self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" - self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" - self.tbe_path = self.fwk_path + "/lib64" - self.cce_path = self.fwk_path + "/ccec_compiler/bin" - self.fwk_version = hisi_fwk_version - self.op_path = "/usr/local/Ascend/latest/opp" - self.aicpu_path = "/usr/local/Ascend/latest" + # 如果存在,设置默认路径 + self.fwk_path = "/usr/local/Ascend/latest/fwkacllib" # Framework 路径 + self.op_impl_path = "/usr/local/Ascend/latest/opp/op_impl/built-in/ai_core/tbe" # Operator 实现路径 + self.tbe_path = self.fwk_path + "/lib64" # TBE 库路径 + self.cce_path = self.fwk_path + "/ccec_compiler/bin" # CCE 编译器路径 + self.fwk_version = hisi_fwk_version # Framework 版本文件路径 + self.op_path = "/usr/local/Ascend/latest/opp" # Operator 路径 + self.aicpu_path = "/usr/local/Ascend/latest" # AI CPU 路径 + else: - # custom or unknown environment - self.fwk_path = "" - self.op_impl_path = "" - self.tbe_path = "" - self.cce_path = "" - self.fwk_version = "" - self.op_path = "" - self.aicpu_path = "" - - # env + # 如果以上环境都不存在,设置为空路径 + self.fwk_path = "" # Framework 路径 + self.op_impl_path = "" # Operator 实现路径 + self.tbe_path = "" # TBE 库路径 + self.cce_path = "" # CCE 编译器路径 + self.fwk_version = "" # Framework 版本文件路径 + self.op_path = "" # Operator 路径 + self.aicpu_path = "" # AI CPU 路径 + + # 初始化环境变量 self.path = os.getenv("PATH") self.python_path = os.getenv("PYTHONPATH") self.ld_lib_path = os.getenv("LD_LIBRARY_PATH") self.ascend_opp_path = os.getenv("ASCEND_OPP_PATH") self.ascend_aicpu_path = os.getenv("ASCEND_AICPU_PATH") - # check content + # 设置需要检查的路径内容 self.path_check = "/fwkacllib/ccec_compiler/bin" self.python_path_check = "opp/op_impl/built-in/ai_core/tbe" self.ld_lib_path_check_fwk = "/fwkacllib/lib64" self.ld_lib_path_check_addons = "/add-ons" self.ascend_opp_path_check = "/op" self.v = "" - def check_env(self, e): self._check_env() raise e def check_version(self): + # 检查指定路径的版本文件是否存在,如果不存在则跳过版本检查 if not Path(self.fwk_version).is_file(): logger.warning("Using custom Ascend AI software package (Ascend Data Center Solution) path, package " "version checking is skipped, please make sure Ascend AI software package (Ascend Data " @@ -282,40 +350,47 @@ class AscendEnvChecker(EnvChecker): "https://www.mindspore.cn/install") return + # 读取版本文件中的版本信息 v = self._read_version(self.fwk_version) + # 如果读取的版本不在支持的版本列表中,则记录警告信息 if v not in self.version: v_list = str([x for x in self.version]) logger.warning(f"MindSpore version {__version__} and Ascend AI software package (Ascend Data Center " f"Solution)version {v} does not match, the version of software package expect one of " f"{v_list}, please reference to the match info on: https://www.mindspore.cn/install") - def check_deps_version(self): """ te, topi, hccl wheel package version check in order to update the change of 'LD_LIBRARY_PATH' env, run a sub process """ + # 构建输入参数列表,包含mindspore版本和受支持的版本列表 input_args = ["--mindspore_version=" + __version__] for v in self.version: input_args.append("--supported_version=" + v) + # 获取依赖版本检查脚本的路径 deps_version_checker = os.path.join(os.path.split(os.path.realpath(__file__))[0], "_check_deps_version.py") + # 构建调用命令,包括python解释器路径、脚本路径和输入参数 call_cmd = [sys.executable, deps_version_checker] + input_args try: + # 运行子进程进行版本检查,设置超时时间为3秒,并捕获输出 process = subprocess.run(call_cmd, timeout=3, text=True, capture_output=True, check=False) + # 如果子进程的输出不为空,则记录警告信息并进行倒计时提醒 if process.stdout.strip() != "": logger.warning(process.stdout.strip()) warning_countdown = 3 for i in range(warning_countdown, 0, -1): logger.warning(f"Please pay attention to the above warning, countdown: {i}") time.sleep(1) + # 如果版本检查超时,则记录信息并跳过 except subprocess.TimeoutExpired: logger.info("Package te, topi, hccl version check timed out, skip.") - def set_env(self): + # 设置Ascend环境变量 if not self.tbe_path: self._check_env() return - + try: import te # pylint: disable=unused-import # pylint: disable=broad-except @@ -329,32 +404,35 @@ class AscendEnvChecker(EnvChecker): raise EnvironmentError( f"No such directory: {self.tbe_path}, Please check if Ascend AI software package (Ascend Data " "Center Solution) is installed correctly.") - - # check te version after set te env + + # 检查te版本 self.check_deps_version() - + + # 设置op实现路径环境变量 if Path(self.op_impl_path).is_dir(): - # python path for sub process + # python路径用于子进程 if os.getenv('PYTHONPATH'): os.environ['PYTHONPATH'] = self.op_impl_path + ":" + os.environ['PYTHONPATH'] else: os.environ['PYTHONPATH'] = self.op_impl_path - # sys path for this process + # sys路径用于当前进程 sys.path.append(self.op_impl_path) - + os.environ['TBE_IMPL_PATH'] = self.op_impl_path else: raise EnvironmentError( - f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data " - "Center Solution) is installed correctly.") - + f"No such directory: {self.op_impl_path}, Please check if Ascend AI software package (Ascend Data Center " + "Solution) is installed correctly.") + + # 设置CCE路径环境变量 if Path(self.cce_path).is_dir(): os.environ['PATH'] = self.cce_path + ":" + os.environ['PATH'] else: raise EnvironmentError( f"No such directory: {self.cce_path}, Please check if Ascend AI software package (Ascend Data Center " "Solution) is installed correctly.") - + + # 设置OP路径环境变量 if self.op_path is None: pass elif Path(self.op_path).is_dir(): @@ -363,7 +441,8 @@ class AscendEnvChecker(EnvChecker): raise EnvironmentError( f"No such directory: {self.op_path}, Please check if Ascend AI software package (Ascend Data Center " "Solution) is installed correctly.") - + + # 设置AICPU路径环境变量 if self.aicpu_path is None: pass elif Path(self.aicpu_path).is_dir(): @@ -372,44 +451,54 @@ class AscendEnvChecker(EnvChecker): raise EnvironmentError( f"No such directory: {self.aicpu_path}, Please check if Ascend AI software package (Ascend Data Center" " Solution) is installed correctly.") - + def _check_env(self): """ascend dependence path check""" + # 检查是否设置正确的PATH环境变量 if self.path is None or self.path_check not in self.path: logger.warning("Can not find ccec_compiler(need by mindspore-ascend), please check if you have set env " "PATH, you can reference to the installation guidelines https://www.mindspore.cn/install") - + + # 检查是否设置正确的PYTHONPATH环境变量 if self.python_path is None or self.python_path_check not in self.python_path: logger.warning( "Can not find tbe op implement(need by mindspore-ascend), please check if you have set env " "PYTHONPATH, you can reference to the installation guidelines " "https://www.mindspore.cn/install") - + + # 检查是否设置正确的LD_LIBRARY_PATH环境变量 if self.ld_lib_path is None or not (self.ld_lib_path_check_fwk in self.ld_lib_path and self.ld_lib_path_check_addons in self.ld_lib_path): logger.warning("Can not find driver so(need by mindspore-ascend), please check if you have set env " "LD_LIBRARY_PATH, you can reference to the installation guidelines " "https://www.mindspore.cn/install") - + + # 检查是否设置正确的ASCEND_OPP_PATH环境变量 if self.ascend_opp_path is None or self.ascend_opp_path_check not in self.ascend_opp_path: logger.warning( "Can not find opp path (need by mindspore-ascend), please check if you have set env ASCEND_OPP_PATH, " "you can reference to the installation guidelines https://www.mindspore.cn/install") - def _read_version(self, file_path): """get ascend version info""" with open(file_path, 'r') as f: all_info = f.readlines() + # 遍历文件中的每一行 for line in all_info: + # 检查行是否以 "Version=" 开头 if line.startswith("Version="): + # 去除行末的换行符并按 "=" 分割, 获取版本号 full_version = line.strip().split("=")[1] + # 提取主版本号和次版本号, 并用 "." 连接 self.v = '.'.join(full_version.split('.')[0:2]) + # 返回版本号 return self.v + # 如果未找到版本信息, 返回 None 或默认值 return self.v def check_version_and_env_config(): - """check version and env config""" + """检查版本和环境配置""" + # 检查包名以确定使用哪种环境检查器 if __package_name__.lower() == "mindspore-ascend": env_checker = AscendEnvChecker() # Note: pre-load libgomp.so to solve error like "cannot allocate memory in statis TLS block" @@ -425,19 +514,21 @@ def check_version_and_env_config(): else: logger.info(f"Package version {__package_name__} does not need to check any environment variable, skipping.") return + # 检查是否关闭版本检查,如果已关闭则直接返回 if os.getenv("MS_DEV_CLOSE_VERSION_CHECK") == "ON": return + # 设置环境变量以关闭版本检查 os.environ["MS_DEV_CLOSE_VERSION_CHECK"] = "ON" - try: - # check version of ascend site or cuda + # 检查 ascend site 或 cuda 的版本 env_checker.check_version() from .. import _c_expression # pylint: disable=unused-import + # 设置环境 env_checker.set_env() except ImportError as e: + # 处理导入错误,检查环境 env_checker.check_env(e) - def _set_pb_env(): """Set env variable `PROTOCOL_BUFFERS` to prevent memory overflow.""" if os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION") == "cpp": @@ -449,7 +540,9 @@ def _set_pb_env(): logger.info("Setting the env `PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python` to prevent memory overflow " "during save or load checkpoint file.") os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" - - + +# 检查版本和环境配置 check_version_and_env_config() + +# 设置协议缓冲区的环境变量, 防止内存溢出 _set_pb_env() diff --git a/src/mindspore2022/mindspore/python/mindspore/run_check/run_check.py b/src/mindspore2022/mindspore/python/mindspore/run_check/run_check.py index 225cd92a..4838e654 100644 --- a/src/mindspore2022/mindspore/python/mindspore/run_check/run_check.py +++ b/src/mindspore2022/mindspore/python/mindspore/run_check/run_check.py @@ -26,23 +26,26 @@ def _check_mul(): """ from importlib import import_module import numpy as np - + try: ms = import_module("mindspore") except ModuleNotFoundError: ms = None finally: pass - + + # 打印MindSpore版本信息 print(f"MindSpore version: ", ms.__version__) - + + # 创建两个MindSpore张量,分别包含数组[1.0, 2.0, 3.0]和[4.0, 5.0, 6.0] input_x = ms.Tensor(np.array([1.0, 2.0, 3.0]), ms.float32) input_y = ms.Tensor(np.array([4.0, 5.0, 6.0]), ms.float32) + # 创建一个乘法操作对象 mul = ms.ops.Mul() + # 执行乘法操作 mul(input_x, input_y) + # 打印乘法计算结果正确,MindSpore安装成功的信息 print(f"The result of multiplication calculation is correct, MindSpore has been installed successfully!") - - def run_check(): """ Provide a convenient API to check if the installation is successful or failed. @@ -55,10 +58,13 @@ def run_check(): The result of multiplication calculation is correct, MindSpore has been installed successfully! """ try: + # 尝试执行检查乘法操作的函数 _check_mul() # pylint: disable=broad-except + # 捕获所有异常并打印错误信息 except Exception as e: print("MindSpore running check failed.") print(str(e)) finally: - pass + # 无论是否发生异常,都会执行此部分代码 + pass # 执行乘法检查的函数,并处理可能的异常情况。如果检查失败,打印错误信息。 \ No newline at end of file diff --git a/src/mindspore2022/tests/dataset_mock.py b/src/mindspore2022/tests/dataset_mock.py index c3308a09..e4e1f42f 100644 --- a/src/mindspore2022/tests/dataset_mock.py +++ b/src/mindspore2022/tests/dataset_mock.py @@ -1,3 +1,6 @@ +这段代码是一个Python类的实现,名为`MindData`,它是一个用于模拟MindSpore框架中数据集处理的桩(Stub)。下面是对这段代码的逐行注释: + +```python # Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,77 +15,94 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ + '''Remove after MindData merge to MindSpore ''' import numpy as np from mindspore import Tensor - class MindData: """ Stub for MindData """ - + # 构造函数,初始化MindData类的实例 def __init__(self, size=1, batch_size=None, repeat_count=1, np_types=None, output_shapes=None, input_indexs=()): - self._size = size - self._batch_size = batch_size - self._repeat_count = repeat_count - self._np_types = np_types - self._output_shapes = output_shapes - self._input_indexs = input_indexs - self._iter_num = 0 - self.dynamic_setting = [False, None] - + self._size = size # 数据集的大小 + self._batch_size = batch_size # 批处理大小 + self._repeat_count = repeat_count # 重复次数 + self._np_types = np_types # NumPy数据类型 + self._output_shapes = output_shapes # 输出形状 + self._input_indexs = input_indexs # 输入索引 + self._iter_num = 0 # 迭代次数计数器 + self.dynamic_setting = [False, None] # 动态设置标志和值 + + # 获取数据集大小 def get_dataset_size(self): return self._size + # 获取重复次数 def get_repeat_count(self): return self._repeat_count + # 获取批处理大小 def get_batch_size(self): return self._batch_size + # 获取输出数据类型 def output_types(self): return self._np_types + # 获取输出形状 def output_shapes(self): return self._output_shapes + # 输入索引属性 @property def input_indexs(self): return self._input_indexs + # 设备队列设置 def device_que(self, send_epoch_end=True, create_data_info_queue=False): self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736' self.send_epoch_end = send_epoch_end return self + # 创建元组迭代器 def create_tuple_iterator(self, num_epochs=-1, do_copy=True): return self.__iter__() + # 发送数据 def send(self, num_epochs=-1): pass + # 停止发送数据 def stop_send(self): pass + # 释放资源 def release(self): pass + # 继续发送数据 def continue_send(self): pass + # 获取数据信息 def get_data_info(self): pass + # 动态最小最大形状 def dynamic_min_max_shapes(self): pass + # 获取长度 def __len__(self): return self._size + # 迭代器 def __iter__(self): return self + # 获取下一个元素 def __next__(self): if self._size < self._iter_num: raise StopIteration @@ -90,11 +110,13 @@ class MindData: next_value = [] for shape, typ in zip(self._output_shapes, self._np_types): next_value.append(Tensor(np.ndarray(shape, typ))) - return tuple(next_value) + # 下一个元素 def next(self): return self.__next__() + # 重置迭代器 def reset(self): self._iter_num = 0 +