diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 5fe3623b271a..15a08e7f1627 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -4,6 +4,7 @@ # Skip do not assign a lambda expression, use a def import functools import logging +import os import torch import torch._dynamo.testing @@ -1280,8 +1281,11 @@ def forward(self, x_1, output_1): self.assertEqual(compiled_out, eager_out) @requires_gpu + @common_utils.parametrize("dump_launch_params", ["0", "1"]) @common_utils.parametrize("dynamic", [False, True]) - def test_triton_kernel_equal_to_1_arg(self, dynamic): + def test_triton_kernel_equal_to_1_arg(self, dynamic, dump_launch_params): + os.environ["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = dump_launch_params + @triton.jit def add_kernel_half_n_elements( in_ptr0, diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 38a9bd1ad9c0..6d978af8d772 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1311,11 +1311,23 @@ class CachingAutotuner(KernelInterface): def filtered_signature() -> list[str]: # constexprs are not passed in as args - return [ - x - for x in self.triton_meta["signature"].keys() - if x not in cfg.kwargs.keys() - ] + new_signature: list[str] = [] + from triton.runtime.interpreter import InterpretedFunction + + for i, x in enumerate(self.triton_meta["signature"].keys()): + if isinstance(self.fn, InterpretedFunction): + # These are torch compiled triton kernels that definitely + # have block size configs. Dynamo does not currently + # trace user defined triton kernels when TRITON_INTERPRET=1 + if x not in cfg.kwargs.keys(): + new_signature.append(x) + elif i not in self.fn.constexprs: + # use constexprs rather than just configs since user + # defined triton kernels may not have any configs + new_signature.append(x) + + return new_signature + else: def filtered_signature() -> list[str]: