multi-kernel matmuls based on varying hint sizes (#156628)

The core idea is to generate multiple matmul kernels using different hints for symbolic variables, then select the most appropriate one at runtime for each unique shape we encounter. You can find some early experimentation details in these posts:

https://fb.workplace.com/groups/8940092306109185/posts/9803850776399996/
https://fb.workplace.com/groups/8940092306109185/posts/9695805170537891/
https://fb.workplace.com/groups/257735836456307/posts/906589324904285/

Here’s a graph illustrating the empirically observed worst-case performance if an oracle always selected the least optimal hint for a given runtime size:

![image](https://github.com/user-attachments/assets/6d90ee06-a572-453e-9cba-03006f343301)

This graph illustrates the performance of a hint size of 64 relative to the worst case. Notice that as the runtime sizes increase, the performance gradually approaches the worst case:

![image](https://github.com/user-attachments/assets/85ad49fe-165a-474c-8d03-db2e57654213)

This graph shows the performance of a hint size of 4096 — very poor for small sizes, and also suboptimal for some mid-sized shapes:

![image](https://github.com/user-attachments/assets/adea1106-3bc8-40f3-97b0-20d940fb74f1)

Finally, here’s the graph that motivated this PR. It illustrates the performance when selecting the best of three kernels generated with three different hints — 64, 256, and 4096:

![image](https://github.com/user-attachments/assets/a7cb0ce5-8139-48b1-b5c9-7670e75cbfce)

## How to review this PR

At a high level, this extends @shunting314's multi-kernel abstraction to support varying GEMM choices driven by different hints. A few key points:

1. Unlike reduction kernels, triton template matmuls pass their grid as arguments to the kernel. This PR updates `MultiKernelCall` to support kernels with varying arguments.
2. The `V.graph.sizevars.size_hints` API is extended to accept a `hint_override`, allowing us to substitute the example input’s size hint with a custom value when generating multiple kernels.
3. The choice generation and benchmarking logic is updated to support multiple hint values. One kernel is generated per value in `torch._inductor.config.multi_kernel_hints`, and at runtime, we select the most suitable kernel for the current shape.
4. This PR does not add support for cpp wrapper codegen to keep it scoped. That will be added in the next PR.

## Results

The following is a basic test that shows our basic multi kernel working where we no longer show significant variance based on the original hint size: https://gist.github.com/bobrenjc93/ba711d529e65fd65839b34799f6323ec

Before
```
Hint\Runtime |     64     |    256     |    4096
---------------------------------------------------
     64      |   0.0948   |   0.3124   |   4.9477
    256      |   0.2243   |   0.2256   |   3.3880
    4096     |   0.3384   |   0.3404   |   3.3010
```

After
```
Hint\Runtime |     64     |    256     |    4096
---------------------------------------------------
     64      |   0.0951   |   0.2289   |   3.3013
    256      |   0.0952   |   0.2258   |   3.4045
    4096     |   0.0957   |   0.2231   |   3.3146
```

We also see an average speedup of 5.04% for the matrix of all hint/runtime pairs in [64, 4096] for every increment of 64: https://docs.google.com/spreadsheets/d/12TmYUDrAAFASGuP3POXTKPeAvQWIRzKzdrVSIb3vQkA/edit?gid=480268938#gid=480268938

![Worst Case, multi-kernel](https://github.com/user-attachments/assets/712df23b-87e2-4d9d-95c2-cc25305ba2ed)

NB: This is just the beginning and I plan on doing more investigation to see further improve on this initial result.

For posterity the script used to generate that matrix is here: https://gist.github.com/bobrenjc93/c211fd0bd97fad8f46b91ad9dee76ad0

HUD benchmark runs:
base: https://github.com/pytorch/pytorch/actions/runs/15889871988
head: https://github.com/pytorch/pytorch/actions/runs/15889876842

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156628
Approved by: https://github.com/jansel
This commit is contained in:
bobrenjc93
2025-07-11 20:54:18 -07:00
committed by PyTorch MergeBot
parent 191693ac85
commit 5221448574
15 changed files with 635 additions and 139 deletions

View File

@ -294,6 +294,15 @@ RecordedEventsType = list[tuple[str, list[Any], dict[str, Any]]]
class TritonTemplateKernel(TritonKernel):
"""
A specialized kernel class for Triton templates that handles code generation
for templated Triton kernels.
This class extends TritonKernel to provide additional functionality for
template-based kernel generation, including support for subgraphs, workspace
arguments, and prologue/epilogue fusion.
"""
def __init__(
self,
kernel_name,
@ -314,6 +323,7 @@ class TritonTemplateKernel(TritonKernel):
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
prologue_loads_all_inputs=False,
hint_override: Optional[int] = None,
) -> None:
numel = sympy_product(output_node.get_size())
super().__init__(
@ -322,6 +332,7 @@ class TritonTemplateKernel(TritonKernel):
"r0_": sympy.S.One,
},
features=SIMDKernelFeatures([], numel),
hint_override=hint_override,
)
self.input_nodes = input_nodes
self.output_node = output_node
@ -1096,16 +1107,26 @@ class TritonTemplateKernel(TritonKernel):
def codegen_range_tree(self):
pass # ignore default codegen
def additional_call_args_and_types(self):
if isinstance(self.grid_fn, SymbolicGridFn):
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
return (grid_args, map(type, grid_args))
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
return (grid_args, map(type, grid_args))
return ((), ())
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
wrapper = V.graph.wrapper_code
_, call_args, _, arg_types = self.args.python_argdefs()
grid_args = ()
if isinstance(self.grid_fn, SymbolicGridFn):
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
else:
additional_call_args, additional_arg_types = (
self.additional_call_args_and_types()
)
if not additional_call_args:
assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn"
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
meta = wrapper.add_meta_once(self.meta)
@ -1114,9 +1135,9 @@ class TritonTemplateKernel(TritonKernel):
f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})"
)
arg_types.append(None)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
call_args.extend(grid_args)
arg_types.extend(map(type, grid_args))
call_args.extend(additional_call_args)
arg_types.extend(additional_arg_types)
if self.workspace_arg is not None:
wrapper.generate_workspace_allocation(self.workspace_arg)
@ -1202,6 +1223,7 @@ class GeneratedCodeCache:
num_consumer_groups: int,
num_buffers_warp_spec: int,
kwargs: dict[str, Any],
hint_override: Optional[int] = None,
) -> Optional[str]:
def layout_key(layout: ir.Layout) -> str:
assert not isinstance(layout, ir.FlexibleLayout)
@ -1252,6 +1274,7 @@ class GeneratedCodeCache:
"num_buffers_warp_spec": num_buffers_warp_spec,
"epilogue_fn_hash": epilogue_fn_hash,
"kwargs": kwargs,
"hint_override": hint_override,
}
)
@ -1356,6 +1379,7 @@ class TritonTemplate(KernelTemplate):
layout: ir.Layout,
kwargs: dict[str, Any],
generate_with_caching,
hint_override: Optional[int] = None,
) -> Optional[GenerateAndLoadResult]:
"""Generate the python code and load it into the current process"""
caching_enabled = (
@ -1428,6 +1452,7 @@ class TritonTemplate(KernelTemplate):
output_node=fake_out,
workspace_arg=workspace_arg,
use_jit=False,
hint_override=hint_override,
**kernel_options,
)
@ -1547,6 +1572,7 @@ class TritonTemplate(KernelTemplate):
call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
generate_with_caching=False,
hint_override: Optional[int] = None,
**kwargs,
):
"""This function generates a TritonTemplateCaller
@ -1591,6 +1617,7 @@ class TritonTemplate(KernelTemplate):
layout,
kwargs,
generate_with_caching and self._cache_codegen_enabled_for_template,
hint_override=hint_override,
)
# May happen as result of dev by 0.
@ -1612,6 +1639,7 @@ class TritonTemplate(KernelTemplate):
extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, result.kernel_args_sizevars_keys),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
@ -1635,13 +1663,14 @@ class TritonTemplate(KernelTemplate):
options = result.kernel_options
def make_kernel_render(out_node):
def make_kernel_render(out_node, hint_override: Optional[int] = None):
assert result is not None
kernel = self.kernel_type(
kernel_name=str(Placeholder.KERNEL_NAME),
output_node=out_node,
workspace_arg=workspace_arg,
use_jit=False,
hint_override=hint_override,
**options,
)
render = functools.partial(
@ -1657,6 +1686,7 @@ class TritonTemplate(KernelTemplate):
*V.graph.sizevars.size_hints(
call_sizes,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
kwargs,
)
@ -1708,6 +1738,7 @@ class TritonTemplate(KernelTemplate):
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,
allowed_prologue_inps=result.prologue_supported_inputs,
hint_override=hint_override,
)
@ -1783,6 +1814,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
mutated_inputs=None,
workspace_arg: Optional[WorkspaceArg] = None,
allowed_prologue_inps: Optional[OrderedSet[str]] = None,
hint_override: Optional[int] = None,
) -> None:
super().__init__(name, input_nodes, layout, description)
self.make_kernel_render = make_kernel_render
@ -1802,6 +1834,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
self.allowed_prologue_inps = (
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet()
)
self.hint_override = hint_override
def benchmark(self, *args, out):
assert self.bmreq is not None
@ -2256,7 +2289,7 @@ class AlgorithmSelectorCache(PersistentCache):
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
has_autotuned = False
def benchmark(choices):
def benchmark(choices, hint_override: Optional[int] = None):
nonlocal has_autotuned
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
has_autotuned = True
@ -2264,13 +2297,13 @@ class AlgorithmSelectorCache(PersistentCache):
# TODO(nmacchioni): remove this layer of abstraction
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
benchmark_fn = self.make_benchmark_fn(
choices, input_nodes, layout, input_gen_fns
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
)
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
return benchmark_fn(choices)
def autotune(choices):
def autotune(choices, hint_override: Optional[int] = None):
log.debug("Starting autotuning")
with dynamo_timed(
@ -2292,13 +2325,13 @@ class AlgorithmSelectorCache(PersistentCache):
),
},
):
return benchmark(choices)
return benchmark(choices, hint_override=hint_override)
if config.autotune_in_subproc:
# Initialize the suprocess pool so it will warmup early.
torch._inductor.autotune_process.get_tuning_process_pool()
def do_autotuning(choices, precompile_fn):
def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None):
precompile_start_ts = time.time()
with dynamo_timed(
f"{name}_template_precompiling",
@ -2319,7 +2352,8 @@ class AlgorithmSelectorCache(PersistentCache):
candidates,
name,
inputs_key,
autotune,
lambda choices: autotune(choices, hint_override=hint_override),
hint_override=hint_override,
)
choices = self.prune_choices_postscreen(
choices, timings, name, inputs_key, self.prescreening_cache
@ -2332,7 +2366,8 @@ class AlgorithmSelectorCache(PersistentCache):
choices,
name,
inputs_key,
autotune,
lambda choices: autotune(choices, hint_override=hint_override),
hint_override=hint_override,
)
autotune_elapse = time.time() - autotune_start_ts
@ -2355,6 +2390,7 @@ class AlgorithmSelectorCache(PersistentCache):
autotune_elapse,
precompile_elapse,
prescreening_elapse,
hint_override=hint_override,
)
def profiler_bench_function():
@ -2389,8 +2425,16 @@ class AlgorithmSelectorCache(PersistentCache):
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
def get_timings():
timings = do_autotuning(choices, precompile_fn)
def get_timings(hint_override: Optional[int] = None):
filtered_choices = [
c
for c in choices
if not hasattr(c, "hint_override")
or c.hint_override == hint_override
]
timings = do_autotuning(
filtered_choices, precompile_fn, hint_override=hint_override
)
min_extern_choice = float("inf")
for choice, timing in timings.items():
if isinstance(choice, ExternKernelCaller):
@ -2627,6 +2671,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
) -> AutotuneArgs:
"""
Factory method to create AutotuneArgs from a list of ChoiceCallers.
@ -2636,7 +2681,9 @@ class AlgorithmSelectorCache(PersistentCache):
# de-duplicate args
unique_example_inputs = {
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
x.get_name(): input_gen_fns.get(
i, lambda x: cls.benchmark_example_value(x, hint_override=hint_override)
)(x)
for i, x in enumerate(input_nodes)
}
example_inputs = list(unique_example_inputs.values())
@ -2649,20 +2696,23 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
input_node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hints(
input_node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hint(
input_node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
)
)
for input_node in input_nodes
]
out = cls.benchmark_example_value(layout)
out = cls.benchmark_example_value(layout, hint_override=hint_override)
out_extern = torch.as_strided(
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
)
@ -2771,8 +2821,11 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
) -> dict[ChoiceCaller, float]:
inputs = cls.get_inputs(choices, input_nodes, layout, input_gen_fns)
inputs = cls.get_inputs(
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
)
return cls.benchmark_choices(choices, inputs)
@classmethod
@ -2782,6 +2835,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
):
from . import autotune_process
@ -2791,7 +2845,7 @@ class AlgorithmSelectorCache(PersistentCache):
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
timings = cls.benchmark_in_current_process(
extern, input_nodes, layout, input_gen_fns
extern, input_nodes, layout, input_gen_fns, hint_override=hint_override
)
timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type]
return timings
@ -2803,6 +2857,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
):
if DEBUG:
print(f"{len(choices)} tuning requests:")
@ -2813,6 +2868,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
)
else:
return functools.partial(
@ -2820,6 +2876,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
)
@staticmethod
@ -2978,6 +3035,7 @@ class AlgorithmSelectorCache(PersistentCache):
elapse: float,
precompile_elapse: float,
prescreening_elapse: Optional[float] = None,
hint_override: Optional[int] = None,
):
V.debug.log_autotuning_results(
name, input_nodes, timings, elapse, precompile_elapse
@ -2992,6 +3050,7 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
n.get_size(),
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
),
)
)
@ -3079,7 +3138,7 @@ class AlgorithmSelectorCache(PersistentCache):
)
@staticmethod
def benchmark_example_value(node):
def benchmark_example_value(node, hint_override: Optional[int] = None):
"""
Convert an ir.Buffer into a concrete torch.Tensor we can use for
benchmarking.
@ -3098,10 +3157,12 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hints(
node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
node.get_device(),
node.get_dtype(),
@ -3109,6 +3170,7 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
V.graph.get_allocation_size(node),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
)