[Custom Ops] Add a new API to allow users to register an autocast for the custom op (#145588)

Fixes #137033

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145588
Approved by: https://github.com/zou3519
This commit is contained in:
Yanbo Liang
2025-01-24 20:08:45 -08:00
committed by PyTorch MergeBot
parent f951d216e0
commit ec91b7720f
4 changed files with 323 additions and 0 deletions

View File

@ -42,6 +42,7 @@ for any operators (they may have been created using :func:`torch.library.custom_
via PyTorch's C++ operator registration APIs).
.. autofunction:: register_kernel
.. autofunction:: register_autocast
.. autofunction:: register_autograd
.. autofunction:: register_fake
.. autofunction:: register_vmap

View File

@ -3033,6 +3033,150 @@ class TestCustomOpAPI(TestCase):
self.assertEqual(z, x + y)
self.assertTrue(called)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_library_register_autocast(self):
for device in ["cuda", "cpu"]:
for mode in ["function", "qualname", "opoverload"]:
@torch.library.custom_op("mylib::my_sin", mutates_args=())
def my_sin(x: Tensor) -> Tensor:
return torch.sin(x)
if mode == "function":
torch.library.register_autocast(my_sin, device, torch.float16)
elif mode == "qualname":
torch.library.register_autocast(
"mylib::my_sin", device, torch.float16
)
elif mode == "opoverload":
torch.library.register_autocast(
torch.ops.mylib.my_sin.default, device, torch.float16
)
x = torch.randn(3, dtype=torch.float32, device=device)
with torch.autocast(device, dtype=torch.float16):
y = torch.ops.mylib.my_sin(x)
self.assertEqual(y.dtype, torch.float16)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_library_register_autocast_low_level(self):
for device in ["cuda", "cpu"]:
for mode in ["qualname", "opoverload"]:
with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib:
lib.define("my_sin(Tensor x) -> Tensor")
def my_sin(x: Tensor) -> Tensor:
return torch.sin(x)
lib.impl("my_sin", my_sin, device.upper())
if mode == "qualname":
torch.library.register_autocast(
"_torch_testing::my_sin", device, torch.float16, lib=lib
)
elif mode == "opoverload":
torch.library.register_autocast(
torch.ops._torch_testing.my_sin.default,
device,
torch.float16,
lib=lib,
)
x = torch.randn(3, dtype=torch.float32, device=device)
with torch.autocast(device, dtype=torch.float16):
y = torch.ops._torch_testing.my_sin(x)
self.assertEqual(y.dtype, torch.float16)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_library_register_autocast_list_input(self):
for device in ["cuda", "cpu"]:
for mode in ["function", "qualname", "opoverload"]:
@torch.library.custom_op("mylib::my_add_sin", mutates_args=())
def my_add_sin(x: List[Tensor]) -> Tensor:
return torch.sin(x[0] + x[1])
if mode == "function":
torch.library.register_autocast(my_add_sin, device, torch.float16)
elif mode == "qualname":
torch.library.register_autocast(
"mylib::my_add_sin", device, torch.float16
)
elif mode == "opoverload":
torch.library.register_autocast(
torch.ops.mylib.my_add_sin.default, device, torch.float16
)
lst = [
torch.randn(3, dtype=torch.float32, device=device) for _ in range(2)
]
with torch.autocast(device, dtype=torch.float16):
y = torch.ops.mylib.my_add_sin(lst)
self.assertEqual(y.dtype, torch.float16)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_library_register_autocast_multiple_times(self):
for device in ["cuda", "cpu"]:
@torch.library.custom_op("mylib::my_sin", mutates_args=())
def my_sin(x: Tensor) -> Tensor:
return torch.sin(x)
torch.library.register_autocast(my_sin, device, torch.float16)
x = torch.randn(3, dtype=torch.float32, device=device)
with torch.autocast(device, dtype=torch.float16):
y1 = my_sin(x)
self.assertEqual(y1.dtype, torch.float16)
# Ensure calling register_autocast multiple times does not error out.
torch.library.register_autocast(my_sin, device, torch.float16)
with torch.autocast(device, dtype=torch.float16):
y2 = my_sin(x)
self.assertEqual(y2.dtype, torch.float16)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
@unittest.skipIf(not TEST_CUDA, "requires CUDA")
def test_library_register_autocast_multiple_times_different_devices(self):
@torch.library.custom_op("mylib::my_sin", mutates_args=())
def my_sin(x: Tensor) -> Tensor:
return torch.sin(x)
# Register autocast for CUDA
torch.library.register_autocast(my_sin, "cuda", torch.float16)
x1 = torch.randn(3, dtype=torch.float32, device="cuda")
with torch.autocast("cuda", dtype=torch.float16):
y1 = my_sin(x1)
self.assertEqual(y1.dtype, torch.float16)
# Register autocast for CPU
torch.library.register_autocast(my_sin, "cpu", torch.float16)
x2 = torch.randn(3, dtype=torch.float32, device="cpu")
with torch.autocast("cpu", dtype=torch.float16):
y2 = my_sin(x2)
self.assertEqual(y2.dtype, torch.float16)
# Register CUDA autocast for the second time
torch.library.register_autocast(my_sin, "cuda", torch.float16)
with torch.autocast("cuda", dtype=torch.float16):
y3 = my_sin(x1)
self.assertEqual(y3.dtype, torch.float16)
# Register CPU autocast for the second time
torch.library.register_autocast(my_sin, "cpu", torch.float16)
with torch.autocast("cpu", dtype=torch.float16):
y4 = my_sin(x2)
self.assertEqual(y4.dtype, torch.float16)
@skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug")
def test_library_register_autograd(self):
for mode in ["function", "qualname", "opoverload"]:

View File

@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
import collections
import inspect
import logging
import weakref
@ -8,6 +9,7 @@ from typing import Any, Callable, Literal, Optional, overload, Union
import torch
from torch import _C, _ops, Tensor
from torch.types import _dtype
from torch.utils._exposed_in import exposed_in
from . import autograd, utils
@ -195,6 +197,8 @@ class CustomOpDef:
self._backward_fn: Optional[Callable] = None
self._torch_dispatch_fns: dict[type, Callable] = {}
self._vmap_fn: Optional[Callable] = None
self._autocast_cuda_dtype: Optional[_dtype] = None
self._autocast_cpu_dtype: Optional[_dtype] = None
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
self._register_to_dispatcher()
@ -763,6 +767,96 @@ class CustomOpDef:
else:
return register(func)
def register_autocast(
self,
device_type: str,
cast_inputs: _dtype,
):
r"""Register an autocast dispatch rule for this custom op.
Valid `device_type` include: "cpu" and "cuda".
Args:
op (str | OpOverload): The operator to register an autocast dispatch rule to.
device_type(str): Device type to use. 'cuda' or 'cpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
are not affected), then executes custom op with autocast disabled.
lib (Optional[Library]): If provided, the lifetime of this registration
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>>
>>> # Create a custom op that works on cuda
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
>>> def my_sin(x: Tensor) -> Tensor:
>>> return torch.sin(x)
>>>
>>> # Register autocast dispatch rule for the cuda device
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
>>>
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
>>> with torch.autocast("cuda", dtype=torch.float16):
>>> y = torch.ops.mylib.my_sin(x)
>>> assert y.dtype == torch.float16
"""
if not isinstance(device_type, str):
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if device_type not in ["cpu", "cuda"]:
raise ValueError(f"Unknown device type: {device_type}")
need_register_cuda = self._autocast_cuda_dtype is None
need_register_cpu = self._autocast_cpu_dtype is None
if device_type == "cuda":
self._autocast_cuda_dtype = cast_inputs
else:
self._autocast_cpu_dtype = cast_inputs
def kernel(_, *args, **kwargs):
assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
autocast_keyset = torch._C.DispatchKeySet(
torch._C.DispatchKey.AutocastCPU
) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
return self._opoverload(*_cast(args, device_type, cast_inputs))
if need_register_cuda and self._autocast_cuda_dtype:
self._lib.impl(self._name, kernel, "AutocastCUDA", with_keyset=True)
elif need_register_cpu and self._autocast_cpu_dtype:
self._lib.impl(self._name, kernel, "AutocastCPU", with_keyset=True)
return kernel
# TODO: Merge this function with torch.amp.autocast_mode._cast, and refactor it
# into a utility function once custom ops support arbitrary input types.
def _cast(value, device_type: str, dtype: _dtype):
if isinstance(value, torch.Tensor):
is_eligible = (
value.is_floating_point()
and value.device.type == device_type
and (value.dtype is not torch.float64)
)
return value.to(dtype) if is_eligible else value
elif isinstance(value, (str, bytes)):
return value
elif isinstance(value, collections.abc.Iterable):
iterable = (_cast(v, device_type, dtype) for v in value)
if isinstance(value, (list, tuple)):
return type(value)(iterable)
else:
return iterable
else:
return value
def increment_version(val: Any) -> None:
if isinstance(val, Tensor):

View File

@ -22,6 +22,7 @@ from typing_extensions import deprecated, ParamSpec
import torch
import torch._library as _library
from torch._library.custom_ops import (
_cast,
_maybe_get_opdef,
custom_op,
CustomOpDef,
@ -30,6 +31,7 @@ from torch._library.custom_ops import (
from torch._library.infer_schema import infer_schema # noqa: F401
from torch._library.triton import triton_op, wrap_triton
from torch._ops import OpOverload
from torch.types import _dtype
__all__ = [
@ -38,6 +40,7 @@ __all__ = [
"define",
"fallthrough_kernel",
"impl_abstract",
"register_autocast",
"register_fake",
"register_torch_dispatch",
"register_vmap",
@ -823,6 +826,87 @@ def register_kernel(
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
def register_autocast(
op: _op_identifier,
device_type: str,
cast_inputs: _dtype,
/,
*,
lib: Optional[Library] = None,
):
r"""Register an autocast dispatch rule for this custom op.
Valid `device_type` include: "cpu" and "cuda".
Args:
op (str | OpOverload): The operator to register an autocast dispatch rule to.
device_type(str): Device type to use. 'cuda' or 'cpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
are not affected), then executes custom op with autocast disabled.
lib (Optional[Library]): If provided, the lifetime of this registration
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>>
>>> # Create a custom op that works on cuda
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
>>> def my_sin(x: Tensor) -> Tensor:
>>> return torch.sin(x)
>>>
>>> # Register autocast dispatch rule for the cuda device
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
>>>
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
>>> with torch.autocast("cuda", dtype=torch.float16):
>>> y = torch.ops.mylib.my_sin(x)
>>> assert y.dtype == torch.float16
"""
if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError(
f"register_autocast(op): got unexpected type for op: {type(op)}"
)
if device_type not in ["cpu", "cuda"]:
raise ValueError(f"Unknown device type: {device_type}")
if isinstance(op, torch._ops.OpOverload):
op = op._name
opdef = _maybe_get_opdef(op)
if opdef is not None:
return opdef.register_autocast(device_type, cast_inputs)
assert isinstance(op, str)
qualname = op
_op = torch._library.utils.lookup_op(qualname)
namespace, opname = torch._library.utils.parse_namespace(qualname)
if lib is None:
lib = Library(namespace, "FRAGMENT")
_keep_alive.append(lib)
def kernel(_, *args, **kwargs):
assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
autocast_keyset = torch._C.DispatchKeySet(
torch._C.DispatchKey.AutocastCPU
) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
return _op(*_cast(args, device_type, cast_inputs))
if device_type == "cuda":
return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True)
else:
# device_type is "cpu"
return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True)
def register_fake(
op: _op_identifier,
func: Optional[Callable] = None,