mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Replace sympy.solve with a new simplified one. (#105877)
This PR implements `try_solve`: a function that tries to move terms of a relational expression around, so as to isolate a given variable on the left-hand side. For example: ```python >>> try_solve(Eq(a + 5, 3), a) Eq(a, -2) >>> try_solve(Gt(Mod(a, 3), 0), a) # returns None >>> try_solve(Gt(Mod(a, 3), 0), Mod(a, 3)) Gt(Mod(a, 3), 0), Mod(a, 3) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/105877 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bfed2da2e4
commit
bd84651e19
@ -4,12 +4,17 @@ import itertools
|
||||
import sys
|
||||
|
||||
import sympy
|
||||
from typing import Callable, List, Tuple, Type
|
||||
from torch.testing._internal.common_device_type import skipIf
|
||||
from torch.testing._internal.common_utils import (
|
||||
TEST_Z3,
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils._sympy.functions import FloorDiv
|
||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.reference import ReferenceAnalysis
|
||||
from torch.utils._sympy.interp import sympy_interp
|
||||
@ -55,6 +60,8 @@ CONSTANTS = [
|
||||
]
|
||||
# less constants for N^2 situations
|
||||
LESS_CONSTANTS = [-1, 0, 1, 2, 100]
|
||||
# SymPy relational types.
|
||||
RELATIONAL_TYPES = [sympy.Eq, sympy.Ne, sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le]
|
||||
|
||||
|
||||
def valid_unary(fn, v):
|
||||
@ -271,8 +278,252 @@ class TestSympyInterp(TestCase):
|
||||
self.assertEqual(ref_r, r)
|
||||
|
||||
|
||||
def type_name_fn(type: Type) -> str:
|
||||
return type.__name__
|
||||
|
||||
def parametrize_relational_types(*types):
|
||||
def wrapper(f: Callable):
|
||||
return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f)
|
||||
return wrapper
|
||||
|
||||
|
||||
class TestSympySolve(TestCase):
|
||||
def _create_integer_symbols(self) -> List[sympy.Symbol]:
|
||||
return sympy.symbols("a b c", integer=True)
|
||||
|
||||
def test_give_up(self):
|
||||
from sympy import Eq, Ne
|
||||
|
||||
a, b, c = self._create_integer_symbols()
|
||||
|
||||
cases = [
|
||||
# Not a relational operation.
|
||||
a + b,
|
||||
# 'a' appears on both sides.
|
||||
Eq(a, a + 1),
|
||||
# 'a' doesn't appear on neither side.
|
||||
Eq(b, c + 1),
|
||||
# Result is a 'sympy.And'.
|
||||
Eq(FloorDiv(a, b), c),
|
||||
# Result is a 'sympy.Or'.
|
||||
Ne(FloorDiv(a, b), c),
|
||||
]
|
||||
|
||||
for case in cases:
|
||||
e = try_solve(case, a)
|
||||
self.assertEqual(e, None)
|
||||
|
||||
@parametrize_relational_types()
|
||||
def test_noop(self, op):
|
||||
a, b, _ = self._create_integer_symbols()
|
||||
|
||||
lhs, rhs = a, 42 * b
|
||||
expr = op(lhs, rhs)
|
||||
|
||||
r = try_solve(expr, a)
|
||||
self.assertNotEqual(r, None)
|
||||
|
||||
r_expr, r_rhs = r
|
||||
self.assertEqual(r_expr, expr)
|
||||
self.assertEqual(r_rhs, rhs)
|
||||
|
||||
@parametrize_relational_types()
|
||||
def test_noop_rhs(self, op):
|
||||
a, b, _ = self._create_integer_symbols()
|
||||
|
||||
lhs, rhs = 42 * b, a
|
||||
|
||||
mirror = mirror_rel_op(op)
|
||||
self.assertNotEqual(mirror, None)
|
||||
|
||||
expr = op(lhs, rhs)
|
||||
|
||||
r = try_solve(expr, a)
|
||||
self.assertNotEqual(r, None)
|
||||
|
||||
r_expr, r_rhs = r
|
||||
self.assertEqual(r_expr, mirror(rhs, lhs))
|
||||
self.assertEqual(r_rhs, lhs)
|
||||
|
||||
def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs):
|
||||
for source, expected in cases:
|
||||
r = try_solve(source, thing, **kwargs)
|
||||
|
||||
self.assertTrue(
|
||||
(r is None and expected is None)
|
||||
or (r is not None and expected is not None)
|
||||
)
|
||||
|
||||
if r is not None:
|
||||
r_expr, r_rhs = r
|
||||
self.assertEqual(r_rhs, expected)
|
||||
self.assertEqual(r_expr, op(thing, expected))
|
||||
|
||||
def test_addition(self):
|
||||
from sympy import Eq
|
||||
|
||||
a, b, c = self._create_integer_symbols()
|
||||
|
||||
cases = [
|
||||
(Eq(a + b, 0), -b),
|
||||
(Eq(a + 5, b - 5), b - 10),
|
||||
(Eq(a + c * b, 1), 1 - c * b),
|
||||
]
|
||||
|
||||
self._test_cases(cases, a, Eq)
|
||||
|
||||
@parametrize_relational_types(sympy.Eq, sympy.Ne)
|
||||
def test_multiplication_division(self, op):
|
||||
a, b, c = self._create_integer_symbols()
|
||||
|
||||
cases = [
|
||||
(op(a * b, 1), 1 / b),
|
||||
(op(a * 5, b - 5), (b - 5) / 5),
|
||||
(op(a * b, c), c / b),
|
||||
]
|
||||
|
||||
self._test_cases(cases, a, op)
|
||||
|
||||
@parametrize_relational_types(*INEQUALITY_TYPES)
|
||||
def test_multiplication_division_inequality(self, op):
|
||||
a, b, _ = self._create_integer_symbols()
|
||||
intneg = sympy.Symbol("neg", integer=True, negative=True)
|
||||
intpos = sympy.Symbol("pos", integer=True, positive=True)
|
||||
|
||||
cases = [
|
||||
# Divide/multiply both sides by positive number.
|
||||
(op(a * intpos, 1), 1 / intpos),
|
||||
(op(a / (5 * intpos), 1), 5 * intpos),
|
||||
(op(a * 5, b - 5), (b - 5) / 5),
|
||||
# 'b' is not strictly positive nor negative, so we can't
|
||||
# divide/multiply both sides by 'b'.
|
||||
(op(a * b, 1), None),
|
||||
(op(a / b, 1), None),
|
||||
(op(a * b * intpos, 1), None),
|
||||
]
|
||||
|
||||
mirror_cases = [
|
||||
# Divide/multiply both sides by negative number.
|
||||
(op(a * intneg, 1), 1 / intneg),
|
||||
(op(a / (5 * intneg), 1), 5 * intneg),
|
||||
(op(a * -5, b - 5), -(b - 5) / 5),
|
||||
]
|
||||
mirror_op = mirror_rel_op(op)
|
||||
assert mirror_op is not None
|
||||
|
||||
self._test_cases(cases, a, op)
|
||||
self._test_cases(mirror_cases, a, mirror_op)
|
||||
|
||||
@parametrize_relational_types()
|
||||
def test_floordiv(self, op):
|
||||
from sympy import Eq, Ne, Gt, Ge, Lt, Le
|
||||
|
||||
a, b, c = sympy.symbols("a b c")
|
||||
pos = sympy.Symbol("pos", positive=True)
|
||||
integer = sympy.Symbol("integer", integer=True)
|
||||
|
||||
# (Eq(FloorDiv(a, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
|
||||
# (Eq(FloorDiv(a + 5, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
|
||||
# (Ne(FloorDiv(a, pos), integer), Or(Lt(a, integer * pos), Ge(a, (integer + 1) * pos))),
|
||||
|
||||
special_case = {
|
||||
# 'FloorDiv' turns into 'And', which can't be simplified any further.
|
||||
Eq: (Eq(FloorDiv(a, pos), integer), None),
|
||||
# 'FloorDiv' turns into 'Or', which can't be simplified any further.
|
||||
Ne: (Ne(FloorDiv(a, pos), integer), None),
|
||||
Gt: (Gt(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
||||
Ge: (Ge(FloorDiv(a, pos), integer), integer * pos),
|
||||
Lt: (Lt(FloorDiv(a, pos), integer), integer * pos),
|
||||
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
||||
}[op]
|
||||
|
||||
cases: List[Tuple[sympy.Basic, sympy.Basic]] = [
|
||||
# 'b' is not strictly positive
|
||||
(op(FloorDiv(a, b), integer), None),
|
||||
# 'c' is not strictly positive
|
||||
(op(FloorDiv(a, pos), c), None),
|
||||
]
|
||||
|
||||
# The result might change after 'FloorDiv' transformation.
|
||||
# Specifically:
|
||||
# - [Ge, Gt] => Ge
|
||||
# - [Le, Lt] => Lt
|
||||
if op in (sympy.Gt, sympy.Ge):
|
||||
r_op = sympy.Ge
|
||||
elif op in (sympy.Lt, sympy.Le):
|
||||
r_op = sympy.Lt
|
||||
else:
|
||||
r_op = op
|
||||
|
||||
self._test_cases([special_case, *cases], a, r_op)
|
||||
self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False)
|
||||
|
||||
def test_floordiv_eq_simplify(self):
|
||||
from sympy import Eq, Lt, Le
|
||||
|
||||
a = sympy.Symbol("a", positive=True, integer=True)
|
||||
|
||||
def check(expr, expected):
|
||||
r = try_solve(expr, a)
|
||||
self.assertNotEqual(r, None)
|
||||
r_expr, _ = r
|
||||
self.assertEqual(r_expr, expected)
|
||||
|
||||
# (a + 10) // 3 == 3
|
||||
# =====================================
|
||||
# 3 * 3 <= a + 10 (always true)
|
||||
# a + 10 < 4 * 3 (not sure)
|
||||
check(Eq(FloorDiv(a + 10, 3), 3), Lt(a, (3 + 1) * 3 - 10))
|
||||
|
||||
# (a + 10) // 2 == 4
|
||||
# =====================================
|
||||
# 4 * 2 <= 10 - a (not sure)
|
||||
# 10 - a < 5 * 2 (always true)
|
||||
check(Eq(FloorDiv(10 - a, 2), 4), Le(a, -(4 * 2 - 10)))
|
||||
|
||||
@skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_z3_proof_floordiv_eq_simplify(self):
|
||||
import z3
|
||||
from sympy import Eq, Lt
|
||||
|
||||
a = sympy.Symbol("a", positive=True, integer=True)
|
||||
a_ = z3.Int("a")
|
||||
|
||||
# (a + 10) // 3 == 3
|
||||
# =====================================
|
||||
# 3 * 3 <= a + 10 (always true)
|
||||
# a + 10 < 4 * 3 (not sure)
|
||||
solver = z3.SolverFor("QF_NRA")
|
||||
|
||||
# Add assertions for 'a_'.
|
||||
solver.add(a_ > 0)
|
||||
|
||||
expr = Eq(FloorDiv(a + 10, 3), 3)
|
||||
r_expr, _ = try_solve(expr, a)
|
||||
|
||||
# Check 'try_solve' really returns the 'expected' below.
|
||||
expected = Lt(a, (3 + 1) * 3 - 10)
|
||||
self.assertEqual(r_expr, expected)
|
||||
|
||||
# Check whether there is an integer 'a_' such that the
|
||||
# equation below is satisfied.
|
||||
solver.add(
|
||||
# expr
|
||||
(z3.ToInt((a_ + 10) / 3.0) == 3)
|
||||
!=
|
||||
# expected
|
||||
(a_ < (3 + 1) * 3 - 10)
|
||||
)
|
||||
|
||||
# Assert that there's no such an integer.
|
||||
# i.e. the transformation is sound.
|
||||
r = solver.check()
|
||||
self.assertEqual(r, z3.unsat)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestValueRanges)
|
||||
instantiate_parametrized_tests(TestSympyInterp)
|
||||
instantiate_parametrized_tests(TestSympySolve)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user