Revert "Complete revamp of float/promotion sympy handling (#126905)"

This reverts commit fb696ef3aa34e20c0fef1c0210a397abd3ea5885.

Reverted https://github.com/pytorch/pytorch/pull/126905 on behalf of https://github.com/ezyang due to internal user reported ceiling equality simplification problem, I have a plan ([comment](https://github.com/pytorch/pytorch/pull/126905#issuecomment-2148805840))
This commit is contained in:
PyTorch MergeBot
2024-06-05 03:57:58 +00:00
parent 55a4ef80c4
commit d5cb5d623a
37 changed files with 669 additions and 1605 deletions

View File

@ -36,12 +36,7 @@ UNARY_OPS = [
"floor",
"ceil",
]
BINARY_OPS = [
"truediv", "floordiv",
# "truncdiv", # TODO
# NB: pow is float_pow
"add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod"
]
BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"]
UNARY_BOOL_OPS = ["not_"]
BINARY_BOOL_OPS = ["or_", "and_"]
@ -86,24 +81,16 @@ def valid_unary(fn, v):
def valid_binary(fn, a, b):
if fn == "pow" and (
# sympy will expand to x*x*... for integral b; don't do it if it's big
b > 4
# no imaginary numbers
or a <= 0
# 0**0 is undefined
or (a == b == 0)
or ( # sympy will expand to x*x*... for integral b; don't do it if it's big
a <= 0 and b == -1
)
or (a == b == 0) # no imaginary numbers # 0**0 is undefined
):
return False
elif fn == "pow_by_natural" and (
# sympy will expand to x*x*... for integral b; don't do it if it's big
b > 4
or b < 0
or (a == b == 0)
):
elif fn == "mod" and b == 0:
return False
elif fn == "mod" and (a < 0 or b <= 0):
return False
elif (fn in ["div", "truediv", "floordiv"]) and b == 0:
elif (fn == "div" or fn == "truediv") and b == 0:
return False
return True
@ -143,26 +130,27 @@ class TestValueRanges(TestCase):
ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
@parametrize("fn", BINARY_OPS)
@parametrize("dtype", ("int", "float"))
def test_binary_ref(self, fn, dtype):
@parametrize("dtype_a", ("int", "float"))
@parametrize("dtype_b", ("int", "float"))
def test_binary_ref(self, fn, dtype_a, dtype_b):
to_dtype = {"int": sympy.Integer, "float": sympy.Float}
# Don't test float on int only methods
if dtype == "float" and fn in ["pow_by_natural", "mod"]:
return
dtype = to_dtype[dtype]
dtype_a = to_dtype[dtype_a]
dtype_b = to_dtype[dtype_b]
for a, b in itertools.product(CONSTANTS, repeat=2):
if not valid_binary(fn, a, b):
continue
a = dtype(a)
b = dtype(b)
a = dtype_a(a)
b = dtype_b(b)
with self.subTest(a=a, b=b):
r = getattr(ValueRangeAnalysis, fn)(a, b)
if r == ValueRanges.unknown():
continue
ref_r = getattr(ReferenceAnalysis, fn)(a, b)
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
# sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf
if fn != "floordiv":
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
self.assertEqual(r.lower, r.upper)
self.assertEqual(ref_r, r.lower)
@ -212,8 +200,7 @@ class TestValueRanges(TestCase):
@parametrize("fn", UNARY_OPS)
def test_unary_ref_range(self, fn):
# TODO: bring back sympy.oo testing for float unary fns
vals = CONSTANTS
vals = [-sympy.oo, *CONSTANTS, sympy.oo]
for a in generate_range(vals):
with self.subTest(a=a):
ref_r = getattr(ValueRangeAnalysis, fn)(a)
@ -229,26 +216,40 @@ class TestValueRanges(TestCase):
# This takes about 4s for all the variants
@parametrize("fn", BINARY_OPS + COMPARE_OPS)
def test_binary_ref_range(self, fn):
# TODO: bring back sympy.oo testing for float unary fns
vals = LESS_CONSTANTS
vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo]
for a, b in itertools.product(generate_range(vals), repeat=2):
# don't attempt pow on exponents that are too large (but oo is OK)
if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
continue
with self.subTest(a=a, b=b):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
if a0 not in a or b0 not in b:
continue
if not valid_binary(fn, a0, b0):
continue
with self.subTest(a0=a0, b0=b0):
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
r = getattr(ReferenceAnalysis, fn)(
sympy.Integer(a0), sympy.Integer(b0)
)
if r.is_finite:
self.assertIn(r, ref_r)
def test_rational_bounds(self):
# Repro from https://github.com/pytorch/pytorch/issues/105097
from sympy import floor, Eq
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
new_expr = (
Eq(30 * floor(4 * ((shape_0 + 1) // 96) *
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 +
2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647),
2880 * floor(((shape_0 + 1) // 96) *
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 +
323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764)))
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
self.assertTrue(new_expr.subs({shape_0: 95}))
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
class TestSympyInterp(TestCase):
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
@ -257,13 +258,7 @@ class TestSympyInterp(TestCase):
if fn in ("div", "truncdiv", "minimum", "maximum", "mod"):
return
is_integer = None
if fn == "pow_by_natural":
is_integer = True
x = sympy.Dummy('x', integer=is_integer)
y = sympy.Dummy('y', integer=is_integer)
from sympy.abc import x, y
vals = CONSTANTS
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
vals = [True, False]
@ -305,17 +300,29 @@ class TestSympyInterp(TestCase):
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
arity = 2
is_integer = None
if fn == "pow_by_natural":
is_integer = True
x = sympy.Dummy('x', integer=is_integer)
y = sympy.Dummy('y', integer=is_integer)
from sympy.abc import x, y
symbols = [x]
if arity == 2:
symbols = [x, y]
# Workaround mpf from symbol error
if fn == "minimum":
sympy_expr = sympy.Min(x, y)
elif fn == "maximum":
sympy_expr = sympy.Max(x, y)
else:
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
if arity == 1:
def trace_f(px):
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
else:
def trace_f(px, py):
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
gm = fx.symbolic_trace(trace_f)
for args in itertools.product(vals, repeat=arity):
if arity == 1 and not valid_unary(fn, *args):
continue
@ -323,28 +330,11 @@ class TestSympyInterp(TestCase):
continue
if fn == "truncdiv" and args[1] == 0:
continue
elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0):
elif fn == "pow" and (args[0] == 0 and args[1] <= 0):
continue
elif fn == "floordiv" and args[1] == 0:
continue
with self.subTest(args=args):
# Workaround mpf from symbol error
if fn == "minimum":
sympy_expr = sympy.Min(x, y)
elif fn == "maximum":
sympy_expr = sympy.Max(x, y)
else:
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
if arity == 1:
def trace_f(px):
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
else:
def trace_f(px, py):
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
gm = fx.symbolic_trace(trace_f)
self.assertEqual(
sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr),
gm(*args)