mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
fix stuck floordiv (#134150)
Summary: Fixes https://github.com/pytorch/pytorch/issues/134133 Test Plan: Tested on the small repro in the linked issue with different lengths N (replacing 100), recording N vs. time taken in nanoseconds: 10 127268319 20 220839662 30 325463125 40 429259441 50 553136055 60 670799769 70 999170514 80 899014103 90 997168902 100 1168202035 110 1388556619 120 1457488235 130 1609816470 140 2177889877 150 1917560313 160 2121096113 170 2428502334 180 4117450755 190 4003068224 So N ~ 200 takes ~5s. Previously even smaller N would go for >1 min. Didn't add a perf test because ezyang is planning to build a benchmark. Also tested on https://www.internalfb.com/diff/D61560171, which now gets past the stuck point. Differential Revision: D61619660 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134150 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c5f6b72041
commit
92c4771853
@ -0,0 +1,39 @@
|
||||
import sys
|
||||
|
||||
from benchmark_base import BenchmarkBase
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Benchmark(BenchmarkBase):
|
||||
N = 100
|
||||
|
||||
def name(self):
|
||||
return "sum_floordiv_regression"
|
||||
|
||||
def description(self):
|
||||
return "information at https://github.com/pytorch/pytorch/issues/134133"
|
||||
|
||||
def prepare_once(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
total = sum(t.item() for t in x)
|
||||
return total // 2
|
||||
|
||||
self.m = M()
|
||||
self.input = [torch.tensor(i + 2) for i in range(self.N)]
|
||||
|
||||
def prepare(self):
|
||||
torch._dynamo.reset()
|
||||
|
||||
def work(self):
|
||||
torch.export.export(self.m, (self.input,))
|
||||
|
||||
|
||||
def main():
|
||||
result_path = sys.argv[1]
|
||||
Benchmark().enable_instruction_count().collect_all().append_results(result_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -186,7 +186,8 @@ class FloorDiv(sympy.Function):
|
||||
|
||||
# Expands (x + y) // b into x // b + y // b.
|
||||
# This only works if floor is an identity, i.e. x / b is an integer.
|
||||
for term in sympy.Add.make_args(base):
|
||||
base_args = sympy.Add.make_args(base)
|
||||
for term in base_args:
|
||||
quotient = term / divisor
|
||||
if quotient.is_integer and isinstance(divisor, sympy.Integer):
|
||||
# NB: this is correct even if the divisor is not an integer, but it
|
||||
@ -195,8 +196,11 @@ class FloorDiv(sympy.Function):
|
||||
return FloorDiv(base - term, divisor) + quotient
|
||||
|
||||
try:
|
||||
gcd = sympy.gcd(base, divisor)
|
||||
if not equal_valued(gcd, 1):
|
||||
# sympy.gcd tends to blow up on large sums, so use it on each summand instead
|
||||
gcd, *gcds_ = (sympy.gcd(term, divisor) for term in base_args)
|
||||
if not equal_valued(gcd, 1) and all(
|
||||
equal_valued(gcd, gcd_) for gcd_ in gcds_
|
||||
):
|
||||
return FloorDiv(
|
||||
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
|
||||
)
|
||||
|
Reference in New Issue
Block a user