mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Properly annotated all apis for cpu front. Checked with cmake using cmake -DUSE_ATEN=ON -DUSE_CUDA=OFF -DBUILD_ATEN=ON and resulting libcaffe2.so has about 11k symbols. Pull Request resolved: https://github.com/pytorch/pytorch/pull/10504 Reviewed By: ezyang Differential Revision: D9316491 Pulled By: Yangqing fbshipit-source-id: 215659abf350af7032e9a4b0f28a856babab2454
95 lines
2.9 KiB
C++
95 lines
2.9 KiB
C++
#pragma once
|
|
|
|
#include <string>
|
|
#include <functional>
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/registry.h"
|
|
#include "caffe2/proto/caffe2.pb.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
class Blob;
|
|
|
|
constexpr int kDefaultChunkSize = -1;
|
|
constexpr int kNoChunking = 0;
|
|
|
|
/**
|
|
* @brief BlobSerializerBase is an abstract class that serializes a blob to a
|
|
* string.
|
|
*
|
|
* This class exists purely for the purpose of registering type-specific
|
|
* serialization code. If you need to serialize a specific type, you should
|
|
* write your own Serializer class, and then register it using
|
|
* REGISTER_BLOB_SERIALIZER. For a detailed example, see TensorSerializer for
|
|
* details.
|
|
*/
|
|
class BlobSerializerBase {
|
|
public:
|
|
virtual ~BlobSerializerBase() {}
|
|
using SerializationAcceptor =
|
|
std::function<void(const std::string& blobName, const std::string& data)>;
|
|
/**
|
|
* @brief The virtual function that returns a serialized string for the input
|
|
* blob.
|
|
* @param blob
|
|
* the input blob to be serialized.
|
|
* @param name
|
|
* the blob name to be used in the serialization implementation. It is up
|
|
* to the implementation whether this name field is going to be used or
|
|
* not.
|
|
* @param acceptor
|
|
* a lambda which accepts key value pairs to save them to storage.
|
|
* serailizer can use it to save blob in several chunks
|
|
* acceptor should be thread-safe
|
|
*/
|
|
virtual void Serialize(const Blob& blob, const std::string& name,
|
|
SerializationAcceptor acceptor) = 0;
|
|
|
|
virtual void SerializeWithChunkSize(
|
|
const Blob& blob,
|
|
const std::string& name,
|
|
SerializationAcceptor acceptor,
|
|
int /*chunk_size*/) {
|
|
// Base implementation.
|
|
Serialize(blob, name, acceptor);
|
|
}
|
|
};
|
|
|
|
// The Blob serialization registry and serializer creator functions.
|
|
CAFFE_DECLARE_TYPED_REGISTRY(
|
|
BlobSerializerRegistry,
|
|
TypeIdentifier,
|
|
BlobSerializerBase,
|
|
std::unique_ptr);
|
|
#define REGISTER_BLOB_SERIALIZER(id, ...) \
|
|
CAFFE_REGISTER_TYPED_CLASS(BlobSerializerRegistry, id, __VA_ARGS__)
|
|
// Creates an operator with the given operator definition.
|
|
inline unique_ptr<BlobSerializerBase> CreateSerializer(TypeIdentifier id) {
|
|
return BlobSerializerRegistry()->Create(id);
|
|
}
|
|
|
|
|
|
/**
|
|
* @brief BlobDeserializerBase is an abstract class that deserializes a blob
|
|
* from a BlobProto or a TensorProto.
|
|
*/
|
|
class CAFFE2_API BlobDeserializerBase {
|
|
public:
|
|
virtual ~BlobDeserializerBase() {}
|
|
|
|
// Deserializes from a BlobProto object.
|
|
virtual void Deserialize(const BlobProto& proto, Blob* blob) = 0;
|
|
};
|
|
|
|
CAFFE_DECLARE_REGISTRY(BlobDeserializerRegistry, BlobDeserializerBase);
|
|
#define REGISTER_BLOB_DESERIALIZER(name, ...) \
|
|
CAFFE_REGISTER_CLASS(BlobDeserializerRegistry, name, __VA_ARGS__)
|
|
// Creates an operator with the given operator definition.
|
|
inline unique_ptr<BlobDeserializerBase> CreateDeserializer(const string& type) {
|
|
return BlobDeserializerRegistry()->Create(type);
|
|
}
|
|
|
|
|
|
} // namespace caffe2
|