mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Re-implementation of https://github.com/pytorch/pytorch/pull/134150, which was reverted because of some internal tests hanging (case B). The original motivation was to get some other internal test unstuck (case A). The root cause is that sympy.gcd is both very clever as well as can blow up in some cases. This PR introduces a fast path with an appropriate fallback to sympy.gcd that ensures that both cases A and B go through. Test Plan: See the included test for specific examples. Also https://fb.workplace.com/groups/1075192433118967/posts/1491493248155548/?comment_id=1491938994777640&reply_comment_id=1492622821375924 Differential Revision: D62043315 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134880 Approved by: https://github.com/ezyang
1197 lines
41 KiB
Python
1197 lines
41 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import math
|
|
import operator
|
|
import sys
|
|
|
|
import sympy
|
|
from sympy import S
|
|
from sympy.core import sympify
|
|
from sympy.core.expr import Expr
|
|
from sympy.core.function import Application
|
|
from sympy.core.logic import _torf, fuzzy_and, fuzzy_or
|
|
from sympy.core.numbers import equal_valued
|
|
from sympy.core.operations import LatticeOp, ShortCircuit
|
|
from sympy.core.sorting import ordered
|
|
from sympy.core.traversal import walk
|
|
from sympy.utilities.iterables import sift
|
|
|
|
from .numbers import int_oo
|
|
|
|
|
|
# Portions of this file are adapted from the Sympy codebase, which was
|
|
# licensed as follows:
|
|
#
|
|
# Copyright (c) 2006-2023 SymPy Development Team
|
|
#
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# a. Redistributions of source code must retain the above copyright notice,
|
|
# this list of conditions and the following disclaimer.
|
|
# b. Redistributions in binary form must reproduce the above copyright
|
|
# notice, this list of conditions and the following disclaimer in the
|
|
# documentation and/or other materials provided with the distribution.
|
|
# c. Neither the name of SymPy nor the names of its contributors
|
|
# may be used to endorse or promote products derived from this software
|
|
# without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
|
|
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
|
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
|
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
|
# DAMAGE.
|
|
|
|
__all__ = [
|
|
"FloorDiv",
|
|
"ModularIndexing",
|
|
"Where",
|
|
"PythonMod",
|
|
"Mod",
|
|
"CleanDiv",
|
|
"CeilToInt",
|
|
"FloorToInt",
|
|
"CeilDiv",
|
|
"IntTrueDiv",
|
|
"FloatTrueDiv",
|
|
"LShift",
|
|
"RShift",
|
|
"IsNonOverlappingAndDenseIndicator",
|
|
"TruncToFloat",
|
|
"TruncToInt",
|
|
"RoundToInt",
|
|
"RoundDecimal",
|
|
"ToFloat",
|
|
"FloatPow",
|
|
"PowByNatural",
|
|
"Identity",
|
|
]
|
|
|
|
|
|
def _keep_float(f):
|
|
@functools.wraps(f)
|
|
def inner(*args):
|
|
r = f(*args)
|
|
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
|
r, sympy.Float
|
|
):
|
|
r = sympy.Float(float(r))
|
|
return r
|
|
|
|
return inner
|
|
|
|
|
|
def fuzzy_eq(x, y):
|
|
if None in (x, y):
|
|
return None
|
|
return x == y
|
|
|
|
|
|
def simple_floordiv_gcd(p, q):
|
|
"""
|
|
Fast path for sympy.gcd, using a simple factoring strategy.
|
|
|
|
We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0,
|
|
where n is the greatest common integer factor and e is the largest
|
|
syntactic common factor (i.e., common sub-expression) in p and q.
|
|
Then the gcd returned is n*e, cancelling which we would be left with
|
|
p1 + p2 and q0.
|
|
|
|
Note that further factoring of p1 + p2 and q0 might be possible with
|
|
sympy.factor (which uses domain-specific theories). E.g., we are unable
|
|
to find that x*y + x + y + 1 is divisible by x + 1. More generally,
|
|
when q is of the form q1 + q2 (instead of being already factored) it
|
|
might be necessary to fall back on sympy.gcd.
|
|
"""
|
|
|
|
def integer_coefficient(x):
|
|
integer_coefficients = [
|
|
abs(int(arg))
|
|
for arg in sympy.Mul.make_args(x)
|
|
if isinstance(arg, (int, sympy.Integer))
|
|
]
|
|
return math.prod(integer_coefficients)
|
|
|
|
def integer_factor(expr):
|
|
integer_factors = map(integer_coefficient, sympy.Add.make_args(expr))
|
|
return functools.reduce(math.gcd, integer_factors)
|
|
|
|
gcd = math.gcd(integer_factor(p), integer_factor(q))
|
|
p, q = p / gcd, q / gcd
|
|
|
|
base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p)))
|
|
divisor_split = sympy.Mul.make_args(q)
|
|
for x in divisor_split:
|
|
if all(x in base_split for base_split in base_splits):
|
|
gcd = gcd * x
|
|
return gcd
|
|
|
|
|
|
# It would be nice to have assertions on whether or not inputs is_integer
|
|
# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
|
|
# sometimes inconsistently reports floats an integers.
|
|
#
|
|
# What we can assume from sympy is that if something is an int, it
|
|
# definitely is is_integer, but if it is a float it may or may not
|
|
# be is_integer. So we are unable to do strong asserts that things
|
|
# are NOT integers.
|
|
|
|
|
|
# TODO: In Triton, // rounds to zero, but in Python, it is floor division.
|
|
# When we can prove both arguments are non-negative, we should just have a
|
|
# GenericFloorDiv (name pending) which can codegen efficiently in Python/C,
|
|
# and then PythonFloorDiv and CIntDiv which have the appropriate rounding
|
|
# semantics.
|
|
#
|
|
# Right now, FloorDiv de facto changes behavior if arguments are negative or
|
|
# not, this can potentially cause correctness issues.
|
|
class FloorDiv(sympy.Function):
|
|
"""
|
|
We maintain this so that:
|
|
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
|
|
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
|
|
|
|
NB: This is Python-style floor division, round to -Inf
|
|
"""
|
|
|
|
nargs = (2,)
|
|
precedence = 50 # precedence of mul # noqa: F811
|
|
|
|
is_integer = True
|
|
|
|
@property
|
|
def base(self):
|
|
return self.args[0]
|
|
|
|
@property
|
|
def divisor(self):
|
|
return self.args[1]
|
|
|
|
def _sympystr(self, printer):
|
|
base = printer.parenthesize(self.base, self.precedence)
|
|
divisor = printer.parenthesize(self.divisor, self.precedence)
|
|
return f"({base}//{divisor})"
|
|
|
|
# Automatic evaluation.
|
|
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
|
|
@classmethod
|
|
def eval(cls, base, divisor):
|
|
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
|
|
# Assert triggered by inequality solver
|
|
# assert base.is_integer, base
|
|
# assert divisor.is_integer, divisor
|
|
|
|
# We don't provide the same error message as in Python because SymPy
|
|
# makes it difficult to check the types.
|
|
if divisor.is_zero:
|
|
raise ZeroDivisionError("division by zero")
|
|
if base in (int_oo, -int_oo, sympy.oo, -sympy.oo) and divisor in (
|
|
int_oo,
|
|
-int_oo,
|
|
sympy.oo,
|
|
-sympy.oo,
|
|
):
|
|
return sympy.nan
|
|
if base is sympy.nan or divisor is sympy.nan:
|
|
return sympy.nan
|
|
|
|
if base.is_zero:
|
|
return sympy.S.Zero
|
|
if base.is_integer and equal_valued(divisor, 1):
|
|
return base
|
|
if base.is_integer and equal_valued(divisor, -1):
|
|
return sympy.Mul(base, -1)
|
|
if (
|
|
isinstance(base, sympy.Number)
|
|
and isinstance(divisor, sympy.Number)
|
|
and (
|
|
base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
|
|
or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
|
|
)
|
|
):
|
|
r = float(base) / float(divisor)
|
|
if r == math.inf:
|
|
return int_oo
|
|
elif r == -math.inf:
|
|
return -int_oo
|
|
elif math.isnan(r):
|
|
return sympy.nan
|
|
else:
|
|
return sympy.Integer(math.floor(r))
|
|
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
|
|
return sympy.Integer(int(base) // int(divisor))
|
|
if isinstance(base, FloorDiv):
|
|
return FloorDiv(base.args[0], base.args[1] * divisor)
|
|
|
|
# Expands (x + y) // b into x // b + y // b.
|
|
# This only works if floor is an identity, i.e. x / b is an integer.
|
|
for term in sympy.Add.make_args(base):
|
|
quotient = term / divisor
|
|
if quotient.is_integer and isinstance(divisor, sympy.Integer):
|
|
# NB: this is correct even if the divisor is not an integer, but it
|
|
# creates rational expressions that cause problems with dynamic
|
|
# shapes.
|
|
return FloorDiv(base - term, divisor) + quotient
|
|
|
|
try:
|
|
gcd = simple_floordiv_gcd(base, divisor)
|
|
if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add):
|
|
gcd = sympy.gcd(base, divisor)
|
|
if not equal_valued(gcd, 1):
|
|
return FloorDiv(
|
|
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
|
|
)
|
|
except sympy.PolynomialError:
|
|
pass # https://github.com/pytorch/pytorch/issues/108276
|
|
|
|
|
|
class ModularIndexing(sympy.Function):
|
|
"""
|
|
ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus
|
|
"""
|
|
|
|
nargs = (3,)
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, base, divisor, modulus):
|
|
if base == 0 or modulus == 1:
|
|
return sympy.Integer(0)
|
|
|
|
if (
|
|
isinstance(base, sympy.Integer)
|
|
and isinstance(divisor, sympy.Integer)
|
|
and isinstance(modulus, sympy.Integer)
|
|
):
|
|
return (base // divisor) % modulus
|
|
|
|
try:
|
|
if divisor != 1:
|
|
gcd = sympy.gcd(base, divisor)
|
|
if gcd != 1:
|
|
return ModularIndexing(
|
|
sympy.simplify(base / gcd),
|
|
sympy.simplify(divisor / gcd),
|
|
modulus,
|
|
)
|
|
except sympy.PolynomialError:
|
|
pass # https://github.com/pytorch/pytorch/issues/108276
|
|
|
|
if isinstance(base, sympy.Add):
|
|
new_terms = []
|
|
all_positive = True
|
|
for term in base.args:
|
|
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
|
|
if (isinstance(term, sympy.Integer) and term < 0) or (
|
|
isinstance(term, sympy.Mul)
|
|
and isinstance(term.args[0], sympy.Integer)
|
|
and term.args[0] < 0
|
|
):
|
|
# workaround for https://github.com/openai/triton/issues/619,
|
|
# if there are negative terms, // produces wrong result
|
|
# TODO if https://github.com/openai/triton/issues/619 is fixed
|
|
# this optimization would become valid
|
|
all_positive = False
|
|
break
|
|
else:
|
|
new_terms.append(term)
|
|
|
|
if len(new_terms) != len(base.args) and all_positive:
|
|
return ModularIndexing(sum(new_terms), divisor, modulus)
|
|
|
|
if isinstance(base, FloorDiv):
|
|
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
|
|
|
|
def _eval_is_nonnegative(self):
|
|
p, q = self.args[:2]
|
|
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
|
|
|
|
def _eval_is_positive(self):
|
|
p, q = self.args[:2]
|
|
return fuzzy_eq(p.is_positive, q.is_positive) # type: ignore[attr-defined]
|
|
|
|
|
|
class Where(sympy.Function):
|
|
"""
|
|
Good ol' ternary operator
|
|
"""
|
|
|
|
nargs = (3,)
|
|
|
|
def _eval_is_integer(self):
|
|
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
|
|
|
|
def _eval_is_nonnegative(self):
|
|
return (
|
|
True
|
|
if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
|
|
else None
|
|
)
|
|
|
|
def _eval_is_positive(self):
|
|
return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
|
|
|
|
@classmethod
|
|
def eval(cls, c, p, q):
|
|
if c == sympy.true:
|
|
return p
|
|
elif c == sympy.false:
|
|
return q
|
|
|
|
|
|
# Python-style modulus: take sign from RHS
|
|
class PythonMod(sympy.Function):
|
|
nargs = (2,)
|
|
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, p, q):
|
|
# python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
|
|
# Triggered by sympy.solvers.inequalities.reduce_inequalities
|
|
# assert p.is_integer, p
|
|
# assert q.is_integer, q
|
|
|
|
if q.is_zero:
|
|
raise ZeroDivisionError("Modulo by zero")
|
|
|
|
# Three cases:
|
|
# 1. p == 0
|
|
# 2. p is either q or -q
|
|
# 3. p is integer and q == 1
|
|
if p is S.Zero or p in (q, -q) or q == 1:
|
|
return S.Zero
|
|
|
|
# Evaluate if they are both literals.
|
|
if q.is_Number and p.is_Number:
|
|
return p % q
|
|
|
|
# If q == 2, it's a matter of whether p is odd or even.
|
|
if q.is_Number and q == 2:
|
|
if p.is_even:
|
|
return S.Zero
|
|
if p.is_odd:
|
|
return S.One
|
|
|
|
# If p is a multiple of q.
|
|
r = p / q
|
|
if r.is_integer:
|
|
return S.Zero
|
|
|
|
# If p < q and its ratio is positive, then:
|
|
# - floor(p / q) = 0
|
|
# - p % q = p - floor(p / q) * q = p
|
|
less = p < q
|
|
if less.is_Boolean and bool(less) and r.is_positive:
|
|
return p
|
|
|
|
if sympy.Mod(p, q) == 0:
|
|
return S.Zero
|
|
|
|
# NB: args[1] for PythonMod
|
|
def _eval_is_nonnegative(self):
|
|
return True if self.args[1].is_positive else None # type: ignore[attr-defined]
|
|
|
|
def _eval_is_nonpositive(self):
|
|
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
|
|
|
|
|
|
# Generic modulus: only defined on non-negative arguments
|
|
class Mod(sympy.Function):
|
|
nargs = (2,)
|
|
|
|
is_integer = True
|
|
is_nonnegative = True
|
|
|
|
@classmethod
|
|
def eval(cls, p, q):
|
|
# This was adapted from: sympy/core/mod.py
|
|
|
|
# Triggered by
|
|
# python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
|
|
# assert p.is_integer, p
|
|
# assert q.is_integer, q
|
|
|
|
if q.is_zero:
|
|
raise ZeroDivisionError("Modulo by zero")
|
|
|
|
# Three cases:
|
|
# 1. p == 0
|
|
# 2. p is either q or -q
|
|
# 3. p is integer and q == 1
|
|
if p is S.Zero or p in (q, -q) or q == 1:
|
|
return S.Zero
|
|
|
|
# Evaluate if they are both literals.
|
|
if q.is_Number and p.is_Number:
|
|
assert p >= 0, p
|
|
assert q >= 1, q
|
|
return p % q
|
|
|
|
# If q == 2, it's a matter of whether p is odd or even.
|
|
if q.is_Number and q == 2:
|
|
if p.is_even:
|
|
return S.Zero
|
|
if p.is_odd:
|
|
return S.One
|
|
|
|
# If p is a multiple of q.
|
|
r = p / q
|
|
if r.is_integer:
|
|
return S.Zero
|
|
|
|
# If p < q and its ratio is positive, then:
|
|
# - floor(p / q) = 0
|
|
# - p % q = p - floor(p / q) * q = p
|
|
less = p < q
|
|
if less.is_Boolean and bool(less) and r.is_positive:
|
|
return p
|
|
|
|
|
|
class CleanDiv(FloorDiv):
|
|
"""
|
|
Div where we can assume no rounding.
|
|
This is to enable future optimizations.
|
|
"""
|
|
|
|
|
|
# Don't use sympy ceiling/floor as they will attempt simplifications involving
|
|
# frac
|
|
class CeilToInt(sympy.Function):
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, number):
|
|
# assert number.is_integer is not True, number
|
|
if number in (sympy.oo, int_oo):
|
|
return int_oo
|
|
if number in (-sympy.oo, -int_oo):
|
|
return -int_oo
|
|
if isinstance(number, sympy.Number):
|
|
return sympy.Integer(math.ceil(float(number)))
|
|
|
|
|
|
class FloorToInt(sympy.Function):
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, number):
|
|
# assert number.is_integer is not True, number
|
|
if number in (sympy.oo, int_oo):
|
|
return int_oo
|
|
if number in (-sympy.oo, int_oo):
|
|
return -int_oo
|
|
if isinstance(number, sympy.Number):
|
|
return sympy.Integer(math.floor(float(number)))
|
|
|
|
|
|
class CeilDiv(sympy.Function):
|
|
"""
|
|
Div used in indexing that rounds up.
|
|
"""
|
|
|
|
is_integer = True
|
|
|
|
def __new__(cls, base, divisor):
|
|
base = sympy.sympify(base)
|
|
divisor = sympy.sympify(divisor)
|
|
if sympy.gcd(base, divisor) == divisor:
|
|
return CleanDiv(base, divisor)
|
|
else:
|
|
return FloorDiv(base + (divisor - 1), divisor)
|
|
|
|
|
|
class LShift(sympy.Function):
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, base, shift):
|
|
if shift < 0:
|
|
raise ValueError("negative shift count")
|
|
return base * 2**shift
|
|
|
|
|
|
class RShift(sympy.Function):
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, base, shift):
|
|
if shift < 0:
|
|
raise ValueError("negative shift count")
|
|
return base // 2**shift
|
|
|
|
|
|
class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
|
def __new__(cls, *args, **assumptions):
|
|
from sympy.core.parameters import global_parameters
|
|
|
|
evaluate = assumptions.pop("evaluate", global_parameters.evaluate)
|
|
args = (sympify(arg) for arg in args)
|
|
|
|
# first standard filter, for cls.zero and cls.identity
|
|
# also reshape Max(a, Max(b, c)) to Max(a, b, c)
|
|
|
|
if evaluate:
|
|
try:
|
|
args = frozenset(cls._new_args_filter(args)) # type: ignore[assignment]
|
|
except ShortCircuit:
|
|
return cls.zero # type: ignore[attr-defined]
|
|
# remove redundant args that are easily identified
|
|
args = cls._collapse_arguments(args, **assumptions)
|
|
# find local zeros
|
|
args = cls._find_localzeros(args, **assumptions)
|
|
args = frozenset(args)
|
|
|
|
if not args:
|
|
return cls.identity # type: ignore[attr-defined]
|
|
|
|
if len(args) == 1:
|
|
return list(args).pop()
|
|
|
|
# base creation
|
|
obj = Expr.__new__(cls, *ordered(args), **assumptions)
|
|
obj._argset = args
|
|
return obj
|
|
|
|
@classmethod
|
|
def _collapse_arguments(cls, args, **assumptions):
|
|
"""Remove redundant args.
|
|
|
|
Examples
|
|
========
|
|
|
|
>>> from sympy import Min, Max
|
|
>>> from sympy.abc import a, b, c, d, e
|
|
|
|
Any arg in parent that appears in any
|
|
parent-like function in any of the flat args
|
|
of parent can be removed from that sub-arg:
|
|
|
|
>>> Min(a, Max(b, Min(a, c, d)))
|
|
Min(a, Max(b, Min(c, d)))
|
|
|
|
If the arg of parent appears in an opposite-than parent
|
|
function in any of the flat args of parent that function
|
|
can be replaced with the arg:
|
|
|
|
>>> Min(a, Max(b, Min(c, d, Max(a, e))))
|
|
Min(a, Max(b, Min(a, c, d)))
|
|
"""
|
|
if not args:
|
|
return args
|
|
args = list(ordered(args))
|
|
if cls is Min:
|
|
other = Max
|
|
else:
|
|
other = Min # type: ignore[assignment]
|
|
|
|
# find global comparable max of Max and min of Min if a new
|
|
# value is being introduced in these args at position 0 of
|
|
# the ordered args
|
|
if args[0].is_number:
|
|
sifted = mins, maxs = [], [] # type: ignore[var-annotated]
|
|
for i in args:
|
|
for v in walk(i, Min, Max):
|
|
if v.args[0].is_comparable:
|
|
sifted[isinstance(v, Max)].append(v)
|
|
small = Min.identity
|
|
for i in mins:
|
|
v = i.args[0]
|
|
if v.is_number and (v < small) == True: # noqa: E712
|
|
small = v
|
|
big = Max.identity
|
|
for i in maxs:
|
|
v = i.args[0]
|
|
if v.is_number and (v > big) == True: # noqa: E712
|
|
big = v
|
|
# at the point when this function is called from __new__,
|
|
# there may be more than one numeric arg present since
|
|
# local zeros have not been handled yet, so look through
|
|
# more than the first arg
|
|
if cls is Min:
|
|
for arg in args:
|
|
if not arg.is_number:
|
|
break
|
|
if (arg < small) == True: # noqa: E712
|
|
small = arg
|
|
elif cls == Max:
|
|
for arg in args:
|
|
if not arg.is_number:
|
|
break
|
|
if (arg > big) == True: # noqa: E712
|
|
big = arg
|
|
T = None
|
|
if cls is Min:
|
|
if small != Min.identity:
|
|
other = Max
|
|
T = small
|
|
elif big != Max.identity:
|
|
other = Min # type: ignore[assignment]
|
|
T = big
|
|
if T is not None:
|
|
# remove numerical redundancy
|
|
for i in range(len(args)):
|
|
a = args[i]
|
|
if isinstance(a, other):
|
|
a0 = a.args[0]
|
|
if ( # noqa: E712
|
|
(a0 > T) if other == Max else (a0 < T) # noqa: E712
|
|
) == True: # noqa: E712
|
|
args[i] = cls.identity # type: ignore[attr-defined]
|
|
|
|
# remove redundant symbolic args
|
|
def do(ai, a):
|
|
if not isinstance(ai, (Min, Max)):
|
|
return ai
|
|
cond = a in ai.args
|
|
if not cond:
|
|
return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
|
|
if isinstance(ai, cls):
|
|
return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
|
|
return a
|
|
|
|
for i, a in enumerate(args):
|
|
args[i + 1 :] = [do(ai, a) for ai in args[i + 1 :]]
|
|
|
|
# factor out common elements as for
|
|
# Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z))
|
|
# and vice versa when swapping Min/Max -- do this only for the
|
|
# easy case where all functions contain something in common;
|
|
# trying to find some optimal subset of args to modify takes
|
|
# too long
|
|
|
|
def factor_minmax(args):
|
|
is_other = lambda arg: isinstance(arg, other) # noqa: E731
|
|
other_args, remaining_args = sift(args, is_other, binary=True)
|
|
if not other_args:
|
|
return args
|
|
|
|
# Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v})
|
|
arg_sets = [set(arg.args) for arg in other_args]
|
|
common = set.intersection(*arg_sets)
|
|
if not common:
|
|
return args
|
|
|
|
new_other_args = list(common)
|
|
arg_sets_diff = [arg_set - common for arg_set in arg_sets]
|
|
|
|
# If any set is empty after removing common then all can be
|
|
# discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b)
|
|
if all(arg_sets_diff):
|
|
other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff]
|
|
new_other_args.append(cls(*other_args_diff, evaluate=False))
|
|
|
|
other_args_factored = other(*new_other_args, evaluate=False)
|
|
return remaining_args + [other_args_factored]
|
|
|
|
if len(args) > 1:
|
|
args = factor_minmax(args)
|
|
|
|
return args
|
|
|
|
@classmethod
|
|
def _new_args_filter(cls, arg_sequence):
|
|
"""
|
|
Generator filtering args.
|
|
|
|
first standard filter, for cls.zero and cls.identity.
|
|
Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``,
|
|
and check arguments for comparability
|
|
"""
|
|
for arg in arg_sequence:
|
|
# pre-filter, checking comparability of arguments
|
|
if (
|
|
not isinstance(arg, Expr)
|
|
or arg.is_extended_real is False
|
|
or (arg.is_number and not arg.is_comparable)
|
|
):
|
|
raise ValueError(f"The argument '{arg}' is not comparable.")
|
|
|
|
if arg == cls.zero: # type: ignore[attr-defined]
|
|
raise ShortCircuit(arg)
|
|
elif arg == cls.identity: # type: ignore[attr-defined]
|
|
continue
|
|
elif arg.func == cls:
|
|
yield from arg.args
|
|
else:
|
|
yield arg
|
|
|
|
@classmethod
|
|
def _find_localzeros(cls, values, **options):
|
|
"""
|
|
Sequentially allocate values to localzeros.
|
|
|
|
When a value is identified as being more extreme than another member it
|
|
replaces that member; if this is never true, then the value is simply
|
|
appended to the localzeros.
|
|
"""
|
|
localzeros = set() # type: ignore[var-annotated]
|
|
for v in values:
|
|
is_newzero = True
|
|
localzeros_ = list(localzeros)
|
|
for z in localzeros_:
|
|
if id(v) == id(z):
|
|
is_newzero = False
|
|
else:
|
|
con = cls._is_connected(v, z)
|
|
if con:
|
|
is_newzero = False
|
|
if con is True or con == cls:
|
|
localzeros.remove(z)
|
|
localzeros.update([v])
|
|
if is_newzero:
|
|
localzeros.update([v])
|
|
return localzeros
|
|
|
|
@classmethod
|
|
def _is_connected(cls, x, y):
|
|
"""
|
|
Check if x and y are connected somehow.
|
|
"""
|
|
if x == y:
|
|
return True
|
|
t, f = Max, Min
|
|
for op in "><":
|
|
for j in range(2):
|
|
try:
|
|
if op == ">":
|
|
v = x >= y
|
|
else:
|
|
v = x <= y
|
|
except TypeError:
|
|
return False # non-real arg
|
|
if not v.is_Relational:
|
|
return t if v else f
|
|
t, f = f, t # type: ignore[assignment]
|
|
x, y = y, x
|
|
x, y = y, x # run next pass with reversed order relative to start
|
|
|
|
return False
|
|
|
|
_eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731
|
|
_eval_is_antihermitian = lambda s: _torf( # noqa: E731
|
|
i.is_antihermitian for i in s.args # noqa: E731
|
|
) # noqa: E731
|
|
_eval_is_commutative = lambda s: _torf( # noqa: E731
|
|
i.is_commutative for i in s.args # noqa: E731
|
|
) # noqa: E731
|
|
_eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731
|
|
_eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731
|
|
_eval_is_even = lambda s: _torf(i.is_even for i in s.args) # noqa: E731
|
|
_eval_is_finite = lambda s: _torf(i.is_finite for i in s.args) # noqa: E731
|
|
_eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args) # noqa: E731
|
|
_eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args) # noqa: E731
|
|
_eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args) # noqa: E731
|
|
_eval_is_integer = lambda s: _torf(i.is_integer for i in s.args) # noqa: E731
|
|
_eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args) # noqa: E731
|
|
_eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731
|
|
_eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731
|
|
_eval_is_nonnegative = lambda s: _torf( # noqa: E731
|
|
i.is_nonnegative for i in s.args # noqa: E731
|
|
) # noqa: E731
|
|
_eval_is_nonpositive = lambda s: _torf( # noqa: E731
|
|
i.is_nonpositive for i in s.args # noqa: E731
|
|
) # noqa: E731
|
|
_eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731
|
|
_eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731
|
|
_eval_is_polar = lambda s: _torf(i.is_polar for i in s.args) # noqa: E731
|
|
_eval_is_positive = lambda s: _torf(i.is_positive for i in s.args) # noqa: E731
|
|
_eval_is_prime = lambda s: _torf(i.is_prime for i in s.args) # noqa: E731
|
|
_eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731
|
|
_eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731
|
|
_eval_is_extended_real = lambda s: _torf( # noqa: E731
|
|
i.is_extended_real for i in s.args # noqa: E731
|
|
) # noqa: E731
|
|
_eval_is_transcendental = lambda s: _torf( # noqa: E731
|
|
i.is_transcendental for i in s.args # noqa: E731
|
|
) # noqa: E731
|
|
_eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731
|
|
|
|
|
|
class Max(MinMaxBase, Application): # type: ignore[misc]
|
|
r"""
|
|
Return, if possible, the maximum value of the list.
|
|
"""
|
|
zero = S.Infinity
|
|
identity = S.NegativeInfinity
|
|
|
|
def _eval_is_positive(self):
|
|
return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined]
|
|
|
|
def _eval_is_nonnegative(self):
|
|
return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
|
|
|
def _eval_is_negative(self):
|
|
return fuzzy_and(a.is_negative for a in self.args)
|
|
|
|
|
|
class Min(MinMaxBase, Application): # type: ignore[misc]
|
|
"""
|
|
Return, if possible, the minimum value of the list.
|
|
"""
|
|
|
|
zero = S.NegativeInfinity
|
|
identity = S.Infinity
|
|
|
|
def _eval_is_positive(self):
|
|
return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined]
|
|
|
|
def _eval_is_nonnegative(self):
|
|
return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
|
|
|
def _eval_is_negative(self):
|
|
return fuzzy_or(a.is_negative for a in self.args)
|
|
|
|
|
|
def safe_pow(base, exp):
|
|
sign = 1
|
|
if base < 0:
|
|
base = -base
|
|
sign = 1 if exp % 2 == 0 else -1
|
|
return sign * _safe_pow(base, exp)
|
|
|
|
|
|
# Prevent people from overflowing pow
|
|
def _safe_pow(base, exponent):
|
|
if exponent < 0:
|
|
raise ValueError("Exponent must be non-negative.")
|
|
|
|
if exponent == 0:
|
|
return 1
|
|
|
|
half_exp = safe_pow(base, exponent // 2)
|
|
if half_exp is int_oo:
|
|
return int_oo
|
|
|
|
# TODO: microoptimization is to avoid overflowing into arbitrary precision
|
|
# and detect overflow prior to doing operations
|
|
|
|
result = half_exp * half_exp
|
|
if result > sys.maxsize:
|
|
return int_oo
|
|
|
|
if exponent % 2 == 1:
|
|
result *= base
|
|
if result > sys.maxsize:
|
|
return int_oo
|
|
|
|
return result
|
|
|
|
|
|
class PowByNatural(sympy.Function):
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, base, exp):
|
|
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
|
|
r = safe_pow(base, exp)
|
|
if r in (-int_oo, int_oo):
|
|
return r
|
|
return sympy.Integer(r)
|
|
if isinstance(exp, sympy.Integer):
|
|
# Rely on regular sympy Pow for this (note that iterated
|
|
# multiplication turns into a Pow anyway, you can't escape!!)
|
|
return sympy.Pow(base, exp)
|
|
if exp in (int_oo, sympy.oo):
|
|
if base.is_nonnegative:
|
|
return int_oo
|
|
elif base.is_negative:
|
|
return sympy.zoo # this is apparently what (-2)**sympy.oo does
|
|
# NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
|
|
# is a natural number if we do
|
|
|
|
|
|
# base is assumed to be nonnegative, thereby prevent complex numbers from
|
|
# occuring
|
|
class FloatPow(sympy.Function):
|
|
is_real = True
|
|
|
|
@classmethod
|
|
def eval(cls, base, exp):
|
|
# NB: These test sympy.Number, not sympy.Float, because:
|
|
# - Sometimes we may have sympy.oo or int_oo, and that's not a Float
|
|
# (but coerces to math.Inf)
|
|
# - Sometimes Float(0.0) will unpredictably decay to Integer(0),
|
|
# but we should still accept it in floatey contexts
|
|
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
|
|
return sympy.Float(float(base) ** float(exp))
|
|
# NB: do not do any nontrivial reasoning
|
|
|
|
|
|
# Overloaded to be compatible with regular Python.
|
|
# https://github.com/pytorch/pytorch/issues/90900
|
|
#
|
|
# In particular, sympy division is willing to simplify x/x == 1
|
|
# where 1 is an integer, but this must be a float if x was float.
|
|
class FloatTrueDiv(sympy.Function):
|
|
is_real = True
|
|
|
|
@classmethod
|
|
def eval(cls, base, divisor):
|
|
# assert base.is_integer is not True, base
|
|
# assert divisor.is_integer is not True, divisor
|
|
|
|
if divisor.is_zero:
|
|
raise ZeroDivisionError("division by zero")
|
|
|
|
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
|
|
return sympy.Float(float(base) / float(divisor))
|
|
|
|
|
|
# Overloaded to be compatible with regular Python. We distinguish this from
|
|
# FloatTrueDiv, because the code generation has to be different for this case:
|
|
# Python has a fancy algorithm for integer true division that isn't just
|
|
# "promote both arguments to float and use float division", so you need to
|
|
# codegen it differently. While technically you can work it out from the
|
|
# types of the input, this is often inconvenient to do in Inductor codegen,
|
|
# so just have a different operator
|
|
# NB: Right now, Inductor codegen doesn't implement this correctly lol
|
|
class IntTrueDiv(sympy.Function):
|
|
is_real = True
|
|
|
|
@classmethod
|
|
def eval(cls, base, divisor):
|
|
if divisor.is_zero:
|
|
raise ZeroDivisionError("division by zero")
|
|
|
|
if (
|
|
isinstance(base, sympy.Number)
|
|
and isinstance(divisor, sympy.Number)
|
|
and (
|
|
base in (int_oo, -int_oo, sympy.oo, -sympy.oo)
|
|
or divisor in (int_oo, -int_oo, sympy.oo, -sympy.oo)
|
|
)
|
|
):
|
|
# Don't have to worry about precision here, you're getting zero or
|
|
# inf from the division
|
|
return sympy.Float(float(base) / float(divisor))
|
|
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
|
|
return sympy.Float(int(base) / int(divisor))
|
|
|
|
|
|
# TODO: As an indicator, this != 0 implies == 1 (and vice versa).
|
|
# Because we do not have the ability to guard on the stride permutation
|
|
# at the moment, it is hard to make further inferences when this is true,
|
|
# as although we know the tensor is contiguous in *some* layout, we don't
|
|
# know which one (however, you could, for example, make the inference that
|
|
# reshaping this to a 1D tensor can be guard-free.)
|
|
class IsNonOverlappingAndDenseIndicator(sympy.Function):
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, *args):
|
|
assert len(args) % 2 == 0
|
|
dim = len(args) // 2
|
|
sizes = args[0:dim]
|
|
strides = args[dim:]
|
|
|
|
# sym_node imported in torch.__init__. Local import to avoid an import cycle
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
eval_is_non_overlapping_and_dense,
|
|
)
|
|
|
|
if all(isinstance(a, sympy.Integer) for a in args):
|
|
return eval_is_non_overlapping_and_dense(
|
|
[int(a) for a in sizes], [int(a) for a in strides]
|
|
)
|
|
|
|
if dim == 1:
|
|
# Manually implement the rank one short circuit
|
|
if strides[0].is_Number and strides[0] == 1:
|
|
return 1
|
|
|
|
if sizes[0].is_Number and sizes[0] < 2:
|
|
return 1
|
|
|
|
# return 0 case covered by case above
|
|
|
|
# TODO: Inability to access size-obliviousness sucks: if we have a
|
|
# size oblivious test on a size-like unbacked SymInt, we could
|
|
# confidently return zero when we have a size-like u0 stride
|
|
# and a size-like u1 size. Maybe a fancy ValueRanges analysis for
|
|
# this function could help figure this out.
|
|
|
|
if all(isinstance(a, sympy.Integer) for a in strides):
|
|
assert dim != 0
|
|
# When all strides are integral, we can sort, and the size for the
|
|
# largest stride doesn't matter and can be arbitrarily symbolic
|
|
s_sizes, s_strides = zip(
|
|
*sorted(zip(sizes, strides), key=operator.itemgetter(1))
|
|
)
|
|
# Put something arbitrary in the max size spot, it'll be ignored
|
|
if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]):
|
|
s_sizes = s_sizes[:-1] + (42,)
|
|
# We can reuse the regular eval, because it is invariant to
|
|
# permutation of dimensions
|
|
return eval_is_non_overlapping_and_dense(
|
|
[int(a) for a in s_sizes], [int(a) for a in s_strides]
|
|
)
|
|
|
|
return None
|
|
|
|
|
|
# NB: this is inconsistent with math.trunc in Python
|
|
class TruncToFloat(sympy.Function):
|
|
is_real = True
|
|
|
|
@classmethod
|
|
def eval(cls, number):
|
|
# assert number.is_integer is not True, number
|
|
if isinstance(number, sympy.Number):
|
|
# NB: It is safe to use truncation to integer, which is what
|
|
# math.trunc does, as Python integers are arbitrary precision and
|
|
# so we are guaranteed not to lose precision when we do this
|
|
return sympy.Float(math.trunc(float(number)))
|
|
|
|
|
|
class TruncToInt(sympy.Function):
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, number):
|
|
# assert number.is_integer is not True, number
|
|
if number in (sympy.oo, int_oo):
|
|
return int_oo
|
|
if number in (-sympy.oo, -int_oo):
|
|
return -int_oo
|
|
if isinstance(number, sympy.Number):
|
|
return sympy.Integer(math.trunc(float(number)))
|
|
|
|
|
|
# This is float -> int
|
|
class RoundToInt(sympy.Function):
|
|
is_integer = True
|
|
|
|
@classmethod
|
|
def eval(cls, number):
|
|
# assert number.is_integer is not True, number
|
|
|
|
if number is sympy.oo:
|
|
return int_oo
|
|
if number is -sympy.oo:
|
|
return -int_oo
|
|
if isinstance(number, sympy.Number):
|
|
return sympy.Integer(round(float(number), 0))
|
|
|
|
|
|
# To get float -> int, Python style round semantics.
|
|
#
|
|
# x = PyFloat_AsDouble(self);
|
|
# if (o_ndigits == Py_None) {
|
|
# /* single-argument round or with None ndigits:
|
|
# * round to nearest integer */
|
|
# rounded = round(x);
|
|
# if (fabs(x-rounded) == 0.5)
|
|
# /* halfway case: round to even */
|
|
# rounded = 2.0*round(x/2.0);
|
|
# return PyLong_FromDouble(rounded);
|
|
# }
|
|
|
|
|
|
# NB: Like Round, this only ever returns floats. ndigits cannot be None
|
|
class RoundDecimal(sympy.Function):
|
|
is_real = True
|
|
|
|
@classmethod
|
|
def eval(cls, number, ndigits):
|
|
# assert number.is_integer is not True, number
|
|
|
|
if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
|
|
return sympy.Float(round(float(number), int(ndigits)))
|
|
|
|
|
|
class ToFloat(sympy.Function):
|
|
is_real = True
|
|
|
|
@classmethod
|
|
def eval(cls, number):
|
|
if number in [sympy.oo, -sympy.oo]:
|
|
return number
|
|
|
|
if isinstance(number, sympy.Integer):
|
|
return sympy.Float(int(number))
|
|
if number is int_oo:
|
|
return sympy.oo
|
|
if number is -int_oo:
|
|
return -sympy.oo
|
|
|
|
|
|
class Identity(sympy.Function):
|
|
"""
|
|
Prevents expansion and other optimizations
|
|
"""
|
|
|
|
def __repr__(self):
|
|
return f"Identity({self.args[0]})"
|
|
|
|
def _eval_is_real(self):
|
|
return self.args[0].is_real
|
|
|
|
def _eval_is_integer(self):
|
|
return self.args[0].is_integer # type: ignore[attr-defined]
|
|
|
|
|
|
def make_opaque_unary_fn(name):
|
|
class OpaqueUnaryFn(sympy.Function):
|
|
"""
|
|
Unlike the builtin sympy functions on real numbers like sympy.sqrt,
|
|
these equivalents do not do any nontrivial reasoning besides
|
|
constant propagation. This helps avoid performing transformations
|
|
that are valid for real numbers but are invalid for floating point;
|
|
in particular, while we are willing to make optimizations that change
|
|
numerics for Tensor compute, we are NOT willing to make optimziations
|
|
that change numerics for size compute.
|
|
"""
|
|
|
|
_torch_handler_name = name
|
|
|
|
@classmethod
|
|
def eval(cls, a):
|
|
if isinstance(a, (sympy.Integer, sympy.Float)):
|
|
# Python converts to float64 before computing, c.f.
|
|
# >>> math.sin(2**53+1)
|
|
# -0.848925964814655
|
|
# >>> math.sin(float(2**53+1))
|
|
# -0.848925964814655
|
|
try:
|
|
return sympy.Float(getattr(math, name)(float(a)))
|
|
# Just use sympy semantics for infinity/overflow, you might get some
|
|
# weird objects but ask silly questions, get silly answers
|
|
except OverflowError:
|
|
return getattr(sympy, name)(a)
|
|
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
|
|
if a is int_oo:
|
|
a = sympy.oo
|
|
if a is -int_oo:
|
|
a = -sympy.oo
|
|
return getattr(sympy, name)(a)
|
|
return None
|
|
|
|
OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name
|
|
|
|
return OpaqueUnaryFn
|
|
|
|
|
|
# Keep in sync with math_op_names in torch/fx/experimental/sym_node.py
|
|
OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt")
|
|
OpaqueUnaryFn_cos = make_opaque_unary_fn("cos")
|
|
OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh")
|
|
OpaqueUnaryFn_sin = make_opaque_unary_fn("sin")
|
|
OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh")
|
|
OpaqueUnaryFn_tan = make_opaque_unary_fn("tan")
|
|
OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh")
|
|
OpaqueUnaryFn_asin = make_opaque_unary_fn("asin")
|
|
OpaqueUnaryFn_acos = make_opaque_unary_fn("acos")
|
|
OpaqueUnaryFn_atan = make_opaque_unary_fn("atan")
|
|
OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
|
|
OpaqueUnaryFn_log = make_opaque_unary_fn("log")
|
|
OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
|