mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Enable combo kernels with unbacked inputs (#162442)
Internal user tried enabling combo kernels, but ran into "Cannot convert symbols to int". This PR is to enable combo kernels on inputs with data-dependent shapes. ### Example exception ``` File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 4997, in benchmark_combo_kernel kernel_code_list = self.generate_combo_kernel_code( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/simd.py", line 1849, in generate_combo_kernel_code src_code = kernel.codegen_kernel() ^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 802, in codegen_kernel code.splice(self.codegen_kernel_benchmark(num_gb=0)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 852, in codegen_kernel_benchmark var_names.extend(self.kernel_benchmark_extra_args()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 733, in kernel_benchmark_extra_args extra_args.append(str(V.graph.sizevars.size_hint(tree.numel))) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 584, in size_hint return int(out) ^^^^^^^^ File "/home/colinpeppler/.conda/envs/pytorch/lib/python3.12/site-packages/sympy/core/expr.py", line 307, in __int__ raise TypeError("Cannot convert symbols to int") torch._inductor.exc.InductorError: TypeError: Cannot convert symbols to int ``` Differential Revision: [D82042230](https://our.internmc.facebook.com/intern/diff/D82042230) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162442 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
6d65737aee
commit
94755e81c4
@ -1010,7 +1010,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
# for the "cat". However, I think it might be a bit overwhelming that
|
||||
# we add such complexity only for handling some particular cases for
|
||||
# benchmarking.
|
||||
out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels.values()))
|
||||
out_numel = V.graph.sizevars.size_hint(
|
||||
sympy_product(self.numels.values()),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
for i, arg in enumerate(call_args):
|
||||
# "buf" may be narrowed. In this case, the number of memory accesses
|
||||
# should be estimated based on the reinterpreted layout.
|
||||
@ -1021,7 +1024,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
nbytes.append(0)
|
||||
continue
|
||||
arg_numel = V.graph.get_numel(arg)
|
||||
buf_size = V.graph.sizevars.size_hint(arg_numel)
|
||||
buf_size = V.graph.sizevars.size_hint(
|
||||
arg_numel, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
if buf_size > out_numel:
|
||||
# This arg points to a buf that has been sliced.
|
||||
# We need to count each individual slice to have
|
||||
|
Reference in New Issue
Block a user