Enable grad scaler for ZeRO-0 + torch.autocast path (#7619)

Currently, the DeepSpeed engine does not enable the grad scaler for the
ZeRO-0 and `torch.autocast` path, even when dtype is set to `fp16`. This
leads to errors in tests when we replace our hard-coded tolerances with
PyTorch’s [standard
tolerances](https://docs.pytorch.org/docs/stable/testing.html#torch.testing.assert_close)
(Thank you @stas00 for you suggestion regarding the previous PR).

This PR enables the grad scaler for this path to improve accuracy, and
refactors the tests to simplify validation by using
`torch.testing.assert_close`. The tests now rely on PyTorch’s standard
(and stricter) tolerances, and they still pass.

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
Masahiro Tanaka
2025-10-04 06:21:08 -07:00
committed by GitHub
parent 65322e103c
commit 71d077da73
2 changed files with 23 additions and 15 deletions

View File

@ -336,8 +336,13 @@ class DeepSpeedEngine(Module):
if not isinstance(model_parameters, list): if not isinstance(model_parameters, list):
model_parameters = list(model_parameters) model_parameters = list(model_parameters)
# grad scaler only for Z0 (no ZeRO) + fp16 + torch_autocast
# ZeRO1/2/3 optimizers have their own grad scaler logic
self.torch_autocast_z0_gradscaler = None
if self.torch_autocast_enabled(): if self.torch_autocast_enabled():
init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules()) init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules())
if (not self.zero_optimization() and self.torch_autocast_dtype() == torch.float16):
self.torch_autocast_z0_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name())
self._configure_zenflow = lambda: configure_zenflow(self) self._configure_zenflow = lambda: configure_zenflow(self)
self._is_zenflow_update_boundary = lambda: is_zenflow_update_boundary(self) self._is_zenflow_update_boundary = lambda: is_zenflow_update_boundary(self)
@ -2303,7 +2308,12 @@ class DeepSpeedEngine(Module):
elif self.bfloat16_enabled(): elif self.bfloat16_enabled():
self.optimizer.backward(loss, retain_graph=retain_graph) self.optimizer.backward(loss, retain_graph=retain_graph)
else: else:
if self.eigenvalue_enabled(): if self.torch_autocast_z0_gradscaler:
if self.eigenvalue_enabled():
self.torch_autocast_z0_gradscaler.scale(loss).backward(create_graph=True, retain_graph=True)
else:
self.torch_autocast_z0_gradscaler.scale(loss).backward(retain_graph=retain_graph)
elif self.eigenvalue_enabled():
loss.backward(create_graph=True, retain_graph=True) loss.backward(create_graph=True, retain_graph=True)
else: else:
loss.backward(retain_graph=retain_graph) loss.backward(retain_graph=retain_graph)
@ -2402,6 +2412,9 @@ class DeepSpeedEngine(Module):
def _take_model_step(self, lr_kwargs, block_eigenvalue={}): def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
if self.gradient_clipping() > 0.0: if self.gradient_clipping() > 0.0:
if self.torch_autocast_z0_gradscaler:
# Unscale for gradient clipping
self.torch_autocast_z0_gradscaler.unscale_(self.optimizer)
if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()): if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):
self.clip_fp32_gradients() self.clip_fp32_gradients()
elif self.amp_enabled(): elif self.amp_enabled():
@ -2409,7 +2422,11 @@ class DeepSpeedEngine(Module):
# https://nvidia.github.io/apex/advanced.html#gradient-clipping # https://nvidia.github.io/apex/advanced.html#gradient-clipping
master_params = amp.master_params(self.optimizer) master_params = amp.master_params(self.optimizer)
clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu) clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
self.optimizer.step() if self.torch_autocast_z0_gradscaler:
self.torch_autocast_z0_gradscaler.step(self.optimizer)
self.torch_autocast_z0_gradscaler.update()
else:
self.optimizer.step()
if hasattr(self.optimizer, '_global_grad_norm'): if hasattr(self.optimizer, '_global_grad_norm'):
self._global_grad_norm = self.optimizer._global_grad_norm self._global_grad_norm = self.optimizer._global_grad_norm

View File

@ -10,7 +10,7 @@ import pytest
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from unit.common import DistributedTest, enable_determinism, reduce_boolean_flags from unit.common import DistributedTest, enable_determinism, allclose_on_all_ranks
from unit.simple_model import SimpleModel from unit.simple_model import SimpleModel
from unit.util import bf16_required_version_check from unit.util import bf16_required_version_check
@ -19,9 +19,6 @@ from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero import GatheredParameters from deepspeed.runtime.zero import GatheredParameters
from deepspeed.runtime.torch_autocast import PARAM_COMM_DTYPE_ATTR_NAME, get_comm_dtype from deepspeed.runtime.torch_autocast import PARAM_COMM_DTYPE_ATTR_NAME, get_comm_dtype
RTOL = 0.1
ATOL = 0.0
def cls_to_qualname(cls): def cls_to_qualname(cls):
return f"{cls.__module__}.{cls.__name__}" return f"{cls.__module__}.{cls.__name__}"
@ -42,7 +39,7 @@ class SimpleModelWithLayerNorm(torch.nn.Module):
def step_amp(enabled, baseline_model, baseline_optimizer, target_engine, dtype, enable_autocast_outside, def step_amp(enabled, baseline_model, baseline_optimizer, target_engine, dtype, enable_autocast_outside,
baseline_scaler, step, x, y, rtol, atol, expect_match): baseline_scaler, step, x, y, expect_match):
device_type = get_accelerator().device_name() device_type = get_accelerator().device_name()
# Runs the forward pass with autocasting. # Runs the forward pass with autocasting.
@ -60,13 +57,7 @@ def step_amp(enabled, baseline_model, baseline_optimizer, target_engine, dtype,
# reduce-scatter in `dtype` makes a difference in the loss. # reduce-scatter in `dtype` makes a difference in the loss.
if step <= 1 and expect_match: if step <= 1 and expect_match:
allclose_local = torch.allclose(baseline_loss.float(), target_loss.float(), rtol=rtol, atol=atol) allclose_on_all_ranks(baseline_loss, target_loss)
if not allclose_local:
print(f"Losses do not match: baseline_loss={baseline_loss}, target_loss={target_loss}")
# Ensure all ranks either pass or fail together.
# If some ranks fail while others pass, subsequent tests or iterations may hang.
if not reduce_boolean_flags(allclose_local, all):
assert False, f"Losses do not match on one or more ranks."
target_engine.backward(target_loss) target_engine.backward(target_loss)
target_engine.step() target_engine.step()
@ -139,7 +130,7 @@ def compare_loss(model_cls,
for i, (x, y) in enumerate(zip(xs, ys)): for i, (x, y) in enumerate(zip(xs, ys)):
step_amp(enable, baseline_model, baseline_optimizer, target_engine, dtype, enable_autocast_outside, step_amp(enable, baseline_model, baseline_optimizer, target_engine, dtype, enable_autocast_outside,
baseline_scaler, i, x, y, RTOL, ATOL, expect_match) baseline_scaler, i, x, y, expect_match)
for module in target_engine.modules(): for module in target_engine.modules():
for p in module.parameters(recurse=False): for p in module.parameters(recurse=False):