mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use check_size instead of check_is_size in ops.py (#164668)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164668 Approved by: https://github.com/angelayi ghstack dependencies: #164664, #164665, #164667
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b58adc3bd
commit
2035f6b2e6
@ -4258,7 +4258,7 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||||||
@torch.compile(fullgraph=True)
|
@torch.compile(fullgraph=True)
|
||||||
def f(x):
|
def f(x):
|
||||||
y = x.item()
|
y = x.item()
|
||||||
torch._check_is_size(y)
|
torch._check(y >= 0)
|
||||||
if y >= 0:
|
if y >= 0:
|
||||||
return x * 2
|
return x * 2
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -11,8 +11,6 @@ class ViewTests(torch._dynamo.test_case.TestCase):
|
|||||||
def f(t, _u0):
|
def f(t, _u0):
|
||||||
u0 = t[0].item()
|
u0 = t[0].item()
|
||||||
u1 = t[1].item()
|
u1 = t[1].item()
|
||||||
torch._check_is_size(u0)
|
|
||||||
torch._check_is_size(u1)
|
|
||||||
n = u0 * u1
|
n = u0 * u1
|
||||||
a = torch.randn(n)
|
a = torch.randn(n)
|
||||||
return a.view(-1, _u0)
|
return a.view(-1, _u0)
|
||||||
@ -25,8 +23,6 @@ class ViewTests(torch._dynamo.test_case.TestCase):
|
|||||||
def f(t, _n):
|
def f(t, _n):
|
||||||
u0 = t[0].item()
|
u0 = t[0].item()
|
||||||
u1 = t[1].item()
|
u1 = t[1].item()
|
||||||
torch._check_is_size(u0)
|
|
||||||
torch._check_is_size(u1)
|
|
||||||
a = torch.randn(u0, u1)
|
a = torch.randn(u0, u1)
|
||||||
return a.view(_n)
|
return a.view(_n)
|
||||||
|
|
||||||
|
|||||||
@ -6742,6 +6742,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||||||
b = x.item()
|
b = x.item()
|
||||||
torch._check(b >= 0)
|
torch._check(b >= 0)
|
||||||
torch._check(b < y.shape[0])
|
torch._check(b < y.shape[0])
|
||||||
|
|
||||||
return y[0, b]
|
return y[0, b]
|
||||||
|
|
||||||
if is_non_strict_test(self._testMethodName):
|
if is_non_strict_test(self._testMethodName):
|
||||||
|
|||||||
@ -1412,7 +1412,6 @@ class TestDeserialize(TestCase):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x.nonzero()
|
y = x.nonzero()
|
||||||
z = y.size(0)
|
z = y.size(0)
|
||||||
torch._check_is_size(z)
|
|
||||||
torch._check(z == 2)
|
torch._check(z == 2)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@ -1423,7 +1422,6 @@ class TestDeserialize(TestCase):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x.nonzero()
|
y = x.nonzero()
|
||||||
z = y.size(0)
|
z = y.size(0)
|
||||||
torch._check_is_size(z)
|
|
||||||
torch._check(z % 3 == 0)
|
torch._check(z % 3 == 0)
|
||||||
torch._check(z == 3)
|
torch._check(z == 3)
|
||||||
return y
|
return y
|
||||||
@ -1707,7 +1705,7 @@ def forward(self, x):
|
|||||||
class Module(torch.nn.Module):
|
class Module(torch.nn.Module):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
n = x.item()
|
n = x.item()
|
||||||
torch._check_is_size(n)
|
torch._check(n >= 0)
|
||||||
return y.sum() + torch.ones(n, 5).sum()
|
return y.sum() + torch.ones(n, 5).sum()
|
||||||
|
|
||||||
f = Module()
|
f = Module()
|
||||||
@ -2222,7 +2220,8 @@ def forward(self, x):
|
|||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
n = x.item()
|
n = x.item()
|
||||||
torch._check_is_size(n, max=y.size(0) - 1)
|
torch._check(n >= 0)
|
||||||
|
torch._check(n < y.size(0))
|
||||||
return torch.empty(n), y[n]
|
return torch.empty(n), y[n]
|
||||||
|
|
||||||
ep = torch.export.export(
|
ep = torch.export.export(
|
||||||
|
|||||||
@ -222,7 +222,7 @@ class TestSourceMatcher(JitTestCase):
|
|||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
b = x.item()
|
b = x.item()
|
||||||
torch._check_is_size(b)
|
torch._check(b >= 0)
|
||||||
torch._check(b + 1 < y.size(0))
|
torch._check(b + 1 < y.size(0))
|
||||||
return y[: b + 1]
|
return y[: b + 1]
|
||||||
|
|
||||||
|
|||||||
@ -1604,11 +1604,30 @@ class GraphModule(torch.nn.Module):
|
|||||||
self.assertEqual(ref, res)
|
self.assertEqual(ref, res)
|
||||||
|
|
||||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_unbacked(self):
|
def test_unbacked1(self):
|
||||||
@nested_compile_region
|
@nested_compile_region
|
||||||
def gn(x, y):
|
def gn(x, y):
|
||||||
b = x.item()
|
b = x.item()
|
||||||
torch._check_is_size(b)
|
return y[:b].clone()
|
||||||
|
|
||||||
|
def fn(x, y):
|
||||||
|
return gn(x, y)
|
||||||
|
|
||||||
|
x = torch.tensor(4)
|
||||||
|
y = torch.randn(8)
|
||||||
|
ref = fn(x, y)
|
||||||
|
opt_fn = torch.compile(
|
||||||
|
fn, backend="eager", fullgraph=True
|
||||||
|
) # Inductor fails with assertion error when lowering aten.sym_constrain_range_for_size.default
|
||||||
|
res = opt_fn(x, y)
|
||||||
|
self.assertEqual(ref, res)
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
|
def test_unbacked2(self):
|
||||||
|
@nested_compile_region
|
||||||
|
def gn(x, y):
|
||||||
|
b = x.item()
|
||||||
|
torch._check(b >= 0)
|
||||||
torch._check(b < y.shape[0])
|
torch._check(b < y.shape[0])
|
||||||
return y[:b].clone()
|
return y[:b].clone()
|
||||||
|
|
||||||
|
|||||||
@ -1732,7 +1732,6 @@ class AOTInductorTestsTemplate:
|
|||||||
|
|
||||||
backed = z.size(0)
|
backed = z.size(0)
|
||||||
unbacked = scalar.item()
|
unbacked = scalar.item()
|
||||||
torch._check_is_size(unbacked)
|
|
||||||
|
|
||||||
unbacked_add_expr = backed + unbacked
|
unbacked_add_expr = backed + unbacked
|
||||||
repeated = x.repeat(unbacked_add_expr, 1)
|
repeated = x.repeat(unbacked_add_expr, 1)
|
||||||
@ -1771,8 +1770,6 @@ class AOTInductorTestsTemplate:
|
|||||||
index_select = torch.index_select(embeddings, 0, index)
|
index_select = torch.index_select(embeddings, 0, index)
|
||||||
|
|
||||||
u0, u1 = lst.tolist()
|
u0, u1 = lst.tolist()
|
||||||
torch._check_is_size(u0)
|
|
||||||
torch._check_is_size(u1)
|
|
||||||
backed0, backed1 = z.size(0), z.size(1)
|
backed0, backed1 = z.size(0), z.size(1)
|
||||||
|
|
||||||
repeated0 = y.repeat(backed0 + u0, 1)
|
repeated0 = y.repeat(backed0 + u0, 1)
|
||||||
@ -1822,9 +1819,6 @@ class AOTInductorTestsTemplate:
|
|||||||
class Repro(torch.nn.Module):
|
class Repro(torch.nn.Module):
|
||||||
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
|
def forward(self, values, repeats, mask, embeddings, x, y, z, lst):
|
||||||
u0, u1, u2 = lst.tolist()
|
u0, u1, u2 = lst.tolist()
|
||||||
torch._check_is_size(u0)
|
|
||||||
torch._check_is_size(u1)
|
|
||||||
torch._check_is_size(u2)
|
|
||||||
backed = z.size(0)
|
backed = z.size(0)
|
||||||
backed1 = z.size(1)
|
backed1 = z.size(1)
|
||||||
|
|
||||||
|
|||||||
@ -1139,7 +1139,7 @@ def unbind_int(func, *args, **kwargs):
|
|||||||
ragged_idx = inp._ragged_idx
|
ragged_idx = inp._ragged_idx
|
||||||
|
|
||||||
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
|
def _torch_check(_lengths: list[int], _offsets: Optional[list[int]] = None):
|
||||||
# This torch._check and torch._check_is_size are needed for torch.compile
|
# This torch._check are needed for torch.compile
|
||||||
# symbolic shapes processing.
|
# symbolic shapes processing.
|
||||||
# offsets and lengths are symbolic variables during compilation,
|
# offsets and lengths are symbolic variables during compilation,
|
||||||
# we guarantee the correct offsets/lengths correspondence:
|
# we guarantee the correct offsets/lengths correspondence:
|
||||||
@ -1151,7 +1151,7 @@ def unbind_int(func, *args, **kwargs):
|
|||||||
lengths_sum = 0
|
lengths_sum = 0
|
||||||
ragged_dim_size = values.shape[ragged_idx - 1]
|
ragged_dim_size = values.shape[ragged_idx - 1]
|
||||||
for i in range(len(_lengths)):
|
for i in range(len(_lengths)):
|
||||||
torch._check_is_size(_lengths[i])
|
torch._check(_lengths[i] >= 0)
|
||||||
torch._check(_lengths[i] <= ragged_dim_size)
|
torch._check(_lengths[i] <= ragged_dim_size)
|
||||||
|
|
||||||
lengths_sum += _lengths[i]
|
lengths_sum += _lengths[i]
|
||||||
@ -1164,7 +1164,7 @@ def unbind_int(func, *args, **kwargs):
|
|||||||
|
|
||||||
if _offsets is not None:
|
if _offsets is not None:
|
||||||
for i in range(len(_offsets)):
|
for i in range(len(_offsets)):
|
||||||
torch._check_is_size(_offsets[i])
|
torch._check(_offsets[i] >= 0)
|
||||||
torch._check(_offsets[i] <= ragged_dim_size)
|
torch._check(_offsets[i] <= ragged_dim_size)
|
||||||
|
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user