mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
Enable grad scaler for ZeRO-0 + torch.autocast path (#7619)
Currently, the DeepSpeed engine does not enable the grad scaler for the ZeRO-0 and `torch.autocast` path, even when dtype is set to `fp16`. This leads to errors in tests when we replace our hard-coded tolerances with PyTorch’s [standard tolerances](https://docs.pytorch.org/docs/stable/testing.html#torch.testing.assert_close) (Thank you @stas00 for you suggestion regarding the previous PR). This PR enables the grad scaler for this path to improve accuracy, and refactors the tests to simplify validation by using `torch.testing.assert_close`. The tests now rely on PyTorch’s standard (and stricter) tolerances, and they still pass. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
This commit is contained in:
@ -336,8 +336,13 @@ class DeepSpeedEngine(Module):
|
||||
if not isinstance(model_parameters, list):
|
||||
model_parameters = list(model_parameters)
|
||||
|
||||
# grad scaler only for Z0 (no ZeRO) + fp16 + torch_autocast
|
||||
# ZeRO1/2/3 optimizers have their own grad scaler logic
|
||||
self.torch_autocast_z0_gradscaler = None
|
||||
if self.torch_autocast_enabled():
|
||||
init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules())
|
||||
if (not self.zero_optimization() and self.torch_autocast_dtype() == torch.float16):
|
||||
self.torch_autocast_z0_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name())
|
||||
|
||||
self._configure_zenflow = lambda: configure_zenflow(self)
|
||||
self._is_zenflow_update_boundary = lambda: is_zenflow_update_boundary(self)
|
||||
@ -2303,7 +2308,12 @@ class DeepSpeedEngine(Module):
|
||||
elif self.bfloat16_enabled():
|
||||
self.optimizer.backward(loss, retain_graph=retain_graph)
|
||||
else:
|
||||
if self.eigenvalue_enabled():
|
||||
if self.torch_autocast_z0_gradscaler:
|
||||
if self.eigenvalue_enabled():
|
||||
self.torch_autocast_z0_gradscaler.scale(loss).backward(create_graph=True, retain_graph=True)
|
||||
else:
|
||||
self.torch_autocast_z0_gradscaler.scale(loss).backward(retain_graph=retain_graph)
|
||||
elif self.eigenvalue_enabled():
|
||||
loss.backward(create_graph=True, retain_graph=True)
|
||||
else:
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
@ -2402,6 +2412,9 @@ class DeepSpeedEngine(Module):
|
||||
|
||||
def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
|
||||
if self.gradient_clipping() > 0.0:
|
||||
if self.torch_autocast_z0_gradscaler:
|
||||
# Unscale for gradient clipping
|
||||
self.torch_autocast_z0_gradscaler.unscale_(self.optimizer)
|
||||
if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()):
|
||||
self.clip_fp32_gradients()
|
||||
elif self.amp_enabled():
|
||||
@ -2409,7 +2422,11 @@ class DeepSpeedEngine(Module):
|
||||
# https://nvidia.github.io/apex/advanced.html#gradient-clipping
|
||||
master_params = amp.master_params(self.optimizer)
|
||||
clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu)
|
||||
self.optimizer.step()
|
||||
if self.torch_autocast_z0_gradscaler:
|
||||
self.torch_autocast_z0_gradscaler.step(self.optimizer)
|
||||
self.torch_autocast_z0_gradscaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
if hasattr(self.optimizer, '_global_grad_norm'):
|
||||
self._global_grad_norm = self.optimizer._global_grad_norm
|
||||
|
@ -10,7 +10,7 @@ import pytest
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from unit.common import DistributedTest, enable_determinism, reduce_boolean_flags
|
||||
from unit.common import DistributedTest, enable_determinism, allclose_on_all_ranks
|
||||
from unit.simple_model import SimpleModel
|
||||
from unit.util import bf16_required_version_check
|
||||
|
||||
@ -19,9 +19,6 @@ from deepspeed.accelerator import get_accelerator
|
||||
from deepspeed.runtime.zero import GatheredParameters
|
||||
from deepspeed.runtime.torch_autocast import PARAM_COMM_DTYPE_ATTR_NAME, get_comm_dtype
|
||||
|
||||
RTOL = 0.1
|
||||
ATOL = 0.0
|
||||
|
||||
|
||||
def cls_to_qualname(cls):
|
||||
return f"{cls.__module__}.{cls.__name__}"
|
||||
@ -42,7 +39,7 @@ class SimpleModelWithLayerNorm(torch.nn.Module):
|
||||
|
||||
|
||||
def step_amp(enabled, baseline_model, baseline_optimizer, target_engine, dtype, enable_autocast_outside,
|
||||
baseline_scaler, step, x, y, rtol, atol, expect_match):
|
||||
baseline_scaler, step, x, y, expect_match):
|
||||
device_type = get_accelerator().device_name()
|
||||
|
||||
# Runs the forward pass with autocasting.
|
||||
@ -60,13 +57,7 @@ def step_amp(enabled, baseline_model, baseline_optimizer, target_engine, dtype,
|
||||
|
||||
# reduce-scatter in `dtype` makes a difference in the loss.
|
||||
if step <= 1 and expect_match:
|
||||
allclose_local = torch.allclose(baseline_loss.float(), target_loss.float(), rtol=rtol, atol=atol)
|
||||
if not allclose_local:
|
||||
print(f"Losses do not match: baseline_loss={baseline_loss}, target_loss={target_loss}")
|
||||
# Ensure all ranks either pass or fail together.
|
||||
# If some ranks fail while others pass, subsequent tests or iterations may hang.
|
||||
if not reduce_boolean_flags(allclose_local, all):
|
||||
assert False, f"Losses do not match on one or more ranks."
|
||||
allclose_on_all_ranks(baseline_loss, target_loss)
|
||||
|
||||
target_engine.backward(target_loss)
|
||||
target_engine.step()
|
||||
@ -139,7 +130,7 @@ def compare_loss(model_cls,
|
||||
|
||||
for i, (x, y) in enumerate(zip(xs, ys)):
|
||||
step_amp(enable, baseline_model, baseline_optimizer, target_engine, dtype, enable_autocast_outside,
|
||||
baseline_scaler, i, x, y, RTOL, ATOL, expect_match)
|
||||
baseline_scaler, i, x, y, expect_match)
|
||||
|
||||
for module in target_engine.modules():
|
||||
for p in module.parameters(recurse=False):
|
||||
|
Reference in New Issue
Block a user