mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix FloorDiv should not generate non integer rationals (due to sympy bug) (#164398)
FloorDiv eval have this optimization ``` # Expands (x + y) // b into x // b + y // b. # This only works if floor is an identity, i.e. x / b is an integer. ``` Before this PR this optimization would generate a result in an expression like this. Duo to a bug in sympy. ``` Mul(Rational(1, 22), Add(Mul(Integer(24), Symbol('s37', integer=True, positive=True)), Integer(672)), FloorDiv(Mul(Symbol('s14', integer=True, positive=True), Symbol('s46', integer=True, positive=True)), Integer(2016))) ``` This is because in sympy an expression can have .is_integer =True yet have 1/22 in it! This PR ensure we do not generate that by simply opting out if this optimization if we end up with quotient that have such rational. Fix https://github.com/pytorch/pytorch/issues/164385, https://github.com/pytorch/pytorch/issues/154996 https://github.com/pytorch/pytorch/issues/153375 https://github.com/pytorch/pytorch/issues/164063 and internal user issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164398 Approved by: https://github.com/jansel, https://github.com/isuruf
This commit is contained in:
committed by
PyTorch MergeBot
parent
22e219d996
commit
15c8bdcc5e
@ -254,6 +254,46 @@ class TestIndexingSimplification(InductorTestCase):
|
||||
ms = benchmarker.benchmark_gpu(lambda: f(x))
|
||||
print(f"{ms=:.03f}")
|
||||
|
||||
@unittest.skipUnless(HAS_GPU, "Need GPU for this test")
|
||||
def test_floordiv_div_sympy_is_integer_bug(self):
|
||||
def foo(arg0, arg1, arg2, arg3, arg4, sentinel):
|
||||
t0 = arg0
|
||||
t1 = t0.reshape((28, 24, 3, 127))
|
||||
t2 = t1.var(dim=2)
|
||||
t3 = arg1
|
||||
t4 = arg2
|
||||
t5 = torch.nn.functional.embedding(
|
||||
torch.clamp(t3, 0, t4.size(0) - 1).to(torch.long), t4
|
||||
)
|
||||
t6 = arg3
|
||||
t7 = torch.nn.functional.pad(t6, [0, 1], mode="constant", value=0.0)
|
||||
t8 = arg4
|
||||
t9 = t8.sum(dim=1)
|
||||
t10 = torch.baddbmm(t5, t7, t9)
|
||||
t11 = torch.cat([t2, t10], dim=0)
|
||||
output = t11 + sentinel
|
||||
return output
|
||||
|
||||
arg0 = torch.rand(
|
||||
[36, 7112, 1, 1], dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
arg1 = torch.randint(0, 512, [30, 24], dtype=torch.int64, device=GPU_TYPE)
|
||||
arg2 = torch.rand(
|
||||
[512, 127], dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
arg3 = torch.rand(
|
||||
[30, 24, 15], dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
arg4 = torch.rand(
|
||||
[30, 4, 16, 127], dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
sentinel = torch.tensor(
|
||||
0.0, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
|
||||
)
|
||||
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
|
||||
out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, sentinel)
|
||||
out_compiled.sum().backward()
|
||||
|
||||
|
||||
class ExprPrinterTests(InductorTestCase):
|
||||
def test_print_pow(self):
|
||||
|
@ -1949,6 +1949,19 @@ class TestFloorDiv(TestCase):
|
||||
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
|
||||
)
|
||||
|
||||
def test_floordiv_div_does_not_generate_non_int_rational(self):
|
||||
s14 = sympy.Symbol("s14", integer=True, positive=True)
|
||||
s37 = sympy.Symbol("s37", integer=True, positive=True)
|
||||
|
||||
inner_expr = FloorDiv(s14, 2016)
|
||||
middle_expr = (24 * s37 + 672) * inner_expr
|
||||
numerator = middle_expr + 21
|
||||
denominator = 22
|
||||
result = FloorDiv(numerator, denominator)
|
||||
rationals = result.atoms(sympy.Rational)
|
||||
all_rationals_ints = all(r.q == 1 for r in rationals)
|
||||
self.assertTrue(all_rationals_ints)
|
||||
|
||||
def test_floordiv_simplify(self):
|
||||
# Tests how we simplify or evaluate FloorDiv without free variables
|
||||
shape_env = ShapeEnv()
|
||||
|
@ -20,6 +20,8 @@ from sympy.core.traversal import walk
|
||||
from sympy.printing.precedence import PRECEDENCE
|
||||
from sympy.utilities.iterables import sift
|
||||
|
||||
from torch.torch_version import TorchVersion
|
||||
|
||||
from .numbers import int_oo
|
||||
|
||||
|
||||
@ -268,7 +270,20 @@ class FloorDiv(sympy.Function):
|
||||
for term in sympy.Add.make_args(base):
|
||||
quotient = term / divisor
|
||||
|
||||
if quotient.is_integer:
|
||||
# This is a sympy bug fixed in https://github.com/sympy/sympy/pull/28442
|
||||
# sympy can generate a quotient with (1/22)*.... such that quotient.is_integer is True
|
||||
# FloorDiv should not allow that as output. see
|
||||
quotient_is_integer = None
|
||||
if isinstance(quotient, sympy.Mul) and TorchVersion(
|
||||
sympy.__version__
|
||||
) < TorchVersion("1.15.0"):
|
||||
rationals = quotient.atoms(sympy.Rational)
|
||||
all_rationals_ints = all(r.q == 1 for r in rationals)
|
||||
quotient_is_integer = quotient.is_integer and all_rationals_ints
|
||||
else:
|
||||
quotient_is_integer = quotient.is_integer
|
||||
|
||||
if quotient_is_integer:
|
||||
terms.append(term)
|
||||
quotients += quotient
|
||||
|
||||
@ -308,7 +323,6 @@ class ModularIndexing(sympy.Function):
|
||||
) -> Optional[sympy.Basic]:
|
||||
if base == 0 or modulus == 1:
|
||||
return sympy.S.Zero
|
||||
|
||||
if (
|
||||
isinstance(base, sympy.Integer)
|
||||
and isinstance(divisor, sympy.Integer)
|
||||
|
Reference in New Issue
Block a user