[inductor] Fix removal of constexpr args from the launcher signature (#161924)

Fixes the case described below which occurs when:
- A user `torch.compile`s a function that uses a triton kernel.
- `TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1` .

Problem:

If the user defined triton kernel is not autotuned:

```python
import os
os.environ["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = "1"
@triton.jit
def kernel(..., BLOCK_SIZE: tl.constexpr):
    ...
@torch.compile
def fn(..)
    kernel[..](..., 128)

fn(..)
```

Then In `triton_heuristics. _interpret_args_grid`, `filter_signature` function:

```python
        def filtered_signature() -> list[str]:
                # constexprs are not passed in as args
                return [
                    x
                    for x in self.triton_meta["signature"].keys()
                    if x not in cfg.kwargs.keys()
                ]
```

because `triton.autotune` is not used on the the `triton.jit` function, `cfg` above will be empty, and so `BLOCK_SIZE` will not be removed from the signature even though it is constexpr, even though it is removed from the arguments that are passed in to `interpret_args_grid`. This results in a mismatch between the number of parameters in the signature and the number of arguments, which leads to the error `NameError: name '_grid_2' is not defined`.

Fix:

Use the triton jit kernel `constexprs` for args to remove.  Not sure if this is a good fix so suggestions are welcome.

Test plan:

Added a parameter to an existing triton kernel to test for this edge case

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161924
Approved by: https://github.com/davidberard98
This commit is contained in:
Mwiza Kunda
2025-09-12 13:58:09 +00:00
committed by PyTorch MergeBot
parent 6c334885d4
commit 03798b0f91
2 changed files with 22 additions and 6 deletions

View File

@ -4,6 +4,7 @@
# Skip do not assign a lambda expression, use a def # Skip do not assign a lambda expression, use a def
import functools import functools
import logging import logging
import os
import torch import torch
import torch._dynamo.testing import torch._dynamo.testing
@ -1280,8 +1281,11 @@ def forward(self, x_1, output_1):
self.assertEqual(compiled_out, eager_out) self.assertEqual(compiled_out, eager_out)
@requires_gpu @requires_gpu
@common_utils.parametrize("dump_launch_params", ["0", "1"])
@common_utils.parametrize("dynamic", [False, True]) @common_utils.parametrize("dynamic", [False, True])
def test_triton_kernel_equal_to_1_arg(self, dynamic): def test_triton_kernel_equal_to_1_arg(self, dynamic, dump_launch_params):
os.environ["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = dump_launch_params
@triton.jit @triton.jit
def add_kernel_half_n_elements( def add_kernel_half_n_elements(
in_ptr0, in_ptr0,

View File

@ -1311,11 +1311,23 @@ class CachingAutotuner(KernelInterface):
def filtered_signature() -> list[str]: def filtered_signature() -> list[str]:
# constexprs are not passed in as args # constexprs are not passed in as args
return [ new_signature: list[str] = []
x from triton.runtime.interpreter import InterpretedFunction
for x in self.triton_meta["signature"].keys()
if x not in cfg.kwargs.keys() for i, x in enumerate(self.triton_meta["signature"].keys()):
] if isinstance(self.fn, InterpretedFunction):
# These are torch compiled triton kernels that definitely
# have block size configs. Dynamo does not currently
# trace user defined triton kernels when TRITON_INTERPRET=1
if x not in cfg.kwargs.keys():
new_signature.append(x)
elif i not in self.fn.constexprs:
# use constexprs rather than just configs since user
# defined triton kernels may not have any configs
new_signature.append(x)
return new_signature
else: else:
def filtered_signature() -> list[str]: def filtered_signature() -> list[str]: