mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
**Describe the bug** When the model is large and there are multiple subgroups, we use ds_to_universal.py, will fail ,the error log are below: ``` *** 1. Extracting ZeRO fragments 0%| | 0/1 [00:03<?, ?it/s] Traceback (most recent call last): File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 21, in <module> main() File "/work/zhengchenyu/ai-project/qwen3/scripts/ds_to_universal_example.py", line 18, in main ds_to_universal_main(args) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 523, in main _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 375, in _extract_zero_shard_files_stage3 _do_parallel_work(do_work, list(range(dp_degree)), args.num_extract_workers) File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 359, in _do_parallel_work results.append(do_work(work)) ^^^^^^^^^^^^^ File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 167, in extract_zero_shards_stage3 dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, File "/opt/conda/lib/python3.11/site-packages/deepspeed/checkpoint/ds_to_universal.py", line 194, in dump_param_fragment state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone() ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: start (0) + length (155582464) exceeds dimension size (74499072). ``` **To Reproduce** Steps to reproduce the behavior: 1. Use large model to run, or set sub_group_size to a lower value. Then train and save model 2. Run ds_to_universal.py **The reason** I found that the previous stage3 universal checkpoint implementation did not take subgroups into account. I also found the following problems during debugging. * Unable to handle multiple sub-groups, which will result in data loss * When load_checkpoint is True, then all process will save to same zero model checkpoint file. If multiple processes write at the same time, the file will be corrupted. Occasionally, file corruption was discovered during testing. Relete issue: #7584 --------- Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
273 lines
10 KiB
Python
273 lines
10 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import os
|
|
import math
|
|
|
|
import deepspeed
|
|
from types import SimpleNamespace
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from deepspeed.utils.torch import required_torch_version
|
|
from deepspeed.checkpoint import UNIVERSAL_CHECKPOINT_INFO
|
|
from deepspeed.checkpoint.ds_to_universal import main as convert_to_universal
|
|
|
|
from unit.common import DistributedTest, DistributedFixture
|
|
from unit.simple_model import *
|
|
from unit.util import bf16_required_version_check
|
|
|
|
from unit.checkpoint.common import compare_opt_state_dicts, compare_state_dicts
|
|
|
|
import pytest
|
|
import deepspeed.comm as dist
|
|
|
|
|
|
def get_expected_mismatch_keys():
|
|
# torch 1.2.* stores raw tensor id numbers in checkpoint state which leads to
|
|
# false positive mismatches in checkpoint state comparisons.
|
|
# Newer torch versions store tensor ids as 0, 1, 2, ...
|
|
return [] if required_torch_version(min_version=1.4) else ['params']
|
|
|
|
|
|
def maybe_step(t):
|
|
return not torch.is_tensor(t) or (t.device.type == 'cpu' and t.numel() == 1)
|
|
|
|
|
|
def gather_opt_state(optimizer_state):
|
|
|
|
def gather_tensor(t):
|
|
|
|
if maybe_step(t):
|
|
return t
|
|
else:
|
|
buffer = [torch.zeros_like(t.flatten()) for _ in range(dist.get_world_size())]
|
|
dist.all_gather(buffer, t.flatten())
|
|
return torch.cat(buffer)
|
|
|
|
return tree_map(gather_tensor, optimizer_state)
|
|
|
|
|
|
def remove_pad_in_opt_state(optimizer_state, num_params):
|
|
|
|
def remove_pad(t):
|
|
if maybe_step(t):
|
|
return t
|
|
else:
|
|
return t[:num_params]
|
|
|
|
return tree_map(remove_pad, optimizer_state)
|
|
|
|
|
|
CP_TAG = "test_tag"
|
|
|
|
|
|
def init_ds_engine(model, ds_config, use_torch_adam):
|
|
|
|
if use_torch_adam:
|
|
ds_optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
|
|
del ds_config["optimizer"]
|
|
model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, optimizer=ds_optimizer)
|
|
else:
|
|
model, _, _, _ = deepspeed.initialize(config=ds_config, model=model, model_parameters=model.parameters())
|
|
|
|
return model
|
|
|
|
|
|
def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, world_size):
|
|
if dtype == torch.bfloat16 and not bf16_required_version_check():
|
|
return
|
|
|
|
test_step = 8
|
|
|
|
model = SimpleModel(hidden_dim, nlayers=2)
|
|
model = init_ds_engine(model, ds_config, use_torch_adam)
|
|
data_loader = random_dataloader(model=model,
|
|
total_samples=test_step,
|
|
hidden_dim=hidden_dim,
|
|
device=model.device,
|
|
dtype=dtype)
|
|
for batch in data_loader:
|
|
loss = model(batch[0], batch[1])
|
|
model.backward(loss)
|
|
model.step()
|
|
|
|
if ds_config["zero_optimization"]["stage"] == 3:
|
|
model.optimizer._set_fp32_optimizer_param_groups()
|
|
sd = model.optimizer.optimizer.state_dict() if load_optim else None
|
|
model.optimizer._clear_fp32_optimizer_param_groups()
|
|
else:
|
|
sd = model.optimizer.optimizer.state_dict() if load_optim else None
|
|
|
|
client_state = {}
|
|
client_state[UNIVERSAL_CHECKPOINT_INFO] = {}
|
|
client_state['iteration'] = test_step
|
|
model.save_checkpoint(tmpdir, tag=CP_TAG, client_state=client_state)
|
|
|
|
cp_dir = os.path.join(tmpdir, CP_TAG)
|
|
univ_cp_dir = f"{cp_dir}_universal"
|
|
|
|
args = SimpleNamespace(input_folder=cp_dir,
|
|
output_folder=univ_cp_dir,
|
|
num_extract_workers=1,
|
|
num_merge_workers=1,
|
|
keep_temp_folder=False,
|
|
strict=True,
|
|
inject_missing_state=False)
|
|
|
|
dist.barrier()
|
|
if dist.get_rank() == 0:
|
|
convert_to_universal(args)
|
|
|
|
model_state = model.state_dict()
|
|
optimizer_state = None
|
|
if load_optim:
|
|
if ds_config["zero_optimization"]["stage"] == 3:
|
|
model.optimizer._set_fp32_optimizer_param_groups()
|
|
optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
|
|
model.optimizer._clear_fp32_optimizer_param_groups()
|
|
update_gathered_stage3_optimizer(optimizer_state, model._get_zero_param_shapes(), world_size)
|
|
else:
|
|
optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
|
|
|
|
if dist.get_rank() == 0:
|
|
torch.save((model_state, optimizer_state), os.path.join(tmpdir, "baseline_state.pt"))
|
|
|
|
dist.barrier()
|
|
model.destroy()
|
|
|
|
|
|
@pytest.fixture
|
|
def ds_config(zero_stage, dtype, sub_group_size):
|
|
ds_config = {
|
|
"train_batch_size": 8,
|
|
"optimizer": {
|
|
"type": 'Adam'
|
|
},
|
|
"zero_optimization": {
|
|
"stage": zero_stage,
|
|
}
|
|
}
|
|
if dtype == torch.float16:
|
|
ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
|
|
elif dtype == torch.bfloat16:
|
|
ds_config["bf16"] = {"enabled": True}
|
|
if sub_group_size > 0:
|
|
ds_config["zero_optimization"]["sub_group_size"] = sub_group_size
|
|
return ds_config
|
|
|
|
|
|
class _baseline(DistributedFixture):
|
|
world_size = None
|
|
|
|
def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam):
|
|
hidden_dim = 10
|
|
train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, self.world_size)
|
|
|
|
|
|
class baseline_ws2(_baseline):
|
|
world_size = 2
|
|
|
|
|
|
class baseline_ws4(_baseline):
|
|
world_size = 4
|
|
|
|
|
|
# Stage3 use shard parameter, need to reorganize the optimizer parameters.
|
|
def update_gathered_stage3_optimizer(optimizer_state, param_shapes, world_size):
|
|
for sub_group_id, group in enumerate(optimizer_state["param_groups"]):
|
|
group["params"] = None
|
|
|
|
new_state = {}
|
|
for sub_group_id, sub_group_param_shape in enumerate(param_shapes):
|
|
total_numel = optimizer_state['state'][sub_group_id]['exp_avg'].numel()
|
|
assert total_numel % world_size == 0
|
|
numel_per_rank = total_numel // world_size
|
|
param_offset_in_current_rank = 0
|
|
for param_name, param_shape in sub_group_param_shape.items():
|
|
param_numel = param_shape.numel()
|
|
param_partition_numel = math.ceil(param_numel / world_size)
|
|
param_optimizer_tensor = {
|
|
"exp_avg": torch.zeros(param_numel),
|
|
"exp_avg_sq": torch.zeros(param_numel),
|
|
"step": optimizer_state['state'][sub_group_id]['step'],
|
|
}
|
|
for key in ["exp_avg", "exp_avg_sq"]:
|
|
write_offset = 0
|
|
for rank in range(world_size):
|
|
offset = param_offset_in_current_rank + rank * numel_per_rank
|
|
length = min(param_partition_numel, param_numel - rank * param_partition_numel)
|
|
tmp = optimizer_state['state'][sub_group_id][key].narrow(0, offset, length)
|
|
param_optimizer_tensor[key].narrow(0, write_offset, length).copy_(tmp)
|
|
write_offset += length
|
|
param_offset_in_current_rank += param_partition_numel
|
|
new_state[param_name] = param_optimizer_tensor
|
|
optimizer_state["state"] = new_state
|
|
|
|
|
|
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
|
|
@pytest.mark.parametrize("zero_stage", [1, 3])
|
|
@pytest.mark.parametrize("use_torch_adam", [False, True])
|
|
@pytest.mark.parametrize("load_optim", [False, True])
|
|
@pytest.mark.parametrize("sub_group_size", [-1, 100])
|
|
class TestZeROUniversalCheckpointDP(DistributedTest):
|
|
|
|
def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam, world_size):
|
|
if dtype == torch.bfloat16 and not bf16_required_version_check():
|
|
pytest.skip(
|
|
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
|
|
)
|
|
|
|
hidden_dim = 10
|
|
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)
|
|
|
|
ds_config["checkpoint"] = {"load_universal": True}
|
|
univ_model = SimpleModel(hidden_dim, nlayers=2)
|
|
univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam)
|
|
univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim)
|
|
|
|
model_state = univ_model.state_dict()
|
|
compare_state_dicts(model_state, loaded_model_state)
|
|
|
|
if load_optim:
|
|
if ds_config["zero_optimization"]["stage"] == 3:
|
|
univ_model.optimizer._set_fp32_optimizer_param_groups()
|
|
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
|
|
univ_model.optimizer._clear_fp32_optimizer_param_groups()
|
|
update_gathered_stage3_optimizer(optimizer_state, univ_model._get_zero_param_shapes(), world_size)
|
|
else:
|
|
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
|
|
# padding sizes may differ when dp sizes are different
|
|
param_count = sum(p.numel() for p in univ_model.parameters())
|
|
optimizer_state = remove_pad_in_opt_state(optimizer_state, param_count)
|
|
loaded_optimizer_state = remove_pad_in_opt_state(loaded_optimizer_state, param_count)
|
|
|
|
compare_opt_state_dicts(optimizer_state, loaded_optimizer_state, get_expected_mismatch_keys())
|
|
|
|
# Run training again to verify that the optimizer has necessary states
|
|
test_step = 8
|
|
data_loader = random_dataloader(model=univ_model,
|
|
total_samples=test_step,
|
|
hidden_dim=hidden_dim,
|
|
device=univ_model.device,
|
|
dtype=dtype)
|
|
for batch in data_loader:
|
|
loss = univ_model(batch[0], batch[1])
|
|
univ_model.backward(loss)
|
|
univ_model.step()
|
|
|
|
univ_model.destroy()
|
|
|
|
@pytest.mark.world_size(2)
|
|
def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
|
|
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2)
|
|
|
|
@pytest.mark.world_size(2)
|
|
def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
|
|
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2)
|
|
|
|
@pytest.mark.world_size(4)
|
|
def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
|
|
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 4)
|