multi-kernel matmuls based on varying hint sizes (#156628)

The core idea is to generate multiple matmul kernels using different hints for symbolic variables, then select the most appropriate one at runtime for each unique shape we encounter. You can find some early experimentation details in these posts:

https://fb.workplace.com/groups/8940092306109185/posts/9803850776399996/
https://fb.workplace.com/groups/8940092306109185/posts/9695805170537891/
https://fb.workplace.com/groups/257735836456307/posts/906589324904285/

Here’s a graph illustrating the empirically observed worst-case performance if an oracle always selected the least optimal hint for a given runtime size:

![image](https://github.com/user-attachments/assets/6d90ee06-a572-453e-9cba-03006f343301)

This graph illustrates the performance of a hint size of 64 relative to the worst case. Notice that as the runtime sizes increase, the performance gradually approaches the worst case:

![image](https://github.com/user-attachments/assets/85ad49fe-165a-474c-8d03-db2e57654213)

This graph shows the performance of a hint size of 4096 — very poor for small sizes, and also suboptimal for some mid-sized shapes:

![image](https://github.com/user-attachments/assets/adea1106-3bc8-40f3-97b0-20d940fb74f1)

Finally, here’s the graph that motivated this PR. It illustrates the performance when selecting the best of three kernels generated with three different hints — 64, 256, and 4096:

![image](https://github.com/user-attachments/assets/a7cb0ce5-8139-48b1-b5c9-7670e75cbfce)

## How to review this PR

At a high level, this extends @shunting314's multi-kernel abstraction to support varying GEMM choices driven by different hints. A few key points:

1. Unlike reduction kernels, triton template matmuls pass their grid as arguments to the kernel. This PR updates `MultiKernelCall` to support kernels with varying arguments.
2. The `V.graph.sizevars.size_hints` API is extended to accept a `hint_override`, allowing us to substitute the example input’s size hint with a custom value when generating multiple kernels.
3. The choice generation and benchmarking logic is updated to support multiple hint values. One kernel is generated per value in `torch._inductor.config.multi_kernel_hints`, and at runtime, we select the most suitable kernel for the current shape.
4. This PR does not add support for cpp wrapper codegen to keep it scoped. That will be added in the next PR.

## Results

The following is a basic test that shows our basic multi kernel working where we no longer show significant variance based on the original hint size: https://gist.github.com/bobrenjc93/ba711d529e65fd65839b34799f6323ec

Before
```
Hint\Runtime |     64     |    256     |    4096
---------------------------------------------------
     64      |   0.0948   |   0.3124   |   4.9477
    256      |   0.2243   |   0.2256   |   3.3880
    4096     |   0.3384   |   0.3404   |   3.3010
```

After
```
Hint\Runtime |     64     |    256     |    4096
---------------------------------------------------
     64      |   0.0951   |   0.2289   |   3.3013
    256      |   0.0952   |   0.2258   |   3.4045
    4096     |   0.0957   |   0.2231   |   3.3146
```

We also see an average speedup of 5.04% for the matrix of all hint/runtime pairs in [64, 4096] for every increment of 64: https://docs.google.com/spreadsheets/d/12TmYUDrAAFASGuP3POXTKPeAvQWIRzKzdrVSIb3vQkA/edit?gid=480268938#gid=480268938

![Worst Case, multi-kernel](https://github.com/user-attachments/assets/712df23b-87e2-4d9d-95c2-cc25305ba2ed)

NB: This is just the beginning and I plan on doing more investigation to see further improve on this initial result.

For posterity the script used to generate that matrix is here: https://gist.github.com/bobrenjc93/c211fd0bd97fad8f46b91ad9dee76ad0

HUD benchmark runs:
base: https://github.com/pytorch/pytorch/actions/runs/15889871988
head: https://github.com/pytorch/pytorch/actions/runs/15889876842

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156628
Approved by: https://github.com/jansel
This commit is contained in:
bobrenjc93
2025-07-11 20:54:18 -07:00
committed by PyTorch MergeBot
parent 191693ac85
commit 5221448574
15 changed files with 635 additions and 139 deletions

View File

@ -51,7 +51,7 @@ aten = torch.ops.aten
def patches(fn):
def skip_cache(self, choices, name, key, benchmark):
def skip_cache(self, choices, name, key, benchmark, hint_override=None):
if benchmark is None:
return {}
timings = benchmark(choices)

View File

@ -1205,7 +1205,7 @@ class TestMaxAutotune(TestCase):
# Make sure all args of generate_and_load_args are passed to make_key_args (Except generate_with_caching)
# update this function each time new arg added to generate_and_load and make sure arg is added to make_key
self.assertEqual(generate_and_load_args - 1, make_key_args)
self.assertEqual(generate_and_load_args, 16)
self.assertEqual(generate_and_load_args, 17)
@fresh_cache()
@config.patch(
@ -1293,7 +1293,7 @@ class TestMaxAutotune(TestCase):
'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]",
'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity',
'kwargs':{'EVEN_K':False,'ALLOW_TF32':True,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32',
'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8}}"""
'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}"""
expected = expected.replace("cuda", GPU_TYPE)
self.assertExpectedInline(
@ -1332,7 +1332,7 @@ class TestMaxAutotune(TestCase):
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94],
'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0,
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True,
'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8}}"""
'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}"""
expected = expected.replace("cuda", GPU_TYPE)
self.assertExpectedInline(
remove_white_space(cache_key),
@ -1585,6 +1585,7 @@ class TestMaxAutotunePrecompile(TestCase):
op: str,
inputs: str,
benchmark: Callable[[Any], dict[ChoiceCaller, float]],
hint_override: Optional[int] = None,
) -> Optional[dict[ChoiceCaller, float]]:
if benchmark is not None:
return benchmark(choices)

View File

@ -18,7 +18,12 @@ from torch.testing._internal.common_utils import (
parametrize,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
IS_BIG_GPU,
requires_triton,
)
class TransformerSnippet(nn.Module):
@ -71,6 +76,7 @@ def make_cpp_wrapper_test(orig_test, **extra_args):
{
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
"benchmark_kernel": True,
"multi_kernel_hints": [64, 256, 4096],
}
)
@instantiate_parametrized_tests
@ -91,6 +97,56 @@ class MultiKernelTest(TestCase):
else:
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
@requires_triton()
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
def test_triton_gemm(self):
def fn(x, y):
return x @ y
compiled_fn = torch.compile(
fn,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
x = torch.randn(4096, 4096, device=GPU_TYPE)
y = torch.randn(4096, 4096, device=GPU_TYPE)
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
ref = fn(x, y)
# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertEqual(ref, act)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
@requires_triton()
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
def test_triton_relu_fused_gemm(self):
def fn(x, y):
return (x @ y).relu()
compiled_fn = torch.compile(
fn,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
x = torch.randn(4096, 4096, device=GPU_TYPE)
y = torch.randn(4096, 4096, device=GPU_TYPE)
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
ref = fn(x, y)
# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertEqual(ref, act)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
@parametrize("force_kernel", (0, 1))
@unittest.mock.patch.dict(
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}

View File

@ -20,7 +20,7 @@ aten = torch.ops.aten
def patches(fn):
def skip_cache(self, choices, name, key, benchmark):
def skip_cache(self, choices, name, key, benchmark, hint_override=None):
if benchmark is None:
return {}
return benchmark(choices)

View File

@ -265,6 +265,7 @@ class PersistentCache(CacheBase):
op: str,
inputs: str,
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
hint_override: Optional[int] = None,
) -> dict[ChoiceCaller, float]:
"""
Check to see if we have benchmarked the given choice callers. For each
@ -277,6 +278,7 @@ class PersistentCache(CacheBase):
b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing.
"""
precision = torch.get_float32_matmul_precision()
cache_key = f"{inputs}_{hint_override}" if hint_override is not None else inputs
timings = {}
@ -285,9 +287,11 @@ class PersistentCache(CacheBase):
hit = True
for choice in choices:
choice_hash = choice.hash_key()
if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}):
if choice_hash in cache.get(op, {}).get(cache_key, {}).get(
precision, {}
):
# cache hit
timings[choice] = cache[op][inputs][precision][choice_hash]
timings[choice] = cache[op][cache_key][precision][choice_hash]
else:
# cache miss
hit = False
@ -300,9 +304,9 @@ class PersistentCache(CacheBase):
timings = benchmark(choices)
assert all(choice in timings for choice in choices)
local_cache.setdefault(op, {})
local_cache[op].setdefault(inputs, {}).setdefault(precision, {})
local_cache[op].setdefault(cache_key, {}).setdefault(precision, {})
for choice, timing in timings.items():
local_cache[op][inputs][precision][choice.hash_key()] = timing
local_cache[op][cache_key][precision][choice.hash_key()] = timing
self.update_local_cache(local_cache)

View File

@ -124,10 +124,13 @@ class CUDACombinedScheduling(BaseScheduling):
return self._triton_scheduling.benchmark_codegened_module(module)
def generate_kernel_code_from_nodes(
self, nodes: Sequence[Any], benchmark_kernel: bool = False
self,
nodes: Sequence[Any],
benchmark_kernel: bool = False,
hint_override: Optional[int] = None,
) -> str:
return self._triton_scheduling.generate_kernel_code_from_nodes(
nodes, benchmark_kernel
nodes, benchmark_kernel, hint_override=hint_override
)
def benchmark_combo_kernel(

View File

@ -4,6 +4,7 @@ import logging
import os
import pathlib
from torch._inductor.ir import MultiTemplateBuffer
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.utils._ordered_set import OrderedSet
@ -45,6 +46,9 @@ class MultiKernelState:
We should name the multi-kernel differently in these 2 cases.
"""
# Prevent circular import
from ..select_algorithm import TritonTemplateKernel
kernel_names = tuple(k.kernel_name for k in kernels)
if kernel_names in self.subkernel_to_kernel_name:
return self.subkernel_to_kernel_name[kernel_names]
@ -58,15 +62,44 @@ class MultiKernelState:
# the second pass of cpp-wrapper.
return multi_kernel_name
arg_index: dict[int, list[slice]] = {}
_, call_args, _, arg_types = kernels[0].args.python_argdefs()
if isinstance(kernels[0], TritonTemplateKernel) and isinstance(
kernels[0].output_node, MultiTemplateBuffer
):
for i, kernel in enumerate(kernels):
additional_call_args, additional_arg_types = (
kernel.additional_call_args_and_types()
)
if i not in arg_index:
arg_index[i] = []
arg_index[i].append(slice(0, len(call_args)))
arg_index[i].append(
slice(
len(call_args) + i * len(additional_call_args),
len(call_args) + (i + 1) * len(additional_call_args),
)
)
else:
kernels[0].add_numel_to_call_args(multi_kernel_name, call_args, arg_types)
for i in range(len(kernels)):
arg_index[i] = [slice(0, len(call_args))]
shape_specialize = isinstance(kernels[0], TritonTemplateKernel)
buf = self.kernel_defs
buf.writeline("")
buf.writeline("arg_index = {")
for key, slice_list in arg_index.items():
slice_reprs = ", ".join(repr(s) for s in slice_list)
buf.writeline(f" {key}: [{slice_reprs}],")
buf.writeline("}")
buf.writeline(
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
)
with buf.indent():
for name in kernel_names:
buf.writeline(f"{name},")
buf.writeline("])")
buf.writeline(f"], arg_index=arg_index, shape_specialize={shape_specialize})")
if config.triton.autotune_at_compile_time:
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
@ -135,6 +168,9 @@ class MultiKernel:
Collect the union of arguments from all subkernels as the arguments
for the multi-kernel.
"""
# Prevent circular import
from ..select_algorithm import TritonTemplateKernel
assert kernel_name == self.kernel_name
V.graph.wrapper_code.write_triton_header_once()
_, call_args, _, arg_types = self.kernels[0].args.python_argdefs()
@ -148,17 +184,42 @@ class MultiKernel:
# the fast kernel directly
kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
# numels for all subkernels should be the same. Use kernels[0] here
self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types)
if isinstance(self.kernels[0], TritonTemplateKernel) and isinstance(
self.kernels[0].output_node, MultiTemplateBuffer
):
# For matmuls the grid arguments are passed in as additional arguments
# to the kernel run method. These grids change based on the various
# parameters of the matmul. So we need to pass each kernel's grid into
# the multi call kernel.
multi_call_args = call_args
multi_call_arg_types = arg_types
for i, kernel in enumerate(self.kernels):
additional_call_args, additional_arg_types = (
kernel.additional_call_args_and_types()
)
multi_call_args.extend(list(additional_call_args))
multi_call_arg_types.extend(list(additional_arg_types))
else:
# numels for all subkernels should be the same. Use kernels[0] here
self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types)
multi_call_args = call_args
multi_call_arg_types = arg_types
for ws in self.kernels[0].args.workspace_args:
V.graph.wrapper_code.generate_workspace_allocation(ws)
V.graph.wrapper_code.generate_kernel_call(
kernel_name,
call_args,
arg_types=arg_types,
)
if V.graph.cpp_wrapper:
# We have already selected the best kernel at compile time
# so we only have one set of call args. NB: this currently
# doesn't work with MultiTemplateBuffer kernels. @bobrenjc93
# will add it in a subsequent PR.
V.graph.wrapper_code.generate_kernel_call(
kernel_name, call_args, arg_types=arg_types
)
else:
V.graph.wrapper_code.generate_kernel_call(
kernel_name, multi_call_args, arg_types=multi_call_arg_types
)
for ws in reversed(self.kernels[0].args.workspace_args):
V.graph.wrapper_code.generate_workspace_deallocation(ws)
@ -205,7 +266,7 @@ class MultiKernelCall:
This class is called at run time to actually run the kernel
"""
def __init__(self, multi_kernel_name, kernels):
def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False):
assert len(kernels) >= 2
self._kernels = kernels
self.multi_kernel_name = multi_kernel_name
@ -215,6 +276,7 @@ class MultiKernelCall:
) == "1" or is_metric_table_enabled("persistent_red_perf")
self.picked_kernel = None
self.arg_index = arg_index
if config.triton.multi_kernel > 1:
# manually force a subkernel to ease perf testing
picked_by_config = config.triton.multi_kernel - 2
@ -225,6 +287,13 @@ class MultiKernelCall:
self._recorded = False
# This means for each unique shape we will do a separate assessment
# for which kernel is the best. This is particularly useful for matmul
# kernels where the best kernel can vary based on very small differences
# in shape.
self._shape_specialize = shape_specialize
self._shape_cache = {}
def cache_file_path(self):
key = code_hash(
",".join(
@ -282,18 +351,34 @@ class MultiKernelCall:
be picked.
"""
def wrap_fn(kernel):
def wrap_fn(kernel, index):
def inner():
args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs)
filtered_args = self._get_filtered_args(args, index)
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
return kernel.run(*args_clone, **kwargs_clone)
return inner
return [
benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40)
for kernel in self.kernels
benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40)
for index, kernel in enumerate(self.kernels)
]
def _get_filtered_args(self, args, index):
"""
We pass in all arguments to all kernels into the MultiKernelCall
so when invoking a particular kernel we need to filter to only the
arguments for that specific kernel.
"""
# This is sometimes invoked at runtime where V.graph is
# a NullHandler
if hasattr(V.graph, "cpp_wrapper") and V.graph.cpp_wrapper:
# for cpp-wrapper, we should not filter the args since
# we already have chosen a single kernel and arg set.
return args
return [item for s in self.arg_index[index] for item in args[s]]
# record_choice and lookup_choice are helper functions for cpp-wrapper
# codegen. The first pass use record_choice to keep the choice and
# the second pass do lookup by calling lookup_choice.
@ -330,6 +415,20 @@ class MultiKernelCall:
return V.graph.multi_kernel_to_choice[multi_kernel_name]
def run(self, *args, **kwargs):
if self._shape_specialize:
cache_key = self._get_shape_cache_key(*args, **kwargs)
cached_choice = self._get_cached_shape_choice(cache_key)
if cached_choice is not None:
self.picked_kernel = cached_choice
log.debug(
"using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
self.picked_kernel,
[k.inductor_meta.get("kernel_name") for k in self.kernels],
cache_key,
)
else:
self._select_kernel_by_shape(*args, **kwargs)
if self.picked_kernel is None:
timings = self.benchmark_sub_kernels(*args, **kwargs)
self.picked_kernel = timings.index(min(timings))
@ -345,6 +444,7 @@ class MultiKernelCall:
get_metric_table("persistent_red_perf").add_row(
functools.partial(self._metrics_table_row, timings)
)
if not self.disable_cache:
self.store_cache()
@ -355,8 +455,42 @@ class MultiKernelCall:
)
assert picked_kernel_name is not None
self.record_choice(self.multi_kernel_name, picked_kernel_name)
self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
self.run(*args, **kwargs)
run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
filtered_args = self._get_filtered_args(args, self.picked_kernel)
run(*filtered_args, **kwargs)
def _get_shape_cache_key(self, *args, **kwargs):
"""
Generate a cache key based on tensor shapes for shape-specialized dispatch.
"""
shapes = []
for arg in args:
if hasattr(arg, "shape"):
shapes.append(tuple(arg.shape))
return tuple(shapes)
def _get_cached_shape_choice(self, cache_key):
"""
Get cached kernel choice for a specific shape.
"""
return self._shape_cache.get(cache_key)
def _cache_shape_choice(self, cache_key, kernel_idx):
"""
Cache kernel choice for a specific shape
"""
self._shape_cache[cache_key] = kernel_idx
def _select_kernel_by_shape(self, *args, **kwargs):
"""
Benchmark kernels for a particular shape and return the
best kernel for this shape.
"""
shape_key = self._get_shape_cache_key(*args, **kwargs)
timings = self.benchmark_sub_kernels(*args, **kwargs)
self.picked_kernel = timings.index(min(timings))
self._cache_shape_choice(shape_key, self.picked_kernel)
def _metrics_table_row(self, timings):
def get_kernel_path(k):

View File

@ -18,6 +18,7 @@ import sympy
import torch
import torch._logging
from torch._inductor.ir import MultiTemplateBuffer
from torch._inductor.tiling_utils import analyze_memory_coalescing
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.immutable_collections import immutable_dict
@ -1548,18 +1549,19 @@ class SIMDScheduling(BaseScheduling):
index_vars = kernel.split_and_set_ranges(node.get_ranges())
node.codegen(index_vars)
def codegen_template(
self, template_node, epilogue_nodes, prologue_nodes, *, only_gen_src_code=False
) -> Optional[str]:
def _codegen_single_template(
self,
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
*,
only_gen_src_code=False,
):
"""
Codegen a triton template
If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper
Helper method to codegen a single template kernel variant
"""
_, (_numel, rnumel) = template_node.group
assert rnumel == 1
kernel, render = template_node.node.make_kernel_render(template_node.node)
buf_name_to_prologue_group = {}
template_reads = template_node.used_buffer_names()
prologue_group = []
@ -1655,18 +1657,113 @@ class SIMDScheduling(BaseScheduling):
if only_gen_src_code:
return src_code
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
if config.trace.enabled:
set_kernel_post_grad_provenance_tracing(node_schedule, kernel_name)
set_kernel_post_grad_provenance_tracing(
node_schedule, kernel.kernel_name
)
self.codegen_comment(node_schedule)
kernel.call_kernel(kernel_name, template_node.node)
return kernel
V.graph.removed_buffers |= kernel.removed_buffers
V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
self.free_buffers_in_scheduler()
return None
def codegen_template(
self,
template_node,
epilogue_nodes,
prologue_nodes,
*,
only_gen_src_code=False,
hint_override: Optional[int] = None,
) -> Optional[str]:
"""
Codegen a triton template with multi-kernel dispatch support
If `only_gen_src_code=True` the src code will be returned instead of being
codegenned into the wrapper
"""
_, (_numel, rnumel) = template_node.group
assert rnumel == 1
if (
isinstance(template_node.node, MultiTemplateBuffer)
and template_node.node._make_kernel_renders
):
kernels = []
src_codes = []
for make_kernel_render in template_node.node._make_kernel_renders.values():
kernel, render = make_kernel_render(
template_node.node, hint_override=hint_override
)
if only_gen_src_code:
src_code = self._codegen_single_template(
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
only_gen_src_code=True,
)
assert isinstance(src_code, str)
src_codes.append(src_code)
else:
kernel = self._codegen_single_template(
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
only_gen_src_code=False,
)
kernels.append(kernel)
if only_gen_src_code:
return "\n\n".join(src_codes)
MultiKernel.merge_workspaces_inplace(kernels)
multi_kernel = MultiKernel(kernels)
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
self.codegen_comment(node_schedule)
multi_kernel.call_kernel(multi_kernel.kernel_name)
V.graph.removed_buffers |= multi_kernel.removed_buffers
V.graph.inplaced_to_remove |= multi_kernel.inplaced_to_remove
self.free_buffers_in_scheduler()
return None
else:
kernel, render = template_node.node.make_kernel_render(
template_node.node, hint_override=hint_override
)
if only_gen_src_code:
return self._codegen_single_template(
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
only_gen_src_code=True,
)
else:
kernel = self._codegen_single_template(
kernel,
render,
template_node,
epilogue_nodes,
prologue_nodes,
only_gen_src_code=False,
)
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
self.codegen_comment(node_schedule)
kernel.call_kernel(kernel.kernel_name, template_node.node)
V.graph.removed_buffers |= kernel.removed_buffers
V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
self.free_buffers_in_scheduler()
return None
def codegen_sync(self):
V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())
@ -2439,7 +2536,9 @@ class SIMDScheduling(BaseScheduling):
def ready_to_flush(self) -> bool:
return False
def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
def generate_kernel_code_from_nodes(
self, nodes, benchmark_kernel=False, hint_override: Optional[int] = None
):
if not any(n.is_template() for n in nodes):
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
@ -2464,6 +2563,7 @@ class SIMDScheduling(BaseScheduling):
epilogue,
prologue,
only_gen_src_code=True,
hint_override=hint_override,
)
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")

View File

@ -1855,6 +1855,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
min_elem_per_thread=0,
optimize_mask=True,
fixed_config: Optional[FixedTritonConfig] = None,
hint_override: Optional[int] = None,
**kwargs,
) -> None:
self.optimize_mask: bool = optimize_mask
@ -1872,6 +1873,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
collections.defaultdict(dict)
)
self.tma_min_block_sizes = dict[str, int]()
self.hint_override = hint_override
self._load_counts: collections.Counter[str] = collections.Counter()
# A set of autotuning hints to pass as part of triton_meta
@ -3696,13 +3698,13 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
buf = V.graph.try_get_buffer(arg_name)
if buf:
result.writeline(
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size(), hint_override=self.hint_override)}, {V.graph.sizevars.size_hints(buf.get_stride(), hint_override=self.hint_override)}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
)
elif arg_name in V.graph.constants:
# note that random seed is put in V.graph.constants
const_tensor = V.graph.constants[arg_name]
result.writeline(
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size(), hint_override=self.hint_override)}, {V.graph.sizevars.size_hints(const_tensor.stride(), hint_override=self.hint_override)}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
)
elif isinstance(arg_sig, SizeArg):
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)

View File

@ -445,6 +445,12 @@ force_same_precision: bool = Config(
default=False,
)
# Size hints for multi-kernel dispatch.
# A reasonable default value of this config would be [64, 256, 4096]
# TODO: @bobrenjc93 to roll this out to a few internal models to ensure this works
# as expected before turning it on for everyone.
multi_kernel_hints: list[int] = []
# Specify candidate backends for gemm autotune.
# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP.
# ATen: default Pytorch ATen kernels.

View File

@ -4877,7 +4877,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
self,
layout: Layout,
inputs: Sequence[IRNode],
choice_timings_fn: Callable[[], dict[ChoiceCaller, float]],
choice_timings_fn: Callable[[Optional[int]], dict[ChoiceCaller, float]],
unfiltered_choices: list[ChoiceCaller],
allowed_prologue_inps: OrderedSet[str],
) -> None:
@ -4888,7 +4888,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
allowed_prologue_inps=allowed_prologue_inps,
)
self._choice_timings_fn = choice_timings_fn
self._choice_timings: Optional[dict[ChoiceCaller, float]] = None
self._choice_timings: dict[Optional[int], dict[ChoiceCaller, float]] = {}
self.original_inputs = inputs
self._output_plannable = all(
isinstance(choice, TritonTemplateCallerBase)
@ -4898,6 +4898,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
)
for choice in unfiltered_choices
)
self._make_kernel_renders: dict[Optional[int], Any] = {}
@property
def output_plannable(self) -> bool:
@ -4906,11 +4907,12 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
"""
return self._output_plannable
@property
def choice_timings(self) -> dict[ChoiceCaller, float]:
if self._choice_timings is None:
self._choice_timings = self._choice_timings_fn()
return self._choice_timings
def choice_timings(
self, hint_override: Optional[int] = None
) -> dict[ChoiceCaller, float]:
if hint_override not in self._choice_timings:
self._choice_timings[hint_override] = self._choice_timings_fn(hint_override)
return self._choice_timings[hint_override]
@contextlib.contextmanager
def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[None]:
@ -4934,8 +4936,22 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
assert self.get_stride() == caller.layout.stride
self.make_kernel_render = caller.get_make_kernel_render()
def get_min_choice(self) -> tuple[ChoiceCaller, float]:
return min(self.choice_timings.items(), key=lambda x: x[1])
def get_min_choice(
self, hint_override: Optional[int] = None
) -> tuple[ChoiceCaller, float]:
timings = self.choice_timings(hint_override=hint_override)
min_choice = min(timings, key=timings.get) # type: ignore[arg-type]
return (min_choice, timings[min_choice])
def finalize_as_triton_callers(
self, callers: dict[Optional[int], TritonTemplateCallerBase]
) -> None:
"""Finalize with multiple callers for different hint overrides"""
for hint_override, caller in callers.items():
self._make_kernel_renders[hint_override] = caller.get_make_kernel_render()
# Set the default to be the one without hint override
self.make_kernel_render = self._make_kernel_renders[None]
class CUDATemplateBuffer(TemplateBuffer):

View File

@ -28,6 +28,7 @@ import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.codecache import LambdaFuture, PyCodeCache
from torch._inductor.ir import TritonTemplateCallerBase
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.fx.experimental.symbolic_shapes import free_symbols
from torch.utils._ordered_set import OrderedSet
@ -2716,7 +2717,10 @@ class Scheduler:
return backend.benchmark_fused_nodes(nodes)
def generate_kernel_code_from_nodes(
self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool
self,
nodes: Sequence[BaseSchedulerNode],
benchmark_kernel: bool,
hint_override: Optional[int] = None,
) -> str:
"""
Benchmark fused list of nodes and return the execution time
@ -2727,7 +2731,9 @@ class Scheduler:
self.current_device = device
backend = self.get_backend(device)
with dynamo_timed("benchmark_fused_nodes"):
return backend.generate_kernel_code_from_nodes(nodes, benchmark_kernel)
return backend.generate_kernel_code_from_nodes(
nodes, benchmark_kernel, hint_override=hint_override
)
def benchmark_codegened_module(
self, module: ModuleType, device: torch.device
@ -2789,7 +2795,7 @@ class Scheduler:
min_node_unfused = next(
(
timing
for timing in multi_node.choice_timings
for timing in multi_node.choice_timings()
if isinstance(
timing,
torch._inductor.select_algorithm.ExternKernelCaller,
@ -2801,7 +2807,23 @@ class Scheduler:
min_node_unfused,
torch._inductor.ir.TritonTemplateCallerBase,
):
node.node.finalize_as_triton_caller(min_node_unfused)
if config.multi_kernel_hints:
callers: dict[Optional[int], TritonTemplateCallerBase] = {}
callers[None] = min_node_unfused
for hint in config.multi_kernel_hints:
timings = multi_node.choice_timings(hint_override=hint)
triton_timings = {
k: v
for k, v in timings.items()
if isinstance(k, TritonTemplateCallerBase)
}
choice = min(triton_timings.items(), key=lambda x: x[1])[0]
callers[hint] = choice
node.node.finalize_as_triton_callers(callers)
else:
node.node.finalize_as_triton_caller(min_node_unfused)
continue
out_tensorbox = min_node_unfused.output_node()
@ -2924,10 +2946,10 @@ class Scheduler:
async_compile = torch._inductor.async_compile.AsyncCompile()
def compile_kernel(
nodes: Sequence[BaseSchedulerNode],
nodes: Sequence[BaseSchedulerNode], hint_override: Optional[int] = None
) -> tuple[Optional[LambdaFuture], ModuleType]:
src_code = self.generate_kernel_code_from_nodes(
nodes, benchmark_kernel=True
nodes, benchmark_kernel=True, hint_override=hint_override
)
mod = PyCodeCache.load(src_code)
if not async_compile.use_process_pool():
@ -2949,8 +2971,58 @@ class Scheduler:
)
assert isinstance(multi_node, ir.MultiTemplateBuffer)
hint_override_best_fusion_choice: dict[
Optional[int], TritonTemplateCallerBase
] = {}
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
for hint_override in config.multi_kernel_hints:
choice_timings = multi_node.choice_timings(hint_override)
for choice, unfused_time in sorted(
choice_timings.items(), key=lambda x: x[1]
):
if not isinstance(
choice, torch._inductor.select_algorithm.TritonTemplateCaller
):
continue
with multi_node.swap_as_triton_caller(choice):
future_choices.append(
(
choice,
*compile_kernel(
node_list_fused, hint_override=choice.hint_override
),
)
)
min_ms_fused = float("inf")
ms_fused_choice: Optional[TritonTemplateCallerBase] = None
new_timings = {}
for choice, future, mod_fused in future_choices:
try:
if future is not None:
future.result()
except Exception as e:
if fusion_log.isEnabledFor(logging.DEBUG):
fusion_log.debug(
"Exception in compiling %s: %s",
"prologue" if not epilogue_fusion else "epilogue",
str(e),
)
continue
with multi_node.swap_as_triton_caller(choice):
ms_fused, path = self.benchmark_codegened_module(
mod_fused, device
)
new_timings[choice] = ms_fused
if ms_fused < min_ms_fused:
min_ms_fused = ms_fused
ms_fused_choice = choice
multi_node._choice_timings[hint_override] = new_timings
assert isinstance(ms_fused_choice, TritonTemplateCallerBase)
hint_override_best_fusion_choice[hint_override] = ms_fused_choice
# Eagerly compile and benchmark non-template nodes
choice_timings = multi_node.choice_timings
choice_timings = multi_node.choice_timings()
_, ms1 = multi_node.get_min_choice()
ms2, path2 = (
self.benchmark_fused_nodes(node_list_2)
@ -3025,8 +3097,15 @@ class Scheduler:
log_fusion(min_ms_fused, ms1, ms2)
if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None:
multi_node.finalize_as_triton_caller(ms_fused_choice)
multi_node._choice_timings = new_timings
if config.multi_kernel_hints:
hint_override_best_fusion_choice[None] = ms_fused_choice
multi_node.finalize_as_triton_callers(
hint_override_best_fusion_choice
)
else:
multi_node.finalize_as_triton_caller(ms_fused_choice)
multi_node._choice_timings[None] = new_timings
return True
else:
return False
@ -5010,7 +5089,10 @@ class BaseScheduling:
raise NotImplementedError
def generate_kernel_code_from_nodes(
self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool
self,
nodes: Sequence[BaseSchedulerNode],
benchmark_kernel: bool,
hint_override: Optional[int] = None,
) -> str:
"""
Generate a kernel given a list of pre-fused nodes.

View File

@ -294,6 +294,15 @@ RecordedEventsType = list[tuple[str, list[Any], dict[str, Any]]]
class TritonTemplateKernel(TritonKernel):
"""
A specialized kernel class for Triton templates that handles code generation
for templated Triton kernels.
This class extends TritonKernel to provide additional functionality for
template-based kernel generation, including support for subgraphs, workspace
arguments, and prologue/epilogue fusion.
"""
def __init__(
self,
kernel_name,
@ -314,6 +323,7 @@ class TritonTemplateKernel(TritonKernel):
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
prologue_loads_all_inputs=False,
hint_override: Optional[int] = None,
) -> None:
numel = sympy_product(output_node.get_size())
super().__init__(
@ -322,6 +332,7 @@ class TritonTemplateKernel(TritonKernel):
"r0_": sympy.S.One,
},
features=SIMDKernelFeatures([], numel),
hint_override=hint_override,
)
self.input_nodes = input_nodes
self.output_node = output_node
@ -1096,16 +1107,26 @@ class TritonTemplateKernel(TritonKernel):
def codegen_range_tree(self):
pass # ignore default codegen
def additional_call_args_and_types(self):
if isinstance(self.grid_fn, SymbolicGridFn):
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
return (grid_args, map(type, grid_args))
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
return (grid_args, map(type, grid_args))
return ((), ())
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
wrapper = V.graph.wrapper_code
_, call_args, _, arg_types = self.args.python_argdefs()
grid_args = ()
if isinstance(self.grid_fn, SymbolicGridFn):
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
else:
additional_call_args, additional_arg_types = (
self.additional_call_args_and_types()
)
if not additional_call_args:
assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn"
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
meta = wrapper.add_meta_once(self.meta)
@ -1114,9 +1135,9 @@ class TritonTemplateKernel(TritonKernel):
f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})"
)
arg_types.append(None)
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
call_args.extend(grid_args)
arg_types.extend(map(type, grid_args))
call_args.extend(additional_call_args)
arg_types.extend(additional_arg_types)
if self.workspace_arg is not None:
wrapper.generate_workspace_allocation(self.workspace_arg)
@ -1202,6 +1223,7 @@ class GeneratedCodeCache:
num_consumer_groups: int,
num_buffers_warp_spec: int,
kwargs: dict[str, Any],
hint_override: Optional[int] = None,
) -> Optional[str]:
def layout_key(layout: ir.Layout) -> str:
assert not isinstance(layout, ir.FlexibleLayout)
@ -1252,6 +1274,7 @@ class GeneratedCodeCache:
"num_buffers_warp_spec": num_buffers_warp_spec,
"epilogue_fn_hash": epilogue_fn_hash,
"kwargs": kwargs,
"hint_override": hint_override,
}
)
@ -1356,6 +1379,7 @@ class TritonTemplate(KernelTemplate):
layout: ir.Layout,
kwargs: dict[str, Any],
generate_with_caching,
hint_override: Optional[int] = None,
) -> Optional[GenerateAndLoadResult]:
"""Generate the python code and load it into the current process"""
caching_enabled = (
@ -1428,6 +1452,7 @@ class TritonTemplate(KernelTemplate):
output_node=fake_out,
workspace_arg=workspace_arg,
use_jit=False,
hint_override=hint_override,
**kernel_options,
)
@ -1547,6 +1572,7 @@ class TritonTemplate(KernelTemplate):
call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
generate_with_caching=False,
hint_override: Optional[int] = None,
**kwargs,
):
"""This function generates a TritonTemplateCaller
@ -1591,6 +1617,7 @@ class TritonTemplate(KernelTemplate):
layout,
kwargs,
generate_with_caching and self._cache_codegen_enabled_for_template,
hint_override=hint_override,
)
# May happen as result of dev by 0.
@ -1612,6 +1639,7 @@ class TritonTemplate(KernelTemplate):
extra_args = V.graph.sizevars.size_hints(
map(sympy.expand, result.kernel_args_sizevars_keys),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
)
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
@ -1635,13 +1663,14 @@ class TritonTemplate(KernelTemplate):
options = result.kernel_options
def make_kernel_render(out_node):
def make_kernel_render(out_node, hint_override: Optional[int] = None):
assert result is not None
kernel = self.kernel_type(
kernel_name=str(Placeholder.KERNEL_NAME),
output_node=out_node,
workspace_arg=workspace_arg,
use_jit=False,
hint_override=hint_override,
**options,
)
render = functools.partial(
@ -1657,6 +1686,7 @@ class TritonTemplate(KernelTemplate):
*V.graph.sizevars.size_hints(
call_sizes,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
kwargs,
)
@ -1708,6 +1738,7 @@ class TritonTemplate(KernelTemplate):
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,
allowed_prologue_inps=result.prologue_supported_inputs,
hint_override=hint_override,
)
@ -1783,6 +1814,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
mutated_inputs=None,
workspace_arg: Optional[WorkspaceArg] = None,
allowed_prologue_inps: Optional[OrderedSet[str]] = None,
hint_override: Optional[int] = None,
) -> None:
super().__init__(name, input_nodes, layout, description)
self.make_kernel_render = make_kernel_render
@ -1802,6 +1834,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
self.allowed_prologue_inps = (
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet()
)
self.hint_override = hint_override
def benchmark(self, *args, out):
assert self.bmreq is not None
@ -2256,7 +2289,7 @@ class AlgorithmSelectorCache(PersistentCache):
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
has_autotuned = False
def benchmark(choices):
def benchmark(choices, hint_override: Optional[int] = None):
nonlocal has_autotuned
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
has_autotuned = True
@ -2264,13 +2297,13 @@ class AlgorithmSelectorCache(PersistentCache):
# TODO(nmacchioni): remove this layer of abstraction
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
benchmark_fn = self.make_benchmark_fn(
choices, input_nodes, layout, input_gen_fns
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
)
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
return benchmark_fn(choices)
def autotune(choices):
def autotune(choices, hint_override: Optional[int] = None):
log.debug("Starting autotuning")
with dynamo_timed(
@ -2292,13 +2325,13 @@ class AlgorithmSelectorCache(PersistentCache):
),
},
):
return benchmark(choices)
return benchmark(choices, hint_override=hint_override)
if config.autotune_in_subproc:
# Initialize the suprocess pool so it will warmup early.
torch._inductor.autotune_process.get_tuning_process_pool()
def do_autotuning(choices, precompile_fn):
def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None):
precompile_start_ts = time.time()
with dynamo_timed(
f"{name}_template_precompiling",
@ -2319,7 +2352,8 @@ class AlgorithmSelectorCache(PersistentCache):
candidates,
name,
inputs_key,
autotune,
lambda choices: autotune(choices, hint_override=hint_override),
hint_override=hint_override,
)
choices = self.prune_choices_postscreen(
choices, timings, name, inputs_key, self.prescreening_cache
@ -2332,7 +2366,8 @@ class AlgorithmSelectorCache(PersistentCache):
choices,
name,
inputs_key,
autotune,
lambda choices: autotune(choices, hint_override=hint_override),
hint_override=hint_override,
)
autotune_elapse = time.time() - autotune_start_ts
@ -2355,6 +2390,7 @@ class AlgorithmSelectorCache(PersistentCache):
autotune_elapse,
precompile_elapse,
prescreening_elapse,
hint_override=hint_override,
)
def profiler_bench_function():
@ -2389,8 +2425,16 @@ class AlgorithmSelectorCache(PersistentCache):
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
def get_timings():
timings = do_autotuning(choices, precompile_fn)
def get_timings(hint_override: Optional[int] = None):
filtered_choices = [
c
for c in choices
if not hasattr(c, "hint_override")
or c.hint_override == hint_override
]
timings = do_autotuning(
filtered_choices, precompile_fn, hint_override=hint_override
)
min_extern_choice = float("inf")
for choice, timing in timings.items():
if isinstance(choice, ExternKernelCaller):
@ -2627,6 +2671,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
) -> AutotuneArgs:
"""
Factory method to create AutotuneArgs from a list of ChoiceCallers.
@ -2636,7 +2681,9 @@ class AlgorithmSelectorCache(PersistentCache):
# de-duplicate args
unique_example_inputs = {
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
x.get_name(): input_gen_fns.get(
i, lambda x: cls.benchmark_example_value(x, hint_override=hint_override)
)(x)
for i, x in enumerate(input_nodes)
}
example_inputs = list(unique_example_inputs.values())
@ -2649,20 +2696,23 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
input_node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hints(
input_node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hint(
input_node.get_layout().offset,
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
)
)
for input_node in input_nodes
]
out = cls.benchmark_example_value(layout)
out = cls.benchmark_example_value(layout, hint_override=hint_override)
out_extern = torch.as_strided(
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
)
@ -2771,8 +2821,11 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
) -> dict[ChoiceCaller, float]:
inputs = cls.get_inputs(choices, input_nodes, layout, input_gen_fns)
inputs = cls.get_inputs(
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
)
return cls.benchmark_choices(choices, inputs)
@classmethod
@ -2782,6 +2835,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
):
from . import autotune_process
@ -2791,7 +2845,7 @@ class AlgorithmSelectorCache(PersistentCache):
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
timings = cls.benchmark_in_current_process(
extern, input_nodes, layout, input_gen_fns
extern, input_nodes, layout, input_gen_fns, hint_override=hint_override
)
timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type]
return timings
@ -2803,6 +2857,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes: list[ir.IRNode],
layout: ir.Layout,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
hint_override: Optional[int] = None,
):
if DEBUG:
print(f"{len(choices)} tuning requests:")
@ -2813,6 +2868,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
)
else:
return functools.partial(
@ -2820,6 +2876,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_nodes=input_nodes,
layout=layout,
input_gen_fns=input_gen_fns,
hint_override=hint_override,
)
@staticmethod
@ -2978,6 +3035,7 @@ class AlgorithmSelectorCache(PersistentCache):
elapse: float,
precompile_elapse: float,
prescreening_elapse: Optional[float] = None,
hint_override: Optional[int] = None,
):
V.debug.log_autotuning_results(
name, input_nodes, timings, elapse, precompile_elapse
@ -2992,6 +3050,7 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
n.get_size(),
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
),
)
)
@ -3079,7 +3138,7 @@ class AlgorithmSelectorCache(PersistentCache):
)
@staticmethod
def benchmark_example_value(node):
def benchmark_example_value(node, hint_override: Optional[int] = None):
"""
Convert an ir.Buffer into a concrete torch.Tensor we can use for
benchmarking.
@ -3098,10 +3157,12 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
node.get_size(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
V.graph.sizevars.size_hints(
node.get_stride(),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
node.get_device(),
node.get_dtype(),
@ -3109,6 +3170,7 @@ class AlgorithmSelectorCache(PersistentCache):
V.graph.sizevars.size_hints(
V.graph.get_allocation_size(node),
fallback=config.unbacked_symint_fallback,
hint_override=hint_override,
),
)

View File

@ -59,6 +59,15 @@ def statically_known_true(
# lifting and in some cases we should be directly passing through to ShapeEnv,
# but there is some extra inductor logic that needs to be handled here
class SizeVarAllocator:
"""
A class that manages symbolic size variables and their relationships.
This class works with the ShapeEnv to handle symbolic shape expressions,
simplify them, and provide utilities for guarding, checking, and evaluating
symbolic expressions. It also manages precomputed replacements and stride
calculations for tensor operations.
"""
def __init__(self, shape_env=None) -> None:
super().__init__()
if shape_env is None:
@ -527,7 +536,9 @@ class SizeVarAllocator:
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
return expr
def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]:
def symbolic_hint(
self, expr: Union[Expr, int], hint_override: Optional[int] = None
) -> Union[Expr, int]:
if isinstance(expr, int):
return expr
# Substitute all hints into expr, but leave unbacked symints alone
@ -541,13 +552,21 @@ class SizeVarAllocator:
return int(expr) # type: ignore[return-value]
except TypeError:
return expr # inf/nan/I
if hint_override:
return hint_override
expr = self.remove_precomputed_replacements(expr)
return sympy_subs(expr, self.var_to_val)
def size_hint(
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
self,
expr: Union[Expr, int],
*,
fallback: Optional[int] = None,
hint_override: Optional[int] = None,
) -> int:
out = self.symbolic_hint(expr)
out = self.symbolic_hint(expr, hint_override=hint_override)
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
# Use the provided heuristic fallback hint
unbacked_sym_vrs = {
@ -581,8 +600,12 @@ class SizeVarAllocator:
exprs: Iterable[Union[Expr, int]],
*,
fallback: Optional[int] = None,
hint_override: Optional[int] = None,
) -> tuple[int, ...]:
return tuple(self.size_hint(x, fallback=fallback) for x in exprs)
return tuple(
self.size_hint(x, fallback=fallback, hint_override=hint_override)
for x in exprs
)
def size_hints_or_throw(
self,

View File

@ -5,7 +5,7 @@ import itertools
import math
from functools import partial
from threading import Lock
from typing import Any, Callable, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
from torch.utils._ordered_set import OrderedSet
@ -33,6 +33,7 @@ class BaseConfig:
block_k: int
num_stages: int
num_warps: int
hint_override: Optional[int] = None
@dataclasses.dataclass
@ -421,7 +422,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
"""
Finalizes configs after scaling, applying additional constraints.
"""
used: OrderedSet[tuple[int, ...]] = OrderedSet()
used: OrderedSet[tuple[Optional[int], ...]] = OrderedSet()
max_mm_configs = config.test_configs.max_mm_configs
@ -430,11 +431,12 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
# Construct key for finding duplicate configs
key: tuple[int, ...] = (
key: tuple[Optional[int], ...] = (
conf.block_m,
conf.block_n,
conf.block_k,
conf.num_stages,
conf.hint_override,
num_warps,
)
@ -451,12 +453,11 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
"BLOCK_M": conf.block_m,
"BLOCK_N": conf.block_n,
"BLOCK_K": conf.block_k,
"num_stages": conf.num_stages,
"num_warps": num_warps,
"hint_override": conf.hint_override,
}
if group_m is not None:
kwargs["GROUP_M"] = group_m
yield self.triton_config(**kwargs)
yield self.triton_config(conf.num_stages, num_warps, **kwargs)
def _scale_mm_configs(
self,
@ -467,6 +468,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
scale: float,
has_int8_tensor: bool,
exclude: Callable[[int, int, int], bool],
hint_override: Optional[int] = None,
) -> list[BaseConfig]:
"""
Scales and filters matrix multiplication configs based on input size.
@ -476,47 +478,52 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
min_block_size = 16
min_block_size_k = 32 if has_int8_tensor else 16
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
)
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size_k,
)
scaled_configs = []
for c in configs:
scaled_config = dataclasses.replace(
c,
block_m=max(min(int(c.block_m * scale), m), min_block_size),
block_n=max(min(int(c.block_n * scale), n), min_block_size),
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
for hint_override in [None] + config.multi_kernel_hints:
m_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
)
),
min_block_size,
)
n_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
)
),
min_block_size,
)
k_hint = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k,
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
hint_override=hint_override,
)
),
min_block_size_k,
)
if not exclude(
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
):
scaled_configs.append(scaled_config)
for c in configs:
scaled_config = dataclasses.replace(
c,
block_m=max(min(int(c.block_m * scale), m_hint), min_block_size),
block_n=max(min(int(c.block_n * scale), n_hint), min_block_size),
block_k=max(min(int(c.block_k * scale), k_hint), min_block_size_k),
hint_override=hint_override,
)
if not exclude(
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
):
scaled_configs.append(scaled_config)
return scaled_configs