[primTorch] Make prims.collapse a real prim (#91748)

`prims.collapse` is currently just a plain python function wrapping
`prims.reshape`. This turns it into a real prim, and also factors out some of
the code duplicated with `_collapse_view_aten`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91748
Approved by: https://github.com/lezcano, https://github.com/ngimel
This commit is contained in:
Peter Bell
2023-02-21 17:28:25 +00:00
committed by PyTorch MergeBot
parent 0d2e91573e
commit 2622adb980
2 changed files with 90 additions and 37 deletions

View File

@ -140,6 +140,36 @@ class TestPrims(TestCase):
self.assertEqual(y, y_np, exact_device=False) self.assertEqual(y, y_np, exact_device=False)
@dtypes(torch.float32)
def test_collapse(self, device, dtype):
t = torch.rand(2, 2, 2)
dim_ranges = [(0, 1), (0, 2), (1, 3), (0, 3)]
expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]
for (start, end), shape in zip(dim_ranges, expected_shapes):
expect = t.reshape(shape)
copy = prims.collapse(t, start, end)
self.assertEqual(copy, expect)
self.assertFalse(copy._is_view())
view = prims.collapse_view(t, start, end)
self.assertEqual(view, expect)
self.assertTrue(view._is_view())
t_discontig = t.transpose(0, 1)
with self.assertRaises(ValueError, msg="no such view exists"):
view = prims.collapse_view(t_discontig, 0, 2)
copy = prims.collapse(t_discontig, 0, 2)
self.assertEqual(copy, t_discontig.reshape(4, 2))
error_dims = [(-1, 2), (0, 4), (1, 0)]
for start, end in error_dims:
for fn in [prims.collapse, prims.collapse_view]:
with self.assertRaises(AssertionError):
fn(t, start, end)
@onlyCUDA @onlyCUDA
def test_nvfuser_impl_is_used(self, device): def test_nvfuser_impl_is_used(self, device):
# This test is to ensure that when the nvfuser implementation exists it is used # This test is to ensure that when the nvfuser implementation exists it is used

View File

@ -1277,11 +1277,42 @@ broadcast_in_dim = _make_prim(
) )
def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
# Special-case for zero dimensional tensors
ndim = max(1, a.dim())
utils.validate_idx(ndim, start)
utils.validate_exclusive_idx(ndim, end)
# Verifies end is strictly greater than start
# (Collapse requires a non-empty interval)
utils.check(
end > start,
lambda: f"Attempting to collapse but end, {end}, is less than or equal to start, {start}!",
ValueError,
)
def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
"""
Returns the shape of a with dims in [start, end) merged into a single dimension.
"""
# Special-case for zero dimensional tensors
shape = (1,) if len(shape) == 0 else tuple(shape)
dim_length = 1
for idx in range(start, end):
dim_length = dim_length * shape[idx]
return shape[0:start] + (dim_length,) + shape[end:]
def _collapse_view_helper( def _collapse_view_helper(
a: TensorLikeType, start: int, end: int a: TensorLikeType, start: int, end: int
) -> Tuple[Optional[ShapeType], Optional[StrideType]]: ) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
assert isinstance(a, TensorLike) assert isinstance(a, TensorLike)
_validate_collapse_args(a, start, end)
# Special-case for zero dimensional tensors # Special-case for zero dimensional tensors
if a.ndim == 0: if a.ndim == 0:
shape = (1,) shape = (1,)
@ -1290,17 +1321,6 @@ def _collapse_view_helper(
shape = a.shape # type: ignore[assignment] shape = a.shape # type: ignore[assignment]
strides = a.stride() # type: ignore[assignment] strides = a.stride() # type: ignore[assignment]
utils.validate_idx(len(shape), start)
utils.validate_exclusive_idx(len(shape), end)
# Verifies end is strictly greater than start
# (Collapse requires a non-empty interval)
if end <= start:
msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format(
end, start
)
raise ValueError(msg)
if a.ndim == 0 or (end - 1 == start): if a.ndim == 0 or (end - 1 == start):
return shape, strides return shape, strides
@ -1342,25 +1362,12 @@ def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeTy
msg = "Attempting to view a collapsed tensor, but no such view exists!" msg = "Attempting to view a collapsed tensor, but no such view exists!"
raise ValueError(msg) raise ValueError(msg)
if new_strides is None: assert new_strides is not None
return a.view(new_shape) return a.as_strided(new_shape, new_strides, a.storage_offset())
else:
return a.as_strided(new_shape, new_strides, a.storage_offset())
def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor: def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
# Special-cases zero-dim tensors new_shape = _collapsed_shape(a.shape, start, end)
if a.ndim == 0:
shape = (1,)
else:
shape = a.shape # type: ignore[assignment]
dim_length = 1
for idx in range(start, end):
dim_length = dim_length * shape[idx]
new_shape = shape[0:start] + (dim_length,) + shape[end:]
return a.view(new_shape) return a.view(new_shape)
@ -1839,19 +1846,35 @@ as_strided_scatter = _make_prim(
# #
# Shape operations # Shape operations
# #
def collapse(a: Tensor, start: int, end: int) -> Tensor:
"""
Wrapper around reshape that collapses a span of dimensions.
See collapse_view for the corresponding view operation.
"""
dim_length = 1 def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor:
for idx in range(start, end): # Special-case for zero dimensional tensors
dim_length = dim_length * a.shape[idx] _validate_collapse_args(a, start, end)
new_shape = _collapsed_shape(a.shape, start, end)
return a.new_empty(new_shape)
new_shape = a.shape[0:start] + (dim_length,) + a.shape[end:]
return reshape(a, new_shape) def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor:
new_shape = _collapsed_shape(a.shape, start, end)
out = a.new_empty(new_shape)
with torch.no_grad():
out.view_as(a).copy_(a)
return out
_collapse_doc = """
Collapse a span of neighboring dimensions into one.
See collapse_view for the corresponding view operation.
"""
collapse = _make_prim(
schema="collapse(Tensor a, int start, int end) -> Tensor",
meta=_collapse_meta,
impl_aten=_collapse_aten,
return_type=RETURN_TYPE.NEW,
doc=_collapse_doc,
)
# TODO: review stride logic # TODO: review stride logic