[Inductor] Refactor "r" reduction prefix to {"r0_", "r1_"}. (#142020)

Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243.

# Feature

This PR changes the `RINDEX` / `"r"` symbol type to `(R0_INDEX, R1_INDEX)` and `("r0_", "r1_")`, respectively. This allows the relevant code to support 2D (often ND) reductions. Unlike the parent PR, this one does not change the tiling algorithm, so `"r1_"` is never used. However, it prepares other parts of the system to handle `"r1_"` once we start using it. This should significantly reduce the chances of hitting merge conflicts, making the parent PR much easier to land.

The only change to the generated triton code is to rename `"rindex"` -> `"r0_index"`, `"RBLOCK"` -> `"R0_BLOCK"`, etc. To maintain compatibilty with existing codegen, this also generates aliases to the old reduction variables like `rindex = r0_index`. If we generated 2D reductions (which this PR will not do), the aliases would be more complicated and would collapse 2D multi-indices to linear indices. See some example kernels in the parent PR.

These aliases can be eliminated by the Triton compiler, and should not impact the final machine code running on the GPU. See the perf testing in the parent PR which confirms the aliases do not impact perf.

# Test plan

The existing CI provides good coverage. This PR modifies the expected code in a few places, renaming reduction variables from `r.*` to `r0_.*`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142020
Approved by: https://github.com/jansel

Co-authored-by: Jason Ansel <jansel@meta.com>
This commit is contained in:
Blaine Burton Rister
2024-12-12 17:22:20 +00:00
committed by PyTorch MergeBot
parent cf538efd0c
commit 520ba556cd
14 changed files with 398 additions and 167 deletions

View File

@ -109,8 +109,8 @@ class TestCoordinateDescentTuner(TestCase):
max_block = TRITON_MAX_BLOCK
self.assertFalse(tuner.value_too_large("XBLOCK", max_block["X"]))
self.assertTrue(tuner.value_too_large("XBLOCK", max_block["X"] * 2))
self.assertFalse(tuner.value_too_large("RBLOCK", max_block["R"]))
self.assertTrue(tuner.value_too_large("RBLOCK", max_block["R"] * 2))
self.assertFalse(tuner.value_too_large("R0_BLOCK", max_block["R0_"]))
self.assertTrue(tuner.value_too_large("R0_BLOCK", max_block["R0_"] * 2))
if __name__ == "__main__":

View File

@ -487,7 +487,7 @@ class PaddingTest(TestCaseBase):
# make sure the load for softmax is aligned
self.assertTrue(
"tl.load(in_ptr0 + (r1 + 30528*x0)" in forward_wrapper,
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
f"forward_wrapper: {forward_wrapper}",
)

View File

@ -1846,7 +1846,7 @@ class CommonTemplate:
from torch._inductor.runtime.runtime_utils import next_power_of_2
from torch._inductor.runtime.triton_heuristics import triton_config_reduction
size_hints = {"x": 67108864, "r": 8192}
size_hints = {"x": 67108864, "r0_": 8192}
for i in range(4):
size_hints["x"] = next_power_of_2(size_hints["x"])
triton_config_reduction(size_hints, 1, 2048, 1, 8)
@ -12547,8 +12547,8 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r2), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r2), rmask, eviction_policy='evict_first', other=0.0)""",
tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r0_2), r0_mask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r0_2), r0_mask, eviction_policy='evict_first', other=0.0)""",
)
@config.patch("triton.use_block_ptr", True)
@ -12571,16 +12571,16 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[262144, 512], strides=[1, 262144], block_shape=[XBLOCK, RBLOCK], order=[0, 1], offsets=[xoffset, roffset]), boundary_check=[1], padding_option='zero')
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r0_2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[262144, 512], strides=[1, 262144], block_shape=[XBLOCK, R0_BLOCK], order=[0, 1], offsets=[xoffset, roffset]), boundary_check=[1], padding_option='zero')
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r0_2)), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
)
else:
self.assertExpectedInline(
"\n".join(lines),
"""\
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), R0_BLOCK]), [XBLOCK, R0_BLOCK])
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
)

View File

@ -51,6 +51,7 @@ from ..utils import (
get_dtype_size,
IndentedBuffer,
Placeholder,
prefix_is_reduction,
sympy_index_symbol,
sympy_product,
sympy_subs,
@ -75,9 +76,7 @@ fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
pexpr = PythonPrinter().doprint
def prefix_is_reduction(prefix: str) -> bool:
return prefix[0] == "r"
all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"])
@dataclasses.dataclass
@ -359,7 +358,7 @@ class SIMDKernel(Kernel):
self.range_trees: List[IterationRangesRoot] = []
self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {}
self.iter_vars_count = itertools.count()
self.inside_reduction = self.numels["r"] != 1
self.inside_reduction = features.is_reduction()
self.cooperative_reduction: bool = (
override_cooperative_reduction
if override_cooperative_reduction is not None
@ -385,6 +384,12 @@ class SIMDKernel(Kernel):
self.simplify_indexing = simplify_indexing
self.initialize_range_tree(pid_cache)
@property
@cache_on_self
@no_type_check # https://github.com/python/mypy/issues/17184
def num_reduction_dims(self) -> int:
return sum(prefix_is_reduction(prefix) for prefix in self.numels)
def dtype_to_str(self, dtype: torch.dtype) -> str:
raise NotImplementedError
@ -396,25 +401,34 @@ class SIMDKernel(Kernel):
return False
def initialize_range_tree(self, pid_cache):
no_r_dim = not self.inside_reduction or self.numels["r"] == 1
active_prefixes = OrderedSet(
prefix for prefix in all_prefixes if prefix in self.numels
)
no_r_dim = not self.inside_reduction or not self.features.is_reduction()
prefixes = "zyxr"
active_prefixes = prefixes[-len(self.numels) :]
def filtered_index_map(seq, mask) -> Dict[Any, int]:
return {
val: idx for idx, val in enumerate(val for val in seq if val in mask)
}
grid_dims = "xyz"
grid_dims = ["x", "y", "z"]
reduction_dims = ["r0_", "r1_"]
if self.no_x_dim:
tensor_dims = "r"
tensor_dims = reduction_dims
elif no_r_dim:
tensor_dims = "xyz"
tensor_dims = grid_dims
else:
tensor_dims = "xyzr"
tensor_dims = grid_dims + reduction_dims
tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes)
# Filter out unused tensor dims.
# Convert to dicts for O(1) index lookup.
tensor_dim_map = filtered_index_map(tensor_dims, active_prefixes)
grid_dim_map = filtered_index_map(grid_dims, all_prefixes)
for i, prefix in enumerate(active_prefixes):
is_reduction = prefix_is_reduction(prefix)
tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None
grid_dim = None if is_reduction else grid_dims.find(prefix)
tensor_dim = tensor_dim_map.get(prefix)
grid_dim = grid_dim_map.get(prefix)
index = i if grid_dim is None else grid_dim
self.range_trees.append(
IterationRangesRoot(
@ -427,7 +441,7 @@ class SIMDKernel(Kernel):
is_loop=is_reduction and not self.persistent_reduction,
tensor_dim=tensor_dim,
grid_dim=grid_dim,
has_zdim="z" in active_prefixes,
has_zdim="z" in self.numels,
)
)
@ -528,7 +542,7 @@ class SIMDKernel(Kernel):
@contextlib.contextmanager
def ctx():
if self.numels["r"] == 1:
if not self.features.is_reduction():
assert not self.inside_reduction
yield
return
@ -963,7 +977,7 @@ class SIMDKernel(Kernel):
def welford_reduce_fallback(self, dtype, value):
sum_ = ops.reduction(dtype, dtype, "sum", value)
self.inside_reduction = False
rnumel = ops.index_expr(self.numels["r"], dtype)
rnumel = ops.index_expr(self.features.reduction_numel, dtype)
mean = ops.truediv(sum_, rnumel)
self.inside_reduction = True
@ -1572,10 +1586,9 @@ class SIMDScheduling(BaseScheduling):
Create a tiling dict from pointwise and reduction splits.
"""
pw_prefixes = ["z", "y", "x"][-len(pw_tiling) :]
reduction_prefixes = ["r"][: len(reduction_tiling)]
reduction_prefixes = ["r0_", "r1_"][: len(reduction_tiling)]
return immutable_dict(
list(zip(pw_prefixes, pw_tiling))
+ list(zip(reduction_prefixes, reduction_tiling))
[*zip(pw_prefixes, pw_tiling), *zip(reduction_prefixes, reduction_tiling)]
)
@classmethod

View File

@ -55,12 +55,16 @@ from ..runtime.triton_heuristics import (
)
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
from ..utils import (
cache_on_self,
DelayReplaceLine,
get_bounds_index_expr,
get_fused_kernel_name,
get_kernel_metadata,
is_welford_reduction,
Placeholder,
prefix_is_reduction,
sympy_dot,
sympy_product,
sympy_subs,
triton_type,
upcast_compute_type,
@ -87,7 +91,6 @@ from .simd import (
IterationRangesEntry,
IterationRangesRoot,
pexpr,
prefix_is_reduction,
SIMDKernel,
SIMDScheduling,
)
@ -170,16 +173,19 @@ class TritonSymbols:
Stores sympy.Symbol instances and constants associated with triton codegen.
"""
reduction_types = OrderedSet([SymT.R0_INDEX, SymT.R1_INDEX])
block_types = OrderedSet([SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, *reduction_types])
block_offsets = {
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
for symt in block_types
}
block_sizes = {
symt: sympy.Symbol(
f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
)
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
for symt in block_types
}
@classmethod
@ -213,7 +219,7 @@ class IndexingOptions:
return "tmp" in self.mask_str
def has_rmask(self):
return "rmask" in self.mask_str
return any(str(mask).startswith("r") for mask in self.mask_vars)
@dataclasses.dataclass
@ -357,7 +363,7 @@ class BlockPtrOptions:
if (
not V.kernel.inside_reduction
and len(params.strides) == len(V.kernel.numels) - 1
and V.kernel.numels["r"] != 1
and V.kernel.features.is_reduction()
):
# Need to expand rank by 1 to match rank when self.inside_reduction=True
final_shape.append(sympy.S.One)
@ -376,9 +382,9 @@ class BlockPtrOptions:
def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr:
"""
Replaces instances of roffset with the new expression.
Replaces instances of r0_offset with the new expression.
"""
roffset = TritonSymbols.block_offsets[SymT.RINDEX]
roffset = TritonSymbols.block_offsets[SymT.R0_INDEX]
return sympy_subs(expr, {roffset: replacement})
def format(self, name: str, roffset=True) -> str:
@ -387,7 +393,7 @@ class BlockPtrOptions:
Args:
name: variable name for pointer
roffset: should roffset be included in offsets=..., for use with tl.advance()
roffset: should rn_offset be included in offsets=..., for use with tl.advance()
Returns:
"tl.make_block_ptr(...)"
@ -448,11 +454,11 @@ class BlockPtrOptions:
Codegen string to pass to tl.advance(name, ...).
Advance is the difference between offsets in each loop iteration.
To compute it, we replace roffset with multiples of RBLOCK.
Since we expect roffset to vary in range(0, rnumel, RBLOCK), the first
iteration has roffset=0, while the second has roffset=RBLOCK.
To compute it, we replace roffset with multiples of R0_BLOCK.
Since we expect roffset to vary in range(0, rnumel, R0_BLOCK), the first
iteration has roffset=0, while the second has roffset=R0_BLOCK.
"""
rblock = TritonSymbols.block_sizes[SymT.RINDEX]
rblock = TritonSymbols.block_sizes[SymT.R0_INDEX]
advance = [
(
self.replace_roffset(offset, rblock)
@ -466,7 +472,10 @@ class BlockPtrOptions:
return False # block_ptr can't do indirect indexing
def has_rindex(self) -> bool:
return any(free_symbol_is_type(expr, SymT.RINDEX) for expr in self.block_shape)
return any(
free_symbol_is_type(expr, TritonSymbols.reduction_types)
for expr in self.block_shape
)
def has_rmask(self):
return self.has_rindex()
@ -722,11 +731,14 @@ class TritonCSEVariable(CSEVariable):
for arg in args:
if isinstance(arg, TritonCSEVariable):
self.mask_vars.update(arg.mask_vars)
elif isinstance(arg, sympy.Symbol) and arg.name[0] in "xyr":
elif isinstance(arg, sympy.Symbol):
# most of the time index vars don't need masks associated with them
# however, when index vars are used to compute indices for indirect reads
# those reads should subsequently be masked,
self.mask_vars.update({f"{arg.name[0]}mask"})
for symt in TritonSymbols.block_types:
if symbol_is_type(arg, symt):
self.mask_vars.update({f"{prefix_str[symt]}mask"})
break
def maybe_upcast_float32(convert_output: bool = True):
@ -1543,6 +1555,9 @@ class TritonKernel(SIMDKernel):
self.autotune_hints: OrderedSet[AutotuneHint] = OrderedSet()
self.triton_meta: Optional[Dict[str, Any]] = None
if self.inside_reduction:
self.codegen_reduction_numels(self.body)
if self.cooperative_reduction:
self.init_cooperative_reduction()
@ -1592,12 +1607,24 @@ class TritonKernel(SIMDKernel):
# reduction indexing goes inside a loop
if not tree.is_loop:
self.iteration_ranges_codegen_header(tree, self.body)
if self.inside_reduction and self.range_trees[-1].is_loop:
# workaround for this issue:
# https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
self.body.writeline(
f"rbase = {self.iteration_ranges_ranges_code(self.range_trees[-1])}"
)
elif self.inside_reduction:
# workaround for this issue:
# https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
self.body.writeline(
f"{tree.prefix}base = {self.iteration_ranges_ranges_code(tree)}"
)
if self.inside_reduction:
if any(tree.is_loop for tree in self.range_trees):
# If the kernel contains loops, compute rbase.
rn_bases = self._get_reduction_symbols(
"base", integer=True, nonnegative=True
)
rbase = self._flatten_reduction_indices(rn_bases)
self.body.splice(f"rbase = {self.index_to_str(rbase)}")
else:
# For looped reductions, indexing is deferred to the innermost loop.
self.codegen_reduction_indices(self.body)
def need_numel_args(self):
r"""
@ -1644,7 +1671,9 @@ class TritonKernel(SIMDKernel):
mask_vars: OrderedSet[str] = OrderedSet()
for var in index_vars:
assert isinstance(var, sympy.Symbol)
has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX)
has_rindex = has_rindex or symbol_is_type(
var, TritonSymbols.reduction_types
)
if override_mask:
pass
elif symbol_is_type(var, SymT.TMP):
@ -1664,11 +1693,14 @@ class TritonKernel(SIMDKernel):
):
pass
else:
# var is one of xN, yN or rN
assert symbol_is_type(
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK)
), var.name
mask_vars.add(f"{var.name[0]}mask")
# var is one of xN, yN, r0_N or r1_N
prefix_matches = [
prefix_str[symt]
for symt in TritonSymbols.block_types
if symbol_is_type(var, symt)
]
assert len(prefix_matches) == 1, f"Ambiguous type: {var.name}"
mask_vars.add(f"{prefix_matches[0]}mask")
need_dense = (
config.triton.dense_indexing
@ -1835,8 +1867,8 @@ class TritonKernel(SIMDKernel):
range_trees = self.active_range_trees(reorder=True)
# Partition the index into subexpressions pertaining to each range tree.
# For example xindex * 5 + rindex * 3 is partitioned to
# (xindex * 5, rindex * 3).
# For example xindex * 5 + r0_index * 3 is partitioned to
# (xindex * 5, r0_index * 3).
index_subexprs = [
BlockPatternMatcher.get_subexpr_involving_symbol(
index_relative_to_xyr_index, tree.symbol()
@ -1848,7 +1880,7 @@ class TritonKernel(SIMDKernel):
range_symbols = {tree.symbol() for tree in range_trees}
block_params = BlockParameters()
for tree, subexpr in zip(range_trees, index_subexprs):
# Reject mixed terms, e.g. xindex * rindex.
# Reject mixed terms, e.g. xindex * r0_index.
# NB: the zero expression is allowed, for broadcasting.
if len(range_symbols.intersection(subexpr.free_symbols)) > 1:
return None
@ -2847,17 +2879,19 @@ class TritonKernel(SIMDKernel):
):
return
if self.inside_reduction and self.range_trees[-1].is_loop:
if self.cooperative_reduction:
self.body.writeline(
"for roffset in range(rsplit_start, rsplit_end, RBLOCK):"
)
else:
self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
innermost_tree = self.range_trees[-1]
if self.inside_reduction and innermost_tree.is_loop:
prefix = innermost_tree.prefix
loop_start = "rsplit_start" if self.cooperative_reduction else "0"
loop_end = "rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
self.body.writeline(
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
)
with self.body.indent():
# last range tree is always reduction
self.iteration_ranges_codegen_header(self.range_trees[-1], self.body)
self.codegen_reduction_indices(self.body)
self.body.splice(self.indexing_code)
self.body.splice(self.loads)
self.body.splice(self.compute)
@ -3202,7 +3236,7 @@ class TritonKernel(SIMDKernel):
for tree in self.range_trees:
if tree.is_reduction and self.persistent_reduction:
# RBLOCK for persistent_reduction is defined in codegen_static_numels
# Rn_BLOCK for persistent_reduction is defined in codegen_static_numels
continue
if tree.tensor_dim is None:
continue
@ -3300,7 +3334,7 @@ class TritonKernel(SIMDKernel):
This code stomps on the passed-in values by writing an constant to the top of the kernel.
In a kernel like:
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, Rn_BLOCK : tl.constexpr):
We would add
xnumel = 4096
@ -3320,7 +3354,7 @@ class TritonKernel(SIMDKernel):
val = self._get_persistent_RBLOCK(tree.numel)
if self.cooperative_reduction:
val = f"{val} // RSPLIT"
code.writeline(f"RBLOCK: tl.constexpr = {val}")
code.writeline(f"{tree.prefix.upper()}BLOCK: tl.constexpr = {val}")
if tree.prefix == "x" and self.no_x_dim:
code.writeline("XBLOCK: tl.constexpr = 1")
@ -3488,6 +3522,74 @@ class TritonKernel(SIMDKernel):
if self._has_constant_mask(tree):
mask_vars.discard(f"{tree.prefix}mask")
@cache_on_self
def get_reduction_prefixes(self) -> List[str]:
return [
prefix_str[symt]
for symt in list(TritonSymbols.reduction_types)[: self.num_reduction_dims]
]
def codegen_reduction_numels(self, buffer) -> None:
"""
Generates code that flattens ND reduction numels, block sizes, etc. into 1D.
"""
# rnumel = r0_numel * ... * r(n-1)_numel
reduction_trees = [tree for tree in self.range_trees if tree.is_reduction]
rnumel = " * ".join(sorted(f"{tree.prefix}numel" for tree in reduction_trees))
buffer.splice(f"rnumel = {self.kexpr(rnumel)}")
# RBLOCK = R0_BLOCK * ... * R(N-1)_BLOCK
rn_blocks = [
TritonSymbols.block_sizes[tree.symt]
for tree in self.range_trees
if tree.is_reduction
]
rblock = sympy_product(rn_blocks)
buffer.splice(f"RBLOCK: tl.constexpr = {self.kexpr(rblock)}")
def _get_reduction_symbols(self, suffix: str, **kwargs) -> List[sympy.Symbol]:
"""
Helper to initialize symbols like rn_numel, rn_base, etc.
"""
rn_prefixes = self.get_reduction_prefixes()
return [sympy.Symbol(f"{prefix}{suffix}", **kwargs) for prefix in rn_prefixes]
@cache_on_self
def _get_reduction_index_coeffs(self) -> List[sympy.Expr]:
"""
Compute coefficients to convert ND reduction indices to linear indices.
For example:
rindex = r0_index * r1_numel * ... * rn_numel + ... + rn_index.
"""
rn_prefixes = self.get_reduction_prefixes()
rn_numels = self._get_reduction_symbols("numel", integer=True, positive=True)
return [
sympy_product(rn_numels[idx + 1 :]) for idx in range(len(rn_prefixes) - 1)
] + [sympy.Integer(1)]
def _flatten_reduction_indices(self, multi_inds: List[sympy.Expr]) -> sympy.Expr:
"""
Compute linear reduction indices from N dimensional ones.
"""
coeffs = self._get_reduction_index_coeffs()
return sympy_dot(coeffs, multi_inds)
def codegen_reduction_indices(self, buffer) -> None:
"""
Generates code that converts ND reduction indices into linear indices.
"""
# Gather relevant numels, indices, etc.
rn_offsets = self._get_reduction_symbols(
"offset", integer=True, nonnegative=True
)
rn_inds = self._get_reduction_symbols("index", integer=True, nonnegative=True)
# Compute roffset and rindex.
roffset = self._flatten_reduction_indices(rn_offsets)
buffer.splice(f"roffset = {self.index_to_str(roffset)}")
rindex = self._flatten_reduction_indices(rn_inds)
buffer.splice(f"rindex = {self.index_to_str(rindex)}")
def iteration_ranges_codegen_header(self, entry, code):
x = entry.prefix
if entry.is_loop:
@ -3774,7 +3876,7 @@ class TritonScheduling(SIMDScheduling):
)
)
if optional_cooperative:
rnumel = kernel.numels["r"]
rnumel = kernel.features.reduction_numel
# for larger sizes non-cooperative gets very slow
if V.graph.sizevars.statically_known_leq(rnumel, 65536):
kernels.append(

View File

@ -490,7 +490,7 @@ class ComboKernel(Kernel):
This code stomps on the passed-in values by writing an constant to the top of the kernel.
In a kernel like:
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
We would add
xnumel = 4096
@ -525,7 +525,8 @@ class ComboKernel(Kernel):
)
val = next_power_of_2(val)
code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}")
uniquify_block_sizes.append("RBLOCK")
code.writeline(f"R0_BLOCK_{num}: tl.constexpr = {val}")
uniquify_block_sizes.append("R0_BLOCK")
if tree.prefix == "x" and sub_kernel.no_x_dim:
code.writeline(f"XBLOCK_{num}: tl.constexpr = 1")
@ -723,17 +724,18 @@ class ComboKernel(Kernel):
def codegen_blocks(self, code: IndentedBuffer) -> None:
for block in self.block_args:
assert block in [
assert block in {
"XBLOCK",
"YBLOCK",
"RBLOCK",
], f"{block} is not supported without autotuning"
"R0_BLOCK",
}, f"{block} is not supported without autotuning"
if "YBLOCK" in self.block_args:
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
else:
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
if "RBLOCK" in self.block_args:
if "R0_BLOCK" in self.block_args:
code.splice(f"R0_BLOCK: tl.constexpr = {self.block_size_reduce}")
code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}")
def add_blockd_to_args(self, argdefs: List[str]) -> List[str]:
@ -753,6 +755,7 @@ class ComboKernel(Kernel):
if self.enable_autotune:
argdefs.extend(block_args)
self.block_args = list(block_names.keys())
return argdefs
def add_numel_to_args(self, argdefs: List[str], signature: List[Any]) -> List[str]:

View File

@ -5,13 +5,12 @@ from typing import Dict
import sympy
from torch._inductor import config
from torch._inductor.codegen.simd import IterationRangesRoot
from torch._inductor.codegen.simd import IterationRangesRoot, prefix_is_reduction
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
from torch._inductor.runtime.triton_heuristics import split_scan_grid
from torch.utils._sympy.functions import CeilDiv
from ..utils import sympy_product
from .simd import prefix_is_reduction
class TritonSplitScanKernel(TritonKernel):
@ -52,18 +51,17 @@ class TritonSplitScanKernel(TritonKernel):
return False
def initialize_range_tree(self, pid_cache):
prefixes = "yxr"
prefixes = ["y", "x", "r0_"]
assert len(self.numels) <= len(
prefixes
), "z dimension not supported for split scan"
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
grid_dims = "rxy"
grid_dims = {"r0_": 0, "x": 1, "y": 2}
for prefix in active_prefixes:
numel = self.numels[prefix]
is_reduction = prefix == "r"
tensor_dim = 0 if is_reduction else None
grid_dim = grid_dims.find(prefix)
tensor_dim = 0 if prefix_is_reduction(prefix) else None
grid_dim = grid_dims[prefix]
self.range_trees.append(
IterationRangesRoot(
f"{prefix}index",

View File

@ -1003,12 +1003,12 @@ class Reduction(Loops):
def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
index = self._index(self.ranges)
rindex = self._index(self.reduction_ranges, SymT.RINDEX)
rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
return (index, rindex)
def inner_fn_free_unbacked_symbols(self) -> Set[Symbol]:
index = self._index(self.ranges)
rindex = self._index(self.reduction_ranges, SymT.RINDEX)
rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
return extract_free_unbacked_symbols(self.inner_fn, index, rindex)
def constant_to_device(self, device: torch.device) -> IRNode:
@ -1973,13 +1973,13 @@ class Scan(Loops):
def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]:
index = self._index(self.ranges)
rindex = self._index(self.scan_ranges, SymT.RINDEX)
rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
idx = self.reindex(index, rindex)
return (idx,)
def inner_fn_free_unbacked_symbols(self) -> Set[Symbol]:
index = self._index(self.ranges)
rindex = self._index(self.scan_ranges, SymT.RINDEX)
rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
idx = self.reindex(index, rindex)
return extract_free_unbacked_symbols(self.inner_fn, idx)
@ -2170,13 +2170,13 @@ class Sort(Loops):
def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
index = self._index(self.ranges)
rindex = self._index(self.sort_ranges, SymT.RINDEX)
rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
idx = self.reindex(index, rindex)
return (idx,)
def inner_fn_free_unbacked_symbols(self) -> Set[Symbol]:
index = self._index(self.ranges)
rindex = self._index(self.sort_ranges, SymT.RINDEX)
rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
idx = self.reindex(index, rindex)
return extract_free_unbacked_symbols(self.inner_fn, idx)

View File

@ -85,10 +85,11 @@ class CoordescTuner:
"XBLOCK",
"YBLOCK",
"ZBLOCK",
# NOTE: we should not tune RBLOCK for persistent reduction.
# NOTE: we should not tune R0_BLOCK for persistent reduction.
# We rely on the fact that persistent reduction's triton.Config
# does not have the RBLOCK field to guarantee that.
"RBLOCK",
# does not have the R0_BLOCK field to guarantee that.
"R0_BLOCK",
"R1_BLOCK",
# the following 3 are for mm
"BLOCK_M",
"BLOCK_N",
@ -101,8 +102,10 @@ class CoordescTuner:
return out
def value_too_large(self, name: str, val: int) -> bool:
if name in {"XBLOCK", "YBLOCK", "ZBLOCK", "RBLOCK"}:
return val > self.get_config_max(name[0].lower())
block_suffix = "BLOCK"
if name.endswith(block_suffix):
prefix = name.strip(block_suffix).lower()
return val > self.get_config_max(prefix)
if name == "num_warps":
return val > self.get_warpsmax()
@ -245,7 +248,7 @@ class CoordescTuner:
for name in tunable_fields:
cur_val = get_field(best_config, name)
# some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None
# some kernel don't have R0_BLOCK/YBLOCK/ZBLOCK. So cur_val may be None
if cur_val is None:
continue

View File

@ -14,7 +14,8 @@ TRITON_MAX_BLOCK = {
"X": 4096,
"Y": 1024,
"Z": 1024,
"R": 4096 * 16, # * 16 is multi-kernel only
"R0_": 4096 * 16, # * 16 is multi-kernel only
"R1_": 2048 * 16, # * 16 is multi-kernel only
}
TRITON_MAX_RSPLIT = 64

View File

@ -20,6 +20,7 @@ from typing import Any, Container, Dict, List, Optional, Set, Tuple
import torch
from ..triton_bundler import TritonBundler
from ..utils import prefix_is_reduction
from .autotune_cache import AutotuneCache
from .benchmarking import benchmarker
from .coordinate_descent_tuner import CoordescTuner
@ -104,6 +105,12 @@ except AttributeError: # Compile workers only have a mock version of torch
log = logging.getLogger(__name__)
def get_total_reduction_numel(numels: Dict[str, int]) -> int:
return conditional_product(
*[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
)
def autotune_hints_to_configs(
hints: Set[AutotuneHint],
size_hints,
@ -333,23 +340,28 @@ class CachingAutotuner(KernelInterface):
for triton_config, compiled_binary in zip(
self.configs, compiled_binaries
):
assert len(self.size_hints) == 2
assert len(self.size_hints) >= 2
xblock = triton_config.kwargs.get("XBLOCK", 1)
rblock = triton_config.kwargs["RBLOCK"]
reduction_kwargs = [
kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R")
]
rblocks = [
triton_config.kwargs[kwarg] for kwarg in reduction_kwargs
]
total_block = (self.size_hints["x"] + xblock - 1) // xblock
nreg = getattr(compiled_binary, "n_regs", None)
if nreg is None:
continue
# make sure rblock is not too small
if rblock <= 64:
# make sure rblocks are not too small
if conditional_product(*rblocks) <= 64:
continue
# each SM of A100 has 65536 32-bit registers. To maximize
# the theoretical occupancy, we need run 2048 threads on each
# SM. So each thread should use no more than 65536 / 2048
# = 32 registers. In cases where occupancy matters, and each
# thread uses too many registers, reduce RBLOCK to reduce
# thread uses too many registers, reduce R0_BLOCK to reduce
# the register usage.
# For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
# from PLBartForCausalLM, latency improve from
@ -385,12 +397,19 @@ class CachingAutotuner(KernelInterface):
# no need to improve occupancy
continue
new_config = copy.deepcopy(triton_config)
new_config.kwargs["RBLOCK"] = rblock // 2
# Reduce the largest Rn_BLOCK by a factor of 2.
largest_rkwarg: str = max(
reduction_kwargs, key=triton_config.kwargs.__getitem__
)
new_config.kwargs[largest_rkwarg] //= 2
if new_config in seen_configs:
continue
seen_configs.add(new_config)
log.debug(
"Dynamically scale down RBLOCK from TritonConfig(%s) and get a new TritonConfig(%s)",
"Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)",
largest_rkwarg,
triton_config,
new_config,
)
@ -998,8 +1017,8 @@ class CachingAutotuner(KernelInterface):
assert not (
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
and "RBLOCK" in launcher.config.kwargs
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
and "R0_BLOCK" in launcher.config.kwargs
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
start_time = time.time_ns()
best_config = self.coordesc_tuner.autotune(
benchmark_one_config, launcher.config, None
@ -1362,6 +1381,20 @@ def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
)
def check_max_block(cfg: Dict[str, int]):
"""
Check that block sizes are within the maximum allowed.
"""
for var, val in cfg.items():
block_suffix = "BLOCK"
if block_suffix in var:
prefix = var.removesuffix(block_suffix)
max_block = TRITON_MAX_BLOCK[prefix]
assert (
val <= max_block
), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
# On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
# therefore using half the number of warps here correspondingly.
@ -1486,82 +1519,143 @@ def triton_config(
cfg["YBLOCK"] = y
if z:
cfg["ZBLOCK"] = z
assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}"
check_max_block(cfg)
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def _get_nd_reduction_numels(r: int, size_hints: Dict[str, int]) -> Dict[str, int]:
"""
Converts a linear reduction numel to ND, in row major order.
This order is often desirable as it presents opportunities to coalesce memory
accesses.
For example, if r = 64 and size_hints = [32,32], this function returns [32, 2].
This unraveling works because both r and size_hints are powers of 2.
"""
# Shrink r to size_hints.
r = min(r, get_total_reduction_numel(size_hints))
num_reduction_dims = len(
[prefix for prefix in size_hints if prefix_is_reduction(prefix)]
)
remaining = r
rnumels = {}
for idx in range(num_reduction_dims - 1, -1, -1):
prefix = f"r{idx}_"
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
dim = min(max_size, remaining)
assert (
remaining % dim == 0
), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
rnumels[prefix] = dim
remaining //= dim
# Sanity check the results.
final_numel = conditional_product(*rnumels.values())
assert (
r == final_numel
), f"Expected ND reduction size ({rnumels}) to have {r} elements."
assert all(
rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
return rnumels
def triton_config_reduction(
size_hints, x, r, num_stages=1, num_warps=None, register_intensive=False
size_hints,
x: int,
r: int,
num_stages=1,
num_warps=None,
register_intensive=False,
) -> Config:
"""
Construct a reduction triton config with some adjustment heuristics
based on size_hints. Size_hints is a tuple of numels in each tile
dimension and will be rounded up to the nearest power of 2.
"""
target = conditional_product(x, r)
if conditional_product(*size_hints.values()) < target:
target //= 8
# Convert the linear reduction numel into a multi-dimensional block.
rnumels = _get_nd_reduction_numels(r, size_hints)
# shrink sizes to size hints
x = min(x, size_hints["x"])
r = min(r, size_hints["r"])
def total_numel() -> int:
return conditional_product(x, *rnumels.values())
target = total_numel()
if conditional_product(*size_hints.values()) < target:
target //= 8
# if we are below original block size, scale up where we can
while x < size_hints["x"] and conditional_product(x, r) < target:
while x < size_hints["x"] and total_numel() < target:
x *= 2
while r < size_hints["r"] and conditional_product(x, r) < target:
r *= 2
for prefix in sorted(rnumels):
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
rnumels[prefix] *= 2
if num_warps is None:
num_warps = conditional_product(x, r) // 128
num_warps = total_numel() // 128
num_warps = _num_warps(
num_warps, max_num_warps=16, register_intensive=register_intensive
)
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
while conditional_product(x, r) > target:
if r == 1:
break
r = r // 2
for prefix in sorted(rnumels):
while total_numel() > target:
if rnumels[prefix] == 1:
break
rnumels[prefix] //= 2
cfg = {"XBLOCK": x, "RBLOCK": r}
cfg = _get_config({"x": x, **rnumels})
check_max_block(cfg)
check_config(cfg, xnumel=size_hints["x"])
assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}"
assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def _get_config(numels: Dict[str, int]) -> Dict[str, int]:
"""
Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc.
"""
return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()}
def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1):
"""
Construct a tile reduction triton config with some adjustment
heuristics based on size_hints. Size_hints is a tuple of numels in
each tile dimension and will be rounded up to the nearest power of 2.
"""
target = conditional_product(x, y, r)
if conditional_product(*size_hints) < target:
target //= 8
# Convert the linear reduction numel into a multi-dimensional block.
rnumels = _get_nd_reduction_numels(r, size_hints)
# shrink sizes to size hints
x = min(x, size_hints["x"])
y = min(y, size_hints["y"])
r = min(r, size_hints["r"])
def total_numel() -> int:
return conditional_product(x, y, *rnumels.values())
target = total_numel()
if conditional_product(*size_hints.values()) < target:
target //= 8
# if we are below original block size, scale up where we can
while x < size_hints["x"] and conditional_product(x, y, r) < target:
while x < size_hints["x"] and total_numel() < target:
x *= 2
while r < size_hints["r"] and conditional_product(x, y, r) < target:
r *= 2
while y < size_hints["y"] and conditional_product(x, y, r) < target:
for prefix in sorted(rnumels):
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
rnumels[prefix] *= 2
while y < size_hints[1] and total_numel() < target:
y *= 2
cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
num_warps = _num_warps(conditional_product(x, y, r) // 256, min_num_warps=1)
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
cfg = _get_config({"x": x, "y": y, **rnumels})
num_warps = _num_warps(total_numel() // 256, min_num_warps=1)
check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1])
check_max_block(cfg)
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
@ -1657,35 +1751,36 @@ def _reduction_configs(
*, size_hints: Dict[str, int], inductor_meta: Dict[str, Any]
) -> List[Config]:
reduction_hint = inductor_meta.get("reduction_hint", None)
assert len(size_hints) == 2
rnumel = size_hints["r"]
# Convert reductions to 1D, to simplify heuristics.
rnumel = get_total_reduction_numel(size_hints)
register_intensive = False
MAX_RBLOCK = 2048
MAX_R0_BLOCK = 2048
if (
size_hints["x"] >= 1024
and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0)
>= 10
):
# A heuristics to reduce RBLOCK if a kernel potentially need many registers.
# A heuristics to reduce R0_BLOCK if a kernel potentially need many registers.
# Consider load and reduction since load need move data into registers and
# reduction needs an accumulator.
#
# The magic numbers are a bit arbitrary.
#
# We cannot rely on dynamically scaling down RBLOCK later, since sometimes
# We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes
# triton makes it to use less registers with worse perf. Check:
# https://github.com/pytorch/pytorch/issues/126463
#
# The heuristic is a very simple one since registers can be reused. But
# hopefully it can be a good enough indicator.
MAX_RBLOCK = 1024
MAX_R0_BLOCK = 1024
register_intensive = True
contiguous_config = triton_config_reduction(
size_hints,
1,
(rnumel if 256 <= rnumel < MAX_RBLOCK else MAX_RBLOCK),
rnumel if 256 <= rnumel < MAX_R0_BLOCK else MAX_R0_BLOCK,
register_intensive=register_intensive,
)
outer_config = triton_config_reduction(
@ -1694,7 +1789,7 @@ def _reduction_configs(
tiny_config = triton_config_reduction(
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
min(rnumel, MAX_RBLOCK),
min(rnumel, MAX_R0_BLOCK),
register_intensive=register_intensive,
)
if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"):
@ -1713,7 +1808,7 @@ def _reduction_configs(
tiny_config,
triton_config_reduction(size_hints, 64, 64),
triton_config_reduction(size_hints, 8, 512),
# halve the XBLOCK/RBLOCK compared to outer_config
# halve the XBLOCK/Rn_BLOCK compared to outer_config
# TODO: this may only be beneficial when each iteration of the reduction
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
triton_config_reduction(size_hints, 64, 4, num_warps=8),
@ -1734,8 +1829,6 @@ def reduction(
size_hints["x"] = 1
assert triton_meta is not None
if len(size_hints) != 2:
raise NotImplementedError(f"size_hints: {size_hints}")
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
return cached_autotune(
@ -1759,7 +1852,12 @@ def cooperative_reduction(
inductor_meta["reduction_hint"] = reduction_hint
if inductor_meta.get("no_x_dim"):
size_hints["x"] = 1
xnumel, rnumel = size_hints["x"], size_hints["r"]
# Cooperative reductions currently only support a single reduction dimension.
assert (
len(size_hints) == 2
), "Cooperative reductions don't support tiling reduction dims"
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
# TODO(jansel): we should base target on the SM count of the local GPU
target = 64
@ -1768,11 +1866,12 @@ def cooperative_reduction(
assert split <= TRITON_MAX_RSPLIT
if inductor_meta["persistent_reduction"]:
configs = _persistent_reduction_configs(
{"x": xnumel, "r": rnumel // split}, reduction_hint, inductor_meta
{"x": xnumel, "r0_": rnumel // split}, reduction_hint, inductor_meta
)
else:
configs = _reduction_configs(
size_hints={"x": xnumel, "r": rnumel // split}, inductor_meta=inductor_meta
size_hints={"x": xnumel, "r0_": rnumel // split},
inductor_meta=inductor_meta,
)
for config in configs:
config.kwargs["RSPLIT"] = split
@ -1793,7 +1892,8 @@ def _persistent_reduction_configs(
reduction_hint=False,
inductor_meta=None,
):
xnumel, rnumel = size_hints["x"], size_hints["r"]
xnumel = size_hints["x"]
rnumel = get_total_reduction_numel(size_hints)
configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
@ -1809,12 +1909,16 @@ def _persistent_reduction_configs(
elif reduction_hint == ReductionHint.OUTER_TINY:
configs = [
triton_config_reduction(
size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
size_hints,
2 * (256 // rnumel) if rnumel <= 256 else 1,
rnumel,
)
]
for c in configs:
# we don't need RBLOCK for persistent reduction
c.kwargs.pop("RBLOCK")
# we don't need Rn_BLOCK for persistent reduction
for prefix in size_hints:
if prefix_is_reduction(prefix):
c.kwargs.pop(f"{prefix.upper()}BLOCK")
if disable_pointwise_autotuning(inductor_meta):
configs = configs[:1]
@ -1865,11 +1969,12 @@ def split_scan(
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
# Fixup configs to enforce the minimum RBLOCK size
# Fixup configs to enforce the minimum Rn_BLOCK size
min_rblock = inductor_meta.get("min_split_scan_rblock", 256)
for cfg in configs:
if cfg.kwargs["RBLOCK"] < min_rblock:
cfg.kwargs["RBLOCK"] = min_rblock
for var in list(cfg.kwargs.keys()):
if var.startswith("R") and cfg.kwargs[var] < min_rblock:
cfg.kwargs[var] = min_rblock
return cached_autotune(
size_hints,
@ -2028,7 +2133,7 @@ def maybe_cooperative_reduction_grid(xnumel):
def split_scan_grid(xnumel, rnumel):
def grid_fn(meta):
assert meta.get("XBLOCK", 1) == 1
return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1)
return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010

View File

@ -262,7 +262,7 @@ class TritonTemplateKernel(TritonKernel):
super().__init__(
{
"x": numel,
"r": sympy.S.One,
"r0_": sympy.S.One,
},
features=SIMDKernelFeatures([], numel),
)

View File

@ -681,6 +681,10 @@ def get_bounds_index_expr(index):
return ValueRanges.unknown()
def prefix_is_reduction(prefix: str) -> bool:
return prefix[0] == "r"
def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
"""
Used to generate an integer-nonnegative symbol.

View File

@ -13,7 +13,7 @@ in this file and seeing what breaks.
"""
from enum import auto, Enum
from typing import Sequence, Union
from typing import Iterable, Union
import sympy
@ -36,9 +36,10 @@ class SymT(Enum):
# Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
# dim in the loop
INDEX = auto()
# Inductor: A reduction indexing r0 variable in loops IR which ranges over
# reduced dim in the loop
RINDEX = auto()
# Inductor: A reduction indexing (r0, r1) variables in loops IR which ranges over
# reduced dim(s) in the loop
R0_INDEX = auto()
R1_INDEX = auto()
# Inductor: In templated kernels torch._inductor.kernel, we have a hook to
# store the final output and append epilogue fusions. To do this, we must
# know what the indexes the outputs range over. NB: These will also
@ -67,7 +68,8 @@ prefix_str = {
SymT.TMP: "tmp",
SymT.PRECOMPUTED_SIZE: "ps",
SymT.INDEX: "i",
SymT.RINDEX: "r",
SymT.R0_INDEX: "r0_",
SymT.R1_INDEX: "r1_",
SymT.TEMPLATE_INDEX: "idx",
SymT.XBLOCK: "x",
SymT.YBLOCK: "y",
@ -85,7 +87,7 @@ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
# This type is a little wider than it should be, because free_symbols says
# that it contains Basic, rather than Symbol
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> bool:
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> bool:
assert isinstance(sym, sympy.Symbol)
name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
if isinstance(prefix, SymT):
@ -94,5 +96,5 @@ def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> boo
return name_str.startswith(tuple(prefix_str[p] for p in prefix))
def free_symbol_is_type(e: sympy.Expr, prefix: SymT) -> bool:
def free_symbol_is_type(e: sympy.Expr, prefix: Union[SymT, Iterable[SymT]]) -> bool:
return any(symbol_is_type(v, prefix) for v in e.free_symbols)