[inductor] Enable combo kernels with unbacked inputs (#162442)

Internal user tried enabling combo kernels, but ran into "Cannot convert symbols to int". This PR is to enable combo kernels on inputs with data-dependent shapes.

### Example exception

```
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 4997, in benchmark_combo_kernel
    kernel_code_list = self.generate_combo_kernel_code(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/simd.py", line 1849, in generate_combo_kernel_code
    src_code = kernel.codegen_kernel()
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 802, in codegen_kernel
    code.splice(self.codegen_kernel_benchmark(num_gb=0))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 852, in codegen_kernel_benchmark
    var_names.extend(self.kernel_benchmark_extra_args())
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 733, in kernel_benchmark_extra_args
    extra_args.append(str(V.graph.sizevars.size_hint(tree.numel)))
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 584, in size_hint
    return int(out)
           ^^^^^^^^
  File "/home/colinpeppler/.conda/envs/pytorch/lib/python3.12/site-packages/sympy/core/expr.py", line 307, in __int__
    raise TypeError("Cannot convert symbols to int")
torch._inductor.exc.InductorError: TypeError: Cannot convert symbols to int
```

Differential Revision: [D82042230](https://our.internmc.facebook.com/intern/diff/D82042230)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162442
Approved by: https://github.com/jansel
This commit is contained in:
Colin Peppler
2025-09-08 17:58:07 -07:00
committed by PyTorch MergeBot
parent 6d65737aee
commit 94755e81c4
5 changed files with 99 additions and 12 deletions

View File

@ -1010,7 +1010,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
# for the "cat". However, I think it might be a bit overwhelming that
# we add such complexity only for handling some particular cases for
# benchmarking.
out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels.values()))
out_numel = V.graph.sizevars.size_hint(
sympy_product(self.numels.values()),
fallback=config.unbacked_symint_fallback,
)
for i, arg in enumerate(call_args):
# "buf" may be narrowed. In this case, the number of memory accesses
# should be estimated based on the reinterpreted layout.
@ -1021,7 +1024,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
nbytes.append(0)
continue
arg_numel = V.graph.get_numel(arg)
buf_size = V.graph.sizevars.size_hint(arg_numel)
buf_size = V.graph.sizevars.size_hint(
arg_numel, fallback=config.unbacked_symint_fallback
)
if buf_size > out_numel:
# This arg points to a buf that has been sliced.
# We need to count each individual slice to have