mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127)
For reshape/view: removes fast paths for 0 elements, checking dimensions to skip. Modifies the loop accumulating input elements, to raise a UserError if we run out of dimensions, graph breaking for compile and erroring out for export. For infer_size: assumes if user passes us an unbacked, it's probably not -1 Will think about changes in https://docs.google.com/document/d/1WYx6EZwVDXtBnWyrzoecgGWdiK0V3XZKftfpWwQ5i3E/edit?tab=t.0#heading=h.22k54zym11qp in a later PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/150127 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
b37fa20771
commit
54f736155b
@ -323,7 +323,7 @@ class TestDraftExport(TestCase):
|
||||
self.assertEqual(
|
||||
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
|
||||
)
|
||||
self.assertEqual(report.failures[0].data["expr"], "Eq(2*u1, 10)")
|
||||
self.assertEqual(report.failures[0].data["expr"], "Eq(9380*u1, 0)")
|
||||
|
||||
def test_dedup_data_dependent_failure(self):
|
||||
class M(torch.nn.Module):
|
||||
@ -480,6 +480,7 @@ class TestDraftExport(TestCase):
|
||||
return torch.nn.functional.linear(masked, weight, bias)
|
||||
|
||||
x = torch.zeros(10)
|
||||
x[0] += 1
|
||||
inp = (torch.randn(10, 8, 7), x, torch.randn(25, 7), torch.randn(25))
|
||||
draft_ep = draft_export(M(), inp)
|
||||
ep = export(M(), inp)
|
||||
|
@ -4301,7 +4301,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
def forward(self, t):
|
||||
items = [t[i].item() for i in range(t.numel())]
|
||||
r = torch.randn([items[0], items[1]])
|
||||
# Could not guard on data-dependent expression Eq(u2, -1)
|
||||
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0)
|
||||
return r.view(items[0], items[2])
|
||||
|
||||
M = M_v0
|
||||
@ -4310,9 +4310,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
"The following call raised this error(.*\n)+"
|
||||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
|
||||
"To fix the error, insert one of the following checks before this call.*:\n"
|
||||
f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
|
||||
f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
|
||||
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
|
||||
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
|
||||
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
|
||||
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
|
||||
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
|
||||
):
|
||||
export(N(), (t,), strict=strict)
|
||||
|
||||
@ -4320,59 +4321,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
def forward(self, t):
|
||||
items = [t[i].item() for i in range(t.numel())]
|
||||
r = torch.randn([items[0], items[1]])
|
||||
# Could not guard on data-dependent expression Eq(u2, -1)
|
||||
torch._check(items[2] != -1)
|
||||
# Could not guard on data-dependent expression u2 >= 0
|
||||
# TODO(pianpwk): this isn't the suggested fixes.
|
||||
# fix issue with % being interpreted as PythonMod instead of Mod
|
||||
torch._check(items[1] == items[2])
|
||||
return r.view(items[0], items[2])
|
||||
|
||||
M = M_v1
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
"The following call raised this error(.*\n)+"
|
||||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
|
||||
"To fix the error, insert one of the following checks before this call.*:\n"
|
||||
f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
|
||||
f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
|
||||
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
|
||||
):
|
||||
export(N(), (t,), strict=strict)
|
||||
|
||||
class M_v2(torch.nn.Module):
|
||||
def forward(self, t):
|
||||
items = [t[i].item() for i in range(t.numel())]
|
||||
r = torch.randn([items[0], items[1]])
|
||||
# Could not guard on data-dependent expression Eq(u2, -1)
|
||||
torch._check(items[2] != -1)
|
||||
# Could not guard on data-dependent expression u2 >= 0
|
||||
torch._check(items[2] >= 0)
|
||||
# Could not guard on data-dependent expression Eq(u1, u2)
|
||||
return r.view(items[0], items[2])
|
||||
|
||||
M = M_v2
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
"The following call raised this error(.*\n)+"
|
||||
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
|
||||
"To fix the error, insert one of the following checks before this call.*:\n"
|
||||
f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
|
||||
f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
|
||||
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
|
||||
):
|
||||
export(N(), (t,), strict=strict)
|
||||
|
||||
class M_v3(torch.nn.Module):
|
||||
def forward(self, t):
|
||||
items = [t[i].item() for i in range(t.numel())]
|
||||
r = torch.randn([items[0], items[1]])
|
||||
# Could not guard on data-dependent expression Eq(u2, -1)
|
||||
torch._check(items[2] != -1)
|
||||
# Could not guard on data-dependent expression u2 >= 0
|
||||
torch._check(items[2] >= 0)
|
||||
# Could not guard on data-dependent expression Eq(u1, u2)
|
||||
torch._check(items[2] == r.shape[1])
|
||||
return r.view(items[0], items[2])
|
||||
|
||||
M = M_v3
|
||||
export(N(), (t,), strict=strict)
|
||||
|
||||
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
|
||||
@ -4484,6 +4438,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
fixes=[], # nothing to fix!
|
||||
)
|
||||
|
||||
def test_simple_unbacked_view(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
u0 = x.item()
|
||||
y = torch.empty(5, u0)
|
||||
return y.view(u0, 5) # [5, u0] -> [u0, 5]
|
||||
|
||||
ep = export(Foo(), (torch.tensor([9]),))
|
||||
self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
|
||||
self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)
|
||||
|
||||
class Foov2(torch.nn.Module):
|
||||
def forward(self, xs):
|
||||
xsl = xs.tolist()
|
||||
a, b = xsl
|
||||
x = torch.zeros(a)
|
||||
return x.reshape(b)
|
||||
|
||||
xs = torch.tensor([4, 4])
|
||||
ep = export(Foov2(), (xs,))
|
||||
self.assertEqual(ep.module()(xs).size(0), 4)
|
||||
self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)
|
||||
|
||||
def test_no_suggested_fixes_for_data_dependent_errors(self):
|
||||
# suggested fixes for data-dependent errors only work in non-strict mode
|
||||
strict = False
|
||||
@ -7549,22 +7526,19 @@ def forward(self, b_a_buffer, x):
|
||||
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
|
||||
)
|
||||
|
||||
def test_check_is_size_error(self):
|
||||
def test_no_check_is_size_error(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
a = x.item()
|
||||
# We cannot automatically infer a is a size here because view
|
||||
# accepts -1
|
||||
return torch.randn(24).view(a, 4)
|
||||
|
||||
f = Module()
|
||||
if is_non_strict_test(self._testMethodName):
|
||||
error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
|
||||
else:
|
||||
error = torch._dynamo.exc.UserError
|
||||
error_msg = r"Could not guard on data-dependent expression"
|
||||
with self.assertRaisesRegex(error, error_msg):
|
||||
_ = export(f, (torch.tensor(6),))
|
||||
ep = export(f, (torch.tensor(6),))
|
||||
ep.module()(torch.tensor(6))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Runtime assertion failed for .* u.* 6"
|
||||
):
|
||||
ep.module()(torch.tensor(5))
|
||||
|
||||
def test_is_non_negative_check_function(self):
|
||||
import sympy as sp
|
||||
@ -13487,7 +13461,7 @@ def forward(self, x):
|
||||
node.target == torch.ops.aten._assert_scalar.default
|
||||
for node in ep.graph.nodes
|
||||
].count(True)
|
||||
self.assertEqual(num_asserts, 1)
|
||||
self.assertEqual(num_asserts, 2)
|
||||
with self.assertRaises(RuntimeError):
|
||||
ep.module()(torch.randn(4, 2))
|
||||
|
||||
|
@ -924,24 +924,29 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
|
||||
Infers the size of a dim with size -1, if it exists.
|
||||
Also checks that new shape is compatible with the number of elements.
|
||||
"""
|
||||
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
|
||||
|
||||
dim = None
|
||||
newsize = 1
|
||||
for i, d in enumerate(shape):
|
||||
if d == -1:
|
||||
if guard_or_false(d == -1):
|
||||
torch._check(dim is None, lambda: "only one dimension can be inferred")
|
||||
dim = i
|
||||
elif d >= 0:
|
||||
newsize *= d
|
||||
else:
|
||||
torch._check(False, lambda: f"invalid shape dimension {d}")
|
||||
torch._check(
|
||||
d >= 0,
|
||||
lambda: (
|
||||
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
|
||||
"If this was meant to be inferred, please explicitly pass in -1."
|
||||
),
|
||||
)
|
||||
newsize *= d
|
||||
if dim is None:
|
||||
torch._check(
|
||||
numel == newsize,
|
||||
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
|
||||
)
|
||||
else:
|
||||
from torch.fx.experimental.symbolic_shapes import definitely_true
|
||||
|
||||
torch._check(
|
||||
newsize != 0,
|
||||
lambda: (
|
||||
|
@ -3717,7 +3717,8 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:
|
||||
|
||||
|
||||
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
|
||||
from torch._dynamo.exc import UserError, UserErrorType
|
||||
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
|
||||
|
||||
# Creates a valid shape
|
||||
shape = utils.extract_shape_from_varargs(shape, validate=False)
|
||||
@ -3726,7 +3727,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
||||
shape = utils.infer_size(shape, a.numel())
|
||||
|
||||
# Special-cases tensors with no elements
|
||||
if guard_size_oblivious(a.numel() == 0):
|
||||
if guard_or_false(a.numel() == 0):
|
||||
return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
|
||||
|
||||
# Special-cases reshaping zero dim tensors
|
||||
@ -3762,6 +3763,12 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
||||
return torch.as_strided(a, [dim0, dim1], [dim1, 1])
|
||||
|
||||
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
|
||||
shape_numel = reduce(operator.mul, shape, 1)
|
||||
torch._check(
|
||||
a.numel() == shape_numel,
|
||||
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
|
||||
)
|
||||
deferred: list[Callable[[], bool]] = []
|
||||
|
||||
# NOTE [Reshape Algorithm]
|
||||
# This algorithm works by attempting to greedily construct the desired dimensions in
|
||||
@ -3794,16 +3801,30 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
||||
continue
|
||||
|
||||
# Skips dimensions that are already the correct length
|
||||
if guard_size_oblivious(length == a_.shape[idx]):
|
||||
if guard_or_false(length == a_.shape[idx]):
|
||||
idx = idx + 1
|
||||
continue
|
||||
|
||||
# Gathers enough original dimensions such that this new dimension can be created
|
||||
# Note that this accumulation will terminate because we've verified a and the shape
|
||||
# specify the same number of elements above
|
||||
def maybe_throw_dde():
|
||||
# NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input
|
||||
# tensor dimensions to match the target shape (length), we've hit data-dependent errors testing
|
||||
# divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd
|
||||
# figure out a valid reshape later in the loop.
|
||||
# But we failed, either by running out of dimensions, or we couldn't figure out the strides,
|
||||
# and we've decided to re-raise to either graph break out, or provide the exact guard so the user
|
||||
# can torch._check() to avoid this.
|
||||
for f in deferred:
|
||||
f()
|
||||
|
||||
accum = a_.shape[idx]
|
||||
end = idx
|
||||
while guard_size_oblivious(accum % length != 0):
|
||||
while guard_or_true(accum % length != 0):
|
||||
deferred.append(lambda: bool(accum % length != 0))
|
||||
if end == a_.ndim - 1:
|
||||
maybe_throw_dde()
|
||||
end = end + 1
|
||||
accum = accum * a_.shape[end]
|
||||
if end != idx:
|
||||
@ -3817,13 +3838,15 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
|
||||
if allow_copy:
|
||||
return prims.reshape(a, shape)
|
||||
|
||||
maybe_throw_dde()
|
||||
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
|
||||
raise ValueError(msg)
|
||||
|
||||
a_ = flatten(a_, idx, end)
|
||||
|
||||
# Splits the (possibly flattened) dimension to create the desired dim length
|
||||
if guard_size_oblivious(accum != length):
|
||||
# Splits the (possibly flattened) dimension to create the desired dim length.
|
||||
# guard_or_true is safe due to the tail unsqueeze routine.
|
||||
if guard_or_true(accum != length):
|
||||
a_ = prims.split_dim(a_, idx, length)
|
||||
|
||||
idx = idx + 1
|
||||
|
Reference in New Issue
Block a user