unify symbolic_shapes and sizevars dynamic shapes APIs naming 1 (#154774)

Inductor have a set of APIs that allows performing symbolic evaluations similar to that of symbolic shapes
but it operates on sympy expressions instead of symnodes. Namings are not consistent making them consistent
in this stack.

Step 1 : unify statically_know_true naming! for consistent experience.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154774
Approved by: https://github.com/drisspg, https://github.com/bobrenjc93, https://github.com/eellison
This commit is contained in:
Laith Sakka
2025-06-11 18:47:29 -07:00
committed by PyTorch MergeBot
parent 9df2e8020f
commit f4376cac54
6 changed files with 38 additions and 59 deletions

View File

@ -35,7 +35,7 @@ from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from .ops_handler import DefaultHandler
from .sizevars import evaluate_expr
from .sizevars import statically_known_true
from .utils import generate_assert
from .virtualized import V
@ -322,7 +322,7 @@ class IndexPropagation(DefaultHandler):
for k, v in self.indirect_var_ranges.items()
),
)
return evaluate_expr(self.shape_env, e, self.axioms, var_to_range)
return statically_known_true(self.shape_env, e, self.axioms, var_to_range)
def indirect_indexing(
self,

View File

@ -633,7 +633,7 @@ class IRNode:
return sympy_product(self.get_size())
def is_zero_elements(self) -> bool:
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
def realize(self) -> Optional[str]:
"""
@ -1664,7 +1664,7 @@ class Reduction(Loops):
reindex = View.dynamic_reshape_indexer(
reduction_ranges, [reduction_numel], dense_index
)
need_mask = not V.graph.sizevars.is_expr_static_and_true(
need_mask = not V.graph.sizevars.statically_known_true(
sympy.Eq(reduction_numel % split, 0)
)
@ -2113,7 +2113,7 @@ class WelfordReduction(MultiOutputReduction):
recursively
"""
reduction_numel = sympy_product(reduction_ranges)
need_mask = not V.graph.sizevars.is_expr_static_and_true(
need_mask = not V.graph.sizevars.statically_known_true(
sympy.Eq(reduction_numel % split, 0)
)
@ -2296,7 +2296,7 @@ class Scan(Loops):
assert len(dtypes) == len(inner_fns)
# Scan with a single element is just a copy
if sizevars.is_expr_static_and_true(sympy.Le(scan_numel, 1)):
if sizevars.statically_known_true(sympy.Le(scan_numel, 1)):
return [
Pointwise.create(
device=device,
@ -2496,7 +2496,7 @@ class Sort(Loops):
max_rblock = 512
is_persistent_kernel = (
config.triton.persistent_reductions
and sizevars.is_expr_static_and_true(sympy.Le(sort_numel, max_rblock))
and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock))
)
if not is_persistent_kernel:
# We only support persistent triton kernels
@ -2505,7 +2505,7 @@ class Sort(Loops):
assert len(dtypes) == len(inner_fns)
# Sort with a single element is just a copy
if sizevars.is_expr_static_and_true(sympy.Le(sort_numel, 1)):
if sizevars.statically_known_true(sympy.Le(sort_numel, 1)):
return [
Pointwise.create(
device=device,
@ -4057,7 +4057,7 @@ class Buffer(IRNode, CodegenSymbol):
)
def is_zero_elements(self): # type: ignore[no-untyped-def]
return V.graph.sizevars.is_expr_static_and_true(sympy.Eq(self.get_numel(), 0))
return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
# Loading from a zero-element buffer is a no-op

View File

@ -750,7 +750,7 @@ def tuned_mm(mat1, mat2, *, layout=None):
k_splits = get_k_splits(m, n, k)
for k_split in k_splits:
if not V.graph.sizevars.is_expr_static_and_true(
if not V.graph.sizevars.statically_known_true(
sympy.Eq(sympy.Mod(k, k_split), 0)
):
continue

View File

@ -8,11 +8,7 @@ from typing import Any, Callable, cast, Optional, Union
import sympy
from sympy import Expr
from torch.fx.experimental.symbolic_shapes import (
free_unbacked_symbols,
has_free_unbacked_symbols,
ShapeEnv,
)
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols, ShapeEnv
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import symbol_is_type, SymT
@ -32,7 +28,7 @@ from .virtualized import V
log = logging.getLogger(__name__)
def evaluate_expr(
def statically_known_true(
shape_env: ShapeEnv,
expr: Union[sympy.Basic, bool],
axioms: Optional[tuple[sympy.Expr]] = None,
@ -308,33 +304,16 @@ class SizeVarAllocator:
return [x for x in sizes if x is not None], reindex, prune
# Note - [On Statically Known]
#
# The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system
# operated by providing essentially a question, where the size hinted values were evaluated. If the condition was
# true, we add a guard and return True, otherwise, False.
#
# def maybe_guard_foo(args):
# if size_hinted_check(args):
# return False # No guard, no optim
# guard(args) # Make a guard
# return True # Safe to apply optimization
#
# The prior system incurred a guard, and green lit an optimization.
#
# The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the
# condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we
# return False.
#
# def maybe_guard_foo(args):
# if all_static(args):
# return True # Safe to apply optimization
# else:
# return False # No guard, no optim
# See Note - [On Statically Known]
def is_expr_static_and_true(self, expr: Union[sympy.Basic, bool]) -> bool:
return evaluate_expr(self.shape_env, expr)
# The statically_known_* family of functions below NEVER guard, they could return True if the
# asked questions can be answered without guarding otherwise they return False.
# Those are similar to statically_known_true in symbolic_shapes but operate on sympy
# expressions instead of symnodes.
def statically_known_true(self, expr: Union[sympy.Basic, bool]) -> bool:
"""
Returns true if an expression is always true (symbolically or via guards),
false otherwise. Never add guards, or throw data dependent errors.
"""
return statically_known_true(self.shape_env, expr)
def statically_known_equals(
self, left: Union[Expr, int], right: Union[Expr, int]
@ -342,9 +321,8 @@ class SizeVarAllocator:
"""
Returns a bool indicating if it is sound to optimize as if left and right are equal.
"""
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
return self.statically_known_true(sympy.Eq(left, right)) # type: ignore[arg-type]
# See Note - [On Statically Known]
def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool:
"""
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
@ -353,51 +331,48 @@ class SizeVarAllocator:
self.statically_known_equals(l, r) for l, r in zip(left, right)
)
# See Note - [On Statically Known]
def statically_known_leq(self, left: Expr, right: Union[Expr, int]) -> bool:
"""
Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
"""
expr = left <= right
return self.is_expr_static_and_true(expr)
return self.statically_known_true(expr)
# See Note - [On Statically Known]
def statically_known_geq(self, left: Expr, right: Union[Expr, int]) -> bool:
"""
Returns a bool indicating if it is sound to optimize as if left is greater than or equal to right.
"""
expr = left >= right
return self.is_expr_static_and_true(expr)
return self.statically_known_true(expr)
# See Note - [On Statically Known]
def statically_known_lt(self, left: Expr, right: Union[Expr, int]) -> bool:
"""
Returns a bool indicating if it is sound to optimize as if left is less than right.
"""
expr = left < right
return self.is_expr_static_and_true(expr)
return self.statically_known_true(expr)
# See Note - [On Statically Known]
def statically_known_gt(self, left: Expr, right: Union[Expr, int]) -> bool:
"""
Returns a bool indicating if it is sound to optimize as if left is greater than right.
"""
expr = left > right
return self.is_expr_static_and_true(expr)
return self.statically_known_true(expr)
# See Note - [On Statically Known]
def statically_known_multiple_of(
self, numerator: Expr, denominator: Union[Expr, int]
) -> bool:
"""
Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
"""
if free_unbacked_symbols(numerator) or free_unbacked_symbols(denominator):
# The reason we skip unbacked here is that we want to avoid the cost of trying to eval this symbolically.
if has_free_unbacked_symbols(numerator) or has_free_unbacked_symbols(
denominator
):
return False
expr = sympy.Eq(numerator % denominator, 0)
return self.is_expr_static_and_true(expr) # type: ignore[arg-type]
return self.statically_known_true(expr) # type: ignore[arg-type]
# See Note - [On Statically Known]
def statically_known_power_of_2(self, expr: Expr) -> bool:
"""
Returns a bool indicating if x is known to be a power of 2.
@ -454,6 +429,9 @@ class SizeVarAllocator:
last_var = var
return order
# Similar to the functions guard_or_false/guard_or_true in symbolic_shapes but operates on sympy
# expressions instead of symnodes. see Note [guard_or_].
def guard_or_false(self, left):
return self.evaluate_expr(left, fallback_value=False)

View File

@ -1593,7 +1593,7 @@ def use_decompose_k_choice(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
from torch._inductor.virtualized import V
return (
V.graph.sizevars.is_expr_static_and_true(
V.graph.sizevars.statically_known_true(
sympy.And(
sympy.Ge(k, decompose_k_threshold * m),
sympy.Ge(k, decompose_k_threshold * n),
@ -2757,7 +2757,7 @@ def expr_fits_within_32bit(e: sympy.Expr) -> bool:
# Allow for unhinted e as long as we can still statically prove
# (e.g., via ValueRanges) that it is still in bounds
if V.graph.sizevars.is_expr_static_and_true(e <= int_max):
if V.graph.sizevars.statically_known_true(e <= int_max):
return True
# Otherwise, the hint MUST exist and be in range
return has_hint(e) and size_hint(e) <= int_max

View File

@ -1335,6 +1335,7 @@ def compute_unbacked_bindings(
return symbol_to_path
# Note [guard_or_]
# The following two functions are common utilities used while defining unbacked semantics
# of various framework code. Those would be used in situations you prefer to guard and know
# the result of the expression over not guarding, but in case you hit a data dependent error