mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			ciflow/tru
			...
			new-codege
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 25e1e99fcd | |||
| f803a57a7a | |||
| 39287c561b | 
@ -3573,6 +3573,7 @@ class CommonTemplate:
 | 
			
		||||
        self.assertEqual(actual, expect)
 | 
			
		||||
 | 
			
		||||
    @skip_if_halide  # only 32-bit indexing
 | 
			
		||||
    @skipIfRocm
 | 
			
		||||
    @largeTensorTest("4GB", inductor=True)
 | 
			
		||||
    def test_large_pointwise(self):
 | 
			
		||||
        def fn(a):
 | 
			
		||||
@ -12962,6 +12963,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
 | 
			
		||||
            if name not in {"softmax", "log_softmax", "logsumexp"}
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
    @skipIfRocm
 | 
			
		||||
    def test_pointwise(self, name, op):
 | 
			
		||||
        dtype = torch.float32
 | 
			
		||||
        check_lowp = True
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,6 @@ from ...utils._sympy.value_ranges import ValueRanges
 | 
			
		||||
from .. import config, ir, metrics
 | 
			
		||||
from ..async_compile import AsyncCompile
 | 
			
		||||
from ..codecache import code_hash, get_path, PyCodeCache, write_atomic
 | 
			
		||||
from ..debug import set_kernel_post_grad_provenance_tracing
 | 
			
		||||
from ..ops_handler import DefaultHandler
 | 
			
		||||
from ..runtime import triton_heuristics
 | 
			
		||||
from ..runtime.benchmarking import benchmarker
 | 
			
		||||
@ -48,7 +47,6 @@ from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
 | 
			
		||||
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
 | 
			
		||||
from ..utils import (
 | 
			
		||||
    cache_on_self,
 | 
			
		||||
    DelayMaybeLine,
 | 
			
		||||
    DelayReplaceLine,
 | 
			
		||||
    get_bounds_index_expr,
 | 
			
		||||
    get_fused_kernel_name,
 | 
			
		||||
@ -75,7 +73,6 @@ from .common import (
 | 
			
		||||
    DeferredLine,
 | 
			
		||||
    IndentedBuffer,
 | 
			
		||||
    InplacedBuffer,
 | 
			
		||||
    is_buffer_removed,
 | 
			
		||||
    OpOverrides,
 | 
			
		||||
    PythonPrinter,
 | 
			
		||||
    RemovedArg,
 | 
			
		||||
@ -641,13 +638,6 @@ def triton_reshape(
 | 
			
		||||
    return f"{value}[{', '.join(expand)}]"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def enable_pdl_codegen():
 | 
			
		||||
    if not torch._inductor.config.triton.enable_pdl:
 | 
			
		||||
        return False
 | 
			
		||||
    major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
 | 
			
		||||
    return major >= 9
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
 | 
			
		||||
# number of operators which Triton "implements", but in a way that is
 | 
			
		||||
# inconsistent with Python semantics (and consistent with C semantics).  We
 | 
			
		||||
@ -1607,6 +1597,10 @@ class TritonKernelOverrides(TritonOverrides):
 | 
			
		||||
        V.kernel.cse.put(cache_key, (mantissa, exponent))
 | 
			
		||||
        return (mantissa, exponent)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def device_assert_async(cond, msg):
 | 
			
		||||
        return f"tl.device_assert({cond}, {repr(msg)})"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HelperFunctions:
 | 
			
		||||
    """An ordered set of helper functions."""
 | 
			
		||||
@ -1950,8 +1944,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
        self.fixed_config = fixed_config
 | 
			
		||||
        super().__init__(tiling, **kwargs)
 | 
			
		||||
        self.cse = TritonCSE(self.newvar_prefix, self.suffix)
 | 
			
		||||
        # Cache of values that can be reused for the prologue.
 | 
			
		||||
        self.prologue_cache: dict[str, str] = {}
 | 
			
		||||
        self.prologue: IndentedBuffer = IndentedBuffer()
 | 
			
		||||
        self.post_loop_combine: IndentedBuffer = IndentedBuffer()
 | 
			
		||||
        self.post_loop_store: IndentedBuffer = IndentedBuffer()
 | 
			
		||||
@ -1966,7 +1958,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
        self.tma_min_block_sizes = dict[str, int]()
 | 
			
		||||
        self.hint_override = hint_override
 | 
			
		||||
        self._load_counts: collections.Counter[str] = collections.Counter()
 | 
			
		||||
        self._load_index = 0
 | 
			
		||||
 | 
			
		||||
        # A set of autotuning hints to pass as part of triton_meta
 | 
			
		||||
        self.autotune_hints = OrderedSet[AutotuneHint]()
 | 
			
		||||
@ -1983,44 +1974,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
        if self.cooperative_reduction:
 | 
			
		||||
            self.init_cooperative_reduction_mask()
 | 
			
		||||
 | 
			
		||||
        self.has_load_with_contiguous_rdim = False
 | 
			
		||||
        # We track the store name since a store can be canceled later
 | 
			
		||||
        self.stores_with_contiguous_rdim: list[str] = []
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _has_stride1_on_rdim(index) -> bool:
 | 
			
		||||
        # These analysis is only needed in deterministic mode so far
 | 
			
		||||
        # to filter triton configs. Return false immediately to avoid
 | 
			
		||||
        # increasing compilation time when the mode is off.
 | 
			
		||||
        if not (
 | 
			
		||||
            config.deterministic or config.test_configs.force_filter_reduction_configs
 | 
			
		||||
        ):
 | 
			
		||||
            return False
 | 
			
		||||
        support_vars = index.free_symbols
 | 
			
		||||
        reduce_vars = [
 | 
			
		||||
            var
 | 
			
		||||
            for var in support_vars
 | 
			
		||||
            if symbol_is_type(var, TritonSymbols.reduction_types)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        if len(reduce_vars) == 0:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # for expression "x0 + 150528*((x1//(s27*s38))) + 3*(ModularIndexing(x1, 1, s38)) + 672*(ModularIndexing(x1, s38, s27))"
 | 
			
		||||
        # stride_vars will results in DivisionByZero error
 | 
			
		||||
        try:
 | 
			
		||||
            stride_vars = V.graph.sizevars.stride_vars(index, reduce_vars, support_vars)
 | 
			
		||||
        except ZeroDivisionError:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        return any(stride == 1 for stride in stride_vars)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def has_store_with_contiguous_rdim(self) -> bool:
 | 
			
		||||
        return not all(
 | 
			
		||||
            is_buffer_removed(name) for name in self.stores_with_contiguous_rdim
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def dtype_to_str(self, dtype: torch.dtype) -> str:
 | 
			
		||||
        return triton_type(dtype)
 | 
			
		||||
 | 
			
		||||
@ -2088,6 +2041,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def codegen_range_tree(self):
 | 
			
		||||
        
 | 
			
		||||
        # Skip if pointwise inner loop
 | 
			
		||||
        #return
 | 
			
		||||
 | 
			
		||||
        for tree in self.range_trees:
 | 
			
		||||
            # reduction indexing goes inside a loop
 | 
			
		||||
            if not tree.is_loop:
 | 
			
		||||
@ -2532,95 +2489,86 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
            and self.range_trees[-1].is_loop
 | 
			
		||||
            and indexing.has_rindex()
 | 
			
		||||
        ) or indexing.can_lift:
 | 
			
		||||
            if indexing.can_lift and var in self.prologue_cache:
 | 
			
		||||
                # Check for epilogue subtiling to reuse the same
 | 
			
		||||
                # tensor descriptor.
 | 
			
		||||
                block_descriptor = self.prologue_cache[var]
 | 
			
		||||
            block_descriptor_id = next(self.block_ptr_id)
 | 
			
		||||
            if isinstance(indexing, BlockPtrOptions):
 | 
			
		||||
                block_descriptor = f"block_ptr{block_descriptor_id}"
 | 
			
		||||
            else:
 | 
			
		||||
                block_descriptor_id = next(self.block_ptr_id)
 | 
			
		||||
                if isinstance(indexing, BlockPtrOptions):
 | 
			
		||||
                    block_descriptor = f"block_ptr{block_descriptor_id}"
 | 
			
		||||
                else:
 | 
			
		||||
                    block_descriptor = f"tma_descriptor{block_descriptor_id}"
 | 
			
		||||
                line_body = DeferredLine(
 | 
			
		||||
                    name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
 | 
			
		||||
                )
 | 
			
		||||
                if indexing.can_lift:
 | 
			
		||||
                    self.prologue.writeline(line_body)
 | 
			
		||||
                    # Cache the descriptor for epilogue subtiling
 | 
			
		||||
                    self.prologue_cache[var] = block_descriptor
 | 
			
		||||
                else:
 | 
			
		||||
                    self.body.writeline(line_body)
 | 
			
		||||
                block_descriptor = f"tma_descriptor{block_descriptor_id}"
 | 
			
		||||
            line_body = DeferredLine(
 | 
			
		||||
                name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
 | 
			
		||||
            )
 | 
			
		||||
            if indexing.can_lift:
 | 
			
		||||
                self.prologue.writeline(line_body)
 | 
			
		||||
            else:
 | 
			
		||||
                self.body.writeline(line_body)
 | 
			
		||||
 | 
			
		||||
                if isinstance(indexing, BlockPtrOptions):
 | 
			
		||||
                    # Store for later use. If the buffer is removed the below advancements
 | 
			
		||||
                    # are no longer necessary
 | 
			
		||||
                    self.block_ptr_to_buffer[block_descriptor] = name
 | 
			
		||||
            if isinstance(indexing, BlockPtrOptions):
 | 
			
		||||
                # Store for later use. If the buffer is removed the below advancements
 | 
			
		||||
                # are no longer necessary
 | 
			
		||||
                self.block_ptr_to_buffer[block_descriptor] = name
 | 
			
		||||
 | 
			
		||||
                    # Generate block pointer advancements, for later use.
 | 
			
		||||
                    for symt in TritonSymbols.reduction_types:
 | 
			
		||||
                        advance_offsets = indexing.advance_roffset(symt)
 | 
			
		||||
                # Generate block pointer advancements, for later use.
 | 
			
		||||
                for symt in TritonSymbols.reduction_types:
 | 
			
		||||
                    advance_offsets = indexing.advance_roffset(symt)
 | 
			
		||||
 | 
			
		||||
                        # Ignore identity advancements.
 | 
			
		||||
                        if all(
 | 
			
		||||
                            V.graph.sizevars.statically_known_equals(
 | 
			
		||||
                                offset, sympy.Integer(0)
 | 
			
		||||
                            )
 | 
			
		||||
                            for offset in advance_offsets
 | 
			
		||||
                        ):
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        advancements = self.pointer_advancements[symt]
 | 
			
		||||
                        assert block_descriptor not in advancements, (
 | 
			
		||||
                            f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
 | 
			
		||||
                    # Ignore identity advancements.
 | 
			
		||||
                    if all(
 | 
			
		||||
                        V.graph.sizevars.statically_known_equals(
 | 
			
		||||
                            offset, sympy.Integer(0)
 | 
			
		||||
                        )
 | 
			
		||||
                        advancements[block_descriptor] = advance_offsets
 | 
			
		||||
                        for offset in advance_offsets
 | 
			
		||||
                    ):
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    advancements = self.pointer_advancements[symt]
 | 
			
		||||
                    assert block_descriptor not in advancements, (
 | 
			
		||||
                        f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
 | 
			
		||||
                    )
 | 
			
		||||
                    advancements[block_descriptor] = advance_offsets
 | 
			
		||||
        else:
 | 
			
		||||
            block_descriptor = indexing.format(var)
 | 
			
		||||
        return block_descriptor, other
 | 
			
		||||
 | 
			
		||||
    def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
 | 
			
		||||
        def stringify_shape(shape):
 | 
			
		||||
            return tuple(
 | 
			
		||||
                symt.name if isinstance(symt, sympy.Symbol) else str(symt)
 | 
			
		||||
                for symt in shape
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if value.shape:
 | 
			
		||||
            value_forward_shape = stringify_shape(value.shape)
 | 
			
		||||
            value_reverse_shape = stringify_shape(value.shape[::-1])
 | 
			
		||||
        # TMA stores may require transposing the data to ensure we are contiguous along
 | 
			
		||||
        # the final dimension. We do this by checking the shape information on value.
 | 
			
		||||
        # It can either
 | 
			
		||||
        #    1. Match the final shape. In this case no broadcast/reshape
 | 
			
		||||
        #       is necessary.
 | 
			
		||||
        #    2. Exist as the Transpose of the final shape, which means we had to transpose
 | 
			
		||||
        #       the store_descriptor relative to the accumulator indexing/value. If this
 | 
			
		||||
        #       happens we will generate a tl.trans().
 | 
			
		||||
        #    3. A mismatched provided shape. When this occurs we will error.
 | 
			
		||||
        #    4. No shape is provided. This will proceed with the default explicit broadcast
 | 
			
		||||
        #       described below.
 | 
			
		||||
        #
 | 
			
		||||
        # To prevent unintended side effects we will gate options 1-3 behind isinstance(indexing, TensorDescriptorOptions).
 | 
			
		||||
        if isinstance(indexing, TensorDescriptorOptions) and value.shape:
 | 
			
		||||
            str_final_shape = tuple([symt.name for symt in indexing.final_shape])
 | 
			
		||||
            if value.shape[::-1] == str_final_shape:
 | 
			
		||||
                value = f"tl.trans({value})"
 | 
			
		||||
            elif value.shape != str_final_shape:
 | 
			
		||||
                raise AssertionError(
 | 
			
		||||
                    "TMA store requires no broadcasting when a shape is provided"
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            value_forward_shape = None
 | 
			
		||||
            value_reverse_shape = None
 | 
			
		||||
        final_shape = stringify_shape(indexing.final_shape)
 | 
			
		||||
        # TODO: Generalize to N Dimensions
 | 
			
		||||
        if (
 | 
			
		||||
            value_forward_shape != final_shape
 | 
			
		||||
            and value_reverse_shape == final_shape
 | 
			
		||||
            and len(final_shape) == 2
 | 
			
		||||
        ):
 | 
			
		||||
            # TMA stores may require transposing the data to ensure we are contiguous along
 | 
			
		||||
            # the final dimension. This applies to Block-pointers generally, but should only practically
 | 
			
		||||
            # be reached with TMA.
 | 
			
		||||
            value = f"tl.trans({value})"
 | 
			
		||||
            # Stores require an explicit broadcast. We do this in two phases:
 | 
			
		||||
            #  1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
 | 
			
		||||
            #     YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
 | 
			
		||||
            #  2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the
 | 
			
		||||
            #     result to the shape of the pointer.
 | 
			
		||||
            value = f"tl.broadcast_to({value}, {indexing.final_shape})"
 | 
			
		||||
 | 
			
		||||
        # Stores require an explicit broadcast. We do this in two phases:
 | 
			
		||||
        #  1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
 | 
			
		||||
        #     YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
 | 
			
		||||
        #  2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the
 | 
			
		||||
        #     result to the shape of the pointer.
 | 
			
		||||
        value = f"tl.broadcast_to({value}, {indexing.final_shape})"
 | 
			
		||||
            # These dims no longer need broadcasting.
 | 
			
		||||
            for idx, (dim, broadcast_dim) in enumerate(
 | 
			
		||||
                zip(indexing.final_shape, indexing.broadcast_shape)
 | 
			
		||||
            ):
 | 
			
		||||
                if V.graph.sizevars.statically_known_equals(dim, broadcast_dim):
 | 
			
		||||
                    indexing.broadcasting_dims[idx] = False
 | 
			
		||||
 | 
			
		||||
        # These dims no longer need broadcasting.
 | 
			
		||||
        for idx, (dim, broadcast_dim) in enumerate(
 | 
			
		||||
            zip(indexing.final_shape, indexing.broadcast_shape)
 | 
			
		||||
        ):
 | 
			
		||||
            if V.graph.sizevars.statically_known_equals(dim, broadcast_dim):
 | 
			
		||||
                indexing.broadcasting_dims[idx] = False
 | 
			
		||||
 | 
			
		||||
        value = indexing.codegen_broadcast_and_reshape(
 | 
			
		||||
            value, indexing.final_shape, indexing.block_shape, False
 | 
			
		||||
        )
 | 
			
		||||
            value = indexing.codegen_broadcast_and_reshape(
 | 
			
		||||
                value, indexing.final_shape, indexing.block_shape, False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # workaround https://github.com/triton-lang/triton/issues/2814
 | 
			
		||||
        value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
 | 
			
		||||
@ -2669,27 +2617,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
        else:
 | 
			
		||||
            return self.loads
 | 
			
		||||
 | 
			
		||||
    def _handle_pdl_before_load(self, wait_buffer):
 | 
			
		||||
        GDC_WAIT = "tl.extra.cuda.gdc_wait()"
 | 
			
		||||
        self._load_index += 1
 | 
			
		||||
        if self.inside_reduction:
 | 
			
		||||
            wait_buffer = self.body
 | 
			
		||||
        if enable_pdl_codegen():
 | 
			
		||||
            if self._load_index == 1:
 | 
			
		||||
                wait_buffer.writeline(GDC_WAIT)
 | 
			
		||||
 | 
			
		||||
    def _handle_pdl_after_load(self, launch_buffer, result_var):
 | 
			
		||||
        GDC_LAUNCH = "tl.extra.cuda.gdc_launch_dependents()"
 | 
			
		||||
        if self.inside_reduction:
 | 
			
		||||
            launch_buffer = self.post_loop_combine
 | 
			
		||||
        if enable_pdl_codegen():
 | 
			
		||||
            current_load_index = self._load_index
 | 
			
		||||
            launch_if_last_load = DelayMaybeLine(
 | 
			
		||||
                lambda: current_load_index == self._load_index,
 | 
			
		||||
                f"0; {GDC_LAUNCH} # gdc launch for {result_var}",
 | 
			
		||||
            )
 | 
			
		||||
            self.cse.generate(launch_buffer, launch_if_last_load, dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
    def load(self, name: str, index: sympy.Expr):
 | 
			
		||||
        """
 | 
			
		||||
        Load from the memory location 'name', offset by some indexing expression 'index'.
 | 
			
		||||
@ -2711,12 +2638,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
                force=False,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim(
 | 
			
		||||
            indexing.index
 | 
			
		||||
        ):
 | 
			
		||||
            self.has_load_with_contiguous_rdim = True
 | 
			
		||||
 | 
			
		||||
        has_rindex = indexing.has_rindex()
 | 
			
		||||
        has_tmpmask = indexing.has_tmpmask()
 | 
			
		||||
 | 
			
		||||
@ -2826,11 +2747,11 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
                dtype = torch.bool
 | 
			
		||||
 | 
			
		||||
        load_buffer = self.get_load_buffer(indexing)
 | 
			
		||||
        self._handle_pdl_before_load(load_buffer)
 | 
			
		||||
        if config.triton.enable_pdl:
 | 
			
		||||
            load_buffer.writeline("tl.extra.cuda.gdc_wait()")
 | 
			
		||||
        result_var = self.cse.generate(
 | 
			
		||||
            load_buffer, make_line(line), dtype=dtype, shape=shape
 | 
			
		||||
        )
 | 
			
		||||
        self._handle_pdl_after_load(load_buffer, result_var)
 | 
			
		||||
        if result_var.use_count > 1:
 | 
			
		||||
            load_counts[name] -= 1  # don't double count cache hit
 | 
			
		||||
        assert isinstance(result_var, TritonCSEVariable)
 | 
			
		||||
@ -2884,11 +2805,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
            tma_compatibility_checker=tma_compatibility_checker,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim(
 | 
			
		||||
            indexing.index
 | 
			
		||||
        ):
 | 
			
		||||
            self.stores_with_contiguous_rdim.append(name)
 | 
			
		||||
 | 
			
		||||
        # Guard against write-after-read corruption in triton.
 | 
			
		||||
        # See # https://github.com/triton-lang/triton/issues/1615
 | 
			
		||||
        # This triton bug means that a load which is broadcasted over multiple
 | 
			
		||||
@ -2924,9 +2840,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
 | 
			
		||||
        exit_stack.close()
 | 
			
		||||
 | 
			
		||||
    def device_assert_async(self, cond, msg) -> None:
 | 
			
		||||
        self.compute.writeline(f"tl.device_assert({cond}, {repr(msg)})")
 | 
			
		||||
 | 
			
		||||
    def guard_cooperative_store(self, name, buffer):
 | 
			
		||||
        """
 | 
			
		||||
        For cooperative reductions only one thread block should write out the result.
 | 
			
		||||
@ -2984,7 +2897,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
                "Bucketize only supports indexing with int32 and int64"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self._handle_pdl_before_load(self.compute)
 | 
			
		||||
        result = self.cse.generate(
 | 
			
		||||
            self.compute,
 | 
			
		||||
            f"triton_helpers.bucketize_binary_search({values}, "
 | 
			
		||||
@ -2998,7 +2910,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
            dtype=indexing_dtype,  # type: ignore[attr-defined]
 | 
			
		||||
            shape=values.shape,
 | 
			
		||||
        )
 | 
			
		||||
        self._handle_pdl_after_load(self.compute, result)
 | 
			
		||||
 | 
			
		||||
        masks = self._combine_masks(values, boundary_indices, sorter_indices)
 | 
			
		||||
        result.mask_vars = masks  # type: ignore[attr-defined]
 | 
			
		||||
@ -3175,34 +3086,14 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
 | 
			
		||||
        if self.persistent_reduction:
 | 
			
		||||
            default = ir.Reduction.default_value(reduction_type, src_dtype)
 | 
			
		||||
 | 
			
		||||
            def update_constant_dtype(constant, src_dtype, dst_dtype):
 | 
			
		||||
                "update reduction constant mask value to match dst_dtype"
 | 
			
		||||
 | 
			
		||||
                # int is the only mask which may not fit within lower bitwidth,
 | 
			
		||||
                # because float uses inf/-inf
 | 
			
		||||
                if src_dtype.is_floating_point or src_dtype == torch.bool:
 | 
			
		||||
                    return constant
 | 
			
		||||
 | 
			
		||||
                if src_dtype == dst_dtype or constant == 0:
 | 
			
		||||
                    return constant
 | 
			
		||||
 | 
			
		||||
                if constant == torch.iinfo(src_dtype).max:
 | 
			
		||||
                    return torch.iinfo(dst_dtype).max
 | 
			
		||||
                elif constant == torch.iinfo(src_dtype).min:
 | 
			
		||||
                    return torch.iinfo(dst_dtype).min
 | 
			
		||||
                else:
 | 
			
		||||
                    return constant
 | 
			
		||||
            default = self._map_tuple_or_scalar(constant_repr, default)
 | 
			
		||||
 | 
			
		||||
            def _mask_value(value, default) -> CSEVariable:
 | 
			
		||||
                default = update_constant_dtype(default, src_dtype, value.dtype)
 | 
			
		||||
                default_str = self._map_tuple_or_scalar(constant_repr, default)
 | 
			
		||||
 | 
			
		||||
                return self.cse.generate(
 | 
			
		||||
                    self.compute,
 | 
			
		||||
                    where_cond(value, default_str),
 | 
			
		||||
                    where_cond(value, default),
 | 
			
		||||
                    dtype=value.dtype,
 | 
			
		||||
                    shape=value.shape,
 | 
			
		||||
                    shape=value.shape if value.shape is not None else default.shape,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            masked_value: Union[CSEVariable, Sequence[CSEVariable]]
 | 
			
		||||
@ -3211,14 +3102,13 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
                # will fallback below
 | 
			
		||||
                pass
 | 
			
		||||
            elif isinstance(value, tuple):
 | 
			
		||||
                masked_value = [_mask_value(v, d) for v, d in zip(value, default)]  # type: ignore[arg-type]
 | 
			
		||||
                masked_value = [_mask_value(v, d) for v, d in zip(value, default)]
 | 
			
		||||
            else:
 | 
			
		||||
                masked_value = _mask_value(value, default)
 | 
			
		||||
 | 
			
		||||
            if reduction_type in ("argmax", "argmin"):
 | 
			
		||||
                assert isinstance(masked_value, CSEVariable)
 | 
			
		||||
                accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype()
 | 
			
		||||
 | 
			
		||||
                accumulator_index = str(
 | 
			
		||||
                    self.cse.generate(
 | 
			
		||||
                        self.compute,
 | 
			
		||||
@ -3993,7 +3883,268 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
 | 
			
		||||
        code.splice(self.prologue)
 | 
			
		||||
        self.prologue.clear()
 | 
			
		||||
        self.prologue_cache.clear()
 | 
			
		||||
 | 
			
		||||
    def _should_use_pointwise_inner_loop(self):
 | 
			
		||||
        """Check if we should use inner loop optimization for this pointwise kernel."""
 | 
			
		||||
        
 | 
			
		||||
        # Safety check: never apply to reduction kernels
 | 
			
		||||
        if self.inside_reduction:
 | 
			
		||||
            return False
 | 
			
		||||
        
 | 
			
		||||
        # Only apply to non-template kernels
 | 
			
		||||
        if self.fixed_config or self.cooperative_reduction or self.persistent_reduction:
 | 
			
		||||
            return False
 | 
			
		||||
        
 | 
			
		||||
        # Check for operations incompatible with inner loop
 | 
			
		||||
        code_str = ""
 | 
			
		||||
        if hasattr(self, 'indexing_code') and self.indexing_code:
 | 
			
		||||
            code_str += str(self.indexing_code)
 | 
			
		||||
        if hasattr(self, 'loads') and self.loads:
 | 
			
		||||
            code_str += str(self.loads)
 | 
			
		||||
        if hasattr(self, 'compute') and self.compute:
 | 
			
		||||
            code_str += str(self.compute)
 | 
			
		||||
        if hasattr(self, 'stores') and self.stores:
 | 
			
		||||
            code_str += str(self.stores)
 | 
			
		||||
        
 | 
			
		||||
        # Skip if using block pointers - they have complex shape calculations
 | 
			
		||||
        if "tl.make_block_ptr" in code_str:
 | 
			
		||||
            return False
 | 
			
		||||
        
 | 
			
		||||
        # Skip if using tensor descriptors (TMA)
 | 
			
		||||
        if "tl.make_tensor_descriptor" in code_str:
 | 
			
		||||
            return False
 | 
			
		||||
        
 | 
			
		||||
        # Separate trees by type
 | 
			
		||||
        valid_prefixes = {'x', 'y', 'z'}
 | 
			
		||||
        valid_trees = []
 | 
			
		||||
        reduction_trees = []
 | 
			
		||||
        
 | 
			
		||||
        for tree in self.range_trees:
 | 
			
		||||
            if tree.prefix in valid_prefixes:
 | 
			
		||||
                valid_trees.append(tree)
 | 
			
		||||
            elif tree.prefix.startswith('r') and (len(tree.prefix) == 1 or tree.prefix[1:].replace('_', '').isdigit()):
 | 
			
		||||
                reduction_trees.append(tree)
 | 
			
		||||
        
 | 
			
		||||
        # Check if all reduction trees are trivial (numel=1)
 | 
			
		||||
        for tree in reduction_trees:
 | 
			
		||||
            try:
 | 
			
		||||
                size_hint = V.graph.sizevars.size_hint(tree.numel, fallback=999999)
 | 
			
		||||
                if size_hint > 1:
 | 
			
		||||
                    return False
 | 
			
		||||
            except:
 | 
			
		||||
                return False
 | 
			
		||||
        
 | 
			
		||||
        # Need at least one valid dimension
 | 
			
		||||
        if len(valid_trees) == 0 or len(valid_trees) > 2:
 | 
			
		||||
            return False
 | 
			
		||||
        
 | 
			
		||||
        # Check dimension sizes
 | 
			
		||||
        dimension_sizes = []
 | 
			
		||||
        for tree in valid_trees:
 | 
			
		||||
            try:
 | 
			
		||||
                size_hint = V.graph.sizevars.size_hint(
 | 
			
		||||
                    tree.numel, 
 | 
			
		||||
                    fallback=config.unbacked_symint_fallback
 | 
			
		||||
                )
 | 
			
		||||
                dimension_sizes.append((tree.prefix, size_hint))
 | 
			
		||||
            except:
 | 
			
		||||
                return False
 | 
			
		||||
        
 | 
			
		||||
        #if dimension_sizes:
 | 
			
		||||
        #    largest_size = max(size for _, size in dimension_sizes)
 | 
			
		||||
        #    # Need sufficient work to hide memory latency
 | 
			
		||||
        #    if largest_size < 512:
 | 
			
		||||
        #        return False
 | 
			
		||||
        #else:
 | 
			
		||||
        #    return False
 | 
			
		||||
        
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def _get_split_dimension_info(self):
 | 
			
		||||
        """Get information about which dimension to split for inner loop."""
 | 
			
		||||
        # Only consider x, y, z dimensions (skip reduction dimensions)
 | 
			
		||||
        valid_prefixes = {'x', 'y', 'z'}
 | 
			
		||||
        non_reduction_trees = [
 | 
			
		||||
            t for t in self.range_trees 
 | 
			
		||||
            if t.prefix in valid_prefixes
 | 
			
		||||
        ]
 | 
			
		||||
        
 | 
			
		||||
        if not non_reduction_trees:
 | 
			
		||||
            return None, []
 | 
			
		||||
        
 | 
			
		||||
        # Find dimension with largest numel
 | 
			
		||||
        split_tree = max(
 | 
			
		||||
            non_reduction_trees,
 | 
			
		||||
            key=lambda tree: V.graph.sizevars.size_hint(
 | 
			
		||||
                tree.numel, fallback=config.unbacked_symint_fallback
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        
 | 
			
		||||
        return split_tree, non_reduction_trees
 | 
			
		||||
 | 
			
		||||
    def _get_largest_dimension_for_inner_loop(self):
 | 
			
		||||
        """Determine which dimension to split based on largest numel."""
 | 
			
		||||
        if len(self.range_trees) == 1:
 | 
			
		||||
            return self.range_trees[0]
 | 
			
		||||
        
 | 
			
		||||
        # Find dimension with largest numel
 | 
			
		||||
        largest_tree = max(
 | 
			
		||||
            self.range_trees,
 | 
			
		||||
            key=lambda tree: V.graph.sizevars.size_hint(
 | 
			
		||||
                tree.numel, fallback=config.unbacked_symint_fallback
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        return largest_tree
 | 
			
		||||
 | 
			
		||||
    def _codegen_pointwise_with_inner_loop(self):
 | 
			
		||||
        """Generate pointwise kernel body with inner R0_BLOCK loop."""
 | 
			
		||||
        
 | 
			
		||||
        # Save the current generated code
 | 
			
		||||
        saved_indexing = self.indexing_code.getvalue()
 | 
			
		||||
        saved_loads = self.loads.getvalue()
 | 
			
		||||
        saved_compute = self.compute.getvalue()
 | 
			
		||||
        saved_stores = self.stores.getvalue()
 | 
			
		||||
        
 | 
			
		||||
        # Clear buffers for our custom generation
 | 
			
		||||
        self.indexing_code.clear()
 | 
			
		||||
        self.body.clear()
 | 
			
		||||
        self.loads.clear()
 | 
			
		||||
        self.compute.clear()
 | 
			
		||||
        self.stores.clear()
 | 
			
		||||
        
 | 
			
		||||
        # Determine which dimension to split
 | 
			
		||||
        non_reduction_trees = [t for t in self.range_trees if t.prefix in {'x', 'y', 'z'}]
 | 
			
		||||
        if not non_reduction_trees:
 | 
			
		||||
            # Restore and fallback
 | 
			
		||||
            self.indexing_code.writelines(saved_indexing.split('\n'))
 | 
			
		||||
            self.loads.writelines(saved_loads.split('\n'))
 | 
			
		||||
            self.compute.writelines(saved_compute.split('\n'))
 | 
			
		||||
            self.stores.writelines(saved_stores.split('\n'))
 | 
			
		||||
            return
 | 
			
		||||
        
 | 
			
		||||
        split_tree = max(
 | 
			
		||||
            non_reduction_trees,
 | 
			
		||||
            key=lambda tree: V.graph.sizevars.size_hint(
 | 
			
		||||
                tree.numel, fallback=config.unbacked_symint_fallback
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        split_prefix = split_tree.prefix
 | 
			
		||||
        
 | 
			
		||||
        # Add static assertion
 | 
			
		||||
        self.body.writeline(f"tl.static_assert({split_prefix.upper()}BLOCK % R0_BLOCK == 0)")
 | 
			
		||||
        
 | 
			
		||||
        # Map prefix to program_id
 | 
			
		||||
        prefix_to_pid = {
 | 
			
		||||
            'x': "tl.program_id(0)",
 | 
			
		||||
            'y': "(tl.program_id(1) + tl.program_id(2) * tl.num_programs(1))",
 | 
			
		||||
            'z': "tl.program_id(2)"
 | 
			
		||||
        }
 | 
			
		||||
        
 | 
			
		||||
        # Build prefix list for broadcast calculations
 | 
			
		||||
        all_prefixes = [t.prefix for t in non_reduction_trees]
 | 
			
		||||
        
 | 
			
		||||
        # Generate offset+index+mask for NON-SPLIT dimensions
 | 
			
		||||
        for tree in non_reduction_trees:
 | 
			
		||||
            if tree.prefix != split_prefix:
 | 
			
		||||
                prefix = tree.prefix
 | 
			
		||||
                pid = prefix_to_pid[prefix]
 | 
			
		||||
                
 | 
			
		||||
                idx = all_prefixes.index(prefix)
 | 
			
		||||
                if len(all_prefixes) == 1:
 | 
			
		||||
                    broadcast = ""
 | 
			
		||||
                elif len(all_prefixes) == 2:
 | 
			
		||||
                    broadcast = "[:, None]" if idx == 0 else "[None, :]"
 | 
			
		||||
                else:
 | 
			
		||||
                    parts = ["None"] * len(all_prefixes)
 | 
			
		||||
                    parts[idx] = ":"
 | 
			
		||||
                    broadcast = f"[{', '.join(parts)}]"
 | 
			
		||||
                
 | 
			
		||||
                self.body.writeline(f"{prefix}offset = {pid} * {prefix.upper()}BLOCK")
 | 
			
		||||
                self.body.writeline(
 | 
			
		||||
                    f"{prefix}index = {prefix}offset + tl.arange(0, {prefix.upper()}BLOCK){broadcast}"
 | 
			
		||||
                )
 | 
			
		||||
                self.body.writeline(f"{prefix}mask = {prefix}index < {prefix}numel")
 | 
			
		||||
        
 | 
			
		||||
        # For split dimension, generate offset only
 | 
			
		||||
        pid = prefix_to_pid[split_prefix]
 | 
			
		||||
        self.body.writeline(f"{split_prefix}offset = {pid} * {split_prefix.upper()}BLOCK")
 | 
			
		||||
        self.body.writeline(f"tile_start = {split_prefix}offset")
 | 
			
		||||
        
 | 
			
		||||
        # Save original name
 | 
			
		||||
        original_name = split_tree.name
 | 
			
		||||
        
 | 
			
		||||
        # Generate pipelined inner loop
 | 
			
		||||
        self.body.writeline(f"for r in tl.range(0, {split_prefix.upper()}BLOCK, R0_BLOCK, num_stages=2):")
 | 
			
		||||
        
 | 
			
		||||
        with self.body.indent():
 | 
			
		||||
            self.body.writeline("lanes = tl.arange(0, R0_BLOCK)")
 | 
			
		||||
            
 | 
			
		||||
            # Calculate broadcast for split dimension
 | 
			
		||||
            split_idx = all_prefixes.index(split_prefix)
 | 
			
		||||
            if len(all_prefixes) == 1:
 | 
			
		||||
                broadcast = ""
 | 
			
		||||
            elif len(all_prefixes) == 2:
 | 
			
		||||
                broadcast = "[:, None]" if split_idx == 0 else "[None, :]"
 | 
			
		||||
            else:
 | 
			
		||||
                parts = ["None"] * len(all_prefixes)
 | 
			
		||||
                parts[split_idx] = ":"
 | 
			
		||||
                broadcast = f"[{', '.join(parts)}]"
 | 
			
		||||
            
 | 
			
		||||
            self.body.writeline(
 | 
			
		||||
                f"{split_prefix}index = (tile_start + r + lanes){broadcast}"
 | 
			
		||||
            )
 | 
			
		||||
            self.body.writeline(f"{split_prefix}mask = {split_prefix}index < {split_prefix}numel")
 | 
			
		||||
            
 | 
			
		||||
            # Override the range tree name temporarily
 | 
			
		||||
            split_tree.name = f"{split_prefix}index"
 | 
			
		||||
            
 | 
			
		||||
            # Generate derived indices
 | 
			
		||||
            for node_name, entry in self.range_tree_nodes.items():
 | 
			
		||||
                if hasattr(entry, 'expr') and entry.name != f"{split_prefix}index":
 | 
			
		||||
                    line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}"
 | 
			
		||||
                    self.body.writeline(line)
 | 
			
		||||
            
 | 
			
		||||
            # Fix shapes in the saved code
 | 
			
		||||
            block_name = f"{split_prefix.upper()}BLOCK"
 | 
			
		||||
            
 | 
			
		||||
            # Replace in all buffers
 | 
			
		||||
            import re
 | 
			
		||||
            
 | 
			
		||||
            def fix_shapes(code):
 | 
			
		||||
                if not code:
 | 
			
		||||
                    return code
 | 
			
		||||
                
 | 
			
		||||
                # Replace [XBLOCK] with [R0_BLOCK]
 | 
			
		||||
                code = code.replace(f"[{block_name}]", "[R0_BLOCK]")
 | 
			
		||||
                
 | 
			
		||||
                # Replace tl.full([XBLOCK], ...) with tl.full([R0_BLOCK], ...)
 | 
			
		||||
                code = re.sub(
 | 
			
		||||
                    rf'tl\.full\(\[{block_name}\]',
 | 
			
		||||
                    'tl.full([R0_BLOCK]',
 | 
			
		||||
                    code
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
                # Replace broadcast_to(expr, [XBLOCK]) with broadcast_to(expr, [R0_BLOCK])
 | 
			
		||||
                code = re.sub(
 | 
			
		||||
                    rf'broadcast_to\((.*?),\s*\[{block_name}\]\)',
 | 
			
		||||
                    r'broadcast_to(\1, [R0_BLOCK])',
 | 
			
		||||
                    code
 | 
			
		||||
                )
 | 
			
		||||
                
 | 
			
		||||
                return code
 | 
			
		||||
            
 | 
			
		||||
            # Apply fixes and write the code
 | 
			
		||||
            if saved_indexing:
 | 
			
		||||
                self.body.splice(fix_shapes(saved_indexing))
 | 
			
		||||
            if saved_loads:
 | 
			
		||||
                self.body.splice(fix_shapes(saved_loads))
 | 
			
		||||
            if saved_compute:
 | 
			
		||||
                self.body.splice(fix_shapes(saved_compute))
 | 
			
		||||
            if saved_stores:
 | 
			
		||||
                self.body.splice(fix_shapes(saved_stores))
 | 
			
		||||
        
 | 
			
		||||
        # Restore original name
 | 
			
		||||
        split_tree.name = original_name
 | 
			
		||||
 | 
			
		||||
    def codegen_body(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -4015,6 +4166,15 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
        ):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if self._should_use_pointwise_inner_loop():
 | 
			
		||||
            self._codegen_pointwise_with_inner_loop()
 | 
			
		||||
            # Clear the buffers since we've used them
 | 
			
		||||
            self.indexing_code.clear()
 | 
			
		||||
            self.loads.clear()
 | 
			
		||||
            self.compute.clear()
 | 
			
		||||
            self.stores.clear()
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        loop_trees = [tree for tree in self.range_trees if tree.is_loop]
 | 
			
		||||
        if self.inside_reduction and len(loop_trees) > 0:
 | 
			
		||||
            # Write the loop headers.
 | 
			
		||||
@ -4105,16 +4265,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
                    args.append(str(arg))
 | 
			
		||||
                elif isinstance(arg, SymbolicCallArg):
 | 
			
		||||
                    hint = V.graph.sizevars.size_hint(
 | 
			
		||||
                        arg.inner_expr,
 | 
			
		||||
                        hint_override=self.hint_override,
 | 
			
		||||
                        fallback=config.unbacked_symint_fallback,
 | 
			
		||||
                        arg.inner_expr, fallback=config.unbacked_symint_fallback
 | 
			
		||||
                    )
 | 
			
		||||
                    args.append(str(hint))
 | 
			
		||||
                elif isinstance(arg, sympy.Expr):
 | 
			
		||||
                    hint = V.graph.sizevars.size_hint(
 | 
			
		||||
                        arg,
 | 
			
		||||
                        hint_override=self.hint_override,
 | 
			
		||||
                        fallback=config.unbacked_symint_fallback,
 | 
			
		||||
                        arg, fallback=config.unbacked_symint_fallback
 | 
			
		||||
                    )
 | 
			
		||||
                    args.append(str(hint))
 | 
			
		||||
                else:
 | 
			
		||||
@ -4173,11 +4329,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
                        f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})"  # type: ignore[arg-type]  # noqa: B950 line too long
 | 
			
		||||
                    )
 | 
			
		||||
                elif isinstance(arg_sig, SizeArg):
 | 
			
		||||
                    symval_hint = V.graph.sizevars.size_hint(
 | 
			
		||||
                        arg_sig.expr,
 | 
			
		||||
                        hint_override=self.hint_override,
 | 
			
		||||
                        fallback=config.unbacked_symint_fallback,
 | 
			
		||||
                    )
 | 
			
		||||
                    symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
 | 
			
		||||
 | 
			
		||||
                    # Force the seed_offset to be 0 so calls to the same kernel
 | 
			
		||||
                    # using different seed offset will have the same benchmark harness.
 | 
			
		||||
@ -4187,9 +4339,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
                    result.writeline(f"{var_name} = {symval_hint}")
 | 
			
		||||
                elif isinstance(arg_sig, WorkspaceArg):
 | 
			
		||||
                    device = V.graph.get_current_device_or_throw()
 | 
			
		||||
                    count = V.graph.sizevars.size_hint(
 | 
			
		||||
                        arg_sig.count, hint_override=self.hint_override
 | 
			
		||||
                    )
 | 
			
		||||
                    count = V.graph.sizevars.size_hint(arg_sig.count)
 | 
			
		||||
                    result.writeline(
 | 
			
		||||
                        f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})"
 | 
			
		||||
                    )
 | 
			
		||||
@ -4284,7 +4434,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
            "min_split_scan_rblock": config.triton.min_split_scan_rblock,
 | 
			
		||||
            "spill_threshold": config.triton.spill_threshold,
 | 
			
		||||
            "store_cubin": config.triton.store_cubin,
 | 
			
		||||
            "deterministic": config.deterministic,
 | 
			
		||||
        }
 | 
			
		||||
        if torch.version.hip is not None:
 | 
			
		||||
            inductor_meta["is_hip"] = True
 | 
			
		||||
@ -4315,6 +4464,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
        metadata, and benchmarking infra.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        from ..utils import prefix_is_reduction
 | 
			
		||||
 | 
			
		||||
        code = IndentedBuffer()
 | 
			
		||||
 | 
			
		||||
        size_hints = {}
 | 
			
		||||
@ -4426,6 +4577,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
 | 
			
		||||
            add_constexpr_arg(f"{tree.prefix.upper()}BLOCK")
 | 
			
		||||
 | 
			
		||||
        # In codegen_kernel(), replace the R0_BLOCK section with:
 | 
			
		||||
        if self._should_use_pointwise_inner_loop():
 | 
			
		||||
            add_constexpr_arg("R0_BLOCK")
 | 
			
		||||
 | 
			
		||||
        if self.cooperative_reduction:
 | 
			
		||||
            add_constexpr_arg("RSPLIT")
 | 
			
		||||
 | 
			
		||||
@ -4457,12 +4612,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
            **self.inductor_meta_common(),
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if config.deterministic or config.test_configs.force_filter_reduction_configs:
 | 
			
		||||
            inductor_meta["has_loadstore_with_contiguous_rdim"] = (
 | 
			
		||||
                self.has_load_with_contiguous_rdim
 | 
			
		||||
                or self.has_store_with_contiguous_rdim
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Bail on 3d tiling, which has more complicated coalesce patterns
 | 
			
		||||
        looped_red = V.kernel.features.is_reduction() and not self.persistent_reduction
 | 
			
		||||
        tiling_scores = self.tiling_scores
 | 
			
		||||
@ -4538,7 +4687,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
 | 
			
		||||
 | 
			
		||||
        triton_meta["configs"] = [config_of(signature)]
 | 
			
		||||
 | 
			
		||||
        if enable_pdl_codegen():
 | 
			
		||||
        if config.triton.enable_pdl:
 | 
			
		||||
            triton_meta["launch_pdl"] = True
 | 
			
		||||
 | 
			
		||||
        # Triton compiler includes equal_to_1 args into constants even
 | 
			
		||||
@ -4990,7 +5139,7 @@ class TritonScheduling(SIMDScheduling):
 | 
			
		||||
            )
 | 
			
		||||
        return cls.backend_features
 | 
			
		||||
 | 
			
		||||
    def codegen_comment(self, node_schedule, kernel_name=None):
 | 
			
		||||
    def codegen_comment(self, node_schedule):
 | 
			
		||||
        wrapper = V.graph.wrapper_code
 | 
			
		||||
        origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
 | 
			
		||||
        if origins:
 | 
			
		||||
@ -5016,13 +5165,6 @@ class TritonScheduling(SIMDScheduling):
 | 
			
		||||
                    f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if kernel_name:
 | 
			
		||||
            debug_handle = set_kernel_post_grad_provenance_tracing(
 | 
			
		||||
                node_schedule,  # type: ignore[arg-type]
 | 
			
		||||
                kernel_name,
 | 
			
		||||
            )
 | 
			
		||||
            wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
 | 
			
		||||
 | 
			
		||||
    def define_kernel(self, src_code, node_schedule, kernel):
 | 
			
		||||
        wrapper = V.graph.wrapper_code
 | 
			
		||||
        if src_code in wrapper.src_to_kernel:
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@ import re
 | 
			
		||||
import sys
 | 
			
		||||
import threading
 | 
			
		||||
import time
 | 
			
		||||
from collections import defaultdict, namedtuple
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
@ -30,9 +30,8 @@ from typing import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch._dynamo.utils import counters, set_feature_use
 | 
			
		||||
from torch._dynamo.utils import set_feature_use
 | 
			
		||||
from torch._environment import is_fbcode
 | 
			
		||||
from torch._inductor import metrics
 | 
			
		||||
from torch._prims_common import compute_required_storage_length
 | 
			
		||||
from torch.utils._ordered_set import OrderedSet
 | 
			
		||||
 | 
			
		||||
@ -237,7 +236,7 @@ def check_autotune_cache(
 | 
			
		||||
        not disabled
 | 
			
		||||
        and filename is not None
 | 
			
		||||
        and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
 | 
			
		||||
        and os.environ.get("TRITON_INTERPRET", "0") != "1"
 | 
			
		||||
        and not os.environ.get("TRITON_INTERPRET", "0") == "1"
 | 
			
		||||
    ):
 | 
			
		||||
        configs_hash = hash_configs(configs)
 | 
			
		||||
 | 
			
		||||
@ -312,8 +311,6 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
            "device_type": self.device_props.type,
 | 
			
		||||
        }
 | 
			
		||||
        self.inductor_meta = {} if inductor_meta is None else inductor_meta
 | 
			
		||||
        self.deterministic_mode = self.inductor_meta.get("deterministic", False)
 | 
			
		||||
 | 
			
		||||
        self.save_cache_hook = save_cache_hook
 | 
			
		||||
        self.mutated_arg_names = mutated_arg_names
 | 
			
		||||
        self.reset_to_zero_arg_names = (
 | 
			
		||||
@ -378,7 +375,7 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
        self.is_backward = False
 | 
			
		||||
 | 
			
		||||
        # Mode for launch grid calculation
 | 
			
		||||
        self.grid_mode: Literal["python", "cpp"] = "python"
 | 
			
		||||
        self.grid_mode: Literal["python", "python_slow", "cpp"] = "python"
 | 
			
		||||
 | 
			
		||||
    def is_statically_launchable(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -485,8 +482,7 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
        # Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
 | 
			
		||||
        device_prop = self.device_props
 | 
			
		||||
        if (
 | 
			
		||||
            not self.deterministic_mode
 | 
			
		||||
            and self.inductor_meta.get("dynamic_scale_rblock", True)
 | 
			
		||||
            self.inductor_meta.get("dynamic_scale_rblock", True)
 | 
			
		||||
            and not self.inductor_meta.get("persistent_reduction")
 | 
			
		||||
            and self.heuristic_type == HeuristicType.REDUCTION
 | 
			
		||||
            and self.size_hints is not None
 | 
			
		||||
@ -690,17 +686,9 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
 | 
			
		||||
        return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
 | 
			
		||||
 | 
			
		||||
    def _create_compile_meta(self, cfg: Config) -> dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Create compilation metadata for a given autotuner config. This involves
 | 
			
		||||
        processing the Config kwargs so that the kwargs that are not part
 | 
			
		||||
        of the triton signature are passed in as options to triton.compile
 | 
			
		||||
        instead
 | 
			
		||||
        """
 | 
			
		||||
    def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
 | 
			
		||||
        """Ahead of time compile a given autotuner config."""
 | 
			
		||||
        compile_meta = copy.deepcopy(self.triton_meta)
 | 
			
		||||
        compile_meta["num_warps"] = cfg.num_warps
 | 
			
		||||
        compile_meta["num_stages"] = cfg.num_stages
 | 
			
		||||
 | 
			
		||||
        cfg_kwargs = cfg.kwargs
 | 
			
		||||
        if self.device_props.type == "hip":
 | 
			
		||||
            cfg_kwargs = {**cfg_kwargs}
 | 
			
		||||
@ -708,13 +696,14 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
                if k in cfg_kwargs:
 | 
			
		||||
                    compile_meta[k] = cfg_kwargs.pop(k)
 | 
			
		||||
        compile_meta["constants"].update(cfg_kwargs)
 | 
			
		||||
 | 
			
		||||
        for i in self.fn.constexprs:
 | 
			
		||||
            arg_name = self.fn.arg_names[i]
 | 
			
		||||
            if arg_name not in compile_meta["constants"] and (
 | 
			
		||||
                arg_name == "num_warps" or arg_name == "num_stages"
 | 
			
		||||
            ):
 | 
			
		||||
                compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
 | 
			
		||||
        compile_meta["num_warps"] = cfg.num_warps
 | 
			
		||||
        compile_meta["num_stages"] = cfg.num_stages
 | 
			
		||||
        if HAS_WARP_SPEC:
 | 
			
		||||
            compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
 | 
			
		||||
            compile_meta["num_buffers_warp_spec"] = getattr(
 | 
			
		||||
@ -728,53 +717,6 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
        compile_meta["device_type"] = self.device_props.type
 | 
			
		||||
        compile_meta["cc"] = self.device_props.cc
 | 
			
		||||
 | 
			
		||||
        return compile_meta
 | 
			
		||||
 | 
			
		||||
    def _create_compile_options(
 | 
			
		||||
        self, cfg: Config, compile_meta: dict[str, Any]
 | 
			
		||||
    ) -> dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Create options to pass to triton.compile based on the compile metadata
 | 
			
		||||
        and the given config.
 | 
			
		||||
        """
 | 
			
		||||
        options = {
 | 
			
		||||
            "num_warps": compile_meta["num_warps"],
 | 
			
		||||
            "num_stages": compile_meta["num_stages"],
 | 
			
		||||
            "debug": compile_meta["debug"],
 | 
			
		||||
            "sanitize_overflow": False,  # turn off additional asserts added for overflow checks
 | 
			
		||||
        }
 | 
			
		||||
        if "enable_fp_fusion" in compile_meta:
 | 
			
		||||
            options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
 | 
			
		||||
        if HAS_WARP_SPEC:
 | 
			
		||||
            options.update(
 | 
			
		||||
                {
 | 
			
		||||
                    "num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
 | 
			
		||||
                    "num_buffers_warp_spec": compile_meta.get(
 | 
			
		||||
                        "num_buffers_warp_spec", 0
 | 
			
		||||
                    ),
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        if self.device_props.type == "cuda":
 | 
			
		||||
            options.update(
 | 
			
		||||
                {
 | 
			
		||||
                    "launch_cooperative_grid": compile_meta.get(
 | 
			
		||||
                        "launch_cooperative_grid", False
 | 
			
		||||
                    ),
 | 
			
		||||
                    "launch_pdl": compile_meta.get("launch_pdl", False),  # True
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        if self.device_props.type == "hip":
 | 
			
		||||
            if "waves_per_eu" in compile_meta:
 | 
			
		||||
                options["waves_per_eu"] = compile_meta["waves_per_eu"]
 | 
			
		||||
            if "matrix_instr_nonkdim" in compile_meta:
 | 
			
		||||
                options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
 | 
			
		||||
 | 
			
		||||
        return options
 | 
			
		||||
 | 
			
		||||
    def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
 | 
			
		||||
        """Ahead of time compile a given autotuner config."""
 | 
			
		||||
        compile_meta = self._create_compile_meta(cfg)
 | 
			
		||||
 | 
			
		||||
        if self.device_props.type == "cpu":
 | 
			
		||||
            triton_helpers.set_driver_to_cpu()
 | 
			
		||||
        else:
 | 
			
		||||
@ -807,8 +749,37 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
            cc_warp_size(compile_meta["cc"]),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        options = self._create_compile_options(cfg, compile_meta)
 | 
			
		||||
 | 
			
		||||
        options = {
 | 
			
		||||
            "num_warps": compile_meta["num_warps"],
 | 
			
		||||
            "num_stages": compile_meta["num_stages"],
 | 
			
		||||
            "debug": compile_meta["debug"],
 | 
			
		||||
            "sanitize_overflow": False,  # turn off additional asserts added for overflow checks
 | 
			
		||||
        }
 | 
			
		||||
        if "enable_fp_fusion" in compile_meta:
 | 
			
		||||
            options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
 | 
			
		||||
        if HAS_WARP_SPEC:
 | 
			
		||||
            options.update(
 | 
			
		||||
                {
 | 
			
		||||
                    "num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
 | 
			
		||||
                    "num_buffers_warp_spec": compile_meta.get(
 | 
			
		||||
                        "num_buffers_warp_spec", 0
 | 
			
		||||
                    ),
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        if self.device_props.type == "cuda":
 | 
			
		||||
            options.update(
 | 
			
		||||
                {
 | 
			
		||||
                    "launch_cooperative_grid": compile_meta.get(
 | 
			
		||||
                        "launch_cooperative_grid", False
 | 
			
		||||
                    ),
 | 
			
		||||
                    "launch_pdl": compile_meta.get("launch_pdl", False),  # True
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        if self.device_props.type == "hip":
 | 
			
		||||
            if "waves_per_eu" in compile_meta:
 | 
			
		||||
                options["waves_per_eu"] = compile_meta["waves_per_eu"]
 | 
			
		||||
            if "matrix_instr_nonkdim" in compile_meta:
 | 
			
		||||
                options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
 | 
			
		||||
        compile_kwargs = {
 | 
			
		||||
            "target": target,
 | 
			
		||||
            "options": options,
 | 
			
		||||
@ -898,34 +869,25 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
            )
 | 
			
		||||
            # reset to zero before evaluating any config
 | 
			
		||||
            self.reset_to_zero_args(*args, **kwargs)
 | 
			
		||||
            kernel_name = self.inductor_meta.get("kernel_name", "triton kernel")
 | 
			
		||||
            if autograd_profiler._is_profiler_enabled:
 | 
			
		||||
                profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
 | 
			
		||||
                with torch._C._profiler._RecordFunctionFast(
 | 
			
		||||
                    kernel_name,
 | 
			
		||||
                    self.inductor_meta.get("kernel_name", "triton kernel"),
 | 
			
		||||
                    cloned_args,
 | 
			
		||||
                    profiler_kwargs,
 | 
			
		||||
                ):
 | 
			
		||||
                    try:
 | 
			
		||||
                        launcher(
 | 
			
		||||
                            *cloned_args,
 | 
			
		||||
                            **cloned_kwargs,
 | 
			
		||||
                            stream=stream,
 | 
			
		||||
                        )
 | 
			
		||||
                    except Exception:
 | 
			
		||||
                        log.error("Failed during launch %s: ", kernel_name)
 | 
			
		||||
                        raise
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                try:
 | 
			
		||||
                    launcher(
 | 
			
		||||
                        *cloned_args,
 | 
			
		||||
                        **cloned_kwargs,
 | 
			
		||||
                        stream=stream,
 | 
			
		||||
                    )
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    log.error("Failed during launch %s: ", kernel_name)
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                launcher(
 | 
			
		||||
                    *cloned_args,
 | 
			
		||||
                    **cloned_kwargs,
 | 
			
		||||
                    stream=stream,
 | 
			
		||||
                )
 | 
			
		||||
            self.restore_args_from_cpu(cpu_copies)
 | 
			
		||||
 | 
			
		||||
        # only use profiler when not already in a profiler instance
 | 
			
		||||
@ -1092,18 +1054,6 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
                        k.shared,
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            if metrics.is_metric_table_enabled("kernel_autotune"):
 | 
			
		||||
                if self.fn.fn is None:
 | 
			
		||||
                    self.fn = self._reload_kernel().fn
 | 
			
		||||
 | 
			
		||||
                kernel_path = self.fn.fn.__code__.co_filename
 | 
			
		||||
                kernel_name = self.fn.__name__
 | 
			
		||||
 | 
			
		||||
                for k, v in timings.items():
 | 
			
		||||
                    metrics.log_kernel_autotune_result(
 | 
			
		||||
                        kernel_path, kernel_name, k.config, v
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            self.reset_to_zero_args(*args, **kwargs)
 | 
			
		||||
            return timings
 | 
			
		||||
 | 
			
		||||
@ -1198,26 +1148,13 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
        Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
 | 
			
		||||
        while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
 | 
			
		||||
        """
 | 
			
		||||
        if self.heuristic_type in (
 | 
			
		||||
            HeuristicType.TEMPLATE,
 | 
			
		||||
            HeuristicType.USER_AUTOTUNE,
 | 
			
		||||
            HeuristicType.FIXED,
 | 
			
		||||
        if (
 | 
			
		||||
            self.heuristic_type == HeuristicType.TEMPLATE
 | 
			
		||||
            or self.heuristic_type == HeuristicType.USER_AUTOTUNE
 | 
			
		||||
        ):
 | 
			
		||||
            # skip triton template
 | 
			
		||||
            return launcher
 | 
			
		||||
 | 
			
		||||
        if self.deterministic_mode and self.heuristic_type in (
 | 
			
		||||
            HeuristicType.REDUCTION,
 | 
			
		||||
            HeuristicType.PERSISTENT_REDUCTION,
 | 
			
		||||
            HeuristicType.SPLIT_SCAN,
 | 
			
		||||
        ):
 | 
			
		||||
            # Not only RBLOCK size matters for numericals of reduction.
 | 
			
		||||
            # num_warps also matters since that affect how much data
 | 
			
		||||
            # is handled by each thread, how many warp-reduction we do
 | 
			
		||||
            # in parallel and how much data is there for block
 | 
			
		||||
            # reduction.
 | 
			
		||||
            return launcher
 | 
			
		||||
 | 
			
		||||
        with dynamo_timed(
 | 
			
		||||
            "CachingAutotuner.coordinate_descent_tuning",
 | 
			
		||||
            # These generate too many pt2_compile_event logs:
 | 
			
		||||
@ -1251,7 +1188,6 @@ class CachingAutotuner(KernelInterface):
 | 
			
		||||
            config2launcher[config] = launcher
 | 
			
		||||
 | 
			
		||||
            out = self.bench(launcher, *args, **kwargs)
 | 
			
		||||
            counters["inductor"]["coordesc_tuning_bench"] += 1
 | 
			
		||||
            log.debug(
 | 
			
		||||
                "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
 | 
			
		||||
                launcher.config,
 | 
			
		||||
@ -2574,7 +2510,27 @@ def pointwise(
 | 
			
		||||
 | 
			
		||||
    configs = None
 | 
			
		||||
    if len(size_hints) == 1:
 | 
			
		||||
        if disable_pointwise_autotuning(inductor_meta) and not (
 | 
			
		||||
 | 
			
		||||
        use_looped_pointwise = True
 | 
			
		||||
        has_r0_block = any(
 | 
			
		||||
            arg == "R0_BLOCK" 
 | 
			
		||||
            for arg in triton_meta.get("signature", {}).keys()
 | 
			
		||||
        )
 | 
			
		||||
        if has_r0_block:
 | 
			
		||||
            configs = [
 | 
			
		||||
                Config({"XBLOCK": 512, "R0_BLOCK": 256}, num_warps=4, num_stages=1),
 | 
			
		||||
                Config({"XBLOCK": 1024, "R0_BLOCK": 512}, num_warps=4, num_stages=1),
 | 
			
		||||
                Config({"XBLOCK": 2048, "R0_BLOCK": 1024}, num_warps=8, num_stages=1),
 | 
			
		||||
                Config({"XBLOCK": 4096, "R0_BLOCK": 2048}, num_warps=8, num_stages=1),
 | 
			
		||||
            ]
 | 
			
		||||
            #configs = [
 | 
			
		||||
            #    triton_config_with_settings(size_hints, 512),
 | 
			
		||||
            #    triton_config_with_settings(size_hints, 1024),
 | 
			
		||||
            #    triton_config_with_settings(size_hints, 2048),
 | 
			
		||||
            #    triton_config_with_settings(size_hints, 4096),
 | 
			
		||||
            #]
 | 
			
		||||
 | 
			
		||||
        elif disable_pointwise_autotuning(inductor_meta) and not (
 | 
			
		||||
            inductor_meta.get("max_autotune")
 | 
			
		||||
            or inductor_meta.get("max_autotune_pointwise")
 | 
			
		||||
        ):
 | 
			
		||||
@ -2588,7 +2544,36 @@ def pointwise(
 | 
			
		||||
                *hinted_configs,
 | 
			
		||||
            ]
 | 
			
		||||
    if len(size_hints) == 2:
 | 
			
		||||
        if (
 | 
			
		||||
        use_looped_pointwise = True
 | 
			
		||||
        has_r0_block = any(
 | 
			
		||||
            arg == "R0_BLOCK" 
 | 
			
		||||
            for arg in triton_meta.get("signature", {}).keys()
 | 
			
		||||
        )
 | 
			
		||||
        if has_r0_block:
 | 
			
		||||
            # Find which dimension is larger
 | 
			
		||||
            dim_sizes = [(name, size) for name, size in size_hints.items()]
 | 
			
		||||
            larger_dim, larger_size = max(dim_sizes, key=lambda x: x[1])
 | 
			
		||||
            smaller_dim, smaller_size = min(dim_sizes, key=lambda x: x[1])
 | 
			
		||||
            
 | 
			
		||||
            # Split the larger dimension with R0_BLOCK
 | 
			
		||||
            larger_block = larger_dim.upper() + "BLOCK"
 | 
			
		||||
            smaller_block = smaller_dim.upper() + "BLOCK"
 | 
			
		||||
            
 | 
			
		||||
            configs = [
 | 
			
		||||
                Config({larger_block: 512, "R0_BLOCK": 256, smaller_block: 32}, num_warps=4, num_stages=1),
 | 
			
		||||
                Config({larger_block: 1024, "R0_BLOCK": 512, smaller_block: 32}, num_warps=4, num_stages=1),
 | 
			
		||||
                Config({larger_block: 512, "R0_BLOCK": 256, smaller_block: 64}, num_warps=4, num_stages=1),
 | 
			
		||||
                Config({larger_block: 1024, "R0_BLOCK": 512, smaller_block: 64}, num_warps=8, num_stages=1),
 | 
			
		||||
            ]
 | 
			
		||||
            #configs = [
 | 
			
		||||
            #    triton_config_with_settings(size_hints, 32, 32),
 | 
			
		||||
            #    triton_config_with_settings(size_hints, 64, 64),
 | 
			
		||||
            #    triton_config_with_settings(size_hints, 256, 16),
 | 
			
		||||
            #    triton_config_with_settings(size_hints, 16, 256),
 | 
			
		||||
            #]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        elif (
 | 
			
		||||
            disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
 | 
			
		||||
        ) and not (
 | 
			
		||||
            inductor_meta.get("max_autotune")
 | 
			
		||||
@ -2878,144 +2863,6 @@ def adapt_config_for_tiling(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReductionConfigKey:
 | 
			
		||||
    """
 | 
			
		||||
    The part of reduction configs that affect determinism.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: Config):
 | 
			
		||||
        # persistent reduction does not have a RBLOCK, use -1 as a flag
 | 
			
		||||
        self.r0_block = config.kwargs.get("R0_BLOCK", -1)
 | 
			
		||||
        self.r1_block = config.kwargs.get("R1_BLOCK", -1)
 | 
			
		||||
        self.num_warps = config.num_warps
 | 
			
		||||
        self.num_ctas = config.num_ctas
 | 
			
		||||
 | 
			
		||||
    def __hash__(self) -> int:
 | 
			
		||||
        return hash((self.r0_block, self.r1_block, self.num_warps, self.num_ctas))
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other: object) -> bool:
 | 
			
		||||
        return (
 | 
			
		||||
            isinstance(other, ReductionConfigKey)
 | 
			
		||||
            and self.r0_block == other.r0_block
 | 
			
		||||
            and self.r1_block == other.r1_block
 | 
			
		||||
            and self.num_warps == other.num_warps
 | 
			
		||||
            and self.num_ctas == other.num_ctas
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def filter_reduction_configs_for_determinism(
 | 
			
		||||
    inductor_meta: dict[str, Any], configs: list[Config]
 | 
			
		||||
) -> list[Config]:
 | 
			
		||||
    """
 | 
			
		||||
    Filter configs for reduction so the numerics can be deterministic.
 | 
			
		||||
 | 
			
		||||
    This function group configs by fields that affect determinism
 | 
			
		||||
    - rblock size
 | 
			
		||||
    - num warps
 | 
			
		||||
    - num ctas
 | 
			
		||||
    and return the most promising group based on heuristics.
 | 
			
		||||
 | 
			
		||||
    Heuristics:
 | 
			
		||||
    - skip reduction configs with too small RBLOCK
 | 
			
		||||
    - skip reduction configs with XBLOCK==1 if we are confident it will not perform well
 | 
			
		||||
    - pick the group with largest size: autotuning more configs may have more chance to give better perf
 | 
			
		||||
    - if there is a tie, pick the group with second largest RBLOCK
 | 
			
		||||
    - if there is still a tie, pick the group with second largest num_warps
 | 
			
		||||
    """
 | 
			
		||||
    configs = unique_configs(configs)
 | 
			
		||||
    assert len(configs) > 0
 | 
			
		||||
 | 
			
		||||
    def _do_filter_due_to_inductor_config():
 | 
			
		||||
        return (
 | 
			
		||||
            inductor_meta.get("deterministic", False)
 | 
			
		||||
            or torch._inductor.config.test_configs.force_filter_reduction_configs
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if not _do_filter_due_to_inductor_config() or len(configs) == 1:
 | 
			
		||||
        # no filtering happening if NOT in deterministic mode
 | 
			
		||||
        return configs
 | 
			
		||||
 | 
			
		||||
    if log.isEnabledFor(logging.DEBUG):
 | 
			
		||||
        log.debug("reduction configs before filtering:")
 | 
			
		||||
        for c in configs:
 | 
			
		||||
            log.debug("%s", c)
 | 
			
		||||
            log.debug("")
 | 
			
		||||
 | 
			
		||||
    def _has_too_small_rblock(config):
 | 
			
		||||
        rblock = config.kwargs.get("R0_BLOCK")
 | 
			
		||||
        # too small RBLOCK is likely to be bad
 | 
			
		||||
        return rblock is not None and rblock <= 4
 | 
			
		||||
 | 
			
		||||
    def _nonpromising_xblock_1(config):
 | 
			
		||||
        # kernel like https://gist.github.com/shunting314/0b3281c087e79bc915fe45985ff9d7d5
 | 
			
		||||
        # without a load/store having contiguous rdim is unlikely to perform well with XBLOCK==1
 | 
			
		||||
        return config.kwargs["XBLOCK"] == 1 and not inductor_meta.get(
 | 
			
		||||
            "has_loadstore_with_contiguous_rdim", True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    newconfigs = [*filter(lambda x: not _has_too_small_rblock(x), configs)]
 | 
			
		||||
    # accept the filtering only if there are configs left
 | 
			
		||||
    if len(newconfigs) > 0:
 | 
			
		||||
        configs = newconfigs
 | 
			
		||||
 | 
			
		||||
    newconfigs = [*filter(lambda x: not _nonpromising_xblock_1(x), configs)]
 | 
			
		||||
    if len(newconfigs) > 0:
 | 
			
		||||
        configs = newconfigs
 | 
			
		||||
 | 
			
		||||
    groups: defaultdict[ReductionConfigKey, list[Config]] = defaultdict(
 | 
			
		||||
        list
 | 
			
		||||
    )  # group configs by RBLOCK, num_warps, num_ctas
 | 
			
		||||
 | 
			
		||||
    for c in configs:
 | 
			
		||||
        key = ReductionConfigKey(c)
 | 
			
		||||
        groups[key].append(c)
 | 
			
		||||
 | 
			
		||||
    assert len(groups) > 0
 | 
			
		||||
 | 
			
		||||
    def _pick_group():
 | 
			
		||||
        grouplist = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
 | 
			
		||||
        max_group_size = len(grouplist[0][1])
 | 
			
		||||
        grouplist = [*filter(lambda g: len(g[1]) == max_group_size, grouplist)]
 | 
			
		||||
 | 
			
		||||
        assert len(grouplist) > 0
 | 
			
		||||
        if len(grouplist) == 1:
 | 
			
		||||
            return grouplist[0][1]
 | 
			
		||||
 | 
			
		||||
        # break tie by R0_BLOCK
 | 
			
		||||
        grouplist = sorted(grouplist, key=lambda x: x[0].r0_block)
 | 
			
		||||
        if grouplist[0][0].r0_block != grouplist[-1][0].r0_block:
 | 
			
		||||
            max_r0_block = grouplist[-1][0].r0_block
 | 
			
		||||
            grouplist = [*filter(lambda x: x[0].r0_block != max_r0_block, grouplist)]
 | 
			
		||||
            second_max_r0_block = grouplist[-1][0].r0_block
 | 
			
		||||
            grouplist = [
 | 
			
		||||
                *filter(lambda x: x[0].r0_block == second_max_r0_block, grouplist)
 | 
			
		||||
            ]
 | 
			
		||||
        if len(grouplist) == 1:
 | 
			
		||||
            return grouplist[0][1]
 | 
			
		||||
 | 
			
		||||
        # break tie by num_warps
 | 
			
		||||
        grouplist = sorted(grouplist, key=lambda x: x[0].num_warps)
 | 
			
		||||
        if grouplist[0][0].num_warps != grouplist[-1][0].num_warps:
 | 
			
		||||
            max_num_warps = grouplist[-1][0].num_warps
 | 
			
		||||
            grouplist = [*filter(lambda x: x[0].num_warps != max_num_warps, grouplist)]
 | 
			
		||||
            second_max_num_warps = grouplist[-1][0].num_warps
 | 
			
		||||
            grouplist = [
 | 
			
		||||
                *filter(lambda x: x[0].num_warps == second_max_num_warps, grouplist)
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        # there is still a tie, pick the first one
 | 
			
		||||
        return grouplist[0][1]
 | 
			
		||||
 | 
			
		||||
    configs = _pick_group()
 | 
			
		||||
 | 
			
		||||
    if log.isEnabledFor(logging.DEBUG):
 | 
			
		||||
        log.debug("reduction configs after filtering:")
 | 
			
		||||
        for c in configs:
 | 
			
		||||
            log.debug("%s", c)
 | 
			
		||||
            log.debug("")
 | 
			
		||||
    return configs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reduction(
 | 
			
		||||
    size_hints,
 | 
			
		||||
    reduction_hint=False,
 | 
			
		||||
@ -3041,7 +2888,6 @@ def reduction(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
 | 
			
		||||
    configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
 | 
			
		||||
    return cached_autotune(
 | 
			
		||||
        size_hints,
 | 
			
		||||
        configs=configs,
 | 
			
		||||
@ -3089,7 +2935,6 @@ def cooperative_reduction(
 | 
			
		||||
    # TODO(jansel): add more configs in max_autotune
 | 
			
		||||
 | 
			
		||||
    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
 | 
			
		||||
    configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
 | 
			
		||||
    return cached_autotune(
 | 
			
		||||
        size_hints,
 | 
			
		||||
        configs=configs,
 | 
			
		||||
@ -3141,12 +2986,15 @@ def _persistent_reduction_configs(
 | 
			
		||||
    if "y" in size_hints:
 | 
			
		||||
        pass
 | 
			
		||||
    # TODO(jansel): we should be able to improve these heuristics
 | 
			
		||||
    elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
 | 
			
		||||
    elif reduction_hint == ReductionHint.INNER:
 | 
			
		||||
        if rnumel > 1024:
 | 
			
		||||
            configs = configs[:1]
 | 
			
		||||
        else:
 | 
			
		||||
            x_block = 8
 | 
			
		||||
            if xnumel // x_block < 128 or loads_and_stores >= 5:
 | 
			
		||||
            if xnumel // x_block < 128 or (loads_and_stores >= 5 and rnumel >= 256):
 | 
			
		||||
                # If loads/stores greater than 5, a lot of register pressure
 | 
			
		||||
                # rnumel < 256 means no vectorized loads if we split up r dim
 | 
			
		||||
                # so xblock still needs to be larger
 | 
			
		||||
                x_block = 1
 | 
			
		||||
 | 
			
		||||
            configs = [
 | 
			
		||||
@ -3202,7 +3050,6 @@ def persistent_reduction(
 | 
			
		||||
    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
 | 
			
		||||
    inductor_meta.pop(persistent_reduction_key)
 | 
			
		||||
 | 
			
		||||
    configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
 | 
			
		||||
    return cached_autotune(
 | 
			
		||||
        size_hints,
 | 
			
		||||
        configs,
 | 
			
		||||
@ -3240,7 +3087,6 @@ def split_scan(
 | 
			
		||||
                cfg.kwargs[var] = min_rblock
 | 
			
		||||
 | 
			
		||||
    configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
 | 
			
		||||
    configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
 | 
			
		||||
    return cached_autotune(
 | 
			
		||||
        size_hints,
 | 
			
		||||
        configs=configs,
 | 
			
		||||
@ -3380,14 +3226,14 @@ class GridExpr:
 | 
			
		||||
    """Generate code for grid size expressions in launcher"""
 | 
			
		||||
 | 
			
		||||
    inductor_meta: dict[str, Any]
 | 
			
		||||
    mode: Literal["python", "cpp"] = "python"
 | 
			
		||||
    mode: Literal["python", "cpp", "python_slow"] = "python"
 | 
			
		||||
    prefix: list[str] = dataclasses.field(default_factory=list)
 | 
			
		||||
    x_grid: Union[str, int] = 1
 | 
			
		||||
    y_grid: Union[str, int] = 1
 | 
			
		||||
    z_grid: Union[str, int] = 1
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        assert self.mode in ("python", "cpp")
 | 
			
		||||
        assert self.mode in ("python", "cpp", "python_slow")
 | 
			
		||||
 | 
			
		||||
    def generate(self, meta: dict[str, int]) -> None:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
@ -3403,6 +3249,10 @@ class GridExpr:
 | 
			
		||||
        # negative integer division is floored
 | 
			
		||||
        if self.mode == "python":
 | 
			
		||||
            return f"-(({numel}) // -({block}))"
 | 
			
		||||
        # This is more generic than above, and works in languages where
 | 
			
		||||
        # positive integer division is floored/truncated
 | 
			
		||||
        elif self.mode == "python_slow":
 | 
			
		||||
            return f"(({numel} + {block} - 1) // ({block}))"
 | 
			
		||||
        # For cpp code gen
 | 
			
		||||
        return f"(({numel} + ({block} - 1)) / ({block}))"
 | 
			
		||||
 | 
			
		||||
@ -3411,7 +3261,7 @@ class GridExpr:
 | 
			
		||||
        items = self._constant_fold(max, seq)
 | 
			
		||||
        if len(items) <= 1:
 | 
			
		||||
            return items[0]
 | 
			
		||||
        if self.mode == "python":
 | 
			
		||||
        if self.mode in ("python", "python_slow"):
 | 
			
		||||
            return f"max({', '.join(map(str, items))})"
 | 
			
		||||
        return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
 | 
			
		||||
 | 
			
		||||
@ -3434,7 +3284,7 @@ class GridExpr:
 | 
			
		||||
 | 
			
		||||
    def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
 | 
			
		||||
        # Grid functions are one per kernel, so name collisions are fine
 | 
			
		||||
        if self.mode == "python":
 | 
			
		||||
        if self.mode in ("python", "python_slow"):
 | 
			
		||||
            return f"{name} = {expr}"
 | 
			
		||||
        if self.mode == "cpp":
 | 
			
		||||
            return f"uint32_t {name} = {expr};"
 | 
			
		||||
@ -3444,7 +3294,7 @@ class GridExpr:
 | 
			
		||||
    def from_meta(
 | 
			
		||||
        inductor_meta: dict[str, Any],
 | 
			
		||||
        cfg: Union[Config, dict[str, int]],
 | 
			
		||||
        mode: Literal["python", "cpp"] = "python",
 | 
			
		||||
        mode: Literal["python", "cpp", "python_slow"] = "python",
 | 
			
		||||
    ) -> GridExpr:
 | 
			
		||||
        grid_cls = globals()[inductor_meta["grid_type"]]
 | 
			
		||||
        assert issubclass(grid_cls, GridExpr)
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user