|
|
|
@ -30,10 +30,12 @@ class AllGatherCell(Cell):
|
|
|
|
|
def __init__(self, group):
|
|
|
|
|
super(AllGatherCell, self).__init__(auto_prefix=False)
|
|
|
|
|
|
|
|
|
|
# 创建AllGather操作对象
|
|
|
|
|
self.allgather = AllGather(group)
|
|
|
|
|
|
|
|
|
|
@ms_function()
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
# 执行AllGather操作
|
|
|
|
|
x = self.allgather(x)
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
@ -50,10 +52,12 @@ class SaveOptShardCkptCell(Cell):
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, group):
|
|
|
|
|
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
|
|
|
|
|
# 创建AllGather操作对象
|
|
|
|
|
self.allgather1 = AllGather(group)
|
|
|
|
|
self.allgather2 = AllGather()
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
# 执行AllGather操作
|
|
|
|
|
x = self.allgather1(x)
|
|
|
|
|
x = self.allgather2(x)
|
|
|
|
|
|
|
|
|
@ -64,11 +68,14 @@ def get_allgather_cell(group, need_merge_twice=False):
|
|
|
|
|
"""Get AllGatherCell object."""
|
|
|
|
|
global _allgather_cell
|
|
|
|
|
if need_merge_twice:
|
|
|
|
|
# 如果需要两次合并,则创建SaveOptShardCkptCell对象
|
|
|
|
|
_allgather_cell = SaveOptShardCkptCell(group)
|
|
|
|
|
else:
|
|
|
|
|
if group:
|
|
|
|
|
# 如果有指定的设备组,则创建AllGatherCell对象
|
|
|
|
|
_allgather_cell = AllGatherCell(group)
|
|
|
|
|
else:
|
|
|
|
|
# 否则,创建AllGatherCell对象,使用全局通信组
|
|
|
|
|
_allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
|
|
|
|
|
return _allgather_cell
|
|
|
|
|
|
|
|
|
@ -77,4 +84,5 @@ def destroy_allgather_cell():
|
|
|
|
|
"""Destroy AllGatherCell object."""
|
|
|
|
|
global _allgather_cell
|
|
|
|
|
if _allgather_cell:
|
|
|
|
|
# 销毁AllGatherCell对象
|
|
|
|
|
_allgather_cell = None
|
|
|
|
|