IValue can store Blob (#11414)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11414

caffe2::Blob can be stored in an IValue. This is a precondition for caffe2 to switch from Blob to IValue.

Reviewed By: ezyang

Differential Revision: D9731326

fbshipit-source-id: 462a39d2d9ab6f85b99b1670848c6976a3de417c
This commit is contained in:
Sebastian Messmer
2018-09-26 01:02:27 -07:00
committed by Facebook Github Bot
parent b7ebc00979
commit 65cbb8226b
4 changed files with 36 additions and 7 deletions

View File

@ -6,6 +6,7 @@
#include <typeinfo>
#include <vector>
#include <ATen/core/intrusive_ptr.h>
#include <ATen/core/typeid.h>
#include <c10/macros/Macros.h>
@ -20,7 +21,7 @@ class Tensor;
* properly when the blob is deallocated or re-allocated with a new type. A blob
* could contain anything, although the most common case is to contain a Tensor.
*/
class CAFFE2_API Blob final {
class CAFFE2_API Blob final : public c10::intrusive_ptr_target {
public:
using DestroyCall = void(void*);
@ -232,4 +233,8 @@ inline void swap(Blob& lhs, Blob& rhs) {
lhs.swap(rhs);
}
inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
return out << "Blob[" << v.TypeName() << "]";
}
} // namespace caffe2

View File

@ -1,8 +1,10 @@
#include <ATen/core/ivalue.h>
#include <ATen/core/Formatting.h>
#define TORCH_FORALL_TAGS(_) \
_(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) _(TensorList)
#define TORCH_FORALL_TAGS(_) \
_(None) \
_(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) \
_(TensorList) _(Blob)
namespace torch { namespace jit {

View File

@ -4,6 +4,7 @@
#include <ATen/core/Tensor.h>
#include <ATen/core/TensorImpl.h>
#include <ATen/core/UndefinedTensorImpl.h>
#include <ATen/core/blob.h>
#include <ATen/core/intrusive_ptr.h>
#include <type_traits>
@ -64,8 +65,10 @@ using DoubleList = ConstantList<double>;
// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
// retain/release calls.
#define TORCH_FORALL_TAGS(_) \
_(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) _(TensorList)
#define TORCH_FORALL_TAGS(_) \
_(None) \
_(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) \
_(TensorList) _(Blob)
struct CAFFE2_API IValue final {
IValue()
@ -125,6 +128,25 @@ struct CAFFE2_API IValue final {
return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
}
IValue(caffe2::Blob blob) : tag(Tag::Blob), is_intrusive_ptr(true) {
// TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
// and
// store it as a Tensor instead.
payload.as_intrusive_ptr =
c10::make_intrusive<caffe2::Blob>(std::move(blob)).release();
}
bool isBlob() const {
return Tag::Blob == tag;
}
caffe2::Blob& toBlob() & {
AT_ASSERT(isBlob());
return *static_cast<caffe2::Blob*>(payload.as_intrusive_ptr);
}
const caffe2::Blob& toBlob() const& {
AT_ASSERT(isBlob());
return *static_cast<caffe2::Blob*>(payload.as_intrusive_ptr);
}
// Tuple
IValue(c10::intrusive_ptr<Tuple> v);
bool isTuple() const { return Tag::Tuple == tag; }

View File

@ -47,7 +47,7 @@ int main(int argc, char** argv) {
LOG(INFO)
<< "Is the blob type float? "
<< myblob.IsType<float>();
const int& myint_const = myblob.Get<int>();
LOG(INFO)
<< "The value of the int number stored in the blob is: "
@ -80,7 +80,7 @@ int main(int argc, char** argv) {
std::string* pvec = new std::string();
myblob.Reset(pvec); // no need to release pvec, myblob takes ownership.
LOG(INFO) << "Is the blob now of type string? "
<< myblob.IsType<std::string>();