Fix UB in BFloat16 round_to_nearest_even (#157942)

Type punning using unions is undefined behavior in C++ (you may not access a member of a union that is not the active member). bit_cast is the right way.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157942
Approved by: https://github.com/Skylion007
This commit is contained in:
Scott Wolchok
2025-07-09 10:41:57 -07:00
committed by PyTorch MergeBot
parent a9ac9f2635
commit e3f8141c25
3 changed files with 8 additions and 8 deletions

View File

@ -4,6 +4,7 @@
// 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa.
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <cmath>
#include <cstdint>
#include <cstring>
@ -67,13 +68,7 @@ inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) {
#endif
return UINT16_C(0x7FC0);
} else {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
union {
uint32_t U32; // NOLINT(facebook-hte-BadMemberName)
float F32; // NOLINT(facebook-hte-BadMemberName)
};
F32 = src;
const uint32_t U32 = c10::bit_cast<uint32_t>(src);
uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
return static_cast<uint16_t>((U32 + rounding_bias) >> 16);
}

View File

@ -3,6 +3,8 @@
#include <cstring>
#include <type_traits>
#include <c10/macros/Macros.h>
#if __has_include(<bit>) && (defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L)
#include <bit>
#define C10_HAVE_STD_BIT_CAST 1
@ -23,7 +25,7 @@ using std::bit_cast;
// See https://en.cppreference.com/w/cpp/numeric/bit_cast for more
// information as well as the source of our implementations.
template <class To, class From>
std::enable_if_t<
C10_HOST_DEVICE std::enable_if_t<
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
To>

View File

@ -58,6 +58,9 @@ def define_targets(rules):
name = "bit_cast",
hdrs = ["bit_cast.h"],
visibility = ["//:__subpackages__"],
deps = [
"//c10/macros",
],
)
rules.cc_library(