Catch overflows in calculating storage byte size

Fixes #73184

In the issue the output tensor's shape is `[2, 4, 536870912, 536870912]` which results in a `numel()` slightly below the point of overflow. When the storage is created it does `numel() * 8` which overflows and a much smaller storage is allocated than required.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73719
Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
Peter Bell
2022-03-31 16:16:03 +00:00
committed by PyTorch MergeBot
parent 40bf3cfeb7
commit 13a3e5c70c
8 changed files with 210 additions and 51 deletions

View File

@ -2,31 +2,93 @@
#include <ATen/EmptyTensor.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <c10/core/CPUAllocator.h>
#include <c10/util/safe_numerics.h>
#include <limits>
namespace at {
namespace detail {
static c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
namespace {
c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
if (pin_memory) {
return at::detail::getCUDAHooks().getPinnedMemoryAllocator();
}
return c10::GetCPUAllocator();
}
constexpr uint64_t storage_max() {
// int64_t and size_t are used somewhat inconsistently throughout ATen.
// To be safe, storage size calculations must fit in both types.
constexpr auto int64_max = static_cast<uint64_t>(
std::numeric_limits<int64_t>::max());
constexpr auto size_max = static_cast<uint64_t>(
std::numeric_limits<size_t>::max());
return std::min(int64_max, size_max);
}
} // namespace (anonymous)
size_t computeStorageNbytesContiguous(
IntArrayRef sizes,
size_t itemsize_bytes,
size_t storage_offset
) {
// Ignore overflow checks on mobile
#ifndef C10_MOBILE
uint64_t size = 1;
bool overflowed = c10::safe_multiplies_u64(sizes, &size);
overflowed |= c10::add_overflows(size, storage_offset, &size);
overflowed |= c10::mul_overflows(size, itemsize_bytes, &size);
overflowed |= size > storage_max();
TORCH_CHECK(!overflowed,
"Storage size calculation overflowed with sizes=", sizes);
return static_cast<size_t>(size);
#else
const auto numel = c10::multiply_integers(sizes);
return itemsize_bytes * (storage_offset + numel);
#endif
}
size_t computeStorageNbytes(
IntArrayRef sizes,
IntArrayRef strides,
size_t itemsize_bytes) {
size_t itemsize_bytes,
size_t storage_offset
) {
// Ignore overflow checks on mobile
#ifndef C10_MOBILE
// size of the underlying storage is 1 bigger than the offset
// of the last element according to stride
size_t size = 1;
uint64_t size = storage_offset + 1;
bool overflowed = false;
for (const auto i : c10::irange(sizes.size())) {
if (sizes[i] == 0) {
return 0;
}
uint64_t strided_size;
overflowed |= c10::mul_overflows(strides[i], sizes[i] - 1, &strided_size);
overflowed |= c10::add_overflows(size, strided_size, &size);
}
overflowed |= c10::mul_overflows(size, itemsize_bytes, &size);
overflowed |= size > storage_max();
TORCH_CHECK(!overflowed,
"Storage size calculation overflowed with sizes=",
sizes, " and strides=", strides);
return static_cast<size_t>(size);
#else
// size of the underlying storage is 1 bigger than the offset
// of the last element according to stride
uint64_t size = 1;
for (const auto i : c10::irange(sizes.size())) {
if (sizes[i] == 0) {
return 0;
}
size += strides[i] * (sizes[i] - 1);
}
return size * itemsize_bytes;
return itemsize_bytes * (storage_offset + size);
#endif
}
TensorBase empty_generic(
@ -37,9 +99,8 @@ TensorBase empty_generic(
c10::optional<c10::MemoryFormat> memory_format_opt) {
at::detail::check_size_nonnegative(size);
int64_t nelements = c10::multiply_integers(size);
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type);
int64_t size_bytes = nelements * dtype.itemsize();
size_t size_bytes = computeStorageNbytesContiguous(size, dtype.itemsize());
auto storage_impl = c10::make_intrusive<StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
size_bytes,
@ -73,7 +134,7 @@ TensorBase empty_strided_generic(
at::detail::check_size_nonnegative(size);
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(scalar_type);
int64_t size_bytes = computeStorageNbytes(size, stride, dtype.itemsize());
size_t size_bytes = computeStorageNbytes(size, stride, dtype.itemsize());
auto storage_impl = c10::make_intrusive<StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
size_bytes,

View File

@ -10,8 +10,11 @@ inline void check_size_nonnegative(IntArrayRef size) {
}
}
TORCH_API size_t computeStorageNbytesContiguous(
IntArrayRef sizes, size_t itemsize, size_t storage_offset=0);
TORCH_API size_t computeStorageNbytes(
IntArrayRef sizes, IntArrayRef strides, size_t itemsize);
IntArrayRef sizes, IntArrayRef strides,
size_t itemsize, size_t storage_offset=0);
TORCH_API TensorBase empty_generic(
IntArrayRef size,

View File

@ -2,6 +2,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/EmptyTensor.h>
#include <ATen/TensorUtils.h>
#include <c10/core/CPUAllocator.h>
@ -30,22 +31,16 @@ TORCH_API bool resize_output_check(const Tensor& output, IntArrayRef shape);
TORCH_API void resize_bytes_cpu(StorageImpl* storage, size_t size_bytes);
static inline void maybe_resize_storage_cpu(TensorImpl* self, uint64_t new_size) {
static inline void maybe_resize_storage_cpu(TensorImpl* self, size_t new_size_bytes) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
// new_size is 0, so just bail in that case
// (same comment is in cuda/Resize.h)
if (new_size == 0) {
if (self->numel() == 0) {
return;
}
const auto new_size_bytes_i =
(new_size + self->storage_offset()) * self->dtype().itemsize();
TORCH_CHECK(!overflows<size_t>(new_size_bytes_i), "Requested storage size (",
new_size_bytes_i, ") cannot be represented as a size_t");
const auto new_size_bytes = static_cast<size_t>(new_size_bytes_i);
const Storage& storage = self->unsafe_storage();
if (!storage) {
auto new_storage = c10::make_intrusive<StorageImpl>(
@ -68,15 +63,19 @@ inline TensorImpl* resize_impl_cpu_(
return self;
}
int64_t storage_size = 1;
const auto itemsize = self->dtype().itemsize();
const auto storage_offset = self->storage_offset();
size_t storage_size = 1;
if (stride) {
self->set_sizes_and_strides(size, *stride);
// NB: storage size can be different from numel.
storage_size = storage_size_for(size, *stride);
storage_size = at::detail::computeStorageNbytes(
size, *stride, itemsize, storage_offset);
} else {
self->set_sizes_contiguous(size);
storage_size = self->numel();
storage_size = at::detail::computeStorageNbytesContiguous(
size, itemsize, storage_offset);
}
if (resize_storage) {
maybe_resize_storage_cpu(self, storage_size);
}
@ -158,6 +157,12 @@ inline void setStrided(
IntArrayRef stride,
int64_t storage_offset) {
TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape");
for (auto val : stride) {
TORCH_CHECK(val >= 0,
"as_strided: Negative strides are not supported at the moment, "
"got strides: ", stride);
}
auto* self_ = self.unsafeGetTensorImpl();
checkInBoundsForStorage(
size, stride, storage_offset, self_->dtype(), self_->storage());
@ -170,11 +175,6 @@ inline void setStrided(
if (self_->sizes() == size && self_->strides() == stride) {
return;
}
for (auto val : stride) {
TORCH_CHECK(val >= 0,
"as_strided: Negative strides are not supported at the moment, "
"got strides: ", stride);
}
self_->set_sizes_and_strides(size, stride);
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/EmptyTensor.h>
#include <ATen/native/ResizeCommon.h>
#include <c10/cuda/CUDAGuard.h>
@ -9,19 +10,15 @@ namespace at { namespace native {
TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes);
static inline void maybe_resize_storage_cuda(TensorImpl* self, uint64_t new_size) {
static inline void maybe_resize_storage_cuda(TensorImpl* self, size_t new_size_bytes) {
// It does not make sense to try to resize a storage
// to hold 0 elements, and this can break
// if storage_offset is positive but
// new_size is 0, so just bail in that case
// (same comment is in Resize.h)
if (new_size == 0) {
if (self->numel() == 0) {
return;
}
auto new_size_bytes_i = (new_size + self->storage_offset()) * self->dtype().itemsize();
TORCH_CHECK(!overflows<size_t>(new_size_bytes_i), "Requested storage size (",
new_size_bytes_i, ") cannot be represented as a size_t");
const auto new_size_bytes = static_cast<size_t>(new_size_bytes_i);
const Storage &storage = self->unsafe_storage();
TORCH_CHECK(storage, "Tensor: invalid null storage");
@ -45,14 +42,17 @@ inline TensorImpl* resize_impl_cuda_(
guard.set_index(self->storage().device().index());
}
int64_t storage_size = 1;
const auto itemsize = self->dtype().itemsize();
const auto storage_offset = self->storage_offset();
size_t storage_size = 1;
if (stride) {
self->set_sizes_and_strides(size, *stride);
// NB: storage size can be different from numel.
storage_size = storage_size_for(size, *stride);
storage_size = at::detail::computeStorageNbytes(
size, *stride, itemsize, storage_offset);
} else {
self->set_sizes_contiguous(size);
storage_size = self->numel();
storage_size = at::detail::computeStorageNbytesContiguous(
size, itemsize, storage_offset);
}
maybe_resize_storage_cuda(self, storage_size);

View File

@ -16,9 +16,11 @@
#include <c10/util/accumulate.h>
#include <c10/util/irange.h>
#include <c10/util/python_stub.h>
#include <c10/util/safe_numerics.h>
#include <algorithm>
#include <atomic>
#include <limits>
#include <memory>
#include <numeric>
@ -2266,11 +2268,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* Compute the number of elements based on the sizes of a tensor.
*/
int64_t compute_numel() const {
int64_t n = 1;
for (auto s : sizes()) {
n *= s;
}
return n;
#if C10_HAS_BUILTIN_OVERFLOW() && !defined(C10_MOBILE)
// Use overflow checks if supported by the compiler
return safe_compute_numel();
#else
return c10::multiply_integers(sizes());
#endif
}
/**
@ -2279,14 +2282,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* using a sparse layout has multiple dimensions with large sizes.
*/
int64_t safe_compute_numel() const {
int64_t n = 1;
for (auto s : sizes()) {
TORCH_CHECK(
s == 0 || n <= std::numeric_limits<int64_t>::max() / s,
"numel: integer multiplication overflow");
n *= s;
}
return n;
uint64_t n = 1;
bool overflows = c10::safe_multiplies_u64(sizes(), &n);
constexpr auto numel_max = std::min(
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()),
static_cast<uint64_t>(std::numeric_limits<size_t>::max()));
overflows |= (n > numel_max);
TORCH_CHECK(!overflows, "numel: integer multiplication overflow");
return static_cast<int64_t>(n);
}
/**

74
c10/util/safe_numerics.h Normal file
View File

@ -0,0 +1,74 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <iterator>
#include <numeric>
#include <type_traits>
// GCC has __builtin_mul_overflow from before it supported __has_builtin
#ifdef _MSC_VER
#define C10_HAS_BUILTIN_OVERFLOW() (0)
#include <c10/util/llvmMathExtras.h>
#include <intrin.h>
#else
#define C10_HAS_BUILTIN_OVERFLOW() (1)
#endif
namespace c10 {
C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
return __builtin_add_overflow(a, b, out);
#else
unsigned long long tmp;
auto carry = _addcarry_u64(0, a, b, &tmp);
*out = tmp;
return carry;
#endif
}
C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
return __builtin_mul_overflow(a, b, out);
#else
*out = a * b;
// This test isnt exact, but avoids doing integer division
return (
(c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64);
#endif
}
template <typename It>
bool safe_multiplies_u64(It first, It last, uint64_t* out) {
#if C10_HAS_BUILTIN_OVERFLOW()
uint64_t prod = 1;
bool overflow = false;
for (; first != last; ++first) {
overflow |= c10::mul_overflows(prod, *first, &prod);
}
*out = prod;
return overflow;
#else
uint64_t prod = 1;
uint64_t prod_log2 = 0;
bool is_zero = false;
for (; first != last; ++first) {
auto x = static_cast<uint64_t>(*first);
prod *= x;
// log2(0) isn't valid, so need to track it specially
is_zero |= (x == 0);
prod_log2 += c10::llvm::Log2_64_Ceil(x);
}
*out = prod;
// This test isnt exact, but avoids doing integer division
return !is_zero && (prod_log2 >= 64);
#endif
}
template <typename Container>
bool safe_multiplies_u64(const Container& c, uint64_t* out) {
return safe_multiplies_u64(c.begin(), c.end(), out);
}
} // namespace c10

View File

@ -2665,6 +2665,15 @@ class TestTensorCreation(TestCase):
y = torch.empty(tuple(size_ones_instead_of_zeros), device=device)
self.assertEqual(x.stride(), y.stride())
@onlyNativeDeviceTypes
def test_empty_overflow(self, device):
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
torch.empty([2, 4, 2**29, 2**29], dtype=torch.float64)
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
torch.empty([8, 8, 2**29, 2**29], dtype=torch.float64)
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
torch.empty_strided([8, 8], [2**61, 1], dtype=torch.float64)
def test_eye(self, device):
for dtype in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16):
if dtype == torch.bfloat16:

View File

@ -1795,6 +1795,14 @@ class TestOldViewOps(TestCase):
x.resize_as_(y)
self.assertEqual(y.shape, x.shape)
@onlyNativeDeviceTypes
def test_resize_overflow(self, device):
x = torch.empty((), dtype=torch.float64)
with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
x.resize_([2, 4, 2**29, 2**29])
with self.assertRaisesRegex(RuntimeError, 'overflow'):
x.resize_([8, 8, 2**29, 2**29])
def test_view_all_dtypes_and_devices(self, device):
for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)