mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix overflow in checkInBoundsForStorage (#147352)
Use `computeStorageNbytes` (which checks for overflows) to include the computation re the storage_offset Pull Request resolved: https://github.com/pytorch/pytorch/pull/147352 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6ccbff1450
commit
e64441915f
@ -87,14 +87,15 @@ inline void checkInBoundsForStorage(
|
||||
const Storage& new_storage) {
|
||||
T storage_size_bytes =
|
||||
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
||||
T storage_offset_bytes = storage_offset * data_type.itemsize();
|
||||
if (storage_size_bytes == 0) {
|
||||
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
|
||||
return;
|
||||
}
|
||||
T storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
|
||||
size, stride, data_type.itemsize(), storage_offset);
|
||||
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
|
||||
TORCH_CHECK(
|
||||
storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
|
||||
storage_size_plus_offset_bytes <= new_storage_size_bytes,
|
||||
"setStorage: sizes ",
|
||||
size,
|
||||
", strides ",
|
||||
@ -105,7 +106,7 @@ inline void checkInBoundsForStorage(
|
||||
", and itemsize ",
|
||||
data_type.itemsize(),
|
||||
" requiring a storage size of ",
|
||||
storage_size_bytes + storage_offset_bytes,
|
||||
storage_size_plus_offset_bytes,
|
||||
" are out of bounds for storage of size ",
|
||||
new_storage_size_bytes);
|
||||
}
|
||||
|
@ -1995,6 +1995,18 @@ class TestOldViewOps(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"):
|
||||
x.resize_([0, 4, 2305843009213693952])
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
def test_as_strided_overflow_storage_offset(self, device):
|
||||
t = torch.randn(2, 3, device=device)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Storage size calculation overflowed"
|
||||
):
|
||||
torch.as_strided(t, [1], [1], 2**63 - 1)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Storage size calculation overflowed"
|
||||
):
|
||||
torch.as_strided(t, [1], [1], 2**61 - 1)
|
||||
|
||||
def test_view_all_dtypes_and_devices(self, device):
|
||||
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
||||
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
|
||||
|
Reference in New Issue
Block a user