Add size param check of unfold (#139965)

Fixes #76617

Changes:

- Add check of input `size` value, give user friendly hint message
- fix `FIXME: move to shape ops test suite` in test file

Before
```python
import torch
x = torch.arange(1., 8)
x.unfold(0, -1, 1)

Traceback (most recent call last):
  File "/home/zong/code/unfold.py", line 12, in <module>
    x.unfold(0, -1, 1)
RuntimeError: Storage size calculation overflowed with sizes=[9, -1] and strides=[1, 1]

```

After
```python
import torch
x = torch.arange(1., 8)
x.unfold(0, -1, 1)

Traceback (most recent call last):
  File "/home/zong/code/pytorch/../unfold.py", line 12, in <module>
    x.unfold(0, -1, 1)
RuntimeError: size is -1 but must be >= 0
```

Test Result:
```bash
pytest test/test_shape_ops.py
```

![image](https://github.com/user-attachments/assets/d7bcef62-04e6-4187-9c8f-bc5220ff6c33)

```bash
$ lintrunner
```

![image](https://github.com/user-attachments/assets/6b48d095-5c8a-4e75-9957-dc22d39a73bb)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139965
Approved by: https://github.com/ezyang
This commit is contained in:
zeshengzong
2024-11-09 17:12:50 +00:00
committed by PyTorch MergeBot
parent f89b2b9630
commit 5ef33e40b3
3 changed files with 26 additions and 21 deletions

View File

@ -3792,6 +3792,7 @@ Tensor unfold(const Tensor& self, int64_t d, int64_t size, int64_t step) {
auto sizes = self.sizes().vec();
auto strides = self.strides().vec();
int64_t max_size = self.dim() == 0 ? 1 : sizes[d];
TORCH_CHECK(size >= 0, "size is ", size, " but must be >= 0");
TORCH_CHECK(size <= max_size, "maximum size for tensor at dimension ", d,
" is ", max_size, " but size is ", size);
TORCH_CHECK(step > 0, "step is ", step, " but must be > 0");

View File

@ -805,6 +805,31 @@ class TestShapeOps(TestCase):
self.assertEqual(x.sparse_dim(), 0)
self.assertEqual(x.dense_dim(), len(shape))
def test_unfold_all_devices_and_dtypes(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
if dt == torch.bool:
x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
else:
x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
def test_unfold_scalars(self, device):
x = torch.tensor(0.5, device=device)
# unfold on a 0-dimensional tensor should always return a 1-d dimensional
# tensor of shape [size] (i.e., the second parameter to unfold)
self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 1))
self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 2))
self.assertEqual(torch.tensor([0.5], device=device), x.unfold(0, 1, 1))
def test_unfold_errors(self, device):
x = torch.arange(1.0, 8, device=device)
with self.assertRaisesRegex(RuntimeError, "size is -1 but must be >= 0"):
x.unfold(0, -1, 1)
with self.assertRaisesRegex(RuntimeError, "step is -1 but must be > 0"):
x.unfold(0, 1, -1)
instantiate_device_type_tests(TestShapeOps, globals())

View File

@ -3139,27 +3139,6 @@ else:
x[1] = True
self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool, device=device))
# FIXME: move to shape ops test suite
def test_unfold_all_devices_and_dtypes(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
if dt == torch.bool:
x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
else:
x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
# FIXME: move to shape ops test suite
def test_unfold_scalars(self, device):
x = torch.tensor(0.5, device=device)
# unfold on a 0-dimensional tensor should always return a 1-d dimensional
# tensor of shape [size] (i.e., the second parameter to unfold)
self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 1))
self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 2))
self.assertEqual(torch.tensor([0.5], device=device), x.unfold(0, 1, 1))
# FIXME: move to data movement test suite
def test_copy_all_dtypes_and_devices(self, device):
from copy import copy