[triton_op] Skip HOP dispatch when possible (#132822)

The capture_triton decorator returns a function that goes through the
triton kernel wrapper HOP. This is useful for make_fx tracing and
non-strict export. However, the HOP dispatch is slow (~1ms) and not
necessary in certain situations.

This PR skips going through the HOP dispatch for any
capture_triton-wrapped triton kernels that are registered as
implementations to a `@triton_op` custom operator. We do this by
creating a new thread-local flag that controls if the
capture_trition-wrapped triton kernel goes through HOP dispatch or not.

Test Plan:
- new test and existing tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132822
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
rzou
2024-08-06 18:05:01 -07:00
committed by PyTorch MergeBot
parent b885ad8fce
commit 4ee5547b37
3 changed files with 115 additions and 6 deletions

View File

@ -2412,6 +2412,62 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
expected = torch.empty_like(x)
self.assertEqual(out, expected)
@requires_gpu
def test_capture_triton_disabled_in_triton_op(self):
import triton
import triton.language as tl
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
add_kernel_decorated = torch._library.capture_triton(add_kernel)
status = []
@torch._library.triton_op("mylib::add", mutates_args=())
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
import torch._higher_order_ops.triton_kernel_wrap
status.append(torch._library.triton.is_capture_triton_enabled())
# capture_triton should return the kernel directly if disabled
result = torch._library.capture_triton(add_kernel)
self.assertIs(result, add_kernel)
# Smoke test: check that with capture_triton disabled this still does something
output = torch.empty_like(x)
output2 = torch.empty_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_decorated[grid](x, y, output, n_elements, BLOCK_SIZE=16)
add_kernel_decorated.run(
x, y, output2, n_elements, BLOCK_SIZE=16, grid=grid, warmup=False
)
return output + output2
x = torch.randn(3, device=GPU_TYPE)
y = torch.randn(3, device=GPU_TYPE)
z = add(x, y)
self.assertEqual(status[-1], False)
self.assertEqual(z, (x + y) * 2)
@requires_gpu
@common_utils.parametrize("dynamic", [False, True])
@common_utils.parametrize("autotune", [False, True])

View File

@ -1004,9 +1004,21 @@ class TraceableTritonKernelWrapper:
return tracing_triton_hopifier_singleton.call_getitem(self, args)
def run(self, *args, **kwargs):
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
from torch._library.triton import is_capture_triton_enabled
if is_capture_triton_enabled():
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
else:
assert self.kernel is not None
return self.kernel.run(*args, **kwargs)
def __call__(self, *args, **kwargs):
return tracing_triton_hopifier_singleton.call_triton_kernel(
self, args, kwargs, None
)
from torch._library.triton import is_capture_triton_enabled
if is_capture_triton_enabled():
return tracing_triton_hopifier_singleton.call_triton_kernel(
self, args, kwargs, None
)
else:
assert self.kernel is not None
return self.kernel[self.grid](*args, **kwargs)

View File

@ -1,6 +1,9 @@
from typing import Callable, Iterable, Optional, Union
import contextlib
import threading
from typing import Callable, Generator, Iterable, Optional, Union
from .custom_ops import custom_op
from .infer_schema import infer_schema
def triton_op(
@ -90,7 +93,18 @@ def triton_op(
"""
def dec(fn: Callable) -> Callable:
result = custom_op(name, fn, mutates_args=mutates_args)
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
# Optimization: we're passing regular Tensors into the triton kernel, so
# no need to go through HOP dispatch
with set_capture_triton_enabled(False):
return fn(*args, **kwargs)
result = custom_op(
name,
backend_fn,
mutates_args=mutates_args,
schema=infer_schema(fn, mutates_args=mutates_args),
)
from .._subclasses.functional_tensor import FunctionalTensorMode
# We require that the user pass us a function that is make_fx traceable,
@ -120,6 +134,31 @@ def triton_op(
return dec(fn)
capture_triton_enabled = threading.local()
capture_triton_enabled_default = True
@contextlib.contextmanager
def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]:
"""If triton kernels annotated with @capture_triton should dispatch via HOP
or go straight to the triton kernel execution.
We have this switch because eager-mode performance of HOP dispatch is slow
enough to matter (~1ms) and we know that capture_triton isn't necessary in
some situations (eager-mode with regular Tensors)
"""
try:
prev = is_capture_triton_enabled()
capture_triton_enabled.value = enabled
yield
finally:
capture_triton_enabled.value = prev
def is_capture_triton_enabled() -> bool:
return getattr(capture_triton_enabled, "value", capture_triton_enabled_default)
def capture_triton(triton_kernel: Callable, /) -> Callable:
"""Allows capture of a triton kernel into a graph via make_fx or
non-strict export (coming soon).
@ -189,4 +228,6 @@ def capture_triton(triton_kernel: Callable, /) -> Callable:
raise RuntimeError(
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
)
if not is_capture_triton_enabled():
return triton_kernel
return TraceableTritonKernelWrapper(triton_kernel, None, None)