This commit is contained in:
Pian Pawakapan
2025-09-11 22:52:45 -07:00
parent 864ffe12d7
commit 4477015fc3
2 changed files with 84 additions and 25 deletions

View File

@ -19,6 +19,9 @@ from .common import TensorArg, WorkspaceArg
log = logging.getLogger(__name__)
FULL_BENCHMARK = False
class MultiKernelState:
"""
Maintain state of multi-kernel compilation so we don't define duplicated
@ -31,7 +34,7 @@ class MultiKernelState:
self.subkernel_to_kernel_name = {}
self.kernel_defs = IndentedBuffer()
def define_kernel(self, kernels):
def define_kernel(self, kernels, kernel_shape_keys=None):
"""
Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}".
This has some minor issue.
@ -85,7 +88,7 @@ class MultiKernelState:
for i in range(len(kernels)):
arg_index[i] = [slice(0, len(call_args))]
shape_specialize = isinstance(kernels[0], TritonTemplateKernel)
shape_specialize = kernel_shape_keys is not None
buf = self.kernel_defs
buf.writeline("")
buf.writeline("arg_index = {")
@ -93,13 +96,25 @@ class MultiKernelState:
slice_reprs = ", ".join(repr(s) for s in slice_list)
buf.writeline(f" {key}: [{slice_reprs}],")
buf.writeline("}")
buf.writeline(
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
)
with buf.indent():
for name in kernel_names:
buf.writeline(f"{name},")
buf.writeline(f"], arg_index=arg_index, shape_specialize={shape_specialize})")
if not shape_specialize: # no size hint keys, just call with list of kernels
buf.writeline(
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
)
with buf.indent():
for name in kernel_names:
buf.writeline(f"{name},")
buf.writeline(f"], arg_index=arg_index, shape_specialize=False)")
else: # call with dict[size hint key, kernel]
assert isinstance(kernels[0], TritonTemplateKernel)
assert len(kernels) == len(kernel_shape_keys)
buf.writeline(
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, {{"
)
with buf.indent():
for shape_key, name in zip(kernel_shape_keys, kernel_names):
buf.writeline(f"{shape_key}: {name},")
buf.writeline(f"}}, arg_index=arg_index, shape_specialize=True)")
if config.triton.autotune_at_compile_time:
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
@ -127,9 +142,18 @@ class MultiKernel:
def __init__(self, kernels):
assert len(kernels) >= 2
self.kernels = kernels
self.kernels = []
self.kernel_shape_keys = None
if isinstance(kernels, dict):
self.kernel_shape_keys = []
for shape_key, kernel in kernels.items():
self.kernels.append(kernel)
self.kernel_shape_keys.append(shape_key)
else:
self.kernels = kernels
self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel(
kernels
self.kernels, self.kernel_shape_keys
)
# need this since some code in inductor check if the kernel object has an args
@ -268,7 +292,21 @@ class MultiKernelCall:
def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False):
assert len(kernels) >= 2
self._kernels = kernels
if isinstance(kernels, dict):
# This means for each unique shape we will do a separate assessment
# for which kernel is the best. This is particularly useful for matmul
# kernels where the best kernel can vary based on very small differences
# in shape.
self._shape_specialize = True
self._shape_cache = {shape_key: i for i, shape_key in enumerate(kernels.keys()) if shape_key is not None}
self._kernels = list(kernels.values())
self._kernel_hints = list(kernels.keys())
else:
self._shape_specialize = False
self._shape_cache = {}
self._kernels = kernels
self._kernel_hints = None
self.multi_kernel_name = multi_kernel_name
self.disable_cache = os.environ.get(
@ -287,13 +325,6 @@ class MultiKernelCall:
self._recorded = False
# This means for each unique shape we will do a separate assessment
# for which kernel is the best. This is particularly useful for matmul
# kernels where the best kernel can vary based on very small differences
# in shape.
self._shape_specialize = shape_specialize
self._shape_cache = {}
def cache_file_path(self):
key = code_hash(
",".join(
@ -418,7 +449,9 @@ class MultiKernelCall:
if self._shape_specialize:
cache_key = self._get_shape_cache_key(*args, **kwargs)
cached_choice = self._get_cached_shape_choice(cache_key)
print("running with cache key", cache_key)
if cached_choice is not None:
print("found cached choice", cached_choice)
self.picked_kernel = cached_choice
log.debug(
"using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
@ -482,14 +515,29 @@ class MultiKernelCall:
"""
self._shape_cache[cache_key] = kernel_idx
def _l1_dist(self, k1, k2):
dist = 0
for s1, s2 in zip(k1, k2):
dist += sum(abs(x - y) for x, y in zip(s1, s2))
return dist
def _select_kernel_by_shape(self, *args, **kwargs):
"""
Benchmark kernels for a particular shape and return the
best kernel for this shape.
"""
assert self._shape_specialize and self._kernel_hints is not None
global FULL_BENCHMARK
shape_key = self._get_shape_cache_key(*args, **kwargs)
timings = self.benchmark_sub_kernels(*args, **kwargs)
self.picked_kernel = timings.index(min(timings))
if FULL_BENCHMARK:
timings = self.benchmark_sub_kernels(*args, **kwargs)
self.picked_kernel = timings.index(min(timings))
else:
dists = [self._l1_dist(shape_key, key) if key is not None else 2**64 for key in self._kernel_hints]
self.picked_kernel = dists.index(min(dists))
print(f"Selected kernel index {self.picked_kernel} for fresh key {shape_key} based on existing key {self._kernel_hints[self.picked_kernel]}")
self._cache_shape_choice(shape_key, self.picked_kernel)
def _metrics_table_row(self, timings):

View File

@ -1692,6 +1692,16 @@ class SIMDScheduling(BaseScheduling):
return kernel
def _make_shape_cache_key(self, node: MultiTemplateBuffer, hint: int) -> tuple[tuple[int, ...], ...]:
"""
Returns cache key for hint-based multi-graph; key is tuple of shapes with hint filled in.
"""
out = []
for arg in list(node.inputs) + [node]:
if (size := arg.maybe_get_size()) is not None:
out.append(tuple(hint if isinstance(s, sympy.Symbol) else s for s in size))
return tuple(out)
def codegen_template(
self,
template_node,
@ -1715,10 +1725,10 @@ class SIMDScheduling(BaseScheduling):
isinstance(template_node.node, MultiTemplateBuffer)
and template_node.node._make_kernel_renders
):
kernels = []
kernels = {}
src_codes = []
for make_kernel_render in template_node.node._make_kernel_renders.values():
for size_hint, make_kernel_render in template_node.node._make_kernel_renders.items():
kernel, render = make_kernel_render(
template_node.node, hint_override=hint_override
)
@ -1743,12 +1753,13 @@ class SIMDScheduling(BaseScheduling):
prologue_nodes,
only_gen_src_code=False,
)
kernels.append(kernel)
shape_cache_key = None if size_hint is None else self._make_shape_cache_key(template_node.node, size_hint)
kernels[shape_cache_key] = kernel
if only_gen_src_code:
return "\n\n".join(src_codes)
MultiKernel.merge_workspaces_inplace(kernels)
MultiKernel.merge_workspaces_inplace(list(kernels.values()))
multi_kernel = MultiKernel(kernels)
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
self.codegen_comment(node_schedule)