mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Deprecate device-specific GradScaler autocast API (#126527)
# Motivation ## for `torch.amp.GradScaler`, - `torch.cpu.amp.GradScaler(args...)` is completely equivalent to `torch. amp.GradScaler("cpu", args...)`. - `torch.cuda.amp.GradScaler(args...)` is completely equivalent to `torch.amp.GradScaler("cuda", args...)`. So, we intend to depreate them and **strongly recommend** developer to use `torch.amp.GradScaler`. ## for `custom_fwd` and `custom_bwd`, this is a good solution to make the custom function run with or without effect even in an autocast-enabled region and can be shared by other backends, like CPU and XPU. So we generalize it to be device-agnostic and put them int `torch/amp/autocast_mode.py` and re-expose to `torch.amp.custom_fwd` and `torch.amp.custom_bwd`. Meanwhile, we deprecate `torch.cuda.amp.custom_fwd` and `torch.cuda.amp.custom_bwd`. # Additional Context Add UT to cover the deprecated warning. No need for more UTs to cover the functionality of `torch.amp.custom_f/bwd`, the existing UTs that previously covered the functionality of `torch.cuda.amp.custom_f/bwd` can cover them. To facilitate the review, we separate these code changes to two PRs. The first PR cover `torch.amp.GradScaler`. The follow-up covers `custom_fwd` and `custom_bwd`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/126527 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/janeyx99, https://github.com/EikanWang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ef86a27dba
commit
c09205a057
@ -2101,7 +2101,7 @@ class BenchmarkRunner:
|
||||
# which is bad as Gradscaler has state and can adjust the scaling
|
||||
# factor between eager and dynamo run, making accuracy check
|
||||
# harder.
|
||||
# self.grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
|
||||
# self.grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
|
||||
self.autocast = functools.partial(
|
||||
torch.amp.autocast, device_type=devices[0]
|
||||
)
|
||||
|
@ -19,18 +19,15 @@ are much faster in ``lower_precision_fp``. Other ops, like reductions, often req
|
||||
range of ``float32``. Mixed precision tries to match each op to its appropriate datatype.
|
||||
|
||||
Ordinarily, "automatic mixed precision training" with datatype of ``torch.float16`` uses :class:`torch.autocast` and
|
||||
:class:`torch.cpu.amp.GradScaler` or :class:`torch.cuda.amp.GradScaler` together, as shown in the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>`
|
||||
:class:`torch.amp.GradScaler` together, as shown in the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>`
|
||||
and `CUDA Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_.
|
||||
However, :class:`torch.autocast` and :class:`torch.GradScaler` are modular, and may be used separately if desired.
|
||||
As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with
|
||||
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`.
|
||||
|
||||
For CUDA and CPU, APIs are also provided separately:
|
||||
|
||||
* ``torch.autocast("cuda", args...)`` is equivalent to ``torch.cuda.amp.autocast(args...)``.
|
||||
* ``torch.autocast("cpu", args...)`` is equivalent to ``torch.cpu.amp.autocast(args...)``. For CPU, only lower precision floating point datatype of ``torch.bfloat16`` is supported for now.
|
||||
* ``torch.GradScaler("cuda", args...)`` is equivalent to ``torch.cuda.amp.GradScaler(args...)``.
|
||||
* ``torch.GradScaler("cpu", args...)`` is equivalent to ``torch.cpu.amp.GradScaler(args...)``.
|
||||
.. warning::
|
||||
``torch.cuda.amp.autocast(args...)`` and ``torch.cpu.amp.autocast(args...)`` will be deprecated. Please use ``torch.autocast("cuda", args...)`` or ``torch.autocast("cpu", args...)`` instead.
|
||||
``torch.cuda.amp.GradScaler(args...)`` and ``torch.cpu.amp.GradScaler(args...)`` will be deprecated. Please use ``torch.GradScaler("cuda", args...)`` or ``torch.GradScaler("cpu", args...)`` instead.
|
||||
|
||||
:class:`torch.autocast` and :class:`torch.cpu.amp.autocast` are new in version `1.10`.
|
||||
|
||||
|
@ -257,7 +257,7 @@ class TestShardedGradScalerParityWithDDP(FSDPTest):
|
||||
use_orig_params=use_orig_params,
|
||||
)
|
||||
grad_scaler = ShardedGradScaler(init_scale=2.0)
|
||||
ref_grad_scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
|
||||
ref_grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
|
||||
scaled_losses: List[torch.Tensor] = []
|
||||
device = torch.device("cuda")
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
|
@ -3437,13 +3437,15 @@ exit(2)
|
||||
grads_graphed = [[g.clone() for g in gs] for gs in grads]
|
||||
|
||||
# Gradient Scaler
|
||||
scaler_for_control = torch.cuda.amp.GradScaler(init_scale=128.0)
|
||||
scaler_for_control = torch.amp.GradScaler(
|
||||
device="cuda", init_scale=128.0
|
||||
)
|
||||
with torch.no_grad():
|
||||
scaler_for_control._lazy_init_scale_growth_tracker(
|
||||
torch.device("cuda")
|
||||
)
|
||||
|
||||
scaler_for_graphed = torch.cuda.amp.GradScaler()
|
||||
scaler_for_graphed = torch.amp.GradScaler(device="cuda")
|
||||
scaler_for_graphed.load_state_dict(scaler_for_control.state_dict())
|
||||
with torch.no_grad():
|
||||
scaler_for_graphed._lazy_init_scale_growth_tracker(
|
||||
@ -4722,7 +4724,7 @@ class TestCudaOptims(TestCase):
|
||||
def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(init_scale=4.0)
|
||||
scaler = torch.amp.GradScaler(device="cuda", init_scale=4.0)
|
||||
g = torch.cuda.CUDAGraph()
|
||||
s = torch.cuda.Stream()
|
||||
|
||||
|
@ -1159,7 +1159,7 @@ t2.start()
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_grad_scaling_scale(self):
|
||||
scaler = torch.cuda.amp.GradScaler(init_scale=2.0)
|
||||
scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
|
||||
t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:0")
|
||||
t1 = torch.full((1,), 4.0, dtype=torch.float32, device="cuda:1")
|
||||
# Create some nested iterables of tensors on different devices.
|
||||
@ -1205,8 +1205,12 @@ t2.start()
|
||||
opt_scaling1,
|
||||
) = _create_scaling_models_optimizers(device=dev1)
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=128.0, growth_factor=2.0, enabled=enabled, growth_interval=1
|
||||
scaler = torch.amp.GradScaler(
|
||||
device="cuda",
|
||||
init_scale=128.0,
|
||||
growth_factor=2.0,
|
||||
enabled=enabled,
|
||||
growth_interval=1,
|
||||
)
|
||||
|
||||
def run(model0, model1, optimizer0, optimizer1, try_scaling_api):
|
||||
|
@ -6152,7 +6152,7 @@ else:
|
||||
@onlyNativeDeviceTypes
|
||||
def test_grad_scaler_pass_itself(self, device):
|
||||
device = torch.device(device)
|
||||
GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler
|
||||
GradScaler = partial(torch.amp.GradScaler, device=device.type)
|
||||
|
||||
class _PlaceHolderOptimizer(torch.optim.Optimizer):
|
||||
tester = self
|
||||
@ -6165,7 +6165,7 @@ else:
|
||||
|
||||
class Optimizer1(_PlaceHolderOptimizer):
|
||||
def step(self, closure=None, *, grad_scaler=None):
|
||||
self.tester.assertTrue(isinstance(grad_scaler, GradScaler))
|
||||
self.tester.assertTrue(isinstance(grad_scaler, torch.amp.GradScaler))
|
||||
self.tester.assertFalse(hasattr(self, "grad_scale"))
|
||||
self.tester.assertFalse(hasattr(self, "found_inf"))
|
||||
|
||||
@ -6189,6 +6189,17 @@ else:
|
||||
scaler.step(o2)
|
||||
scaler.update()
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_grad_scaler_deprecated_warning(self, device):
|
||||
device = torch.device(device)
|
||||
GradScaler = torch.cuda.amp.GradScaler if "cuda" == device.type else torch.cpu.amp.GradScaler
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
rf"torch.{device.type}.amp.GradScaler\(args...\) is deprecated.",
|
||||
):
|
||||
_ = GradScaler(init_scale=2.0)
|
||||
|
||||
@dtypesIfCUDA(torch.float, torch.double, torch.half)
|
||||
@dtypesIfCPU(torch.float, torch.double, torch.bfloat16, torch.half)
|
||||
@dtypes(torch.float, torch.double)
|
||||
|
@ -1,3 +1,5 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["GradScaler"]
|
||||
@ -6,7 +8,7 @@ __all__ = ["GradScaler"]
|
||||
class GradScaler(torch.amp.GradScaler):
|
||||
r"""
|
||||
See :class:`torch.amp.GradScaler`.
|
||||
``torch.cpu.amp.GradScaler(args...)`` is equivalent to ``torch.amp.GradScaler("cpu", args...)``
|
||||
``torch.cpu.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cpu", args...)`` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -17,6 +19,9 @@ class GradScaler(torch.amp.GradScaler):
|
||||
growth_interval: int = 2000,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
warnings.warn(
|
||||
"torch.cpu.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cpu', args...) instead."
|
||||
)
|
||||
super().__init__(
|
||||
"cpu",
|
||||
init_scale=init_scale,
|
||||
|
@ -1,3 +1,5 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["GradScaler"]
|
||||
@ -6,7 +8,7 @@ __all__ = ["GradScaler"]
|
||||
class GradScaler(torch.amp.GradScaler):
|
||||
r"""
|
||||
See :class:`torch.amp.GradScaler`.
|
||||
``torch.cuda.amp.GradScaler(args...)`` is equivalent to ``torch.amp.GradScaler("cuda", args...)``
|
||||
``torch.cuda.amp.GradScaler(args...)`` is deprecated. Please use ``torch.amp.GradScaler("cuda", args...)`` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -17,6 +19,9 @@ class GradScaler(torch.amp.GradScaler):
|
||||
growth_interval: int = 2000,
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
warnings.warn(
|
||||
"torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead."
|
||||
)
|
||||
super().__init__(
|
||||
"cuda",
|
||||
init_scale=init_scale,
|
||||
|
Reference in New Issue
Block a user