mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
remove no longer needed torch._check_is_size calls from test_dynamic_shapes (#164627)
No longer needed in those tests to prevent DDE Pull Request resolved: https://github.com/pytorch/pytorch/pull/164627 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
9fc2c6446d
commit
8c728e129d
@ -870,7 +870,6 @@ def forward(self, x_1):
|
|||||||
def test_non_overlapping_and_dense_unbacked(self):
|
def test_non_overlapping_and_dense_unbacked(self):
|
||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
u0 = shape_env.create_unbacked_symint()
|
u0 = shape_env.create_unbacked_symint()
|
||||||
torch._check_is_size(u0)
|
|
||||||
cf = torch.ops.aten.is_non_overlapping_and_dense.default
|
cf = torch.ops.aten.is_non_overlapping_and_dense.default
|
||||||
|
|
||||||
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
|
||||||
@ -906,7 +905,6 @@ def forward(self, x_1):
|
|||||||
|
|
||||||
# unbacked
|
# unbacked
|
||||||
u0 = shape_env.create_unbacked_symint()
|
u0 = shape_env.create_unbacked_symint()
|
||||||
torch._check_is_size(u0)
|
|
||||||
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
|
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
|
||||||
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
|
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
|
||||||
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
|
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
|
||||||
@ -934,7 +932,6 @@ def forward(self, x_1):
|
|||||||
|
|
||||||
# return False on arbitrary strides
|
# return False on arbitrary strides
|
||||||
u1 = shape_env.create_unbacked_symint()
|
u1 = shape_env.create_unbacked_symint()
|
||||||
torch._check_is_size(u1)
|
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
cf(
|
cf(
|
||||||
torch.empty_strided(
|
torch.empty_strided(
|
||||||
@ -1130,7 +1127,6 @@ def forward(self, x_1):
|
|||||||
def test_debug_has_internal_overlap_unbacked(self):
|
def test_debug_has_internal_overlap_unbacked(self):
|
||||||
shape_env = ShapeEnv()
|
shape_env = ShapeEnv()
|
||||||
u0 = shape_env.create_unbacked_symint()
|
u0 = shape_env.create_unbacked_symint()
|
||||||
torch._check_is_size(u0)
|
|
||||||
cf = torch._debug_has_internal_overlap
|
cf = torch._debug_has_internal_overlap
|
||||||
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
|
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
|
||||||
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
|
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
|
||||||
@ -3448,8 +3444,6 @@ class TestUbackedOps(TestCase):
|
|||||||
t1 = x.view((f, f))
|
t1 = x.view((f, f))
|
||||||
t2 = x.reshape((f, f))
|
t2 = x.reshape((f, f))
|
||||||
t3 = torch._ops.ops.aten.view_copy(x, (f, f))
|
t3 = torch._ops.ops.aten.view_copy(x, (f, f))
|
||||||
# TODO avoid _check_is_size here.
|
|
||||||
torch._check_is_size(f)
|
|
||||||
return t1 * 10, t2 * 10, t3
|
return t1 * 10, t2 * 10, t3
|
||||||
|
|
||||||
compiled_func = torch.compile(
|
compiled_func = torch.compile(
|
||||||
@ -3555,8 +3549,6 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]
|
|||||||
# reshape (u2, u3) -> (u0, u1)
|
# reshape (u2, u3) -> (u0, u1)
|
||||||
def func(x, y):
|
def func(x, y):
|
||||||
u0, u1 = y.tolist()
|
u0, u1 = y.tolist()
|
||||||
torch._check_is_size(u0)
|
|
||||||
torch._check_is_size(u1)
|
|
||||||
|
|
||||||
result1 = torch.reshape(x, (u0, u1))
|
result1 = torch.reshape(x, (u0, u1))
|
||||||
return result1 * 10
|
return result1 * 10
|
||||||
@ -3591,14 +3583,20 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||||||
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
|
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
|
||||||
ge_5: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
ge_5: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
||||||
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
|
||||||
|
sym_sum: "Sym(u0 + 1)" = torch.sym_sum((1, _local_scalar_dense))
|
||||||
|
gt: "Sym(u0 + 1 > 0)" = sym_sum > 0; sym_sum = None
|
||||||
|
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 0 < u0 + 1 on node 'gt'"); gt = _assert_scalar_3 = None
|
||||||
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
||||||
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
|
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
|
||||||
ge_7: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
|
ge_7: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
|
||||||
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_3 = None
|
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_4 = None
|
||||||
|
sym_sum_1: "Sym(u1 + 1)" = torch.sym_sum((1, _local_scalar_dense_1))
|
||||||
|
gt_1: "Sym(u1 + 1 > 0)" = sym_sum_1 > 0; sym_sum_1 = None
|
||||||
|
_assert_scalar_5 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 + 1 on node 'gt_1'"); gt_1 = _assert_scalar_5 = None
|
||||||
mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None
|
mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None
|
||||||
mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1
|
mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1
|
||||||
eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None
|
eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None
|
||||||
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
|
_assert_scalar_6 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_6 = None
|
||||||
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
|
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
|
||||||
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
|
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
|
||||||
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
||||||
@ -3659,8 +3657,12 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||||||
# standard slice
|
# standard slice
|
||||||
def f1(x, xs):
|
def f1(x, xs):
|
||||||
u0, u1 = xs.tolist()
|
u0, u1 = xs.tolist()
|
||||||
torch._check_is_size(u0, max=x.size(0))
|
# in this test we add the torch checks not to avoid DDE but to ensure
|
||||||
torch._check_is_size(u1, max=x.size(0))
|
# that we pick specific path during compilation.
|
||||||
|
torch._check(u0 >= 0)
|
||||||
|
torch._check(u0 <= x.size(0))
|
||||||
|
torch._check(u1 >= 0)
|
||||||
|
torch._check(u1 <= x.size(0))
|
||||||
torch._check(u0 <= u1)
|
torch._check(u0 <= u1)
|
||||||
out = x[u0:u1]
|
out = x[u0:u1]
|
||||||
assert statically_known_true(out.size(0) == (u1 - u0))
|
assert statically_known_true(out.size(0) == (u1 - u0))
|
||||||
@ -3880,9 +3882,6 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
|||||||
# hints to to compute strides.
|
# hints to to compute strides.
|
||||||
def func(x, y):
|
def func(x, y):
|
||||||
u0, u1 = y.tolist()
|
u0, u1 = y.tolist()
|
||||||
torch._check_is_size(u0)
|
|
||||||
torch._check_is_size(u1)
|
|
||||||
|
|
||||||
result2 = x.view(u0, u1) * 10
|
result2 = x.view(u0, u1) * 10
|
||||||
return result2
|
return result2
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user