mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[inductor] Fix removal of constexpr args from the launcher signature (#161924)
Fixes the case described below which occurs when: - A user `torch.compile`s a function that uses a triton kernel. - `TORCHINDUCTOR_DUMP_LAUNCH_PARAMS=1` . Problem: If the user defined triton kernel is not autotuned: ```python import os os.environ["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = "1" @triton.jit def kernel(..., BLOCK_SIZE: tl.constexpr): ... @torch.compile def fn(..) kernel[..](..., 128) fn(..) ``` Then In `triton_heuristics. _interpret_args_grid`, `filter_signature` function: ```python def filtered_signature() -> list[str]: # constexprs are not passed in as args return [ x for x in self.triton_meta["signature"].keys() if x not in cfg.kwargs.keys() ] ``` because `triton.autotune` is not used on the the `triton.jit` function, `cfg` above will be empty, and so `BLOCK_SIZE` will not be removed from the signature even though it is constexpr, even though it is removed from the arguments that are passed in to `interpret_args_grid`. This results in a mismatch between the number of parameters in the signature and the number of arguments, which leads to the error `NameError: name '_grid_2' is not defined`. Fix: Use the triton jit kernel `constexprs` for args to remove. Not sure if this is a good fix so suggestions are welcome. Test plan: Added a parameter to an existing triton kernel to test for this edge case Pull Request resolved: https://github.com/pytorch/pytorch/pull/161924 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
6c334885d4
commit
03798b0f91
@ -4,6 +4,7 @@
|
||||
# Skip do not assign a lambda expression, use a def
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch._dynamo.testing
|
||||
@ -1280,8 +1281,11 @@ def forward(self, x_1, output_1):
|
||||
self.assertEqual(compiled_out, eager_out)
|
||||
|
||||
@requires_gpu
|
||||
@common_utils.parametrize("dump_launch_params", ["0", "1"])
|
||||
@common_utils.parametrize("dynamic", [False, True])
|
||||
def test_triton_kernel_equal_to_1_arg(self, dynamic):
|
||||
def test_triton_kernel_equal_to_1_arg(self, dynamic, dump_launch_params):
|
||||
os.environ["TORCHINDUCTOR_DUMP_LAUNCH_PARAMS"] = dump_launch_params
|
||||
|
||||
@triton.jit
|
||||
def add_kernel_half_n_elements(
|
||||
in_ptr0,
|
||||
|
@ -1311,11 +1311,23 @@ class CachingAutotuner(KernelInterface):
|
||||
|
||||
def filtered_signature() -> list[str]:
|
||||
# constexprs are not passed in as args
|
||||
return [
|
||||
x
|
||||
for x in self.triton_meta["signature"].keys()
|
||||
if x not in cfg.kwargs.keys()
|
||||
]
|
||||
new_signature: list[str] = []
|
||||
from triton.runtime.interpreter import InterpretedFunction
|
||||
|
||||
for i, x in enumerate(self.triton_meta["signature"].keys()):
|
||||
if isinstance(self.fn, InterpretedFunction):
|
||||
# These are torch compiled triton kernels that definitely
|
||||
# have block size configs. Dynamo does not currently
|
||||
# trace user defined triton kernels when TRITON_INTERPRET=1
|
||||
if x not in cfg.kwargs.keys():
|
||||
new_signature.append(x)
|
||||
elif i not in self.fn.constexprs:
|
||||
# use constexprs rather than just configs since user
|
||||
# defined triton kernels may not have any configs
|
||||
new_signature.append(x)
|
||||
|
||||
return new_signature
|
||||
|
||||
else:
|
||||
|
||||
def filtered_signature() -> list[str]:
|
||||
|
Reference in New Issue
Block a user