mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Restrict block analysis to only match dimension sizes and strides that are integers. E.g. `sympy` can match index expressions like `ModularIndexing(xindex, 4, 4)) + 4*(ModularIndexing(xindex, 32, 2))` with the candidate below that is invalid. ```python match_expr = stride_mod0_*((xindex//(dim_mod1_*dim_mod2_*dim_mod3_*dim_mod4_))) + stride_mod1_*(ModularIndexing(xindex, dim_mod2_*dim_mod3_*dim_mod4_, dim_mod1_)) + stride_mod2_*(ModularIndexing(xindex, dim_mod3_*dim_mod4_, dim_mod2_)) + stride_mod3_*(ModularIndexing(xindex, dim_mod4_, dim_mod3_)) + stride_mod4_*(ModularIndexing(xindex, 1, dim_mod4_)) match={ dim_mod4_: 32, dim_mod3_: 2, stride_mod3_: 4, dim_mod2_: 1/16, dim_mod1_: 4, stride_mod1_: 1, stride_mod4_: 0, stride_mod2_: 0, stride_mod0_: 0 } ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149615 Approved by: https://github.com/blaine-rister
193 lines
7.1 KiB
Python
193 lines
7.1 KiB
Python
import collections
|
|
import functools
|
|
import textwrap
|
|
from typing import Optional
|
|
|
|
import sympy
|
|
from sympy import Expr, Symbol
|
|
|
|
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
|
|
|
from ..utils import sympy_dot, sympy_subs
|
|
from ..virtualized import V
|
|
|
|
|
|
class BlockPatternMatcher:
|
|
"""
|
|
Matches block indexing expressions.
|
|
"""
|
|
|
|
_indexing_wild_signed_int = functools.partial(
|
|
sympy.Wild, properties=[lambda x: x.is_integer]
|
|
)
|
|
_indexing_wild_unsigned_int = functools.partial(
|
|
sympy.Wild, properties=[lambda x: x.is_integer and x.is_nonnegative]
|
|
)
|
|
|
|
@classmethod
|
|
def get_subexpr_involving_symbol(cls, expr: Expr, symbol: Symbol) -> Expr:
|
|
"""
|
|
Given a sympy expression, return the subexpression comprised only of terms
|
|
involving the specified symbol.
|
|
|
|
For example, if `expr` is `x * 5 + x ** 2 + y * 2 + 5`, and `symbol` is `x`,
|
|
this returns `x * 5 + x ** 2`.
|
|
"""
|
|
expr = cls._preprocess(expr)
|
|
return sympy.S.Zero + sum(
|
|
term for term in sympy.Add.make_args(expr) if symbol in term.free_symbols
|
|
)
|
|
|
|
@staticmethod
|
|
def get_slice_numels(dims: list[Expr]) -> list[Expr]:
|
|
"""
|
|
Compute the cumulative size of each dimension's slice.
|
|
This proceeds from the last dim up to the second.
|
|
"""
|
|
numels = collections.deque([sympy.S.One])
|
|
for dim in dims[:0:-1]:
|
|
numel = dim * numels[0]
|
|
numels.appendleft(numel)
|
|
return [*numels]
|
|
|
|
@staticmethod
|
|
def _preprocess(expr: Expr) -> Expr:
|
|
# Remove any Identity nodes, e.g. expand x + (5 * y) to x + 5 * y.
|
|
return expr.expand(identity=True)
|
|
|
|
@classmethod
|
|
def match_mod_div_block_expr(
|
|
cls,
|
|
index: Expr,
|
|
index_var: Symbol,
|
|
numel: Expr,
|
|
num_dims: int,
|
|
) -> Optional[tuple[list[Expr], list[Expr], list[Expr]]]:
|
|
"""
|
|
Matches modular indexing expressions, converting them to implied block dimensions and strides.
|
|
See triton.py for more information.
|
|
"""
|
|
index = cls._preprocess(index)
|
|
|
|
# Pattern match to find the strides and offset.
|
|
wild_unsigned_int = functools.partial(
|
|
cls._indexing_wild_unsigned_int, exclude=[index_var]
|
|
)
|
|
wild_signed_int = functools.partial(
|
|
cls._indexing_wild_signed_int, exclude=[index_var]
|
|
)
|
|
dims: list[Expr] = [
|
|
wild_unsigned_int(f"dim_mod{idx}") for idx in range(num_dims)
|
|
]
|
|
strides: list[Expr] = [
|
|
wild_signed_int(f"stride_mod{idx}") for idx in range(num_dims)
|
|
]
|
|
|
|
# The first dimension's index is computed by division.
|
|
# The remaining are computed by modulo.
|
|
slice_numels = cls.get_slice_numels(dims[:num_dims])
|
|
block_index_exprs = [FloorDiv(index_var, slice_numels[0])] + [
|
|
ModularIndexing(index_var, numel, dim)
|
|
for dim, numel in zip(dims[1:], slice_numels[1:])
|
|
]
|
|
|
|
# Calculate a linear index from block indices.
|
|
match_expr = sympy_dot(strides, block_index_exprs)
|
|
|
|
# Heuristic: if the number of dimensions is high, check that the minimum requirements
|
|
# are met before attempting an expensive full match. see triton.py:match_mod_div_block
|
|
# for more details. In short, here we check that each subexpression in sympy.Add contains
|
|
# only FloorDiv or ModularIndexing expressions.
|
|
if num_dims >= 5:
|
|
stride = sympy.symbols("stride", cls=wild_signed_int)
|
|
denom, other = sympy.symbols("denominator other", cls=wild_unsigned_int)
|
|
mod_div_pattern = stride * ModularIndexing(index_var, denom, other)
|
|
floor_div_pattern = stride * FloorDiv(index_var, denom)
|
|
first_dim_floor_div_matched = False
|
|
match_failed = False
|
|
for arg in sympy.Add.make_args(index):
|
|
if arg.match(floor_div_pattern):
|
|
# There should only be a single FloorDiv(index, denom) expression
|
|
# corresponding to the first dimension
|
|
if first_dim_floor_div_matched:
|
|
match_failed = True
|
|
break
|
|
first_dim_floor_div_matched = True
|
|
elif arg.match(mod_div_pattern):
|
|
continue
|
|
else:
|
|
match_failed = True
|
|
break
|
|
|
|
if match_failed:
|
|
return None
|
|
|
|
# Pattern match.
|
|
match = index.match(match_expr)
|
|
if match is None:
|
|
return None
|
|
|
|
# Provide default values for unmatched dims and strides.
|
|
for dim in dims[1:]:
|
|
if dim not in match:
|
|
match[dim] = sympy.S.One
|
|
for stride in strides[1:]:
|
|
if stride not in match:
|
|
match[stride] = sympy.S.Zero
|
|
|
|
sizevars = V.graph.sizevars
|
|
|
|
def get_match(expr: Expr) -> Expr:
|
|
return sizevars.lookup_precomputed_size(match[expr])
|
|
|
|
# Replace wildcards with matched expressions.
|
|
dims = [dims[0]] + [get_match(dim) for dim in dims[1:]]
|
|
strides = [get_match(stride) for stride in strides]
|
|
slice_numels = cls.get_slice_numels(dims)
|
|
block_index_exprs = [sympy_subs(expr, match) for expr in block_index_exprs]
|
|
|
|
# The leading dimension is not directly matched in our expression.
|
|
# We solve for it by dividing the range tree numel by the product of
|
|
# all other dimensions. We quit if they are not known to be divisible.
|
|
assert dims[0] not in match, "Expected not to match the leading dimension!"
|
|
if not sizevars.statically_known_multiple_of(numel, slice_numels[0]):
|
|
return None
|
|
dims[0] = numel / slice_numels[0]
|
|
|
|
# Sanity check that we can recover the index from the matched subexpressions.
|
|
matched_index = sympy_dot(strides, block_index_exprs)
|
|
assert sizevars.statically_known_equals(
|
|
# New precomputed replacements may be generated when the `get_match` function
|
|
# above is called, but the `index` that is being matched has not been updated.
|
|
# So remove them when checking for equivalence e.g. if ps0=3*s0 and
|
|
# index=3*s0*expr, matched_index=ps0*expr, then index == matched_index
|
|
sizevars.remove_precomputed_replacements(matched_index),
|
|
sizevars.remove_precomputed_replacements(index),
|
|
), textwrap.dedent(
|
|
f"""
|
|
Invalid match!
|
|
Index: {index}
|
|
Matched expression: {matched_index}
|
|
"""
|
|
)
|
|
|
|
return dims, strides, block_index_exprs
|
|
|
|
@classmethod
|
|
def match_affine_block_expr(
|
|
cls,
|
|
index: Expr,
|
|
index_var: Symbol,
|
|
) -> Optional[Expr]:
|
|
"""
|
|
Matches simple expressions of the form stride * index, returning the
|
|
stride.
|
|
"""
|
|
index = cls._preprocess(index)
|
|
stride = cls._indexing_wild_signed_int(name="stride", exclude=[index_var])
|
|
m = index.match(index_var * stride)
|
|
if m is None:
|
|
return None
|
|
|
|
return m[stride]
|