mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[multi-kernel] shape-similarity kernel selection (#163090)
Introduces a variant of size-hint multi-kernel, where for novel runtime shapes, instead of performing full benchmarking to determine the optimal kernel, selects one of many kernels pre-generated from multi-kernel hints, based off similarity b/w hint / runtime input & output shapes (L1 distance in log2 space). Some caveats/changes: - Size-hint multi-kernel now only kicks in if the kernel has dynamic shapes - Pre-generation still only does 1-d search over specified hints, e.g. `matmul([s0, s1], [s1, s2])` with size-hints `[64, 256]` only generates 2 kernels - based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256]). Extending this to reasonable n-d search (via user API?) is an extension Benchmarking results, compared to multi-kernel w/ full benchmarking (hints 64, 4096), and compiling with the ground truth hint: <img width="1902" height="1222" alt="550541081_1088709150049684_6528797079439730237_n" src="https://github.com/user-attachments/assets/056cca48-c16a-4451-9b4a-fa13a7a058a9" /> Full benchmarking doing worse is extremely weird, but we did see similar spikes in #156628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163090 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
22c5e8c17c
commit
2a9745de3c
@ -50,6 +50,16 @@ def _contains_multi_kernel_code(wrapper_code: str):
|
||||
)
|
||||
|
||||
|
||||
def _contains_size_hint_multi_kernel_code(wrapper_code: str):
|
||||
return (
|
||||
re.search(
|
||||
r"multi_kernel_[^ ]* = async_compile.size_hint_multi_kernel[(]",
|
||||
wrapper_code,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def make_cpp_wrapper_test(orig_test, **extra_args):
|
||||
"""
|
||||
Wrap an existing test into a new test with cpp-wrapper enabled.
|
||||
@ -115,6 +125,7 @@ class MultiKernelTest(TestCase):
|
||||
)
|
||||
x = torch.randn(4096, 4096, device=GPU_TYPE)
|
||||
y = torch.randn(4096, 4096, device=GPU_TYPE)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
|
||||
ref = fn(x, y)
|
||||
|
||||
@ -123,7 +134,7 @@ class MultiKernelTest(TestCase):
|
||||
# We mainly care about the wrapper for the final pass here.
|
||||
wrapper_code = wrapper_code[-1]
|
||||
self.assertEqual(ref, act)
|
||||
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
||||
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
|
||||
|
||||
@requires_triton()
|
||||
# TODO: bobrenjc93 to fix multi-kernel for ROCM
|
||||
@ -142,6 +153,7 @@ class MultiKernelTest(TestCase):
|
||||
)
|
||||
x = torch.randn(4096, 4096, device=GPU_TYPE)
|
||||
y = torch.randn(4096, 4096, device=GPU_TYPE)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
|
||||
ref = fn(x, y)
|
||||
|
||||
@ -150,7 +162,7 @@ class MultiKernelTest(TestCase):
|
||||
# We mainly care about the wrapper for the final pass here.
|
||||
wrapper_code = wrapper_code[-1]
|
||||
self.assertEqual(ref, act)
|
||||
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
||||
self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code))
|
||||
|
||||
@parametrize("force_kernel", (0, 1))
|
||||
@unittest.mock.patch.dict(
|
||||
|
@ -521,6 +521,11 @@ class AsyncCompile:
|
||||
# no need to call this in parallel since the sub-kernels are already parallel tasks
|
||||
return MultiKernelCall(*args, **kwargs)
|
||||
|
||||
def size_hint_multi_kernel(self, *args, **kwargs) -> Any:
|
||||
from torch._inductor.codegen.multi_kernel import SizeHintMultiKernelCall
|
||||
|
||||
return SizeHintMultiKernelCall(*args, **kwargs)
|
||||
|
||||
def cpp(self, source_code: str):
|
||||
kernel_code_log.info("CPP Kernel:\n%s", source_code)
|
||||
if get_compile_threads() <= 1:
|
||||
|
@ -1,8 +1,10 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from torch._inductor.ir import MultiTemplateBuffer
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
@ -31,7 +33,13 @@ class MultiKernelState:
|
||||
self.subkernel_to_kernel_name = {}
|
||||
self.kernel_defs = IndentedBuffer()
|
||||
|
||||
def define_kernel(self, kernels):
|
||||
def define_kernel(
|
||||
self,
|
||||
kernels: list[Any],
|
||||
kernel_shape_keys: Optional[
|
||||
list[Union[None, tuple[tuple[int, ...], ...]]]
|
||||
] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
|
||||
This has some minor issue.
|
||||
@ -45,6 +53,12 @@ class MultiKernelState:
|
||||
The only different is cache eviction policy.
|
||||
|
||||
We should name the multi-kernel differently in these 2 cases.
|
||||
|
||||
kernels:
|
||||
A list of kernels
|
||||
kernel_shape_keys:
|
||||
Specified for size-hint multi-kernels.
|
||||
Each list element is a shape key, corresponding to the concrete input & output size hints each kernel was tuned for.
|
||||
"""
|
||||
# Prevent circular import
|
||||
from ..select_algorithm import TritonTemplateKernel
|
||||
@ -68,9 +82,7 @@ class MultiKernelState:
|
||||
kernels[0].output_node, MultiTemplateBuffer
|
||||
):
|
||||
for i, kernel in enumerate(kernels):
|
||||
additional_call_args, additional_arg_types = (
|
||||
kernel.additional_call_args_and_types()
|
||||
)
|
||||
additional_call_args, _ = kernel.additional_call_args_and_types()
|
||||
if i not in arg_index:
|
||||
arg_index[i] = []
|
||||
arg_index[i].append(slice(0, len(call_args)))
|
||||
@ -85,7 +97,7 @@ class MultiKernelState:
|
||||
for i in range(len(kernels)):
|
||||
arg_index[i] = [slice(0, len(call_args))]
|
||||
|
||||
shape_specialize = isinstance(kernels[0], TritonTemplateKernel)
|
||||
keyed_by_sizes = kernel_shape_keys is not None
|
||||
buf = self.kernel_defs
|
||||
buf.writeline("")
|
||||
buf.writeline("arg_index = {")
|
||||
@ -93,13 +105,26 @@ class MultiKernelState:
|
||||
slice_reprs = ", ".join(repr(s) for s in slice_list)
|
||||
buf.writeline(f" {key}: [{slice_reprs}],")
|
||||
buf.writeline("}")
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
|
||||
)
|
||||
with buf.indent():
|
||||
for name in kernel_names:
|
||||
buf.writeline(f"{name},")
|
||||
buf.writeline(f"], arg_index=arg_index, shape_specialize={shape_specialize})")
|
||||
|
||||
if not keyed_by_sizes: # no size hint keys, just call with list of kernels
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
|
||||
)
|
||||
with buf.indent():
|
||||
for name in kernel_names:
|
||||
buf.writeline(f"{name},")
|
||||
buf.writeline("], arg_index=arg_index)")
|
||||
else: # call with dict[size hint key, kernel]
|
||||
assert isinstance(kernels[0], TritonTemplateKernel)
|
||||
assert isinstance(kernel_shape_keys, list)
|
||||
assert len(kernels) == len(kernel_shape_keys)
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.size_hint_multi_kernel({multi_kernel_name!r}, {{"
|
||||
)
|
||||
with buf.indent():
|
||||
for shape_key, name in zip(kernel_shape_keys, kernel_names):
|
||||
buf.writeline(f"{shape_key}: {name},")
|
||||
buf.writeline("}, arg_index=arg_index)")
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
|
||||
@ -266,8 +291,8 @@ class MultiKernelCall:
|
||||
This class is called at run time to actually run the kernel
|
||||
"""
|
||||
|
||||
def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False):
|
||||
assert len(kernels) >= 2
|
||||
def __init__(self, multi_kernel_name, kernels, arg_index):
|
||||
assert len(kernels) >= 1
|
||||
self._kernels = kernels
|
||||
self.multi_kernel_name = multi_kernel_name
|
||||
|
||||
@ -287,13 +312,6 @@ class MultiKernelCall:
|
||||
|
||||
self._recorded = False
|
||||
|
||||
# This means for each unique shape we will do a separate assessment
|
||||
# for which kernel is the best. This is particularly useful for matmul
|
||||
# kernels where the best kernel can vary based on very small differences
|
||||
# in shape.
|
||||
self._shape_specialize = shape_specialize
|
||||
self._shape_cache = {}
|
||||
|
||||
def cache_file_path(self):
|
||||
key = code_hash(
|
||||
",".join(
|
||||
@ -415,20 +433,6 @@ class MultiKernelCall:
|
||||
return V.graph.multi_kernel_to_choice[multi_kernel_name]
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
if self._shape_specialize:
|
||||
cache_key = self._get_shape_cache_key(*args, **kwargs)
|
||||
cached_choice = self._get_cached_shape_choice(cache_key)
|
||||
if cached_choice is not None:
|
||||
self.picked_kernel = cached_choice
|
||||
log.debug(
|
||||
"using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
|
||||
self.picked_kernel,
|
||||
[k.inductor_meta.get("kernel_name") for k in self.kernels],
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
self._select_kernel_by_shape(*args, **kwargs)
|
||||
|
||||
if self.picked_kernel is None:
|
||||
timings = self.benchmark_sub_kernels(*args, **kwargs)
|
||||
self.picked_kernel = timings.index(min(timings))
|
||||
@ -460,38 +464,6 @@ class MultiKernelCall:
|
||||
filtered_args = self._get_filtered_args(args, self.picked_kernel)
|
||||
run(*filtered_args, **kwargs)
|
||||
|
||||
def _get_shape_cache_key(self, *args, **kwargs):
|
||||
"""
|
||||
Generate a cache key based on tensor shapes for shape-specialized dispatch.
|
||||
"""
|
||||
shapes = []
|
||||
for arg in args:
|
||||
if hasattr(arg, "shape"):
|
||||
shapes.append(tuple(arg.shape))
|
||||
return tuple(shapes)
|
||||
|
||||
def _get_cached_shape_choice(self, cache_key):
|
||||
"""
|
||||
Get cached kernel choice for a specific shape.
|
||||
"""
|
||||
return self._shape_cache.get(cache_key)
|
||||
|
||||
def _cache_shape_choice(self, cache_key, kernel_idx):
|
||||
"""
|
||||
Cache kernel choice for a specific shape
|
||||
"""
|
||||
self._shape_cache[cache_key] = kernel_idx
|
||||
|
||||
def _select_kernel_by_shape(self, *args, **kwargs):
|
||||
"""
|
||||
Benchmark kernels for a particular shape and return the
|
||||
best kernel for this shape.
|
||||
"""
|
||||
shape_key = self._get_shape_cache_key(*args, **kwargs)
|
||||
timings = self.benchmark_sub_kernels(*args, **kwargs)
|
||||
self.picked_kernel = timings.index(min(timings))
|
||||
self._cache_shape_choice(shape_key, self.picked_kernel)
|
||||
|
||||
def _metrics_table_row(self, timings):
|
||||
def get_kernel_path(k):
|
||||
return k.fn.fn.__code__.co_filename
|
||||
@ -511,3 +483,121 @@ class MultiKernelCall:
|
||||
row[f"kernel{i}_path"] = ""
|
||||
row[f"kernel{i}_latency"] = ""
|
||||
return row
|
||||
|
||||
|
||||
class SizeHintMultiKernel(MultiKernel):
|
||||
"""
|
||||
Version of multi-kernel that generates kernels based on specified size hints.
|
||||
Currently only performs 1-d search over hints; doesn't perform combinatorial n-d search
|
||||
if n > 1 dynamic dimensions are specified.
|
||||
|
||||
e.g. matmul([s0, s1], [s1, s2]) with size-hints [64, 256] only generates 2 kernels,
|
||||
based on tuning shapes ([64, 64], [64, 64]) and ([256, 256], [256, 256])
|
||||
"""
|
||||
|
||||
def __init__(self, kernels):
|
||||
assert isinstance(kernels, dict) and len(kernels) >= 1
|
||||
|
||||
self.kernels, self.kernel_shape_keys = [], []
|
||||
for shape_key, kernel in kernels.items():
|
||||
self.kernels.append(kernel)
|
||||
self.kernel_shape_keys.append(shape_key)
|
||||
self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
|
||||
self.kernels, self.kernel_shape_keys
|
||||
)
|
||||
|
||||
# need this since some code in inductor check if the kernel object has an args
|
||||
# attribute to decide if it's a non-null kernel.
|
||||
self.args = object()
|
||||
|
||||
|
||||
class SizeHintMultiKernelCall(MultiKernelCall):
|
||||
"""
|
||||
Runtime class for size-hint multi-kernels.
|
||||
Instead of having a plain list of kernels to benchmark over, keys them by input & output shapes,
|
||||
and optionally perform shape-based selection. The pre-generated kernel is chosen based on the shape keys,
|
||||
with the heuristic being log2 l1 distance between the pre-generated / runtime input & output shapes.
|
||||
"""
|
||||
|
||||
def __init__(self, multi_kernel_name, kernels, arg_index):
|
||||
super().__init__(multi_kernel_name, list(kernels.values()), arg_index)
|
||||
self._kernel_hints = list(kernels.keys())
|
||||
|
||||
# Caches results for unique shapes.
|
||||
self._shape_cache = {}
|
||||
|
||||
def _get_shape_cache_key(self, *args, **kwargs):
|
||||
"""
|
||||
Generate a cache key based on tensor shapes for shape-specialized dispatch.
|
||||
"""
|
||||
shapes = []
|
||||
for arg in args:
|
||||
if hasattr(arg, "shape"):
|
||||
shapes.append(tuple(arg.shape))
|
||||
return tuple(shapes)
|
||||
|
||||
def _get_cached_shape_choice(self, cache_key):
|
||||
"""
|
||||
Get cached kernel choice for a specific shape.
|
||||
"""
|
||||
return self._shape_cache.get(cache_key)
|
||||
|
||||
def _cache_shape_choice(self, cache_key, kernel_idx):
|
||||
"""
|
||||
Cache kernel choice for a specific shape.
|
||||
"""
|
||||
self._shape_cache[cache_key] = kernel_idx
|
||||
|
||||
def _dist_heuristic(self, k1, k2):
|
||||
"""
|
||||
log2 L1 distance heuristic for kernel selection.
|
||||
"""
|
||||
|
||||
def dist(x, y):
|
||||
lx = math.log2(x) if x > 0 else -1
|
||||
ly = math.log2(y) if y > 0 else -1
|
||||
return abs(lx - ly)
|
||||
|
||||
out = 0
|
||||
for s1, s2 in zip(k1, k2):
|
||||
out += sum(dist(x, y) for x, y in zip(s1, s2))
|
||||
return out
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
cache_key = self._get_shape_cache_key(*args, **kwargs)
|
||||
cached_choice = self._get_cached_shape_choice(cache_key)
|
||||
if cached_choice is not None:
|
||||
self.picked_kernel = cached_choice
|
||||
log.debug(
|
||||
"using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
|
||||
self.picked_kernel,
|
||||
[k.inductor_meta.get("kernel_name") for k in self.kernels],
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
self._select_kernel_by_shape(*args, **kwargs)
|
||||
|
||||
if not self._recorded:
|
||||
self._recorded = True
|
||||
picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get(
|
||||
"kernel_name"
|
||||
)
|
||||
assert picked_kernel_name is not None
|
||||
self.record_choice(self.multi_kernel_name, picked_kernel_name)
|
||||
|
||||
run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
|
||||
filtered_args = self._get_filtered_args(args, self.picked_kernel)
|
||||
run(*filtered_args, **kwargs)
|
||||
|
||||
def _select_kernel_by_shape(self, *args, **kwargs):
|
||||
"""
|
||||
Benchmark kernels for a particular shape and return the
|
||||
best kernel for this shape.
|
||||
"""
|
||||
shape_key = self._get_shape_cache_key(*args, **kwargs)
|
||||
dists = [
|
||||
self._dist_heuristic(shape_key, key) if key is not None else 2**62
|
||||
for key in self._kernel_hints
|
||||
]
|
||||
self.picked_kernel = dists.index(min(dists))
|
||||
self._cache_shape_choice(shape_key, self.picked_kernel)
|
||||
|
@ -60,7 +60,7 @@ from ..utils import (
|
||||
from ..virtualized import ops, OpsWrapper, V
|
||||
from .block_analysis import BlockPatternMatcher
|
||||
from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter
|
||||
from .multi_kernel import MultiKernel
|
||||
from .multi_kernel import MultiKernel, SizeHintMultiKernel
|
||||
from .simd_kernel_features import (
|
||||
DisableReduction,
|
||||
EnableReduction,
|
||||
@ -1689,6 +1689,51 @@ class SIMDScheduling(BaseScheduling):
|
||||
|
||||
return kernel
|
||||
|
||||
def _get_multikernel_shapes(
|
||||
self, node: MultiTemplateBuffer
|
||||
) -> tuple[tuple[int, ...], ...]:
|
||||
from ..ir import IRNode
|
||||
|
||||
def get_size(arg):
|
||||
if not isinstance(arg, IRNode) or (size := arg.maybe_get_size()) is None:
|
||||
return None
|
||||
return tuple(s for s in size)
|
||||
|
||||
out = []
|
||||
for arg in list(node.inputs) + [node]:
|
||||
if isinstance(arg, (list, tuple)):
|
||||
out.append(tuple(get_size(_arg) for _arg in arg))
|
||||
else:
|
||||
out.append(get_size(arg))
|
||||
return tuple(out)
|
||||
|
||||
def _kernel_has_dynamic_shapes(self, node: MultiTemplateBuffer) -> bool:
|
||||
shapes = self._get_multikernel_shapes(node)
|
||||
return any(
|
||||
any(
|
||||
isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer)
|
||||
for s in shape
|
||||
)
|
||||
for shape in shapes
|
||||
)
|
||||
|
||||
def _make_shape_cache_key(
|
||||
self, node: MultiTemplateBuffer, hint: int
|
||||
) -> tuple[tuple[int, ...], ...]:
|
||||
"""
|
||||
Returns cache key for hint-based multi-graph; key is tuple of shapes with hint filled in.
|
||||
"""
|
||||
shapes = self._get_multikernel_shapes(node)
|
||||
return tuple(
|
||||
tuple(
|
||||
hint
|
||||
if isinstance(s, sympy.Expr) and not isinstance(s, sympy.Integer)
|
||||
else s
|
||||
for s in shape
|
||||
)
|
||||
for shape in shapes
|
||||
)
|
||||
|
||||
def codegen_template(
|
||||
self,
|
||||
template_node,
|
||||
@ -1711,11 +1756,16 @@ class SIMDScheduling(BaseScheduling):
|
||||
if (
|
||||
isinstance(template_node.node, MultiTemplateBuffer)
|
||||
and template_node.node._make_kernel_renders
|
||||
and len(template_node.node._make_kernel_renders) > 1
|
||||
and self._kernel_has_dynamic_shapes(template_node.node)
|
||||
):
|
||||
kernels = []
|
||||
kernels = {}
|
||||
src_codes = []
|
||||
|
||||
for make_kernel_render in template_node.node._make_kernel_renders.values():
|
||||
for (
|
||||
size_hint,
|
||||
make_kernel_render,
|
||||
) in template_node.node._make_kernel_renders.items():
|
||||
kernel, render = make_kernel_render(
|
||||
template_node.node, hint_override=hint_override
|
||||
)
|
||||
@ -1732,6 +1782,8 @@ class SIMDScheduling(BaseScheduling):
|
||||
assert isinstance(src_code, str)
|
||||
src_codes.append(src_code)
|
||||
else:
|
||||
if size_hint is None:
|
||||
continue # skip kernel generation based on real runtime value; only use hints
|
||||
kernel = self._codegen_single_template(
|
||||
kernel,
|
||||
render,
|
||||
@ -1740,13 +1792,18 @@ class SIMDScheduling(BaseScheduling):
|
||||
prologue_nodes,
|
||||
only_gen_src_code=False,
|
||||
)
|
||||
kernels.append(kernel)
|
||||
shape_cache_key = (
|
||||
None
|
||||
if size_hint is None
|
||||
else self._make_shape_cache_key(template_node.node, size_hint)
|
||||
)
|
||||
kernels[shape_cache_key] = kernel
|
||||
|
||||
if only_gen_src_code:
|
||||
return "\n\n".join(src_codes)
|
||||
|
||||
MultiKernel.merge_workspaces_inplace(kernels)
|
||||
multi_kernel = MultiKernel(kernels)
|
||||
MultiKernel.merge_workspaces_inplace(list(kernels.values()))
|
||||
multi_kernel = SizeHintMultiKernel(kernels)
|
||||
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
|
||||
self.codegen_comment(node_schedule)
|
||||
|
||||
|
Reference in New Issue
Block a user