Files
pytorch/torch/utils/_sympy/functions.py
2023-07-05 18:35:57 +00:00

175 lines
5.7 KiB
Python

import sympy
from sympy.core.logic import fuzzy_and, fuzzy_or
__all__ = ["FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", "LShift", "RShift"]
class FloorDiv(sympy.Function):
"""
We maintain this so that:
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
"""
nargs = (2,)
precedence = 50 # precedence of mul # noqa: F811
# Default return type for SymPy assumptions.
# https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers
is_real = True
@property
def base(self):
return self.args[0]
@property
def divisor(self):
return self.args[1]
def _sympystr(self, printer):
base = printer.parenthesize(self.base, self.precedence)
divisor = printer.parenthesize(self.divisor, self.precedence)
return f"({base}//{divisor})"
# SymPy assumptions based on argument types.
def _eval_is_real(self):
return fuzzy_or([self.base.is_real, self.divisor.is_real])
def _eval_is_integer(self):
return fuzzy_and([self.base.is_integer, self.divisor.is_integer])
# Automatic evaluation.
# https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
@classmethod
def eval(cls, base, divisor):
def check_supported_type(x):
if (x.is_integer is False and x.is_real is False and x.is_complex) or x.is_Boolean:
raise TypeError(
f"unsupported operand type(s) for //: "
f"'{type(base).__name__}' and '{type(divisor).__name__}'"
f", expected integer or real")
check_supported_type(base)
check_supported_type(divisor)
# We don't provide the same error message as in Python because SymPy
# makes it difficult to check the types.
if divisor.is_zero:
raise ZeroDivisionError("division by zero")
if base.is_zero:
return sympy.S.Zero
if base.is_integer and divisor == 1:
return base
if base.is_real and divisor == 1:
return sympy.floor(base)
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
return base // divisor
if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance(divisor, (sympy.Integer, sympy.Float)):
return sympy.floor(base / divisor)
if isinstance(base, FloorDiv):
return FloorDiv(base.args[0], base.args[1] * divisor)
if isinstance(base, sympy.Add):
for a in base.args:
gcd = sympy.gcd(a, divisor)
if gcd == divisor:
return FloorDiv(base - a, divisor) + a / gcd
gcd = sympy.gcd(base, divisor)
if gcd != 1:
return FloorDiv(
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
)
class ModularIndexing(sympy.Function):
"""
ModularIndexing(a, b, c) => (a // b) % c
"""
nargs = (3,)
is_integer = True
@classmethod
def eval(cls, base, divisor, modulus):
if base == 0 or modulus == 1:
return sympy.Integer(0)
if (
isinstance(base, sympy.Integer)
and isinstance(divisor, sympy.Integer)
and isinstance(modulus, sympy.Integer)
):
return (base // divisor) % modulus
if divisor != 1:
gcd = sympy.gcd(base, divisor)
if gcd != 1:
return ModularIndexing(
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd), modulus
)
if isinstance(base, sympy.Add):
new_terms = []
all_positive = True
for term in base.args:
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
if (isinstance(term, sympy.Integer) and term < 0) or (
isinstance(term, sympy.Mul)
and isinstance(term.args[0], sympy.Integer)
and term.args[0] < 0
):
# workaround for https://github.com/openai/triton/issues/619,
# if there are negative terms, // produces wrong result
# TODO if https://github.com/openai/triton/issues/619 is fixed
# this optimization would become valid
all_positive = False
break
else:
new_terms.append(term)
if len(new_terms) != len(base.args) and all_positive:
return ModularIndexing(sum(new_terms), divisor, modulus)
if isinstance(base, FloorDiv):
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
class CleanDiv(FloorDiv):
"""
Div where we can assume no rounding.
This is to enable future optimizations.
"""
pass
class CeilDiv(sympy.Function):
"""
Div used in indexing that rounds up.
"""
is_integer = True
def __new__(cls, base, divisor):
if sympy.gcd(base, divisor) == divisor:
return CleanDiv(base, divisor)
else:
return FloorDiv(base + (divisor - 1), divisor)
class LShift(sympy.Function):
@classmethod
def eval(cls, base, shift):
if shift < 0:
raise ValueError('negative shift count')
return base * 2 ** shift
class RShift(sympy.Function):
@classmethod
def eval(cls, base, shift):
if shift < 0:
raise ValueError('negative shift count')
return base // 2 ** shift