Fix refine_ranges corner case (#164075) (#164846)

Summary:
address https://github.com/pytorch/pytorch/issues/161360

u0>0 should update the range of u0 to start from [1, ..] this fix it. it was not doing that.

Test Plan: contbuild & OSS CI, see 27234792ad

D84038721

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164846
Approved by: https://github.com/izaitsevfb, https://github.com/ezyang
This commit is contained in:
Laith Sakka
2025-10-08 18:42:37 +00:00
committed by PyTorch MergeBot
parent 4c0fec3e4d
commit 0b85236477
5 changed files with 19 additions and 11 deletions

View File

@ -1 +1 @@
e0dda9059d082537cee36be6c5e4fe3b18c880c0
deb42f2a8e48f5032b4a98ee781a15fa87a157cf

View File

@ -6826,13 +6826,10 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertTrue(guard_failure is not None)
first_guard_failure = guard_failure[0].partition("\n")[0]
if torch._dynamo.config.assume_static_by_default:
self.assertIn(
"""tensor 'x' size mismatch at index 0. expected 2, actual 5""",
first_guard_failure,
)
else:
self.assertIn("""x.size()[0] < 3""", first_guard_failure)
self.assertIn(
"""tensor 'x' size mismatch at index 0. expected 2, actual 5""",
first_guard_failure,
)
def test_guard_failure_fn2(self):
def fn(x, y):

View File

@ -4049,6 +4049,17 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertEqual(compiled_func2(x, zero), func2(x, zero))
self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_2(self):
class M(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
return nz[-1]
mod = M()
x = torch.randn(4)
self.assertEqual(torch.compile(mod)(x), mod(x))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_index_with_check(self):
def func3(x, y):

View File

@ -1858,7 +1858,7 @@ L['a'].size()[1] > L['a'].size()[0]
show_guards(tensor),
"""\
L['a'].size()[1] < L['a'].size()[0]
L['a'].size()[0] <= 19
3 <= L['a'].size()[0] and L['a'].size()[0] <= 19
L['a'].size()[1] <= 18""")
def test_sym_storage_offset(self):

View File

@ -7870,13 +7870,13 @@ class ShapeEnv:
# sympy.Eq may update both lower and upper bounds.
# sympy.G{t,e} may update the lower bound, only.
# sympy.L{t,e} may update the upper bound, only.
if lower < rhs_vr.lower and isinstance(
if lower <= rhs_vr.lower and isinstance(
r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)
):
# Strictly greater relations allow us to refine a bit more, since
# x < y implies that the lower bound for x is: y + 1.
lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt))
if upper > rhs_vr.upper and isinstance(
if upper >= rhs_vr.upper and isinstance(
r_expr, (sympy.Eq, sympy.Le, sympy.Lt)
):
upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt))