[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127)

For reshape/view: removes fast paths for 0 elements, checking dimensions to skip. Modifies the loop accumulating input elements, to raise a UserError if we run out of dimensions, graph breaking for compile and erroring out for export.
For infer_size: assumes if user passes us an unbacked, it's probably not -1

Will think about changes in https://docs.google.com/document/d/1WYx6EZwVDXtBnWyrzoecgGWdiK0V3XZKftfpWwQ5i3E/edit?tab=t.0#heading=h.22k54zym11qp in a later PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150127
Approved by: https://github.com/laithsakka
This commit is contained in:
Pian Pawakapan
2025-04-23 05:42:27 +00:00
committed by PyTorch MergeBot
parent b37fa20771
commit 54f736155b
4 changed files with 81 additions and 78 deletions

View File

@ -924,24 +924,29 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
Infers the size of a dim with size -1, if it exists.
Also checks that new shape is compatible with the number of elements.
"""
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
dim = None
newsize = 1
for i, d in enumerate(shape):
if d == -1:
if guard_or_false(d == -1):
torch._check(dim is None, lambda: "only one dimension can be inferred")
dim = i
elif d >= 0:
newsize *= d
else:
torch._check(False, lambda: f"invalid shape dimension {d}")
torch._check(
d >= 0,
lambda: (
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
"If this was meant to be inferred, please explicitly pass in -1."
),
)
newsize *= d
if dim is None:
torch._check(
numel == newsize,
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
)
else:
from torch.fx.experimental.symbolic_shapes import definitely_true
torch._check(
newsize != 0,
lambda: (