Fix saving and loading pickle files on Big Endian systems (#95881)

This change fixes test/test_cpp_api_parity.py tests on Big Endian systems.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95881
Approved by: https://github.com/malfet
This commit is contained in:
Aleksei Nikiforov
2023-04-05 06:11:31 +00:00
committed by PyTorch MergeBot
parent 1e3abda31a
commit ae0d06b42c
3 changed files with 68 additions and 8 deletions

View File

@ -8,6 +8,7 @@
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/unpickler.h>
#include <torch/csrc/utils/byte_order.h>
#include <string>
namespace torch::jit {
@ -210,6 +211,7 @@ IValue Unpickler::parse_ivalue() {
double Unpickler::readFloat() {
AT_ASSERT(sizeof(double) == 8);
double big_endian = read<double>();
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double little_endian;
@ -221,6 +223,9 @@ double Unpickler::readFloat() {
reinterpret_cast<char*>(&little_endian));
return little_endian;
#else /* __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ */
return big_endian;
#endif /* __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ */
}
void Unpickler::run() {
@ -323,21 +328,21 @@ PickleOpCode Unpickler::readInstruction() {
stack_.emplace_back(int64_t(value));
} break;
case PickleOpCode::BININT2: {
uint16_t value = read<uint16_t>();
uint16_t value = from_le16(read<uint16_t>());
stack_.emplace_back(int64_t(value));
} break;
case PickleOpCode::BININT: {
int32_t value = read<int32_t>();
int32_t value = from_le32(read<int32_t>());
stack_.emplace_back(int64_t(value));
} break;
case PickleOpCode::LONG1: {
// Only read LONG1s with 8 as the length
uint8_t length = read<uint8_t>();
TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
stack_.emplace_back(int64_t(read<int64_t>()));
stack_.emplace_back(int64_t(from_le64(read<int64_t>())));
} break;
case PickleOpCode::BINUNICODE: {
uint32_t length = read<uint32_t>();
uint32_t length = from_le32(read<uint32_t>());
stack_.emplace_back(readBytes(length));
} break;
case PickleOpCode::BINFLOAT: