Compare commits

...

2 Commits

Author SHA1 Message Date
ccb10f2203 Improve torch.library.opcheck and register_autograd docs
Fixes https://github.com/pytorch/pytorch/issues/141618

ghstack-source-id: 37ab54441b1d7d2913d49a88c4a167099ed2cf3b
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141883
2024-12-02 11:21:03 -05:00
f57f77fe93 Make torch.library.triton_op public
We've been using it privately for half a year and everything's been
good. This PR:
1. Makes torch.library.triton_op public
2. Renames capture_triton -> wrap_triton. We got feedback that no one
   knew what "capture triton" does.
3. Makes torch.library.wrap_triton public.
4. Improves some typing.

triton_op is used to construct a Python custom operator that may call 1+
triton kernels. Each of those triton kernels must be annotated with
wrap_triton.

Test Plan:
- existing tests

ghstack-source-id: a56bbd836aaf682bc59281170e9c478d795c8886
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141880
2024-12-02 11:20:59 -05:00
5 changed files with 86 additions and 40 deletions

View File

@ -1002,7 +1002,7 @@ class TritonHOPifier:
2. In order to capture a user-defined triton kernel while performing
tracing (via make_fx or non-strict export), a user must annotate their
triton kernel with the `capture_triton` decorator. The decorator uses
triton kernel with the `wrap_triton` decorator. The decorator uses
TritonHOPifier to convert calls to the triton kernel into a call
to the HOP (which can then be traced).
@ -1299,7 +1299,7 @@ class TritonHOPifier:
###############################################################################
# Helpers for capture_triton API that makes a user-defined triton kernel traceable into
# Helpers for wrap_triton API that makes a user-defined triton kernel traceable into
# a graph via make_fx or non-strict export (coming soon)
@ -1330,7 +1330,7 @@ class TracingTritonHOPifier(TritonHOPifier):
) -> Tuple[Union[int, sympy.Expr, SymInt], ...]:
if not isinstance(grid, collections.abc.Sequence):
raise RuntimeError(
"capture_triton can only handle grids that resolve to Sequence[int]."
"wrap_triton can only handle grids that resolve to Sequence[int]."
)
# normalize to tuple
return tuple(grid)
@ -1389,18 +1389,18 @@ class TraceableTritonKernelWrapper:
return tracing_triton_hopifier_singleton.call_getitem(self, args) # type: ignore[return-value]
def run(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> Any:
from torch._library.triton import is_capture_triton_enabled
from torch._library.triton import is_wrap_triton_enabled
if is_capture_triton_enabled():
if is_wrap_triton_enabled():
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
else:
assert self.kernel is not None
return self.kernel.run(*args, **kwargs)
def __call__(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> Any:
from torch._library.triton import is_capture_triton_enabled
from torch._library.triton import is_wrap_triton_enabled
if is_capture_triton_enabled():
if is_wrap_triton_enabled():
return tracing_triton_hopifier_singleton.call_triton_kernel(
self, args, kwargs, None
)

View File

@ -26,7 +26,7 @@ def custom_op(
mutates_args: Union[str, Iterable[str]],
device_types: device_types_t = None,
schema: Optional[str] = None,
) -> Callable:
) -> Any:
"""Wraps a function into custom operator.
Reasons why you may want to create a custom op include:
@ -494,6 +494,10 @@ class CustomOpDef:
not depend on or mutate global state. If you need a non-traceable backward,
you can make it a separate custom_op that you call inside ``backward_fn``.
If you need different autograd behavior on different devices, then we
recommend creating two different custom operators, one for each device
that needs different behavior, and switching between them at runtime.
Examples:
>>> import torch
>>> import numpy as np

View File

@ -1,11 +1,14 @@
import contextlib
import threading
from typing import Callable, Generator, Iterable, Optional, Union
from typing import Any, Callable, Generator, Iterable, Optional, Protocol, Union
from torch.utils._exposed_in import exposed_in
from .custom_ops import custom_op
from .infer_schema import infer_schema
@exposed_in("torch.library")
def triton_op(
name: str,
fn: Optional[Callable] = None,
@ -16,7 +19,18 @@ def triton_op(
) -> Callable:
"""Create a custom operator whose implementation is backed by 1+ triton kernels.
Use this instead of :func:`torch.library.custom_op` when the implementation
This is a more structured way of using triton kernels with PyTorch.
Prefer using triton kernels with no ``torch.library`` custom operator wrappers
(like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because
that is simpler;
only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you
want to create an operator that behaves like PyTorch built-in operators.
For example, you may use a ``torch.library`` wrapper API to define the
behavior of the triton kernel when passed a tensor subclass or under
a TorchDispatchMode.
Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op`
when the implementation
consists of 1+ triton kernels. :func:`torch.library.custom_op` treats
custom operators as opaque (:func:`torch.compile` and
:func:`torch.export.export` will never trace into them), but ``triton_op``
@ -25,7 +39,7 @@ def triton_op(
Note that ``fn`` must only consist of calls to PyTorch-understood
operators and triton kernels. Any triton kernels called inside ``fn``
must be wrapped in a call to :func:`torch._library.capture_triton``.
must be wrapped in a call to :func:`torch._library.wrap_triton``.
Args:
name (str): A name for the custom op that looks like "{namespace}::{name}",
@ -46,7 +60,7 @@ def triton_op(
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> import torch
>>> from torch._library import triton_op, capture_triton
>>> from torch._library import triton_op, wrap_triton
>>>
>>> import triton
>>> from triton import language as tl
@ -76,8 +90,8 @@ def triton_op(
>>> def grid(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> # NB: we need to wrap the triton kernel in a call to capture_triton
>>> capture_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>> # NB: we need to wrap the triton kernel in a call to wrap_triton
>>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
>>> return output
>>>
>>> @torch.compile
@ -96,7 +110,7 @@ def triton_op(
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
# Optimization: we're passing regular Tensors into the triton kernel, so
# no need to go through HOP dispatch
with set_capture_triton_enabled(False):
with set_wrap_triton_enabled(False):
return fn(*args, **kwargs)
result = custom_op(
@ -134,39 +148,55 @@ def triton_op(
return dec(fn)
capture_triton_enabled = threading.local()
capture_triton_enabled_default = True
wrap_triton_enabled = threading.local()
wrap_triton_enabled_default = True
@contextlib.contextmanager
def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]:
"""If triton kernels annotated with @capture_triton should dispatch via HOP
def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]:
"""If triton kernels annotated with @wrap_triton should dispatch via HOP
or go straight to the triton kernel execution.
We have this switch because eager-mode performance of HOP dispatch is slow
enough to matter (~1ms) and we know that capture_triton isn't necessary in
enough to matter (~1ms) and we know that wrap_triton isn't necessary in
some situations (eager-mode with regular Tensors)
"""
try:
prev = is_capture_triton_enabled()
capture_triton_enabled.value = enabled
prev = is_wrap_triton_enabled()
wrap_triton_enabled.value = enabled
yield
finally:
capture_triton_enabled.value = prev
wrap_triton_enabled.value = prev
def is_capture_triton_enabled() -> bool:
return getattr(capture_triton_enabled, "value", capture_triton_enabled_default)
def is_wrap_triton_enabled() -> bool:
return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default)
def capture_triton(triton_kernel: Callable, /) -> Callable:
class TritonKernelLike(Protocol):
def __getitem__(self, grid: Any) -> "TritonKernelLike":
pass
def __call__(self, *args: Any, **kwargs: Any) -> None:
pass
def capture_triton(triton_kernel: Callable, /) -> TritonKernelLike:
"""This API has been renamed to wrap_triton"""
return wrap_triton(triton_kernel)
@exposed_in("torch.library")
def wrap_triton(triton_kernel: Callable, /) -> TritonKernelLike:
"""Allows capture of a triton kernel into a graph via make_fx or
non-strict export (coming soon).
non-strict ``torch.export``.
These technologies perform Dispatcher-based tracing (via
``__torch_dispatch__``) and cannot see calls to raw triton kernels.
The ``capture_triton`` API returns a new callable that can actually
be traced into a graph.
The ``wrap_triton`` API wraps a triton kernel into a callable that
can actually be traced into a graph.
Please use this API together with :func:`torch.library.triton_op`.
Examples:
@ -175,7 +205,7 @@ def capture_triton(triton_kernel: Callable, /) -> Callable:
>>> import triton
>>> from triton import language as tl
>>> from torch.fx.experimental.proxy_tensor import make_fx
>>> from torch._higher_order_ops.triton_kernel_wrap import capture_triton
>>> from torch.library import wrap_triton
>>>
>>> @triton.jit
>>> def add_kernel(
@ -201,7 +231,7 @@ def capture_triton(triton_kernel: Callable, /) -> Callable:
>>> def grid_fn(meta):
>>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
>>>
>>> capture_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16)
>>> return output
>>>
>>> x = torch.randn(3, device="cuda")
@ -226,8 +256,8 @@ def capture_triton(triton_kernel: Callable, /) -> Callable:
if not isinstance(triton_kernel, (JITFunction, Autotuner)):
raise RuntimeError(
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
"wrap_triton only works on functions annotated with triton.jit or triton.autotune"
)
if not is_capture_triton_enabled():
if not is_wrap_triton_enabled():
return triton_kernel
return TraceableTritonKernelWrapper(triton_kernel, None, None)

View File

@ -18,6 +18,7 @@ from torch._library.custom_ops import (
device_types_t,
)
from torch._library.infer_schema import infer_schema # noqa: F401
from torch._library.triton import triton_op, wrap_triton
from torch._ops import OpOverload
@ -32,6 +33,8 @@ __all__ = [
"register_vmap",
"get_ctx",
"custom_op",
"triton_op",
"wrap_triton",
"infer_schema",
]
@ -868,6 +871,10 @@ def register_autograd(
not depend on or mutate global state. If you need a non-traceable backward,
you can make it a separate custom_op that you call inside ``backward_fn``.
If you need different autograd behavior on different devices, then we
recommend creating two different custom operators, one for each device
that needs different behavior, and switching between them at runtime.
Examples:
>>> import torch
>>> import numpy as np
@ -1318,12 +1325,12 @@ def opcheck(
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
>>> def numpy_add(x: Tensor, y: float) -> Tensor:
>>> def numpy_mul(x: Tensor, y: float) -> Tensor:
>>> x_np = x.numpy(force=True)
>>> z_np = x_np + y
>>> return torch.from_numpy(z_np).to(x.device)
>>>
>>> @numpy_sin.register_fake
>>> @numpy_mul.register_fake
>>> def _(x, y):
>>> return torch.empty_like(x)
>>>
@ -1334,7 +1341,7 @@ def opcheck(
>>> def backward(ctx, grad):
>>> return grad * ctx.y, None
>>>
>>> numpy_sin.register_autograd(backward, setup_context=setup_context)
>>> numpy_mul.register_autograd(backward, setup_context=setup_context)
>>>
>>> sample_inputs = [
>>> (torch.randn(3), 3.14),
@ -1344,7 +1351,7 @@ def opcheck(
>>> ]
>>>
>>> for args in sample_inputs:
>>> torch.library.opcheck(foo, args)
>>> torch.library.opcheck(numpy_mul, args)
"""
import torch.testing._internal.optests as optests

View File

@ -1,4 +1,9 @@
# mypy: allow-untyped-defs
from typing import Callable, TypeVar
F = TypeVar("F")
# Allows one to expose an API in a private submodule publicly as per the definition
# in PyTorch's public api policy.
#
@ -7,8 +12,8 @@
# may not be very robust because it's not clear what __module__ is used for.
# However, both numpy and jax overwrite the __module__ attribute of their APIs
# without problem, so it seems fine.
def exposed_in(module):
def wrapper(fn):
def exposed_in(module: str) -> Callable[[F], F]:
def wrapper(fn: F) -> F:
fn.__module__ = module
return fn