mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
generalize custom_fwd&custom_bwd to be device-agnostic (#126531)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126531 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/albanD, https://github.com/EikanWang ghstack dependencies: #126527
This commit is contained in:
committed by
PyTorch MergeBot
parent
c09205a057
commit
e7a42702f9
@ -46,6 +46,12 @@ Autocasting
|
||||
.. autoclass:: autocast
|
||||
:members:
|
||||
|
||||
.. currentmodule:: torch.amp
|
||||
|
||||
.. autofunction:: custom_fwd
|
||||
|
||||
.. autofunction:: custom_bwd
|
||||
|
||||
.. currentmodule:: torch.cuda.amp
|
||||
|
||||
.. autoclass:: autocast
|
||||
|
@ -446,13 +446,13 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
|
||||
class MyMM(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
def forward(ctx, a, b):
|
||||
ctx.save_for_backward(a, b)
|
||||
return a.mm(b)
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
@torch.amp.custom_bwd(device_type="cuda")
|
||||
def backward(ctx, grad):
|
||||
a, b = ctx.saved_tensors
|
||||
return grad.mm(b.t()), a.t().mm(grad)
|
||||
|
@ -1721,7 +1721,7 @@ torch.cuda.synchronize()
|
||||
def test_autocast_custom_enabled(self):
|
||||
class MyMM(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
def forward(ctx, a, b):
|
||||
self.assertTrue(a.dtype is torch.float32)
|
||||
self.assertTrue(b.dtype is torch.float32)
|
||||
@ -1730,7 +1730,7 @@ torch.cuda.synchronize()
|
||||
return a.mm(b)
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
@torch.amp.custom_bwd(device_type="cuda")
|
||||
def backward(ctx, grad):
|
||||
self.assertTrue(torch.is_autocast_enabled())
|
||||
a, b = ctx.saved_tensors
|
||||
@ -1754,7 +1754,7 @@ torch.cuda.synchronize()
|
||||
def test_autocast_custom_cast_inputs(self):
|
||||
class MyMM(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
@torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32)
|
||||
def forward(ctx, a, container, expect_type):
|
||||
b = container[1][0]
|
||||
self.assertTrue(a.dtype is expect_type)
|
||||
@ -1764,7 +1764,7 @@ torch.cuda.synchronize()
|
||||
return a.mm(b)
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
@torch.amp.custom_bwd(device_type="cuda")
|
||||
def backward(ctx, grad):
|
||||
self.assertFalse(torch.is_autocast_enabled())
|
||||
a, b = ctx.saved_tensors
|
||||
@ -1799,6 +1799,39 @@ torch.cuda.synchronize()
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
def test_autocast_custom_deprecated_warning(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
|
||||
class MyMM(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, x, y):
|
||||
ctx.save_for_backward(x, y)
|
||||
self.assertFalse(torch.is_autocast_enabled())
|
||||
return x + y
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
def backward(ctx, grad):
|
||||
_, _ = ctx.saved_tensors
|
||||
self.assertFalse(torch.is_autocast_enabled())
|
||||
return grad, grad
|
||||
|
||||
self.assertRegex(
|
||||
str(w[0].message), r"torch.cuda.amp.custom_fwd\(args...\) is deprecated."
|
||||
)
|
||||
self.assertRegex(
|
||||
str(w[1].message), r"torch.cuda.amp.custom_bwd\(args...\) is deprecated."
|
||||
)
|
||||
|
||||
mymm = MyMM.apply
|
||||
x = torch.randn(3, 3, requires_grad=True)
|
||||
y = torch.randn(3, 3, requires_grad=True)
|
||||
with torch.amp.autocast("cuda"):
|
||||
output = mymm(x, y)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
|
||||
def test_autocast_cat_jit(self):
|
||||
# Reported at https://github.com/pytorch/pytorch/issues/38958
|
||||
|
||||
|
@ -1278,6 +1278,7 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._C._wrap_tensor_impl",
|
||||
"torch._C.fork",
|
||||
"torch._C.get_autocast_cpu_dtype",
|
||||
"torch._C.get_autocast_dtype",
|
||||
"torch._C.get_autocast_gpu_dtype",
|
||||
"torch._C.get_autocast_ipu_dtype",
|
||||
"torch._C.get_autocast_xla_dtype",
|
||||
@ -2327,6 +2328,8 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch.amp.autocast_mode._enter_autocast",
|
||||
"torch.amp.autocast_mode._exit_autocast",
|
||||
"torch.amp.autocast_mode.autocast_decorator",
|
||||
"torch.amp.autocast_mode.custom_bwd",
|
||||
"torch.amp.autocast_mode.custom_fwd",
|
||||
"torch.are_deterministic_algorithms_enabled",
|
||||
"torch.atleast_1d",
|
||||
"torch.atleast_2d",
|
||||
|
@ -87,6 +87,7 @@ constant_fold_functions = [
|
||||
torch.cuda.get_device_properties,
|
||||
torch.cuda.is_available,
|
||||
torch.distributed.is_available,
|
||||
torch.get_autocast_dtype,
|
||||
torch.get_autocast_gpu_dtype,
|
||||
torch.get_default_dtype,
|
||||
torch.is_autocast_cache_enabled,
|
||||
|
@ -2,6 +2,8 @@ from .autocast_mode import (
|
||||
_enter_autocast,
|
||||
_exit_autocast,
|
||||
autocast,
|
||||
custom_bwd,
|
||||
custom_fwd,
|
||||
is_autocast_available,
|
||||
)
|
||||
from .grad_scaler import GradScaler
|
||||
|
@ -1,3 +1,4 @@
|
||||
import collections
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
@ -6,7 +7,20 @@ from typing import Any, Optional
|
||||
import torch
|
||||
from torch.types import _dtype
|
||||
|
||||
__all__ = ["autocast_decorator", "autocast", "is_autocast_available"]
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
HAS_NUMPY = True
|
||||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
__all__ = [
|
||||
"autocast_decorator",
|
||||
"autocast",
|
||||
"is_autocast_available",
|
||||
"custom_fwd",
|
||||
"custom_bwd",
|
||||
]
|
||||
|
||||
|
||||
def is_autocast_available(device_type: str) -> bool:
|
||||
@ -366,3 +380,123 @@ def _exit_autocast(mode):
|
||||
if torch._C._is_torch_function_mode_enabled():
|
||||
return torch.overrides.handle_torch_function(torch.amp._exit_autocast, [], mode)
|
||||
mode.__exit__(None, None, None)
|
||||
|
||||
|
||||
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
|
||||
# may be falsely detected as "Iterables."
|
||||
def _cast(value, device_type: str, dtype: _dtype):
|
||||
if isinstance(value, torch.Tensor):
|
||||
is_eligible = (
|
||||
value.is_floating_point()
|
||||
and value.device.type == device_type
|
||||
and (value.dtype is not torch.float64)
|
||||
)
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, (str, bytes)):
|
||||
return value
|
||||
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
||||
return value
|
||||
elif isinstance(value, collections.abc.Mapping):
|
||||
return {
|
||||
_cast(k, device_type, dtype): _cast(v, device_type, dtype)
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif isinstance(value, collections.abc.Iterable):
|
||||
iterable = (_cast(v, device_type, dtype) for v in value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return type(value)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def custom_fwd(
|
||||
fwd=None,
|
||||
*,
|
||||
device_type: str,
|
||||
cast_inputs: Optional[_dtype] = None,
|
||||
):
|
||||
"""
|
||||
Create a helper decorator for ``forward`` methods of custom autograd functions.
|
||||
|
||||
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||
when ``forward`` runs in an autocast-enabled region, casts incoming
|
||||
floating-point Tensors to the target dtype (non-floating-point Tensors are not affected),
|
||||
then executes ``forward`` with autocast disabled.
|
||||
If ``None``, ``forward``'s internal ops execute with the current autocast state.
|
||||
|
||||
.. note::
|
||||
If the decorated ``forward`` is called outside an autocast-enabled region,
|
||||
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
|
||||
"""
|
||||
if not isinstance(device_type, str):
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
if fwd is None:
|
||||
return functools.partial(
|
||||
custom_fwd, device_type=device_type, cast_inputs=cast_inputs
|
||||
)
|
||||
|
||||
@functools.wraps(fwd)
|
||||
def decorate_fwd(*args, **kwargs):
|
||||
args[0]._dtype = torch.get_autocast_dtype(device_type)
|
||||
if cast_inputs is None:
|
||||
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
|
||||
return fwd(*args, **kwargs)
|
||||
else:
|
||||
autocast_context = torch.is_autocast_enabled(device_type)
|
||||
args[0]._fwd_used_autocast = False
|
||||
if autocast_context:
|
||||
with autocast(device_type=device_type, enabled=False):
|
||||
return fwd(
|
||||
*_cast(args, device_type, cast_inputs),
|
||||
**_cast(kwargs, device_type, cast_inputs),
|
||||
)
|
||||
else:
|
||||
return fwd(*args, **kwargs)
|
||||
|
||||
return decorate_fwd
|
||||
|
||||
|
||||
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
|
||||
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
|
||||
# cast_inputs supplied to custom_fwd.
|
||||
def custom_bwd(bwd=None, *, device_type: str):
|
||||
"""Create a helper decorator for backward methods of custom autograd functions.
|
||||
|
||||
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||
Ensures that ``backward`` executes with the same autocast state as ``forward``.
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
device_type(str): Device type to use. 'cuda', 'cpu', 'xpu' and so on.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
"""
|
||||
|
||||
if not isinstance(device_type, str):
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
if bwd is None:
|
||||
return functools.partial(custom_bwd, device_type=device_type)
|
||||
|
||||
@functools.wraps(bwd)
|
||||
def decorate_bwd(*args, **kwargs):
|
||||
with autocast(
|
||||
device_type=device_type,
|
||||
enabled=args[0]._fwd_used_autocast,
|
||||
dtype=args[0]._dtype,
|
||||
):
|
||||
return bwd(*args, **kwargs)
|
||||
|
||||
return decorate_bwd
|
||||
|
@ -1,17 +1,9 @@
|
||||
import collections
|
||||
import functools
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
HAS_NUMPY = True
|
||||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
|
||||
|
||||
|
||||
@ -57,93 +49,25 @@ class autocast(torch.amp.autocast_mode.autocast):
|
||||
return super().__call__(func)
|
||||
|
||||
|
||||
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
|
||||
# may be falsely detected as "Iterables."
|
||||
def _cast(value, dtype):
|
||||
if isinstance(value, torch.Tensor):
|
||||
is_eligible = (
|
||||
value.is_floating_point()
|
||||
and value.is_cuda
|
||||
and (value.dtype is not torch.float64)
|
||||
)
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, (str, bytes)):
|
||||
return value
|
||||
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
||||
return value
|
||||
elif isinstance(value, collections.abc.Mapping):
|
||||
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
|
||||
elif isinstance(value, collections.abc.Iterable):
|
||||
iterable = (_cast(v, dtype) for v in value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return type(value)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
# custom_fwd is a decorator that may or may not be used with arguments, following
|
||||
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
|
||||
# this works:
|
||||
# @custom_fwd
|
||||
# def forward(...):
|
||||
# this also works:
|
||||
# @custom_fwd(cast_inputs=torch.float)
|
||||
# def forward(...):
|
||||
def custom_fwd(fwd=None, *, cast_inputs=None):
|
||||
"""
|
||||
Create a helper decorator for ``forward`` methods of custom autograd functions.
|
||||
|
||||
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||
when ``forward`` runs in an autocast-enabled region, casts incoming
|
||||
floating-point CUDA Tensors to the target dtype (non-floating-point Tensors are not affected),
|
||||
then executes ``forward`` with autocast disabled.
|
||||
If ``None``, ``forward``'s internal ops execute with the current autocast state.
|
||||
|
||||
.. note::
|
||||
If the decorated ``forward`` is called outside an autocast-enabled region,
|
||||
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
|
||||
``torch.cuda.amp.custom_fwd(args...)`` is deprecated. Please use
|
||||
``torch.amp.custom_fwd(args..., device_type='cuda')`` instead.
|
||||
"""
|
||||
if fwd is None:
|
||||
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
|
||||
|
||||
@functools.wraps(fwd)
|
||||
def decorate_fwd(*args, **kwargs):
|
||||
args[0]._dtype = torch.get_autocast_gpu_dtype()
|
||||
if cast_inputs is None:
|
||||
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
|
||||
return fwd(*args, **kwargs)
|
||||
else:
|
||||
autocast_context = torch.is_autocast_enabled()
|
||||
args[0]._fwd_used_autocast = False
|
||||
if autocast_context:
|
||||
with autocast(enabled=False):
|
||||
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
|
||||
else:
|
||||
return fwd(*args, **kwargs)
|
||||
|
||||
return decorate_fwd
|
||||
warnings.warn(
|
||||
"torch.cuda.amp.custom_fwd(args...) is deprecated. Please use torch.amp.custom_fwd(args..., device_type='cuda') instead."
|
||||
)
|
||||
return functools.partial(torch.amp.custom_fwd, device_type="cuda")(
|
||||
fwd=fwd, cast_inputs=cast_inputs
|
||||
)
|
||||
|
||||
|
||||
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
|
||||
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
|
||||
# cast_inputs supplied to custom_fwd.
|
||||
def custom_bwd(bwd):
|
||||
"""Create a helper decorator for backward methods of custom autograd functions.
|
||||
|
||||
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||
Ensures that ``backward`` executes with the same autocast state as ``forward``.
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
"""
|
||||
|
||||
@functools.wraps(bwd)
|
||||
def decorate_bwd(*args, **kwargs):
|
||||
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
|
||||
return bwd(*args, **kwargs)
|
||||
|
||||
return decorate_bwd
|
||||
``torch.cuda.amp.custom_bwd(args...)`` is deprecated. Please use
|
||||
``torch.amp.custom_bwd(args..., device_type='cuda')`` instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"torch.cuda.amp.custom_bwd(args...) is deprecated. Please use torch.amp.custom_bwd(args..., device_type='cuda') instead."
|
||||
)
|
||||
return functools.partial(torch.amp.custom_bwd, device_type="cuda")(bwd)
|
||||
|
Reference in New Issue
Block a user