disallow empty named dims list to flatten(names, name) (#61953)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/61137 by raising an error if an empty tuple is passed in for the names:
```
>>> torch.empty((2, 3), names=['a', 'b']).flatten((), 'abc')
RuntimeError: flatten(tensor, dims, out_dim): dims cannot be empty
```

or from the original issue:
```
>>> torch.empty((2, 3)).flatten((), 'abc')
RuntimeError: flatten(tensor, dims, out_dim): dims cannot be empty
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61953

Reviewed By: iramazanli

Differential Revision: D30574571

Pulled By: malfet

fbshipit-source-id: e606e84458a8dd66e5da6d0eb1a260f37b4ce91b
This commit is contained in:
Matti Picus
2021-08-31 18:54:44 -07:00
committed by Facebook GitHub Bot
parent c59970db6b
commit 6bb4b5d150
2 changed files with 7 additions and 0 deletions

View File

@ -1072,6 +1072,11 @@ class TestNamedTensor(TestCase):
with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
tensor.flatten(['H', 'D', 'W'], 'features')
def test_flatten_nodims(self):
tensor = torch.empty((2, 3))
with self.assertRaisesRegex(RuntimeError, "cannot be empty"):
tensor.flatten((), 'abcd')
def test_unflatten(self):
# test args: tensor, int, namedshape
self.assertTrue(torch.equal(