mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Use indices for constants in triton_meta (#121427)
@bertmaher pointed out that constants are passed with their indices, not their names. Looking at triton source, this appears to be true 392370b303/python/triton/runtime/jit.py (L381-L385)
I'm guessing both indices and names work here but lets be consistent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121427
Approved by: https://github.com/aakhundov
This commit is contained in:
committed by
PyTorch MergeBot
parent
f61192b014
commit
18d574a07a
@ -986,15 +986,15 @@ class WrapperCodeGen(CodeGen):
|
||||
from .common import KernelArgType, SizeArg, TensorArg
|
||||
|
||||
signature: List[KernelArgType] = []
|
||||
constants = {}
|
||||
constants: Dict[int, Any] = {}
|
||||
non_constant_indices = []
|
||||
equal_to_1_args: List[str] = []
|
||||
equal_to_1_arg_idx: List[int] = []
|
||||
for idx, key in enumerate(kernel.arg_names):
|
||||
if key not in kwargs:
|
||||
continue
|
||||
arg = kwargs[key]
|
||||
if idx in kernel.constexprs:
|
||||
constants[key] = arg
|
||||
constants[idx] = arg
|
||||
else:
|
||||
non_constant_indices.append(idx)
|
||||
if isinstance(arg, ir.Buffer):
|
||||
@ -1020,7 +1020,7 @@ class WrapperCodeGen(CodeGen):
|
||||
else:
|
||||
signature.append(SizeArg(key, arg))
|
||||
if arg is not None and V.graph.sizevars.statically_known_equals(arg, 1): # type: ignore[arg-type]
|
||||
equal_to_1_args.append(key)
|
||||
equal_to_1_arg_idx.append(idx)
|
||||
index_dtype = "tl.int32"
|
||||
triton_meta = {
|
||||
"signature": signature_to_meta(
|
||||
@ -1033,13 +1033,13 @@ class WrapperCodeGen(CodeGen):
|
||||
# Triton compiler includes equal_to_1 args into constants even
|
||||
# when they are not constexpr. otherwise there may be a segfault
|
||||
# during launching the Inductor-compiled Triton kernel.
|
||||
# TODO(aakhundov): add None args to constnats, too. currently, this
|
||||
# TODO(aakhundov): add None args to constants, too. currently, this
|
||||
# causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input.
|
||||
# https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307
|
||||
# https://github.com/openai/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384
|
||||
"constants": {
|
||||
**constants,
|
||||
**{arg: 1 for arg in equal_to_1_args},
|
||||
**{idx: 1 for idx in equal_to_1_arg_idx},
|
||||
},
|
||||
"configs": [
|
||||
config_of(
|
||||
|
Reference in New Issue
Block a user