mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Defer selection of triton template (#120275)
Our prior approach to epilogue fusion was to select from a choice from a set of triton templates and extern calls based on benchmarking inputs, then unconditionally fuse epilogues. This can be sub-optimal in following ways: - We select an extern kernel, however an epilogue like relu() exists such that choosing a triton template + relu would have been faster - We select a triton template, epilogue fuse, and register spilling occurs causing it to be slower than not epilogue fusing. In this PR we wait to select either the Triton Template or Extern Kernel based on benchmarking results from the kernel itself and its epilogue. As soon as a successful fusion occurs where a fused Triton Template + epilogue is faster than the unfused choice we finalize the MultiTemplateBuffer as a specific template. If no fusion occurs we'll finalize the MultiTemplateBuffer after fusion. Note: if there are multiple epilogue fusions (not super likely), even though we select a template after the first fusion, we will still benchmark to see if subsequent epilogue are worth fusing. We could potentially defer choosing template in this case in a follow up at expense of compile time. Gives 4% HF training win, 10% TIMM inference win. Increases compilation time which I will be trying to address more in follow up prs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120275 Approved by: https://github.com/jansel ghstack dependencies: #121996
This commit is contained in:
committed by
PyTorch MergeBot
parent
e5e0685f61
commit
cbbed46377
@ -5,7 +5,7 @@ import sys
|
||||
|
||||
import torch
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch._inductor.utils import fresh_inductor_cache, run_and_get_code
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_CI,
|
||||
@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import (
|
||||
slowTest,
|
||||
TEST_WITH_ASAN,
|
||||
)
|
||||
|
||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
@ -35,6 +36,7 @@ if IS_WINDOWS and IS_CI:
|
||||
sys.exit(0)
|
||||
raise unittest.SkipTest("requires sympy/functorch/filelock")
|
||||
|
||||
|
||||
from inductor.test_torchinductor import check_model, check_model_cuda, copy_tests
|
||||
|
||||
|
||||
@ -191,6 +193,108 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
|
||||
copy_tests(BenchmarkFusionTestTemplate, BenchmarkFusionCudaTest, "cuda")
|
||||
|
||||
class BenchmarkMultiTemplateFusionCudaTest(InductorTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls._stack = contextlib.ExitStack()
|
||||
cls._stack.enter_context(
|
||||
config.patch(
|
||||
{
|
||||
"benchmark_kernel": True,
|
||||
"benchmark_fusion": True,
|
||||
"benchmark_multi_templates": True,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._stack.close()
|
||||
super().tearDownClass()
|
||||
|
||||
def _equivalent_output_code_impl(self, size, first_dim=None, activation=True):
|
||||
def foo(m, inp):
|
||||
a = m(inp)
|
||||
if activation:
|
||||
return torch.nn.functional.relu(a)
|
||||
return a
|
||||
|
||||
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
|
||||
first_dim = first_dim if first_dim is not None else size
|
||||
|
||||
m = torch.nn.Linear(size, size, bias=True).half().cuda()
|
||||
inp = torch.rand([first_dim, size]).half().cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
res, code = run_and_get_code(foo_c, m, inp)
|
||||
|
||||
torch._dynamo.reset()
|
||||
with unittest.mock.patch.object(
|
||||
torch._inductor.config, "benchmark_multi_templates", False
|
||||
):
|
||||
foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo)
|
||||
with torch.no_grad():
|
||||
res2, code2 = run_and_get_code(foo_c, m, inp)
|
||||
|
||||
self.assertEqual(res, res2, atol=1e-4, rtol=1.1)
|
||||
return code, code2
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@torch._inductor.config.patch(max_autotune_gemm_backends="TRITON")
|
||||
def test_equivalent_template_code(self):
|
||||
code, code2 = self._equivalent_output_code_impl(256)
|
||||
for out_code in [code, code2]:
|
||||
FileCheck().check("def call").check_count(
|
||||
"empty_strided_cuda", 1, exactly=True
|
||||
).check("triton_tem_fused_relu_0.run").check_count(
|
||||
"del", 3, exactly=True
|
||||
).check(
|
||||
"return"
|
||||
).run(
|
||||
out_code[0]
|
||||
)
|
||||
|
||||
@fresh_inductor_cache()
|
||||
@torch._inductor.config.patch(max_autotune_gemm_backends="ATEN")
|
||||
def test_equivalent_extern_code(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
code, code2 = self._equivalent_output_code_impl(512, 1, False)
|
||||
|
||||
for out_code in [code, code2]:
|
||||
FileCheck().check("def call").check_count(
|
||||
"empty_strided_cuda", 1, exactly=True
|
||||
).check("extern_kernels.").check_count("del", 3, exactly=True).check(
|
||||
"return"
|
||||
).run(
|
||||
out_code[0]
|
||||
)
|
||||
|
||||
def test_changed_layout(self):
|
||||
# cat addmm planning will change layout - make sure propagated
|
||||
def fn(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor):
|
||||
return torch.cat(
|
||||
[
|
||||
torch.addmm(a, b, c),
|
||||
torch.addmm(b, c, a),
|
||||
],
|
||||
1,
|
||||
)
|
||||
|
||||
args = [
|
||||
torch.randn(4, 4, device="cuda"),
|
||||
torch.randn(4, 4, device="cuda"),
|
||||
torch.randn(4, 4, device="cuda"),
|
||||
]
|
||||
|
||||
expected = fn(*args)
|
||||
actual = torch.compile(fn, mode="max-autotune")(*args)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
torch._dynamo.reset()
|
||||
|
||||
|
||||
if HAS_CPU and not torch.backends.mps.is_available():
|
||||
|
||||
class BenchmarkFusionCpuTest(TestCase):
|
||||
|
@ -14,11 +14,10 @@ from torch._inductor.autotune_process import (
|
||||
TuningProcessPool,
|
||||
)
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.ir import Buffer, FixedLayout
|
||||
from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout
|
||||
from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm
|
||||
from torch._inductor.select_algorithm import (
|
||||
AlgorithmSelectorCache,
|
||||
ChoiceCaller,
|
||||
TritonTemplateCaller,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
@ -200,6 +200,7 @@ RUN_PARALLEL_BLOCKLIST = [
|
||||
"test_tensorexpr",
|
||||
"test_cuda_primary_ctx",
|
||||
"test_cuda_trace",
|
||||
"inductor/test_benchmark_fusion",
|
||||
"test_cuda_nvml_based_avail",
|
||||
# temporarily sets a global config
|
||||
"test_autograd_fallback",
|
||||
|
@ -56,7 +56,7 @@ from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.select_algorithm import ChoiceCaller
|
||||
from torch._inductor.ir import ChoiceCaller
|
||||
|
||||
from torch.hub import _Faketqdm, tqdm
|
||||
|
||||
|
@ -16,7 +16,6 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -32,7 +31,6 @@ from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, metrics
|
||||
from ..utils import (
|
||||
DeferredLineBase,
|
||||
do_bench,
|
||||
free_symbol_startswith,
|
||||
IndentedBuffer,
|
||||
sympy_dot,
|
||||
@ -42,8 +40,6 @@ from ..utils import (
|
||||
)
|
||||
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..ir import TensorBox
|
||||
|
||||
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
||||
|
||||
@ -1672,45 +1668,6 @@ def jinja2_env():
|
||||
return None
|
||||
|
||||
|
||||
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
|
||||
|
||||
|
||||
class ChoiceCaller:
|
||||
"""
|
||||
Represents a possible choice used in autotune_process.py.
|
||||
During autotuning, self.benchmark() is first called to get benchmark result,
|
||||
and if this choice is selected, self.output_node() is called to get the output_node.
|
||||
|
||||
Children classes: TritonTemplateCaller, CUDATemplateCaller.
|
||||
"""
|
||||
|
||||
def __init__(self, name, input_nodes, layout):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.layout = layout
|
||||
self.input_nodes = input_nodes
|
||||
|
||||
def benchmark(self, *args, out) -> float:
|
||||
algo = self.to_callable()
|
||||
return do_bench(lambda: algo(*args, out=out))
|
||||
|
||||
def call_name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_callable(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def hash_key(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def output_node(self) -> "TensorBox":
|
||||
raise NotImplementedError()
|
||||
|
||||
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
return {}
|
||||
|
||||
|
||||
class KernelTemplate:
|
||||
"""
|
||||
Base class for defining kernel templates.
|
||||
@ -1752,7 +1709,7 @@ class KernelTemplate:
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def generate(self, **kwargs) -> ChoiceCaller:
|
||||
def generate(self, **kwargs) -> "torch._inductor.ir.ChoiceCaller":
|
||||
"""
|
||||
Generates a ChoiceCaller instance from the given arguments.
|
||||
"""
|
||||
|
@ -3,12 +3,19 @@ from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from ... import ir
|
||||
from ...autotune_process import CUDABenchmarkRequest
|
||||
from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout, TensorBox
|
||||
from ...select_algorithm import ChoiceCaller
|
||||
from ...ir import (
|
||||
Buffer,
|
||||
ChoiceCaller,
|
||||
CUDATemplateBuffer,
|
||||
IRNode,
|
||||
Layout,
|
||||
PrimitiveInfoType,
|
||||
TensorBox,
|
||||
)
|
||||
from ...utils import sympy_product
|
||||
from ...virtualized import V
|
||||
from ..common import IndentedBuffer, Kernel, OpOverrides
|
||||
|
||||
from ..common import IndentedBuffer, Kernel, OpOverrides, PrimitiveInfoType
|
||||
from ..cpp import CppPrinter, DTYPE_TO_CPP
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -31,6 +31,7 @@ import sympy
|
||||
import torch
|
||||
import torch._logging
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import preserve_rng_state
|
||||
|
||||
from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata
|
||||
from torch._prims_common import is_integer_dtype
|
||||
@ -3848,7 +3849,18 @@ class TritonScheduling(BaseScheduling):
|
||||
def ready_to_flush(self) -> bool:
|
||||
return False
|
||||
|
||||
@preserve_rng_state()
|
||||
def benchmark_fused_nodes(self, nodes):
|
||||
@dataclasses.dataclass
|
||||
class LastUsageHolder:
|
||||
n: Any
|
||||
last_usage: Any
|
||||
|
||||
def __del__(self):
|
||||
self.n.last_usage = self.last_usage
|
||||
|
||||
last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes]
|
||||
|
||||
# empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
|
||||
for n in nodes:
|
||||
n.last_usage = set()
|
||||
@ -3926,6 +3938,12 @@ class TritonScheduling(BaseScheduling):
|
||||
# generating out of range indices for later calls.
|
||||
ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args)[0]))
|
||||
|
||||
# overhead of cloning args gives bias for fusing the kernel
|
||||
# in the case of mutating/in-placeable second fusion
|
||||
# TODO - would be better as a hook in triton do_bench that reset
|
||||
# the input values between benchmarking
|
||||
ms = ms - do_bench(lambda: wrapped_jit_function.clone_args(*args))
|
||||
|
||||
log.debug(
|
||||
"The fused kernel for %s took %.3f ms to run",
|
||||
{n.get_name() for n in nodes},
|
||||
|
@ -284,6 +284,10 @@ debug_fusion = os.environ.get("TORCHINDUCTOR_DEBUG_FUSION") == "1"
|
||||
benchmark_fusion = os.environ.get("TORCHINDUCTOR_BENCHMARK_FUSION") == "1"
|
||||
enabled_metric_tables = os.environ.get("TORCHINDUCTOR_ENABLED_METRIC_TABLES", "")
|
||||
|
||||
benchmark_multi_templates = (
|
||||
os.environ.get("TORCHINDUCTOR_BENCHMARK_MULTI_TEMPLATES", "0") == "1"
|
||||
)
|
||||
|
||||
# how many nodes to allow into a single fusion
|
||||
max_fusion_size = 64
|
||||
|
||||
|
@ -66,6 +66,7 @@ from .utils import (
|
||||
convert_shape_to_inductor,
|
||||
convert_shape_to_symint,
|
||||
developer_warning,
|
||||
do_bench,
|
||||
get_kernel_metadata,
|
||||
is_dynamic,
|
||||
pad_listlike,
|
||||
@ -3476,6 +3477,92 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
pass
|
||||
|
||||
|
||||
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
|
||||
|
||||
|
||||
class ChoiceCaller:
|
||||
"""
|
||||
Represents a possible choice used in autotune_process.py.
|
||||
During autotuning, self.benchmark() is first called to get benchmark result,
|
||||
and if this choice is selected, self.output_node() is called to get the output_node.
|
||||
|
||||
Children classes: TritonTemplateCaller, CUDATemplateCaller.
|
||||
"""
|
||||
|
||||
def __init__(self, name, input_nodes, layout):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.layout = layout
|
||||
self.input_nodes = input_nodes
|
||||
|
||||
def benchmark(self, *args, out) -> float:
|
||||
algo = self.to_callable()
|
||||
return do_bench(lambda: algo(*args, out=out))
|
||||
|
||||
def call_name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def to_callable(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def hash_key(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def output_node(self) -> "TensorBox":
|
||||
raise NotImplementedError()
|
||||
|
||||
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
return {}
|
||||
|
||||
|
||||
class TritonTemplateCallerBase(ChoiceCaller):
|
||||
def get_make_kernel_render(self) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
"""
|
||||
Represents a Buffer with multiple backing implementation choices.
|
||||
|
||||
Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential
|
||||
epilogue we will benchmark each of the choices with the epilogue to determine an implementation.
|
||||
Otherwise, the fastest base choice will be chosen.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layout: Layout,
|
||||
inputs: List[IRNode],
|
||||
choice_timings: Dict[ChoiceCaller, float],
|
||||
):
|
||||
super().__init__(layout=layout, inputs=inputs, make_kernel_render=None)
|
||||
self.choice_timings = choice_timings
|
||||
self.original_inputs = inputs
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_as_triton_caller(self, caller: TritonTemplateCallerBase):
|
||||
assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller)
|
||||
assert self.layout == caller.layout
|
||||
|
||||
render = self.make_kernel_render
|
||||
self.make_kernel_render = caller.get_make_kernel_render()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.make_kernel_render = render
|
||||
|
||||
def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase):
|
||||
assert isinstance(caller, torch._inductor.select_algorithm.TritonTemplateCaller)
|
||||
assert self.layout.size == caller.layout.size
|
||||
assert self.layout.stride == caller.layout.stride
|
||||
self.make_kernel_render = caller.get_make_kernel_render()
|
||||
|
||||
def get_min_choice(self) -> Tuple[ChoiceCaller, float]:
|
||||
min_choice = min(self.choice_timings, key=self.choice_timings.get) # type: ignore[arg-type]
|
||||
return (min_choice, self.choice_timings[min_choice])
|
||||
|
||||
|
||||
class CUDATemplateBuffer(TemplateBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -1,7 +1,8 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..select_algorithm import autotune_select_algorithm, ChoiceCaller, TritonTemplate
|
||||
from ..ir import ChoiceCaller
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -1258,7 +1258,6 @@ class Scheduler:
|
||||
self.fuse_cache = {}
|
||||
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
||||
|
||||
self.nodes = []
|
||||
self.available_buffer_names = {
|
||||
*V.graph.graph_inputs.keys(),
|
||||
*V.graph.constants.keys(),
|
||||
@ -1310,6 +1309,7 @@ class Scheduler:
|
||||
self.topological_sort_schedule()
|
||||
self.logged_slow_fusion = set()
|
||||
self.fuse_nodes()
|
||||
self.finalize_multi_template_buffers()
|
||||
if config.reorder_for_compute_comm_overlap:
|
||||
# Refresh node_users and inverse_users to reflect fused nodes
|
||||
self.compute_node_users()
|
||||
@ -1695,7 +1695,7 @@ class Scheduler:
|
||||
fusion_log.debug("===== fusion complete (%d iterations) =====", i + 1)
|
||||
break
|
||||
|
||||
def benchmark_fused_nodes(self, nodes):
|
||||
def benchmark_fused_nodes(self, nodes) -> Tuple[float, str]:
|
||||
"""
|
||||
Benchmark fused list of nodes and return the execution time
|
||||
in milliseconds on randomly generated inputs.
|
||||
@ -1707,6 +1707,54 @@ class Scheduler:
|
||||
backend = self.get_backend(device)
|
||||
return backend.benchmark_fused_nodes(nodes)
|
||||
|
||||
def finalize_multi_template_buffers(self):
|
||||
def replace_buffer(orig_node: ir.MultiTemplateBuffer, new_node: ir.Buffer):
|
||||
replaced_name = new_node.name
|
||||
orig_name = orig_node.get_name()
|
||||
assert isinstance(orig_name, str) and isinstance(replaced_name, str)
|
||||
|
||||
del V.graph.name_to_buffer[replaced_name]
|
||||
new_node.name = orig_name
|
||||
|
||||
V.graph.buffers.remove(orig_node)
|
||||
V.graph.name_to_buffer[orig_name] = new_node
|
||||
|
||||
for i, node in enumerate(self.nodes):
|
||||
if isinstance(node, SchedulerNode) and isinstance(
|
||||
node.node, ir.MultiTemplateBuffer
|
||||
):
|
||||
multi_node = node.node
|
||||
min_node_unfused, _ = multi_node.get_min_choice()
|
||||
|
||||
if isinstance(
|
||||
min_node_unfused,
|
||||
torch._inductor.ir.TritonTemplateCallerBase,
|
||||
):
|
||||
node.node.finalize_as_triton_caller(min_node_unfused)
|
||||
continue
|
||||
|
||||
out_tensorbox = min_node_unfused.output_node()
|
||||
out_storage = out_tensorbox.data
|
||||
assert isinstance(out_storage, ir.StorageBox)
|
||||
out_buffer = out_storage.data
|
||||
assert isinstance(out_buffer, ir.Buffer)
|
||||
|
||||
out_buffer.layout = multi_node.layout
|
||||
replace_buffer(multi_node, out_buffer)
|
||||
new_scheduler_node = self.create_scheduler_node(out_buffer)
|
||||
|
||||
self.nodes[i] = new_scheduler_node
|
||||
self.name_to_node[node.get_name()] = new_scheduler_node
|
||||
self.name_to_fused_node[node.get_name()] = new_scheduler_node
|
||||
|
||||
new_scheduler_node.users = node.users
|
||||
new_scheduler_node.min_order = node.min_order
|
||||
new_scheduler_node.max_order = node.max_order
|
||||
new_scheduler_node.last_usage = node.last_usage
|
||||
for user in new_scheduler_node.users:
|
||||
user.node.inverse_users.remove(node)
|
||||
user.node.inverse_users.append(new_scheduler_node)
|
||||
|
||||
def speedup_by_fusion(self, node1, node2):
|
||||
"""
|
||||
If config.benchmark_fusion is False, always return True.
|
||||
@ -1749,42 +1797,82 @@ class Scheduler:
|
||||
|
||||
why = WhyNoFuse(node1, node2)
|
||||
|
||||
try:
|
||||
ms1, path1 = self.benchmark_fused_nodes(node_list_1)
|
||||
if math.isinf(ms1):
|
||||
why("register spilling of the first kernel")
|
||||
return False
|
||||
def log_fusion(ms_fused, ms1, ms2):
|
||||
if fusion_log.isEnabledFor(logging.DEBUG):
|
||||
if ms_fused < ms1 + ms2:
|
||||
fusion_log.debug(
|
||||
"can fuse (benchmark): fusing %s with %s cause %sx speedup",
|
||||
node1.get_names(),
|
||||
node2.get_names(),
|
||||
green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
|
||||
)
|
||||
else:
|
||||
fusion_log.debug(
|
||||
"cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
|
||||
node1.get_names(),
|
||||
node2.get_names(),
|
||||
red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
|
||||
)
|
||||
|
||||
if isinstance(node1, SchedulerNode) and isinstance(
|
||||
node1.node, ir.MultiTemplateBuffer
|
||||
):
|
||||
multi_node = node1.node
|
||||
choice_timings = multi_node.choice_timings
|
||||
|
||||
_, ms1 = multi_node.get_min_choice()
|
||||
ms2, path2 = self.benchmark_fused_nodes(node_list_2)
|
||||
if math.isinf(ms2):
|
||||
why("register spilling of the second kernel")
|
||||
return False
|
||||
ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused)
|
||||
if math.isinf(ms_fused):
|
||||
why("register spilling of the fused kernel")
|
||||
return False
|
||||
except CompilationError as e:
|
||||
# workaround triton issue: https://github.com/openai/triton/issues/2151
|
||||
if "Loop-carried variable" in str(e):
|
||||
return True # allow fusion
|
||||
else:
|
||||
raise
|
||||
|
||||
if fusion_log.isEnabledFor(logging.DEBUG):
|
||||
if ms_fused < ms1 + ms2:
|
||||
fusion_log.debug(
|
||||
"can fuse (benchmark): fusing %s with %s cause %sx speedup",
|
||||
node1.get_names(),
|
||||
node2.get_names(),
|
||||
green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
|
||||
)
|
||||
else:
|
||||
fusion_log.debug(
|
||||
"cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
|
||||
node1.get_names(),
|
||||
node2.get_names(),
|
||||
red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
|
||||
)
|
||||
min_ms_fused = float("inf")
|
||||
ms_fused_choice = None
|
||||
|
||||
for choice, unfused_time in choice_timings.items():
|
||||
if not isinstance(choice, torch._inductor.ir.TritonTemplateCallerBase):
|
||||
continue
|
||||
|
||||
if unfused_time >= ms1 + ms2:
|
||||
continue
|
||||
|
||||
# TODO - parallel compile triton templates
|
||||
# TODO - should prune/skip choices that are not within certain % of best choice
|
||||
with node1.node.swap_as_triton_caller(choice):
|
||||
ms_fused, _ = self.benchmark_fused_nodes(node_list_fused)
|
||||
|
||||
if ms_fused < min_ms_fused:
|
||||
min_ms_fused = ms_fused
|
||||
ms_fused_choice = choice
|
||||
|
||||
log_fusion(min_ms_fused, ms1, ms2)
|
||||
|
||||
# after we do a fusion, we finalize a triton template.
|
||||
# TODO - could preserve multi template and choices for subsequent fusions
|
||||
if min_ms_fused < (ms1 + ms2) and ms_fused_choice is not None:
|
||||
node1.node.finalize_as_triton_caller(ms_fused_choice)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
ms1, path1 = self.benchmark_fused_nodes(node_list_1)
|
||||
if math.isinf(ms1):
|
||||
why("register spilling of the first kernel")
|
||||
return False
|
||||
ms2, path2 = self.benchmark_fused_nodes(node_list_2)
|
||||
if math.isinf(ms2):
|
||||
why("register spilling of the second kernel")
|
||||
return False
|
||||
ms_fused, path_fused = self.benchmark_fused_nodes(node_list_fused)
|
||||
if math.isinf(ms_fused):
|
||||
why("register spilling of the fused kernel")
|
||||
return False
|
||||
except CompilationError as e:
|
||||
# workaround triton issue: https://github.com/openai/triton/issues/2151
|
||||
if "Loop-carried variable" in str(e):
|
||||
return True # allow fusion
|
||||
else:
|
||||
raise
|
||||
|
||||
log_fusion(ms_fused, ms1, ms2)
|
||||
if (
|
||||
is_metric_table_enabled("slow_fusion")
|
||||
and ms_fused >= ms1 + ms2
|
||||
|
@ -22,12 +22,8 @@ from torch._dynamo.utils import counters, identity, preserve_rng_state
|
||||
from . import config, ir
|
||||
from .autotune_process import TensorMeta, TritonBenchmarkRequest
|
||||
from .codecache import code_hash, PersistentCache, PyCodeCache
|
||||
from .codegen.common import (
|
||||
ChoiceCaller,
|
||||
IndentedBuffer,
|
||||
KernelTemplate,
|
||||
PrimitiveInfoType,
|
||||
)
|
||||
from .codegen.common import IndentedBuffer, KernelTemplate
|
||||
|
||||
from .codegen.triton import (
|
||||
gen_common_triton_imports,
|
||||
texpr,
|
||||
@ -35,8 +31,10 @@ from .codegen.triton import (
|
||||
TritonPrinter,
|
||||
TritonScheduling,
|
||||
)
|
||||
|
||||
from .codegen.triton_utils import config_of, signature_to_meta
|
||||
from .exc import CUDACompileError
|
||||
from .ir import ChoiceCaller, PrimitiveInfoType
|
||||
from .utils import (
|
||||
do_bench,
|
||||
get_dtype_size,
|
||||
@ -653,7 +651,7 @@ class ExternKernelChoice:
|
||||
)
|
||||
|
||||
|
||||
class TritonTemplateCaller(ChoiceCaller):
|
||||
class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
@ -713,6 +711,9 @@ class TritonTemplateCaller(ChoiceCaller):
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
return self.log_info
|
||||
|
||||
def get_make_kernel_render(self):
|
||||
return self.make_kernel_render
|
||||
|
||||
|
||||
class ExternKernelCaller(ChoiceCaller):
|
||||
def __init__(
|
||||
@ -814,6 +815,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# generating a random torch.Tensor for benchmarking.
|
||||
input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
|
||||
precompilation_timeout_seconds: int = 60 * 60,
|
||||
return_multi_template=False,
|
||||
):
|
||||
from .codegen.cuda.cuda_kernel import CUDATemplateCaller
|
||||
|
||||
@ -911,6 +913,33 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
or config.trace.log_autotuning_results
|
||||
):
|
||||
self.log_results(name, input_nodes, timings, autotune_elapse)
|
||||
|
||||
if return_multi_template:
|
||||
min_extern_choice = float("inf")
|
||||
for choice, timing in timings.items():
|
||||
if isinstance(choice, ExternKernelCaller):
|
||||
min_extern_choice = min(min_extern_choice, timing)
|
||||
|
||||
timings = {
|
||||
choice: time
|
||||
for choice, time in timings.items()
|
||||
if (
|
||||
time <= min_extern_choice
|
||||
or not isinstance(choice, ExternKernelCaller)
|
||||
)
|
||||
}
|
||||
|
||||
if len(timings) == 1:
|
||||
return next(iter(timings)).output_node()
|
||||
|
||||
return torch._inductor.ir.TensorBox.create(
|
||||
torch._inductor.ir.MultiTemplateBuffer(
|
||||
layout,
|
||||
input_nodes,
|
||||
timings,
|
||||
)
|
||||
)
|
||||
|
||||
selected_choice = builtins.min(timings, key=timings.__getitem__).output_node()
|
||||
log.debug("selected choice: %s", str(selected_choice))
|
||||
return selected_choice
|
||||
@ -1143,6 +1172,14 @@ def autotune_select_algorithm(*args, **kwargs):
|
||||
global _ALGORITHM_SELECTOR_CACHE
|
||||
if _ALGORITHM_SELECTOR_CACHE is None:
|
||||
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
||||
|
||||
if "return_multi_template" not in kwargs:
|
||||
# TODO - enable multi templates even if benchmark_fusion not enabled
|
||||
kwargs["return_multi_template"] = (
|
||||
torch._inductor.config.benchmark_multi_templates
|
||||
and torch._inductor.config.benchmark_fusion
|
||||
)
|
||||
|
||||
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user