mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
IValue can store Blob (#11414)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/11414 caffe2::Blob can be stored in an IValue. This is a precondition for caffe2 to switch from Blob to IValue. Reviewed By: ezyang Differential Revision: D9731326 fbshipit-source-id: 462a39d2d9ab6f85b99b1670848c6976a3de417c
This commit is contained in:
committed by
Facebook Github Bot
parent
b7ebc00979
commit
65cbb8226b
@ -6,6 +6,7 @@
|
|||||||
#include <typeinfo>
|
#include <typeinfo>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ATen/core/intrusive_ptr.h>
|
||||||
#include <ATen/core/typeid.h>
|
#include <ATen/core/typeid.h>
|
||||||
#include <c10/macros/Macros.h>
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
@ -20,7 +21,7 @@ class Tensor;
|
|||||||
* properly when the blob is deallocated or re-allocated with a new type. A blob
|
* properly when the blob is deallocated or re-allocated with a new type. A blob
|
||||||
* could contain anything, although the most common case is to contain a Tensor.
|
* could contain anything, although the most common case is to contain a Tensor.
|
||||||
*/
|
*/
|
||||||
class CAFFE2_API Blob final {
|
class CAFFE2_API Blob final : public c10::intrusive_ptr_target {
|
||||||
public:
|
public:
|
||||||
using DestroyCall = void(void*);
|
using DestroyCall = void(void*);
|
||||||
|
|
||||||
@ -232,4 +233,8 @@ inline void swap(Blob& lhs, Blob& rhs) {
|
|||||||
lhs.swap(rhs);
|
lhs.swap(rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
|
||||||
|
return out << "Blob[" << v.TypeName() << "]";
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <ATen/core/Formatting.h>
|
#include <ATen/core/Formatting.h>
|
||||||
|
|
||||||
#define TORCH_FORALL_TAGS(_) \
|
#define TORCH_FORALL_TAGS(_) \
|
||||||
_(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) _(TensorList)
|
_(None) \
|
||||||
|
_(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) \
|
||||||
|
_(TensorList) _(Blob)
|
||||||
|
|
||||||
namespace torch { namespace jit {
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include <ATen/core/Tensor.h>
|
#include <ATen/core/Tensor.h>
|
||||||
#include <ATen/core/TensorImpl.h>
|
#include <ATen/core/TensorImpl.h>
|
||||||
#include <ATen/core/UndefinedTensorImpl.h>
|
#include <ATen/core/UndefinedTensorImpl.h>
|
||||||
|
#include <ATen/core/blob.h>
|
||||||
#include <ATen/core/intrusive_ptr.h>
|
#include <ATen/core/intrusive_ptr.h>
|
||||||
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
@ -64,8 +65,10 @@ using DoubleList = ConstantList<double>;
|
|||||||
// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
|
// to mark whether that type is a subtype of c10::intrusive_ptr_target and needs
|
||||||
// retain/release calls.
|
// retain/release calls.
|
||||||
|
|
||||||
#define TORCH_FORALL_TAGS(_) \
|
#define TORCH_FORALL_TAGS(_) \
|
||||||
_(None) _(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) _(TensorList)
|
_(None) \
|
||||||
|
_(Tensor) _(Double) _(Int) _(Tuple) _(IntList) _(DoubleList) _(String) \
|
||||||
|
_(TensorList) _(Blob)
|
||||||
|
|
||||||
struct CAFFE2_API IValue final {
|
struct CAFFE2_API IValue final {
|
||||||
IValue()
|
IValue()
|
||||||
@ -125,6 +128,25 @@ struct CAFFE2_API IValue final {
|
|||||||
return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
|
return at::Tensor(toIntrusivePtr<at::TensorImpl, at::UndefinedTensorImpl>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IValue(caffe2::Blob blob) : tag(Tag::Blob), is_intrusive_ptr(true) {
|
||||||
|
// TODO (after Tensor merge) If we pass in a Blob holding a Tensor, extract
|
||||||
|
// and
|
||||||
|
// store it as a Tensor instead.
|
||||||
|
payload.as_intrusive_ptr =
|
||||||
|
c10::make_intrusive<caffe2::Blob>(std::move(blob)).release();
|
||||||
|
}
|
||||||
|
bool isBlob() const {
|
||||||
|
return Tag::Blob == tag;
|
||||||
|
}
|
||||||
|
caffe2::Blob& toBlob() & {
|
||||||
|
AT_ASSERT(isBlob());
|
||||||
|
return *static_cast<caffe2::Blob*>(payload.as_intrusive_ptr);
|
||||||
|
}
|
||||||
|
const caffe2::Blob& toBlob() const& {
|
||||||
|
AT_ASSERT(isBlob());
|
||||||
|
return *static_cast<caffe2::Blob*>(payload.as_intrusive_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
// Tuple
|
// Tuple
|
||||||
IValue(c10::intrusive_ptr<Tuple> v);
|
IValue(c10::intrusive_ptr<Tuple> v);
|
||||||
bool isTuple() const { return Tag::Tuple == tag; }
|
bool isTuple() const { return Tag::Tuple == tag; }
|
||||||
|
@ -47,7 +47,7 @@ int main(int argc, char** argv) {
|
|||||||
LOG(INFO)
|
LOG(INFO)
|
||||||
<< "Is the blob type float? "
|
<< "Is the blob type float? "
|
||||||
<< myblob.IsType<float>();
|
<< myblob.IsType<float>();
|
||||||
|
|
||||||
const int& myint_const = myblob.Get<int>();
|
const int& myint_const = myblob.Get<int>();
|
||||||
LOG(INFO)
|
LOG(INFO)
|
||||||
<< "The value of the int number stored in the blob is: "
|
<< "The value of the int number stored in the blob is: "
|
||||||
@ -80,7 +80,7 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
std::string* pvec = new std::string();
|
std::string* pvec = new std::string();
|
||||||
myblob.Reset(pvec); // no need to release pvec, myblob takes ownership.
|
myblob.Reset(pvec); // no need to release pvec, myblob takes ownership.
|
||||||
|
|
||||||
LOG(INFO) << "Is the blob now of type string? "
|
LOG(INFO) << "Is the blob now of type string? "
|
||||||
<< myblob.IsType<std::string>();
|
<< myblob.IsType<std::string>();
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user