mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor] Refactor "r" reduction prefix to {"r0_", "r1_"}. (#142020)
Preparatory refactor for https://github.com/pytorch/pytorch/pull/137243. # Feature This PR changes the `RINDEX` / `"r"` symbol type to `(R0_INDEX, R1_INDEX)` and `("r0_", "r1_")`, respectively. This allows the relevant code to support 2D (often ND) reductions. Unlike the parent PR, this one does not change the tiling algorithm, so `"r1_"` is never used. However, it prepares other parts of the system to handle `"r1_"` once we start using it. This should significantly reduce the chances of hitting merge conflicts, making the parent PR much easier to land. The only change to the generated triton code is to rename `"rindex"` -> `"r0_index"`, `"RBLOCK"` -> `"R0_BLOCK"`, etc. To maintain compatibilty with existing codegen, this also generates aliases to the old reduction variables like `rindex = r0_index`. If we generated 2D reductions (which this PR will not do), the aliases would be more complicated and would collapse 2D multi-indices to linear indices. See some example kernels in the parent PR. These aliases can be eliminated by the Triton compiler, and should not impact the final machine code running on the GPU. See the perf testing in the parent PR which confirms the aliases do not impact perf. # Test plan The existing CI provides good coverage. This PR modifies the expected code in a few places, renaming reduction variables from `r.*` to `r0_.*`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142020 Approved by: https://github.com/jansel Co-authored-by: Jason Ansel <jansel@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf538efd0c
commit
520ba556cd
@ -109,8 +109,8 @@ class TestCoordinateDescentTuner(TestCase):
|
||||
max_block = TRITON_MAX_BLOCK
|
||||
self.assertFalse(tuner.value_too_large("XBLOCK", max_block["X"]))
|
||||
self.assertTrue(tuner.value_too_large("XBLOCK", max_block["X"] * 2))
|
||||
self.assertFalse(tuner.value_too_large("RBLOCK", max_block["R"]))
|
||||
self.assertTrue(tuner.value_too_large("RBLOCK", max_block["R"] * 2))
|
||||
self.assertFalse(tuner.value_too_large("R0_BLOCK", max_block["R0_"]))
|
||||
self.assertTrue(tuner.value_too_large("R0_BLOCK", max_block["R0_"] * 2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -487,7 +487,7 @@ class PaddingTest(TestCaseBase):
|
||||
|
||||
# make sure the load for softmax is aligned
|
||||
self.assertTrue(
|
||||
"tl.load(in_ptr0 + (r1 + 30528*x0)" in forward_wrapper,
|
||||
"tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper,
|
||||
f"forward_wrapper: {forward_wrapper}",
|
||||
)
|
||||
|
||||
|
@ -1846,7 +1846,7 @@ class CommonTemplate:
|
||||
from torch._inductor.runtime.runtime_utils import next_power_of_2
|
||||
from torch._inductor.runtime.triton_heuristics import triton_config_reduction
|
||||
|
||||
size_hints = {"x": 67108864, "r": 8192}
|
||||
size_hints = {"x": 67108864, "r0_": 8192}
|
||||
for i in range(4):
|
||||
size_hints["x"] = next_power_of_2(size_hints["x"])
|
||||
triton_config_reduction(size_hints, 1, 2048, 1, 8)
|
||||
@ -12547,8 +12547,8 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
self.assertExpectedInline(
|
||||
"\n".join(lines),
|
||||
"""\
|
||||
tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r2), rmask, eviction_policy='evict_last', other=0.0)
|
||||
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r2), rmask, eviction_policy='evict_first', other=0.0)""",
|
||||
tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r0_2), r0_mask, eviction_policy='evict_last', other=0.0)
|
||||
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r0_2), r0_mask, eviction_policy='evict_first', other=0.0)""",
|
||||
)
|
||||
|
||||
@config.patch("triton.use_block_ptr", True)
|
||||
@ -12571,16 +12571,16 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||
self.assertExpectedInline(
|
||||
"\n".join(lines),
|
||||
"""\
|
||||
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
|
||||
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[262144, 512], strides=[1, 262144], block_shape=[XBLOCK, RBLOCK], order=[0, 1], offsets=[xoffset, roffset]), boundary_check=[1], padding_option='zero')
|
||||
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
|
||||
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r0_2)), rmask, eviction_policy='evict_last', other=0.0)
|
||||
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[262144, 512], strides=[1, 262144], block_shape=[XBLOCK, R0_BLOCK], order=[0, 1], offsets=[xoffset, roffset]), boundary_check=[1], padding_option='zero')
|
||||
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r0_2)), rmask, eviction_policy='evict_last', other=0.0)
|
||||
tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
"\n".join(lines),
|
||||
"""\
|
||||
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
|
||||
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), R0_BLOCK]), [XBLOCK, R0_BLOCK])
|
||||
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
|
||||
)
|
||||
|
||||
|
@ -51,6 +51,7 @@ from ..utils import (
|
||||
get_dtype_size,
|
||||
IndentedBuffer,
|
||||
Placeholder,
|
||||
prefix_is_reduction,
|
||||
sympy_index_symbol,
|
||||
sympy_product,
|
||||
sympy_subs,
|
||||
@ -75,9 +76,7 @@ fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
|
||||
|
||||
pexpr = PythonPrinter().doprint
|
||||
|
||||
|
||||
def prefix_is_reduction(prefix: str) -> bool:
|
||||
return prefix[0] == "r"
|
||||
all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"])
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -359,7 +358,7 @@ class SIMDKernel(Kernel):
|
||||
self.range_trees: List[IterationRangesRoot] = []
|
||||
self.range_tree_nodes: Dict[sympy.Symbol, IterationRangesEntry] = {}
|
||||
self.iter_vars_count = itertools.count()
|
||||
self.inside_reduction = self.numels["r"] != 1
|
||||
self.inside_reduction = features.is_reduction()
|
||||
self.cooperative_reduction: bool = (
|
||||
override_cooperative_reduction
|
||||
if override_cooperative_reduction is not None
|
||||
@ -385,6 +384,12 @@ class SIMDKernel(Kernel):
|
||||
self.simplify_indexing = simplify_indexing
|
||||
self.initialize_range_tree(pid_cache)
|
||||
|
||||
@property
|
||||
@cache_on_self
|
||||
@no_type_check # https://github.com/python/mypy/issues/17184
|
||||
def num_reduction_dims(self) -> int:
|
||||
return sum(prefix_is_reduction(prefix) for prefix in self.numels)
|
||||
|
||||
def dtype_to_str(self, dtype: torch.dtype) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -396,25 +401,34 @@ class SIMDKernel(Kernel):
|
||||
return False
|
||||
|
||||
def initialize_range_tree(self, pid_cache):
|
||||
no_r_dim = not self.inside_reduction or self.numels["r"] == 1
|
||||
active_prefixes = OrderedSet(
|
||||
prefix for prefix in all_prefixes if prefix in self.numels
|
||||
)
|
||||
no_r_dim = not self.inside_reduction or not self.features.is_reduction()
|
||||
|
||||
prefixes = "zyxr"
|
||||
active_prefixes = prefixes[-len(self.numels) :]
|
||||
def filtered_index_map(seq, mask) -> Dict[Any, int]:
|
||||
return {
|
||||
val: idx for idx, val in enumerate(val for val in seq if val in mask)
|
||||
}
|
||||
|
||||
grid_dims = "xyz"
|
||||
grid_dims = ["x", "y", "z"]
|
||||
reduction_dims = ["r0_", "r1_"]
|
||||
if self.no_x_dim:
|
||||
tensor_dims = "r"
|
||||
tensor_dims = reduction_dims
|
||||
elif no_r_dim:
|
||||
tensor_dims = "xyz"
|
||||
tensor_dims = grid_dims
|
||||
else:
|
||||
tensor_dims = "xyzr"
|
||||
tensor_dims = grid_dims + reduction_dims
|
||||
|
||||
tensor_dims = "".join(p for p in tensor_dims if p in active_prefixes)
|
||||
# Filter out unused tensor dims.
|
||||
# Convert to dicts for O(1) index lookup.
|
||||
tensor_dim_map = filtered_index_map(tensor_dims, active_prefixes)
|
||||
grid_dim_map = filtered_index_map(grid_dims, all_prefixes)
|
||||
|
||||
for i, prefix in enumerate(active_prefixes):
|
||||
is_reduction = prefix_is_reduction(prefix)
|
||||
tensor_dim = tensor_dims.find(prefix) if prefix in tensor_dims else None
|
||||
grid_dim = None if is_reduction else grid_dims.find(prefix)
|
||||
tensor_dim = tensor_dim_map.get(prefix)
|
||||
grid_dim = grid_dim_map.get(prefix)
|
||||
index = i if grid_dim is None else grid_dim
|
||||
self.range_trees.append(
|
||||
IterationRangesRoot(
|
||||
@ -427,7 +441,7 @@ class SIMDKernel(Kernel):
|
||||
is_loop=is_reduction and not self.persistent_reduction,
|
||||
tensor_dim=tensor_dim,
|
||||
grid_dim=grid_dim,
|
||||
has_zdim="z" in active_prefixes,
|
||||
has_zdim="z" in self.numels,
|
||||
)
|
||||
)
|
||||
|
||||
@ -528,7 +542,7 @@ class SIMDKernel(Kernel):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ctx():
|
||||
if self.numels["r"] == 1:
|
||||
if not self.features.is_reduction():
|
||||
assert not self.inside_reduction
|
||||
yield
|
||||
return
|
||||
@ -963,7 +977,7 @@ class SIMDKernel(Kernel):
|
||||
def welford_reduce_fallback(self, dtype, value):
|
||||
sum_ = ops.reduction(dtype, dtype, "sum", value)
|
||||
self.inside_reduction = False
|
||||
rnumel = ops.index_expr(self.numels["r"], dtype)
|
||||
rnumel = ops.index_expr(self.features.reduction_numel, dtype)
|
||||
mean = ops.truediv(sum_, rnumel)
|
||||
|
||||
self.inside_reduction = True
|
||||
@ -1572,10 +1586,9 @@ class SIMDScheduling(BaseScheduling):
|
||||
Create a tiling dict from pointwise and reduction splits.
|
||||
"""
|
||||
pw_prefixes = ["z", "y", "x"][-len(pw_tiling) :]
|
||||
reduction_prefixes = ["r"][: len(reduction_tiling)]
|
||||
reduction_prefixes = ["r0_", "r1_"][: len(reduction_tiling)]
|
||||
return immutable_dict(
|
||||
list(zip(pw_prefixes, pw_tiling))
|
||||
+ list(zip(reduction_prefixes, reduction_tiling))
|
||||
[*zip(pw_prefixes, pw_tiling), *zip(reduction_prefixes, reduction_tiling)]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -55,12 +55,16 @@ from ..runtime.triton_heuristics import (
|
||||
)
|
||||
from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode
|
||||
from ..utils import (
|
||||
cache_on_self,
|
||||
DelayReplaceLine,
|
||||
get_bounds_index_expr,
|
||||
get_fused_kernel_name,
|
||||
get_kernel_metadata,
|
||||
is_welford_reduction,
|
||||
Placeholder,
|
||||
prefix_is_reduction,
|
||||
sympy_dot,
|
||||
sympy_product,
|
||||
sympy_subs,
|
||||
triton_type,
|
||||
upcast_compute_type,
|
||||
@ -87,7 +91,6 @@ from .simd import (
|
||||
IterationRangesEntry,
|
||||
IterationRangesRoot,
|
||||
pexpr,
|
||||
prefix_is_reduction,
|
||||
SIMDKernel,
|
||||
SIMDScheduling,
|
||||
)
|
||||
@ -170,16 +173,19 @@ class TritonSymbols:
|
||||
Stores sympy.Symbol instances and constants associated with triton codegen.
|
||||
"""
|
||||
|
||||
reduction_types = OrderedSet([SymT.R0_INDEX, SymT.R1_INDEX])
|
||||
block_types = OrderedSet([SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, *reduction_types])
|
||||
|
||||
block_offsets = {
|
||||
symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True)
|
||||
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
|
||||
for symt in block_types
|
||||
}
|
||||
|
||||
block_sizes = {
|
||||
symt: sympy.Symbol(
|
||||
f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True
|
||||
)
|
||||
for symt in [SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, SymT.RINDEX]
|
||||
for symt in block_types
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -213,7 +219,7 @@ class IndexingOptions:
|
||||
return "tmp" in self.mask_str
|
||||
|
||||
def has_rmask(self):
|
||||
return "rmask" in self.mask_str
|
||||
return any(str(mask).startswith("r") for mask in self.mask_vars)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -357,7 +363,7 @@ class BlockPtrOptions:
|
||||
if (
|
||||
not V.kernel.inside_reduction
|
||||
and len(params.strides) == len(V.kernel.numels) - 1
|
||||
and V.kernel.numels["r"] != 1
|
||||
and V.kernel.features.is_reduction()
|
||||
):
|
||||
# Need to expand rank by 1 to match rank when self.inside_reduction=True
|
||||
final_shape.append(sympy.S.One)
|
||||
@ -376,9 +382,9 @@ class BlockPtrOptions:
|
||||
|
||||
def replace_roffset(self, expr: sympy.Expr, replacement: sympy.Expr) -> sympy.Expr:
|
||||
"""
|
||||
Replaces instances of roffset with the new expression.
|
||||
Replaces instances of r0_offset with the new expression.
|
||||
"""
|
||||
roffset = TritonSymbols.block_offsets[SymT.RINDEX]
|
||||
roffset = TritonSymbols.block_offsets[SymT.R0_INDEX]
|
||||
return sympy_subs(expr, {roffset: replacement})
|
||||
|
||||
def format(self, name: str, roffset=True) -> str:
|
||||
@ -387,7 +393,7 @@ class BlockPtrOptions:
|
||||
|
||||
Args:
|
||||
name: variable name for pointer
|
||||
roffset: should roffset be included in offsets=..., for use with tl.advance()
|
||||
roffset: should rn_offset be included in offsets=..., for use with tl.advance()
|
||||
|
||||
Returns:
|
||||
"tl.make_block_ptr(...)"
|
||||
@ -448,11 +454,11 @@ class BlockPtrOptions:
|
||||
Codegen string to pass to tl.advance(name, ...).
|
||||
|
||||
Advance is the difference between offsets in each loop iteration.
|
||||
To compute it, we replace roffset with multiples of RBLOCK.
|
||||
Since we expect roffset to vary in range(0, rnumel, RBLOCK), the first
|
||||
iteration has roffset=0, while the second has roffset=RBLOCK.
|
||||
To compute it, we replace roffset with multiples of R0_BLOCK.
|
||||
Since we expect roffset to vary in range(0, rnumel, R0_BLOCK), the first
|
||||
iteration has roffset=0, while the second has roffset=R0_BLOCK.
|
||||
"""
|
||||
rblock = TritonSymbols.block_sizes[SymT.RINDEX]
|
||||
rblock = TritonSymbols.block_sizes[SymT.R0_INDEX]
|
||||
advance = [
|
||||
(
|
||||
self.replace_roffset(offset, rblock)
|
||||
@ -466,7 +472,10 @@ class BlockPtrOptions:
|
||||
return False # block_ptr can't do indirect indexing
|
||||
|
||||
def has_rindex(self) -> bool:
|
||||
return any(free_symbol_is_type(expr, SymT.RINDEX) for expr in self.block_shape)
|
||||
return any(
|
||||
free_symbol_is_type(expr, TritonSymbols.reduction_types)
|
||||
for expr in self.block_shape
|
||||
)
|
||||
|
||||
def has_rmask(self):
|
||||
return self.has_rindex()
|
||||
@ -722,11 +731,14 @@ class TritonCSEVariable(CSEVariable):
|
||||
for arg in args:
|
||||
if isinstance(arg, TritonCSEVariable):
|
||||
self.mask_vars.update(arg.mask_vars)
|
||||
elif isinstance(arg, sympy.Symbol) and arg.name[0] in "xyr":
|
||||
elif isinstance(arg, sympy.Symbol):
|
||||
# most of the time index vars don't need masks associated with them
|
||||
# however, when index vars are used to compute indices for indirect reads
|
||||
# those reads should subsequently be masked,
|
||||
self.mask_vars.update({f"{arg.name[0]}mask"})
|
||||
for symt in TritonSymbols.block_types:
|
||||
if symbol_is_type(arg, symt):
|
||||
self.mask_vars.update({f"{prefix_str[symt]}mask"})
|
||||
break
|
||||
|
||||
|
||||
def maybe_upcast_float32(convert_output: bool = True):
|
||||
@ -1543,6 +1555,9 @@ class TritonKernel(SIMDKernel):
|
||||
self.autotune_hints: OrderedSet[AutotuneHint] = OrderedSet()
|
||||
self.triton_meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
if self.inside_reduction:
|
||||
self.codegen_reduction_numels(self.body)
|
||||
|
||||
if self.cooperative_reduction:
|
||||
self.init_cooperative_reduction()
|
||||
|
||||
@ -1592,12 +1607,24 @@ class TritonKernel(SIMDKernel):
|
||||
# reduction indexing goes inside a loop
|
||||
if not tree.is_loop:
|
||||
self.iteration_ranges_codegen_header(tree, self.body)
|
||||
if self.inside_reduction and self.range_trees[-1].is_loop:
|
||||
# workaround for this issue:
|
||||
# https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
|
||||
self.body.writeline(
|
||||
f"rbase = {self.iteration_ranges_ranges_code(self.range_trees[-1])}"
|
||||
)
|
||||
elif self.inside_reduction:
|
||||
# workaround for this issue:
|
||||
# https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7
|
||||
self.body.writeline(
|
||||
f"{tree.prefix}base = {self.iteration_ranges_ranges_code(tree)}"
|
||||
)
|
||||
|
||||
if self.inside_reduction:
|
||||
if any(tree.is_loop for tree in self.range_trees):
|
||||
# If the kernel contains loops, compute rbase.
|
||||
rn_bases = self._get_reduction_symbols(
|
||||
"base", integer=True, nonnegative=True
|
||||
)
|
||||
rbase = self._flatten_reduction_indices(rn_bases)
|
||||
self.body.splice(f"rbase = {self.index_to_str(rbase)}")
|
||||
else:
|
||||
# For looped reductions, indexing is deferred to the innermost loop.
|
||||
self.codegen_reduction_indices(self.body)
|
||||
|
||||
def need_numel_args(self):
|
||||
r"""
|
||||
@ -1644,7 +1671,9 @@ class TritonKernel(SIMDKernel):
|
||||
mask_vars: OrderedSet[str] = OrderedSet()
|
||||
for var in index_vars:
|
||||
assert isinstance(var, sympy.Symbol)
|
||||
has_rindex = has_rindex or symbol_is_type(var, SymT.RINDEX)
|
||||
has_rindex = has_rindex or symbol_is_type(
|
||||
var, TritonSymbols.reduction_types
|
||||
)
|
||||
if override_mask:
|
||||
pass
|
||||
elif symbol_is_type(var, SymT.TMP):
|
||||
@ -1664,11 +1693,14 @@ class TritonKernel(SIMDKernel):
|
||||
):
|
||||
pass
|
||||
else:
|
||||
# var is one of xN, yN or rN
|
||||
assert symbol_is_type(
|
||||
var, (SymT.RINDEX, SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK)
|
||||
), var.name
|
||||
mask_vars.add(f"{var.name[0]}mask")
|
||||
# var is one of xN, yN, r0_N or r1_N
|
||||
prefix_matches = [
|
||||
prefix_str[symt]
|
||||
for symt in TritonSymbols.block_types
|
||||
if symbol_is_type(var, symt)
|
||||
]
|
||||
assert len(prefix_matches) == 1, f"Ambiguous type: {var.name}"
|
||||
mask_vars.add(f"{prefix_matches[0]}mask")
|
||||
|
||||
need_dense = (
|
||||
config.triton.dense_indexing
|
||||
@ -1835,8 +1867,8 @@ class TritonKernel(SIMDKernel):
|
||||
range_trees = self.active_range_trees(reorder=True)
|
||||
|
||||
# Partition the index into subexpressions pertaining to each range tree.
|
||||
# For example xindex * 5 + rindex * 3 is partitioned to
|
||||
# (xindex * 5, rindex * 3).
|
||||
# For example xindex * 5 + r0_index * 3 is partitioned to
|
||||
# (xindex * 5, r0_index * 3).
|
||||
index_subexprs = [
|
||||
BlockPatternMatcher.get_subexpr_involving_symbol(
|
||||
index_relative_to_xyr_index, tree.symbol()
|
||||
@ -1848,7 +1880,7 @@ class TritonKernel(SIMDKernel):
|
||||
range_symbols = {tree.symbol() for tree in range_trees}
|
||||
block_params = BlockParameters()
|
||||
for tree, subexpr in zip(range_trees, index_subexprs):
|
||||
# Reject mixed terms, e.g. xindex * rindex.
|
||||
# Reject mixed terms, e.g. xindex * r0_index.
|
||||
# NB: the zero expression is allowed, for broadcasting.
|
||||
if len(range_symbols.intersection(subexpr.free_symbols)) > 1:
|
||||
return None
|
||||
@ -2847,17 +2879,19 @@ class TritonKernel(SIMDKernel):
|
||||
):
|
||||
return
|
||||
|
||||
if self.inside_reduction and self.range_trees[-1].is_loop:
|
||||
if self.cooperative_reduction:
|
||||
self.body.writeline(
|
||||
"for roffset in range(rsplit_start, rsplit_end, RBLOCK):"
|
||||
)
|
||||
else:
|
||||
self.body.writeline("for roffset in range(0, rnumel, RBLOCK):")
|
||||
innermost_tree = self.range_trees[-1]
|
||||
if self.inside_reduction and innermost_tree.is_loop:
|
||||
prefix = innermost_tree.prefix
|
||||
loop_start = "rsplit_start" if self.cooperative_reduction else "0"
|
||||
loop_end = "rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
|
||||
self.body.writeline(
|
||||
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
|
||||
)
|
||||
|
||||
with self.body.indent():
|
||||
# last range tree is always reduction
|
||||
self.iteration_ranges_codegen_header(self.range_trees[-1], self.body)
|
||||
self.codegen_reduction_indices(self.body)
|
||||
self.body.splice(self.indexing_code)
|
||||
self.body.splice(self.loads)
|
||||
self.body.splice(self.compute)
|
||||
@ -3202,7 +3236,7 @@ class TritonKernel(SIMDKernel):
|
||||
|
||||
for tree in self.range_trees:
|
||||
if tree.is_reduction and self.persistent_reduction:
|
||||
# RBLOCK for persistent_reduction is defined in codegen_static_numels
|
||||
# Rn_BLOCK for persistent_reduction is defined in codegen_static_numels
|
||||
continue
|
||||
if tree.tensor_dim is None:
|
||||
continue
|
||||
@ -3300,7 +3334,7 @@ class TritonKernel(SIMDKernel):
|
||||
This code stomps on the passed-in values by writing an constant to the top of the kernel.
|
||||
|
||||
In a kernel like:
|
||||
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
|
||||
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, Rn_BLOCK : tl.constexpr):
|
||||
|
||||
We would add
|
||||
xnumel = 4096
|
||||
@ -3320,7 +3354,7 @@ class TritonKernel(SIMDKernel):
|
||||
val = self._get_persistent_RBLOCK(tree.numel)
|
||||
if self.cooperative_reduction:
|
||||
val = f"{val} // RSPLIT"
|
||||
code.writeline(f"RBLOCK: tl.constexpr = {val}")
|
||||
code.writeline(f"{tree.prefix.upper()}BLOCK: tl.constexpr = {val}")
|
||||
|
||||
if tree.prefix == "x" and self.no_x_dim:
|
||||
code.writeline("XBLOCK: tl.constexpr = 1")
|
||||
@ -3488,6 +3522,74 @@ class TritonKernel(SIMDKernel):
|
||||
if self._has_constant_mask(tree):
|
||||
mask_vars.discard(f"{tree.prefix}mask")
|
||||
|
||||
@cache_on_self
|
||||
def get_reduction_prefixes(self) -> List[str]:
|
||||
return [
|
||||
prefix_str[symt]
|
||||
for symt in list(TritonSymbols.reduction_types)[: self.num_reduction_dims]
|
||||
]
|
||||
|
||||
def codegen_reduction_numels(self, buffer) -> None:
|
||||
"""
|
||||
Generates code that flattens ND reduction numels, block sizes, etc. into 1D.
|
||||
"""
|
||||
# rnumel = r0_numel * ... * r(n-1)_numel
|
||||
reduction_trees = [tree for tree in self.range_trees if tree.is_reduction]
|
||||
rnumel = " * ".join(sorted(f"{tree.prefix}numel" for tree in reduction_trees))
|
||||
buffer.splice(f"rnumel = {self.kexpr(rnumel)}")
|
||||
|
||||
# RBLOCK = R0_BLOCK * ... * R(N-1)_BLOCK
|
||||
rn_blocks = [
|
||||
TritonSymbols.block_sizes[tree.symt]
|
||||
for tree in self.range_trees
|
||||
if tree.is_reduction
|
||||
]
|
||||
rblock = sympy_product(rn_blocks)
|
||||
buffer.splice(f"RBLOCK: tl.constexpr = {self.kexpr(rblock)}")
|
||||
|
||||
def _get_reduction_symbols(self, suffix: str, **kwargs) -> List[sympy.Symbol]:
|
||||
"""
|
||||
Helper to initialize symbols like rn_numel, rn_base, etc.
|
||||
"""
|
||||
rn_prefixes = self.get_reduction_prefixes()
|
||||
return [sympy.Symbol(f"{prefix}{suffix}", **kwargs) for prefix in rn_prefixes]
|
||||
|
||||
@cache_on_self
|
||||
def _get_reduction_index_coeffs(self) -> List[sympy.Expr]:
|
||||
"""
|
||||
Compute coefficients to convert ND reduction indices to linear indices.
|
||||
For example:
|
||||
rindex = r0_index * r1_numel * ... * rn_numel + ... + rn_index.
|
||||
"""
|
||||
rn_prefixes = self.get_reduction_prefixes()
|
||||
rn_numels = self._get_reduction_symbols("numel", integer=True, positive=True)
|
||||
return [
|
||||
sympy_product(rn_numels[idx + 1 :]) for idx in range(len(rn_prefixes) - 1)
|
||||
] + [sympy.Integer(1)]
|
||||
|
||||
def _flatten_reduction_indices(self, multi_inds: List[sympy.Expr]) -> sympy.Expr:
|
||||
"""
|
||||
Compute linear reduction indices from N dimensional ones.
|
||||
"""
|
||||
coeffs = self._get_reduction_index_coeffs()
|
||||
return sympy_dot(coeffs, multi_inds)
|
||||
|
||||
def codegen_reduction_indices(self, buffer) -> None:
|
||||
"""
|
||||
Generates code that converts ND reduction indices into linear indices.
|
||||
"""
|
||||
# Gather relevant numels, indices, etc.
|
||||
rn_offsets = self._get_reduction_symbols(
|
||||
"offset", integer=True, nonnegative=True
|
||||
)
|
||||
rn_inds = self._get_reduction_symbols("index", integer=True, nonnegative=True)
|
||||
|
||||
# Compute roffset and rindex.
|
||||
roffset = self._flatten_reduction_indices(rn_offsets)
|
||||
buffer.splice(f"roffset = {self.index_to_str(roffset)}")
|
||||
rindex = self._flatten_reduction_indices(rn_inds)
|
||||
buffer.splice(f"rindex = {self.index_to_str(rindex)}")
|
||||
|
||||
def iteration_ranges_codegen_header(self, entry, code):
|
||||
x = entry.prefix
|
||||
if entry.is_loop:
|
||||
@ -3774,7 +3876,7 @@ class TritonScheduling(SIMDScheduling):
|
||||
)
|
||||
)
|
||||
if optional_cooperative:
|
||||
rnumel = kernel.numels["r"]
|
||||
rnumel = kernel.features.reduction_numel
|
||||
# for larger sizes non-cooperative gets very slow
|
||||
if V.graph.sizevars.statically_known_leq(rnumel, 65536):
|
||||
kernels.append(
|
||||
|
@ -490,7 +490,7 @@ class ComboKernel(Kernel):
|
||||
This code stomps on the passed-in values by writing an constant to the top of the kernel.
|
||||
|
||||
In a kernel like:
|
||||
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
|
||||
def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
|
||||
|
||||
We would add
|
||||
xnumel = 4096
|
||||
@ -525,7 +525,8 @@ class ComboKernel(Kernel):
|
||||
)
|
||||
val = next_power_of_2(val)
|
||||
code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}")
|
||||
uniquify_block_sizes.append("RBLOCK")
|
||||
code.writeline(f"R0_BLOCK_{num}: tl.constexpr = {val}")
|
||||
uniquify_block_sizes.append("R0_BLOCK")
|
||||
|
||||
if tree.prefix == "x" and sub_kernel.no_x_dim:
|
||||
code.writeline(f"XBLOCK_{num}: tl.constexpr = 1")
|
||||
@ -723,17 +724,18 @@ class ComboKernel(Kernel):
|
||||
|
||||
def codegen_blocks(self, code: IndentedBuffer) -> None:
|
||||
for block in self.block_args:
|
||||
assert block in [
|
||||
assert block in {
|
||||
"XBLOCK",
|
||||
"YBLOCK",
|
||||
"RBLOCK",
|
||||
], f"{block} is not supported without autotuning"
|
||||
"R0_BLOCK",
|
||||
}, f"{block} is not supported without autotuning"
|
||||
if "YBLOCK" in self.block_args:
|
||||
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}")
|
||||
code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}")
|
||||
else:
|
||||
code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}")
|
||||
if "RBLOCK" in self.block_args:
|
||||
if "R0_BLOCK" in self.block_args:
|
||||
code.splice(f"R0_BLOCK: tl.constexpr = {self.block_size_reduce}")
|
||||
code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}")
|
||||
|
||||
def add_blockd_to_args(self, argdefs: List[str]) -> List[str]:
|
||||
@ -753,6 +755,7 @@ class ComboKernel(Kernel):
|
||||
if self.enable_autotune:
|
||||
argdefs.extend(block_args)
|
||||
self.block_args = list(block_names.keys())
|
||||
|
||||
return argdefs
|
||||
|
||||
def add_numel_to_args(self, argdefs: List[str], signature: List[Any]) -> List[str]:
|
||||
|
@ -5,13 +5,12 @@ from typing import Dict
|
||||
import sympy
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.simd import IterationRangesRoot
|
||||
from torch._inductor.codegen.simd import IterationRangesRoot, prefix_is_reduction
|
||||
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
|
||||
from torch._inductor.runtime.triton_heuristics import split_scan_grid
|
||||
from torch.utils._sympy.functions import CeilDiv
|
||||
|
||||
from ..utils import sympy_product
|
||||
from .simd import prefix_is_reduction
|
||||
|
||||
|
||||
class TritonSplitScanKernel(TritonKernel):
|
||||
@ -52,18 +51,17 @@ class TritonSplitScanKernel(TritonKernel):
|
||||
return False
|
||||
|
||||
def initialize_range_tree(self, pid_cache):
|
||||
prefixes = "yxr"
|
||||
prefixes = ["y", "x", "r0_"]
|
||||
assert len(self.numels) <= len(
|
||||
prefixes
|
||||
), "z dimension not supported for split scan"
|
||||
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
|
||||
|
||||
grid_dims = "rxy"
|
||||
grid_dims = {"r0_": 0, "x": 1, "y": 2}
|
||||
for prefix in active_prefixes:
|
||||
numel = self.numels[prefix]
|
||||
is_reduction = prefix == "r"
|
||||
tensor_dim = 0 if is_reduction else None
|
||||
grid_dim = grid_dims.find(prefix)
|
||||
tensor_dim = 0 if prefix_is_reduction(prefix) else None
|
||||
grid_dim = grid_dims[prefix]
|
||||
self.range_trees.append(
|
||||
IterationRangesRoot(
|
||||
f"{prefix}index",
|
||||
|
@ -1003,12 +1003,12 @@ class Reduction(Loops):
|
||||
|
||||
def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
|
||||
index = self._index(self.ranges)
|
||||
rindex = self._index(self.reduction_ranges, SymT.RINDEX)
|
||||
rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
|
||||
return (index, rindex)
|
||||
|
||||
def inner_fn_free_unbacked_symbols(self) -> Set[Symbol]:
|
||||
index = self._index(self.ranges)
|
||||
rindex = self._index(self.reduction_ranges, SymT.RINDEX)
|
||||
rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
|
||||
return extract_free_unbacked_symbols(self.inner_fn, index, rindex)
|
||||
|
||||
def constant_to_device(self, device: torch.device) -> IRNode:
|
||||
@ -1973,13 +1973,13 @@ class Scan(Loops):
|
||||
|
||||
def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]:
|
||||
index = self._index(self.ranges)
|
||||
rindex = self._index(self.scan_ranges, SymT.RINDEX)
|
||||
rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
|
||||
idx = self.reindex(index, rindex)
|
||||
return (idx,)
|
||||
|
||||
def inner_fn_free_unbacked_symbols(self) -> Set[Symbol]:
|
||||
index = self._index(self.ranges)
|
||||
rindex = self._index(self.scan_ranges, SymT.RINDEX)
|
||||
rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
|
||||
idx = self.reindex(index, rindex)
|
||||
return extract_free_unbacked_symbols(self.inner_fn, idx)
|
||||
|
||||
@ -2170,13 +2170,13 @@ class Sort(Loops):
|
||||
|
||||
def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
|
||||
index = self._index(self.ranges)
|
||||
rindex = self._index(self.sort_ranges, SymT.RINDEX)
|
||||
rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
|
||||
idx = self.reindex(index, rindex)
|
||||
return (idx,)
|
||||
|
||||
def inner_fn_free_unbacked_symbols(self) -> Set[Symbol]:
|
||||
index = self._index(self.ranges)
|
||||
rindex = self._index(self.sort_ranges, SymT.RINDEX)
|
||||
rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
|
||||
idx = self.reindex(index, rindex)
|
||||
return extract_free_unbacked_symbols(self.inner_fn, idx)
|
||||
|
||||
|
@ -85,10 +85,11 @@ class CoordescTuner:
|
||||
"XBLOCK",
|
||||
"YBLOCK",
|
||||
"ZBLOCK",
|
||||
# NOTE: we should not tune RBLOCK for persistent reduction.
|
||||
# NOTE: we should not tune R0_BLOCK for persistent reduction.
|
||||
# We rely on the fact that persistent reduction's triton.Config
|
||||
# does not have the RBLOCK field to guarantee that.
|
||||
"RBLOCK",
|
||||
# does not have the R0_BLOCK field to guarantee that.
|
||||
"R0_BLOCK",
|
||||
"R1_BLOCK",
|
||||
# the following 3 are for mm
|
||||
"BLOCK_M",
|
||||
"BLOCK_N",
|
||||
@ -101,8 +102,10 @@ class CoordescTuner:
|
||||
return out
|
||||
|
||||
def value_too_large(self, name: str, val: int) -> bool:
|
||||
if name in {"XBLOCK", "YBLOCK", "ZBLOCK", "RBLOCK"}:
|
||||
return val > self.get_config_max(name[0].lower())
|
||||
block_suffix = "BLOCK"
|
||||
if name.endswith(block_suffix):
|
||||
prefix = name.strip(block_suffix).lower()
|
||||
return val > self.get_config_max(prefix)
|
||||
if name == "num_warps":
|
||||
return val > self.get_warpsmax()
|
||||
|
||||
@ -245,7 +248,7 @@ class CoordescTuner:
|
||||
|
||||
for name in tunable_fields:
|
||||
cur_val = get_field(best_config, name)
|
||||
# some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None
|
||||
# some kernel don't have R0_BLOCK/YBLOCK/ZBLOCK. So cur_val may be None
|
||||
if cur_val is None:
|
||||
continue
|
||||
|
||||
|
@ -14,7 +14,8 @@ TRITON_MAX_BLOCK = {
|
||||
"X": 4096,
|
||||
"Y": 1024,
|
||||
"Z": 1024,
|
||||
"R": 4096 * 16, # * 16 is multi-kernel only
|
||||
"R0_": 4096 * 16, # * 16 is multi-kernel only
|
||||
"R1_": 2048 * 16, # * 16 is multi-kernel only
|
||||
}
|
||||
TRITON_MAX_RSPLIT = 64
|
||||
|
||||
|
@ -20,6 +20,7 @@ from typing import Any, Container, Dict, List, Optional, Set, Tuple
|
||||
import torch
|
||||
|
||||
from ..triton_bundler import TritonBundler
|
||||
from ..utils import prefix_is_reduction
|
||||
from .autotune_cache import AutotuneCache
|
||||
from .benchmarking import benchmarker
|
||||
from .coordinate_descent_tuner import CoordescTuner
|
||||
@ -104,6 +105,12 @@ except AttributeError: # Compile workers only have a mock version of torch
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_total_reduction_numel(numels: Dict[str, int]) -> int:
|
||||
return conditional_product(
|
||||
*[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
|
||||
)
|
||||
|
||||
|
||||
def autotune_hints_to_configs(
|
||||
hints: Set[AutotuneHint],
|
||||
size_hints,
|
||||
@ -333,23 +340,28 @@ class CachingAutotuner(KernelInterface):
|
||||
for triton_config, compiled_binary in zip(
|
||||
self.configs, compiled_binaries
|
||||
):
|
||||
assert len(self.size_hints) == 2
|
||||
assert len(self.size_hints) >= 2
|
||||
xblock = triton_config.kwargs.get("XBLOCK", 1)
|
||||
rblock = triton_config.kwargs["RBLOCK"]
|
||||
reduction_kwargs = [
|
||||
kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R")
|
||||
]
|
||||
rblocks = [
|
||||
triton_config.kwargs[kwarg] for kwarg in reduction_kwargs
|
||||
]
|
||||
total_block = (self.size_hints["x"] + xblock - 1) // xblock
|
||||
nreg = getattr(compiled_binary, "n_regs", None)
|
||||
if nreg is None:
|
||||
continue
|
||||
|
||||
# make sure rblock is not too small
|
||||
if rblock <= 64:
|
||||
# make sure rblocks are not too small
|
||||
if conditional_product(*rblocks) <= 64:
|
||||
continue
|
||||
|
||||
# each SM of A100 has 65536 32-bit registers. To maximize
|
||||
# the theoretical occupancy, we need run 2048 threads on each
|
||||
# SM. So each thread should use no more than 65536 / 2048
|
||||
# = 32 registers. In cases where occupancy matters, and each
|
||||
# thread uses too many registers, reduce RBLOCK to reduce
|
||||
# thread uses too many registers, reduce R0_BLOCK to reduce
|
||||
# the register usage.
|
||||
# For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd
|
||||
# from PLBartForCausalLM, latency improve from
|
||||
@ -385,12 +397,19 @@ class CachingAutotuner(KernelInterface):
|
||||
# no need to improve occupancy
|
||||
continue
|
||||
new_config = copy.deepcopy(triton_config)
|
||||
new_config.kwargs["RBLOCK"] = rblock // 2
|
||||
|
||||
# Reduce the largest Rn_BLOCK by a factor of 2.
|
||||
largest_rkwarg: str = max(
|
||||
reduction_kwargs, key=triton_config.kwargs.__getitem__
|
||||
)
|
||||
new_config.kwargs[largest_rkwarg] //= 2
|
||||
|
||||
if new_config in seen_configs:
|
||||
continue
|
||||
seen_configs.add(new_config)
|
||||
log.debug(
|
||||
"Dynamically scale down RBLOCK from TritonConfig(%s) and get a new TritonConfig(%s)",
|
||||
"Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)",
|
||||
largest_rkwarg,
|
||||
triton_config,
|
||||
new_config,
|
||||
)
|
||||
@ -998,8 +1017,8 @@ class CachingAutotuner(KernelInterface):
|
||||
|
||||
assert not (
|
||||
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
|
||||
and "RBLOCK" in launcher.config.kwargs
|
||||
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
|
||||
and "R0_BLOCK" in launcher.config.kwargs
|
||||
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
|
||||
start_time = time.time_ns()
|
||||
best_config = self.coordesc_tuner.autotune(
|
||||
benchmark_one_config, launcher.config, None
|
||||
@ -1362,6 +1381,20 @@ def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
|
||||
)
|
||||
|
||||
|
||||
def check_max_block(cfg: Dict[str, int]):
|
||||
"""
|
||||
Check that block sizes are within the maximum allowed.
|
||||
"""
|
||||
for var, val in cfg.items():
|
||||
block_suffix = "BLOCK"
|
||||
if block_suffix in var:
|
||||
prefix = var.removesuffix(block_suffix)
|
||||
max_block = TRITON_MAX_BLOCK[prefix]
|
||||
assert (
|
||||
val <= max_block
|
||||
), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
|
||||
|
||||
|
||||
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
|
||||
# On AMD GPU each warp has 64 lanes which is double the size on NV GPU,
|
||||
# therefore using half the number of warps here correspondingly.
|
||||
@ -1486,82 +1519,143 @@ def triton_config(
|
||||
cfg["YBLOCK"] = y
|
||||
if z:
|
||||
cfg["ZBLOCK"] = z
|
||||
assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}"
|
||||
check_max_block(cfg)
|
||||
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
|
||||
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
|
||||
def _get_nd_reduction_numels(r: int, size_hints: Dict[str, int]) -> Dict[str, int]:
|
||||
"""
|
||||
Converts a linear reduction numel to ND, in row major order.
|
||||
This order is often desirable as it presents opportunities to coalesce memory
|
||||
accesses.
|
||||
For example, if r = 64 and size_hints = [32,32], this function returns [32, 2].
|
||||
This unraveling works because both r and size_hints are powers of 2.
|
||||
"""
|
||||
# Shrink r to size_hints.
|
||||
r = min(r, get_total_reduction_numel(size_hints))
|
||||
num_reduction_dims = len(
|
||||
[prefix for prefix in size_hints if prefix_is_reduction(prefix)]
|
||||
)
|
||||
|
||||
remaining = r
|
||||
rnumels = {}
|
||||
for idx in range(num_reduction_dims - 1, -1, -1):
|
||||
prefix = f"r{idx}_"
|
||||
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
|
||||
dim = min(max_size, remaining)
|
||||
assert (
|
||||
remaining % dim == 0
|
||||
), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
|
||||
rnumels[prefix] = dim
|
||||
remaining //= dim
|
||||
|
||||
# Sanity check the results.
|
||||
final_numel = conditional_product(*rnumels.values())
|
||||
assert (
|
||||
r == final_numel
|
||||
), f"Expected ND reduction size ({rnumels}) to have {r} elements."
|
||||
assert all(
|
||||
rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
|
||||
), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
|
||||
|
||||
return rnumels
|
||||
|
||||
|
||||
def triton_config_reduction(
|
||||
size_hints, x, r, num_stages=1, num_warps=None, register_intensive=False
|
||||
size_hints,
|
||||
x: int,
|
||||
r: int,
|
||||
num_stages=1,
|
||||
num_warps=None,
|
||||
register_intensive=False,
|
||||
) -> Config:
|
||||
"""
|
||||
Construct a reduction triton config with some adjustment heuristics
|
||||
based on size_hints. Size_hints is a tuple of numels in each tile
|
||||
dimension and will be rounded up to the nearest power of 2.
|
||||
"""
|
||||
|
||||
target = conditional_product(x, r)
|
||||
if conditional_product(*size_hints.values()) < target:
|
||||
target //= 8
|
||||
# Convert the linear reduction numel into a multi-dimensional block.
|
||||
rnumels = _get_nd_reduction_numels(r, size_hints)
|
||||
|
||||
# shrink sizes to size hints
|
||||
x = min(x, size_hints["x"])
|
||||
r = min(r, size_hints["r"])
|
||||
|
||||
def total_numel() -> int:
|
||||
return conditional_product(x, *rnumels.values())
|
||||
|
||||
target = total_numel()
|
||||
if conditional_product(*size_hints.values()) < target:
|
||||
target //= 8
|
||||
|
||||
# if we are below original block size, scale up where we can
|
||||
while x < size_hints["x"] and conditional_product(x, r) < target:
|
||||
while x < size_hints["x"] and total_numel() < target:
|
||||
x *= 2
|
||||
while r < size_hints["r"] and conditional_product(x, r) < target:
|
||||
r *= 2
|
||||
for prefix in sorted(rnumels):
|
||||
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
|
||||
rnumels[prefix] *= 2
|
||||
|
||||
if num_warps is None:
|
||||
num_warps = conditional_product(x, r) // 128
|
||||
num_warps = total_numel() // 128
|
||||
num_warps = _num_warps(
|
||||
num_warps, max_num_warps=16, register_intensive=register_intensive
|
||||
)
|
||||
|
||||
x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps)
|
||||
|
||||
while conditional_product(x, r) > target:
|
||||
if r == 1:
|
||||
break
|
||||
r = r // 2
|
||||
for prefix in sorted(rnumels):
|
||||
while total_numel() > target:
|
||||
if rnumels[prefix] == 1:
|
||||
break
|
||||
rnumels[prefix] //= 2
|
||||
|
||||
cfg = {"XBLOCK": x, "RBLOCK": r}
|
||||
cfg = _get_config({"x": x, **rnumels})
|
||||
check_max_block(cfg)
|
||||
check_config(cfg, xnumel=size_hints["x"])
|
||||
assert x <= TRITON_MAX_BLOCK["X"], f"increase TRITON_MAX_BLOCK['X'] to {x}"
|
||||
assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
|
||||
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
|
||||
def _get_config(numels: Dict[str, int]) -> Dict[str, int]:
|
||||
"""
|
||||
Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc.
|
||||
"""
|
||||
|
||||
return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()}
|
||||
|
||||
|
||||
def triton_config_tiled_reduction(size_hints, x, y, r, num_stages=1):
|
||||
"""
|
||||
Construct a tile reduction triton config with some adjustment
|
||||
heuristics based on size_hints. Size_hints is a tuple of numels in
|
||||
each tile dimension and will be rounded up to the nearest power of 2.
|
||||
"""
|
||||
|
||||
target = conditional_product(x, y, r)
|
||||
if conditional_product(*size_hints) < target:
|
||||
target //= 8
|
||||
# Convert the linear reduction numel into a multi-dimensional block.
|
||||
rnumels = _get_nd_reduction_numels(r, size_hints)
|
||||
|
||||
# shrink sizes to size hints
|
||||
x = min(x, size_hints["x"])
|
||||
y = min(y, size_hints["y"])
|
||||
r = min(r, size_hints["r"])
|
||||
|
||||
def total_numel() -> int:
|
||||
return conditional_product(x, y, *rnumels.values())
|
||||
|
||||
target = total_numel()
|
||||
if conditional_product(*size_hints.values()) < target:
|
||||
target //= 8
|
||||
|
||||
# if we are below original block size, scale up where we can
|
||||
while x < size_hints["x"] and conditional_product(x, y, r) < target:
|
||||
while x < size_hints["x"] and total_numel() < target:
|
||||
x *= 2
|
||||
while r < size_hints["r"] and conditional_product(x, y, r) < target:
|
||||
r *= 2
|
||||
while y < size_hints["y"] and conditional_product(x, y, r) < target:
|
||||
for prefix in sorted(rnumels):
|
||||
while rnumels[prefix] < size_hints[prefix] and total_numel() < target:
|
||||
rnumels[prefix] *= 2
|
||||
while y < size_hints[1] and total_numel() < target:
|
||||
y *= 2
|
||||
|
||||
cfg = {"XBLOCK": x, "YBLOCK": y, "RBLOCK": r}
|
||||
num_warps = _num_warps(conditional_product(x, y, r) // 256, min_num_warps=1)
|
||||
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
|
||||
assert r <= TRITON_MAX_BLOCK["R"], f"increase TRITON_MAX_BLOCK['r'] to {r}"
|
||||
cfg = _get_config({"x": x, "y": y, **rnumels})
|
||||
num_warps = _num_warps(total_numel() // 256, min_num_warps=1)
|
||||
check_config(cfg, xnumel=size_hints[0], ynumel=size_hints[1])
|
||||
check_max_block(cfg)
|
||||
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
|
||||
@ -1657,35 +1751,36 @@ def _reduction_configs(
|
||||
*, size_hints: Dict[str, int], inductor_meta: Dict[str, Any]
|
||||
) -> List[Config]:
|
||||
reduction_hint = inductor_meta.get("reduction_hint", None)
|
||||
assert len(size_hints) == 2
|
||||
rnumel = size_hints["r"]
|
||||
|
||||
# Convert reductions to 1D, to simplify heuristics.
|
||||
rnumel = get_total_reduction_numel(size_hints)
|
||||
|
||||
register_intensive = False
|
||||
MAX_RBLOCK = 2048
|
||||
MAX_R0_BLOCK = 2048
|
||||
if (
|
||||
size_hints["x"] >= 1024
|
||||
and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0)
|
||||
>= 10
|
||||
):
|
||||
# A heuristics to reduce RBLOCK if a kernel potentially need many registers.
|
||||
# A heuristics to reduce R0_BLOCK if a kernel potentially need many registers.
|
||||
# Consider load and reduction since load need move data into registers and
|
||||
# reduction needs an accumulator.
|
||||
#
|
||||
# The magic numbers are a bit arbitrary.
|
||||
#
|
||||
# We cannot rely on dynamically scaling down RBLOCK later, since sometimes
|
||||
# We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes
|
||||
# triton makes it to use less registers with worse perf. Check:
|
||||
# https://github.com/pytorch/pytorch/issues/126463
|
||||
#
|
||||
# The heuristic is a very simple one since registers can be reused. But
|
||||
# hopefully it can be a good enough indicator.
|
||||
MAX_RBLOCK = 1024
|
||||
MAX_R0_BLOCK = 1024
|
||||
register_intensive = True
|
||||
|
||||
contiguous_config = triton_config_reduction(
|
||||
size_hints,
|
||||
1,
|
||||
(rnumel if 256 <= rnumel < MAX_RBLOCK else MAX_RBLOCK),
|
||||
rnumel if 256 <= rnumel < MAX_R0_BLOCK else MAX_R0_BLOCK,
|
||||
register_intensive=register_intensive,
|
||||
)
|
||||
outer_config = triton_config_reduction(
|
||||
@ -1694,7 +1789,7 @@ def _reduction_configs(
|
||||
tiny_config = triton_config_reduction(
|
||||
size_hints,
|
||||
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
||||
min(rnumel, MAX_RBLOCK),
|
||||
min(rnumel, MAX_R0_BLOCK),
|
||||
register_intensive=register_intensive,
|
||||
)
|
||||
if inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise"):
|
||||
@ -1713,7 +1808,7 @@ def _reduction_configs(
|
||||
tiny_config,
|
||||
triton_config_reduction(size_hints, 64, 64),
|
||||
triton_config_reduction(size_hints, 8, 512),
|
||||
# halve the XBLOCK/RBLOCK compared to outer_config
|
||||
# halve the XBLOCK/Rn_BLOCK compared to outer_config
|
||||
# TODO: this may only be beneficial when each iteration of the reduction
|
||||
# is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
|
||||
triton_config_reduction(size_hints, 64, 4, num_warps=8),
|
||||
@ -1734,8 +1829,6 @@ def reduction(
|
||||
size_hints["x"] = 1
|
||||
|
||||
assert triton_meta is not None
|
||||
if len(size_hints) != 2:
|
||||
raise NotImplementedError(f"size_hints: {size_hints}")
|
||||
|
||||
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
||||
return cached_autotune(
|
||||
@ -1759,7 +1852,12 @@ def cooperative_reduction(
|
||||
inductor_meta["reduction_hint"] = reduction_hint
|
||||
if inductor_meta.get("no_x_dim"):
|
||||
size_hints["x"] = 1
|
||||
xnumel, rnumel = size_hints["x"], size_hints["r"]
|
||||
|
||||
# Cooperative reductions currently only support a single reduction dimension.
|
||||
assert (
|
||||
len(size_hints) == 2
|
||||
), "Cooperative reductions don't support tiling reduction dims"
|
||||
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
|
||||
|
||||
# TODO(jansel): we should base target on the SM count of the local GPU
|
||||
target = 64
|
||||
@ -1768,11 +1866,12 @@ def cooperative_reduction(
|
||||
assert split <= TRITON_MAX_RSPLIT
|
||||
if inductor_meta["persistent_reduction"]:
|
||||
configs = _persistent_reduction_configs(
|
||||
{"x": xnumel, "r": rnumel // split}, reduction_hint, inductor_meta
|
||||
{"x": xnumel, "r0_": rnumel // split}, reduction_hint, inductor_meta
|
||||
)
|
||||
else:
|
||||
configs = _reduction_configs(
|
||||
size_hints={"x": xnumel, "r": rnumel // split}, inductor_meta=inductor_meta
|
||||
size_hints={"x": xnumel, "r0_": rnumel // split},
|
||||
inductor_meta=inductor_meta,
|
||||
)
|
||||
for config in configs:
|
||||
config.kwargs["RSPLIT"] = split
|
||||
@ -1793,7 +1892,8 @@ def _persistent_reduction_configs(
|
||||
reduction_hint=False,
|
||||
inductor_meta=None,
|
||||
):
|
||||
xnumel, rnumel = size_hints["x"], size_hints["r"]
|
||||
xnumel = size_hints["x"]
|
||||
rnumel = get_total_reduction_numel(size_hints)
|
||||
|
||||
configs = [
|
||||
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
|
||||
@ -1809,12 +1909,16 @@ def _persistent_reduction_configs(
|
||||
elif reduction_hint == ReductionHint.OUTER_TINY:
|
||||
configs = [
|
||||
triton_config_reduction(
|
||||
size_hints, 2 * (256 // rnumel) if rnumel <= 256 else 1, rnumel
|
||||
size_hints,
|
||||
2 * (256 // rnumel) if rnumel <= 256 else 1,
|
||||
rnumel,
|
||||
)
|
||||
]
|
||||
for c in configs:
|
||||
# we don't need RBLOCK for persistent reduction
|
||||
c.kwargs.pop("RBLOCK")
|
||||
# we don't need Rn_BLOCK for persistent reduction
|
||||
for prefix in size_hints:
|
||||
if prefix_is_reduction(prefix):
|
||||
c.kwargs.pop(f"{prefix.upper()}BLOCK")
|
||||
|
||||
if disable_pointwise_autotuning(inductor_meta):
|
||||
configs = configs[:1]
|
||||
@ -1865,11 +1969,12 @@ def split_scan(
|
||||
|
||||
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
|
||||
|
||||
# Fixup configs to enforce the minimum RBLOCK size
|
||||
# Fixup configs to enforce the minimum Rn_BLOCK size
|
||||
min_rblock = inductor_meta.get("min_split_scan_rblock", 256)
|
||||
for cfg in configs:
|
||||
if cfg.kwargs["RBLOCK"] < min_rblock:
|
||||
cfg.kwargs["RBLOCK"] = min_rblock
|
||||
for var in list(cfg.kwargs.keys()):
|
||||
if var.startswith("R") and cfg.kwargs[var] < min_rblock:
|
||||
cfg.kwargs[var] = min_rblock
|
||||
|
||||
return cached_autotune(
|
||||
size_hints,
|
||||
@ -2028,7 +2133,7 @@ def maybe_cooperative_reduction_grid(xnumel):
|
||||
def split_scan_grid(xnumel, rnumel):
|
||||
def grid_fn(meta):
|
||||
assert meta.get("XBLOCK", 1) == 1
|
||||
return (ceildiv(rnumel, meta.get("RBLOCK", 1)), xnumel, 1)
|
||||
return (ceildiv(rnumel, meta.get("R0_BLOCK", 1)), xnumel, 1)
|
||||
|
||||
grid_fn_str = f"split_scan_grid({xnumel}, {rnumel})"
|
||||
setattr(grid_fn, "grid_fn_str", grid_fn_str) # noqa: B010
|
||||
|
@ -262,7 +262,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
super().__init__(
|
||||
{
|
||||
"x": numel,
|
||||
"r": sympy.S.One,
|
||||
"r0_": sympy.S.One,
|
||||
},
|
||||
features=SIMDKernelFeatures([], numel),
|
||||
)
|
||||
|
@ -681,6 +681,10 @@ def get_bounds_index_expr(index):
|
||||
return ValueRanges.unknown()
|
||||
|
||||
|
||||
def prefix_is_reduction(prefix: str) -> bool:
|
||||
return prefix[0] == "r"
|
||||
|
||||
|
||||
def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
|
||||
"""
|
||||
Used to generate an integer-nonnegative symbol.
|
||||
|
@ -13,7 +13,7 @@ in this file and seeing what breaks.
|
||||
"""
|
||||
|
||||
from enum import auto, Enum
|
||||
from typing import Sequence, Union
|
||||
from typing import Iterable, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -36,9 +36,10 @@ class SymT(Enum):
|
||||
# Inductor: An indexing variable i0 in loops IR which ranges over non-reduced
|
||||
# dim in the loop
|
||||
INDEX = auto()
|
||||
# Inductor: A reduction indexing r0 variable in loops IR which ranges over
|
||||
# reduced dim in the loop
|
||||
RINDEX = auto()
|
||||
# Inductor: A reduction indexing (r0, r1) variables in loops IR which ranges over
|
||||
# reduced dim(s) in the loop
|
||||
R0_INDEX = auto()
|
||||
R1_INDEX = auto()
|
||||
# Inductor: In templated kernels torch._inductor.kernel, we have a hook to
|
||||
# store the final output and append epilogue fusions. To do this, we must
|
||||
# know what the indexes the outputs range over. NB: These will also
|
||||
@ -67,7 +68,8 @@ prefix_str = {
|
||||
SymT.TMP: "tmp",
|
||||
SymT.PRECOMPUTED_SIZE: "ps",
|
||||
SymT.INDEX: "i",
|
||||
SymT.RINDEX: "r",
|
||||
SymT.R0_INDEX: "r0_",
|
||||
SymT.R1_INDEX: "r1_",
|
||||
SymT.TEMPLATE_INDEX: "idx",
|
||||
SymT.XBLOCK: "x",
|
||||
SymT.YBLOCK: "y",
|
||||
@ -85,7 +87,7 @@ def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol:
|
||||
|
||||
# This type is a little wider than it should be, because free_symbols says
|
||||
# that it contains Basic, rather than Symbol
|
||||
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> bool:
|
||||
def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Iterable[SymT]]) -> bool:
|
||||
assert isinstance(sym, sympy.Symbol)
|
||||
name_str = sym.name.lower() # Match capitalized names like XBLOCK, RBLOCK
|
||||
if isinstance(prefix, SymT):
|
||||
@ -94,5 +96,5 @@ def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> boo
|
||||
return name_str.startswith(tuple(prefix_str[p] for p in prefix))
|
||||
|
||||
|
||||
def free_symbol_is_type(e: sympy.Expr, prefix: SymT) -> bool:
|
||||
def free_symbol_is_type(e: sympy.Expr, prefix: Union[SymT, Iterable[SymT]]) -> bool:
|
||||
return any(symbol_is_type(v, prefix) for v in e.free_symbols)
|
||||
|
Reference in New Issue
Block a user