mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fast path for sympy gcd in floordiv (#134880)
Summary: Re-implementation of https://github.com/pytorch/pytorch/pull/134150, which was reverted because of some internal tests hanging (case B). The original motivation was to get some other internal test unstuck (case A). The root cause is that sympy.gcd is both very clever as well as can blow up in some cases. This PR introduces a fast path with an appropriate fallback to sympy.gcd that ensures that both cases A and B go through. Test Plan: See the included test for specific examples. Also https://fb.workplace.com/groups/1075192433118967/posts/1491493248155548/?comment_id=1491938994777640&reply_comment_id=1492622821375924 Differential Revision: D62043315 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134880 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
67208f08bd
commit
8bfd4916d6
@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils._sympy.functions import FloorDiv
|
||||
from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd
|
||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
|
||||
@ -667,6 +667,24 @@ class TestSympySolve(TestCase):
|
||||
r = solver.check()
|
||||
self.assertEqual(r, z3.unsat)
|
||||
|
||||
def test_simple_floordiv_gcd(self):
|
||||
x, y, z = sympy.symbols("x y z")
|
||||
|
||||
# positive tests
|
||||
self.assertEqual(simple_floordiv_gcd(x, x), x)
|
||||
self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128)
|
||||
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128)
|
||||
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128)
|
||||
self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x)
|
||||
self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x)
|
||||
self.assertEqual(simple_floordiv_gcd(x * y, x), x)
|
||||
self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y)
|
||||
self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x)
|
||||
|
||||
# negative tests
|
||||
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
|
||||
|
||||
|
||||
class TestSingletonInt(TestCase):
|
||||
def test_basic(self):
|
||||
j1 = SingletonInt(1, coeff=1)
|
||||
|
Reference in New Issue
Block a user