mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-21 16:48:52 +08:00
This PR includes these two fixes: - Use GradScaler only for FP16 (not for BF16) - Fix dtype conversion for ZeRO3 allgather - The reduce hook should be called only once, even when a parameter is shared across multiple layers (tied parameters). - Currently, the hook is triggered at each tied layer because we temporarily set `.data` with a different dtype. - The fix ensures that the parameter consistently retains the same dtype. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Signed-off-by: Stas Bekman <stas@stason.org> Signed-off-by: jakehemmerle <jakehemmerle@protonmail.com> Signed-off-by: Qi Bin <qibin0506@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Jake Hemmerle <jakehemmerle@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Qi Bin <qibin0506@users.noreply.github.com>
82 lines
3.4 KiB
Python
82 lines
3.4 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import os
|
|
import torch
|
|
|
|
from deepspeed.utils import logger
|
|
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
|
|
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank, see_memory_usage
|
|
from deepspeed.runtime.torch_autocast import get_comm_dtype, is_autocast_initialized
|
|
|
|
|
|
class DeepSpeedOptimizer(object):
|
|
pass
|
|
|
|
|
|
class ZeROOptimizer(DeepSpeedOptimizer):
|
|
|
|
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
|
|
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
|
|
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
|
|
assert os.path.isfile(
|
|
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
|
|
optim_sd = torch.load(optim_state_path, weights_only=False)
|
|
|
|
self._load_global_state(optim_sd)
|
|
|
|
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
|
|
if self.mpu is None:
|
|
logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.")
|
|
tp_world_size = 1
|
|
else:
|
|
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
|
|
else self.mpu.get_tensor_model_parallel_world_size()
|
|
|
|
for i, (param_group,
|
|
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
|
|
# We have an assumption that all params in the same param_group have the same keys
|
|
opt_keys = set()
|
|
steps = []
|
|
|
|
lp_groups = getattr(self, lp_groups_name)
|
|
for lp in lp_groups[i]:
|
|
if lp._hp_mapping is not None:
|
|
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
|
|
step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
|
|
tp_world_size)
|
|
for key in lp._hp_mapping.get_optim_state_keys():
|
|
opt_keys.add(key)
|
|
steps.append(step)
|
|
|
|
hp_param = param_group['params'][0]
|
|
assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal"
|
|
if steps[0] is not None:
|
|
self.optimizer.state[hp_param]['step'] = steps[0]
|
|
|
|
map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)
|
|
|
|
for key, value in loaded_param_group.items():
|
|
if key == 'params':
|
|
continue
|
|
param_group[key] = value
|
|
|
|
def report_ipg_memory_usage(self, tag, param_elems, dtype=None):
|
|
dtypes = self.ipg_buckets.keys() if dtype is None else [dtype]
|
|
|
|
for dt in dtypes:
|
|
bucket = self.ipg_buckets[dt]
|
|
elem_count = bucket.elements + param_elems
|
|
percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size
|
|
see_memory_usage(
|
|
f"{tag}: elems in_bucket {dt} {bucket.elements} param {param_elems} max_percent {percent_of_bucket_size}"
|
|
)
|
|
|
|
def get_param_comm_dtype(self, param):
|
|
if is_autocast_initialized():
|
|
return get_comm_dtype(param)
|
|
else:
|
|
return self.communication_data_type
|