mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix UB in BFloat16 round_to_nearest_even (#157942)
Type punning using unions is undefined behavior in C++ (you may not access a member of a union that is not the active member). bit_cast is the right way. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157942 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
a9ac9f2635
commit
e3f8141c25
@ -4,6 +4,7 @@
|
||||
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
@ -67,13 +68,7 @@ inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
|
||||
#endif
|
||||
return UINT16_C(0x7FC0);
|
||||
} else {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
union {
|
||||
uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
|
||||
float F32; // NOLINT(facebook-hte-BadMemberName)
|
||||
};
|
||||
|
||||
F32 = src;
|
||||
const uint32_t U32 = c10::bit_cast<uint32_t>(src);
|
||||
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
|
||||
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
|
||||
}
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
|
||||
#include <bit>
|
||||
#define C10_HAVE_STD_BIT_CAST 1
|
||||
@ -23,7 +25,7 @@ using std::bit_cast;
|
||||
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
|
||||
// information as well as the source of our implementations.
|
||||
template <class To, class From>
|
||||
std::enable_if_t<
|
||||
C10_HOST_DEVICE std::enable_if_t<
|
||||
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
|
||||
std::is_trivially_copyable_v<To>,
|
||||
To>
|
||||
|
@ -58,6 +58,9 @@ def define_targets(rules):
|
||||
name = "bit_cast",
|
||||
hdrs = ["bit_cast.h"],
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = [
|
||||
"//c10/macros",
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_library(
|
||||
|
Reference in New Issue
Block a user