mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Enable non-ZeRO mode (#7515)
Enabled via `stage=0` which corresponds to DDP. Remove hardwired path to b16_optimizer. Enable`torch.autocast` for DDP training Enable native mixed precision DDP for bfloat16 Update torch.autocast and native mixed precision UTs <img width="976" height="184" alt="image" src="https://github.com/user-attachments/assets/92904cdc-e312-46a4-943f-011eb5ab146a" /> --------- Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@ -33,7 +33,7 @@ app = modal.App("deepspeedai-torch-latest-ci", image=image)
|
||||
def pytest():
|
||||
import subprocess
|
||||
subprocess.run(
|
||||
"pytest -n 4 --verbose tests/unit/runtime/zero/test_zero.py tests/unit/runtime/half_precision/test_bf16.py --torch_ver=2.6 --cuda_ver=12.4".split(),
|
||||
"pytest -n 4 --verbose tests/unit/runtime/zero/test_zero.py tests/unit/runtime/half_precision/test_bf16.py tests/unit/runtime/zero/test_zero_autocast.py --torch_ver=2.6 --cuda_ver=12.4".split(),
|
||||
check=True,
|
||||
cwd=ROOT_PATH / ".",
|
||||
)
|
||||
|
@ -137,6 +137,9 @@ BFLOAT16_CHECK_OVERFLOW_DEFAULT = False
|
||||
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
|
||||
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True
|
||||
|
||||
# DDP variant of BFLOAT16
|
||||
DDP_BFLOAT16 = "bf16"
|
||||
|
||||
#########################################
|
||||
# FP16 support
|
||||
#########################################
|
||||
|
@ -54,7 +54,7 @@ from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam
|
||||
from deepspeed.runtime.constants import \
|
||||
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
|
||||
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \
|
||||
DATA_PARALLEL_GROUP, GLOBAL_RANK
|
||||
DATA_PARALLEL_GROUP, GLOBAL_RANK, DDP_BFLOAT16
|
||||
from deepspeed.runtime.zero.config import ZeroStageEnum
|
||||
from deepspeed.compression import compression_scheduler
|
||||
from deepspeed.compression.constants import \
|
||||
@ -1091,13 +1091,9 @@ class DeepSpeedEngine(Module):
|
||||
model_dtype = torch.bfloat16
|
||||
|
||||
if self._config.grad_accum_dtype is None:
|
||||
if model_dtype == torch.bfloat16 and not self.zero_optimization():
|
||||
grad_accum_dtype = torch.float32
|
||||
else:
|
||||
grad_accum_dtype = model_dtype
|
||||
grad_accum_dtype = model_dtype
|
||||
else:
|
||||
grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value
|
||||
|
||||
return (model_dtype, grad_accum_dtype)
|
||||
|
||||
def _optimizer_has_ckpt_event_prologue(self):
|
||||
@ -1139,7 +1135,7 @@ class DeepSpeedEngine(Module):
|
||||
or (self.zero_optimization_partition_weights() and self.is_first_weights_partition_group()):
|
||||
self.save_non_zero_checkpoint = True
|
||||
|
||||
if self.zero_optimization() or self.bfloat16_enabled():
|
||||
if hasattr(self.optimizer, 'dp_process_group'):
|
||||
param_rank = dist.get_rank(group=self.optimizer.dp_process_group)
|
||||
|
||||
# Only the first parameter parallel process needs to store the
|
||||
@ -1407,23 +1403,18 @@ class DeepSpeedEngine(Module):
|
||||
return AMP
|
||||
# data type checks
|
||||
elif model_dtype == grad_accum_dtype:
|
||||
if model_dtype == torch.bfloat16:
|
||||
if self.pipeline_parallelism:
|
||||
logger.warning(
|
||||
"**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****"
|
||||
)
|
||||
return BFLOAT16
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation"
|
||||
)
|
||||
if model_dtype == torch.float16:
|
||||
return FP16
|
||||
# else optimizer_wrapper = None
|
||||
if model_dtype == torch.float32:
|
||||
return None
|
||||
if model_dtype == torch.bfloat16 and self.pipeline_parallelism:
|
||||
logger.warning(
|
||||
"**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****"
|
||||
)
|
||||
return BFLOAT16
|
||||
return FP16 if model_dtype == torch.float16 else DDP_BFLOAT16
|
||||
elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:
|
||||
return BFLOAT16
|
||||
else:
|
||||
raise NotImplementedError("unsupported mix of model dtype and gradient accumulation type")
|
||||
raise NotImplementedError(f"unsupported mix of {model_dtype=} and {grad_accum_dtype=}")
|
||||
|
||||
return None
|
||||
|
||||
@ -1466,8 +1457,9 @@ class DeepSpeedEngine(Module):
|
||||
self._set_client_model(model)
|
||||
self._broadcast_model()
|
||||
# TODO: maybe need to broadcast experts differently?
|
||||
elif optimizer_wrapper == FP16:
|
||||
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
|
||||
elif optimizer_wrapper in [FP16, DDP_BFLOAT16]:
|
||||
lp_dtype = torch.float16 if optimizer_wrapper == FP16 else torch.bfloat16
|
||||
self.optimizer = self._configure_fp16_optimizer(basic_optimizer, lp_dtype)
|
||||
elif optimizer_wrapper == BFLOAT16:
|
||||
self.optimizer = self._configure_bf16_optimizer(basic_optimizer)
|
||||
else:
|
||||
@ -1641,7 +1633,7 @@ class DeepSpeedEngine(Module):
|
||||
)
|
||||
return quantizer
|
||||
|
||||
def _configure_fp16_optimizer(self, optimizer):
|
||||
def _configure_fp16_optimizer(self, optimizer, low_precision_dtype):
|
||||
initial_dynamic_scale = self.initial_dynamic_scale()
|
||||
dynamic_loss_args = self.dynamic_loss_scale_args()
|
||||
clip_grad = self.gradient_clipping()
|
||||
@ -1659,6 +1651,7 @@ class DeepSpeedEngine(Module):
|
||||
optimizer = FP16_Optimizer(
|
||||
optimizer,
|
||||
deepspeed=self,
|
||||
low_precision_dtype=low_precision_dtype,
|
||||
dynamic_loss_scale=True,
|
||||
initial_dynamic_scale=initial_dynamic_scale,
|
||||
dynamic_loss_args=dynamic_loss_args,
|
||||
@ -1674,6 +1667,7 @@ class DeepSpeedEngine(Module):
|
||||
optimizer = FP16_Optimizer(
|
||||
optimizer,
|
||||
deepspeed=self,
|
||||
low_precision_dtype=low_precision_dtype,
|
||||
static_loss_scale=self.loss_scale(),
|
||||
mpu=self.mpu,
|
||||
clip_grad=clip_grad,
|
||||
|
@ -40,6 +40,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
def __init__(self,
|
||||
init_optimizer,
|
||||
deepspeed=None,
|
||||
low_precision_dtype=torch.float16,
|
||||
static_loss_scale=1.0,
|
||||
dynamic_loss_scale=False,
|
||||
initial_dynamic_scale=2**32,
|
||||
@ -53,11 +54,13 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
|
||||
self.fused_adam_legacy = fused_adam_legacy
|
||||
self.timers = timers
|
||||
self.deepspeed = deepspeed
|
||||
self.has_moe_layers = has_moe_layers
|
||||
self.deepspeed = deepspeed
|
||||
self.using_pipeline = getattr(self.deepspeed, 'pipeline_parallelism', False)
|
||||
self.low_precision_dtype = low_precision_dtype
|
||||
self.use_grad_scaling = low_precision_dtype == torch.float16
|
||||
if not get_accelerator().is_available():
|
||||
raise SystemError("Cannot use fp16 without accelerator.")
|
||||
raise SystemError("Cannot use {low_precision_dtype} without accelerator.")
|
||||
self.optimizer = init_optimizer
|
||||
|
||||
# param flattened by groups
|
||||
@ -86,24 +89,27 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
param_group['params'] = [self.fp32_groups_flat[i]]
|
||||
|
||||
# we may have a way of fusing dynamic scale. Do not support for now
|
||||
if dynamic_loss_scale:
|
||||
self.dynamic_loss_scale = True
|
||||
self.cur_iter = 0
|
||||
self.last_overflow_iter = -1
|
||||
self.scale_factor = 2
|
||||
if self.use_grad_scaling:
|
||||
if dynamic_loss_scale:
|
||||
self.dynamic_loss_scale = True
|
||||
self.cur_iter = 0
|
||||
self.last_overflow_iter = -1
|
||||
self.scale_factor = 2
|
||||
|
||||
if dynamic_loss_args is None:
|
||||
self.cur_scale = initial_dynamic_scale
|
||||
self.scale_window = 1000
|
||||
self.min_loss_scale = 1
|
||||
if dynamic_loss_args is None:
|
||||
self.cur_scale = initial_dynamic_scale
|
||||
self.scale_window = 1000
|
||||
self.min_loss_scale = 1
|
||||
else:
|
||||
self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
|
||||
self.scale_window = dynamic_loss_args[SCALE_WINDOW]
|
||||
self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
|
||||
else:
|
||||
self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
|
||||
self.scale_window = dynamic_loss_args[SCALE_WINDOW]
|
||||
self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
|
||||
self.dynamic_loss_scale = False
|
||||
self.cur_iter = 0
|
||||
self.cur_scale = static_loss_scale
|
||||
else:
|
||||
self.dynamic_loss_scale = False
|
||||
self.cur_iter = 0
|
||||
self.cur_scale = static_loss_scale
|
||||
self.cur_scale = 1.0
|
||||
self.verbose = verbose
|
||||
|
||||
self.custom_loss_scaler = False
|
||||
@ -166,14 +172,15 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
|
||||
|
||||
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
|
||||
prev_scale = self.cur_scale
|
||||
self._update_scale(self.overflow)
|
||||
if self.use_grad_scaling:
|
||||
prev_scale = self.cur_scale
|
||||
self._update_scale(self.overflow)
|
||||
|
||||
if self.overflow:
|
||||
if self.verbose:
|
||||
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
|
||||
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
|
||||
return self.overflow
|
||||
if self.overflow:
|
||||
if self.verbose:
|
||||
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
|
||||
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
|
||||
return self.overflow
|
||||
|
||||
scaled_grad_norm = get_global_norm(norm_list=norm_groups)
|
||||
|
||||
@ -204,6 +211,8 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
return self.optimizer.param_groups[0]["lr"]
|
||||
|
||||
def override_loss_scale(self, loss_scale):
|
||||
assert self.use_grad_scaling, f"Loss scale overriding only supported for torch.float16, rather than {self.low_precision_dtype}"
|
||||
|
||||
if loss_scale != self.external_loss_scale:
|
||||
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
|
||||
self.custom_loss_scaler = True
|
||||
@ -260,18 +269,20 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
self.overflow = self.overflow_checker.has_overflow(fp16_params)
|
||||
if self.timers:
|
||||
self.timers(OVERFLOW_CHECK_TIMER).stop()
|
||||
prev_scale = self.cur_scale
|
||||
self._update_scale(self.overflow)
|
||||
if self.overflow:
|
||||
if self.verbose:
|
||||
log_dist(
|
||||
"Overflow detected. Skipping step. Attempted loss "
|
||||
f"scale: {prev_scale}, reducing to {self.cur_scale}",
|
||||
ranks=[0])
|
||||
# Clear gradients
|
||||
for i, group in enumerate(self.fp16_groups):
|
||||
for p in group:
|
||||
p.grad = None
|
||||
|
||||
if self.use_grad_scaling:
|
||||
prev_scale = self.cur_scale
|
||||
self._update_scale(self.overflow)
|
||||
if self.overflow:
|
||||
if self.verbose:
|
||||
log_dist(
|
||||
"Overflow detected. Skipping step. Attempted loss "
|
||||
f"scale: {prev_scale}, reducing to {self.cur_scale}",
|
||||
ranks=[0])
|
||||
# Clear gradients
|
||||
for i, group in enumerate(self.fp16_groups):
|
||||
for p in group:
|
||||
p.grad = None
|
||||
|
||||
if self.timers:
|
||||
self.timers.log(OVERFLOW_TIMERS)
|
||||
@ -449,13 +460,14 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
torch.save(checkpoint, "saved.pth")
|
||||
"""
|
||||
state_dict = {}
|
||||
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
|
||||
state_dict['cur_scale'] = self.cur_scale
|
||||
state_dict['cur_iter'] = self.cur_iter
|
||||
if state_dict['dynamic_loss_scale']:
|
||||
state_dict['last_overflow_iter'] = self.last_overflow_iter
|
||||
state_dict['scale_factor'] = self.scale_factor
|
||||
state_dict['scale_window'] = self.scale_window
|
||||
if self.use_grad_scaling:
|
||||
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
|
||||
state_dict['cur_scale'] = self.cur_scale
|
||||
state_dict['cur_iter'] = self.cur_iter
|
||||
if state_dict['dynamic_loss_scale']:
|
||||
state_dict['last_overflow_iter'] = self.last_overflow_iter
|
||||
state_dict['scale_factor'] = self.scale_factor
|
||||
state_dict['scale_window'] = self.scale_window
|
||||
state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
|
||||
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
|
||||
state_dict[CLIP_GRAD] = self.clip_grad
|
||||
@ -483,13 +495,14 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
"""
|
||||
# I think it should actually be ok to reload the optimizer before the model.
|
||||
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
|
||||
self.cur_scale = state_dict['cur_scale']
|
||||
self.cur_iter = state_dict['cur_iter']
|
||||
if state_dict['dynamic_loss_scale']:
|
||||
self.last_overflow_iter = state_dict['last_overflow_iter']
|
||||
self.scale_factor = state_dict['scale_factor']
|
||||
self.scale_window = state_dict['scale_window']
|
||||
if self.use_grad_scaling:
|
||||
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
|
||||
self.cur_scale = state_dict['cur_scale']
|
||||
self.cur_iter = state_dict['cur_iter']
|
||||
if state_dict['dynamic_loss_scale']:
|
||||
self.last_overflow_iter = state_dict['last_overflow_iter']
|
||||
self.scale_factor = state_dict['scale_factor']
|
||||
self.scale_window = state_dict['scale_window']
|
||||
if load_optimizer_states:
|
||||
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
|
||||
self.clip_grad = state_dict[CLIP_GRAD]
|
||||
@ -515,12 +528,16 @@ class FP16_Optimizer(DeepSpeedOptimizer):
|
||||
|
||||
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
|
||||
def _get_loss_scale(self):
|
||||
if not self.use_grad_scaling:
|
||||
return None
|
||||
|
||||
if self.custom_loss_scaler:
|
||||
return self.external_loss_scale
|
||||
else:
|
||||
return self.cur_scale
|
||||
|
||||
def _set_loss_scale(self, value):
|
||||
self.loss_scaler.cur_scale = value
|
||||
if self.use_grad_scaling:
|
||||
self.loss_scaler.cur_scale = value
|
||||
|
||||
loss_scale = property(_get_loss_scale, _set_loss_scale)
|
||||
|
@ -33,6 +33,46 @@ Gradient Accumulation
|
||||
.. autofunction:: deepspeed.DeepSpeedEngine.is_gradient_accumulation_boundary
|
||||
|
||||
|
||||
Mixed Precision Training
|
||||
-------------------------
|
||||
DeepSpeed supports mixed precision training using either native or PyTorch mechanisms. The desired mixed precision mode can be selected through the configuration dict.
|
||||
Mixed precision training can used with ZeRO (i.e., stages > 0) and without ZeRO (i.e., stage=0).
|
||||
|
||||
|
||||
Native Mixed Precision
|
||||
======================================================
|
||||
DeepSpeed provides native support for
|
||||
`fp16 <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`_ and `bf16 <https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options>`_ mixed precsion training.
|
||||
|
||||
|
||||
PyTorch Automatic Mixed Precision (AMP)
|
||||
======================================================
|
||||
DeepSpeed provides torch-compatible automatic mixed precision (AMP) training via
|
||||
`torch.autocast <https://docs.pytorch.org/docs/stable/amp.html>`_ functionality. The following snippet illustrates how to enable Torch AMP.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"torch_autocast": {
|
||||
"enabled": true,
|
||||
"dtype": "bfloat16",
|
||||
"lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"]
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
Each configuration works as follows:
|
||||
|
||||
* ``enabled``: Enable ``torch.autocast`` when set to ``True``. You don't need to call ``torch.autocast`` in your code. The grad scaler is also applied in the DeepSpeed optimizer.
|
||||
* ``dtype``: Lower precision dtype passed to ``torch.autocast``. Gradients for all-reduce (reduce-scatter) and parameters for all-gather (only for ZeRO3) of ``lower_precision_safe_modules`` are also downcasted to this ``dtype``.
|
||||
* ``lower_precision_safe_modules``: The list of modules that will be downcasted for all-reduce (reduce-scatter) and all-gather (ZeRO3). The precision for PyTorch operators in forward/backward follows ``torch.autocast``'s policy, not this list. If you don't set this item, DeepSpeed uses the default list: ``[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]``.
|
||||
|
||||
.. autofunction:: deepspeed.runtime.torch_autocast.init_autocast_params
|
||||
.. autofunction:: deepspeed.runtime.torch_autocast.is_autocast_initialized
|
||||
.. autofunction:: deepspeed.runtime.torch_autocast.get_default_autocast_lower_precision_modules
|
||||
.. autofunction:: deepspeed.runtime.torch_autocast.has_autocast_dtype
|
||||
|
||||
|
||||
Model Saving
|
||||
------------
|
||||
.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model
|
||||
|
@ -229,7 +229,7 @@ class TestOptimizerImplementation(DistributedTest):
|
||||
is_supported[(None, 'fp16', None)] = True
|
||||
is_supported[(None, 'fp16', 'fp16')] = True
|
||||
# BF16 Wrapper
|
||||
is_supported[(None, 'bf16', 'fp32')] = True
|
||||
is_supported[(None, 'bf16', 'bf16')] = True
|
||||
is_supported[(None, 'bf16', None)] = True
|
||||
# No Wrapper
|
||||
is_supported[(None, 'fp32', None)] = True
|
||||
|
@ -139,7 +139,7 @@ def compare_loss(model_cls, enable, zero_stage, dtype, autocast_conf, enable_aut
|
||||
class TestZeroAutoCast(DistributedTest):
|
||||
world_size = 2
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [1, 2, 3])
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test(self, enable, zero_stage, dtype):
|
||||
lower_precision_safe_modules = [torch.nn.Linear]
|
||||
@ -147,7 +147,7 @@ class TestZeroAutoCast(DistributedTest):
|
||||
|
||||
compare_loss(SimpleModel, enable, zero_stage, dtype, autocast_conf, False, lower_precision_safe_modules)
|
||||
|
||||
@pytest.mark.parametrize("zero_stage", [1, 2, 3])
|
||||
@pytest.mark.parametrize("zero_stage", [0, 1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
def test_safe_modules_conf(self, enable, zero_stage, dtype):
|
||||
lower_precision_safe_modules = [torch.nn.Linear]
|
||||
|
Reference in New Issue
Block a user