|
|
|
@ -22,38 +22,41 @@ from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops.operations._inner_ops import Send, Receive
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["AdaSum"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_NUM_HASH = 2 ** 31
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_update_parameters = C.MultitypeFuncGraph("update_parameters")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_update_parameters.register("Tensor", "Tensor", "Tensor", "Tensor")
|
|
|
|
|
def _update_parameters_after_broadcast(delta_weight, update_delta_weight, parameter, old_parameter):
|
|
|
|
|
"""更新参数的函数,在广播后应用delta_weight来更新参数."""
|
|
|
|
|
shape = F.shape(delta_weight)
|
|
|
|
|
update_delta_weight = P.Reshape()(update_delta_weight, shape)
|
|
|
|
|
new_parameter = old_parameter - update_delta_weight
|
|
|
|
|
return P.Assign()(parameter, new_parameter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _send_before_receive(send_part, send, recv):
|
|
|
|
|
"""在接收之前发送数据的辅助函数."""
|
|
|
|
|
send_ok = send(send_part)
|
|
|
|
|
return recv(send_ok)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _receive_before_send(send_part, send, recv):
|
|
|
|
|
"""在发送之前接收数据的辅助函数."""
|
|
|
|
|
receive_ok = recv(send_part)
|
|
|
|
|
send_part = F.depend(send_part, receive_ok)
|
|
|
|
|
return F.depend(receive_ok, send(send_part))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisibility, allreduce_node_num):
|
|
|
|
|
"""send result and receive result."""
|
|
|
|
|
"""发送结果并接收结果的辅助函数."""
|
|
|
|
|
if parameter_divisibility:
|
|
|
|
|
recv_part = P.Squeeze()(recv_part)
|
|
|
|
|
local_part = F.depend(local_part, recv_part)
|
|
|
|
@ -76,14 +79,14 @@ def _send_recv_res(left_send, recv_part, local_part, allreduce, parameter_divisi
|
|
|
|
|
res = allreduce(local_part)
|
|
|
|
|
res /= allreduce_node_num
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_adasum_opt_forward = C.MultitypeFuncGraph("adasum_opt_forward")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_adasum_opt_forward.register("Bool", "Function", "Bool", "Int64", "Function", "Function", "Tensor")
|
|
|
|
|
def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, allreduce_node_num, send, recv, delta_w):
|
|
|
|
|
"""adasum optimizer process."""
|
|
|
|
|
"""adaSum优化器的前向过程处理函数."""
|
|
|
|
|
if parameter_divisibility:
|
|
|
|
|
delta_w = P.Squeeze()(delta_w)
|
|
|
|
|
ori_len = F.shape(delta_w)[0]
|
|
|
|
@ -93,7 +96,7 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al
|
|
|
|
|
else:
|
|
|
|
|
left_part = delta_w
|
|
|
|
|
right_part = delta_w
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if left_send:
|
|
|
|
|
if parameter_divisibility:
|
|
|
|
|
recv_part = _send_before_receive(left_part, send, recv)
|
|
|
|
@ -108,26 +111,26 @@ def _adasum_opt_forward_process(left_send, allreduce, parameter_divisibility, al
|
|
|
|
|
recv_part = left_part
|
|
|
|
|
update_delta_w = _send_recv_res(left_send, recv_part, left_part, allreduce, parameter_divisibility,
|
|
|
|
|
allreduce_node_num)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return update_delta_w
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_adasum_opt_rollback = C.MultitypeFuncGraph("adasum_opt_rollback")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_adasum_opt_rollback.register("Bool", "Bool", "Tensor", "Function", "Function")
|
|
|
|
|
def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, send, recv):
|
|
|
|
|
"""adasum optimizer rollback process."""
|
|
|
|
|
"""adaSum优化器的回滚处理函数."""
|
|
|
|
|
if parameter_divisibility:
|
|
|
|
|
if left_send:
|
|
|
|
|
recv_part = _send_before_receive(delta_w, send, recv)
|
|
|
|
|
else:
|
|
|
|
|
recv_part = _receive_before_send(delta_w, send, recv)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
recv_part = P.Squeeze()(recv_part)
|
|
|
|
|
recv_part = P.Reshape()(recv_part, (-1,))
|
|
|
|
|
delta_w = P.Reshape()(delta_w, (-1,))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if left_send:
|
|
|
|
|
res = P.Concat()((recv_part, delta_w))
|
|
|
|
|
else:
|
|
|
|
@ -135,28 +138,28 @@ def _adasum_opt_rollback_process(left_send, parameter_divisibility, delta_w, sen
|
|
|
|
|
else:
|
|
|
|
|
res = delta_w
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaSum(Cell):
|
|
|
|
|
r"""
|
|
|
|
|
The Adaptive Summation, or AdaSum, is a novel algorithm for improving distributed data
|
|
|
|
|
parallel training of Deep Learning models.
|
|
|
|
|
|
|
|
|
|
自适应加法(AdaSum)是一种新算法,用于改善深度学习模型的分布式数据并行训练。
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
rank (int): Rank number.
|
|
|
|
|
device_number (int): Device number.
|
|
|
|
|
group_number (int): Group number.
|
|
|
|
|
parameter_tuple (Tuple(Parameter)): Tuple of parameters.
|
|
|
|
|
|
|
|
|
|
rank (int): 排名编号。
|
|
|
|
|
device_number (int): 设备数量。
|
|
|
|
|
group_number (int): 组数量。
|
|
|
|
|
parameter_tuple (Tuple(Parameter)): 参数元组。
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **delta_weights** (Tuple(Tensor)) - Tuple of gradients.
|
|
|
|
|
- **parameters** (Tuple(Parameter)) - Tuple of current parameters.
|
|
|
|
|
- **old_parameters** (Tuple(Parameter)) - Tuple of last parameters.
|
|
|
|
|
|
|
|
|
|
- **delta_weights** (Tuple(Tensor)) - 梯度的元组。
|
|
|
|
|
- **parameters** (Tuple(Parameter)) - 当前参数的元组。
|
|
|
|
|
- **old_parameters** (Tuple(Parameter)) - 上一参数的元组。
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
- **adasum_parameters** (Tuple(Tensor)) - Tuple of parameters after adasum process.
|
|
|
|
|
- **adasum_parameters** (Tuple(Tensor)) - 经过adasum处理后的参数元组。
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, rank, device_number, group_number, parameter_tuple):
|
|
|
|
|
"""AdaSum类的初始化函数."""
|
|
|
|
|
super(AdaSum, self).__init__()
|
|
|
|
|
self.rank = rank
|
|
|
|
|
self.device_number = device_number
|
|
|
|
@ -164,9 +167,9 @@ class AdaSum(Cell):
|
|
|
|
|
self.parameter_tuple = parameter_tuple
|
|
|
|
|
self._generate_communication_op()
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_communication_op(self):
|
|
|
|
|
"""generate communication op."""
|
|
|
|
|
"""生成通信操作的私有方法."""
|
|
|
|
|
self.calc_times = int(math.log(self.group_number, 2))
|
|
|
|
|
self.send_node = []
|
|
|
|
|
self.send_list_forward = []
|
|
|
|
@ -179,7 +182,7 @@ class AdaSum(Cell):
|
|
|
|
|
self.allreduce_node_num_list = []
|
|
|
|
|
last_delta_weights = []
|
|
|
|
|
group_start_rank = (self.rank // self.device_number) * self.device_number
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for step in range(self.calc_times):
|
|
|
|
|
current_group = self.device_number * (2 ** step)
|
|
|
|
|
sr_target = self.rank
|
|
|
|
@ -189,7 +192,7 @@ class AdaSum(Cell):
|
|
|
|
|
else:
|
|
|
|
|
dest_target = sr_target - current_group
|
|
|
|
|
self.send_node.append(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
neighbor_ids = []
|
|
|
|
|
group_name_last = 0
|
|
|
|
|
for index in range(2 ** (step + 1)):
|
|
|
|
@ -201,7 +204,7 @@ class AdaSum(Cell):
|
|
|
|
|
group_name_last += neighbor_id
|
|
|
|
|
group_name = "adasum_" + str(step) + "_" + str(group_name_last)
|
|
|
|
|
create_group(group_name, neighbor_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
send_left = []
|
|
|
|
|
send_right = []
|
|
|
|
|
recv_left = []
|
|
|
|
@ -234,7 +237,7 @@ class AdaSum(Cell):
|
|
|
|
|
send_right.append(send)
|
|
|
|
|
recv_right.append(recv)
|
|
|
|
|
weights_index += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.send_node and self.send_node[-1]:
|
|
|
|
|
self.send_list_forward.append(send_left)
|
|
|
|
|
self.send_list_rollback.append(send_right)
|
|
|
|
@ -247,27 +250,27 @@ class AdaSum(Cell):
|
|
|
|
|
self.recv_list_forward.append(recv_left)
|
|
|
|
|
self.recv_list_rollback.append(recv_right)
|
|
|
|
|
last_delta_weights = left_delta_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server_all_reduce = P.AllReduce("sum", group_name)
|
|
|
|
|
server_all_reduce.add_prim_attr("fusion", fusion_id + 2)
|
|
|
|
|
self.allreduce_list.append(server_all_reduce)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for param_divisibility in delta_weights_divisibility:
|
|
|
|
|
if param_divisibility:
|
|
|
|
|
allreduce_node_num += (0,)
|
|
|
|
|
else:
|
|
|
|
|
allreduce_node_num += (2 ** (step + 1),)
|
|
|
|
|
self.allreduce_node_num_list.append(allreduce_node_num)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
broadcast_group = [x for x in range(group_start_rank, group_start_rank + self.device_number)]
|
|
|
|
|
broadcast_group_name = "broadcast_group_" + str(group_start_rank)
|
|
|
|
|
create_group(broadcast_group_name, broadcast_group)
|
|
|
|
|
for b_rank in range(len(broadcast_group)):
|
|
|
|
|
self.broadcast_list.append(P.Broadcast(b_rank, group=broadcast_group_name))
|
|
|
|
|
self.sync_barrier = P.AllReduce("sum", group=broadcast_group_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_delta_weights_info(self, last_delta_weights):
|
|
|
|
|
"""get delta weights info."""
|
|
|
|
|
"""获取delta权重信息的私有方法."""
|
|
|
|
|
half_delta_weights = []
|
|
|
|
|
if last_delta_weights:
|
|
|
|
|
half_delta_weights = last_delta_weights
|
|
|
|
@ -292,14 +295,16 @@ class AdaSum(Cell):
|
|
|
|
|
right_delta_weights.append((right_shape, dtype))
|
|
|
|
|
delta_weights_divisibility += (divisibility_flag,)
|
|
|
|
|
return left_delta_weights, right_delta_weights, delta_weights_divisibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _hash(self, step, target, weights_index):
|
|
|
|
|
"""计算哈希值的私有方法."""
|
|
|
|
|
target = "tag" + str(step) + str(target) + str(weights_index)
|
|
|
|
|
target_hash = hashlib.sha1(target.encode()).hexdigest()
|
|
|
|
|
hash_res = int(int(target_hash, 16) % MAX_NUM_HASH)
|
|
|
|
|
return hash_res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def construct(self, delta_weights, parameters, old_parameters):
|
|
|
|
|
"""构建方法,用于执行adaSum优化过程."""
|
|
|
|
|
forward_weights = [delta_weights]
|
|
|
|
|
for i in range(self.calc_times):
|
|
|
|
|
process_weights = self.hyper_map(F.partial(_adasum_opt_forward, self.send_node[i], self.allreduce_list[i]),
|
|
|
|
@ -314,4 +319,4 @@ class AdaSum(Cell):
|
|
|
|
|
forward_weights[j] = process_weights
|
|
|
|
|
adasum_parameters = self.hyper_map(F.partial(_update_parameters), delta_weights, forward_weights[0],
|
|
|
|
|
parameters, old_parameters)
|
|
|
|
|
return adasum_parameters
|
|
|
|
|
return adasum_parameters
|