Fixed the issue that universal checkpoint cannot be loaded for stage3 when world size expansion. (#7599)

When the world size expands from 2 to 4, then convert to universal
checkpoint, and load from universal checkpoint.
The new rank, for example, rank3 will load model file
`zero_pp_rank_3_mp_rank_00_model_states.pt`. But this file was not
produced during the last execution.
For stage3, just load the first file, that is
`zero_pp_rank_0_mp_rank_00_model_states`.
The existing unit test
TestZeROUniversalCheckpointDP::test_dp_world_size_2to4 can verify this
problem.

---------

Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
zhengchenyu
2025-10-01 23:37:19 +08:00
committed by GitHub
parent 330f738cd7
commit 07e76bd45f

View File

@ -3011,7 +3011,7 @@ class DeepSpeedEngine(Module):
bf16_mode = self.bfloat16_enabled()
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode)
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None, pp_placeholder=None):
if mp_placeholder is not None:
mp_rank_str = mp_placeholder
else:
@ -3019,7 +3019,12 @@ class DeepSpeedEngine(Module):
mp_rank_str = f"{mp_rank:02d}"
if self.zero_optimization_partition_weights():
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
if pp_placeholder is not None:
pp_rank = pp_placeholder
else:
pp_rank = dist.get_rank(group=self.optimizer.dp_process_group)
filename = "zero_pp_rank_{}".format(pp_rank)
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
@ -3054,15 +3059,15 @@ class DeepSpeedEngine(Module):
def _get_all_ckpt_names(self, checkpoints_path, tag):
# It is required that (checkpoints_path, tag) are consistent among all ranks.
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
tag,
mp_placeholder="*",
pp_placeholder="0" if self.load_universal_checkpoint() else None)
import glob
ckpt_files = glob.glob(ckpt_file_pattern)
ckpt_files.sort()
if self.load_universal_checkpoint():
return [ckpt_files[0]]
else:
return ckpt_files
return ckpt_files
def load_checkpoint(self,
load_dir,