mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Fix removal of constexpr args from the launcher signature (#161924)
Fixes the case described below which occurs when: - A user `torch.compile`s a function that uses a triton kernel. - `TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1` . Problem: If the user defined triton kernel is not autotuned: ```python import os os.environ["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = "1" @triton.jit def kernel(..., BLOCK_SIZE: tl.constexpr): ... @torch.compile def fn(..) kernel[..](..., 128) fn(..) ``` Then In `triton_heuristics. _interpret_args_grid`, `filter_signature` function: ```python def filtered_signature() -> list[str]: # constexprs are not passed in as args return [ x for x in self.triton_meta["signature"].keys() if x not in cfg.kwargs.keys() ] ``` because `triton.autotune` is not used on the the `triton.jit` function, `cfg` above will be empty, and so `BLOCK_SIZE` will not be removed from the signature even though it is constexpr, even though it is removed from the arguments that are passed in to `interpret_args_grid`. This results in a mismatch between the number of parameters in the signature and the number of arguments, which leads to the error `NameError: name '_grid_2' is not defined`. Fix: Use the triton jit kernel `constexprs` for args to remove. Not sure if this is a good fix so suggestions are welcome. Test plan: Added a parameter to an existing triton kernel to test for this edge case Pull Request resolved: https://github.com/pytorch/pytorch/pull/161924 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
6c334885d4
commit
03798b0f91
@ -4,6 +4,7 @@
|
|||||||
# Skip do not assign a lambda expression, use a def
|
# Skip do not assign a lambda expression, use a def
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.testing
|
import torch._dynamo.testing
|
||||||
@ -1280,8 +1281,11 @@ def forward(self, x_1, output_1):
|
|||||||
self.assertEqual(compiled_out, eager_out)
|
self.assertEqual(compiled_out, eager_out)
|
||||||
|
|
||||||
@requires_gpu
|
@requires_gpu
|
||||||
|
@common_utils.parametrize("dump_launch_params", ["0", "1"])
|
||||||
@common_utils.parametrize("dynamic", [False, True])
|
@common_utils.parametrize("dynamic", [False, True])
|
||||||
def test_triton_kernel_equal_to_1_arg(self, dynamic):
|
def test_triton_kernel_equal_to_1_arg(self, dynamic, dump_launch_params):
|
||||||
|
os.environ["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = dump_launch_params
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def add_kernel_half_n_elements(
|
def add_kernel_half_n_elements(
|
||||||
in_ptr0,
|
in_ptr0,
|
||||||
|
@ -1311,11 +1311,23 @@ class CachingAutotuner(KernelInterface):
|
|||||||
|
|
||||||
def filtered_signature() -> list[str]:
|
def filtered_signature() -> list[str]:
|
||||||
# constexprs are not passed in as args
|
# constexprs are not passed in as args
|
||||||
return [
|
new_signature: list[str] = []
|
||||||
x
|
from triton.runtime.interpreter import InterpretedFunction
|
||||||
for x in self.triton_meta["signature"].keys()
|
|
||||||
if x not in cfg.kwargs.keys()
|
for i, x in enumerate(self.triton_meta["signature"].keys()):
|
||||||
]
|
if isinstance(self.fn, InterpretedFunction):
|
||||||
|
# These are torch compiled triton kernels that definitely
|
||||||
|
# have block size configs. Dynamo does not currently
|
||||||
|
# trace user defined triton kernels when TRITON_INTERPRET=1
|
||||||
|
if x not in cfg.kwargs.keys():
|
||||||
|
new_signature.append(x)
|
||||||
|
elif i not in self.fn.constexprs:
|
||||||
|
# use constexprs rather than just configs since user
|
||||||
|
# defined triton kernels may not have any configs
|
||||||
|
new_signature.append(x)
|
||||||
|
|
||||||
|
return new_signature
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def filtered_signature() -> list[str]:
|
def filtered_signature() -> list[str]:
|
||||||
|
Reference in New Issue
Block a user