mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Custom Ops] Add a new API to allow users to register an autocast for the custom op (#145588)
Fixes #137033 Pull Request resolved: https://github.com/pytorch/pytorch/pull/145588 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
f951d216e0
commit
ec91b7720f
@ -22,6 +22,7 @@ from typing_extensions import deprecated, ParamSpec
|
||||
import torch
|
||||
import torch._library as _library
|
||||
from torch._library.custom_ops import (
|
||||
_cast,
|
||||
_maybe_get_opdef,
|
||||
custom_op,
|
||||
CustomOpDef,
|
||||
@ -30,6 +31,7 @@ from torch._library.custom_ops import (
|
||||
from torch._library.infer_schema import infer_schema # noqa: F401
|
||||
from torch._library.triton import triton_op, wrap_triton
|
||||
from torch._ops import OpOverload
|
||||
from torch.types import _dtype
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -38,6 +40,7 @@ __all__ = [
|
||||
"define",
|
||||
"fallthrough_kernel",
|
||||
"impl_abstract",
|
||||
"register_autocast",
|
||||
"register_fake",
|
||||
"register_torch_dispatch",
|
||||
"register_vmap",
|
||||
@ -823,6 +826,87 @@ def register_kernel(
|
||||
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
|
||||
|
||||
|
||||
def register_autocast(
|
||||
op: _op_identifier,
|
||||
device_type: str,
|
||||
cast_inputs: _dtype,
|
||||
/,
|
||||
*,
|
||||
lib: Optional[Library] = None,
|
||||
):
|
||||
r"""Register an autocast dispatch rule for this custom op.
|
||||
|
||||
Valid `device_type` include: "cpu" and "cuda".
|
||||
|
||||
Args:
|
||||
op (str | OpOverload): The operator to register an autocast dispatch rule to.
|
||||
device_type(str): Device type to use. 'cuda' or 'cpu'.
|
||||
The type is the same as the `type` attribute of a :class:`torch.device`.
|
||||
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
|
||||
cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region,
|
||||
casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors
|
||||
are not affected), then executes custom op with autocast disabled.
|
||||
lib (Optional[Library]): If provided, the lifetime of this registration
|
||||
|
||||
Examples::
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> import torch
|
||||
>>> from torch import Tensor
|
||||
>>> from torch.library import custom_op
|
||||
>>>
|
||||
>>> # Create a custom op that works on cuda
|
||||
>>> @torch.library.custom_op("mylib::my_sin", mutates_args=())
|
||||
>>> def my_sin(x: Tensor) -> Tensor:
|
||||
>>> return torch.sin(x)
|
||||
>>>
|
||||
>>> # Register autocast dispatch rule for the cuda device
|
||||
>>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16)
|
||||
>>>
|
||||
>>> x = torch.randn(3, dtype=torch.float32, device="cuda")
|
||||
>>> with torch.autocast("cuda", dtype=torch.float16):
|
||||
>>> y = torch.ops.mylib.my_sin(x)
|
||||
>>> assert y.dtype == torch.float16
|
||||
|
||||
"""
|
||||
if not isinstance(
|
||||
op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
|
||||
):
|
||||
raise ValueError(
|
||||
f"register_autocast(op): got unexpected type for op: {type(op)}"
|
||||
)
|
||||
if device_type not in ["cpu", "cuda"]:
|
||||
raise ValueError(f"Unknown device type: {device_type}")
|
||||
|
||||
if isinstance(op, torch._ops.OpOverload):
|
||||
op = op._name
|
||||
opdef = _maybe_get_opdef(op)
|
||||
if opdef is not None:
|
||||
return opdef.register_autocast(device_type, cast_inputs)
|
||||
|
||||
assert isinstance(op, str)
|
||||
qualname = op
|
||||
_op = torch._library.utils.lookup_op(qualname)
|
||||
|
||||
namespace, opname = torch._library.utils.parse_namespace(qualname)
|
||||
if lib is None:
|
||||
lib = Library(namespace, "FRAGMENT")
|
||||
_keep_alive.append(lib)
|
||||
|
||||
def kernel(_, *args, **kwargs):
|
||||
assert len(kwargs) == 0, "Custom ops do not support kwargs yet."
|
||||
autocast_keyset = torch._C.DispatchKeySet(
|
||||
torch._C.DispatchKey.AutocastCPU
|
||||
) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA)
|
||||
with torch._C._ExcludeDispatchKeyGuard(autocast_keyset):
|
||||
return _op(*_cast(args, device_type, cast_inputs))
|
||||
|
||||
if device_type == "cuda":
|
||||
return lib.impl(opname, kernel, "AutocastCUDA", with_keyset=True)
|
||||
else:
|
||||
# device_type is "cpu"
|
||||
return lib.impl(opname, kernel, "AutocastCPU", with_keyset=True)
|
||||
|
||||
|
||||
def register_fake(
|
||||
op: _op_identifier,
|
||||
func: Optional[Callable] = None,
|
||||
|
Reference in New Issue
Block a user