Fix refine_ranges corner case (#164075) (#164846)

Summary:
address https://github.com/pytorch/pytorch/issues/161360

u0>0 should update the range of u0 to start from [1, ..] this fix it. it was not doing that.

Test Plan: contbuild & OSS CI, see 27234792ad

D84038721

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164846
Approved by: https://github.com/izaitsevfb, https://github.com/ezyang
This commit is contained in:
Laith Sakka
2025-10-08 18:42:37 +00:00
committed by PyTorch MergeBot
parent 4c0fec3e4d
commit 0b85236477
5 changed files with 19 additions and 11 deletions

View File

@ -1 +1 @@
e0dda9059d082537cee36be6c5e4fe3b18c880c0 deb42f2a8e48f5032b4a98ee781a15fa87a157cf

View File

@ -6826,13 +6826,10 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertTrue(guard_failure is not None) self.assertTrue(guard_failure is not None)
first_guard_failure = guard_failure[0].partition("\n")[0] first_guard_failure = guard_failure[0].partition("\n")[0]
if torch._dynamo.config.assume_static_by_default: self.assertIn(
self.assertIn( """tensor 'x' size mismatch at index 0. expected 2, actual 5""",
"""tensor 'x' size mismatch at index 0. expected 2, actual 5""", first_guard_failure,
first_guard_failure, )
)
else:
self.assertIn("""x.size()[0] < 3""", first_guard_failure)
def test_guard_failure_fn2(self): def test_guard_failure_fn2(self):
def fn(x, y): def fn(x, y):

View File

@ -4049,6 +4049,17 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
self.assertEqual(compiled_func2(x, zero), func2(x, zero)) self.assertEqual(compiled_func2(x, zero), func2(x, zero))
self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.frame_count, 2)
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_2(self):
class M(torch.nn.Module):
def forward(self, x):
nz = x.nonzero()
return nz[-1]
mod = M()
x = torch.randn(4)
self.assertEqual(torch.compile(mod)(x), mod(x))
@torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_select_index_with_check(self): def test_unbacked_select_index_with_check(self):
def func3(x, y): def func3(x, y):

View File

@ -1858,7 +1858,7 @@ L['a'].size()[1] > L['a'].size()[0]
show_guards(tensor), show_guards(tensor),
"""\ """\
L['a'].size()[1] < L['a'].size()[0] L['a'].size()[1] < L['a'].size()[0]
L['a'].size()[0] <= 19 3 <= L['a'].size()[0] and L['a'].size()[0] <= 19
L['a'].size()[1] <= 18""") L['a'].size()[1] <= 18""")
def test_sym_storage_offset(self): def test_sym_storage_offset(self):

View File

@ -7870,13 +7870,13 @@ class ShapeEnv:
# sympy.Eq may update both lower and upper bounds. # sympy.Eq may update both lower and upper bounds.
# sympy.G{t,e} may update the lower bound, only. # sympy.G{t,e} may update the lower bound, only.
# sympy.L{t,e} may update the upper bound, only. # sympy.L{t,e} may update the upper bound, only.
if lower < rhs_vr.lower and isinstance( if lower <= rhs_vr.lower and isinstance(
r_expr, (sympy.Eq, sympy.Ge, sympy.Gt) r_expr, (sympy.Eq, sympy.Ge, sympy.Gt)
): ):
# Strictly greater relations allow us to refine a bit more, since # Strictly greater relations allow us to refine a bit more, since
# x < y implies that the lower bound for x is: y + 1. # x < y implies that the lower bound for x is: y + 1.
lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt)) lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt))
if upper > rhs_vr.upper and isinstance( if upper >= rhs_vr.upper and isinstance(
r_expr, (sympy.Eq, sympy.Le, sympy.Lt) r_expr, (sympy.Eq, sympy.Le, sympy.Lt)
): ):
upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt)) upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt))