mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[BE] Remove HermeticPyObjectTLS and Simplify PythonOpRegistrationTrampoline (#163464)"
This reverts commit 94195a37ae4eae9c486a81b0f67725c8970f74d6. Reverted https://github.com/pytorch/pytorch/pull/163464 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/163464#issuecomment-3353307034))
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/impl/GPUTrace.h>
|
||||
#include <c10/core/impl/HermeticPyObjectTLS.h>
|
||||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
@ -259,6 +260,10 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
|
||||
}
|
||||
|
||||
std::optional<PyObject*> mb_obj =
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj();
|
||||
if (mb_obj.has_value()) {
|
||||
@ -374,6 +379,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) {
|
||||
// Flip THPVariable to be non-owning
|
||||
// (near use-after-free miss here: fresh MaybeOwned is created breaking
|
||||
// reference on Tensor in struct BEFORE we overwrite the old one)
|
||||
TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
|
||||
self->cdata = MaybeOwned<Variable>::borrowed(tensor);
|
||||
|
||||
// NB: At this point, tensor *could* be dead (e.g., some other C++ thread
|
||||
@ -2467,13 +2473,28 @@ static PyObject* THPVariable_NewWithVar(
|
||||
auto v = (THPVariable*)obj;
|
||||
// TODO: named constructor to avoid default initialization
|
||||
new (&v->cdata) MaybeOwned<Variable>();
|
||||
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||
const auto& var = THPVariable_Unpack(v);
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj);
|
||||
if (has_torch_dispatch_if_known.has_value()
|
||||
? *has_torch_dispatch_if_known
|
||||
: check_has_torch_dispatch(obj)) {
|
||||
var.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
// Do NOT initialize pyobj field on the tensor, you own the C++
|
||||
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!check_has_torch_dispatch(obj),
|
||||
"While HermeticPyObject was enabled, we attempted to create a tensor "
|
||||
"subclass with __torch_dispatch__. This violates the invariant that "
|
||||
"operations in HermeticPyObject have equivalent C++ implementations. "
|
||||
"If your operator registered from Python operator registration isn't "
|
||||
"doing anything strange, there may be an internal PyTorch bug involving "
|
||||
"not appropriately disabling TorchDispatchMode before executing "
|
||||
"Python op registration.");
|
||||
} else {
|
||||
// Normal codepath
|
||||
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||
const auto& var = THPVariable_Unpack(v);
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj);
|
||||
if (has_torch_dispatch_if_known.has_value()
|
||||
? *has_torch_dispatch_if_known
|
||||
: check_has_torch_dispatch(obj)) {
|
||||
var.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
|
Reference in New Issue
Block a user