mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add size param check of unfold (#139965)
Fixes #76617 Changes: - Add check of input `size` value, give user friendly hint message - fix `FIXME: move to shape ops test suite` in test file Before ```python import torch x = torch.arange(1., 8) x.unfold(0, -1, 1) Traceback (most recent call last): File "/home/zong/code/unfold.py", line 12, in <module> x.unfold(0, -1, 1) RuntimeError: Storage size calculation overflowed with sizes=[9, -1] and strides=[1, 1] ``` After ```python import torch x = torch.arange(1., 8) x.unfold(0, -1, 1) Traceback (most recent call last): File "/home/zong/code/pytorch/../unfold.py", line 12, in <module> x.unfold(0, -1, 1) RuntimeError: size is -1 but must be >= 0 ``` Test Result: ```bash pytest test/test_shape_ops.py ```  ```bash $ lintrunner ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/139965 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
f89b2b9630
commit
5ef33e40b3
@ -805,6 +805,31 @@ class TestShapeOps(TestCase):
|
||||
self.assertEqual(x.sparse_dim(), 0)
|
||||
self.assertEqual(x.dense_dim(), len(shape))
|
||||
|
||||
def test_unfold_all_devices_and_dtypes(self, device):
|
||||
for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
|
||||
if dt == torch.bool:
|
||||
x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
|
||||
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
|
||||
else:
|
||||
x = torch.empty((0, 1, 3, 0), dtype=dt, device=device)
|
||||
self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
|
||||
|
||||
def test_unfold_scalars(self, device):
|
||||
x = torch.tensor(0.5, device=device)
|
||||
# unfold on a 0-dimensional tensor should always return a 1-d dimensional
|
||||
# tensor of shape [size] (i.e., the second parameter to unfold)
|
||||
|
||||
self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 1))
|
||||
self.assertEqual(torch.empty(0, device=device), x.unfold(0, 0, 2))
|
||||
self.assertEqual(torch.tensor([0.5], device=device), x.unfold(0, 1, 1))
|
||||
|
||||
def test_unfold_errors(self, device):
|
||||
x = torch.arange(1.0, 8, device=device)
|
||||
with self.assertRaisesRegex(RuntimeError, "size is -1 but must be >= 0"):
|
||||
x.unfold(0, -1, 1)
|
||||
with self.assertRaisesRegex(RuntimeError, "step is -1 but must be > 0"):
|
||||
x.unfold(0, 1, -1)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestShapeOps, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user