mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Complete revamp of float/promotion sympy handling (#126905)"
This reverts commit fb696ef3aa34e20c0fef1c0210a397abd3ea5885. Reverted https://github.com/pytorch/pytorch/pull/126905 on behalf of https://github.com/ezyang due to internal user reported ceiling equality simplification problem, I have a plan ([comment](https://github.com/pytorch/pytorch/pull/126905#issuecomment-2148805840))
This commit is contained in:
@ -36,12 +36,7 @@ UNARY_OPS = [
|
||||
"floor",
|
||||
"ceil",
|
||||
]
|
||||
BINARY_OPS = [
|
||||
"truediv", "floordiv",
|
||||
# "truncdiv", # TODO
|
||||
# NB: pow is float_pow
|
||||
"add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod"
|
||||
]
|
||||
BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"]
|
||||
|
||||
UNARY_BOOL_OPS = ["not_"]
|
||||
BINARY_BOOL_OPS = ["or_", "and_"]
|
||||
@ -86,24 +81,16 @@ def valid_unary(fn, v):
|
||||
|
||||
def valid_binary(fn, a, b):
|
||||
if fn == "pow" and (
|
||||
# sympy will expand to x*x*... for integral b; don't do it if it's big
|
||||
b > 4
|
||||
# no imaginary numbers
|
||||
or a <= 0
|
||||
# 0**0 is undefined
|
||||
or (a == b == 0)
|
||||
or ( # sympy will expand to x*x*... for integral b; don't do it if it's big
|
||||
a <= 0 and b == -1
|
||||
)
|
||||
or (a == b == 0) # no imaginary numbers # 0**0 is undefined
|
||||
):
|
||||
return False
|
||||
elif fn == "pow_by_natural" and (
|
||||
# sympy will expand to x*x*... for integral b; don't do it if it's big
|
||||
b > 4
|
||||
or b < 0
|
||||
or (a == b == 0)
|
||||
):
|
||||
elif fn == "mod" and b == 0:
|
||||
return False
|
||||
elif fn == "mod" and (a < 0 or b <= 0):
|
||||
return False
|
||||
elif (fn in ["div", "truediv", "floordiv"]) and b == 0:
|
||||
elif (fn == "div" or fn == "truediv") and b == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -143,26 +130,27 @@ class TestValueRanges(TestCase):
|
||||
ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
|
||||
|
||||
@parametrize("fn", BINARY_OPS)
|
||||
@parametrize("dtype", ("int", "float"))
|
||||
def test_binary_ref(self, fn, dtype):
|
||||
@parametrize("dtype_a", ("int", "float"))
|
||||
@parametrize("dtype_b", ("int", "float"))
|
||||
def test_binary_ref(self, fn, dtype_a, dtype_b):
|
||||
to_dtype = {"int": sympy.Integer, "float": sympy.Float}
|
||||
# Don't test float on int only methods
|
||||
if dtype == "float" and fn in ["pow_by_natural", "mod"]:
|
||||
return
|
||||
dtype = to_dtype[dtype]
|
||||
dtype_a = to_dtype[dtype_a]
|
||||
dtype_b = to_dtype[dtype_b]
|
||||
for a, b in itertools.product(CONSTANTS, repeat=2):
|
||||
if not valid_binary(fn, a, b):
|
||||
continue
|
||||
a = dtype(a)
|
||||
b = dtype(b)
|
||||
a = dtype_a(a)
|
||||
b = dtype_b(b)
|
||||
with self.subTest(a=a, b=b):
|
||||
r = getattr(ValueRangeAnalysis, fn)(a, b)
|
||||
if r == ValueRanges.unknown():
|
||||
continue
|
||||
ref_r = getattr(ReferenceAnalysis, fn)(a, b)
|
||||
|
||||
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
|
||||
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
|
||||
# sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf
|
||||
if fn != "floordiv":
|
||||
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
|
||||
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
|
||||
self.assertEqual(r.lower, r.upper)
|
||||
self.assertEqual(ref_r, r.lower)
|
||||
|
||||
@ -212,8 +200,7 @@ class TestValueRanges(TestCase):
|
||||
|
||||
@parametrize("fn", UNARY_OPS)
|
||||
def test_unary_ref_range(self, fn):
|
||||
# TODO: bring back sympy.oo testing for float unary fns
|
||||
vals = CONSTANTS
|
||||
vals = [-sympy.oo, *CONSTANTS, sympy.oo]
|
||||
for a in generate_range(vals):
|
||||
with self.subTest(a=a):
|
||||
ref_r = getattr(ValueRangeAnalysis, fn)(a)
|
||||
@ -229,26 +216,40 @@ class TestValueRanges(TestCase):
|
||||
# This takes about 4s for all the variants
|
||||
@parametrize("fn", BINARY_OPS + COMPARE_OPS)
|
||||
def test_binary_ref_range(self, fn):
|
||||
# TODO: bring back sympy.oo testing for float unary fns
|
||||
vals = LESS_CONSTANTS
|
||||
vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo]
|
||||
for a, b in itertools.product(generate_range(vals), repeat=2):
|
||||
# don't attempt pow on exponents that are too large (but oo is OK)
|
||||
if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
|
||||
continue
|
||||
with self.subTest(a=a, b=b):
|
||||
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
||||
for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
|
||||
if a0 not in a or b0 not in b:
|
||||
continue
|
||||
if not valid_binary(fn, a0, b0):
|
||||
continue
|
||||
with self.subTest(a0=a0, b0=b0):
|
||||
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
||||
r = getattr(ReferenceAnalysis, fn)(
|
||||
sympy.Integer(a0), sympy.Integer(b0)
|
||||
)
|
||||
if r.is_finite:
|
||||
self.assertIn(r, ref_r)
|
||||
|
||||
def test_rational_bounds(self):
|
||||
# Repro from https://github.com/pytorch/pytorch/issues/105097
|
||||
from sympy import floor, Eq
|
||||
shape_0 = sympy.Symbol('shape_0', positive=True, integer=True)
|
||||
new_expr = (
|
||||
Eq(30 * floor(4 * ((shape_0 + 1) // 96) *
|
||||
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 +
|
||||
2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647),
|
||||
2880 * floor(((shape_0 + 1) // 96) *
|
||||
((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 +
|
||||
323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764)))
|
||||
new_range_env = {shape_0: ValueRanges(lower=1, upper=190)}
|
||||
self.assertTrue(new_expr.subs({shape_0: 95}))
|
||||
self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr))
|
||||
|
||||
|
||||
class TestSympyInterp(TestCase):
|
||||
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
|
||||
@ -257,13 +258,7 @@ class TestSympyInterp(TestCase):
|
||||
if fn in ("div", "truncdiv", "minimum", "maximum", "mod"):
|
||||
return
|
||||
|
||||
is_integer = None
|
||||
if fn == "pow_by_natural":
|
||||
is_integer = True
|
||||
|
||||
x = sympy.Dummy('x', integer=is_integer)
|
||||
y = sympy.Dummy('y', integer=is_integer)
|
||||
|
||||
from sympy.abc import x, y
|
||||
vals = CONSTANTS
|
||||
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
||||
vals = [True, False]
|
||||
@ -305,17 +300,29 @@ class TestSympyInterp(TestCase):
|
||||
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
||||
arity = 2
|
||||
|
||||
is_integer = None
|
||||
if fn == "pow_by_natural":
|
||||
is_integer = True
|
||||
|
||||
x = sympy.Dummy('x', integer=is_integer)
|
||||
y = sympy.Dummy('y', integer=is_integer)
|
||||
from sympy.abc import x, y
|
||||
|
||||
symbols = [x]
|
||||
if arity == 2:
|
||||
symbols = [x, y]
|
||||
|
||||
# Workaround mpf from symbol error
|
||||
if fn == "minimum":
|
||||
sympy_expr = sympy.Min(x, y)
|
||||
elif fn == "maximum":
|
||||
sympy_expr = sympy.Max(x, y)
|
||||
else:
|
||||
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
||||
|
||||
if arity == 1:
|
||||
def trace_f(px):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
|
||||
else:
|
||||
def trace_f(px, py):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
|
||||
|
||||
gm = fx.symbolic_trace(trace_f)
|
||||
|
||||
for args in itertools.product(vals, repeat=arity):
|
||||
if arity == 1 and not valid_unary(fn, *args):
|
||||
continue
|
||||
@ -323,28 +330,11 @@ class TestSympyInterp(TestCase):
|
||||
continue
|
||||
if fn == "truncdiv" and args[1] == 0:
|
||||
continue
|
||||
elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0):
|
||||
elif fn == "pow" and (args[0] == 0 and args[1] <= 0):
|
||||
continue
|
||||
elif fn == "floordiv" and args[1] == 0:
|
||||
continue
|
||||
with self.subTest(args=args):
|
||||
# Workaround mpf from symbol error
|
||||
if fn == "minimum":
|
||||
sympy_expr = sympy.Min(x, y)
|
||||
elif fn == "maximum":
|
||||
sympy_expr = sympy.Max(x, y)
|
||||
else:
|
||||
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
||||
|
||||
if arity == 1:
|
||||
def trace_f(px):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
|
||||
else:
|
||||
def trace_f(px, py):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
|
||||
|
||||
gm = fx.symbolic_trace(trace_f)
|
||||
|
||||
self.assertEqual(
|
||||
sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr),
|
||||
gm(*args)
|
||||
|
Reference in New Issue
Block a user