mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix prims unbind if given dimension size is 0 (#100122)
Fixes #99832 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100122 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
2989d6c93d
commit
4c6f7cbc86
@ -1179,6 +1179,14 @@ class TestRefs(TestCase):
|
||||
self.assertEqual(actual.stride(), expect.stride())
|
||||
self.assertTrue(actual.is_contiguous())
|
||||
|
||||
def test_unbind(self):
|
||||
# If unbind returns empty tuple, it breaks some assumptions in some backward tests in test_ops.py.
|
||||
# So can't put this test into common_methods_invocations.py.
|
||||
a = torch.rand([3, 0, 4])
|
||||
actual = refs.unbind(a, 1)
|
||||
expect = torch.unbind(a, 1)
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestRefs, globals())
|
||||
|
||||
|
@ -3450,9 +3450,12 @@ def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
|
||||
lambda: "Dimension specified as 0 but tensor has no dimensions",
|
||||
IndexError,
|
||||
)
|
||||
return tuple(
|
||||
torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
|
||||
)
|
||||
if t.shape[dim] == 0:
|
||||
return tuple()
|
||||
else:
|
||||
return tuple(
|
||||
torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
|
||||
)
|
||||
|
||||
|
||||
@out_wrapper()
|
||||
|
Reference in New Issue
Block a user