mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Type _sympy/functions.py [1/n] (#136205)
Signed-off-by: Bob Ren <bobren@fb.com> I was chatting with @jamesjwu about strategies to learn the code and he suggested adding types to some files. This stack of PRs adds types to _sympy/functions.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/136205 Approved by: https://github.com/Skylion007, https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
803ce507f1
commit
8d9c42735a
@ -3,6 +3,17 @@ import functools
|
||||
import math
|
||||
import operator
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
SupportsFloat,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sympy
|
||||
from sympy import S
|
||||
@ -19,6 +30,8 @@ from sympy.utilities.iterables import sift
|
||||
from .numbers import int_oo
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=SupportsFloat)
|
||||
|
||||
# Portions of this file are adapted from the Sympy codebase, which was
|
||||
# licensed as follows:
|
||||
#
|
||||
@ -76,9 +89,9 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def _keep_float(f):
|
||||
def _keep_float(f: Callable[..., _T]) -> Callable[..., sympy.Float]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args):
|
||||
def inner(*args: Any) -> Union[_T, sympy.Float]:
|
||||
r = f(*args)
|
||||
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
||||
r, sympy.Float
|
||||
@ -89,13 +102,13 @@ def _keep_float(f):
|
||||
return inner
|
||||
|
||||
|
||||
def fuzzy_eq(x, y):
|
||||
def fuzzy_eq(x: Optional[bool], y: Optional[bool]) -> Optional[bool]:
|
||||
if None in (x, y):
|
||||
return None
|
||||
return x == y
|
||||
|
||||
|
||||
def simple_floordiv_gcd(p, q):
|
||||
def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
|
||||
"""
|
||||
Fast path for sympy.gcd, using a simple factoring strategy.
|
||||
|
||||
@ -112,23 +125,27 @@ def simple_floordiv_gcd(p, q):
|
||||
might be necessary to fall back on sympy.gcd.
|
||||
"""
|
||||
|
||||
def integer_coefficient(x):
|
||||
integer_coefficients = [
|
||||
def integer_coefficient(x: sympy.Basic) -> int:
|
||||
integer_coefficients: List[int] = [
|
||||
abs(int(arg))
|
||||
for arg in sympy.Mul.make_args(x)
|
||||
if isinstance(arg, (int, sympy.Integer))
|
||||
]
|
||||
return math.prod(integer_coefficients)
|
||||
|
||||
def integer_factor(expr):
|
||||
integer_factors = map(integer_coefficient, sympy.Add.make_args(expr))
|
||||
def integer_factor(expr: sympy.Basic) -> int:
|
||||
integer_factors: Iterable[int] = map(
|
||||
integer_coefficient, sympy.Add.make_args(expr)
|
||||
)
|
||||
return functools.reduce(math.gcd, integer_factors)
|
||||
|
||||
gcd = math.gcd(integer_factor(p), integer_factor(q))
|
||||
gcd: int = math.gcd(integer_factor(p), integer_factor(q))
|
||||
p, q = p / gcd, q / gcd
|
||||
|
||||
base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p)))
|
||||
divisor_split = sympy.Mul.make_args(q)
|
||||
base_splits: List[Tuple[sympy.Basic, ...]] = list(
|
||||
map(sympy.Mul.make_args, sympy.Add.make_args(p))
|
||||
)
|
||||
divisor_split: Tuple[sympy.Basic, ...] = sympy.Mul.make_args(q)
|
||||
for x in divisor_split:
|
||||
if all(x in base_split for base_split in base_splits):
|
||||
gcd = gcd * x
|
||||
@ -162,20 +179,19 @@ class FloorDiv(sympy.Function):
|
||||
NB: This is Python-style floor division, round to -Inf
|
||||
"""
|
||||
|
||||
nargs = (2,)
|
||||
precedence = 50 # precedence of mul # noqa: F811
|
||||
|
||||
is_integer = True
|
||||
nargs: Tuple[int, ...] = (2,)
|
||||
precedence: int = 50 # precedence of mul # noqa: F811
|
||||
is_integer: bool = True
|
||||
|
||||
@property
|
||||
def base(self):
|
||||
def base(self) -> sympy.Basic:
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def divisor(self):
|
||||
def divisor(self) -> sympy.Basic:
|
||||
return self.args[1]
|
||||
|
||||
def _sympystr(self, printer):
|
||||
def _sympystr(self, printer: sympy.printing.printer.Printer) -> str:
|
||||
base = printer.parenthesize(self.base, self.precedence)
|
||||
divisor = printer.parenthesize(self.divisor, self.precedence)
|
||||
return f"({base}//{divisor})"
|
||||
@ -183,7 +199,7 @@ class FloorDiv(sympy.Function):
|
||||
# Automatic evaluation.
|
||||
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
|
||||
@classmethod
|
||||
def eval(cls, base, divisor):
|
||||
def eval(cls, base: sympy.Basic, divisor: sympy.Basic) -> Union[sympy.Basic, None]:
|
||||
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
|
||||
# Assert triggered by inequality solver
|
||||
# assert base.is_integer, base
|
||||
@ -252,17 +268,21 @@ class FloorDiv(sympy.Function):
|
||||
except sympy.PolynomialError:
|
||||
pass # https://github.com/pytorch/pytorch/issues/108276
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ModularIndexing(sympy.Function):
|
||||
"""
|
||||
ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus
|
||||
"""
|
||||
|
||||
nargs = (3,)
|
||||
is_integer = True
|
||||
nargs: Tuple[int, ...] = (3,)
|
||||
is_integer: bool = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, base, divisor, modulus):
|
||||
def eval(
|
||||
cls, base: sympy.Basic, divisor: sympy.Basic, modulus: sympy.Basic
|
||||
) -> Optional[sympy.Basic]:
|
||||
if base == 0 or modulus == 1:
|
||||
return sympy.Integer(0)
|
||||
|
||||
@ -286,8 +306,8 @@ class ModularIndexing(sympy.Function):
|
||||
pass # https://github.com/pytorch/pytorch/issues/108276
|
||||
|
||||
if isinstance(base, sympy.Add):
|
||||
new_terms = []
|
||||
all_positive = True
|
||||
new_terms: List[sympy.Basic] = []
|
||||
all_positive: bool = True
|
||||
for term in base.args:
|
||||
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
|
||||
if (isinstance(term, sympy.Integer) and term < 0) or (
|
||||
@ -310,11 +330,13 @@ class ModularIndexing(sympy.Function):
|
||||
if isinstance(base, FloorDiv):
|
||||
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
|
||||
|
||||
def _eval_is_nonnegative(self): # type:ignore[override]
|
||||
return None
|
||||
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
p, q = self.args[:2]
|
||||
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_positive(self): # type:ignore[override]
|
||||
def _eval_is_positive(self) -> Optional[bool]:
|
||||
p, q = self.args[:2]
|
||||
return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined]
|
||||
|
||||
@ -324,37 +346,40 @@ class Where(sympy.Function):
|
||||
Good ol' ternary operator
|
||||
"""
|
||||
|
||||
nargs = (3,)
|
||||
nargs: Tuple[int, ...] = (3,)
|
||||
|
||||
def _eval_is_integer(self):
|
||||
def _eval_is_integer(self) -> Optional[bool]:
|
||||
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonnegative(self): # type:ignore[override]
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
return (
|
||||
True
|
||||
if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
|
||||
else None
|
||||
)
|
||||
|
||||
def _eval_is_positive(self): # type:ignore[override]
|
||||
def _eval_is_positive(self) -> Optional[bool]:
|
||||
return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def eval(cls, c, p, q):
|
||||
def eval(
|
||||
cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic
|
||||
) -> Optional[sympy.Basic]:
|
||||
if c == sympy.true:
|
||||
return p
|
||||
elif c == sympy.false:
|
||||
return q
|
||||
return None
|
||||
|
||||
|
||||
# Python-style modulus: take sign from RHS
|
||||
class PythonMod(sympy.Function):
|
||||
nargs = (2,)
|
||||
nargs: Tuple[int, ...] = (2,)
|
||||
|
||||
is_integer = True
|
||||
is_integer: bool = True
|
||||
|
||||
@classmethod
|
||||
def eval(cls, p, q):
|
||||
def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]:
|
||||
# python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
|
||||
# Triggered by sympy.solvers.inequalities.reduce_inequalities
|
||||
# assert p.is_integer, p
|
||||
@ -396,11 +421,13 @@ class PythonMod(sympy.Function):
|
||||
if sympy.Mod(p, q) == 0:
|
||||
return S.Zero
|
||||
|
||||
return None
|
||||
|
||||
# NB: args[1] for PythonMod
|
||||
def _eval_is_nonnegative(self): # type:ignore[override]
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_nonpositive(self):
|
||||
def _eval_is_nonpositive(self) -> Optional[bool]:
|
||||
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user