mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Overload mul_overflows
for size_t
(#155736)
Partially fixes https://github.com/pytorch/executorch/pull/11537. We want to extend `mul_overflows` to support `size_t` in ExecuTorch. The current workflow in ET checks that the `c10` mirrors exactly as in PT, so the tests are failing. See comment: https://github.com/pytorch/executorch/pull/11537#issuecomment-2963821312 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155736 Approved by: https://github.com/swolchok
This commit is contained in:
committed by
PyTorch MergeBot
parent
42b48ee672
commit
3dda80e990
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
// GCC has __builtin_mul_overflow from before it supported __has_builtin
|
||||
@ -31,28 +32,36 @@ C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) {
|
||||
#endif
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
|
||||
template <typename T>
|
||||
C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) {
|
||||
#if C10_HAS_BUILTIN_OVERFLOW()
|
||||
return __builtin_mul_overflow(a, b, out);
|
||||
#else
|
||||
*out = a * b;
|
||||
// This test isn't exact, but avoids doing integer division
|
||||
return (
|
||||
(c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) < 64);
|
||||
static_assert(
|
||||
std::is_integral_v<T>, "mul_overflows only supports integral types");
|
||||
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
// For signed types, use the division-based check
|
||||
volatile T tmp = a * b;
|
||||
*out = tmp;
|
||||
if (a == 0 || b == 0) {
|
||||
return false;
|
||||
}
|
||||
return !(a == tmp / b);
|
||||
} else {
|
||||
// For unsigned types, use leading zeros approach
|
||||
// This test isn't exact, but avoids doing integer division
|
||||
*out = a * b;
|
||||
constexpr int bits = sizeof(T) * 8;
|
||||
return (
|
||||
(c10::llvm::countLeadingZeros(a) + c10::llvm::countLeadingZeros(b)) <
|
||||
bits);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
C10_ALWAYS_INLINE bool mul_overflows(int64_t a, int64_t b, int64_t* out) {
|
||||
#if C10_HAS_BUILTIN_OVERFLOW()
|
||||
return __builtin_mul_overflow(a, b, out);
|
||||
#else
|
||||
volatile int64_t tmp = a * b;
|
||||
*out = tmp;
|
||||
if (a == 0 || b == 0) {
|
||||
return false;
|
||||
}
|
||||
return !(a == tmp / b);
|
||||
#endif
|
||||
C10_ALWAYS_INLINE bool mul_overflows(uint64_t a, uint64_t b, uint64_t* out) {
|
||||
return mul_overflows<uint64_t>(a, b, out);
|
||||
}
|
||||
|
||||
template <typename It>
|
||||
|
Reference in New Issue
Block a user