[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127)

For reshape/view: removes fast paths for 0 elements, checking dimensions to skip. Modifies the loop accumulating input elements, to raise a UserError if we run out of dimensions, graph breaking for compile and erroring out for export.
For infer_size: assumes if user passes us an unbacked, it's probably not -1

Will think about changes in https://docs.google.com/document/d/1WYx6EZwVDXtBnWyrzoecgGWdiK0V3XZKftfpWwQ5i3E/edit?tab=t.0#heading=h.22k54zym11qp in a later PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150127
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan
2025-04-23 05:42:27 +00:00
committed by PyTorch MergeBot
parent b37fa20771
commit 54f736155b
4 changed files with 81 additions and 78 deletions

View File

@ -323,7 +323,7 @@ class TestDraftExport(TestCase):
self.assertEqual(
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
)
self.assertEqual(report.failures[0].data["expr"], "Eq(2*u1, 10)")
self.assertEqual(report.failures[0].data["expr"], "Eq(9380*u1, 0)")
def test_dedup_data_dependent_failure(self):
class M(torch.nn.Module):
@ -480,6 +480,7 @@ class TestDraftExport(TestCase):
return torch.nn.functional.linear(masked, weight, bias)
x = torch.zeros(10)
x[0] += 1
inp = (torch.randn(10, 8, 7), x, torch.randn(25, 7), torch.randn(25))
draft_ep = draft_export(M(), inp)
ep = export(M(), inp)

View File

@ -4301,7 +4301,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
# Could not guard on data-dependent expression Ne(Mod(u1, u2), 0)
return r.view(items[0], items[2])
M = M_v0
@ -4310,9 +4310,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('torch._check(items[2] == (-1))')}.*\n"
f".*{re.escape('torch._check(items[2] != (-1))')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in Eq(u2, -1) and its negation.)')}",
f".*{re.escape('torch._check((items[1] % items[2]) != 0)')}.*\n"
f".*{re.escape('torch._check((items[1] % items[2]) == 0)')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1]')}"
f".*{re.escape('or r.shape[1], `u2` with items[2] in Ne(Mod(u1, u2), 0) and its negation.')}",
):
export(N(), (t,), strict=strict)
@ -4320,59 +4321,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
# TODO(pianpwk): this isn't the suggested fixes.
# fix issue with % being interpreted as PythonMod instead of Mod
torch._check(items[1] == items[2])
return r.view(items[0], items[2])
M = M_v1
with self.assertRaisesRegex(
error_type,
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('You can add either: torch._check_is_size(u2) or torch._check(u2>=0) Note: torch._check_is_size(u2) could prevent data dependent errors that happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning. See documentation on guard_size_oblivious for more details: https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html')}.*\n"
f".*{re.escape('torch._check(items[2] < 0)')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u2` with items[2] in u2 >= 0 and its negation.)')}",
):
export(N(), (t,), strict=strict)
class M_v2(torch.nn.Module):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
torch._check(items[2] >= 0)
# Could not guard on data-dependent expression Eq(u1, u2)
return r.view(items[0], items[2])
M = M_v2
with self.assertRaisesRegex(
error_type,
"The following call raised this error(.*\n)+"
f".*{re.escape('return r.view(items[0], items[2])')}(.*\n)+"
"To fix the error, insert one of the following checks before this call.*:\n"
f".*{re.escape('torch._check(items[2] == items[1])')}.*\n"
f".*{re.escape('torch._check(items[2] != items[1])')}(.*\n)+"
f".*{re.escape('(These suggested fixes were derived by replacing `u1` with items[1] or r.shape[1], `u2` with items[2] in Eq(u2, u1) and its negation.)')}",
):
export(N(), (t,), strict=strict)
class M_v3(torch.nn.Module):
def forward(self, t):
items = [t[i].item() for i in range(t.numel())]
r = torch.randn([items[0], items[1]])
# Could not guard on data-dependent expression Eq(u2, -1)
torch._check(items[2] != -1)
# Could not guard on data-dependent expression u2 >= 0
torch._check(items[2] >= 0)
# Could not guard on data-dependent expression Eq(u1, u2)
torch._check(items[2] == r.shape[1])
return r.view(items[0], items[2])
M = M_v3
export(N(), (t,), strict=strict)
def test_suggested_fixes_for_data_dependent_errors_puzzlers(self):
@ -4484,6 +4438,29 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
fixes=[], # nothing to fix!
)
def test_simple_unbacked_view(self):
class Foo(torch.nn.Module):
def forward(self, x):
u0 = x.item()
y = torch.empty(5, u0)
return y.view(u0, 5) # [5, u0] -> [u0, 5]
ep = export(Foo(), (torch.tensor([9]),))
self.assertEqual(ep.module()(torch.tensor([8])).size(0), 8)
self.assertEqual(ep.module()(torch.tensor([5])).size(0), 5)
class Foov2(torch.nn.Module):
def forward(self, xs):
xsl = xs.tolist()
a, b = xsl
x = torch.zeros(a)
return x.reshape(b)
xs = torch.tensor([4, 4])
ep = export(Foov2(), (xs,))
self.assertEqual(ep.module()(xs).size(0), 4)
self.assertEqual(ep.module()(torch.tensor([5, 5])).size(0), 5)
def test_no_suggested_fixes_for_data_dependent_errors(self):
# suggested fixes for data-dependent errors only work in non-strict mode
strict = False
@ -7549,22 +7526,19 @@ def forward(self, b_a_buffer, x):
len([node for node in gm.graph.nodes if node.op == "placeholder"]), 2
)
def test_check_is_size_error(self):
def test_no_check_is_size_error(self):
class Module(torch.nn.Module):
def forward(self, x):
a = x.item()
# We cannot automatically infer a is a size here because view
# accepts -1
return torch.randn(24).view(a, 4)
f = Module()
if is_non_strict_test(self._testMethodName):
error = torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
else:
error = torch._dynamo.exc.UserError
error_msg = r"Could not guard on data-dependent expression"
with self.assertRaisesRegex(error, error_msg):
_ = export(f, (torch.tensor(6),))
ep = export(f, (torch.tensor(6),))
ep.module()(torch.tensor(6))
with self.assertRaisesRegex(
RuntimeError, r"Runtime assertion failed for .* u.* 6"
):
ep.module()(torch.tensor(5))
def test_is_non_negative_check_function(self):
import sympy as sp
@ -13487,7 +13461,7 @@ def forward(self, x):
node.target == torch.ops.aten._assert_scalar.default
for node in ep.graph.nodes
].count(True)
self.assertEqual(num_asserts, 1)
self.assertEqual(num_asserts, 2)
with self.assertRaises(RuntimeError):
ep.module()(torch.randn(4, 2))

View File

@ -924,24 +924,29 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
Infers the size of a dim with size -1, if it exists.
Also checks that new shape is compatible with the number of elements.
"""
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
dim = None
newsize = 1
for i, d in enumerate(shape):
if d == -1:
if guard_or_false(d == -1):
torch._check(dim is None, lambda: "only one dimension can be inferred")
dim = i
elif d >= 0:
newsize *= d
else:
torch._check(False, lambda: f"invalid shape dimension {d}")
torch._check(
d >= 0,
lambda: (
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
"If this was meant to be inferred, please explicitly pass in -1."
),
)
newsize *= d
if dim is None:
torch._check(
numel == newsize,
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
)
else:
from torch.fx.experimental.symbolic_shapes import definitely_true
torch._check(
newsize != 0,
lambda: (

View File

@ -3717,7 +3717,8 @@ def repeat(a: Tensor, *repeat_shape) -> Tensor:
def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
from torch._dynamo.exc import UserError, UserErrorType
from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
# Creates a valid shape
shape = utils.extract_shape_from_varargs(shape, validate=False)
@ -3726,7 +3727,7 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
shape = utils.infer_size(shape, a.numel())
# Special-cases tensors with no elements
if guard_size_oblivious(a.numel() == 0):
if guard_or_false(a.numel() == 0):
return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
# Special-cases reshaping zero dim tensors
@ -3762,6 +3763,12 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
return torch.as_strided(a, [dim0, dim1], [dim1, 1])
# Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
shape_numel = reduce(operator.mul, shape, 1)
torch._check(
a.numel() == shape_numel,
f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
)
deferred: list[Callable[[], bool]] = []
# NOTE [Reshape Algorithm]
# This algorithm works by attempting to greedily construct the desired dimensions in
@ -3794,16 +3801,30 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
continue
# Skips dimensions that are already the correct length
if guard_size_oblivious(length == a_.shape[idx]):
if guard_or_false(length == a_.shape[idx]):
idx = idx + 1
continue
# Gathers enough original dimensions such that this new dimension can be created
# Note that this accumulation will terminate because we've verified a and the shape
# specify the same number of elements above
def maybe_throw_dde():
# NOTE: if you've hit a data-dependent error here, it's because in trying to accumulate input
# tensor dimensions to match the target shape (length), we've hit data-dependent errors testing
# divisibility (accum % length != 0), and have deferred raising them, in the hope that we'd
# figure out a valid reshape later in the loop.
# But we failed, either by running out of dimensions, or we couldn't figure out the strides,
# and we've decided to re-raise to either graph break out, or provide the exact guard so the user
# can torch._check() to avoid this.
for f in deferred:
f()
accum = a_.shape[idx]
end = idx
while guard_size_oblivious(accum % length != 0):
while guard_or_true(accum % length != 0):
deferred.append(lambda: bool(accum % length != 0))
if end == a_.ndim - 1:
maybe_throw_dde()
end = end + 1
accum = accum * a_.shape[end]
if end != idx:
@ -3817,13 +3838,15 @@ def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorL
if allow_copy:
return prims.reshape(a, shape)
maybe_throw_dde()
msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
raise ValueError(msg)
a_ = flatten(a_, idx, end)
# Splits the (possibly flattened) dimension to create the desired dim length
if guard_size_oblivious(accum != length):
# Splits the (possibly flattened) dimension to create the desired dim length.
# guard_or_true is safe due to the tail unsqueeze routine.
if guard_or_true(accum != length):
a_ = prims.split_dim(a_, idx, length)
idx = idx + 1