mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
489 lines
15 KiB
C++
489 lines
15 KiB
C++
#ifndef CAFFE2_CORE_BLOB_SERIALIZATION_H_
|
|
#define CAFFE2_CORE_BLOB_SERIALIZATION_H_
|
|
|
|
#include <limits>
|
|
#include <future>
|
|
|
|
#include <google/protobuf/repeated_field.h>
|
|
|
|
#include "caffe2/core/blob.h"
|
|
#include "caffe2/core/blob_serializer_base.h"
|
|
#include "caffe2/core/tensor.h"
|
|
#include "caffe2/core/typeid.h"
|
|
#include "caffe2/core/types.h"
|
|
|
|
CAFFE2_DECLARE_int(caffe2_tensor_chunk_size);
|
|
|
|
namespace caffe2 {
|
|
|
|
constexpr auto kTensorBlobType = "Tensor";
|
|
|
|
// The Blob serialization registry and serializer creator functions.
|
|
CAFFE_DECLARE_TYPED_REGISTRY(
|
|
BlobSerializerRegistry,
|
|
CaffeTypeId,
|
|
BlobSerializerBase);
|
|
#define REGISTER_BLOB_SERIALIZER(id, ...) \
|
|
CAFFE_REGISTER_TYPED_CLASS(BlobSerializerRegistry, id, __VA_ARGS__)
|
|
// Creates an operator with the given operator definition.
|
|
inline unique_ptr<BlobSerializerBase> CreateSerializer(CaffeTypeId id) {
|
|
return BlobSerializerRegistry()->Create(id);
|
|
}
|
|
|
|
/**
|
|
* @brief TensorSerializer is the serializer for Tensors.
|
|
*
|
|
* TensorSerializer takes in a blob that contains a Tensor, and serializes it
|
|
* into a TensorProto protocol buffer.
|
|
*/
|
|
template <class Context>
|
|
class TensorSerializer : public BlobSerializerBase {
|
|
public:
|
|
TensorSerializer() : context_() {}
|
|
~TensorSerializer() {}
|
|
/**
|
|
* Serializes a Blob. Note that this blob has to contain Tensor<Context>,
|
|
* otherwise this function produces a fatal error.
|
|
*/
|
|
void Serialize(
|
|
const Blob& blob,
|
|
const string& name,
|
|
SerializationAcceptor acceptor) override;
|
|
void Serialize(const Tensor<Context>& tensor, const string& name,
|
|
TensorProto* proto, size_t chunkBegin, int32_t chunkSize);
|
|
|
|
private:
|
|
// A utility function to store the device context detauls.
|
|
void StoreDeviceDetail(const Tensor<Context>& input, TensorProto* proto);
|
|
Context context_;
|
|
};
|
|
|
|
/**
|
|
* @brief BlobDeserializerBase is an abstract class that deserializes a blob
|
|
* from a BlobProto or a TensorProto.
|
|
*/
|
|
class BlobDeserializerBase {
|
|
public:
|
|
virtual ~BlobDeserializerBase() {}
|
|
|
|
// Deserializes from a BlobProto object.
|
|
virtual bool Deserialize(const BlobProto& proto, Blob* blob) = 0;
|
|
};
|
|
|
|
CAFFE_DECLARE_REGISTRY(BlobDeserializerRegistry, BlobDeserializerBase);
|
|
#define REGISTER_BLOB_DESERIALIZER(name, ...) \
|
|
CAFFE_REGISTER_CLASS(BlobDeserializerRegistry, name, __VA_ARGS__)
|
|
// Creates an operator with the given operator definition.
|
|
inline unique_ptr<BlobDeserializerBase> CreateDeserializer(const string& type) {
|
|
return BlobDeserializerRegistry()->Create(type);
|
|
}
|
|
|
|
/**
|
|
* @brief TensorDeserializer is the deserializer for Tensors.
|
|
*
|
|
* The device that the deserialized Tensor will live under is determined by the
|
|
* device_detail field. If you want to specify the device of the deserialized
|
|
* tensor, change the TensorProto's corresponding fields before calling
|
|
* Deserialize.
|
|
*/
|
|
template <class Context>
|
|
class TensorDeserializer : public BlobDeserializerBase {
|
|
public:
|
|
bool Deserialize(const BlobProto& proto, Blob* blob) override;
|
|
bool Deserialize(const TensorProto& proto, Tensor<Context>* tensor);
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
// Implementations
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace detail {
|
|
template <typename SrcType, typename DstType, class Context>
|
|
inline void CopyToProtoAsIs(
|
|
const size_t size,
|
|
const SrcType* src,
|
|
google::protobuf::RepeatedField<DstType>* field,
|
|
Context* context) {
|
|
static_assert(
|
|
sizeof(SrcType) == sizeof(DstType),
|
|
"The source type and dest type cannot be copied as-is. Did "
|
|
"you mean CopyToProtoWithCast?");
|
|
field->Reserve(size);
|
|
for (int i = 0; i < size; ++i) {
|
|
field->Add(0);
|
|
}
|
|
context->template Copy<SrcType, Context, CPUContext>(
|
|
size, src, reinterpret_cast<SrcType*>(field->mutable_data()));
|
|
// Make sure that we finish the copy into the protobuf.
|
|
context->FinishDeviceComputation();
|
|
}
|
|
|
|
template <typename SrcType, typename DstType, class Context>
|
|
inline void CopyToProtoWithCast(
|
|
const size_t size,
|
|
const SrcType* src,
|
|
google::protobuf::RepeatedField<DstType>* field,
|
|
Context* context) {
|
|
// TODO: we are having one unnecessary copy here if the context is already
|
|
// CPUContext. Remove it if it is performance critical.
|
|
unique_ptr<SrcType[]> buffer(new SrcType[size]);
|
|
context->template Copy<SrcType, Context, CPUContext>(
|
|
size, src, buffer.get());
|
|
context->FinishDeviceComputation();
|
|
field->Reserve(size);
|
|
for (int i = 0; i < size; ++i) {
|
|
field->Add(static_cast<DstType>(buffer[i]));
|
|
}
|
|
}
|
|
|
|
template <typename SrcType, typename DstType, class Context>
|
|
inline void CopyFromProtoAsIs(
|
|
const size_t size,
|
|
const google::protobuf::RepeatedField<SrcType>& field,
|
|
DstType* dst,
|
|
Context* context) {
|
|
static_assert(
|
|
sizeof(SrcType) == sizeof(DstType),
|
|
"The source type and dest type cannot be copied as-is. Did "
|
|
"you mean CopyFromProtoWithCast?");
|
|
CHECK_EQ(size, field.size()) << "Incorrect proto field size.";
|
|
context->template Copy<DstType, CPUContext, Context>(
|
|
size, reinterpret_cast<const DstType*>(field.data()), dst);
|
|
}
|
|
|
|
template <typename SrcType, typename DstType, class Context>
|
|
inline void CopyFromProtoWithCast(
|
|
const size_t size,
|
|
const google::protobuf::RepeatedField<SrcType>& field,
|
|
DstType* dst,
|
|
Context* context) {
|
|
CHECK_EQ(size, field.size()) << "Incorrect proto field size.";
|
|
// TODO: we are having one unnecessary copy here if the context is already
|
|
// CPUContext. Remove it if it is performance critical.
|
|
unique_ptr<DstType[]> buffer(new DstType[size]);
|
|
const SrcType* src = field.data();
|
|
for (int i = 0; i < size; ++i) {
|
|
buffer[i] = static_cast<DstType>(src[i]);
|
|
}
|
|
context->template Copy<DstType, CPUContext, Context>(size, buffer.get(), dst);
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
template <class Context>
|
|
void TensorSerializer<Context>::Serialize(
|
|
const Blob& blob,
|
|
const string& name,
|
|
BlobSerializerBase::SerializationAcceptor acceptor) {
|
|
CHECK(blob.IsType<Tensor<Context>>());
|
|
const auto& tensor = blob.template Get<Tensor<Context>>();
|
|
|
|
#ifndef __ANDROID__
|
|
std::vector<std::future<void>> futures;
|
|
#endif
|
|
|
|
for (size_t chunkBegin = 0; chunkBegin < tensor.size();
|
|
chunkBegin += FLAGS_caffe2_tensor_chunk_size) {
|
|
auto task = [&](size_t chunkBegin) {
|
|
BlobProto blob_proto;
|
|
blob_proto.set_name(name);
|
|
blob_proto.set_type(kTensorBlobType);
|
|
TensorProto& proto = *blob_proto.mutable_tensor();
|
|
proto.set_name(name);
|
|
this->Serialize(
|
|
tensor,
|
|
name,
|
|
blob_proto.mutable_tensor(),
|
|
chunkBegin,
|
|
FLAGS_caffe2_tensor_chunk_size);
|
|
acceptor(name, blob_proto.SerializeAsString());
|
|
};
|
|
#ifndef __ANDROID__
|
|
if (tensor.size() > FLAGS_caffe2_tensor_chunk_size) {
|
|
futures.emplace_back(std::async(std::launch::async, task, chunkBegin));
|
|
} else {
|
|
// Sync mode for small tensors
|
|
task(chunkBegin);
|
|
}
|
|
#else
|
|
// Since Android does not have std::future, we will always do sync mode
|
|
//
|
|
task(chunkBegin);
|
|
#endif
|
|
}
|
|
|
|
#ifndef __ANDROID__
|
|
for (auto& fut : futures) {
|
|
fut.get();
|
|
}
|
|
#endif
|
|
}
|
|
|
|
template <class Context>
|
|
void TensorSerializer<Context>::Serialize(
|
|
const Tensor<Context>& input, const string& name,
|
|
TensorProto* proto_ptr, size_t chunkBegin, int32_t chunkSize) {
|
|
CAFFE_ENFORCE(
|
|
chunkBegin < input.size(),
|
|
"Chunk begin is out of tensor: ",
|
|
chunkBegin,
|
|
' ',
|
|
input.size());
|
|
if (chunkBegin + chunkSize > input.size()) {
|
|
chunkSize = input.size() - chunkBegin;
|
|
}
|
|
|
|
TensorProto& proto = *proto_ptr;
|
|
proto.mutable_segment()->set_begin(chunkBegin);
|
|
proto.mutable_segment()->set_end(chunkBegin + chunkSize);
|
|
|
|
for (int i = 0; i < input.ndim(); ++i) {
|
|
proto.add_dims(input.dim(i));
|
|
}
|
|
const TensorProto::DataType data_type = TypeMetaToDataType(input.meta());
|
|
proto.set_data_type(data_type);
|
|
// A lot of copypaste is error prone. Should we create a macro for this?
|
|
switch (data_type) {
|
|
case TensorProto_DataType_FLOAT:
|
|
detail::CopyToProtoAsIs(
|
|
chunkSize,
|
|
input.template data<float>() + chunkBegin,
|
|
proto.mutable_float_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_INT32:
|
|
detail::CopyToProtoAsIs(
|
|
chunkSize,
|
|
input.template data<int>() + chunkBegin,
|
|
proto.mutable_int32_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_BYTE:
|
|
LOG(FATAL) << "This should not happen. When serializing, "
|
|
"BYTE is deprecated and moved to UINT8.";
|
|
break;
|
|
case TensorProto_DataType_STRING:
|
|
{
|
|
proto.mutable_string_data()->Reserve(chunkSize);
|
|
const string* content = input.template data<string>();
|
|
for (int i = chunkBegin; i < chunkBegin + chunkSize; ++i) {
|
|
proto.add_string_data(content[i]);
|
|
}
|
|
break;
|
|
}
|
|
case TensorProto_DataType_BOOL:
|
|
detail::CopyToProtoWithCast(
|
|
chunkSize,
|
|
input.template data<bool>() + chunkBegin,
|
|
proto.mutable_int32_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_UINT8:
|
|
detail::CopyToProtoWithCast(
|
|
chunkSize,
|
|
input.template data<uint8_t>() + chunkBegin,
|
|
proto.mutable_int32_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_INT8:
|
|
detail::CopyToProtoWithCast(
|
|
chunkSize,
|
|
input.template data<int8_t>() + chunkBegin,
|
|
proto.mutable_int32_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_UINT16:
|
|
detail::CopyToProtoWithCast(
|
|
chunkSize,
|
|
input.template data<uint16_t>() + chunkBegin,
|
|
proto.mutable_int32_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_INT16:
|
|
detail::CopyToProtoWithCast(
|
|
chunkSize,
|
|
input.template data<int16_t>() + chunkBegin,
|
|
proto.mutable_int32_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_INT64:
|
|
detail::CopyToProtoAsIs(
|
|
chunkSize,
|
|
input.template data<int64_t>() + chunkBegin,
|
|
proto.mutable_int64_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_FLOAT16:
|
|
detail::CopyToProtoWithCast(
|
|
chunkSize,
|
|
reinterpret_cast<const uint16_t*>(input.template data<float16>()) +
|
|
chunkBegin,
|
|
proto.mutable_int32_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_DOUBLE:
|
|
detail::CopyToProtoAsIs(
|
|
chunkSize,
|
|
input.template data<double>() + chunkBegin,
|
|
proto.mutable_double_data(),
|
|
&this->context_);
|
|
break;
|
|
case TensorProto_DataType_UNDEFINED:
|
|
LOG(FATAL) << "TensorSerializer does not have a serialization "
|
|
"implementation for " << input.meta().name();
|
|
break;
|
|
// Note: we intentially do not provide "default:" so if any new data types
|
|
// are added, the compiler should warn the user to add the case here.
|
|
}
|
|
StoreDeviceDetail(input, &proto);
|
|
}
|
|
|
|
template <class Context>
|
|
bool TensorDeserializer<Context>::Deserialize(
|
|
const BlobProto& blob_proto, Blob* blob) {
|
|
return Deserialize(
|
|
blob_proto.tensor(),
|
|
blob->GetMutable<Tensor<Context>>());
|
|
}
|
|
|
|
template <class Context>
|
|
bool TensorDeserializer<Context>::Deserialize(
|
|
const TensorProto& proto, Tensor<Context>* tensor) {
|
|
// We create a local context for deserializing. Since Caffe2 contexts are
|
|
// usually lightweighted, this should not involve too much overhead.
|
|
Context context(proto.device_detail());
|
|
context.SwitchToDevice();
|
|
vector<TIndex> dims;
|
|
for (const TIndex d : proto.dims()) {
|
|
dims.push_back(d);
|
|
}
|
|
tensor->Resize(dims);
|
|
|
|
// Safety check for zero-sized tensors: no copy needed.
|
|
if (tensor->size() == 0) {
|
|
return true;
|
|
}
|
|
|
|
int64_t chunkBegin = 0;
|
|
auto chunkEnd = tensor->size();
|
|
if (proto.has_segment()) {
|
|
chunkBegin = proto.segment().begin();
|
|
chunkEnd = proto.segment().end();
|
|
}
|
|
CAFFE_ENFORCE(
|
|
0 <= chunkBegin && chunkBegin < chunkEnd && chunkEnd <= tensor->size(),
|
|
"Invalid chunk ",
|
|
chunkBegin,
|
|
' ',
|
|
chunkEnd,
|
|
" with total tensor size ",
|
|
tensor->size());
|
|
auto chunkSize = chunkEnd - chunkBegin;
|
|
|
|
switch (proto.data_type()) {
|
|
case TensorProto_DataType_FLOAT:
|
|
detail::CopyFromProtoAsIs(
|
|
chunkSize,
|
|
proto.float_data(),
|
|
tensor->template mutable_data<float>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_INT32:
|
|
detail::CopyFromProtoAsIs(
|
|
chunkSize,
|
|
proto.int32_data(),
|
|
tensor->template mutable_data<int>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_BYTE:
|
|
// Since BYTE stores the data in a string field instead of a repreated
|
|
// field we will have it special cased.
|
|
if (chunkSize != proto.byte_data().size()) {
|
|
LOG(ERROR) << "Incorrect proto field size.";
|
|
return false;
|
|
}
|
|
context.template Copy<uint8_t, Context, CPUContext>(
|
|
chunkSize,
|
|
reinterpret_cast<const uint8_t*>(proto.byte_data().data()),
|
|
tensor->template mutable_data<uint8_t>() + chunkBegin);
|
|
break;
|
|
case TensorProto_DataType_STRING:
|
|
// Special handing of string because it is a non-fundamental type.
|
|
{
|
|
string* content = tensor->template mutable_data<string>();
|
|
for (int i = 0; i < chunkSize; ++i) {
|
|
content[i + chunkBegin] = proto.string_data(i);
|
|
}
|
|
}
|
|
break;
|
|
case TensorProto_DataType_BOOL:
|
|
detail::CopyFromProtoWithCast(
|
|
chunkSize,
|
|
proto.int32_data(),
|
|
tensor->template mutable_data<bool>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_UINT8:
|
|
detail::CopyFromProtoWithCast(
|
|
chunkSize,
|
|
proto.int32_data(),
|
|
tensor->template mutable_data<uint8_t>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_INT8:
|
|
detail::CopyFromProtoWithCast(
|
|
chunkSize,
|
|
proto.int32_data(),
|
|
tensor->template mutable_data<int8_t>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_UINT16:
|
|
detail::CopyFromProtoWithCast(
|
|
chunkSize,
|
|
proto.int32_data(),
|
|
tensor->template mutable_data<uint16_t>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_INT16:
|
|
detail::CopyFromProtoWithCast(
|
|
chunkSize,
|
|
proto.int32_data(),
|
|
tensor->template mutable_data<int16_t>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_INT64:
|
|
detail::CopyFromProtoAsIs(
|
|
chunkSize,
|
|
proto.int64_data(),
|
|
tensor->template mutable_data<int64_t>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_FLOAT16:
|
|
detail::CopyFromProtoWithCast(
|
|
chunkSize,
|
|
proto.int32_data(),
|
|
reinterpret_cast<uint16_t*>(
|
|
tensor->template mutable_data<float16>()) +
|
|
chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_DOUBLE:
|
|
detail::CopyFromProtoAsIs(
|
|
chunkSize,
|
|
proto.double_data(),
|
|
tensor->template mutable_data<double>() + chunkBegin,
|
|
&context);
|
|
break;
|
|
case TensorProto_DataType_UNDEFINED:
|
|
LOG(ERROR)
|
|
<< "Cannot deserialize from a TensorProto UNDEFINED data type.";
|
|
return false;
|
|
}
|
|
context.FinishDeviceComputation();
|
|
return true;
|
|
}
|
|
|
|
} // namespace caffe2
|
|
|
|
#endif // CAFFE2_CORE_BLOB_SERIALIZATION_H_
|