mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fixed the issue that universal checkpoint cannot be loaded for stage3 when world size expansion. (#7599)
When the world size expands from 2 to 4, then convert to universal checkpoint, and load from universal checkpoint. The new rank, for example, rank3 will load model file `zero_pp_rank_3_mp_rank_00_model_states.pt`. But this file was not produced during the last execution. For stage3, just load the first file, that is `zero_pp_rank_0_mp_rank_00_model_states`. The existing unit test TestZeROUniversalCheckpointDP::test_dp_world_size_2to4 can verify this problem. --------- Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -3011,7 +3011,7 @@ class DeepSpeedEngine(Module):
|
||||
bf16_mode = self.bfloat16_enabled()
|
||||
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode)
|
||||
|
||||
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
|
||||
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None, pp_placeholder=None):
|
||||
if mp_placeholder is not None:
|
||||
mp_rank_str = mp_placeholder
|
||||
else:
|
||||
@ -3019,7 +3019,12 @@ class DeepSpeedEngine(Module):
|
||||
mp_rank_str = f"{mp_rank:02d}"
|
||||
|
||||
if self.zero_optimization_partition_weights():
|
||||
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
|
||||
if pp_placeholder is not None:
|
||||
pp_rank = pp_placeholder
|
||||
else:
|
||||
pp_rank = dist.get_rank(group=self.optimizer.dp_process_group)
|
||||
|
||||
filename = "zero_pp_rank_{}".format(pp_rank)
|
||||
ckpt_name = os.path.join(
|
||||
checkpoints_path,
|
||||
str(tag),
|
||||
@ -3054,15 +3059,15 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
def _get_all_ckpt_names(self, checkpoints_path, tag):
|
||||
# It is required that (checkpoints_path, tag) are consistent among all ranks.
|
||||
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
|
||||
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
|
||||
tag,
|
||||
mp_placeholder="*",
|
||||
pp_placeholder="0" if self.load_universal_checkpoint() else None)
|
||||
import glob
|
||||
|
||||
ckpt_files = glob.glob(ckpt_file_pattern)
|
||||
ckpt_files.sort()
|
||||
if self.load_universal_checkpoint():
|
||||
return [ckpt_files[0]]
|
||||
else:
|
||||
return ckpt_files
|
||||
return ckpt_files
|
||||
|
||||
def load_checkpoint(self,
|
||||
load_dir,
|
||||
|
Reference in New Issue
Block a user