mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Concat namespaces in torch/csrc/utils/* Pull Request resolved: https://github.com/pytorch/pytorch/pull/128787 Approved by: https://github.com/Skylion007
484 lines
12 KiB
C++
484 lines
12 KiB
C++
#include <c10/util/BFloat16.h>
|
|
#include <c10/util/irange.h>
|
|
#include <torch/csrc/utils/byte_order.h>
|
|
|
|
#include <cstring>
|
|
#include <vector>
|
|
|
|
#if defined(_MSC_VER)
|
|
#include <stdlib.h>
|
|
#endif
|
|
|
|
namespace {
|
|
|
|
static inline void swapBytes16(void* ptr) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
uint16_t output;
|
|
memcpy(&output, ptr, sizeof(uint16_t));
|
|
#if defined(_MSC_VER) && !defined(_DEBUG)
|
|
output = _byteswap_ushort(output);
|
|
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
|
|
output = __builtin_bswap16(output);
|
|
#else
|
|
uint16_t Hi = output >> 8;
|
|
uint16_t Lo = output << 8;
|
|
output = Hi | Lo;
|
|
#endif
|
|
memcpy(ptr, &output, sizeof(uint16_t));
|
|
}
|
|
|
|
static inline void swapBytes32(void* ptr) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
uint32_t output;
|
|
memcpy(&output, ptr, sizeof(uint32_t));
|
|
#if defined(_MSC_VER) && !defined(_DEBUG)
|
|
output = _byteswap_ulong(output);
|
|
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
|
|
output = __builtin_bswap32(output);
|
|
#else
|
|
uint32_t Byte0 = output & 0x000000FF;
|
|
uint32_t Byte1 = output & 0x0000FF00;
|
|
uint32_t Byte2 = output & 0x00FF0000;
|
|
uint32_t Byte3 = output & 0xFF000000;
|
|
output = (Byte0 << 24) | (Byte1 << 8) | (Byte2 >> 8) | (Byte3 >> 24);
|
|
#endif
|
|
memcpy(ptr, &output, sizeof(uint32_t));
|
|
}
|
|
|
|
static inline void swapBytes64(void* ptr) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
uint64_t output;
|
|
memcpy(&output, ptr, sizeof(uint64_t));
|
|
#if defined(_MSC_VER)
|
|
output = _byteswap_uint64(output);
|
|
#elif defined(__llvm__) || defined(__GNUC__) && !defined(__ICC)
|
|
output = __builtin_bswap64(output);
|
|
#else
|
|
uint64_t Byte0 = output & 0x00000000000000FF;
|
|
uint64_t Byte1 = output & 0x000000000000FF00;
|
|
uint64_t Byte2 = output & 0x0000000000FF0000;
|
|
uint64_t Byte3 = output & 0x00000000FF000000;
|
|
uint64_t Byte4 = output & 0x000000FF00000000;
|
|
uint64_t Byte5 = output & 0x0000FF0000000000;
|
|
uint64_t Byte6 = output & 0x00FF000000000000;
|
|
uint64_t Byte7 = output & 0xFF00000000000000;
|
|
output = (Byte0 << (7 * 8)) | (Byte1 << (5 * 8)) | (Byte2 << (3 * 8)) |
|
|
(Byte3 << (1 * 8)) | (Byte7 >> (7 * 8)) | (Byte6 >> (5 * 8)) |
|
|
(Byte5 >> (3 * 8)) | (Byte4 >> (1 * 8));
|
|
#endif
|
|
memcpy(ptr, &output, sizeof(uint64_t));
|
|
}
|
|
|
|
static inline uint16_t decodeUInt16(const uint8_t* data) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
uint16_t output;
|
|
memcpy(&output, data, sizeof(uint16_t));
|
|
return output;
|
|
}
|
|
|
|
static inline uint16_t decodeUInt16ByteSwapped(const uint8_t* data) {
|
|
uint16_t output = decodeUInt16(data);
|
|
swapBytes16(&output);
|
|
return output;
|
|
}
|
|
|
|
static inline uint32_t decodeUInt32(const uint8_t* data) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
uint32_t output;
|
|
memcpy(&output, data, sizeof(uint32_t));
|
|
return output;
|
|
}
|
|
|
|
static inline uint32_t decodeUInt32ByteSwapped(const uint8_t* data) {
|
|
uint32_t output = decodeUInt32(data);
|
|
swapBytes32(&output);
|
|
return output;
|
|
}
|
|
|
|
static inline uint64_t decodeUInt64(const uint8_t* data) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
uint64_t output;
|
|
memcpy(&output, data, sizeof(uint64_t));
|
|
return output;
|
|
}
|
|
|
|
static inline uint64_t decodeUInt64ByteSwapped(const uint8_t* data) {
|
|
uint64_t output = decodeUInt64(data);
|
|
swapBytes64(&output);
|
|
return output;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
namespace torch::utils {
|
|
|
|
THPByteOrder THP_nativeByteOrder() {
|
|
uint32_t x = 1;
|
|
return *(uint8_t*)&x ? THP_LITTLE_ENDIAN : THP_BIG_ENDIAN;
|
|
}
|
|
|
|
void THP_decodeInt16Buffer(
|
|
int16_t* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
dst[i] = (int16_t)(do_byte_swap ? decodeUInt16ByteSwapped(src)
|
|
: decodeUInt16(src));
|
|
src += sizeof(int16_t);
|
|
}
|
|
}
|
|
|
|
void THP_decodeInt32Buffer(
|
|
int32_t* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
dst[i] = (int32_t)(do_byte_swap ? decodeUInt32ByteSwapped(src)
|
|
: decodeUInt32(src));
|
|
src += sizeof(int32_t);
|
|
}
|
|
}
|
|
|
|
void THP_decodeInt64Buffer(
|
|
int64_t* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
dst[i] = (int64_t)(do_byte_swap ? decodeUInt64ByteSwapped(src)
|
|
: decodeUInt64(src));
|
|
src += sizeof(int64_t);
|
|
}
|
|
}
|
|
|
|
void THP_decodeHalfBuffer(
|
|
c10::Half* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
union {
|
|
uint16_t x;
|
|
c10::Half f;
|
|
};
|
|
x = (do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
|
|
dst[i] = f;
|
|
src += sizeof(uint16_t);
|
|
}
|
|
}
|
|
|
|
void THP_decodeBFloat16Buffer(
|
|
at::BFloat16* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
uint16_t x =
|
|
(do_byte_swap ? decodeUInt16ByteSwapped(src) : decodeUInt16(src));
|
|
std::memcpy(&dst[i], &x, sizeof(dst[i]));
|
|
src += sizeof(uint16_t);
|
|
}
|
|
}
|
|
|
|
void THP_decodeBoolBuffer(
|
|
bool* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
dst[i] = (int)src[i] != 0 ? true : false;
|
|
}
|
|
}
|
|
|
|
void THP_decodeFloatBuffer(
|
|
float* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
union {
|
|
uint32_t x;
|
|
float f;
|
|
};
|
|
x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
|
|
dst[i] = f;
|
|
src += sizeof(float);
|
|
}
|
|
}
|
|
|
|
void THP_decodeDoubleBuffer(
|
|
double* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
union {
|
|
uint64_t x;
|
|
double d;
|
|
};
|
|
x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
|
|
dst[i] = d;
|
|
src += sizeof(double);
|
|
}
|
|
}
|
|
|
|
void THP_decodeComplexFloatBuffer(
|
|
c10::complex<float>* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
union {
|
|
uint32_t x;
|
|
float re;
|
|
};
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
union {
|
|
uint32_t y;
|
|
float im;
|
|
};
|
|
|
|
x = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
|
|
src += sizeof(float);
|
|
y = (do_byte_swap ? decodeUInt32ByteSwapped(src) : decodeUInt32(src));
|
|
src += sizeof(float);
|
|
|
|
dst[i] = c10::complex<float>(re, im);
|
|
}
|
|
}
|
|
|
|
void THP_decodeComplexDoubleBuffer(
|
|
c10::complex<double>* dst,
|
|
const uint8_t* src,
|
|
bool do_byte_swap,
|
|
size_t len) {
|
|
for (const auto i : c10::irange(len)) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
union {
|
|
uint64_t x;
|
|
double re;
|
|
};
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
union {
|
|
uint64_t y;
|
|
double im;
|
|
};
|
|
static_assert(sizeof(uint64_t) == sizeof(double));
|
|
|
|
x = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
|
|
src += sizeof(double);
|
|
y = (do_byte_swap ? decodeUInt64ByteSwapped(src) : decodeUInt64(src));
|
|
src += sizeof(double);
|
|
|
|
dst[i] = c10::complex<double>(re, im);
|
|
}
|
|
}
|
|
|
|
void THP_decodeInt16Buffer(
|
|
int16_t* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeInt16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeInt32Buffer(
|
|
int32_t* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeInt32Buffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeInt64Buffer(
|
|
int64_t* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeInt64Buffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeHalfBuffer(
|
|
c10::Half* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeHalfBuffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeBFloat16Buffer(
|
|
at::BFloat16* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeBFloat16Buffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeBoolBuffer(
|
|
bool* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeBoolBuffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeFloatBuffer(
|
|
float* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeDoubleBuffer(
|
|
double* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeDoubleBuffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeComplexFloatBuffer(
|
|
c10::complex<float>* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeComplexFloatBuffer(dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_decodeComplexDoubleBuffer(
|
|
c10::complex<double>* dst,
|
|
const uint8_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
THP_decodeComplexDoubleBuffer(
|
|
dst, src, (order != THP_nativeByteOrder()), len);
|
|
}
|
|
|
|
void THP_encodeInt16Buffer(
|
|
uint8_t* dst,
|
|
const int16_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
memcpy(dst, src, sizeof(int16_t) * len);
|
|
if (order != THP_nativeByteOrder()) {
|
|
for (const auto i : c10::irange(len)) {
|
|
(void)i;
|
|
swapBytes16(dst);
|
|
dst += sizeof(int16_t);
|
|
}
|
|
}
|
|
}
|
|
|
|
void THP_encodeInt32Buffer(
|
|
uint8_t* dst,
|
|
const int32_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
memcpy(dst, src, sizeof(int32_t) * len);
|
|
if (order != THP_nativeByteOrder()) {
|
|
for (const auto i : c10::irange(len)) {
|
|
(void)i;
|
|
swapBytes32(dst);
|
|
dst += sizeof(int32_t);
|
|
}
|
|
}
|
|
}
|
|
|
|
void THP_encodeInt64Buffer(
|
|
uint8_t* dst,
|
|
const int64_t* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
memcpy(dst, src, sizeof(int64_t) * len);
|
|
if (order != THP_nativeByteOrder()) {
|
|
for (const auto i : c10::irange(len)) {
|
|
(void)i;
|
|
swapBytes64(dst);
|
|
dst += sizeof(int64_t);
|
|
}
|
|
}
|
|
}
|
|
|
|
void THP_encodeFloatBuffer(
|
|
uint8_t* dst,
|
|
const float* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
memcpy(dst, src, sizeof(float) * len);
|
|
if (order != THP_nativeByteOrder()) {
|
|
for (const auto i : c10::irange(len)) {
|
|
(void)i;
|
|
swapBytes32(dst);
|
|
dst += sizeof(float);
|
|
}
|
|
}
|
|
}
|
|
|
|
void THP_encodeDoubleBuffer(
|
|
uint8_t* dst,
|
|
const double* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
memcpy(dst, src, sizeof(double) * len);
|
|
if (order != THP_nativeByteOrder()) {
|
|
for (const auto i : c10::irange(len)) {
|
|
(void)i;
|
|
swapBytes64(dst);
|
|
dst += sizeof(double);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
std::vector<T> complex_to_float(const c10::complex<T>* src, size_t len) {
|
|
std::vector<T> new_src;
|
|
new_src.reserve(2 * len);
|
|
for (const auto i : c10::irange(len)) {
|
|
auto elem = src[i];
|
|
new_src.emplace_back(elem.real());
|
|
new_src.emplace_back(elem.imag());
|
|
}
|
|
return new_src;
|
|
}
|
|
|
|
void THP_encodeComplexFloatBuffer(
|
|
uint8_t* dst,
|
|
const c10::complex<float>* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
auto new_src = complex_to_float(src, len);
|
|
memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(float) * len);
|
|
if (order != THP_nativeByteOrder()) {
|
|
for (const auto i : c10::irange(2 * len)) {
|
|
(void)i; // Suppress unused variable warning
|
|
swapBytes32(dst);
|
|
dst += sizeof(float);
|
|
}
|
|
}
|
|
}
|
|
|
|
void THP_encodeComplexDoubleBuffer(
|
|
uint8_t* dst,
|
|
const c10::complex<double>* src,
|
|
THPByteOrder order,
|
|
size_t len) {
|
|
auto new_src = complex_to_float(src, len);
|
|
memcpy(dst, static_cast<void*>(&new_src), 2 * sizeof(double) * len);
|
|
if (order != THP_nativeByteOrder()) {
|
|
for (const auto i : c10::irange(2 * len)) {
|
|
(void)i; // Suppress unused variable warning
|
|
swapBytes64(dst);
|
|
dst += sizeof(double);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace torch::utils
|