generalize custom_fwd&custom_bwd to be device-agnostic (#126531)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126531
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/albanD, https://github.com/EikanWang
ghstack dependencies: #126527
This commit is contained in:
Yu, Guangye
2024-05-24 21:53:25 +00:00
committed by PyTorch MergeBot
parent c09205a057
commit e7a42702f9
8 changed files with 202 additions and 99 deletions

View File

@ -46,6 +46,12 @@ Autocasting
.. autoclass:: autocast
:members:
.. currentmodule:: torch.amp
.. autofunction:: custom_fwd
.. autofunction:: custom_bwd
.. currentmodule:: torch.cuda.amp
.. autoclass:: autocast

View File

@ -446,13 +446,13 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
class MyMM(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type="cuda")
def forward(ctx, a, b):
ctx.save_for_backward(a, b)
return a.mm(b)
@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, grad):
a, b = ctx.saved_tensors
return grad.mm(b.t()), a.t().mm(grad)

View File

@ -1721,7 +1721,7 @@ torch.cuda.synchronize()
def test_autocast_custom_enabled(self):
class MyMM(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd
@torch.amp.custom_fwd(device_type="cuda")
def forward(ctx, a, b):
self.assertTrue(a.dtype is torch.float32)
self.assertTrue(b.dtype is torch.float32)
@ -1730,7 +1730,7 @@ torch.cuda.synchronize()
return a.mm(b)
@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, grad):
self.assertTrue(torch.is_autocast_enabled())
a, b = ctx.saved_tensors
@ -1754,7 +1754,7 @@ torch.cuda.synchronize()
def test_autocast_custom_cast_inputs(self):
class MyMM(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
def forward(ctx, a, container, expect_type):
b = container[1][0]
self.assertTrue(a.dtype is expect_type)
@ -1764,7 +1764,7 @@ torch.cuda.synchronize()
return a.mm(b)
@staticmethod
@torch.cuda.amp.custom_bwd
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, grad):
self.assertFalse(torch.is_autocast_enabled())
a, b = ctx.saved_tensors
@ -1799,6 +1799,39 @@ torch.cuda.synchronize()
loss = output.sum()
loss.backward()
def test_autocast_custom_deprecated_warning(self):
with warnings.catch_warnings(record=True) as w:
class MyMM(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
self.assertFalse(torch.is_autocast_enabled())
return x + y
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad):
_, _ = ctx.saved_tensors
self.assertFalse(torch.is_autocast_enabled())
return grad, grad
self.assertRegex(
str(w[0].message), r"torch.cuda.amp.custom_fwd\(args...\) is deprecated."
)
self.assertRegex(
str(w[1].message), r"torch.cuda.amp.custom_bwd\(args...\) is deprecated."
)
mymm = MyMM.apply
x = torch.randn(3, 3, requires_grad=True)
y = torch.randn(3, 3, requires_grad=True)
with torch.amp.autocast("cuda"):
output = mymm(x, y)
loss = output.sum()
loss.backward()
def test_autocast_cat_jit(self):
# Reported at https://github.com/pytorch/pytorch/issues/38958

View File

@ -1278,6 +1278,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._C._wrap_tensor_impl",
"torch._C.fork",
"torch._C.get_autocast_cpu_dtype",
"torch._C.get_autocast_dtype",
"torch._C.get_autocast_gpu_dtype",
"torch._C.get_autocast_ipu_dtype",
"torch._C.get_autocast_xla_dtype",
@ -2327,6 +2328,8 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.amp.autocast_mode._enter_autocast",
"torch.amp.autocast_mode._exit_autocast",
"torch.amp.autocast_mode.autocast_decorator",
"torch.amp.autocast_mode.custom_bwd",
"torch.amp.autocast_mode.custom_fwd",
"torch.are_deterministic_algorithms_enabled",
"torch.atleast_1d",
"torch.atleast_2d",

View File

@ -87,6 +87,7 @@ constant_fold_functions = [
torch.cuda.get_device_properties,
torch.cuda.is_available,
torch.distributed.is_available,
torch.get_autocast_dtype,
torch.get_autocast_gpu_dtype,
torch.get_default_dtype,
torch.is_autocast_cache_enabled,

View File

@ -2,6 +2,8 @@ from .autocast_mode import (
_enter_autocast,
_exit_autocast,
autocast,
custom_bwd,
custom_fwd,
is_autocast_available,
)
from .grad_scaler import GradScaler

View File

@ -1,3 +1,4 @@
import collections
import functools
import warnings
@ -6,7 +7,20 @@ from typing import Any, Optional
import torch
from torch.types import _dtype
__all__ = ["autocast_decorator", "autocast", "is_autocast_available"]
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
__all__ = [
"autocast_decorator",
"autocast",
"is_autocast_available",
"custom_fwd",
"custom_bwd",
]
def is_autocast_available(device_type: str) -> bool:
@ -366,3 +380,123 @@ def _exit_autocast(mode):
if torch._C._is_torch_function_mode_enabled():
return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode)
mode.__exit__(None, None, None)
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
# may be falsely detected as "Iterables."
def _cast(value, device_type: str, dtype: _dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.device.type == device_type
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {
_cast(k, device_type, dtype): _cast(v, device_type, dtype)
for k, v in value.items()
}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, device_type, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
def custom_fwd(
fwd=None,
*,
device_type: str,
cast_inputs: Optional[_dtype] = None,
):
"""
Create a helper decorator for ``forward`` methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point Tensors to the target dtype (non-floating-point Tensors are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
"""
if not isinstance(device_type, str):
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if fwd is None:
return functools.partial(
custom_fwd, device_type=device_type, cast_inputs=cast_inputs
)
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_dtype(device_type)
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled(device_type)
args[0]._fwd_used_autocast = False
if autocast_context:
with autocast(device_type=device_type, enabled=False):
return fwd(
*_cast(args, device_type, cast_inputs),
**_cast(kwargs, device_type, cast_inputs),
)
else:
return fwd(*args, **kwargs)
return decorate_fwd
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
# cast_inputs supplied to custom_fwd.
def custom_bwd(bwd=None, *, device_type: str):
"""Create a helper decorator for backward methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
Ensures that ``backward`` executes with the same autocast state as ``forward``.
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
"""
if not isinstance(device_type, str):
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if bwd is None:
return functools.partial(custom_bwd, device_type=device_type)
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with autocast(
device_type=device_type,
enabled=args[0]._fwd_used_autocast,
dtype=args[0]._dtype,
):
return bwd(*args, **kwargs)
return decorate_bwd

View File

@ -1,17 +1,9 @@
import collections
import functools
import warnings
from typing import Any
import torch
try:
import numpy as np
HAS_NUMPY = True
except ModuleNotFoundError:
np = None # type: ignore[assignment]
from typing import Any
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
@ -57,93 +49,25 @@ class autocast(torch.amp.autocast_mode.autocast):
return super().__call__(func)
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
# may be falsely detected as "Iterables."
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.is_cuda
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
# custom_fwd is a decorator that may or may not be used with arguments, following
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
# this works:
# @custom_fwd
# def forward(...):
# this also works:
# @custom_fwd(cast_inputs=torch.float)
# def forward(...):
def custom_fwd(fwd=None, *, cast_inputs=None):
"""
Create a helper decorator for ``forward`` methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
See the :ref:`example page<amp-custom-examples>` for more detail.
Args:
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
when ``forward`` runs in an autocast-enabled region, casts incoming
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
then executes ``forward`` with autocast disabled.
If ``None``, ``forward``'s internal ops execute with the current autocast state.
.. note::
If the decorated ``forward`` is called outside an autocast-enabled region,
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use
``torch.amp.custom_fwd(args..., device_type='cuda')`` instead.
"""
if fwd is None:
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
args[0]._dtype = torch.get_autocast_gpu_dtype()
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
return fwd(*args, **kwargs)
else:
autocast_context = torch.is_autocast_enabled()
args[0]._fwd_used_autocast = False
if autocast_context:
with autocast(enabled=False):
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
else:
return fwd(*args, **kwargs)
return decorate_fwd
warnings.warn(
"torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead."
)
return functools.partial(torch.amp.custom_fwd, device_type="cuda")(
fwd=fwd, cast_inputs=cast_inputs
)
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
# cast_inputs supplied to custom_fwd.
def custom_bwd(bwd):
"""Create a helper decorator for backward methods of custom autograd functions.
Autograd functions are subclasses of :class:`torch.autograd.Function`.
Ensures that ``backward`` executes with the same autocast state as ``forward``.
See the :ref:`example page<amp-custom-examples>` for more detail.
"""
@functools.wraps(bwd)
def decorate_bwd(*args, **kwargs):
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
return bwd(*args, **kwargs)
return decorate_bwd
``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use
``torch.amp.custom_bwd(args..., device_type='cuda')`` instead.
"""
warnings.warn(
"torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead."
)
return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd)