[primTorch] Make prims.collapse a real prim (#91748)

`prims.collapse` is currently just a plain python function wrapping
`prims.reshape`. This turns it into a real prim, and also factors out some of
the code duplicated with `_collapse_view_aten`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91748
Approved by: https://github.com/lezcano, https://github.com/ngimel
This commit is contained in:
Peter Bell
2023-02-21 17:28:25 +00:00
committed by PyTorch MergeBot
parent 0d2e91573e
commit 2622adb980
2 changed files with 90 additions and 37 deletions

View File

@ -140,6 +140,36 @@ class TestPrims(TestCase):
self.assertEqual(y, y_np, exact_device=False)
@dtypes(torch.float32)
def test_collapse(self, device, dtype):
t = torch.rand(2, 2, 2)
dim_ranges = [(0, 1), (0, 2), (1, 3), (0, 3)]
expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]
for (start, end), shape in zip(dim_ranges, expected_shapes):
expect = t.reshape(shape)
copy = prims.collapse(t, start, end)
self.assertEqual(copy, expect)
self.assertFalse(copy._is_view())
view = prims.collapse_view(t, start, end)
self.assertEqual(view, expect)
self.assertTrue(view._is_view())
t_discontig = t.transpose(0, 1)
with self.assertRaises(ValueError, msg="no such view exists"):
view = prims.collapse_view(t_discontig, 0, 2)
copy = prims.collapse(t_discontig, 0, 2)
self.assertEqual(copy, t_discontig.reshape(4, 2))
error_dims = [(-1, 2), (0, 4), (1, 0)]
for start, end in error_dims:
for fn in [prims.collapse, prims.collapse_view]:
with self.assertRaises(AssertionError):
fn(t, start, end)
@onlyCUDA
def test_nvfuser_impl_is_used(self, device):
# This test is to ensure that when the nvfuser implementation exists it is used