fast path for sympy gcd in floordiv (#134880)

Summary:
Re-implementation of https://github.com/pytorch/pytorch/pull/134150, which was reverted because of some internal tests hanging (case B). The original motivation was to get some other internal test unstuck (case A).

The root cause is that sympy.gcd is both very clever as well as can blow up in some cases. This PR introduces a fast path with an appropriate fallback to sympy.gcd that ensures that both cases A and B go through.

Test Plan:
See the included test for specific examples.
Also https://fb.workplace.com/groups/1075192433118967/posts/1491493248155548/?comment_id=1491938994777640&reply_comment_id=1492622821375924

Differential Revision: D62043315

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134880
Approved by: https://github.com/ezyang
This commit is contained in:
Avik Chaudhuri
2024-09-04 14:56:49 +00:00
committed by PyTorch MergeBot
parent 67208f08bd
commit 8bfd4916d6
2 changed files with 62 additions and 2 deletions

View File

@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)
from torch.utils._sympy.functions import FloorDiv
from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
@ -667,6 +667,24 @@ class TestSympySolve(TestCase):
r = solver.check()
self.assertEqual(r, z3.unsat)
def test_simple_floordiv_gcd(self):
x, y, z = sympy.symbols("x y z")
# positive tests
self.assertEqual(simple_floordiv_gcd(x, x), x)
self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128)
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128)
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128)
self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x)
self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x)
self.assertEqual(simple_floordiv_gcd(x * y, x), x)
self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y)
self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x)
# negative tests
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
class TestSingletonInt(TestCase):
def test_basic(self):
j1 = SingletonInt(1, coeff=1)

View File

@ -95,6 +95,46 @@ def fuzzy_eq(x, y):
return x == y
def simple_floordiv_gcd(p, q):
"""
Fast path for sympy.gcd, using a simple factoring strategy.
We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0,
where n is the greatest common integer factor and e is the largest
syntactic common factor (i.e., common sub-expression) in p and q.
Then the gcd returned is n*e, cancelling which we would be left with
p1 + p2 and q0.
Note that further factoring of p1 + p2 and q0 might be possible with
sympy.factor (which uses domain-specific theories). E.g., we are unable
to find that x*y + x + y + 1 is divisible by x + 1. More generally,
when q is of the form q1 + q2 (instead of being already factored) it
might be necessary to fall back on sympy.gcd.
"""
def integer_coefficient(x):
integer_coefficients = [
abs(int(arg))
for arg in sympy.Mul.make_args(x)
if isinstance(arg, (int, sympy.Integer))
]
return math.prod(integer_coefficients)
def integer_factor(expr):
integer_factors = map(integer_coefficient, sympy.Add.make_args(expr))
return functools.reduce(math.gcd, integer_factors)
gcd = math.gcd(integer_factor(p), integer_factor(q))
p, q = p / gcd, q / gcd
base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p)))
divisor_split = sympy.Mul.make_args(q)
for x in divisor_split:
if all(x in base_split for base_split in base_splits):
gcd = gcd * x
return gcd
# It would be nice to have assertions on whether or not inputs is_integer
# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
# sometimes inconsistently reports floats an integers.
@ -202,7 +242,9 @@ class FloorDiv(sympy.Function):
return FloorDiv(base - term, divisor) + quotient
try:
gcd = sympy.gcd(base, divisor)
gcd = simple_floordiv_gcd(base, divisor)
if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add):
gcd = sympy.gcd(base, divisor)
if not equal_valued(gcd, 1):
return FloorDiv(
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)