feat(boost): 增加对新模块的导入及中文注释和文档翻译

branch-yixin
yixin 2 months ago
parent c7222ba353
commit 510df7cf31

@ -19,16 +19,25 @@ accumulation and so on.
Note:
This feature is a beta feature, and we are still improving its functionality.
"""
# 从当前包的boost模块导入AutoBoost类
from .boost import AutoBoost
# 从当前包的base模块导入OptimizerProcess和ParameterProcess类
from .base import OptimizerProcess, ParameterProcess
# 从当前包的boost_cell_wrapper模块导入BoostTrainOneStepCell和BoostTrainOneStepWithLossScaleCell类
from .boost_cell_wrapper import BoostTrainOneStepCell, BoostTrainOneStepWithLossScaleCell
# 从当前包的less_batch_normalization模块导入LessBN类
from .less_batch_normalization import LessBN
# 从当前包的grad_freeze模块导入GradientFreeze, FreezeOpt和freeze_cell类或函数
from .grad_freeze import GradientFreeze, FreezeOpt, freeze_cell
# 从当前包的grad_accumulation模块导入GradientAccumulation类
from .grad_accumulation import GradientAccumulation
# 从当前包的adasum模块导入AdaSum类
from .adasum import AdaSum
# 从当前包的dim_reduce模块导入DimReduce类
from .dim_reduce import DimReduce
# 定义一个列表,包含所有要公开的模块成员
__all__ = ['AutoBoost',
'OptimizerProcess', 'ParameterProcess',
'BoostTrainOneStepCell', 'BoostTrainOneStepWithLossScaleCell',

@ -22,38 +22,41 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import Send, Receive
__all__ = ["AdaSum"]
MAX_NUM_HASH = 2 ** 31
_update_parameters = C.MultitypeFuncGraph("update_parameters")
@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor")
def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter):
"""更新参数的函数在广播后应用delta_weight来更新参数."""
shape = F.shape(delta_weight)
update_delta_weight = P.Reshape()(update_delta_weight, shape)
new_parameter = old_parameter - update_delta_weight
return P.Assign()(parameter, new_parameter)
def _send_before_receive(send_part, send, recv):
"""在接收之前发送数据的辅助函数."""
send_ok = send(send_part)
return recv(send_ok)
def _receive_before_send(send_part, send, recv):
"""在发送之前接收数据的辅助函数."""
receive_ok = recv(send_part)
send_part = F.depend(send_part, receive_ok)
return F.depend(receive_ok, send(send_part))
def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
"""send result and receive result."""
"""发送结果并接收结果的辅助函数."""
if parameter_divisibility:
recv_part = P.Squeeze()(recv_part)
local_part = F.depend(local_part, recv_part)
@ -76,14 +79,14 @@ def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisi
res = allreduce(local_part)
res /= allreduce_node_num
return res
_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
"""adasum optimizer process."""
"""adaSum优化器的前向过程处理函数."""
if parameter_divisibility:
delta_w = P.Squeeze()(delta_w)
ori_len = F.shape(delta_w)[0]
@ -93,7 +96,7 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al
else:
left_part = delta_w
right_part = delta_w
if left_send:
if parameter_divisibility:
recv_part = _send_before_receive(left_part, send, recv)
@ -108,26 +111,26 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al
recv_part = left_part
update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility,
allreduce_node_num)
return update_delta_w
_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
"""adasum optimizer rollback process."""
"""adaSum优化器的回滚处理函数."""
if parameter_divisibility:
if left_send:
recv_part = _send_before_receive(delta_w, send, recv)
else:
recv_part = _receive_before_send(delta_w, send, recv)
recv_part = P.Squeeze()(recv_part)
recv_part = P.Reshape()(recv_part, (-1,))
delta_w = P.Reshape()(delta_w, (-1,))
if left_send:
res = P.Concat()((recv_part, delta_w))
else:
@ -135,28 +138,28 @@ def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, sen
else:
res = delta_w
return res
class AdaSum(Cell):
r"""
The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data
parallel training of Deep Learning models.
自适应加法AdaSum是一种新算法用于改善深度学习模型的分布式数据并行训练
Args:
rank (int): Rank number.
device_number (int): Device number.
group_number (int): Group number.
parameter_tuple (Tuple(Parameter)): Tuple of parameters.
rank (int): 排名编号
device_number (int): 设备数量
group_number (int): 组数量
parameter_tuple (Tuple(Parameter)): 参数元组
Inputs:
- **delta_weights** (Tuple(Tensor)) - Tuple of gradients.
- **parameters** (Tuple(Parameter)) - Tuple of current parameters.
- **old_parameters** (Tuple(Parameter)) - Tuple of last parameters.
- **delta_weights** (Tuple(Tensor)) - 梯度的元组
- **parameters** (Tuple(Parameter)) - 当前参数的元组
- **old_parameters** (Tuple(Parameter)) - 上一参数的元组
Outputs:
- **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process.
- **adasum_parameters** (Tuple(Tensor)) - 经过adasum处理后的参数元组
"""
def __init__(self, rank, device_number, group_number, parameter_tuple):
"""AdaSum类的初始化函数."""
super(AdaSum, self).__init__()
self.rank = rank
self.device_number = device_number
@ -164,9 +167,9 @@ class AdaSum(Cell):
self.parameter_tuple = parameter_tuple
self._generate_communication_op()
self.hyper_map = C.HyperMap()
def _generate_communication_op(self):
"""generate communication op."""
"""生成通信操作的私有方法."""
self.calc_times = int(math.log(self.group_number, 2))
self.send_node = []
self.send_list_forward = []
@ -179,7 +182,7 @@ class AdaSum(Cell):
self.allreduce_node_num_list = []
last_delta_weights = []
group_start_rank = (self.rank // self.device_number) * self.device_number
for step in range(self.calc_times):
current_group = self.device_number * (2 ** step)
sr_target = self.rank
@ -189,7 +192,7 @@ class AdaSum(Cell):
else:
dest_target = sr_target - current_group
self.send_node.append(False)
neighbor_ids = []
group_name_last = 0
for index in range(2 ** (step + 1)):
@ -201,7 +204,7 @@ class AdaSum(Cell):
group_name_last += neighbor_id
group_name = "adasum_" + str(step) + "_" + str(group_name_last)
create_group(group_name, neighbor_ids)
send_left = []
send_right = []
recv_left = []
@ -234,7 +237,7 @@ class AdaSum(Cell):
send_right.append(send)
recv_right.append(recv)
weights_index += 1
if self.send_node and self.send_node[-1]:
self.send_list_forward.append(send_left)
self.send_list_rollback.append(send_right)
@ -247,27 +250,27 @@ class AdaSum(Cell):
self.recv_list_forward.append(recv_left)
self.recv_list_rollback.append(recv_right)
last_delta_weights = left_delta_weights
server_all_reduce = P.AllReduce("sum", group_name)
server_all_reduce.add_prim_attr("fusion", fusion_id + 2)
self.allreduce_list.append(server_all_reduce)
for param_divisibility in delta_weights_divisibility:
if param_divisibility:
allreduce_node_num += (0,)
else:
allreduce_node_num += (2 ** (step + 1),)
self.allreduce_node_num_list.append(allreduce_node_num)
broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)]
broadcast_group_name = "broadcast_group_" + str(group_start_rank)
create_group(broadcast_group_name, broadcast_group)
for b_rank in range(len(broadcast_group)):
self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name))
self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name)
def _get_delta_weights_info(self, last_delta_weights):
"""get delta weights info."""
"""获取delta权重信息的私有方法."""
half_delta_weights = []
if last_delta_weights:
half_delta_weights = last_delta_weights
@ -292,14 +295,16 @@ class AdaSum(Cell):
right_delta_weights.append((right_shape, dtype))
delta_weights_divisibility += (divisibility_flag,)
return left_delta_weights, right_delta_weights, delta_weights_divisibility
def _hash(self, step, target, weights_index):
"""计算哈希值的私有方法."""
target = "tag" + str(step) + str(target) + str(weights_index)
target_hash = hashlib.sha1(target.encode()).hexdigest()
hash_res = int(int(target_hash, 16) % MAX_NUM_HASH)
return hash_res
def construct(self, delta_weights, parameters, old_parameters):
"""构建方法用于执行adaSum优化过程."""
forward_weights = [delta_weights]
for i in range(self.calc_times):
process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),
@ -314,4 +319,4 @@ class AdaSum(Cell):
forward_weights[j] = process_weights
adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],
parameters, old_parameters)
return adasum_parameters
return adasum_parameters
Loading…
Cancel
Save