mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
multi-kernel matmuls based on varying hint sizes (#156628)
The core idea is to generate multiple matmul kernels using different hints for symbolic variables, then select the most appropriate one at runtime for each unique shape we encounter. You can find some early experimentation details in these posts: https://fb.workplace.com/groups/8940092306109185/posts/9803850776399996/ https://fb.workplace.com/groups/8940092306109185/posts/9695805170537891/ https://fb.workplace.com/groups/257735836456307/posts/906589324904285/ Here’s a graph illustrating the empirically observed worst-case performance if an oracle always selected the least optimal hint for a given runtime size:  This graph illustrates the performance of a hint size of 64 relative to the worst case. Notice that as the runtime sizes increase, the performance gradually approaches the worst case:  This graph shows the performance of a hint size of 4096 — very poor for small sizes, and also suboptimal for some mid-sized shapes:  Finally, here’s the graph that motivated this PR. It illustrates the performance when selecting the best of three kernels generated with three different hints — 64, 256, and 4096:  ## How to review this PR At a high level, this extends @shunting314's multi-kernel abstraction to support varying GEMM choices driven by different hints. A few key points: 1. Unlike reduction kernels, triton template matmuls pass their grid as arguments to the kernel. This PR updates `MultiKernelCall` to support kernels with varying arguments. 2. The `V.graph.sizevars.size_hints` API is extended to accept a `hint_override`, allowing us to substitute the example input’s size hint with a custom value when generating multiple kernels. 3. The choice generation and benchmarking logic is updated to support multiple hint values. One kernel is generated per value in `torch._inductor.config.multi_kernel_hints`, and at runtime, we select the most suitable kernel for the current shape. 4. This PR does not add support for cpp wrapper codegen to keep it scoped. That will be added in the next PR. ## Results The following is a basic test that shows our basic multi kernel working where we no longer show significant variance based on the original hint size: https://gist.github.com/bobrenjc93/ba711d529e65fd65839b34799f6323ec Before ``` Hint\Runtime | 64 | 256 | 4096 --------------------------------------------------- 64 | 0.0948 | 0.3124 | 4.9477 256 | 0.2243 | 0.2256 | 3.3880 4096 | 0.3384 | 0.3404 | 3.3010 ``` After ``` Hint\Runtime | 64 | 256 | 4096 --------------------------------------------------- 64 | 0.0951 | 0.2289 | 3.3013 256 | 0.0952 | 0.2258 | 3.4045 4096 | 0.0957 | 0.2231 | 3.3146 ``` We also see an average speedup of 5.04% for the matrix of all hint/runtime pairs in [64, 4096] for every increment of 64: https://docs.google.com/spreadsheets/d/12TmYUDrAAFASGuP3POXTKPeAvQWIRzKzdrVSIb3vQkA/edit?gid=480268938#gid=480268938  NB: This is just the beginning and I plan on doing more investigation to see further improve on this initial result. For posterity the script used to generate that matrix is here: https://gist.github.com/bobrenjc93/c211fd0bd97fad8f46b91ad9dee76ad0 HUD benchmark runs: base: https://github.com/pytorch/pytorch/actions/runs/15889871988 head: https://github.com/pytorch/pytorch/actions/runs/15889876842 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156628 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
191693ac85
commit
5221448574
@ -51,7 +51,7 @@ aten = torch.ops.aten
|
||||
|
||||
|
||||
def patches(fn):
|
||||
def skip_cache(self, choices, name, key, benchmark):
|
||||
def skip_cache(self, choices, name, key, benchmark, hint_override=None):
|
||||
if benchmark is None:
|
||||
return {}
|
||||
timings = benchmark(choices)
|
||||
|
@ -1205,7 +1205,7 @@ class TestMaxAutotune(TestCase):
|
||||
# Make sure all args of generate_and_load_args are passed to make_key_args (Except generate_with_caching)
|
||||
# update this function each time new arg added to generate_and_load and make sure arg is added to make_key
|
||||
self.assertEqual(generate_and_load_args - 1, make_key_args)
|
||||
self.assertEqual(generate_and_load_args, 16)
|
||||
self.assertEqual(generate_and_load_args, 17)
|
||||
|
||||
@fresh_cache()
|
||||
@config.patch(
|
||||
@ -1293,7 +1293,7 @@ class TestMaxAutotune(TestCase):
|
||||
'layout':"[[10,30],[30,1],torch.float32,device(type='cuda',index=0),0]",
|
||||
'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity',
|
||||
'kwargs':{'EVEN_K':False,'ALLOW_TF32':True,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32',
|
||||
'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8}}"""
|
||||
'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}"""
|
||||
|
||||
expected = expected.replace("cuda", GPU_TYPE)
|
||||
self.assertExpectedInline(
|
||||
@ -1332,7 +1332,7 @@ class TestMaxAutotune(TestCase):
|
||||
'num_stages':1,'num_warps':2,'prefix_args':0,'suffix_args':0,'call_sizes':[s77,s94],
|
||||
'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0,
|
||||
'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','kwargs':{'EVEN_K':False,'ALLOW_TF32':True,
|
||||
'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8}}"""
|
||||
'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8},'hint_override':None}"""
|
||||
expected = expected.replace("cuda", GPU_TYPE)
|
||||
self.assertExpectedInline(
|
||||
remove_white_space(cache_key),
|
||||
@ -1585,6 +1585,7 @@ class TestMaxAutotunePrecompile(TestCase):
|
||||
op: str,
|
||||
inputs: str,
|
||||
benchmark: Callable[[Any], dict[ChoiceCaller, float]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Optional[dict[ChoiceCaller, float]]:
|
||||
if benchmark is not None:
|
||||
return benchmark(choices)
|
||||
|
@ -18,7 +18,12 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
skipIfXpu,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_GPU,
|
||||
IS_BIG_GPU,
|
||||
requires_triton,
|
||||
)
|
||||
|
||||
|
||||
class TransformerSnippet(nn.Module):
|
||||
@ -71,6 +76,7 @@ def make_cpp_wrapper_test(orig_test, **extra_args):
|
||||
{
|
||||
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
|
||||
"benchmark_kernel": True,
|
||||
"multi_kernel_hints": [64, 256, 4096],
|
||||
}
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
@ -91,6 +97,56 @@ class MultiKernelTest(TestCase):
|
||||
else:
|
||||
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
|
||||
|
||||
@requires_triton()
|
||||
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
|
||||
def test_triton_gemm(self):
|
||||
def fn(x, y):
|
||||
return x @ y
|
||||
|
||||
compiled_fn = torch.compile(
|
||||
fn,
|
||||
options={
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
},
|
||||
)
|
||||
x = torch.randn(4096, 4096, device=GPU_TYPE)
|
||||
y = torch.randn(4096, 4096, device=GPU_TYPE)
|
||||
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
|
||||
ref = fn(x, y)
|
||||
|
||||
# wrapper_code will contains 2 entries if cpp_wrapper=True.
|
||||
# One for the first pass and one for the second pass.
|
||||
# We mainly care about the wrapper for the final pass here.
|
||||
wrapper_code = wrapper_code[-1]
|
||||
self.assertEqual(ref, act)
|
||||
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
||||
|
||||
@requires_triton()
|
||||
@unittest.skipIf(not IS_BIG_GPU, "templates require big gpu")
|
||||
def test_triton_relu_fused_gemm(self):
|
||||
def fn(x, y):
|
||||
return (x @ y).relu()
|
||||
|
||||
compiled_fn = torch.compile(
|
||||
fn,
|
||||
options={
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "TRITON",
|
||||
},
|
||||
)
|
||||
x = torch.randn(4096, 4096, device=GPU_TYPE)
|
||||
y = torch.randn(4096, 4096, device=GPU_TYPE)
|
||||
act, wrapper_code = run_and_get_code(compiled_fn, x, y)
|
||||
ref = fn(x, y)
|
||||
|
||||
# wrapper_code will contains 2 entries if cpp_wrapper=True.
|
||||
# One for the first pass and one for the second pass.
|
||||
# We mainly care about the wrapper for the final pass here.
|
||||
wrapper_code = wrapper_code[-1]
|
||||
self.assertEqual(ref, act)
|
||||
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
|
||||
|
||||
@parametrize("force_kernel", (0, 1))
|
||||
@unittest.mock.patch.dict(
|
||||
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}
|
||||
|
@ -20,7 +20,7 @@ aten = torch.ops.aten
|
||||
|
||||
|
||||
def patches(fn):
|
||||
def skip_cache(self, choices, name, key, benchmark):
|
||||
def skip_cache(self, choices, name, key, benchmark, hint_override=None):
|
||||
if benchmark is None:
|
||||
return {}
|
||||
return benchmark(choices)
|
||||
|
@ -265,6 +265,7 @@ class PersistentCache(CacheBase):
|
||||
op: str,
|
||||
inputs: str,
|
||||
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
"""
|
||||
Check to see if we have benchmarked the given choice callers. For each
|
||||
@ -277,6 +278,7 @@ class PersistentCache(CacheBase):
|
||||
b. `max_autotune_gemm=False`: don't benchmark the choice, return nothing.
|
||||
"""
|
||||
precision = torch.get_float32_matmul_precision()
|
||||
cache_key = f"{inputs}_{hint_override}" if hint_override is not None else inputs
|
||||
|
||||
timings = {}
|
||||
|
||||
@ -285,9 +287,11 @@ class PersistentCache(CacheBase):
|
||||
hit = True
|
||||
for choice in choices:
|
||||
choice_hash = choice.hash_key()
|
||||
if choice_hash in cache.get(op, {}).get(inputs, {}).get(precision, {}):
|
||||
if choice_hash in cache.get(op, {}).get(cache_key, {}).get(
|
||||
precision, {}
|
||||
):
|
||||
# cache hit
|
||||
timings[choice] = cache[op][inputs][precision][choice_hash]
|
||||
timings[choice] = cache[op][cache_key][precision][choice_hash]
|
||||
else:
|
||||
# cache miss
|
||||
hit = False
|
||||
@ -300,9 +304,9 @@ class PersistentCache(CacheBase):
|
||||
timings = benchmark(choices)
|
||||
assert all(choice in timings for choice in choices)
|
||||
local_cache.setdefault(op, {})
|
||||
local_cache[op].setdefault(inputs, {}).setdefault(precision, {})
|
||||
local_cache[op].setdefault(cache_key, {}).setdefault(precision, {})
|
||||
for choice, timing in timings.items():
|
||||
local_cache[op][inputs][precision][choice.hash_key()] = timing
|
||||
local_cache[op][cache_key][precision][choice.hash_key()] = timing
|
||||
|
||||
self.update_local_cache(local_cache)
|
||||
|
||||
|
@ -124,10 +124,13 @@ class CUDACombinedScheduling(BaseScheduling):
|
||||
return self._triton_scheduling.benchmark_codegened_module(module)
|
||||
|
||||
def generate_kernel_code_from_nodes(
|
||||
self, nodes: Sequence[Any], benchmark_kernel: bool = False
|
||||
self,
|
||||
nodes: Sequence[Any],
|
||||
benchmark_kernel: bool = False,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> str:
|
||||
return self._triton_scheduling.generate_kernel_code_from_nodes(
|
||||
nodes, benchmark_kernel
|
||||
nodes, benchmark_kernel, hint_override=hint_override
|
||||
)
|
||||
|
||||
def benchmark_combo_kernel(
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from torch._inductor.ir import MultiTemplateBuffer
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
@ -45,6 +46,9 @@ class MultiKernelState:
|
||||
|
||||
We should name the multi-kernel differently in these 2 cases.
|
||||
"""
|
||||
# Prevent circular import
|
||||
from ..select_algorithm import TritonTemplateKernel
|
||||
|
||||
kernel_names = tuple(k.kernel_name for k in kernels)
|
||||
if kernel_names in self.subkernel_to_kernel_name:
|
||||
return self.subkernel_to_kernel_name[kernel_names]
|
||||
@ -58,15 +62,44 @@ class MultiKernelState:
|
||||
# the second pass of cpp-wrapper.
|
||||
return multi_kernel_name
|
||||
|
||||
arg_index: dict[int, list[slice]] = {}
|
||||
_, call_args, _, arg_types = kernels[0].args.python_argdefs()
|
||||
if isinstance(kernels[0], TritonTemplateKernel) and isinstance(
|
||||
kernels[0].output_node, MultiTemplateBuffer
|
||||
):
|
||||
for i, kernel in enumerate(kernels):
|
||||
additional_call_args, additional_arg_types = (
|
||||
kernel.additional_call_args_and_types()
|
||||
)
|
||||
if i not in arg_index:
|
||||
arg_index[i] = []
|
||||
arg_index[i].append(slice(0, len(call_args)))
|
||||
arg_index[i].append(
|
||||
slice(
|
||||
len(call_args) + i * len(additional_call_args),
|
||||
len(call_args) + (i + 1) * len(additional_call_args),
|
||||
)
|
||||
)
|
||||
else:
|
||||
kernels[0].add_numel_to_call_args(multi_kernel_name, call_args, arg_types)
|
||||
for i in range(len(kernels)):
|
||||
arg_index[i] = [slice(0, len(call_args))]
|
||||
|
||||
shape_specialize = isinstance(kernels[0], TritonTemplateKernel)
|
||||
buf = self.kernel_defs
|
||||
buf.writeline("")
|
||||
buf.writeline("arg_index = {")
|
||||
for key, slice_list in arg_index.items():
|
||||
slice_reprs = ", ".join(repr(s) for s in slice_list)
|
||||
buf.writeline(f" {key}: [{slice_reprs}],")
|
||||
buf.writeline("}")
|
||||
buf.writeline(
|
||||
f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, ["
|
||||
)
|
||||
with buf.indent():
|
||||
for name in kernel_names:
|
||||
buf.writeline(f"{name},")
|
||||
buf.writeline("])")
|
||||
buf.writeline(f"], arg_index=arg_index, shape_specialize={shape_specialize})")
|
||||
|
||||
if config.triton.autotune_at_compile_time:
|
||||
V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = (
|
||||
@ -135,6 +168,9 @@ class MultiKernel:
|
||||
Collect the union of arguments from all subkernels as the arguments
|
||||
for the multi-kernel.
|
||||
"""
|
||||
# Prevent circular import
|
||||
from ..select_algorithm import TritonTemplateKernel
|
||||
|
||||
assert kernel_name == self.kernel_name
|
||||
V.graph.wrapper_code.write_triton_header_once()
|
||||
_, call_args, _, arg_types = self.kernels[0].args.python_argdefs()
|
||||
@ -148,17 +184,42 @@ class MultiKernel:
|
||||
# the fast kernel directly
|
||||
kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||
|
||||
# numels for all subkernels should be the same. Use kernels[0] here
|
||||
self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types)
|
||||
if isinstance(self.kernels[0], TritonTemplateKernel) and isinstance(
|
||||
self.kernels[0].output_node, MultiTemplateBuffer
|
||||
):
|
||||
# For matmuls the grid arguments are passed in as additional arguments
|
||||
# to the kernel run method. These grids change based on the various
|
||||
# parameters of the matmul. So we need to pass each kernel's grid into
|
||||
# the multi call kernel.
|
||||
multi_call_args = call_args
|
||||
multi_call_arg_types = arg_types
|
||||
for i, kernel in enumerate(self.kernels):
|
||||
additional_call_args, additional_arg_types = (
|
||||
kernel.additional_call_args_and_types()
|
||||
)
|
||||
multi_call_args.extend(list(additional_call_args))
|
||||
multi_call_arg_types.extend(list(additional_arg_types))
|
||||
else:
|
||||
# numels for all subkernels should be the same. Use kernels[0] here
|
||||
self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types)
|
||||
multi_call_args = call_args
|
||||
multi_call_arg_types = arg_types
|
||||
|
||||
for ws in self.kernels[0].args.workspace_args:
|
||||
V.graph.wrapper_code.generate_workspace_allocation(ws)
|
||||
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
kernel_name,
|
||||
call_args,
|
||||
arg_types=arg_types,
|
||||
)
|
||||
if V.graph.cpp_wrapper:
|
||||
# We have already selected the best kernel at compile time
|
||||
# so we only have one set of call args. NB: this currently
|
||||
# doesn't work with MultiTemplateBuffer kernels. @bobrenjc93
|
||||
# will add it in a subsequent PR.
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
kernel_name, call_args, arg_types=arg_types
|
||||
)
|
||||
else:
|
||||
V.graph.wrapper_code.generate_kernel_call(
|
||||
kernel_name, multi_call_args, arg_types=multi_call_arg_types
|
||||
)
|
||||
|
||||
for ws in reversed(self.kernels[0].args.workspace_args):
|
||||
V.graph.wrapper_code.generate_workspace_deallocation(ws)
|
||||
@ -205,7 +266,7 @@ class MultiKernelCall:
|
||||
This class is called at run time to actually run the kernel
|
||||
"""
|
||||
|
||||
def __init__(self, multi_kernel_name, kernels):
|
||||
def __init__(self, multi_kernel_name, kernels, arg_index, shape_specialize=False):
|
||||
assert len(kernels) >= 2
|
||||
self._kernels = kernels
|
||||
self.multi_kernel_name = multi_kernel_name
|
||||
@ -215,6 +276,7 @@ class MultiKernelCall:
|
||||
) == "1" or is_metric_table_enabled("persistent_red_perf")
|
||||
|
||||
self.picked_kernel = None
|
||||
self.arg_index = arg_index
|
||||
if config.triton.multi_kernel > 1:
|
||||
# manually force a subkernel to ease perf testing
|
||||
picked_by_config = config.triton.multi_kernel - 2
|
||||
@ -225,6 +287,13 @@ class MultiKernelCall:
|
||||
|
||||
self._recorded = False
|
||||
|
||||
# This means for each unique shape we will do a separate assessment
|
||||
# for which kernel is the best. This is particularly useful for matmul
|
||||
# kernels where the best kernel can vary based on very small differences
|
||||
# in shape.
|
||||
self._shape_specialize = shape_specialize
|
||||
self._shape_cache = {}
|
||||
|
||||
def cache_file_path(self):
|
||||
key = code_hash(
|
||||
",".join(
|
||||
@ -282,18 +351,34 @@ class MultiKernelCall:
|
||||
be picked.
|
||||
"""
|
||||
|
||||
def wrap_fn(kernel):
|
||||
def wrap_fn(kernel, index):
|
||||
def inner():
|
||||
args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs)
|
||||
filtered_args = self._get_filtered_args(args, index)
|
||||
args_clone, kwargs_clone = kernel.clone_args(*filtered_args, **kwargs)
|
||||
return kernel.run(*args_clone, **kwargs_clone)
|
||||
|
||||
return inner
|
||||
|
||||
return [
|
||||
benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40)
|
||||
for kernel in self.kernels
|
||||
benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40)
|
||||
for index, kernel in enumerate(self.kernels)
|
||||
]
|
||||
|
||||
def _get_filtered_args(self, args, index):
|
||||
"""
|
||||
We pass in all arguments to all kernels into the MultiKernelCall
|
||||
so when invoking a particular kernel we need to filter to only the
|
||||
arguments for that specific kernel.
|
||||
"""
|
||||
|
||||
# This is sometimes invoked at runtime where V.graph is
|
||||
# a NullHandler
|
||||
if hasattr(V.graph, "cpp_wrapper") and V.graph.cpp_wrapper:
|
||||
# for cpp-wrapper, we should not filter the args since
|
||||
# we already have chosen a single kernel and arg set.
|
||||
return args
|
||||
return [item for s in self.arg_index[index] for item in args[s]]
|
||||
|
||||
# record_choice and lookup_choice are helper functions for cpp-wrapper
|
||||
# codegen. The first pass use record_choice to keep the choice and
|
||||
# the second pass do lookup by calling lookup_choice.
|
||||
@ -330,6 +415,20 @@ class MultiKernelCall:
|
||||
return V.graph.multi_kernel_to_choice[multi_kernel_name]
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
if self._shape_specialize:
|
||||
cache_key = self._get_shape_cache_key(*args, **kwargs)
|
||||
cached_choice = self._get_cached_shape_choice(cache_key)
|
||||
if cached_choice is not None:
|
||||
self.picked_kernel = cached_choice
|
||||
log.debug(
|
||||
"using cached shape-specialized choice %dth sub-kernel in %s. Cache key: %s",
|
||||
self.picked_kernel,
|
||||
[k.inductor_meta.get("kernel_name") for k in self.kernels],
|
||||
cache_key,
|
||||
)
|
||||
else:
|
||||
self._select_kernel_by_shape(*args, **kwargs)
|
||||
|
||||
if self.picked_kernel is None:
|
||||
timings = self.benchmark_sub_kernels(*args, **kwargs)
|
||||
self.picked_kernel = timings.index(min(timings))
|
||||
@ -345,6 +444,7 @@ class MultiKernelCall:
|
||||
get_metric_table("persistent_red_perf").add_row(
|
||||
functools.partial(self._metrics_table_row, timings)
|
||||
)
|
||||
|
||||
if not self.disable_cache:
|
||||
self.store_cache()
|
||||
|
||||
@ -355,8 +455,42 @@ class MultiKernelCall:
|
||||
)
|
||||
assert picked_kernel_name is not None
|
||||
self.record_choice(self.multi_kernel_name, picked_kernel_name)
|
||||
self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
|
||||
self.run(*args, **kwargs)
|
||||
|
||||
run = self.kernels[self.picked_kernel].run # type: ignore[method-assign]
|
||||
filtered_args = self._get_filtered_args(args, self.picked_kernel)
|
||||
run(*filtered_args, **kwargs)
|
||||
|
||||
def _get_shape_cache_key(self, *args, **kwargs):
|
||||
"""
|
||||
Generate a cache key based on tensor shapes for shape-specialized dispatch.
|
||||
"""
|
||||
shapes = []
|
||||
for arg in args:
|
||||
if hasattr(arg, "shape"):
|
||||
shapes.append(tuple(arg.shape))
|
||||
return tuple(shapes)
|
||||
|
||||
def _get_cached_shape_choice(self, cache_key):
|
||||
"""
|
||||
Get cached kernel choice for a specific shape.
|
||||
"""
|
||||
return self._shape_cache.get(cache_key)
|
||||
|
||||
def _cache_shape_choice(self, cache_key, kernel_idx):
|
||||
"""
|
||||
Cache kernel choice for a specific shape
|
||||
"""
|
||||
self._shape_cache[cache_key] = kernel_idx
|
||||
|
||||
def _select_kernel_by_shape(self, *args, **kwargs):
|
||||
"""
|
||||
Benchmark kernels for a particular shape and return the
|
||||
best kernel for this shape.
|
||||
"""
|
||||
shape_key = self._get_shape_cache_key(*args, **kwargs)
|
||||
timings = self.benchmark_sub_kernels(*args, **kwargs)
|
||||
self.picked_kernel = timings.index(min(timings))
|
||||
self._cache_shape_choice(shape_key, self.picked_kernel)
|
||||
|
||||
def _metrics_table_row(self, timings):
|
||||
def get_kernel_path(k):
|
||||
|
@ -18,6 +18,7 @@ import sympy
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
from torch._inductor.ir import MultiTemplateBuffer
|
||||
from torch._inductor.tiling_utils import analyze_memory_coalescing
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.fx.immutable_collections import immutable_dict
|
||||
@ -1548,18 +1549,19 @@ class SIMDScheduling(BaseScheduling):
|
||||
index_vars = kernel.split_and_set_ranges(node.get_ranges())
|
||||
node.codegen(index_vars)
|
||||
|
||||
def codegen_template(
|
||||
self, template_node, epilogue_nodes, prologue_nodes, *, only_gen_src_code=False
|
||||
) -> Optional[str]:
|
||||
def _codegen_single_template(
|
||||
self,
|
||||
kernel,
|
||||
render,
|
||||
template_node,
|
||||
epilogue_nodes,
|
||||
prologue_nodes,
|
||||
*,
|
||||
only_gen_src_code=False,
|
||||
):
|
||||
"""
|
||||
Codegen a triton template
|
||||
|
||||
If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper
|
||||
Helper method to codegen a single template kernel variant
|
||||
"""
|
||||
_, (_numel, rnumel) = template_node.group
|
||||
assert rnumel == 1
|
||||
kernel, render = template_node.node.make_kernel_render(template_node.node)
|
||||
|
||||
buf_name_to_prologue_group = {}
|
||||
template_reads = template_node.used_buffer_names()
|
||||
prologue_group = []
|
||||
@ -1655,18 +1657,113 @@ class SIMDScheduling(BaseScheduling):
|
||||
if only_gen_src_code:
|
||||
return src_code
|
||||
|
||||
kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
kernel.kernel_name = self.define_kernel(src_code, node_schedule, kernel)
|
||||
|
||||
if config.trace.enabled:
|
||||
set_kernel_post_grad_provenance_tracing(node_schedule, kernel_name)
|
||||
set_kernel_post_grad_provenance_tracing(
|
||||
node_schedule, kernel.kernel_name
|
||||
)
|
||||
|
||||
self.codegen_comment(node_schedule)
|
||||
kernel.call_kernel(kernel_name, template_node.node)
|
||||
return kernel
|
||||
|
||||
V.graph.removed_buffers |= kernel.removed_buffers
|
||||
V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
|
||||
self.free_buffers_in_scheduler()
|
||||
return None
|
||||
def codegen_template(
|
||||
self,
|
||||
template_node,
|
||||
epilogue_nodes,
|
||||
prologue_nodes,
|
||||
*,
|
||||
only_gen_src_code=False,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Codegen a triton template with multi-kernel dispatch support
|
||||
|
||||
If `only_gen_src_code=True` the src code will be returned instead of being
|
||||
codegenned into the wrapper
|
||||
"""
|
||||
|
||||
_, (_numel, rnumel) = template_node.group
|
||||
assert rnumel == 1
|
||||
|
||||
if (
|
||||
isinstance(template_node.node, MultiTemplateBuffer)
|
||||
and template_node.node._make_kernel_renders
|
||||
):
|
||||
kernels = []
|
||||
src_codes = []
|
||||
|
||||
for make_kernel_render in template_node.node._make_kernel_renders.values():
|
||||
kernel, render = make_kernel_render(
|
||||
template_node.node, hint_override=hint_override
|
||||
)
|
||||
|
||||
if only_gen_src_code:
|
||||
src_code = self._codegen_single_template(
|
||||
kernel,
|
||||
render,
|
||||
template_node,
|
||||
epilogue_nodes,
|
||||
prologue_nodes,
|
||||
only_gen_src_code=True,
|
||||
)
|
||||
assert isinstance(src_code, str)
|
||||
src_codes.append(src_code)
|
||||
else:
|
||||
kernel = self._codegen_single_template(
|
||||
kernel,
|
||||
render,
|
||||
template_node,
|
||||
epilogue_nodes,
|
||||
prologue_nodes,
|
||||
only_gen_src_code=False,
|
||||
)
|
||||
kernels.append(kernel)
|
||||
|
||||
if only_gen_src_code:
|
||||
return "\n\n".join(src_codes)
|
||||
|
||||
MultiKernel.merge_workspaces_inplace(kernels)
|
||||
multi_kernel = MultiKernel(kernels)
|
||||
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
|
||||
self.codegen_comment(node_schedule)
|
||||
|
||||
multi_kernel.call_kernel(multi_kernel.kernel_name)
|
||||
V.graph.removed_buffers |= multi_kernel.removed_buffers
|
||||
V.graph.inplaced_to_remove |= multi_kernel.inplaced_to_remove
|
||||
self.free_buffers_in_scheduler()
|
||||
return None
|
||||
else:
|
||||
kernel, render = template_node.node.make_kernel_render(
|
||||
template_node.node, hint_override=hint_override
|
||||
)
|
||||
|
||||
if only_gen_src_code:
|
||||
return self._codegen_single_template(
|
||||
kernel,
|
||||
render,
|
||||
template_node,
|
||||
epilogue_nodes,
|
||||
prologue_nodes,
|
||||
only_gen_src_code=True,
|
||||
)
|
||||
else:
|
||||
kernel = self._codegen_single_template(
|
||||
kernel,
|
||||
render,
|
||||
template_node,
|
||||
epilogue_nodes,
|
||||
prologue_nodes,
|
||||
only_gen_src_code=False,
|
||||
)
|
||||
|
||||
node_schedule = [*prologue_nodes, template_node, *epilogue_nodes]
|
||||
self.codegen_comment(node_schedule)
|
||||
kernel.call_kernel(kernel.kernel_name, template_node.node)
|
||||
|
||||
V.graph.removed_buffers |= kernel.removed_buffers
|
||||
V.graph.inplaced_to_remove |= kernel.inplaced_to_remove
|
||||
self.free_buffers_in_scheduler()
|
||||
return None
|
||||
|
||||
def codegen_sync(self):
|
||||
V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())
|
||||
@ -2439,7 +2536,9 @@ class SIMDScheduling(BaseScheduling):
|
||||
def ready_to_flush(self) -> bool:
|
||||
return False
|
||||
|
||||
def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
|
||||
def generate_kernel_code_from_nodes(
|
||||
self, nodes, benchmark_kernel=False, hint_override: Optional[int] = None
|
||||
):
|
||||
if not any(n.is_template() for n in nodes):
|
||||
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
|
||||
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
|
||||
@ -2464,6 +2563,7 @@ class SIMDScheduling(BaseScheduling):
|
||||
epilogue,
|
||||
prologue,
|
||||
only_gen_src_code=True,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
|
||||
|
@ -1855,6 +1855,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
min_elem_per_thread=0,
|
||||
optimize_mask=True,
|
||||
fixed_config: Optional[FixedTritonConfig] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.optimize_mask: bool = optimize_mask
|
||||
@ -1872,6 +1873,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
collections.defaultdict(dict)
|
||||
)
|
||||
self.tma_min_block_sizes = dict[str, int]()
|
||||
self.hint_override = hint_override
|
||||
self._load_counts: collections.Counter[str] = collections.Counter()
|
||||
|
||||
# A set of autotuning hints to pass as part of triton_meta
|
||||
@ -3696,13 +3698,13 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
buf = V.graph.try_get_buffer(arg_name)
|
||||
if buf:
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size(), hint_override=self.hint_override)}, {V.graph.sizevars.size_hints(buf.get_stride(), hint_override=self.hint_override)}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
|
||||
)
|
||||
elif arg_name in V.graph.constants:
|
||||
# note that random seed is put in V.graph.constants
|
||||
const_tensor = V.graph.constants[arg_name]
|
||||
result.writeline(
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size(), hint_override=self.hint_override)}, {V.graph.sizevars.size_hints(const_tensor.stride(), hint_override=self.hint_override)}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
|
||||
)
|
||||
elif isinstance(arg_sig, SizeArg):
|
||||
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
|
||||
|
@ -445,6 +445,12 @@ force_same_precision: bool = Config(
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Size hints for multi-kernel dispatch.
|
||||
# A reasonable default value of this config would be [64, 256, 4096]
|
||||
# TODO: @bobrenjc93 to roll this out to a few internal models to ensure this works
|
||||
# as expected before turning it on for everyone.
|
||||
multi_kernel_hints: list[int] = []
|
||||
|
||||
# Specify candidate backends for gemm autotune.
|
||||
# Possible choices are combinations of: ATen, Triton, CUTLASS, CK, CPP.
|
||||
# ATen: default Pytorch ATen kernels.
|
||||
|
@ -4877,7 +4877,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
self,
|
||||
layout: Layout,
|
||||
inputs: Sequence[IRNode],
|
||||
choice_timings_fn: Callable[[], dict[ChoiceCaller, float]],
|
||||
choice_timings_fn: Callable[[Optional[int]], dict[ChoiceCaller, float]],
|
||||
unfiltered_choices: list[ChoiceCaller],
|
||||
allowed_prologue_inps: OrderedSet[str],
|
||||
) -> None:
|
||||
@ -4888,7 +4888,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
allowed_prologue_inps=allowed_prologue_inps,
|
||||
)
|
||||
self._choice_timings_fn = choice_timings_fn
|
||||
self._choice_timings: Optional[dict[ChoiceCaller, float]] = None
|
||||
self._choice_timings: dict[Optional[int], dict[ChoiceCaller, float]] = {}
|
||||
self.original_inputs = inputs
|
||||
self._output_plannable = all(
|
||||
isinstance(choice, TritonTemplateCallerBase)
|
||||
@ -4898,6 +4898,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
)
|
||||
for choice in unfiltered_choices
|
||||
)
|
||||
self._make_kernel_renders: dict[Optional[int], Any] = {}
|
||||
|
||||
@property
|
||||
def output_plannable(self) -> bool:
|
||||
@ -4906,11 +4907,12 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
"""
|
||||
return self._output_plannable
|
||||
|
||||
@property
|
||||
def choice_timings(self) -> dict[ChoiceCaller, float]:
|
||||
if self._choice_timings is None:
|
||||
self._choice_timings = self._choice_timings_fn()
|
||||
return self._choice_timings
|
||||
def choice_timings(
|
||||
self, hint_override: Optional[int] = None
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
if hint_override not in self._choice_timings:
|
||||
self._choice_timings[hint_override] = self._choice_timings_fn(hint_override)
|
||||
return self._choice_timings[hint_override]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[None]:
|
||||
@ -4934,8 +4936,22 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
assert self.get_stride() == caller.layout.stride
|
||||
self.make_kernel_render = caller.get_make_kernel_render()
|
||||
|
||||
def get_min_choice(self) -> tuple[ChoiceCaller, float]:
|
||||
return min(self.choice_timings.items(), key=lambda x: x[1])
|
||||
def get_min_choice(
|
||||
self, hint_override: Optional[int] = None
|
||||
) -> tuple[ChoiceCaller, float]:
|
||||
timings = self.choice_timings(hint_override=hint_override)
|
||||
min_choice = min(timings, key=timings.get) # type: ignore[arg-type]
|
||||
return (min_choice, timings[min_choice])
|
||||
|
||||
def finalize_as_triton_callers(
|
||||
self, callers: dict[Optional[int], TritonTemplateCallerBase]
|
||||
) -> None:
|
||||
"""Finalize with multiple callers for different hint overrides"""
|
||||
for hint_override, caller in callers.items():
|
||||
self._make_kernel_renders[hint_override] = caller.get_make_kernel_render()
|
||||
|
||||
# Set the default to be the one without hint override
|
||||
self.make_kernel_render = self._make_kernel_renders[None]
|
||||
|
||||
|
||||
class CUDATemplateBuffer(TemplateBuffer):
|
||||
|
@ -28,6 +28,7 @@ import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.codecache import LambdaFuture, PyCodeCache
|
||||
from torch._inductor.ir import TritonTemplateCallerBase
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.fx.experimental.symbolic_shapes import free_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -2716,7 +2717,10 @@ class Scheduler:
|
||||
return backend.benchmark_fused_nodes(nodes)
|
||||
|
||||
def generate_kernel_code_from_nodes(
|
||||
self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool
|
||||
self,
|
||||
nodes: Sequence[BaseSchedulerNode],
|
||||
benchmark_kernel: bool,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Benchmark fused list of nodes and return the execution time
|
||||
@ -2727,7 +2731,9 @@ class Scheduler:
|
||||
self.current_device = device
|
||||
backend = self.get_backend(device)
|
||||
with dynamo_timed("benchmark_fused_nodes"):
|
||||
return backend.generate_kernel_code_from_nodes(nodes, benchmark_kernel)
|
||||
return backend.generate_kernel_code_from_nodes(
|
||||
nodes, benchmark_kernel, hint_override=hint_override
|
||||
)
|
||||
|
||||
def benchmark_codegened_module(
|
||||
self, module: ModuleType, device: torch.device
|
||||
@ -2789,7 +2795,7 @@ class Scheduler:
|
||||
min_node_unfused = next(
|
||||
(
|
||||
timing
|
||||
for timing in multi_node.choice_timings
|
||||
for timing in multi_node.choice_timings()
|
||||
if isinstance(
|
||||
timing,
|
||||
torch._inductor.select_algorithm.ExternKernelCaller,
|
||||
@ -2801,7 +2807,23 @@ class Scheduler:
|
||||
min_node_unfused,
|
||||
torch._inductor.ir.TritonTemplateCallerBase,
|
||||
):
|
||||
node.node.finalize_as_triton_caller(min_node_unfused)
|
||||
if config.multi_kernel_hints:
|
||||
callers: dict[Optional[int], TritonTemplateCallerBase] = {}
|
||||
callers[None] = min_node_unfused
|
||||
|
||||
for hint in config.multi_kernel_hints:
|
||||
timings = multi_node.choice_timings(hint_override=hint)
|
||||
triton_timings = {
|
||||
k: v
|
||||
for k, v in timings.items()
|
||||
if isinstance(k, TritonTemplateCallerBase)
|
||||
}
|
||||
choice = min(triton_timings.items(), key=lambda x: x[1])[0]
|
||||
callers[hint] = choice
|
||||
|
||||
node.node.finalize_as_triton_callers(callers)
|
||||
else:
|
||||
node.node.finalize_as_triton_caller(min_node_unfused)
|
||||
continue
|
||||
|
||||
out_tensorbox = min_node_unfused.output_node()
|
||||
@ -2924,10 +2946,10 @@ class Scheduler:
|
||||
async_compile = torch._inductor.async_compile.AsyncCompile()
|
||||
|
||||
def compile_kernel(
|
||||
nodes: Sequence[BaseSchedulerNode],
|
||||
nodes: Sequence[BaseSchedulerNode], hint_override: Optional[int] = None
|
||||
) -> tuple[Optional[LambdaFuture], ModuleType]:
|
||||
src_code = self.generate_kernel_code_from_nodes(
|
||||
nodes, benchmark_kernel=True
|
||||
nodes, benchmark_kernel=True, hint_override=hint_override
|
||||
)
|
||||
mod = PyCodeCache.load(src_code)
|
||||
if not async_compile.use_process_pool():
|
||||
@ -2949,8 +2971,58 @@ class Scheduler:
|
||||
)
|
||||
assert isinstance(multi_node, ir.MultiTemplateBuffer)
|
||||
|
||||
hint_override_best_fusion_choice: dict[
|
||||
Optional[int], TritonTemplateCallerBase
|
||||
] = {}
|
||||
future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
|
||||
for hint_override in config.multi_kernel_hints:
|
||||
choice_timings = multi_node.choice_timings(hint_override)
|
||||
for choice, unfused_time in sorted(
|
||||
choice_timings.items(), key=lambda x: x[1]
|
||||
):
|
||||
if not isinstance(
|
||||
choice, torch._inductor.select_algorithm.TritonTemplateCaller
|
||||
):
|
||||
continue
|
||||
with multi_node.swap_as_triton_caller(choice):
|
||||
future_choices.append(
|
||||
(
|
||||
choice,
|
||||
*compile_kernel(
|
||||
node_list_fused, hint_override=choice.hint_override
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
min_ms_fused = float("inf")
|
||||
ms_fused_choice: Optional[TritonTemplateCallerBase] = None
|
||||
new_timings = {}
|
||||
for choice, future, mod_fused in future_choices:
|
||||
try:
|
||||
if future is not None:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
if fusion_log.isEnabledFor(logging.DEBUG):
|
||||
fusion_log.debug(
|
||||
"Exception in compiling %s: %s",
|
||||
"prologue" if not epilogue_fusion else "epilogue",
|
||||
str(e),
|
||||
)
|
||||
continue
|
||||
with multi_node.swap_as_triton_caller(choice):
|
||||
ms_fused, path = self.benchmark_codegened_module(
|
||||
mod_fused, device
|
||||
)
|
||||
new_timings[choice] = ms_fused
|
||||
if ms_fused < min_ms_fused:
|
||||
min_ms_fused = ms_fused
|
||||
ms_fused_choice = choice
|
||||
multi_node._choice_timings[hint_override] = new_timings
|
||||
assert isinstance(ms_fused_choice, TritonTemplateCallerBase)
|
||||
hint_override_best_fusion_choice[hint_override] = ms_fused_choice
|
||||
|
||||
# Eagerly compile and benchmark non-template nodes
|
||||
choice_timings = multi_node.choice_timings
|
||||
choice_timings = multi_node.choice_timings()
|
||||
_, ms1 = multi_node.get_min_choice()
|
||||
ms2, path2 = (
|
||||
self.benchmark_fused_nodes(node_list_2)
|
||||
@ -3025,8 +3097,15 @@ class Scheduler:
|
||||
log_fusion(min_ms_fused, ms1, ms2)
|
||||
|
||||
if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None:
|
||||
multi_node.finalize_as_triton_caller(ms_fused_choice)
|
||||
multi_node._choice_timings = new_timings
|
||||
if config.multi_kernel_hints:
|
||||
hint_override_best_fusion_choice[None] = ms_fused_choice
|
||||
multi_node.finalize_as_triton_callers(
|
||||
hint_override_best_fusion_choice
|
||||
)
|
||||
else:
|
||||
multi_node.finalize_as_triton_caller(ms_fused_choice)
|
||||
|
||||
multi_node._choice_timings[None] = new_timings
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -5010,7 +5089,10 @@ class BaseScheduling:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_kernel_code_from_nodes(
|
||||
self, nodes: Sequence[BaseSchedulerNode], benchmark_kernel: bool
|
||||
self,
|
||||
nodes: Sequence[BaseSchedulerNode],
|
||||
benchmark_kernel: bool,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a kernel given a list of pre-fused nodes.
|
||||
|
@ -294,6 +294,15 @@ RecordedEventsType = list[tuple[str, list[Any], dict[str, Any]]]
|
||||
|
||||
|
||||
class TritonTemplateKernel(TritonKernel):
|
||||
"""
|
||||
A specialized kernel class for Triton templates that handles code generation
|
||||
for templated Triton kernels.
|
||||
|
||||
This class extends TritonKernel to provide additional functionality for
|
||||
template-based kernel generation, including support for subgraphs, workspace
|
||||
arguments, and prologue/epilogue fusion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name,
|
||||
@ -314,6 +323,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
prologue_loads_all_inputs=False,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> None:
|
||||
numel = sympy_product(output_node.get_size())
|
||||
super().__init__(
|
||||
@ -322,6 +332,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
"r0_": sympy.S.One,
|
||||
},
|
||||
features=SIMDKernelFeatures([], numel),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
self.input_nodes = input_nodes
|
||||
self.output_node = output_node
|
||||
@ -1096,16 +1107,26 @@ class TritonTemplateKernel(TritonKernel):
|
||||
def codegen_range_tree(self):
|
||||
pass # ignore default codegen
|
||||
|
||||
def additional_call_args_and_types(self):
|
||||
if isinstance(self.grid_fn, SymbolicGridFn):
|
||||
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
return (grid_args, map(type, grid_args))
|
||||
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
|
||||
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
return (grid_args, map(type, grid_args))
|
||||
return ((), ())
|
||||
|
||||
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
|
||||
wrapper = V.graph.wrapper_code
|
||||
_, call_args, _, arg_types = self.args.python_argdefs()
|
||||
|
||||
grid_args = ()
|
||||
if isinstance(self.grid_fn, SymbolicGridFn):
|
||||
grid_args = self.grid_fn.sympy_call(*self.call_sizes, self.meta)
|
||||
elif all(isinstance(x, (int, sympy.Integer)) for x in self.call_sizes):
|
||||
grid_args = self.grid_fn(*map(int, self.call_sizes), self.meta)
|
||||
else:
|
||||
additional_call_args, additional_arg_types = (
|
||||
self.additional_call_args_and_types()
|
||||
)
|
||||
|
||||
if not additional_call_args:
|
||||
assert not V.graph.cpp_wrapper, "cpp_wrapper requires SymbolicGridFn"
|
||||
wrapper.add_import_once(f"import {self.grid_fn.__module__}")
|
||||
meta = wrapper.add_meta_once(self.meta)
|
||||
@ -1114,9 +1135,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||
f"*{fn_name}({', '.join(map(pexpr, self.call_sizes))}, {meta})"
|
||||
)
|
||||
arg_types.append(None)
|
||||
assert len(grid_args) in (0, 3), "grid_fn should return 3 values"
|
||||
call_args.extend(grid_args)
|
||||
arg_types.extend(map(type, grid_args))
|
||||
|
||||
call_args.extend(additional_call_args)
|
||||
arg_types.extend(additional_arg_types)
|
||||
|
||||
if self.workspace_arg is not None:
|
||||
wrapper.generate_workspace_allocation(self.workspace_arg)
|
||||
@ -1202,6 +1223,7 @@ class GeneratedCodeCache:
|
||||
num_consumer_groups: int,
|
||||
num_buffers_warp_spec: int,
|
||||
kwargs: dict[str, Any],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
def layout_key(layout: ir.Layout) -> str:
|
||||
assert not isinstance(layout, ir.FlexibleLayout)
|
||||
@ -1252,6 +1274,7 @@ class GeneratedCodeCache:
|
||||
"num_buffers_warp_spec": num_buffers_warp_spec,
|
||||
"epilogue_fn_hash": epilogue_fn_hash,
|
||||
"kwargs": kwargs,
|
||||
"hint_override": hint_override,
|
||||
}
|
||||
)
|
||||
|
||||
@ -1356,6 +1379,7 @@ class TritonTemplate(KernelTemplate):
|
||||
layout: ir.Layout,
|
||||
kwargs: dict[str, Any],
|
||||
generate_with_caching,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> Optional[GenerateAndLoadResult]:
|
||||
"""Generate the python code and load it into the current process"""
|
||||
caching_enabled = (
|
||||
@ -1428,6 +1452,7 @@ class TritonTemplate(KernelTemplate):
|
||||
output_node=fake_out,
|
||||
workspace_arg=workspace_arg,
|
||||
use_jit=False,
|
||||
hint_override=hint_override,
|
||||
**kernel_options,
|
||||
)
|
||||
|
||||
@ -1547,6 +1572,7 @@ class TritonTemplate(KernelTemplate):
|
||||
call_sizes: Optional[Sequence[sympy.core.symbol.Symbol]] = None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
generate_with_caching=False,
|
||||
hint_override: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""This function generates a TritonTemplateCaller
|
||||
@ -1591,6 +1617,7 @@ class TritonTemplate(KernelTemplate):
|
||||
layout,
|
||||
kwargs,
|
||||
generate_with_caching and self._cache_codegen_enabled_for_template,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
# May happen as result of dev by 0.
|
||||
@ -1612,6 +1639,7 @@ class TritonTemplate(KernelTemplate):
|
||||
extra_args = V.graph.sizevars.size_hints(
|
||||
map(sympy.expand, result.kernel_args_sizevars_keys),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
kernel_hash_name = f"triton_{self.name}_{next(self.index_counter)}"
|
||||
@ -1635,13 +1663,14 @@ class TritonTemplate(KernelTemplate):
|
||||
|
||||
options = result.kernel_options
|
||||
|
||||
def make_kernel_render(out_node):
|
||||
def make_kernel_render(out_node, hint_override: Optional[int] = None):
|
||||
assert result is not None
|
||||
kernel = self.kernel_type(
|
||||
kernel_name=str(Placeholder.KERNEL_NAME),
|
||||
output_node=out_node,
|
||||
workspace_arg=workspace_arg,
|
||||
use_jit=False,
|
||||
hint_override=hint_override,
|
||||
**options,
|
||||
)
|
||||
render = functools.partial(
|
||||
@ -1657,6 +1686,7 @@ class TritonTemplate(KernelTemplate):
|
||||
*V.graph.sizevars.size_hints(
|
||||
call_sizes,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
kwargs,
|
||||
)
|
||||
@ -1708,6 +1738,7 @@ class TritonTemplate(KernelTemplate):
|
||||
mutated_inputs=mutated_inputs,
|
||||
workspace_arg=workspace_arg,
|
||||
allowed_prologue_inps=result.prologue_supported_inputs,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
|
||||
@ -1783,6 +1814,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
mutated_inputs=None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
allowed_prologue_inps: Optional[OrderedSet[str]] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(name, input_nodes, layout, description)
|
||||
self.make_kernel_render = make_kernel_render
|
||||
@ -1802,6 +1834,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
self.allowed_prologue_inps = (
|
||||
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet()
|
||||
)
|
||||
self.hint_override = hint_override
|
||||
|
||||
def benchmark(self, *args, out):
|
||||
assert self.bmreq is not None
|
||||
@ -2256,7 +2289,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
|
||||
has_autotuned = False
|
||||
|
||||
def benchmark(choices):
|
||||
def benchmark(choices, hint_override: Optional[int] = None):
|
||||
nonlocal has_autotuned
|
||||
# TODO(nmacchioni): remove this hacky way to tell if we ran benchmarking
|
||||
has_autotuned = True
|
||||
@ -2264,13 +2297,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# TODO(nmacchioni): remove this layer of abstraction
|
||||
# construct `benchmark_fn` which should pick between in-process and sub-process autotuning
|
||||
benchmark_fn = self.make_benchmark_fn(
|
||||
choices, input_nodes, layout, input_gen_fns
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
# `benchmark_fn(choices)` will execute each choice, and return a dict[choice, timing] which
|
||||
# maps each choice to its runtime, calculated by the specified benchmarker, in milliseconds
|
||||
return benchmark_fn(choices)
|
||||
|
||||
def autotune(choices):
|
||||
def autotune(choices, hint_override: Optional[int] = None):
|
||||
log.debug("Starting autotuning")
|
||||
|
||||
with dynamo_timed(
|
||||
@ -2292,13 +2325,13 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
),
|
||||
},
|
||||
):
|
||||
return benchmark(choices)
|
||||
return benchmark(choices, hint_override=hint_override)
|
||||
|
||||
if config.autotune_in_subproc:
|
||||
# Initialize the suprocess pool so it will warmup early.
|
||||
torch._inductor.autotune_process.get_tuning_process_pool()
|
||||
|
||||
def do_autotuning(choices, precompile_fn):
|
||||
def do_autotuning(choices, precompile_fn, hint_override: Optional[int] = None):
|
||||
precompile_start_ts = time.time()
|
||||
with dynamo_timed(
|
||||
f"{name}_template_precompiling",
|
||||
@ -2319,7 +2352,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
candidates,
|
||||
name,
|
||||
inputs_key,
|
||||
autotune,
|
||||
lambda choices: autotune(choices, hint_override=hint_override),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
choices = self.prune_choices_postscreen(
|
||||
choices, timings, name, inputs_key, self.prescreening_cache
|
||||
@ -2332,7 +2366,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
choices,
|
||||
name,
|
||||
inputs_key,
|
||||
autotune,
|
||||
lambda choices: autotune(choices, hint_override=hint_override),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
autotune_elapse = time.time() - autotune_start_ts
|
||||
@ -2355,6 +2390,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
autotune_elapse,
|
||||
precompile_elapse,
|
||||
prescreening_elapse,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
def profiler_bench_function():
|
||||
@ -2389,8 +2425,16 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
if return_multi_template and (config.max_autotune or config.max_autotune_gemm):
|
||||
|
||||
def get_timings():
|
||||
timings = do_autotuning(choices, precompile_fn)
|
||||
def get_timings(hint_override: Optional[int] = None):
|
||||
filtered_choices = [
|
||||
c
|
||||
for c in choices
|
||||
if not hasattr(c, "hint_override")
|
||||
or c.hint_override == hint_override
|
||||
]
|
||||
timings = do_autotuning(
|
||||
filtered_choices, precompile_fn, hint_override=hint_override
|
||||
)
|
||||
min_extern_choice = float("inf")
|
||||
for choice, timing in timings.items():
|
||||
if isinstance(choice, ExternKernelCaller):
|
||||
@ -2627,6 +2671,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> AutotuneArgs:
|
||||
"""
|
||||
Factory method to create AutotuneArgs from a list of ChoiceCallers.
|
||||
@ -2636,7 +2681,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
|
||||
# de-duplicate args
|
||||
unique_example_inputs = {
|
||||
x.get_name(): input_gen_fns.get(i, cls.benchmark_example_value)(x)
|
||||
x.get_name(): input_gen_fns.get(
|
||||
i, lambda x: cls.benchmark_example_value(x, hint_override=hint_override)
|
||||
)(x)
|
||||
for i, x in enumerate(input_nodes)
|
||||
}
|
||||
example_inputs = list(unique_example_inputs.values())
|
||||
@ -2649,20 +2696,23 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
input_node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hint(
|
||||
input_node.get_layout().offset,
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
)
|
||||
for input_node in input_nodes
|
||||
]
|
||||
out = cls.benchmark_example_value(layout)
|
||||
out = cls.benchmark_example_value(layout, hint_override=hint_override)
|
||||
out_extern = torch.as_strided(
|
||||
out, out.size(), out.stride(), V.graph.sizevars.size_hint(layout.offset)
|
||||
)
|
||||
@ -2771,8 +2821,11 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
inputs = cls.get_inputs(choices, input_nodes, layout, input_gen_fns)
|
||||
inputs = cls.get_inputs(
|
||||
choices, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
return cls.benchmark_choices(choices, inputs)
|
||||
|
||||
@classmethod
|
||||
@ -2782,6 +2835,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
from . import autotune_process
|
||||
|
||||
@ -2791,7 +2845,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
triton = [c for c in choices if not isinstance(c, ExternKernelCaller)]
|
||||
|
||||
timings = cls.benchmark_in_current_process(
|
||||
extern, input_nodes, layout, input_gen_fns
|
||||
extern, input_nodes, layout, input_gen_fns, hint_override=hint_override
|
||||
)
|
||||
timings.update(autotune_process.benchmark_in_sub_process(triton)) # type: ignore[arg-type]
|
||||
return timings
|
||||
@ -2803,6 +2857,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes: list[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]],
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
if DEBUG:
|
||||
print(f"{len(choices)} tuning requests:")
|
||||
@ -2813,6 +2868,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
else:
|
||||
return functools.partial(
|
||||
@ -2820,6 +2876,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
input_gen_fns=input_gen_fns,
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -2978,6 +3035,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
elapse: float,
|
||||
precompile_elapse: float,
|
||||
prescreening_elapse: Optional[float] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
):
|
||||
V.debug.log_autotuning_results(
|
||||
name, input_nodes, timings, elapse, precompile_elapse
|
||||
@ -2992,6 +3050,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
n.get_size(),
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
)
|
||||
@ -3079,7 +3138,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def benchmark_example_value(node):
|
||||
def benchmark_example_value(node, hint_override: Optional[int] = None):
|
||||
"""
|
||||
Convert an ir.Buffer into a concrete torch.Tensor we can use for
|
||||
benchmarking.
|
||||
@ -3098,10 +3157,12 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_size(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
V.graph.sizevars.size_hints(
|
||||
node.get_stride(),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
node.get_device(),
|
||||
node.get_dtype(),
|
||||
@ -3109,6 +3170,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
V.graph.sizevars.size_hints(
|
||||
V.graph.get_allocation_size(node),
|
||||
fallback=config.unbacked_symint_fallback,
|
||||
hint_override=hint_override,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -59,6 +59,15 @@ def statically_known_true(
|
||||
# lifting and in some cases we should be directly passing through to ShapeEnv,
|
||||
# but there is some extra inductor logic that needs to be handled here
|
||||
class SizeVarAllocator:
|
||||
"""
|
||||
A class that manages symbolic size variables and their relationships.
|
||||
|
||||
This class works with the ShapeEnv to handle symbolic shape expressions,
|
||||
simplify them, and provide utilities for guarding, checking, and evaluating
|
||||
symbolic expressions. It also manages precomputed replacements and stride
|
||||
calculations for tensor operations.
|
||||
"""
|
||||
|
||||
def __init__(self, shape_env=None) -> None:
|
||||
super().__init__()
|
||||
if shape_env is None:
|
||||
@ -527,7 +536,9 @@ class SizeVarAllocator:
|
||||
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
||||
return expr
|
||||
|
||||
def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]:
|
||||
def symbolic_hint(
|
||||
self, expr: Union[Expr, int], hint_override: Optional[int] = None
|
||||
) -> Union[Expr, int]:
|
||||
if isinstance(expr, int):
|
||||
return expr
|
||||
# Substitute all hints into expr, but leave unbacked symints alone
|
||||
@ -541,13 +552,21 @@ class SizeVarAllocator:
|
||||
return int(expr) # type: ignore[return-value]
|
||||
except TypeError:
|
||||
return expr # inf/nan/I
|
||||
|
||||
if hint_override:
|
||||
return hint_override
|
||||
|
||||
expr = self.remove_precomputed_replacements(expr)
|
||||
return sympy_subs(expr, self.var_to_val)
|
||||
|
||||
def size_hint(
|
||||
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
|
||||
self,
|
||||
expr: Union[Expr, int],
|
||||
*,
|
||||
fallback: Optional[int] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> int:
|
||||
out = self.symbolic_hint(expr)
|
||||
out = self.symbolic_hint(expr, hint_override=hint_override)
|
||||
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
|
||||
# Use the provided heuristic fallback hint
|
||||
unbacked_sym_vrs = {
|
||||
@ -581,8 +600,12 @@ class SizeVarAllocator:
|
||||
exprs: Iterable[Union[Expr, int]],
|
||||
*,
|
||||
fallback: Optional[int] = None,
|
||||
hint_override: Optional[int] = None,
|
||||
) -> tuple[int, ...]:
|
||||
return tuple(self.size_hint(x, fallback=fallback) for x in exprs)
|
||||
return tuple(
|
||||
self.size_hint(x, fallback=fallback, hint_override=hint_override)
|
||||
for x in exprs
|
||||
)
|
||||
|
||||
def size_hints_or_throw(
|
||||
self,
|
||||
|
@ -5,7 +5,7 @@ import itertools
|
||||
import math
|
||||
from functools import partial
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -33,6 +33,7 @@ class BaseConfig:
|
||||
block_k: int
|
||||
num_stages: int
|
||||
num_warps: int
|
||||
hint_override: Optional[int] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -421,7 +422,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
"""
|
||||
Finalizes configs after scaling, applying additional constraints.
|
||||
"""
|
||||
used: OrderedSet[tuple[int, ...]] = OrderedSet()
|
||||
used: OrderedSet[tuple[Optional[int], ...]] = OrderedSet()
|
||||
|
||||
max_mm_configs = config.test_configs.max_mm_configs
|
||||
|
||||
@ -430,11 +431,12 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256)
|
||||
|
||||
# Construct key for finding duplicate configs
|
||||
key: tuple[int, ...] = (
|
||||
key: tuple[Optional[int], ...] = (
|
||||
conf.block_m,
|
||||
conf.block_n,
|
||||
conf.block_k,
|
||||
conf.num_stages,
|
||||
conf.hint_override,
|
||||
num_warps,
|
||||
)
|
||||
|
||||
@ -451,12 +453,11 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
"BLOCK_M": conf.block_m,
|
||||
"BLOCK_N": conf.block_n,
|
||||
"BLOCK_K": conf.block_k,
|
||||
"num_stages": conf.num_stages,
|
||||
"num_warps": num_warps,
|
||||
"hint_override": conf.hint_override,
|
||||
}
|
||||
if group_m is not None:
|
||||
kwargs["GROUP_M"] = group_m
|
||||
yield self.triton_config(**kwargs)
|
||||
yield self.triton_config(conf.num_stages, num_warps, **kwargs)
|
||||
|
||||
def _scale_mm_configs(
|
||||
self,
|
||||
@ -467,6 +468,7 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
scale: float,
|
||||
has_int8_tensor: bool,
|
||||
exclude: Callable[[int, int, int], bool],
|
||||
hint_override: Optional[int] = None,
|
||||
) -> list[BaseConfig]:
|
||||
"""
|
||||
Scales and filters matrix multiplication configs based on input size.
|
||||
@ -476,47 +478,52 @@ class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton):
|
||||
min_block_size = 16
|
||||
min_block_size_k = 32 if has_int8_tensor else 16
|
||||
|
||||
m = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
m,
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
)
|
||||
n = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
n,
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
)
|
||||
k = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
k,
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size_k,
|
||||
)
|
||||
|
||||
scaled_configs = []
|
||||
for c in configs:
|
||||
scaled_config = dataclasses.replace(
|
||||
c,
|
||||
block_m=max(min(int(c.block_m * scale), m), min_block_size),
|
||||
block_n=max(min(int(c.block_n * scale), n), min_block_size),
|
||||
block_k=max(min(int(c.block_k * scale), k), min_block_size_k),
|
||||
for hint_override in [None] + config.multi_kernel_hints:
|
||||
m_hint = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
m,
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
hint_override=hint_override,
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
)
|
||||
n_hint = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
n,
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
hint_override=hint_override,
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
)
|
||||
k_hint = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
k,
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
hint_override=hint_override,
|
||||
)
|
||||
),
|
||||
min_block_size_k,
|
||||
)
|
||||
|
||||
if not exclude(
|
||||
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
|
||||
):
|
||||
scaled_configs.append(scaled_config)
|
||||
for c in configs:
|
||||
scaled_config = dataclasses.replace(
|
||||
c,
|
||||
block_m=max(min(int(c.block_m * scale), m_hint), min_block_size),
|
||||
block_n=max(min(int(c.block_n * scale), n_hint), min_block_size),
|
||||
block_k=max(min(int(c.block_k * scale), k_hint), min_block_size_k),
|
||||
hint_override=hint_override,
|
||||
)
|
||||
|
||||
if not exclude(
|
||||
scaled_config.block_m, scaled_config.block_n, scaled_config.block_k
|
||||
):
|
||||
scaled_configs.append(scaled_config)
|
||||
|
||||
return scaled_configs
|
||||
|
||||
|
Reference in New Issue
Block a user