|
|
|
@ -35,6 +35,7 @@ _update_parameters = C.MultitypeFuncGraph("update_parameters")
|
|
|
|
|
|
|
|
|
|
@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor")
|
|
|
|
|
def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter):
|
|
|
|
|
"""更新参数的函数,在广播后应用delta_weight来更新参数."""
|
|
|
|
|
shape = F.shape(delta_weight)
|
|
|
|
|
update_delta_weight = P.Reshape()(update_delta_weight, shape)
|
|
|
|
|
new_parameter = old_parameter - update_delta_weight
|
|
|
|
@ -42,18 +43,20 @@ def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _send_before_receive(send_part, send, recv):
|
|
|
|
|
"""在接收之前发送数据的辅助函数."""
|
|
|
|
|
send_ok = send(send_part)
|
|
|
|
|
return recv(send_ok)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _receive_before_send(send_part, send, recv):
|
|
|
|
|
"""在发送之前接收数据的辅助函数."""
|
|
|
|
|
receive_ok = recv(send_part)
|
|
|
|
|
send_part = F.depend(send_part, receive_ok)
|
|
|
|
|
return F.depend(receive_ok, send(send_part))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
|
|
|
|
|
"""send result and receive result."""
|
|
|
|
|
"""发送结果并接收结果的辅助函数."""
|
|
|
|
|
if parameter_divisibility:
|
|
|
|
|
recv_part = P.Squeeze()(recv_part)
|
|
|
|
|
local_part = F.depend(local_part, recv_part)
|
|
|
|
@ -83,7 +86,7 @@ _adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
|
|
|
|
|
|
|
|
|
|
@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
|
|
|
|
|
def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
|
|
|
|
|
"""adasum optimizer process."""
|
|
|
|
|
"""adaSum优化器的前向过程处理函数."""
|
|
|
|
|
if parameter_divisibility:
|
|
|
|
|
delta_w = P.Squeeze()(delta_w)
|
|
|
|
|
ori_len = F.shape(delta_w)[0]
|
|
|
|
@ -117,7 +120,7 @@ _adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
|
|
|
|
|
|
|
|
|
|
@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
|
|
|
|
|
def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
|
|
|
|
|
"""adasum optimizer rollback process."""
|
|
|
|
|
"""adaSum优化器的回滚处理函数."""
|
|
|
|
|
if parameter_divisibility:
|
|
|
|
|
if left_send:
|
|
|
|
|
recv_part = _send_before_receive(delta_w, send, recv)
|
|
|
|
@ -139,24 +142,24 @@ def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, sen
|
|
|
|
|
|
|
|
|
|
class AdaSum(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data
|
|
|
|
|
parallel training of Deep Learning models.
|
|
|
|
|
自适应加法(AdaSum)是一种新算法,用于改善深度学习模型的分布式数据并行训练。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
rank (int): Rank number.
|
|
|
|
|
device_number (int): Device number.
|
|
|
|
|
group_number (int): Group number.
|
|
|
|
|
parameter_tuple (Tuple(Parameter)): Tuple of parameters.
|
|
|
|
|
rank (int): 排名编号。
|
|
|
|
|
device_number (int): 设备数量。
|
|
|
|
|
group_number (int): 组数量。
|
|
|
|
|
parameter_tuple (Tuple(Parameter)): 参数元组。
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **delta_weights** (Tuple(Tensor)) - Tuple of gradients.
|
|
|
|
|
- **parameters** (Tuple(Parameter)) - Tuple of current parameters.
|
|
|
|
|
- **old_parameters** (Tuple(Parameter)) - Tuple of last parameters.
|
|
|
|
|
- **delta_weights** (Tuple(Tensor)) - 梯度的元组。
|
|
|
|
|
- **parameters** (Tuple(Parameter)) - 当前参数的元组。
|
|
|
|
|
- **old_parameters** (Tuple(Parameter)) - 上一参数的元组。
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process.
|
|
|
|
|
- **adasum_parameters** (Tuple(Tensor)) - 经过adasum处理后的参数元组。
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, rank, device_number, group_number, parameter_tuple):
|
|
|
|
|
"""AdaSum类的初始化函数."""
|
|
|
|
|
super(AdaSum, self).__init__()
|
|
|
|
|
self.rank = rank
|
|
|
|
|
self.device_number = device_number
|
|
|
|
@ -166,7 +169,7 @@ class AdaSum(Cell):
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
|
|
|
|
|
def _generate_communication_op(self):
|
|
|
|
|
"""generate communication op."""
|
|
|
|
|
"""生成通信操作的私有方法."""
|
|
|
|
|
self.calc_times = int(math.log(self.group_number, 2))
|
|
|
|
|
self.send_node = []
|
|
|
|
|
self.send_list_forward = []
|
|
|
|
@ -267,7 +270,7 @@ class AdaSum(Cell):
|
|
|
|
|
self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name)
|
|
|
|
|
|
|
|
|
|
def _get_delta_weights_info(self, last_delta_weights):
|
|
|
|
|
"""get delta weights info."""
|
|
|
|
|
"""获取delta权重信息的私有方法."""
|
|
|
|
|
half_delta_weights = []
|
|
|
|
|
if last_delta_weights:
|
|
|
|
|
half_delta_weights = last_delta_weights
|
|
|
|
@ -294,12 +297,14 @@ class AdaSum(Cell):
|
|
|
|
|
return left_delta_weights, right_delta_weights, delta_weights_divisibility
|
|
|
|
|
|
|
|
|
|
def _hash(self, step, target, weights_index):
|
|
|
|
|
"""计算哈希值的私有方法."""
|
|
|
|
|
target = "tag" + str(step) + str(target) + str(weights_index)
|
|
|
|
|
target_hash = hashlib.sha1(target.encode()).hexdigest()
|
|
|
|
|
hash_res = int(int(target_hash, 16) % MAX_NUM_HASH)
|
|
|
|
|
return hash_res
|
|
|
|
|
|
|
|
|
|
def construct(self, delta_weights, parameters, old_parameters):
|
|
|
|
|
"""构建方法,用于执行adaSum优化过程."""
|
|
|
|
|
forward_weights = [delta_weights]
|
|
|
|
|
for i in range(self.calc_times):
|
|
|
|
|
process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),
|
|
|
|
|