mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145385 Approved by: https://github.com/ezyang
924 lines
35 KiB
Python
924 lines
35 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
from collections.abc import Iterable, Sequence
|
|
from typing import Any, Callable, cast, Optional, Union
|
|
|
|
import sympy
|
|
from sympy import Expr
|
|
|
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, ShapeEnv
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
|
from torch.utils._sympy.value_ranges import bound_sympy, IntInfinity, ValueRanges
|
|
|
|
from .runtime.runtime_utils import is_power_of_2
|
|
from .utils import (
|
|
has_free_symbols,
|
|
sympy_index_symbol,
|
|
sympy_index_symbol_with_prefix,
|
|
sympy_subs,
|
|
VarRanges,
|
|
)
|
|
from .virtualized import V
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def evaluate_expr(
|
|
shape_env: ShapeEnv,
|
|
expr: Union[sympy.Basic, bool],
|
|
axioms: Optional[tuple[sympy.Expr]] = None,
|
|
var_to_range: Optional[tuple[tuple[sympy.Symbol, ValueRanges[Any]]]] = None,
|
|
) -> bool:
|
|
if expr in (True, False):
|
|
return bool(expr)
|
|
|
|
try:
|
|
simplified = shape_env._maybe_evaluate_static(
|
|
expr,
|
|
axioms=axioms,
|
|
var_to_range=var_to_range,
|
|
)
|
|
if simplified is not None:
|
|
return bool(simplified)
|
|
except Exception:
|
|
log.debug("Could not simplify %s", expr, exc_info=True)
|
|
|
|
return False
|
|
|
|
|
|
# This class is a little awkward, because ShapeEnv is doing most of the heavy
|
|
# lifting and in some cases we should be directly passing through to ShapeEnv,
|
|
# but there is some extra inductor logic that needs to be handled here
|
|
class SizeVarAllocator:
|
|
def __init__(self, shape_env=None) -> None:
|
|
super().__init__()
|
|
if shape_env is None:
|
|
shape_env = ShapeEnv()
|
|
self.shape_env = shape_env
|
|
self.var_to_val = self.shape_env.var_to_val
|
|
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
|
|
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
|
|
# The basic idea is if we have some complicated sympy expression
|
|
# f(s0), we may choose to precompute it on the host and then replace
|
|
# all occurrences of that sympy expression with ps0, so that when we
|
|
# codegen we simply reference ps0 directly without repeating
|
|
# f(s0). Unlike regular size variables, ps variables cannot be
|
|
# guarded upon; so if we are asked to guard on a Sympy expression
|
|
# which potentially could have already had a precomputed replacement
|
|
# on it, we are obligated to invert the precomputed replacements
|
|
# (inv_precomputed_replacements).
|
|
self.precomputed_replacements: dict[Expr, sympy.Symbol] = {}
|
|
self.inv_precomputed_replacements: dict[sympy.Symbol, Expr] = {}
|
|
self.stride_vars = self.make_stride_vars_cache()
|
|
self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
|
|
self._simplify_loops = self.make_simplify_loops_cache()
|
|
|
|
def simplify(self, expr: Expr):
|
|
return sympy.expand(expr).xreplace(self.replacements)
|
|
|
|
def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]:
|
|
"""
|
|
self._simplify_with_ranges() can be expensive, cache its results
|
|
"""
|
|
cache: dict[tuple[Any, ...], Expr] = {}
|
|
replacement_count = len(self.replacements)
|
|
|
|
def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr:
|
|
nonlocal replacement_count
|
|
if replacement_count != len(self.replacements):
|
|
# new replacements invalidates cached results
|
|
cache.clear()
|
|
replacement_count = len(self.replacements)
|
|
key = (expr, *var_ranges.items())
|
|
result = cache.get(key, None)
|
|
if result is None:
|
|
result = self._simplify_with_ranges(expr, var_ranges)
|
|
cache[key] = result
|
|
return result
|
|
|
|
return simplify_with_ranges
|
|
|
|
def make_simplify_loops_cache(self):
|
|
"""
|
|
self._simplify_with_ranges() can be expensive, cache its results
|
|
"""
|
|
cache: dict[tuple[Any, ...], Any] = {}
|
|
replacement_count = len(self.replacements)
|
|
|
|
def simplify_loops(index_vars, sizes, index_formulas):
|
|
nonlocal replacement_count
|
|
if replacement_count != len(self.replacements):
|
|
# new replacements invalidates cached results
|
|
cache.clear()
|
|
replacement_count = len(self.replacements)
|
|
key = (*index_vars, *sizes, *index_formulas)
|
|
result = cache.get(key, None)
|
|
if result is None:
|
|
result = self._simplify_loops_impl(index_vars, sizes, index_formulas)
|
|
cache[key] = result
|
|
return result
|
|
|
|
return simplify_loops
|
|
|
|
def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr:
|
|
"""
|
|
Simplify indexing expression with knowledge of the ranges of
|
|
iteration variables.
|
|
"""
|
|
|
|
expr = join_dimensions(self.simplify(expr))
|
|
original_expr = expr
|
|
|
|
var_to_range = dict(self.shape_env.var_to_range)
|
|
var_to_range.update(
|
|
{
|
|
k: ValueRanges(
|
|
0, max(0, v - 1) if not has_free_symbols([v]) else IntInfinity()
|
|
)
|
|
for k, v in var_ranges.items()
|
|
}
|
|
)
|
|
for var in expr.free_symbols:
|
|
if var not in var_to_range:
|
|
var_to_range[var] = ValueRanges(0, IntInfinity())
|
|
|
|
var_to_range_tuple = cast(
|
|
tuple[tuple[sympy.Symbol, ValueRanges[sympy.Expr]]],
|
|
tuple(var_to_range.items()),
|
|
)
|
|
|
|
axioms = []
|
|
for var, upper_bound in var_ranges.items():
|
|
axioms.append(0 <= var)
|
|
axioms.append(var < upper_bound)
|
|
axioms = tuple(axioms) + self.shape_env.get_axioms()
|
|
|
|
def statically_known(expr):
|
|
evaluated = self.shape_env._maybe_evaluate_static(
|
|
expr,
|
|
axioms=axioms,
|
|
var_to_range=var_to_range_tuple,
|
|
)
|
|
return bool(evaluated)
|
|
|
|
def remove_zero_terms(base, divisor):
|
|
"""Symbols smaller than the divisor are zero"""
|
|
if not statically_known(base >= 0):
|
|
return base
|
|
|
|
for v in base.free_symbols:
|
|
if v in var_ranges:
|
|
# var smaller than divisor can be removed
|
|
# if the rest is guaranteed to be multiple of divisor
|
|
rest = sympy.Wild("_rest", exclude=[v])
|
|
m = base.match(v + rest)
|
|
if m and v not in m[rest].free_symbols:
|
|
gcd = sympy.gcd(m[rest], divisor)
|
|
if gcd == divisor:
|
|
if statically_known(v < divisor):
|
|
base = m[rest]
|
|
return base
|
|
|
|
def visit_indexing_div(base, divisor):
|
|
return FloorDiv(remove_zero_terms(base, divisor), divisor)
|
|
|
|
def visit_modular_indexing(base, divisor, modulus):
|
|
base = remove_zero_terms(base, divisor)
|
|
|
|
can_remove_mod = statically_known(base >= 0) and statically_known(
|
|
base < modulus * divisor
|
|
)
|
|
|
|
if can_remove_mod:
|
|
return FloorDiv(base, divisor)
|
|
return ModularIndexing(base, divisor, modulus)
|
|
|
|
if expr.has(ModularIndexing):
|
|
expr = expr.replace(
|
|
ModularIndexing(
|
|
sympy.Wild("base", integer=True),
|
|
sympy.Wild("divisor", integer=True),
|
|
sympy.Wild("modulus", integer=True),
|
|
),
|
|
visit_modular_indexing,
|
|
)
|
|
|
|
if expr.has(FloorDiv):
|
|
expr = expr.replace(
|
|
FloorDiv(
|
|
sympy.Wild("base", integer=True),
|
|
sympy.Wild("divisor", integer=True),
|
|
),
|
|
visit_indexing_div,
|
|
)
|
|
|
|
if expr != original_expr:
|
|
return self._simplify_with_ranges(expr, var_ranges)
|
|
return expr
|
|
|
|
def _simplify_loops_impl(
|
|
self, index_vars: list[sympy.Symbol], sizes, index_formulas
|
|
):
|
|
"""
|
|
Try to remove as many axis from loop iterations as possible, by:
|
|
1) removing size==1 dimensions
|
|
2) fuse contiguous dimensions into a single loop
|
|
If channel_last = True, we will prevent the last dim fused with other dims
|
|
"""
|
|
sizes = list(map(self.simplify, sizes))
|
|
|
|
strides = [
|
|
# index_formulas may contain boolean expressions (e.g. s0 < 10),
|
|
# for which "strides" don't make sense so we ignore them here.
|
|
# NOTE: These expressions may still block merging dims in the sound
|
|
# substitution test performed in can_merge_dims.
|
|
(
|
|
self.stride_vars(x, index_vars)
|
|
if isinstance(x, sympy.Expr)
|
|
else [0] * len(index_vars)
|
|
)
|
|
for x in index_formulas
|
|
]
|
|
assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0]))
|
|
|
|
for i in range(len(sizes)):
|
|
if sizes[i] == 1:
|
|
# remove dim
|
|
sizes[i] = None
|
|
|
|
def can_merge_dims(a, b):
|
|
for k in range(len(strides)):
|
|
if self.simplify(strides[k][a] * sizes[a]) == self.simplify(
|
|
strides[k][b]
|
|
):
|
|
# approximate test passed, try sound version
|
|
va = index_vars[a]
|
|
vb = index_vars[b]
|
|
m1 = sympy_index_symbol("_merge_tester1")
|
|
m2 = sympy_index_symbol("_merge_tester2")
|
|
# NOTE: can't sub vb=0 here in case va * vb appears in the expression,
|
|
# in which case both expr1 and expr2 would be zero!
|
|
expr1 = sympy_subs(index_formulas[k], {va: m1 * sizes[a], vb: m2})
|
|
expr2 = sympy_subs(index_formulas[k], {va: 0, vb: (m1 + m2)})
|
|
if self.simplify(expr1) == self.simplify(expr2):
|
|
continue
|
|
return False
|
|
return True
|
|
|
|
changed = True
|
|
while changed:
|
|
changed = False
|
|
for i, j in itertools.product(
|
|
reversed(range(len(sizes))), reversed(range(len(sizes)))
|
|
):
|
|
if i == j or sizes[i] is None or sizes[j] is None:
|
|
continue
|
|
if can_merge_dims(i, j):
|
|
changed = True
|
|
sizes[i] = sizes[i] * sizes[j]
|
|
sizes[j] = None
|
|
|
|
def reindex(index):
|
|
it = list(reversed(index))
|
|
new_index = []
|
|
for size in sizes:
|
|
if size is None:
|
|
new_index.append(sympy.S.Zero)
|
|
else:
|
|
new_index.append(it.pop())
|
|
assert not it
|
|
return new_index
|
|
|
|
def prune(index):
|
|
assert len(index) == len(sizes)
|
|
return [i for i, s in zip(index, sizes) if s is not None]
|
|
|
|
return [x for x in sizes if x is not None], reindex, prune
|
|
|
|
# Note - [On Statically Known]
|
|
#
|
|
# The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
|
|
# operated by providing essentially a question, where the size hinted values were evaluated. If the condition was
|
|
# true, we add a guard and return True, otherwise, False.
|
|
#
|
|
# def maybe_guard_foo(args):
|
|
# if size_hinted_check(args):
|
|
# return False # No guard, no optim
|
|
# guard(args) # Make a guard
|
|
# return True # Safe to apply optimization
|
|
#
|
|
# The prior system incurred a guard, and green lit an optimization.
|
|
#
|
|
# The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
|
|
# condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
|
|
# return False.
|
|
#
|
|
# def maybe_guard_foo(args):
|
|
# if all_static(args):
|
|
# return True # Safe to apply optimization
|
|
# else:
|
|
# return False # No guard, no optim
|
|
|
|
# See Note - [On Statically Known]
|
|
|
|
def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool:
|
|
return evaluate_expr(self.shape_env, expr)
|
|
|
|
def statically_known_equals(
|
|
self, left: Union[Expr, int], right: Union[Expr, int]
|
|
) -> bool:
|
|
"""
|
|
Returns a bool indicating if it is sound to optimize as if left and right are equal.
|
|
"""
|
|
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
|
|
|
|
# See Note - [On Statically Known]
|
|
def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool:
|
|
"""
|
|
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
|
|
"""
|
|
return len(left) == len(right) and all(
|
|
self.statically_known_equals(l, r) for l, r in zip(left, right)
|
|
)
|
|
|
|
# See Note - [On Statically Known]
|
|
def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool:
|
|
"""
|
|
Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
|
|
"""
|
|
expr = left <= right
|
|
return self.is_expr_static_and_true(expr)
|
|
|
|
# See Note - [On Statically Known]
|
|
def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool:
|
|
"""
|
|
Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right.
|
|
"""
|
|
expr = left >= right
|
|
return self.is_expr_static_and_true(expr)
|
|
|
|
# See Note - [On Statically Known]
|
|
def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool:
|
|
"""
|
|
Returns a bool indicating if it is sound to optimize as if left is less than right.
|
|
"""
|
|
expr = left < right
|
|
return self.is_expr_static_and_true(expr)
|
|
|
|
# See Note - [On Statically Known]
|
|
def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool:
|
|
"""
|
|
Returns a bool indicating if it is sound to optimize as if left is greater than right.
|
|
"""
|
|
expr = left > right
|
|
return self.is_expr_static_and_true(expr)
|
|
|
|
# See Note - [On Statically Known]
|
|
def statically_known_multiple_of(
|
|
self, numerator: Expr, denominator: Union[Expr, int]
|
|
) -> bool:
|
|
"""
|
|
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
|
|
"""
|
|
if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator):
|
|
return False
|
|
expr = sympy.Eq(numerator % denominator, 0)
|
|
return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
|
|
|
|
# See Note - [On Statically Known]
|
|
def statically_known_power_of_2(self, expr: Expr) -> bool:
|
|
"""
|
|
Returns a bool indicating if x is known to be a power of 2.
|
|
"""
|
|
return isinstance(expr, sympy.Integer) and is_power_of_2(int(expr))
|
|
|
|
# The guard functions require you to ALREADY KNOW that a particular
|
|
# condition holds. If you don't know (you want to guard on an expression
|
|
# being a particular value, and then get access to that value), use
|
|
# the evaluate functions.
|
|
|
|
def guard_equals(self, left: Expr, right: Expr) -> Expr:
|
|
if isinstance(left, Expr):
|
|
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
|
if isinstance(right, Expr):
|
|
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
|
|
|
expr = sympy.Eq(left, right)
|
|
static_expr = self.shape_env._maybe_evaluate_static(expr)
|
|
|
|
if static_expr is not None:
|
|
assert bool(static_expr)
|
|
return left
|
|
|
|
assert self.shape_env.defer_runtime_assert(expr, "guard_equals")
|
|
return left
|
|
|
|
def guard_leq(self, left: Expr, right: Expr) -> None:
|
|
return self.guard_lt(left, right + 1)
|
|
|
|
def guard_lt(self, left: Expr, right: Expr) -> None:
|
|
expr = sympy.Lt(left, right)
|
|
static_expr = self.shape_env._maybe_evaluate_static(expr)
|
|
|
|
if static_expr is not None:
|
|
assert bool(static_expr)
|
|
return
|
|
|
|
assert self.shape_env.defer_runtime_assert(expr, "guard_lt")
|
|
|
|
def guarded_order(self, seq):
|
|
"""
|
|
Return the order of a sequence as a permutation of range(len(seq)) and guard on that order not changing.
|
|
"""
|
|
seq = [*map(self.remove_precomputed_replacements, seq)]
|
|
seq = [(self.size_hint(var), orig_idx, var) for orig_idx, var in enumerate(seq)]
|
|
seq.sort()
|
|
order = [-1] * len(seq)
|
|
last_var = None
|
|
for new_index, (_, orig_index, var) in enumerate(seq):
|
|
order[orig_index] = new_index
|
|
if last_var is not None:
|
|
self.guard_leq(last_var, var)
|
|
last_var = var
|
|
return order
|
|
|
|
# The evaluate functions evaluate some symbolic sympy expression
|
|
# (NB: not necessarily an Expr) and return what the concrete result
|
|
# is, guarding on the expression being that result
|
|
|
|
# NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b)
|
|
# as this will ensure that you actually have a sympy'ified expression,
|
|
# and will prevent you from incorrectly writing evaluate_expr(a == b)
|
|
# which does the wrong thing if a or b is a sympy expression
|
|
def evaluate_expr(
|
|
self,
|
|
left: Union[Expr, sympy.logic.boolalg.Boolean],
|
|
size_oblivious: bool = False,
|
|
) -> bool:
|
|
assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left)
|
|
return self.shape_env.evaluate_expr(
|
|
sympy.sympify(left), size_oblivious=size_oblivious
|
|
)
|
|
|
|
def evaluate_min(self, left: Expr, right: Expr) -> Expr:
|
|
"""return the smaller of left and right, and guard on that choice"""
|
|
if isinstance(left, Expr):
|
|
left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
|
if isinstance(right, Expr):
|
|
right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
|
try:
|
|
lv = self.size_hint(left)
|
|
rv = self.size_hint(right)
|
|
except TypeError: # unbacked symints
|
|
if left == right or self.statically_known_leq(left, right):
|
|
return left
|
|
if self.statically_known_leq(right, left):
|
|
return right
|
|
gcd = sympy.gcd(left, right)
|
|
if left == gcd: # handle `min(10*u0, u0)` etc
|
|
return left
|
|
if right == gcd:
|
|
return right
|
|
raise TypeError(
|
|
f"evaluate_min({left}, {right}) with unbacked symints"
|
|
) from None
|
|
if lv <= rv:
|
|
self.guard_leq(left, right)
|
|
return left
|
|
else:
|
|
self.guard_leq(right, left)
|
|
return right
|
|
|
|
def evaluate_max(self, left: Expr, right: Expr) -> Expr:
|
|
"""return the larger of left and right, and guard on that choice"""
|
|
# Always choose the opposite of eval min for consistency
|
|
# This means min(a, b) and max(a, b) produce the same guards
|
|
min_val = self.evaluate_min(left, right)
|
|
return right if min_val is left else left
|
|
|
|
def evaluate_static_shape(self, left: Union[Expr, int]) -> int:
|
|
if isinstance(left, int):
|
|
return left
|
|
right = self.size_hint(left)
|
|
self.guard_equals(left, sympy.Integer(right))
|
|
return int(right)
|
|
|
|
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]:
|
|
return [self.evaluate_static_shape(x) for x in left]
|
|
|
|
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
|
|
if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined]
|
|
return sympy_subs(expr, self.inv_precomputed_replacements) # type: ignore[arg-type]
|
|
return expr
|
|
|
|
def symbolic_hint(self, expr: Union[Expr, int]) -> Union[Expr, int]:
|
|
if isinstance(expr, int):
|
|
return expr
|
|
# Substitute all hints into expr, but leave unbacked symints alone
|
|
expr = self.simplify(expr)
|
|
if not isinstance(expr, Expr):
|
|
assert isinstance(expr, int)
|
|
return expr
|
|
free_symbols = expr.free_symbols
|
|
if not free_symbols:
|
|
try:
|
|
return int(expr) # type: ignore[return-value]
|
|
except TypeError:
|
|
return expr # inf/nan/I
|
|
expr = self.remove_precomputed_replacements(expr)
|
|
return sympy_subs(expr, self.var_to_val)
|
|
|
|
def size_hint(
|
|
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
|
|
) -> int:
|
|
out = self.symbolic_hint(expr)
|
|
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
|
|
# Use the provided heuristic fallback hint
|
|
unbacked_sym_vrs = {
|
|
s: self.shape_env.var_to_range.get(s, None) for s in out.free_symbols
|
|
}
|
|
if all(vr is not None for vr in unbacked_sym_vrs.values()):
|
|
hint_vr = bound_sympy(out, unbacked_sym_vrs) # type: ignore[arg-type]
|
|
if isinstance(hint_vr.lower, (int, sympy.Integer)):
|
|
fallback = max(fallback, int(hint_vr.lower))
|
|
if isinstance(hint_vr.upper, (int, sympy.Integer)):
|
|
fallback = min(fallback, int(hint_vr.upper))
|
|
return fallback
|
|
|
|
try:
|
|
return int(out)
|
|
except Exception:
|
|
log.debug("failed on: %s", out)
|
|
raise
|
|
|
|
def size_hints(
|
|
self,
|
|
exprs: Iterable[Expr],
|
|
*,
|
|
fallback: Optional[int] = None,
|
|
) -> tuple[int, ...]:
|
|
return tuple(self.size_hint(x, fallback=fallback) for x in exprs)
|
|
|
|
def _lru_cache(self, fn, maxsize=None):
|
|
"""
|
|
Wrapper around functools.lru_cache that clears when replacements
|
|
has been invalidated.
|
|
"""
|
|
fn_cache = functools.lru_cache(maxsize)(fn)
|
|
prior_len = len(self.replacements)
|
|
|
|
@functools.wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
nonlocal prior_len
|
|
if prior_len != len(self.replacements):
|
|
prior_len = len(self.replacements)
|
|
fn_cache.cache_clear()
|
|
return fn_cache(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
def make_stride_vars_cache(self):
|
|
cache = self._lru_cache(self._stride_vars)
|
|
|
|
def stride_vars(
|
|
index: Expr,
|
|
vars: Sequence[sympy.Symbol],
|
|
support_vars: Optional[Sequence[sympy.Symbol]] = None,
|
|
) -> list[Expr]:
|
|
if not support_vars:
|
|
support_vars = vars
|
|
return cache(index, tuple(vars), tuple(support_vars))
|
|
|
|
return stride_vars
|
|
|
|
def _stride_vars(
|
|
self,
|
|
index: Expr,
|
|
vars: Sequence[sympy.Symbol],
|
|
support_vars: Sequence[sympy.Symbol],
|
|
) -> list[Expr]:
|
|
"""Convert an indexing expression back into strides
|
|
|
|
NOTE: This is only valid if the index is a standard strided offset
|
|
calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a
|
|
stride of -10 because the index wraps around after the first element
|
|
|
|
"""
|
|
strides = []
|
|
index = self.simplify(index)
|
|
# remove any offset
|
|
index = index - sympy_subs(
|
|
index, {v: sympy.S.Zero for v in support_vars if v != 0}
|
|
)
|
|
for i in range(len(vars)):
|
|
# drop all the other dims
|
|
index_dim = sympy_subs(
|
|
index,
|
|
{
|
|
support_vars[j]: sympy.S.Zero
|
|
for j in range(len(support_vars))
|
|
if vars[i] != support_vars[j] and support_vars[j] != 0
|
|
},
|
|
)
|
|
v = vars[i]
|
|
if v == 0:
|
|
strides.append(sympy.S.Zero)
|
|
else:
|
|
# TODO(jansel): should we use sympy.diff here?
|
|
strides.append(
|
|
sympy_subs(index_dim, {v: sympy.S.One})
|
|
- sympy_subs(index_dim, {v: sympy.S.Zero})
|
|
)
|
|
return strides
|
|
|
|
def atomically_apply_size_hint(
|
|
self, expr: Union[Expr, int], *, fallback: Optional[int] = None
|
|
) -> Union[Expr, int]:
|
|
if isinstance(expr, int):
|
|
return int(expr)
|
|
|
|
# For multiple expressions that depend on an unbacked symint,
|
|
# we want to compute them consistently for a size hint we have chosen.
|
|
# So, recursively compute expressions via size hints of contained symbols.
|
|
# For example: u1 * u2 - 10 ==> fallback * fallback - 10
|
|
assert isinstance(expr, Expr), type(expr)
|
|
free_symbols = expr.free_symbols
|
|
size_dict = {
|
|
symbol: V.graph.sizevars.size_hint(symbol, fallback=fallback)
|
|
for symbol in free_symbols
|
|
}
|
|
return expr.subs(size_dict)
|
|
|
|
def offset_var(self, index: Expr, vars: list[sympy.Symbol]) -> Expr:
|
|
"""Extract offset part of an indexing expression"""
|
|
index = self.simplify(index)
|
|
return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0})
|
|
|
|
def stride_hints(
|
|
self,
|
|
index: Expr,
|
|
vars: Sequence[sympy.Symbol],
|
|
support_vars: Optional[Sequence[sympy.Symbol]] = None,
|
|
) -> list[int]:
|
|
for v in index.free_symbols:
|
|
if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined]
|
|
index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
|
|
result = []
|
|
for s in self.stride_vars(index, vars, support_vars):
|
|
try:
|
|
result.append(self.size_hint(s))
|
|
except TypeError:
|
|
result.append(0)
|
|
return result
|
|
|
|
def stride_order(self, index: Expr, vars: list[sympy.Symbol]) -> list[int]:
|
|
strides = tuple(map(abs, self.stride_hints(index, vars)))
|
|
order = list(range(len(strides)))
|
|
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
|
|
return order
|
|
|
|
def lookup_precomputed_size(self, expr: Expr) -> Expr:
|
|
if (
|
|
isinstance(expr, (int, sympy.Symbol, sympy.Number))
|
|
or expr.is_number
|
|
or expr.is_symbol
|
|
):
|
|
return expr
|
|
expr = self.remove_precomputed_replacements(expr)
|
|
if expr not in self.precomputed_replacements:
|
|
sym = sympy_index_symbol_with_prefix(
|
|
SymT.PRECOMPUTED_SIZE, len(self.precomputed_replacements)
|
|
)
|
|
self.precomputed_replacements[expr] = sym
|
|
self.inv_precomputed_replacements[sym] = expr
|
|
return self.precomputed_replacements[expr]
|
|
|
|
def free_symbols(self) -> OrderedSet[sympy.Symbol]:
|
|
return OrderedSet(self.var_to_val.keys()) - OrderedSet(self.replacements.keys())
|
|
|
|
def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr:
|
|
"""
|
|
A pair of special ModularIndexing can be combined.
|
|
|
|
E.g. ModularIndexing(ModularIndexing(x, 1, a), 1, b)
|
|
We can simplify this to ModuleIndexing(x, 1, b), if
|
|
1. x is non negative integer
|
|
2. a and b are positive integers
|
|
3. a is a multiple of b.
|
|
"""
|
|
|
|
def _check_args(x, div, mod, is_first):
|
|
if not isinstance(div, sympy.Integer) or not isinstance(mod, sympy.Integer):
|
|
return False
|
|
if div != 1:
|
|
return False
|
|
if mod <= 0:
|
|
return False
|
|
|
|
if is_first:
|
|
# first ModularIndexing should conatins a nested ModularIndex
|
|
if not isinstance(x, ModularIndexing):
|
|
return False
|
|
else:
|
|
# second ModularIndexing should constains a non-negative
|
|
# symbol
|
|
if not isinstance(x, sympy.Symbol) or not self.statically_known_geq(
|
|
x, 0
|
|
):
|
|
return False
|
|
return True
|
|
|
|
if isinstance(index, ModularIndexing):
|
|
x, div, mod = index.args
|
|
|
|
if not _check_args(x, div, mod, True):
|
|
return index
|
|
|
|
x2, div2, mod2 = x.args
|
|
|
|
if not _check_args(x2, div2, mod2, False):
|
|
return index
|
|
|
|
if mod2 % mod != 0:
|
|
return index
|
|
|
|
return ModularIndexing(x2, 1, mod)
|
|
|
|
return index
|
|
|
|
def expand_floor_div(
|
|
self, index: sympy.Expr
|
|
) -> Union[bool, tuple[sympy.Expr, sympy.Expr]]:
|
|
"""
|
|
Expand the FloorDiv to the entire expression so that the expression may
|
|
be simplfied.
|
|
|
|
E.g., for a 2D contiguous tensor with shape [a, 2 * b], and index variables
|
|
x1, x2, index expression 'x1 * 2b + x2' can be easily combined.
|
|
But index expression 'x1 * b + x2 // 2' can not.
|
|
By expanding the FloorDiv to the entire expression, we get
|
|
'(x1 * 2b + x2) // 2'. This transformation allows us to merge loops
|
|
for the numerator!
|
|
|
|
Return false if this optimization can be applied;
|
|
Return the new expression and the denominator otherwise.
|
|
The original expression will be equivalent to 'new_expression // denominator'
|
|
"""
|
|
if not isinstance(index, sympy.Add):
|
|
return False
|
|
terms = index.args
|
|
|
|
if len(terms) < 2:
|
|
return False
|
|
floor_div_index = -1
|
|
varlist = []
|
|
factorlist = []
|
|
for idx, term in enumerate(terms):
|
|
if isinstance(term, sympy.Mul):
|
|
# For dynamic shape, term like '2*s1*x1' has 3 child nodes.
|
|
# - A integer for 2
|
|
# - A symbol for s1
|
|
# - A symbol for x1
|
|
# Skip for now.
|
|
if len(term.args) != 2:
|
|
return False
|
|
factor, var = term.args
|
|
varlist.append(var)
|
|
factorlist.append(factor)
|
|
if not isinstance(factor, sympy.Integer) or not isinstance(
|
|
var, sympy.Symbol
|
|
):
|
|
return False
|
|
# It's easier to reason about the correceness of the transformation
|
|
# for non-negative integers.
|
|
if not self.statically_known_geq(var, 0):
|
|
return False
|
|
elif isinstance(term, FloorDiv):
|
|
var, factor = term.args
|
|
if not isinstance(factor, sympy.Integer) or not isinstance(
|
|
var, sympy.Symbol
|
|
):
|
|
return False
|
|
if not self.statically_known_geq(var, 0):
|
|
return False
|
|
if floor_div_index >= 0:
|
|
# can not handle multi FloorDiv yet
|
|
return False
|
|
|
|
floor_div_index = idx
|
|
varlist.append(var)
|
|
# this factor is denominator
|
|
factorlist.append(factor)
|
|
else:
|
|
return False
|
|
|
|
if floor_div_index < 0:
|
|
return False
|
|
|
|
# Construct the new expression and remember the denominator
|
|
denominator = factorlist[floor_div_index]
|
|
new_index = sympy.S.Zero
|
|
|
|
for var, factor, idx in zip(varlist, factorlist, itertools.count()):
|
|
if idx == floor_div_index:
|
|
new_index += var
|
|
else:
|
|
new_index += (factor * denominator) * var
|
|
|
|
return new_index, denominator
|
|
|
|
|
|
def join_dimensions(expr: Expr) -> Expr:
|
|
if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing):
|
|
return expr # fast exit path
|
|
return _join_dimensions_cached(expr)
|
|
|
|
|
|
@functools.lru_cache(256)
|
|
def _join_dimensions_cached(expr: Expr) -> Expr:
|
|
"""
|
|
ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
|
|
becomes
|
|
ModularIndexing(i0, 1, 128)
|
|
ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32)
|
|
becomes i0
|
|
|
|
|
|
This type of pattern can come from view operations
|
|
"""
|
|
assert isinstance(expr, sympy.Add)
|
|
|
|
scale = sympy.Wild("scale", exclude=[0], integer=True)
|
|
base = sympy.Wild("base", integer=True)
|
|
divisor = sympy.Wild("divisor", integer=True)
|
|
mod1 = sympy.Wild("modulus", integer=True)
|
|
mod2 = sympy.Wild("modulus2", integer=True)
|
|
for term1 in expr.args:
|
|
m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
|
|
if m1:
|
|
for term2 in expr.args:
|
|
m2 = term2.match(
|
|
m1[scale]
|
|
* m1[mod1]
|
|
* ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2)
|
|
)
|
|
if m2 and term1 != term2:
|
|
expr = join_dimensions(
|
|
expr
|
|
- term1
|
|
- term2
|
|
+ m1[scale]
|
|
* ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2])
|
|
)
|
|
return expr
|
|
for term1 in expr.args:
|
|
m1 = term1.match(scale * ModularIndexing(base, divisor, mod1))
|
|
if m1:
|
|
for term2 in expr.args:
|
|
m2 = term2.match(
|
|
m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1])
|
|
)
|
|
if m2 is not None: # in case of success we get an empty dict here
|
|
expr = join_dimensions(
|
|
expr
|
|
- term1
|
|
- term2
|
|
+ m1[scale] * FloorDiv(m1[base], m1[divisor])
|
|
)
|
|
return expr
|
|
return expr
|
|
|
|
|
|
class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
|
"""
|
|
A wrapper around .virtualize.ops that uses var range information to
|
|
simplify ModularIndexing/FloorDiv.
|
|
"""
|
|
|
|
def __init__(self, inner, var_ranges: VarRanges) -> None:
|
|
super().__init__(inner)
|
|
self.name = "SimplifyIndexing"
|
|
self._simplify: Callable[
|
|
[Expr], Expr
|
|
] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
return self._inner.load(name, self._simplify(index))
|
|
|
|
def store(self, name, index, value, mode=None):
|
|
return self._inner.store(name, self._simplify(index), value, mode=mode)
|
|
|
|
def store_reduction(self, name, index, value):
|
|
return self._inner.store_reduction(name, self._simplify(index), value)
|
|
|
|
def index_expr(self, index, dtype):
|
|
return self._inner.index_expr(self._simplify(index), dtype)
|
|
|
|
def check_bounds(self, index, size, lower, upper):
|
|
return self._inner.check_bounds(self._simplify(index), size, lower, upper)
|