Type _sympy/functions.py [1/n] (#136205)

Signed-off-by: Bob Ren <bobren@fb.com>

I was chatting with @jamesjwu about strategies to learn the code and he suggested adding types to some files. This stack of PRs adds types to _sympy/functions.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136205
Approved by: https://github.com/Skylion007, https://github.com/jamesjwu
This commit is contained in:
Bob Ren
2024-09-19 07:40:29 -07:00
committed by PyTorch MergeBot
parent 803ce507f1
commit 8d9c42735a

View File

@ -3,6 +3,17 @@ import functools
import math
import operator
import sys
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
SupportsFloat,
Tuple,
TypeVar,
Union,
)
import sympy
from sympy import S
@ -19,6 +30,8 @@ from sympy.utilities.iterables import sift
from .numbers import int_oo
_T = TypeVar("_T", bound=SupportsFloat)
# Portions of this file are adapted from the Sympy codebase, which was
# licensed as follows:
#
@ -76,9 +89,9 @@ __all__ = [
]
def _keep_float(f):
def _keep_float(f: Callable[..., _T]) -> Callable[..., sympy.Float]:
@functools.wraps(f)
def inner(*args):
def inner(*args: Any) -> Union[_T, sympy.Float]:
r = f(*args)
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
r, sympy.Float
@ -89,13 +102,13 @@ def _keep_float(f):
return inner
def fuzzy_eq(x, y):
def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]:
if None in (x, y):
return None
return x == y
def simple_floordiv_gcd(p, q):
def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
"""
Fast path for sympy.gcd, using a simple factoring strategy.
@ -112,23 +125,27 @@ def simple_floordiv_gcd(p, q):
might be necessary to fall back on sympy.gcd.
"""
def integer_coefficient(x):
integer_coefficients = [
def integer_coefficient(x: sympy.Basic) -> int:
integer_coefficients: List[int] = [
abs(int(arg))
for arg in sympy.Mul.make_args(x)
if isinstance(arg, (int, sympy.Integer))
]
return math.prod(integer_coefficients)
def integer_factor(expr):
integer_factors = map(integer_coefficient, sympy.Add.make_args(expr))
def integer_factor(expr: sympy.Basic) -> int:
integer_factors: Iterable[int] = map(
integer_coefficient, sympy.Add.make_args(expr)
)
return functools.reduce(math.gcd, integer_factors)
gcd = math.gcd(integer_factor(p), integer_factor(q))
gcd: int = math.gcd(integer_factor(p), integer_factor(q))
p, q = p / gcd, q / gcd
base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p)))
divisor_split = sympy.Mul.make_args(q)
base_splits: List[Tuple[sympy.Basic, ...]] = list(
map(sympy.Mul.make_args, sympy.Add.make_args(p))
)
divisor_split: Tuple[sympy.Basic, ...] = sympy.Mul.make_args(q)
for x in divisor_split:
if all(x in base_split for base_split in base_splits):
gcd = gcd * x
@ -162,20 +179,19 @@ class FloorDiv(sympy.Function):
NB: This is Python-style floor division, round to -Inf
"""
nargs = (2,)
precedence = 50 # precedence of mul # noqa: F811
is_integer = True
nargs: Tuple[int, ...] = (2,)
precedence: int = 50 # precedence of mul # noqa: F811
is_integer: bool = True
@property
def base(self):
def base(self) -> sympy.Basic:
return self.args[0]
@property
def divisor(self):
def divisor(self) -> sympy.Basic:
return self.args[1]
def _sympystr(self, printer):
def _sympystr(self, printer: sympy.printing.printer.Printer) -> str:
base = printer.parenthesize(self.base, self.precedence)
divisor = printer.parenthesize(self.divisor, self.precedence)
return f"({base}//{divisor})"
@ -183,7 +199,7 @@ class FloorDiv(sympy.Function):
# Automatic evaluation.
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
@classmethod
def eval(cls, base, divisor):
def eval(cls, base: sympy.Basic, divisor: sympy.Basic) -> Union[sympy.Basic, None]:
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
# Assert triggered by inequality solver
# assert base.is_integer, base
@ -252,17 +268,21 @@ class FloorDiv(sympy.Function):
except sympy.PolynomialError:
pass # https://github.com/pytorch/pytorch/issues/108276
return None
class ModularIndexing(sympy.Function):
"""
ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus
"""
nargs = (3,)
is_integer = True
nargs: Tuple[int, ...] = (3,)
is_integer: bool = True
@classmethod
def eval(cls, base, divisor, modulus):
def eval(
cls, base: sympy.Basic, divisor: sympy.Basic, modulus: sympy.Basic
) -> Optional[sympy.Basic]:
if base == 0 or modulus == 1:
return sympy.Integer(0)
@ -286,8 +306,8 @@ class ModularIndexing(sympy.Function):
pass # https://github.com/pytorch/pytorch/issues/108276
if isinstance(base, sympy.Add):
new_terms = []
all_positive = True
new_terms: List[sympy.Basic] = []
all_positive: bool = True
for term in base.args:
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
if (isinstance(term, sympy.Integer) and term < 0) or (
@ -310,11 +330,13 @@ class ModularIndexing(sympy.Function):
if isinstance(base, FloorDiv):
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
def _eval_is_nonnegative(self): # type:ignore[override]
return None
def _eval_is_nonnegative(self) -> Optional[bool]:
p, q = self.args[:2]
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
def _eval_is_positive(self): # type:ignore[override]
def _eval_is_positive(self) -> Optional[bool]:
p, q = self.args[:2]
return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined]
@ -324,37 +346,40 @@ class Where(sympy.Function):
Good ol' ternary operator
"""
nargs = (3,)
nargs: Tuple[int, ...] = (3,)
def _eval_is_integer(self):
def _eval_is_integer(self) -> Optional[bool]:
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
def _eval_is_nonnegative(self): # type:ignore[override]
def _eval_is_nonnegative(self) -> Optional[bool]:
return (
True
if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
else None
)
def _eval_is_positive(self): # type:ignore[override]
def _eval_is_positive(self) -> Optional[bool]:
return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
@classmethod
def eval(cls, c, p, q):
def eval(
cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic
) -> Optional[sympy.Basic]:
if c == sympy.true:
return p
elif c == sympy.false:
return q
return None
# Python-style modulus: take sign from RHS
class PythonMod(sympy.Function):
nargs = (2,)
nargs: Tuple[int, ...] = (2,)
is_integer = True
is_integer: bool = True
@classmethod
def eval(cls, p, q):
def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]:
# python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
# Triggered by sympy.solvers.inequalities.reduce_inequalities
# assert p.is_integer, p
@ -396,11 +421,13 @@ class PythonMod(sympy.Function):
if sympy.Mod(p, q) == 0:
return S.Zero
return None
# NB: args[1] for PythonMod
def _eval_is_nonnegative(self): # type:ignore[override]
def _eval_is_nonnegative(self) -> Optional[bool]:
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
def _eval_is_nonpositive(self):
def _eval_is_nonpositive(self) -> Optional[bool]:
return True if self.args[1].is_negative else None # type: ignore[attr-defined]