Fix the universal checkpoint issue for stage3 when there are multiple subgroups. (#7585)

**Describe the bug**

When the model is large and there are multiple subgroups, we use
ds_to_universal.py, will fail ,the error log are below:

```
*** 1. Extracting ZeRO fragments
  0%|                                                     | 0/1 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 21, in <module>
    main()
  File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 18, in main
    ds_to_universal_main(args)
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 523, in main
    _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 375, in _extract_zero_shard_files_stage3
    _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers)
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 359, in _do_parallel_work
    results.append(do_work(work))
                   ^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 167, in extract_zero_shards_stage3
    dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
  File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 194, in dump_param_fragment
    state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: start (0) + length (155582464) exceeds dimension size (74499072).

```

**To Reproduce**
Steps to reproduce the behavior:
1. Use large model to run, or set sub_group_size to a lower value. Then
train and save model
2. Run ds_to_universal.py

**The reason**

I found that the previous stage3 universal checkpoint implementation did
not take subgroups into account. I also found the following problems
during debugging.

* Unable to handle multiple sub-groups, which will result in data loss
* When load_checkpoint is True, then all process will save to same zero
model checkpoint file. If multiple processes write at the same time, the
file will be corrupted. Occasionally, file corruption was discovered
during testing.

Relete issue:  #7584

---------

Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
zhengchenyu
2025-09-28 01:39:43 +08:00
committed by GitHub
parent 6ea345ae27
commit 91d14527b6
5 changed files with 117 additions and 57 deletions

View File

@ -152,14 +152,14 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)
for idx, sub_group_shape in enumerate(param_shapes):
flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"],
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0],
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg"],
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg_sq"],
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][idx],
)
offset = 0
for name, shape in param_shapes.items():
for name, shape in sub_group_shape.items():
unpartitioned_numel = shape.numel()
partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree)
padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel))
@ -390,10 +390,10 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')
def _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir):
def _merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir):
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir)
_do_parallel_work(do_work, param_shapes.keys(), args.num_merge_workers)
_do_parallel_work(do_work, param_keys, args.num_merge_workers)
def _zero_partitioned_param_info(unpartitioned_numel, world_size):
@ -514,7 +514,6 @@ def main(args):
else:
model_files = _get_model_state_files(args.input_folder)
param_shapes = _parse_model_states_stage3(model_files)
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
dp_degree = len(model_files)
temp_dir = os.path.join(args.output_folder, 'tmp')
@ -523,7 +522,8 @@ def main(args):
_extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)
print('*** 2. Merging slices .....')
_merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir)
param_keys = {key for sub_group_shapes in param_shapes for key in sub_group_shapes.keys()}
_merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir)
print('*** 3. Saving common optimizer states')
_save_optimizer_state_stage3(args, optim_files)

View File

@ -1838,6 +1838,7 @@ class DeepSpeedEngine(Module):
optimizer = Stage3ZeroOptimizer(
self.module,
optimizer,
self.param_names,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
@ -1886,6 +1887,7 @@ class DeepSpeedEngine(Module):
model_dtype, gradient_accumulation_dtype = self.get_data_types()
optimizer = MiCS_Optimizer(self.module,
basic_optimizer,
self.param_names,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
@ -3010,9 +3012,6 @@ class DeepSpeedEngine(Module):
mp_rank_str = f"{mp_rank:02d}"
if self.zero_optimization_partition_weights():
if self.load_universal_checkpoint():
filename = "zero_pp_rank_0"
else:
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
ckpt_name = os.path.join(
checkpoints_path,
@ -3053,6 +3052,9 @@ class DeepSpeedEngine(Module):
ckpt_files = glob.glob(ckpt_file_pattern)
ckpt_files.sort()
if self.load_universal_checkpoint():
return [ckpt_files[0]]
else:
return ckpt_files
def load_checkpoint(self,

View File

@ -367,6 +367,7 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
def __init__(self,
module,
init_optimizer,
param_names,
timers,
ds_config,
static_loss_scale=1,
@ -398,7 +399,7 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
aio_config=None):
log_dist("Init MiCS optimizer", ranks=[0])
super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
super().__init__(module, init_optimizer, param_names, timers, ds_config, static_loss_scale, dynamic_loss_scale,
dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
max_reuse_distance, max_live_parameters, param_persistence_threshold,
model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,

View File

@ -141,6 +141,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self,
module,
init_optimizer,
param_names,
timers,
ds_config,
static_loss_scale=1.0,
@ -200,6 +201,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
raise SystemError("Cannot use fp16 without accelerator.")
self.optimizer = init_optimizer
self.param_names = param_names
# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
@ -2806,8 +2808,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.")
if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
param_shapes)
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
else:
self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)],
load_optimizer_states=load_optimizer_states)
@ -2828,11 +2829,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self.persistent_parameters[0].partition(self.persistent_parameters)
# self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
param_shapes):
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes)
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder)
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes):
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir):
""" Load optimizer and model states from the checkpoint directory. """
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
@ -2842,18 +2842,34 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)
key_list = ["fp32", "exp_avg", "exp_avg_sq"]
for key in key_list:
key_tensor = torch.empty(0)
for layer in param_shapes[0].keys():
key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, layer), key)
key_tensor = torch.cat((key_tensor, key_layer_state_partition))
# Generally the step of each optimizer file should be the same, we can obtain from any parameter.
state_step = optim_sd[OPTIMIZER_STATE_DICT]['state'][0]['step']
for key in ["fp32", "exp_avg", "exp_avg_sq"]:
for sub_group_id, fp16_group in enumerate(self.fp16_groups):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
key_tensor = torch.zeros_like(fp32_param)
offset = 0
for param in fp16_group:
if param not in self.param_names:
raise ValueError(f"failed to find optimizer param in named params")
param_name = self.param_names[param]
key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, param_name),
key)
key_tensor.narrow(0, offset, key_layer_state_partition.numel()).copy_(key_layer_state_partition)
offset += key_layer_state_partition.numel()
if key == "fp32":
self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor)
self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0])
self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(key_tensor)
self.optimizer.state[fp32_param]['step'] = state_step
else:
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor
self.optimizer.state[fp32_param][key] = key_tensor
for param_group in self.optimizer.param_groups:
# Generally, the hyperparameters of each parameter should be the same, we can obtain from any parameter.
for key, value in optim_sd[OPTIMIZER_STATE_DICT]["param_groups"][0].items():
if key == 'params':
param_group['params'] = []
else:
param_group[key] = value
if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
@ -2869,10 +2885,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
self._release_sub_group(sub_group_id, timer_names)
self._post_step(timer_names)
self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT])
for param_group in self.optimizer.param_groups:
param_group['params'] = []
for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
if sum(fp32_param.size()) > 0:

View File

@ -3,6 +3,9 @@
# DeepSpeed Team
import os
import math
import deepspeed
from types import SimpleNamespace
from torch.utils._pytree import tree_map
@ -72,13 +75,13 @@ def init_ds_engine(model, ds_config, use_torch_adam):
return model
def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir):
def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, world_size):
if dtype == torch.bfloat16 and not bf16_required_version_check():
return
test_step = 8
model = SimpleModel(hidden_dim)
model = SimpleModel(hidden_dim, nlayers=2)
model = init_ds_engine(model, ds_config, use_torch_adam)
data_loader = random_dataloader(model=model,
total_samples=test_step,
@ -124,6 +127,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype,
model.optimizer._set_fp32_optimizer_param_groups()
optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
model.optimizer._clear_fp32_optimizer_param_groups()
update_gathered_stage3_optimizer(optimizer_state, model._get_zero_param_shapes(), world_size)
else:
optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
@ -135,7 +139,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype,
@pytest.fixture
def ds_config(zero_stage, dtype):
def ds_config(zero_stage, dtype, sub_group_size):
ds_config = {
"train_batch_size": 8,
"optimizer": {
@ -149,6 +153,8 @@ def ds_config(zero_stage, dtype):
ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif dtype == torch.bfloat16:
ds_config["bf16"] = {"enabled": True}
if sub_group_size > 0:
ds_config["zero_optimization"]["sub_group_size"] = sub_group_size
return ds_config
@ -157,7 +163,7 @@ class _baseline(DistributedFixture):
def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam):
hidden_dim = 10
train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir)
train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, self.world_size)
class baseline_ws2(_baseline):
@ -168,13 +174,46 @@ class baseline_ws4(_baseline):
world_size = 4
# Stage3 use shard parameter, need to reorganize the optimizer parameters.
def update_gathered_stage3_optimizer(optimizer_state, param_shapes, world_size):
for sub_group_id, group in enumerate(optimizer_state["param_groups"]):
group["params"] = None
new_state = {}
for sub_group_id, sub_group_param_shape in enumerate(param_shapes):
total_numel = optimizer_state['state'][sub_group_id]['exp_avg'].numel()
assert total_numel % world_size == 0
numel_per_rank = total_numel // world_size
param_offset_in_current_rank = 0
for param_name, param_shape in sub_group_param_shape.items():
param_numel = param_shape.numel()
param_partition_numel = math.ceil(param_numel / world_size)
param_optimizer_tensor = {
"exp_avg": torch.zeros(param_numel),
"exp_avg_sq": torch.zeros(param_numel),
"step": optimizer_state['state'][sub_group_id]['step'],
}
for key in ["exp_avg", "exp_avg_sq"]:
write_offset = 0
for rank in range(world_size):
offset = param_offset_in_current_rank + rank * numel_per_rank
length = min(param_partition_numel, param_numel - rank * param_partition_numel)
tmp = optimizer_state['state'][sub_group_id][key].narrow(0, offset, length)
param_optimizer_tensor[key].narrow(0, write_offset, length).copy_(tmp)
write_offset += length
param_offset_in_current_rank += param_partition_numel
new_state[param_name] = param_optimizer_tensor
optimizer_state["state"] = new_state
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("zero_stage", [1, 3])
@pytest.mark.parametrize("use_torch_adam", [False, True])
@pytest.mark.parametrize("load_optim", [False, True])
@pytest.mark.parametrize("sub_group_size", [-1, 100])
class TestZeROUniversalCheckpointDP(DistributedTest):
def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam, world_size):
if dtype == torch.bfloat16 and not bf16_required_version_check():
pytest.skip(
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
@ -184,14 +223,20 @@ class TestZeROUniversalCheckpointDP(DistributedTest):
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)
ds_config["checkpoint"] = {"load_universal": True}
univ_model = SimpleModel(hidden_dim)
univ_model = SimpleModel(hidden_dim, nlayers=2)
univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam)
univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim)
model_state = univ_model.state_dict()
compare_state_dicts(model_state, loaded_model_state)
if load_optim and ds_config["zero_optimization"]["stage"] != 3:
if load_optim:
if ds_config["zero_optimization"]["stage"] == 3:
univ_model.optimizer._set_fp32_optimizer_param_groups()
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
univ_model.optimizer._clear_fp32_optimizer_param_groups()
update_gathered_stage3_optimizer(optimizer_state, univ_model._get_zero_param_shapes(), world_size)
else:
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
# padding sizes may differ when dp sizes are different
param_count = sum(p.numel() for p in univ_model.parameters())
@ -216,12 +261,12 @@ class TestZeROUniversalCheckpointDP(DistributedTest):
@pytest.mark.world_size(2)
def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2)
@pytest.mark.world_size(2)
def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2)
@pytest.mark.world_size(4)
def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 4)