mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Custom Ops] Add a new API to allow users to register an autocast for the custom op (#145588)
Fixes #137033 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145588 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
f951d216e0
commit
ec91b7720f
@ -42,6 +42,7 @@ for any operators (they may have been created using :func:`torch.library.custom_
|
||||
via PyTorch's C++ operator registration APIs).
|
||||
|
||||
.. autofunction:: register_kernel
|
||||
.. autofunction:: register_autocast
|
||||
.. autofunction:: register_autograd
|
||||
.. autofunction:: register_fake
|
||||
.. autofunction:: register_vmap
|
||||
|
@ -3033,6 +3033,150 @@ class TestCustomOpAPI(TestCase):
|
||||
self.assertEqual(z, x + y)
|
||||
self.assertTrue(called)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_library_register_autocast(self):
|
||||
for device in ["cuda", "cpu"]:
|
||||
for mode in ["function", "qualname", "opoverload"]:
|
||||
|
||||
@torch.library.custom_op("mylib::my_sin", mutates_args=())
|
||||
def my_sin(x: Tensor) -> Tensor:
|
||||
return torch.sin(x)
|
||||
|
||||
if mode == "function":
|
||||
torch.library.register_autocast(my_sin, device, torch.float16)
|
||||
elif mode == "qualname":
|
||||
torch.library.register_autocast(
|
||||
"mylib::my_sin", device, torch.float16
|
||||
)
|
||||
elif mode == "opoverload":
|
||||
torch.library.register_autocast(
|
||||
torch.ops.mylib.my_sin.default, device, torch.float16
|
||||
)
|
||||
|
||||
x = torch.randn(3, dtype=torch.float32, device=device)
|
||||
with torch.autocast(device, dtype=torch.float16):
|
||||
y = torch.ops.mylib.my_sin(x)
|
||||
self.assertEqual(y.dtype, torch.float16)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_library_register_autocast_low_level(self):
|
||||
for device in ["cuda", "cpu"]:
|
||||
for mode in ["qualname", "opoverload"]:
|
||||
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
|
||||
lib.define("my_sin(Tensor x) -> Tensor")
|
||||
|
||||
def my_sin(x: Tensor) -> Tensor:
|
||||
return torch.sin(x)
|
||||
|
||||
lib.impl("my_sin", my_sin, device.upper())
|
||||
|
||||
if mode == "qualname":
|
||||
torch.library.register_autocast(
|
||||
"_torch_testing::my_sin", device, torch.float16, lib=lib
|
||||
)
|
||||
elif mode == "opoverload":
|
||||
torch.library.register_autocast(
|
||||
torch.ops._torch_testing.my_sin.default,
|
||||
device,
|
||||
torch.float16,
|
||||
lib=lib,
|
||||
)
|
||||
|
||||
x = torch.randn(3, dtype=torch.float32, device=device)
|
||||
with torch.autocast(device, dtype=torch.float16):
|
||||
y = torch.ops._torch_testing.my_sin(x)
|
||||
self.assertEqual(y.dtype, torch.float16)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_library_register_autocast_list_input(self):
|
||||
for device in ["cuda", "cpu"]:
|
||||
for mode in ["function", "qualname", "opoverload"]:
|
||||
|
||||
@torch.library.custom_op("mylib::my_add_sin", mutates_args=())
|
||||
def my_add_sin(x: List[Tensor]) -> Tensor:
|
||||
return torch.sin(x[0] + x[1])
|
||||
|
||||
if mode == "function":
|
||||
torch.library.register_autocast(my_add_sin, device, torch.float16)
|
||||
elif mode == "qualname":
|
||||
torch.library.register_autocast(
|
||||
"mylib::my_add_sin", device, torch.float16
|
||||
)
|
||||
elif mode == "opoverload":
|
||||
torch.library.register_autocast(
|
||||
torch.ops.mylib.my_add_sin.default, device, torch.float16
|
||||
)
|
||||
|
||||
lst = [
|
||||
torch.randn(3, dtype=torch.float32, device=device) for _ in range(2)
|
||||
]
|
||||
with torch.autocast(device, dtype=torch.float16):
|
||||
y = torch.ops.mylib.my_add_sin(lst)
|
||||
self.assertEqual(y.dtype, torch.float16)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_library_register_autocast_multiple_times(self):
|
||||
for device in ["cuda", "cpu"]:
|
||||
|
||||
@torch.library.custom_op("mylib::my_sin", mutates_args=())
|
||||
def my_sin(x: Tensor) -> Tensor:
|
||||
return torch.sin(x)
|
||||
|
||||
torch.library.register_autocast(my_sin, device, torch.float16)
|
||||
|
||||
x = torch.randn(3, dtype=torch.float32, device=device)
|
||||
with torch.autocast(device, dtype=torch.float16):
|
||||
y1 = my_sin(x)
|
||||
self.assertEqual(y1.dtype, torch.float16)
|
||||
|
||||
# Ensure calling register_autocast multiple times does not error out.
|
||||
torch.library.register_autocast(my_sin, device, torch.float16)
|
||||
|
||||
with torch.autocast(device, dtype=torch.float16):
|
||||
y2 = my_sin(x)
|
||||
self.assertEqual(y2.dtype, torch.float16)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
|
||||
def test_library_register_autocast_multiple_times_different_devices(self):
|
||||
@torch.library.custom_op("mylib::my_sin", mutates_args=())
|
||||
def my_sin(x: Tensor) -> Tensor:
|
||||
return torch.sin(x)
|
||||
|
||||
# Register autocast for CUDA
|
||||
torch.library.register_autocast(my_sin, "cuda", torch.float16)
|
||||
|
||||
x1 = torch.randn(3, dtype=torch.float32, device="cuda")
|
||||
with torch.autocast("cuda", dtype=torch.float16):
|
||||
y1 = my_sin(x1)
|
||||
self.assertEqual(y1.dtype, torch.float16)
|
||||
|
||||
# Register autocast for CPU
|
||||
torch.library.register_autocast(my_sin, "cpu", torch.float16)
|
||||
|
||||
x2 = torch.randn(3, dtype=torch.float32, device="cpu")
|
||||
with torch.autocast("cpu", dtype=torch.float16):
|
||||
y2 = my_sin(x2)
|
||||
self.assertEqual(y2.dtype, torch.float16)
|
||||
|
||||
# Register CUDA autocast for the second time
|
||||
torch.library.register_autocast(my_sin, "cuda", torch.float16)
|
||||
|
||||
with torch.autocast("cuda", dtype=torch.float16):
|
||||
y3 = my_sin(x1)
|
||||
self.assertEqual(y3.dtype, torch.float16)
|
||||
|
||||
# Register CPU autocast for the second time
|
||||
torch.library.register_autocast(my_sin, "cpu", torch.float16)
|
||||
|
||||
with torch.autocast("cpu", dtype=torch.float16):
|
||||
y4 = my_sin(x2)
|
||||
self.assertEqual(y4.dtype, torch.float16)
|
||||
|
||||
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
|
||||
def test_library_register_autograd(self):
|
||||
for mode in ["function", "qualname", "opoverload"]:
|
||||
|
@ -1,4 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import collections
|
||||
import inspect
|
||||
import logging
|
||||
import weakref
|
||||
@ -8,6 +9,7 @@ from typing import Any, Callable, Literal, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
from torch import _C, _ops, Tensor
|
||||
from torch.types import _dtype
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
from . import autograd, utils
|
||||
@ -195,6 +197,8 @@ class CustomOpDef:
|
||||
self._backward_fn: Optional[Callable] = None
|
||||
self._torch_dispatch_fns: dict[type, Callable] = {}
|
||||
self._vmap_fn: Optional[Callable] = None
|
||||
self._autocast_cuda_dtype: Optional[_dtype] = None
|
||||
self._autocast_cpu_dtype: Optional[_dtype] = None
|
||||
|
||||
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
|
||||
self._register_to_dispatcher()
|
||||
@ -763,6 +767,96 @@ class CustomOpDef:
|
||||
else:
|
||||
return register(func)
|
||||
|
||||
def register_autocast(
|
||||
self,
|
||||
device_type: str,
|
||||
cast_inputs: _dtype,
|
||||
):
|
||||
r"""Register an autocast dispatch rule for this custom op.
|
||||
|
||||
Valid `device_type` include: "cpu" and "cuda".
|
||||
|
||||
Args:
|
||||
op (str | OpOverload): The operator to register an autocast dispatch rule to.
|
||||
device_type(str): Device type to use. 'cuda' or 'cpu'.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
|
||||
casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
|
||||
are not affected), then executes custom op with autocast disabled.
|
||||
lib (Optional[Library]): If provided, the lifetime of this registration
|
||||
|
||||
Examples::
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> import torch
|
||||
>>> from torch import Tensor
|
||||
>>> from torch.library import custom_op
|
||||
>>>
|
||||
>>> # Create a custom op that works on cuda
|
||||
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
|
||||
>>> def my_sin(x: Tensor) -> Tensor:
|
||||
>>> return torch.sin(x)
|
||||
>>>
|
||||
>>> # Register autocast dispatch rule for the cuda device
|
||||
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
|
||||
>>>
|
||||
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
|
||||
>>> with torch.autocast("cuda", dtype=torch.float16):
|
||||
>>> y = torch.ops.mylib.my_sin(x)
|
||||
>>> assert y.dtype == torch.float16
|
||||
|
||||
"""
|
||||
if not isinstance(device_type, str):
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
if device_type not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"Unknown device type: {device_type}")
|
||||
|
||||
need_register_cuda = self._autocast_cuda_dtype is None
|
||||
need_register_cpu = self._autocast_cpu_dtype is None
|
||||
if device_type == "cuda":
|
||||
self._autocast_cuda_dtype = cast_inputs
|
||||
else:
|
||||
self._autocast_cpu_dtype = cast_inputs
|
||||
|
||||
def kernel(_, *args, **kwargs):
|
||||
assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
|
||||
autocast_keyset = torch._C.DispatchKeySet(
|
||||
torch._C.DispatchKey.AutocastCPU
|
||||
) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
|
||||
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
|
||||
return self._opoverload(*_cast(args, device_type, cast_inputs))
|
||||
|
||||
if need_register_cuda and self._autocast_cuda_dtype:
|
||||
self._lib.impl(self._name, kernel, "AutocastCUDA", with_keyset=True)
|
||||
elif need_register_cpu and self._autocast_cpu_dtype:
|
||||
self._lib.impl(self._name, kernel, "AutocastCPU", with_keyset=True)
|
||||
|
||||
return kernel
|
||||
|
||||
|
||||
# TODO: Merge this function with torch.amp.autocast_mode._cast, and refactor it
|
||||
# into a utility function once custom ops support arbitrary input types.
|
||||
def _cast(value, device_type: str, dtype: _dtype):
|
||||
if isinstance(value, torch.Tensor):
|
||||
is_eligible = (
|
||||
value.is_floating_point()
|
||||
and value.device.type == device_type
|
||||
and (value.dtype is not torch.float64)
|
||||
)
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, (str, bytes)):
|
||||
return value
|
||||
elif isinstance(value, collections.abc.Iterable):
|
||||
iterable = (_cast(v, device_type, dtype) for v in value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return type(value)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
def increment_version(val: Any) -> None:
|
||||
if isinstance(val, Tensor):
|
||||
|
@ -22,6 +22,7 @@ from typing_extensions import deprecated, ParamSpec
|
||||
import torch
|
||||
import torch._library as _library
|
||||
from torch._library.custom_ops import (
|
||||
_cast,
|
||||
_maybe_get_opdef,
|
||||
custom_op,
|
||||
CustomOpDef,
|
||||
@ -30,6 +31,7 @@ from torch._library.custom_ops import (
|
||||
from torch._library.infer_schema import infer_schema # noqa: F401
|
||||
from torch._library.triton import triton_op, wrap_triton
|
||||
from torch._ops import OpOverload
|
||||
from torch.types import _dtype
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -38,6 +40,7 @@ __all__ = [
|
||||
"define",
|
||||
"fallthrough_kernel",
|
||||
"impl_abstract",
|
||||
"register_autocast",
|
||||
"register_fake",
|
||||
"register_torch_dispatch",
|
||||
"register_vmap",
|
||||
@ -823,6 +826,87 @@ def register_kernel(
|
||||
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
|
||||
|
||||
|
||||
def register_autocast(
|
||||
op: _op_identifier,
|
||||
device_type: str,
|
||||
cast_inputs: _dtype,
|
||||
/,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
):
|
||||
r"""Register an autocast dispatch rule for this custom op.
|
||||
|
||||
Valid `device_type` include: "cpu" and "cuda".
|
||||
|
||||
Args:
|
||||
op (str | OpOverload): The operator to register an autocast dispatch rule to.
|
||||
device_type(str): Device type to use. 'cuda' or 'cpu'.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
|
||||
casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
|
||||
are not affected), then executes custom op with autocast disabled.
|
||||
lib (Optional[Library]): If provided, the lifetime of this registration
|
||||
|
||||
Examples::
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> import torch
|
||||
>>> from torch import Tensor
|
||||
>>> from torch.library import custom_op
|
||||
>>>
|
||||
>>> # Create a custom op that works on cuda
|
||||
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
|
||||
>>> def my_sin(x: Tensor) -> Tensor:
|
||||
>>> return torch.sin(x)
|
||||
>>>
|
||||
>>> # Register autocast dispatch rule for the cuda device
|
||||
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
|
||||
>>>
|
||||
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
|
||||
>>> with torch.autocast("cuda", dtype=torch.float16):
|
||||
>>> y = torch.ops.mylib.my_sin(x)
|
||||
>>> assert y.dtype == torch.float16
|
||||
|
||||
"""
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError(
|
||||
f"register_autocast(op): got unexpected type for op: {type(op)}"
|
||||
)
|
||||
if device_type not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"Unknown device type: {device_type}")
|
||||
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
if opdef is not None:
|
||||
return opdef.register_autocast(device_type, cast_inputs)
|
||||
|
||||
assert isinstance(op, str)
|
||||
qualname = op
|
||||
_op = torch._library.utils.lookup_op(qualname)
|
||||
|
||||
namespace, opname = torch._library.utils.parse_namespace(qualname)
|
||||
if lib is None:
|
||||
lib = Library(namespace, "FRAGMENT")
|
||||
_keep_alive.append(lib)
|
||||
|
||||
def kernel(_, *args, **kwargs):
|
||||
assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
|
||||
autocast_keyset = torch._C.DispatchKeySet(
|
||||
torch._C.DispatchKey.AutocastCPU
|
||||
) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
|
||||
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
|
||||
return _op(*_cast(args, device_type, cast_inputs))
|
||||
|
||||
if device_type == "cuda":
|
||||
return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True)
|
||||
else:
|
||||
# device_type is "cpu"
|
||||
return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True)
|
||||
|
||||
|
||||
def register_fake(
|
||||
op: _op_identifier,
|
||||
func: Optional[Callable] = None,
|
||||
|
Reference in New Issue
Block a user