mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix overflow in checkInBoundsForStorage (#147352)
Use `computeStorageNbytes` (which checks for overflows) to include the computation re the storage_offset Pull Request resolved: https://github.com/pytorch/pytorch/pull/147352 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6ccbff1450
commit
e64441915f
@ -87,14 +87,15 @@ inline void checkInBoundsForStorage(
|
|||||||
const Storage& new_storage) {
|
const Storage& new_storage) {
|
||||||
T storage_size_bytes =
|
T storage_size_bytes =
|
||||||
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
|
||||||
T storage_offset_bytes = storage_offset * data_type.itemsize();
|
|
||||||
if (storage_size_bytes == 0) {
|
if (storage_size_bytes == 0) {
|
||||||
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
|
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
T storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
|
||||||
|
size, stride, data_type.itemsize(), storage_offset);
|
||||||
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
|
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes,
|
storage_size_plus_offset_bytes <= new_storage_size_bytes,
|
||||||
"setStorage: sizes ",
|
"setStorage: sizes ",
|
||||||
size,
|
size,
|
||||||
", strides ",
|
", strides ",
|
||||||
@ -105,7 +106,7 @@ inline void checkInBoundsForStorage(
|
|||||||
", and itemsize ",
|
", and itemsize ",
|
||||||
data_type.itemsize(),
|
data_type.itemsize(),
|
||||||
" requiring a storage size of ",
|
" requiring a storage size of ",
|
||||||
storage_size_bytes + storage_offset_bytes,
|
storage_size_plus_offset_bytes,
|
||||||
" are out of bounds for storage of size ",
|
" are out of bounds for storage of size ",
|
||||||
new_storage_size_bytes);
|
new_storage_size_bytes);
|
||||||
}
|
}
|
||||||
|
@ -1995,6 +1995,18 @@ class TestOldViewOps(TestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"):
|
with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"):
|
||||||
x.resize_([0, 4, 2305843009213693952])
|
x.resize_([0, 4, 2305843009213693952])
|
||||||
|
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
def test_as_strided_overflow_storage_offset(self, device):
|
||||||
|
t = torch.randn(2, 3, device=device)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Storage size calculation overflowed"
|
||||||
|
):
|
||||||
|
torch.as_strided(t, [1], [1], 2**63 - 1)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Storage size calculation overflowed"
|
||||||
|
):
|
||||||
|
torch.as_strided(t, [1], [1], 2**61 - 1)
|
||||||
|
|
||||||
def test_view_all_dtypes_and_devices(self, device):
|
def test_view_all_dtypes_and_devices(self, device):
|
||||||
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
|
||||||
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
|
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
|
||||||
|
Reference in New Issue
Block a user