Fixed the problem of loading universal checkpoint error in multi-machine mode. (#7601)

In a multi-machine environment, loading the stage3 universal checkpoint
will produce incorrect results, causing the loss to increase abnormally.
This commit is contained in:
zhengchenyu
2025-09-29 04:26:11 +08:00
committed by GitHub
parent 66c70312f2
commit 47b3fb5e7f

View File

@ -2904,7 +2904,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.overflow = sd.get('overflow', self.overflow)
def load_hp_checkpoint_state(self, folder, key):
local_rank = dist.get_local_rank()
rank = dist.get_rank(group=self.dp_process_group)
# Load tensors from files and reshape them to flat vectors
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)
@ -2918,8 +2918,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
padding_size = world_size * partitioned_numel - unpartitioned_numel
padding_tensor = torch.zeros(padding_size, dtype=loaded_checkpoint_state.dtype)
loaded_checkpoint_state = torch.cat([loaded_checkpoint_state, padding_tensor])
checkpoint_state_partition = loaded_checkpoint_state.narrow(0, local_rank * partitioned_numel,
partitioned_numel)
checkpoint_state_partition = loaded_checkpoint_state.narrow(0, rank * partitioned_numel, partitioned_numel)
return checkpoint_state_partition