Deprecate device-specific GradScaler autocast API (#126527)

# Motivation

## for `torch.amp.GradScaler`,
- `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`.
- `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`.

So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`.

## for `custom_fwd` and `custom_bwd`,
this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU.
So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`.

# Additional Context
Add UT to cover the deprecated warning.
No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them.
To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126527
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2024-05-24 21:49:05 +00:00
committed by PyTorch MergeBot
parent ef86a27dba
commit c09205a057
8 changed files with 43 additions and 19 deletions

View File

@ -2101,7 +2101,7 @@ class BenchmarkRunner:
# which is bad as Gradscaler has state and can adjust the scaling
# factor between eager and dynamo run, making accuracy check
# harder.
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
# self.grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
self.autocast = functools.partial(
torch.amp.autocast, device_type=devices[0]
)

View File

@ -19,18 +19,15 @@ are much faster in ``lower_precision_fp``. Other ops, like reductions, often req
range of ``float32``. Mixed precision tries to match each op to its appropriate datatype.
Ordinarily, "automatic mixed precision training" with datatype of ``torch.float16`` uses :class:`torch.autocast` and
:class:`torch.cpu.amp.GradScaler` or :class:`torch.cuda.amp.GradScaler` together, as shown in the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>`
:class:`torch.amp.GradScaler` together, as shown in the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>`
and `CUDA Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_.
However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and may be used separately if desired.
As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`.
For CUDA and CPU, APIs are also provided separately:
* ``torch.autocast("cuda", args...)`` is equivalent to ``torch.cuda.amp.autocast(args...)``.
* ``torch.autocast("cpu", args...)`` is equivalent to ``torch.cpu.amp.autocast(args...)``. For CPU, only lower precision floating point datatype of ``torch.bfloat16`` is supported for now.
* ``torch.GradScaler("cuda", args...)`` is equivalent to ``torch.cuda.amp.GradScaler(args...)``.
* ``torch.GradScaler("cpu", args...)`` is equivalent to ``torch.cpu.amp.GradScaler(args...)``.
.. warning::
``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead.
``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` will be deprecated. Please use ``torch.GradScaler("cuda", args...)`` or ``torch.GradScaler("cpu", args...)`` instead.
:class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`.

View File

@ -257,7 +257,7 @@ class TestShardedGradScalerParityWithDDP(FSDPTest):
use_orig_params=use_orig_params,
)
grad_scaler = ShardedGradScaler(init_scale=2.0)
ref_grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
ref_grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
scaled_losses: List[torch.Tensor] = []
device = torch.device("cuda")
torch.manual_seed(42 + self.rank + 1)

View File

@ -3437,13 +3437,15 @@ exit(2)
grads_graphed = [[g.clone() for g in gs] for gs in grads]
# Gradient Scaler
scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
scaler_for_control = torch.amp.GradScaler(
device="cuda", init_scale=128.0
)
with torch.no_grad():
scaler_for_control._lazy_init_scale_growth_tracker(
torch.device("cuda")
)
scaler_for_graphed = torch.cuda.amp.GradScaler()
scaler_for_graphed = torch.amp.GradScaler(device="cuda")
scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
with torch.no_grad():
scaler_for_graphed._lazy_init_scale_growth_tracker(
@ -4722,7 +4724,7 @@ class TestCudaOptims(TestCase):
def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused):
torch.cuda.empty_cache()
scaler = torch.cuda.amp.GradScaler(init_scale=4.0)
scaler = torch.amp.GradScaler(device="cuda", init_scale=4.0)
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()

View File

@ -1159,7 +1159,7 @@ t2.start()
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_grad_scaling_scale(self):
scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0")
t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1")
# Create some nested iterables of tensors on different devices.
@ -1205,8 +1205,12 @@ t2.start()
opt_scaling1,
) = _create_scaling_models_optimizers(device=dev1)
scaler = torch.cuda.amp.GradScaler(
init_scale=128.0, growth_factor=2.0, enabled=enabled, growth_interval=1
scaler = torch.amp.GradScaler(
device="cuda",
init_scale=128.0,
growth_factor=2.0,
enabled=enabled,
growth_interval=1,
)
def run(model0, model1, optimizer0, optimizer1, try_scaling_api):

View File

@ -6152,7 +6152,7 @@ else:
@onlyNativeDeviceTypes
def test_grad_scaler_pass_itself(self, device):
device = torch.device(device)
GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler
GradScaler = partial(torch.amp.GradScaler, device=device.type)
class _PlaceHolderOptimizer(torch.optim.Optimizer):
tester = self
@ -6165,7 +6165,7 @@ else:
class Optimizer1(_PlaceHolderOptimizer):
def step(self, closure=None, *, grad_scaler=None):
self.tester.assertTrue(isinstance(grad_scaler, GradScaler))
self.tester.assertTrue(isinstance(grad_scaler, torch.amp.GradScaler))
self.tester.assertFalse(hasattr(self, "grad_scale"))
self.tester.assertFalse(hasattr(self, "found_inf"))
@ -6189,6 +6189,17 @@ else:
scaler.step(o2)
scaler.update()
@onlyNativeDeviceTypes
def test_grad_scaler_deprecated_warning(self, device):
device = torch.device(device)
GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler
with self.assertWarnsRegex(
UserWarning,
rf"torch.{device.type}.amp.GradScaler\(args...\) is deprecated.",
):
_ = GradScaler(init_scale=2.0)
@dtypesIfCUDA(torch.float, torch.double, torch.half)
@dtypesIfCPU(torch.float, torch.double, torch.bfloat16, torch.half)
@dtypes(torch.float, torch.double)

View File

@ -1,3 +1,5 @@
import warnings
import torch
__all__ = ["GradScaler"]
@ -6,7 +8,7 @@ __all__ = ["GradScaler"]
class GradScaler(torch.amp.GradScaler):
r"""
See :class:`torch.amp.GradScaler`.
``torch.cpu.amp.GradScaler(args...)`` is equivalent to ``torch.amp.GradScaler("cpu", args...)``
``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead.
"""
def __init__(
@ -17,6 +19,9 @@ class GradScaler(torch.amp.GradScaler):
growth_interval: int = 2000,
enabled: bool = True,
) -> None:
warnings.warn(
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead."
)
super().__init__(
"cpu",
init_scale=init_scale,

View File

@ -1,3 +1,5 @@
import warnings
import torch
__all__ = ["GradScaler"]
@ -6,7 +8,7 @@ __all__ = ["GradScaler"]
class GradScaler(torch.amp.GradScaler):
r"""
See :class:`torch.amp.GradScaler`.
``torch.cuda.amp.GradScaler(args...)`` is equivalent to ``torch.amp.GradScaler("cuda", args...)``
``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead.
"""
def __init__(
@ -17,6 +19,9 @@ class GradScaler(torch.amp.GradScaler):
growth_interval: int = 2000,
enabled: bool = True,
) -> None:
warnings.warn(
"torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead."
)
super().__init__(
"cuda",
init_scale=init_scale,