Catch overflows in calculating storage byte size

Fixes #73184

In the issue the output tensor's shape is `[2, 4, 536870912, 536870912]` which results in a `numel()` slightly below the point of overflow. When the storage is created it does `numel() * 8` which overflows and a much smaller storage is allocated than required.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73719
Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
Peter Bell
2022-03-31 16:16:03 +00:00
committed by PyTorch MergeBot
parent 40bf3cfeb7
commit 13a3e5c70c
8 changed files with 210 additions and 51 deletions

View File

@ -1795,6 +1795,14 @@ class TestOldViewOps(TestCase):
x.resize_as_(y)
self.assertEqual(y.shape, x.shape)
@onlyNativeDeviceTypes
def test_resize_overflow(self, device):
x = torch.empty((), dtype=torch.float64)
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
x.resize_([2, 4, 2**29, 2**29])
with self.assertRaisesRegex(RuntimeError, 'overflow'):
x.resize_([8, 8, 2**29, 2**29])
def test_view_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)