mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			ciflow/tru
			...
			mlazos/ck2
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0fef7b0ae2 | 
							
								
								
									
										262
									
								
								test/inductor/test_combo_kernels.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										262
									
								
								test/inductor/test_combo_kernels.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,262 @@
 | 
			
		||||
# Owner(s): ["module: inductor"]
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
import torch._inductor
 | 
			
		||||
 | 
			
		||||
from torch.testing._internal.common_utils import (
 | 
			
		||||
    instantiate_parametrized_tests,
 | 
			
		||||
    TestCase,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
 | 
			
		||||
from torch.testing._internal.triton_utils import requires_cuda
 | 
			
		||||
 | 
			
		||||
aten = torch.ops.aten
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    try:
 | 
			
		||||
        from .test_torchinductor import check_model, check_model_cuda
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        from test_torchinductor import check_model, check_model_cuda
 | 
			
		||||
except (unittest.SkipTest, ImportError) as e:
 | 
			
		||||
    sys.stderr.write(f"{type(e)}: {e}\n")
 | 
			
		||||
    if __name__ == "__main__":
 | 
			
		||||
        sys.exit(0)
 | 
			
		||||
    raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@instantiate_parametrized_tests
 | 
			
		||||
class ComboKernelTests(TestCase):
 | 
			
		||||
    check_model_cuda = check_model_cuda
 | 
			
		||||
    check_model_cpu = check_model
 | 
			
		||||
    check_kernel_count = True
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        torch._inductor.metrics.reset()
 | 
			
		||||
        torch._inductor.config.combo_kernels = True
 | 
			
		||||
        torch._inductor.config.benchmark_combo_kernel = False
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        super().tearDown()
 | 
			
		||||
        torch._inductor.metrics.reset()
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_activation_functions(self):
 | 
			
		||||
        def test_activations(a, b, c):
 | 
			
		||||
            a1 = torch.nn.functional.relu(a)
 | 
			
		||||
            b1 = torch.nn.functional.sigmoid(b)
 | 
			
		||||
            c1 = torch.nn.functional.tanh(c)
 | 
			
		||||
            return a1, b1, c1
 | 
			
		||||
 | 
			
		||||
        inps = [
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(20, 20, device="cuda"),
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        out_eager = test_activations(*inps)
 | 
			
		||||
        out_compiled = torch.compile(test_activations)(*inps)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(out_eager, out_compiled)
 | 
			
		||||
        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_reduce_functions(self):
 | 
			
		||||
        def test_reduce(a, b, c, d):
 | 
			
		||||
            a1 = torch.sum(a, dim=0)
 | 
			
		||||
            b1 = torch.max(b, dim=0)
 | 
			
		||||
            c1 = torch.min(c, dim=0)
 | 
			
		||||
            d1 = torch.nn.functional.tanh(d)
 | 
			
		||||
 | 
			
		||||
            return a1, b1, c1, d1
 | 
			
		||||
 | 
			
		||||
        inps = [
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(20, 20, device="cuda"),
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(30, 8, device="cuda"),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        out_eager = test_reduce(*inps)
 | 
			
		||||
        out_compiled = torch.compile(test_reduce)(*inps)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(out_eager, out_compiled)
 | 
			
		||||
        self.assertTrue(torch._inductor.metrics.generated_kernel_count <= 2)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_mutated_args(self):
 | 
			
		||||
        def test_mutated(a, b, c, d):
 | 
			
		||||
            a.add_(1)
 | 
			
		||||
            b.sigmoid_()
 | 
			
		||||
            c = torch.add(c, 5)
 | 
			
		||||
            d.tanh_()
 | 
			
		||||
 | 
			
		||||
            return a, b, c, d
 | 
			
		||||
 | 
			
		||||
        inps = [
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(20, 20, device="cuda"),
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(30, 8, device="cuda"),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        out_eager = test_mutated(*inps)
 | 
			
		||||
        out_compiled = torch.compile(test_mutated)(*inps)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(out_eager, out_compiled)
 | 
			
		||||
        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_reduce_split(self):
 | 
			
		||||
        def fn(a, b):
 | 
			
		||||
            a1 = torch.linalg.vector_norm(a)
 | 
			
		||||
            b1 = torch.sum(b, dim=0)
 | 
			
		||||
            return a1, b1
 | 
			
		||||
 | 
			
		||||
        inps = [
 | 
			
		||||
            torch.rand(2048, 512, device="cuda"),
 | 
			
		||||
            torch.rand(20, 20, device="cuda"),
 | 
			
		||||
        ]
 | 
			
		||||
        out_eager = fn(*inps)
 | 
			
		||||
        out_compiled = torch.compile(fn)(*inps)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(out_eager, out_compiled)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_2d_blocking_partitioning(self):
 | 
			
		||||
        def fn(a0, a1, a2, b0, b1, b2):
 | 
			
		||||
            c0 = torch.add(a0, b0)
 | 
			
		||||
            c1 = torch.add(a1, b1)
 | 
			
		||||
            c2 = torch.add(a2, b2)
 | 
			
		||||
            return c0, c1, c2
 | 
			
		||||
 | 
			
		||||
        self.check_model_cuda(
 | 
			
		||||
            fn,
 | 
			
		||||
            (
 | 
			
		||||
                torch.rand(30, 20, device="cuda"),
 | 
			
		||||
                torch.rand(40, 30, device="cuda"),
 | 
			
		||||
                torch.rand(36, 40, device="cuda"),
 | 
			
		||||
                torch.rand(30, 20, device="cuda"),
 | 
			
		||||
                torch.rand(30, 40, device="cuda").t(),
 | 
			
		||||
                torch.rand(40, 36, device="cuda").t(),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@instantiate_parametrized_tests
 | 
			
		||||
class ComboKernelBenchmarkTests(TestCase):
 | 
			
		||||
    check_model_cuda = check_model_cuda
 | 
			
		||||
    check_model_cpu = check_model
 | 
			
		||||
    check_kernel_count = True
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        torch._inductor.metrics.reset()
 | 
			
		||||
        torch._inductor.config.combo_kernels = True
 | 
			
		||||
        torch._inductor.config.benchmark_combo_kernel = True
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        super().tearDown()
 | 
			
		||||
        torch._inductor.metrics.reset()
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_activation_benchmark(self):
 | 
			
		||||
        def test_activations(a, b, c):
 | 
			
		||||
            a1 = torch.nn.functional.relu(a)
 | 
			
		||||
            b1 = torch.nn.functional.sigmoid(b)
 | 
			
		||||
            c1 = torch.nn.functional.tanh(c)
 | 
			
		||||
            return a1, b1, c1
 | 
			
		||||
 | 
			
		||||
        inps = [
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(20, 20, device="cuda"),
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        out_eager = test_activations(*inps)
 | 
			
		||||
        out_compiled = torch.compile(test_activations)(*inps)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(out_eager, out_compiled)
 | 
			
		||||
        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 5)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_reduce_benchmark(self):
 | 
			
		||||
        def test_reduce(a, b, c, d):
 | 
			
		||||
            a1 = torch.sum(a, dim=0)
 | 
			
		||||
            b1 = torch.max(b, dim=0)
 | 
			
		||||
            c1 = torch.min(c, dim=0)
 | 
			
		||||
            d1 = torch.nn.functional.tanh(d)
 | 
			
		||||
 | 
			
		||||
            return a1, b1, c1, d1
 | 
			
		||||
 | 
			
		||||
        inps = [
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(20, 20, device="cuda"),
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(30, 8, device="cuda"),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        out_eager = test_reduce(*inps)
 | 
			
		||||
        out_compiled = torch.compile(test_reduce)(*inps)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(out_eager, out_compiled)
 | 
			
		||||
        self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_mutated_benchmark(self):
 | 
			
		||||
        def test_mutated(a, b, c, d):
 | 
			
		||||
            a.add_(1)
 | 
			
		||||
            b.sigmoid_()
 | 
			
		||||
            c = torch.add(c, 5)
 | 
			
		||||
            d.tanh_()
 | 
			
		||||
 | 
			
		||||
            return a, b, c, d
 | 
			
		||||
 | 
			
		||||
        inps = [
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(20, 20, device="cuda"),
 | 
			
		||||
            torch.rand(10, 10, device="cuda"),
 | 
			
		||||
            torch.rand(30, 8, device="cuda"),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        out_eager = test_mutated(*inps)
 | 
			
		||||
        out_compiled = torch.compile(test_mutated)(*inps)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(out_eager, out_compiled)
 | 
			
		||||
        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 6)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    def test_2d_blocking_benchmark(self):
 | 
			
		||||
        def fn(a0, a1, a2, b0, b1, b2):
 | 
			
		||||
            c0 = torch.add(a0, b0)
 | 
			
		||||
            c1 = torch.add(a1, b1)
 | 
			
		||||
            c2 = torch.add(a2, b2)
 | 
			
		||||
            return c0, c1, c2
 | 
			
		||||
 | 
			
		||||
        self.check_model_cuda(
 | 
			
		||||
            fn,
 | 
			
		||||
            (
 | 
			
		||||
                torch.rand(30, 20, device="cuda"),
 | 
			
		||||
                torch.rand(40, 30, device="cuda"),
 | 
			
		||||
                torch.rand(36, 40, device="cuda"),
 | 
			
		||||
                torch.rand(30, 20, device="cuda"),
 | 
			
		||||
                torch.rand(30, 40, device="cuda").t(),
 | 
			
		||||
                torch.rand(40, 36, device="cuda").t(),
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from torch._dynamo.test_case import run_tests
 | 
			
		||||
 | 
			
		||||
    if HAS_CPU or HAS_CUDA:
 | 
			
		||||
        run_tests(needs="filelock")
 | 
			
		||||
@ -622,7 +622,7 @@ class ForeachTests(TestCase):
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
 | 
			
		||||
        self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)
 | 
			
		||||
 | 
			
		||||
    @requires_cuda
 | 
			
		||||
    @inplace_bin_ops
 | 
			
		||||
 | 
			
		||||
@ -132,10 +132,10 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase):
 | 
			
		||||
 | 
			
		||||
            args = (x, y)
 | 
			
		||||
 | 
			
		||||
            events = self._test_profiling_kernel_names(fn, args, "_for_")
 | 
			
		||||
            events = self._test_profiling_kernel_names(fn, args, "_poi_")
 | 
			
		||||
            event_found = False
 | 
			
		||||
            for event in events:
 | 
			
		||||
                if event.name == "triton_for_fused_0":
 | 
			
		||||
                if event.name == "triton_poi_fused_0":
 | 
			
		||||
                    event_found = True
 | 
			
		||||
                    self.assertTrue(
 | 
			
		||||
                        event.input_shapes
 | 
			
		||||
 | 
			
		||||
@ -82,6 +82,7 @@ class CppWrapperCpu(WrapperCodeGen):
 | 
			
		||||
        arg_types=None,
 | 
			
		||||
        grid_fn: str = "grid",
 | 
			
		||||
        triton_meta=None,
 | 
			
		||||
        grid_extra_kwargs="",
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Generates kernel call code.
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
import functools
 | 
			
		||||
import os
 | 
			
		||||
from itertools import chain, count
 | 
			
		||||
from typing import Any, List, Optional, TYPE_CHECKING
 | 
			
		||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import sympy
 | 
			
		||||
 | 
			
		||||
@ -159,16 +159,27 @@ class CppWrapperCuda(CppWrapperCpu):
 | 
			
		||||
 | 
			
		||||
        return ", ".join(new_args)
 | 
			
		||||
 | 
			
		||||
    def generate_default_grid(self, name: str, grid: List[Any], cuda: bool = True):
 | 
			
		||||
    def generate_default_grid(
 | 
			
		||||
        self,
 | 
			
		||||
        name: str,
 | 
			
		||||
        grid: List[Any],
 | 
			
		||||
        cuda: bool = True,
 | 
			
		||||
        grid_callable: Optional[Callable[..., Any]] = None,
 | 
			
		||||
        **grid_extra_kwags,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Generate grid configs for launching a CUDA kernel using the grid
 | 
			
		||||
        function from triton_heuristics.
 | 
			
		||||
        """
 | 
			
		||||
        if not cuda:
 | 
			
		||||
            return grid
 | 
			
		||||
        assert isinstance(grid, list), f"expected {grid=} to be a list"
 | 
			
		||||
        assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
 | 
			
		||||
        grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
 | 
			
		||||
        grid_fn = default_grid(*grid)
 | 
			
		||||
        grid_callable = grid_callable or default_grid
 | 
			
		||||
        if not grid_extra_kwags:
 | 
			
		||||
            grid_fn = grid_callable(*grid)
 | 
			
		||||
        else:
 | 
			
		||||
            grid_fn = grid_callable(*grid, **grid_extra_kwags)
 | 
			
		||||
        params = CudaKernelParamCache.get(name)
 | 
			
		||||
        assert (
 | 
			
		||||
            params is not None
 | 
			
		||||
@ -191,6 +202,7 @@ class CppWrapperCuda(CppWrapperCpu):
 | 
			
		||||
        arg_types=None,
 | 
			
		||||
        grid_fn: str = "grid",
 | 
			
		||||
        triton_meta=None,
 | 
			
		||||
        grid_extra_kwargs="",
 | 
			
		||||
    ):
 | 
			
		||||
        if not cuda:
 | 
			
		||||
            # Even in CppWrapperCuda, we may see cpp kernels
 | 
			
		||||
 | 
			
		||||
@ -71,8 +71,8 @@ class CUDACombinedScheduling(BaseScheduling):
 | 
			
		||||
    def flush(self):
 | 
			
		||||
        return self._triton_scheduling.flush()
 | 
			
		||||
 | 
			
		||||
    def codegen_foreach(self, *args, **kwargs):
 | 
			
		||||
        return self._triton_scheduling.codegen_foreach(*args, **kwargs)
 | 
			
		||||
    def codegen_combo_kernel(self, *args, **kwargs):
 | 
			
		||||
        return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def benchmark_fused_nodes(self, nodes):
 | 
			
		||||
        return self._triton_scheduling.benchmark_fused_nodes(nodes)
 | 
			
		||||
@ -81,3 +81,6 @@ class CUDACombinedScheduling(BaseScheduling):
 | 
			
		||||
        return self._triton_scheduling.generate_kernel_code_from_nodes(
 | 
			
		||||
            nodes, benchmark_kernel
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def benchmark_combo_kernel(self, node_list):
 | 
			
		||||
        return self._triton_scheduling.benchmark_combo_kernel(node_list)
 | 
			
		||||
 | 
			
		||||
@ -1296,6 +1296,7 @@ class TritonKernel(Kernel):
 | 
			
		||||
        reduction_hint=ReductionHint.DEFAULT,
 | 
			
		||||
        min_elem_per_thread=0,
 | 
			
		||||
        disable_persistent_reduction=False,
 | 
			
		||||
        optimize_mask=True,
 | 
			
		||||
    ):
 | 
			
		||||
        if pid_cache is None:
 | 
			
		||||
            pid_cache = {}
 | 
			
		||||
@ -1317,6 +1318,7 @@ class TritonKernel(Kernel):
 | 
			
		||||
        self.block_ptr_id = itertools.count()
 | 
			
		||||
        # buffer accesses in the kernel
 | 
			
		||||
        self.buf_accesses: DefaultDict[str, List[Dep]] = collections.defaultdict(list)
 | 
			
		||||
        self.optimize_mask = optimize_mask
 | 
			
		||||
 | 
			
		||||
        self.persistent_reduction: bool = (
 | 
			
		||||
            not disable_persistent_reduction
 | 
			
		||||
@ -1751,9 +1753,11 @@ class TritonKernel(Kernel):
 | 
			
		||||
        if isinstance(index, sympy.Integer):
 | 
			
		||||
            expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
 | 
			
		||||
            index_str = f"tl.full({expand_str}, {index_str}, tl.int32)"
 | 
			
		||||
            return IndexingOptions(index_str, set(), "None", expand_str, has_rindex)
 | 
			
		||||
            if self.optimize_mask:
 | 
			
		||||
                return IndexingOptions(index_str, set(), "None", expand_str, has_rindex)
 | 
			
		||||
            mask_vars = dense_mask_vars
 | 
			
		||||
 | 
			
		||||
        if need_dense and not have_dense:
 | 
			
		||||
        elif need_dense and not have_dense:
 | 
			
		||||
            expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str()
 | 
			
		||||
            index_str = f"tl.broadcast_to({index_str}, {expand_str})"
 | 
			
		||||
            mask_vars = dense_mask_vars
 | 
			
		||||
@ -1785,6 +1789,8 @@ class TritonKernel(Kernel):
 | 
			
		||||
        return trees
 | 
			
		||||
 | 
			
		||||
    def filter_masks(self, mask_vars):
 | 
			
		||||
        if not self.optimize_mask:
 | 
			
		||||
            return
 | 
			
		||||
        for tree in self.range_trees:
 | 
			
		||||
            # Masks are superfluous if we only have one element
 | 
			
		||||
            if V.graph.sizevars.statically_known_equals(tree.numel, 1):  # type: ignore[arg-type]
 | 
			
		||||
@ -3759,34 +3765,53 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
    def codegen_sync(self):
 | 
			
		||||
        V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize())
 | 
			
		||||
 | 
			
		||||
    def codegen_foreach(self, foreach_node):
 | 
			
		||||
        from .triton_foreach import ForeachKernel
 | 
			
		||||
    def codegen_combo_kernel(self, combo_kernel_node):
 | 
			
		||||
        from .triton_combo_kernel import ComboKernel
 | 
			
		||||
 | 
			
		||||
        for partitions_with_metadata in ForeachKernel.horizontal_partition(
 | 
			
		||||
            foreach_node.get_subkernel_nodes(), self
 | 
			
		||||
        ):
 | 
			
		||||
            kernel = ForeachKernel()
 | 
			
		||||
            for nodes, tiled_groups, numel, rnumel in partitions_with_metadata:
 | 
			
		||||
                node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
 | 
			
		||||
                (
 | 
			
		||||
                    reduction_hint_val,
 | 
			
		||||
                    mutations,
 | 
			
		||||
                    index_dtype,
 | 
			
		||||
                ) = self.get_kernel_args(node_schedule, numel, rnumel)
 | 
			
		||||
        subkernel_nodes = combo_kernel_node.get_subkernel_nodes()
 | 
			
		||||
        fused_node_lists = [node.get_nodes() for node in subkernel_nodes]
 | 
			
		||||
        subkernel_map, node_schedule_map = {}, {}
 | 
			
		||||
        for pn, nodes in zip(subkernel_nodes, fused_node_lists):
 | 
			
		||||
            _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
 | 
			
		||||
            node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
 | 
			
		||||
            tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
 | 
			
		||||
            node_schedule_map[pn] = node_schedule, tiled_groups, numel, rnumel
 | 
			
		||||
            (
 | 
			
		||||
                reduction_hint_val,
 | 
			
		||||
                mutations,
 | 
			
		||||
                index_dtype,
 | 
			
		||||
            ) = self.get_kernel_args(node_schedule, numel, rnumel)
 | 
			
		||||
            subkernel_map[pn] = ComboKernel.create_triton_kernel(
 | 
			
		||||
                *tiled_groups,
 | 
			
		||||
                reduction_hint=reduction_hint_val,
 | 
			
		||||
                mutations=mutations,
 | 
			
		||||
                index_dtype=index_dtype,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
                subkernel = kernel.create_sub_kernel(
 | 
			
		||||
                    *tiled_groups,
 | 
			
		||||
                    reduction_hint=reduction_hint_val,
 | 
			
		||||
                    mutations=mutations,
 | 
			
		||||
                    index_dtype=index_dtype,
 | 
			
		||||
                )
 | 
			
		||||
        partitions = ComboKernel.horizontal_partition(
 | 
			
		||||
            nodes=combo_kernel_node.get_subkernel_nodes(),
 | 
			
		||||
            triton_scheduling=self,
 | 
			
		||||
            custom_algorithm=combo_kernel_node.use_custom_partition_algo,
 | 
			
		||||
            kernel_map=subkernel_map,
 | 
			
		||||
            node_info_map=node_schedule_map,
 | 
			
		||||
        )
 | 
			
		||||
        log.debug(
 | 
			
		||||
            "ComboKernels: %d nodes partitioned into %s groups",
 | 
			
		||||
            len(combo_kernel_node.get_subkernel_nodes()),
 | 
			
		||||
            [len(p) for p in partitions],
 | 
			
		||||
        )
 | 
			
		||||
        for node_group in partitions:
 | 
			
		||||
            fused_node_lists = [node.get_nodes() for node in node_group]
 | 
			
		||||
            kernel = ComboKernel()
 | 
			
		||||
 | 
			
		||||
            for pn, nodes in zip(node_group, fused_node_lists):
 | 
			
		||||
                self.codegen_node_schedule_with_kernel(
 | 
			
		||||
                    node_schedule,
 | 
			
		||||
                    subkernel,
 | 
			
		||||
                    node_schedule_map[pn][0],
 | 
			
		||||
                    kernel.create_sub_kernel(subkernel_map[pn]),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                with V.set_kernel_handler(subkernel):
 | 
			
		||||
                subkernel = subkernel_map[pn]
 | 
			
		||||
                node_schedule = node_schedule_map[pn][0]
 | 
			
		||||
                with V.set_kernel_handler(subkernel):  # type: ignore[call-arg]
 | 
			
		||||
                    for node in node_schedule:
 | 
			
		||||
                        if node not in (EnableReduction, DisableReduction):
 | 
			
		||||
                            node.mark_run()
 | 
			
		||||
@ -3794,8 +3819,9 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
                V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
 | 
			
		||||
 | 
			
		||||
            src_code = kernel.codegen_kernel()
 | 
			
		||||
            kernel_name = self.define_kernel(src_code, [foreach_node])
 | 
			
		||||
            self.codegen_comment([foreach_node])
 | 
			
		||||
            kernel_name = self.define_kernel(src_code, [combo_kernel_node])
 | 
			
		||||
            self.codegen_comment([combo_kernel_node])
 | 
			
		||||
            log.debug("ComboKernels: generated kernel %s.", kernel_name)
 | 
			
		||||
            kernel.call_kernel(V.graph.wrapper_code, kernel_name)
 | 
			
		||||
 | 
			
		||||
        self.scheduler.free_buffers()
 | 
			
		||||
@ -4062,6 +4088,134 @@ class TritonScheduling(BaseScheduling):
 | 
			
		||||
        store_cache()
 | 
			
		||||
        return ms, mod.__file__
 | 
			
		||||
 | 
			
		||||
    def benchmark_combo_kernel(self, node_list):
 | 
			
		||||
        from .triton_combo_kernel import ComboKernel
 | 
			
		||||
 | 
			
		||||
        def cache_file_path():
 | 
			
		||||
            assert mod.__file__ is not None
 | 
			
		||||
            return os.path.splitext(mod.__file__)[0] + ".kernel_perf"
 | 
			
		||||
 | 
			
		||||
        def load_cache():
 | 
			
		||||
            path = cache_file_path()
 | 
			
		||||
            if os.path.exists(path):
 | 
			
		||||
                with open(path) as fd:
 | 
			
		||||
                    return tuple(float(e) for e in fd.read().split())
 | 
			
		||||
            return (None, None)
 | 
			
		||||
 | 
			
		||||
        def store_cache():
 | 
			
		||||
            path = cache_file_path()
 | 
			
		||||
            with open(path, "w") as fd:
 | 
			
		||||
                fd.write(str(ms) + " " + str(ms_clone))
 | 
			
		||||
 | 
			
		||||
        subkernel_nodes = node_list
 | 
			
		||||
        fused_node_lists = [node.get_nodes() for node in subkernel_nodes]
 | 
			
		||||
        subkernel_map, node_schedule_map = {}, {}
 | 
			
		||||
        for pn, nodes in zip(subkernel_nodes, fused_node_lists):
 | 
			
		||||
            _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
 | 
			
		||||
            node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
 | 
			
		||||
            tiled_groups = self.select_tiling(node_schedule, numel, rnumel)
 | 
			
		||||
            (
 | 
			
		||||
                reduction_hint_val,
 | 
			
		||||
                mutations,
 | 
			
		||||
                index_dtype,
 | 
			
		||||
            ) = self.get_kernel_args(node_schedule, numel, rnumel)
 | 
			
		||||
            node_schedule_map[pn] = (node_schedule, tiled_groups, numel, rnumel)
 | 
			
		||||
            subkernel_map[pn] = ComboKernel.create_triton_kernel(
 | 
			
		||||
                *tiled_groups,
 | 
			
		||||
                reduction_hint=reduction_hint_val,
 | 
			
		||||
                mutations=mutations,
 | 
			
		||||
                index_dtype=index_dtype,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        partitions = ComboKernel.horizontal_partition(
 | 
			
		||||
            nodes=subkernel_nodes,
 | 
			
		||||
            triton_scheduling=self,
 | 
			
		||||
            kernel_map=subkernel_map,
 | 
			
		||||
            node_info_map=node_schedule_map,
 | 
			
		||||
            custom_algorithm=True,
 | 
			
		||||
        )
 | 
			
		||||
        log.debug(
 | 
			
		||||
            "ComboKernels: %d nodes partitioned into %s groups",
 | 
			
		||||
            len(subkernel_nodes),
 | 
			
		||||
            [len(p) for p in partitions],
 | 
			
		||||
        )
 | 
			
		||||
        total_ms, file_list = 0, []
 | 
			
		||||
        total_clone_ms = 0
 | 
			
		||||
        removed_buffers_orig = V.graph.removed_buffers
 | 
			
		||||
        V.graph.removed_buffers = set(removed_buffers_orig)
 | 
			
		||||
        inplaced_to_remove_orig = V.graph.inplaced_to_remove
 | 
			
		||||
        V.graph.inplaced_to_remove = set(inplaced_to_remove_orig)
 | 
			
		||||
        for node_group in partitions:
 | 
			
		||||
            fused_node_lists = [node.get_nodes() for node in node_group]
 | 
			
		||||
            kernel = ComboKernel()
 | 
			
		||||
            names = [n.get_names() for nodes in fused_node_lists for n in nodes]
 | 
			
		||||
 | 
			
		||||
            for pn, nodes in zip(node_group, fused_node_lists):
 | 
			
		||||
                # empty last_usage. May cause more aggressive 'evict_last'. Should be fine.
 | 
			
		||||
                for n in nodes:
 | 
			
		||||
                    n.last_usage = set()
 | 
			
		||||
 | 
			
		||||
                self.codegen_node_schedule_with_kernel(
 | 
			
		||||
                    node_schedule_map[pn][0],
 | 
			
		||||
                    kernel.create_sub_kernel(subkernel_map[pn]),
 | 
			
		||||
                )
 | 
			
		||||
                subkernel = subkernel_map[pn]
 | 
			
		||||
                V.graph.removed_buffers |= subkernel.removed_buffers
 | 
			
		||||
                V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove
 | 
			
		||||
            with config.patch("benchmark_kernel", True), V.set_kernel_handler(kernel):
 | 
			
		||||
                src_code = kernel.codegen_kernel()
 | 
			
		||||
 | 
			
		||||
            src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_")
 | 
			
		||||
            mod = PyCodeCache.load(src_code)
 | 
			
		||||
 | 
			
		||||
            log.debug(
 | 
			
		||||
                "kernel src code for %s written to: %s",
 | 
			
		||||
                names,
 | 
			
		||||
                mod.__file__,
 | 
			
		||||
            )
 | 
			
		||||
            ms, ms_clone = load_cache()
 | 
			
		||||
            if ms is not None:
 | 
			
		||||
                total_ms += ms
 | 
			
		||||
                total_clone_ms += ms_clone
 | 
			
		||||
                file_list.append(mod.__file__)
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            args = mod.get_args()
 | 
			
		||||
            call = mod.call
 | 
			
		||||
            wrapped_jit_function = mod.triton_
 | 
			
		||||
 | 
			
		||||
            # call once to trigger the compilation
 | 
			
		||||
            call(wrapped_jit_function.clone_args(*args)[0])
 | 
			
		||||
 | 
			
		||||
            launchers = wrapped_jit_function.launchers
 | 
			
		||||
            assert len(launchers) == 1
 | 
			
		||||
            if launchers[0].n_spills > 0:
 | 
			
		||||
                # skip benchmarking the kernel if there are register spills
 | 
			
		||||
                ms = ms_clone = float("inf")
 | 
			
		||||
            else:
 | 
			
		||||
                # We have to clone the inplace updated arguments to avoid earlier calls
 | 
			
		||||
                # generating out of range indices for later calls.
 | 
			
		||||
                ms = do_bench_gpu(
 | 
			
		||||
                    lambda: call(wrapped_jit_function.clone_args(*args)[0])
 | 
			
		||||
                )
 | 
			
		||||
                ms_clone = do_bench_gpu(
 | 
			
		||||
                    lambda: wrapped_jit_function.clone_args(*args)[0]
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            log.debug(
 | 
			
		||||
                "The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs",
 | 
			
		||||
                {n.get_name() for n in nodes},
 | 
			
		||||
                ms,
 | 
			
		||||
                ms_clone,
 | 
			
		||||
            )
 | 
			
		||||
            store_cache()
 | 
			
		||||
            total_ms += ms
 | 
			
		||||
            total_clone_ms += ms_clone
 | 
			
		||||
            file_list.append(mod.__file__)
 | 
			
		||||
        V.graph.removed_buffers = removed_buffers_orig
 | 
			
		||||
        V.graph.inplaced_to_remove = inplaced_to_remove_orig
 | 
			
		||||
        return total_ms, total_clone_ms, file_list
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class CandidateTiling:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										675
									
								
								torch/_inductor/codegen/triton_combo_kernel.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										675
									
								
								torch/_inductor/codegen/triton_combo_kernel.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,675 @@
 | 
			
		||||
import itertools
 | 
			
		||||
import logging
 | 
			
		||||
import textwrap
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
from typing import Dict, List, Tuple
 | 
			
		||||
 | 
			
		||||
from sympy import Integer
 | 
			
		||||
 | 
			
		||||
from .. import config, metrics
 | 
			
		||||
from ..runtime.hints import DeviceProperties
 | 
			
		||||
from ..runtime.runtime_utils import next_power_of_2
 | 
			
		||||
from ..runtime.triton_heuristics import grid_combo_kernels
 | 
			
		||||
from ..scheduler import FusedSchedulerNode
 | 
			
		||||
from ..utils import Placeholder
 | 
			
		||||
from ..virtualized import V
 | 
			
		||||
from .common import DeferredLine, IndentedBuffer, Kernel, PythonPrinter, SizeArg
 | 
			
		||||
from .triton import gen_common_triton_imports, TritonKernel
 | 
			
		||||
from .triton_utils import config_of, signature_to_meta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
pexpr = PythonPrinter().doprint
 | 
			
		||||
LARGE_NUMELS = 512e5
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _default_custom_combo_kernel_horizontal_partition(
 | 
			
		||||
    nodes, triton_scheduling, kernel_map, node_info_map
 | 
			
		||||
):
 | 
			
		||||
    """Horizontally partition the given list of nodes into a list of list of nodes where each sublist
 | 
			
		||||
    represents a partion. Nodes in different partitions are implemented in different combo kernels.
 | 
			
		||||
    Nodes in the same partition are likely to be implemented
 | 
			
		||||
    in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.
 | 
			
		||||
 | 
			
		||||
    Input arguments:
 | 
			
		||||
        nodes: a list of fused scheduler nodes to partition.
 | 
			
		||||
        triton_scheduling: TritonScheduling instance.
 | 
			
		||||
        kernel_map: a map from node to its kernel.
 | 
			
		||||
        node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel).
 | 
			
		||||
    Oputput:
 | 
			
		||||
        a list of list of nodes with each sublist representing a partition.
 | 
			
		||||
 | 
			
		||||
    The default algorithm is to partition nodes based on the following rules:
 | 
			
		||||
        1) nodes with the same number of block dimensions are grouped togather.
 | 
			
		||||
        2) large pointwise nodes are separated from other nodes.
 | 
			
		||||
        3) large reduce nodes are separated from other nodes.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    assert len(nodes) >= 1
 | 
			
		||||
 | 
			
		||||
    # first partition nodes based on number of block dimensions
 | 
			
		||||
    tilings = [node_info_map[n][1] for n in nodes]
 | 
			
		||||
 | 
			
		||||
    max_dims = max(len(t) for t in tilings)
 | 
			
		||||
    nodes_per_ndim = []
 | 
			
		||||
    for i in range(2, max_dims + 1):
 | 
			
		||||
        group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
 | 
			
		||||
        reduction = [
 | 
			
		||||
            n
 | 
			
		||||
            for n in group_per_dim
 | 
			
		||||
            if kernel_map[n].inside_reduction
 | 
			
		||||
            and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim)
 | 
			
		||||
        ]
 | 
			
		||||
        not_reduction = [n for n in group_per_dim if n not in reduction]
 | 
			
		||||
        # rnumel > 2048 usually has long execution time
 | 
			
		||||
        long_reduction = [n for n in reduction if n.group[-1][-1] > 2048]
 | 
			
		||||
        short_reduction = [n for n in reduction if n not in long_reduction]
 | 
			
		||||
        if long_reduction:
 | 
			
		||||
            log.warning(
 | 
			
		||||
                "ComboKernels: %d long reduction nodes are separated",
 | 
			
		||||
                len(long_reduction),
 | 
			
		||||
            )
 | 
			
		||||
        large_pointwise = [
 | 
			
		||||
            n
 | 
			
		||||
            for n in not_reduction
 | 
			
		||||
            if not kernel_map[n].inside_reduction
 | 
			
		||||
            and len(kernel_map[n].numels) == 2
 | 
			
		||||
            and V.graph.sizevars.size_hint(kernel_map[n].numels[0]) > LARGE_NUMELS
 | 
			
		||||
        ]
 | 
			
		||||
        if large_pointwise:
 | 
			
		||||
            # TODO benchmark the performance when large pointwise nodes combining with others
 | 
			
		||||
            log.warning(
 | 
			
		||||
                "ComboKernels: %d large pointwise nodes are separated",
 | 
			
		||||
                len(large_pointwise),
 | 
			
		||||
            )
 | 
			
		||||
            not_reduction = [n for n in not_reduction if n not in large_pointwise]
 | 
			
		||||
            for node in large_pointwise:
 | 
			
		||||
                nodes_per_ndim.append([node])
 | 
			
		||||
 | 
			
		||||
        for g in (not_reduction, short_reduction, long_reduction):
 | 
			
		||||
            if g:
 | 
			
		||||
                nodes_per_ndim.append(g)
 | 
			
		||||
 | 
			
		||||
    assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
 | 
			
		||||
    return nodes_per_ndim
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_custom_combo_kernel_horizontal_partition_algorithm = (
 | 
			
		||||
    _default_custom_combo_kernel_horizontal_partition
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def set_custom_combo_kernel_horizontal_partition(algorithm):
 | 
			
		||||
    """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
 | 
			
		||||
    are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
 | 
			
		||||
    in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args.
 | 
			
		||||
 | 
			
		||||
    The algorithm should take a list of nodes and return a list of list of nodes.
 | 
			
		||||
 | 
			
		||||
    The default algorithm is to partition nodes based on number of block dimensions.
 | 
			
		||||
    """
 | 
			
		||||
    global _custom_combo_kernel_horizontal_partition_algorithm
 | 
			
		||||
    _custom_combo_kernel_horizontal_partition_algorithm = algorithm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PartitionState:
 | 
			
		||||
    partitions: List[List[FusedSchedulerNode]]
 | 
			
		||||
    cur_partition: List[FusedSchedulerNode]
 | 
			
		||||
    cur_count: int
 | 
			
		||||
 | 
			
		||||
    def finalize(self):
 | 
			
		||||
        if self.cur_partition:
 | 
			
		||||
            self.partitions.append(self.cur_partition)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ComboKernel(Kernel):
 | 
			
		||||
    MAX_NUM_ARGS = 250  # number where I would no longer get triton errors
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _update_partition(partition_state, node_rw_count, node_info):
 | 
			
		||||
        if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS:
 | 
			
		||||
            partition_state.partitions.append(partition_state.cur_partition)
 | 
			
		||||
            partition_state.cur_partition = [node_info]
 | 
			
		||||
            partition_state.cur_count = node_rw_count
 | 
			
		||||
        else:
 | 
			
		||||
            partition_state.cur_count += node_rw_count
 | 
			
		||||
            partition_state.cur_partition.append(node_info)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _base_horizontal_partition(subkernel_nodes, triton_scheduling, node_info_map):
 | 
			
		||||
        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
 | 
			
		||||
        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
 | 
			
		||||
        (read/writes) and to have the same 2D or 1D blocking strategy."""
 | 
			
		||||
        # TODO support combination of kernels with different block dimensions
 | 
			
		||||
        assert len(subkernel_nodes) >= 1
 | 
			
		||||
 | 
			
		||||
        ndim_to_partition_state: Dict[int, PartitionState] = defaultdict(
 | 
			
		||||
            lambda: PartitionState([], [], 0)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for node in subkernel_nodes:
 | 
			
		||||
            node_schedule, tiled_groups, numel, rnumel = node_info_map[node]
 | 
			
		||||
            node_info = node
 | 
			
		||||
 | 
			
		||||
            read_writes = node.read_writes
 | 
			
		||||
            read_write_count = len(read_writes.reads) + len(read_writes.writes)
 | 
			
		||||
 | 
			
		||||
            ndim = len(tiled_groups)
 | 
			
		||||
            partition_state = ndim_to_partition_state[ndim]
 | 
			
		||||
            ComboKernel._update_partition(partition_state, read_write_count, node_info)
 | 
			
		||||
 | 
			
		||||
        all_partitions = []
 | 
			
		||||
        for partition_state in ndim_to_partition_state.values():
 | 
			
		||||
            partition_state.finalize()
 | 
			
		||||
            all_partitions.extend(partition_state.partitions)
 | 
			
		||||
 | 
			
		||||
        return all_partitions
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def horizontal_partition(
 | 
			
		||||
        nodes, triton_scheduling, kernel_map, node_info_map, custom_algorithm=False
 | 
			
		||||
    ):
 | 
			
		||||
        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum)
 | 
			
		||||
        for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into
 | 
			
		||||
        sublists in the following way:
 | 
			
		||||
            1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True
 | 
			
		||||
            2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is
 | 
			
		||||
               guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same
 | 
			
		||||
               2D or 1D blocking strategy.
 | 
			
		||||
        """
 | 
			
		||||
        if custom_algorithm:
 | 
			
		||||
            raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm(
 | 
			
		||||
                nodes, triton_scheduling, kernel_map, node_info_map
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raw_partitions = [nodes]
 | 
			
		||||
 | 
			
		||||
        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
 | 
			
		||||
        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
 | 
			
		||||
        (read/writes) and to have the same 2D or 1D blocking strategy."""
 | 
			
		||||
        all_partitions = []
 | 
			
		||||
        for raw_partition in raw_partitions:
 | 
			
		||||
            all_partitions.extend(
 | 
			
		||||
                ComboKernel._base_horizontal_partition(
 | 
			
		||||
                    raw_partition, triton_scheduling, node_info_map
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        return all_partitions
 | 
			
		||||
 | 
			
		||||
    def __init__(self, use_custom_partition_algo=False):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.sub_kernels = []
 | 
			
		||||
        self.iter_vars_count = itertools.count()
 | 
			
		||||
        self.block_count = 0
 | 
			
		||||
        self.grids = []
 | 
			
		||||
        self.min_x_blocks_list = []
 | 
			
		||||
        self.use_custom_partition_algo = use_custom_partition_algo
 | 
			
		||||
 | 
			
		||||
    def codegen_pid_range(self, code, num):
 | 
			
		||||
        num_kernels = len(self.sub_kernels)
 | 
			
		||||
        if self.block_count == 0:
 | 
			
		||||
            cond = "if"
 | 
			
		||||
        else:
 | 
			
		||||
            cond = "elif"
 | 
			
		||||
        code.splice(f"{cond} pid % {num_kernels} == {num}:")
 | 
			
		||||
        with code.indent():
 | 
			
		||||
            code.splice(f"pid_offset = pid // {num_kernels}")
 | 
			
		||||
            self.block_count += 1
 | 
			
		||||
 | 
			
		||||
    def create_sub_kernel(self, triton_kernel):
 | 
			
		||||
        sub_kernel = triton_kernel
 | 
			
		||||
        metrics.generated_kernel_count -= 1
 | 
			
		||||
        sub_kernel.args = self.args
 | 
			
		||||
        sub_kernel.iter_vars_count = self.iter_vars_count
 | 
			
		||||
        sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
 | 
			
		||||
        self.sub_kernels.append(sub_kernel)
 | 
			
		||||
        return sub_kernel
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def create_triton_kernel(*groups, index_dtype, mutations, reduction_hint):
 | 
			
		||||
        return TritonKernel(
 | 
			
		||||
            *groups,
 | 
			
		||||
            index_dtype=index_dtype,
 | 
			
		||||
            mutations=mutations,
 | 
			
		||||
            pid_cache={"tl.program_id(0)": "pid_offset"},
 | 
			
		||||
            reduction_hint=reduction_hint,
 | 
			
		||||
            optimize_mask=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def codegen_static_numels_sub_kernel(self, code, sub_kernel, num):
 | 
			
		||||
        """
 | 
			
		||||
        We get a small speedup from hard coding numels if they are static.
 | 
			
		||||
 | 
			
		||||
        This code stomps on the passed-in values by writing an constant to the top of the kernel.
 | 
			
		||||
 | 
			
		||||
        In a kernel like:
 | 
			
		||||
        def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
 | 
			
		||||
 | 
			
		||||
        We would add
 | 
			
		||||
        xnumel = 4096
 | 
			
		||||
        rnumel = 768
 | 
			
		||||
 | 
			
		||||
        After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
 | 
			
		||||
        a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
 | 
			
		||||
        knows that its a static numel, as that you just plop a constant into the kernel.
 | 
			
		||||
        """
 | 
			
		||||
        grid = []
 | 
			
		||||
        uniquify_block_sizes = []
 | 
			
		||||
        for tree in sub_kernel.range_trees:
 | 
			
		||||
            simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
 | 
			
		||||
            code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}")
 | 
			
		||||
 | 
			
		||||
            if tree.prefix != "r":
 | 
			
		||||
                grid.append(int(simplified_tree_numel))
 | 
			
		||||
 | 
			
		||||
            if tree.prefix == "r" and sub_kernel.persistent_reduction:
 | 
			
		||||
                if isinstance(simplified_tree_numel, (Integer, int)):
 | 
			
		||||
                    val = int(simplified_tree_numel)
 | 
			
		||||
                else:
 | 
			
		||||
                    continue
 | 
			
		||||
                val = next_power_of_2(val)
 | 
			
		||||
                code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}")
 | 
			
		||||
                uniquify_block_sizes.append("RBLOCK")
 | 
			
		||||
 | 
			
		||||
            if tree.prefix == "x" and sub_kernel.no_x_dim:
 | 
			
		||||
                code.writeline(f"XBLOCK_{num}: tl.constexpr = 1")
 | 
			
		||||
                uniquify_block_sizes.append("XBLOCK")
 | 
			
		||||
        self.grids.append(grid)
 | 
			
		||||
        return uniquify_block_sizes
 | 
			
		||||
 | 
			
		||||
    def min_x_blocks_sub_kernel(self, sub_kernel, num):
 | 
			
		||||
        """
 | 
			
		||||
        We get a small speedup from hard coding numels if they are static.
 | 
			
		||||
 | 
			
		||||
        This code stomps on the passed-in values by writing an constant to the top of the kernel.
 | 
			
		||||
 | 
			
		||||
        In a kernel like:
 | 
			
		||||
        def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
 | 
			
		||||
 | 
			
		||||
        We would add
 | 
			
		||||
        xnumel = 4096
 | 
			
		||||
        rnumel = 768
 | 
			
		||||
 | 
			
		||||
        After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes
 | 
			
		||||
        a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream
 | 
			
		||||
        knows that its a static numel, as that you just plop a constant into the kernel.
 | 
			
		||||
        """
 | 
			
		||||
        min_x_blocks = 0
 | 
			
		||||
        for tree in sub_kernel.range_trees:
 | 
			
		||||
            simplified_tree_numel = V.graph.sizevars.simplify(tree.numel)
 | 
			
		||||
            if tree.prefix == "x" and sub_kernel.no_x_dim:
 | 
			
		||||
                min_x_blocks = int(simplified_tree_numel)
 | 
			
		||||
        self.min_x_blocks_list.append(min_x_blocks)
 | 
			
		||||
 | 
			
		||||
    def select_heuristics(self, sub_kernel) -> Tuple[str, List[int]]:
 | 
			
		||||
        size_hints = [
 | 
			
		||||
            next_power_of_2(V.graph.sizevars.size_hint(numel))
 | 
			
		||||
            for numel in sub_kernel.numels
 | 
			
		||||
        ]
 | 
			
		||||
        if sub_kernel.persistent_reduction:
 | 
			
		||||
            assert sub_kernel.inside_reduction
 | 
			
		||||
            heuristics = "persistent_reduction"
 | 
			
		||||
        elif sub_kernel.inside_reduction:
 | 
			
		||||
            heuristics = "reduction"
 | 
			
		||||
        else:
 | 
			
		||||
            size_hints.pop()
 | 
			
		||||
            heuristics = "pointwise"
 | 
			
		||||
        return heuristics, size_hints
 | 
			
		||||
 | 
			
		||||
    def select_combo_heuristics(self, heuristics_list, size_hints_list):
 | 
			
		||||
        if "reduction" in heuristics_list:
 | 
			
		||||
            i, _ = max(
 | 
			
		||||
                enumerate(size_hints_list),
 | 
			
		||||
                key=lambda x: x[1][0] if heuristics_list[x[0]] == "reduction" else 0,
 | 
			
		||||
            )
 | 
			
		||||
            return heuristics_list[i], size_hints_list[i], self.sub_kernels[i]
 | 
			
		||||
        elif "pointwise" in heuristics_list:
 | 
			
		||||
            i, _ = max(
 | 
			
		||||
                enumerate(size_hints_list),
 | 
			
		||||
                key=lambda x: x[1][0] if heuristics_list[x[0]] == "pointwise" else 0,
 | 
			
		||||
            )
 | 
			
		||||
            # modify size_hint to avoid oom check fail (may be a false alarm)
 | 
			
		||||
            num_pointwise = len([e for e in heuristics_list if e == "pointwise"])
 | 
			
		||||
            num_reduction = len([e for e in heuristics_list if e == "reduction"])
 | 
			
		||||
            num_persistent_reduction = len(
 | 
			
		||||
                [e for e in heuristics_list if e == "persistent_reduction"]
 | 
			
		||||
            )
 | 
			
		||||
            assert (
 | 
			
		||||
                num_reduction == 0
 | 
			
		||||
            ), "combining pointwise and reduction are not supported yet."
 | 
			
		||||
            heuristics = (
 | 
			
		||||
                "pointwise_with_reduction"
 | 
			
		||||
                if num_persistent_reduction > 0
 | 
			
		||||
                else "pointwise"
 | 
			
		||||
            )
 | 
			
		||||
            if len(heuristics_list) - num_pointwise >= 4:
 | 
			
		||||
                size_hints = size_hints_list[i]
 | 
			
		||||
                size_hints[0] = min(128, size_hints[0])
 | 
			
		||||
            return heuristics, size_hints_list[i], self.sub_kernels[i]
 | 
			
		||||
        else:
 | 
			
		||||
            return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]
 | 
			
		||||
 | 
			
		||||
    def get_mutated_args_sub_kernels(self) -> List[str]:
 | 
			
		||||
        mutated_args = set()
 | 
			
		||||
        for sub_kernel in self.sub_kernels:
 | 
			
		||||
            for mutation in sub_kernel.mutations:
 | 
			
		||||
                if mutation in sub_kernel.args.input_buffers:
 | 
			
		||||
                    mutated_args.add(sub_kernel.args.input_buffers[mutation])
 | 
			
		||||
                if (
 | 
			
		||||
                    mutation in sub_kernel.args.inplace_buffers
 | 
			
		||||
                    and mutation not in V.graph.removed_buffers
 | 
			
		||||
                    and mutation not in sub_kernel.removed_buffers
 | 
			
		||||
                ):
 | 
			
		||||
                    mutated_args.add(
 | 
			
		||||
                        sub_kernel.args.inplace_buffers[mutation].inner_name
 | 
			
		||||
                    )
 | 
			
		||||
                if mutation in sub_kernel.args.output_buffers:
 | 
			
		||||
                    mutated_args.add(sub_kernel.args.output_buffers[mutation])
 | 
			
		||||
        return sorted(mutated_args)
 | 
			
		||||
 | 
			
		||||
    def jit_line(
 | 
			
		||||
        self, heuristics, size_hints, selected_kernel, pointwise_with_reduce=False
 | 
			
		||||
    ):
 | 
			
		||||
        # TODO: is it correct to use the first sub kernel's heuristics?
 | 
			
		||||
        _, _, signature = self.args.python_argdefs()
 | 
			
		||||
        # TODO Is it ok to just use sub_kernel[0].index_dtype?
 | 
			
		||||
        index_dtype = self.sub_kernels[0].index_dtype
 | 
			
		||||
        for i, sub in enumerate(self.sub_kernels):
 | 
			
		||||
            self.min_x_blocks_sub_kernel(sub, i)
 | 
			
		||||
        triton_meta = {
 | 
			
		||||
            "signature": signature_to_meta(signature, size_dtype=index_dtype),
 | 
			
		||||
            "device": DeviceProperties.create(V.graph.scheduler.current_device),
 | 
			
		||||
            "constants": {},
 | 
			
		||||
        }
 | 
			
		||||
        triton_meta["configs"] = [config_of(signature)]
 | 
			
		||||
        mutated_args = self.get_mutated_args_sub_kernels()
 | 
			
		||||
        inductor_meta = {
 | 
			
		||||
            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
 | 
			
		||||
            "mutated_arg_names": mutated_args,
 | 
			
		||||
            **TritonKernel.inductor_meta_common(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        sub_kernel = selected_kernel
 | 
			
		||||
        if sub_kernel.inside_reduction:
 | 
			
		||||
            reduction_hint = sub_kernel.reduction_hint
 | 
			
		||||
            heuristics_line = f"""
 | 
			
		||||
                @triton_heuristics.{heuristics}(
 | 
			
		||||
                    size_hints={size_hints!r},
 | 
			
		||||
                    reduction_hint={reduction_hint},
 | 
			
		||||
                    filename=__file__,
 | 
			
		||||
                    triton_meta={triton_meta!r},
 | 
			
		||||
                    inductor_meta={inductor_meta}
 | 
			
		||||
                )
 | 
			
		||||
                @triton.jit
 | 
			
		||||
            """
 | 
			
		||||
        else:
 | 
			
		||||
            tile_hint = ""
 | 
			
		||||
            if len(size_hints) == 2:
 | 
			
		||||
                # TODO only input, output and 2 args
 | 
			
		||||
                tile_hint = "tile_hint=TileHint.SQUARE,"
 | 
			
		||||
            else:
 | 
			
		||||
                tile_hint = "tile_hint=TileHint.DEFAULT,"
 | 
			
		||||
            heuristics_line = f"""
 | 
			
		||||
                @triton_heuristics.{heuristics}(
 | 
			
		||||
                    size_hints={size_hints!r}, {tile_hint}
 | 
			
		||||
                    filename=__file__,
 | 
			
		||||
                    triton_meta={triton_meta!r},
 | 
			
		||||
                    inductor_meta={inductor_meta!r}
 | 
			
		||||
                )
 | 
			
		||||
                @triton.jit
 | 
			
		||||
            """
 | 
			
		||||
 | 
			
		||||
        return heuristics_line
 | 
			
		||||
 | 
			
		||||
    def add_blockd_to_args(self, argdefs):
 | 
			
		||||
        block_args = {}
 | 
			
		||||
        for num, sub_kernel in enumerate(self.sub_kernels):
 | 
			
		||||
            # TODO: we assume all sub_kernels have the same block size
 | 
			
		||||
            for tree in sub_kernel.range_trees:
 | 
			
		||||
                if tree.prefix == "r" and (
 | 
			
		||||
                    not sub_kernel.inside_reduction or sub_kernel.persistent_reduction
 | 
			
		||||
                ):
 | 
			
		||||
                    continue
 | 
			
		||||
                if tree.prefix == "x" and sub_kernel.no_x_dim:
 | 
			
		||||
                    continue
 | 
			
		||||
                # argdefs.append(f"{tree.prefix.upper()}BLOCK_{num} : tl.constexpr")
 | 
			
		||||
                block_args[f"{tree.prefix.upper()}BLOCK : tl.constexpr"] = tree.prefix
 | 
			
		||||
        for arg in block_args:
 | 
			
		||||
            argdefs.append(arg)
 | 
			
		||||
        return argdefs
 | 
			
		||||
 | 
			
		||||
    def codegen_kernel(self, name=None):
 | 
			
		||||
        # TODO: is it correct to use the first sub kernel's heuristics?
 | 
			
		||||
        heuristics_list, size_hints_list = [], []
 | 
			
		||||
        for subkernel in self.sub_kernels:
 | 
			
		||||
            h, s = self.select_heuristics(subkernel)
 | 
			
		||||
            heuristics_list.append(h)
 | 
			
		||||
            size_hints_list.append(s)
 | 
			
		||||
        heuristics, size_hints, selected_kernel = self.select_combo_heuristics(
 | 
			
		||||
            heuristics_list, size_hints_list
 | 
			
		||||
        )
 | 
			
		||||
        pointwise_with_reduction, heuristics = (
 | 
			
		||||
            (True, "pointwise")
 | 
			
		||||
            if heuristics == "pointwise_with_reduction"
 | 
			
		||||
            else (False, heuristics)
 | 
			
		||||
        )
 | 
			
		||||
        code = IndentedBuffer()
 | 
			
		||||
 | 
			
		||||
        code.splice(gen_common_triton_imports())
 | 
			
		||||
        if config.benchmark_combo_kernel:
 | 
			
		||||
            code.splice(self.imports_for_benchmark_kernel())
 | 
			
		||||
 | 
			
		||||
        argdefs, _, _ = self.args.python_argdefs()
 | 
			
		||||
        argdefs = self.add_blockd_to_args(argdefs)
 | 
			
		||||
        code.splice(
 | 
			
		||||
            self.jit_line(
 | 
			
		||||
                heuristics,
 | 
			
		||||
                size_hints,
 | 
			
		||||
                selected_kernel,
 | 
			
		||||
                pointwise_with_reduce=pointwise_with_reduction,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        code.writeline(
 | 
			
		||||
            f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with code.indent():
 | 
			
		||||
            code.splice("pid = tl.program_id(0)")
 | 
			
		||||
 | 
			
		||||
            for num, sub_kernel in enumerate(self.sub_kernels):
 | 
			
		||||
                self.codegen_pid_range(code, num)
 | 
			
		||||
                with code.indent():
 | 
			
		||||
                    uniquify = self.codegen_static_numels_sub_kernel(
 | 
			
		||||
                        code, sub_kernel, num
 | 
			
		||||
                    )
 | 
			
		||||
                    sub_kernel.codegen_body()
 | 
			
		||||
                    uniquified_body = self.uniquify_block_sizes(
 | 
			
		||||
                        sub_kernel.body, num, uniquify
 | 
			
		||||
                    )
 | 
			
		||||
                    code.splice(uniquified_body)
 | 
			
		||||
 | 
			
		||||
            code.splice("else:")
 | 
			
		||||
            with code.indent():
 | 
			
		||||
                code.splice("pass")
 | 
			
		||||
 | 
			
		||||
        if config.benchmark_combo_kernel:
 | 
			
		||||
            code.splice(self.codegen_kernel_benchmark(num_gb=0))
 | 
			
		||||
 | 
			
		||||
        return code.getvalue()
 | 
			
		||||
 | 
			
		||||
    def codegen_kernel_benchmark(self, num_gb, grid=None):
 | 
			
		||||
        result = IndentedBuffer()
 | 
			
		||||
        argdefs, call_args, signature = self.args.python_argdefs()
 | 
			
		||||
 | 
			
		||||
        result.writelines(["", "", "def get_args():"])
 | 
			
		||||
        with result.indent():
 | 
			
		||||
            name_cnt = itertools.count()
 | 
			
		||||
            var_names = []
 | 
			
		||||
            for arg_name, arg_sig in zip(call_args, signature):
 | 
			
		||||
                var_name = f"arg_{next(name_cnt)}"
 | 
			
		||||
                buf = V.graph.get_buffer(arg_name)
 | 
			
		||||
                if buf:
 | 
			
		||||
                    result.writeline(
 | 
			
		||||
                        f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})"  # noqa: B950 line too long
 | 
			
		||||
                    )
 | 
			
		||||
                elif arg_name in V.graph.constants:
 | 
			
		||||
                    # note that random seed is put in V.graph.constants
 | 
			
		||||
                    const_tensor = V.graph.constants[arg_name]
 | 
			
		||||
                    result.writeline(
 | 
			
		||||
                        f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})"  # type: ignore[arg-type]  # noqa: B950 line too long
 | 
			
		||||
                    )
 | 
			
		||||
                elif isinstance(arg_sig, SizeArg):
 | 
			
		||||
                    symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
 | 
			
		||||
 | 
			
		||||
                    # Force the seed_offset to be 0 so calls to the same kernel
 | 
			
		||||
                    # using different seed offset will have the same benchmark harness.
 | 
			
		||||
                    # We can dedup kernel definitions in this case.
 | 
			
		||||
                    if "seed_offset" in arg_sig.name:
 | 
			
		||||
                        symval_hint = 0
 | 
			
		||||
                    result.writeline(f"{var_name} = {symval_hint}")
 | 
			
		||||
                else:
 | 
			
		||||
                    raise KeyError(
 | 
			
		||||
                        f"Don't find the buffer or const tensor for {arg_name}"
 | 
			
		||||
                    )
 | 
			
		||||
                var_names.append(var_name)
 | 
			
		||||
            result.writeline(f"return {', '.join(var_names)},")
 | 
			
		||||
 | 
			
		||||
        result.writelines(["\n", "\n", "def call(args):"])
 | 
			
		||||
        if grid is None:
 | 
			
		||||
            grid = self.grid(self.grids)
 | 
			
		||||
            grid_str = ", ".join(pexpr(item) for item in grid)
 | 
			
		||||
            grid_extra_kwargs = f"num_kernels={len(self.sub_kernels)}, min_blocks={max(self.min_x_blocks_list) * len(self.sub_kernels)}"
 | 
			
		||||
            grid_str = f"{grid_str}, {grid_extra_kwargs}"
 | 
			
		||||
            grid_arg = f"grid=grid_combo_kernels({grid_str})"
 | 
			
		||||
        else:
 | 
			
		||||
            grid_arg = f"grid={grid}"
 | 
			
		||||
        index = V.graph.scheduler.current_device.index
 | 
			
		||||
        with result.indent():
 | 
			
		||||
            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
 | 
			
		||||
            with result.indent():
 | 
			
		||||
                result.writeline(
 | 
			
		||||
                    V.graph.device_ops.set_device(index)
 | 
			
		||||
                )  # no-op to ensure context
 | 
			
		||||
                stream_name = f"stream{index}"
 | 
			
		||||
                result.writeline(f"{stream_name} = get_raw_stream({index})")
 | 
			
		||||
                result.writeline(
 | 
			
		||||
                    f"{str(Placeholder.KERNEL_NAME)}.run(*args, {grid_arg}, stream={stream_name})"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # benchmark all configs
 | 
			
		||||
        result.writelines(["\n", "\n", "def benchmark_all_configs(args):"])
 | 
			
		||||
        with result.indent():
 | 
			
		||||
            result.writeline(f"with {V.graph.device_ops.device_guard(index)}:")
 | 
			
		||||
            with result.indent():
 | 
			
		||||
                result.writeline(
 | 
			
		||||
                    V.graph.device_ops.set_device(index)
 | 
			
		||||
                )  # no-op to ensure context
 | 
			
		||||
                result.writeline(
 | 
			
		||||
                    f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args, {grid_arg})"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        result.writelines(["\n", "\n", "if __name__ == '__main__':"])
 | 
			
		||||
        with result.indent():
 | 
			
		||||
            result.writeline("from triton.testing import do_bench")
 | 
			
		||||
            result.writeline("")
 | 
			
		||||
 | 
			
		||||
            result.writeline("args = get_args()")
 | 
			
		||||
            result.writeline(
 | 
			
		||||
                "ms = do_bench(lambda: call(args), rep=40, fast_flush=True)"
 | 
			
		||||
            )
 | 
			
		||||
            result.writeline(f"num_gb = {num_gb}")
 | 
			
		||||
            result.writeline("gb_per_s = num_gb / (ms / 1e3)")
 | 
			
		||||
            result.writeline(
 | 
			
		||||
                'print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def imports_for_benchmark_kernel(self):
 | 
			
		||||
        return textwrap.dedent(
 | 
			
		||||
            """
 | 
			
		||||
            from torch._dynamo.testing import rand_strided
 | 
			
		||||
            {}
 | 
			
		||||
            import torch
 | 
			
		||||
            from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
 | 
			
		||||
        """.format(
 | 
			
		||||
                V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def uniquify_block_sizes(
 | 
			
		||||
        self, code: IndentedBuffer, num_kernel, uniquify: List[str]
 | 
			
		||||
    ) -> IndentedBuffer:
 | 
			
		||||
        if not uniquify:
 | 
			
		||||
            return code
 | 
			
		||||
        modified = IndentedBuffer(initial_indent=code._indent)
 | 
			
		||||
        for line in code._lines:
 | 
			
		||||
            if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]):
 | 
			
		||||
                modified_line = line
 | 
			
		||||
                for block in blocks:
 | 
			
		||||
                    modified_line = modified_line.replace(
 | 
			
		||||
                        block, f"{block}_{num_kernel}"
 | 
			
		||||
                    )
 | 
			
		||||
                modified.writeline(modified_line)
 | 
			
		||||
            elif isinstance(line, DeferredLine) and (
 | 
			
		||||
                blocks := [e for e in uniquify if e in line.line]
 | 
			
		||||
            ):
 | 
			
		||||
                modified_line = line.line
 | 
			
		||||
                for block in blocks:
 | 
			
		||||
                    modified_line = modified_line.replace(
 | 
			
		||||
                        block, f"{block}_{num_kernel}"
 | 
			
		||||
                    )
 | 
			
		||||
                new_line = DeferredLine(line.name, modified_line)
 | 
			
		||||
                modified.writeline(new_line)
 | 
			
		||||
            else:
 | 
			
		||||
                modified.writeline(line)
 | 
			
		||||
        return modified
 | 
			
		||||
 | 
			
		||||
    def call_kernel(self, code, name: str):
 | 
			
		||||
        _, call_args, _ = self.args.python_argdefs()
 | 
			
		||||
        # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
 | 
			
		||||
        for i in range(len(call_args)):
 | 
			
		||||
            if V.graph.is_unspec_arg(call_args[i]):
 | 
			
		||||
                call_args[i] = call_args[i] + ".item()"
 | 
			
		||||
 | 
			
		||||
        wrapper = V.graph.wrapper_code
 | 
			
		||||
        grid = self.grid(self.grids)
 | 
			
		||||
        grid = wrapper.generate_default_grid(
 | 
			
		||||
            name,
 | 
			
		||||
            grid,
 | 
			
		||||
            grid_callable=grid_combo_kernels,
 | 
			
		||||
            num_kernels=len(self.sub_kernels),
 | 
			
		||||
            min_blocks=max(self.min_x_blocks_list) * len(self.sub_kernels),
 | 
			
		||||
        )
 | 
			
		||||
        wrapper.generate_kernel_call(
 | 
			
		||||
            name,
 | 
			
		||||
            call_args,
 | 
			
		||||
            grid,
 | 
			
		||||
            V.graph.scheduler.current_device.index,
 | 
			
		||||
            cuda=True,
 | 
			
		||||
            triton=True,
 | 
			
		||||
            grid_fn="grid_combo_kernels",
 | 
			
		||||
            grid_extra_kwargs=f"num_kernels={len(self.sub_kernels)}, min_blocks={max(self.min_x_blocks_list) * len(self.sub_kernels)}",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def grid(self, sub_kernel_numels):
 | 
			
		||||
        xnumel = [e[-1] if len(e) > 0 else None for e in sub_kernel_numels]
 | 
			
		||||
        ynumel = [e[-2] if len(e) > 1 else None for e in sub_kernel_numels]
 | 
			
		||||
        znumel = [e[-3] if len(e) > 2 else None for e in sub_kernel_numels]
 | 
			
		||||
 | 
			
		||||
        # TODO: improve 1d/2d mixed cases
 | 
			
		||||
        xnumel = None if any(e is None for e in xnumel) else max(xnumel)
 | 
			
		||||
        ynumel = None if any(e is None for e in ynumel) else max(ynumel)
 | 
			
		||||
        znumel = None if any(e is None for e in znumel) else max(znumel)
 | 
			
		||||
 | 
			
		||||
        numels = (
 | 
			
		||||
            (xnumel,)
 | 
			
		||||
            if not ynumel
 | 
			
		||||
            else (ynumel, xnumel)
 | 
			
		||||
            if not znumel
 | 
			
		||||
            else (znumel, ynumel, xnumel)
 | 
			
		||||
        )
 | 
			
		||||
        return numels
 | 
			
		||||
@ -1,248 +0,0 @@
 | 
			
		||||
import itertools
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Dict, List, Tuple
 | 
			
		||||
 | 
			
		||||
from sympy import Integer
 | 
			
		||||
 | 
			
		||||
from .. import metrics
 | 
			
		||||
from ..runtime.hints import DeviceProperties
 | 
			
		||||
from ..scheduler import SchedulerNode
 | 
			
		||||
from ..utils import ceildiv, Placeholder
 | 
			
		||||
from ..virtualized import V
 | 
			
		||||
from .common import IndentedBuffer, Kernel
 | 
			
		||||
from .triton import gen_common_triton_imports, TritonKernel
 | 
			
		||||
from .triton_utils import config_of, signature_to_meta
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PartitionState:
 | 
			
		||||
    partitions: List[
 | 
			
		||||
        List[Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]]
 | 
			
		||||
    ]
 | 
			
		||||
    cur_partition: List[
 | 
			
		||||
        Tuple[List[SchedulerNode], Tuple[Integer, ...], Integer, Integer]
 | 
			
		||||
    ]
 | 
			
		||||
    cur_count: int
 | 
			
		||||
 | 
			
		||||
    def finalize(self):
 | 
			
		||||
        if self.cur_partition:
 | 
			
		||||
            self.partitions.append(self.cur_partition)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ForeachKernel(Kernel):
 | 
			
		||||
    MAX_NUM_ARGS = 250  # number where I would no longer get triton errors
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _update_partition(partition_state, node_rw_count, node_info):
 | 
			
		||||
        if partition_state.cur_count + node_rw_count > ForeachKernel.MAX_NUM_ARGS:
 | 
			
		||||
            partition_state.partitions.append(partition_state.cur_partition)
 | 
			
		||||
            partition_state.cur_partition = [node_info]
 | 
			
		||||
            partition_state.cur_count = node_rw_count
 | 
			
		||||
        else:
 | 
			
		||||
            partition_state.cur_count += node_rw_count
 | 
			
		||||
            partition_state.cur_partition.append(node_info)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def horizontal_partition(subkernel_nodes, triton_scheduling):
 | 
			
		||||
        """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel)
 | 
			
		||||
        for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args
 | 
			
		||||
        (read/writes) and to have the same 2D or 1D blocking strategy."""
 | 
			
		||||
        assert len(subkernel_nodes) >= 1
 | 
			
		||||
 | 
			
		||||
        partition_state_1d = PartitionState([], [], 0)
 | 
			
		||||
        yelem_to_partition_state_2d: Dict[Integer, PartitionState] = defaultdict(
 | 
			
		||||
            lambda: PartitionState([], [], 0)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        for node in subkernel_nodes:
 | 
			
		||||
            fused_nodes = node.get_nodes()
 | 
			
		||||
            _, (numel, rnumel) = max(
 | 
			
		||||
                fused_nodes, key=lambda x: int(x.is_reduction())
 | 
			
		||||
            ).group
 | 
			
		||||
            tiled_groups = triton_scheduling.select_tiling(fused_nodes, numel, rnumel)
 | 
			
		||||
            node_info = fused_nodes, tiled_groups, numel, rnumel
 | 
			
		||||
 | 
			
		||||
            read_writes = node.read_writes
 | 
			
		||||
            read_write_count = len(read_writes.reads) + len(read_writes.writes)
 | 
			
		||||
 | 
			
		||||
            if tiled_groups[1] == 1:
 | 
			
		||||
                ForeachKernel._update_partition(
 | 
			
		||||
                    partition_state_1d, read_write_count, node_info
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                y_elem = tiled_groups[0]
 | 
			
		||||
                partition_state_2d = yelem_to_partition_state_2d[y_elem]
 | 
			
		||||
                ForeachKernel._update_partition(
 | 
			
		||||
                    partition_state_2d, read_write_count, node_info
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        partition_state_1d.finalize()
 | 
			
		||||
        all_partitions = partition_state_1d.partitions
 | 
			
		||||
        for partition_state_2d in yelem_to_partition_state_2d.values():
 | 
			
		||||
            partition_state_2d.finalize()
 | 
			
		||||
            all_partitions.extend(partition_state_2d.partitions)
 | 
			
		||||
 | 
			
		||||
        return all_partitions
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.blocking_2d = False
 | 
			
		||||
        self.block_size_1d = 1024  # Try tuning this value
 | 
			
		||||
        self.block_size_2d = 32
 | 
			
		||||
        self.num_warps = 8
 | 
			
		||||
        self.sub_kernels = []
 | 
			
		||||
        self.iter_vars_count = itertools.count()
 | 
			
		||||
        self.x_block_count = 0
 | 
			
		||||
        self.y_block_count = 0
 | 
			
		||||
 | 
			
		||||
    def get_block_size(self):
 | 
			
		||||
        if self.blocking_2d:
 | 
			
		||||
            return self.block_size_2d
 | 
			
		||||
        else:
 | 
			
		||||
            return self.block_size_1d
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def codegen_pid_offsets(code, block_count, lower_bound, prefix):
 | 
			
		||||
        if block_count == 0:
 | 
			
		||||
            code.splice(f"{prefix}pid_offset = {prefix}pid")
 | 
			
		||||
        else:
 | 
			
		||||
            code.splice(f"{prefix}pid_offset = {prefix}pid - {lower_bound}")
 | 
			
		||||
 | 
			
		||||
    def codegen_pid_range(self, code, x_elems):
 | 
			
		||||
        num_x_blocks = ceildiv(x_elems, self.get_block_size())
 | 
			
		||||
        upper_bound_x_pid = self.x_block_count + num_x_blocks
 | 
			
		||||
        lower_bound_x_pid = self.x_block_count
 | 
			
		||||
 | 
			
		||||
        if self.x_block_count == 0:
 | 
			
		||||
            cond = "if"
 | 
			
		||||
        else:
 | 
			
		||||
            cond = "elif"
 | 
			
		||||
 | 
			
		||||
        x_pid_bounds_check = (
 | 
			
		||||
            f"xpid >= {lower_bound_x_pid} and xpid < {upper_bound_x_pid}"
 | 
			
		||||
        )
 | 
			
		||||
        code.splice(f"{cond} {x_pid_bounds_check}:")
 | 
			
		||||
 | 
			
		||||
        with code.indent():
 | 
			
		||||
            ForeachKernel.codegen_pid_offsets(
 | 
			
		||||
                code, num_x_blocks, lower_bound_x_pid, "x"
 | 
			
		||||
            )
 | 
			
		||||
            self.x_block_count += num_x_blocks
 | 
			
		||||
 | 
			
		||||
    def create_sub_kernel(self, *groups, index_dtype, mutations, reduction_hint):
 | 
			
		||||
        sub_kernel = TritonKernel(
 | 
			
		||||
            *groups,
 | 
			
		||||
            index_dtype=index_dtype,
 | 
			
		||||
            mutations=mutations,
 | 
			
		||||
            pid_cache={
 | 
			
		||||
                "tl.program_id(0)": "xpid_offset",
 | 
			
		||||
                "tl.program_id(1)": "ypid",
 | 
			
		||||
            },
 | 
			
		||||
            reduction_hint=reduction_hint,
 | 
			
		||||
        )
 | 
			
		||||
        if self.blocking_2d:
 | 
			
		||||
            assert len(groups) == 3
 | 
			
		||||
 | 
			
		||||
        self.blocking_2d |= groups[1] != 1 and len(groups) == 3
 | 
			
		||||
        metrics.generated_kernel_count -= 1
 | 
			
		||||
        sub_kernel.args = self.args
 | 
			
		||||
        sub_kernel.iter_vars_count = self.iter_vars_count
 | 
			
		||||
        sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids
 | 
			
		||||
        self.sub_kernels.append(sub_kernel)
 | 
			
		||||
        return sub_kernel
 | 
			
		||||
 | 
			
		||||
    def jit_lines(self):
 | 
			
		||||
        can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels)
 | 
			
		||||
        size_dtype = "tl.int32" if can_use_32bit else "tl.int64"
 | 
			
		||||
        _, _, signature = self.args.python_argdefs()
 | 
			
		||||
        triton_meta = {
 | 
			
		||||
            "signature": signature_to_meta(signature, size_dtype=size_dtype),
 | 
			
		||||
            "device": DeviceProperties.create(V.graph.scheduler.current_device),
 | 
			
		||||
            "constants": {},
 | 
			
		||||
        }
 | 
			
		||||
        triton_meta["configs"] = [config_of(signature)]
 | 
			
		||||
        inductor_meta = {
 | 
			
		||||
            "kernel_name": str(Placeholder.DESCRIPTIVE_NAME),
 | 
			
		||||
            **TritonKernel.inductor_meta_common(),
 | 
			
		||||
        }
 | 
			
		||||
        return f"""
 | 
			
		||||
            @triton_heuristics.foreach(
 | 
			
		||||
                num_warps={self.num_warps},
 | 
			
		||||
                triton_meta={triton_meta!r},
 | 
			
		||||
                inductor_meta={inductor_meta!r},
 | 
			
		||||
            )
 | 
			
		||||
            @triton.jit
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
    def grid(self):
 | 
			
		||||
        return (
 | 
			
		||||
            self.x_block_count,
 | 
			
		||||
            ceildiv(int(self.sub_kernels[0].numels[0]), self.block_size_2d)
 | 
			
		||||
            if self.blocking_2d
 | 
			
		||||
            else 1,
 | 
			
		||||
            1,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def codegen_kernel(self, name=None):
 | 
			
		||||
        code = IndentedBuffer()
 | 
			
		||||
 | 
			
		||||
        code.splice(gen_common_triton_imports())
 | 
			
		||||
        argdefs, _, _ = self.args.python_argdefs()
 | 
			
		||||
        code.splice(self.jit_lines())
 | 
			
		||||
        code.writeline(
 | 
			
		||||
            f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(argdefs)}):"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        with code.indent():
 | 
			
		||||
            code.splice("xpid = tl.program_id(0)")
 | 
			
		||||
            if self.blocking_2d:
 | 
			
		||||
                code.splice("ypid = tl.program_id(1)")
 | 
			
		||||
                code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
 | 
			
		||||
                code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
 | 
			
		||||
            else:
 | 
			
		||||
                code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
 | 
			
		||||
 | 
			
		||||
            for sub_kernel in self.sub_kernels:
 | 
			
		||||
                assert len(sub_kernel.numels) <= 3
 | 
			
		||||
                # TODO mlazos: support dynamic shapes
 | 
			
		||||
                numel_ind = 0 if not self.blocking_2d else 1
 | 
			
		||||
                self.codegen_pid_range(code, int(sub_kernel.numels[numel_ind]))
 | 
			
		||||
                with code.indent():
 | 
			
		||||
                    if self.blocking_2d:
 | 
			
		||||
                        code.splice(f"ynumel = {sub_kernel.numels[0]}")
 | 
			
		||||
                        code.splice(f"xnumel = {sub_kernel.numels[1]}")
 | 
			
		||||
                    else:
 | 
			
		||||
                        code.splice(f"xnumel = {sub_kernel.numels[0]}")
 | 
			
		||||
 | 
			
		||||
                    sub_kernel.codegen_body()
 | 
			
		||||
                    code.splice(sub_kernel.body)
 | 
			
		||||
 | 
			
		||||
            code.splice("else:")
 | 
			
		||||
            with code.indent():
 | 
			
		||||
                code.splice("pass")
 | 
			
		||||
 | 
			
		||||
        return code.getvalue()
 | 
			
		||||
 | 
			
		||||
    def call_kernel(self, code, name: str):
 | 
			
		||||
        _, call_args, _ = self.args.python_argdefs()
 | 
			
		||||
        # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar
 | 
			
		||||
        for i in range(len(call_args)):
 | 
			
		||||
            if V.graph.is_unspec_arg(call_args[i]):
 | 
			
		||||
                call_args[i] = call_args[i] + ".item()"
 | 
			
		||||
        if V.graph.cpp_wrapper:
 | 
			
		||||
            V.graph.wrapper_code.generate_kernel_call(
 | 
			
		||||
                name,
 | 
			
		||||
                call_args,
 | 
			
		||||
                device_index=V.graph.scheduler.current_device.index,
 | 
			
		||||
                grid=self.grid(),
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # TODO: refactor generate_kernel_call
 | 
			
		||||
            call_args_str = ", ".join(call_args)
 | 
			
		||||
            stream_name = code.write_get_raw_stream(
 | 
			
		||||
                V.graph.scheduler.current_device.index
 | 
			
		||||
            )
 | 
			
		||||
            code.writeline(
 | 
			
		||||
                f"{name}.run({call_args_str}, grid=({self.grid()}), stream={stream_name})"
 | 
			
		||||
            )
 | 
			
		||||
@ -529,7 +529,7 @@ class WrapperCodeGen(CodeGen):
 | 
			
		||||
            """
 | 
			
		||||
            import triton
 | 
			
		||||
            import triton.language as tl
 | 
			
		||||
            from {} import grid, split_scan_grid, start_graph, end_graph
 | 
			
		||||
            from {} import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
 | 
			
		||||
            {}
 | 
			
		||||
            """.format(
 | 
			
		||||
                triton_heuristics.__name__,
 | 
			
		||||
@ -1361,8 +1361,15 @@ class WrapperCodeGen(CodeGen):
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def generate_default_grid(self, name: str, grid_args: List[Any]):
 | 
			
		||||
        return grid_args
 | 
			
		||||
    def generate_default_grid(
 | 
			
		||||
        self,
 | 
			
		||||
        name: str,
 | 
			
		||||
        grid: List[Any],
 | 
			
		||||
        cuda: bool = True,
 | 
			
		||||
        grid_callable: Optional[Callable[..., Any]] = None,
 | 
			
		||||
        **grid_extra_kwags,
 | 
			
		||||
    ):
 | 
			
		||||
        return grid
 | 
			
		||||
 | 
			
		||||
    def generate_kernel_call(
 | 
			
		||||
        self,
 | 
			
		||||
@ -1375,6 +1382,7 @@ class WrapperCodeGen(CodeGen):
 | 
			
		||||
        arg_types=None,
 | 
			
		||||
        grid_fn: str = "grid",
 | 
			
		||||
        triton_meta=None,
 | 
			
		||||
        grid_extra_kwargs="",
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Generates kernel call code.
 | 
			
		||||
@ -1392,6 +1400,8 @@ class WrapperCodeGen(CodeGen):
 | 
			
		||||
            )
 | 
			
		||||
            if triton:
 | 
			
		||||
                grid_str = ", ".join(pexpr(item) for item in grid)
 | 
			
		||||
                if grid_extra_kwargs:
 | 
			
		||||
                    grid_str = f"{grid_str}, {grid_extra_kwargs}"
 | 
			
		||||
                grid_str = f"{grid_fn}({grid_str})"
 | 
			
		||||
                self.writeline(
 | 
			
		||||
                    f"{name}.run({call_args_str}, grid={grid_str}, stream={stream_name})"
 | 
			
		||||
 | 
			
		||||
@ -362,6 +362,12 @@ assert_indirect_indexing = True
 | 
			
		||||
# compute CSE bounds on variables that do not appear in the FX graph
 | 
			
		||||
compute_all_bounds = False
 | 
			
		||||
 | 
			
		||||
# enable the combo kernel that combines data-independent kernels (additional
 | 
			
		||||
# to foreach kernels) into a single one (Experimental)
 | 
			
		||||
combo_kernels = False
 | 
			
		||||
# benchmark combo kernels and only allow ones with perf gains
 | 
			
		||||
benchmark_combo_kernel = False
 | 
			
		||||
 | 
			
		||||
# constant folding on the joint graph
 | 
			
		||||
joint_graph_constant_folding = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1741,3 +1741,15 @@ def split_scan_grid(xnumel, rnumel):
 | 
			
		||||
    setattr(grid_fn, "grid_fn_str", grid_fn_str)  # noqa: B010
 | 
			
		||||
 | 
			
		||||
    return grid_fn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def grid_combo_kernels(*numels, num_kernels, min_blocks):
 | 
			
		||||
    """min_blocks is the minimal size of the grid x dimension"""
 | 
			
		||||
    kernel_grid_fn = grid(*numels)
 | 
			
		||||
 | 
			
		||||
    def grid_fn(meta):
 | 
			
		||||
        cuda_grid = list(kernel_grid_fn(meta))
 | 
			
		||||
        cuda_grid[0] = max(num_kernels * cuda_grid[0], min_blocks)
 | 
			
		||||
        return tuple(cuda_grid)
 | 
			
		||||
 | 
			
		||||
    return grid_fn
 | 
			
		||||
 | 
			
		||||
@ -1025,8 +1025,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
    """Scheduler node which consists of a list of scheduler nodes that each operate on a
 | 
			
		||||
    distinct tensor in a list of tensors."""
 | 
			
		||||
    """
 | 
			
		||||
    This is a schedular node that consists of a set of scheduler nodes that
 | 
			
		||||
    has no data dependencies among them and can be executed in parallel.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def get_consumer_subnode_for(self, producer):
 | 
			
		||||
        if producer.get_name() in self.read_to_node:
 | 
			
		||||
@ -1075,6 +1077,11 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def fuse(cls, producer, consumer):
 | 
			
		||||
        assert producer.is_foreach() or consumer.is_foreach()
 | 
			
		||||
        use_custom_partition_algo = (
 | 
			
		||||
            producer.use_custom_partition_algo
 | 
			
		||||
            if producer.is_foreach()
 | 
			
		||||
            else consumer.use_custom_partition_algo
 | 
			
		||||
        )
 | 
			
		||||
        prev_node_1 = None
 | 
			
		||||
        prev_node_2 = None
 | 
			
		||||
        if producer.is_foreach() and consumer.is_foreach():
 | 
			
		||||
@ -1108,13 +1115,24 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
                    fused_nodes.append(new_node)
 | 
			
		||||
                else:
 | 
			
		||||
                    fused_nodes.append(node)
 | 
			
		||||
        else:
 | 
			
		||||
            raise AssertionError(
 | 
			
		||||
                "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return cls(producer.scheduler, fused_nodes, prev_node_1, prev_node_2)  # type: ignore[possibly-undefined]
 | 
			
		||||
        return cls(
 | 
			
		||||
            producer.scheduler,
 | 
			
		||||
            fused_nodes,
 | 
			
		||||
            use_custom_partition_algo=use_custom_partition_algo,
 | 
			
		||||
            prev_node_1=prev_node_1,
 | 
			
		||||
            prev_node_2=prev_node_2,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        scheduler: "Scheduler",
 | 
			
		||||
        nodes: List[SchedulerNode],
 | 
			
		||||
        snodes: List[SchedulerNode],
 | 
			
		||||
        use_custom_partition_algo: bool,
 | 
			
		||||
        prev_node_1=None,
 | 
			
		||||
        prev_node_2=None,
 | 
			
		||||
    ):
 | 
			
		||||
@ -1122,9 +1140,9 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
        self.name_to_node = {}
 | 
			
		||||
 | 
			
		||||
        if prev_node_1 is None or prev_node_2 is None:
 | 
			
		||||
            super().__init__(scheduler, nodes)
 | 
			
		||||
            super().__init__(scheduler, snodes)
 | 
			
		||||
 | 
			
		||||
            for node in nodes:
 | 
			
		||||
            for node in snodes:
 | 
			
		||||
                for read in node.read_writes.reads:
 | 
			
		||||
                    self.read_to_node[read.name] = node
 | 
			
		||||
 | 
			
		||||
@ -1132,7 +1150,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
                    self.name_to_node[name] = node
 | 
			
		||||
        else:
 | 
			
		||||
            self.scheduler = scheduler
 | 
			
		||||
            self.snodes = nodes
 | 
			
		||||
            self.snodes = snodes
 | 
			
		||||
            self.node: ir.Buffer = None  # type: ignore[assignment]
 | 
			
		||||
            self.users: List[NodeUser] = []
 | 
			
		||||
 | 
			
		||||
@ -1163,17 +1181,43 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
            for name in other_node.get_names():
 | 
			
		||||
                self.name_to_node[name] = other_node
 | 
			
		||||
 | 
			
		||||
        self.group = (nodes[0].get_device(), "foreach")
 | 
			
		||||
 | 
			
		||||
        self.use_custom_partition_algo = use_custom_partition_algo
 | 
			
		||||
        self.group = (snodes[0].get_device(), "combo_kernel")
 | 
			
		||||
        self.origins: Set[torch.fx.Node] = set()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def combinable_nodes(cls, nodes: List[SchedulerNode]) -> List[SchedulerNode]:
 | 
			
		||||
        extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)]
 | 
			
		||||
        if extern:
 | 
			
		||||
            log.debug(
 | 
			
		||||
                "ComboKernels: %d external nodes are filtered %s",
 | 
			
		||||
                len(extern),
 | 
			
		||||
                [node.node.origins for node in extern],
 | 
			
		||||
            )
 | 
			
		||||
        filtered_nodes = [
 | 
			
		||||
            x
 | 
			
		||||
            for x in nodes
 | 
			
		||||
            if not isinstance(x, (NopKernelSchedulerNode, ExternKernelSchedulerNode))
 | 
			
		||||
        ]
 | 
			
		||||
        foreach_nodes = [
 | 
			
		||||
            x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode)
 | 
			
		||||
        ]
 | 
			
		||||
        if foreach_nodes:
 | 
			
		||||
            log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes))
 | 
			
		||||
        filtered_nodes = [
 | 
			
		||||
            x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode)
 | 
			
		||||
        ]
 | 
			
		||||
        template_nodes = [x for x in filtered_nodes if x.is_template()]
 | 
			
		||||
        if template_nodes:
 | 
			
		||||
            log.debug(
 | 
			
		||||
                "ComboKernels: %d template nodes are filtered", {len(template_nodes)}
 | 
			
		||||
            )
 | 
			
		||||
        filtered_nodes = [x for x in filtered_nodes if x not in template_nodes]
 | 
			
		||||
        return filtered_nodes
 | 
			
		||||
 | 
			
		||||
    def mark_run(self):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def codegen(self):
 | 
			
		||||
        assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}"
 | 
			
		||||
        self.node.get_store_function()(self.node.make_loader()())
 | 
			
		||||
 | 
			
		||||
    def can_free(self):
 | 
			
		||||
        return NotImplementedError
 | 
			
		||||
 | 
			
		||||
@ -1181,7 +1225,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def get_subkernel_nodes(self):
 | 
			
		||||
        """Returns a list of nodes which comprise the foreach kernel, operating on corresponding elements of our input lists.
 | 
			
		||||
        """Returns a list of nodes which comprise the combo kernel.
 | 
			
		||||
        These nodes may be vertically fused."""
 | 
			
		||||
        return list(self.snodes)
 | 
			
		||||
 | 
			
		||||
@ -1339,6 +1383,8 @@ class Scheduler:
 | 
			
		||||
            # Refresh node_users and inverse_users to reflect fused nodes
 | 
			
		||||
            self.compute_node_users()
 | 
			
		||||
            self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
 | 
			
		||||
        if config.combo_kernels:
 | 
			
		||||
            self.create_combo_kernel_nodes(num_ck_nodes=None)
 | 
			
		||||
        self.compute_last_usage()
 | 
			
		||||
        V.debug.ir_post_fusion(self.nodes)
 | 
			
		||||
        V.debug.graph_diagram(self.nodes)
 | 
			
		||||
@ -1405,7 +1451,7 @@ class Scheduler:
 | 
			
		||||
            removed_node_names.update(names)
 | 
			
		||||
            snodes = [self.name_to_node[name] for name in names]
 | 
			
		||||
 | 
			
		||||
            fe_node = ForeachKernelSchedulerNode(self, snodes)  # type: ignore[arg-type]
 | 
			
		||||
            fe_node = ForeachKernelSchedulerNode(self, snodes, use_custom_partition_algo=False)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
            fe_nodes.append(fe_node)
 | 
			
		||||
 | 
			
		||||
@ -1694,6 +1740,51 @@ class Scheduler:
 | 
			
		||||
            visit(node)
 | 
			
		||||
        self.nodes = result
 | 
			
		||||
 | 
			
		||||
    def _get_unmet_dep_nodes(self, snode):
 | 
			
		||||
        unmet_deps = set()
 | 
			
		||||
        if isinstance(
 | 
			
		||||
            snode,
 | 
			
		||||
            (
 | 
			
		||||
                SchedulerNode,
 | 
			
		||||
                ExternKernelSchedulerNode,
 | 
			
		||||
                NopKernelSchedulerNode,
 | 
			
		||||
                FusedSchedulerNode,
 | 
			
		||||
            ),
 | 
			
		||||
        ):
 | 
			
		||||
            for dep in snode.unmet_dependencies:
 | 
			
		||||
                unmet_deps.add(dep.name)
 | 
			
		||||
        else:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"get_unmet_dep_nodes is not implemented for {type(snode)}."
 | 
			
		||||
            )
 | 
			
		||||
        return list({self.name_to_fused_node[n] for n in unmet_deps})
 | 
			
		||||
 | 
			
		||||
    def _topological_sort_nodes(self):
 | 
			
		||||
        """
 | 
			
		||||
        Sort nodes by their topological order, return a list of node lists.
 | 
			
		||||
        """
 | 
			
		||||
        order = []
 | 
			
		||||
        nodes = {n: 0 for n in self.nodes}
 | 
			
		||||
        children: Dict[Any, Any] = {}
 | 
			
		||||
        for node in self.nodes:
 | 
			
		||||
            deps = self._get_unmet_dep_nodes(node)
 | 
			
		||||
            nodes[node] = len(deps)
 | 
			
		||||
            for dep in deps:
 | 
			
		||||
                c = children.get(dep, [])
 | 
			
		||||
                c.append(node)
 | 
			
		||||
                children[dep] = c
 | 
			
		||||
 | 
			
		||||
        zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
 | 
			
		||||
        while zero_deg_nodes:
 | 
			
		||||
            order.append(zero_deg_nodes)
 | 
			
		||||
            for n in zero_deg_nodes:
 | 
			
		||||
                for user in children.get(n, []):
 | 
			
		||||
                    nodes[user] -= 1
 | 
			
		||||
                nodes.pop(n)
 | 
			
		||||
            zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
 | 
			
		||||
        assert not nodes, "Topological sort failed!"
 | 
			
		||||
        return order
 | 
			
		||||
 | 
			
		||||
    def compute_ancestors(self):
 | 
			
		||||
        """
 | 
			
		||||
        Populate each node.ancestors
 | 
			
		||||
@ -1975,6 +2066,65 @@ class Scheduler:
 | 
			
		||||
        self.topological_sort_schedule()
 | 
			
		||||
        self.prune_redundant_deps()
 | 
			
		||||
 | 
			
		||||
    def create_combo_kernel_nodes(self, num_ck_nodes=None):
 | 
			
		||||
        """
 | 
			
		||||
        Groups parallel nodes
 | 
			
		||||
        """
 | 
			
		||||
        fused_nodes = set(self.nodes)
 | 
			
		||||
        count = 0
 | 
			
		||||
        num_nodes_orig = len(self.nodes)
 | 
			
		||||
        log.debug("ComboKernels: Generating with num_ck_nodes = %d...", num_ck_nodes)
 | 
			
		||||
        for num, node_list in enumerate(self._group_nodes_for_combo_kernels()):
 | 
			
		||||
            node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list)
 | 
			
		||||
            if len(node_list) < 2:
 | 
			
		||||
                continue
 | 
			
		||||
            if num_ck_nodes is not None and count > num_ck_nodes:
 | 
			
		||||
                break
 | 
			
		||||
            if not self.speedup_by_combo_kernel(node_list):
 | 
			
		||||
                log.debug("ComboKernels: Not speeding up %d-th group", num)
 | 
			
		||||
                continue
 | 
			
		||||
            count += 1
 | 
			
		||||
            group_snode = ForeachKernelSchedulerNode(
 | 
			
		||||
                node_list[0].scheduler, node_list, use_custom_partition_algo=True
 | 
			
		||||
            )
 | 
			
		||||
            log.info(
 | 
			
		||||
                "ComboKernels: Combining %d nodes for %d-th group",
 | 
			
		||||
                len(node_list),
 | 
			
		||||
                num,
 | 
			
		||||
            )
 | 
			
		||||
            for node in node_list:
 | 
			
		||||
                fused_nodes.remove(node)
 | 
			
		||||
            fused_nodes.add(group_snode)
 | 
			
		||||
            self.name_to_fused_node.update(
 | 
			
		||||
                {n.get_name(): group_snode for n in group_snode.get_nodes()}
 | 
			
		||||
            )
 | 
			
		||||
        self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
 | 
			
		||||
        self.topological_sort_schedule()
 | 
			
		||||
        log.info(
 | 
			
		||||
            "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodels",
 | 
			
		||||
            count,
 | 
			
		||||
            num_nodes_orig,
 | 
			
		||||
            len(self.nodes),
 | 
			
		||||
        )
 | 
			
		||||
        self.prune_redundant_deps()
 | 
			
		||||
 | 
			
		||||
    def _group_nodes_for_combo_kernels(self):
 | 
			
		||||
        """
 | 
			
		||||
        Returns a list of lists of nodes that are to be grouped together.
 | 
			
		||||
        """
 | 
			
		||||
        sorted_nodes = self._topological_sort_nodes()
 | 
			
		||||
        grouped_nodes = []
 | 
			
		||||
        max_num_nodes = 8
 | 
			
		||||
        for nodes in sorted_nodes:
 | 
			
		||||
            grouped_nodes.extend(
 | 
			
		||||
                [
 | 
			
		||||
                    nodes[i : i + max_num_nodes]
 | 
			
		||||
                    for i in range(0, len(nodes), max_num_nodes)
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return grouped_nodes
 | 
			
		||||
 | 
			
		||||
    def prune_redundant_deps(self):
 | 
			
		||||
        for node in self.nodes:
 | 
			
		||||
            node.prune_redundant_deps(self.name_to_fused_node)
 | 
			
		||||
@ -2583,7 +2733,7 @@ class Scheduler:
 | 
			
		||||
            elif node.is_extern():
 | 
			
		||||
                self.codegen_extern_call(node)
 | 
			
		||||
            elif node.is_foreach():
 | 
			
		||||
                self.get_backend(device).codegen_foreach(node)  # type: ignore[possibly-undefined]
 | 
			
		||||
                self.get_backend(device).codegen_combo_kernel(node)  # type: ignore[possibly-undefined]
 | 
			
		||||
            elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
 | 
			
		||||
                self.get_backend(device).codegen_node(node)  # type: ignore[possibly-undefined]
 | 
			
		||||
            else:
 | 
			
		||||
@ -2614,6 +2764,98 @@ class Scheduler:
 | 
			
		||||
        node = self.name_to_node[buf_name]
 | 
			
		||||
        return node.node.get_layout()
 | 
			
		||||
 | 
			
		||||
    def benchmark_combo_kernel(self, node_list):
 | 
			
		||||
        """
 | 
			
		||||
        Benchmark fused list of nodes and return the execution time
 | 
			
		||||
        in milliseconds on randomly generated inputs.
 | 
			
		||||
        """
 | 
			
		||||
        device = node_list[0].get_device()
 | 
			
		||||
        V.graph.scheduler = self
 | 
			
		||||
        self.current_device = device
 | 
			
		||||
        backend = self.get_backend(device)
 | 
			
		||||
        return backend.benchmark_combo_kernel(node_list)
 | 
			
		||||
 | 
			
		||||
    def speedup_by_combo_kernel(self, node_list):
 | 
			
		||||
        """
 | 
			
		||||
        If config.benchmark_fusion is False, always return True.
 | 
			
		||||
        Otherwise, return True if fusion can brings speedup.
 | 
			
		||||
        """
 | 
			
		||||
        if not config.benchmark_combo_kernel:
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        subkernel_nodes = node_list
 | 
			
		||||
        device = subkernel_nodes[0].get_device()
 | 
			
		||||
 | 
			
		||||
        # don't support benchmark fusion for CPU right now.
 | 
			
		||||
        if device.type == "cpu":
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        from triton.compiler.errors import CompilationError
 | 
			
		||||
 | 
			
		||||
        ms1, path1_list = 0.0, []
 | 
			
		||||
        for i, snode in enumerate(subkernel_nodes):
 | 
			
		||||
            node_list = snode.get_nodes()
 | 
			
		||||
            # We can not accurately benchmark kernel using atomic_add
 | 
			
		||||
            # due to how we generate random integer inputs.
 | 
			
		||||
            # Skip benchmarking them by allowing fusion.
 | 
			
		||||
            if any(
 | 
			
		||||
                hasattr(n.node, "data")
 | 
			
		||||
                and hasattr(n.node.data, "scatter_mode")
 | 
			
		||||
                and n.node.data.scatter_mode == "atomic_add"
 | 
			
		||||
                for n in node_list
 | 
			
		||||
            ):
 | 
			
		||||
                fusion_log.debug(
 | 
			
		||||
                    "ComboKernel: benchmarking may not accurate due to atomic_add"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                ms, path = self.benchmark_fused_nodes(node_list)
 | 
			
		||||
                if math.isinf(ms):
 | 
			
		||||
                    fusion_log.debug(
 | 
			
		||||
                        "ComboKernel benchmark: register spilling of %d-th subkernel",
 | 
			
		||||
                        i,
 | 
			
		||||
                    )
 | 
			
		||||
                    return False
 | 
			
		||||
            except CompilationError as e:
 | 
			
		||||
                # workaround triton issue: https://github.com/openai/triton/issues/2151
 | 
			
		||||
                if "Loop-carried variable" in str(e):
 | 
			
		||||
                    fusion_log.debug(
 | 
			
		||||
                        "ComboKernel benchmark: return True because of loop-carried variable"
 | 
			
		||||
                    )
 | 
			
		||||
                    return True  # allow fusion
 | 
			
		||||
                else:
 | 
			
		||||
                    raise
 | 
			
		||||
            ms1 += ms
 | 
			
		||||
            path1_list.append(path)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            ms2, ms2_clone, path2_list = self.benchmark_combo_kernel(subkernel_nodes)
 | 
			
		||||
        except CompilationError as e:
 | 
			
		||||
            # workaround triton issue: https://github.com/openai/triton/issues/2151
 | 
			
		||||
            if "Loop-carried variable" in str(e):
 | 
			
		||||
                fusion_log.debug(
 | 
			
		||||
                    "ComboKernel benchmark: return True because of loop-carried variable"
 | 
			
		||||
                )
 | 
			
		||||
                return True  # allow fusion
 | 
			
		||||
            else:
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
        # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking.
 | 
			
		||||
        small_kernel = ms2_clone / ms2 > 0.6 and ms2 - ms2_clone < 0.2
 | 
			
		||||
        if fusion_log.isEnabledFor(logging.DEBUG):
 | 
			
		||||
            if ms1 > ms2 or small_kernel:
 | 
			
		||||
                fusion_log.debug(
 | 
			
		||||
                    "can fuse (benchmark): fusing causes %sx speedup",
 | 
			
		||||
                    green_text(f"{ms1 / ms2:.3f}"),
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                fusion_log.debug(
 | 
			
		||||
                    "cannot fuse (benchmark): fusing causes %sx slowdown",
 | 
			
		||||
                    red_text(f"{ms1 / ms2:.3f}"),
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        return ms2 < ms1 or small_kernel
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseScheduling:
 | 
			
		||||
    def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user