mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make user defined Triton kernels serializable for fx_graph_runnable (#160002)
Resolves issue https://github.com/pytorch/pytorch/issues/153475 where `fx_graph_runnable` didn't work with user defined triton kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160002 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
fb887c3bb5
commit
4183d4ff3d
@ -11,12 +11,65 @@ import torch.distributed as dist
|
||||
from torch._inductor.codecache import WritableTempFile
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
if has_triton():
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.atomic_add(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE": 1024},
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
pre_hook=init_to_zero("output_ptr"),
|
||||
)
|
||||
],
|
||||
pre_hook=init_to_zero("output_ptr"),
|
||||
post_hook=init_to_zero("output_ptr"),
|
||||
key=["n_elements"],
|
||||
)
|
||||
@triton.jit
|
||||
def add_kernel_autotune(
|
||||
x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.atomic_add(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.triton_utils import requires_gpu
|
||||
|
||||
|
||||
class FxGraphRunnableArtifactFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
@ -100,6 +153,41 @@ class FxGraphRunnableTest(TestCase):
|
||||
torch.compile(f)(torch.randn(4))
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@unittest.skipUnless(has_triton(), "Triton not available")
|
||||
def test_user_defined_triton_kernel_autotune(self):
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
|
||||
n_elements = output.numel()
|
||||
|
||||
def grid(
|
||||
meta,
|
||||
):
|
||||
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
|
||||
add_kernel_autotune[grid](x, y, output, n_elements)
|
||||
return output
|
||||
|
||||
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||
|
||||
torch.compile(add)(x, y)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@unittest.skipUnless(has_triton(), "Triton not available")
|
||||
@requires_gpu
|
||||
def test_user_defined_triton_kernel(self):
|
||||
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
|
||||
n_elements = x.numel()
|
||||
add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4)
|
||||
return output
|
||||
|
||||
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||
|
||||
torch.compile(add)(x, y)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
def test_two_inputs_matmul(self):
|
||||
def f(a, b):
|
||||
return (a @ b).relu()
|
||||
|
@ -34,6 +34,21 @@ from tempfile import TemporaryFile
|
||||
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
if has_triton():
|
||||
from triton.runtime.autotuner import Autotuner
|
||||
from triton.runtime.jit import JITFunction
|
||||
else:
|
||||
|
||||
class Autotuner: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
class JITFunction: # type: ignore[no-redef]
|
||||
pass
|
||||
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
import torch.nn as nn
|
||||
@ -58,6 +73,7 @@ from torch._dynamo.debug_utils import (
|
||||
)
|
||||
from torch._dynamo.utils import clone_inputs, counters, same
|
||||
from torch._environment import is_fbcode
|
||||
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
||||
from torch._inductor.cpp_builder import normalize_path_separator
|
||||
from torch._inductor.output_code import OutputCode
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
@ -302,6 +318,16 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
"""
|
||||
).strip()
|
||||
|
||||
triton_imports = ""
|
||||
|
||||
if len(kernel_side_table.id_to_kernel) > 0:
|
||||
triton_imports = textwrap.dedent(
|
||||
"""
|
||||
import triton
|
||||
import triton.language as tl
|
||||
"""
|
||||
).strip()
|
||||
|
||||
model_str = textwrap.dedent(
|
||||
f"""
|
||||
{generate_env_vars_string(stable_output=stable_output)}
|
||||
@ -312,6 +338,7 @@ from torch._dynamo.testing import rand_strided
|
||||
from math import inf
|
||||
import torch._inductor.inductor_prims
|
||||
{distributed_imports}
|
||||
{triton_imports}
|
||||
|
||||
{generate_config_string(stable_output=stable_output)}
|
||||
|
||||
@ -330,6 +357,45 @@ isolate_fails_code_str = None
|
||||
model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
|
||||
model_str += _cuda_system_info_comment()
|
||||
|
||||
kernel_side_table_prefix = (
|
||||
"torch._higher_order_ops.triton_kernel_wrap.kernel_side_table"
|
||||
)
|
||||
# Track which grid entry corresponds to the best config
|
||||
for id in kernel_side_table.id_to_kernel:
|
||||
kernel = kernel_side_table.get_kernel(id)
|
||||
if isinstance(kernel, Autotuner):
|
||||
config_strs = []
|
||||
for kernel_config in kernel.configs:
|
||||
config_strs.append(f"""triton.Config(
|
||||
{str(kernel_config.kwargs)},
|
||||
num_warps={kernel_config.num_warps},
|
||||
num_stages={kernel_config.num_stages},
|
||||
)""")
|
||||
|
||||
config_str = ",".join(config_strs)
|
||||
model_str += textwrap.dedent(f"""
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
{config_str}
|
||||
],
|
||||
key=[]
|
||||
)
|
||||
""").strip()
|
||||
|
||||
model_str += "\n@triton.jit\n"
|
||||
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
|
||||
fn_name = (
|
||||
kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name
|
||||
)
|
||||
fn_name = fn_name.split(".")[-1]
|
||||
|
||||
model_str += src_code
|
||||
model_str += "\n"
|
||||
model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n"
|
||||
|
||||
if len(kernel_side_table.constant_args) > 0:
|
||||
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
|
||||
|
||||
model_str += NNModuleToString.convert(gm)
|
||||
|
||||
writer = InputWriter(save_dir, stable_hash=stable_hash)
|
||||
|
Reference in New Issue
Block a user