Use integer divison in arange length calculation when start/end/step are integral (#134296)

Fixes #133338

Test Plan:

```
TORCH_LOGS=dynamic python
import torch

torch._dynamo.config.capture_scalar_outputs = True

@torch.compile()
def f(x):
    y = x.item()
    torch._check_is_size(y)
    r = torch.arange(y, dtype=torch.float32)
    torch._check(r.size(0) == y)
    return r

f(torch.tensor([300]))
```

Before and after diff. Verify the following line

```
I0813 11:05:44.890000 652898 torch/fx/experimental/symbolic_shapes.py:5198] [0/0] runtime_assert Eq(CeilToInt(IntTrueDiv(u0, 1)), u0) [guard added] at aa.py:10 in f (_dynamo/utils.py:2092 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(CeilToInt(IntTrueDiv(u0, 1)), u0)"
```

no longer shows in the logs. Also verify CI passes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134296
Approved by: https://github.com/aorenste
This commit is contained in:
Bob Ren
2024-08-22 17:12:12 -07:00
committed by PyTorch MergeBot
parent 1a0d00f1f4
commit 94f92fbd88
2 changed files with 3 additions and 3 deletions

View File

@ -964,7 +964,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
)
# (dynamic shapes, static shapes)
self.assertIn(cnt.frame_count, (5, 7))
self.assertIn(cnt.op_count, (94, 106, 121))
self.assertIn(cnt.op_count, (92, 106, 119))
def test_convert_boxes_to_pooler_format(self):
boxes1 = [

View File

@ -4989,14 +4989,14 @@ def arange(
dtype = torch.int64 if integer_args else torch.get_default_dtype()
is_integer = utils.is_integer_dtype(dtype)
if is_integer:
if is_integer or integer_args:
xstart = sym_int(start)
xend = sym_int(end)
xstep = sym_int(step)
# For int64 we truncate arguments to int before calculating length, but
# other integral dtypes we don't. Weird... but needed to match ATen shapes.
if dtype == torch.int64:
if dtype == torch.int64 or integer_args:
# Uses floordiv to avoid ceil in inductor.
sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined]
length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]