diff --git a/deepspeed/checkpoint/ds_to_universal.py b/deepspeed/checkpoint/ds_to_universal.py index 2c8cb280d..8a39f6bb4 100755 --- a/deepspeed/checkpoint/ds_to_universal.py +++ b/deepspeed/checkpoint/ds_to_universal.py @@ -152,21 +152,21 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D): def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index): state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False) - flat_state = dict( - exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"], - exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"], - fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0], - ) - - offset = 0 - for name, shape in param_shapes.items(): - unpartitioned_numel = shape.numel() - partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree) - padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel)) - for state_key in flat_state.keys(): - dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, - padding_free_numel) - offset += partitioned_numel + for idx, sub_group_shape in enumerate(param_shapes): + flat_state = dict( + exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg"], + exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg_sq"], + fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][idx], + ) + offset = 0 + for name, shape in sub_group_shape.items(): + unpartitioned_numel = shape.numel() + partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree) + padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel)) + for state_key in flat_state.keys(): + dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset, + padding_free_numel) + offset += partitioned_numel cnt = 0 @@ -390,10 +390,10 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir): print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices') -def _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir): +def _merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir): zero_output_folder = os.path.join(args.output_folder, "zero") do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir) - _do_parallel_work(do_work, param_shapes.keys(), args.num_merge_workers) + _do_parallel_work(do_work, param_keys, args.num_merge_workers) def _zero_partitioned_param_info(unpartitioned_numel, world_size): @@ -514,7 +514,6 @@ def main(args): else: model_files = _get_model_state_files(args.input_folder) param_shapes = _parse_model_states_stage3(model_files) - param_shapes = {k: v for d in param_shapes for k, v in d.items()} dp_degree = len(model_files) temp_dir = os.path.join(args.output_folder, 'tmp') @@ -523,7 +522,8 @@ def main(args): _extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir) print('*** 2. Merging slices .....') - _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir) + param_keys = {key for sub_group_shapes in param_shapes for key in sub_group_shapes.keys()} + _merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir) print('*** 3. Saving common optimizer states') _save_optimizer_state_stage3(args, optim_files) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a5c106836..363da7a76 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1838,6 +1838,7 @@ class DeepSpeedEngine(Module): optimizer = Stage3ZeroOptimizer( self.module, optimizer, + self.param_names, timers=timers, ds_config=self.config, static_loss_scale=self.loss_scale(), @@ -1886,6 +1887,7 @@ class DeepSpeedEngine(Module): model_dtype, gradient_accumulation_dtype = self.get_data_types() optimizer = MiCS_Optimizer(self.module, basic_optimizer, + self.param_names, timers=timers, ds_config=self.config, static_loss_scale=self.loss_scale(), @@ -3010,10 +3012,7 @@ class DeepSpeedEngine(Module): mp_rank_str = f"{mp_rank:02d}" if self.zero_optimization_partition_weights(): - if self.load_universal_checkpoint(): - filename = "zero_pp_rank_0" - else: - filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group)) + filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group)) ckpt_name = os.path.join( checkpoints_path, str(tag), @@ -3053,7 +3052,10 @@ class DeepSpeedEngine(Module): ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files.sort() - return ckpt_files + if self.load_universal_checkpoint(): + return [ckpt_files[0]] + else: + return ckpt_files def load_checkpoint(self, load_dir, diff --git a/deepspeed/runtime/zero/mics.py b/deepspeed/runtime/zero/mics.py index 92b129f5d..0939409ff 100755 --- a/deepspeed/runtime/zero/mics.py +++ b/deepspeed/runtime/zero/mics.py @@ -367,6 +367,7 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3): def __init__(self, module, init_optimizer, + param_names, timers, ds_config, static_loss_scale=1, @@ -398,7 +399,7 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3): aio_config=None): log_dist("Init MiCS optimizer", ranks=[0]) - super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale, + super().__init__(module, init_optimizer, param_names, timers, ds_config, static_loss_scale, dynamic_loss_scale, dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size, max_reuse_distance, max_live_parameters, param_persistence_threshold, model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm, diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index b62e2610e..b163729e3 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -141,6 +141,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): self, module, init_optimizer, + param_names, timers, ds_config, static_loss_scale=1.0, @@ -200,6 +201,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): raise SystemError("Cannot use fp16 without accelerator.") self.optimizer = init_optimizer + self.param_names = param_names # Use torch (un)flatten ops self.flatten = _flatten_dense_tensors @@ -2806,8 +2808,7 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.") if checkpoint_folder: - self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights, - param_shapes) + self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)], load_optimizer_states=load_optimizer_states) @@ -2828,11 +2829,10 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): self.persistent_parameters[0].partition(self.persistent_parameters) # self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather - def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, - param_shapes): - self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes) + def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): + self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder) - def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes): + def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir): """ Load optimizer and model states from the checkpoint directory. """ checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") @@ -2842,18 +2842,34 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): optim_sd = torch.load(optim_state_path, weights_only=False) self._load_global_state_stage3(optim_sd) - key_list = ["fp32", "exp_avg", "exp_avg_sq"] + # Generally the step of each optimizer file should be the same, we can obtain from any parameter. + state_step = optim_sd[OPTIMIZER_STATE_DICT]['state'][0]['step'] + for key in ["fp32", "exp_avg", "exp_avg_sq"]: + for sub_group_id, fp16_group in enumerate(self.fp16_groups): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + key_tensor = torch.zeros_like(fp32_param) + offset = 0 + for param in fp16_group: + if param not in self.param_names: + raise ValueError(f"failed to find optimizer param in named params") + param_name = self.param_names[param] + key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, param_name), + key) + key_tensor.narrow(0, offset, key_layer_state_partition.numel()).copy_(key_layer_state_partition) + offset += key_layer_state_partition.numel() + if key == "fp32": + self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(key_tensor) + self.optimizer.state[fp32_param]['step'] = state_step + else: + self.optimizer.state[fp32_param][key] = key_tensor - for key in key_list: - key_tensor = torch.empty(0) - for layer in param_shapes[0].keys(): - key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, layer), key) - key_tensor = torch.cat((key_tensor, key_layer_state_partition)) - if key == "fp32": - self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor) - self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0]) - else: - optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor + for param_group in self.optimizer.param_groups: + # Generally, the hyperparameters of each parameter should be the same, we can obtain from any parameter. + for key, value in optim_sd[OPTIMIZER_STATE_DICT]["param_groups"][0].items(): + if key == 'params': + param_group['params'] = [] + else: + param_group[key] = value if self.swap_optimizer: # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint @@ -2869,10 +2885,6 @@ class DeepSpeedZeroOptimizer_Stage3(ZeROOptimizer): self._release_sub_group(sub_group_id, timer_names) self._post_step(timer_names) - self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT]) - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] if sum(fp32_param.size()) > 0: diff --git a/tests/unit/checkpoint/test_universal_checkpoint.py b/tests/unit/checkpoint/test_universal_checkpoint.py index 46d4294bd..27e151103 100644 --- a/tests/unit/checkpoint/test_universal_checkpoint.py +++ b/tests/unit/checkpoint/test_universal_checkpoint.py @@ -3,6 +3,9 @@ # DeepSpeed Team +import os +import math + import deepspeed from types import SimpleNamespace from torch.utils._pytree import tree_map @@ -72,13 +75,13 @@ def init_ds_engine(model, ds_config, use_torch_adam): return model -def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir): +def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, world_size): if dtype == torch.bfloat16 and not bf16_required_version_check(): return test_step = 8 - model = SimpleModel(hidden_dim) + model = SimpleModel(hidden_dim, nlayers=2) model = init_ds_engine(model, ds_config, use_torch_adam) data_loader = random_dataloader(model=model, total_samples=test_step, @@ -124,6 +127,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, model.optimizer._set_fp32_optimizer_param_groups() optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict()) model.optimizer._clear_fp32_optimizer_param_groups() + update_gathered_stage3_optimizer(optimizer_state, model._get_zero_param_shapes(), world_size) else: optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict()) @@ -135,7 +139,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, @pytest.fixture -def ds_config(zero_stage, dtype): +def ds_config(zero_stage, dtype, sub_group_size): ds_config = { "train_batch_size": 8, "optimizer": { @@ -149,6 +153,8 @@ def ds_config(zero_stage, dtype): ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8} elif dtype == torch.bfloat16: ds_config["bf16"] = {"enabled": True} + if sub_group_size > 0: + ds_config["zero_optimization"]["sub_group_size"] = sub_group_size return ds_config @@ -157,7 +163,7 @@ class _baseline(DistributedFixture): def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam): hidden_dim = 10 - train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir) + train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, self.world_size) class baseline_ws2(_baseline): @@ -168,13 +174,46 @@ class baseline_ws4(_baseline): world_size = 4 +# Stage3 use shard parameter, need to reorganize the optimizer parameters. +def update_gathered_stage3_optimizer(optimizer_state, param_shapes, world_size): + for sub_group_id, group in enumerate(optimizer_state["param_groups"]): + group["params"] = None + + new_state = {} + for sub_group_id, sub_group_param_shape in enumerate(param_shapes): + total_numel = optimizer_state['state'][sub_group_id]['exp_avg'].numel() + assert total_numel % world_size == 0 + numel_per_rank = total_numel // world_size + param_offset_in_current_rank = 0 + for param_name, param_shape in sub_group_param_shape.items(): + param_numel = param_shape.numel() + param_partition_numel = math.ceil(param_numel / world_size) + param_optimizer_tensor = { + "exp_avg": torch.zeros(param_numel), + "exp_avg_sq": torch.zeros(param_numel), + "step": optimizer_state['state'][sub_group_id]['step'], + } + for key in ["exp_avg", "exp_avg_sq"]: + write_offset = 0 + for rank in range(world_size): + offset = param_offset_in_current_rank + rank * numel_per_rank + length = min(param_partition_numel, param_numel - rank * param_partition_numel) + tmp = optimizer_state['state'][sub_group_id][key].narrow(0, offset, length) + param_optimizer_tensor[key].narrow(0, write_offset, length).copy_(tmp) + write_offset += length + param_offset_in_current_rank += param_partition_numel + new_state[param_name] = param_optimizer_tensor + optimizer_state["state"] = new_state + + @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32]) @pytest.mark.parametrize("zero_stage", [1, 3]) @pytest.mark.parametrize("use_torch_adam", [False, True]) @pytest.mark.parametrize("load_optim", [False, True]) +@pytest.mark.parametrize("sub_group_size", [-1, 100]) class TestZeROUniversalCheckpointDP(DistributedTest): - def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam): + def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam, world_size): if dtype == torch.bfloat16 and not bf16_required_version_check(): pytest.skip( " DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" @@ -184,15 +223,21 @@ class TestZeROUniversalCheckpointDP(DistributedTest): loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False) ds_config["checkpoint"] = {"load_universal": True} - univ_model = SimpleModel(hidden_dim) + univ_model = SimpleModel(hidden_dim, nlayers=2) univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam) univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim) model_state = univ_model.state_dict() compare_state_dicts(model_state, loaded_model_state) - if load_optim and ds_config["zero_optimization"]["stage"] != 3: - optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict()) + if load_optim: + if ds_config["zero_optimization"]["stage"] == 3: + univ_model.optimizer._set_fp32_optimizer_param_groups() + optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict()) + univ_model.optimizer._clear_fp32_optimizer_param_groups() + update_gathered_stage3_optimizer(optimizer_state, univ_model._get_zero_param_shapes(), world_size) + else: + optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict()) # padding sizes may differ when dp sizes are different param_count = sum(p.numel() for p in univ_model.parameters()) optimizer_state = remove_pad_in_opt_state(optimizer_state, param_count) @@ -216,12 +261,12 @@ class TestZeROUniversalCheckpointDP(DistributedTest): @pytest.mark.world_size(2) def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam): - self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam) + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2) @pytest.mark.world_size(2) def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam): - self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam) + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2) @pytest.mark.world_size(4) def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam): - self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam) + self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 4)