From 3668e69ed5161888cd488bccea92bf38b140d392 Mon Sep 17 00:00:00 2001 From: zouwentao <1692762422@qq.com> Date: Mon, 30 Dec 2024 18:17:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../mindspore/python/mindspore/parallel/_cell_wrapper.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/mindspore2022/mindspore/python/mindspore/parallel/_cell_wrapper.py b/src/mindspore2022/mindspore/python/mindspore/parallel/_cell_wrapper.py index fa063178..092ab801 100644 --- a/src/mindspore2022/mindspore/python/mindspore/parallel/_cell_wrapper.py +++ b/src/mindspore2022/mindspore/python/mindspore/parallel/_cell_wrapper.py @@ -30,10 +30,12 @@ class AllGatherCell(Cell): def __init__(self, group): super(AllGatherCell, self).__init__(auto_prefix=False) + # 创建AllGather操作对象 self.allgather = AllGather(group) @ms_function() def construct(self, x): + # 执行AllGather操作 x = self.allgather(x) return x @@ -50,10 +52,12 @@ class SaveOptShardCkptCell(Cell): """ def __init__(self, group): super(SaveOptShardCkptCell, self).__init__(auto_prefix=False) + # 创建AllGather操作对象 self.allgather1 = AllGather(group) self.allgather2 = AllGather() def construct(self, x): + # 执行AllGather操作 x = self.allgather1(x) x = self.allgather2(x) @@ -64,11 +68,14 @@ def get_allgather_cell(group, need_merge_twice=False): """Get AllGatherCell object.""" global _allgather_cell if need_merge_twice: + # 如果需要两次合并,则创建SaveOptShardCkptCell对象 _allgather_cell = SaveOptShardCkptCell(group) else: if group: + # 如果有指定的设备组,则创建AllGatherCell对象 _allgather_cell = AllGatherCell(group) else: + # 否则,创建AllGatherCell对象,使用全局通信组 _allgather_cell = AllGatherCell(GlobalComm.WORLD_COMM_GROUP) return _allgather_cell @@ -77,4 +84,5 @@ def destroy_allgather_cell(): """Destroy AllGatherCell object.""" global _allgather_cell if _allgather_cell: + # 销毁AllGatherCell对象 _allgather_cell = None