mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[primTorch] Make prims.collapse
a real prim (#91748)
`prims.collapse` is currently just a plain python function wrapping `prims.reshape`. This turns it into a real prim, and also factors out some of the code duplicated with `_collapse_view_aten`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91748 Approved by: https://github.com/lezcano, https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
0d2e91573e
commit
2622adb980
@ -140,6 +140,36 @@ class TestPrims(TestCase):
|
||||
|
||||
self.assertEqual(y, y_np, exact_device=False)
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_collapse(self, device, dtype):
|
||||
t = torch.rand(2, 2, 2)
|
||||
dim_ranges = [(0, 1), (0, 2), (1, 3), (0, 3)]
|
||||
expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]
|
||||
|
||||
for (start, end), shape in zip(dim_ranges, expected_shapes):
|
||||
expect = t.reshape(shape)
|
||||
|
||||
copy = prims.collapse(t, start, end)
|
||||
self.assertEqual(copy, expect)
|
||||
self.assertFalse(copy._is_view())
|
||||
|
||||
view = prims.collapse_view(t, start, end)
|
||||
self.assertEqual(view, expect)
|
||||
self.assertTrue(view._is_view())
|
||||
|
||||
t_discontig = t.transpose(0, 1)
|
||||
with self.assertRaises(ValueError, msg="no such view exists"):
|
||||
view = prims.collapse_view(t_discontig, 0, 2)
|
||||
|
||||
copy = prims.collapse(t_discontig, 0, 2)
|
||||
self.assertEqual(copy, t_discontig.reshape(4, 2))
|
||||
|
||||
error_dims = [(-1, 2), (0, 4), (1, 0)]
|
||||
for start, end in error_dims:
|
||||
for fn in [prims.collapse, prims.collapse_view]:
|
||||
with self.assertRaises(AssertionError):
|
||||
fn(t, start, end)
|
||||
|
||||
@onlyCUDA
|
||||
def test_nvfuser_impl_is_used(self, device):
|
||||
# This test is to ensure that when the nvfuser implementation exists it is used
|
||||
|
Reference in New Issue
Block a user