diff --git a/src/mindspore2022/mindspore/python/mindspore/parallel/_auto_parallel_context.py b/src/mindspore2022/mindspore/python/mindspore/parallel/_auto_parallel_context.py index a802a2a4..438564a1 100644 --- a/src/mindspore2022/mindspore/python/mindspore/parallel/_auto_parallel_context.py +++ b/src/mindspore2022/mindspore/python/mindspore/parallel/_auto_parallel_context.py @@ -29,538 +29,621 @@ _DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" class _ParallelFusionConfig: """ - The key of the Parallel fusion method configuration. + 并行融合方法的配置键类。 """ - ALLREDUCE = "allreduce" - ALLGATHER = "allgather" - REDUCESCATTER = "reducescatter" - MODE = "mode" - FUSION_CONFIG = "config" - AUTO = "auto" - INDEX = "index" - SIZE = "size" - - + ALLREDUCE = "allreduce" # allreduce 通信融合方法的键 + ALLGATHER = "allgather" # allgather 通信融合方法的键 + REDUCESCATTER = "reducescatter" # reducescatter 通信融合方法的键 + MODE = "mode" # 融合方法的模式键 + FUSION_CONFIG = "config" # 融合方法的配置键 + AUTO = "auto" # 自动融合模式 + INDEX = "index" # 通过索引设置融合的策略 + SIZE = "size" # 通过大小设置融合的策略 + + class _ParallelOptimizerConfig: """ - The key of the Parallel Optimizer. There are three + 并行优化器的配置键类。 """ - GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard" - PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold" - - + GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard" # 梯度累积分片配置键 + PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold" # 并行优化器阈值配置键 + + class _AutoParallelContext: """ - _AutoParallelContext is the environment in which operations are executed - - Note: - Create a context through instantiating Context object is not recommended. - Should use auto_parallel_context() to get the context since Context is singleton. + _AutoParallelContext 类是操作执行的环境。 + + 注意: + 不建议通过实例化 Context 对象来创建上下文。 + 应该使用 auto_parallel_context() 来获取上下文,因为 Context 是单例模式。 """ - _instance = None - _instance_lock = threading.Lock() - + _instance = None # 类的单例实例 + _instance_lock = threading.Lock() # 用于线程安全的锁 + def __init__(self): - self._context_handle = AutoParallelContext.get_instance() - self._dataset_strategy_using_str = True - + """ + 初始化方法。 + """ + self._context_handle = AutoParallelContext.get_instance() # 获取 AutoParallelContext 的实例 + self._dataset_strategy_using_str = True # 标记是否使用字符串形式的 dataset 策略 + def __new__(cls): + """ + 类的实例化方法,确保单例模式。 + """ if cls._instance is None: - cls._instance_lock.acquire() - cls._instance = object.__new__(cls) - cls._instance_lock.release() - return cls._instance - + cls._instance_lock.acquire() # 获取锁 + cls._instance = object.__new__(cls) # 创建实例 + cls._instance_lock.release() # 释放锁 + return cls._instance # 返回单例实例 + def check_context_handle(self): """ - Check context handle. - - Raises: - ValueError: If the context handle is none. + 检查上下文句柄。 + + 引发: + ValueError: 如果上下文句柄为 None。 """ if self._context_handle is None: raise ValueError("Context handle is none in context!!!") - + def set_device_num(self, device_num): """ - Set device num for auto parallel. - - Args: - device_num (int): The device number. - - Raises: - ValueError: If the device num is not in [1, 4096]. + 设置自动并行的设备数量。 + + 参数: + device_num (int): 设备数量。 + + 引发: + ValueError: 如果设备数量不在 [1, 4096] 范围内。 """ - self.check_context_handle() + self.check_context_handle() # 检查上下文句柄 if device_num < 1 or device_num > 4096: raise ValueError("The context configuration parameter 'device_num' must be in [1, 4096], " "but got the value of device_num : {}.".format(device_num)) - from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE - self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE) - self._context_handle.set_device_num(device_num) - + from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE # 导入 HCCL 测试可用性标志 + self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE) # 设置 HCCL 测试可用性 + self._context_handle.set_device_num(device_num) # 设置设备数量 + def get_device_num(self): - """Get device num.""" - self.check_context_handle() - return self._context_handle.get_device_num() - + """ + 获取设备数量。 + """ + self.check_context_handle() # 检查上下文句柄 + return self._context_handle.get_device_num() # 返回设备数量 + def set_comm_fusion(self, config): """ - Set fusion method for auto parallel. - - Args: - config (dict): A dict contains the methods and values for setting the communication fusion. Currently it - supports: `allreduce`. - - Raises: - KeyError: When key of comm_fusion is not 'allreduce'. + 设置自动并行的通信融合方法。 + + 参数: + config (dict): 包含设置通信融合的方法和值的字典。目前支持 'allreduce'。 + + 引发: + KeyError: 当通信融合的键不是 'allreduce' 时。 """ - self.check_context_handle() + self.check_context_handle() # 检查上下文句柄 for key in list(config.keys()): if key == _ParallelFusionConfig.ALLREDUCE: - self._set_allreduce_comm_fusion(config[key]) + self._set_allreduce_comm_fusion(config[key]) # 设置 allreduce 通信融合 elif key == _ParallelFusionConfig.ALLGATHER: - self._set_allgather_comm_fusion(config[key], key) + self._set_allgather_comm_fusion(config[key], key) # 设置 allgather 通信融合 elif key == _ParallelFusionConfig.REDUCESCATTER: - self._set_allgather_comm_fusion(config[key], key) + self._set_allgather_comm_fusion(config[key], key) # 设置 reducescatter 通信融合 else: raise KeyError("comm fusion type must be allreduce, allgather or reducescatter, but got {}".format(key)) - def _set_allreduce_comm_fusion(self, comm_fusion): - """ - Set fusion method for auto parallel. - Args: - comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it - supports four fusion methods: `auto`, `size` and `index`. +def _set_allreduce_comm_fusion(self, comm_fusion): + """ + 设置自动并行的 allreduce 通信融合方法。 - Raises: - KeyError: When key of comm_fusion is not 'mode' or 'config'. - KeyError: When `mode` is not 'auto', 'size' or 'index'. - """ - self.check_context_handle() - if not self.get_enable_all_reduce_fusion(): - return - if not isinstance(comm_fusion, dict): - raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format( - type(comm_fusion))) - if _ParallelFusionConfig.MODE not in comm_fusion: - raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") - if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: - raise KeyError("For 'comm_fusion', the key 'config' should be contained.") - check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE] - if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: - self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) - else: - raise KeyError("fusion method mode must be auto, index or size, but got {}".format( - comm_fusion[_ParallelFusionConfig.MODE])) - if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO: - self.set_fusion_threshold_mb(fusion_threshold=64) - if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE: - self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) - if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX: - self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) - - def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"): - """ - Set allgather and reducescatter fusion method for auto parallel. + 参数: + comm_fusion (dict): 包含设置融合方法的方法和值的字典。目前支持 'auto', 'size' 和 'index'。 - Args: - comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it - supports four fusion methods: `auto` and `size`. - comm_type (str): The name of the communication operator, `allgather` or `reducescatter`. + 引发: + KeyError: 当通信融合的键不是 'mode' 或 'config' 时。 + KeyError: 当 'mode' 不是 'auto', 'size' 或 'index' 时。 + """ + self.check_context_handle() # 检查上下文句柄 + if not self.get_enable_all_reduce_fusion(): + return # 如果未启用 allreduce 融合,则直接返回 + if not isinstance(comm_fusion, dict): + raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format( + type(comm_fusion))) # 检查配置类型是否为字典 + if _ParallelFusionConfig.MODE not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") # 检查是否包含 'mode' 键 + if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'config' should be contained.") # 检查是否包含 'config' 键 + check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE] + if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: + self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) # 设置融合模式 + else: + raise KeyError("fusion method mode must be auto, index or size, but got {}".format( + comm_fusion[_ParallelFusionConfig.MODE])) # 抛出错误,如果模式不支持 + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO: + self.set_fusion_threshold_mb(fusion_threshold=64) # 自动模式下设置默认融合阈值 + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE: + self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) # 按大小模式设置融合阈值 + if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX: + self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) # 按索引模式设置融合策略 + + +def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"): + """ + 设置自动并行的 allgather 和 reducescatter 通信融合方法。 - Raises: - KeyError: When key of comm_fusion is not 'mode' or 'config'. - KeyError: When `mode` is not 'auto', 'size'. - """ - self.check_context_handle() - if comm_type == "allgather" and not self.get_enable_all_gather_fusion(): - return - if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion(): - return - if not isinstance(comm_fusion, dict): - raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format( - comm_type, type(comm_fusion))) - if _ParallelFusionConfig.MODE not in comm_fusion: - raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") - if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: - raise KeyError("For 'comm_fusion', the key 'config' should be contained.") - check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE] - if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: - self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) - else: - raise KeyError("fusion method mode must be auto or size, but got {}".format( - comm_fusion[_ParallelFusionConfig.MODE])) + 参数: + comm_fusion (dict): 包含设置融合方法的方法和值的字典。目前支持 'auto' 和 'size'。 + comm_type (str): 通信操作符的名称,'allgather' 或 'reducescatter'。 - fusion_threshold = 64 if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO else \ - comm_fusion[_ParallelFusionConfig.FUSION_CONFIG] - self.set_fusion_threshold_mb(fusion_threshold, comm_type) + 引发: + KeyError: 当通信融合的键不是 'mode' 或 'config' 时。 + KeyError: 当 'mode' 不是 'auto' 或 'size' 时。 + """ + self.check_context_handle() # 检查上下文句柄 + if comm_type == "allgather" and not self.get_enable_all_gather_fusion(): + return # 如果未启用 allgather 融合,则直接返回 + if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion(): + return # 如果未启用 reducescatter 融合,则直接返回 + if not isinstance(comm_fusion, dict): + raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format( + comm_type, type(comm_fusion))) # 检查配置类型是否为字典 + if _ParallelFusionConfig.MODE not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") # 检查是否包含 'mode' 键 + if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: + raise KeyError("For 'comm_fusion', the key 'config' should be contained.") # 检查是否包含 'config' 键 + check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE] + if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: + self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) # 设置融合模式 + else: + raise KeyError("fusion method mode must be auto or size, but got {}".format( + comm_fusion[_ParallelFusionConfig.MODE])) # 抛出错误,如果模式不支持 + + fusion_threshold = 64 if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO else \ + comm_fusion[_ParallelFusionConfig.FUSION_CONFIG] # 设置融合阈值 + self.set_fusion_threshold_mb(fusion_threshold, comm_type) # 设置指定通信类型的融合阈值 + + +def get_comm_fusion(self): + """ + 获取通信融合配置。 + """ + self.check_context_handle() # 检查上下文句柄 + mode = self._context_handle.get_fusion_mode() # 获取当前融合模式 + if mode in (_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE): + config = self.fusion_threshold_mb() # 获取融合阈值作为配置 + if mode == _ParallelFusionConfig.INDEX: + config = self.get_all_reduce_fusion_split_indices() # 获取索引作为配置 + return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode, + _ParallelFusionConfig.FUSION_CONFIG: config}} # 返回配置字典 - def get_comm_fusion(self): - """Get comm fusion config.""" - self.check_context_handle() - mode = self._context_handle.get_fusion_mode() - if mode in (_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE): - config = self.fusion_threshold_mb() - if mode == _ParallelFusionConfig.INDEX: - config = self.get_all_reduce_fusion_split_indices() - return {_ParallelFusionConfig.ALLREDUCE: {_ParallelFusionConfig.MODE: mode, - _ParallelFusionConfig.FUSION_CONFIG: config}} +def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"): + """ + 设置自动并行的融合阈值(MB)。 - def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"): - """ - Set fusion threshold (MB) for auto parallel. + 参数: + fusion_threshold (int): 融合阈值(单位:MB)。默认:64。 + comm_type (str): 通信操作符的名称,'allreduce', 'allgather' 或 'reducescatter'。 - Args: - fusion_threshold (int): The fusion threshold (unit: MB). Default: 64. - comm_type (str): The name of the communication operator, `allreduce`, `allgather` or `reducescatter`. + 引发: + ValueError: 如果融合阈值不在 [0, +inf] 范围内。 + """ + self.check_context_handle() # 检查上下文句柄 + if fusion_threshold < 0: + raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold)) # 检查阈值是否有效 - Raises: - ValueError: If the fusion threshold is not in [0, +inf]. - """ - self.check_context_handle() - if fusion_threshold < 0: - raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold)) + if comm_type == _ParallelFusionConfig.ALLREDUCE: + self._context_handle.set_fusion_threshold_mb(fusion_threshold) # 设置 allreduce 的融合阈值 + if comm_type == _ParallelFusionConfig.ALLGATHER: + self._context_handle.set_allgather_fusion_threshold_mb(fusion_threshold) # 设置 allgather 的融合阈值 + if comm_type == _ParallelFusionConfig.REDUCESCATTER: + self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold) # 设置 reducescatter 的融合阈值 - if comm_type == _ParallelFusionConfig.ALLREDUCE: - self._context_handle.set_fusion_threshold_mb(fusion_threshold) - if comm_type == _ParallelFusionConfig.ALLGATHER: - self._context_handle.set_allgather_fusion_threshold_mb(fusion_threshold) - if comm_type == _ParallelFusionConfig.REDUCESCATTER: - self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold) +def fusion_threshold_mb(self): + """ + 获取 allreduce 的融合阈值。 + """ + self.check_context_handle() # 检查上下文句柄 + return self._context_handle.fusion_threshold_mb() # 返回 allreduce 的融合阈值 - def fusion_threshold_mb(self): - """Get all reduce threshold.""" - self.check_context_handle() - return self._context_handle.fusion_threshold_mb() - def allgather_fusion_threshold_mb(self): - """Get allgather threshold.""" - self.check_context_handle() - return self._context_handle.allgather_fusion_threshold_mb() +def allgather_fusion_threshold_mb(self): + """ + 获取 allgather 的融合阈值。 + """ + self.check_context_handle() # 检查上下文句柄 + return self._context_handle.allgather_fusion_threshold_mb() # 返回 allgather 的融合阈值 - def reducescatter_fusion_threshold_mb(self): - """Get reducescatter threshold.""" - self.check_context_handle() - return self._context_handle.reducescatter_fusion_threshold_mb() - def set_global_rank(self, global_rank): - """ - Set global rank for auto parallel. +def reducescatter_fusion_threshold_mb(self): + """ + 获取 reducescatter 的融合阈值。 + """ + self.check_context_handle() # 检查上下文句柄 + return self._context_handle.reducescatter_fusion_threshold_mb() # 返回 reducescatter 的融合阈值 - Args: - global_rank (int): The rank id of current rank. - Raises: - ValueError: If the global rank is not in [1, 4096]. - """ - self.check_context_handle() - if global_rank < 0 or global_rank > 4095: - raise ValueError("The context configuration parameter 'global_rank' must be in [0, 4095], " - "but got the value of global_rank : {}.".format(global_rank)) - self._context_handle.set_global_rank(global_rank) +def set_global_rank(self, global_rank): + """ + 设置自动并行的全局排名。 - def get_global_rank(self): - """Get current rank id.""" - self.check_context_handle() - return self._context_handle.get_global_rank() - - def set_pipeline_stages(self, stages): - """Set the stages of the pipeline""" - if isinstance(stages, bool) or not isinstance(stages, int): - raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' " - "must be int, but got the type : {}.".format(type(stages))) - if stages < 1: - raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' " - "should be greater or equal 1, but got the value of stages : {}.".format(stages)) - self.check_context_handle() - self._context_handle.set_pipeline_stage_split_num(stages) + 参数: + global_rank (int): 当前排名的排名 ID。 - def get_pipeline_stages(self): - """Get the stages of the pipeline""" - self.check_context_handle() - return self._context_handle.get_pipeline_stage_split_num() + 引发: + ValueError: 如果全局排名不在 [0, 4095] 范围内。 + """ + self.check_context_handle() # 检查上下文句柄 + if global_rank < 0 or global_rank > 4095: + raise ValueError("The context configuration parameter 'global_rank' must be in [0, 4095], " + "but got the value of global_rank : {}.".format(global_rank)) # 检查全局排名是否有效 + self._context_handle.set_global_rank(global_rank) # 设置全局排名 - def set_gradients_mean(self, gradients_mean): - """ - Set gradients_mean flag. - Note: - If gradients_mean is true, it will insert a div operator after parameter gradients allreduce. +def get_global_rank(self): + """ + 获取当前排名 ID。 + """ + self.check_context_handle() # 检查上下文句柄 + return self._context_handle.get_global_rank() # 返回当前排名 ID - Args: - gradients_mean (bool): The gradients_mean flag. - """ - self.check_context_handle() - self._context_handle.set_gradients_mean(gradients_mean) - def get_gradients_mean(self): - """Get gradients_mean flag.""" - self.check_context_handle() - return self._context_handle.get_gradients_mean() +def set_pipeline_stages(self, stages): + """ + 设置流水线的阶段数。 + """ + if isinstance(stages, bool) or not isinstance(stages, int): + raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' " + "must be int, but got the type : {}.".format(type(stages))) # 检查阶段数类型是否为整数 + if stages < 1: + raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' " + "should be greater or equal 1, but got the value of stages : {}.".format(stages)) # 检查阶段数是否有效 + self.check_context_handle() # 检查上下文句柄 + self._context_handle.set_pipeline_stage_split_num(stages) # 设置阶段数 - def set_gradient_fp32_sync(self, gradient_fp32_sync): - """ - Set gradient_fp32_sync. - Note: - If gradient_fp32_sync is true, - it will convert tensor type from fp16 to fp32 before parameter gradients allreduce. +def get_pipeline_stages(self): + """ + 获取流水线的阶段数。 + """ + self.check_context_handle() # 检查上下文句柄 + return self._context_handle.get_pipeline_stage_split_num() # 返回阶段数 - Args: - gradient_fp32_sync (bool): The gradient_fp32_sync flag. - """ - self.check_context_handle() - self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync) - def get_gradient_fp32_sync(self): - """Get gradient_fp32_sync flag.""" - self.check_context_handle() - return self._context_handle.get_gradient_fp32_sync() +def set_gradients_mean(self, gradients_mean): + """ + 设置 gradients_mean 标志。 - def set_loss_repeated_mean(self, loss_repeated_mean): - """ - Set loss_repeated_mean flag. + 注意: + 如果 gradients_mean 为真,将在参数梯度 allreduce 后插入一个除法操作符。 - Note: - If loss_repeated_mean is true, - Distributed automatic differentiation will perform a mean operator - in backward in the case of repeated calculations. + 参数: + gradients_mean (bool): gradients_mean 标志。 + """ + self.check_context_handle() # 检查上下文句柄 + self._context_handle.set_gradients_mean(gradients_mean) # 设置 gradients_mean 标志 - Args: - loss_repeated_mean (bool): The loss_repeated_mean flag. - """ - if not isinstance(loss_repeated_mean, bool): - raise TypeError("For 'set_auto_parallel_context', the argument 'loss_repeated_mean' " - "must be bool, but got the type : {}.".format(type(loss_repeated_mean))) - self.check_context_handle() - self._context_handle.set_loss_repeated_mean(loss_repeated_mean) +def get_gradients_mean(self): + """ + 获取 gradients_mean 标志。 + """ + self.check_context_handle() # 检查上下文句柄 + return self._context_handle.get_gradients_mean() # 返回 gradients_mean 标志 + + +def set_gradient_fp32_sync(self, gradient_fp32_sync): + """ + 设置 gradient_fp32_sync。 + + 注意: + 如果 gradient_fp32_sync 为真,将在参数梯度 allreduce 前将张量类型从 fp16 转换为 fp32。 + + 参数: + gradient_fp32_sync (bool): gradient_fp32_sync 标志。 + """ + self.check_context_handle() # 检查上下文句柄 + self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync) # 设置 gradient_fp32_sync 标志 + + +def get_gradient_fp32_sync(self): + """ + 获取 gradient_fp32_sync 标志。 + """ + self.check_context_handle() # 检查上下文句柄 + return self._context_handle.get_gradient_fp32_sync() # 返回 gradient_fp32_sync 标志 + + +def set_loss_repeated_mean(self, loss_repeated_mean): + """ + 设置 loss_repeated_mean 标志。 + + 注意: + 如果 loss_repeated_mean 为真,在重复计算的情况下,分布式自动微分将执行一个均值操作符。 + + 参数: + loss_repeated_mean (bool): loss_repeated_mean 标志。 + """ + if not isinstance(loss_repeated_mean, bool): + raise TypeError("For 'set_auto_parallel_context', the argument 'loss_repeated_mean' " + "must be bool, but got the type : {}.".format(type(loss_repeated_mean))) # 检查类型是否为布尔 def get_loss_repeated_mean(self): """Get loss_repeated_mean flag.""" self.check_context_handle() return self._context_handle.get_loss_repeated_mean() - def set_parallel_mode(self, parallel_mode): - """ - Set parallel mode for auto parallel. +def set_parallel_mode(self, parallel_mode): + """ + 设置自动并行的并行模式。 - Args: - parallel_mode (str): The parallel mode of auto parallel. + 参数: + parallel_mode (str): 自动并行的并行模式。 - Raises: - ValueError: If parallel mode is not supported. - """ - self.check_context_handle() - run_mode = context.get_context("mode") - if run_mode == context.PYNATIVE_MODE and parallel_mode not in ( - context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE, - context.ParallelMode.AUTO_PARALLEL): - raise ValueError(f"Pynative only supports STAND_ALONE, DATA_PARALLEL and AUTO_PARALLEL using" - f" sharding_propagation under shard function" - f" for ParallelMode, " - f"but got {parallel_mode.upper()}.") - ret = self._context_handle.set_parallel_mode(parallel_mode) - if ret is False: - raise ValueError("The context configuration parameter 'parallel_mode' only support 'stand_alone', " - "'data_parallel', 'hybrid_parallel', 'semi_auto_parallel' and 'auto_parallel', " - "but got the value : {}.".format(parallel_mode)) - - def get_parallel_mode(self): - """Get parallel mode.""" - self.check_context_handle() - if _is_role_pserver(): - return context.ParallelMode.STAND_ALONE - return self._context_handle.get_parallel_mode() + 引发: + ValueError: 如果并行模式不支持。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + run_mode = context.get_context("mode") # 获取当前的运行模式 + if run_mode == context.PYNATIVE_MODE and parallel_mode not in ( + context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE, + context.ParallelMode.AUTO_PARALLEL): + raise ValueError(f"Pynative 模式仅支持 STAND_ALONE, DATA_PARALLEL 和 AUTO_PARALLEL 使用" + f" sharding_propagation 在 shard 函数下" + f" 作为并行模式, " + f"但得到的是 {parallel_mode.upper()}.") # 检查 Pynative 模式下支持的并行模式 + ret = self._context_handle.set_parallel_mode(parallel_mode) # 设置并行模式 + if ret is False: + raise ValueError("上下文配置参数 'parallel_mode' 仅支持 'stand_alone', " + "'data_parallel', 'hybrid_parallel', 'semi_auto_parallel' 和 'auto_parallel', " + "但得到的是 : {}.".format(parallel_mode)) # 检查设置的并行模式是否有效 + + +def get_parallel_mode(self): + """ + 获取并行模式。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + if _is_role_pserver(): # 检查当前角色是否为参数服务器 + return context.ParallelMode.STAND_ALONE # 参数服务器默认使用 STAND_ALONE 模式 + return self._context_handle.get_parallel_mode() # 返回当前并行模式 - def set_strategy_search_mode(self, search_mode): - """ - Set search mode of strategy. - Args: - search_mode (str): The search mode of strategy. - """ - self.check_context_handle() - ret = self._context_handle.set_strategy_search_mode(search_mode) - if ret is False: - raise ValueError("The context configuration parameter 'auto_parallel_search_mode' only support " - "'recursive_programming', 'dynamic_programming' and 'sharding_propagation', " - "but got the value: {}." - .format(search_mode)) - - def get_strategy_search_mode(self): - """Get search mode of strategy.""" - self.check_context_handle() - return self._context_handle.get_strategy_search_mode() +def set_strategy_search_mode(self, search_mode): + """ + 设置策略搜索模式。 - def set_auto_parallel_search_mode(self, search_mode): - """ - Set search mode of strategy searching. This is the old version of 'search_mode', and will be deleted in a future - MindSpore version. + 参数: + search_mode (str): 策略搜索模式。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + ret = self._context_handle.set_strategy_search_mode(search_mode) # 设置策略搜索模式 + if ret is False: + raise ValueError("上下文配置参数 'auto_parallel_search_mode' 仅支持 " + "'recursive_programming', 'dynamic_programming' 和 'sharding_propagation', " + "但得到的是: {}." + .format(search_mode)) # 检查设置的搜索模式是否有效 - Args: - search_mode (str): The search mode of strategy. - """ - logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. " - "The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.") - self.check_context_handle() - ret = self._context_handle.set_strategy_search_mode(search_mode) - if ret is False: - raise ValueError("The context configuration parameter 'search_mode' only support " - "'recursive_programming', 'dynamic_programming' and 'sharding_propagation', " - "but got the value: {}." - .format(search_mode)) - - def get_auto_parallel_search_mode(self): - """Get search mode of strategy. This is the old version of 'search_mode', and will be deleted in a future - MindSpore version. - """ - logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. " - "The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.") - self.check_context_handle() - return self._context_handle.get_strategy_search_mode() - def set_sharding_propagation(self, sharding_propagation): - """ - Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators - will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm - will search the desired strategies. Default: False. - This attribute is replaced by context.set_auto_parallel(search_mode="sharding_propagation"). +def get_strategy_search_mode(self): + """ + 获取策略搜索模式。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_strategy_search_mode() # 返回当前策略搜索模式 - Args: - sharding_propagation (bool): Enable/disable strategy propagation. - """ - logger.warning("This attribute is replaced by context.set_auto_parallel(search_mode='sharding_propagation'), " - "and this attribute will be deleted in a future MindSpore version.") - self.check_context_handle() - if not isinstance(sharding_propagation, bool): - raise TypeError("For 'set_auto_parallel_context().set_sharding_propagation', " - "the argument 'sharding_propagation' must be bool, but got the type : {}." - .format(type(sharding_propagation))) - self._context_handle.set_sharding_propagation(sharding_propagation) - - def get_sharding_propagation(self): - """Get the value of sharding strategy propagation.""" - self.check_context_handle() - return self._context_handle.get_sharding_propagation() - def set_parameter_broadcast(self, parameter_broadcast): - """ - Set parameter broadcast. +def set_auto_parallel_search_mode(self, search_mode): + """ + 设置策略搜索模式(旧版)。这是 'search_mode' 的旧版本,将在未来的 MindSpore 版本中删除。 - Args: - parameter_broadcast (bool): Parameter broadcast or not. - """ - self.check_context_handle() - self._context_handle.set_parameter_broadcast(parameter_broadcast) + 参数: + search_mode (str): 策略搜索模式。 + """ + logger.warning("属性 'auto_parallel_search_mode' 目前已被 'search_mode' 替代。" + "属性 'auto_parallel_search_mode' 将在未来的 MindSpore 版本中删除。") + self.check_context_handle() # 检查上下文句柄是否有效 + ret = self._context_handle.set_strategy_search_mode(search_mode) # 设置策略搜索模式 + if ret is False: + raise ValueError("上下文配置参数 'search_mode' 仅支持 " + "'recursive_programming', 'dynamic_programming' 和 'sharding_propagation', " + "但得到的是: {}." + .format(search_mode)) # 检查设置的搜索模式是否有效 + + +def get_auto_parallel_search_mode(self): + """ + 获取策略搜索模式(旧版)。这是 'search_mode' 的旧版本,将在未来的 MindSpore 版本中删除。 + """ + logger.warning("属性 'auto_parallel_search_mode' 目前已被 'search_mode' 替代。" + "属性 'auto_parallel_search_mode' 将在未来的 MindSpore 版本中删除。") + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_strategy_search_mode() # 返回当前策略搜索模式 - def get_parameter_broadcast(self): - """Get parameter broadcast flag.""" - self.check_context_handle() - return self._context_handle.get_parameter_broadcast() - def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file): - """ - Set strategy checkpoint load path. +def set_sharding_propagation(self, sharding_propagation): + """ + 设置 AUTO_PARALLEL 模式下分片策略的传播值。如果为 True,配置策略的操作符将策略传播到其他操作符,以最小化重新分配成本; + 否则,算法将搜索所需的策略。默认:False。 + 此属性已被 context.set_auto_parallel(search_mode="sharding_propagation") 替代。 - Args: - strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint. - """ - self.check_context_handle() - self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file) + 参数: + sharding_propagation (bool): 启用/禁用策略传播。 + """ + logger.warning("此属性已被 context.set_auto_parallel(search_mode='sharding_propagation') 替代," + "并且此属性将在未来的 MindSpore 版本中删除。") + self.check_context_handle() # 检查上下文句柄是否有效 + if not isinstance(sharding_propagation, bool): + raise TypeError("对于 'set_auto_parallel_context().set_sharding_propagation', " + "参数 'sharding_propagation' 必须为 bool 类型,但得到的是: {}." + .format(type(sharding_propagation))) # 检查类型是否为布尔 + self._context_handle.set_sharding_propagation(sharding_propagation) # 设置分片策略传播值 - def get_strategy_ckpt_load_file(self): - """Get strategy checkpoint load path.""" - self.check_context_handle() - return self._context_handle.get_strategy_ckpt_load_file() - def set_full_batch(self, full_batch): - """ - Set whether load full batch on each device. +def get_sharding_propagation(self): + """ + 获取分片策略的传播值。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_sharding_propagation() # 返回分片策略传播值 - Args: - full_batch (bool): True if load full batch on each device. - """ - self.check_context_handle() - self._context_handle.set_full_batch(full_batch) - def get_full_batch(self): - """Get whether load full batch on each device.""" - self.check_context_handle() - if _is_role_pserver(): - return False - return self._context_handle.get_full_batch() +def set_parameter_broadcast(self, parameter_broadcast): + """ + 设置参数广播。 - def set_dataset_strategy(self, dataset_strategy): - """ - Set dataset sharding strategy. + 参数: + parameter_broadcast (bool): 是否广播参数。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + self._context_handle.set_parameter_broadcast(parameter_broadcast) # 设置参数广播标志 - Args: - dataset_strategy (str or tuple(tuple)): The dataset sharding strategy. - """ - self.check_context_handle() - if isinstance(dataset_strategy, str): - if dataset_strategy not in ("full_batch", "data_parallel"): - raise ValueError("For 'set_auto_parallel_context', the argument " - "'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}." - .format(dataset_strategy)) - self._context_handle.set_full_batch(dataset_strategy == "full_batch") - self._dataset_strategy_using_str = True - return - if not isinstance(dataset_strategy, tuple): - raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' " - "must be str or tuple type, but got the type : {}.".format(type(dataset_strategy))) - for ele in dataset_strategy: - if not isinstance(ele, tuple): - raise TypeError("For 'set_auto_parallel_context', the element of argument " - "'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele))) - for dim in ele: - if not isinstance(dim, int): - raise TypeError("For 'set_auto_parallel_context', the element of argument " - "'dataset_strategy' must be int type, but got the type : {} .".format(type(dim))) - self._dataset_strategy_using_str = False - self._context_handle.set_dataset_strategy(dataset_strategy) - - def get_dataset_strategy(self): - """Get dataset sharding strategy.""" - self.check_context_handle() - if self._dataset_strategy_using_str: - if self._context_handle.get_full_batch(): - return "full_batch" - return "data_parallel" - return self._context_handle.get_dataset_strategy() - def set_grad_accumulation_step(self, grad_accumulation_step): - """ - Set grad accumulation step. +def get_parameter_broadcast(self): + """ + 获取参数广播标志。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_parameter_broadcast() # 返回参数广播标志 - Args: - grad_accumulation_step (int): The grad accumulation step. - """ - self.check_context_handle() - Validator.check_positive_int(grad_accumulation_step) - self._context_handle.set_grad_accumulation_step(grad_accumulation_step) - def get_grad_accumulation_step(self): - """Get grad accumulation step.""" - self.check_context_handle() - return self._context_handle.get_grad_accumulation_step() +def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file): + """ + 设置策略检查点加载路径。 - def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): - """ - Set strategy checkpoint save path. + 参数: + strategy_ckpt_load_file (str): 加载并行策略检查点的路径。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file) # 设置策略检查点加载路径 - Args: - strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. - """ - self.check_context_handle() - dir_path = os.path.dirname(strategy_ckpt_save_file) - if dir_path and not os.path.exists(dir_path): - os.makedirs(dir_path) - self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) + +def get_strategy_ckpt_load_file(self): + """ + 获取策略检查点加载路径。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_strategy_ckpt_load_file() # 返回策略检查点加载路径 + + +def set_full_batch(self, full_batch): + """ + 设置是否在每个设备上加载完整批次。 + + 参数: + full_batch (bool): 如果在每个设备上加载完整批次则为 True。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + self._context_handle.set_full_batch(full_batch) # 设置是否加载完整批次 + + +def get_full_batch(self): + """ + 获取是否在每个设备上加载完整批次。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + if _is_role_pserver(): # 检查当前角色是否为参数服务器 + return False # 参数服务器不加载完整批次 + return self._context_handle.get_full_batch() # 返回是否加载完整批次 + + +def set_dataset_strategy(self, dataset_strategy): + """ + 设置数据集分片策略。 + + 参数: + dataset_strategy (str or tuple(tuple)): 数据集分片策略。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + if isinstance(dataset_strategy, str): + if dataset_strategy not in ("full_batch", "data_parallel"): + raise ValueError("对于 'set_auto_parallel_context', 参数 " + "'dataset_strategy' 必须为 'full_batch' 或 'data_parallel', 但得到的是: {}." + .format(dataset_strategy)) # 检查字符串策略是否有效 + self._context_handle.set_full_batch(dataset_strategy == "full_batch") # 设置是否加载完整批次 + self._dataset_strategy_using_str = True # 标记使用字符串策略 + return + if not isinstance(dataset_strategy, tuple): + raise TypeError("对于 'set_auto_parallel_context', 参数 'dataset_strategy' " + "必须为 str 或 tuple 类型,但得到的是: {}." + .format(type(dataset_strategy))) # 检查类型是否为元组 + for ele in dataset_strategy: + if not isinstance(ele, tuple): + raise TypeError("对于 'set_auto_parallel_context', 参数 'dataset_strategy' 的元素 " + "必须为 tuple 类型,但得到的是: {} .".format(type(ele))) # 检查元素类型是否为元组 + for dim in ele: + if not isinstance(dim, int): + raise TypeError("对于 'set_auto_parallel_context', 参数 'dataset_strategy' 的元素 " + "必须为 int 类型,但得到的是: {} .".format(type(dim))) # 检查维度类型是否为整数 + self._dataset_strategy_using_str = False # 标记不使用字符串策略 + self._context_handle.set_dataset_strategy(dataset_strategy) # 设置数据集分片策略 + + +def get_dataset_strategy(self): + """ + 获取数据集分片策略。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + if self._dataset_strategy_using_str: # 如果使用字符串策略 + if self._context_handle.get_full_batch(): + return "full_batch" # 返回完整批次策略 + return "data_parallel" # 返回数据并行策略 + return self._context_handle.get_dataset_strategy() # 返回数据集分片策略 + + +def set_grad_accumulation_step(self, grad_accumulation_step): + """ + 设置梯度累积步长。 + + 参数: + grad_accumulation_step (int): 梯度累积步长。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + Validator.check_positive_int(grad_accumulation_step) # 检查步长是否为正整数 + self._context_handle.set_grad_accumulation_step(grad_accumulation_step) # 设置梯度累积步长 + + +def get_grad_accumulation_step(self): + """ + 获取梯度累积步长。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_grad_accumulation_step() # 返回梯度累积步长 + + +def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): + """ + 设置策略检查点保存路径。 + + 参数: + strategy_ckpt_save_file (bool): 保存并行策略检查点的路径。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + dir_path = os.path.dirname(strategy_ckpt_save_file) # 获取目录路径 + if dir_path and not os.path.exists(dir_path): + os.makedirs(dir_path) # 如果目录不存在,则创建目录 + self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) # 设置策略检查点保存路径 + + +def get_strategy_ckpt_save_file(self): + """ + 获取策略检查点保存路径。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_strategy_ckpt_save_file() # 返回策略检查点保存路径 + + +def set_group_ckpt_save_file(self, group_ckpt_save_file): + """ + 设置组检查点保存路径。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + dir_path = os.path.dirname(group_ckpt_save_file) # 获取目录路径 + if dir_path and not os.path.exists(dir_path): + os.makedirs(dir_path) + self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) def get_strategy_ckpt_save_file(self): """Get strategy checkpoint save path.""" @@ -817,114 +900,133 @@ class _AutoParallelContext: def set_enable_alltoall(self, enable_a2a): """ - Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll. - Default: False. + 设置启用 AllToAll 的值。如果为 False,则使用 AllGather 和 Split 来绕过 AllToAll。 + 默认:False。 - Args: - enable_a2a (bool): Enable/disable AllToAll. + 参数: + enable_a2a (bool): 启用/禁用 AllToAll。 """ - self.check_context_handle() - if not isinstance(enable_a2a, bool): - raise TypeError("For 'set_auto_parallel_context().set_enable_alltoall', the argument 'enable_a2a' " - "must be bool, but got the type : {}.".format(type(enable_a2a))) - self._context_handle.set_enable_alltoall(enable_a2a) + self.check_context_handle() # 检查上下文句柄是否有效 + if not isinstance(enable_a2a, bool): + raise TypeError("对于 'set_auto_parallel_context().set_enable_alltoall', 参数 'enable_a2a' " + "必须为 bool 类型,但得到的是: {}.".format(type(enable_a2a))) # 检查类型是否为布尔 + self._context_handle.set_enable_alltoall(enable_a2a) # 设置启用 AllToAll 的值 - def get_enable_alltoall(self): - """Get the value of enabling AllToAll.""" - self.check_context_handle() - return self._context_handle.get_enable_alltoall() - def set_communi_parallel_mode(self, communi_parallel_mode): - """ - Set communication parallel mode. +def get_enable_alltoall(self): + """ + 获取启用 AllToAll 的值。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_enable_alltoall() # 返回启用 AllToAll 的值 - Args: - communi_parallel_mode (str): The communication parallel mode. - Raises: - ValueError: If parallel mode is not supported. - """ - if not isinstance(communi_parallel_mode, str): - raise TypeError("For 'set_auto_parallel_context().set_communi_parallel_mode', " - "the argument 'communi_parallel_mode' must be str, but got the type : {}." - .format(type(communi_parallel_mode))) - self.check_context_handle() - ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode) - if ret is False: - raise ValueError("For 'set_auto_parallel_context().set_communi_parallel_mode', " - "the argument 'communi_parallel_mode' only support 'ALL_GROUP_PARALLEL', " - "'SAME_SEVER_GROUP_PARALLEL' and 'NO_GROUP_PARALLEL', " - "but got the value : {}.".format(communi_parallel_mode)) - - def get_communi_parallel_mode(self): - """Get communication parallel mode.""" - self.check_context_handle() - return self._context_handle.get_communi_parallel_mode() +def set_communi_parallel_mode(self, communi_parallel_mode): + """ + 设置通信并行模式。 - def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size): - """ - Set optimizer_weight_shard_size. + 参数: + communi_parallel_mode (str): 通信并行模式。 - Args: - optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel - optimizer across devices. - """ - self.check_context_handle() - if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool): - raise TypeError(f"The type of optimizer_weight_shard_size must be int, \ - but got {type(optimizer_weight_shard_size)}.") - if optimizer_weight_shard_size <= 1: - logger.warning("The setting 'optimizer_weight_shard_size' is invalid. " - "Please use the integer larger than 1.") - return - self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size) - - def get_optimizer_weight_shard_size(self): - """Get optimizer_weight_shard_size.""" - self.check_context_handle() - return self._context_handle.get_optimizer_weight_shard_size() + 引发: + ValueError: 如果并行模式不支持。 + """ + if not isinstance(communi_parallel_mode, str): + raise TypeError("对于 'set_auto_parallel_context().set_communi_parallel_mode', " + "参数 'communi_parallel_mode' 必须为 str 类型,但得到的是: {}." + .format(type(communi_parallel_mode))) # 检查类型是否为字符串 + self.check_context_handle() # 检查上下文句柄是否有效 + ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode) # 设置通信并行模式 + if ret is False: + raise ValueError("对于 'set_auto_parallel_context().set_communi_parallel_mode', " + "参数 'communi_parallel_mode' 仅支持 'ALL_GROUP_PARALLEL', " + "'SAME_SEVER_GROUP_PARALLEL' 和 'NO_GROUP_PARALLEL', " + "但得到的是: {}.".format(communi_parallel_mode)) # 检查设置的并行模式是否有效 + + +def get_communi_parallel_mode(self): + """ + 获取通信并行模式。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_communi_parallel_mode() # 返回当前通信并行模式 - def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save): - """ - Set optimizer_weight_shard_aggregated_save. - Args: - optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when - enable parallel optimizer. - """ - self.check_context_handle() - if not isinstance(optimizer_weight_shard_aggregated_save, bool): - raise TypeError('optimizer_weight_shard_aggregated_save is invalid type') - self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save) +def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size): + """ + 设置 optimizer_weight_shard_size。 + + 参数: + optimizer_weight_shard_size (int): 当未全局使用跨设备的并行优化器时,优化器分片组大小。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool): + raise TypeError(f"optimizer_weight_shard_size 的类型必须为 int, " + f"但得到的是: {type(optimizer_weight_shard_size)}.") # 检查类型是否为整数 + if optimizer_weight_shard_size <= 1: + logger.warning("设置 'optimizer_weight_shard_size' 无效。 " + "请使用大于 1 的整数。") + return # 如果分片大小无效,则不设置 + self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size) # 设置优化器分片组大小 + + +def get_optimizer_weight_shard_size(self): + """ + 获取 optimizer_weight_shard_size。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_optimizer_weight_shard_size() # 返回优化器分片组大小 - def get_optimizer_weight_shard_aggregated_save(self): - """Get optimizer_weight_shard_size.""" - self.check_context_handle() - return self._context_handle.get_optimizer_weight_shard_aggregated_save() +def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save): + """ + 设置 optimizer_weight_shard_aggregated_save。 + 参数: + optimizer_weight_shard_aggregated_save (bool): 在启用并行优化器时,是否集成保存权重分片。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + if not isinstance(optimizer_weight_shard_aggregated_save, bool): + raise TypeError('optimizer_weight_shard_aggregated_save 的类型无效') # 检查类型是否为布尔 + self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save) # 设置是否集成保存权重分片 - def reset(self): - """Reset all settings.""" - self.check_context_handle() - self._context_handle.reset() +def get_optimizer_weight_shard_aggregated_save(self): + """ + 获取 optimizer_weight_shard_size。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + return self._context_handle.get_optimizer_weight_shard_aggregated_save() # 返回是否集成保存权重分片 - def _check_and_default_group(self, group): - """Validate the given group, if group is empty, returns a default fusion group""" - if isinstance(group, (str)): - group_len = len(group) - if group_len > _MAX_GROUP_NAME_LEN: - raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') + +def reset(self): + """ + 重置所有设置。 + """ + self.check_context_handle() # 检查上下文句柄是否有效 + self._context_handle.reset() # 重置上下文句柄 + + +def _check_and_default_group(self, group): + """ + 验证给定的组,如果组为空,则返回默认的融合组。 + """ + if isinstance(group, (str)): + group_len = len(group) + if group_len > _MAX_GROUP_NAME_LEN: + raise ValueError(f'组名长度超出范围 {_MAX_GROUP_NAME_LEN}') # 检查组名长度是否超出范围 + else: + raise TypeError('组必须是 Python 字符串') # 检查类型是否为字符串 + + if group == "": + if context.get_context("device_target") == "Ascend": + group = _DEFAULT_HCCL_FUSION_GROUP_NAME # Ascend 设备使用默认 HCCL 融合组名 else: - raise TypeError('Group must be a python str') + group = _DEFAULT_NCCL_FUSION_GROUP_NAME # 其他设备使用默认 NCCL 融合组名 + return group # 返回验证后的组名 - if group == "": - if context.get_context("device_target") == "Ascend": - group = _DEFAULT_HCCL_FUSION_GROUP_NAME - else: - group = _DEFAULT_NCCL_FUSION_GROUP_NAME - return group + +_auto_parallel_context = None # 全局自动并行上下文的实例,初始化为 None``` _auto_parallel_context = None diff --git a/src/mindspore2022/mindspore/python/mindspore/parallel/_cell_wrapper.py b/src/mindspore2022/mindspore/python/mindspore/parallel/_cell_wrapper.py index fa063178..092ab801 100644 --- a/src/mindspore2022/mindspore/python/mindspore/parallel/_cell_wrapper.py +++ b/src/mindspore2022/mindspore/python/mindspore/parallel/_cell_wrapper.py @@ -30,10 +30,12 @@ class AllGatherCell(Cell): def __init__(self, group): super(AllGatherCell, self).__init__(auto_prefix=False) + # 创建AllGather操作对象 self.allgather = AllGather(group) @ms_function() def construct(self, x): + # 执行AllGather操作 x = self.allgather(x) return x @@ -50,10 +52,12 @@ class SaveOptShardCkptCell(Cell): """ def __init__(self, group): super(SaveOptShardCkptCell, self).__init__(auto_prefix=False) + # 创建AllGather操作对象 self.allgather1 = AllGather(group) self.allgather2 = AllGather() def construct(self, x): + # 执行AllGather操作 x = self.allgather1(x) x = self.allgather2(x) @@ -64,11 +68,14 @@ def get_allgather_cell(group, need_merge_twice=False): """Get AllGatherCell object.""" global _allgather_cell if need_merge_twice: + # 如果需要两次合并,则创建SaveOptShardCkptCell对象 _allgather_cell = SaveOptShardCkptCell(group) else: if group: + # 如果有指定的设备组,则创建AllGatherCell对象 _allgather_cell = AllGatherCell(group) else: + # 否则,创建AllGatherCell对象,使用全局通信组 _allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP) return _allgather_cell @@ -77,4 +84,5 @@ def destroy_allgather_cell(): """Destroy AllGatherCell object.""" global _allgather_cell if _allgather_cell: + # 销毁AllGatherCell对象 _allgather_cell = None