Files
pytorch/torch
Mwiza Kunda 03798b0f91 [inductor] Fix removal of constexpr args from the launcher signature (#161924)
Fixes the case described below which occurs when:
- A user `torch.compile`s a function that uses a triton kernel.
- `TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1` .

Problem:

If the user defined triton kernel is not autotuned:

```python
import os
os.environ["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = "1"
@triton.jit
def kernel(..., BLOCK_SIZE: tl.constexpr):
    ...
@torch.compile
def fn(..)
    kernel[..](..., 128)

fn(..)
```

Then In `triton_heuristics. _interpret_args_grid`, `filter_signature` function:

```python
        def filtered_signature() -> list[str]:
                # constexprs are not passed in as args
                return [
                    x
                    for x in self.triton_meta["signature"].keys()
                    if x not in cfg.kwargs.keys()
                ]
```

because `triton.autotune` is not used on the the `triton.jit` function, `cfg` above will be empty, and so `BLOCK_SIZE` will not be removed from the signature even though it is constexpr, even though it is removed from the arguments that are passed in to `interpret_args_grid`. This results in a mismatch between the number of parameters in the signature and the number of arguments, which leads to the error `NameError: name '_grid_2' is not defined`.

Fix:

Use the triton jit kernel `constexprs` for args to remove.  Not sure if this is a good fix so suggestions are welcome.

Test plan:

Added a parameter to an existing triton kernel to test for this edge case

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161924
Approved by: https://github.com/davidberard98
2025-09-12 13:58:09 +00:00
..
2025-09-11 16:35:23 +00:00
2025-09-12 05:56:25 +00:00
2025-04-27 09:56:42 +00:00
2025-04-27 09:56:42 +00:00
2025-06-14 18:18:43 +00:00