Compare commits

...

1 Commits

Author SHA1 Message Date
0fef7b0ae2 PT2 Inductor ComboKernels (#124969)
Summary:

A ComboKernel combines independent Inductor Triton kernels into a single one.
Consolidation with Foreach kernel:
1) For the scheduler node, the logic is consolidated into ForeachKernelSchedulerNode
2) The backend kernel is consolidated into ComboKernel.

Example:
- element wise kernels
original Pytorch function:
```
 def test_activations(a, b, c):
     a1 = torch.nn.functional.relu(a)
     b1 = torch.nn.functional.sigmoid(b)
     c1 = torch.nn.functional.tanh(c)
     return a1, b1, c1
```
combokernel
```
triton_heuristics.pointwise(
    size_hints=[512], tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*fp32', 5: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5), equal_to_1=())]},
    inductor_meta={'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': []}
)
triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 3 == 0:
        pid_offset = pid // 3
        xnumel = 100
        rnumel = 1
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (x0), xmask)
        tmp1 = triton_helpers.maximum(0, tmp0)
        tl.store(out_ptr0 + (x0), tmp1, xmask)
    elif pid % 3 == 1:
        pid_offset = pid // 3
        xnumel = 400
        rnumel = 1
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x1 = xindex
        tmp2 = tl.load(in_ptr1 + (x1), xmask)
        tmp3 = tl.sigmoid(tmp2)
        tl.store(out_ptr1 + (x1), tmp3, xmask)
    elif pid % 3 == 2:
        pid_offset = pid // 3
        xnumel = 100
        rnumel = 1
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x2 = xindex
        tmp4 = tl.load(in_ptr2 + (x2), xmask)
        tmp5 = libdevice.tanh(tmp4)
        tl.store(out_ptr2 + (x2), tmp5, xmask)
    else:
        pass
```
- reduction kernels
Original Pytorch function:
```
def test_reduce(a, b, c):
     a1 = torch.sum(a, dim=0)
     b1 = torch.max(b, dim=0)
     c1 = torch.min(c, dim=0)
     return a1, b1, c1
```
Generated combokernal:
```
 triton_heuristics.persistent_reduction(
     size_hints=[32, 32],
     reduction_hint=ReductionHint.DEFAULT,
     filename=__file__,
     triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: '*fp32', 4: '*i64', 5: '*fp32', 6: '*i64', 7: '*fp32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3, 4, 5, 6, 7), equal_to_1=())]},
     inductor_meta={'kernel_name': 'triton_per_fused_0', 'mutated_arg_names': []}
 )
 triton.jit
 def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, XBLOCK : tl.constexpr):
     pid = tl.program_id(0)
     if pid % 3 == 0:
         pid_offset = pid // 3
         xnumel = 20
         rnumel = 20
         RBLOCK_0: tl.constexpr = 32
         xoffset = pid_offset * XBLOCK
         xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
         xmask = xindex < xnumel
         rindex = tl.arange(0, RBLOCK_0)[None, :]
         roffset = 0
         rmask = rindex < rnumel
         r1 = rindex
         x0 = xindex
         tmp0 = tl.load(in_ptr0 + (x0 + (20*r1)), rmask & xmask, other=0.0)
         tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK_0])
         tmp3 = tl.where(rmask & xmask, tmp1, float("-inf"))
         tmp4 = triton_helpers.max2(tmp3, 1)[:, None]
         tmp6 = tl.broadcast_to(rindex, tmp3.shape)
         _, tmp5_tmp = triton_helpers.max_with_index(tmp3, tmp6, 1)
         tmp5 = tmp5_tmp[:, None]
         tl.store(out_ptr0 + (x0), tmp4, xmask)
         tl.store(out_ptr1 + (x0), tmp5, xmask)
     elif pid % 3 == 1:
         pid_offset = pid // 3
         xnumel = 10
         rnumel = 10
         RBLOCK_1: tl.constexpr = 16
         xoffset = pid_offset * XBLOCK
         xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
         xmask = xindex < xnumel
         rindex = tl.arange(0, RBLOCK_1)[None, :]
         roffset = 0
         rmask = rindex < rnumel
         r3 = rindex
         x2 = xindex
         tmp7 = tl.load(in_ptr1 + (x2 + (10*r3)), rmask & xmask, other=0.0)
         tmp8 = tl.broadcast_to(tmp7, [XBLOCK, RBLOCK_1])
         tmp10 = tl.where(rmask & xmask, tmp8, float("inf"))
         tmp11 = triton_helpers.min2(tmp10, 1)[:, None]
         tmp13 = tl.broadcast_to(rindex, tmp10.shape)
         _, tmp12_tmp = triton_helpers.min_with_index(tmp10, tmp13, 1)
         tmp12 = tmp12_tmp[:, None]
         tl.store(out_ptr2 + (x2), tmp11, xmask)
         tl.store(out_ptr3 + (x2), tmp12, xmask)
     elif pid % 3 == 2:
         pid_offset = pid // 3
         xnumel = 10
         rnumel = 10
         RBLOCK_2: tl.constexpr = 16
         xoffset = pid_offset * XBLOCK
         xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
         xmask = xindex < xnumel
         rindex = tl.arange(0, RBLOCK_2)[None, :]
         roffset = 0
         rmask = rindex < rnumel
         r5 = rindex
         x4 = xindex
         tmp14 = tl.load(in_ptr2 + (x4 + (10*r5)), rmask & xmask, other=0.0)
         tmp15 = tl.broadcast_to(tmp14, [XBLOCK, RBLOCK_2])
         tmp17 = tl.where(rmask & xmask, tmp15, 0)
         tmp18 = tl.sum(tmp17, 1)[:, None]
         tl.store(out_ptr4 + (x4), tmp18, xmask)
     else:
         pass
```

Note: ComboKernels uses masks to allow combination of kernels working with tensors of different sizes.

Test Plan:
```
buck2 test mode/dev-nosan caffe2/test/inductor:foreach
```
```
buck2 test mode/dev-nosan caffe2/test/inductor:combo_kernels
```

Differential Revision: D54134695
2024-05-14 13:43:10 -07:00
13 changed files with 1432 additions and 303 deletions

View File

@ -0,0 +1,262 @@
# Owner(s): ["module: inductor"]
import sys
import unittest
import torch
import torch._inductor
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TestCase,
)
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.testing._internal.triton_utils import requires_cuda
aten = torch.ops.aten
try:
try:
from .test_torchinductor import check_model, check_model_cuda
except ImportError:
from test_torchinductor import check_model, check_model_cuda
except (unittest.SkipTest, ImportError) as e:
sys.stderr.write(f"{type(e)}: {e}\n")
if __name__ == "__main__":
sys.exit(0)
raise
@instantiate_parametrized_tests
class ComboKernelTests(TestCase):
check_model_cuda = check_model_cuda
check_model_cpu = check_model
check_kernel_count = True
def setUp(self):
super().setUp()
torch._inductor.metrics.reset()
torch._inductor.config.combo_kernels = True
torch._inductor.config.benchmark_combo_kernel = False
def tearDown(self):
super().tearDown()
torch._inductor.metrics.reset()
@requires_cuda
def test_activation_functions(self):
def test_activations(a, b, c):
a1 = torch.nn.functional.relu(a)
b1 = torch.nn.functional.sigmoid(b)
c1 = torch.nn.functional.tanh(c)
return a1, b1, c1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
]
out_eager = test_activations(*inps)
out_compiled = torch.compile(test_activations)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
@requires_cuda
def test_reduce_functions(self):
def test_reduce(a, b, c, d):
a1 = torch.sum(a, dim=0)
b1 = torch.max(b, dim=0)
c1 = torch.min(c, dim=0)
d1 = torch.nn.functional.tanh(d)
return a1, b1, c1, d1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_reduce(*inps)
out_compiled = torch.compile(test_reduce)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2)
@requires_cuda
def test_mutated_args(self):
def test_mutated(a, b, c, d):
a.add_(1)
b.sigmoid_()
c = torch.add(c, 5)
d.tanh_()
return a, b, c, d
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_mutated(*inps)
out_compiled = torch.compile(test_mutated)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
@requires_cuda
def test_reduce_split(self):
def fn(a, b):
a1 = torch.linalg.vector_norm(a)
b1 = torch.sum(b, dim=0)
return a1, b1
inps = [
torch.rand(2048, 512, device="cuda"),
torch.rand(20, 20, device="cuda"),
]
out_eager = fn(*inps)
out_compiled = torch.compile(fn)(*inps)
self.assertEqual(out_eager, out_compiled)
@requires_cuda
def test_2d_blocking_partitioning(self):
def fn(a0, a1, a2, b0, b1, b2):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
c2 = torch.add(a2, b2)
return c0, c1, c2
self.check_model_cuda(
fn,
(
torch.rand(30, 20, device="cuda"),
torch.rand(40, 30, device="cuda"),
torch.rand(36, 40, device="cuda"),
torch.rand(30, 20, device="cuda"),
torch.rand(30, 40, device="cuda").t(),
torch.rand(40, 36, device="cuda").t(),
),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
@instantiate_parametrized_tests
class ComboKernelBenchmarkTests(TestCase):
check_model_cuda = check_model_cuda
check_model_cpu = check_model
check_kernel_count = True
def setUp(self):
super().setUp()
torch._inductor.metrics.reset()
torch._inductor.config.combo_kernels = True
torch._inductor.config.benchmark_combo_kernel = True
def tearDown(self):
super().tearDown()
torch._inductor.metrics.reset()
@requires_cuda
def test_activation_benchmark(self):
def test_activations(a, b, c):
a1 = torch.nn.functional.relu(a)
b1 = torch.nn.functional.sigmoid(b)
c1 = torch.nn.functional.tanh(c)
return a1, b1, c1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
]
out_eager = test_activations(*inps)
out_compiled = torch.compile(test_activations)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
@requires_cuda
def test_reduce_benchmark(self):
def test_reduce(a, b, c, d):
a1 = torch.sum(a, dim=0)
b1 = torch.max(b, dim=0)
c1 = torch.min(c, dim=0)
d1 = torch.nn.functional.tanh(d)
return a1, b1, c1, d1
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_reduce(*inps)
out_compiled = torch.compile(test_reduce)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10)
@requires_cuda
def test_mutated_benchmark(self):
def test_mutated(a, b, c, d):
a.add_(1)
b.sigmoid_()
c = torch.add(c, 5)
d.tanh_()
return a, b, c, d
inps = [
torch.rand(10, 10, device="cuda"),
torch.rand(20, 20, device="cuda"),
torch.rand(10, 10, device="cuda"),
torch.rand(30, 8, device="cuda"),
]
out_eager = test_mutated(*inps)
out_compiled = torch.compile(test_mutated)(*inps)
self.assertEqual(out_eager, out_compiled)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6)
@requires_cuda
def test_2d_blocking_benchmark(self):
def fn(a0, a1, a2, b0, b1, b2):
c0 = torch.add(a0, b0)
c1 = torch.add(a1, b1)
c2 = torch.add(a2, b2)
return c0, c1, c2
self.check_model_cuda(
fn,
(
torch.rand(30, 20, device="cuda"),
torch.rand(40, 30, device="cuda"),
torch.rand(36, 40, device="cuda"),
torch.rand(30, 20, device="cuda"),
torch.rand(30, 40, device="cuda").t(),
torch.rand(40, 36, device="cuda").t(),
),
)
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
if HAS_CPU or HAS_CUDA:
run_tests(needs="filelock")

View File

@ -622,7 +622,7 @@ class ForeachTests(TestCase):
),
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
@requires_cuda
@inplace_bin_ops

View File

@ -132,10 +132,10 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase):
args = (x, y)
events = self._test_profiling_kernel_names(fn, args, "_for_")
events = self._test_profiling_kernel_names(fn, args, "_poi_")
event_found = False
for event in events:
if event.name == "triton_for_fused_0":
if event.name == "triton_poi_fused_0":
event_found = True
self.assertTrue(
event.input_shapes

View File

@ -82,6 +82,7 @@ class CppWrapperCpu(WrapperCodeGen):
arg_types=None,
grid_fn: str = "grid",
triton_meta=None,
grid_extra_kwargs="",
):
"""
Generates kernel call code.

View File

@ -1,7 +1,7 @@
import functools
import os
from itertools import chain, count
from typing import Any, List, Optional, TYPE_CHECKING
from typing import Any, Callable, List, Optional, TYPE_CHECKING
import sympy
@ -159,16 +159,27 @@ class CppWrapperCuda(CppWrapperCpu):
return ", ".join(new_args)
def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
def generate_default_grid(
self,
name: str,
grid: List[Any],
cuda: bool = True,
grid_callable: Optional[Callable[..., Any]] = None,
**grid_extra_kwags,
):
"""
Generate grid configs for launching a CUDA kernel using the grid
function from triton_heuristics.
"""
if not cuda:
return grid
assert isinstance(grid, list), f"expected {grid=} to be a list"
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
grid_fn = default_grid(*grid)
grid_callable = grid_callable or default_grid
if not grid_extra_kwags:
grid_fn = grid_callable(*grid)
else:
grid_fn = grid_callable(*grid, **grid_extra_kwags)
params = CudaKernelParamCache.get(name)
assert (
params is not None
@ -191,6 +202,7 @@ class CppWrapperCuda(CppWrapperCpu):
arg_types=None,
grid_fn: str = "grid",
triton_meta=None,
grid_extra_kwargs="",
):
if not cuda:
# Even in CppWrapperCuda, we may see cpp kernels

View File

@ -71,8 +71,8 @@ class CUDACombinedScheduling(BaseScheduling):
def flush(self):
return self._triton_scheduling.flush()
def codegen_foreach(self, *args, **kwargs):
return self._triton_scheduling.codegen_foreach(*args, **kwargs)
def codegen_combo_kernel(self, *args, **kwargs):
return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs)
def benchmark_fused_nodes(self, nodes):
return self._triton_scheduling.benchmark_fused_nodes(nodes)
@ -81,3 +81,6 @@ class CUDACombinedScheduling(BaseScheduling):
return self._triton_scheduling.generate_kernel_code_from_nodes(
nodes, benchmark_kernel
)
def benchmark_combo_kernel(self, node_list):
return self._triton_scheduling.benchmark_combo_kernel(node_list)

View File

@ -1296,6 +1296,7 @@ class TritonKernel(Kernel):
reduction_hint=ReductionHint.DEFAULT,
min_elem_per_thread=0,
disable_persistent_reduction=False,
optimize_mask=True,
):
if pid_cache is None:
pid_cache = {}
@ -1317,6 +1318,7 @@ class TritonKernel(Kernel):
self.block_ptr_id = itertools.count()
# buffer accesses in the kernel
self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
self.optimize_mask = optimize_mask
self.persistent_reduction: bool = (
not disable_persistent_reduction
@ -1751,9 +1753,11 @@ class TritonKernel(Kernel):
if isinstance(index, sympy.Integer):
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
return IndexingOptions(index_str, set(), "None", expand_str, has_rindex)
if self.optimize_mask:
return IndexingOptions(index_str, set(), "None", expand_str, has_rindex)
mask_vars = dense_mask_vars
if need_dense and not have_dense:
elif need_dense and not have_dense:
expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
index_str = f"tl.broadcast_to({index_str}, {expand_str})"
mask_vars = dense_mask_vars
@ -1785,6 +1789,8 @@ class TritonKernel(Kernel):
return trees
def filter_masks(self, mask_vars):
if not self.optimize_mask:
return
for tree in self.range_trees:
# Masks are superfluous if we only have one element
if V.graph.sizevars.statically_known_equals(tree.numel, 1): # type: ignore[arg-type]
@ -3759,34 +3765,53 @@ class TritonScheduling(BaseScheduling):
def codegen_sync(self):
V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())
def codegen_foreach(self, foreach_node):
from .triton_foreach import ForeachKernel
def codegen_combo_kernel(self, combo_kernel_node):
from .triton_combo_kernel import ComboKernel
for partitions_with_metadata in ForeachKernel.horizontal_partition(
foreach_node.get_subkernel_nodes(), self
):
kernel = ForeachKernel()
for nodes, tiled_groups, numel, rnumel in partitions_with_metadata:
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
(
reduction_hint_val,
mutations,
index_dtype,
) = self.get_kernel_args(node_schedule, numel, rnumel)
subkernel_nodes = combo_kernel_node.get_subkernel_nodes()
fused_node_lists = [node.get_nodes() for node in subkernel_nodes]
subkernel_map, node_schedule_map = {}, {}
for pn, nodes in zip(subkernel_nodes, fused_node_lists):
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
node_schedule_map[pn] = node_schedule, tiled_groups, numel, rnumel
(
reduction_hint_val,
mutations,
index_dtype,
) = self.get_kernel_args(node_schedule, numel, rnumel)
subkernel_map[pn] = ComboKernel.create_triton_kernel(
*tiled_groups,
reduction_hint=reduction_hint_val,
mutations=mutations,
index_dtype=index_dtype,
)
subkernel = kernel.create_sub_kernel(
*tiled_groups,
reduction_hint=reduction_hint_val,
mutations=mutations,
index_dtype=index_dtype,
)
partitions = ComboKernel.horizontal_partition(
nodes=combo_kernel_node.get_subkernel_nodes(),
triton_scheduling=self,
custom_algorithm=combo_kernel_node.use_custom_partition_algo,
kernel_map=subkernel_map,
node_info_map=node_schedule_map,
)
log.debug(
"ComboKernels: %d nodes partitioned into %s groups",
len(combo_kernel_node.get_subkernel_nodes()),
[len(p) for p in partitions],
)
for node_group in partitions:
fused_node_lists = [node.get_nodes() for node in node_group]
kernel = ComboKernel()
for pn, nodes in zip(node_group, fused_node_lists):
self.codegen_node_schedule_with_kernel(
node_schedule,
subkernel,
node_schedule_map[pn][0],
kernel.create_sub_kernel(subkernel_map[pn]),
)
with V.set_kernel_handler(subkernel):
subkernel = subkernel_map[pn]
node_schedule = node_schedule_map[pn][0]
with V.set_kernel_handler(subkernel): # type: ignore[call-arg]
for node in node_schedule:
if node not in (EnableReduction, DisableReduction):
node.mark_run()
@ -3794,8 +3819,9 @@ class TritonScheduling(BaseScheduling):
V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
src_code = kernel.codegen_kernel()
kernel_name = self.define_kernel(src_code, [foreach_node])
self.codegen_comment([foreach_node])
kernel_name = self.define_kernel(src_code, [combo_kernel_node])
self.codegen_comment([combo_kernel_node])
log.debug("ComboKernels: generated kernel %s.", kernel_name)
kernel.call_kernel(V.graph.wrapper_code, kernel_name)
self.scheduler.free_buffers()
@ -4062,6 +4088,134 @@ class TritonScheduling(BaseScheduling):
store_cache()
return ms, mod.__file__
def benchmark_combo_kernel(self, node_list):
from .triton_combo_kernel import ComboKernel
def cache_file_path():
assert mod.__file__ is not None
return os.path.splitext(mod.__file__)[0] + ".kernel_perf"
def load_cache():
path = cache_file_path()
if os.path.exists(path):
with open(path) as fd:
return tuple(float(e) for e in fd.read().split())
return (None, None)
def store_cache():
path = cache_file_path()
with open(path, "w") as fd:
fd.write(str(ms) + " " + str(ms_clone))
subkernel_nodes = node_list
fused_node_lists = [node.get_nodes() for node in subkernel_nodes]
subkernel_map, node_schedule_map = {}, {}
for pn, nodes in zip(subkernel_nodes, fused_node_lists):
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
(
reduction_hint_val,
mutations,
index_dtype,
) = self.get_kernel_args(node_schedule, numel, rnumel)
node_schedule_map[pn] = (node_schedule, tiled_groups, numel, rnumel)
subkernel_map[pn] = ComboKernel.create_triton_kernel(
*tiled_groups,
reduction_hint=reduction_hint_val,
mutations=mutations,
index_dtype=index_dtype,
)
partitions = ComboKernel.horizontal_partition(
nodes=subkernel_nodes,
triton_scheduling=self,
kernel_map=subkernel_map,
node_info_map=node_schedule_map,
custom_algorithm=True,
)
log.debug(
"ComboKernels: %d nodes partitioned into %s groups",
len(subkernel_nodes),
[len(p) for p in partitions],
)
total_ms, file_list = 0, []
total_clone_ms = 0
removed_buffers_orig = V.graph.removed_buffers
V.graph.removed_buffers = set(removed_buffers_orig)
inplaced_to_remove_orig = V.graph.inplaced_to_remove
V.graph.inplaced_to_remove = set(inplaced_to_remove_orig)
for node_group in partitions:
fused_node_lists = [node.get_nodes() for node in node_group]
kernel = ComboKernel()
names = [n.get_names() for nodes in fused_node_lists for n in nodes]
for pn, nodes in zip(node_group, fused_node_lists):
# empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
for n in nodes:
n.last_usage = set()
self.codegen_node_schedule_with_kernel(
node_schedule_map[pn][0],
kernel.create_sub_kernel(subkernel_map[pn]),
)
subkernel = subkernel_map[pn]
V.graph.removed_buffers |= subkernel.removed_buffers
V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel):
src_code = kernel.codegen_kernel()
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
mod = PyCodeCache.load(src_code)
log.debug(
"kernel src code for %s written to: %s",
names,
mod.__file__,
)
ms, ms_clone = load_cache()
if ms is not None:
total_ms += ms
total_clone_ms += ms_clone
file_list.append(mod.__file__)
continue
args = mod.get_args()
call = mod.call
wrapped_jit_function = mod.triton_
# call once to trigger the compilation
call(wrapped_jit_function.clone_args(*args)[0])
launchers = wrapped_jit_function.launchers
assert len(launchers) == 1
if launchers[0].n_spills > 0:
# skip benchmarking the kernel if there are register spills
ms = ms_clone = float("inf")
else:
# We have to clone the inplace updated arguments to avoid earlier calls
# generating out of range indices for later calls.
ms = do_bench_gpu(
lambda: call(wrapped_jit_function.clone_args(*args)[0])
)
ms_clone = do_bench_gpu(
lambda: wrapped_jit_function.clone_args(*args)[0]
)
log.debug(
"The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs",
{n.get_name() for n in nodes},
ms,
ms_clone,
)
store_cache()
total_ms += ms
total_clone_ms += ms_clone
file_list.append(mod.__file__)
V.graph.removed_buffers = removed_buffers_orig
V.graph.inplaced_to_remove = inplaced_to_remove_orig
return total_ms, total_clone_ms, file_list
@dataclasses.dataclass
class CandidateTiling:

View File

@ -0,0 +1,675 @@
import itertools
import logging
import textwrap
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple
from sympy import Integer
from .. import config, metrics
from ..runtime.hints import DeviceProperties
from ..runtime.runtime_utils import next_power_of_2
from ..runtime.triton_heuristics import grid_combo_kernels
from ..scheduler import FusedSchedulerNode
from ..utils import Placeholder
from ..virtualized import V
from .common import DeferredLine, IndentedBuffer, Kernel, PythonPrinter, SizeArg
from .triton import gen_common_triton_imports, TritonKernel
from .triton_utils import config_of, signature_to_meta
log = logging.getLogger(__name__)
pexpr = PythonPrinter().doprint
LARGE_NUMELS = 512e5
def _default_custom_combo_kernel_horizontal_partition(
nodes, triton_scheduling, kernel_map, node_info_map
):
"""Horizontally partition the given list of nodes into a list of list of nodes where each sublist
represents a partion. Nodes in different partitions are implemented in different combo kernels.
Nodes in the same partition are likely to be implemented
in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.
Input arguments:
nodes: a list of fused scheduler nodes to partition.
triton_scheduling: TritonScheduling instance.
kernel_map: a map from node to its kernel.
node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel).
Oputput:
a list of list of nodes with each sublist representing a partition.
The default algorithm is to partition nodes based on the following rules:
1) nodes with the same number of block dimensions are grouped togather.
2) large pointwise nodes are separated from other nodes.
3) large reduce nodes are separated from other nodes.
"""
assert len(nodes) >= 1
# first partition nodes based on number of block dimensions
tilings = [node_info_map[n][1] for n in nodes]
max_dims = max(len(t) for t in tilings)
nodes_per_ndim = []
for i in range(2, max_dims + 1):
group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
reduction = [
n
for n in group_per_dim
if kernel_map[n].inside_reduction
and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim)
]
not_reduction = [n for n in group_per_dim if n not in reduction]
# rnumel > 2048 usually has long execution time
long_reduction = [n for n in reduction if n.group[-1][-1] > 2048]
short_reduction = [n for n in reduction if n not in long_reduction]
if long_reduction:
log.warning(
"ComboKernels: %d long reduction nodes are separated",
len(long_reduction),
)
large_pointwise = [
n
for n in not_reduction
if not kernel_map[n].inside_reduction
and len(kernel_map[n].numels) == 2
and V.graph.sizevars.size_hint(kernel_map[n].numels[0]) > LARGE_NUMELS
]
if large_pointwise:
# TODO benchmark the performance when large pointwise nodes combining with others
log.warning(
"ComboKernels: %d large pointwise nodes are separated",
len(large_pointwise),
)
not_reduction = [n for n in not_reduction if n not in large_pointwise]
for node in large_pointwise:
nodes_per_ndim.append([node])
for g in (not_reduction, short_reduction, long_reduction):
if g:
nodes_per_ndim.append(g)
assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
return nodes_per_ndim
_custom_combo_kernel_horizontal_partition_algorithm = (
_default_custom_combo_kernel_horizontal_partition
)
def set_custom_combo_kernel_horizontal_partition(algorithm):
"""Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.
The algorithm should take a list of nodes and return a list of list of nodes.
The default algorithm is to partition nodes based on number of block dimensions.
"""
global _custom_combo_kernel_horizontal_partition_algorithm
_custom_combo_kernel_horizontal_partition_algorithm = algorithm
@dataclass
class PartitionState:
partitions: List[List[FusedSchedulerNode]]
cur_partition: List[FusedSchedulerNode]
cur_count: int
def finalize(self):
if self.cur_partition:
self.partitions.append(self.cur_partition)
class ComboKernel(Kernel):
MAX_NUM_ARGS = 250 # number where I would no longer get triton errors
@staticmethod
def _update_partition(partition_state, node_rw_count, node_info):
if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS:
partition_state.partitions.append(partition_state.cur_partition)
partition_state.cur_partition = [node_info]
partition_state.cur_count = node_rw_count
else:
partition_state.cur_count += node_rw_count
partition_state.cur_partition.append(node_info)
@staticmethod
def _base_horizontal_partition(subkernel_nodes, triton_scheduling, node_info_map):
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
(read/writes) and to have the same 2D or 1D blocking strategy."""
# TODO support combination of kernels with different block dimensions
assert len(subkernel_nodes) >= 1
ndim_to_partition_state: Dict[int, PartitionState] = defaultdict(
lambda: PartitionState([], [], 0)
)
for node in subkernel_nodes:
node_schedule, tiled_groups, numel, rnumel = node_info_map[node]
node_info = node
read_writes = node.read_writes
read_write_count = len(read_writes.reads) + len(read_writes.writes)
ndim = len(tiled_groups)
partition_state = ndim_to_partition_state[ndim]
ComboKernel._update_partition(partition_state, read_write_count, node_info)
all_partitions = []
for partition_state in ndim_to_partition_state.values():
partition_state.finalize()
all_partitions.extend(partition_state.partitions)
return all_partitions
@staticmethod
def horizontal_partition(
nodes, triton_scheduling, kernel_map, node_info_map, custom_algorithm=False
):
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum)
for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into
sublists in the following way:
1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True
2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is
guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same
2D or 1D blocking strategy.
"""
if custom_algorithm:
raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm(
nodes, triton_scheduling, kernel_map, node_info_map
)
else:
raw_partitions = [nodes]
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
(read/writes) and to have the same 2D or 1D blocking strategy."""
all_partitions = []
for raw_partition in raw_partitions:
all_partitions.extend(
ComboKernel._base_horizontal_partition(
raw_partition, triton_scheduling, node_info_map
)
)
return all_partitions
def __init__(self, use_custom_partition_algo=False):
super().__init__()
self.sub_kernels = []
self.iter_vars_count = itertools.count()
self.block_count = 0
self.grids = []
self.min_x_blocks_list = []
self.use_custom_partition_algo = use_custom_partition_algo
def codegen_pid_range(self, code, num):
num_kernels = len(self.sub_kernels)
if self.block_count == 0:
cond = "if"
else:
cond = "elif"
code.splice(f"{cond} pid % {num_kernels} == {num}:")
with code.indent():
code.splice(f"pid_offset = pid // {num_kernels}")
self.block_count += 1
def create_sub_kernel(self, triton_kernel):
sub_kernel = triton_kernel
metrics.generated_kernel_count -= 1
sub_kernel.args = self.args
sub_kernel.iter_vars_count = self.iter_vars_count
sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
self.sub_kernels.append(sub_kernel)
return sub_kernel
@staticmethod
def create_triton_kernel(*groups, index_dtype, mutations, reduction_hint):
return TritonKernel(
*groups,
index_dtype=index_dtype,
mutations=mutations,
pid_cache={"tl.program_id(0)": "pid_offset"},
reduction_hint=reduction_hint,
optimize_mask=False,
)
def codegen_static_numels_sub_kernel(self, code, sub_kernel, num):
"""
We get a small speedup from hard coding numels if they are static.
This code stomps on the passed-in values by writing an constant to the top of the kernel.
In a kernel like:
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
We would add
xnumel = 4096
rnumel = 768
After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
knows that its a static numel, as that you just plop a constant into the kernel.
"""
grid = []
uniquify_block_sizes = []
for tree in sub_kernel.range_trees:
simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
if tree.prefix != "r":
grid.append(int(simplified_tree_numel))
if tree.prefix == "r" and sub_kernel.persistent_reduction:
if isinstance(simplified_tree_numel, (Integer, int)):
val = int(simplified_tree_numel)
else:
continue
val = next_power_of_2(val)
code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}")
uniquify_block_sizes.append("RBLOCK")
if tree.prefix == "x" and sub_kernel.no_x_dim:
code.writeline(f"XBLOCK_{num}: tl.constexpr = 1")
uniquify_block_sizes.append("XBLOCK")
self.grids.append(grid)
return uniquify_block_sizes
def min_x_blocks_sub_kernel(self, sub_kernel, num):
"""
We get a small speedup from hard coding numels if they are static.
This code stomps on the passed-in values by writing an constant to the top of the kernel.
In a kernel like:
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
We would add
xnumel = 4096
rnumel = 768
After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
knows that its a static numel, as that you just plop a constant into the kernel.
"""
min_x_blocks = 0
for tree in sub_kernel.range_trees:
simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
if tree.prefix == "x" and sub_kernel.no_x_dim:
min_x_blocks = int(simplified_tree_numel)
self.min_x_blocks_list.append(min_x_blocks)
def select_heuristics(self, sub_kernel) -> Tuple[str, List[int]]:
size_hints = [
next_power_of_2(V.graph.sizevars.size_hint(numel))
for numel in sub_kernel.numels
]
if sub_kernel.persistent_reduction:
assert sub_kernel.inside_reduction
heuristics = "persistent_reduction"
elif sub_kernel.inside_reduction:
heuristics = "reduction"
else:
size_hints.pop()
heuristics = "pointwise"
return heuristics, size_hints
def select_combo_heuristics(self, heuristics_list, size_hints_list):
if "reduction" in heuristics_list:
i, _ = max(
enumerate(size_hints_list),
key=lambda x: x[1][0] if heuristics_list[x[0]] == "reduction" else 0,
)
return heuristics_list[i], size_hints_list[i], self.sub_kernels[i]
elif "pointwise" in heuristics_list:
i, _ = max(
enumerate(size_hints_list),
key=lambda x: x[1][0] if heuristics_list[x[0]] == "pointwise" else 0,
)
# modify size_hint to avoid oom check fail (may be a false alarm)
num_pointwise = len([e for e in heuristics_list if e == "pointwise"])
num_reduction = len([e for e in heuristics_list if e == "reduction"])
num_persistent_reduction = len(
[e for e in heuristics_list if e == "persistent_reduction"]
)
assert (
num_reduction == 0
), "combining pointwise and reduction are not supported yet."
heuristics = (
"pointwise_with_reduction"
if num_persistent_reduction > 0
else "pointwise"
)
if len(heuristics_list) - num_pointwise >= 4:
size_hints = size_hints_list[i]
size_hints[0] = min(128, size_hints[0])
return heuristics, size_hints_list[i], self.sub_kernels[i]
else:
return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]
def get_mutated_args_sub_kernels(self) -> List[str]:
mutated_args = set()
for sub_kernel in self.sub_kernels:
for mutation in sub_kernel.mutations:
if mutation in sub_kernel.args.input_buffers:
mutated_args.add(sub_kernel.args.input_buffers[mutation])
if (
mutation in sub_kernel.args.inplace_buffers
and mutation not in V.graph.removed_buffers
and mutation not in sub_kernel.removed_buffers
):
mutated_args.add(
sub_kernel.args.inplace_buffers[mutation].inner_name
)
if mutation in sub_kernel.args.output_buffers:
mutated_args.add(sub_kernel.args.output_buffers[mutation])
return sorted(mutated_args)
def jit_line(
self, heuristics, size_hints, selected_kernel, pointwise_with_reduce=False
):
# TODO: is it correct to use the first sub kernel's heuristics?
_, _, signature = self.args.python_argdefs()
# TODO Is it ok to just use sub_kernel[0].index_dtype?
index_dtype = self.sub_kernels[0].index_dtype
for i, sub in enumerate(self.sub_kernels):
self.min_x_blocks_sub_kernel(sub, i)
triton_meta = {
"signature": signature_to_meta(signature, size_dtype=index_dtype),
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]
mutated_args = self.get_mutated_args_sub_kernels()
inductor_meta = {
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
"mutated_arg_names": mutated_args,
**TritonKernel.inductor_meta_common(),
}
sub_kernel = selected_kernel
if sub_kernel.inside_reduction:
reduction_hint = sub_kernel.reduction_hint
heuristics_line = f"""
@triton_heuristics.{heuristics}(
size_hints={size_hints!r},
reduction_hint={reduction_hint},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta}
)
@triton.jit
"""
else:
tile_hint = ""
if len(size_hints) == 2:
# TODO only input, output and 2 args
tile_hint = "tile_hint=TileHint.SQUARE,"
else:
tile_hint = "tile_hint=TileHint.DEFAULT,"
heuristics_line = f"""
@triton_heuristics.{heuristics}(
size_hints={size_hints!r}, {tile_hint}
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r}
)
@triton.jit
"""
return heuristics_line
def add_blockd_to_args(self, argdefs):
block_args = {}
for num, sub_kernel in enumerate(self.sub_kernels):
# TODO: we assume all sub_kernels have the same block size
for tree in sub_kernel.range_trees:
if tree.prefix == "r" and (
not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
):
continue
if tree.prefix == "x" and sub_kernel.no_x_dim:
continue
# argdefs.append(f"{tree.prefix.upper()}BLOCK_{num} : tl.constexpr")
block_args[f"{tree.prefix.upper()}BLOCK : tl.constexpr"] = tree.prefix
for arg in block_args:
argdefs.append(arg)
return argdefs
def codegen_kernel(self, name=None):
# TODO: is it correct to use the first sub kernel's heuristics?
heuristics_list, size_hints_list = [], []
for subkernel in self.sub_kernels:
h, s = self.select_heuristics(subkernel)
heuristics_list.append(h)
size_hints_list.append(s)
heuristics, size_hints, selected_kernel = self.select_combo_heuristics(
heuristics_list, size_hints_list
)
pointwise_with_reduction, heuristics = (
(True, "pointwise")
if heuristics == "pointwise_with_reduction"
else (False, heuristics)
)
code = IndentedBuffer()
code.splice(gen_common_triton_imports())
if config.benchmark_combo_kernel:
code.splice(self.imports_for_benchmark_kernel())
argdefs, _, _ = self.args.python_argdefs()
argdefs = self.add_blockd_to_args(argdefs)
code.splice(
self.jit_line(
heuristics,
size_hints,
selected_kernel,
pointwise_with_reduce=pointwise_with_reduction,
)
)
code.writeline(
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
)
with code.indent():
code.splice("pid = tl.program_id(0)")
for num, sub_kernel in enumerate(self.sub_kernels):
self.codegen_pid_range(code, num)
with code.indent():
uniquify = self.codegen_static_numels_sub_kernel(
code, sub_kernel, num
)
sub_kernel.codegen_body()
uniquified_body = self.uniquify_block_sizes(
sub_kernel.body, num, uniquify
)
code.splice(uniquified_body)
code.splice("else:")
with code.indent():
code.splice("pass")
if config.benchmark_combo_kernel:
code.splice(self.codegen_kernel_benchmark(num_gb=0))
return code.getvalue()
def codegen_kernel_benchmark(self, num_gb, grid=None):
result = IndentedBuffer()
argdefs, call_args, signature = self.args.python_argdefs()
result.writelines(["", "", "def get_args():"])
with result.indent():
name_cnt = itertools.count()
var_names = []
for arg_name, arg_sig in zip(call_args, signature):
var_name = f"arg_{next(name_cnt)}"
buf = V.graph.get_buffer(arg_name)
if buf:
result.writeline(
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long
)
elif arg_name in V.graph.constants:
# note that random seed is put in V.graph.constants
const_tensor = V.graph.constants[arg_name]
result.writeline(
f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
)
elif isinstance(arg_sig, SizeArg):
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
# Force the seed_offset to be 0 so calls to the same kernel
# using different seed offset will have the same benchmark harness.
# We can dedup kernel definitions in this case.
if "seed_offset" in arg_sig.name:
symval_hint = 0
result.writeline(f"{var_name} = {symval_hint}")
else:
raise KeyError(
f"Don't find the buffer or const tensor for {arg_name}"
)
var_names.append(var_name)
result.writeline(f"return {', '.join(var_names)},")
result.writelines(["\n", "\n", "def call(args):"])
if grid is None:
grid = self.grid(self.grids)
grid_str = ", ".join(pexpr(item) for item in grid)
grid_extra_kwargs = f"num_kernels={len(self.sub_kernels)}, min_blocks={max(self.min_x_blocks_list) * len(self.sub_kernels)}"
grid_str = f"{grid_str}, {grid_extra_kwargs}"
grid_arg = f"grid=grid_combo_kernels({grid_str})"
else:
grid_arg = f"grid={grid}"
index = V.graph.scheduler.current_device.index
with result.indent():
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
with result.indent():
result.writeline(
V.graph.device_ops.set_device(index)
) # no-op to ensure context
stream_name = f"stream{index}"
result.writeline(f"{stream_name} = get_raw_stream({index})")
result.writeline(
f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
)
# benchmark all configs
result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
with result.indent():
result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
with result.indent():
result.writeline(
V.graph.device_ops.set_device(index)
) # no-op to ensure context
result.writeline(
f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
)
result.writelines(["\n", "\n", "if __name__ == '__main__':"])
with result.indent():
result.writeline("from triton.testing import do_bench")
result.writeline("")
result.writeline("args = get_args()")
result.writeline(
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)"
)
result.writeline(f"num_gb = {num_gb}")
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
result.writeline(
'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")'
)
return result
def imports_for_benchmark_kernel(self):
return textwrap.dedent(
"""
from torch._dynamo.testing import rand_strided
{}
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
""".format(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
)
def uniquify_block_sizes(
self, code: IndentedBuffer, num_kernel, uniquify: List[str]
) -> IndentedBuffer:
if not uniquify:
return code
modified = IndentedBuffer(initial_indent=code._indent)
for line in code._lines:
if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]):
modified_line = line
for block in blocks:
modified_line = modified_line.replace(
block, f"{block}_{num_kernel}"
)
modified.writeline(modified_line)
elif isinstance(line, DeferredLine) and (
blocks := [e for e in uniquify if e in line.line]
):
modified_line = line.line
for block in blocks:
modified_line = modified_line.replace(
block, f"{block}_{num_kernel}"
)
new_line = DeferredLine(line.name, modified_line)
modified.writeline(new_line)
else:
modified.writeline(line)
return modified
def call_kernel(self, code, name: str):
_, call_args, _ = self.args.python_argdefs()
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
for i in range(len(call_args)):
if V.graph.is_unspec_arg(call_args[i]):
call_args[i] = call_args[i] + ".item()"
wrapper = V.graph.wrapper_code
grid = self.grid(self.grids)
grid = wrapper.generate_default_grid(
name,
grid,
grid_callable=grid_combo_kernels,
num_kernels=len(self.sub_kernels),
min_blocks=max(self.min_x_blocks_list) * len(self.sub_kernels),
)
wrapper.generate_kernel_call(
name,
call_args,
grid,
V.graph.scheduler.current_device.index,
cuda=True,
triton=True,
grid_fn="grid_combo_kernels",
grid_extra_kwargs=f"num_kernels={len(self.sub_kernels)}, min_blocks={max(self.min_x_blocks_list) * len(self.sub_kernels)}",
)
def grid(self, sub_kernel_numels):
xnumel = [e[-1] if len(e) > 0 else None for e in sub_kernel_numels]
ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
# TODO: improve 1d/2d mixed cases
xnumel = None if any(e is None for e in xnumel) else max(xnumel)
ynumel = None if any(e is None for e in ynumel) else max(ynumel)
znumel = None if any(e is None for e in znumel) else max(znumel)
numels = (
(xnumel,)
if not ynumel
else (ynumel, xnumel)
if not znumel
else (znumel, ynumel, xnumel)
)
return numels

View File

@ -1,248 +0,0 @@
import itertools
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple
from sympy import Integer
from .. import metrics
from ..runtime.hints import DeviceProperties
from ..scheduler import SchedulerNode
from ..utils import ceildiv, Placeholder
from ..virtualized import V
from .common import IndentedBuffer, Kernel
from .triton import gen_common_triton_imports, TritonKernel
from .triton_utils import config_of, signature_to_meta
@dataclass
class PartitionState:
partitions: List[
List[Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]]
]
cur_partition: List[
Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]
]
cur_count: int
def finalize(self):
if self.cur_partition:
self.partitions.append(self.cur_partition)
class ForeachKernel(Kernel):
MAX_NUM_ARGS = 250 # number where I would no longer get triton errors
@staticmethod
def _update_partition(partition_state, node_rw_count, node_info):
if partition_state.cur_count + node_rw_count > ForeachKernel.MAX_NUM_ARGS:
partition_state.partitions.append(partition_state.cur_partition)
partition_state.cur_partition = [node_info]
partition_state.cur_count = node_rw_count
else:
partition_state.cur_count += node_rw_count
partition_state.cur_partition.append(node_info)
@staticmethod
def horizontal_partition(subkernel_nodes, triton_scheduling):
"""Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
(read/writes) and to have the same 2D or 1D blocking strategy."""
assert len(subkernel_nodes) >= 1
partition_state_1d = PartitionState([], [], 0)
yelem_to_partition_state_2d: Dict[Integer, PartitionState] = defaultdict(
lambda: PartitionState([], [], 0)
)
for node in subkernel_nodes:
fused_nodes = node.get_nodes()
_, (numel, rnumel) = max(
fused_nodes, key=lambda x: int(x.is_reduction())
).group
tiled_groups = triton_scheduling.select_tiling(fused_nodes, numel, rnumel)
node_info = fused_nodes, tiled_groups, numel, rnumel
read_writes = node.read_writes
read_write_count = len(read_writes.reads) + len(read_writes.writes)
if tiled_groups[1] == 1:
ForeachKernel._update_partition(
partition_state_1d, read_write_count, node_info
)
else:
y_elem = tiled_groups[0]
partition_state_2d = yelem_to_partition_state_2d[y_elem]
ForeachKernel._update_partition(
partition_state_2d, read_write_count, node_info
)
partition_state_1d.finalize()
all_partitions = partition_state_1d.partitions
for partition_state_2d in yelem_to_partition_state_2d.values():
partition_state_2d.finalize()
all_partitions.extend(partition_state_2d.partitions)
return all_partitions
def __init__(self):
super().__init__()
self.blocking_2d = False
self.block_size_1d = 1024 # Try tuning this value
self.block_size_2d = 32
self.num_warps = 8
self.sub_kernels = []
self.iter_vars_count = itertools.count()
self.x_block_count = 0
self.y_block_count = 0
def get_block_size(self):
if self.blocking_2d:
return self.block_size_2d
else:
return self.block_size_1d
@staticmethod
def codegen_pid_offsets(code, block_count, lower_bound, prefix):
if block_count == 0:
code.splice(f"{prefix}pid_offset = {prefix}pid")
else:
code.splice(f"{prefix}pid_offset = {prefix}pid - {lower_bound}")
def codegen_pid_range(self, code, x_elems):
num_x_blocks = ceildiv(x_elems, self.get_block_size())
upper_bound_x_pid = self.x_block_count + num_x_blocks
lower_bound_x_pid = self.x_block_count
if self.x_block_count == 0:
cond = "if"
else:
cond = "elif"
x_pid_bounds_check = (
f"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}"
)
code.splice(f"{cond} {x_pid_bounds_check}:")
with code.indent():
ForeachKernel.codegen_pid_offsets(
code, num_x_blocks, lower_bound_x_pid, "x"
)
self.x_block_count += num_x_blocks
def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint):
sub_kernel = TritonKernel(
*groups,
index_dtype=index_dtype,
mutations=mutations,
pid_cache={
"tl.program_id(0)": "xpid_offset",
"tl.program_id(1)": "ypid",
},
reduction_hint=reduction_hint,
)
if self.blocking_2d:
assert len(groups) == 3
self.blocking_2d |= groups[1] != 1 and len(groups) == 3
metrics.generated_kernel_count -= 1
sub_kernel.args = self.args
sub_kernel.iter_vars_count = self.iter_vars_count
sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
self.sub_kernels.append(sub_kernel)
return sub_kernel
def jit_lines(self):
can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
_, _, signature = self.args.python_argdefs()
triton_meta = {
"signature": signature_to_meta(signature, size_dtype=size_dtype),
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]
inductor_meta = {
"kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
**TritonKernel.inductor_meta_common(),
}
return f"""
@triton_heuristics.foreach(
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)
@triton.jit
"""
def grid(self):
return (
self.x_block_count,
ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)
if self.blocking_2d
else 1,
1,
)
def codegen_kernel(self, name=None):
code = IndentedBuffer()
code.splice(gen_common_triton_imports())
argdefs, _, _ = self.args.python_argdefs()
code.splice(self.jit_lines())
code.writeline(
f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
)
with code.indent():
code.splice("xpid = tl.program_id(0)")
if self.blocking_2d:
code.splice("ypid = tl.program_id(1)")
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
else:
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
for sub_kernel in self.sub_kernels:
assert len(sub_kernel.numels) <= 3
# TODO mlazos: support dynamic shapes
numel_ind = 0 if not self.blocking_2d else 1
self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))
with code.indent():
if self.blocking_2d:
code.splice(f"ynumel = {sub_kernel.numels[0]}")
code.splice(f"xnumel = {sub_kernel.numels[1]}")
else:
code.splice(f"xnumel = {sub_kernel.numels[0]}")
sub_kernel.codegen_body()
code.splice(sub_kernel.body)
code.splice("else:")
with code.indent():
code.splice("pass")
return code.getvalue()
def call_kernel(self, code, name: str):
_, call_args, _ = self.args.python_argdefs()
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
for i in range(len(call_args)):
if V.graph.is_unspec_arg(call_args[i]):
call_args[i] = call_args[i] + ".item()"
if V.graph.cpp_wrapper:
V.graph.wrapper_code.generate_kernel_call(
name,
call_args,
device_index=V.graph.scheduler.current_device.index,
grid=self.grid(),
)
else:
# TODO: refactor generate_kernel_call
call_args_str = ", ".join(call_args)
stream_name = code.write_get_raw_stream(
V.graph.scheduler.current_device.index
)
code.writeline(
f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})"
)

View File

@ -529,7 +529,7 @@ class WrapperCodeGen(CodeGen):
"""
import triton
import triton.language as tl
from {} import grid, split_scan_grid, start_graph, end_graph
from {} import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
{}
""".format(
triton_heuristics.__name__,
@ -1361,8 +1361,15 @@ class WrapperCodeGen(CodeGen):
"""
)
def generate_default_grid(self, name: str, grid_args: List[Any]):
return grid_args
def generate_default_grid(
self,
name: str,
grid: List[Any],
cuda: bool = True,
grid_callable: Optional[Callable[..., Any]] = None,
**grid_extra_kwags,
):
return grid
def generate_kernel_call(
self,
@ -1375,6 +1382,7 @@ class WrapperCodeGen(CodeGen):
arg_types=None,
grid_fn: str = "grid",
triton_meta=None,
grid_extra_kwargs="",
):
"""
Generates kernel call code.
@ -1392,6 +1400,8 @@ class WrapperCodeGen(CodeGen):
)
if triton:
grid_str = ", ".join(pexpr(item) for item in grid)
if grid_extra_kwargs:
grid_str = f"{grid_str}, {grid_extra_kwargs}"
grid_str = f"{grid_fn}({grid_str})"
self.writeline(
f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"

View File

@ -362,6 +362,12 @@ assert_indirect_indexing = True
# compute CSE bounds on variables that do not appear in the FX graph
compute_all_bounds = False
# enable the combo kernel that combines data-independent kernels (additional
# to foreach kernels) into a single one (Experimental)
combo_kernels = False
# benchmark combo kernels and only allow ones with perf gains
benchmark_combo_kernel = False
# constant folding on the joint graph
joint_graph_constant_folding = True

View File

@ -1741,3 +1741,15 @@ def split_scan_grid(xnumel, rnumel):
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
return grid_fn
def grid_combo_kernels(*numels, num_kernels, min_blocks):
"""min_blocks is the minimal size of the grid x dimension"""
kernel_grid_fn = grid(*numels)
def grid_fn(meta):
cuda_grid = list(kernel_grid_fn(meta))
cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks)
return tuple(cuda_grid)
return grid_fn

View File

@ -1025,8 +1025,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
class ForeachKernelSchedulerNode(FusedSchedulerNode):
"""Scheduler node which consists of a list of scheduler nodes that each operate on a
distinct tensor in a list of tensors."""
"""
This is a schedular node that consists of a set of scheduler nodes that
has no data dependencies among them and can be executed in parallel.
"""
def get_consumer_subnode_for(self, producer):
if producer.get_name() in self.read_to_node:
@ -1075,6 +1077,11 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
@classmethod
def fuse(cls, producer, consumer):
assert producer.is_foreach() or consumer.is_foreach()
use_custom_partition_algo = (
producer.use_custom_partition_algo
if producer.is_foreach()
else consumer.use_custom_partition_algo
)
prev_node_1 = None
prev_node_2 = None
if producer.is_foreach() and consumer.is_foreach():
@ -1108,13 +1115,24 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
fused_nodes.append(new_node)
else:
fused_nodes.append(node)
else:
raise AssertionError(
"At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node"
)
return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2) # type: ignore[possibly-undefined]
return cls(
producer.scheduler,
fused_nodes,
use_custom_partition_algo=use_custom_partition_algo,
prev_node_1=prev_node_1,
prev_node_2=prev_node_2,
)
def __init__(
self,
scheduler: "Scheduler",
nodes: List[SchedulerNode],
snodes: List[SchedulerNode],
use_custom_partition_algo: bool,
prev_node_1=None,
prev_node_2=None,
):
@ -1122,9 +1140,9 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
self.name_to_node = {}
if prev_node_1 is None or prev_node_2 is None:
super().__init__(scheduler, nodes)
super().__init__(scheduler, snodes)
for node in nodes:
for node in snodes:
for read in node.read_writes.reads:
self.read_to_node[read.name] = node
@ -1132,7 +1150,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
self.name_to_node[name] = node
else:
self.scheduler = scheduler
self.snodes = nodes
self.snodes = snodes
self.node: ir.Buffer = None # type: ignore[assignment]
self.users: List[NodeUser] = []
@ -1163,17 +1181,43 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
for name in other_node.get_names():
self.name_to_node[name] = other_node
self.group = (nodes[0].get_device(), "foreach")
self.use_custom_partition_algo = use_custom_partition_algo
self.group = (snodes[0].get_device(), "combo_kernel")
self.origins: Set[torch.fx.Node] = set()
@classmethod
def combinable_nodes(cls, nodes: List[SchedulerNode]) -> List[SchedulerNode]:
extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)]
if extern:
log.debug(
"ComboKernels: %d external nodes are filtered %s",
len(extern),
[node.node.origins for node in extern],
)
filtered_nodes = [
x
for x in nodes
if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode))
]
foreach_nodes = [
x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode)
]
if foreach_nodes:
log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes))
filtered_nodes = [
x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode)
]
template_nodes = [x for x in filtered_nodes if x.is_template()]
if template_nodes:
log.debug(
"ComboKernels: %d template nodes are filtered", {len(template_nodes)}
)
filtered_nodes = [x for x in filtered_nodes if x not in template_nodes]
return filtered_nodes
def mark_run(self):
raise NotImplementedError
def codegen(self):
assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}"
self.node.get_store_function()(self.node.make_loader()())
def can_free(self):
return NotImplementedError
@ -1181,7 +1225,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
return True
def get_subkernel_nodes(self):
"""Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists.
"""Returns a list of nodes which comprise the combo kernel.
These nodes may be vertically fused."""
return list(self.snodes)
@ -1339,6 +1383,8 @@ class Scheduler:
# Refresh node_users and inverse_users to reflect fused nodes
self.compute_node_users()
self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
if config.combo_kernels:
self.create_combo_kernel_nodes(num_ck_nodes=None)
self.compute_last_usage()
V.debug.ir_post_fusion(self.nodes)
V.debug.graph_diagram(self.nodes)
@ -1405,7 +1451,7 @@ class Scheduler:
removed_node_names.update(names)
snodes = [self.name_to_node[name] for name in names]
fe_node = ForeachKernelSchedulerNode(self, snodes) # type: ignore[arg-type]
fe_node = ForeachKernelSchedulerNode(self, snodes, use_custom_partition_algo=False) # type: ignore[arg-type]
fe_nodes.append(fe_node)
@ -1694,6 +1740,51 @@ class Scheduler:
visit(node)
self.nodes = result
def _get_unmet_dep_nodes(self, snode):
unmet_deps = set()
if isinstance(
snode,
(
SchedulerNode,
ExternKernelSchedulerNode,
NopKernelSchedulerNode,
FusedSchedulerNode,
),
):
for dep in snode.unmet_dependencies:
unmet_deps.add(dep.name)
else:
raise RuntimeError(
f"get_unmet_dep_nodes is not implemented for {type(snode)}."
)
return list({self.name_to_fused_node[n] for n in unmet_deps})
def _topological_sort_nodes(self):
"""
Sort nodes by their topological order, return a list of node lists.
"""
order = []
nodes = {n: 0 for n in self.nodes}
children: Dict[Any, Any] = {}
for node in self.nodes:
deps = self._get_unmet_dep_nodes(node)
nodes[node] = len(deps)
for dep in deps:
c = children.get(dep, [])
c.append(node)
children[dep] = c
zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
while zero_deg_nodes:
order.append(zero_deg_nodes)
for n in zero_deg_nodes:
for user in children.get(n, []):
nodes[user] -= 1
nodes.pop(n)
zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
assert not nodes, "Topological sort failed!"
return order
def compute_ancestors(self):
"""
Populate each node.ancestors
@ -1975,6 +2066,65 @@ class Scheduler:
self.topological_sort_schedule()
self.prune_redundant_deps()
def create_combo_kernel_nodes(self, num_ck_nodes=None):
"""
Groups parallel nodes
"""
fused_nodes = set(self.nodes)
count = 0
num_nodes_orig = len(self.nodes)
log.debug("ComboKernels: Generating with num_ck_nodes = %d...", num_ck_nodes)
for num, node_list in enumerate(self._group_nodes_for_combo_kernels()):
node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list)
if len(node_list) < 2:
continue
if num_ck_nodes is not None and count > num_ck_nodes:
break
if not self.speedup_by_combo_kernel(node_list):
log.debug("ComboKernels: Not speeding up %d-th group", num)
continue
count += 1
group_snode = ForeachKernelSchedulerNode(
node_list[0].scheduler, node_list, use_custom_partition_algo=True
)
log.info(
"ComboKernels: Combining %d nodes for %d-th group",
len(node_list),
num,
)
for node in node_list:
fused_nodes.remove(node)
fused_nodes.add(group_snode)
self.name_to_fused_node.update(
{n.get_name(): group_snode for n in group_snode.get_nodes()}
)
self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
self.topological_sort_schedule()
log.info(
"Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodels",
count,
num_nodes_orig,
len(self.nodes),
)
self.prune_redundant_deps()
def _group_nodes_for_combo_kernels(self):
"""
Returns a list of lists of nodes that are to be grouped together.
"""
sorted_nodes = self._topological_sort_nodes()
grouped_nodes = []
max_num_nodes = 8
for nodes in sorted_nodes:
grouped_nodes.extend(
[
nodes[i : i + max_num_nodes]
for i in range(0, len(nodes), max_num_nodes)
]
)
return grouped_nodes
def prune_redundant_deps(self):
for node in self.nodes:
node.prune_redundant_deps(self.name_to_fused_node)
@ -2583,7 +2733,7 @@ class Scheduler:
elif node.is_extern():
self.codegen_extern_call(node)
elif node.is_foreach():
self.get_backend(device).codegen_foreach(node) # type: ignore[possibly-undefined]
self.get_backend(device).codegen_combo_kernel(node) # type: ignore[possibly-undefined]
elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
self.get_backend(device).codegen_node(node) # type: ignore[possibly-undefined]
else:
@ -2614,6 +2764,98 @@ class Scheduler:
node = self.name_to_node[buf_name]
return node.node.get_layout()
def benchmark_combo_kernel(self, node_list):
"""
Benchmark fused list of nodes and return the execution time
in milliseconds on randomly generated inputs.
"""
device = node_list[0].get_device()
V.graph.scheduler = self
self.current_device = device
backend = self.get_backend(device)
return backend.benchmark_combo_kernel(node_list)
def speedup_by_combo_kernel(self, node_list):
"""
If config.benchmark_fusion is False, always return True.
Otherwise, return True if fusion can brings speedup.
"""
if not config.benchmark_combo_kernel:
return True
subkernel_nodes = node_list
device = subkernel_nodes[0].get_device()
# don't support benchmark fusion for CPU right now.
if device.type == "cpu":
return True
from triton.compiler.errors import CompilationError
ms1, path1_list = 0.0, []
for i, snode in enumerate(subkernel_nodes):
node_list = snode.get_nodes()
# We can not accurately benchmark kernel using atomic_add
# due to how we generate random integer inputs.
# Skip benchmarking them by allowing fusion.
if any(
hasattr(n.node, "data")
and hasattr(n.node.data, "scatter_mode")
and n.node.data.scatter_mode == "atomic_add"
for n in node_list
):
fusion_log.debug(
"ComboKernel: benchmarking may not accurate due to atomic_add"
)
try:
ms, path = self.benchmark_fused_nodes(node_list)
if math.isinf(ms):
fusion_log.debug(
"ComboKernel benchmark: register spilling of %d-th subkernel",
i,
)
return False
except CompilationError as e:
# workaround triton issue: https://github.com/openai/triton/issues/2151
if "Loop-carried variable" in str(e):
fusion_log.debug(
"ComboKernel benchmark: return True because of loop-carried variable"
)
return True # allow fusion
else:
raise
ms1 += ms
path1_list.append(path)
try:
ms2, ms2_clone, path2_list = self.benchmark_combo_kernel(subkernel_nodes)
except CompilationError as e:
# workaround triton issue: https://github.com/openai/triton/issues/2151
if "Loop-carried variable" in str(e):
fusion_log.debug(
"ComboKernel benchmark: return True because of loop-carried variable"
)
return True # allow fusion
else:
raise
# small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking.
small_kernel = ms2_clone / ms2 > 0.6 and ms2 - ms2_clone < 0.2
if fusion_log.isEnabledFor(logging.DEBUG):
if ms1 > ms2 or small_kernel:
fusion_log.debug(
"can fuse (benchmark): fusing causes %sx speedup",
green_text(f"{ms1 / ms2:.3f}"),
)
else:
fusion_log.debug(
"cannot fuse (benchmark): fusing causes %sx slowdown",
red_text(f"{ms1 / ms2:.3f}"),
)
return ms2 < ms1 or small_kernel
class BaseScheduling:
def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):