branch_zwt
zouwentao 2 months ago
parent a672592d99
commit 3668e69ed5

@ -30,10 +30,12 @@ class AllGatherCell(Cell):
def __init__(self, group):
super(AllGatherCell, self).__init__(auto_prefix=False)
# 创建AllGather操作对象
self.allgather = AllGather(group)
@ms_function()
def construct(self, x):
# 执行AllGather操作
x = self.allgather(x)
return x
@ -50,10 +52,12 @@ class SaveOptShardCkptCell(Cell):
"""
def __init__(self, group):
super(SaveOptShardCkptCell, self).__init__(auto_prefix=False)
# 创建AllGather操作对象
self.allgather1 = AllGather(group)
self.allgather2 = AllGather()
def construct(self, x):
# 执行AllGather操作
x = self.allgather1(x)
x = self.allgather2(x)
@ -64,11 +68,14 @@ def get_allgather_cell(group, need_merge_twice=False):
"""Get AllGatherCell object."""
global _allgather_cell
if need_merge_twice:
# 如果需要两次合并则创建SaveOptShardCkptCell对象
_allgather_cell = SaveOptShardCkptCell(group)
else:
if group:
# 如果有指定的设备组则创建AllGatherCell对象
_allgather_cell = AllGatherCell(group)
else:
# 否则创建AllGatherCell对象使用全局通信组
_allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP)
return _allgather_cell
@ -77,4 +84,5 @@ def destroy_allgather_cell():
"""Destroy AllGatherCell object."""
global _allgather_cell
if _allgather_cell:
# 销毁AllGatherCell对象
_allgather_cell = None

Loading…
Cancel
Save