Compare commits

...

2 Commits

Author SHA1 Message Date
f4adf6dd1d lint 2024-11-04 12:30:35 -08:00
41a1e2557f [inductor] Error on unsupported autotuner configs 2024-11-04 12:25:03 -08:00
2 changed files with 108 additions and 1 deletions

View File

@ -3377,6 +3377,99 @@ class CustomOpTests(torch._inductor.test_case.TestCase):
"grid_wrapper_for_op_zeros_0"
).check_next("return (256").check_next("return (64").run(output)
@requires_gpu
def test_autotune_no_pre_or_post_hook(self):
import triton
import triton.language as tl
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
# pre_hook requires running arbitrary code at runtime, which we cannot handle at this time
# https://github.com/pytorch/pytorch/issues/139059
def get_default_config():
config = triton.Config(
{"BLOCK_SIZE": 1024},
num_warps=4,
num_stages=2,
pre_hook=init_to_zero("output_ptr"),
)
return [config]
def get_okay_config():
config = triton.Config({"BLOCK_SIZE": 1024}, num_warps=4, num_stages=2)
return [config]
@triton.autotune(
configs=get_default_config(),
key=["n_elements"],
)
@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.atomic_add(output_ptr + offsets, output, mask=mask)
# Instead of passing a bad Config, we pass in a pre_hook/post_hook---which we cannot support right now
@triton.autotune(
configs=get_okay_config(),
key=["n_elements"],
pre_hook=lambda x: None,
post_hook=lambda x: None,
)
@triton.jit
def add_kernel_autotuner_config(
x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.atomic_add(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements)
return output
def add2(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuner_config[grid](x, y, output, n_elements)
return output
x = torch.ones((4096,), device="cuda:0", dtype=torch.float16)
y = torch.ones((4096,), device="cuda:0", dtype=torch.float16)
# should always pass
assert add(x, y).mean() == 2, "Problem with add kernel"
# this should cause an exception, since pre_hook is not allowed
msg = "Only configs and keys are supported for triton.autotune"
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
add_compiled = torch.compile(add, mode="reduce-overhead", fullgraph=True)
add_compiled(x, y).mean()
# this should also cause an exception, since we can't pass pre_hook in the autotuner configs
with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg):
add_compiled2 = torch.compile(add2, mode="reduce-overhead", fullgraph=True)
add_compiled2(x, y).mean()
common_utils.instantiate_parametrized_tests(KernelTests)
common_utils.instantiate_parametrized_tests(CustomOpTests)

View File

@ -1038,7 +1038,6 @@ class TritonHOPifier:
# We only support configs and keys arguments of triton.autotune
# Make sure other arguments are defaulted
defaults = inspect.signature(Autotuner.__init__).parameters
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
# The call to get_first_attr is to maintain backward-compatibility.
if (
@ -1068,6 +1067,21 @@ class TritonHOPifier:
"use_cuda_graph" in defaults
and defaults["use_cuda_graph"].default != kernel.use_cuda_graph
)
# pre_hook requires running arbitrary code at runtime, which we cannot handle at this time
# https://github.com/pytorch/pytorch/issues/139059
or (
# Check Config passed to autotuner in configs
any(cfg.pre_hook is not None for cfg in kernel.configs)
)
# we also cannot support pre/post hook in the autotuner config
or (
"pre_hook" in defaults
and defaults["pre_hook"] != kernel.pre_hook
)
or (
"post_hook" in defaults
and defaults["post_hook"] != kernel.post_hook
)
)
):
self.raise_unsupported(