Compare commits

...

3 Commits

Author SHA1 Message Date
25e1e99fcd new 2025-10-08 16:09:43 +00:00
f803a57a7a Fixes 2025-10-06 22:51:49 +00:00
39287c561b Fixes 2025-10-06 21:54:04 +00:00
3 changed files with 482 additions and 488 deletions

View File

@ -3573,6 +3573,7 @@ class CommonTemplate:
self.assertEqual(actual, expect)
@skip_if_halide # only 32-bit indexing
@skipIfRocm
@largeTensorTest("4GB", inductor=True)
def test_large_pointwise(self):
def fn(a):
@ -12962,6 +12963,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
if name not in {"softmax", "log_softmax", "logsumexp"}
],
)
@skipIfRocm
def test_pointwise(self, name, op):
dtype = torch.float32
check_lowp = True

View File

@ -34,7 +34,6 @@ from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
from ..async_compile import AsyncCompile
from ..codecache import code_hash, get_path, PyCodeCache, write_atomic
from ..debug import set_kernel_post_grad_provenance_tracing
from ..ops_handler import DefaultHandler
from ..runtime import triton_heuristics
from ..runtime.benchmarking import benchmarker
@ -48,7 +47,6 @@ from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
from ..utils import (
cache_on_self,
DelayMaybeLine,
DelayReplaceLine,
get_bounds_index_expr,
get_fused_kernel_name,
@ -75,7 +73,6 @@ from .common import (
DeferredLine,
IndentedBuffer,
InplacedBuffer,
is_buffer_removed,
OpOverrides,
PythonPrinter,
RemovedArg,
@ -641,13 +638,6 @@ def triton_reshape(
return f"{value}[{', '.join(expand)}]"
def enable_pdl_codegen():
if not torch._inductor.config.triton.enable_pdl:
return False
major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
return major >= 9
# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a
# number of operators which Triton "implements", but in a way that is
# inconsistent with Python semantics (and consistent with C semantics). We
@ -1607,6 +1597,10 @@ class TritonKernelOverrides(TritonOverrides):
V.kernel.cse.put(cache_key, (mantissa, exponent))
return (mantissa, exponent)
@staticmethod
def device_assert_async(cond, msg):
return f"tl.device_assert({cond}, {repr(msg)})"
class HelperFunctions:
"""An ordered set of helper functions."""
@ -1950,8 +1944,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.fixed_config = fixed_config
super().__init__(tiling, **kwargs)
self.cse = TritonCSE(self.newvar_prefix, self.suffix)
# Cache of values that can be reused for the prologue.
self.prologue_cache: dict[str, str] = {}
self.prologue: IndentedBuffer = IndentedBuffer()
self.post_loop_combine: IndentedBuffer = IndentedBuffer()
self.post_loop_store: IndentedBuffer = IndentedBuffer()
@ -1966,7 +1958,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.tma_min_block_sizes = dict[str, int]()
self.hint_override = hint_override
self._load_counts: collections.Counter[str] = collections.Counter()
self._load_index = 0
# A set of autotuning hints to pass as part of triton_meta
self.autotune_hints = OrderedSet[AutotuneHint]()
@ -1983,44 +1974,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
if self.cooperative_reduction:
self.init_cooperative_reduction_mask()
self.has_load_with_contiguous_rdim = False
# We track the store name since a store can be canceled later
self.stores_with_contiguous_rdim: list[str] = []
@staticmethod
def _has_stride1_on_rdim(index) -> bool:
# These analysis is only needed in deterministic mode so far
# to filter triton configs. Return false immediately to avoid
# increasing compilation time when the mode is off.
if not (
config.deterministic or config.test_configs.force_filter_reduction_configs
):
return False
support_vars = index.free_symbols
reduce_vars = [
var
for var in support_vars
if symbol_is_type(var, TritonSymbols.reduction_types)
]
if len(reduce_vars) == 0:
return False
# for expression "x0 + 150528*((x1//(s27*s38))) + 3*(ModularIndexing(x1, 1, s38)) + 672*(ModularIndexing(x1, s38, s27))"
# stride_vars will results in DivisionByZero error
try:
stride_vars = V.graph.sizevars.stride_vars(index, reduce_vars, support_vars)
except ZeroDivisionError:
return False
return any(stride == 1 for stride in stride_vars)
@property
def has_store_with_contiguous_rdim(self) -> bool:
return not all(
is_buffer_removed(name) for name in self.stores_with_contiguous_rdim
)
def dtype_to_str(self, dtype: torch.dtype) -> str:
return triton_type(dtype)
@ -2088,6 +2041,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
)
def codegen_range_tree(self):
# Skip if pointwise inner loop
#return
for tree in self.range_trees:
# reduction indexing goes inside a loop
if not tree.is_loop:
@ -2532,95 +2489,86 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
and self.range_trees[-1].is_loop
and indexing.has_rindex()
) or indexing.can_lift:
if indexing.can_lift and var in self.prologue_cache:
# Check for epilogue subtiling to reuse the same
# tensor descriptor.
block_descriptor = self.prologue_cache[var]
block_descriptor_id = next(self.block_ptr_id)
if isinstance(indexing, BlockPtrOptions):
block_descriptor = f"block_ptr{block_descriptor_id}"
else:
block_descriptor_id = next(self.block_ptr_id)
if isinstance(indexing, BlockPtrOptions):
block_descriptor = f"block_ptr{block_descriptor_id}"
else:
block_descriptor = f"tma_descriptor{block_descriptor_id}"
line_body = DeferredLine(
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
)
if indexing.can_lift:
self.prologue.writeline(line_body)
# Cache the descriptor for epilogue subtiling
self.prologue_cache[var] = block_descriptor
else:
self.body.writeline(line_body)
block_descriptor = f"tma_descriptor{block_descriptor_id}"
line_body = DeferredLine(
name, f"{block_descriptor} = {indexing.format(var, roffset=False)}"
)
if indexing.can_lift:
self.prologue.writeline(line_body)
else:
self.body.writeline(line_body)
if isinstance(indexing, BlockPtrOptions):
# Store for later use. If the buffer is removed the below advancements
# are no longer necessary
self.block_ptr_to_buffer[block_descriptor] = name
if isinstance(indexing, BlockPtrOptions):
# Store for later use. If the buffer is removed the below advancements
# are no longer necessary
self.block_ptr_to_buffer[block_descriptor] = name
# Generate block pointer advancements, for later use.
for symt in TritonSymbols.reduction_types:
advance_offsets = indexing.advance_roffset(symt)
# Generate block pointer advancements, for later use.
for symt in TritonSymbols.reduction_types:
advance_offsets = indexing.advance_roffset(symt)
# Ignore identity advancements.
if all(
V.graph.sizevars.statically_known_equals(
offset, sympy.Integer(0)
)
for offset in advance_offsets
):
continue
advancements = self.pointer_advancements[symt]
assert block_descriptor not in advancements, (
f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
# Ignore identity advancements.
if all(
V.graph.sizevars.statically_known_equals(
offset, sympy.Integer(0)
)
advancements[block_descriptor] = advance_offsets
for offset in advance_offsets
):
continue
advancements = self.pointer_advancements[symt]
assert block_descriptor not in advancements, (
f"duplicate advancement for pointer '{block_descriptor}' at type '{symt}'"
)
advancements[block_descriptor] = advance_offsets
else:
block_descriptor = indexing.format(var)
return block_descriptor, other
def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""):
def stringify_shape(shape):
return tuple(
symt.name if isinstance(symt, sympy.Symbol) else str(symt)
for symt in shape
)
if value.shape:
value_forward_shape = stringify_shape(value.shape)
value_reverse_shape = stringify_shape(value.shape[::-1])
# TMA stores may require transposing the data to ensure we are contiguous along
# the final dimension. We do this by checking the shape information on value.
# It can either
# 1. Match the final shape. In this case no broadcast/reshape
# is necessary.
# 2. Exist as the Transpose of the final shape, which means we had to transpose
# the store_descriptor relative to the accumulator indexing/value. If this
# happens we will generate a tl.trans().
# 3. A mismatched provided shape. When this occurs we will error.
# 4. No shape is provided. This will proceed with the default explicit broadcast
# described below.
#
# To prevent unintended side effects we will gate options 1-3 behind isinstance(indexing, TensorDescriptorOptions).
if isinstance(indexing, TensorDescriptorOptions) and value.shape:
str_final_shape = tuple([symt.name for symt in indexing.final_shape])
if value.shape[::-1] == str_final_shape:
value = f"tl.trans({value})"
elif value.shape != str_final_shape:
raise AssertionError(
"TMA store requires no broadcasting when a shape is provided"
)
else:
value_forward_shape = None
value_reverse_shape = None
final_shape = stringify_shape(indexing.final_shape)
# TODO: Generalize to N Dimensions
if (
value_forward_shape != final_shape
and value_reverse_shape == final_shape
and len(final_shape) == 2
):
# TMA stores may require transposing the data to ensure we are contiguous along
# the final dimension. This applies to Block-pointers generally, but should only practically
# be reached with TMA.
value = f"tl.trans({value})"
# Stores require an explicit broadcast. We do this in two phases:
# 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
# YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
# 2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the
# result to the shape of the pointer.
value = f"tl.broadcast_to({value}, {indexing.final_shape})"
# Stores require an explicit broadcast. We do this in two phases:
# 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK,
# YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads.
# 2. In case the block pointer / tma descriptor has different dimensionality, broadcast/reshape the
# result to the shape of the pointer.
value = f"tl.broadcast_to({value}, {indexing.final_shape})"
# These dims no longer need broadcasting.
for idx, (dim, broadcast_dim) in enumerate(
zip(indexing.final_shape, indexing.broadcast_shape)
):
if V.graph.sizevars.statically_known_equals(dim, broadcast_dim):
indexing.broadcasting_dims[idx] = False
# These dims no longer need broadcasting.
for idx, (dim, broadcast_dim) in enumerate(
zip(indexing.final_shape, indexing.broadcast_shape)
):
if V.graph.sizevars.statically_known_equals(dim, broadcast_dim):
indexing.broadcasting_dims[idx] = False
value = indexing.codegen_broadcast_and_reshape(
value, indexing.final_shape, indexing.block_shape, False
)
value = indexing.codegen_broadcast_and_reshape(
value, indexing.final_shape, indexing.block_shape, False
)
# workaround https://github.com/triton-lang/triton/issues/2814
value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})"
@ -2669,27 +2617,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
else:
return self.loads
def _handle_pdl_before_load(self, wait_buffer):
GDC_WAIT = "tl.extra.cuda.gdc_wait()"
self._load_index += 1
if self.inside_reduction:
wait_buffer = self.body
if enable_pdl_codegen():
if self._load_index == 1:
wait_buffer.writeline(GDC_WAIT)
def _handle_pdl_after_load(self, launch_buffer, result_var):
GDC_LAUNCH = "tl.extra.cuda.gdc_launch_dependents()"
if self.inside_reduction:
launch_buffer = self.post_loop_combine
if enable_pdl_codegen():
current_load_index = self._load_index
launch_if_last_load = DelayMaybeLine(
lambda: current_load_index == self._load_index,
f"0; {GDC_LAUNCH} # gdc launch for {result_var}",
)
self.cse.generate(launch_buffer, launch_if_last_load, dtype=torch.int32)
def load(self, name: str, index: sympy.Expr):
"""
Load from the memory location 'name', offset by some indexing expression 'index'.
@ -2711,12 +2638,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
force=False,
),
)
if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim(
indexing.index
):
self.has_load_with_contiguous_rdim = True
has_rindex = indexing.has_rindex()
has_tmpmask = indexing.has_tmpmask()
@ -2826,11 +2747,11 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
dtype = torch.bool
load_buffer = self.get_load_buffer(indexing)
self._handle_pdl_before_load(load_buffer)
if config.triton.enable_pdl:
load_buffer.writeline("tl.extra.cuda.gdc_wait()")
result_var = self.cse.generate(
load_buffer, make_line(line), dtype=dtype, shape=shape
)
self._handle_pdl_after_load(load_buffer, result_var)
if result_var.use_count > 1:
load_counts[name] -= 1 # don't double count cache hit
assert isinstance(result_var, TritonCSEVariable)
@ -2884,11 +2805,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
tma_compatibility_checker=tma_compatibility_checker,
)
if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim(
indexing.index
):
self.stores_with_contiguous_rdim.append(name)
# Guard against write-after-read corruption in triton.
# See # https://github.com/triton-lang/triton/issues/1615
# This triton bug means that a load which is broadcasted over multiple
@ -2924,9 +2840,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
exit_stack.close()
def device_assert_async(self, cond, msg) -> None:
self.compute.writeline(f"tl.device_assert({cond}, {repr(msg)})")
def guard_cooperative_store(self, name, buffer):
"""
For cooperative reductions only one thread block should write out the result.
@ -2984,7 +2897,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
"Bucketize only supports indexing with int32 and int64"
)
self._handle_pdl_before_load(self.compute)
result = self.cse.generate(
self.compute,
f"triton_helpers.bucketize_binary_search({values}, "
@ -2998,7 +2910,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
dtype=indexing_dtype, # type: ignore[attr-defined]
shape=values.shape,
)
self._handle_pdl_after_load(self.compute, result)
masks = self._combine_masks(values, boundary_indices, sorter_indices)
result.mask_vars = masks # type: ignore[attr-defined]
@ -3175,34 +3086,14 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
if self.persistent_reduction:
default = ir.Reduction.default_value(reduction_type, src_dtype)
def update_constant_dtype(constant, src_dtype, dst_dtype):
"update reduction constant mask value to match dst_dtype"
# int is the only mask which may not fit within lower bitwidth,
# because float uses inf/-inf
if src_dtype.is_floating_point or src_dtype == torch.bool:
return constant
if src_dtype == dst_dtype or constant == 0:
return constant
if constant == torch.iinfo(src_dtype).max:
return torch.iinfo(dst_dtype).max
elif constant == torch.iinfo(src_dtype).min:
return torch.iinfo(dst_dtype).min
else:
return constant
default = self._map_tuple_or_scalar(constant_repr, default)
def _mask_value(value, default) -> CSEVariable:
default = update_constant_dtype(default, src_dtype, value.dtype)
default_str = self._map_tuple_or_scalar(constant_repr, default)
return self.cse.generate(
self.compute,
where_cond(value, default_str),
where_cond(value, default),
dtype=value.dtype,
shape=value.shape,
shape=value.shape if value.shape is not None else default.shape,
)
masked_value: Union[CSEVariable, Sequence[CSEVariable]]
@ -3211,14 +3102,13 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
# will fallback below
pass
elif isinstance(value, tuple):
masked_value = [_mask_value(v, d) for v, d in zip(value, default)] # type: ignore[arg-type]
masked_value = [_mask_value(v, d) for v, d in zip(value, default)]
else:
masked_value = _mask_value(value, default)
if reduction_type in ("argmax", "argmin"):
assert isinstance(masked_value, CSEVariable)
accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype()
accumulator_index = str(
self.cse.generate(
self.compute,
@ -3993,7 +3883,268 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
code.splice(self.prologue)
self.prologue.clear()
self.prologue_cache.clear()
def _should_use_pointwise_inner_loop(self):
"""Check if we should use inner loop optimization for this pointwise kernel."""
# Safety check: never apply to reduction kernels
if self.inside_reduction:
return False
# Only apply to non-template kernels
if self.fixed_config or self.cooperative_reduction or self.persistent_reduction:
return False
# Check for operations incompatible with inner loop
code_str = ""
if hasattr(self, 'indexing_code') and self.indexing_code:
code_str += str(self.indexing_code)
if hasattr(self, 'loads') and self.loads:
code_str += str(self.loads)
if hasattr(self, 'compute') and self.compute:
code_str += str(self.compute)
if hasattr(self, 'stores') and self.stores:
code_str += str(self.stores)
# Skip if using block pointers - they have complex shape calculations
if "tl.make_block_ptr" in code_str:
return False
# Skip if using tensor descriptors (TMA)
if "tl.make_tensor_descriptor" in code_str:
return False
# Separate trees by type
valid_prefixes = {'x', 'y', 'z'}
valid_trees = []
reduction_trees = []
for tree in self.range_trees:
if tree.prefix in valid_prefixes:
valid_trees.append(tree)
elif tree.prefix.startswith('r') and (len(tree.prefix) == 1 or tree.prefix[1:].replace('_', '').isdigit()):
reduction_trees.append(tree)
# Check if all reduction trees are trivial (numel=1)
for tree in reduction_trees:
try:
size_hint = V.graph.sizevars.size_hint(tree.numel, fallback=999999)
if size_hint > 1:
return False
except:
return False
# Need at least one valid dimension
if len(valid_trees) == 0 or len(valid_trees) > 2:
return False
# Check dimension sizes
dimension_sizes = []
for tree in valid_trees:
try:
size_hint = V.graph.sizevars.size_hint(
tree.numel,
fallback=config.unbacked_symint_fallback
)
dimension_sizes.append((tree.prefix, size_hint))
except:
return False
#if dimension_sizes:
# largest_size = max(size for _, size in dimension_sizes)
# # Need sufficient work to hide memory latency
# if largest_size < 512:
# return False
#else:
# return False
return True
def _get_split_dimension_info(self):
"""Get information about which dimension to split for inner loop."""
# Only consider x, y, z dimensions (skip reduction dimensions)
valid_prefixes = {'x', 'y', 'z'}
non_reduction_trees = [
t for t in self.range_trees
if t.prefix in valid_prefixes
]
if not non_reduction_trees:
return None, []
# Find dimension with largest numel
split_tree = max(
non_reduction_trees,
key=lambda tree: V.graph.sizevars.size_hint(
tree.numel, fallback=config.unbacked_symint_fallback
)
)
return split_tree, non_reduction_trees
def _get_largest_dimension_for_inner_loop(self):
"""Determine which dimension to split based on largest numel."""
if len(self.range_trees) == 1:
return self.range_trees[0]
# Find dimension with largest numel
largest_tree = max(
self.range_trees,
key=lambda tree: V.graph.sizevars.size_hint(
tree.numel, fallback=config.unbacked_symint_fallback
)
)
return largest_tree
def _codegen_pointwise_with_inner_loop(self):
"""Generate pointwise kernel body with inner R0_BLOCK loop."""
# Save the current generated code
saved_indexing = self.indexing_code.getvalue()
saved_loads = self.loads.getvalue()
saved_compute = self.compute.getvalue()
saved_stores = self.stores.getvalue()
# Clear buffers for our custom generation
self.indexing_code.clear()
self.body.clear()
self.loads.clear()
self.compute.clear()
self.stores.clear()
# Determine which dimension to split
non_reduction_trees = [t for t in self.range_trees if t.prefix in {'x', 'y', 'z'}]
if not non_reduction_trees:
# Restore and fallback
self.indexing_code.writelines(saved_indexing.split('\n'))
self.loads.writelines(saved_loads.split('\n'))
self.compute.writelines(saved_compute.split('\n'))
self.stores.writelines(saved_stores.split('\n'))
return
split_tree = max(
non_reduction_trees,
key=lambda tree: V.graph.sizevars.size_hint(
tree.numel, fallback=config.unbacked_symint_fallback
)
)
split_prefix = split_tree.prefix
# Add static assertion
self.body.writeline(f"tl.static_assert({split_prefix.upper()}BLOCK % R0_BLOCK == 0)")
# Map prefix to program_id
prefix_to_pid = {
'x': "tl.program_id(0)",
'y': "(tl.program_id(1) + tl.program_id(2) * tl.num_programs(1))",
'z': "tl.program_id(2)"
}
# Build prefix list for broadcast calculations
all_prefixes = [t.prefix for t in non_reduction_trees]
# Generate offset+index+mask for NON-SPLIT dimensions
for tree in non_reduction_trees:
if tree.prefix != split_prefix:
prefix = tree.prefix
pid = prefix_to_pid[prefix]
idx = all_prefixes.index(prefix)
if len(all_prefixes) == 1:
broadcast = ""
elif len(all_prefixes) == 2:
broadcast = "[:, None]" if idx == 0 else "[None, :]"
else:
parts = ["None"] * len(all_prefixes)
parts[idx] = ":"
broadcast = f"[{', '.join(parts)}]"
self.body.writeline(f"{prefix}offset = {pid} * {prefix.upper()}BLOCK")
self.body.writeline(
f"{prefix}index = {prefix}offset + tl.arange(0, {prefix.upper()}BLOCK){broadcast}"
)
self.body.writeline(f"{prefix}mask = {prefix}index < {prefix}numel")
# For split dimension, generate offset only
pid = prefix_to_pid[split_prefix]
self.body.writeline(f"{split_prefix}offset = {pid} * {split_prefix.upper()}BLOCK")
self.body.writeline(f"tile_start = {split_prefix}offset")
# Save original name
original_name = split_tree.name
# Generate pipelined inner loop
self.body.writeline(f"for r in tl.range(0, {split_prefix.upper()}BLOCK, R0_BLOCK, num_stages=2):")
with self.body.indent():
self.body.writeline("lanes = tl.arange(0, R0_BLOCK)")
# Calculate broadcast for split dimension
split_idx = all_prefixes.index(split_prefix)
if len(all_prefixes) == 1:
broadcast = ""
elif len(all_prefixes) == 2:
broadcast = "[:, None]" if split_idx == 0 else "[None, :]"
else:
parts = ["None"] * len(all_prefixes)
parts[split_idx] = ":"
broadcast = f"[{', '.join(parts)}]"
self.body.writeline(
f"{split_prefix}index = (tile_start + r + lanes){broadcast}"
)
self.body.writeline(f"{split_prefix}mask = {split_prefix}index < {split_prefix}numel")
# Override the range tree name temporarily
split_tree.name = f"{split_prefix}index"
# Generate derived indices
for node_name, entry in self.range_tree_nodes.items():
if hasattr(entry, 'expr') and entry.name != f"{split_prefix}index":
line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}"
self.body.writeline(line)
# Fix shapes in the saved code
block_name = f"{split_prefix.upper()}BLOCK"
# Replace in all buffers
import re
def fix_shapes(code):
if not code:
return code
# Replace [XBLOCK] with [R0_BLOCK]
code = code.replace(f"[{block_name}]", "[R0_BLOCK]")
# Replace tl.full([XBLOCK], ...) with tl.full([R0_BLOCK], ...)
code = re.sub(
rf'tl\.full\(\[{block_name}\]',
'tl.full([R0_BLOCK]',
code
)
# Replace broadcast_to(expr, [XBLOCK]) with broadcast_to(expr, [R0_BLOCK])
code = re.sub(
rf'broadcast_to\((.*?),\s*\[{block_name}\]\)',
r'broadcast_to(\1, [R0_BLOCK])',
code
)
return code
# Apply fixes and write the code
if saved_indexing:
self.body.splice(fix_shapes(saved_indexing))
if saved_loads:
self.body.splice(fix_shapes(saved_loads))
if saved_compute:
self.body.splice(fix_shapes(saved_compute))
if saved_stores:
self.body.splice(fix_shapes(saved_stores))
# Restore original name
split_tree.name = original_name
def codegen_body(self):
"""
@ -4015,6 +4166,15 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
):
return
if self._should_use_pointwise_inner_loop():
self._codegen_pointwise_with_inner_loop()
# Clear the buffers since we've used them
self.indexing_code.clear()
self.loads.clear()
self.compute.clear()
self.stores.clear()
return
loop_trees = [tree for tree in self.range_trees if tree.is_loop]
if self.inside_reduction and len(loop_trees) > 0:
# Write the loop headers.
@ -4105,16 +4265,12 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
args.append(str(arg))
elif isinstance(arg, SymbolicCallArg):
hint = V.graph.sizevars.size_hint(
arg.inner_expr,
hint_override=self.hint_override,
fallback=config.unbacked_symint_fallback,
arg.inner_expr, fallback=config.unbacked_symint_fallback
)
args.append(str(hint))
elif isinstance(arg, sympy.Expr):
hint = V.graph.sizevars.size_hint(
arg,
hint_override=self.hint_override,
fallback=config.unbacked_symint_fallback,
arg, fallback=config.unbacked_symint_fallback
)
args.append(str(hint))
else:
@ -4173,11 +4329,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long
)
elif isinstance(arg_sig, SizeArg):
symval_hint = V.graph.sizevars.size_hint(
arg_sig.expr,
hint_override=self.hint_override,
fallback=config.unbacked_symint_fallback,
)
symval_hint = V.graph.sizevars.size_hint(arg_sig.expr)
# Force the seed_offset to be 0 so calls to the same kernel
# using different seed offset will have the same benchmark harness.
@ -4187,9 +4339,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
result.writeline(f"{var_name} = {symval_hint}")
elif isinstance(arg_sig, WorkspaceArg):
device = V.graph.get_current_device_or_throw()
count = V.graph.sizevars.size_hint(
arg_sig.count, hint_override=self.hint_override
)
count = V.graph.sizevars.size_hint(arg_sig.count)
result.writeline(
f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})"
)
@ -4284,7 +4434,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
"min_split_scan_rblock": config.triton.min_split_scan_rblock,
"spill_threshold": config.triton.spill_threshold,
"store_cubin": config.triton.store_cubin,
"deterministic": config.deterministic,
}
if torch.version.hip is not None:
inductor_meta["is_hip"] = True
@ -4315,6 +4464,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
metadata, and benchmarking infra.
"""
from ..utils import prefix_is_reduction
code = IndentedBuffer()
size_hints = {}
@ -4426,6 +4577,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
add_constexpr_arg(f"{tree.prefix.upper()}BLOCK")
# In codegen_kernel(), replace the R0_BLOCK section with:
if self._should_use_pointwise_inner_loop():
add_constexpr_arg("R0_BLOCK")
if self.cooperative_reduction:
add_constexpr_arg("RSPLIT")
@ -4457,12 +4612,6 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
**self.inductor_meta_common(),
}
if config.deterministic or config.test_configs.force_filter_reduction_configs:
inductor_meta["has_loadstore_with_contiguous_rdim"] = (
self.has_load_with_contiguous_rdim
or self.has_store_with_contiguous_rdim
)
# Bail on 3d tiling, which has more complicated coalesce patterns
looped_red = V.kernel.features.is_reduction() and not self.persistent_reduction
tiling_scores = self.tiling_scores
@ -4538,7 +4687,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
triton_meta["configs"] = [config_of(signature)]
if enable_pdl_codegen():
if config.triton.enable_pdl:
triton_meta["launch_pdl"] = True
# Triton compiler includes equal_to_1 args into constants even
@ -4990,7 +5139,7 @@ class TritonScheduling(SIMDScheduling):
)
return cls.backend_features
def codegen_comment(self, node_schedule, kernel_name=None):
def codegen_comment(self, node_schedule):
wrapper = V.graph.wrapper_code
origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper)
if origins:
@ -5016,13 +5165,6 @@ class TritonScheduling(SIMDScheduling):
f"{wrapper.comment} Fused node name list: {', '.join(node_names)}"
)
if kernel_name:
debug_handle = set_kernel_post_grad_provenance_tracing(
node_schedule, # type: ignore[arg-type]
kernel_name,
)
wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
def define_kernel(self, src_code, node_schedule, kernel):
wrapper = V.graph.wrapper_code
if src_code in wrapper.src_to_kernel:

View File

@ -17,7 +17,7 @@ import re
import sys
import threading
import time
from collections import defaultdict, namedtuple
from collections import namedtuple
from typing import (
Any,
Callable,
@ -30,9 +30,8 @@ from typing import (
)
import torch
from torch._dynamo.utils import counters, set_feature_use
from torch._dynamo.utils import set_feature_use
from torch._environment import is_fbcode
from torch._inductor import metrics
from torch._prims_common import compute_required_storage_length
from torch.utils._ordered_set import OrderedSet
@ -237,7 +236,7 @@ def check_autotune_cache(
not disabled
and filename is not None
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
and os.environ.get("TRITON_INTERPRET", "0") != "1"
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
):
configs_hash = hash_configs(configs)
@ -312,8 +311,6 @@ class CachingAutotuner(KernelInterface):
"device_type": self.device_props.type,
}
self.inductor_meta = {} if inductor_meta is None else inductor_meta
self.deterministic_mode = self.inductor_meta.get("deterministic", False)
self.save_cache_hook = save_cache_hook
self.mutated_arg_names = mutated_arg_names
self.reset_to_zero_arg_names = (
@ -378,7 +375,7 @@ class CachingAutotuner(KernelInterface):
self.is_backward = False
# Mode for launch grid calculation
self.grid_mode: Literal["python", "cpp"] = "python"
self.grid_mode: Literal["python", "python_slow", "cpp"] = "python"
def is_statically_launchable(self):
"""
@ -485,8 +482,7 @@ class CachingAutotuner(KernelInterface):
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
device_prop = self.device_props
if (
not self.deterministic_mode
and self.inductor_meta.get("dynamic_scale_rblock", True)
self.inductor_meta.get("dynamic_scale_rblock", True)
and not self.inductor_meta.get("persistent_reduction")
and self.heuristic_type == HeuristicType.REDUCTION
and self.size_hints is not None
@ -690,17 +686,9 @@ class CachingAutotuner(KernelInterface):
return get_interface_for_device(self.device_props.type.replace("hip", "cuda"))
def _create_compile_meta(self, cfg: Config) -> dict[str, Any]:
"""
Create compilation metadata for a given autotuner config. This involves
processing the Config kwargs so that the kwargs that are not part
of the triton signature are passed in as options to triton.compile
instead
"""
def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
"""Ahead of time compile a given autotuner config."""
compile_meta = copy.deepcopy(self.triton_meta)
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
cfg_kwargs = cfg.kwargs
if self.device_props.type == "hip":
cfg_kwargs = {**cfg_kwargs}
@ -708,13 +696,14 @@ class CachingAutotuner(KernelInterface):
if k in cfg_kwargs:
compile_meta[k] = cfg_kwargs.pop(k)
compile_meta["constants"].update(cfg_kwargs)
for i in self.fn.constexprs:
arg_name = self.fn.arg_names[i]
if arg_name not in compile_meta["constants"] and (
arg_name == "num_warps" or arg_name == "num_stages"
):
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
if HAS_WARP_SPEC:
compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0)
compile_meta["num_buffers_warp_spec"] = getattr(
@ -728,53 +717,6 @@ class CachingAutotuner(KernelInterface):
compile_meta["device_type"] = self.device_props.type
compile_meta["cc"] = self.device_props.cc
return compile_meta
def _create_compile_options(
self, cfg: Config, compile_meta: dict[str, Any]
) -> dict[str, Any]:
"""
Create options to pass to triton.compile based on the compile metadata
and the given config.
"""
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
}
if "enable_fp_fusion" in compile_meta:
options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
if HAS_WARP_SPEC:
options.update(
{
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
"num_buffers_warp_spec": compile_meta.get(
"num_buffers_warp_spec", 0
),
}
)
if self.device_props.type == "cuda":
options.update(
{
"launch_cooperative_grid": compile_meta.get(
"launch_cooperative_grid", False
),
"launch_pdl": compile_meta.get("launch_pdl", False), # True
}
)
if self.device_props.type == "hip":
if "waves_per_eu" in compile_meta:
options["waves_per_eu"] = compile_meta["waves_per_eu"]
if "matrix_instr_nonkdim" in compile_meta:
options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
return options
def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]:
"""Ahead of time compile a given autotuner config."""
compile_meta = self._create_compile_meta(cfg)
if self.device_props.type == "cpu":
triton_helpers.set_driver_to_cpu()
else:
@ -807,8 +749,37 @@ class CachingAutotuner(KernelInterface):
cc_warp_size(compile_meta["cc"]),
)
options = self._create_compile_options(cfg, compile_meta)
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
}
if "enable_fp_fusion" in compile_meta:
options["enable_fp_fusion"] = compile_meta["enable_fp_fusion"]
if HAS_WARP_SPEC:
options.update(
{
"num_consumer_groups": compile_meta.get("num_consumer_groups", 0),
"num_buffers_warp_spec": compile_meta.get(
"num_buffers_warp_spec", 0
),
}
)
if self.device_props.type == "cuda":
options.update(
{
"launch_cooperative_grid": compile_meta.get(
"launch_cooperative_grid", False
),
"launch_pdl": compile_meta.get("launch_pdl", False), # True
}
)
if self.device_props.type == "hip":
if "waves_per_eu" in compile_meta:
options["waves_per_eu"] = compile_meta["waves_per_eu"]
if "matrix_instr_nonkdim" in compile_meta:
options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"]
compile_kwargs = {
"target": target,
"options": options,
@ -898,34 +869,25 @@ class CachingAutotuner(KernelInterface):
)
# reset to zero before evaluating any config
self.reset_to_zero_args(*args, **kwargs)
kernel_name = self.inductor_meta.get("kernel_name", "triton kernel")
if autograd_profiler._is_profiler_enabled:
profiler_kwargs = self.get_profiler_kwargs(stream, launcher)
with torch._C._profiler._RecordFunctionFast(
kernel_name,
self.inductor_meta.get("kernel_name", "triton kernel"),
cloned_args,
profiler_kwargs,
):
try:
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
except Exception:
log.error("Failed during launch %s: ", kernel_name)
raise
else:
try:
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
except Exception:
log.error("Failed during launch %s: ", kernel_name)
raise
else:
launcher(
*cloned_args,
**cloned_kwargs,
stream=stream,
)
self.restore_args_from_cpu(cpu_copies)
# only use profiler when not already in a profiler instance
@ -1092,18 +1054,6 @@ class CachingAutotuner(KernelInterface):
k.shared,
)
if metrics.is_metric_table_enabled("kernel_autotune"):
if self.fn.fn is None:
self.fn = self._reload_kernel().fn
kernel_path = self.fn.fn.__code__.co_filename
kernel_name = self.fn.__name__
for k, v in timings.items():
metrics.log_kernel_autotune_result(
kernel_path, kernel_name, k.config, v
)
self.reset_to_zero_args(*args, **kwargs)
return timings
@ -1198,26 +1148,13 @@ class CachingAutotuner(KernelInterface):
Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1;
while if coordinate descent tuning is run with max-autotune enabled, it will start from C3.
"""
if self.heuristic_type in (
HeuristicType.TEMPLATE,
HeuristicType.USER_AUTOTUNE,
HeuristicType.FIXED,
if (
self.heuristic_type == HeuristicType.TEMPLATE
or self.heuristic_type == HeuristicType.USER_AUTOTUNE
):
# skip triton template
return launcher
if self.deterministic_mode and self.heuristic_type in (
HeuristicType.REDUCTION,
HeuristicType.PERSISTENT_REDUCTION,
HeuristicType.SPLIT_SCAN,
):
# Not only RBLOCK size matters for numericals of reduction.
# num_warps also matters since that affect how much data
# is handled by each thread, how many warp-reduction we do
# in parallel and how much data is there for block
# reduction.
return launcher
with dynamo_timed(
"CachingAutotuner.coordinate_descent_tuning",
# These generate too many pt2_compile_event logs:
@ -1251,7 +1188,6 @@ class CachingAutotuner(KernelInterface):
config2launcher[config] = launcher
out = self.bench(launcher, *args, **kwargs)
counters["inductor"]["coordesc_tuning_bench"] += 1
log.debug(
"COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d",
launcher.config,
@ -2574,7 +2510,27 @@ def pointwise(
configs = None
if len(size_hints) == 1:
if disable_pointwise_autotuning(inductor_meta) and not (
use_looped_pointwise = True
has_r0_block = any(
arg == "R0_BLOCK"
for arg in triton_meta.get("signature", {}).keys()
)
if has_r0_block:
configs = [
Config({"XBLOCK": 512, "R0_BLOCK": 256}, num_warps=4, num_stages=1),
Config({"XBLOCK": 1024, "R0_BLOCK": 512}, num_warps=4, num_stages=1),
Config({"XBLOCK": 2048, "R0_BLOCK": 1024}, num_warps=8, num_stages=1),
Config({"XBLOCK": 4096, "R0_BLOCK": 2048}, num_warps=8, num_stages=1),
]
#configs = [
# triton_config_with_settings(size_hints, 512),
# triton_config_with_settings(size_hints, 1024),
# triton_config_with_settings(size_hints, 2048),
# triton_config_with_settings(size_hints, 4096),
#]
elif disable_pointwise_autotuning(inductor_meta) and not (
inductor_meta.get("max_autotune")
or inductor_meta.get("max_autotune_pointwise")
):
@ -2588,7 +2544,36 @@ def pointwise(
*hinted_configs,
]
if len(size_hints) == 2:
if (
use_looped_pointwise = True
has_r0_block = any(
arg == "R0_BLOCK"
for arg in triton_meta.get("signature", {}).keys()
)
if has_r0_block:
# Find which dimension is larger
dim_sizes = [(name, size) for name, size in size_hints.items()]
larger_dim, larger_size = max(dim_sizes, key=lambda x: x[1])
smaller_dim, smaller_size = min(dim_sizes, key=lambda x: x[1])
# Split the larger dimension with R0_BLOCK
larger_block = larger_dim.upper() + "BLOCK"
smaller_block = smaller_dim.upper() + "BLOCK"
configs = [
Config({larger_block: 512, "R0_BLOCK": 256, smaller_block: 32}, num_warps=4, num_stages=1),
Config({larger_block: 1024, "R0_BLOCK": 512, smaller_block: 32}, num_warps=4, num_stages=1),
Config({larger_block: 512, "R0_BLOCK": 256, smaller_block: 64}, num_warps=4, num_stages=1),
Config({larger_block: 1024, "R0_BLOCK": 512, smaller_block: 64}, num_warps=8, num_stages=1),
]
#configs = [
# triton_config_with_settings(size_hints, 32, 32),
# triton_config_with_settings(size_hints, 64, 64),
# triton_config_with_settings(size_hints, 256, 16),
# triton_config_with_settings(size_hints, 16, 256),
#]
elif (
disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
) and not (
inductor_meta.get("max_autotune")
@ -2878,144 +2863,6 @@ def adapt_config_for_tiling(
)
class ReductionConfigKey:
"""
The part of reduction configs that affect determinism.
"""
def __init__(self, config: Config):
# persistent reduction does not have a RBLOCK, use -1 as a flag
self.r0_block = config.kwargs.get("R0_BLOCK", -1)
self.r1_block = config.kwargs.get("R1_BLOCK", -1)
self.num_warps = config.num_warps
self.num_ctas = config.num_ctas
def __hash__(self) -> int:
return hash((self.r0_block, self.r1_block, self.num_warps, self.num_ctas))
def __eq__(self, other: object) -> bool:
return (
isinstance(other, ReductionConfigKey)
and self.r0_block == other.r0_block
and self.r1_block == other.r1_block
and self.num_warps == other.num_warps
and self.num_ctas == other.num_ctas
)
def filter_reduction_configs_for_determinism(
inductor_meta: dict[str, Any], configs: list[Config]
) -> list[Config]:
"""
Filter configs for reduction so the numerics can be deterministic.
This function group configs by fields that affect determinism
- rblock size
- num warps
- num ctas
and return the most promising group based on heuristics.
Heuristics:
- skip reduction configs with too small RBLOCK
- skip reduction configs with XBLOCK==1 if we are confident it will not perform well
- pick the group with largest size: autotuning more configs may have more chance to give better perf
- if there is a tie, pick the group with second largest RBLOCK
- if there is still a tie, pick the group with second largest num_warps
"""
configs = unique_configs(configs)
assert len(configs) > 0
def _do_filter_due_to_inductor_config():
return (
inductor_meta.get("deterministic", False)
or torch._inductor.config.test_configs.force_filter_reduction_configs
)
if not _do_filter_due_to_inductor_config() or len(configs) == 1:
# no filtering happening if NOT in deterministic mode
return configs
if log.isEnabledFor(logging.DEBUG):
log.debug("reduction configs before filtering:")
for c in configs:
log.debug("%s", c)
log.debug("")
def _has_too_small_rblock(config):
rblock = config.kwargs.get("R0_BLOCK")
# too small RBLOCK is likely to be bad
return rblock is not None and rblock <= 4
def _nonpromising_xblock_1(config):
# kernel like https://gist.github.com/shunting314/0b3281c087e79bc915fe45985ff9d7d5
# without a load/store having contiguous rdim is unlikely to perform well with XBLOCK==1
return config.kwargs["XBLOCK"] == 1 and not inductor_meta.get(
"has_loadstore_with_contiguous_rdim", True
)
newconfigs = [*filter(lambda x: not _has_too_small_rblock(x), configs)]
# accept the filtering only if there are configs left
if len(newconfigs) > 0:
configs = newconfigs
newconfigs = [*filter(lambda x: not _nonpromising_xblock_1(x), configs)]
if len(newconfigs) > 0:
configs = newconfigs
groups: defaultdict[ReductionConfigKey, list[Config]] = defaultdict(
list
) # group configs by RBLOCK, num_warps, num_ctas
for c in configs:
key = ReductionConfigKey(c)
groups[key].append(c)
assert len(groups) > 0
def _pick_group():
grouplist = sorted(groups.items(), key=lambda x: len(x[1]), reverse=True)
max_group_size = len(grouplist[0][1])
grouplist = [*filter(lambda g: len(g[1]) == max_group_size, grouplist)]
assert len(grouplist) > 0
if len(grouplist) == 1:
return grouplist[0][1]
# break tie by R0_BLOCK
grouplist = sorted(grouplist, key=lambda x: x[0].r0_block)
if grouplist[0][0].r0_block != grouplist[-1][0].r0_block:
max_r0_block = grouplist[-1][0].r0_block
grouplist = [*filter(lambda x: x[0].r0_block != max_r0_block, grouplist)]
second_max_r0_block = grouplist[-1][0].r0_block
grouplist = [
*filter(lambda x: x[0].r0_block == second_max_r0_block, grouplist)
]
if len(grouplist) == 1:
return grouplist[0][1]
# break tie by num_warps
grouplist = sorted(grouplist, key=lambda x: x[0].num_warps)
if grouplist[0][0].num_warps != grouplist[-1][0].num_warps:
max_num_warps = grouplist[-1][0].num_warps
grouplist = [*filter(lambda x: x[0].num_warps != max_num_warps, grouplist)]
second_max_num_warps = grouplist[-1][0].num_warps
grouplist = [
*filter(lambda x: x[0].num_warps == second_max_num_warps, grouplist)
]
# there is still a tie, pick the first one
return grouplist[0][1]
configs = _pick_group()
if log.isEnabledFor(logging.DEBUG):
log.debug("reduction configs after filtering:")
for c in configs:
log.debug("%s", c)
log.debug("")
return configs
def reduction(
size_hints,
reduction_hint=False,
@ -3041,7 +2888,6 @@ def reduction(
)
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
return cached_autotune(
size_hints,
configs=configs,
@ -3089,7 +2935,6 @@ def cooperative_reduction(
# TODO(jansel): add more configs in max_autotune
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
return cached_autotune(
size_hints,
configs=configs,
@ -3141,12 +2986,15 @@ def _persistent_reduction_configs(
if "y" in size_hints:
pass
# TODO(jansel): we should be able to improve these heuristics
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
elif reduction_hint == ReductionHint.INNER:
if rnumel > 1024:
configs = configs[:1]
else:
x_block = 8
if xnumel // x_block < 128 or loads_and_stores >= 5:
if xnumel // x_block < 128 or (loads_and_stores >= 5 and rnumel >= 256):
# If loads/stores greater than 5, a lot of register pressure
# rnumel < 256 means no vectorized loads if we split up r dim
# so xblock still needs to be larger
x_block = 1
configs = [
@ -3202,7 +3050,6 @@ def persistent_reduction(
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
inductor_meta.pop(persistent_reduction_key)
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
return cached_autotune(
size_hints,
configs,
@ -3240,7 +3087,6 @@ def split_scan(
cfg.kwargs[var] = min_rblock
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
configs = filter_reduction_configs_for_determinism(inductor_meta, configs)
return cached_autotune(
size_hints,
configs=configs,
@ -3380,14 +3226,14 @@ class GridExpr:
"""Generate code for grid size expressions in launcher"""
inductor_meta: dict[str, Any]
mode: Literal["python", "cpp"] = "python"
mode: Literal["python", "cpp", "python_slow"] = "python"
prefix: list[str] = dataclasses.field(default_factory=list)
x_grid: Union[str, int] = 1
y_grid: Union[str, int] = 1
z_grid: Union[str, int] = 1
def __post_init__(self) -> None:
assert self.mode in ("python", "cpp")
assert self.mode in ("python", "cpp", "python_slow")
def generate(self, meta: dict[str, int]) -> None:
raise NotImplementedError
@ -3403,6 +3249,10 @@ class GridExpr:
# negative integer division is floored
if self.mode == "python":
return f"-(({numel}) // -({block}))"
# This is more generic than above, and works in languages where
# positive integer division is floored/truncated
elif self.mode == "python_slow":
return f"(({numel} + {block} - 1) // ({block}))"
# For cpp code gen
return f"(({numel} + ({block} - 1)) / ({block}))"
@ -3411,7 +3261,7 @@ class GridExpr:
items = self._constant_fold(max, seq)
if len(items) <= 1:
return items[0]
if self.mode == "python":
if self.mode in ("python", "python_slow"):
return f"max({', '.join(map(str, items))})"
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
@ -3434,7 +3284,7 @@ class GridExpr:
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
# Grid functions are one per kernel, so name collisions are fine
if self.mode == "python":
if self.mode in ("python", "python_slow"):
return f"{name} = {expr}"
if self.mode == "cpp":
return f"uint32_t {name} = {expr};"
@ -3444,7 +3294,7 @@ class GridExpr:
def from_meta(
inductor_meta: dict[str, Any],
cfg: Union[Config, dict[str, int]],
mode: Literal["python", "cpp"] = "python",
mode: Literal["python", "cpp", "python_slow"] = "python",
) -> GridExpr:
grid_cls = globals()[inductor_meta["grid_type"]]
assert issubclass(grid_cls, GridExpr)