Replace sympy.solve with a new simplified one. (#105877)

This PR implements `try_solve`: a function that tries to move terms of a relational
expression around, so as to isolate a given variable on the left-hand side.

For example:

```python
>>> try_solve(Eq(a + 5, 3), a)
Eq(a, -2)
>>> try_solve(Gt(Mod(a, 3), 0), a) # returns None
>>> try_solve(Gt(Mod(a, 3), 0), Mod(a, 3))
Gt(Mod(a, 3), 0), Mod(a, 3)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105877
Approved by: https://github.com/ezyang
This commit is contained in:
Yukio Siraichi
2023-08-01 16:32:49 -03:00
committed by PyTorch MergeBot
parent bfed2da2e4
commit bd84651e19
3 changed files with 447 additions and 116 deletions

View File

@ -4,12 +4,17 @@ import itertools
import sys
import sympy
from typing import Callable, List, Tuple, Type
from torch.testing._internal.common_device_type import skipIf
from torch.testing._internal.common_utils import (
TEST_Z3,
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
from torch.utils._sympy.functions import FloorDiv
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
@ -55,6 +60,8 @@ CONSTANTS = [
]
# less constants for N^2 situations
LESS_CONSTANTS = [-1, 0, 1, 2, 100]
# SymPy relational types.
RELATIONAL_TYPES = [sympy.Eq, sympy.Ne, sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le]
def valid_unary(fn, v):
@ -271,8 +278,252 @@ class TestSympyInterp(TestCase):
self.assertEqual(ref_r, r)
def type_name_fn(type: Type) -> str:
return type.__name__
def parametrize_relational_types(*types):
def wrapper(f: Callable):
return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f)
return wrapper
class TestSympySolve(TestCase):
def _create_integer_symbols(self) -> List[sympy.Symbol]:
return sympy.symbols("a b c", integer=True)
def test_give_up(self):
from sympy import Eq, Ne
a, b, c = self._create_integer_symbols()
cases = [
# Not a relational operation.
a + b,
# 'a' appears on both sides.
Eq(a, a + 1),
# 'a' doesn't appear on neither side.
Eq(b, c + 1),
# Result is a 'sympy.And'.
Eq(FloorDiv(a, b), c),
# Result is a 'sympy.Or'.
Ne(FloorDiv(a, b), c),
]
for case in cases:
e = try_solve(case, a)
self.assertEqual(e, None)
@parametrize_relational_types()
def test_noop(self, op):
a, b, _ = self._create_integer_symbols()
lhs, rhs = a, 42 * b
expr = op(lhs, rhs)
r = try_solve(expr, a)
self.assertNotEqual(r, None)
r_expr, r_rhs = r
self.assertEqual(r_expr, expr)
self.assertEqual(r_rhs, rhs)
@parametrize_relational_types()
def test_noop_rhs(self, op):
a, b, _ = self._create_integer_symbols()
lhs, rhs = 42 * b, a
mirror = mirror_rel_op(op)
self.assertNotEqual(mirror, None)
expr = op(lhs, rhs)
r = try_solve(expr, a)
self.assertNotEqual(r, None)
r_expr, r_rhs = r
self.assertEqual(r_expr, mirror(rhs, lhs))
self.assertEqual(r_rhs, lhs)
def _test_cases(self, cases: List[Tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, op: Type[sympy.Rel], **kwargs):
for source, expected in cases:
r = try_solve(source, thing, **kwargs)
self.assertTrue(
(r is None and expected is None)
or (r is not None and expected is not None)
)
if r is not None:
r_expr, r_rhs = r
self.assertEqual(r_rhs, expected)
self.assertEqual(r_expr, op(thing, expected))
def test_addition(self):
from sympy import Eq
a, b, c = self._create_integer_symbols()
cases = [
(Eq(a + b, 0), -b),
(Eq(a + 5, b - 5), b - 10),
(Eq(a + c * b, 1), 1 - c * b),
]
self._test_cases(cases, a, Eq)
@parametrize_relational_types(sympy.Eq, sympy.Ne)
def test_multiplication_division(self, op):
a, b, c = self._create_integer_symbols()
cases = [
(op(a * b, 1), 1 / b),
(op(a * 5, b - 5), (b - 5) / 5),
(op(a * b, c), c / b),
]
self._test_cases(cases, a, op)
@parametrize_relational_types(*INEQUALITY_TYPES)
def test_multiplication_division_inequality(self, op):
a, b, _ = self._create_integer_symbols()
intneg = sympy.Symbol("neg", integer=True, negative=True)
intpos = sympy.Symbol("pos", integer=True, positive=True)
cases = [
# Divide/multiply both sides by positive number.
(op(a * intpos, 1), 1 / intpos),
(op(a / (5 * intpos), 1), 5 * intpos),
(op(a * 5, b - 5), (b - 5) / 5),
# 'b' is not strictly positive nor negative, so we can't
# divide/multiply both sides by 'b'.
(op(a * b, 1), None),
(op(a / b, 1), None),
(op(a * b * intpos, 1), None),
]
mirror_cases = [
# Divide/multiply both sides by negative number.
(op(a * intneg, 1), 1 / intneg),
(op(a / (5 * intneg), 1), 5 * intneg),
(op(a * -5, b - 5), -(b - 5) / 5),
]
mirror_op = mirror_rel_op(op)
assert mirror_op is not None
self._test_cases(cases, a, op)
self._test_cases(mirror_cases, a, mirror_op)
@parametrize_relational_types()
def test_floordiv(self, op):
from sympy import Eq, Ne, Gt, Ge, Lt, Le
a, b, c = sympy.symbols("a b c")
pos = sympy.Symbol("pos", positive=True)
integer = sympy.Symbol("integer", integer=True)
# (Eq(FloorDiv(a, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
# (Eq(FloorDiv(a + 5, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
# (Ne(FloorDiv(a, pos), integer), Or(Lt(a, integer * pos), Ge(a, (integer + 1) * pos))),
special_case = {
# 'FloorDiv' turns into 'And', which can't be simplified any further.
Eq: (Eq(FloorDiv(a, pos), integer), None),
# 'FloorDiv' turns into 'Or', which can't be simplified any further.
Ne: (Ne(FloorDiv(a, pos), integer), None),
Gt: (Gt(FloorDiv(a, pos), integer), (integer + 1) * pos),
Ge: (Ge(FloorDiv(a, pos), integer), integer * pos),
Lt: (Lt(FloorDiv(a, pos), integer), integer * pos),
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
}[op]
cases: List[Tuple[sympy.Basic, sympy.Basic]] = [
# 'b' is not strictly positive
(op(FloorDiv(a, b), integer), None),
# 'c' is not strictly positive
(op(FloorDiv(a, pos), c), None),
]
# The result might change after 'FloorDiv' transformation.
# Specifically:
# - [Ge, Gt] => Ge
# - [Le, Lt] => Lt
if op in (sympy.Gt, sympy.Ge):
r_op = sympy.Ge
elif op in (sympy.Lt, sympy.Le):
r_op = sympy.Lt
else:
r_op = op
self._test_cases([special_case, *cases], a, r_op)
self._test_cases([(special_case[0], None), *cases], a, r_op, floordiv_inequality=False)
def test_floordiv_eq_simplify(self):
from sympy import Eq, Lt, Le
a = sympy.Symbol("a", positive=True, integer=True)
def check(expr, expected):
r = try_solve(expr, a)
self.assertNotEqual(r, None)
r_expr, _ = r
self.assertEqual(r_expr, expected)
# (a + 10) // 3 == 3
# =====================================
# 3 * 3 <= a + 10 (always true)
# a + 10 < 4 * 3 (not sure)
check(Eq(FloorDiv(a + 10, 3), 3), Lt(a, (3 + 1) * 3 - 10))
# (a + 10) // 2 == 4
# =====================================
# 4 * 2 <= 10 - a (not sure)
# 10 - a < 5 * 2 (always true)
check(Eq(FloorDiv(10 - a, 2), 4), Le(a, -(4 * 2 - 10)))
@skipIf(not TEST_Z3, "Z3 not installed")
def test_z3_proof_floordiv_eq_simplify(self):
import z3
from sympy import Eq, Lt
a = sympy.Symbol("a", positive=True, integer=True)
a_ = z3.Int("a")
# (a + 10) // 3 == 3
# =====================================
# 3 * 3 <= a + 10 (always true)
# a + 10 < 4 * 3 (not sure)
solver = z3.SolverFor("QF_NRA")
# Add assertions for 'a_'.
solver.add(a_ > 0)
expr = Eq(FloorDiv(a + 10, 3), 3)
r_expr, _ = try_solve(expr, a)
# Check 'try_solve' really returns the 'expected' below.
expected = Lt(a, (3 + 1) * 3 - 10)
self.assertEqual(r_expr, expected)
# Check whether there is an integer 'a_' such that the
# equation below is satisfied.
solver.add(
# expr
(z3.ToInt((a_ + 10) / 3.0) == 3)
!=
# expected
(a_ < (3 + 1) * 3 - 10)
)
# Assert that there's no such an integer.
# i.e. the transformation is sound.
r = solver.check()
self.assertEqual(r, z3.unsat)
instantiate_parametrized_tests(TestValueRanges)
instantiate_parametrized_tests(TestSympyInterp)
instantiate_parametrized_tests(TestSympySolve)
if __name__ == "__main__":