[Inductor] Use indices for constants in triton_meta (#121427)

@bertmaher pointed out that constants are passed with their indices, not their names. Looking at triton source, this appears to be true 392370b303/python/triton/runtime/jit.py (L381-L385)
I'm guessing both indices and names work here but lets be consistent.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121427
Approved by: https://github.com/aakhundov
This commit is contained in:
Oguz Ulgen
2024-03-07 10:24:19 -08:00
committed by PyTorch MergeBot
parent f61192b014
commit 18d574a07a

View File

@ -986,15 +986,15 @@ class WrapperCodeGen(CodeGen):
from .common import KernelArgType, SizeArg, TensorArg
signature: List[KernelArgType] = []
constants = {}
constants: Dict[int, Any] = {}
non_constant_indices = []
equal_to_1_args: List[str] = []
equal_to_1_arg_idx: List[int] = []
for idx, key in enumerate(kernel.arg_names):
if key not in kwargs:
continue
arg = kwargs[key]
if idx in kernel.constexprs:
constants[key] = arg
constants[idx] = arg
else:
non_constant_indices.append(idx)
if isinstance(arg, ir.Buffer):
@ -1020,7 +1020,7 @@ class WrapperCodeGen(CodeGen):
else:
signature.append(SizeArg(key, arg))
if arg is not None and V.graph.sizevars.statically_known_equals(arg, 1): # type: ignore[arg-type]
equal_to_1_args.append(key)
equal_to_1_arg_idx.append(idx)
index_dtype = "tl.int32"
triton_meta = {
"signature": signature_to_meta(
@ -1033,13 +1033,13 @@ class WrapperCodeGen(CodeGen):
# Triton compiler includes equal_to_1 args into constants even
# when they are not constexpr. otherwise there may be a segfault
# during launching the Inductor-compiled Triton kernel.
# TODO(aakhundov): add None args to constnats, too. currently, this
# TODO(aakhundov): add None args to constants, too. currently, this
# causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
"constants": {
**constants,
**{arg: 1 for arg in equal_to_1_args},
**{idx: 1 for idx in equal_to_1_arg_idx},
},
"configs": [
config_of(