#ifndef CAFFE2_CORE_BLOB_SERIALIZATION_H_ #define CAFFE2_CORE_BLOB_SERIALIZATION_H_ #include #include #include #include "caffe2/core/blob.h" #include "caffe2/core/blob_serializer_base.h" #include "caffe2/core/tensor.h" #include "caffe2/core/typeid.h" #include "caffe2/core/types.h" CAFFE2_DECLARE_int(caffe2_tensor_chunk_size); namespace caffe2 { // The Blob serialization registry and serializer creator functions. CAFFE_DECLARE_TYPED_REGISTRY( BlobSerializerRegistry, CaffeTypeId, BlobSerializerBase); #define REGISTER_BLOB_SERIALIZER(id, ...) \ CAFFE_REGISTER_TYPED_CLASS(BlobSerializerRegistry, id, __VA_ARGS__) // Creates an operator with the given operator definition. inline unique_ptr CreateSerializer(CaffeTypeId id) { return BlobSerializerRegistry()->Create(id); } /** * @brief TensorSerializer is the serializer for Tensors. * * TensorSerializer takes in a blob that contains a Tensor, and serializes it * into a TensorProto protocol buffer. */ template class TensorSerializer : public BlobSerializerBase { public: TensorSerializer() : context_() {} ~TensorSerializer() {} /** * Serializes a Blob. Note that this blob has to contain Tensor, * otherwise this function produces a fatal error. */ void Serialize( const Blob& blob, const string& name, SerializationAcceptor acceptor) override; void Serialize(const Tensor& tensor, const string& name, TensorProto* proto, size_t chunkBegin, int32_t chunkSize); private: // A utility function to store the device context detauls. void StoreDeviceDetail(const Tensor& input, TensorProto* proto); Context context_; }; /** * @brief BlobDeserializerBase is an abstract class that deserializes a blob * from a BlobProto or a TensorProto. */ class BlobDeserializerBase { public: virtual ~BlobDeserializerBase() {} // Deserializes from a BlobProto object. virtual bool Deserialize(const BlobProto& proto, Blob* blob) = 0; }; CAFFE_DECLARE_REGISTRY(BlobDeserializerRegistry, BlobDeserializerBase); #define REGISTER_BLOB_DESERIALIZER(name, ...) \ CAFFE_REGISTER_CLASS(BlobDeserializerRegistry, name, __VA_ARGS__) // Creates an operator with the given operator definition. inline unique_ptr CreateDeserializer(const string& type) { return BlobDeserializerRegistry()->Create(type); } /** * @brief TensorDeserializer is the deserializer for Tensors. * * The device that the deserialized Tensor will live under is determined by the * device_detail field. If you want to specify the device of the deserialized * tensor, change the TensorProto's corresponding fields before calling * Deserialize. */ template class TensorDeserializer : public BlobDeserializerBase { public: bool Deserialize(const BlobProto& proto, Blob* blob) override; bool Deserialize(const TensorProto& proto, Tensor* tensor); }; //////////////////////////////////////////////////////////////////////////////// // Implementations //////////////////////////////////////////////////////////////////////////////// namespace detail { template inline void CopyToProtoAsIs( const size_t size, const SrcType* src, google::protobuf::RepeatedField* field, Context* context) { static_assert( sizeof(SrcType) == sizeof(DstType), "The source type and dest type cannot be copied as-is. Did " "you mean CopyToProtoWithCast?"); field->Reserve(size); for (int i = 0; i < size; ++i) { field->Add(0); } context->template Copy( size, src, reinterpret_cast(field->mutable_data())); // Make sure that we finish the copy into the protobuf. context->FinishDeviceComputation(); } template inline void CopyToProtoWithCast( const size_t size, const SrcType* src, google::protobuf::RepeatedField* field, Context* context) { // TODO: we are having one unnecessary copy here if the context is already // CPUContext. Remove it if it is performance critical. unique_ptr buffer(new SrcType[size]); context->template Copy( size, src, buffer.get()); context->FinishDeviceComputation(); field->Reserve(size); for (int i = 0; i < size; ++i) { field->Add(static_cast(buffer[i])); } } template inline void CopyFromProtoAsIs( const size_t size, const google::protobuf::RepeatedField& field, DstType* dst, Context* context) { static_assert( sizeof(SrcType) == sizeof(DstType), "The source type and dest type cannot be copied as-is. Did " "you mean CopyFromProtoWithCast?"); CHECK_EQ(size, field.size()) << "Incorrect proto field size."; context->template Copy( size, reinterpret_cast(field.data()), dst); } template inline void CopyFromProtoWithCast( const size_t size, const google::protobuf::RepeatedField& field, DstType* dst, Context* context) { CHECK_EQ(size, field.size()) << "Incorrect proto field size."; // TODO: we are having one unnecessary copy here if the context is already // CPUContext. Remove it if it is performance critical. unique_ptr buffer(new DstType[size]); const SrcType* src = field.data(); for (int i = 0; i < size; ++i) { buffer[i] = static_cast(src[i]); } context->template Copy(size, buffer.get(), dst); } } // namespace detail template void TensorSerializer::Serialize( const Blob& blob, const string& name, BlobSerializerBase::SerializationAcceptor acceptor) { CHECK(blob.IsType>()); const auto& tensor = blob.template Get>(); #ifndef __ANDROID__ std::vector> futures; #endif for (size_t chunkBegin = 0; chunkBegin < tensor.size(); chunkBegin += FLAGS_caffe2_tensor_chunk_size) { auto task = [&](size_t chunkBegin) { BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type("Tensor"); TensorProto& proto = *blob_proto.mutable_tensor(); proto.set_name(name); this->Serialize( tensor, name, blob_proto.mutable_tensor(), chunkBegin, FLAGS_caffe2_tensor_chunk_size); acceptor(name, blob_proto.SerializeAsString()); }; #ifndef __ANDROID__ if (tensor.size() > FLAGS_caffe2_tensor_chunk_size) { futures.emplace_back(std::async(std::launch::async, task, chunkBegin)); } else { // Sync mode for small tensors task(chunkBegin); } #else // Since Android does not have std::future, we will always do sync mode // task(chunkBegin); #endif } #ifndef __ANDROID__ for (auto& fut : futures) { fut.get(); } #endif } template void TensorSerializer::Serialize( const Tensor& input, const string& name, TensorProto* proto_ptr, size_t chunkBegin, int32_t chunkSize) { CAFFE_ENFORCE( chunkBegin < input.size(), "Chunk begin is out of tensor: ", chunkBegin, ' ', input.size()); if (chunkBegin + chunkSize > input.size()) { chunkSize = input.size() - chunkBegin; } TensorProto& proto = *proto_ptr; proto.mutable_segment()->set_begin(chunkBegin); proto.mutable_segment()->set_end(chunkBegin + chunkSize); for (int i = 0; i < input.ndim(); ++i) { proto.add_dims(input.dim(i)); } const TensorProto::DataType data_type = TypeMetaToDataType(input.meta()); proto.set_data_type(data_type); // A lot of copypaste is error prone. Should we create a macro for this? switch (data_type) { case TensorProto_DataType_FLOAT: detail::CopyToProtoAsIs( chunkSize, input.template data() + chunkBegin, proto.mutable_float_data(), &this->context_); break; case TensorProto_DataType_INT32: detail::CopyToProtoAsIs( chunkSize, input.template data() + chunkBegin, proto.mutable_int32_data(), &this->context_); break; case TensorProto_DataType_BYTE: LOG(FATAL) << "This should not happen. When serializing, " "BYTE is deprecated and moved to UINT8."; break; case TensorProto_DataType_STRING: { proto.mutable_string_data()->Reserve(chunkSize); const string* content = input.template data(); for (int i = chunkBegin; i < chunkBegin + chunkSize; ++i) { proto.add_string_data(content[i]); } break; } case TensorProto_DataType_BOOL: detail::CopyToProtoWithCast( chunkSize, input.template data() + chunkBegin, proto.mutable_int32_data(), &this->context_); break; case TensorProto_DataType_UINT8: detail::CopyToProtoWithCast( chunkSize, input.template data() + chunkBegin, proto.mutable_int32_data(), &this->context_); break; case TensorProto_DataType_INT8: detail::CopyToProtoWithCast( chunkSize, input.template data() + chunkBegin, proto.mutable_int32_data(), &this->context_); break; case TensorProto_DataType_UINT16: detail::CopyToProtoWithCast( chunkSize, input.template data() + chunkBegin, proto.mutable_int32_data(), &this->context_); break; case TensorProto_DataType_INT16: detail::CopyToProtoWithCast( chunkSize, input.template data() + chunkBegin, proto.mutable_int32_data(), &this->context_); break; case TensorProto_DataType_INT64: detail::CopyToProtoAsIs( chunkSize, input.template data() + chunkBegin, proto.mutable_int64_data(), &this->context_); break; case TensorProto_DataType_FLOAT16: detail::CopyToProtoWithCast( chunkSize, reinterpret_cast(input.template data()) + chunkBegin, proto.mutable_int32_data(), &this->context_); break; case TensorProto_DataType_DOUBLE: detail::CopyToProtoAsIs( chunkSize, input.template data() + chunkBegin, proto.mutable_double_data(), &this->context_); break; case TensorProto_DataType_UNDEFINED: LOG(FATAL) << "TensorSerializer does not have a serialization " "implementation for " << input.meta().name(); break; // Note: we intentially do not provide "default:" so if any new data types // are added, the compiler should warn the user to add the case here. } StoreDeviceDetail(input, &proto); } template bool TensorDeserializer::Deserialize( const BlobProto& blob_proto, Blob* blob) { return Deserialize( blob_proto.tensor(), blob->GetMutable>()); } template bool TensorDeserializer::Deserialize( const TensorProto& proto, Tensor* tensor) { // We create a local context for deserializing. Since Caffe2 contexts are // usually lightweighted, this should not involve too much overhead. Context context(proto.device_detail()); context.SwitchToDevice(); vector dims; for (const TIndex d : proto.dims()) { dims.push_back(d); } tensor->Resize(dims); // Safety check for zero-sized tensors: no copy needed. if (tensor->size() == 0) { return true; } int64_t chunkBegin = 0; auto chunkEnd = tensor->size(); if (proto.has_segment()) { chunkBegin = proto.segment().begin(); chunkEnd = proto.segment().end(); } CAFFE_ENFORCE( 0 <= chunkBegin && chunkBegin < chunkEnd && chunkEnd <= tensor->size(), "Invalid chunk ", chunkBegin, ' ', chunkEnd, " with total tensor size ", tensor->size()); auto chunkSize = chunkEnd - chunkBegin; switch (proto.data_type()) { case TensorProto_DataType_FLOAT: detail::CopyFromProtoAsIs( chunkSize, proto.float_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_INT32: detail::CopyFromProtoAsIs( chunkSize, proto.int32_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_BYTE: // Since BYTE stores the data in a string field instead of a repreated // field we will have it special cased. if (chunkSize != proto.byte_data().size()) { LOG(ERROR) << "Incorrect proto field size."; return false; } context.template Copy( chunkSize, reinterpret_cast(proto.byte_data().data()), tensor->template mutable_data() + chunkBegin); break; case TensorProto_DataType_STRING: // Special handing of string because it is a non-fundamental type. { string* content = tensor->template mutable_data(); for (int i = 0; i < chunkSize; ++i) { content[i + chunkBegin] = proto.string_data(i); } } break; case TensorProto_DataType_BOOL: detail::CopyFromProtoWithCast( chunkSize, proto.int32_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_UINT8: detail::CopyFromProtoWithCast( chunkSize, proto.int32_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_INT8: detail::CopyFromProtoWithCast( chunkSize, proto.int32_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_UINT16: detail::CopyFromProtoWithCast( chunkSize, proto.int32_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_INT16: detail::CopyFromProtoWithCast( chunkSize, proto.int32_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_INT64: detail::CopyFromProtoAsIs( chunkSize, proto.int64_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_FLOAT16: detail::CopyFromProtoWithCast( chunkSize, proto.int32_data(), reinterpret_cast( tensor->template mutable_data()) + chunkBegin, &context); break; case TensorProto_DataType_DOUBLE: detail::CopyFromProtoAsIs( chunkSize, proto.double_data(), tensor->template mutable_data() + chunkBegin, &context); break; case TensorProto_DataType_UNDEFINED: LOG(ERROR) << "Cannot deserialize from a TensorProto UNDEFINED data type."; return false; } context.FinishDeviceComputation(); return true; } } // namespace caffe2 #endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_