Compare commits

...

2 Commits

Author SHA1 Message Date
7a5b38ccd3 Update on "[Inductor] Fix unbacked float symbol handling in kernel codegen"
When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
2025-11-03 13:25:27 -08:00
d0111aed27 [Inductor] Fix unbacked float symbol handling in kernel codegen
[ghstack-poisoned]
2025-11-03 12:48:24 -08:00
4 changed files with 34 additions and 2 deletions

View File

@ -14408,6 +14408,20 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
@skip_if_halide
@requires_cuda_and_triton
def test_unbacked_float_item(self):
def fn(x, max_val):
return torch.clamp(x, 0, max_val.item())
self.common(
fn,
(
torch.randn(10, 20, 30, device=self.device),
torch.tensor(5.0, device=self.device),
),
)
# end of class CommonTemplate - add new tests here

View File

@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache):
throw std::runtime_error("expected int arg");
return reinterpret_cast<uintptr_t>(result);
}}
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
if(unlikely(result == -1.0 && PyErr_Occurred()))
throw std::runtime_error("expected float arg");
return static_cast<float>(result);
}}
{extra_parse_arg}

View File

@ -1731,9 +1731,15 @@ class KernelArgs:
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.sizevars.items():
arg_defs.append(f"const {INDEX_TYPE} {inner}")
if isinstance(outer, sympy.Symbol) and symbol_is_type(
outer, (SymT.UNBACKED_FLOAT)
):
arg_defs.append(f"const float {inner}")
arg_types.append("const float")
else:
arg_defs.append(f"const {INDEX_TYPE} {inner}")
arg_types.append(f"const {INDEX_TYPE}")
call_args.append(self.wrap_size_arg(outer))
arg_types.append(f"const {INDEX_TYPE}")
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
assert not self.workspace_args, "Workspace not supported on CPU "
@ -2352,6 +2358,7 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
SymT.UNBACKED_INT,
SymT.SIZE,
SymT.PRECOMPUTED_SIZE,
SymT.UNBACKED_FLOAT,
),
)
}

View File

@ -4,6 +4,7 @@ from typing import Any, Optional
import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config
from ..runtime.hints import AttrsDescriptorWrapper
@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
return "constexpr"
elif isinstance(arg.expr, (float, sympy.Float)):
return "fp32"
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
arg.expr, (SymT.UNBACKED_FLOAT)
):
return "fp32"
elif isinstance(arg.expr, bool):
return "i1"