fix stuck floordiv (#134150)

Summary: Fixes https://github.com/pytorch/pytorch/issues/134133

Test Plan:
Tested on the small repro in the linked issue with different lengths N (replacing 100), recording N vs. time taken in nanoseconds:
10 127268319
20 220839662
30 325463125
40 429259441
50 553136055
60 670799769
70 999170514
80 899014103
90 997168902
100 1168202035
110 1388556619
120 1457488235
130 1609816470
140 2177889877
150 1917560313
160 2121096113
170 2428502334
180 4117450755
190 4003068224

So N ~ 200 takes ~5s. Previously even smaller N would go for >1 min.

Didn't add a perf test because ezyang is planning to build a benchmark.

Also tested on https://www.internalfb.com/diff/D61560171, which now gets past the stuck point.

Differential Revision: D61619660

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134150
Approved by: https://github.com/ezyang
This commit is contained in:
Avik Chaudhuri
2024-08-26 07:27:59 +00:00
committed by PyTorch MergeBot
parent c5f6b72041
commit 92c4771853
2 changed files with 46 additions and 3 deletions

View File

@ -0,0 +1,39 @@
import sys
from benchmark_base import BenchmarkBase
import torch
class Benchmark(BenchmarkBase):
N = 100
def name(self):
return "sum_floordiv_regression"
def description(self):
return "information at https://github.com/pytorch/pytorch/issues/134133"
def prepare_once(self):
class M(torch.nn.Module):
def forward(self, x):
total = sum(t.item() for t in x)
return total // 2
self.m = M()
self.input = [torch.tensor(i + 2) for i in range(self.N)]
def prepare(self):
torch._dynamo.reset()
def work(self):
torch.export.export(self.m, (self.input,))
def main():
result_path = sys.argv[1]
Benchmark().enable_instruction_count().collect_all().append_results(result_path)
if __name__ == "__main__":
main()

View File

@ -186,7 +186,8 @@ class FloorDiv(sympy.Function):
# Expands (x + y) // b into x // b + y // b.
# This only works if floor is an identity, i.e. x / b is an integer.
for term in sympy.Add.make_args(base):
base_args = sympy.Add.make_args(base)
for term in base_args:
quotient = term / divisor
if quotient.is_integer and isinstance(divisor, sympy.Integer):
# NB: this is correct even if the divisor is not an integer, but it
@ -195,8 +196,11 @@ class FloorDiv(sympy.Function):
return FloorDiv(base - term, divisor) + quotient
try:
gcd = sympy.gcd(base, divisor)
if not equal_valued(gcd, 1):
# sympy.gcd tends to blow up on large sums, so use it on each summand instead
gcd, *gcds_ = (sympy.gcd(term, divisor) for term in base_args)
if not equal_valued(gcd, 1) and all(
equal_valued(gcd, gcd_) for gcd_ in gcds_
):
return FloorDiv(
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
)