Fixed segfault when trying to permute empty tensor (#116335)

Fixes #116325.

Fixed unchecked access to first element of `dims` when permuting an empty tensor. Added test to prevent regressions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116335
Approved by: https://github.com/Skylion007
This commit is contained in:
Tobias Ringwald
2023-12-23 23:14:28 +00:00
committed by PyTorch MergeBot
parent 015bd0e0a1
commit 3a4fe835cc
2 changed files with 6 additions and 2 deletions

View File

@ -952,6 +952,11 @@ class TestSparse(TestSparseBase):
s.permute(dims=(1, 0))
with self.assertRaisesRegex(RuntimeError, "duplicate dims"):
s.permute(dims=(1, 1, 1))
# Calling permute on a sparse tensor with an empty tuple used to segfault,
# see https://github.com/pytorch/pytorch/issues/116325
x = torch.rand((), device=device, dtype=dtype).to_sparse()
x.permute(())
self.assertEqual(len(x.values()), 1)
def test_shape(sparse_dims, nnz, with_size):
ndim = len(with_size)