mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
PyTorch/Caffe2 tensor interop in Python (#17190)
Summary: Because of two separate python extensions with different pybind instances I have to go through void* conversion. Since it's hidden from user, it's fine. New APIs added on C2 side: - workspace.FetchTorch('blob') - workspace.Workspace.current.blobs['blob'].to_torch() - workspace.FeedBlob('blob', pytorch_tensor) Works on CPU an GPU. The only glitches are with resizing because of variable/tensor split. But data sharing works properly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17190 Reviewed By: ezyang Differential Revision: D14163882 Pulled By: dzhulgakov fbshipit-source-id: d18e5b8fcae026f393c842a1149e972515732de2
This commit is contained in:
committed by
Facebook Github Bot
parent
244d330980
commit
dec116e96f
@ -19,12 +19,14 @@
|
||||
#include <torch/csrc/tensor/python_tensor.h>
|
||||
#include <torch/csrc/utils/auto_gil.h>
|
||||
#include <torch/csrc/utils/cuda_lazy_init.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_strings.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/csrc/utils/tensor_new.h>
|
||||
#include <torch/csrc/jit/tracer.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <memory>
|
||||
@ -35,6 +37,8 @@ using namespace at;
|
||||
using namespace torch;
|
||||
using namespace torch::autograd;
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PyObject *THPVariableClass = nullptr;
|
||||
|
||||
static const char* VOLATILE_WARNING =
|
||||
@ -489,6 +493,25 @@ namespace torch { namespace autograd {
|
||||
extern PyMethodDef variable_methods[];
|
||||
extern void initTorchFunctions(PyObject *module);
|
||||
|
||||
void initTensorImplConversion(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
m.def("_wrap_tensor_impl", [](void* ptr) {
|
||||
auto p = c10::intrusive_ptr<c10::TensorImpl, at::UndefinedTensorImpl>::
|
||||
unsafe_reclaim_from_nonowning(static_cast<c10::TensorImpl*>(ptr));
|
||||
AT_CHECK(p.defined(), "Can't wrap undefined tensor");
|
||||
AT_CHECK(!p->is_variable(), "Can wrap only non-variable tensor");
|
||||
auto tensor = at::Tensor::wrap_tensor_impl(std::move(p));
|
||||
return py::cast(torch::autograd::Variable(
|
||||
torch::autograd::make_variable(std::move(tensor), false)));
|
||||
});
|
||||
// set on the module level to avoid mixing pybind and plain CPython extensions
|
||||
m.def("_tensor_impl_raw_handle", [](torch::autograd::Variable* t) -> void* {
|
||||
auto p = t->data().getIntrusivePtr();
|
||||
// We return a raw non-owning pointer here, we rely on surrounding
|
||||
// code to keep the original tensor alive
|
||||
return p.get();
|
||||
});
|
||||
}
|
||||
}}
|
||||
|
||||
bool THPVariable_initModule(PyObject *module)
|
||||
@ -502,5 +525,6 @@ bool THPVariable_initModule(PyObject *module)
|
||||
Py_INCREF(&THPVariableType);
|
||||
PyModule_AddObject(module, "_TensorBase", (PyObject *)&THPVariableType);
|
||||
torch::autograd::initTorchFunctions(module);
|
||||
torch::autograd::initTensorImplConversion(module);
|
||||
return true;
|
||||
}
|
||||
|
Reference in New Issue
Block a user