mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fixed the problem of loading universal checkpoint error in multi-machine mode. (#7601)
In a multi-machine environment, loading the stage3 universal checkpoint will produce incorrect results, causing the loss to increase abnormally.
This commit is contained in:
@ -2904,7 +2904,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
self.overflow = sd.get('overflow', self.overflow)
|
||||
|
||||
def load_hp_checkpoint_state(self, folder, key):
|
||||
local_rank = dist.get_local_rank()
|
||||
rank = dist.get_rank(group=self.dp_process_group)
|
||||
|
||||
# Load tensors from files and reshape them to flat vectors
|
||||
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)
|
||||
@ -2918,8 +2918,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
padding_size = world_size * partitioned_numel - unpartitioned_numel
|
||||
padding_tensor = torch.zeros(padding_size, dtype=loaded_checkpoint_state.dtype)
|
||||
loaded_checkpoint_state = torch.cat([loaded_checkpoint_state, padding_tensor])
|
||||
checkpoint_state_partition = loaded_checkpoint_state.narrow(0, local_rank * partitioned_numel,
|
||||
partitioned_numel)
|
||||
checkpoint_state_partition = loaded_checkpoint_state.narrow(0, rank * partitioned_numel, partitioned_numel)
|
||||
|
||||
return checkpoint_state_partition
|
||||
|
||||
|
Reference in New Issue
Block a user