Make user defined Triton kernels serializable for fx_graph_runnable (#160002)

Resolves issue https://github.com/pytorch/pytorch/issues/153475 where `fx_graph_runnable` didn't work with user defined triton kernels.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160002
Approved by: https://github.com/eellison
This commit is contained in:
PaulZhang12
2025-08-08 13:07:09 -07:00
committed by PyTorch MergeBot
parent fb887c3bb5
commit 4183d4ff3d
2 changed files with 154 additions and 0 deletions

View File

@ -11,12 +11,65 @@ import torch.distributed as dist
from torch._inductor.codecache import WritableTempFile
from torch._inductor.test_case import TestCase
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
from torch.utils._triton import has_triton
if torch.distributed.is_available():
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
if has_triton():
import triton
import triton.language as tl
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.atomic_add(output_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE": 1024},
num_warps=4,
num_stages=2,
pre_hook=init_to_zero("output_ptr"),
)
],
pre_hook=init_to_zero("output_ptr"),
post_hook=init_to_zero("output_ptr"),
key=["n_elements"],
)
@triton.jit
def add_kernel_autotune(
x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.atomic_add(output_ptr + offsets, output, mask=mask)
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.triton_utils import requires_gpu
class FxGraphRunnableArtifactFilter(logging.Filter):
def filter(self, record):
@ -100,6 +153,41 @@ class FxGraphRunnableTest(TestCase):
torch.compile(f)(torch.randn(4))
self._exec_and_verify_payload()
@unittest.skipUnless(has_triton(), "Triton not available")
def test_user_defined_triton_kernel_autotune(self):
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
n_elements = output.numel()
def grid(
meta,
):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotune[grid](x, y, output, n_elements)
return output
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
torch.compile(add)(x, y)
self._exec_and_verify_payload()
@unittest.skipUnless(has_triton(), "Triton not available")
@requires_gpu
def test_user_defined_triton_kernel(self):
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
n_elements = x.numel()
add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4)
return output
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
torch.compile(add)(x, y)
self._exec_and_verify_payload()
def test_two_inputs_matmul(self):
def f(a, b):
return (a @ b).relu()

View File

@ -34,6 +34,21 @@ from tempfile import TemporaryFile
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
from typing_extensions import Unpack
from torch.utils._triton import has_triton
if has_triton():
from triton.runtime.autotuner import Autotuner
from triton.runtime.jit import JITFunction
else:
class Autotuner: # type: ignore[no-redef]
pass
class JITFunction: # type: ignore[no-redef]
pass
import torch
import torch.fx as fx
import torch.nn as nn
@ -58,6 +73,7 @@ from torch._dynamo.debug_utils import (
)
from torch._dynamo.utils import clone_inputs, counters, same
from torch._environment import is_fbcode
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
from torch._inductor.cpp_builder import normalize_path_separator
from torch._inductor.output_code import OutputCode
from torch._library.fake_class_registry import FakeScriptObject
@ -302,6 +318,16 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
"""
).strip()
triton_imports = ""
if len(kernel_side_table.id_to_kernel) > 0:
triton_imports = textwrap.dedent(
"""
import triton
import triton.language as tl
"""
).strip()
model_str = textwrap.dedent(
f"""
{generate_env_vars_string(stable_output=stable_output)}
@ -312,6 +338,7 @@ from torch._dynamo.testing import rand_strided
from math import inf
import torch._inductor.inductor_prims
{distributed_imports}
{triton_imports}
{generate_config_string(stable_output=stable_output)}
@ -330,6 +357,45 @@ isolate_fails_code_str = None
model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
model_str += _cuda_system_info_comment()
kernel_side_table_prefix = (
"torch._higher_order_ops.triton_kernel_wrap.kernel_side_table"
)
# Track which grid entry corresponds to the best config
for id in kernel_side_table.id_to_kernel:
kernel = kernel_side_table.get_kernel(id)
if isinstance(kernel, Autotuner):
config_strs = []
for kernel_config in kernel.configs:
config_strs.append(f"""triton.Config(
{str(kernel_config.kwargs)},
num_warps={kernel_config.num_warps},
num_stages={kernel_config.num_stages},
)""")
config_str = ",".join(config_strs)
model_str += textwrap.dedent(f"""
@triton.autotune(
configs=[
{config_str}
],
key=[]
)
""").strip()
model_str += "\n@triton.jit\n"
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
fn_name = (
kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name
)
fn_name = fn_name.split(".")[-1]
model_str += src_code
model_str += "\n"
model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n"
if len(kernel_side_table.constant_args) > 0:
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
model_str += NNModuleToString.convert(gm)
writer = InputWriter(save_dir, stable_hash=stable_hash)