fast path for sympy gcd in floordiv (#134880)

Summary:
Re-implementation of https://github.com/pytorch/pytorch/pull/134150, which was reverted because of some internal tests hanging (case B). The original motivation was to get some other internal test unstuck (case A).

The root cause is that sympy.gcd is both very clever as well as can blow up in some cases. This PR introduces a fast path with an appropriate fallback to sympy.gcd that ensures that both cases A and B go through.

Test Plan:
See the included test for specific examples.
Also https://fb.workplace.com/groups/1075192433118967/posts/1491493248155548/?comment_id=1491938994777640&reply_comment_id=1492622821375924

Differential Revision: D62043315

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134880
Approved by: https://github.com/ezyang
This commit is contained in:
Avik Chaudhuri
2024-09-04 14:56:49 +00:00
committed by PyTorch MergeBot
parent 67208f08bd
commit 8bfd4916d6
2 changed files with 62 additions and 2 deletions

View File

@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import (
run_tests,
TestCase,
)
from torch.utils._sympy.functions import FloorDiv
from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
@ -667,6 +667,24 @@ class TestSympySolve(TestCase):
r = solver.check()
self.assertEqual(r, z3.unsat)
def test_simple_floordiv_gcd(self):
x, y, z = sympy.symbols("x y z")
# positive tests
self.assertEqual(simple_floordiv_gcd(x, x), x)
self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128)
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128)
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128)
self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x)
self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x)
self.assertEqual(simple_floordiv_gcd(x * y, x), x)
self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y)
self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x)
# negative tests
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
class TestSingletonInt(TestCase):
def test_basic(self):
j1 = SingletonInt(1, coeff=1)