remove no longer needed torch._check_is_size calls from test_dynamic_shapes (#164627)

No longer needed in those tests to prevent DDE

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164627
Approved by: https://github.com/ezyang
This commit is contained in:
Laith Sakka
2025-10-03 17:20:48 -07:00
committed by PyTorch MergeBot
parent 9fc2c6446d
commit 8c728e129d

View File

@ -870,7 +870,6 @@ def forward(self, x_1):
def test_non_overlapping_and_dense_unbacked(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
torch._check_is_size(u0)
cf = torch.ops.aten.is_non_overlapping_and_dense.default
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
@ -906,7 +905,6 @@ def forward(self, x_1):
# unbacked
u0 = shape_env.create_unbacked_symint()
torch._check_is_size(u0)
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
@ -934,7 +932,6 @@ def forward(self, x_1):
# return False on arbitrary strides
u1 = shape_env.create_unbacked_symint()
torch._check_is_size(u1)
self.assertFalse(
cf(
torch.empty_strided(
@ -1130,7 +1127,6 @@ def forward(self, x_1):
def test_debug_has_internal_overlap_unbacked(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
torch._check_is_size(u0)
cf = torch._debug_has_internal_overlap
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
@ -3448,8 +3444,6 @@ class TestUbackedOps(TestCase):
t1 = x.view((f, f))
t2 = x.reshape((f, f))
t3 = torch._ops.ops.aten.view_copy(x, (f, f))
# TODO avoid _check_is_size here.
torch._check_is_size(f)
return t1 * 10, t2 * 10, t3
compiled_func = torch.compile(
@ -3555,8 +3549,6 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
# reshape (u2, u3) -> (u0, u1)
def func(x, y):
u0, u1 = y.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
result1 = torch.reshape(x, (u0, u1))
return result1 * 10
@ -3591,14 +3583,20 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
ge_5: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
sym_sum: "Sym(u0 + 1)" = torch.sym_sum((1, _local_scalar_dense))
gt: "Sym(u0 + 1 > 0)" = sym_sum > 0; sym_sum = None
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 0 < u0 + 1 on node 'gt'"); gt = _assert_scalar_3 = None
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
ge_7: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_3 = None
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_4 = None
sym_sum_1: "Sym(u1 + 1)" = torch.sym_sum((1, _local_scalar_dense_1))
gt_1: "Sym(u1 + 1 > 0)" = sym_sum_1 > 0; sym_sum_1 = None
_assert_scalar_5 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 + 1 on node 'gt_1'"); gt_1 = _assert_scalar_5 = None
mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None
mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1
eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
_assert_scalar_6 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_6 = None
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
@ -3659,8 +3657,12 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
# standard slice
def f1(x, xs):
u0, u1 = xs.tolist()
torch._check_is_size(u0, max=x.size(0))
torch._check_is_size(u1, max=x.size(0))
# in this test we add the torch checks not to avoid DDE but to ensure
# that we pick specific path during compilation.
torch._check(u0 >= 0)
torch._check(u0 <= x.size(0))
torch._check(u1 >= 0)
torch._check(u1 <= x.size(0))
torch._check(u0 <= u1)
out = x[u0:u1]
assert statically_known_true(out.size(0) == (u1 - u0))
@ -3880,9 +3882,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
# hints to to compute strides.
def func(x, y):
u0, u1 = y.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
result2 = x.view(u0, u1) * 10
return result2