mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fast path for sympy gcd in floordiv (#134880)
Summary: Re-implementation of https://github.com/pytorch/pytorch/pull/134150, which was reverted because of some internal tests hanging (case B). The original motivation was to get some other internal test unstuck (case A). The root cause is that sympy.gcd is both very clever as well as can blow up in some cases. This PR introduces a fast path with an appropriate fallback to sympy.gcd that ensures that both cases A and B go through. Test Plan: See the included test for specific examples. Also https://fb.workplace.com/groups/1075192433118967/posts/1491493248155548/?comment_id=1491938994777640&reply_comment_id=1492622821375924 Differential Revision: D62043315 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134880 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
67208f08bd
commit
8bfd4916d6
@ -14,7 +14,7 @@ from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.utils._sympy.functions import FloorDiv
|
||||
from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd
|
||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
|
||||
@ -667,6 +667,24 @@ class TestSympySolve(TestCase):
|
||||
r = solver.check()
|
||||
self.assertEqual(r, z3.unsat)
|
||||
|
||||
def test_simple_floordiv_gcd(self):
|
||||
x, y, z = sympy.symbols("x y z")
|
||||
|
||||
# positive tests
|
||||
self.assertEqual(simple_floordiv_gcd(x, x), x)
|
||||
self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128)
|
||||
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128)
|
||||
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128)
|
||||
self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x)
|
||||
self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x)
|
||||
self.assertEqual(simple_floordiv_gcd(x * y, x), x)
|
||||
self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y)
|
||||
self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x)
|
||||
|
||||
# negative tests
|
||||
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
|
||||
|
||||
|
||||
class TestSingletonInt(TestCase):
|
||||
def test_basic(self):
|
||||
j1 = SingletonInt(1, coeff=1)
|
||||
|
@ -95,6 +95,46 @@ def fuzzy_eq(x, y):
|
||||
return x == y
|
||||
|
||||
|
||||
def simple_floordiv_gcd(p, q):
|
||||
"""
|
||||
Fast path for sympy.gcd, using a simple factoring strategy.
|
||||
|
||||
We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0,
|
||||
where n is the greatest common integer factor and e is the largest
|
||||
syntactic common factor (i.e., common sub-expression) in p and q.
|
||||
Then the gcd returned is n*e, cancelling which we would be left with
|
||||
p1 + p2 and q0.
|
||||
|
||||
Note that further factoring of p1 + p2 and q0 might be possible with
|
||||
sympy.factor (which uses domain-specific theories). E.g., we are unable
|
||||
to find that x*y + x + y + 1 is divisible by x + 1. More generally,
|
||||
when q is of the form q1 + q2 (instead of being already factored) it
|
||||
might be necessary to fall back on sympy.gcd.
|
||||
"""
|
||||
|
||||
def integer_coefficient(x):
|
||||
integer_coefficients = [
|
||||
abs(int(arg))
|
||||
for arg in sympy.Mul.make_args(x)
|
||||
if isinstance(arg, (int, sympy.Integer))
|
||||
]
|
||||
return math.prod(integer_coefficients)
|
||||
|
||||
def integer_factor(expr):
|
||||
integer_factors = map(integer_coefficient, sympy.Add.make_args(expr))
|
||||
return functools.reduce(math.gcd, integer_factors)
|
||||
|
||||
gcd = math.gcd(integer_factor(p), integer_factor(q))
|
||||
p, q = p / gcd, q / gcd
|
||||
|
||||
base_splits = list(map(sympy.Mul.make_args, sympy.Add.make_args(p)))
|
||||
divisor_split = sympy.Mul.make_args(q)
|
||||
for x in divisor_split:
|
||||
if all(x in base_split for base_split in base_splits):
|
||||
gcd = gcd * x
|
||||
return gcd
|
||||
|
||||
|
||||
# It would be nice to have assertions on whether or not inputs is_integer
|
||||
# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
|
||||
# sometimes inconsistently reports floats an integers.
|
||||
@ -202,7 +242,9 @@ class FloorDiv(sympy.Function):
|
||||
return FloorDiv(base - term, divisor) + quotient
|
||||
|
||||
try:
|
||||
gcd = sympy.gcd(base, divisor)
|
||||
gcd = simple_floordiv_gcd(base, divisor)
|
||||
if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add):
|
||||
gcd = sympy.gcd(base, divisor)
|
||||
if not equal_valued(gcd, 1):
|
||||
return FloorDiv(
|
||||
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
|
||||
|
Reference in New Issue
Block a user