mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
disallow empty named dims list to flatten(names, name) (#61953)
Summary: Fixes https://github.com/pytorch/pytorch/issues/61137 by raising an error if an empty tuple is passed in for the names: ``` >>> torch.empty((2, 3), names=['a', 'b']).flatten((), 'abc') RuntimeError: flatten(tensor, dims, out_dim): dims cannot be empty ``` or from the original issue: ``` >>> torch.empty((2, 3)).flatten((), 'abc') RuntimeError: flatten(tensor, dims, out_dim): dims cannot be empty ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/61953 Reviewed By: iramazanli Differential Revision: D30574571 Pulled By: malfet fbshipit-source-id: e606e84458a8dd66e5da6d0eb1a260f37b4ce91b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c59970db6b
commit
6bb4b5d150
@ -1072,6 +1072,11 @@ class TestNamedTensor(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
|
||||
tensor.flatten(['H', 'D', 'W'], 'features')
|
||||
|
||||
def test_flatten_nodims(self):
|
||||
tensor = torch.empty((2, 3))
|
||||
with self.assertRaisesRegex(RuntimeError, "cannot be empty"):
|
||||
tensor.flatten((), 'abcd')
|
||||
|
||||
def test_unflatten(self):
|
||||
# test args: tensor, int, namedshape
|
||||
self.assertTrue(torch.equal(
|
||||
|
Reference in New Issue
Block a user