mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Tag PyObject on TensorImpl per torchdeploy interpreter (#57985)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57985 Fixes https://github.com/pytorch/pytorch/issues/57756 This PR introduces a new `pyobj_interpreter_` field on TensorImpl which tracks what Python interpreter (if any) owns the TensorImpl. This makes it illegal to bind a TensorImpl from multiple Python interpreters, and means that we can now directly store PyObject pointer on TensorImpl even in the presence of multiple Python interpreters, as is the case in torchdeploy. This is a necessary step for PyObject preservation, which cannot be easily implemented when there are multiple Python interpreters. Although the PR is not that long, there is a very subtle portion of the implementation devoted to ensuring that the tagging process is thread safe, since multiple threads can concurrently try to tag a PyObject. Check Note [Python interpreter tag] and Note [Memory ordering on Python interpreter tag] for detailed discussion of how this is handled. You will have to check this code carefully in code review; I did not torture test the multithreaded paths in any meaningful way. In a follow up PR, I will pack the interpreter and PyObject fields into single atomic word on 64-bit. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: wconstab Differential Revision: D28390242 Pulled By: ezyang fbshipit-source-id: a6d9b244ee6b9c7209e1ed185e336297848e3017
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fe8e5eb260
commit
773cfae93b
@ -33,6 +33,8 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -41,6 +43,33 @@ using namespace at;
|
||||
using namespace torch;
|
||||
using namespace torch::autograd;
|
||||
|
||||
namespace {
|
||||
|
||||
std::string concrete_name_fn(const c10::impl::PyInterpreter* self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
class PyInterpreterHolder {
|
||||
public:
|
||||
PyInterpreterHolder()
|
||||
: impl_(new c10::impl::PyInterpreter(&concrete_name_fn)) {}
|
||||
// NB: intentionally leaks the memory
|
||||
~PyInterpreterHolder() {
|
||||
impl_->disarm();
|
||||
}
|
||||
c10::impl::PyInterpreter* get() const noexcept {
|
||||
return impl_;
|
||||
}
|
||||
|
||||
private:
|
||||
c10::impl::PyInterpreter* impl_;
|
||||
};
|
||||
PyInterpreterHolder self_interpreter;
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
@ -55,49 +84,24 @@ static const char* VOLATILE_WARNING =
|
||||
"volatile was removed and now has no effect. Use "
|
||||
"`with torch.no_grad():` instead.";
|
||||
|
||||
#ifdef USE_DEPLOY
|
||||
// used only in libtorch_deployinterpreter.so
|
||||
// there are muliple copies of the python interpreter that
|
||||
// can shared Tensors, so rather than use their internal pointer
|
||||
// to a PyObject use a library-local map.
|
||||
static std::unordered_map<void*, PyObject*> impl_to_pyobj;
|
||||
|
||||
void set_pyobj(const Variable& self, PyObject* pyobj) {
|
||||
TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined tensor");
|
||||
void* key = self.unsafeGetTensorImpl();
|
||||
if (!pyobj) {
|
||||
impl_to_pyobj.erase(key);
|
||||
return;
|
||||
}
|
||||
impl_to_pyobj[key] = pyobj;
|
||||
}
|
||||
|
||||
PyObject* pyobj(const Variable& self) {
|
||||
TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined tensor");
|
||||
auto it = impl_to_pyobj.find(self.unsafeGetTensorImpl());
|
||||
return it == impl_to_pyobj.end() ? nullptr : it->second;
|
||||
}
|
||||
#else
|
||||
void set_pyobj(const Variable& self, PyObject* pyobj) {
|
||||
TORCH_CHECK(self.defined(), "cannot call set_pyobj() on undefined tensor");
|
||||
self.unsafeGetTensorImpl()->set_pyobj(pyobj);
|
||||
}
|
||||
|
||||
PyObject* pyobj(const Variable& self) {
|
||||
TORCH_CHECK(self.defined(), "cannot call pyobj() on undefined tensor");
|
||||
return self.unsafeGetTensorImpl()->pyobj();
|
||||
}
|
||||
#endif
|
||||
|
||||
// Creates a new Python object for a Variable. The Variable must not already
|
||||
// have a PyObject* associated with it.
|
||||
static PyObject* THPVariable_NewWithVar(PyTypeObject* type, Variable var)
|
||||
{
|
||||
// Creates a new Python object for a Variable. The status parameter
|
||||
// specifies what the interpreter tag status on the object is; for
|
||||
// example, if you ran check_pyobj, the return optional of this object
|
||||
// tells you if the tensor was already tagged or not so you can pass
|
||||
// TAGGED_BY_US or MAYBE_UNINITIALIZED; in other cases, you know where
|
||||
// var came from and can directly assert that it's DEFINITELY_UNINITIALIZED.
|
||||
// It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED.
|
||||
static PyObject* THPVariable_NewWithVar(
|
||||
PyTypeObject* type,
|
||||
Variable var,
|
||||
c10::impl::PyInterpreterStatus status) {
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
if (obj) {
|
||||
auto v = (THPVariable*) obj;
|
||||
new (&v->cdata) Variable(std::move(var));
|
||||
set_pyobj(v->cdata, obj);
|
||||
// cannot use var as it is moved out of
|
||||
THPVariable_Unpack(v).unsafeGetTensorImpl()->init_pyobj(
|
||||
self_interpreter.get(), obj, status);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
@ -108,12 +112,26 @@ PyObject * THPVariable_Wrap(Variable var)
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
if (auto obj = pyobj(var)) {
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
c10::optional<PyObject*> mb_obj =
|
||||
var.unsafeGetTensorImpl()->check_pyobj(self_interpreter.get());
|
||||
c10::impl::PyInterpreterStatus status;
|
||||
if (mb_obj.has_value()) {
|
||||
auto obj = *mb_obj;
|
||||
if (obj) {
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
}
|
||||
// TODO: a better invariant is that if we tagged, we MUST have a valid
|
||||
// PyObject. That's PyObject preservation
|
||||
// (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR
|
||||
// being a thing, the PyObject field will get cleared when all references
|
||||
// to the Python object are removed.
|
||||
status = c10::impl::PyInterpreterStatus::TAGGED_BY_US;
|
||||
} else {
|
||||
status = c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED;
|
||||
}
|
||||
|
||||
return THPVariable_NewWithVar((PyTypeObject *)THPVariableClass, std::move(var));
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)THPVariableClass, std::move(var), status);
|
||||
}
|
||||
|
||||
static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
|
||||
@ -165,7 +183,10 @@ static int THPVariable_clear(THPVariable *self)
|
||||
// objects stay live, buster! See
|
||||
// https://github.com/pytorch/pytorch/issues/22884 for an example of
|
||||
// this actually showing up.
|
||||
set_pyobj(self->cdata, nullptr);
|
||||
//
|
||||
// [torchdeploy] Note that we DON'T clear the interpreter field. Once on an
|
||||
// interpreter, always on an interpreter.
|
||||
tensor.unsafeGetTensorImpl()->unchecked_clear_pyobj(self_interpreter.get());
|
||||
}
|
||||
self->cdata.reset();
|
||||
return 0;
|
||||
@ -194,7 +215,10 @@ static PyObject* THPVariable_as_subclass(PyObject* _self, PyObject* args, PyObje
|
||||
if (!PyType_Check(cls)) {
|
||||
throw torch::TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name);
|
||||
}
|
||||
return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias());
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)cls,
|
||||
self.alias(),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -209,7 +233,8 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P
|
||||
if (!PyType_Check(cls)) {
|
||||
throw torch::TypeError("cls must be a type (got %s)", Py_TYPE(cls)->tp_name);
|
||||
}
|
||||
auto data = r.tensor(1).detach();
|
||||
auto data =
|
||||
r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED)
|
||||
// We set `data`'s `allow_tensor_metadata_change` to true here, because we want to
|
||||
// allow the following use case for backward compatibility:
|
||||
//
|
||||
@ -221,7 +246,10 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P
|
||||
// ```
|
||||
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
||||
auto var = data.set_requires_grad(r.toBool(2));
|
||||
return THPVariable_NewWithVar((PyTypeObject*)cls, std::move(var));
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)cls,
|
||||
std::move(var),
|
||||
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -951,11 +979,15 @@ PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs
|
||||
TORCH_CHECK(type != &THPVariableType, "Cannot directly construct _TensorBase; subclass it and then construct that");
|
||||
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||
auto tensor = torch::utils::legacy_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs);
|
||||
return THPVariable_NewWithVar(type, std::move(tensor));
|
||||
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
|
||||
// given a raw pointer that will refcount bump
|
||||
return THPVariable_NewWithVar(
|
||||
type,
|
||||
std::move(tensor),
|
||||
c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
|
||||
int THPVariableMetaType_init(PyObject *cls, PyObject *args, PyObject *kwargs) {
|
||||
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
||||
return -1;
|
||||
|
Reference in New Issue
Block a user