mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use integer divison in arange length calculation when start/end/step are integral (#134296)
Fixes #133338 Test Plan: ``` TORCH_LOGS=dynamic python import torch torch._dynamo.config.capture_scalar_outputs = True @torch.compile() def f(x): y = x.item() torch._check_is_size(y) r = torch.arange(y, dtype=torch.float32) torch._check(r.size(0) == y) return r f(torch.tensor([300])) ``` Before and after diff. Verify the following line ``` I0813 11:05:44.890000 652898 torch/fx/experimental/symbolic_shapes.py:5198] [0/0] runtime_assert Eq(CeilToInt(IntTrueDiv(u0, 1)), u0) [guard added] at aa.py:10 in f (_dynamo/utils.py:2092 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(CeilToInt(IntTrueDiv(u0, 1)), u0)" ``` no longer shows in the logs. Also verify CI passes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134296 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
1a0d00f1f4
commit
94f92fbd88
@ -964,7 +964,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||||||
)
|
)
|
||||||
# (dynamic shapes, static shapes)
|
# (dynamic shapes, static shapes)
|
||||||
self.assertIn(cnt.frame_count, (5, 7))
|
self.assertIn(cnt.frame_count, (5, 7))
|
||||||
self.assertIn(cnt.op_count, (94, 106, 121))
|
self.assertIn(cnt.op_count, (92, 106, 119))
|
||||||
|
|
||||||
def test_convert_boxes_to_pooler_format(self):
|
def test_convert_boxes_to_pooler_format(self):
|
||||||
boxes1 = [
|
boxes1 = [
|
||||||
|
@ -4989,14 +4989,14 @@ def arange(
|
|||||||
dtype = torch.int64 if integer_args else torch.get_default_dtype()
|
dtype = torch.int64 if integer_args else torch.get_default_dtype()
|
||||||
|
|
||||||
is_integer = utils.is_integer_dtype(dtype)
|
is_integer = utils.is_integer_dtype(dtype)
|
||||||
if is_integer:
|
if is_integer or integer_args:
|
||||||
xstart = sym_int(start)
|
xstart = sym_int(start)
|
||||||
xend = sym_int(end)
|
xend = sym_int(end)
|
||||||
xstep = sym_int(step)
|
xstep = sym_int(step)
|
||||||
|
|
||||||
# For int64 we truncate arguments to int before calculating length, but
|
# For int64 we truncate arguments to int before calculating length, but
|
||||||
# other integral dtypes we don't. Weird... but needed to match ATen shapes.
|
# other integral dtypes we don't. Weird... but needed to match ATen shapes.
|
||||||
if dtype == torch.int64:
|
if dtype == torch.int64 or integer_args:
|
||||||
# Uses floordiv to avoid ceil in inductor.
|
# Uses floordiv to avoid ceil in inductor.
|
||||||
sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined]
|
sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined]
|
||||||
length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]
|
length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined]
|
||||||
|
Reference in New Issue
Block a user