mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Fix the universal checkpoint issue for stage3 when there are multiple subgroups. (#7585)
**Describe the bug** When the model is large and there are multiple subgroups, we use ds_to_universal.py, will fail ,the error log are below: ``` *** 1. Extracting ZeRO fragments 0%| | 0/1 [00:03<?, ?it/s] Traceback (most recent call last): File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 21, in <module> main() File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 18, in main ds_to_universal_main(args) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 523, in main _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 375, in _extract_zero_shard_files_stage3 _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 359, in _do_parallel_work results.append(do_work(work)) ^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 167, in extract_zero_shards_stage3 dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 194, in dump_param_fragment state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: start (0) + length (155582464) exceeds dimension size (74499072). ``` **To Reproduce** Steps to reproduce the behavior: 1. Use large model to run, or set sub_group_size to a lower value. Then train and save model 2. Run ds_to_universal.py **The reason** I found that the previous stage3 universal checkpoint implementation did not take subgroups into account. I also found the following problems during debugging. * Unable to handle multiple sub-groups, which will result in data loss * When load_checkpoint is True, then all process will save to same zero model checkpoint file. If multiple processes write at the same time, the file will be corrupted. Occasionally, file corruption was discovered during testing. Relete issue: #7584 --------- Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -152,14 +152,14 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
|
||||
def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
|
||||
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)
|
||||
|
||||
for idx, sub_group_shape in enumerate(param_shapes):
|
||||
flat_state = dict(
|
||||
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
|
||||
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"],
|
||||
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0],
|
||||
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg"],
|
||||
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg_sq"],
|
||||
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][idx],
|
||||
)
|
||||
|
||||
offset = 0
|
||||
for name, shape in param_shapes.items():
|
||||
for name, shape in sub_group_shape.items():
|
||||
unpartitioned_numel = shape.numel()
|
||||
partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree)
|
||||
padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel))
|
||||
@ -390,10 +390,10 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
|
||||
print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')
|
||||
|
||||
|
||||
def _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir):
|
||||
def _merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir):
|
||||
zero_output_folder = os.path.join(args.output_folder, "zero")
|
||||
do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir)
|
||||
_do_parallel_work(do_work, param_shapes.keys(), args.num_merge_workers)
|
||||
_do_parallel_work(do_work, param_keys, args.num_merge_workers)
|
||||
|
||||
|
||||
def _zero_partitioned_param_info(unpartitioned_numel, world_size):
|
||||
@ -514,7 +514,6 @@ def main(args):
|
||||
else:
|
||||
model_files = _get_model_state_files(args.input_folder)
|
||||
param_shapes = _parse_model_states_stage3(model_files)
|
||||
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
||||
dp_degree = len(model_files)
|
||||
|
||||
temp_dir = os.path.join(args.output_folder, 'tmp')
|
||||
@ -523,7 +522,8 @@ def main(args):
|
||||
_extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)
|
||||
|
||||
print('*** 2. Merging slices .....')
|
||||
_merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir)
|
||||
param_keys = {key for sub_group_shapes in param_shapes for key in sub_group_shapes.keys()}
|
||||
_merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir)
|
||||
|
||||
print('*** 3. Saving common optimizer states')
|
||||
_save_optimizer_state_stage3(args, optim_files)
|
||||
|
@ -1838,6 +1838,7 @@ class DeepSpeedEngine(Module):
|
||||
optimizer = Stage3ZeroOptimizer(
|
||||
self.module,
|
||||
optimizer,
|
||||
self.param_names,
|
||||
timers=timers,
|
||||
ds_config=self.config,
|
||||
static_loss_scale=self.loss_scale(),
|
||||
@ -1886,6 +1887,7 @@ class DeepSpeedEngine(Module):
|
||||
model_dtype, gradient_accumulation_dtype = self.get_data_types()
|
||||
optimizer = MiCS_Optimizer(self.module,
|
||||
basic_optimizer,
|
||||
self.param_names,
|
||||
timers=timers,
|
||||
ds_config=self.config,
|
||||
static_loss_scale=self.loss_scale(),
|
||||
@ -3010,9 +3012,6 @@ class DeepSpeedEngine(Module):
|
||||
mp_rank_str = f"{mp_rank:02d}"
|
||||
|
||||
if self.zero_optimization_partition_weights():
|
||||
if self.load_universal_checkpoint():
|
||||
filename = "zero_pp_rank_0"
|
||||
else:
|
||||
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
|
||||
ckpt_name = os.path.join(
|
||||
checkpoints_path,
|
||||
@ -3053,6 +3052,9 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
ckpt_files = glob.glob(ckpt_file_pattern)
|
||||
ckpt_files.sort()
|
||||
if self.load_universal_checkpoint():
|
||||
return [ckpt_files[0]]
|
||||
else:
|
||||
return ckpt_files
|
||||
|
||||
def load_checkpoint(self,
|
||||
|
@ -367,6 +367,7 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
|
||||
def __init__(self,
|
||||
module,
|
||||
init_optimizer,
|
||||
param_names,
|
||||
timers,
|
||||
ds_config,
|
||||
static_loss_scale=1,
|
||||
@ -398,7 +399,7 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
|
||||
aio_config=None):
|
||||
|
||||
log_dist("Init MiCS optimizer", ranks=[0])
|
||||
super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
|
||||
super().__init__(module, init_optimizer, param_names, timers, ds_config, static_loss_scale, dynamic_loss_scale,
|
||||
dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
|
||||
max_reuse_distance, max_live_parameters, param_persistence_threshold,
|
||||
model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,
|
||||
|
@ -141,6 +141,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
self,
|
||||
module,
|
||||
init_optimizer,
|
||||
param_names,
|
||||
timers,
|
||||
ds_config,
|
||||
static_loss_scale=1.0,
|
||||
@ -200,6 +201,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
raise SystemError("Cannot use fp16 without accelerator.")
|
||||
|
||||
self.optimizer = init_optimizer
|
||||
self.param_names = param_names
|
||||
|
||||
# Use torch (un)flatten ops
|
||||
self.flatten = _flatten_dense_tensors
|
||||
@ -2806,8 +2808,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.")
|
||||
|
||||
if checkpoint_folder:
|
||||
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
|
||||
param_shapes)
|
||||
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
|
||||
else:
|
||||
self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)],
|
||||
load_optimizer_states=load_optimizer_states)
|
||||
@ -2828,11 +2829,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
self.persistent_parameters[0].partition(self.persistent_parameters)
|
||||
# self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather
|
||||
|
||||
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
|
||||
param_shapes):
|
||||
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes)
|
||||
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
|
||||
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder)
|
||||
|
||||
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes):
|
||||
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir):
|
||||
""" Load optimizer and model states from the checkpoint directory. """
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
|
||||
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
|
||||
@ -2842,18 +2842,34 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
optim_sd = torch.load(optim_state_path, weights_only=False)
|
||||
self._load_global_state_stage3(optim_sd)
|
||||
|
||||
key_list = ["fp32", "exp_avg", "exp_avg_sq"]
|
||||
|
||||
for key in key_list:
|
||||
key_tensor = torch.empty(0)
|
||||
for layer in param_shapes[0].keys():
|
||||
key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, layer), key)
|
||||
key_tensor = torch.cat((key_tensor, key_layer_state_partition))
|
||||
# Generally the step of each optimizer file should be the same, we can obtain from any parameter.
|
||||
state_step = optim_sd[OPTIMIZER_STATE_DICT]['state'][0]['step']
|
||||
for key in ["fp32", "exp_avg", "exp_avg_sq"]:
|
||||
for sub_group_id, fp16_group in enumerate(self.fp16_groups):
|
||||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||
key_tensor = torch.zeros_like(fp32_param)
|
||||
offset = 0
|
||||
for param in fp16_group:
|
||||
if param not in self.param_names:
|
||||
raise ValueError(f"failed to find optimizer param in named params")
|
||||
param_name = self.param_names[param]
|
||||
key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, param_name),
|
||||
key)
|
||||
key_tensor.narrow(0, offset, key_layer_state_partition.numel()).copy_(key_layer_state_partition)
|
||||
offset += key_layer_state_partition.numel()
|
||||
if key == "fp32":
|
||||
self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor)
|
||||
self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0])
|
||||
self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(key_tensor)
|
||||
self.optimizer.state[fp32_param]['step'] = state_step
|
||||
else:
|
||||
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor
|
||||
self.optimizer.state[fp32_param][key] = key_tensor
|
||||
|
||||
for param_group in self.optimizer.param_groups:
|
||||
# Generally, the hyperparameters of each parameter should be the same, we can obtain from any parameter.
|
||||
for key, value in optim_sd[OPTIMIZER_STATE_DICT]["param_groups"][0].items():
|
||||
if key == 'params':
|
||||
param_group['params'] = []
|
||||
else:
|
||||
param_group[key] = value
|
||||
|
||||
if self.swap_optimizer:
|
||||
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
|
||||
@ -2869,10 +2885,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer):
|
||||
self._release_sub_group(sub_group_id, timer_names)
|
||||
self._post_step(timer_names)
|
||||
|
||||
self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT])
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['params'] = []
|
||||
|
||||
for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
|
||||
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
|
||||
if sum(fp32_param.size()) > 0:
|
||||
|
@ -3,6 +3,9 @@
|
||||
|
||||
# DeepSpeed Team
|
||||
|
||||
import os
|
||||
import math
|
||||
|
||||
import deepspeed
|
||||
from types import SimpleNamespace
|
||||
from torch.utils._pytree import tree_map
|
||||
@ -72,13 +75,13 @@ def init_ds_engine(model, ds_config, use_torch_adam):
|
||||
return model
|
||||
|
||||
|
||||
def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir):
|
||||
def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, world_size):
|
||||
if dtype == torch.bfloat16 and not bf16_required_version_check():
|
||||
return
|
||||
|
||||
test_step = 8
|
||||
|
||||
model = SimpleModel(hidden_dim)
|
||||
model = SimpleModel(hidden_dim, nlayers=2)
|
||||
model = init_ds_engine(model, ds_config, use_torch_adam)
|
||||
data_loader = random_dataloader(model=model,
|
||||
total_samples=test_step,
|
||||
@ -124,6 +127,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype,
|
||||
model.optimizer._set_fp32_optimizer_param_groups()
|
||||
optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
|
||||
model.optimizer._clear_fp32_optimizer_param_groups()
|
||||
update_gathered_stage3_optimizer(optimizer_state, model._get_zero_param_shapes(), world_size)
|
||||
else:
|
||||
optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
|
||||
|
||||
@ -135,7 +139,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype,
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ds_config(zero_stage, dtype):
|
||||
def ds_config(zero_stage, dtype, sub_group_size):
|
||||
ds_config = {
|
||||
"train_batch_size": 8,
|
||||
"optimizer": {
|
||||
@ -149,6 +153,8 @@ def ds_config(zero_stage, dtype):
|
||||
ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
||||
elif dtype == torch.bfloat16:
|
||||
ds_config["bf16"] = {"enabled": True}
|
||||
if sub_group_size > 0:
|
||||
ds_config["zero_optimization"]["sub_group_size"] = sub_group_size
|
||||
return ds_config
|
||||
|
||||
|
||||
@ -157,7 +163,7 @@ class _baseline(DistributedFixture):
|
||||
|
||||
def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam):
|
||||
hidden_dim = 10
|
||||
train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir)
|
||||
train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, self.world_size)
|
||||
|
||||
|
||||
class baseline_ws2(_baseline):
|
||||
@ -168,13 +174,46 @@ class baseline_ws4(_baseline):
|
||||
world_size = 4
|
||||
|
||||
|
||||
# Stage3 use shard parameter, need to reorganize the optimizer parameters.
|
||||
def update_gathered_stage3_optimizer(optimizer_state, param_shapes, world_size):
|
||||
for sub_group_id, group in enumerate(optimizer_state["param_groups"]):
|
||||
group["params"] = None
|
||||
|
||||
new_state = {}
|
||||
for sub_group_id, sub_group_param_shape in enumerate(param_shapes):
|
||||
total_numel = optimizer_state['state'][sub_group_id]['exp_avg'].numel()
|
||||
assert total_numel % world_size == 0
|
||||
numel_per_rank = total_numel // world_size
|
||||
param_offset_in_current_rank = 0
|
||||
for param_name, param_shape in sub_group_param_shape.items():
|
||||
param_numel = param_shape.numel()
|
||||
param_partition_numel = math.ceil(param_numel / world_size)
|
||||
param_optimizer_tensor = {
|
||||
"exp_avg": torch.zeros(param_numel),
|
||||
"exp_avg_sq": torch.zeros(param_numel),
|
||||
"step": optimizer_state['state'][sub_group_id]['step'],
|
||||
}
|
||||
for key in ["exp_avg", "exp_avg_sq"]:
|
||||
write_offset = 0
|
||||
for rank in range(world_size):
|
||||
offset = param_offset_in_current_rank + rank * numel_per_rank
|
||||
length = min(param_partition_numel, param_numel - rank * param_partition_numel)
|
||||
tmp = optimizer_state['state'][sub_group_id][key].narrow(0, offset, length)
|
||||
param_optimizer_tensor[key].narrow(0, write_offset, length).copy_(tmp)
|
||||
write_offset += length
|
||||
param_offset_in_current_rank += param_partition_numel
|
||||
new_state[param_name] = param_optimizer_tensor
|
||||
optimizer_state["state"] = new_state
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
|
||||
@pytest.mark.parametrize("zero_stage", [1, 3])
|
||||
@pytest.mark.parametrize("use_torch_adam", [False, True])
|
||||
@pytest.mark.parametrize("load_optim", [False, True])
|
||||
@pytest.mark.parametrize("sub_group_size", [-1, 100])
|
||||
class TestZeROUniversalCheckpointDP(DistributedTest):
|
||||
|
||||
def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
|
||||
def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam, world_size):
|
||||
if dtype == torch.bfloat16 and not bf16_required_version_check():
|
||||
pytest.skip(
|
||||
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
|
||||
@ -184,14 +223,20 @@ class TestZeROUniversalCheckpointDP(DistributedTest):
|
||||
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)
|
||||
|
||||
ds_config["checkpoint"] = {"load_universal": True}
|
||||
univ_model = SimpleModel(hidden_dim)
|
||||
univ_model = SimpleModel(hidden_dim, nlayers=2)
|
||||
univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam)
|
||||
univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim)
|
||||
|
||||
model_state = univ_model.state_dict()
|
||||
compare_state_dicts(model_state, loaded_model_state)
|
||||
|
||||
if load_optim and ds_config["zero_optimization"]["stage"] != 3:
|
||||
if load_optim:
|
||||
if ds_config["zero_optimization"]["stage"] == 3:
|
||||
univ_model.optimizer._set_fp32_optimizer_param_groups()
|
||||
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
|
||||
univ_model.optimizer._clear_fp32_optimizer_param_groups()
|
||||
update_gathered_stage3_optimizer(optimizer_state, univ_model._get_zero_param_shapes(), world_size)
|
||||
else:
|
||||
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
|
||||
# padding sizes may differ when dp sizes are different
|
||||
param_count = sum(p.numel() for p in univ_model.parameters())
|
||||
@ -216,12 +261,12 @@ class TestZeROUniversalCheckpointDP(DistributedTest):
|
||||
|
||||
@pytest.mark.world_size(2)
|
||||
def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
|
||||
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
|
||||
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2)
|
||||
|
||||
@pytest.mark.world_size(2)
|
||||
def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
|
||||
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
|
||||
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2)
|
||||
|
||||
@pytest.mark.world_size(4)
|
||||
def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
|
||||
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
|
||||
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 4)
|
||||
|
Reference in New Issue
Block a user