mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
remove gso from collapse_view_helper (#162212)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162212 Approved by: https://github.com/aorenste Co-authored-by: Aaron Orenstein <aorenste@fb.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
0e7ccc09db
commit
e4174b1fd7
@ -142,7 +142,7 @@ class TestPrims(TestCase):
|
||||
self.assertTrue(view._is_view())
|
||||
|
||||
t_discontig = t.transpose(0, 1)
|
||||
with self.assertRaises(ValueError, msg="no such view exists"):
|
||||
with self.assertRaises(RuntimeError, msg="Attempting to view a collapsed tensor, but no such view exists!"):
|
||||
view = prims.collapse_view(t_discontig, 0, 2)
|
||||
|
||||
copy = prims.collapse(t_discontig, 0, 1)
|
||||
|
@ -1384,12 +1384,22 @@ def _collapsed_shape(shape: ShapeType, start: int, end: int) -> tuple[int, ...]:
|
||||
return shape[0:start] + (dim_length,) + shape[end + 1 :]
|
||||
|
||||
|
||||
# If the collapse is invalid or cannot be determined (because of unbacked data)
|
||||
# then `must_be_valid` determines the behavior:
|
||||
# None: return None, None.
|
||||
# str: Do a torch._check() to ensure the collapse is valid and if it isn't
|
||||
# then fail with the provided string.
|
||||
def _collapse_view_helper(
|
||||
a: TensorLikeType, start: int, end: int
|
||||
a: TensorLikeType, start: int, end: int, must_be_valid: Optional[str]
|
||||
) -> tuple[Optional[ShapeType], Optional[StrideType]]:
|
||||
assert isinstance(a, TensorLike)
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_or_false,
|
||||
guard_or_true,
|
||||
sym_and,
|
||||
sym_or,
|
||||
)
|
||||
|
||||
_validate_collapse_args(a, start, end)
|
||||
|
||||
@ -1404,52 +1414,69 @@ def _collapse_view_helper(
|
||||
if a.ndim == 0 or (end == start):
|
||||
return shape, strides
|
||||
|
||||
length = shape[end]
|
||||
valid_op = True
|
||||
if guard_or_false(a.numel() != 0):
|
||||
for idx in range(end - 1, start - 1, -1):
|
||||
valid_op = sym_and(
|
||||
valid_op,
|
||||
sym_or(
|
||||
shape[idx] == 1,
|
||||
shape[idx + 1] == 1,
|
||||
strides[idx] == strides[idx + 1] * shape[idx + 1],
|
||||
),
|
||||
) # type: ignore[assignment]
|
||||
|
||||
# early exit if we already know its invalid.
|
||||
if guard_or_false(valid_op is False):
|
||||
break
|
||||
|
||||
# for unbacked this become a runtime assertion.
|
||||
valid_op = sym_or(valid_op, a.numel() == 0)
|
||||
|
||||
if must_be_valid:
|
||||
torch._check(valid_op, lambda: must_be_valid)
|
||||
else:
|
||||
if not guard_or_false(valid_op):
|
||||
return None, None
|
||||
|
||||
# compute stride
|
||||
stride = strides[end]
|
||||
for idx in range(end - 1, start - 1, -1):
|
||||
if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious(
|
||||
shape[idx + 1] == 0
|
||||
):
|
||||
length = 0
|
||||
stride = 0
|
||||
break
|
||||
if shape[idx] != 1:
|
||||
# TODO with unbacked we should really exclude when shape[idx] == 1
|
||||
# something like
|
||||
# min(stride[end], torch.ite(shape[x]!=1,stride[idx], inf), ...)
|
||||
stride = min(stride, strides[idx])
|
||||
|
||||
if guard_size_oblivious(shape[idx] == 1):
|
||||
continue
|
||||
|
||||
length = length * shape[idx]
|
||||
if guard_size_oblivious(stride < strides[idx]):
|
||||
stride = stride
|
||||
else:
|
||||
stride = strides[idx]
|
||||
|
||||
if (
|
||||
guard_size_oblivious(a.numel() > 0)
|
||||
and guard_size_oblivious(shape[idx + 1] != 1)
|
||||
and not guard_size_oblivious(
|
||||
strides[idx] == strides[idx + 1] * shape[idx + 1]
|
||||
)
|
||||
):
|
||||
return None, None
|
||||
# compute length
|
||||
length = shape[end]
|
||||
if guard_or_true(length != 0):
|
||||
for idx in range(end - 1, start - 1, -1):
|
||||
if guard_or_false(shape[idx] == 0):
|
||||
length = 0
|
||||
stride = 0
|
||||
break
|
||||
length = length * shape[idx]
|
||||
else:
|
||||
stride = 0
|
||||
|
||||
new_shape = shape[:start] + (length,) + shape[end + 1 :]
|
||||
new_strides = strides[:start] + (stride,) + strides[end + 1 :]
|
||||
|
||||
# NOTE: when the input has no elements it's restrided as if it were contiguous
|
||||
if guard_size_oblivious(a.numel() == 0):
|
||||
# except for unbacked.
|
||||
if guard_or_false(a.numel() == 0):
|
||||
new_strides = utils.make_contiguous_strides_for(new_shape)
|
||||
|
||||
return new_shape, new_strides
|
||||
|
||||
|
||||
def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
|
||||
new_shape, new_strides = _collapse_view_helper(a, start, end)
|
||||
|
||||
if new_shape is None:
|
||||
msg = "Attempting to view a collapsed tensor, but no such view exists!"
|
||||
raise ValueError(msg)
|
||||
|
||||
new_shape, new_strides = _collapse_view_helper(
|
||||
a, start, end, "Attempting to view a collapsed tensor, but no such view exists!"
|
||||
)
|
||||
assert new_strides is not None
|
||||
assert new_shape is not None
|
||||
return a.as_strided(new_shape, new_strides, a.storage_offset())
|
||||
|
||||
|
||||
|
@ -3132,7 +3132,10 @@ def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorL
|
||||
|
||||
# Tries to take a view
|
||||
# TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
|
||||
new_shape, _new_strides = prims._collapse_view_helper(a, start_dim, end_dim)
|
||||
# Unbacked semnatics: if validty of in-place flattening is undecided we copy.
|
||||
new_shape, _new_strides = prims._collapse_view_helper(
|
||||
a, start_dim, end_dim, must_be_valid=None
|
||||
)
|
||||
if new_shape is not None:
|
||||
return prims.collapse_view(a, start_dim, end_dim)
|
||||
|
||||
@ -3840,7 +3843,9 @@ def _reshape_view_helper_core_alg(
|
||||
# may return a view of a copy
|
||||
|
||||
# Checks if collapse can be a view and short-circuits to copying reshape if it can't
|
||||
new_shape, _new_strides = prims._collapse_view_helper(a_, idx, end)
|
||||
new_shape, _new_strides = prims._collapse_view_helper(
|
||||
a_, idx, end, must_be_valid=None
|
||||
)
|
||||
if new_shape is None:
|
||||
if allow_copy:
|
||||
return prims.reshape(a, shape)
|
||||
|
Reference in New Issue
Block a user