mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
multi-kernel matmuls based on varying hint sizes (#156628)
The core idea is to generate multiple matmul kernels using different hints for symbolic variables, then select the most appropriate one at runtime for each unique shape we encounter. You can find some early experimentation details in these posts: https://fb.workplace.com/groups/8940092306109185/posts/9803850776399996/ https://fb.workplace.com/groups/8940092306109185/posts/9695805170537891/ https://fb.workplace.com/groups/257735836456307/posts/906589324904285/ Here’s a graph illustrating the empirically observed worst-case performance if an oracle always selected the least optimal hint for a given runtime size:  This graph illustrates the performance of a hint size of 64 relative to the worst case. Notice that as the runtime sizes increase, the performance gradually approaches the worst case:  This graph shows the performance of a hint size of 4096 — very poor for small sizes, and also suboptimal for some mid-sized shapes:  Finally, here’s the graph that motivated this PR. It illustrates the performance when selecting the best of three kernels generated with three different hints — 64, 256, and 4096:  ## How to review this PR At a high level, this extends @shunting314's multi-kernel abstraction to support varying GEMM choices driven by different hints. A few key points: 1. Unlike reduction kernels, triton template matmuls pass their grid as arguments to the kernel. This PR updates `MultiKernelCall` to support kernels with varying arguments. 2. The `V.graph.sizevars.size_hints` API is extended to accept a `hint_override`, allowing us to substitute the example input’s size hint with a custom value when generating multiple kernels. 3. The choice generation and benchmarking logic is updated to support multiple hint values. One kernel is generated per value in `torch._inductor.config.multi_kernel_hints`, and at runtime, we select the most suitable kernel for the current shape. 4. This PR does not add support for cpp wrapper codegen to keep it scoped. That will be added in the next PR. ## Results The following is a basic test that shows our basic multi kernel working where we no longer show significant variance based on the original hint size: https://gist.github.com/bobrenjc93/ba711d529e65fd65839b34799f6323ec Before ``` Hint\Runtime | 64 | 256 | 4096 --------------------------------------------------- 64 | 0.0948 | 0.3124 | 4.9477 256 | 0.2243 | 0.2256 | 3.3880 4096 | 0.3384 | 0.3404 | 3.3010 ``` After ``` Hint\Runtime | 64 | 256 | 4096 --------------------------------------------------- 64 | 0.0951 | 0.2289 | 3.3013 256 | 0.0952 | 0.2258 | 3.4045 4096 | 0.0957 | 0.2231 | 3.3146 ``` We also see an average speedup of 5.04% for the matrix of all hint/runtime pairs in [64, 4096] for every increment of 64: https://docs.google.com/spreadsheets/d/12TmYUDrAAFASGuP3POXTKPeAvQWIRzKzdrVSIb3vQkA/edit?gid=480268938#gid=480268938  NB: This is just the beginning and I plan on doing more investigation to see further improve on this initial result. For posterity the script used to generate that matrix is here: https://gist.github.com/bobrenjc93/c211fd0bd97fad8f46b91ad9dee76ad0 HUD benchmark runs: base: https://github.com/pytorch/pytorch/actions/runs/15889871988 head: https://github.com/pytorch/pytorch/actions/runs/15889876842 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156628 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
191693ac85
commit
5221448574
@ -294,6 +294,15 @@ RecordedEventsType = list[tuple[str, list[Any], dict[str, Any]]]
|
||||
|
||||
|
||||
class TritonTemplateKernel(TritonKernel):
|
||||
"""
|
||||
A specialized kernel class for Triton templates that handles code generation
|
||||
for templated Triton kernels.
|
||||
|
||||
This class extends TritonKernel to provide additional functionality for
|
||||
template-based kernel generation, including support for subgraphs, workspace
|
||||
arguments, and prologue/epilogue fusion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name,
|
||||
@ -314,6 +323,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
prologue_loads_all_inputs=False,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> None:
|
||||
numel = sympy_product(output_node.get_size())
|
||||
super().__init__(
|
||||
@ -322,6 +332,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
"r0_": sympy.S.One,
|
||||
},
|
||||
features=SIMDKernelFeatures([], numel),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
self.input_nodes = input_nodes
|
||||
self.output_node = output_node
|
||||
@ -1096,16 +1107,26 @@ class TritonTemplateKernel(TritonKernel):
|
||||
def codegen_range_tree(self):
|
||||
pass # ignore default codegen
|
||||
|
||||
def additional_call_args_and_types(self):
|
||||
if isinstance(self.grid_fn, SymbolicGridFn):
|
||||
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
return (grid_args, map(type, grid_args))
|
||||
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
|
||||
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
return (grid_args, map(type, grid_args))
|
||||
return ((), ())
|
||||
|
||||
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, _, arg_types = self.args.python_argdefs()
|
||||
|
||||
grid_args = ()
|
||||
if isinstance(self.grid_fn, SymbolicGridFn):
|
||||
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
|
||||
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
|
||||
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
|
||||
else:
|
||||
additional_call_args, additional_arg_types = (
|
||||
self.additional_call_args_and_types()
|
||||
)
|
||||
|
||||
if not additional_call_args:
|
||||
assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn"
|
||||
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
|
||||
meta = wrapper.add_meta_once(self.meta)
|
||||
@ -1114,9 +1135,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||
f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})"
|
||||
)
|
||||
arg_types.append(None)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
call_args.extend(grid_args)
|
||||
arg_types.extend(map(type, grid_args))
|
||||
|
||||
call_args.extend(additional_call_args)
|
||||
arg_types.extend(additional_arg_types)
|
||||
|
||||
if self.workspace_arg is not None:
|
||||
wrapper.generate_workspace_allocation(self.workspace_arg)
|
||||
@ -1202,6 +1223,7 @@ class GeneratedCodeCache:
|
||||
num_consumer_groups: int,
|
||||
num_buffers_warp_spec: int,
|
||||
kwargs: dict[str, Any],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
def layout_key(layout: ir.Layout) -> str:
|
||||
assert not isinstance(layout, ir.FlexibleLayout)
|
||||
@ -1252,6 +1274,7 @@ class GeneratedCodeCache:
|
||||
"num_buffers_warp_spec": num_buffers_warp_spec,
|
||||
"epilogue_fn_hash": epilogue_fn_hash,
|
||||
"kwargs": kwargs,
|
||||
"hint_override": hint_override,
|
||||
}
|
||||
)
|
||||
|
||||
@ -1356,6 +1379,7 @@ class TritonTemplate(KernelTemplate):
|
||||
layout: ir.Layout,
|
||||
kwargs: dict[str, Any],
|
||||
generate_with_caching,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Optional[GenerateAndLoadResult]:
|
||||
"""Generate the python code and load it into the current process"""
|
||||
caching_enabled = (
|
||||
@ -1428,6 +1452,7 @@ class TritonTemplate(KernelTemplate):
|
||||
output_node=fake_out,
|
||||
workspace_arg=workspace_arg,
|
||||
use_jit=False,
|
||||
hint_override=hint_override,
|
||||
**kernel_options,
|
||||
)
|
||||
|
||||
@ -1547,6 +1572,7 @@ class TritonTemplate(KernelTemplate):
|
||||
call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
generate_with_caching=False,
|
||||
hint_override: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""This function generates a TritonTemplateCaller
|
||||
@ -1591,6 +1617,7 @@ class TritonTemplate(KernelTemplate):
|
||||
layout,
|
||||
kwargs,
|
||||
generate_with_caching and self._cache_codegen_enabled_for_template,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
# May happen as result of dev by 0.
|
||||
@ -1612,6 +1639,7 @@ class TritonTemplate(KernelTemplate):
|
||||
extra_args = V.graph.sizevars.size_hints(
|
||||
map(sympy.expand, result.kernel_args_sizevars_keys),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
|
||||
@ -1635,13 +1663,14 @@ class TritonTemplate(KernelTemplate):
|
||||
|
||||
options = result.kernel_options
|
||||
|
||||
def make_kernel_render(out_node):
|
||||
def make_kernel_render(out_node, hint_override: Optional[int] = None):
|
||||
assert result is not None
|
||||
kernel = self.kernel_type(
|
||||
kernel_name=str(Placeholder.KERNEL_NAME),
|
||||
output_node=out_node,
|
||||
workspace_arg=workspace_arg,
|
||||
use_jit=False,
|
||||
hint_override=hint_override,
|
||||
**options,
|
||||
)
|
||||
render = functools.partial(
|
||||
@ -1657,6 +1686,7 @@ class TritonTemplate(KernelTemplate):
|
||||
*V.graph.sizevars.size_hints(
|
||||
call_sizes,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
kwargs,
|
||||
)
|
||||
@ -1708,6 +1738,7 @@ class TritonTemplate(KernelTemplate):
|
||||
mutated_inputs=mutated_inputs,
|
||||
workspace_arg=workspace_arg,
|
||||
allowed_prologue_inps=result.prologue_supported_inputs,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
|
||||
@ -1783,6 +1814,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
mutated_inputs=None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
allowed_prologue_inps: Optional[OrderedSet[str]] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(name, input_nodes, layout, description)
|
||||
self.make_kernel_render = make_kernel_render
|
||||
@ -1802,6 +1834,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
self.allowed_prologue_inps = (
|
||||
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet()
|
||||
)
|
||||
self.hint_override = hint_override
|
||||
|
||||
def benchmark(self, *args, out):
|
||||
assert self.bmreq is not None
|
||||
@ -2256,7 +2289,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
|
||||
has_autotuned = False
|
||||
|
||||
def benchmark(choices):
|
||||
def benchmark(choices, hint_override: Optional[int] = None):
|
||||
nonlocal has_autotuned
|
||||
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
|
||||
has_autotuned = True
|
||||
@ -2264,13 +2297,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# TODO(nmacchioni): remove this layer of abstraction
|
||||
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
|
||||
benchmark_fn = self.make_benchmark_fn(
|
||||
choices, input_nodes, layout, input_gen_fns
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
|
||||
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
|
||||
return benchmark_fn(choices)
|
||||
|
||||
def autotune(choices):
|
||||
def autotune(choices, hint_override: Optional[int] = None):
|
||||
log.debug("Starting autotuning")
|
||||
|
||||
with dynamo_timed(
|
||||
@ -2292,13 +2325,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
),
|
||||
},
|
||||
):
|
||||
return benchmark(choices)
|
||||
return benchmark(choices, hint_override=hint_override)
|
||||
|
||||
if config.autotune_in_subproc:
|
||||
# Initialize the suprocess pool so it will warmup early.
|
||||
torch._inductor.autotune_process.get_tuning_process_pool()
|
||||
|
||||
def do_autotuning(choices, precompile_fn):
|
||||
def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None):
|
||||
precompile_start_ts = time.time()
|
||||
with dynamo_timed(
|
||||
f"{name}_template_precompiling",
|
||||
@ -2319,7 +2352,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
candidates,
|
||||
name,
|
||||
inputs_key,
|
||||
autotune,
|
||||
lambda choices: autotune(choices, hint_override=hint_override),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
choices = self.prune_choices_postscreen(
|
||||
choices, timings, name, inputs_key, self.prescreening_cache
|
||||
@ -2332,7 +2366,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
choices,
|
||||
name,
|
||||
inputs_key,
|
||||
autotune,
|
||||
lambda choices: autotune(choices, hint_override=hint_override),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
autotune_elapse = time.time() - autotune_start_ts
|
||||
@ -2355,6 +2390,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
autotune_elapse,
|
||||
precompile_elapse,
|
||||
prescreening_elapse,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
def profiler_bench_function():
|
||||
@ -2389,8 +2425,16 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
|
||||
|
||||
def get_timings():
|
||||
timings = do_autotuning(choices, precompile_fn)
|
||||
def get_timings(hint_override: Optional[int] = None):
|
||||
filtered_choices = [
|
||||
c
|
||||
for c in choices
|
||||
if not hasattr(c, "hint_override")
|
||||
or c.hint_override == hint_override
|
||||
]
|
||||
timings = do_autotuning(
|
||||
filtered_choices, precompile_fn, hint_override=hint_override
|
||||
)
|
||||
min_extern_choice = float("inf")
|
||||
for choice, timing in timings.items():
|
||||
if isinstance(choice, ExternKernelCaller):
|
||||
@ -2627,6 +2671,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> AutotuneArgs:
|
||||
"""
|
||||
Factory method to create AutotuneArgs from a list of ChoiceCallers.
|
||||
@ -2636,7 +2681,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
# de-duplicate args
|
||||
unique_example_inputs = {
|
||||
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
|
||||
x.get_name(): input_gen_fns.get(
|
||||
i, lambda x: cls.benchmark_example_value(x, hint_override=hint_override)
|
||||
)(x)
|
||||
for i, x in enumerate(input_nodes)
|
||||
}
|
||||
example_inputs = list(unique_example_inputs.values())
|
||||
@ -2649,20 +2696,23 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hint(
|
||||
input_node.get_layout().offset,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
)
|
||||
for input_node in input_nodes
|
||||
]
|
||||
out = cls.benchmark_example_value(layout)
|
||||
out = cls.benchmark_example_value(layout, hint_override=hint_override)
|
||||
out_extern = torch.as_strided(
|
||||
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
|
||||
)
|
||||
@ -2771,8 +2821,11 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
inputs = cls.get_inputs(choices, input_nodes, layout, input_gen_fns)
|
||||
inputs = cls.get_inputs(
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
return cls.benchmark_choices(choices, inputs)
|
||||
|
||||
@classmethod
|
||||
@ -2782,6 +2835,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
from . import autotune_process
|
||||
|
||||
@ -2791,7 +2845,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
|
||||
|
||||
timings = cls.benchmark_in_current_process(
|
||||
extern, input_nodes, layout, input_gen_fns
|
||||
extern, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type]
|
||||
return timings
|
||||
@ -2803,6 +2857,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
if DEBUG:
|
||||
print(f"{len(choices)} tuning requests:")
|
||||
@ -2813,6 +2868,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
else:
|
||||
return functools.partial(
|
||||
@ -2820,6 +2876,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -2978,6 +3035,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
elapse: float,
|
||||
precompile_elapse: float,
|
||||
prescreening_elapse: Optional[float] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
V.debug.log_autotuning_results(
|
||||
name, input_nodes, timings, elapse, precompile_elapse
|
||||
@ -2992,6 +3050,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
n.get_size(),
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
)
|
||||
@ -3079,7 +3138,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def benchmark_example_value(node):
|
||||
def benchmark_example_value(node, hint_override: Optional[int] = None):
|
||||
"""
|
||||
Convert an ir.Buffer into a concrete torch.Tensor we can use for
|
||||
benchmarking.
|
||||
@ -3098,10 +3157,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
node.get_device(),
|
||||
node.get_dtype(),
|
||||
@ -3109,6 +3170,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
V.graph.get_allocation_size(node),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user