From 47b3fb5e7f544b97d61450d13d8e13a49d9418ca Mon Sep 17 00:00:00 2001 From: zhengchenyu Date: Mon, 29 Sep 2025 04:26:11 +0800 Subject: [PATCH] Fixed the problem of loading universal checkpoint error in multi-machine mode. (#7601) In a multi-machine environment, loading the stage3 universal checkpoint will produce incorrect results, causing the loss to increase abnormally. --- deepspeed/runtime/zero/stage3.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b163729e3..76790223a 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2904,7 +2904,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): self.overflow = sd.get('overflow', self.overflow) def load_hp_checkpoint_state(self, folder, key): - local_rank = dist.get_local_rank() + rank = dist.get_rank(group=self.dp_process_group) # Load tensors from files and reshape them to flat vectors loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1) @@ -2918,8 +2918,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): padding_size = world_size * partitioned_numel - unpartitioned_numel padding_tensor = torch.zeros(padding_size, dtype=loaded_checkpoint_state.dtype) loaded_checkpoint_state = torch.cat([loaded_checkpoint_state, padding_tensor]) - checkpoint_state_partition = loaded_checkpoint_state.narrow(0, local_rank * partitioned_numel, - partitioned_numel) + checkpoint_state_partition = loaded_checkpoint_state.narrow(0, rank * partitioned_numel, partitioned_numel) return checkpoint_state_partition