From 92c4771853892193d73d87bd60eca4dc7efc51d8 Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Mon, 26 Aug 2024 07:27:59 +0000 Subject: [PATCH] fix stuck floordiv (#134150) Summary: Fixes https://github.com/pytorch/pytorch/issues/134133 Test Plan: Tested on the small repro in the linked issue with different lengths N (replacing 100), recording N vs. time taken in nanoseconds: 10 127268319 20 220839662 30 325463125 40 429259441 50 553136055 60 670799769 70 999170514 80 899014103 90 997168902 100 1168202035 110 1388556619 120 1457488235 130 1609816470 140 2177889877 150 1917560313 160 2121096113 170 2428502334 180 4117450755 190 4003068224 So N ~ 200 takes ~5s. Previously even smaller N would go for >1 min. Didn't add a perf test because ezyang is planning to build a benchmark. Also tested on https://www.internalfb.com/diff/D61560171, which now gets past the stuck point. Differential Revision: D61619660 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134150 Approved by: https://github.com/ezyang --- .../benchmarks/sum_floordiv_benchmark.py | 39 +++++++++++++++++++ torch/utils/_sympy/functions.py | 10 +++-- 2 files changed, 46 insertions(+), 3 deletions(-) create mode 100644 benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py new file mode 100644 index 000000000000..227c0bd911ad --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/sum_floordiv_benchmark.py @@ -0,0 +1,39 @@ +import sys + +from benchmark_base import BenchmarkBase + +import torch + + +class Benchmark(BenchmarkBase): + N = 100 + + def name(self): + return "sum_floordiv_regression" + + def description(self): + return "information at https://github.com/pytorch/pytorch/issues/134133" + + def prepare_once(self): + class M(torch.nn.Module): + def forward(self, x): + total = sum(t.item() for t in x) + return total // 2 + + self.m = M() + self.input = [torch.tensor(i + 2) for i in range(self.N)] + + def prepare(self): + torch._dynamo.reset() + + def work(self): + torch.export.export(self.m, (self.input,)) + + +def main(): + result_path = sys.argv[1] + Benchmark().enable_instruction_count().collect_all().append_results(result_path) + + +if __name__ == "__main__": + main() diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 0998b5767861..d54495047e27 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -186,7 +186,8 @@ class FloorDiv(sympy.Function): # Expands (x + y) // b into x // b + y // b. # This only works if floor is an identity, i.e. x / b is an integer. - for term in sympy.Add.make_args(base): + base_args = sympy.Add.make_args(base) + for term in base_args: quotient = term / divisor if quotient.is_integer and isinstance(divisor, sympy.Integer): # NB: this is correct even if the divisor is not an integer, but it @@ -195,8 +196,11 @@ class FloorDiv(sympy.Function): return FloorDiv(base - term, divisor) + quotient try: - gcd = sympy.gcd(base, divisor) - if not equal_valued(gcd, 1): + # sympy.gcd tends to blow up on large sums, so use it on each summand instead + gcd, *gcds_ = (sympy.gcd(term, divisor) for term in base_args) + if not equal_valued(gcd, 1) and all( + equal_valued(gcd, gcd_) for gcd_ in gcds_ + ): return FloorDiv( sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) )