mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add the bound check for flatten with out_dim (#120894)
Fixes #120762 The bound is not valid in the example but unchecked. ``` a = torch.tensor([1, 2, 3]) a.flatten(start_dim=0, end_dim=1, out_dim='a') ``` The same is checked for the case ``` a = torch.tensor([1, 2, 3]) a.flatten(start_dim=0, end_dim=1) ``` - Therefore, just apply the same check. @malfet @janeyx99 Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/120894 Approved by: https://github.com/malfet, https://github.com/spzala
This commit is contained in:
committed by
PyTorch MergeBot
parent
06fe6ed82b
commit
2d9efad38f
@ -3419,6 +3419,10 @@ Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
|
||||
}
|
||||
|
||||
Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim, Dimname out_dim) {
|
||||
start_dim = maybe_wrap_dim(start_dim, self.dim());
|
||||
end_dim = maybe_wrap_dim(end_dim, self.dim());
|
||||
TORCH_CHECK(start_dim <= end_dim, "flatten() has invalid args: start_dim cannot come after end_dim");
|
||||
|
||||
auto outnames = self.names().vec();
|
||||
outnames.erase(outnames.begin() + start_dim, outnames.begin() + end_dim + 1);
|
||||
outnames.insert(outnames.begin() + start_dim, out_dim);
|
||||
|
@ -1079,6 +1079,21 @@ class TestNamedTensor(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "cannot be empty"):
|
||||
tensor.flatten((), 'abcd')
|
||||
|
||||
def test_flatten_index_error(self):
|
||||
tensor = torch.randn(1, 2)
|
||||
with self.assertRaisesRegex(IndexError,
|
||||
r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
|
||||
tensor.flatten(0, 2)
|
||||
with self.assertRaisesRegex(IndexError,
|
||||
r"Dimension out of range \(expected to be in range of \[-2, 1\], but got 2\)"):
|
||||
tensor.flatten(0, 2, 'N')
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
|
||||
tensor.flatten(1, 0)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"flatten\(\) has invalid args: start_dim cannot come after end_dim"):
|
||||
tensor.flatten(1, 0, 'N')
|
||||
|
||||
def test_unflatten(self):
|
||||
# test args: tensor, int, namedshape
|
||||
self.assertTrue(torch.equal(
|
||||
|
Reference in New Issue
Block a user