Fix endian handling in THPStorage_fromBuffer (#92834)

Fixes #92831

This PR fixes a test failure of `TestTorch.test_from_buffer` on a big-endian machine. The root cause of this failure is that current `THPStorage_fromBuffer` does not perform endian handling correctly on a big-endian.

In `THPStorage_fromBuffer`, the given buffer is stored as machine native-endian. Thus, if the specified byte order (e.g. `big`) is equal to machine native-endian, swapping elements should not be performed. However, in the current implementation, [`decode*BE()`](https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/byte_order.cpp#L72-L109) always swaps elements regardless of machine native-endian (i.e. these methods assume buffer is stored as little-endian).

Thus, this PR uses the following approaches:
- if the specified byte order (e.g. `big`) is equal to machine native-endian, call `decode*LE()` that does not swap elements by passing `torch::utils::THP_LITTLE_ENDIAN` to `THP_decode*Buffer()`.
- if the specified byte order (e.g. `big`) is not equal to machine native-endian, call `decode*BE()` that always swap elements by passing `torch::utils::THP_BIG_ENDIAN` to `THP_decode*Buffer()`.

After applying this PR to the master branch, I confirmed that the test passes on a big-endian machine.

```
% python test/test_torch.py TestTorch.test_from_buffer
/home/ishizaki/PyTorch/master/test/test_torch.py:6367: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
...
/home/ishizaki/PyTorch/master/test/test_torch.py:6396: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  self.assertEqual(bytes.tolist(), [1, 2, 3, 4])
.
----------------------------------------------------------------------
Ran 1 test in 0.021s

OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92834
Approved by: https://github.com/ezyang
This commit is contained in:
Kazuaki Ishizaki
2023-01-29 00:55:54 +00:00
committed by PyTorch MergeBot
parent 1e0c57b645
commit cb817d6176
3 changed files with 175 additions and 36 deletions

View File

@ -121,11 +121,11 @@ THPByteOrder THP_nativeByteOrder() {
void THP_decodeInt16Buffer(
int16_t* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
dst[i] =
(int16_t)(order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src));
(int16_t)(do_byte_swap ? decodeUInt16BE(src) : decodeUInt16LE(src));
src += sizeof(int16_t);
}
}
@ -133,11 +133,11 @@ void THP_decodeInt16Buffer(
void THP_decodeInt32Buffer(
int32_t* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
dst[i] =
(int32_t)(order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src));
(int32_t)(do_byte_swap ? decodeUInt32BE(src) : decodeUInt32LE(src));
src += sizeof(int32_t);
}
}
@ -145,11 +145,11 @@ void THP_decodeInt32Buffer(
void THP_decodeInt64Buffer(
int64_t* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
dst[i] =
(int64_t)(order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src));
(int64_t)(do_byte_swap ? decodeUInt64BE(src) : decodeUInt64LE(src));
src += sizeof(int64_t);
}
}
@ -157,7 +157,7 @@ void THP_decodeInt64Buffer(
void THP_decodeHalfBuffer(
c10::Half* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
@ -165,7 +165,7 @@ void THP_decodeHalfBuffer(
uint16_t x;
c10::Half f;
};
x = (order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src));
x = (do_byte_swap ? decodeUInt16BE(src) : decodeUInt16LE(src));
dst[i] = f;
src += sizeof(uint16_t);
}
@ -174,11 +174,10 @@ void THP_decodeHalfBuffer(
void THP_decodeBFloat16Buffer(
at::BFloat16* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
uint16_t x =
(order == THP_BIG_ENDIAN ? decodeUInt16BE(src) : decodeUInt16LE(src));
uint16_t x = (do_byte_swap ? decodeUInt16BE(src) : decodeUInt16LE(src));
std::memcpy(&dst[i], &x, sizeof(dst[i]));
src += sizeof(uint16_t);
}
@ -187,7 +186,7 @@ void THP_decodeBFloat16Buffer(
void THP_decodeBoolBuffer(
bool* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
dst[i] = (int)src[i] != 0 ? true : false;
@ -197,7 +196,7 @@ void THP_decodeBoolBuffer(
void THP_decodeFloatBuffer(
float* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
@ -205,7 +204,7 @@ void THP_decodeFloatBuffer(
uint32_t x;
float f;
};
x = (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src));
x = (do_byte_swap ? decodeUInt32BE(src) : decodeUInt32LE(src));
dst[i] = f;
src += sizeof(float);
}
@ -214,7 +213,7 @@ void THP_decodeFloatBuffer(
void THP_decodeDoubleBuffer(
double* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
@ -222,7 +221,7 @@ void THP_decodeDoubleBuffer(
uint64_t x;
double d;
};
x = (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src));
x = (do_byte_swap ? decodeUInt64BE(src) : decodeUInt64LE(src));
dst[i] = d;
src += sizeof(double);
}
@ -231,7 +230,7 @@ void THP_decodeDoubleBuffer(
void THP_decodeComplexFloatBuffer(
c10::complex<float>* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
@ -245,9 +244,9 @@ void THP_decodeComplexFloatBuffer(
float im;
};
x = (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src));
x = (do_byte_swap ? decodeUInt32BE(src) : decodeUInt32LE(src));
src += sizeof(float);
y = (order == THP_BIG_ENDIAN ? decodeUInt32BE(src) : decodeUInt32LE(src));
y = (do_byte_swap ? decodeUInt32BE(src) : decodeUInt32LE(src));
src += sizeof(float);
dst[i] = c10::complex<float>(re, im);
@ -257,7 +256,7 @@ void THP_decodeComplexFloatBuffer(
void THP_decodeComplexDoubleBuffer(
c10::complex<double>* dst,
const uint8_t* src,
THPByteOrder order,
bool do_byte_swap,
size_t len) {
for (const auto i : c10::irange(len)) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
@ -271,15 +270,95 @@ void THP_decodeComplexDoubleBuffer(
double im;
};
x = (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src));
x = (do_byte_swap ? decodeUInt64BE(src) : decodeUInt64LE(src));
src += sizeof(double);
y = (order == THP_BIG_ENDIAN ? decodeUInt64BE(src) : decodeUInt64LE(src));
y = (do_byte_swap ? decodeUInt64BE(src) : decodeUInt64LE(src));
src += sizeof(double);
dst[i] = c10::complex<double>(re, im);
}
}
void THP_decodeInt16Buffer(
int16_t* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeInt16Buffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeInt32Buffer(
int32_t* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeInt32Buffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeInt64Buffer(
int64_t* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeInt64Buffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeHalfBuffer(
c10::Half* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeHalfBuffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeBFloat16Buffer(
at::BFloat16* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeBFloat16Buffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeBoolBuffer(
bool* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeBoolBuffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeFloatBuffer(
float* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeFloatBuffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeDoubleBuffer(
double* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeDoubleBuffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeComplexFloatBuffer(
c10::complex<float>* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeComplexFloatBuffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_decodeComplexDoubleBuffer(
c10::complex<double>* dst,
const uint8_t* src,
THPByteOrder order,
size_t len) {
THP_decodeComplexDoubleBuffer(dst, src, (order == THP_BIG_ENDIAN), len);
}
void THP_encodeInt16Buffer(
uint8_t* dst,
const int16_t* src,