mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Split this directory into two PRs to keep them from being too large. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062 Approved by: https://github.com/oulgen, https://github.com/mlazos
383 lines
13 KiB
Python
383 lines
13 KiB
Python
# mypy: allow-untyped-defs
|
|
"""This file implements the IndexPropagation ops handler, which wraps an
|
|
underlying handler to add a limited form of constant propagation, as well as
|
|
propagation of sympy expressions downstream of ops.index_expr calls.
|
|
|
|
For example, say we have the IR:
|
|
|
|
tmp0 = ops.index_expr(x, torch.int32)
|
|
tmp1 = ops.constant(2, torch.int32)
|
|
tmp2 = ops.mul(tmp0, tmp1)
|
|
tmp3 = ops.indirect_indexing(tmp2, x_size)
|
|
tmp4 = ops.load("buf0", tmp3)
|
|
|
|
The underlying handler would just see:
|
|
|
|
ops.load("buf0", x * 2)
|
|
|
|
This is limited by the set of operators handled in the sympy expression
|
|
printers. So simple operations like minimum and maximum cannot be translated to
|
|
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
|
|
|
"""
|
|
|
|
import itertools
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass
|
|
from typing import Any, Literal, Optional, overload, Union
|
|
from typing_extensions import TypeAlias
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._prims_common import dtype_to_type, is_integer_dtype
|
|
from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
|
|
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
|
|
|
from .ops_handler import DefaultHandler
|
|
from .sizevars import statically_known_true
|
|
from .utils import generate_assert
|
|
from .virtualized import V
|
|
|
|
|
|
_ExprType = Union[sympy.Expr, float, int, bool]
|
|
|
|
|
|
def _is_constant(val: _ExprType):
|
|
if isinstance(val, sympy.Basic):
|
|
return val.is_number
|
|
return isinstance(val, (int, float, bool))
|
|
|
|
|
|
def upper_bound(val: _ExprType):
|
|
return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val
|
|
|
|
|
|
@dataclass
|
|
class TypedExpr:
|
|
"""A SymPy expression with associated type"""
|
|
|
|
expr: _ExprType
|
|
dtype: torch.dtype
|
|
|
|
def is_constant(self):
|
|
return _is_constant(self.expr)
|
|
|
|
def __post_init__(self):
|
|
if _is_constant(self.expr):
|
|
expr = self.expr
|
|
if isinstance(expr, sympy.Expr):
|
|
expr = expr.expand(identity=True)
|
|
expr = dtype_to_type(self.dtype)(expr)
|
|
if is_integer_dtype(self.dtype):
|
|
bits = torch.iinfo(self.dtype).bits
|
|
if self.dtype.is_signed:
|
|
expr = expr + 2 ** (bits - 1)
|
|
expr = expr % 2**bits
|
|
if self.dtype.is_signed:
|
|
expr = expr - 2 ** (bits - 1)
|
|
self.expr = expr
|
|
|
|
|
|
class SymPyOps:
|
|
"""An ops handler where all IR values are SymPy expressions
|
|
|
|
When a value cannot be represented as a SymPy expression, the method is
|
|
either not defined, or returns NotImplemented
|
|
|
|
"""
|
|
|
|
@staticmethod
|
|
def identity(value: Any) -> Any:
|
|
return value
|
|
|
|
@staticmethod
|
|
def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
|
|
return TypedExpr(value, dtype)
|
|
|
|
@staticmethod
|
|
def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr:
|
|
return TypedExpr(value, dtype)
|
|
|
|
@staticmethod
|
|
def to_dtype(
|
|
value: TypedExpr,
|
|
dtype: torch.dtype,
|
|
src_dtype: Optional[torch.dtype] = None,
|
|
use_compute_types: bool = False,
|
|
) -> TypedExpr:
|
|
return TypedExpr(value.expr, dtype)
|
|
|
|
@staticmethod
|
|
def abs(x: TypedExpr) -> TypedExpr:
|
|
return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type]
|
|
|
|
@staticmethod
|
|
def square(x: TypedExpr) -> TypedExpr:
|
|
return TypedExpr(x.expr * x.expr, x.dtype)
|
|
|
|
@staticmethod
|
|
def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
|
result_type = torch.promote_types(x.dtype, y.dtype)
|
|
return TypedExpr(x.expr + y.expr, result_type)
|
|
|
|
@staticmethod
|
|
def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
|
result_type = torch.promote_types(x.dtype, y.dtype)
|
|
return TypedExpr(x.expr - y.expr, result_type)
|
|
|
|
@staticmethod
|
|
def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
|
result_type = torch.promote_types(x.dtype, y.dtype)
|
|
return TypedExpr(x.expr * y.expr, result_type)
|
|
|
|
@staticmethod
|
|
def neg(x: TypedExpr) -> TypedExpr:
|
|
return TypedExpr(-x.expr, x.dtype)
|
|
|
|
@staticmethod
|
|
def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
|
result_type = torch.promote_types(x.dtype, y.dtype)
|
|
if not is_integer_dtype(result_type):
|
|
return NotImplemented
|
|
|
|
return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
|
|
|
|
@staticmethod
|
|
def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
|
|
result_type = torch.promote_types(x.dtype, y.dtype)
|
|
if not is_integer_dtype(result_type):
|
|
return NotImplemented
|
|
|
|
result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr)
|
|
return TypedExpr(result_expr, result_type)
|
|
|
|
@staticmethod
|
|
def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
|
|
result_type = torch.promote_types(x.dtype, y.dtype)
|
|
if not is_integer_dtype(result_type):
|
|
return NotImplemented
|
|
|
|
x_expr = sympy.sympify(x.expr)
|
|
y_expr = sympy.sympify(y.expr)
|
|
# In these cases, remainder in Python == remainder in C++, so this transformation
|
|
# is sound
|
|
if (
|
|
x_expr.is_nonnegative is not None
|
|
and x_expr.is_nonnegative == y_expr.is_positive
|
|
):
|
|
result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr)
|
|
return TypedExpr(result_expr, result_type)
|
|
return NotImplemented
|
|
|
|
@staticmethod
|
|
def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
|
result_type = torch.promote_types(x.dtype, y.dtype)
|
|
return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
|
|
|
|
@staticmethod
|
|
def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
|
|
result_type = torch.promote_types(x.dtype, y.dtype)
|
|
return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
|
|
|
|
|
|
@dataclass
|
|
class IndexPropVar:
|
|
value: Any # Either an IR value, or TypedExpr if is_symbolic is true
|
|
is_symbolic: bool = False
|
|
|
|
@staticmethod
|
|
def new_symbolic(expr: TypedExpr) -> "IndexPropVar":
|
|
return IndexPropVar(expr, is_symbolic=True)
|
|
|
|
def __post_init__(self):
|
|
assert not self.is_symbolic or isinstance(self.value, TypedExpr), (
|
|
"Symbolic IndexPropVar must contain a TypedExpr"
|
|
)
|
|
|
|
|
|
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
|
|
|
|
|
|
class IndexPropagation(DefaultHandler):
|
|
"""Ops wrapper that tries to propagate constant and index_expr values through the computation.
|
|
|
|
This aims to maximize the compile time simplification possible, and convert
|
|
indirect indexing from arange into normal static indexing.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
inner: Any,
|
|
iter_ranges: dict[sympy.Symbol, sympy.Expr],
|
|
indirect_var_ranges: dict[sympy.Symbol, sympy.Expr],
|
|
) -> None:
|
|
self._inner = inner
|
|
self.shape_env = V.graph.sizevars.shape_env
|
|
|
|
var_to_range = {
|
|
k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items()
|
|
}
|
|
self.var_to_range = tuple(
|
|
itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items())
|
|
)
|
|
# NOTE: this is intentionally kept as a reference so the caller can
|
|
# update it in-place
|
|
self.indirect_var_ranges = indirect_var_ranges
|
|
|
|
axioms = []
|
|
for x, s in iter_ranges.items():
|
|
axioms.append(0 <= x)
|
|
axioms.append(x < s)
|
|
self.axioms = tuple(axioms) + self.shape_env.get_axioms()
|
|
|
|
def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
|
|
# Construct a new constant/index_expr from the SymPy expression
|
|
if _is_constant(expr):
|
|
val = dtype_to_type(dtype)(expr)
|
|
return self._inner.constant(val, dtype)
|
|
return self._inner.index_expr(expr, dtype)
|
|
|
|
def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
|
|
if isinstance(a, (list, tuple)):
|
|
return tuple(self.unwrap(v) for v in a)
|
|
|
|
if not isinstance(a, IndexPropVar):
|
|
return a
|
|
|
|
# Prefer the sympy representation if possible
|
|
if a.is_symbolic:
|
|
return self.materialize_expr(a.value.expr, a.value.dtype)
|
|
|
|
return a.value
|
|
|
|
def wrap(self, a) -> IndexPropResult:
|
|
if isinstance(a, (list, tuple)):
|
|
return tuple(self.wrap(v) for v in a)
|
|
return IndexPropVar(a)
|
|
|
|
@overload
|
|
def fallback(
|
|
self,
|
|
name: Literal["indirect_indexing"],
|
|
args: Sequence[Any],
|
|
kwargs: dict[str, Any],
|
|
) -> IndexPropVar: ...
|
|
|
|
@overload
|
|
def fallback(
|
|
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
|
) -> IndexPropResult: ...
|
|
|
|
def fallback(
|
|
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
|
) -> IndexPropResult:
|
|
# Fallback to the wrapped handler
|
|
new_args = [self.unwrap(a) for a in args]
|
|
new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
|
|
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
|
|
|
def propagate_sympy(
|
|
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
|
) -> IndexPropResult:
|
|
# Build a new SymPy expression from this ops call
|
|
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
|
if not isinstance(a, IndexPropVar):
|
|
return a
|
|
return a.value
|
|
|
|
new_args = [unwrap(a) for a in args]
|
|
new_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
|
|
new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
|
|
is_valid_expr = new_expr is not NotImplemented and (
|
|
# Inductor doesn't expect floating point in sympy expressions, but
|
|
# allow floating point constants to be propagated
|
|
new_expr.is_constant() or new_expr.expr.is_integer
|
|
)
|
|
if not is_valid_expr:
|
|
return self.fallback(name, args, kwargs)
|
|
return IndexPropVar.new_symbolic(new_expr)
|
|
|
|
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
|
if not hasattr(SymPyOps, name):
|
|
return self.fallback(name, args, kwargs)
|
|
|
|
var_arguments = [
|
|
a
|
|
for a in itertools.chain(args, kwargs.values())
|
|
if isinstance(a, IndexPropVar)
|
|
]
|
|
if not all(v.is_symbolic for v in var_arguments):
|
|
return self.fallback(name, args, kwargs)
|
|
|
|
return self.propagate_sympy(name, args, kwargs)
|
|
|
|
def statically_true(self, e):
|
|
"""
|
|
Given some iter_ranges, return a function that given an expression, returns whether
|
|
it is true or false using value ranges, guard knowledge and runtime_asserts.
|
|
|
|
FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts
|
|
If this is an issue, just use guards in `self.axioms`.
|
|
|
|
The proper way of handling this would be to have a global shape_env that adds
|
|
runtime_asserts as they happen in the code. Then, it should be used in SimplifyIndexing
|
|
to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also
|
|
for indirect_indexing
|
|
"""
|
|
var_to_range = (
|
|
*self.var_to_range,
|
|
*(
|
|
(k, ValueRanges(0, upper_bound(v) - 1))
|
|
for k, v in self.indirect_var_ranges.items()
|
|
),
|
|
)
|
|
# pyrefly: ignore # bad-argument-type
|
|
return statically_known_true(self.shape_env, e, self.axioms, var_to_range)
|
|
|
|
def indirect_indexing(
|
|
self,
|
|
index: Union[Any, IndexPropVar],
|
|
size: Any,
|
|
check: bool = True,
|
|
wrap_neg=True,
|
|
) -> Any:
|
|
if isinstance(index, IndexPropVar) and index.is_symbolic:
|
|
# If we find something we can convert into a direct indexing we do so
|
|
# We still need to (perhaps) wrap the expression and add bound checks
|
|
# We want to do this "constant folding", as we don't allow to fuse
|
|
# kernels into indirect indexing
|
|
|
|
expr = sympy.sympify(index.value.expr)
|
|
|
|
# TODO Perhaps move this logic to the simplify indexing pass
|
|
def wrap_expr(expr):
|
|
# Positive, negative, mixed
|
|
if self.statically_true(0 <= expr):
|
|
return expr
|
|
elif self.statically_true(expr < 0):
|
|
return expr + size
|
|
else:
|
|
return Where(expr < 0, expr + size, expr)
|
|
|
|
# Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr
|
|
can_prove_lower = self.statically_true(0 <= expr) or self.statically_true(
|
|
-size <= expr
|
|
)
|
|
can_prove_upper = self.statically_true(expr < size)
|
|
if wrap_neg:
|
|
expr = wrap_expr(expr)
|
|
if generate_assert(check):
|
|
self.fallback(
|
|
"check_bounds",
|
|
(expr, size),
|
|
dict(lower=not can_prove_lower, upper=not can_prove_upper),
|
|
)
|
|
return expr
|
|
|
|
indirect_var = self.fallback(
|
|
"indirect_indexing", (index, size, check, wrap_neg), {}
|
|
).value
|
|
return indirect_var
|