Add the bound check for flatten with out_dim (#120894)

Fixes #120762

The bound is not valid in the example but unchecked.
```
a = torch.tensor([1, 2, 3])
a.flatten(start_dim=0, end_dim=1, out_dim='a')
```

The same is checked for the case

```
a = torch.tensor([1, 2, 3])
a.flatten(start_dim=0, end_dim=1)
```

- Therefore, just apply the same check.

@malfet @janeyx99
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120894
Approved by: https://github.com/malfet, https://github.com/spzala
This commit is contained in:
lancerts
2024-03-02 03:56:52 +00:00
committed by PyTorch MergeBot
parent 06fe6ed82b
commit 2d9efad38f
2 changed files with 19 additions and 0 deletions

View File

@ -3419,6 +3419,10 @@ Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
}
Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim, Dimname out_dim) {
start_dim = maybe_wrap_dim(start_dim, self.dim());
end_dim = maybe_wrap_dim(end_dim, self.dim());
TORCH_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
auto outnames = self.names().vec();
outnames.erase(outnames.begin() + start_dim, outnames.begin() + end_dim + 1);
outnames.insert(outnames.begin() + start_dim, out_dim);

View File

@ -1079,6 +1079,21 @@ class TestNamedTensor(TestCase):
with self.assertRaisesRegex(RuntimeError, "cannot be empty"):
tensor.flatten((), 'abcd')
def test_flatten_index_error(self):
tensor = torch.randn(1, 2)
with self.assertRaisesRegex(IndexError,
r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
tensor.flatten(0, 2)
with self.assertRaisesRegex(IndexError,
r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
tensor.flatten(0, 2, 'N')
with self.assertRaisesRegex(RuntimeError,
r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
tensor.flatten(1, 0)
with self.assertRaisesRegex(RuntimeError,
r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
tensor.flatten(1, 0, 'N')
def test_unflatten(self):
# test args: tensor, int, namedshape
self.assertTrue(torch.equal(