mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Catch overflows in calculating storage byte size
Fixes #73184 In the issue the output tensor's shape is `[2, 4, 536870912, 536870912]` which results in a `numel()` slightly below the point of overflow. When the storage is created it does `numel() * 8` which overflows and a much smaller storage is allocated than required. Pull Request resolved: https://github.com/pytorch/pytorch/pull/73719 Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
40bf3cfeb7
commit
13a3e5c70c
@ -1795,6 +1795,14 @@ class TestOldViewOps(TestCase):
|
||||
x.resize_as_(y)
|
||||
self.assertEqual(y.shape, x.shape)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_resize_overflow(self, device):
|
||||
x = torch.empty((), dtype=torch.float64)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
|
||||
x.resize_([2, 4, 2**29, 2**29])
|
||||
with self.assertRaisesRegex(RuntimeError, 'overflow'):
|
||||
x.resize_([8, 8, 2**29, 2**29])
|
||||
|
||||
def test_view_all_dtypes_and_devices(self, device):
|
||||
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
||||
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
|
||||
|
Reference in New Issue
Block a user