mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 21:59:56 +08:00
Make the PyObject preservation scheme thread-safe with free threaded (nogil) Python. The general idea is: * Python Tensor and Storage objects always hold a strong reference to their underlying c10 object * c10 objects hold a strong reference to their Python objects if there's at least one other reference to the c10 object This is implemented in `intrusive_ptr`: * The top most bit (`kHasPyObject`) from the weakref count is now used to indicate if the `intrusive_ptr_target` has an associated PyObject. So `kHasPyObject` is one bit, the weakref count is now 31 bits and the strong refcount remains 32 bits. * When the reference count increases from one to two and `kHasPyObject` is set, we incref the associated Python object to ensure that it's kept alive. * When the reference count decreases from two to one (i.e., there are no C++ reference to the `intrusive_ptr_target` other than from the Python object), we decre the associated Python object to break the cycle. Other benefits: * We can delete a lot of the copypasta from Python internal `subtype_dealloc` * This fixes the weakref and GC bugs we had in the previous scheme. Python weakrefs on Tensors and Storages should just work as expected now. Risks: * Extra branch for reference count operations on `intrusive_ptr<TensorImpl>`, `intrusive_ptr<StorageImpl>`, and the generic `intrusive_ptr<intrusive_ptr_target>` even when we're not using Python. * It's a big change Second attempt at #166342
78 lines
2.3 KiB
C++
78 lines
2.3 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/impl/HermeticPyObjectTLS.h>
|
|
#include <c10/core/impl/PyInterpreter.h>
|
|
#include <c10/core/impl/PyInterpreterHooks.h>
|
|
#include <c10/util/python_stub.h>
|
|
#include <optional>
|
|
|
|
#include <atomic>
|
|
|
|
namespace torch::utils {
|
|
class PyObjectPreservation;
|
|
}
|
|
|
|
namespace c10::impl {
|
|
|
|
struct C10_API PyObjectSlot {
|
|
public:
|
|
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
|
|
|
// Query the PyObject interpreter. This may return null if there is no
|
|
// interpreter.
|
|
PyInterpreter* pyobj_interpreter() const {
|
|
return pyobj_interpreter_.load(std::memory_order_acquire);
|
|
}
|
|
|
|
PyInterpreter& load_pyobj_interpreter() const {
|
|
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
|
TORCH_INTERNAL_ASSERT(
|
|
interpreter, "cannot access PyObject for Tensor - no interpreter set");
|
|
return *interpreter;
|
|
}
|
|
|
|
PyObject* load_pyobj() const {
|
|
return pyobj_.load(std::memory_order_acquire);
|
|
}
|
|
|
|
bool has_unique_reference() const {
|
|
PyObject* pyobj = load_pyobj();
|
|
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
|
|
}
|
|
|
|
void clear() {
|
|
pyobj_.store(nullptr, std::memory_order_relaxed);
|
|
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
|
|
}
|
|
|
|
// Non thread-safe swap
|
|
void swap(PyObjectSlot& other) noexcept {
|
|
PyInterpreter* tmp_interpreter =
|
|
pyobj_interpreter_.load(std::memory_order_relaxed);
|
|
pyobj_interpreter_.store(
|
|
other.pyobj_interpreter_.load(std::memory_order_relaxed),
|
|
std::memory_order_relaxed);
|
|
other.pyobj_interpreter_.store(tmp_interpreter, std::memory_order_relaxed);
|
|
|
|
PyObject* tmp_pyobj = pyobj_.load(std::memory_order_relaxed);
|
|
pyobj_.store(
|
|
other.pyobj_.load(std::memory_order_relaxed),
|
|
std::memory_order_relaxed);
|
|
other.pyobj_.store(tmp_pyobj, std::memory_order_relaxed);
|
|
}
|
|
|
|
private:
|
|
// This is now always the global interpreter if the PyObject is set.
|
|
// Maybe we can remove this field some day...
|
|
std::atomic<PyInterpreter*> pyobj_interpreter_;
|
|
|
|
// The PyObject representing this Tensor or nullptr. Ownership is managed
|
|
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
|
|
// reference is already dead.
|
|
std::atomic<PyObject*> pyobj_;
|
|
|
|
friend class torch::utils::PyObjectPreservation;
|
|
};
|
|
|
|
} // namespace c10::impl
|