Fix FloorDiv should not generate non integer rationals (due to sympy bug) (#164398)

FloorDiv eval have this optimization
```
  # Expands (x + y) // b into x // b + y // b.
  # This only works if floor is an identity, i.e. x / b is an integer.
 ```

 Before this PR this optimization would generate a result in an expression like this. Duo to a bug in sympy.
 ```
Mul(Rational(1, 22), Add(Mul(Integer(24), Symbol('s37', integer=True, positive=True)), Integer(672)), FloorDiv(Mul(Symbol('s14', integer=True, positive=True), Symbol('s46', integer=True, positive=True)), Integer(2016)))
 ```

 This is because in sympy an expression can have .is_integer =True yet have 1/22 in it!
 This PR ensure we do not generate that by simply opting out if this optimization if we end
 up with quotient that have such rational.

  Fix
  https://github.com/pytorch/pytorch/issues/164385,
  https://github.com/pytorch/pytorch/issues/154996
  https://github.com/pytorch/pytorch/issues/153375
  https://github.com/pytorch/pytorch/issues/164063
and internal user issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164398
Approved by: https://github.com/jansel, https://github.com/isuruf
This commit is contained in:
Laith Sakka
2025-10-02 08:21:57 -07:00
committed by PyTorch MergeBot
parent 22e219d996
commit 15c8bdcc5e
3 changed files with 69 additions and 2 deletions

View File

@ -254,6 +254,46 @@ class TestIndexingSimplification(InductorTestCase):
ms = benchmarker.benchmark_gpu(lambda: f(x))
print(f"{ms=:.03f}")
@unittest.skipUnless(HAS_GPU, "Need GPU for this test")
def test_floordiv_div_sympy_is_integer_bug(self):
def foo(arg0, arg1, arg2, arg3, arg4, sentinel):
t0 = arg0
t1 = t0.reshape((28, 24, 3, 127))
t2 = t1.var(dim=2)
t3 = arg1
t4 = arg2
t5 = torch.nn.functional.embedding(
torch.clamp(t3, 0, t4.size(0) - 1).to(torch.long), t4
)
t6 = arg3
t7 = torch.nn.functional.pad(t6, [0, 1], mode="constant", value=0.0)
t8 = arg4
t9 = t8.sum(dim=1)
t10 = torch.baddbmm(t5, t7, t9)
t11 = torch.cat([t2, t10], dim=0)
output = t11 + sentinel
return output
arg0 = torch.rand(
[36, 7112, 1, 1], dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
arg1 = torch.randint(0, 512, [30, 24], dtype=torch.int64, device=GPU_TYPE)
arg2 = torch.rand(
[512, 127], dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
arg3 = torch.rand(
[30, 24, 15], dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
arg4 = torch.rand(
[30, 4, 16, 127], dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
sentinel = torch.tensor(
0.0, dtype=torch.bfloat16, device=GPU_TYPE, requires_grad=True
)
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, sentinel)
out_compiled.sum().backward()
class ExprPrinterTests(InductorTestCase):
def test_print_pow(self):

View File

@ -1949,6 +1949,19 @@ class TestFloorDiv(TestCase):
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_div_does_not_generate_non_int_rational(self):
s14 = sympy.Symbol("s14", integer=True, positive=True)
s37 = sympy.Symbol("s37", integer=True, positive=True)
inner_expr = FloorDiv(s14, 2016)
middle_expr = (24 * s37 + 672) * inner_expr
numerator = middle_expr + 21
denominator = 22
result = FloorDiv(numerator, denominator)
rationals = result.atoms(sympy.Rational)
all_rationals_ints = all(r.q == 1 for r in rationals)
self.assertTrue(all_rationals_ints)
def test_floordiv_simplify(self):
# Tests how we simplify or evaluate FloorDiv without free variables
shape_env = ShapeEnv()

View File

@ -20,6 +20,8 @@ from sympy.core.traversal import walk
from sympy.printing.precedence import PRECEDENCE
from sympy.utilities.iterables import sift
from torch.torch_version import TorchVersion
from .numbers import int_oo
@ -268,7 +270,20 @@ class FloorDiv(sympy.Function):
for term in sympy.Add.make_args(base):
quotient = term / divisor
if quotient.is_integer:
# This is a sympy bug fixed in https://github.com/sympy/sympy/pull/28442
# sympy can generate a quotient with (1/22)*.... such that quotient.is_integer is True
# FloorDiv should not allow that as output. see
quotient_is_integer = None
if isinstance(quotient, sympy.Mul) and TorchVersion(
sympy.__version__
) < TorchVersion("1.15.0"):
rationals = quotient.atoms(sympy.Rational)
all_rationals_ints = all(r.q == 1 for r in rationals)
quotient_is_integer = quotient.is_integer and all_rationals_ints
else:
quotient_is_integer = quotient.is_integer
if quotient_is_integer:
terms.append(term)
quotients += quotient
@ -308,7 +323,6 @@ class ModularIndexing(sympy.Function):
) -> Optional[sympy.Basic]:
if base == 0 or modulus == 1:
return sympy.S.Zero
if (
isinstance(base, sympy.Integer)
and isinstance(divisor, sympy.Integer)