mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[triton_op] Skip HOP dispatch when possible (#132822)
The capture_triton decorator returns a function that goes through the triton kernel wrapper HOP. This is useful for make_fx tracing and non-strict export. However, the HOP dispatch is slow (~1ms) and not necessary in certain situations. This PR skips going through the HOP dispatch for any capture_triton-wrapped triton kernels that are registered as implementations to a `@triton_op` custom operator. We do this by creating a new thread-local flag that controls if the capture_trition-wrapped triton kernel goes through HOP dispatch or not. Test Plan: - new test and existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/132822 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
@ -2412,6 +2412,62 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
|
||||
expected = torch.empty_like(x)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
@requires_gpu
|
||||
def test_capture_triton_disabled_in_triton_op(self):
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
add_kernel_decorated = torch._library.capture_triton(add_kernel)
|
||||
|
||||
status = []
|
||||
|
||||
@torch._library.triton_op("mylib::add", mutates_args=())
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
import torch._higher_order_ops.triton_kernel_wrap
|
||||
|
||||
status.append(torch._library.triton.is_capture_triton_enabled())
|
||||
|
||||
# capture_triton should return the kernel directly if disabled
|
||||
result = torch._library.capture_triton(add_kernel)
|
||||
self.assertIs(result, add_kernel)
|
||||
|
||||
# Smoke test: check that with capture_triton disabled this still does something
|
||||
output = torch.empty_like(x)
|
||||
output2 = torch.empty_like(x)
|
||||
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
add_kernel_decorated[grid](x, y, output, n_elements, BLOCK_SIZE=16)
|
||||
|
||||
add_kernel_decorated.run(
|
||||
x, y, output2, n_elements, BLOCK_SIZE=16, grid=grid, warmup=False
|
||||
)
|
||||
|
||||
return output + output2
|
||||
|
||||
x = torch.randn(3, device=GPU_TYPE)
|
||||
y = torch.randn(3, device=GPU_TYPE)
|
||||
z = add(x, y)
|
||||
self.assertEqual(status[-1], False)
|
||||
self.assertEqual(z, (x + y) * 2)
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
@common_utils.parametrize("autotune", [False, True])
|
||||
|
@ -1004,9 +1004,21 @@ class TraceableTritonKernelWrapper:
|
||||
return tracing_triton_hopifier_singleton.call_getitem(self, args)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
|
||||
from torch._library.triton import is_capture_triton_enabled
|
||||
|
||||
if is_capture_triton_enabled():
|
||||
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
return self.kernel.run(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return tracing_triton_hopifier_singleton.call_triton_kernel(
|
||||
self, args, kwargs, None
|
||||
)
|
||||
from torch._library.triton import is_capture_triton_enabled
|
||||
|
||||
if is_capture_triton_enabled():
|
||||
return tracing_triton_hopifier_singleton.call_triton_kernel(
|
||||
self, args, kwargs, None
|
||||
)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
return self.kernel[self.grid](*args, **kwargs)
|
||||
|
@ -1,6 +1,9 @@
|
||||
from typing import Callable, Iterable, Optional, Union
|
||||
import contextlib
|
||||
import threading
|
||||
from typing import Callable, Generator, Iterable, Optional, Union
|
||||
|
||||
from .custom_ops import custom_op
|
||||
from .infer_schema import infer_schema
|
||||
|
||||
|
||||
def triton_op(
|
||||
@ -90,7 +93,18 @@ def triton_op(
|
||||
"""
|
||||
|
||||
def dec(fn: Callable) -> Callable:
|
||||
result = custom_op(name, fn, mutates_args=mutates_args)
|
||||
def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
|
||||
# Optimization: we're passing regular Tensors into the triton kernel, so
|
||||
# no need to go through HOP dispatch
|
||||
with set_capture_triton_enabled(False):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
result = custom_op(
|
||||
name,
|
||||
backend_fn,
|
||||
mutates_args=mutates_args,
|
||||
schema=infer_schema(fn, mutates_args=mutates_args),
|
||||
)
|
||||
from .._subclasses.functional_tensor import FunctionalTensorMode
|
||||
|
||||
# We require that the user pass us a function that is make_fx traceable,
|
||||
@ -120,6 +134,31 @@ def triton_op(
|
||||
return dec(fn)
|
||||
|
||||
|
||||
capture_triton_enabled = threading.local()
|
||||
capture_triton_enabled_default = True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_capture_triton_enabled(enabled: bool) -> Generator[None, None, None]:
|
||||
"""If triton kernels annotated with @capture_triton should dispatch via HOP
|
||||
or go straight to the triton kernel execution.
|
||||
|
||||
We have this switch because eager-mode performance of HOP dispatch is slow
|
||||
enough to matter (~1ms) and we know that capture_triton isn't necessary in
|
||||
some situations (eager-mode with regular Tensors)
|
||||
"""
|
||||
try:
|
||||
prev = is_capture_triton_enabled()
|
||||
capture_triton_enabled.value = enabled
|
||||
yield
|
||||
finally:
|
||||
capture_triton_enabled.value = prev
|
||||
|
||||
|
||||
def is_capture_triton_enabled() -> bool:
|
||||
return getattr(capture_triton_enabled, "value", capture_triton_enabled_default)
|
||||
|
||||
|
||||
def capture_triton(triton_kernel: Callable, /) -> Callable:
|
||||
"""Allows capture of a triton kernel into a graph via make_fx or
|
||||
non-strict export (coming soon).
|
||||
@ -189,4 +228,6 @@ def capture_triton(triton_kernel: Callable, /) -> Callable:
|
||||
raise RuntimeError(
|
||||
"capture_triton only works on functions annotated with triton.jit or triton.autotune"
|
||||
)
|
||||
if not is_capture_triton_enabled():
|
||||
return triton_kernel
|
||||
return TraceableTritonKernelWrapper(triton_kernel, None, None)
|
||||
|
Reference in New Issue
Block a user