mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamic shapes] guard_or_false for _reshape_view_helper, utils._infer_size for wildcard dims (#150127)
For reshape/view: removes fast paths for 0 elements, checking dimensions to skip. Modifies the loop accumulating input elements, to raise a UserError if we run out of dimensions, graph breaking for compile and erroring out for export. For infer_size: assumes if user passes us an unbacked, it's probably not -1 Will think about changes in https://docs.google.com/document/d/1WYx6EZwVDXtBnWyrzoecgGWdiK0V3XZKftfpWwQ5i3E/edit?tab=t.0#heading=h.22k54zym11qp in a later PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/150127 Approved by: https://github.com/laithsakka
This commit is contained in:
committed by
PyTorch MergeBot
parent
b37fa20771
commit
54f736155b
@ -924,24 +924,29 @@ def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]:
|
||||
Infers the size of a dim with size -1, if it exists.
|
||||
Also checks that new shape is compatible with the number of elements.
|
||||
"""
|
||||
from torch.fx.experimental.symbolic_shapes import definitely_true, guard_or_false
|
||||
|
||||
dim = None
|
||||
newsize = 1
|
||||
for i, d in enumerate(shape):
|
||||
if d == -1:
|
||||
if guard_or_false(d == -1):
|
||||
torch._check(dim is None, lambda: "only one dimension can be inferred")
|
||||
dim = i
|
||||
elif d >= 0:
|
||||
newsize *= d
|
||||
else:
|
||||
torch._check(False, lambda: f"invalid shape dimension {d}")
|
||||
torch._check(
|
||||
d >= 0,
|
||||
lambda: (
|
||||
f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1."
|
||||
"If this was meant to be inferred, please explicitly pass in -1."
|
||||
),
|
||||
)
|
||||
newsize *= d
|
||||
if dim is None:
|
||||
torch._check(
|
||||
numel == newsize,
|
||||
lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
|
||||
)
|
||||
else:
|
||||
from torch.fx.experimental.symbolic_shapes import definitely_true
|
||||
|
||||
torch._check(
|
||||
newsize != 0,
|
||||
lambda: (
|
||||
|
Reference in New Issue
Block a user