mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Enable combo kernels with unbacked inputs (#162442)
Internal user tried enabling combo kernels, but ran into "Cannot convert symbols to int". This PR is to enable combo kernels on inputs with data-dependent shapes. ### Example exception ``` File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton.py", line 4997, in benchmark_combo_kernel kernel_code_list = self.generate_combo_kernel_code( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/simd.py", line 1849, in generate_combo_kernel_code src_code = kernel.codegen_kernel() ^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 802, in codegen_kernel code.splice(self.codegen_kernel_benchmark(num_gb=0)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 852, in codegen_kernel_benchmark var_names.extend(self.kernel_benchmark_extra_args()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/codegen/triton_combo_kernel.py", line 733, in kernel_benchmark_extra_args extra_args.append(str(V.graph.sizevars.size_hint(tree.numel))) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/colinpeppler/pytorch/torch/_inductor/sizevars.py", line 584, in size_hint return int(out) ^^^^^^^^ File "/home/colinpeppler/.conda/envs/pytorch/lib/python3.12/site-packages/sympy/core/expr.py", line 307, in __int__ raise TypeError("Cannot convert symbols to int") torch._inductor.exc.InductorError: TypeError: Cannot convert symbols to int ``` Differential Revision: [D82042230](https://our.internmc.facebook.com/intern/diff/D82042230) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162442 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
6d65737aee
commit
94755e81c4
@ -564,6 +564,33 @@ class TestUnbackedSymints(InductorTestCase):
|
||||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
@skipGPUIf(not HAS_GPU, "requires gpu and triton")
|
||||
@dynamo_config.patch({"capture_dynamic_output_shape_ops": True})
|
||||
@inductor_config.patch({"combo_kernels": True, "benchmark_combo_kernel": True})
|
||||
def test_combo_kernel_size_hint_failure(self, device):
|
||||
# A size hint failure is "TypeError: Cannot convert symbols to int"
|
||||
if device == "cpu":
|
||||
raise unittest.SkipTest("Combo kernels must be for GPU.")
|
||||
|
||||
def fn(x):
|
||||
nz = torch.nonzero(x)
|
||||
u0 = nz.size(0)
|
||||
t1 = torch.ones(u0, device=device)
|
||||
t2 = torch.zeros(u0 + 1, device=device)
|
||||
t3 = torch.zeros(u0 * 2, device=device)
|
||||
t4 = torch.zeros(u0 - x.size(0), device=device)
|
||||
out1 = t1 - 1
|
||||
out2 = t2 + 2
|
||||
out3 = t3 * 3
|
||||
out4 = t4 / 4
|
||||
return out1, out2, out3, out4
|
||||
|
||||
example_inputs = (torch.randn(32, device=device, dtype=torch.float16),)
|
||||
torch._dynamo.mark_dynamic(example_inputs[0], 0)
|
||||
actual = torch.compile(fn, fullgraph=True)(*example_inputs)
|
||||
expected = fn(*example_inputs)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True)
|
||||
|
||||
|
@ -1010,7 +1010,10 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
# for the "cat". However, I think it might be a bit overwhelming that
|
||||
# we add such complexity only for handling some particular cases for
|
||||
# benchmarking.
|
||||
out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels.values()))
|
||||
out_numel = V.graph.sizevars.size_hint(
|
||||
sympy_product(self.numels.values()),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
for i, arg in enumerate(call_args):
|
||||
# "buf" may be narrowed. In this case, the number of memory accesses
|
||||
# should be estimated based on the reinterpreted layout.
|
||||
@ -1021,7 +1024,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
nbytes.append(0)
|
||||
continue
|
||||
arg_numel = V.graph.get_numel(arg)
|
||||
buf_size = V.graph.sizevars.size_hint(arg_numel)
|
||||
buf_size = V.graph.sizevars.size_hint(
|
||||
arg_numel, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
if buf_size > out_numel:
|
||||
# This arg points to a buf that has been sliced.
|
||||
# We need to count each individual slice to have
|
||||
|
@ -3868,9 +3868,15 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
if isinstance(arg, int):
|
||||
args.append(str(arg))
|
||||
elif isinstance(arg, SymbolicCallArg):
|
||||
args.append(str(V.graph.sizevars.size_hint(arg.inner_expr)))
|
||||
hint = V.graph.sizevars.size_hint(
|
||||
arg.inner_expr, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
args.append(str(hint))
|
||||
elif isinstance(arg, sympy.Expr):
|
||||
args.append(str(V.graph.sizevars.size_hint(arg)))
|
||||
hint = V.graph.sizevars.size_hint(
|
||||
arg, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
args.append(str(hint))
|
||||
else:
|
||||
raise ValueError(f"Unsupported numel argument type: {type(arg)}")
|
||||
return args
|
||||
@ -3887,14 +3893,34 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
var_name = f"arg_{next(name_cnt)}"
|
||||
buf = V.graph.try_get_buffer(arg_name)
|
||||
if buf:
|
||||
size = V.graph.sizevars.size_hints(
|
||||
buf.get_size(),
|
||||
hint_override=self.hint_override,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
stride = V.graph.sizevars.size_hints(
|
||||
buf.get_stride(),
|
||||
hint_override=self.hint_override,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size(), hint_override=self.hint_override)}, {V.graph.sizevars.size_hints(buf.get_stride(), hint_override=self.hint_override)}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
|
||||
f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
|
||||
)
|
||||
elif arg_name in V.graph.constants:
|
||||
# note that random seed is put in V.graph.constants
|
||||
const_tensor = V.graph.constants[arg_name]
|
||||
size = V.graph.sizevars.size_hints(
|
||||
const_tensor.size(),
|
||||
hint_override=self.hint_override,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
stride = V.graph.sizevars.size_hints(
|
||||
const_tensor.stride(),
|
||||
hint_override=self.hint_override,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
)
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size(), hint_override=self.hint_override)}, {V.graph.sizevars.size_hints(const_tensor.stride(), hint_override=self.hint_override)}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
)
|
||||
elif isinstance(arg_sig, SizeArg):
|
||||
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
|
||||
|
@ -90,7 +90,10 @@ def _default_custom_combo_kernel_horizontal_partition(
|
||||
long_reduction = [
|
||||
n
|
||||
for n in reduction
|
||||
if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
|
||||
if (
|
||||
V.graph.sizevars.shape_env.has_hint(n.group[-1][-1])
|
||||
and V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
|
||||
)
|
||||
]
|
||||
short_reduction = [n for n in reduction if n not in long_reduction]
|
||||
if long_reduction:
|
||||
@ -103,6 +106,7 @@ def _default_custom_combo_kernel_horizontal_partition(
|
||||
for n in not_reduction
|
||||
if not kernel_map[n].inside_reduction
|
||||
and len(kernel_map[n].numels) == 2
|
||||
and V.graph.sizevars.shape_env.has_hint(kernel_map[n].numels["x"])
|
||||
and V.graph.sizevars.size_hint(kernel_map[n].numels["x"]) > LARGE_NUMELS
|
||||
]
|
||||
if large_pointwise:
|
||||
@ -485,7 +489,11 @@ class ComboKernel(Kernel):
|
||||
|
||||
def select_heuristics(self, sub_kernel: TritonKernel) -> tuple[str, dict[str, int]]:
|
||||
size_hints = {
|
||||
prefix: next_power_of_2(V.graph.sizevars.size_hint(numel))
|
||||
prefix: next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
numel, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
)
|
||||
for prefix, numel in sub_kernel.numels.items()
|
||||
if not prefix_is_reduction(prefix) or sub_kernel.inside_reduction
|
||||
}
|
||||
@ -726,7 +734,13 @@ class ComboKernel(Kernel):
|
||||
if numel_name not in self.dynamic_shape_args:
|
||||
continue
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
extra_args.append(str(V.graph.sizevars.size_hint(tree.numel)))
|
||||
extra_args.append(
|
||||
str(
|
||||
V.graph.sizevars.size_hint(
|
||||
tree.numel, fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
)
|
||||
)
|
||||
return extra_args
|
||||
|
||||
def codegen_kernel(self, name: Optional[str] = None) -> str:
|
||||
@ -810,14 +824,26 @@ class ComboKernel(Kernel):
|
||||
var_name = f"arg_{next(name_cnt)}"
|
||||
buf = V.graph.try_get_buffer(arg_name)
|
||||
if buf:
|
||||
size = V.graph.sizevars.size_hints(
|
||||
buf.get_size(), fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
stride = V.graph.sizevars.size_hints(
|
||||
buf.get_stride(), fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
|
||||
f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
|
||||
)
|
||||
elif arg_name in V.graph.constants:
|
||||
# note that random seed is put in V.graph.constants
|
||||
const_tensor = V.graph.constants[arg_name]
|
||||
size = V.graph.sizevars.size_hints(
|
||||
const_tensor.size(), fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
stride = V.graph.sizevars.size_hints(
|
||||
const_tensor.stride(), fallback=config.unbacked_symint_fallback
|
||||
)
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
)
|
||||
elif isinstance(arg_sig, SizeArg):
|
||||
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
|
||||
|
@ -2385,6 +2385,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||
# constant now, need type info. I agree, this needs type info, and while this is not true type info
|
||||
# it suffices as a type hint for the purposes of producing the correct code for this type.
|
||||
arg = SymbolicCallArg(sym, tree.numel)
|
||||
|
||||
is_benchmark_kernel = kernel_name == ""
|
||||
if not is_benchmark_kernel:
|
||||
self.writeline(SymbolicCallArgLine(self, arg, V.graph))
|
||||
|
||||
return arg
|
||||
|
Reference in New Issue
Block a user