Fix prims unbind if given dimension size is 0 (#100122)

Fixes #99832

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100122
Approved by: https://github.com/ngimel
This commit is contained in:
Yanbo Liang
2023-04-26 23:40:21 +00:00
committed by PyTorch MergeBot
parent 2989d6c93d
commit 4c6f7cbc86
2 changed files with 14 additions and 3 deletions

View File

@ -1179,6 +1179,14 @@ class TestRefs(TestCase):
self.assertEqual(actual.stride(), expect.stride())
self.assertTrue(actual.is_contiguous())
def test_unbind(self):
# If unbind returns empty tuple, it breaks some assumptions in some backward tests in test_ops.py.
# So can't put this test into common_methods_invocations.py.
a = torch.rand([3, 0, 4])
actual = refs.unbind(a, 1)
expect = torch.unbind(a, 1)
self.assertEqual(actual, expect)
instantiate_device_type_tests(TestRefs, globals())

View File

@ -3450,9 +3450,12 @@ def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
lambda: "Dimension specified as 0 but tensor has no dimensions",
IndexError,
)
return tuple(
torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
)
if t.shape[dim] == 0:
return tuple()
else:
return tuple(
torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
)
@out_wrapper()