mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use check_size instead of check_is_size in ops.py (#164668)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164668 Approved by: https://github.com/angelayi ghstack dependencies: #164664, #164665, #164667
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b58adc3bd
commit
2035f6b2e6
@ -4258,7 +4258,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
@torch.compile(fullgraph=True)
|
||||
def f(x):
|
||||
y = x.item()
|
||||
torch._check_is_size(y)
|
||||
torch._check(y >= 0)
|
||||
if y >= 0:
|
||||
return x * 2
|
||||
else:
|
||||
|
@ -11,8 +11,6 @@ class ViewTests(torch._dynamo.test_case.TestCase):
|
||||
def f(t, _u0):
|
||||
u0 = t[0].item()
|
||||
u1 = t[1].item()
|
||||
torch._check_is_size(u0)
|
||||
torch._check_is_size(u1)
|
||||
n = u0 * u1
|
||||
a = torch.randn(n)
|
||||
return a.view(-1, _u0)
|
||||
@ -25,8 +23,6 @@ class ViewTests(torch._dynamo.test_case.TestCase):
|
||||
def f(t, _n):
|
||||
u0 = t[0].item()
|
||||
u1 = t[1].item()
|
||||
torch._check_is_size(u0)
|
||||
torch._check_is_size(u1)
|
||||
a = torch.randn(u0, u1)
|
||||
return a.view(_n)
|
||||
|
||||
|
@ -6742,6 +6742,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
b = x.item()
|
||||
torch._check(b >= 0)
|
||||
torch._check(b < y.shape[0])
|
||||
|
||||
return y[0, b]
|
||||
|
||||
if is_non_strict_test(self._testMethodName):
|
||||
|
@ -1412,7 +1412,6 @@ class TestDeserialize(TestCase):
|
||||
def forward(self, x):
|
||||
y = x.nonzero()
|
||||
z = y.size(0)
|
||||
torch._check_is_size(z)
|
||||
torch._check(z == 2)
|
||||
return y
|
||||
|
||||
@ -1423,7 +1422,6 @@ class TestDeserialize(TestCase):
|
||||
def forward(self, x):
|
||||
y = x.nonzero()
|
||||
z = y.size(0)
|
||||
torch._check_is_size(z)
|
||||
torch._check(z % 3 == 0)
|
||||
torch._check(z == 3)
|
||||
return y
|
||||
@ -1707,7 +1705,7 @@ def forward(self, x):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
n = x.item()
|
||||
torch._check_is_size(n)
|
||||
torch._check(n >= 0)
|
||||
return y.sum() + torch.ones(n, 5).sum()
|
||||
|
||||
f = Module()
|
||||
@ -2222,7 +2220,8 @@ def forward(self, x):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
n = x.item()
|
||||
torch._check_is_size(n, max=y.size(0) - 1)
|
||||
torch._check(n >= 0)
|
||||
torch._check(n < y.size(0))
|
||||
return torch.empty(n), y[n]
|
||||
|
||||
ep = torch.export.export(
|
||||
|
@ -222,7 +222,7 @@ class TestSourceMatcher(JitTestCase):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
b = x.item()
|
||||
torch._check_is_size(b)
|
||||
torch._check(b >= 0)
|
||||
torch._check(b + 1 < y.size(0))
|
||||
return y[: b + 1]
|
||||
|
||||
|
@ -1604,11 +1604,30 @@ class GraphModule(torch.nn.Module):
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_unbacked(self):
|
||||
def test_unbacked1(self):
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
b = x.item()
|
||||
torch._check_is_size(b)
|
||||
return y[:b].clone()
|
||||
|
||||
def fn(x, y):
|
||||
return gn(x, y)
|
||||
|
||||
x = torch.tensor(4)
|
||||
y = torch.randn(8)
|
||||
ref = fn(x, y)
|
||||
opt_fn = torch.compile(
|
||||
fn, backend="eager", fullgraph=True
|
||||
) # Inductor fails with assertion error when lowering aten.sym_constrain_range_for_size.default
|
||||
res = opt_fn(x, y)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_unbacked2(self):
|
||||
@nested_compile_region
|
||||
def gn(x, y):
|
||||
b = x.item()
|
||||
torch._check(b >= 0)
|
||||
torch._check(b < y.shape[0])
|
||||
return y[:b].clone()
|
||||
|
||||
|
@ -1732,7 +1732,6 @@ class AOTInductorTestsTemplate:
|
||||
|
||||
backed = z.size(0)
|
||||
unbacked = scalar.item()
|
||||
torch._check_is_size(unbacked)
|
||||
|
||||
unbacked_add_expr = backed + unbacked
|
||||
repeated = x.repeat(unbacked_add_expr, 1)
|
||||
@ -1771,8 +1770,6 @@ class AOTInductorTestsTemplate:
|
||||
index_select = torch.index_select(embeddings, 0, index)
|
||||
|
||||
u0, u1 = lst.tolist()
|
||||
torch._check_is_size(u0)
|
||||
torch._check_is_size(u1)
|
||||
backed0, backed1 = z.size(0), z.size(1)
|
||||
|
||||
repeated0 = y.repeat(backed0 + u0, 1)
|
||||
@ -1822,9 +1819,6 @@ class AOTInductorTestsTemplate:
|
||||
class Repro(torch.nn.Module):
|
||||
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
|
||||
u0, u1, u2 = lst.tolist()
|
||||
torch._check_is_size(u0)
|
||||
torch._check_is_size(u1)
|
||||
torch._check_is_size(u2)
|
||||
backed = z.size(0)
|
||||
backed1 = z.size(1)
|
||||
|
||||
|
@ -1139,7 +1139,7 @@ def unbind_int(func, *args, **kwargs):
|
||||
ragged_idx = inp._ragged_idx
|
||||
|
||||
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
|
||||
# This torch._check and torch._check_is_size are needed for torch.compile
|
||||
# This torch._check are needed for torch.compile
|
||||
# symbolic shapes processing.
|
||||
# offsets and lengths are symbolic variables during compilation,
|
||||
# we guarantee the correct offsets/lengths correspondence:
|
||||
@ -1151,7 +1151,7 @@ def unbind_int(func, *args, **kwargs):
|
||||
lengths_sum = 0
|
||||
ragged_dim_size = values.shape[ragged_idx - 1]
|
||||
for i in range(len(_lengths)):
|
||||
torch._check_is_size(_lengths[i])
|
||||
torch._check(_lengths[i] >= 0)
|
||||
torch._check(_lengths[i] <= ragged_dim_size)
|
||||
|
||||
lengths_sum += _lengths[i]
|
||||
@ -1164,7 +1164,7 @@ def unbind_int(func, *args, **kwargs):
|
||||
|
||||
if _offsets is not None:
|
||||
for i in range(len(_offsets)):
|
||||
torch._check_is_size(_offsets[i])
|
||||
torch._check(_offsets[i] >= 0)
|
||||
torch._check(_offsets[i] <= ragged_dim_size)
|
||||
|
||||
if lengths is None:
|
||||
|
Reference in New Issue
Block a user