mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-25 08:11:06 +08:00 
			
		
		
		
	[ModelLoading] Use byte encoding for uint8, fp16 etc. instead of int32 (#34343)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34343 Use byte encoding for uint8, fp16 etc. instead of int32 in TensorProto serialization/deserialization tl;dr - fp16 tensor deserialization 12x faster, serialized size 25% lower - uint8 tensor deserialization 36x faster, serialized size 25% lower Test Plan: ``` ============================================================================ caffe2/caffe2/fb/predictor/ModelLoaderBenchmark.cpprelative time/iter iters/s ============================================================================ BlobProtoInt32DeserializationFloat16 12.37ms 80.82 BlobProtoByteDeserializationFloat16 1125.46% 1.10ms 909.64 ---------------------------------------------------------------------------- BlobProtoInt32DeserializationUInt8 17.57ms 56.92 BlobProtoByteDeserializationUInt8 3629.45% 484.02us 2.07K ============================================================================ ``` Reviewed By: yinghai Differential Revision: D20137451 fbshipit-source-id: 8ed4be2286a6d4c7e134fcb0832f22bc645039a1
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							98afce3c56
						
					
				
				
					commit
					879a90b322
				
			| @ -1,7 +1,7 @@ | ||||
| #include "caffe2/core/blob_serialization.h" | ||||
|  | ||||
| #include <sstream> | ||||
| #include <mutex> | ||||
| #include <sstream> | ||||
|  | ||||
| #include "caffe2/core/blob.h" | ||||
| #include "caffe2/utils/proto_utils.h" | ||||
| @ -21,6 +21,22 @@ C10_DEFINE_bool( | ||||
|     false, | ||||
|     "Serialize FLOAT16 tensors using byte_data field"); | ||||
|  | ||||
| C10_DEFINE_bool( | ||||
|     caffe2_serialize_using_bytes_as_holder, | ||||
|     false, | ||||
|     "Serialize BOOL, UINT8, INT8, UINT16, INT16, INT64, FLOAT16 tensors using byte_data field instead of int32"); | ||||
|  | ||||
| #ifdef _MSC_VER | ||||
| // It's MSVC, so we just have to guess ... and allow an override | ||||
| #ifdef FOLLY_ENDIAN_BE | ||||
| constexpr auto kIsLittleEndian = false; | ||||
| #else | ||||
| constexpr auto kIsLittleEndian = true; | ||||
| #endif | ||||
| #else | ||||
| constexpr auto kIsLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; | ||||
| #endif | ||||
|  | ||||
| namespace caffe2 { | ||||
| /** | ||||
|  * @brief StringSerializer is the serializer for String. | ||||
| @ -183,6 +199,51 @@ void TensorSerializer::SerializeWithChunkSize( | ||||
| #endif | ||||
| } | ||||
|  | ||||
| static bool EnableByteEncoding( | ||||
|     const TensorProto::DataType& dataType, | ||||
|     const size_t& typeSize) { | ||||
|   // if typeSize == 1, endianness does not matter. Else check for endianness. | ||||
|   bool ret = false; | ||||
|   bool safeForEndianness = (typeSize == 1 || kIsLittleEndian); | ||||
|   if (safeForEndianness) { | ||||
|     ret = FLAGS_caffe2_serialize_using_bytes_as_holder; | ||||
|     // Check if special casing for float is enabled if | ||||
|     // caffe2_serialize_using_bytes_as_holder is not enabled. | ||||
|     if (!ret) { | ||||
|       ret = | ||||
|           (dataType == TensorProto_DataType_FLOAT16 && | ||||
|            FLAGS_caffe2_serialize_fp16_as_bytes); | ||||
|     } | ||||
|   } | ||||
|   return ret; | ||||
| } | ||||
|  | ||||
| template <typename T, typename S = T> | ||||
| static void SerializeUsingBytesOrInt32( | ||||
|     const Tensor& input, | ||||
|     const TensorProto::DataType& dataType, | ||||
|     size_t chunkBegin, | ||||
|     int32_t chunkSize, | ||||
|     BaseContext* context, | ||||
|     TensorProto& proto) { | ||||
|   const auto typeSize = sizeof(T); | ||||
|   if (EnableByteEncoding(dataType, typeSize)) { | ||||
|     const auto bufSize = typeSize * chunkSize; | ||||
|     auto* byteData = | ||||
|         reinterpret_cast<const uint8_t*>(input.template data<S>() + chunkBegin); | ||||
|     unique_ptr<uint8_t[]> buffer(new uint8_t[bufSize]); | ||||
|     context->template CopyToCPU<uint8_t>(bufSize, byteData, buffer.get()); | ||||
|     context->FinishDeviceComputation(); | ||||
|     proto.set_byte_data(buffer.release(), bufSize); | ||||
|   } else { | ||||
|     detail::CopyToProtoWithCast( | ||||
|         chunkSize, | ||||
|         reinterpret_cast<const T*>(input.template data<S>()) + chunkBegin, | ||||
|         proto.mutable_int32_data(), | ||||
|         context); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void TensorSerializer::Serialize( | ||||
|     const Tensor& input, | ||||
|     const string& name, | ||||
| @ -255,39 +316,24 @@ void TensorSerializer::Serialize( | ||||
|       break; | ||||
|     } | ||||
|     case TensorProto_DataType_BOOL: | ||||
|       detail::CopyToProtoWithCast( | ||||
|           chunkSize, | ||||
|           input.template data<bool>() + chunkBegin, | ||||
|           proto.mutable_int32_data(), | ||||
|           uniq_ptr.get()); | ||||
|       SerializeUsingBytesOrInt32<bool>( | ||||
|           input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto); | ||||
|       break; | ||||
|     case TensorProto_DataType_UINT8: | ||||
|       detail::CopyToProtoWithCast( | ||||
|           chunkSize, | ||||
|           input.template data<uint8_t>() + chunkBegin, | ||||
|           proto.mutable_int32_data(), | ||||
|           uniq_ptr.get()); | ||||
|       SerializeUsingBytesOrInt32<uint8_t>( | ||||
|           input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto); | ||||
|       break; | ||||
|     case TensorProto_DataType_INT8: | ||||
|       detail::CopyToProtoWithCast( | ||||
|           chunkSize, | ||||
|           input.template data<int8_t>() + chunkBegin, | ||||
|           proto.mutable_int32_data(), | ||||
|           uniq_ptr.get()); | ||||
|       SerializeUsingBytesOrInt32<int8_t>( | ||||
|           input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto); | ||||
|       break; | ||||
|     case TensorProto_DataType_UINT16: | ||||
|       detail::CopyToProtoWithCast( | ||||
|           chunkSize, | ||||
|           input.template data<uint16_t>() + chunkBegin, | ||||
|           proto.mutable_int32_data(), | ||||
|           uniq_ptr.get()); | ||||
|       SerializeUsingBytesOrInt32<uint16_t>( | ||||
|           input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto); | ||||
|       break; | ||||
|     case TensorProto_DataType_INT16: | ||||
|       detail::CopyToProtoWithCast( | ||||
|           chunkSize, | ||||
|           input.template data<int16_t>() + chunkBegin, | ||||
|           proto.mutable_int32_data(), | ||||
|           uniq_ptr.get()); | ||||
|       SerializeUsingBytesOrInt32<int16_t>( | ||||
|           input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto); | ||||
|       break; | ||||
|     case TensorProto_DataType_INT64: | ||||
|       detail::CopyToProtoAsIs( | ||||
| @ -296,31 +342,10 @@ void TensorSerializer::Serialize( | ||||
|           proto.mutable_int64_data(), | ||||
|           uniq_ptr.get()); | ||||
|       break; | ||||
|     case TensorProto_DataType_FLOAT16: { | ||||
|       if (FLAGS_caffe2_serialize_fp16_as_bytes) { | ||||
|         const int kValue = 1; | ||||
|         CAFFE_ENFORCE_EQ( | ||||
|             reinterpret_cast<const char*>(&kValue)[0], | ||||
|             1, | ||||
|             "Serialization of FLOAT16 on big endian platform " | ||||
|             "is not written yet."); | ||||
|         unique_ptr<char[]> buffer(new char[2 * chunkSize]); | ||||
|         this->context_->template CopyToCPU<char>( | ||||
|             2 * chunkSize, | ||||
|             reinterpret_cast<const char*>( | ||||
|                 input.template data<at::Half>() + chunkBegin), | ||||
|             buffer.get()); | ||||
|         this->context_->FinishDeviceComputation(); | ||||
|         proto.set_byte_data(buffer.release(), 2 * chunkSize); | ||||
|       } else { | ||||
|         detail::CopyToProtoWithCast( | ||||
|             chunkSize, | ||||
|             reinterpret_cast<const uint16_t*>(input.template data<at::Half>()) + | ||||
|                 chunkBegin, | ||||
|             proto.mutable_int32_data(), | ||||
|             uniq_ptr.get()); | ||||
|       } | ||||
|     } break; | ||||
|     case TensorProto_DataType_FLOAT16: | ||||
|       SerializeUsingBytesOrInt32<uint16_t, at::Half>( | ||||
|           input, data_type, chunkBegin, chunkSize, uniq_ptr.get(), proto); | ||||
|       break; | ||||
|     case TensorProto_DataType_DOUBLE: | ||||
|       detail::CopyToProtoAsIs( | ||||
|           chunkSize, | ||||
| @ -482,6 +507,43 @@ void TensorDeserializer::Deserialize(const BlobProto& blob_proto, Blob* blob) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T, typename D = T> | ||||
| void DeserializeFromBytesOrInt32( | ||||
|     const TensorProto& tensor_proto, | ||||
|     size_t chunkBegin, | ||||
|     int32_t chunkSize, | ||||
|     BaseContext* context, | ||||
|     Tensor* tensor) { | ||||
|   if (tensor_proto.has_byte_data()) { | ||||
|     auto typeSize = sizeof(T); | ||||
|     CAFFE_ENFORCE( | ||||
|         kIsLittleEndian || typeSize == 1, | ||||
|         "Serialization with bytes not supported on big endian platform."); | ||||
|     size_t numElems = tensor_proto.byte_data().size(); | ||||
|     if (tensor_proto.data_type() == TensorProto_DataType_UINT8) { | ||||
|       if (tensor_proto.has_segment()) { | ||||
|         const auto& segment = tensor_proto.segment(); | ||||
|         numElems = segment.end() - segment.begin(); | ||||
|       } | ||||
|     } | ||||
|     CAFFE_ENFORCE_EQ( | ||||
|         typeSize * chunkSize, numElems, "Incorrect proto field size."); | ||||
|     const uint8_t* protoData = | ||||
|         reinterpret_cast<const uint8_t*>(tensor_proto.byte_data().data()); | ||||
|     context->template CopyToCPU<D>( | ||||
|         chunkSize, | ||||
|         reinterpret_cast<const D*>(protoData), | ||||
|         tensor->template mutable_data<D>() + chunkBegin); | ||||
|   } else { | ||||
|     // Backward compatibility with models which used int32_data field | ||||
|     detail::CopyFromProtoWithCast( | ||||
|         chunkSize, | ||||
|         tensor_proto.int32_data(), | ||||
|         reinterpret_cast<T*>(tensor->template mutable_data<D>()) + chunkBegin, | ||||
|         context); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void TensorDeserializer::DeserializeToTensor( | ||||
|     const TensorProto& tensor_proto, | ||||
|     Tensor* tensor) { | ||||
| @ -548,39 +610,24 @@ void TensorDeserializer::DeserializeToTensor( | ||||
|       } | ||||
|       break; | ||||
|     case TensorProto_DataType_BOOL: | ||||
|       detail::CopyFromProtoWithCast( | ||||
|           chunkSize, | ||||
|           tensor_proto.int32_data(), | ||||
|           tensor->template mutable_data<bool>() + chunkBegin, | ||||
|           context); | ||||
|       DeserializeFromBytesOrInt32<bool>( | ||||
|           tensor_proto, chunkBegin, chunkSize, context, tensor); | ||||
|       break; | ||||
|     case TensorProto_DataType_UINT8: | ||||
|       detail::CopyFromProtoWithCast( | ||||
|           chunkSize, | ||||
|           tensor_proto.int32_data(), | ||||
|           tensor->template mutable_data<uint8_t>() + chunkBegin, | ||||
|           context); | ||||
|       DeserializeFromBytesOrInt32<uint8_t>( | ||||
|           tensor_proto, chunkBegin, chunkSize, context, tensor); | ||||
|       break; | ||||
|     case TensorProto_DataType_INT8: | ||||
|       detail::CopyFromProtoWithCast( | ||||
|           chunkSize, | ||||
|           tensor_proto.int32_data(), | ||||
|           tensor->template mutable_data<int8_t>() + chunkBegin, | ||||
|           context); | ||||
|       DeserializeFromBytesOrInt32<int8_t>( | ||||
|           tensor_proto, chunkBegin, chunkSize, context, tensor); | ||||
|       break; | ||||
|     case TensorProto_DataType_UINT16: | ||||
|       detail::CopyFromProtoWithCast( | ||||
|           chunkSize, | ||||
|           tensor_proto.int32_data(), | ||||
|           tensor->template mutable_data<uint16_t>() + chunkBegin, | ||||
|           context); | ||||
|       DeserializeFromBytesOrInt32<uint16_t>( | ||||
|           tensor_proto, chunkBegin, chunkSize, context, tensor); | ||||
|       break; | ||||
|     case TensorProto_DataType_INT16: | ||||
|       detail::CopyFromProtoWithCast( | ||||
|           chunkSize, | ||||
|           tensor_proto.int32_data(), | ||||
|           tensor->template mutable_data<int16_t>() + chunkBegin, | ||||
|           context); | ||||
|       DeserializeFromBytesOrInt32<int16_t>( | ||||
|           tensor_proto, chunkBegin, chunkSize, context, tensor); | ||||
|       break; | ||||
|     case TensorProto_DataType_INT64: | ||||
|       detail::CopyFromProtoAsIs( | ||||
| @ -590,31 +637,8 @@ void TensorDeserializer::DeserializeToTensor( | ||||
|           context); | ||||
|       break; | ||||
|     case TensorProto_DataType_FLOAT16: | ||||
|       if (tensor_proto.has_byte_data()) { | ||||
|         const int kValue = 1; | ||||
|         CAFFE_ENFORCE_EQ( | ||||
|             reinterpret_cast<const char*>(&kValue)[0], | ||||
|             1, | ||||
|             "Serialization of FLOAT16 on big endian platform " | ||||
|             "is not written yet."); | ||||
|         CAFFE_ENFORCE_EQ( | ||||
|             2 * chunkSize, | ||||
|             tensor_proto.byte_data().size(), | ||||
|             "Incorrect proto field size."); | ||||
|         context->template CopyToCPU<at::Half>( | ||||
|             chunkSize, | ||||
|             reinterpret_cast<const at::Half*>(tensor_proto.byte_data().data()), | ||||
|             tensor->template mutable_data<at::Half>() + chunkBegin); | ||||
|       } else { | ||||
|         // Backward compatibility with models which used int32_data field | ||||
|         detail::CopyFromProtoWithCast( | ||||
|             chunkSize, | ||||
|             tensor_proto.int32_data(), | ||||
|             reinterpret_cast<uint16_t*>( | ||||
|                 tensor->template mutable_data<at::Half>()) + | ||||
|                 chunkBegin, | ||||
|             context); | ||||
|       } | ||||
|       DeserializeFromBytesOrInt32<uint16_t, at::Half>( | ||||
|           tensor_proto, chunkBegin, chunkSize, context, tensor); | ||||
|       break; | ||||
|     case TensorProto_DataType_DOUBLE: | ||||
|       detail::CopyFromProtoAsIs( | ||||
| @ -666,13 +690,12 @@ std::string SerializeAsString_EnforceCheck( | ||||
|   if (!error_location) { | ||||
|     CAFFE_ENFORCE(result, "protobuf::SerializeToString failed"); | ||||
|   } else { | ||||
|     CAFFE_ENFORCE(result, | ||||
|         "protobuf::SerializeToString failed for ", error_location); | ||||
|     CAFFE_ENFORCE( | ||||
|         result, "protobuf::SerializeToString failed for ", error_location); | ||||
|   } | ||||
|   return serialize_output; | ||||
| } | ||||
|  | ||||
|  | ||||
| namespace { | ||||
| // Serialize Tensor | ||||
| REGISTER_BLOB_SERIALIZER((TypeMeta::Id<Tensor>()), TensorSerializer); | ||||
| @ -680,5 +703,5 @@ REGISTER_BLOB_DESERIALIZER(TensorCPU, TensorDeserializer); | ||||
| // Serialize std::string | ||||
| REGISTER_BLOB_SERIALIZER((TypeMeta::Id<std::string>()), StringSerializer); | ||||
| REGISTER_BLOB_DESERIALIZER(std::string, StringDeserializer); | ||||
| }  // namespace | ||||
| }  // namespace caffe2 | ||||
| } // namespace | ||||
| } // namespace caffe2 | ||||
|  | ||||
| @ -21,6 +21,7 @@ | ||||
| C10_DEFINE_int64(caffe2_test_big_tensor_size, 100000000, ""); | ||||
| C10_DECLARE_int(caffe2_tensor_chunk_size); | ||||
| C10_DECLARE_bool(caffe2_serialize_fp16_as_bytes); | ||||
| C10_DECLARE_bool(caffe2_serialize_using_bytes_as_holder); | ||||
|  | ||||
| namespace caffe2 { | ||||
| using namespace ::caffe2::db; | ||||
| @ -36,7 +37,7 @@ class BlobTestNonDefaultConstructible { | ||||
|   BlobTestNonDefaultConstructible(int x) : val(x) {} | ||||
|   int32_t val; | ||||
| }; | ||||
| } | ||||
| } // namespace | ||||
|  | ||||
| CAFFE_KNOWN_TYPE(BlobTestFoo); | ||||
| CAFFE_KNOWN_TYPE(BlobTestBar); | ||||
| @ -236,8 +237,10 @@ TEST(TensorNonTypedTest, NonDefaultConstructible) { | ||||
|       EnforceNotMet); | ||||
| } | ||||
|  | ||||
| template <typename T> class TensorCPUTest : public ::testing::Test {}; | ||||
| template <typename T> class TensorCPUDeathTest : public ::testing::Test {}; | ||||
| template <typename T> | ||||
| class TensorCPUTest : public ::testing::Test {}; | ||||
| template <typename T> | ||||
| class TensorCPUDeathTest : public ::testing::Test {}; | ||||
| typedef ::testing::Types<char, int, float> TensorTypes; | ||||
| TYPED_TEST_CASE(TensorCPUTest, TensorTypes); | ||||
| TYPED_TEST_CASE(TensorCPUDeathTest, TensorTypes); | ||||
| @ -359,7 +362,7 @@ TYPED_TEST(TensorCPUTest, TensorShareDataRawPointer) { | ||||
|   dims[0] = 2; | ||||
|   dims[1] = 3; | ||||
|   dims[2] = 5; | ||||
|   std::unique_ptr<TypeParam[]> raw_buffer(new TypeParam[2*3*5]); | ||||
|   std::unique_ptr<TypeParam[]> raw_buffer(new TypeParam[2 * 3 * 5]); | ||||
|   Tensor tensor(dims, CPU); | ||||
|   tensor.ShareExternalPointer(raw_buffer.get()); | ||||
|   EXPECT_EQ(tensor.mutable_data<TypeParam>(), raw_buffer.get()); | ||||
| @ -412,7 +415,6 @@ TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) { | ||||
|   } | ||||
| } | ||||
|  | ||||
|  | ||||
| TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) { | ||||
|   vector<int> dims(3); | ||||
|   dims[0] = 2; | ||||
| @ -461,7 +463,7 @@ TYPED_TEST(TensorCPUTest, KeepOnShrink) { | ||||
|   EXPECT_TRUE(larger_ptr != nullptr); | ||||
|  | ||||
|   // This check can fail when malloc() returns the same recently freed address | ||||
|   //EXPECT_NE(ptr, larger_ptr); | ||||
|   // EXPECT_NE(ptr, larger_ptr); | ||||
|  | ||||
|   // Shrinking - will not reallocate | ||||
|   tensor.Resize(1, 2, 4); | ||||
| @ -497,7 +499,7 @@ TYPED_TEST(TensorCPUTest, MaxKeepOnShrink) { | ||||
|   EXPECT_TRUE(new_ptr != nullptr); | ||||
|  | ||||
|   // This check can fail when malloc() returns the same recently freed address | ||||
|   //EXPECT_NE(ptr, new_ptr); | ||||
|   // EXPECT_NE(ptr, new_ptr); | ||||
|  | ||||
|   // Restore default flags | ||||
|   FLAGS_caffe2_max_keep_on_shrink_memory = LLONG_MAX; | ||||
| @ -971,7 +973,7 @@ class DummyTypeDeserializer : public BlobDeserializerBase { | ||||
|     container->deserialize(proto); | ||||
|   } | ||||
| }; | ||||
| } | ||||
| } // namespace | ||||
|  | ||||
| CAFFE_KNOWN_TYPE(DummyType); | ||||
|  | ||||
| @ -1153,5 +1155,99 @@ TEST(TensorSerialization, MistakenlySerializingDtypeUninitializedTensor) { | ||||
|   EXPECT_EQ(1, new_tensor.dim()); | ||||
| } | ||||
|  | ||||
| static caffe2::BlobProto CreateProtoWithInt32Data( | ||||
|     const caffe2::TensorProto::DataType& dataType, | ||||
|     size_t numEl, | ||||
|     bool useCached = true) { | ||||
|   static std::map<caffe2::TensorProto::DataType, caffe2::BlobProto> protos; | ||||
|   if (useCached && protos.count(dataType)) { | ||||
|     return protos[dataType]; | ||||
|   } | ||||
|   caffe2::BlobProto proto; | ||||
|   proto.set_type("Tensor"); | ||||
|   auto tensor = proto.mutable_tensor(); | ||||
|   tensor->add_dims(numEl); | ||||
|   tensor->add_dims(1); | ||||
|   tensor->set_data_type(dataType); | ||||
|   tensor->set_name("test_feature"); | ||||
|   tensor->mutable_device_detail()->set_device_type(0); | ||||
|   tensor->mutable_segment()->set_begin(0); | ||||
|   tensor->mutable_segment()->set_end(numEl); | ||||
|   for (size_t i = 0; i < numEl; ++i) { | ||||
|     int32_t data = 0; | ||||
|     switch (dataType) { | ||||
|       case caffe2::TensorProto_DataType_INT32: | ||||
|         data = static_cast<int32_t>(rand() % 0xffffffff); | ||||
|         break; | ||||
|       case caffe2::TensorProto_DataType_BOOL: | ||||
|         data = static_cast<uint8_t>(rand() % 0x00000001); | ||||
|         break; | ||||
|       case caffe2::TensorProto_DataType_UINT8: | ||||
|         data = static_cast<uint8_t>(rand() % 0x000000ff); | ||||
|         break; | ||||
|       case caffe2::TensorProto_DataType_INT8: | ||||
|         data = static_cast<int8_t>(rand() % 0x000000ff); | ||||
|         break; | ||||
|       case caffe2::TensorProto_DataType_UINT16: | ||||
|         data = static_cast<uint16_t>(rand() % 0x0000ffff); | ||||
|         break; | ||||
|       case caffe2::TensorProto_DataType_INT16: | ||||
|         data = static_cast<int16_t>(rand() % 0x0000ffff); | ||||
|         break; | ||||
|       case caffe2::TensorProto_DataType_FLOAT16: | ||||
|         data = static_cast<uint16_t>(rand() % 0x0000ffff); | ||||
|         break; | ||||
|       default: | ||||
|         continue; | ||||
|     } | ||||
|     tensor->add_int32_data(data); | ||||
|   } | ||||
|   protos[dataType] = proto; | ||||
|   return proto; | ||||
| } | ||||
|  | ||||
| void TestDataType( | ||||
|     const caffe2::TensorProto::DataType& dataType, | ||||
|     std::string dataTypeName) { | ||||
|   LOG(INFO) << dataTypeName; | ||||
|   FLAGS_caffe2_serialize_using_bytes_as_holder = true; | ||||
|   size_t numEl = 1000; | ||||
|   // Proto with int32 | ||||
|   auto protoInt32 = CreateProtoWithInt32Data(dataType, numEl, false); | ||||
|   caffe2::Blob blobInt32; | ||||
|   DeserializeBlob(protoInt32, &blobInt32); | ||||
|   auto serializedStr = SerializeBlob(blobInt32, protoInt32.name()); | ||||
|   caffe2::BlobProto protoBytes; | ||||
|   // Proto with bytes | ||||
|   protoBytes.ParseFromString(serializedStr); | ||||
|   caffe2::Blob blobBytes; | ||||
|   DeserializeBlob(protoBytes, &blobBytes); | ||||
|   FLAGS_caffe2_serialize_using_bytes_as_holder = false; | ||||
|   // Proto with int32 from proto with bytes | ||||
|   protoBytes.ParseFromString(SerializeBlob(blobBytes, protoBytes.name())); | ||||
|   EXPECT_EQ(numEl, protoInt32.tensor().int32_data_size()); | ||||
|   EXPECT_EQ(numEl, protoBytes.tensor().int32_data_size()); | ||||
|   for (int i = 0; i < numEl; ++i) { | ||||
|     EXPECT_EQ( | ||||
|         protoInt32.tensor().int32_data(i), protoBytes.tensor().int32_data(i)); | ||||
|   } | ||||
| } | ||||
|  | ||||
| TEST(TensorSerialization, TestCorrectness) { | ||||
|   FLAGS_caffe2_serialize_using_bytes_as_holder = true; | ||||
|   TestDataType( | ||||
|       caffe2::TensorProto_DataType_INT32, "TensorProto_DataType_INT32"); | ||||
|   TestDataType(caffe2::TensorProto_DataType_BOOL, "TensorProto_DataType_BOOL"); | ||||
|   TestDataType( | ||||
|       caffe2::TensorProto_DataType_UINT8, "TensorProto_DataType_UINT8"); | ||||
|   TestDataType(caffe2::TensorProto_DataType_INT8, "TensorProto_DataType_INT8"); | ||||
|   TestDataType( | ||||
|       caffe2::TensorProto_DataType_UINT16, "TensorProto_DataType_UINT16"); | ||||
|   TestDataType( | ||||
|       caffe2::TensorProto_DataType_INT16, "TensorProto_DataType_INT16"); | ||||
|   TestDataType( | ||||
|       caffe2::TensorProto_DataType_FLOAT16, "TensorProto_DataType_FLOAT16"); | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
| } // namespace caffe2 | ||||
|  | ||||
		Reference in New Issue
	
	Block a user