Enable non-ZeRO mode (#7515)

Enabled via `stage=0` which corresponds to DDP. 
Remove hardwired path to b16_optimizer.
Enable`torch.autocast` for DDP training
Enable native mixed precision DDP for bfloat16
Update torch.autocast and native mixed precision UTs

<img width="976" height="184" alt="image"
src="https://github.com/user-attachments/assets/92904cdc-e312-46a4-943f-011eb5ab146a"
/>

---------

Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
Olatunji Ruwase
2025-08-27 14:07:29 -04:00
committed by GitHub
parent 66ad278048
commit 889f0ead27
7 changed files with 133 additions and 79 deletions

View File

@ -33,7 +33,7 @@ app = modal.App("deepspeedai-torch-latest-ci", image=image)
def pytest():
import subprocess
subprocess.run(
"pytest -n 4 --verbose tests/unit/runtime/zero/test_zero.py tests/unit/runtime/half_precision/test_bf16.py --torch_ver=2.6 --cuda_ver=12.4".split(),
"pytest -n 4 --verbose tests/unit/runtime/zero/test_zero.py tests/unit/runtime/half_precision/test_bf16.py tests/unit/runtime/zero/test_zero_autocast.py --torch_ver=2.6 --cuda_ver=12.4".split(),
check=True,
cwd=ROOT_PATH / ".",
)

View File

@ -137,6 +137,9 @@ BFLOAT16_CHECK_OVERFLOW_DEFAULT = False
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True
# DDP variant of BFLOAT16
DDP_BFLOAT16 = "bf16"
#########################################
# FP16 support
#########################################

View File

@ -54,7 +54,7 @@ from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam
from deepspeed.runtime.constants import \
ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \
PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \
DATA_PARALLEL_GROUP, GLOBAL_RANK
DATA_PARALLEL_GROUP, GLOBAL_RANK, DDP_BFLOAT16
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.compression import compression_scheduler
from deepspeed.compression.constants import \
@ -1091,13 +1091,9 @@ class DeepSpeedEngine(Module):
model_dtype = torch.bfloat16
if self._config.grad_accum_dtype is None:
if model_dtype == torch.bfloat16 and not self.zero_optimization():
grad_accum_dtype = torch.float32
else:
grad_accum_dtype = model_dtype
grad_accum_dtype = model_dtype
else:
grad_accum_dtype = DtypeEnum(self._config.grad_accum_dtype).value
return (model_dtype, grad_accum_dtype)
def _optimizer_has_ckpt_event_prologue(self):
@ -1139,7 +1135,7 @@ class DeepSpeedEngine(Module):
or (self.zero_optimization_partition_weights() and self.is_first_weights_partition_group()):
self.save_non_zero_checkpoint = True
if self.zero_optimization() or self.bfloat16_enabled():
if hasattr(self.optimizer, 'dp_process_group'):
param_rank = dist.get_rank(group=self.optimizer.dp_process_group)
# Only the first parameter parallel process needs to store the
@ -1407,23 +1403,18 @@ class DeepSpeedEngine(Module):
return AMP
# data type checks
elif model_dtype == grad_accum_dtype:
if model_dtype == torch.bfloat16:
if self.pipeline_parallelism:
logger.warning(
"**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****"
)
return BFLOAT16
else:
raise NotImplementedError(
"Bfloat16 wrapper must use a gradient accumulation type of fp32, enable ZeRO to use Bfloat16 gradient accumulation"
)
if model_dtype == torch.float16:
return FP16
# else optimizer_wrapper = None
if model_dtype == torch.float32:
return None
if model_dtype == torch.bfloat16 and self.pipeline_parallelism:
logger.warning(
"**** BF16 gradient accumulation is not safe numerically with large number of accumulation steps, proceed with caution *****"
)
return BFLOAT16
return FP16 if model_dtype == torch.float16 else DDP_BFLOAT16
elif model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32:
return BFLOAT16
else:
raise NotImplementedError("unsupported mix of model dtype and gradient accumulation type")
raise NotImplementedError(f"unsupported mix of {model_dtype=} and {grad_accum_dtype=}")
return None
@ -1466,8 +1457,9 @@ class DeepSpeedEngine(Module):
self._set_client_model(model)
self._broadcast_model()
# TODO: maybe need to broadcast experts differently?
elif optimizer_wrapper == FP16:
self.optimizer = self._configure_fp16_optimizer(basic_optimizer)
elif optimizer_wrapper in [FP16, DDP_BFLOAT16]:
lp_dtype = torch.float16 if optimizer_wrapper == FP16 else torch.bfloat16
self.optimizer = self._configure_fp16_optimizer(basic_optimizer, lp_dtype)
elif optimizer_wrapper == BFLOAT16:
self.optimizer = self._configure_bf16_optimizer(basic_optimizer)
else:
@ -1641,7 +1633,7 @@ class DeepSpeedEngine(Module):
)
return quantizer
def _configure_fp16_optimizer(self, optimizer):
def _configure_fp16_optimizer(self, optimizer, low_precision_dtype):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
@ -1659,6 +1651,7 @@ class DeepSpeedEngine(Module):
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
low_precision_dtype=low_precision_dtype,
dynamic_loss_scale=True,
initial_dynamic_scale=initial_dynamic_scale,
dynamic_loss_args=dynamic_loss_args,
@ -1674,6 +1667,7 @@ class DeepSpeedEngine(Module):
optimizer = FP16_Optimizer(
optimizer,
deepspeed=self,
low_precision_dtype=low_precision_dtype,
static_loss_scale=self.loss_scale(),
mpu=self.mpu,
clip_grad=clip_grad,

View File

@ -40,6 +40,7 @@ class FP16_Optimizer(DeepSpeedOptimizer):
def __init__(self,
init_optimizer,
deepspeed=None,
low_precision_dtype=torch.float16,
static_loss_scale=1.0,
dynamic_loss_scale=False,
initial_dynamic_scale=2**32,
@ -53,11 +54,13 @@ class FP16_Optimizer(DeepSpeedOptimizer):
self.fused_adam_legacy = fused_adam_legacy
self.timers = timers
self.deepspeed = deepspeed
self.has_moe_layers = has_moe_layers
self.deepspeed = deepspeed
self.using_pipeline = getattr(self.deepspeed, 'pipeline_parallelism', False)
self.low_precision_dtype = low_precision_dtype
self.use_grad_scaling = low_precision_dtype == torch.float16
if not get_accelerator().is_available():
raise SystemError("Cannot use fp16 without accelerator.")
raise SystemError("Cannot use {low_precision_dtype} without accelerator.")
self.optimizer = init_optimizer
# param flattened by groups
@ -86,24 +89,27 @@ class FP16_Optimizer(DeepSpeedOptimizer):
param_group['params'] = [self.fp32_groups_flat[i]]
# we may have a way of fusing dynamic scale. Do not support for now
if dynamic_loss_scale:
self.dynamic_loss_scale = True
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2
if self.use_grad_scaling:
if dynamic_loss_scale:
self.dynamic_loss_scale = True
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = 2
if dynamic_loss_args is None:
self.cur_scale = initial_dynamic_scale
self.scale_window = 1000
self.min_loss_scale = 1
if dynamic_loss_args is None:
self.cur_scale = initial_dynamic_scale
self.scale_window = 1000
self.min_loss_scale = 1
else:
self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
self.scale_window = dynamic_loss_args[SCALE_WINDOW]
self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
else:
self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE]
self.scale_window = dynamic_loss_args[SCALE_WINDOW]
self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE]
self.dynamic_loss_scale = False
self.cur_iter = 0
self.cur_scale = static_loss_scale
else:
self.dynamic_loss_scale = False
self.cur_iter = 0
self.cur_scale = static_loss_scale
self.cur_scale = 1.0
self.verbose = verbose
self.custom_loss_scaler = False
@ -166,14 +172,15 @@ class FP16_Optimizer(DeepSpeedOptimizer):
norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.use_grad_scaling:
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
return self.overflow
if self.overflow:
if self.verbose:
logger.info("[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss "
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale))
return self.overflow
scaled_grad_norm = get_global_norm(norm_list=norm_groups)
@ -204,6 +211,8 @@ class FP16_Optimizer(DeepSpeedOptimizer):
return self.optimizer.param_groups[0]["lr"]
def override_loss_scale(self, loss_scale):
assert self.use_grad_scaling, f"Loss scale overriding only supported for torch.float16, rather than {self.low_precision_dtype}"
if loss_scale != self.external_loss_scale:
logger.info(f'[deepspeed] setting loss scale from {self.external_loss_scale} -> {loss_scale}')
self.custom_loss_scaler = True
@ -260,18 +269,20 @@ class FP16_Optimizer(DeepSpeedOptimizer):
self.overflow = self.overflow_checker.has_overflow(fp16_params)
if self.timers:
self.timers(OVERFLOW_CHECK_TIMER).stop()
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
log_dist(
"Overflow detected. Skipping step. Attempted loss "
f"scale: {prev_scale}, reducing to {self.cur_scale}",
ranks=[0])
# Clear gradients
for i, group in enumerate(self.fp16_groups):
for p in group:
p.grad = None
if self.use_grad_scaling:
prev_scale = self.cur_scale
self._update_scale(self.overflow)
if self.overflow:
if self.verbose:
log_dist(
"Overflow detected. Skipping step. Attempted loss "
f"scale: {prev_scale}, reducing to {self.cur_scale}",
ranks=[0])
# Clear gradients
for i, group in enumerate(self.fp16_groups):
for p in group:
p.grad = None
if self.timers:
self.timers.log(OVERFLOW_TIMERS)
@ -449,13 +460,14 @@ class FP16_Optimizer(DeepSpeedOptimizer):
torch.save(checkpoint, "saved.pth")
"""
state_dict = {}
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['cur_scale'] = self.cur_scale
state_dict['cur_iter'] = self.cur_iter
if state_dict['dynamic_loss_scale']:
state_dict['last_overflow_iter'] = self.last_overflow_iter
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
if self.use_grad_scaling:
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['cur_scale'] = self.cur_scale
state_dict['cur_iter'] = self.cur_iter
if state_dict['dynamic_loss_scale']:
state_dict['last_overflow_iter'] = self.last_overflow_iter
state_dict['scale_factor'] = self.scale_factor
state_dict['scale_window'] = self.scale_window
state_dict[OPTIMIZER_STATE_DICT] = self.optimizer.state_dict()
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
state_dict[CLIP_GRAD] = self.clip_grad
@ -483,13 +495,14 @@ class FP16_Optimizer(DeepSpeedOptimizer):
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# I think it should actually be ok to reload the optimizer before the model.
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
if state_dict['dynamic_loss_scale']:
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
if self.use_grad_scaling:
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.cur_scale = state_dict['cur_scale']
self.cur_iter = state_dict['cur_iter']
if state_dict['dynamic_loss_scale']:
self.last_overflow_iter = state_dict['last_overflow_iter']
self.scale_factor = state_dict['scale_factor']
self.scale_window = state_dict['scale_window']
if load_optimizer_states:
self.optimizer.load_state_dict(state_dict[OPTIMIZER_STATE_DICT])
self.clip_grad = state_dict[CLIP_GRAD]
@ -515,12 +528,16 @@ class FP16_Optimizer(DeepSpeedOptimizer):
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
if not self.use_grad_scaling:
return None
if self.custom_loss_scaler:
return self.external_loss_scale
else:
return self.cur_scale
def _set_loss_scale(self, value):
self.loss_scaler.cur_scale = value
if self.use_grad_scaling:
self.loss_scaler.cur_scale = value
loss_scale = property(_get_loss_scale, _set_loss_scale)

View File

@ -33,6 +33,46 @@ Gradient Accumulation
.. autofunction:: deepspeed.DeepSpeedEngine.is_gradient_accumulation_boundary
Mixed Precision Training
-------------------------
DeepSpeed supports mixed precision training using either native or PyTorch mechanisms. The desired mixed precision mode can be selected through the configuration dict.
Mixed precision training can used with ZeRO (i.e., stages > 0) and without ZeRO (i.e., stage=0).
Native Mixed Precision
======================================================
DeepSpeed provides native support for
`fp16 <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`_ and `bf16 <https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options>`_ mixed precsion training.
PyTorch Automatic Mixed Precision (AMP)
======================================================
DeepSpeed provides torch-compatible automatic mixed precision (AMP) training via
`torch.autocast <https://docs.pytorch.org/docs/stable/amp.html>`_ functionality. The following snippet illustrates how to enable Torch AMP.
.. code-block:: python
{
"torch_autocast": {
"enabled": true,
"dtype": "bfloat16",
"lower_precision_safe_modules": ["torch.nn.Linear", "torch.nn.Conv2d"]
},
...
}
Each configuration works as follows:
* ``enabled``: Enable ``torch.autocast`` when set to ``True``. You don't need to call ``torch.autocast`` in your code. The grad scaler is also applied in the DeepSpeed optimizer.
* ``dtype``: Lower precision dtype passed to ``torch.autocast``. Gradients for all-reduce (reduce-scatter) and parameters for all-gather (only for ZeRO3) of ``lower_precision_safe_modules`` are also downcasted to this ``dtype``.
* ``lower_precision_safe_modules``: The list of modules that will be downcasted for all-reduce (reduce-scatter) and all-gather (ZeRO3). The precision for PyTorch operators in forward/backward follows ``torch.autocast``'s policy, not this list. If you don't set this item, DeepSpeed uses the default list: ``[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]``.
.. autofunction:: deepspeed.runtime.torch_autocast.init_autocast_params
.. autofunction:: deepspeed.runtime.torch_autocast.is_autocast_initialized
.. autofunction:: deepspeed.runtime.torch_autocast.get_default_autocast_lower_precision_modules
.. autofunction:: deepspeed.runtime.torch_autocast.has_autocast_dtype
Model Saving
------------
.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model

View File

@ -229,7 +229,7 @@ class TestOptimizerImplementation(DistributedTest):
is_supported[(None, 'fp16', None)] = True
is_supported[(None, 'fp16', 'fp16')] = True
# BF16 Wrapper
is_supported[(None, 'bf16', 'fp32')] = True
is_supported[(None, 'bf16', 'bf16')] = True
is_supported[(None, 'bf16', None)] = True
# No Wrapper
is_supported[(None, 'fp32', None)] = True

View File

@ -139,7 +139,7 @@ def compare_loss(model_cls, enable, zero_stage, dtype, autocast_conf, enable_aut
class TestZeroAutoCast(DistributedTest):
world_size = 2
@pytest.mark.parametrize("zero_stage", [1, 2, 3])
@pytest.mark.parametrize("zero_stage", [0, 1, 2, 3])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test(self, enable, zero_stage, dtype):
lower_precision_safe_modules = [torch.nn.Linear]
@ -147,7 +147,7 @@ class TestZeroAutoCast(DistributedTest):
compare_loss(SimpleModel, enable, zero_stage, dtype, autocast_conf, False, lower_precision_safe_modules)
@pytest.mark.parametrize("zero_stage", [1, 2, 3])
@pytest.mark.parametrize("zero_stage", [0, 1, 2, 3])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_safe_modules_conf(self, enable, zero_stage, dtype):
lower_precision_safe_modules = [torch.nn.Linear]