mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[primTorch] Make prims.collapse a real prim (#91748)
`prims.collapse` is currently just a plain python function wrapping `prims.reshape`. This turns it into a real prim, and also factors out some of the code duplicated with `_collapse_view_aten`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91748 Approved by: https://github.com/lezcano, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
0d2e91573e
commit
2622adb980
@ -140,6 +140,36 @@ class TestPrims(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(y, y_np, exact_device=False)
|
self.assertEqual(y, y_np, exact_device=False)
|
||||||
|
|
||||||
|
@dtypes(torch.float32)
|
||||||
|
def test_collapse(self, device, dtype):
|
||||||
|
t = torch.rand(2, 2, 2)
|
||||||
|
dim_ranges = [(0, 1), (0, 2), (1, 3), (0, 3)]
|
||||||
|
expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]
|
||||||
|
|
||||||
|
for (start, end), shape in zip(dim_ranges, expected_shapes):
|
||||||
|
expect = t.reshape(shape)
|
||||||
|
|
||||||
|
copy = prims.collapse(t, start, end)
|
||||||
|
self.assertEqual(copy, expect)
|
||||||
|
self.assertFalse(copy._is_view())
|
||||||
|
|
||||||
|
view = prims.collapse_view(t, start, end)
|
||||||
|
self.assertEqual(view, expect)
|
||||||
|
self.assertTrue(view._is_view())
|
||||||
|
|
||||||
|
t_discontig = t.transpose(0, 1)
|
||||||
|
with self.assertRaises(ValueError, msg="no such view exists"):
|
||||||
|
view = prims.collapse_view(t_discontig, 0, 2)
|
||||||
|
|
||||||
|
copy = prims.collapse(t_discontig, 0, 2)
|
||||||
|
self.assertEqual(copy, t_discontig.reshape(4, 2))
|
||||||
|
|
||||||
|
error_dims = [(-1, 2), (0, 4), (1, 0)]
|
||||||
|
for start, end in error_dims:
|
||||||
|
for fn in [prims.collapse, prims.collapse_view]:
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
fn(t, start, end)
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_nvfuser_impl_is_used(self, device):
|
def test_nvfuser_impl_is_used(self, device):
|
||||||
# This test is to ensure that when the nvfuser implementation exists it is used
|
# This test is to ensure that when the nvfuser implementation exists it is used
|
||||||
|
|||||||
@ -1277,11 +1277,42 @@ broadcast_in_dim = _make_prim(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
|
||||||
|
# Special-case for zero dimensional tensors
|
||||||
|
ndim = max(1, a.dim())
|
||||||
|
utils.validate_idx(ndim, start)
|
||||||
|
utils.validate_exclusive_idx(ndim, end)
|
||||||
|
|
||||||
|
# Verifies end is strictly greater than start
|
||||||
|
# (Collapse requires a non-empty interval)
|
||||||
|
utils.check(
|
||||||
|
end > start,
|
||||||
|
lambda: f"Attempting to collapse but end, {end}, is less than or equal to start, {start}!",
|
||||||
|
ValueError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
|
||||||
|
"""
|
||||||
|
Returns the shape of a with dims in [start, end) merged into a single dimension.
|
||||||
|
"""
|
||||||
|
# Special-case for zero dimensional tensors
|
||||||
|
shape = (1,) if len(shape) == 0 else tuple(shape)
|
||||||
|
|
||||||
|
dim_length = 1
|
||||||
|
for idx in range(start, end):
|
||||||
|
dim_length = dim_length * shape[idx]
|
||||||
|
|
||||||
|
return shape[0:start] + (dim_length,) + shape[end:]
|
||||||
|
|
||||||
|
|
||||||
def _collapse_view_helper(
|
def _collapse_view_helper(
|
||||||
a: TensorLikeType, start: int, end: int
|
a: TensorLikeType, start: int, end: int
|
||||||
) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
|
) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
|
||||||
assert isinstance(a, TensorLike)
|
assert isinstance(a, TensorLike)
|
||||||
|
|
||||||
|
_validate_collapse_args(a, start, end)
|
||||||
|
|
||||||
# Special-case for zero dimensional tensors
|
# Special-case for zero dimensional tensors
|
||||||
if a.ndim == 0:
|
if a.ndim == 0:
|
||||||
shape = (1,)
|
shape = (1,)
|
||||||
@ -1290,17 +1321,6 @@ def _collapse_view_helper(
|
|||||||
shape = a.shape # type: ignore[assignment]
|
shape = a.shape # type: ignore[assignment]
|
||||||
strides = a.stride() # type: ignore[assignment]
|
strides = a.stride() # type: ignore[assignment]
|
||||||
|
|
||||||
utils.validate_idx(len(shape), start)
|
|
||||||
utils.validate_exclusive_idx(len(shape), end)
|
|
||||||
|
|
||||||
# Verifies end is strictly greater than start
|
|
||||||
# (Collapse requires a non-empty interval)
|
|
||||||
if end <= start:
|
|
||||||
msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format(
|
|
||||||
end, start
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if a.ndim == 0 or (end - 1 == start):
|
if a.ndim == 0 or (end - 1 == start):
|
||||||
return shape, strides
|
return shape, strides
|
||||||
|
|
||||||
@ -1342,25 +1362,12 @@ def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeTy
|
|||||||
msg = "Attempting to view a collapsed tensor, but no such view exists!"
|
msg = "Attempting to view a collapsed tensor, but no such view exists!"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
if new_strides is None:
|
assert new_strides is not None
|
||||||
return a.view(new_shape)
|
return a.as_strided(new_shape, new_strides, a.storage_offset())
|
||||||
else:
|
|
||||||
return a.as_strided(new_shape, new_strides, a.storage_offset())
|
|
||||||
|
|
||||||
|
|
||||||
def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
|
def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
|
||||||
# Special-cases zero-dim tensors
|
new_shape = _collapsed_shape(a.shape, start, end)
|
||||||
if a.ndim == 0:
|
|
||||||
shape = (1,)
|
|
||||||
else:
|
|
||||||
shape = a.shape # type: ignore[assignment]
|
|
||||||
|
|
||||||
dim_length = 1
|
|
||||||
for idx in range(start, end):
|
|
||||||
dim_length = dim_length * shape[idx]
|
|
||||||
|
|
||||||
new_shape = shape[0:start] + (dim_length,) + shape[end:]
|
|
||||||
|
|
||||||
return a.view(new_shape)
|
return a.view(new_shape)
|
||||||
|
|
||||||
|
|
||||||
@ -1839,19 +1846,35 @@ as_strided_scatter = _make_prim(
|
|||||||
#
|
#
|
||||||
# Shape operations
|
# Shape operations
|
||||||
#
|
#
|
||||||
def collapse(a: Tensor, start: int, end: int) -> Tensor:
|
|
||||||
"""
|
|
||||||
Wrapper around reshape that collapses a span of dimensions.
|
|
||||||
|
|
||||||
See collapse_view for the corresponding view operation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dim_length = 1
|
def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor:
|
||||||
for idx in range(start, end):
|
# Special-case for zero dimensional tensors
|
||||||
dim_length = dim_length * a.shape[idx]
|
_validate_collapse_args(a, start, end)
|
||||||
|
new_shape = _collapsed_shape(a.shape, start, end)
|
||||||
|
return a.new_empty(new_shape)
|
||||||
|
|
||||||
new_shape = a.shape[0:start] + (dim_length,) + a.shape[end:]
|
|
||||||
return reshape(a, new_shape)
|
def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor:
|
||||||
|
new_shape = _collapsed_shape(a.shape, start, end)
|
||||||
|
out = a.new_empty(new_shape)
|
||||||
|
with torch.no_grad():
|
||||||
|
out.view_as(a).copy_(a)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
_collapse_doc = """
|
||||||
|
Collapse a span of neighboring dimensions into one.
|
||||||
|
|
||||||
|
See collapse_view for the corresponding view operation.
|
||||||
|
"""
|
||||||
|
collapse = _make_prim(
|
||||||
|
schema="collapse(Tensor a, int start, int end) -> Tensor",
|
||||||
|
meta=_collapse_meta,
|
||||||
|
impl_aten=_collapse_aten,
|
||||||
|
return_type=RETURN_TYPE.NEW,
|
||||||
|
doc=_collapse_doc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: review stride logic
|
# TODO: review stride logic
|
||||||
|
|||||||
Reference in New Issue
Block a user