Defer selection of triton template (#120275)

Our prior approach to epilogue fusion was to select from a choice from a set of triton templates and extern calls based on benchmarking inputs, then unconditionally fuse epilogues. This can be sub-optimal in following ways:

- We select an extern kernel, however an epilogue like relu() exists such that choosing a triton template + relu would have been faster
- We select a triton template, epilogue fuse, and register spilling occurs causing it to be slower than not epilogue fusing.

In this PR we wait to select either the Triton Template or Extern Kernel based on benchmarking results from the kernel itself and its epilogue. As soon as a successful fusion occurs where a fused Triton Template + epilogue is faster than the unfused choice we finalize the MultiTemplateBuffer as a specific template. If no fusion occurs we'll finalize the MultiTemplateBuffer after fusion.

Note: if there are multiple epilogue fusions (not super likely), even though we select a template after the first fusion, we will still benchmark to see if subsequent epilogue are worth fusing. We could potentially defer choosing template in this case in a follow up at expense of compile time.

Gives 4% HF training win, 10% TIMM inference win. Increases compilation time which I will be trying to address more in follow up prs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120275
Approved by: https://github.com/jansel
ghstack dependencies: #121996
This commit is contained in:
eellison
2024-03-18 18:35:34 -07:00
committed by PyTorch MergeBot
parent e5e0685f61
commit cbbed46377
12 changed files with 397 additions and 94 deletions

View File

@ -5,7 +5,7 @@ import sys
import torch
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
IS_CI,
@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import (
slowTest,
TEST_WITH_ASAN,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
# Make the helper files in test/ importable
@ -35,6 +36,7 @@ if IS_WINDOWS and IS_CI:
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests
@ -191,6 +193,108 @@ if HAS_CUDA and not TEST_WITH_ASAN:
copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda")
class BenchmarkMultiTemplateFusionCudaTest(InductorTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"benchmark_kernel": True,
"benchmark_fusion": True,
"benchmark_multi_templates": True,
}
)
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
def _equivalent_output_code_impl(self, size, first_dim=None, activation=True):
def foo(m, inp):
a = m(inp)
if activation:
return torch.nn.functional.relu(a)
return a
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
first_dim = first_dim if first_dim is not None else size
m = torch.nn.Linear(size, size, bias=True).half().cuda()
inp = torch.rand([first_dim, size]).half().cuda()
with torch.no_grad():
res, code = run_and_get_code(foo_c, m, inp)
torch._dynamo.reset()
with unittest.mock.patch.object(
torch._inductor.config, "benchmark_multi_templates", False
):
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
with torch.no_grad():
res2, code2 = run_and_get_code(foo_c, m, inp)
self.assertEqual(res, res2, atol=1e-4, rtol=1.1)
return code, code2
@fresh_inductor_cache()
@torch._inductor.config.patch(max_autotune_gemm_backends="TRITON")
def test_equivalent_template_code(self):
code, code2 = self._equivalent_output_code_impl(256)
for out_code in [code, code2]:
FileCheck().check("def call").check_count(
"empty_strided_cuda", 1, exactly=True
).check("triton_tem_fused_relu_0.run").check_count(
"del", 3, exactly=True
).check(
"return"
).run(
out_code[0]
)
@fresh_inductor_cache()
@torch._inductor.config.patch(max_autotune_gemm_backends="ATEN")
def test_equivalent_extern_code(self):
torch._dynamo.reset()
code, code2 = self._equivalent_output_code_impl(512, 1, False)
for out_code in [code, code2]:
FileCheck().check("def call").check_count(
"empty_strided_cuda", 1, exactly=True
).check("extern_kernels.").check_count("del", 3, exactly=True).check(
"return"
).run(
out_code[0]
)
def test_changed_layout(self):
# cat addmm planning will change layout - make sure propagated
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
return torch.cat(
[
torch.addmm(a, b, c),
torch.addmm(b, c, a),
],
1,
)
args = [
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
]
expected = fn(*args)
actual = torch.compile(fn, mode="max-autotune")(*args)
self.assertEqual(expected, actual)
torch._dynamo.reset()
if HAS_CPU and not torch.backends.mps.is_available():
class BenchmarkFusionCpuTest(TestCase):

View File

@ -14,11 +14,10 @@ from torch._inductor.autotune_process import (
TuningProcessPool,
)
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import Buffer, FixedLayout
from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
from torch._inductor.select_algorithm import (
AlgorithmSelectorCache,
ChoiceCaller,
TritonTemplateCaller,
)
from torch._inductor.test_case import run_tests, TestCase

View File

@ -200,6 +200,7 @@ RUN_PARALLEL_BLOCKLIST = [
"test_tensorexpr",
"test_cuda_primary_ctx",
"test_cuda_trace",
"inductor/test_benchmark_fusion",
"test_cuda_nvml_based_avail",
# temporarily sets a global config
"test_autograd_fallback",

View File

@ -56,7 +56,7 @@ from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
if TYPE_CHECKING:
from torch._inductor.graph import GraphLowering
from torch._inductor.select_algorithm import ChoiceCaller
from torch._inductor.ir import ChoiceCaller
from torch.hub import _Faketqdm, tqdm

View File

@ -16,7 +16,6 @@ from typing import (
Optional,
Set,
Tuple,
TYPE_CHECKING,
Union,
)
@ -32,7 +31,6 @@ from torch.utils._sympy.value_ranges import ValueRanges
from .. import config, metrics
from ..utils import (
DeferredLineBase,
do_bench,
free_symbol_startswith,
IndentedBuffer,
sympy_dot,
@ -42,8 +40,6 @@ from ..utils import (
)
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING:
from ..ir import TensorBox
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
@ -1672,45 +1668,6 @@ def jinja2_env():
return None
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
class ChoiceCaller:
"""
Represents a possible choice used in autotune_process.py.
During autotuning, self.benchmark() is first called to get benchmark result,
and if this choice is selected, self.output_node() is called to get the output_node.
Children classes: TritonTemplateCaller, CUDATemplateCaller.
"""
def __init__(self, name, input_nodes, layout):
super().__init__()
self.name = name
self.layout = layout
self.input_nodes = input_nodes
def benchmark(self, *args, out) -> float:
algo = self.to_callable()
return do_bench(lambda: algo(*args, out=out))
def call_name(self) -> str:
raise NotImplementedError()
def to_callable(self):
raise NotImplementedError()
def hash_key(self) -> str:
raise NotImplementedError()
def output_node(self) -> "TensorBox":
raise NotImplementedError()
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
"""Information returned here is logged to the autotune log file when that is enabled."""
return {}
class KernelTemplate:
"""
Base class for defining kernel templates.
@ -1752,7 +1709,7 @@ class KernelTemplate:
except NotImplementedError:
pass
def generate(self, **kwargs) -> ChoiceCaller:
def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
"""
Generates a ChoiceCaller instance from the given arguments.
"""

View File

@ -3,12 +3,19 @@ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
from ... import ir
from ...autotune_process import CUDABenchmarkRequest
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox
from ...select_algorithm import ChoiceCaller
from ...ir import (
Buffer,
ChoiceCaller,
CUDATemplateBuffer,
IRNode,
Layout,
PrimitiveInfoType,
TensorBox,
)
from ...utils import sympy_product
from ...virtualized import V
from ..common import IndentedBuffer, Kernel, OpOverrides
from ..common import IndentedBuffer, Kernel, OpOverrides, PrimitiveInfoType
from ..cpp import CppPrinter, DTYPE_TO_CPP
if TYPE_CHECKING:

View File

@ -31,6 +31,7 @@ import sympy
import torch
import torch._logging
import torch.utils._pytree as pytree
from torch._dynamo.utils import preserve_rng_state
from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata
from torch._prims_common import is_integer_dtype
@ -3848,7 +3849,18 @@ class TritonScheduling(BaseScheduling):
def ready_to_flush(self) -> bool:
return False
@preserve_rng_state()
def benchmark_fused_nodes(self, nodes):
@dataclasses.dataclass
class LastUsageHolder:
n: Any
last_usage: Any
def __del__(self):
self.n.last_usage = self.last_usage
last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes]
# empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
for n in nodes:
n.last_usage = set()
@ -3926,6 +3938,12 @@ class TritonScheduling(BaseScheduling):
# generating out of range indices for later calls.
ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args)[0]))
# overhead of cloning args gives bias for fusing the kernel
# in the case of mutating/in-placeable second fusion
# TODO - would be better as a hook in triton do_bench that reset
# the input values between benchmarking
ms = ms - do_bench(lambda: wrapped_jit_function.clone_args(*args))
log.debug(
"The fused kernel for %s took %.3f ms to run",
{n.get_name() for n in nodes},

View File

@ -284,6 +284,10 @@ debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
benchmark_multi_templates = (
os.environ.get("TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES", "0") == "1"
)
# how many nodes to allow into a single fusion
max_fusion_size = 64

View File

@ -66,6 +66,7 @@ from .utils import (
convert_shape_to_inductor,
convert_shape_to_symint,
developer_warning,
do_bench,
get_kernel_metadata,
is_dynamic,
pad_listlike,
@ -3476,6 +3477,92 @@ class TritonTemplateBuffer(TemplateBuffer):
pass
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
class ChoiceCaller:
"""
Represents a possible choice used in autotune_process.py.
During autotuning, self.benchmark() is first called to get benchmark result,
and if this choice is selected, self.output_node() is called to get the output_node.
Children classes: TritonTemplateCaller, CUDATemplateCaller.
"""
def __init__(self, name, input_nodes, layout):
super().__init__()
self.name = name
self.layout = layout
self.input_nodes = input_nodes
def benchmark(self, *args, out) -> float:
algo = self.to_callable()
return do_bench(lambda: algo(*args, out=out))
def call_name(self) -> str:
raise NotImplementedError()
def to_callable(self):
raise NotImplementedError()
def hash_key(self) -> str:
raise NotImplementedError()
def output_node(self) -> "TensorBox":
raise NotImplementedError()
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
"""Information returned here is logged to the autotune log file when that is enabled."""
return {}
class TritonTemplateCallerBase(ChoiceCaller):
def get_make_kernel_render(self) -> Any:
raise NotImplementedError()
class MultiTemplateBuffer(TritonTemplateBuffer):
"""
Represents a Buffer with multiple backing implementation choices.
Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential
epilogue we will benchmark each of the choices with the epilogue to determine an implementation.
Otherwise, the fastest base choice will be chosen.
"""
def __init__(
self,
layout: Layout,
inputs: List[IRNode],
choice_timings: Dict[ChoiceCaller, float],
):
super().__init__(layout=layout, inputs=inputs, make_kernel_render=None)
self.choice_timings = choice_timings
self.original_inputs = inputs
@contextlib.contextmanager
def swap_as_triton_caller(self, caller: TritonTemplateCallerBase):
assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller)
assert self.layout == caller.layout
render = self.make_kernel_render
self.make_kernel_render = caller.get_make_kernel_render()
try:
yield
finally:
self.make_kernel_render = render
def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase):
assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller)
assert self.layout.size == caller.layout.size
assert self.layout.stride == caller.layout.stride
self.make_kernel_render = caller.get_make_kernel_render()
def get_min_choice(self) -> Tuple[ChoiceCaller, float]:
min_choice = min(self.choice_timings, key=self.choice_timings.get) # type: ignore[arg-type]
return (min_choice, self.choice_timings[min_choice])
class CUDATemplateBuffer(TemplateBuffer):
def __init__(
self,

View File

@ -1,7 +1,8 @@
import logging
from typing import List
from ..select_algorithm import autotune_select_algorithm, ChoiceCaller, TritonTemplate
from ..ir import ChoiceCaller
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
log = logging.getLogger(__name__)

View File

@ -1258,7 +1258,6 @@ class Scheduler:
self.fuse_cache = {}
self.post_grad_graph_id = next(_post_grad_graph_counter)
self.nodes = []
self.available_buffer_names = {
*V.graph.graph_inputs.keys(),
*V.graph.constants.keys(),
@ -1310,6 +1309,7 @@ class Scheduler:
self.topological_sort_schedule()
self.logged_slow_fusion = set()
self.fuse_nodes()
self.finalize_multi_template_buffers()
if config.reorder_for_compute_comm_overlap:
# Refresh node_users and inverse_users to reflect fused nodes
self.compute_node_users()
@ -1695,7 +1695,7 @@ class Scheduler:
fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1)
break
def benchmark_fused_nodes(self, nodes):
def benchmark_fused_nodes(self, nodes) -> Tuple[float, str]:
"""
Benchmark fused list of nodes and return the execution time
in milliseconds on randomly generated inputs.
@ -1707,6 +1707,54 @@ class Scheduler:
backend = self.get_backend(device)
return backend.benchmark_fused_nodes(nodes)
def finalize_multi_template_buffers(self):
def replace_buffer(orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer):
replaced_name = new_node.name
orig_name = orig_node.get_name()
assert isinstance(orig_name, str) and isinstance(replaced_name, str)
del V.graph.name_to_buffer[replaced_name]
new_node.name = orig_name
V.graph.buffers.remove(orig_node)
V.graph.name_to_buffer[orig_name] = new_node
for i, node in enumerate(self.nodes):
if isinstance(node, SchedulerNode) and isinstance(
node.node, ir.MultiTemplateBuffer
):
multi_node = node.node
min_node_unfused, _ = multi_node.get_min_choice()
if isinstance(
min_node_unfused,
torch._inductor.ir.TritonTemplateCallerBase,
):
node.node.finalize_as_triton_caller(min_node_unfused)
continue
out_tensorbox = min_node_unfused.output_node()
out_storage = out_tensorbox.data
assert isinstance(out_storage, ir.StorageBox)
out_buffer = out_storage.data
assert isinstance(out_buffer, ir.Buffer)
out_buffer.layout = multi_node.layout
replace_buffer(multi_node, out_buffer)
new_scheduler_node = self.create_scheduler_node(out_buffer)
self.nodes[i] = new_scheduler_node
self.name_to_node[node.get_name()] = new_scheduler_node
self.name_to_fused_node[node.get_name()] = new_scheduler_node
new_scheduler_node.users = node.users
new_scheduler_node.min_order = node.min_order
new_scheduler_node.max_order = node.max_order
new_scheduler_node.last_usage = node.last_usage
for user in new_scheduler_node.users:
user.node.inverse_users.remove(node)
user.node.inverse_users.append(new_scheduler_node)
def speedup_by_fusion(self, node1, node2):
"""
If config.benchmark_fusion is False, always return True.
@ -1749,42 +1797,82 @@ class Scheduler:
why = WhyNoFuse(node1, node2)
try:
ms1, path1 = self.benchmark_fused_nodes(node_list_1)
if math.isinf(ms1):
why("register spilling of the first kernel")
return False
def log_fusion(ms_fused, ms1, ms2):
if fusion_log.isEnabledFor(logging.DEBUG):
if ms_fused < ms1 + ms2:
fusion_log.debug(
"can fuse (benchmark): fusing %s with %s cause %sx speedup",
node1.get_names(),
node2.get_names(),
green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
)
else:
fusion_log.debug(
"cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
node1.get_names(),
node2.get_names(),
red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
)
if isinstance(node1, SchedulerNode) and isinstance(
node1.node, ir.MultiTemplateBuffer
):
multi_node = node1.node
choice_timings = multi_node.choice_timings
_, ms1 = multi_node.get_min_choice()
ms2, path2 = self.benchmark_fused_nodes(node_list_2)
if math.isinf(ms2):
why("register spilling of the second kernel")
return False
ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused)
if math.isinf(ms_fused):
why("register spilling of the fused kernel")
return False
except CompilationError as e:
# workaround triton issue: https://github.com/openai/triton/issues/2151
if "Loop-carried variable" in str(e):
return True # allow fusion
else:
raise
if fusion_log.isEnabledFor(logging.DEBUG):
if ms_fused < ms1 + ms2:
fusion_log.debug(
"can fuse (benchmark): fusing %s with %s cause %sx speedup",
node1.get_names(),
node2.get_names(),
green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
)
else:
fusion_log.debug(
"cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
node1.get_names(),
node2.get_names(),
red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
)
min_ms_fused = float("inf")
ms_fused_choice = None
for choice, unfused_time in choice_timings.items():
if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase):
continue
if unfused_time >= ms1 + ms2:
continue
# TODO - parallel compile triton templates
# TODO - should prune/skip choices that are not within certain % of best choice
with node1.node.swap_as_triton_caller(choice):
ms_fused, _ = self.benchmark_fused_nodes(node_list_fused)
if ms_fused < min_ms_fused:
min_ms_fused = ms_fused
ms_fused_choice = choice
log_fusion(min_ms_fused, ms1, ms2)
# after we do a fusion, we finalize a triton template.
# TODO - could preserve multi template and choices for subsequent fusions
if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None:
node1.node.finalize_as_triton_caller(ms_fused_choice)
return True
else:
return False
else:
try:
ms1, path1 = self.benchmark_fused_nodes(node_list_1)
if math.isinf(ms1):
why("register spilling of the first kernel")
return False
ms2, path2 = self.benchmark_fused_nodes(node_list_2)
if math.isinf(ms2):
why("register spilling of the second kernel")
return False
ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused)
if math.isinf(ms_fused):
why("register spilling of the fused kernel")
return False
except CompilationError as e:
# workaround triton issue: https://github.com/openai/triton/issues/2151
if "Loop-carried variable" in str(e):
return True # allow fusion
else:
raise
log_fusion(ms_fused, ms1, ms2)
if (
is_metric_table_enabled("slow_fusion")
and ms_fused >= ms1 + ms2

View File

@ -22,12 +22,8 @@ from torch._dynamo.utils import counters, identity, preserve_rng_state
from . import config, ir
from .autotune_process import TensorMeta, TritonBenchmarkRequest
from .codecache import code_hash, PersistentCache, PyCodeCache
from .codegen.common import (
ChoiceCaller,
IndentedBuffer,
KernelTemplate,
PrimitiveInfoType,
)
from .codegen.common import IndentedBuffer, KernelTemplate
from .codegen.triton import (
gen_common_triton_imports,
texpr,
@ -35,8 +31,10 @@ from .codegen.triton import (
TritonPrinter,
TritonScheduling,
)
from .codegen.triton_utils import config_of, signature_to_meta
from .exc import CUDACompileError
from .ir import ChoiceCaller, PrimitiveInfoType
from .utils import (
do_bench,
get_dtype_size,
@ -653,7 +651,7 @@ class ExternKernelChoice:
)
class TritonTemplateCaller(ChoiceCaller):
class TritonTemplateCaller(ir.TritonTemplateCallerBase):
def __init__(
self,
name,
@ -713,6 +711,9 @@ class TritonTemplateCaller(ChoiceCaller):
"""Information returned here is logged to the autotune log file when that is enabled."""
return self.log_info
def get_make_kernel_render(self):
return self.make_kernel_render
class ExternKernelCaller(ChoiceCaller):
def __init__(
@ -814,6 +815,7 @@ class AlgorithmSelectorCache(PersistentCache):
# generating a random torch.Tensor for benchmarking.
input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
precompilation_timeout_seconds: int = 60 * 60,
return_multi_template=False,
):
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
@ -911,6 +913,33 @@ class AlgorithmSelectorCache(PersistentCache):
or config.trace.log_autotuning_results
):
self.log_results(name, input_nodes, timings, autotune_elapse)
if return_multi_template:
min_extern_choice = float("inf")
for choice, timing in timings.items():
if isinstance(choice, ExternKernelCaller):
min_extern_choice = min(min_extern_choice, timing)
timings = {
choice: time
for choice, time in timings.items()
if (
time <= min_extern_choice
or not isinstance(choice, ExternKernelCaller)
)
}
if len(timings) == 1:
return next(iter(timings)).output_node()
return torch._inductor.ir.TensorBox.create(
torch._inductor.ir.MultiTemplateBuffer(
layout,
input_nodes,
timings,
)
)
selected_choice = builtins.min(timings, key=timings.__getitem__).output_node()
log.debug("selected choice: %s", str(selected_choice))
return selected_choice
@ -1143,6 +1172,14 @@ def autotune_select_algorithm(*args, **kwargs):
global _ALGORITHM_SELECTOR_CACHE
if _ALGORITHM_SELECTOR_CACHE is None:
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
if "return_multi_template" not in kwargs:
# TODO - enable multi templates even if benchmark_fusion not enabled
kwargs["return_multi_template"] = (
torch._inductor.config.benchmark_multi_templates
and torch._inductor.config.benchmark_fusion
)
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)