From cc5d74c366e8e106097489d05ec062ddcd279d38 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 30 Sep 2025 18:20:17 +0000 Subject: [PATCH] Revert "[BE] Remove HermeticPyObjectTLS and Simplify PythonOpRegistrationTrampoline (#163464)" This reverts commit 94195a37ae4eae9c486a81b0f67725c8970f74d6. Reverted https://github.com/pytorch/pytorch/pull/163464 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/163464#issuecomment-3353307034)) --- .../core/PythonOpRegistrationTrampoline.cpp | 24 ++++--- .../core/PythonOpRegistrationTrampoline.h | 10 ++- c10/core/impl/HermeticPyObjectTLS.cpp | 21 +++++++ c10/core/impl/HermeticPyObjectTLS.h | 62 +++++++++++++++++++ c10/core/impl/PyObjectSlot.h | 20 +++++- cmake/prioritized_text.txt | 1 + torch/csrc/PyInterpreter.cpp | 13 ++-- torch/csrc/Storage.cpp | 13 +++- torch/csrc/autograd/python_variable.cpp | 35 ++++++++--- torch/csrc/utils/python_dispatch.cpp | 34 ++++++++++ 10 files changed, 202 insertions(+), 31 deletions(-) create mode 100644 c10/core/impl/HermeticPyObjectTLS.cpp create mode 100644 c10/core/impl/HermeticPyObjectTLS.h diff --git a/aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp b/aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp index f50b2507d914..219d774de3a5 100644 --- a/aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp +++ b/aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp @@ -1,22 +1,32 @@ #include -#include -// TODO: delete this namespace at::impl { -c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::interpreter_ = nullptr; +// The strategy is that all python interpreters attempt to register themselves +// as the main interpreter, but only one wins. Only that interpreter is +// allowed to interact with the C++ dispatcher. Furthermore, when we execute +// logic on that interpreter, we do so hermetically, never setting pyobj field +// on Tensor. + +std::atomic + PythonOpRegistrationTrampoline::interpreter_{nullptr}; c10::impl::PyInterpreter* PythonOpRegistrationTrampoline::getInterpreter() { - return c10::impl::getGlobalPyInterpreter(); + return PythonOpRegistrationTrampoline::interpreter_.load(); } bool PythonOpRegistrationTrampoline::registerInterpreter( c10::impl::PyInterpreter* interp) { - if (interpreter_ != nullptr) { + c10::impl::PyInterpreter* expected = nullptr; + interpreter_.compare_exchange_strong(expected, interp); + if (expected != nullptr) { + // This is the second (or later) Python interpreter, which means we need + // non-trivial hermetic PyObject TLS + c10::impl::HermeticPyObjectTLS::init_state(); return false; + } else { + return true; } - interpreter_ = interp; - return true; } } // namespace at::impl diff --git a/aten/src/ATen/core/PythonOpRegistrationTrampoline.h b/aten/src/ATen/core/PythonOpRegistrationTrampoline.h index 062dbebc3ceb..bec323c7d25b 100644 --- a/aten/src/ATen/core/PythonOpRegistrationTrampoline.h +++ b/aten/src/ATen/core/PythonOpRegistrationTrampoline.h @@ -2,21 +2,19 @@ #include -// TODO: We can get rid of this +// TODO: this can probably live in c10 namespace at::impl { -// Manages the single Python interpreter instance for PyTorch. class TORCH_API PythonOpRegistrationTrampoline final { - static c10::impl::PyInterpreter* interpreter_; + static std::atomic interpreter_; public: - // Register the Python interpreter. Returns true on first registration, - // false if an interpreter was already registered. + // Returns true if you successfully registered yourself (that means + // you are in the hot seat for doing the operator registrations!) static bool registerInterpreter(c10::impl::PyInterpreter*); - // Returns the registered interpreter via the global PyInterpreter hooks. // Returns nullptr if no interpreter has been registered yet. static c10::impl::PyInterpreter* getInterpreter(); }; diff --git a/c10/core/impl/HermeticPyObjectTLS.cpp b/c10/core/impl/HermeticPyObjectTLS.cpp new file mode 100644 index 000000000000..856c63a93a92 --- /dev/null +++ b/c10/core/impl/HermeticPyObjectTLS.cpp @@ -0,0 +1,21 @@ +#include + +namespace c10::impl { + +thread_local static std::atomic hermeticPyObjectState{false}; + +std::atomic HermeticPyObjectTLS::haveState_{false}; + +void HermeticPyObjectTLS::set_state(bool state) { + hermeticPyObjectState = state; +} + +bool HermeticPyObjectTLS::get_tls_state() { + return hermeticPyObjectState; +} + +void HermeticPyObjectTLS::init_state() { + haveState_ = true; +} + +} // namespace c10::impl diff --git a/c10/core/impl/HermeticPyObjectTLS.h b/c10/core/impl/HermeticPyObjectTLS.h new file mode 100644 index 000000000000..a973a5d2cef8 --- /dev/null +++ b/c10/core/impl/HermeticPyObjectTLS.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include + +namespace c10::impl { + +// This TLS controls whether or not we permanently associate PyObject +// with Tensor the first time it is allocated. When hermetic PyObject +// TLS is enabled (state is true), we DO NOT save PyObjects to Tensor, +// meaning you get a distinct PyObject whenever you execute the code in +// question. +struct C10_API HermeticPyObjectTLS { + static void set_state(bool state); + static bool get_state() { + // Hypothetical fastpath if torchdeploy/multipy // codespell:ignore multipy + // isn't used. Per + // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf + // this qualifies relaxed access because it is a single-location data + // structure (only the boolean here). + // + // Forgetting about data races for a moment, is there a logical race? + // + // - Boolean only ever transitions from false to true. So the + // critical situation is when one interpreter is already running + // when a second interpreter switches haveState from false to true. + // + // - The first interpreter is indifferent whether or not it sees + // hasState true/false; obviously false works (this is what the + // interpreter was previously using; more directly, the interpreter + // calls into itself as the handler, so being hermetic is not + // required), and true simply means serviced python operator calls will + // be hermetic; in these cases it is expected to be functionally + // equivalent. + // + // - The second interpreter MUST see hasState true (as its requests will + // be forwarded to the first interpreter), but it is assumed that there + // is a synchronization between the interpreter initialization, and + // when we actually perform operations, so it is guaranteed to see + // hasState true. + // + // QED. + // + // This fastpath is currently disabled so that we can more easily test that + // hermetic mode works correctly even on stock build of PyTorch. + if (false && !haveState_.load(std::memory_order_relaxed)) + return false; + return get_tls_state(); + } + // Call this from the multipy/torchdeploy // codespell:ignore multipy + // top level + static void init_state(); + + private: + // This only flipped once from false to true during + // torchdeploy/multipy initialization, // codespell:ignore multipy + // and never again. + static std::atomic haveState_; + static bool get_tls_state(); +}; + +} // namespace c10::impl diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index 8824d33eb224..e7d78f8360c3 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -41,15 +42,32 @@ struct C10_API PyObjectSlot { PyObject* _unchecked_untagged_pyobj() const; - // Test the interpreter / PyObj as they may be null + // Test the interpreter tag. If tagged for the current interpreter, return + // a non-nullopt (but possibly null) PyObject. If (possibly) untagged, + // returns a nullopt. If it is definitely invalid, raises an error. + // + // If `ignore_hermetic_tls` is false and this function is called from a + // hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then + // nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic + // context is ignored, allowing you to check the interpreter tag of a + // nonhermetic PyObject from within a hermetic context. This is necessary + // because there are some cases where the deallocator function of a + // nonhermetic PyObject is called from within a hermetic context, so it must + // be properly treated as a nonhermetic PyObject. + // // NB: this lives in header so that we can avoid actually creating the // std::optional + // @todo alban: I'm not too sure what's going on here, we can probably delete + // it but it's worthwhile making sure std::optional check_pyobj() const { impl::PyInterpreter* interpreter = getGlobalPyInterpreter(); if (interpreter == nullptr || pyobj_ == nullptr) { return std::nullopt; } + if (c10::impl::HermeticPyObjectTLS::get_state()) { + return std::nullopt; + } return _unchecked_untagged_pyobj(); } diff --git a/cmake/prioritized_text.txt b/cmake/prioritized_text.txt index f7d41b0bec3e..e5e36f34f98d 100644 --- a/cmake/prioritized_text.txt +++ b/cmake/prioritized_text.txt @@ -153,6 +153,7 @@ _ZN3c104impl12PyObjectSlot10owns_pyobjEv _ZN3c104impl12PyObjectSlot19maybe_destroy_pyobjEv _ZN3c104impl12PyObjectSlotC1Ev _ZN3c104impl12PyObjectSlotD2Ev +_ZN3c104impl19HermeticPyObjectTLS13get_tls_stateEv _ZN3c104impl20TorchDispatchModeTLS13any_modes_setEb _ZN3c104impl23ExcludeDispatchKeyGuardC1ENS_14DispatchKeySetE _ZN3c104impl23ExcludeDispatchKeyGuardD2Ev diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index b863fd44c152..993f8b8216a6 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -157,10 +157,10 @@ class PyInterpreterHolder { public: PyInterpreterHolder() : impl_(new c10::impl::PyInterpreter( - ConcretePyInterpreterVTable::instance())) { - // Register the single interpreter - at::impl::PythonOpRegistrationTrampoline::registerInterpreter(impl_); - } + ConcretePyInterpreterVTable::instance())), + is_main_interpreter_( + at::impl::PythonOpRegistrationTrampoline::registerInterpreter( + impl_)) {} PyInterpreterHolder(const PyInterpreterHolder&) = delete; PyInterpreterHolder(PyInterpreterHolder&&) = delete; PyInterpreterHolder& operator=(const PyInterpreterHolder&) = delete; @@ -174,14 +174,13 @@ class PyInterpreterHolder { c10::impl::PyInterpreter* get() const noexcept { return impl_; } - // In single-interpreter mode, this is always true - // TODO: delete this bool is_main_interpreter() const noexcept { - return true; + return is_main_interpreter_; } private: c10::impl::PyInterpreter* impl_; + bool is_main_interpreter_; }; py::object torchDispatchFromTensorImpl( diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 2dac6151a798..001b02a3d7dd 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -74,9 +74,13 @@ PyObject* THPStorage_NewWithStorage( s->cdata = c10::MaybeOwned::owned(std::move(_storage)); - s->is_hermetic = false; - const auto& storage = THPStorage_Unpack(s); - storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj); + if (!c10::impl::HermeticPyObjectTLS::get_state()) { + s->is_hermetic = false; + const auto& storage = THPStorage_Unpack(s); + storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj); + } else { + s->is_hermetic = true; + } return obj; } @@ -84,6 +88,9 @@ PyObject* THPStorage_NewWithStorage( // Wraps the c10::Storage with a storage PyObject PyObject* THPStorage_Wrap(c10::Storage storage) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); + if (c10::impl::HermeticPyObjectTLS::get_state()) { + return THPStorage_NewWithStorage(THPStorageClass, std::move(storage)); + } c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); std::optional maybe_pyobj = pyobj_slot->check_pyobj(); diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 5d88d71dc04c..0efef58d85f2 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include @@ -259,6 +260,10 @@ PyObject* THPVariable_Wrap(const at::TensorBase& var) { Py_RETURN_NONE; } + if (c10::impl::HermeticPyObjectTLS::get_state()) { + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var); + } + std::optional mb_obj = var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(); if (mb_obj.has_value()) { @@ -374,6 +379,7 @@ static bool THPVariable_tryResurrect(THPVariable* self) { // Flip THPVariable to be non-owning // (near use-after-free miss here: fresh MaybeOwned is created breaking // reference on Tensor in struct BEFORE we overwrite the old one) + TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state()); self->cdata = MaybeOwned::borrowed(tensor); // NB: At this point, tensor *could* be dead (e.g., some other C++ thread @@ -2467,13 +2473,28 @@ static PyObject* THPVariable_NewWithVar( auto v = (THPVariable*)obj; // TODO: named constructor to avoid default initialization new (&v->cdata) MaybeOwned(); - v->cdata = MaybeOwned::owned(Variable(_var)); - const auto& var = THPVariable_Unpack(v); - var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj); - if (has_torch_dispatch_if_known.has_value() - ? *has_torch_dispatch_if_known - : check_has_torch_dispatch(obj)) { - var.unsafeGetTensorImpl()->set_python_dispatch(true); + if (c10::impl::HermeticPyObjectTLS::get_state()) { + // Do NOT initialize pyobj field on the tensor, you own the C++ + v->cdata = MaybeOwned::owned(Variable(_var)); + TORCH_INTERNAL_ASSERT( + !check_has_torch_dispatch(obj), + "While HermeticPyObject was enabled, we attempted to create a tensor " + "subclass with __torch_dispatch__. This violates the invariant that " + "operations in HermeticPyObject have equivalent C++ implementations. " + "If your operator registered from Python operator registration isn't " + "doing anything strange, there may be an internal PyTorch bug involving " + "not appropriately disabling TorchDispatchMode before executing " + "Python op registration."); + } else { + // Normal codepath + v->cdata = MaybeOwned::owned(Variable(_var)); + const auto& var = THPVariable_Unpack(v); + var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj); + if (has_torch_dispatch_if_known.has_value() + ? *has_torch_dispatch_if_known + : check_has_torch_dispatch(obj)) { + var.unsafeGetTensorImpl()->set_python_dispatch(true); + } } } return obj; diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 23ed9cb1b22a..ed1084d443ec 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -84,6 +84,40 @@ inline static torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { } } +struct EnableHermeticPyObject { + EnableHermeticPyObject() + : old_(c10::impl::HermeticPyObjectTLS::get_state()), + old_excluded_python_( + c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)), + old_python_( + c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)), + old_python_snapshot_(c10::impl::tls_is_dispatch_key_included( + at::DispatchKey::PythonTLSSnapshot)) { + c10::impl::HermeticPyObjectTLS::set_state(true); + c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true); + c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false); + c10::impl::tls_set_dispatch_key_included( + at::DispatchKey::PythonTLSSnapshot, false); + } + ~EnableHermeticPyObject() { + c10::impl::HermeticPyObjectTLS::set_state(old_); + c10::impl::tls_set_dispatch_key_excluded( + at::DispatchKey::Python, old_excluded_python_); + c10::impl::tls_set_dispatch_key_included( + at::DispatchKey::Python, old_python_); + c10::impl::tls_set_dispatch_key_included( + at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_); + } + EnableHermeticPyObject(const EnableHermeticPyObject&) = delete; + EnableHermeticPyObject(EnableHermeticPyObject&&) = delete; + EnableHermeticPyObject& operator=(const EnableHermeticPyObject&) = delete; + EnableHermeticPyObject& operator=(EnableHermeticPyObject&&) = delete; + bool old_; + bool old_excluded_python_; + bool old_python_; + bool old_python_snapshot_; +}; + class PythonKernelHolder : public c10::OperatorKernel { c10::SafePyObject func_; c10::DispatchKey dispatch_key_;