[Custom Ops] Add a new API to allow users to register an autocast for the custom op (#145588)

Fixes #137033

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145588
Approved by: https://github.com/zou3519
This commit is contained in:
Yanbo Liang
2025-01-24 20:08:45 -08:00
committed by PyTorch MergeBot
parent f951d216e0
commit ec91b7720f
4 changed files with 323 additions and 0 deletions

View File

@ -22,6 +22,7 @@ from typing_extensions import deprecated, ParamSpec
import torch
import torch._library as _library
from torch._library.custom_ops import (
_cast,
_maybe_get_opdef,
custom_op,
CustomOpDef,
@ -30,6 +31,7 @@ from torch._library.custom_ops import (
from torch._library.infer_schema import infer_schema # noqa: F401
from torch._library.triton import triton_op, wrap_triton
from torch._ops import OpOverload
from torch.types import _dtype
__all__ = [
@ -38,6 +40,7 @@ __all__ = [
"define",
"fallthrough_kernel",
"impl_abstract",
"register_autocast",
"register_fake",
"register_torch_dispatch",
"register_vmap",
@ -823,6 +826,87 @@ def register_kernel(
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
def register_autocast(
op: _op_identifier,
device_type: str,
cast_inputs: _dtype,
/,
*,
lib: Optional[Library] = None,
):
r"""Register an autocast dispatch rule for this custom op.
Valid `device_type` include: "cpu" and "cuda".
Args:
op (str | OpOverload): The operator to register an autocast dispatch rule to.
device_type(str): Device type to use. 'cuda' or 'cpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
are not affected), then executes custom op with autocast disabled.
lib (Optional[Library]): If provided, the lifetime of this registration
Examples::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch import Tensor
>>> from torch.library import custom_op
>>>
>>> # Create a custom op that works on cuda
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
>>> def my_sin(x: Tensor) -> Tensor:
>>> return torch.sin(x)
>>>
>>> # Register autocast dispatch rule for the cuda device
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
>>>
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
>>> with torch.autocast("cuda", dtype=torch.float16):
>>> y = torch.ops.mylib.my_sin(x)
>>> assert y.dtype == torch.float16
"""
if not isinstance(
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
):
raise ValueError(
f"register_autocast(op): got unexpected type for op: {type(op)}"
)
if device_type not in ["cpu", "cuda"]:
raise ValueError(f"Unknown device type: {device_type}")
if isinstance(op, torch._ops.OpOverload):
op = op._name
opdef = _maybe_get_opdef(op)
if opdef is not None:
return opdef.register_autocast(device_type, cast_inputs)
assert isinstance(op, str)
qualname = op
_op = torch._library.utils.lookup_op(qualname)
namespace, opname = torch._library.utils.parse_namespace(qualname)
if lib is None:
lib = Library(namespace, "FRAGMENT")
_keep_alive.append(lib)
def kernel(_, *args, **kwargs):
assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
autocast_keyset = torch._C.DispatchKeySet(
torch._C.DispatchKey.AutocastCPU
) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
return _op(*_cast(args, device_type, cast_inputs))
if device_type == "cuda":
return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True)
else:
# device_type is "cpu"
return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True)
def register_fake(
op: _op_identifier,
func: Optional[Callable] = None,