Files
pytorch/c10/core/impl/PyObjectSlot.h
Sam Gross b0b5069e21 Rework PyObject preservation v2
Make the PyObject preservation scheme thread-safe with free threaded (nogil) Python. The general idea is:

* Python Tensor and Storage objects always hold a strong reference to their underlying c10 object
* c10 objects hold a strong reference to their Python objects if there's at least one other reference to the c10 object

This is implemented in `intrusive_ptr`:

* The top most bit (`kHasPyObject`) from the weakref count is now used to indicate if the `intrusive_ptr_target` has an associated PyObject. So `kHasPyObject` is one bit, the weakref count is now 31 bits and the strong refcount remains 32 bits.
* When the reference count increases from one to two and `kHasPyObject` is set, we incref the associated Python object to ensure that it's kept alive.
* When the reference count decreases from two to one (i.e., there are no C++ reference to the `intrusive_ptr_target` other than from the Python object), we decre the associated Python object to break the cycle.

Other benefits:

* We can delete a lot of the copypasta from Python internal `subtype_dealloc`
* This fixes the weakref and GC bugs we had in the previous scheme. Python weakrefs on Tensors and Storages should just work as expected now.

Risks:

* Extra branch for reference count operations on `intrusive_ptr<TensorImpl>`, `intrusive_ptr<StorageImpl>`, and the generic `intrusive_ptr<intrusive_ptr_target>` even when we're not using Python.
* It's a big change

Second attempt at #166342
2025-11-11 14:49:47 +00:00

78 lines
2.3 KiB
C++

#pragma once
#include <c10/core/impl/HermeticPyObjectTLS.h>
#include <c10/core/impl/PyInterpreter.h>
#include <c10/core/impl/PyInterpreterHooks.h>
#include <c10/util/python_stub.h>
#include <optional>
#include <atomic>
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10::impl {
struct C10_API PyObjectSlot {
public:
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
// Query the PyObject interpreter. This may return null if there is no
// interpreter.
PyInterpreter* pyobj_interpreter() const {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyInterpreter& load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
TORCH_INTERNAL_ASSERT(
interpreter, "cannot access PyObject for Tensor - no interpreter set");
return *interpreter;
}
PyObject* load_pyobj() const {
return pyobj_.load(std::memory_order_acquire);
}
bool has_unique_reference() const {
PyObject* pyobj = load_pyobj();
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
}
void clear() {
pyobj_.store(nullptr, std::memory_order_relaxed);
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
}
// Non thread-safe swap
void swap(PyObjectSlot& other) noexcept {
PyInterpreter* tmp_interpreter =
pyobj_interpreter_.load(std::memory_order_relaxed);
pyobj_interpreter_.store(
other.pyobj_interpreter_.load(std::memory_order_relaxed),
std::memory_order_relaxed);
other.pyobj_interpreter_.store(tmp_interpreter, std::memory_order_relaxed);
PyObject* tmp_pyobj = pyobj_.load(std::memory_order_relaxed);
pyobj_.store(
other.pyobj_.load(std::memory_order_relaxed),
std::memory_order_relaxed);
other.pyobj_.store(tmp_pyobj, std::memory_order_relaxed);
}
private:
// This is now always the global interpreter if the PyObject is set.
// Maybe we can remove this field some day...
std::atomic<PyInterpreter*> pyobj_interpreter_;
// The PyObject representing this Tensor or nullptr. Ownership is managed
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
// reference is already dead.
std::atomic<PyObject*> pyobj_;
friend class torch::utils::PyObjectPreservation;
};
} // namespace c10::impl