Fix overflow in checkInBoundsForStorage (#147352)

Use `computeStorageNbytes` (which checks for overflows) to include the computation re the storage_offset

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147352
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2025-02-24 17:12:53 -08:00
committed by PyTorch MergeBot
parent 6ccbff1450
commit e64441915f
2 changed files with 16 additions and 3 deletions

View File

@ -87,14 +87,15 @@ inline void checkInBoundsForStorage(
const Storage& new_storage) { const Storage& new_storage) {
T storage_size_bytes = T storage_size_bytes =
at::detail::computeStorageNbytes(size, stride, data_type.itemsize()); at::detail::computeStorageNbytes(size, stride, data_type.itemsize());
T storage_offset_bytes = storage_offset * data_type.itemsize();
if (storage_size_bytes == 0) { if (storage_size_bytes == 0) {
// NB: (a tensor with arbitrary 0 dims)'s storage can have any numel. // NB: (a tensor with arbitrary 0 dims)'s storage can have any numel.
return; return;
} }
T storage_size_plus_offset_bytes = at::detail::computeStorageNbytes(
size, stride, data_type.itemsize(), storage_offset);
T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes()); T new_storage_size_bytes = maybe_convert_symint<T>(new_storage.sym_nbytes());
TORCH_CHECK( TORCH_CHECK(
storage_size_bytes + storage_offset_bytes <= new_storage_size_bytes, storage_size_plus_offset_bytes <= new_storage_size_bytes,
"setStorage: sizes ", "setStorage: sizes ",
size, size,
", strides ", ", strides ",
@ -105,7 +106,7 @@ inline void checkInBoundsForStorage(
", and itemsize ", ", and itemsize ",
data_type.itemsize(), data_type.itemsize(),
" requiring a storage size of ", " requiring a storage size of ",
storage_size_bytes + storage_offset_bytes, storage_size_plus_offset_bytes,
" are out of bounds for storage of size ", " are out of bounds for storage of size ",
new_storage_size_bytes); new_storage_size_bytes);
} }

View File

@ -1995,6 +1995,18 @@ class TestOldViewOps(TestCase):
with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"): with self.assertRaisesRegex(RuntimeError, "Stride calculation overflowed"):
x.resize_([0, 4, 2305843009213693952]) x.resize_([0, 4, 2305843009213693952])
@onlyNativeDeviceTypes
def test_as_strided_overflow_storage_offset(self, device):
t = torch.randn(2, 3, device=device)
with self.assertRaisesRegex(
RuntimeError, "Storage size calculation overflowed"
):
torch.as_strided(t, [1], [1], 2**63 - 1)
with self.assertRaisesRegex(
RuntimeError, "Storage size calculation overflowed"
):
torch.as_strided(t, [1], [1], 2**61 - 1)
def test_view_all_dtypes_and_devices(self, device): def test_view_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)