use check_size instead of check_is_size in ops.py (#164668)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164668
Approved by: https://github.com/angelayi
ghstack dependencies: #164664, #164665, #164667
This commit is contained in:
Laith Sakka
2025-10-07 13:13:33 -07:00
committed by PyTorch MergeBot
parent 2b58adc3bd
commit 2035f6b2e6
8 changed files with 30 additions and 21 deletions

View File

@ -4258,7 +4258,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
@torch.compile(fullgraph=True)
def f(x):
y = x.item()
torch._check_is_size(y)
torch._check(y >= 0)
if y >= 0:
return x * 2
else:

View File

@ -11,8 +11,6 @@ class ViewTests(torch._dynamo.test_case.TestCase):
def f(t, _u0):
u0 = t[0].item()
u1 = t[1].item()
torch._check_is_size(u0)
torch._check_is_size(u1)
n = u0 * u1
a = torch.randn(n)
return a.view(-1, _u0)
@ -25,8 +23,6 @@ class ViewTests(torch._dynamo.test_case.TestCase):
def f(t, _n):
u0 = t[0].item()
u1 = t[1].item()
torch._check_is_size(u0)
torch._check_is_size(u1)
a = torch.randn(u0, u1)
return a.view(_n)

View File

@ -6742,6 +6742,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
b = x.item()
torch._check(b >= 0)
torch._check(b < y.shape[0])
return y[0, b]
if is_non_strict_test(self._testMethodName):

View File

@ -1412,7 +1412,6 @@ class TestDeserialize(TestCase):
def forward(self, x):
y = x.nonzero()
z = y.size(0)
torch._check_is_size(z)
torch._check(z == 2)
return y
@ -1423,7 +1422,6 @@ class TestDeserialize(TestCase):
def forward(self, x):
y = x.nonzero()
z = y.size(0)
torch._check_is_size(z)
torch._check(z % 3 == 0)
torch._check(z == 3)
return y
@ -1707,7 +1705,7 @@ def forward(self, x):
class Module(torch.nn.Module):
def forward(self, x, y):
n = x.item()
torch._check_is_size(n)
torch._check(n >= 0)
return y.sum() + torch.ones(n, 5).sum()
f = Module()
@ -2222,7 +2220,8 @@ def forward(self, x):
class Foo(torch.nn.Module):
def forward(self, x, y):
n = x.item()
torch._check_is_size(n, max=y.size(0) - 1)
torch._check(n >= 0)
torch._check(n < y.size(0))
return torch.empty(n), y[n]
ep = torch.export.export(

View File

@ -222,7 +222,7 @@ class TestSourceMatcher(JitTestCase):
class M(torch.nn.Module):
def forward(self, x, y):
b = x.item()
torch._check_is_size(b)
torch._check(b >= 0)
torch._check(b + 1 < y.size(0))
return y[: b + 1]

View File

@ -1604,11 +1604,30 @@ class GraphModule(torch.nn.Module):
self.assertEqual(ref, res)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_unbacked(self):
def test_unbacked1(self):
@nested_compile_region
def gn(x, y):
b = x.item()
torch._check_is_size(b)
return y[:b].clone()
def fn(x, y):
return gn(x, y)
x = torch.tensor(4)
y = torch.randn(8)
ref = fn(x, y)
opt_fn = torch.compile(
fn, backend="eager", fullgraph=True
) # Inductor fails with assertion error when lowering aten.sym_constrain_range_for_size.default
res = opt_fn(x, y)
self.assertEqual(ref, res)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_unbacked2(self):
@nested_compile_region
def gn(x, y):
b = x.item()
torch._check(b >= 0)
torch._check(b < y.shape[0])
return y[:b].clone()

View File

@ -1732,7 +1732,6 @@ class AOTInductorTestsTemplate:
backed = z.size(0)
unbacked = scalar.item()
torch._check_is_size(unbacked)
unbacked_add_expr = backed + unbacked
repeated = x.repeat(unbacked_add_expr, 1)
@ -1771,8 +1770,6 @@ class AOTInductorTestsTemplate:
index_select = torch.index_select(embeddings, 0, index)
u0, u1 = lst.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
backed0, backed1 = z.size(0), z.size(1)
repeated0 = y.repeat(backed0 + u0, 1)
@ -1822,9 +1819,6 @@ class AOTInductorTestsTemplate:
class Repro(torch.nn.Module):
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
u0, u1, u2 = lst.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
torch._check_is_size(u2)
backed = z.size(0)
backed1 = z.size(1)

View File

@ -1139,7 +1139,7 @@ def unbind_int(func, *args, **kwargs):
ragged_idx = inp._ragged_idx
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
# This torch._check and torch._check_is_size are needed for torch.compile
# This torch._check are needed for torch.compile
# symbolic shapes processing.
# offsets and lengths are symbolic variables during compilation,
# we guarantee the correct offsets/lengths correspondence:
@ -1151,7 +1151,7 @@ def unbind_int(func, *args, **kwargs):
lengths_sum = 0
ragged_dim_size = values.shape[ragged_idx - 1]
for i in range(len(_lengths)):
torch._check_is_size(_lengths[i])
torch._check(_lengths[i] >= 0)
torch._check(_lengths[i] <= ragged_dim_size)
lengths_sum += _lengths[i]
@ -1164,7 +1164,7 @@ def unbind_int(func, *args, **kwargs):
if _offsets is not None:
for i in range(len(_offsets)):
torch._check_is_size(_offsets[i])
torch._check(_offsets[i] >= 0)
torch._check(_offsets[i] <= ragged_dim_size)
if lengths is None: