[multi-kernel] shape-similarity kernel selection (#163090)

Introduces a variant of size-hint multi-kernel, where for novel runtime shapes, instead of performing full benchmarking to determine the optimal kernel, selects one of many kernels pre-generated from multi-kernel hints, based off similarity b/w hint / runtime input & output shapes (L1 distance in log2 space).

Some caveats/changes:
- Size-hint multi-kernel now only kicks in if the kernel has dynamic shapes
- Pre-generation still only does 1-d search over specified hints, e.g. `matmul([s0, s1], [s1, s2])` with size-hints `[64, 256]` only generates 2 kernels - based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256]). Extending this to reasonable n-d search (via user API?) is an extension

Benchmarking results, compared to multi-kernel w/ full benchmarking (hints 64, 4096), and compiling with the ground truth hint:
<img width="1902" height="1222" alt="550541081_1088709150049684_6528797079439730237_n" src="https://github.com/user-attachments/assets/056cca48-c16a-4451-9b4a-fa13a7a058a9" />

Full benchmarking doing worse is extremely weird, but we did see similar spikes in #156628

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163090
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Pian Pawakapan
2025-09-23 21:00:42 +00:00
committed by PyTorch MergeBot
parent 22c5e8c17c
commit 2a9745de3c
4 changed files with 239 additions and 75 deletions

View File

@ -50,6 +50,16 @@ def _contains_multi_kernel_code(wrapper_code: str):
)
def _contains_size_hint_multi_kernel_code(wrapper_code: str):
return (
re.search(
r"multi_kernel_[^ ]* = async_compile.size_hint_multi_kernel[(]",
wrapper_code,
)
is not None
)
def make_cpp_wrapper_test(orig_test, **extra_args):
"""
Wrap an existing test into a new test with cpp-wrapper enabled.
@ -115,6 +125,7 @@ class MultiKernelTest(TestCase):
)
x = torch.randn(4096, 4096, device=GPU_TYPE)
y = torch.randn(4096, 4096, device=GPU_TYPE)
torch._dynamo.mark_dynamic(x, 0)
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
ref = fn(x, y)
@ -123,7 +134,7 @@ class MultiKernelTest(TestCase):
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertEqual(ref, act)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
@requires_triton()
# TODO: bobrenjc93 to fix multi-kernel for ROCM
@ -142,6 +153,7 @@ class MultiKernelTest(TestCase):
)
x = torch.randn(4096, 4096, device=GPU_TYPE)
y = torch.randn(4096, 4096, device=GPU_TYPE)
torch._dynamo.mark_dynamic(x, 0)
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
ref = fn(x, y)
@ -150,7 +162,7 @@ class MultiKernelTest(TestCase):
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertEqual(ref, act)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
@parametrize("force_kernel", (0, 1))
@unittest.mock.patch.dict(

View File

@ -521,6 +521,11 @@ class AsyncCompile:
# no need to call this in parallel since the sub-kernels are already parallel tasks
return MultiKernelCall(*args, **kwargs)
def size_hint_multi_kernel(self, *args, **kwargs) -> Any:
from torch._inductor.codegen.multi_kernel import SizeHintMultiKernelCall
return SizeHintMultiKernelCall(*args, **kwargs)
def cpp(self, source_code: str):
kernel_code_log.info("CPP Kernel:\n%s", source_code)
if get_compile_threads() <= 1:

View File

@ -1,8 +1,10 @@
# mypy: allow-untyped-defs
import functools
import logging
import math
import os
import pathlib
from typing import Any, Optional, Union
from torch._inductor.ir import MultiTemplateBuffer
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
@ -31,7 +33,13 @@ class MultiKernelState:
self.subkernel_to_kernel_name = {}
self.kernel_defs = IndentedBuffer()
def define_kernel(self, kernels):
def define_kernel(
self,
kernels: list[Any],
kernel_shape_keys: Optional[
list[Union[None, tuple[tuple[int, ...], ...]]]
] = None,
) -> str:
"""
Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
This has some minor issue.
@ -45,6 +53,12 @@ class MultiKernelState:
The only different is cache eviction policy.
We should name the multi-kernel differently in these 2 cases.
kernels:
A list of kernels
kernel_shape_keys:
Specified for size-hint multi-kernels.
Each list element is a shape key, corresponding to the concrete input & output size hints each kernel was tuned for.
"""
# Prevent circular import
from ..select_algorithm import TritonTemplateKernel
@ -68,9 +82,7 @@ class MultiKernelState:
kernels[0].output_node, MultiTemplateBuffer
):
for i, kernel in enumerate(kernels):
additional_call_args, additional_arg_types = (
kernel.additional_call_args_and_types()
)
additional_call_args, _ = kernel.additional_call_args_and_types()
if i not in arg_index:
arg_index[i] = []
arg_index[i].append(slice(0, len(call_args)))
@ -85,7 +97,7 @@ class MultiKernelState:
for i in range(len(kernels)):
arg_index[i] = [slice(0, len(call_args))]
shape_specialize = isinstance(kernels[0], TritonTemplateKernel)
keyed_by_sizes = kernel_shape_keys is not None
buf = self.kernel_defs
buf.writeline("")
buf.writeline("arg_index = {")
@ -93,13 +105,26 @@ class MultiKernelState:
slice_reprs = ", ".join(repr(s) for s in slice_list)
buf.writeline(f" {key}: [{slice_reprs}],")
buf.writeline("}")
buf.writeline(
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
)
with buf.indent():
for name in kernel_names:
buf.writeline(f"{name},")
buf.writeline(f"], arg_index=arg_index, shape_specialize={shape_specialize})")
if not keyed_by_sizes: # no size hint keys, just call with list of kernels
buf.writeline(
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
)
with buf.indent():
for name in kernel_names:
buf.writeline(f"{name},")
buf.writeline("], arg_index=arg_index)")
else: # call with dict[size hint key, kernel]
assert isinstance(kernels[0], TritonTemplateKernel)
assert isinstance(kernel_shape_keys, list)
assert len(kernels) == len(kernel_shape_keys)
buf.writeline(
f"{multi_kernel_name} = async_compile.size_hint_multi_kernel({multi_kernel_name!r}, {{"
)
with buf.indent():
for shape_key, name in zip(kernel_shape_keys, kernel_names):
buf.writeline(f"{shape_key}: {name},")
buf.writeline("}, arg_index=arg_index)")
if config.triton.autotune_at_compile_time:
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
@ -266,8 +291,8 @@ class MultiKernelCall:
This class is called at run time to actually run the kernel
"""
def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False):
assert len(kernels) >= 2
def __init__(self, multi_kernel_name, kernels, arg_index):
assert len(kernels) >= 1
self._kernels = kernels
self.multi_kernel_name = multi_kernel_name
@ -287,13 +312,6 @@ class MultiKernelCall:
self._recorded = False
# This means for each unique shape we will do a separate assessment
# for which kernel is the best. This is particularly useful for matmul
# kernels where the best kernel can vary based on very small differences
# in shape.
self._shape_specialize = shape_specialize
self._shape_cache = {}
def cache_file_path(self):
key = code_hash(
",".join(
@ -415,20 +433,6 @@ class MultiKernelCall:
return V.graph.multi_kernel_to_choice[multi_kernel_name]
def run(self, *args, **kwargs):
if self._shape_specialize:
cache_key = self._get_shape_cache_key(*args, **kwargs)
cached_choice = self._get_cached_shape_choice(cache_key)
if cached_choice is not None:
self.picked_kernel = cached_choice
log.debug(
"using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
self.picked_kernel,
[k.inductor_meta.get("kernel_name") for k in self.kernels],
cache_key,
)
else:
self._select_kernel_by_shape(*args, **kwargs)
if self.picked_kernel is None:
timings = self.benchmark_sub_kernels(*args, **kwargs)
self.picked_kernel = timings.index(min(timings))
@ -460,38 +464,6 @@ class MultiKernelCall:
filtered_args = self._get_filtered_args(args, self.picked_kernel)
run(*filtered_args, **kwargs)
def _get_shape_cache_key(self, *args, **kwargs):
"""
Generate a cache key based on tensor shapes for shape-specialized dispatch.
"""
shapes = []
for arg in args:
if hasattr(arg, "shape"):
shapes.append(tuple(arg.shape))
return tuple(shapes)
def _get_cached_shape_choice(self, cache_key):
"""
Get cached kernel choice for a specific shape.
"""
return self._shape_cache.get(cache_key)
def _cache_shape_choice(self, cache_key, kernel_idx):
"""
Cache kernel choice for a specific shape
"""
self._shape_cache[cache_key] = kernel_idx
def _select_kernel_by_shape(self, *args, **kwargs):
"""
Benchmark kernels for a particular shape and return the
best kernel for this shape.
"""
shape_key = self._get_shape_cache_key(*args, **kwargs)
timings = self.benchmark_sub_kernels(*args, **kwargs)
self.picked_kernel = timings.index(min(timings))
self._cache_shape_choice(shape_key, self.picked_kernel)
def _metrics_table_row(self, timings):
def get_kernel_path(k):
return k.fn.fn.__code__.co_filename
@ -511,3 +483,121 @@ class MultiKernelCall:
row[f"kernel{i}_path"] = ""
row[f"kernel{i}_latency"] = ""
return row
class SizeHintMultiKernel(MultiKernel):
"""
Version of multi-kernel that generates kernels based on specified size hints.
Currently only performs 1-d search over hints; doesn't perform combinatorial n-d search
if n > 1 dynamic dimensions are specified.
e.g. matmul([s0, s1], [s1, s2]) with size-hints [64, 256] only generates 2 kernels,
based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256])
"""
def __init__(self, kernels):
assert isinstance(kernels, dict) and len(kernels) >= 1
self.kernels, self.kernel_shape_keys = [], []
for shape_key, kernel in kernels.items():
self.kernels.append(kernel)
self.kernel_shape_keys.append(shape_key)
self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
self.kernels, self.kernel_shape_keys
)
# need this since some code in inductor check if the kernel object has an args
# attribute to decide if it's a non-null kernel.
self.args = object()
class SizeHintMultiKernelCall(MultiKernelCall):
"""
Runtime class for size-hint multi-kernels.
Instead of having a plain list of kernels to benchmark over, keys them by input & output shapes,
and optionally perform shape-based selection. The pre-generated kernel is chosen based on the shape keys,
with the heuristic being log2 l1 distance between the pre-generated / runtime input & output shapes.
"""
def __init__(self, multi_kernel_name, kernels, arg_index):
super().__init__(multi_kernel_name, list(kernels.values()), arg_index)
self._kernel_hints = list(kernels.keys())
# Caches results for unique shapes.
self._shape_cache = {}
def _get_shape_cache_key(self, *args, **kwargs):
"""
Generate a cache key based on tensor shapes for shape-specialized dispatch.
"""
shapes = []
for arg in args:
if hasattr(arg, "shape"):
shapes.append(tuple(arg.shape))
return tuple(shapes)
def _get_cached_shape_choice(self, cache_key):
"""
Get cached kernel choice for a specific shape.
"""
return self._shape_cache.get(cache_key)
def _cache_shape_choice(self, cache_key, kernel_idx):
"""
Cache kernel choice for a specific shape.
"""
self._shape_cache[cache_key] = kernel_idx
def _dist_heuristic(self, k1, k2):
"""
log2 L1 distance heuristic for kernel selection.
"""
def dist(x, y):
lx = math.log2(x) if x > 0 else -1
ly = math.log2(y) if y > 0 else -1
return abs(lx - ly)
out = 0
for s1, s2 in zip(k1, k2):
out += sum(dist(x, y) for x, y in zip(s1, s2))
return out
def run(self, *args, **kwargs):
cache_key = self._get_shape_cache_key(*args, **kwargs)
cached_choice = self._get_cached_shape_choice(cache_key)
if cached_choice is not None:
self.picked_kernel = cached_choice
log.debug(
"using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
self.picked_kernel,
[k.inductor_meta.get("kernel_name") for k in self.kernels],
cache_key,
)
else:
self._select_kernel_by_shape(*args, **kwargs)
if not self._recorded:
self._recorded = True
picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get(
"kernel_name"
)
assert picked_kernel_name is not None
self.record_choice(self.multi_kernel_name, picked_kernel_name)
run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
filtered_args = self._get_filtered_args(args, self.picked_kernel)
run(*filtered_args, **kwargs)
def _select_kernel_by_shape(self, *args, **kwargs):
"""
Benchmark kernels for a particular shape and return the
best kernel for this shape.
"""
shape_key = self._get_shape_cache_key(*args, **kwargs)
dists = [
self._dist_heuristic(shape_key, key) if key is not None else 2**62
for key in self._kernel_hints
]
self.picked_kernel = dists.index(min(dists))
self._cache_shape_choice(shape_key, self.picked_kernel)

View File

@ -60,7 +60,7 @@ from ..utils import (
from ..virtualized import ops, OpsWrapper, V
from .block_analysis import BlockPatternMatcher
from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter
from .multi_kernel import MultiKernel
from .multi_kernel import MultiKernel, SizeHintMultiKernel
from .simd_kernel_features import (
DisableReduction,
EnableReduction,
@ -1689,6 +1689,51 @@ class SIMDScheduling(BaseScheduling):
return kernel
def _get_multikernel_shapes(
self, node: MultiTemplateBuffer
) -> tuple[tuple[int, ...], ...]:
from ..ir import IRNode
def get_size(arg):
if not isinstance(arg, IRNode) or (size := arg.maybe_get_size()) is None:
return None
return tuple(s for s in size)
out = []
for arg in list(node.inputs) + [node]:
if isinstance(arg, (list, tuple)):
out.append(tuple(get_size(_arg) for _arg in arg))
else:
out.append(get_size(arg))
return tuple(out)
def _kernel_has_dynamic_shapes(self, node: MultiTemplateBuffer) -> bool:
shapes = self._get_multikernel_shapes(node)
return any(
any(
isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer)
for s in shape
)
for shape in shapes
)
def _make_shape_cache_key(
self, node: MultiTemplateBuffer, hint: int
) -> tuple[tuple[int, ...], ...]:
"""
Returns cache key for hint-based multi-graph; key is tuple of shapes with hint filled in.
"""
shapes = self._get_multikernel_shapes(node)
return tuple(
tuple(
hint
if isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer)
else s
for s in shape
)
for shape in shapes
)
def codegen_template(
self,
template_node,
@ -1711,11 +1756,16 @@ class SIMDScheduling(BaseScheduling):
if (
isinstance(template_node.node, MultiTemplateBuffer)
and template_node.node._make_kernel_renders
and len(template_node.node._make_kernel_renders) > 1
and self._kernel_has_dynamic_shapes(template_node.node)
):
kernels = []
kernels = {}
src_codes = []
for make_kernel_render in template_node.node._make_kernel_renders.values():
for (
size_hint,
make_kernel_render,
) in template_node.node._make_kernel_renders.items():
kernel, render = make_kernel_render(
template_node.node, hint_override=hint_override
)
@ -1732,6 +1782,8 @@ class SIMDScheduling(BaseScheduling):
assert isinstance(src_code, str)
src_codes.append(src_code)
else:
if size_hint is None:
continue # skip kernel generation based on real runtime value; only use hints
kernel = self._codegen_single_template(
kernel,
render,
@ -1740,13 +1792,18 @@ class SIMDScheduling(BaseScheduling):
prologue_nodes,
only_gen_src_code=False,
)
kernels.append(kernel)
shape_cache_key = (
None
if size_hint is None
else self._make_shape_cache_key(template_node.node, size_hint)
)
kernels[shape_cache_key] = kernel
if only_gen_src_code:
return "\n\n".join(src_codes)
MultiKernel.merge_workspaces_inplace(kernels)
multi_kernel = MultiKernel(kernels)
MultiKernel.merge_workspaces_inplace(list(kernels.values()))
multi_kernel = SizeHintMultiKernel(kernels)
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
self.codegen_comment(node_schedule)