remove gso from collapse_view_helper (#162212)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162212
Approved by: https://github.com/aorenste

Co-authored-by: Aaron Orenstein <aorenste@fb.com>
This commit is contained in:
Laith Sakka
2025-09-09 11:56:04 -07:00
committed by PyTorch MergeBot
parent 0e7ccc09db
commit e4174b1fd7
3 changed files with 68 additions and 36 deletions

View File

@ -142,7 +142,7 @@ class TestPrims(TestCase):
self.assertTrue(view._is_view())
t_discontig = t.transpose(0, 1)
with self.assertRaises(ValueError, msg="no such view exists"):
with self.assertRaises(RuntimeError, msg="Attempting to view a collapsed tensor, but no such view exists!"):
view = prims.collapse_view(t_discontig, 0, 2)
copy = prims.collapse(t_discontig, 0, 1)

View File

@ -1384,12 +1384,22 @@ def _collapsed_shape(shape: ShapeType, start: int, end: int) -> tuple[int, ...]:
return shape[0:start] + (dim_length,) + shape[end + 1 :]
# If the collapse is invalid or cannot be determined (because of unbacked data)
# then `must_be_valid` determines the behavior:
# None: return None, None.
# str: Do a torch._check() to ensure the collapse is valid and if it isn't
# then fail with the provided string.
def _collapse_view_helper(
a: TensorLikeType, start: int, end: int
a: TensorLikeType, start: int, end: int, must_be_valid: Optional[str]
) -> tuple[Optional[ShapeType], Optional[StrideType]]:
assert isinstance(a, TensorLike)
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
sym_and,
sym_or,
)
_validate_collapse_args(a, start, end)
@ -1404,52 +1414,69 @@ def _collapse_view_helper(
if a.ndim == 0 or (end == start):
return shape, strides
length = shape[end]
valid_op = True
if guard_or_false(a.numel() != 0):
for idx in range(end - 1, start - 1, -1):
valid_op = sym_and(
valid_op,
sym_or(
shape[idx] == 1,
shape[idx + 1] == 1,
strides[idx] == strides[idx + 1] * shape[idx + 1],
),
) # type: ignore[assignment]
# early exit if we already know its invalid.
if guard_or_false(valid_op is False):
break
# for unbacked this become a runtime assertion.
valid_op = sym_or(valid_op, a.numel() == 0)
if must_be_valid:
torch._check(valid_op, lambda: must_be_valid)
else:
if not guard_or_false(valid_op):
return None, None
# compute stride
stride = strides[end]
for idx in range(end - 1, start - 1, -1):
if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious(
shape[idx + 1] == 0
):
length = 0
stride = 0
break
if shape[idx] != 1:
# TODO with unbacked we should really exclude when shape[idx] == 1
# something like
# min(stride[end], torch.ite(shape[x]!=1,stride[idx], inf), ...)
stride = min(stride, strides[idx])
if guard_size_oblivious(shape[idx] == 1):
continue
length = length * shape[idx]
if guard_size_oblivious(stride < strides[idx]):
stride = stride
else:
stride = strides[idx]
if (
guard_size_oblivious(a.numel() > 0)
and guard_size_oblivious(shape[idx + 1] != 1)
and not guard_size_oblivious(
strides[idx] == strides[idx + 1] * shape[idx + 1]
)
):
return None, None
# compute length
length = shape[end]
if guard_or_true(length != 0):
for idx in range(end - 1, start - 1, -1):
if guard_or_false(shape[idx] == 0):
length = 0
stride = 0
break
length = length * shape[idx]
else:
stride = 0
new_shape = shape[:start] + (length,) + shape[end + 1 :]
new_strides = strides[:start] + (stride,) + strides[end + 1 :]
# NOTE: when the input has no elements it's restrided as if it were contiguous
if guard_size_oblivious(a.numel() == 0):
# except for unbacked.
if guard_or_false(a.numel() == 0):
new_strides = utils.make_contiguous_strides_for(new_shape)
return new_shape, new_strides
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
new_shape, new_strides = _collapse_view_helper(a, start, end)
if new_shape is None:
msg = "Attempting to view a collapsed tensor, but no such view exists!"
raise ValueError(msg)
new_shape, new_strides = _collapse_view_helper(
a, start, end, "Attempting to view a collapsed tensor, but no such view exists!"
)
assert new_strides is not None
assert new_shape is not None
return a.as_strided(new_shape, new_strides, a.storage_offset())

View File

@ -3132,7 +3132,10 @@ def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorL
# Tries to take a view
# TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
new_shape, _new_strides = prims._collapse_view_helper(a, start_dim, end_dim)
# Unbacked semnatics: if validty of in-place flattening is undecided we copy.
new_shape, _new_strides = prims._collapse_view_helper(
a, start_dim, end_dim, must_be_valid=None
)
if new_shape is not None:
return prims.collapse_view(a, start_dim, end_dim)
@ -3840,7 +3843,9 @@ def _reshape_view_helper_core_alg(
# may return a view of a copy
# Checks if collapse can be a view and short-circuits to copying reshape if it can't
new_shape, _new_strides = prims._collapse_view_helper(a_, idx, end)
new_shape, _new_strides = prims._collapse_view_helper(
a_, idx, end, must_be_valid=None
)
if new_shape is None:
if allow_copy:
return prims.reshape(a, shape)