mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-25 08:11:06 +08:00
[ModelLoading] Use byte encoding for uint8, fp16 etc. instead of int32 (#34343)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34343 Use byte encoding for uint8, fp16 etc. instead of int32 in TensorProto serialization/deserialization tl;dr - fp16 tensor deserialization 12x faster, serialized size 25% lower - uint8 tensor deserialization 36x faster, serialized size 25% lower Test Plan: ``` ============================================================================ caffe2/caffe2/fb/predictor/ModelLoaderBenchmark.cpprelative time/iter iters/s ============================================================================ BlobProtoInt32DeserializationFloat16 12.37ms 80.82 BlobProtoByteDeserializationFloat16 1125.46% 1.10ms 909.64 ---------------------------------------------------------------------------- BlobProtoInt32DeserializationUInt8 17.57ms 56.92 BlobProtoByteDeserializationUInt8 3629.45% 484.02us 2.07K ============================================================================ ``` Reviewed By: yinghai Differential Revision: D20137451 fbshipit-source-id: 8ed4be2286a6d4c7e134fcb0832f22bc645039a1
This commit is contained in:
committed by
Facebook Github Bot
parent
98afce3c56
commit
879a90b322
@ -1,7 +1,7 @@
|
||||
#include "caffe2/core/blob_serialization.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
|
||||
#include "caffe2/core/blob.h"
|
||||
#include "caffe2/utils/proto_utils.h"
|
||||
@ -21,6 +21,22 @@ C10_DEFINE_bool(
|
||||
false,
|
||||
"Serialize FLOAT16 tensors using byte_data field");
|
||||
|
||||
C10_DEFINE_bool(
|
||||
caffe2_serialize_using_bytes_as_holder,
|
||||
false,
|
||||
"Serialize BOOL, UINT8, INT8, UINT16, INT16, INT64, FLOAT16 tensors using byte_data field instead of int32");
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// It's MSVC, so we just have to guess ... and allow an override
|
||||
#ifdef FOLLY_ENDIAN_BE
|
||||
constexpr auto kIsLittleEndian = false;
|
||||
#else
|
||||
constexpr auto kIsLittleEndian = true;
|
||||
#endif
|
||||
#else
|
||||
constexpr auto kIsLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
|
||||
#endif
|
||||
|
||||
namespace caffe2 {
|
||||
/**
|
||||
* @brief StringSerializer is the serializer for String.
|
||||
@ -183,6 +199,51 @@ void TensorSerializer::SerializeWithChunkSize(
|
||||
#endif
|
||||
}
|
||||
|
||||
static bool EnableByteEncoding(
|
||||
const TensorProto::DataType& dataType,
|
||||
const size_t& typeSize) {
|
||||
// if typeSize == 1, endianness does not matter. Else check for endianness.
|
||||
bool ret = false;
|
||||
bool safeForEndianness = (typeSize == 1 || kIsLittleEndian);
|
||||
if (safeForEndianness) {
|
||||
ret = FLAGS_caffe2_serialize_using_bytes_as_holder;
|
||||
// Check if special casing for float is enabled if
|
||||
// caffe2_serialize_using_bytes_as_holder is not enabled.
|
||||
if (!ret) {
|
||||
ret =
|
||||
(dataType == TensorProto_DataType_FLOAT16 &&
|
||||
FLAGS_caffe2_serialize_fp16_as_bytes);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T, typename S = T>
|
||||
static void SerializeUsingBytesOrInt32(
|
||||
const Tensor& input,
|
||||
const TensorProto::DataType& dataType,
|
||||
size_t chunkBegin,
|
||||
int32_t chunkSize,
|
||||
BaseContext* context,
|
||||
TensorProto& proto) {
|
||||
const auto typeSize = sizeof(T);
|
||||
if (EnableByteEncoding(dataType, typeSize)) {
|
||||
const auto bufSize = typeSize * chunkSize;
|
||||
auto* byteData =
|
||||
reinterpret_cast<const uint8_t*>(input.template data<S>() + chunkBegin);
|
||||
unique_ptr<uint8_t[]> buffer(new uint8_t[bufSize]);
|
||||
context->template CopyToCPU<uint8_t>(bufSize, byteData, buffer.get());
|
||||
context->FinishDeviceComputation();
|
||||
proto.set_byte_data(buffer.release(), bufSize);
|
||||
} else {
|
||||
detail::CopyToProtoWithCast(
|
||||
chunkSize,
|
||||
reinterpret_cast<const T*>(input.template data<S>()) + chunkBegin,
|
||||
proto.mutable_int32_data(),
|
||||
context);
|
||||
}
|
||||
}
|
||||
|
||||
void TensorSerializer::Serialize(
|
||||
const Tensor& input,
|
||||
const string& name,
|
||||
@ -255,39 +316,24 @@ void TensorSerializer::Serialize(
|
||||
break;
|
||||
}
|
||||
case TensorProto_DataType_BOOL:
|
||||
detail::CopyToProtoWithCast(
|
||||
chunkSize,
|
||||
input.template data<bool>() + chunkBegin,
|
||||
proto.mutable_int32_data(),
|
||||
uniq_ptr.get());
|
||||
SerializeUsingBytesOrInt32<bool>(
|
||||
input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto);
|
||||
break;
|
||||
case TensorProto_DataType_UINT8:
|
||||
detail::CopyToProtoWithCast(
|
||||
chunkSize,
|
||||
input.template data<uint8_t>() + chunkBegin,
|
||||
proto.mutable_int32_data(),
|
||||
uniq_ptr.get());
|
||||
SerializeUsingBytesOrInt32<uint8_t>(
|
||||
input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto);
|
||||
break;
|
||||
case TensorProto_DataType_INT8:
|
||||
detail::CopyToProtoWithCast(
|
||||
chunkSize,
|
||||
input.template data<int8_t>() + chunkBegin,
|
||||
proto.mutable_int32_data(),
|
||||
uniq_ptr.get());
|
||||
SerializeUsingBytesOrInt32<int8_t>(
|
||||
input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto);
|
||||
break;
|
||||
case TensorProto_DataType_UINT16:
|
||||
detail::CopyToProtoWithCast(
|
||||
chunkSize,
|
||||
input.template data<uint16_t>() + chunkBegin,
|
||||
proto.mutable_int32_data(),
|
||||
uniq_ptr.get());
|
||||
SerializeUsingBytesOrInt32<uint16_t>(
|
||||
input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto);
|
||||
break;
|
||||
case TensorProto_DataType_INT16:
|
||||
detail::CopyToProtoWithCast(
|
||||
chunkSize,
|
||||
input.template data<int16_t>() + chunkBegin,
|
||||
proto.mutable_int32_data(),
|
||||
uniq_ptr.get());
|
||||
SerializeUsingBytesOrInt32<int16_t>(
|
||||
input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto);
|
||||
break;
|
||||
case TensorProto_DataType_INT64:
|
||||
detail::CopyToProtoAsIs(
|
||||
@ -296,31 +342,10 @@ void TensorSerializer::Serialize(
|
||||
proto.mutable_int64_data(),
|
||||
uniq_ptr.get());
|
||||
break;
|
||||
case TensorProto_DataType_FLOAT16: {
|
||||
if (FLAGS_caffe2_serialize_fp16_as_bytes) {
|
||||
const int kValue = 1;
|
||||
CAFFE_ENFORCE_EQ(
|
||||
reinterpret_cast<const char*>(&kValue)[0],
|
||||
1,
|
||||
"Serialization of FLOAT16 on big endian platform "
|
||||
"is not written yet.");
|
||||
unique_ptr<char[]> buffer(new char[2 * chunkSize]);
|
||||
this->context_->template CopyToCPU<char>(
|
||||
2 * chunkSize,
|
||||
reinterpret_cast<const char*>(
|
||||
input.template data<at::Half>() + chunkBegin),
|
||||
buffer.get());
|
||||
this->context_->FinishDeviceComputation();
|
||||
proto.set_byte_data(buffer.release(), 2 * chunkSize);
|
||||
} else {
|
||||
detail::CopyToProtoWithCast(
|
||||
chunkSize,
|
||||
reinterpret_cast<const uint16_t*>(input.template data<at::Half>()) +
|
||||
chunkBegin,
|
||||
proto.mutable_int32_data(),
|
||||
uniq_ptr.get());
|
||||
}
|
||||
} break;
|
||||
case TensorProto_DataType_FLOAT16:
|
||||
SerializeUsingBytesOrInt32<uint16_t, at::Half>(
|
||||
input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto);
|
||||
break;
|
||||
case TensorProto_DataType_DOUBLE:
|
||||
detail::CopyToProtoAsIs(
|
||||
chunkSize,
|
||||
@ -482,6 +507,43 @@ void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename D = T>
|
||||
void DeserializeFromBytesOrInt32(
|
||||
const TensorProto& tensor_proto,
|
||||
size_t chunkBegin,
|
||||
int32_t chunkSize,
|
||||
BaseContext* context,
|
||||
Tensor* tensor) {
|
||||
if (tensor_proto.has_byte_data()) {
|
||||
auto typeSize = sizeof(T);
|
||||
CAFFE_ENFORCE(
|
||||
kIsLittleEndian || typeSize == 1,
|
||||
"Serialization with bytes not supported on big endian platform.");
|
||||
size_t numElems = tensor_proto.byte_data().size();
|
||||
if (tensor_proto.data_type() == TensorProto_DataType_UINT8) {
|
||||
if (tensor_proto.has_segment()) {
|
||||
const auto& segment = tensor_proto.segment();
|
||||
numElems = segment.end() - segment.begin();
|
||||
}
|
||||
}
|
||||
CAFFE_ENFORCE_EQ(
|
||||
typeSize * chunkSize, numElems, "Incorrect proto field size.");
|
||||
const uint8_t* protoData =
|
||||
reinterpret_cast<const uint8_t*>(tensor_proto.byte_data().data());
|
||||
context->template CopyToCPU<D>(
|
||||
chunkSize,
|
||||
reinterpret_cast<const D*>(protoData),
|
||||
tensor->template mutable_data<D>() + chunkBegin);
|
||||
} else {
|
||||
// Backward compatibility with models which used int32_data field
|
||||
detail::CopyFromProtoWithCast(
|
||||
chunkSize,
|
||||
tensor_proto.int32_data(),
|
||||
reinterpret_cast<T*>(tensor->template mutable_data<D>()) + chunkBegin,
|
||||
context);
|
||||
}
|
||||
}
|
||||
|
||||
void TensorDeserializer::DeserializeToTensor(
|
||||
const TensorProto& tensor_proto,
|
||||
Tensor* tensor) {
|
||||
@ -548,39 +610,24 @@ void TensorDeserializer::DeserializeToTensor(
|
||||
}
|
||||
break;
|
||||
case TensorProto_DataType_BOOL:
|
||||
detail::CopyFromProtoWithCast(
|
||||
chunkSize,
|
||||
tensor_proto.int32_data(),
|
||||
tensor->template mutable_data<bool>() + chunkBegin,
|
||||
context);
|
||||
DeserializeFromBytesOrInt32<bool>(
|
||||
tensor_proto, chunkBegin, chunkSize, context, tensor);
|
||||
break;
|
||||
case TensorProto_DataType_UINT8:
|
||||
detail::CopyFromProtoWithCast(
|
||||
chunkSize,
|
||||
tensor_proto.int32_data(),
|
||||
tensor->template mutable_data<uint8_t>() + chunkBegin,
|
||||
context);
|
||||
DeserializeFromBytesOrInt32<uint8_t>(
|
||||
tensor_proto, chunkBegin, chunkSize, context, tensor);
|
||||
break;
|
||||
case TensorProto_DataType_INT8:
|
||||
detail::CopyFromProtoWithCast(
|
||||
chunkSize,
|
||||
tensor_proto.int32_data(),
|
||||
tensor->template mutable_data<int8_t>() + chunkBegin,
|
||||
context);
|
||||
DeserializeFromBytesOrInt32<int8_t>(
|
||||
tensor_proto, chunkBegin, chunkSize, context, tensor);
|
||||
break;
|
||||
case TensorProto_DataType_UINT16:
|
||||
detail::CopyFromProtoWithCast(
|
||||
chunkSize,
|
||||
tensor_proto.int32_data(),
|
||||
tensor->template mutable_data<uint16_t>() + chunkBegin,
|
||||
context);
|
||||
DeserializeFromBytesOrInt32<uint16_t>(
|
||||
tensor_proto, chunkBegin, chunkSize, context, tensor);
|
||||
break;
|
||||
case TensorProto_DataType_INT16:
|
||||
detail::CopyFromProtoWithCast(
|
||||
chunkSize,
|
||||
tensor_proto.int32_data(),
|
||||
tensor->template mutable_data<int16_t>() + chunkBegin,
|
||||
context);
|
||||
DeserializeFromBytesOrInt32<int16_t>(
|
||||
tensor_proto, chunkBegin, chunkSize, context, tensor);
|
||||
break;
|
||||
case TensorProto_DataType_INT64:
|
||||
detail::CopyFromProtoAsIs(
|
||||
@ -590,31 +637,8 @@ void TensorDeserializer::DeserializeToTensor(
|
||||
context);
|
||||
break;
|
||||
case TensorProto_DataType_FLOAT16:
|
||||
if (tensor_proto.has_byte_data()) {
|
||||
const int kValue = 1;
|
||||
CAFFE_ENFORCE_EQ(
|
||||
reinterpret_cast<const char*>(&kValue)[0],
|
||||
1,
|
||||
"Serialization of FLOAT16 on big endian platform "
|
||||
"is not written yet.");
|
||||
CAFFE_ENFORCE_EQ(
|
||||
2 * chunkSize,
|
||||
tensor_proto.byte_data().size(),
|
||||
"Incorrect proto field size.");
|
||||
context->template CopyToCPU<at::Half>(
|
||||
chunkSize,
|
||||
reinterpret_cast<const at::Half*>(tensor_proto.byte_data().data()),
|
||||
tensor->template mutable_data<at::Half>() + chunkBegin);
|
||||
} else {
|
||||
// Backward compatibility with models which used int32_data field
|
||||
detail::CopyFromProtoWithCast(
|
||||
chunkSize,
|
||||
tensor_proto.int32_data(),
|
||||
reinterpret_cast<uint16_t*>(
|
||||
tensor->template mutable_data<at::Half>()) +
|
||||
chunkBegin,
|
||||
context);
|
||||
}
|
||||
DeserializeFromBytesOrInt32<uint16_t, at::Half>(
|
||||
tensor_proto, chunkBegin, chunkSize, context, tensor);
|
||||
break;
|
||||
case TensorProto_DataType_DOUBLE:
|
||||
detail::CopyFromProtoAsIs(
|
||||
@ -666,13 +690,12 @@ std::string SerializeAsString_EnforceCheck(
|
||||
if (!error_location) {
|
||||
CAFFE_ENFORCE(result, "protobuf::SerializeToString failed");
|
||||
} else {
|
||||
CAFFE_ENFORCE(result,
|
||||
"protobuf::SerializeToString failed for ", error_location);
|
||||
CAFFE_ENFORCE(
|
||||
result, "protobuf::SerializeToString failed for ", error_location);
|
||||
}
|
||||
return serialize_output;
|
||||
}
|
||||
|
||||
|
||||
namespace {
|
||||
// Serialize Tensor
|
||||
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Tensor>()), TensorSerializer);
|
||||
@ -680,5 +703,5 @@ REGISTER_BLOB_DESERIALIZER(TensorCPU, TensorDeserializer);
|
||||
// Serialize std::string
|
||||
REGISTER_BLOB_SERIALIZER((TypeMeta::Id<std::string>()), StringSerializer);
|
||||
REGISTER_BLOB_DESERIALIZER(std::string, StringDeserializer);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
C10_DEFINE_int64(caffe2_test_big_tensor_size, 100000000, "");
|
||||
C10_DECLARE_int(caffe2_tensor_chunk_size);
|
||||
C10_DECLARE_bool(caffe2_serialize_fp16_as_bytes);
|
||||
C10_DECLARE_bool(caffe2_serialize_using_bytes_as_holder);
|
||||
|
||||
namespace caffe2 {
|
||||
using namespace ::caffe2::db;
|
||||
@ -36,7 +37,7 @@ class BlobTestNonDefaultConstructible {
|
||||
BlobTestNonDefaultConstructible(int x) : val(x) {}
|
||||
int32_t val;
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CAFFE_KNOWN_TYPE(BlobTestFoo);
|
||||
CAFFE_KNOWN_TYPE(BlobTestBar);
|
||||
@ -236,8 +237,10 @@ TEST(TensorNonTypedTest, NonDefaultConstructible) {
|
||||
EnforceNotMet);
|
||||
}
|
||||
|
||||
template <typename T> class TensorCPUTest : public ::testing::Test {};
|
||||
template <typename T> class TensorCPUDeathTest : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class TensorCPUTest : public ::testing::Test {};
|
||||
template <typename T>
|
||||
class TensorCPUDeathTest : public ::testing::Test {};
|
||||
typedef ::testing::Types<char, int, float> TensorTypes;
|
||||
TYPED_TEST_CASE(TensorCPUTest, TensorTypes);
|
||||
TYPED_TEST_CASE(TensorCPUDeathTest, TensorTypes);
|
||||
@ -359,7 +362,7 @@ TYPED_TEST(TensorCPUTest, TensorShareDataRawPointer) {
|
||||
dims[0] = 2;
|
||||
dims[1] = 3;
|
||||
dims[2] = 5;
|
||||
std::unique_ptr<TypeParam[]> raw_buffer(new TypeParam[2*3*5]);
|
||||
std::unique_ptr<TypeParam[]> raw_buffer(new TypeParam[2 * 3 * 5]);
|
||||
Tensor tensor(dims, CPU);
|
||||
tensor.ShareExternalPointer(raw_buffer.get());
|
||||
EXPECT_EQ(tensor.mutable_data<TypeParam>(), raw_buffer.get());
|
||||
@ -412,7 +415,6 @@ TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) {
|
||||
vector<int> dims(3);
|
||||
dims[0] = 2;
|
||||
@ -461,7 +463,7 @@ TYPED_TEST(TensorCPUTest, KeepOnShrink) {
|
||||
EXPECT_TRUE(larger_ptr != nullptr);
|
||||
|
||||
// This check can fail when malloc() returns the same recently freed address
|
||||
//EXPECT_NE(ptr, larger_ptr);
|
||||
// EXPECT_NE(ptr, larger_ptr);
|
||||
|
||||
// Shrinking - will not reallocate
|
||||
tensor.Resize(1, 2, 4);
|
||||
@ -497,7 +499,7 @@ TYPED_TEST(TensorCPUTest, MaxKeepOnShrink) {
|
||||
EXPECT_TRUE(new_ptr != nullptr);
|
||||
|
||||
// This check can fail when malloc() returns the same recently freed address
|
||||
//EXPECT_NE(ptr, new_ptr);
|
||||
// EXPECT_NE(ptr, new_ptr);
|
||||
|
||||
// Restore default flags
|
||||
FLAGS_caffe2_max_keep_on_shrink_memory = LLONG_MAX;
|
||||
@ -971,7 +973,7 @@ class DummyTypeDeserializer : public BlobDeserializerBase {
|
||||
container->deserialize(proto);
|
||||
}
|
||||
};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
CAFFE_KNOWN_TYPE(DummyType);
|
||||
|
||||
@ -1153,5 +1155,99 @@ TEST(TensorSerialization, MistakenlySerializingDtypeUninitializedTensor) {
|
||||
EXPECT_EQ(1, new_tensor.dim());
|
||||
}
|
||||
|
||||
static caffe2::BlobProto CreateProtoWithInt32Data(
|
||||
const caffe2::TensorProto::DataType& dataType,
|
||||
size_t numEl,
|
||||
bool useCached = true) {
|
||||
static std::map<caffe2::TensorProto::DataType, caffe2::BlobProto> protos;
|
||||
if (useCached && protos.count(dataType)) {
|
||||
return protos[dataType];
|
||||
}
|
||||
caffe2::BlobProto proto;
|
||||
proto.set_type("Tensor");
|
||||
auto tensor = proto.mutable_tensor();
|
||||
tensor->add_dims(numEl);
|
||||
tensor->add_dims(1);
|
||||
tensor->set_data_type(dataType);
|
||||
tensor->set_name("test_feature");
|
||||
tensor->mutable_device_detail()->set_device_type(0);
|
||||
tensor->mutable_segment()->set_begin(0);
|
||||
tensor->mutable_segment()->set_end(numEl);
|
||||
for (size_t i = 0; i < numEl; ++i) {
|
||||
int32_t data = 0;
|
||||
switch (dataType) {
|
||||
case caffe2::TensorProto_DataType_INT32:
|
||||
data = static_cast<int32_t>(rand() % 0xffffffff);
|
||||
break;
|
||||
case caffe2::TensorProto_DataType_BOOL:
|
||||
data = static_cast<uint8_t>(rand() % 0x00000001);
|
||||
break;
|
||||
case caffe2::TensorProto_DataType_UINT8:
|
||||
data = static_cast<uint8_t>(rand() % 0x000000ff);
|
||||
break;
|
||||
case caffe2::TensorProto_DataType_INT8:
|
||||
data = static_cast<int8_t>(rand() % 0x000000ff);
|
||||
break;
|
||||
case caffe2::TensorProto_DataType_UINT16:
|
||||
data = static_cast<uint16_t>(rand() % 0x0000ffff);
|
||||
break;
|
||||
case caffe2::TensorProto_DataType_INT16:
|
||||
data = static_cast<int16_t>(rand() % 0x0000ffff);
|
||||
break;
|
||||
case caffe2::TensorProto_DataType_FLOAT16:
|
||||
data = static_cast<uint16_t>(rand() % 0x0000ffff);
|
||||
break;
|
||||
default:
|
||||
continue;
|
||||
}
|
||||
tensor->add_int32_data(data);
|
||||
}
|
||||
protos[dataType] = proto;
|
||||
return proto;
|
||||
}
|
||||
|
||||
void TestDataType(
|
||||
const caffe2::TensorProto::DataType& dataType,
|
||||
std::string dataTypeName) {
|
||||
LOG(INFO) << dataTypeName;
|
||||
FLAGS_caffe2_serialize_using_bytes_as_holder = true;
|
||||
size_t numEl = 1000;
|
||||
// Proto with int32
|
||||
auto protoInt32 = CreateProtoWithInt32Data(dataType, numEl, false);
|
||||
caffe2::Blob blobInt32;
|
||||
DeserializeBlob(protoInt32, &blobInt32);
|
||||
auto serializedStr = SerializeBlob(blobInt32, protoInt32.name());
|
||||
caffe2::BlobProto protoBytes;
|
||||
// Proto with bytes
|
||||
protoBytes.ParseFromString(serializedStr);
|
||||
caffe2::Blob blobBytes;
|
||||
DeserializeBlob(protoBytes, &blobBytes);
|
||||
FLAGS_caffe2_serialize_using_bytes_as_holder = false;
|
||||
// Proto with int32 from proto with bytes
|
||||
protoBytes.ParseFromString(SerializeBlob(blobBytes, protoBytes.name()));
|
||||
EXPECT_EQ(numEl, protoInt32.tensor().int32_data_size());
|
||||
EXPECT_EQ(numEl, protoBytes.tensor().int32_data_size());
|
||||
for (int i = 0; i < numEl; ++i) {
|
||||
EXPECT_EQ(
|
||||
protoInt32.tensor().int32_data(i), protoBytes.tensor().int32_data(i));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorSerialization, TestCorrectness) {
|
||||
FLAGS_caffe2_serialize_using_bytes_as_holder = true;
|
||||
TestDataType(
|
||||
caffe2::TensorProto_DataType_INT32, "TensorProto_DataType_INT32");
|
||||
TestDataType(caffe2::TensorProto_DataType_BOOL, "TensorProto_DataType_BOOL");
|
||||
TestDataType(
|
||||
caffe2::TensorProto_DataType_UINT8, "TensorProto_DataType_UINT8");
|
||||
TestDataType(caffe2::TensorProto_DataType_INT8, "TensorProto_DataType_INT8");
|
||||
TestDataType(
|
||||
caffe2::TensorProto_DataType_UINT16, "TensorProto_DataType_UINT16");
|
||||
TestDataType(
|
||||
caffe2::TensorProto_DataType_INT16, "TensorProto_DataType_INT16");
|
||||
TestDataType(
|
||||
caffe2::TensorProto_DataType_FLOAT16, "TensorProto_DataType_FLOAT16");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
|
||||
Reference in New Issue
Block a user