mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 22:25:03 +08:00
Compare commits
4 Commits
ruisi/fix_
...
ciflow/h10
| Author | SHA1 | Date | |
|---|---|---|---|
| 2ee209ebd7 | |||
| 2aba180114 | |||
| 45b2c3d312 | |||
| 5b1e112cf9 |
@ -245,6 +245,9 @@ class TORCH_API TensorBase {
|
||||
size_t weak_use_count() const noexcept {
|
||||
return impl_.weak_use_count();
|
||||
}
|
||||
bool is_uniquely_owned() const noexcept {
|
||||
return impl_.is_uniquely_owned();
|
||||
}
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
|
||||
@ -10,6 +10,13 @@
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
ignore_empty_generic_uninitialised_conditional_jump
|
||||
Memcheck:Cond
|
||||
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
Cond_cuda
|
||||
Memcheck:Cond
|
||||
|
||||
@ -44,7 +44,7 @@ struct C10_API SafePyObject {
|
||||
(*other.pyinterpreter_)->incref(other.data_);
|
||||
}
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
}
|
||||
data_ = other.data_;
|
||||
pyinterpreter_ = other.pyinterpreter_;
|
||||
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
|
||||
|
||||
~SafePyObject() {
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() {
|
||||
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
||||
}
|
||||
|
||||
void StorageImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void StorageImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool StorageImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||
// Allowlist verification.
|
||||
// Only if the devicetype is in the allowlist,
|
||||
|
||||
@ -105,6 +105,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
data_ptr_.clear();
|
||||
}
|
||||
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
size_t nbytes() const {
|
||||
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
|
||||
TORCH_CHECK(!size_bytes_is_heap_allocated_);
|
||||
@ -370,4 +376,18 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
|
||||
bool resizable,
|
||||
std::optional<at::Device> device_opt);
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <class T>
|
||||
struct TargetTraits<
|
||||
T,
|
||||
std::enable_if_t<
|
||||
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -277,7 +277,6 @@ void TensorImpl::release_resources() {
|
||||
if (storage_) {
|
||||
storage_ = {};
|
||||
}
|
||||
pyobj_slot_.maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
@ -989,6 +988,30 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
|
||||
}
|
||||
}
|
||||
|
||||
void TensorImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void TensorImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool TensorImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -2178,6 +2178,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
private:
|
||||
// See NOTE [std::optional operator usage in CUDA]
|
||||
// We probably don't want to expose this publicly until
|
||||
@ -3079,6 +3085,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
friend class C10_TensorImpl_Size_Check_Dummy_Class;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <class T>
|
||||
struct TargetTraits<
|
||||
T,
|
||||
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Note [TensorImpl size constraints]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Changed the size of TensorImpl? If the size went down, good for
|
||||
|
||||
@ -11,8 +11,11 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
|
||||
void incref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
|
||||
} // do nothing
|
||||
void decref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
#define PANIC(m) \
|
||||
TORCH_INTERNAL_ASSERT( \
|
||||
@ -20,6 +23,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
"attempted to call " #m \
|
||||
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
|
||||
|
||||
size_t refcnt(PyObject* pyobj) const override {
|
||||
PANIC(refcnt);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
|
||||
PANIC(detach);
|
||||
}
|
||||
|
||||
@ -18,6 +18,9 @@ namespace c10 {
|
||||
struct IValue;
|
||||
class OperatorHandle;
|
||||
struct TensorImpl;
|
||||
namespace impl {
|
||||
struct PyObjectSlot;
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::jit {
|
||||
@ -126,9 +129,12 @@ struct C10_API PyInterpreterVTable {
|
||||
|
||||
// Run Py_INCREF on a PyObject.
|
||||
virtual void incref(PyObject* pyobj) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
||||
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
|
||||
virtual void decref(PyObject* pyobj) const = 0;
|
||||
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
|
||||
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
|
||||
// Run Py_REFCNT on a PyObject.
|
||||
virtual size_t refcnt(PyObject* pyobj) const = 0;
|
||||
|
||||
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
||||
// detach, which will also arrange for the PyObject to get copied in this
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
#include <c10/core/impl/PyObjectSlot.h>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
PyObjectSlot::~PyObjectSlot() {
|
||||
maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
void PyObjectSlot::maybe_destroy_pyobj() {
|
||||
if (owns_pyobj()) {
|
||||
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
||||
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
||||
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
|
||||
// NB: this destructor can only be entered when there are no
|
||||
// references to this C++ object (obviously), NOR any references
|
||||
// to the PyObject (if there are references to the PyObject,
|
||||
// then the PyObject holds an owning reference to the tensor).
|
||||
// So it is OK to clear pyobj_ here as it is impossible for it to
|
||||
// be used again (modulo weak reference races)
|
||||
pyobj_ = nullptr; // for safety
|
||||
}
|
||||
}
|
||||
|
||||
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<PyObject*>(
|
||||
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
|
||||
}
|
||||
|
||||
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter) {
|
||||
return *interpreter;
|
||||
}
|
||||
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
|
||||
}
|
||||
|
||||
bool PyObjectSlot::owns_pyobj() {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
|
||||
}
|
||||
|
||||
void PyObjectSlot::set_owns_pyobj(bool b) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
pyobj_ = reinterpret_cast<PyObject*>(
|
||||
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
@ -8,117 +8,58 @@
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace torch::utils {
|
||||
class PyObjectPreservation;
|
||||
}
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
struct C10_API PyObjectSlot {
|
||||
public:
|
||||
PyObjectSlot();
|
||||
|
||||
~PyObjectSlot();
|
||||
|
||||
void maybe_destroy_pyobj();
|
||||
|
||||
// Associate the TensorImpl with the specified PyObject, and, if necessary,
|
||||
// also tag the interpreter.
|
||||
//
|
||||
// NB: This lives in a header so that we can inline away the switch on status
|
||||
//
|
||||
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
|
||||
// PyObject if necessary!
|
||||
void init_pyobj(PyObject* pyobj) {
|
||||
pyobj_interpreter_.store(
|
||||
getGlobalPyInterpreter(), std::memory_order_relaxed);
|
||||
pyobj_ = pyobj;
|
||||
}
|
||||
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
// Query the PyObject interpreter. This may return null if there is no
|
||||
// interpreter. This is racy!
|
||||
PyInterpreter* pyobj_interpreter();
|
||||
|
||||
PyObject* _unchecked_untagged_pyobj() const;
|
||||
|
||||
// Test the interpreter tag. If tagged for the current interpreter, return
|
||||
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||
//
|
||||
// If `ignore_hermetic_tls` is false and this function is called from a
|
||||
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
|
||||
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
|
||||
// context is ignored, allowing you to check the interpreter tag of a
|
||||
// nonhermetic PyObject from within a hermetic context. This is necessary
|
||||
// because there are some cases where the deallocator function of a
|
||||
// nonhermetic PyObject is called from within a hermetic context, so it must
|
||||
// be properly treated as a nonhermetic PyObject.
|
||||
//
|
||||
// NB: this lives in header so that we can avoid actually creating the
|
||||
// std::optional
|
||||
|
||||
// @todo alban: I'm not too sure what's going on here, we can probably delete
|
||||
// it but it's worthwhile making sure
|
||||
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
|
||||
impl::PyInterpreter* interpreter =
|
||||
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
// interpreter.
|
||||
PyInterpreter* pyobj_interpreter() const {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
PyInterpreter& load_pyobj_interpreter() const;
|
||||
PyInterpreter& load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
interpreter, "cannot access PyObject for Tensor - no interpreter set");
|
||||
return *interpreter;
|
||||
}
|
||||
|
||||
bool owns_pyobj();
|
||||
PyObject* load_pyobj() const {
|
||||
return pyobj_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
void set_owns_pyobj(bool b);
|
||||
void store_pyobj(PyObject* obj) {
|
||||
pyobj_.store(obj, std::memory_order_release);
|
||||
}
|
||||
|
||||
bool has_unique_reference() const {
|
||||
PyObject* pyobj = load_pyobj();
|
||||
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
pyobj_.store(nullptr, std::memory_order_relaxed);
|
||||
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
private:
|
||||
// This field contains the interpreter tag for this object. See
|
||||
// Note [Python interpreter tag] for general context
|
||||
//
|
||||
// Note [Memory ordering on Python interpreter tag]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// What memory_order do we need when accessing this atomic? We don't
|
||||
// need a single total modification order (as provided by
|
||||
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
|
||||
// transition from -1 to some positive integer and never changes afterwards.
|
||||
// Because there is only one modification, it trivially already has a total
|
||||
// modification order (e.g., we don't need fences or locked instructions on
|
||||
// x86)
|
||||
//
|
||||
// In fact, one could make a reasonable argument that relaxed reads are OK,
|
||||
// due to the presence of external locking (GIL) to ensure that interactions
|
||||
// with other data structures are still correctly synchronized, so that
|
||||
// we fall in the "Single-Location Data Structures" case as described in
|
||||
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
|
||||
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
|
||||
// as I get the same assembly in both cases. So I just use the more
|
||||
// conservative acquire (which will impede compiler optimizations but I don't
|
||||
// care)
|
||||
// This is now always the global interpreter if the PyObject is set.
|
||||
// Maybe we can remove this field some day...
|
||||
std::atomic<PyInterpreter*> pyobj_interpreter_;
|
||||
|
||||
// This field contains a reference to a PyObject representing this Tensor.
|
||||
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
|
||||
// PyObject for it and set this field. This field does not have to be
|
||||
// protected by an atomic as it is only allowed to be accessed when you hold
|
||||
// the GIL, or during destruction of the tensor.
|
||||
//
|
||||
// When a PyObject dies, you are obligated to clear this field
|
||||
// (otherwise, you will try to use-after-free the pyobj); this currently
|
||||
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
|
||||
//
|
||||
// NB: Ordinarily, this should not be a strong reference, as if the
|
||||
// PyObject owns the Tensor, this would create a reference cycle.
|
||||
// However, sometimes this ownership flips. To track who owns
|
||||
// who, this has a single pointer tag indicating whether or not the
|
||||
// C++ object owns the PyObject (the common case, zero, means PyObject
|
||||
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
|
||||
// or check_pyobj for checked access. See references to PyObject
|
||||
// resurrection in torch/csrc/autograd/python_variable.cpp
|
||||
PyObject* pyobj_;
|
||||
// The PyObject representing this Tensor or nullptr. Ownership is managed
|
||||
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
|
||||
// reference is already dead.
|
||||
std::atomic<PyObject*> pyobj_;
|
||||
|
||||
friend class torch::utils::PyObjectPreservation;
|
||||
};
|
||||
|
||||
} // namespace c10::impl
|
||||
|
||||
@ -12,6 +12,10 @@ template <typename, typename...>
|
||||
class class_;
|
||||
}
|
||||
|
||||
namespace torch::utils {
|
||||
class PyObjectPreservation;
|
||||
}
|
||||
|
||||
namespace c10 {
|
||||
class intrusive_ptr_target;
|
||||
namespace raw {
|
||||
@ -33,6 +37,8 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
|
||||
constexpr uint64_t kReferenceCountOne = 1;
|
||||
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
|
||||
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
|
||||
// Indicates whether the object has a PyObject wrapper.
|
||||
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
|
||||
|
||||
template <class TTarget>
|
||||
struct intrusive_target_default_null_type final {
|
||||
@ -55,7 +61,11 @@ inline uint32_t refcount(uint64_t combined_refcount) {
|
||||
}
|
||||
|
||||
inline uint32_t weakcount(uint64_t combined_refcount) {
|
||||
return static_cast<uint32_t>(combined_refcount >> 32);
|
||||
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32);
|
||||
}
|
||||
|
||||
inline bool has_pyobject(uint64_t combined_refcount) {
|
||||
return (combined_refcount & kHasPyObject) != 0;
|
||||
}
|
||||
|
||||
// The only requirement for refcount increment is that it happens-before
|
||||
@ -66,12 +76,6 @@ inline uint64_t atomic_combined_refcount_increment(
|
||||
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
|
||||
}
|
||||
|
||||
inline uint32_t atomic_refcount_increment(
|
||||
std::atomic<uint64_t>& combined_refcount) {
|
||||
return detail::refcount(atomic_combined_refcount_increment(
|
||||
combined_refcount, kReferenceCountOne));
|
||||
}
|
||||
|
||||
inline uint32_t atomic_weakcount_increment(
|
||||
std::atomic<uint64_t>& combined_refcount) {
|
||||
return detail::weakcount(atomic_combined_refcount_increment(
|
||||
@ -99,6 +103,11 @@ inline uint32_t atomic_weakcount_decrement(
|
||||
combined_refcount, kWeakReferenceCountOne));
|
||||
}
|
||||
|
||||
template <class T, class = void>
|
||||
struct TargetTraits {
|
||||
static constexpr bool can_have_pyobject = false;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
@ -155,6 +164,23 @@ class C10_API intrusive_ptr_target {
|
||||
// we can atomically operate on both at the same time for performance
|
||||
// and defined behaviors.
|
||||
//
|
||||
// Note [PyObject preservation for Tensor and Storages]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// intrusive_ptr has special support for preserving PyObject wrappers
|
||||
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
|
||||
// the combined_refcount_ is used to indicate whether the object has a
|
||||
// PyObject wrapper.
|
||||
//
|
||||
// - The PyObject, if it exists, holds a strong reference to the
|
||||
// intrusive_ptr_target.
|
||||
//
|
||||
// - When the refcount goes from 1 to 2, we incref the PyObject.
|
||||
//
|
||||
// - When the refcount goes from 2 to 1, we decref the PyObject.
|
||||
//
|
||||
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
|
||||
// are other C++ references to the intrusive_ptr_target.
|
||||
|
||||
mutable std::atomic<uint64_t> combined_refcount_;
|
||||
static_assert(sizeof(std::atomic<uint64_t>) == 8);
|
||||
static_assert(alignof(std::atomic<uint64_t>) == 8);
|
||||
@ -172,6 +198,8 @@ class C10_API intrusive_ptr_target {
|
||||
template <typename T>
|
||||
friend struct ExclusivelyOwnedTensorTraits;
|
||||
|
||||
friend class torch::utils::PyObjectPreservation;
|
||||
|
||||
protected:
|
||||
// protected destructor. We never want to destruct intrusive_ptr_target*
|
||||
// directly.
|
||||
@ -255,6 +283,16 @@ class C10_API intrusive_ptr_target {
|
||||
*/
|
||||
virtual void release_resources() {}
|
||||
|
||||
/**
|
||||
* These two methods are called when the refcount transitions between one
|
||||
* and two and the object has a PyObject wrapper.
|
||||
*/
|
||||
virtual void incref_pyobject() const {}
|
||||
virtual void decref_pyobject() const {}
|
||||
virtual bool try_incref_pyobject() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
|
||||
return detail::refcount(combined_refcount_.load(order));
|
||||
}
|
||||
@ -265,6 +303,19 @@ class C10_API intrusive_ptr_target {
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <>
|
||||
struct TargetTraits<c10::intrusive_ptr_target> {
|
||||
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
|
||||
// or StorageImpl, so we have to allow for PyObject support.
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <class TTarget, class NullType>
|
||||
class weak_intrusive_ptr;
|
||||
|
||||
@ -314,18 +365,34 @@ class intrusive_ptr final {
|
||||
|
||||
void retain_() {
|
||||
if (target_ != NullType::singleton()) {
|
||||
uint32_t new_refcount =
|
||||
detail::atomic_refcount_increment(target_->combined_refcount_);
|
||||
uint64_t combined = detail::atomic_combined_refcount_increment(
|
||||
target_->combined_refcount_, detail::kReferenceCountOne);
|
||||
uint32_t new_refcount = detail::refcount(combined);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
new_refcount != 1,
|
||||
"intrusive_ptr: Cannot increase refcount after it reached zero.");
|
||||
|
||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
// If the refcount transitioned from 1 to 2, we need to incref the
|
||||
// PyObject. In other words, we need to ensure that the PyObject stays
|
||||
// alive now that we have a C++ reference to this object in addition to
|
||||
// the PyObject itself.
|
||||
if (C10_UNLIKELY(
|
||||
detail::has_pyobject(combined) &&
|
||||
detail::refcount(combined) == 2)) {
|
||||
target_->incref_pyobject();
|
||||
}
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!detail::has_pyobject(combined),
|
||||
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reset_() noexcept {
|
||||
if (target_ != NullType::singleton()) {
|
||||
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
|
||||
detail::kUniqueRef) {
|
||||
if (is_uniquely_owned()) {
|
||||
// Both counts are 1, so there are no weak references and
|
||||
// we are releasing the last strong reference. No other
|
||||
// threads can observe the effects of this target_ deletion
|
||||
@ -337,9 +404,10 @@ class intrusive_ptr final {
|
||||
|
||||
auto combined_refcount = detail::atomic_combined_refcount_decrement(
|
||||
target_->combined_refcount_, detail::kReferenceCountOne);
|
||||
if (detail::refcount(combined_refcount) == 0) {
|
||||
bool should_delete =
|
||||
(combined_refcount == detail::kWeakReferenceCountOne);
|
||||
uint32_t new_refcount = detail::refcount(combined_refcount);
|
||||
bool has_pyobject = detail::has_pyobject(combined_refcount);
|
||||
if (new_refcount == 0) {
|
||||
bool should_delete = detail::weakcount(combined_refcount) == 1;
|
||||
// See comment above about weakcount. As long as refcount>0,
|
||||
// weakcount is one larger than the actual number of weak references.
|
||||
// So we need to decrement it here.
|
||||
@ -356,6 +424,18 @@ class intrusive_ptr final {
|
||||
if (should_delete) {
|
||||
delete target_;
|
||||
}
|
||||
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
// If the refcount transitioned from 2 to 1, we need to decref the
|
||||
// PyObject. In other words, we don't want to keep the PyObject alive if
|
||||
// there are no C++ references to this object other than the PyObject
|
||||
// itself.
|
||||
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
|
||||
target_->decref_pyobject();
|
||||
}
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
!has_pyobject,
|
||||
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -522,6 +602,16 @@ class intrusive_ptr final {
|
||||
return use_count() == 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stronger than unique() in that it must not have any weakrefs as well.
|
||||
*/
|
||||
bool is_uniquely_owned() const noexcept {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
|
||||
uint64_t combined =
|
||||
target_->combined_refcount_.load(std::memory_order_acquire);
|
||||
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an owning (!) pointer to the underlying object and makes the
|
||||
* intrusive_ptr instance invalid. That means the refcount is not decreased.
|
||||
@ -932,6 +1022,7 @@ class weak_intrusive_ptr final {
|
||||
if (target_ == NullType::singleton()) {
|
||||
return intrusive_ptr<TTarget, NullType>();
|
||||
} else {
|
||||
bool increfed = false;
|
||||
auto combined_refcount =
|
||||
target_->combined_refcount_.load(std::memory_order_relaxed);
|
||||
do {
|
||||
@ -940,12 +1031,31 @@ class weak_intrusive_ptr final {
|
||||
// Return nullptr.
|
||||
return intrusive_ptr<TTarget, NullType>();
|
||||
}
|
||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
if (detail::has_pyobject(combined_refcount) &&
|
||||
detail::refcount(combined_refcount) == 1 && !increfed) {
|
||||
// Object has a python wrapper with no other C++ references.
|
||||
// We need to to incref the Python object before we acquire a
|
||||
// strong reference to the C++ object to avoid a situation
|
||||
// where the Python object is deallocated concurrently.
|
||||
if (!target_->try_incref_pyobject()) {
|
||||
return intrusive_ptr<TTarget, NullType>();
|
||||
}
|
||||
increfed = true;
|
||||
}
|
||||
}
|
||||
} while (!target_->combined_refcount_.compare_exchange_weak(
|
||||
combined_refcount,
|
||||
combined_refcount + detail::kReferenceCountOne,
|
||||
std::memory_order_acquire,
|
||||
std::memory_order_relaxed));
|
||||
|
||||
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
|
||||
if (increfed && detail::refcount(combined_refcount) != 1) {
|
||||
target_->decref_pyobject();
|
||||
}
|
||||
}
|
||||
|
||||
return intrusive_ptr<TTarget, NullType>(
|
||||
target_, raw::DontIncreaseRefcount{});
|
||||
}
|
||||
@ -1060,7 +1170,18 @@ namespace intrusive_ptr {
|
||||
// NullType::singleton to this function
|
||||
inline void incref(intrusive_ptr_target* self) {
|
||||
if (self) {
|
||||
detail::atomic_refcount_increment(self->combined_refcount_);
|
||||
uint64_t combined = detail::atomic_combined_refcount_increment(
|
||||
self->combined_refcount_, detail::kReferenceCountOne);
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
if (C10_UNLIKELY(
|
||||
detail::has_pyobject(combined) &&
|
||||
detail::refcount(combined) == 2)) {
|
||||
self->incref_pyobject();
|
||||
}
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
113
docs/source/accelerator/device.md
Normal file
113
docs/source/accelerator/device.md
Normal file
@ -0,0 +1,113 @@
|
||||
# Device Management
|
||||
|
||||
## Background
|
||||
|
||||
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
|
||||
|
||||
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
|
||||
|
||||
## Design
|
||||
|
||||
Accelerator vendors need to implement these core functions:
|
||||
|
||||
| Function Name | Description | Application Scenarios |
|
||||
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
|
||||
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
|
||||
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
|
||||
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
|
||||
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
|
||||
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
|
||||
|
||||
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
|
||||
|
||||
## Implementation
|
||||
|
||||
This section shows how to implement device management using `set_device` as an example. The implementation requires:
|
||||
1. C++ wrappers around the device runtime
|
||||
2. Python bindings to expose the C++ functions
|
||||
3. User-friendly Python APIs
|
||||
|
||||
### C++ Side
|
||||
|
||||
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
### Binding
|
||||
|
||||
Expose the C++ functions to Python using pybind11:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
:linenos:
|
||||
```
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
|
||||
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
|
||||
:linenos:
|
||||
:emphasize-lines: 5
|
||||
```
|
||||
|
||||
### Python Side
|
||||
|
||||
Wrap the C++ bindings with user-friendly Python functions:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
|
||||
:language: python
|
||||
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
:linenos:
|
||||
```
|
||||
|
||||
Here's the complete mapping from C++ to Python:
|
||||
|
||||
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
|
||||
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
|
||||
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
|
||||
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
|
||||
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
|
||||
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
|
||||
|
||||
## Guard
|
||||
|
||||
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
|
||||
|
||||
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
|
||||
|
||||
```{eval-rst}
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
:linenos:
|
||||
```
|
||||
|
||||
**What needs to be implemented:**
|
||||
|
||||
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
|
||||
2. **getDevice()**: Get the current device
|
||||
3. **setDevice()**: Set the active device
|
||||
4. **Type checking**: Validate that device type matches the backend
|
||||
|
||||
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
|
||||
|
||||
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"
|
||||
@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
|
||||
:glob:
|
||||
:maxdepth: 1
|
||||
|
||||
device
|
||||
hooks
|
||||
autoload
|
||||
operators
|
||||
|
||||
@ -4,17 +4,12 @@
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
void orCheckFail(
|
||||
const char* func,
|
||||
const char* file,
|
||||
uint32_t line,
|
||||
const char* msg = "");
|
||||
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
|
||||
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (__err != orSuccess) { \
|
||||
orCheckFail( \
|
||||
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
#define OPENREG_CHECK(EXPR, ...) \
|
||||
do { \
|
||||
const orError_t __err = EXPR; \
|
||||
if (C10_UNLIKELY(__err != orSuccess)) { \
|
||||
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
#include <c10/util/Exception.h>
|
||||
#include <include/openreg.h>
|
||||
|
||||
#include "OpenRegException.h"
|
||||
@ -9,21 +10,22 @@ orError_t GetDeviceCount(int* dev_count) {
|
||||
return orGetDeviceCount(dev_count);
|
||||
}
|
||||
|
||||
orError_t GetDevice(c10::DeviceIndex* device) {
|
||||
orError_t GetDevice(DeviceIndex* device) {
|
||||
int tmp_device = -1;
|
||||
auto err = orGetDevice(&tmp_device);
|
||||
*device = static_cast<c10::DeviceIndex>(tmp_device);
|
||||
*device = static_cast<DeviceIndex>(tmp_device);
|
||||
return err;
|
||||
}
|
||||
|
||||
orError_t SetDevice(c10::DeviceIndex device) {
|
||||
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
|
||||
orError_t SetDevice(DeviceIndex device) {
|
||||
int cur_device = -1;
|
||||
orGetDevice(&cur_device);
|
||||
OPENREG_CHECK(orGetDevice(&cur_device));
|
||||
if (device == cur_device) {
|
||||
return orSuccess;
|
||||
}
|
||||
return orSetDevice(device);
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
|
||||
|
||||
int device_count_impl() {
|
||||
int count = 0;
|
||||
@ -31,34 +33,37 @@ int device_count_impl() {
|
||||
return count;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
|
||||
OPENREG_EXPORT DeviceIndex device_count() noexcept {
|
||||
// initialize number of devices only once
|
||||
static int count = []() {
|
||||
try {
|
||||
auto result = device_count_impl();
|
||||
TORCH_CHECK(
|
||||
result <= std::numeric_limits<c10::DeviceIndex>::max(),
|
||||
result <= std::numeric_limits<DeviceIndex>::max(),
|
||||
"Too many devices, DeviceIndex overflowed");
|
||||
return result;
|
||||
} catch (const c10::Error& ex) {
|
||||
} catch (const Error& ex) {
|
||||
// We don't want to fail, but still log the warning
|
||||
// msg() returns the message without the stack trace
|
||||
TORCH_WARN("Device initialization: ", ex.msg());
|
||||
return 0;
|
||||
}
|
||||
}();
|
||||
return static_cast<c10::DeviceIndex>(count);
|
||||
return static_cast<DeviceIndex>(count);
|
||||
}
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device() {
|
||||
c10::DeviceIndex cur_device = -1;
|
||||
GetDevice(&cur_device);
|
||||
OPENREG_EXPORT DeviceIndex current_device() {
|
||||
DeviceIndex cur_device = -1;
|
||||
OPENREG_CHECK(GetDevice(&cur_device));
|
||||
return cur_device;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
|
||||
SetDevice(device);
|
||||
// LITERALINCLUDE START: OPENREG set_device FUNCTION
|
||||
OPENREG_EXPORT void set_device(DeviceIndex device) {
|
||||
check_device_index(device);
|
||||
OPENREG_CHECK(SetDevice(device));
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG set_device FUNCTION
|
||||
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
int current_device = -1;
|
||||
@ -71,4 +76,8 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
|
||||
return current_device;
|
||||
}
|
||||
|
||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
|
||||
check_device_index(to_device);
|
||||
return ExchangeDevice(to_device);
|
||||
}
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -9,10 +9,20 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT c10::DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
|
||||
OPENREG_EXPORT DeviceIndex device_count() noexcept;
|
||||
OPENREG_EXPORT DeviceIndex current_device();
|
||||
OPENREG_EXPORT void set_device(DeviceIndex device);
|
||||
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
|
||||
|
||||
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
|
||||
|
||||
static inline void check_device_index(int64_t device) {
|
||||
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
|
||||
"The device index is out of range. It must be in [0, ",
|
||||
static_cast<int>(c10::openreg::device_count()),
|
||||
"), but got ",
|
||||
static_cast<int>(device),
|
||||
".");
|
||||
}
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
|
||||
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
|
||||
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
|
||||
|
||||
} // namespace c10::openreg
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
|
||||
namespace c10::openreg {
|
||||
|
||||
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
|
||||
|
||||
@ -58,6 +59,7 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
|
||||
set_device(d.index());
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
|
||||
|
||||
/**
|
||||
* Set the current device to c10::Device, without checking for errors
|
||||
|
||||
@ -27,6 +27,10 @@ class TestDevice(TestCase):
|
||||
self.assertEqual(torch.accelerator.current_device_index(), 1)
|
||||
self.assertEqual(torch.accelerator.current_device_index(), device)
|
||||
|
||||
def test_invalid_device_index(self):
|
||||
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
|
||||
torch.accelerator.set_device_index(2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -34,18 +34,21 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
|
||||
}
|
||||
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
|
||||
|
||||
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
|
||||
|
||||
PyObject* _setDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
|
||||
auto device = THPUtils_unpackLong(arg);
|
||||
|
||||
auto device = THPUtils_unpackDeviceIndex(arg);
|
||||
torch::utils::device_lazy_init(at::kPrivateUse1);
|
||||
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
|
||||
c10::openreg::set_device(device);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
|
||||
|
||||
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");
|
||||
|
||||
@ -41,8 +41,13 @@ def current_device():
|
||||
return torch_openreg._C._get_device()
|
||||
|
||||
|
||||
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
|
||||
def set_device(device) -> None:
|
||||
return torch_openreg._C._set_device(device)
|
||||
if device >= 0:
|
||||
torch_openreg._C._set_device(device)
|
||||
|
||||
|
||||
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
|
||||
|
||||
|
||||
def init():
|
||||
|
||||
@ -952,7 +952,9 @@ User code traceback:
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: skip: from user code at:
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
assert x is None
|
||||
""",
|
||||
@ -1078,6 +1080,88 @@ from user code:
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(verbose=True)
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skipped_frame_with_verbose_traceback(self, records):
|
||||
def fn(x):
|
||||
with GenericCtxMgr():
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break under GenericContextWrappingVariable
|
||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skip_frame_in_loop_message(self, records):
|
||||
def fn(x):
|
||||
for i in range(2):
|
||||
with GenericCtxMgr():
|
||||
if x.sum() > 0:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
torch.compile(fn, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
|
||||
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
|
||||
The graph break occurred in the following user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
if x.sum() > 0:
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(dynamo=logging.DEBUG)
|
||||
def test_skip_frame_empty_function_message(self, records):
|
||||
def empty_fn(x):
|
||||
pass
|
||||
|
||||
torch.compile(empty_fn, backend="eager")(torch.randn(3))
|
||||
skip_messages = [
|
||||
r
|
||||
for r in records
|
||||
if "intentionally decided to skip the frame" in r.getMessage()
|
||||
]
|
||||
self.assertEqual(len(skip_messages), 1)
|
||||
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
|
||||
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
|
||||
|
||||
self.assertExpectedInline(
|
||||
msg,
|
||||
"""\
|
||||
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
|
||||
Reason: no content in function call empty_fn test_error_messages.py N""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_nested_compile_user_frames(self, records):
|
||||
def fn(x):
|
||||
@ -1624,6 +1708,110 @@ from user code:
|
||||
)
|
||||
|
||||
|
||||
class NestedGraphBreakLoggingTests(
|
||||
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
|
||||
):
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skipped_frame_with_verbose_traceback_nested(self, records):
|
||||
global f1, f2, f3
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def f1(x):
|
||||
with GenericCtxMgr():
|
||||
torch._dynamo.graph_break()
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 2)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 3)
|
||||
|
||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
||||
Graph break under GenericContextWrappingVariable
|
||||
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
|
||||
Hint: Move the offending context manager(s) to outside the compiled region.
|
||||
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
|
||||
|
||||
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
|
||||
torch.compile(f3, backend="eager")(torch.randn(3))
|
||||
File "test_error_messages.py", line N, in f3
|
||||
return f2(x + 3)
|
||||
File "test_error_messages.py", line N, in f2
|
||||
return f1(x + 2)
|
||||
File "test_error_messages.py", line N, in f1
|
||||
torch._dynamo.graph_break()
|
||||
""",
|
||||
)
|
||||
|
||||
@make_logging_test(graph_breaks=True)
|
||||
def test_skip_frame_in_loop_message_nested(self, records):
|
||||
global f1, f2, f3
|
||||
|
||||
class GenericCtxMgr:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
pass
|
||||
|
||||
def f1(x):
|
||||
for i in range(2):
|
||||
with GenericCtxMgr():
|
||||
if x.sum() > 0:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
def f2(x):
|
||||
return f1(x + 4)
|
||||
|
||||
def f3(x):
|
||||
return f2(x + 5)
|
||||
|
||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
||||
self.assertEqual(len(records), 1)
|
||||
self.assertExpectedInline(
|
||||
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
|
||||
"""\
|
||||
Graph break in user code at test_error_messages.py:N
|
||||
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
|
||||
Data-dependent branching
|
||||
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
|
||||
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
|
||||
Hint: Use `torch.cond` to express dynamic control flow.
|
||||
|
||||
Developer debug context: attempted to jump with TensorVariable()
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
|
||||
User code traceback:
|
||||
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
|
||||
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
|
||||
File "test_error_messages.py", line N, in f3
|
||||
return f2(x + 5)
|
||||
File "test_error_messages.py", line N, in f2
|
||||
return f1(x + 4)
|
||||
File "test_error_messages.py", line N, in f1
|
||||
if x.sum() > 0:
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
|
||||
@ -14036,6 +14036,44 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
|
||||
except Exception as e:
|
||||
self.fail(f"torch.compile failed with error: {e}")
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_tensorify_track_item_symint(self):
|
||||
def _random_resize(image: torch.Tensor):
|
||||
image_metanet = image
|
||||
default_patch_size = 14
|
||||
rand_cnn_resolution = (224, 256)
|
||||
min_nump = rand_cnn_resolution[0] // default_patch_size
|
||||
max_nump = rand_cnn_resolution[1] // default_patch_size
|
||||
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
|
||||
torch._check(new_nump > 0)
|
||||
torch._check(new_nump * default_patch_size > 1)
|
||||
|
||||
image_metanet = F.interpolate(
|
||||
image_metanet,
|
||||
size=(new_nump * default_patch_size, new_nump * default_patch_size),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
img_h_new, img_w_new = image_metanet.shape[2:]
|
||||
|
||||
return (img_h_new, img_w_new), image_metanet
|
||||
|
||||
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
|
||||
|
||||
# Test the function
|
||||
input_tensor = torch.rand(1, 3, 224, 224)
|
||||
(h, w), output = _random_resize_compiled(input_tensor)
|
||||
|
||||
# Verify output properties
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], 3)
|
||||
self.assertEqual(output.shape[2], h)
|
||||
self.assertEqual(output.shape[3], w)
|
||||
self.assertTrue(h % 14 == 0)
|
||||
self.assertTrue(w % 14 == 0)
|
||||
self.assertTrue(224 <= h <= 256)
|
||||
self.assertTrue(224 <= w <= 256)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -10895,6 +10895,34 @@ get_out().sum().backward()
|
||||
|
||||
self.assertTrue(gradcheck(func, x, fast_mode=True))
|
||||
|
||||
def test_grad_thread_safety(self):
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
NUM_ITERS = 10
|
||||
NUM_THREADS = 4
|
||||
|
||||
# Concurrent calls to tensor.untyped_storage()
|
||||
def access_grad(tensor, barrier):
|
||||
barrier.wait()
|
||||
return weakref.ref(tensor.grad)
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
||||
(tensor**2).sum().backward()
|
||||
|
||||
barrier = threading.Barrier(NUM_THREADS)
|
||||
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
|
||||
futures = [
|
||||
executor.submit(access_grad, tensor, barrier)
|
||||
for _ in range(NUM_THREADS)
|
||||
]
|
||||
|
||||
# Check that all the grad tensors returned were the same
|
||||
for future in futures:
|
||||
self.assertEqual(future.result()(), tensor.grad)
|
||||
self.assertIsNotNone(tensor.grad)
|
||||
|
||||
|
||||
def index_perm_variable(shape, max_indices):
|
||||
if not isinstance(shape, tuple):
|
||||
|
||||
@ -259,7 +259,8 @@ class TestTorchDeviceType(TestCase):
|
||||
def test_storage_use_count(self, device):
|
||||
a = torch.randn(10, device=device)
|
||||
prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata)
|
||||
self.assertEqual(prev_cf, 1)
|
||||
# Two references: 'a' and the wrapper returned by untyped_storage()
|
||||
self.assertEqual(prev_cf, 2)
|
||||
b = a.view(2, 5)
|
||||
self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1)
|
||||
|
||||
@ -9324,7 +9325,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
member_var = object()
|
||||
|
||||
err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
with self.assertRaisesRegex(TypeError, err_msg):
|
||||
s0 = t0.as_subclass(BadSubTensor)
|
||||
|
||||
# FIXME: Port to a test suite that better fits slicing
|
||||
@ -10324,20 +10325,21 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||
def test_tensor_dead_weak_ref(self):
|
||||
x = torch.empty(2)
|
||||
x = torch.ones(2)
|
||||
w_x = weakref.ref(x)
|
||||
y = torch.empty(2)
|
||||
y = torch.ones(2)
|
||||
y.grad = x
|
||||
del x
|
||||
|
||||
x = w_x()
|
||||
# Ideally, x would keep the tensor live. But CPython doesn't
|
||||
# provide enough hooks to do this. So it will go dead and x
|
||||
# will transmute into an undefined tensor. Not great, but the
|
||||
# best we can do.
|
||||
# x should keep the tensor live. This didn't happen in earlier PyTorch
|
||||
# versions.
|
||||
del y
|
||||
|
||||
self.assertRaises(RuntimeError, lambda: x.sigmoid())
|
||||
self.assertEqual(2, x.sum())
|
||||
|
||||
del x
|
||||
self.assertIsNone(w_x())
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
|
||||
def test_storage_dead_weak_ref(self):
|
||||
@ -10345,16 +10347,9 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
w_x = weakref.ref(x)
|
||||
y = torch.tensor(x)
|
||||
del x
|
||||
|
||||
x = w_x()
|
||||
# Ideally, x would keep the storage live. But CPython doesn't
|
||||
# provide enough hooks to do this. So it will go dead and x
|
||||
# will transmute into storage with null StorageImpl. Not great, but the
|
||||
# best we can do.
|
||||
self.assertIsNotNone(w_x())
|
||||
del y
|
||||
|
||||
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
|
||||
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
|
||||
self.assertIsNone(w_x())
|
||||
|
||||
def test_tensor_resurrected_weak_ref(self):
|
||||
x = torch.empty(2)
|
||||
@ -10415,6 +10410,31 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
|
||||
|
||||
self.assertTrue(called)
|
||||
|
||||
def test_storage_thread_safety(self):
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
NUM_ITERS = 10
|
||||
NUM_THREADS = 4
|
||||
|
||||
# Concurrent calls to tensor.untyped_storage()
|
||||
def access_untyped_storage(tensor, barrier):
|
||||
barrier.wait()
|
||||
return weakref.ref(tensor.untyped_storage())
|
||||
|
||||
for i in range(NUM_ITERS):
|
||||
tensor = torch.tensor([1.0, 2.0, 3.0])
|
||||
barrier = threading.Barrier(NUM_THREADS)
|
||||
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
|
||||
futures = [
|
||||
executor.submit(access_untyped_storage, tensor, barrier)
|
||||
for _ in range(NUM_THREADS)
|
||||
]
|
||||
|
||||
# Check that all the storages returned were the same
|
||||
for future in futures:
|
||||
self.assertEqual(future.result()(), tensor.untyped_storage())
|
||||
|
||||
# FIXME: move to test_linalg
|
||||
@torch.inference_mode()
|
||||
def test_bmm_multithreaded(self):
|
||||
|
||||
@ -1870,7 +1870,7 @@ class ConvertFrame:
|
||||
raise
|
||||
|
||||
soft_fail = isinstance(e, Unsupported)
|
||||
|
||||
code = frame.f_code
|
||||
# This is a soft failure. In the sense, the code path reaches here
|
||||
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
|
||||
# BUILD_SET etc. In such case, we can fallback to eager without
|
||||
@ -1885,7 +1885,13 @@ class ConvertFrame:
|
||||
user_stack_formatted = "".join(
|
||||
traceback.format_list(user_stack)
|
||||
)
|
||||
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
|
||||
frame_info = exc.format_frame_info(code)
|
||||
user_stack_trace = (
|
||||
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
|
||||
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
|
||||
"The graph break occurred in the following user code:\n"
|
||||
f"{user_stack_formatted}"
|
||||
)
|
||||
torch._logging.trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
@ -1897,6 +1903,7 @@ class ConvertFrame:
|
||||
graph_break_log.debug(
|
||||
user_stack_trace,
|
||||
exc_info=True,
|
||||
stack_info=config.verbose,
|
||||
)
|
||||
|
||||
if not config.suppress_errors and not soft_fail:
|
||||
|
||||
@ -794,6 +794,38 @@ def format_error_msg_verbose(
|
||||
return msg
|
||||
|
||||
|
||||
def format_frame_info(code: types.CodeType) -> str:
|
||||
return (
|
||||
f"{getattr(code, 'co_name', '<unknown>')} "
|
||||
f"({getattr(code, 'co_filename', '<unknown>')} "
|
||||
f"line {getattr(code, 'co_firstlineno', 0)})"
|
||||
)
|
||||
|
||||
|
||||
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
|
||||
if code is not None:
|
||||
frame_info = format_frame_info(code)
|
||||
return (
|
||||
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
|
||||
f"Reason: {reason}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
|
||||
f"Reason: {reason}"
|
||||
)
|
||||
|
||||
|
||||
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
|
||||
frame_info = format_frame_info(code)
|
||||
return (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
|
||||
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
|
||||
f"{frame_summary}"
|
||||
)
|
||||
|
||||
|
||||
def format_error_msg(
|
||||
exc: Exception,
|
||||
code: types.CodeType,
|
||||
|
||||
@ -94,6 +94,8 @@ from .exc import (
|
||||
BackendCompilerFailed,
|
||||
collapse_resume_frames,
|
||||
format_graph_break_message,
|
||||
format_loop_skip_frame_message,
|
||||
format_skip_frame_message,
|
||||
get_stack_above_dynamo,
|
||||
ResumePrologueTracingError,
|
||||
StepUnsupported,
|
||||
@ -605,9 +607,9 @@ def generic_jump(
|
||||
)
|
||||
# compile a partial subgraph prefix then jump into user code
|
||||
if self.maybe_has_backedge():
|
||||
msg = (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"{self.frame_summary()}"
|
||||
msg = format_loop_skip_frame_message(
|
||||
self.f_code,
|
||||
"".join(traceback.format_list([self.frame_summary()])),
|
||||
)
|
||||
log.info(msg)
|
||||
raise exc.SkipFrame(msg)
|
||||
@ -883,9 +885,9 @@ def break_graph_if_unsupported(
|
||||
)
|
||||
|
||||
if self.maybe_has_backedge():
|
||||
msg = (
|
||||
"Skipping frame because there is a graph break in a for/while loop\n"
|
||||
f"{self.frame_summary()}"
|
||||
msg = format_loop_skip_frame_message(
|
||||
self.f_code,
|
||||
"".join(traceback.format_list([self.frame_summary()])),
|
||||
)
|
||||
log.info(msg)
|
||||
raise exc.SkipFrame(msg) from excp
|
||||
@ -4626,8 +4628,9 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
and not self.error_on_graph_break
|
||||
and not self.is_tracing_resume_prologue
|
||||
):
|
||||
raise exc.SkipFrame("because no content in function call")
|
||||
|
||||
raise exc.SkipFrame(
|
||||
format_skip_frame_message(self.f_code, "no content in function call")
|
||||
)
|
||||
self.instruction_pointer = None
|
||||
_step_logger()(
|
||||
logging.INFO,
|
||||
|
||||
@ -2248,12 +2248,15 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
|
||||
try:
|
||||
val.data_ptr() # will throw for functorch tensors
|
||||
except RuntimeError as e:
|
||||
from .exc import SkipFrame
|
||||
from .exc import format_skip_frame_message, SkipFrame
|
||||
|
||||
# This will be GradTrackingTensor/BatchedTensor/etc
|
||||
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
|
||||
raise SkipFrame(
|
||||
f"torch.compile cannot be run in context: {functorch_subclass_name}"
|
||||
format_skip_frame_message(
|
||||
None,
|
||||
f"torch.compile cannot be run in context: {functorch_subclass_name}",
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@ from torch._guards import Source
|
||||
from .. import config, graph_break_hints, polyfills, variables
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
|
||||
from ..exc import (
|
||||
format_skip_frame_message,
|
||||
get_dynamo_observed_exception,
|
||||
handle_observed_exception,
|
||||
InfiniteGeneratorError,
|
||||
@ -1652,8 +1653,13 @@ class SkipFunctionVariable(VariableTracker):
|
||||
skip_frame_msg = kwargs.get("msg")
|
||||
if skip_frame_msg:
|
||||
skip_frame_msg = skip_frame_msg.as_python_constant()
|
||||
else:
|
||||
skip_frame_msg = ""
|
||||
raise SkipFrame(
|
||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
|
||||
format_skip_frame_message(
|
||||
tx.f_code,
|
||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
|
||||
)
|
||||
)
|
||||
elif self.value is torch._dynamo.step_unsupported:
|
||||
raise StepUnsupported
|
||||
|
||||
@ -536,9 +536,14 @@ class StorageWeakRefWrapper:
|
||||
if self.extra_ref_check is not None and not self.extra_ref_check():
|
||||
return False
|
||||
|
||||
# if extra_ref_check is not None we expect an additional reference
|
||||
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
|
||||
return (stor_count - (self.extra_ref_check is not None)) == 0
|
||||
if self.extra_ref_check is not None:
|
||||
# if extra_ref_check is not None we expect two additional references:
|
||||
# - one from the Python storage object
|
||||
# - one from the cached Tensor
|
||||
stor_count -= 2
|
||||
assert stor_count >= 0
|
||||
return stor_count == 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.ref is None or self.ref.expired():
|
||||
@ -1439,7 +1444,15 @@ class CUDAGraphNode:
|
||||
self_loc = self_ref()
|
||||
if self_loc is None:
|
||||
return False
|
||||
return self_loc.get_output_refcount(i) == 2
|
||||
refcount = self_loc.get_output_refcount(i)
|
||||
# pyrefly: ignore
|
||||
if self_loc.cached_tensor_outputs[i]._use_count() > 1:
|
||||
# c10::Tensor may also holds one reference count
|
||||
assert refcount >= 3
|
||||
return refcount == 3
|
||||
else:
|
||||
assert refcount >= 2
|
||||
return refcount == 2
|
||||
|
||||
check = functools.partial(check_refcount, i=i)
|
||||
|
||||
|
||||
@ -891,10 +891,14 @@ class TorchLogsFormatter(logging.Formatter):
|
||||
# exception handling - copied from logging.Formatter.format
|
||||
s = record.message
|
||||
if record.exc_info:
|
||||
from torch._dynamo import config
|
||||
|
||||
should_format_exc = config.verbose or artifact_name != "graph_breaks"
|
||||
# Cache the traceback text to avoid converting it multiple times
|
||||
# (it's constant anyway)
|
||||
if not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if should_format_exc:
|
||||
if not record.exc_text:
|
||||
record.exc_text = self.formatException(record.exc_info)
|
||||
if record.exc_text:
|
||||
if s[-1:] != "\n":
|
||||
s = s + "\n"
|
||||
|
||||
@ -398,36 +398,27 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
|
||||
|
||||
// weak_use_count() adds 1 if use_count is non-zero
|
||||
TORCH_CHECK(
|
||||
a->cdata->weak_use_count() == 1,
|
||||
a->cdata.weak_use_count() == 1,
|
||||
"Expected no weakrefs to t1's Tensor object but got ",
|
||||
a->cdata->weak_use_count() - 1);
|
||||
a->cdata.weak_use_count() - 1);
|
||||
TORCH_CHECK(
|
||||
b->cdata->weak_use_count() == 1,
|
||||
b->cdata.weak_use_count() == 1,
|
||||
"Expected no weakrefs to t2's Tensor object but got ",
|
||||
b->cdata->weak_use_count() - 1);
|
||||
b->cdata.weak_use_count() - 1);
|
||||
|
||||
// NB: Creating local copies of *both* Tensors here ensures that they each
|
||||
// hold a strong reference to their PyObject. This avoids having to fix up
|
||||
// reference counts when we swap the PyObject slots below.
|
||||
at::Tensor tmp_a = a->cdata;
|
||||
at::Tensor tmp_b = b->cdata;
|
||||
|
||||
// Swap the Tensor Impl
|
||||
c10::MaybeOwned<at::Tensor> tmp = a->cdata;
|
||||
a->cdata = tmp_b;
|
||||
b->cdata = tmp_a;
|
||||
|
||||
// The TensorImpls contain PyObjectSlots that have a reference to the PyObject
|
||||
// associated with the TensorImpl. Swap this field as well.
|
||||
std::optional<PyObject*> mb_obj_a =
|
||||
a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
std::optional<PyObject*> mb_obj_b =
|
||||
b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
mb_obj_a.has_value() && mb_obj_b.has_value(),
|
||||
"Both tensors should have PyObjects tagged by the current python interpreter");
|
||||
TORCH_CHECK(mb_obj_a.value() == a_);
|
||||
TORCH_CHECK(mb_obj_b.value() == b_);
|
||||
|
||||
a->cdata = b->cdata;
|
||||
b->cdata = tmp;
|
||||
|
||||
a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_);
|
||||
b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_);
|
||||
// Fix up the PyObjects associated with each TensorImpl
|
||||
a->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(a_);
|
||||
b->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(b_);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
||||
@ -45,7 +45,9 @@ struct ConcretePyInterpreterVTable final
|
||||
std::string name() const override;
|
||||
|
||||
void incref(PyObject* pyobj) const override;
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
|
||||
void decref(PyObject* pyobj) const override;
|
||||
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override;
|
||||
size_t refcnt(PyObject* pyobj) const override;
|
||||
|
||||
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
|
||||
// operate upon a PyObjectSlot rather than a TensorImpl
|
||||
@ -235,53 +237,13 @@ py::object torchDispatchFromTensorImpl(
|
||||
TorchFunctionName::TorchDispatch));
|
||||
}
|
||||
|
||||
// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||
// Before calling PyInterpreter::decref, we must statically know if the
|
||||
// pyobj has a PyObjectSlot or not.
|
||||
// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
|
||||
// - If it does not have a PyObjectSlot, we can freely decref
|
||||
// One alternative to this is using PyObject_IsInstance
|
||||
// to get at this information. However, we don't want to risk an incorrect
|
||||
// `__instancecheck__` changing the semantics here.
|
||||
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
|
||||
const {
|
||||
void ConcretePyInterpreterVTable::decref(PyObject* pyobj) const {
|
||||
// Leak the pyobj if not initialized. This can happen if we are running
|
||||
// exit handlers that are destructing tensors with residual (owned)
|
||||
// PyObjects stored in them.
|
||||
if (!Py_IsInitialized())
|
||||
return;
|
||||
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
// Two possibilities:
|
||||
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
|
||||
// Storage. Then we must be careful about PyObject resurrection (see
|
||||
// THPVariable_clear).
|
||||
// 2. We are decref-ing some other Python object. We don't do
|
||||
// PyObject resurrection on non-Tensors, so we just carry on as usual
|
||||
if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
|
||||
if (THPVariable_Check(pyobj)) {
|
||||
// It's still alive! This can happen if a weak ref resurrected
|
||||
// the PyObject without flipping ownership. At this point it is
|
||||
// too late to rescue the object, so just stub out the PyObject
|
||||
// so that it fails on subsequent uses. Don't raise an error here;
|
||||
// you're probably in a destructor.
|
||||
TORCH_WARN(
|
||||
"Deallocating Tensor that still has live PyObject references. "
|
||||
"This probably happened because you took out a weak reference to "
|
||||
"Tensor and didn't call _fix_weakref() after dereferencing it. "
|
||||
"Subsequent accesses to this tensor via the PyObject will now fail.");
|
||||
(reinterpret_cast<THPVariable*>(pyobj))->cdata =
|
||||
c10::MaybeOwned<torch::autograd::Variable>();
|
||||
} else if (THPStorage_Check(pyobj)) {
|
||||
TORCH_WARN(
|
||||
"Deallocating UntypedStorage that still has live PyObject references. "
|
||||
"This probably happened because you took out a weak reference to "
|
||||
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
|
||||
"Subsequent accesses to this storage via the PyObject will now fail.");
|
||||
(reinterpret_cast<THPStorage*>(pyobj))->cdata =
|
||||
c10::MaybeOwned<c10::Storage>();
|
||||
}
|
||||
}
|
||||
Py_DECREF(pyobj);
|
||||
}
|
||||
|
||||
@ -292,6 +254,25 @@ void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
|
||||
Py_INCREF(pyobj);
|
||||
}
|
||||
|
||||
bool ConcretePyInterpreterVTable::try_incref(
|
||||
const c10::impl::PyObjectSlot& pyobj_slot) const {
|
||||
if (!Py_IsInitialized())
|
||||
return false;
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
PyObject* pyobj = pyobj_slot.load_pyobj();
|
||||
if (!pyobj) {
|
||||
return false;
|
||||
}
|
||||
return PyUnstable_TryIncRef(pyobj);
|
||||
}
|
||||
|
||||
size_t ConcretePyInterpreterVTable::refcnt(PyObject* pyobj) const {
|
||||
if (!Py_IsInitialized() || pyobj == nullptr)
|
||||
return 0;
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
return Py_REFCNT(pyobj);
|
||||
}
|
||||
|
||||
bool isPythonTensor(const at::Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
|
||||
}
|
||||
@ -620,11 +601,7 @@ static void set_tensor_attr_with_capsule(
|
||||
const c10::TensorImpl* tensor,
|
||||
py::capsule& capsule,
|
||||
const char* attr_name) {
|
||||
std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
TORCH_CHECK(
|
||||
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||
auto obj = mb_obj.value();
|
||||
PyObject* obj = tensor->pyobj_slot()->load_pyobj();
|
||||
py::handle(obj).attr(attr_name) = capsule;
|
||||
}
|
||||
|
||||
@ -648,11 +625,7 @@ static c10::ArrayRef<T> get_set_cached_attr(
|
||||
const c10::TensorImpl* tensor,
|
||||
const char* base_attr_name,
|
||||
const py::object& obj) {
|
||||
std::optional<PyObject*> mb_obj =
|
||||
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
|
||||
TORCH_CHECK(
|
||||
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
|
||||
auto tensor_obj = mb_obj.value();
|
||||
PyObject* tensor_obj = tensor->pyobj_slot()->load_pyobj();
|
||||
auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
|
||||
|
||||
bool is_buffer_allocated = false;
|
||||
|
||||
@ -23,6 +23,8 @@
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
using torch::utils::PyObjectPreservation;
|
||||
|
||||
template <>
|
||||
void THPPointer<c10::StorageImpl>::free() {
|
||||
if (ptr) {
|
||||
@ -32,238 +34,72 @@ void THPPointer<c10::StorageImpl>::free() {
|
||||
|
||||
PyTypeObject* THPStorageClass = nullptr;
|
||||
|
||||
PyObject* THPStorage_NewWithStorage(
|
||||
PyTypeObject* type,
|
||||
c10::Storage _storage,
|
||||
bool allow_preexisting_pyobj) {
|
||||
TORCH_CHECK(
|
||||
PyType_IsSubtype(type, &THPStorageType),
|
||||
"Creating a Storage subclass from a class that does not inherit from ",
|
||||
"Storage is not possible. Make sure your class inherits from Storage.");
|
||||
|
||||
auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
|
||||
TORCH_CHECK(
|
||||
allow_preexisting_pyobj,
|
||||
"Creating a new Storage subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Storage object is already associated to a python object ",
|
||||
"of type ",
|
||||
maybe_pyobj.value()->ob_type->tp_name);
|
||||
PyObject* obj = *maybe_pyobj;
|
||||
PyTypeObject* obj_type = Py_TYPE(obj);
|
||||
TORCH_CHECK(
|
||||
obj_type == type || PyType_IsSubtype(obj_type, type),
|
||||
"Creating a new Storage subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Storage object is already associated to a python object ",
|
||||
"of type ",
|
||||
maybe_pyobj.value()->ob_type->tp_name,
|
||||
" which is not a subclass of the "
|
||||
"requested type");
|
||||
return THPStorage_Wrap(std::move(_storage));
|
||||
}
|
||||
|
||||
// Create a new Python Storage object, but don't set the pyobj slot on the
|
||||
// c10::Storage object.
|
||||
static PyObject* THPStorage_New(PyTypeObject* type, c10::Storage _storage) {
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
|
||||
|
||||
auto s = reinterpret_cast<THPStorage*>(obj);
|
||||
|
||||
new (&s->cdata) c10::MaybeOwned<c10::Storage>();
|
||||
|
||||
s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
|
||||
|
||||
if (!c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
s->is_hermetic = false;
|
||||
const auto& storage = THPStorage_Unpack(s);
|
||||
storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj);
|
||||
} else {
|
||||
s->is_hermetic = true;
|
||||
}
|
||||
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in
|
||||
// free-threaded Python.
|
||||
PyUnstable_EnableTryIncRef(obj);
|
||||
|
||||
auto s = (THPStorage*)obj;
|
||||
new (&s->cdata) c10::Storage(std::move(_storage));
|
||||
return obj;
|
||||
}
|
||||
|
||||
// Wraps the c10::Storage with a storage PyObject
|
||||
// Create a new Python Storage object for a new c10::Storage, and set the
|
||||
// pyobj slot. The c10::Storage must not already have a pyobj set.
|
||||
PyObject* THPStorage_NewWithStorage(PyTypeObject* type, c10::Storage _storage) {
|
||||
TORCH_CHECK(
|
||||
type == THPStorageClass || PyType_IsSubtype(type, &THPStorageType),
|
||||
"Creating a Storage subclass from a class that does not inherit from ",
|
||||
"Storage is not possible. Make sure your class inherits from Storage.");
|
||||
TORCH_INTERNAL_ASSERT(_storage.use_count() == 1);
|
||||
|
||||
c10::StorageImpl* storage_impl = _storage.unsafeGetStorageImpl();
|
||||
PyObject* obj = THPStorage_New(type, std::move(_storage));
|
||||
PyObjectPreservation::init_fresh_nonatomic(
|
||||
storage_impl, storage_impl->pyobj_slot(), obj);
|
||||
return obj;
|
||||
}
|
||||
|
||||
// Returns a PyObject wrapper for the c10::Storage object. The existing
|
||||
// wrapper is returned if it already exists.
|
||||
PyObject* THPStorage_Wrap(c10::Storage storage) {
|
||||
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage));
|
||||
return THPStorage_New(THPStorageClass, std::move(storage));
|
||||
}
|
||||
|
||||
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||
c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
|
||||
|
||||
std::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
if (maybe_pyobj.has_value()) {
|
||||
auto obj = *maybe_pyobj;
|
||||
if (obj) {
|
||||
TORCH_CHECK(
|
||||
THPStorage_Check(obj),
|
||||
"Expected a storage type, but got ",
|
||||
Py_TYPE(obj)->tp_name);
|
||||
|
||||
if (pyobj_slot->owns_pyobj()) {
|
||||
pyobj_slot->set_owns_pyobj(false);
|
||||
reinterpret_cast<THPStorage*>(obj)->cdata =
|
||||
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
|
||||
return obj;
|
||||
} else {
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
}
|
||||
}
|
||||
PyObject* obj = pyobj_slot->load_pyobj();
|
||||
if (obj) {
|
||||
return Py_NewRef(obj);
|
||||
}
|
||||
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage));
|
||||
|
||||
obj = THPStorage_New(THPStorageClass, std::move(storage));
|
||||
PyObject* wrapper =
|
||||
PyObjectPreservation::init_once(storage_impl, pyobj_slot, obj);
|
||||
if (wrapper != obj) {
|
||||
// Another thread beat us to it
|
||||
Py_DECREF(obj);
|
||||
return Py_NewRef(wrapper);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
static bool THPStorage_isPreservable(THPStorage* self) {
|
||||
if (self->cdata.unsafeIsBorrowed()) {
|
||||
return false;
|
||||
}
|
||||
auto const& storage = THPStorage_Unpack(self);
|
||||
|
||||
if (self->is_hermetic) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/true) != reinterpret_cast<PyObject*>(self)) {
|
||||
return false;
|
||||
}
|
||||
if (storage.use_count() <= 1) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool THPStorage_tryPreserve(THPStorage* self) {
|
||||
if (!THPStorage_isPreservable(self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& storage = THPStorage_Unpack(self);
|
||||
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
|
||||
|
||||
auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/true);
|
||||
// NOTE: It is possible to just set the PyObjectSlot here, but the point is
|
||||
// that we should have already set PyObjectSlot when the storage PyObject
|
||||
// was created.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
maybe_pyobj.has_value(),
|
||||
"Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
|
||||
|
||||
PyObject* pyobj = *maybe_pyobj;
|
||||
|
||||
TORCH_CHECK(
|
||||
THPStorage_Check(pyobj),
|
||||
"Expected a storage type, but got ",
|
||||
Py_TYPE(pyobj)->tp_name);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
(void*)pyobj == (void*)self,
|
||||
"Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
|
||||
|
||||
TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
|
||||
|
||||
storage_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
|
||||
// ensure the PyObject is in a valid state
|
||||
_Py_NewReference(reinterpret_cast<PyObject*>(self));
|
||||
|
||||
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void THPStorage_subclass_dealloc(PyObject* self) {
|
||||
static void THPStorage_dealloc(PyObject* self) {
|
||||
THPStorage* _self = reinterpret_cast<THPStorage*>(self);
|
||||
|
||||
if (THPStorage_tryPreserve(_self)) {
|
||||
return;
|
||||
auto pyobj_slot = _self->cdata.unsafeGetStorageImpl()->pyobj_slot();
|
||||
if (pyobj_slot->load_pyobj() == self) {
|
||||
TORCH_INTERNAL_ASSERT(_self->cdata.use_count() == 1);
|
||||
pyobj_slot->clear();
|
||||
}
|
||||
|
||||
// Some subclass of StorageBase could be GC-tracked objects even
|
||||
// though the base class is not
|
||||
auto* type = Py_TYPE(self);
|
||||
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
}
|
||||
|
||||
bool has_finalizer = type->tp_finalize || type->tp_del;
|
||||
|
||||
if (type->tp_finalize) {
|
||||
PyObject_GC_Track(self);
|
||||
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
|
||||
// The finalizer has resurrected the PyObject and there is a new Python
|
||||
// reference to it, so we can just stop deallocating. Read about
|
||||
// resurrection from `__del__` here:
|
||||
// https://docs.python.org/3/reference/datamodel.html#object.__del__
|
||||
return;
|
||||
}
|
||||
PyObject_GC_UnTrack(self);
|
||||
}
|
||||
|
||||
// base test is unnecessary as THPStorae does not set this
|
||||
if (type->tp_weaklistoffset) {
|
||||
PyObject_ClearWeakRefs(self);
|
||||
}
|
||||
|
||||
if (type->tp_del) {
|
||||
PyObject_GC_Track(self);
|
||||
type->tp_del(self);
|
||||
if (Py_REFCNT(self) > 0) {
|
||||
// Resurrected (see above comment about resurrection from `__del__`)
|
||||
return;
|
||||
}
|
||||
PyObject_GC_UnTrack(self);
|
||||
}
|
||||
|
||||
if (has_finalizer) {
|
||||
/* New weakrefs could be created during the finalizer call.
|
||||
If this occurs, clear them out without calling their
|
||||
finalizers since they might rely on part of the object
|
||||
being finalized that has already been destroyed. */
|
||||
if (type->tp_weaklistoffset) {
|
||||
/* Modeled after GET_WEAKREFS_LISTPTR() */
|
||||
PyWeakReference** list = reinterpret_cast<PyWeakReference**>(
|
||||
PyObject_GET_WEAKREFS_LISTPTR(self));
|
||||
while (*list)
|
||||
_PyWeakref_ClearRef(*list);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear slots
|
||||
{
|
||||
PyTypeObject* base = type;
|
||||
while (base != &THPStorageType) {
|
||||
if (Py_SIZE(base)) {
|
||||
clear_slots(base, self);
|
||||
}
|
||||
base = base->tp_base;
|
||||
TORCH_INTERNAL_ASSERT(base);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear __dict__
|
||||
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||
PyObject** dictptr = _PyObject_GetDictPtr(self);
|
||||
if (dictptr != nullptr) {
|
||||
PyObject* dict = *dictptr;
|
||||
if (dict != nullptr) {
|
||||
Py_DECREF(dict);
|
||||
*dictptr = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
|
||||
|
||||
_self->cdata.~MaybeOwned<c10::Storage>();
|
||||
_self->cdata.~Storage();
|
||||
Py_TYPE(_self)->tp_free(self);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||
Py_DECREF(type);
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_pynew(
|
||||
@ -553,64 +389,13 @@ static PyMappingMethods THPStorage_mappingmethods = {
|
||||
reinterpret_cast<binaryfunc>(THPStorage_get),
|
||||
reinterpret_cast<objobjargproc>(THPStorage_set)};
|
||||
|
||||
struct THPStorageMeta {
|
||||
PyHeapTypeObject base;
|
||||
};
|
||||
|
||||
static int THPStorageMetaType_init(
|
||||
PyObject* cls,
|
||||
PyObject* args,
|
||||
PyObject* kwargs);
|
||||
|
||||
static PyTypeObject THPStorageMetaType = {
|
||||
PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0)
|
||||
"torch._C._StorageMeta", /* tp_name */
|
||||
sizeof(THPStorageMeta), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
nullptr, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
nullptr, /* tp_reserved */
|
||||
nullptr, /* tp_repr */
|
||||
nullptr, /* tp_as_number */
|
||||
nullptr, /* tp_as_sequence */
|
||||
nullptr, /* tp_as_mapping */
|
||||
nullptr, /* tp_hash */
|
||||
nullptr, /* tp_call */
|
||||
nullptr, /* tp_str */
|
||||
nullptr, /* tp_getattro */
|
||||
nullptr, /* tp_setattro */
|
||||
nullptr, /* tp_as_buffer */
|
||||
// NOLINTNEXTLINE(misc-redundant-expression)
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
nullptr, /* tp_traverse */
|
||||
nullptr, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
nullptr, /* tp_iternext */
|
||||
nullptr, /* tp_methods */
|
||||
nullptr, /* tp_members */
|
||||
nullptr, /* tp_getset */
|
||||
DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
|
||||
nullptr, /* tp_dict */
|
||||
nullptr, /* tp_descr_get */
|
||||
nullptr, /* tp_descr_set */
|
||||
0, /* tp_dictoffset */
|
||||
THPStorageMetaType_init, /* tp_init */
|
||||
nullptr, /* tp_alloc */
|
||||
nullptr, /* tp_new */
|
||||
};
|
||||
|
||||
// TODO: implement equality
|
||||
PyTypeObject THPStorageType = {
|
||||
PyVarObject_HEAD_INIT(&THPStorageMetaType, 0)
|
||||
PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0)
|
||||
"torch._C.StorageBase", /* tp_name */
|
||||
sizeof(THPStorage), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
nullptr, /* tp_dealloc */
|
||||
THPStorage_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
@ -649,15 +434,6 @@ PyTypeObject THPStorageType = {
|
||||
THPStorage_pynew, /* tp_new */
|
||||
};
|
||||
|
||||
int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
||||
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
||||
return -1;
|
||||
}
|
||||
(reinterpret_cast<PyTypeObject*>(cls))->tp_dealloc =
|
||||
static_cast<destructor>(THPStorage_subclass_dealloc);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
|
||||
HANDLE_TH_ERRORS
|
||||
THPStorage_assertNotNull(self);
|
||||
@ -692,13 +468,6 @@ bool THPStorage_init(PyObject* module) {
|
||||
THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
|
||||
THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
|
||||
|
||||
THPStorageMetaType.tp_base = &PyType_Type;
|
||||
if (PyType_Ready(&THPStorageMetaType) < 0)
|
||||
return false;
|
||||
Py_INCREF(&THPStorageMetaType);
|
||||
PyModule_AddObject(
|
||||
module, "_StorageMeta", reinterpret_cast<PyObject*>(&THPStorageMetaType));
|
||||
|
||||
THPStorageType.tp_methods = methods.data();
|
||||
THPStorageType.tp_getset = THPStorage_properties;
|
||||
if (PyType_Ready(&THPStorageType) < 0)
|
||||
|
||||
@ -11,15 +11,13 @@
|
||||
|
||||
struct THPStorage {
|
||||
PyObject_HEAD
|
||||
c10::MaybeOwned<c10::Storage> cdata;
|
||||
bool is_hermetic;
|
||||
c10::Storage cdata;
|
||||
};
|
||||
|
||||
TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
|
||||
TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
|
||||
PyTypeObject* type,
|
||||
c10::Storage _storage,
|
||||
bool allow_preexisting_pyobj = false);
|
||||
c10::Storage _storage);
|
||||
TORCH_PYTHON_API extern PyTypeObject* THPStorageClass;
|
||||
|
||||
inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
|
||||
@ -49,7 +47,7 @@ TORCH_PYTHON_API void THPStorage_assertNotNull(PyObject* obj);
|
||||
TORCH_PYTHON_API extern PyTypeObject THPStorageType;
|
||||
|
||||
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
|
||||
return *storage->cdata;
|
||||
return storage->cdata;
|
||||
}
|
||||
|
||||
inline const c10::Storage& THPStorage_Unpack(PyObject* obj) {
|
||||
|
||||
@ -529,9 +529,8 @@ static PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) {
|
||||
THPUtils_typename(new_cdata));
|
||||
c10::StorageImpl* ptr =
|
||||
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(new_cdata));
|
||||
self->cdata.~MaybeOwned<c10::Storage>();
|
||||
self->cdata = c10::MaybeOwned<c10::Storage>::owned(
|
||||
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr)));
|
||||
self->cdata =
|
||||
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr));
|
||||
Py_INCREF(self);
|
||||
return reinterpret_cast<PyObject*>(self);
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
||||
@ -180,7 +180,9 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
if (!GradMode::is_enabled() && !new_grad.is_sparse() &&
|
||||
!new_grad.is_sparse_csr() &&
|
||||
!(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) &&
|
||||
at::caching::adjusted_use_count(new_grad) <= num_expected_refs &&
|
||||
impl::is_tensor_stealable(
|
||||
new_grad,
|
||||
num_expected_refs + at::caching::is_cached_tensor(new_grad)) &&
|
||||
(new_grad.is_mkldnn() ||
|
||||
utils::obeys_layout_contract(new_grad, variable))) {
|
||||
// See Case 1.1: Stealable dense new_grad
|
||||
@ -193,7 +195,7 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
// SparseTensor should be the only one holding a reference to these.
|
||||
new_grad._indices().use_count() <= 1 &&
|
||||
new_grad._values().use_count() <= 1 &&
|
||||
new_grad.use_count() <= num_expected_refs) {
|
||||
impl::is_tensor_stealable(new_grad, num_expected_refs)) {
|
||||
// Case 1.2: Stealable sparse new_grad
|
||||
// No scenario where we expect this to be true currently
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
|
||||
@ -86,8 +86,8 @@ bool can_accumulate_inplace(const Variable& v) {
|
||||
v.is_non_overlapping_and_dense() &&
|
||||
|
||||
// and we hold the last reference
|
||||
at::caching::adjusted_use_count(v) == 1 && v.has_storage() &&
|
||||
v.storage().use_count() == 1);
|
||||
impl::is_tensor_stealable(v, 1 + at::caching::is_cached_tensor(v)) &&
|
||||
v.has_storage() && v.storage().use_count() == 1);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
|
||||
@ -54,6 +54,7 @@
|
||||
using namespace at;
|
||||
using namespace torch;
|
||||
using namespace torch::autograd;
|
||||
using torch::utils::PyObjectPreservation;
|
||||
|
||||
namespace {
|
||||
class OperatorArgsKwargsView {
|
||||
@ -321,20 +322,15 @@ PyObject* THPVariableClass = nullptr;
|
||||
|
||||
PyObject* ParameterClass = nullptr;
|
||||
|
||||
static PyObject* THPVariable_NewWithVar(
|
||||
PyTypeObject* type,
|
||||
const at::TensorBase& _var,
|
||||
bool allow_preexisting_pyobj = false,
|
||||
std::optional<bool> has_torch_dispatch_if_known = std::nullopt);
|
||||
|
||||
// clang-tidy gets confused by static const
|
||||
static constexpr const char* VOLATILE_WARNING =
|
||||
"volatile was removed and now has no effect. Use "
|
||||
"`with torch.no_grad():` instead.";
|
||||
|
||||
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls);
|
||||
|
||||
static bool check_has_torch_dispatch(PyObject* obj) {
|
||||
PyTypeObject* tp = Py_TYPE(obj);
|
||||
if (THPVariable_CheckTypeExact(tp)) {
|
||||
if (THPVariable_CheckExact(obj)) {
|
||||
return false;
|
||||
}
|
||||
py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
|
||||
@ -370,152 +366,86 @@ void activateGPUTrace() {
|
||||
c10::impl::GPUTrace::set_trace(getPyInterpreter());
|
||||
}
|
||||
|
||||
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
|
||||
static void check_tensor_subclass(PyObject* obj, PyTypeObject* type) {
|
||||
TORCH_CHECK(
|
||||
PyObject_TypeCheck(obj, type),
|
||||
"Creating a new Tensor subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Tensor object is already associated to a python object ",
|
||||
"of type ",
|
||||
Py_TYPE(obj)->tp_name,
|
||||
" which is not a subclass of the requested type");
|
||||
}
|
||||
|
||||
// Generic for const Tensor& or Tensor&&
|
||||
template <typename T>
|
||||
static PyObject* THPVariable_WrapWithType(
|
||||
T&& var,
|
||||
std::optional<PyTypeObject*> desired_type) {
|
||||
if (!var.defined()) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
|
||||
}
|
||||
c10::TensorImpl* tensor_impl = var.unsafeGetTensorImpl();
|
||||
c10::impl::PyObjectSlot* pyobj_slot = tensor_impl->pyobj_slot();
|
||||
|
||||
std::optional<PyObject*> mb_obj =
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
if (mb_obj.has_value()) {
|
||||
auto obj = *mb_obj;
|
||||
if (obj) {
|
||||
if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
|
||||
// C++ owns the Python object; this implies there weren't any other
|
||||
// owning references to the Python object. Since we're making the
|
||||
// object "live" again on Python side, let's flip back the ownership
|
||||
// (Python owns C++) as it would now be unsound to deallocate the C++
|
||||
// object if all C++ references go to zero
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false);
|
||||
reinterpret_cast<THPVariable*>(obj)->cdata =
|
||||
MaybeOwned<Variable>::owned(Variable(var));
|
||||
// NB: incref is not necessary, because we are "stealing" the previous
|
||||
// ownership from the Variable to return it here for the wrap
|
||||
return obj;
|
||||
}
|
||||
Py_INCREF(obj);
|
||||
return obj;
|
||||
PyObject* obj = pyobj_slot->load_pyobj();
|
||||
if (obj) {
|
||||
if (desired_type) {
|
||||
check_tensor_subclass(obj, *desired_type);
|
||||
}
|
||||
// TODO: a better invariant is that if we tagged, we MUST have a valid
|
||||
// PyObject. That's PyObject preservation
|
||||
// (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR
|
||||
// being a thing, the PyObject field will get cleared when all references
|
||||
// to the Python object are removed.
|
||||
return Py_NewRef(obj);
|
||||
}
|
||||
|
||||
if (C10_LIKELY(var.device().type() != c10::kXLA)) {
|
||||
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
|
||||
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPVariableClass);
|
||||
if (desired_type) {
|
||||
type = *desired_type;
|
||||
} else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) {
|
||||
if (auto clazz = getPythonTensorClass(var.device())) {
|
||||
type = reinterpret_cast<PyTypeObject*>(clazz);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto clazz = getPythonTensorClass(var.device())) {
|
||||
return THPVariable_NewWithVar((PyTypeObject*)clazz, var);
|
||||
obj = type->tp_alloc(type, 0);
|
||||
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
|
||||
|
||||
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in
|
||||
// free-threaded Python.
|
||||
PyUnstable_EnableTryIncRef(obj);
|
||||
|
||||
auto v = reinterpret_cast<THPVariable*>(obj);
|
||||
new (&v->cdata) Tensor(std::forward<T>(var));
|
||||
|
||||
if (THPVariable_Unpack(obj).is_uniquely_owned()) {
|
||||
// We can use a faster non-atomic code path if we have the only reference to
|
||||
// a fresh Tensor.
|
||||
PyObjectPreservation::init_fresh_nonatomic(tensor_impl, pyobj_slot, obj);
|
||||
return obj;
|
||||
}
|
||||
|
||||
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
|
||||
PyObject* wrapper =
|
||||
PyObjectPreservation::init_once(tensor_impl, pyobj_slot, obj);
|
||||
if (wrapper != obj) {
|
||||
// Another thread beat us to it
|
||||
Py_DECREF(obj);
|
||||
if (desired_type) {
|
||||
check_tensor_subclass(wrapper, *desired_type);
|
||||
}
|
||||
return Py_NewRef(wrapper);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
static bool isResurrectable(THPVariable* self) {
|
||||
// We want to divide this check into 2 cases.
|
||||
|
||||
// 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is
|
||||
// true). You might think that in this case, it is impossible for tp_clear to
|
||||
// be called: surely the C++ reference to the PyObject is keeping it live? And
|
||||
// you'd be right! In fact, when C++ owns the PyObject, we have an invariant
|
||||
// that the refcount on the PyObject should be precisely one (because if you
|
||||
// take out another reference to the PyObject, we're supposed to flip the
|
||||
// ownership pointer back). In reality, you can violate this invariant
|
||||
// temporarily with weak references, so we don't test for it in asserts.
|
||||
|
||||
// 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is
|
||||
// false). In this case, tp_clear can get called if the PyObject is referenced
|
||||
// from a dead cycle, and nowhere else. But if resurrection did not occur,
|
||||
// then the reference to C++ from the PyObject must be the ONLY reference to
|
||||
// the C++ object.
|
||||
if (self->cdata.unsafeIsBorrowed()) {
|
||||
return false;
|
||||
}
|
||||
auto const& tensor = THPVariable_Unpack(self);
|
||||
if (!tensor.defined() || tensor.use_count() <= 1) {
|
||||
return false;
|
||||
}
|
||||
// Check if this is hermetic. If it is, no resurrection.
|
||||
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false) != (PyObject*)self) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
PyObject* THPVariable_Wrap(at::TensorBase&& var) {
|
||||
return THPVariable_WrapWithType(std::move(var), std::nullopt);
|
||||
}
|
||||
|
||||
// returns true if successfully rezzed; if so, cancel the
|
||||
// rest of deallocation
|
||||
static bool THPVariable_tryResurrect(THPVariable* self) {
|
||||
const auto& tensor = THPVariable_Unpack(self);
|
||||
|
||||
if (!isResurrectable(self)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// At this point, we are definitely going to resurrect the tensor. So, the
|
||||
// tensor better be defined :)
|
||||
TORCH_INTERNAL_ASSERT(tensor.defined());
|
||||
|
||||
// There are other C++ owners of the tensor. Flip ownership
|
||||
// so that C++ owns this Python object, and cancel deallocation.
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
|
||||
|
||||
c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
|
||||
auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
maybe_pyobj.has_value(),
|
||||
"Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
|
||||
|
||||
tensor_impl->pyobj_slot()->set_owns_pyobj(true);
|
||||
|
||||
// Resurrect the Python object. This is something CPython does
|
||||
// internally occasionally, see
|
||||
// https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259
|
||||
// so we just copy the pattern here. Note that we don't have to worry
|
||||
// about saving and restoring the refcount (as the quoted code does)
|
||||
// because we actually DO need to reset the refcount to one here, we
|
||||
// can't assume that some other code has taken care of it.
|
||||
// NB: this will overreport _Py_RefTotal but based on inspection of object.c
|
||||
// there is no way to avoid this
|
||||
|
||||
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
|
||||
// ensure the PyObject is in a valid state
|
||||
_Py_NewReference((PyObject*)self);
|
||||
|
||||
// Flip THPVariable to be non-owning
|
||||
// (near use-after-free miss here: fresh MaybeOwned is created breaking
|
||||
// reference on Tensor in struct BEFORE we overwrite the old one)
|
||||
TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
|
||||
self->cdata = MaybeOwned<Variable>::borrowed(tensor);
|
||||
|
||||
// NB: At this point, tensor *could* be dead (e.g., some other C++ thread
|
||||
// decrefed it.) At this point, it is probably waiting on the GIL to
|
||||
// deallocate the Python object and will kill self, BUT NOT YET.
|
||||
|
||||
return true;
|
||||
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
|
||||
return THPVariable_WrapWithType(var, std::nullopt);
|
||||
}
|
||||
|
||||
static int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "TensorBase tp_traverse function was not overridden properly");
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int THPFake_clear(THPVariable* self) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "TensorBase tp_clear function was not overridden properly");
|
||||
return 0;
|
||||
PyObject* THPVariable_Wrap(const at::TensorBase& var, PyTypeObject* type) {
|
||||
return THPVariable_WrapWithType(var, type);
|
||||
}
|
||||
|
||||
static PyObject* THPVariable_pynew(
|
||||
@ -677,16 +607,16 @@ static PyObject* THPVariable_as_subclass(
|
||||
ParsedArgs<1> parsed_args{};
|
||||
auto r = parser.parse(_self, args, kwargs, parsed_args);
|
||||
PyObject* cls = r.pyobject(0);
|
||||
TORCH_CHECK_TYPE(
|
||||
PyType_Check(cls),
|
||||
"cls must be a type (got ",
|
||||
Py_TYPE(cls)->tp_name,
|
||||
")");
|
||||
TORCH_CHECK_TENSOR_SUBTYPE(cls);
|
||||
// guard completely turns off torch dispatch modes, doesn't just pop off the
|
||||
// stack
|
||||
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
|
||||
c10::impl::DisablePythonDispatcher dpd_g;
|
||||
return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias());
|
||||
PyObject* obj = THPVariable_WrapWithType(self.alias(), (PyTypeObject*)cls);
|
||||
if (check_has_torch_dispatch(obj)) {
|
||||
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
}
|
||||
return obj;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -701,11 +631,7 @@ static PyObject* THPVariable_make_subclass(
|
||||
ParsedArgs<7> parsed_args{};
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
PyObject* cls = r.pyobject(0);
|
||||
TORCH_CHECK_TYPE(
|
||||
PyType_Check(cls),
|
||||
"cls must be a type (got ",
|
||||
Py_TYPE(cls)->tp_name,
|
||||
")");
|
||||
TORCH_CHECK_TENSOR_SUBTYPE(cls);
|
||||
// guard completely turns off torch dispatch modes, doesn't just pop off the
|
||||
// stack
|
||||
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
|
||||
@ -738,7 +664,11 @@ static PyObject* THPVariable_make_subclass(
|
||||
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
|
||||
}
|
||||
|
||||
return THPVariable_NewWithVar((PyTypeObject*)cls, data);
|
||||
PyObject* obj = THPVariable_WrapWithType(data, (PyTypeObject*)cls);
|
||||
if (check_has_torch_dispatch(obj)) {
|
||||
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
}
|
||||
return obj;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -835,11 +765,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
PyObject* cls = r.pyobject(0);
|
||||
|
||||
TORCH_CHECK_TYPE(
|
||||
PyType_Check(cls),
|
||||
"cls must be a type (got ",
|
||||
Py_TYPE(cls)->tp_name,
|
||||
")");
|
||||
TORCH_CHECK_TENSOR_SUBTYPE(cls);
|
||||
|
||||
// This is an important safety check; without it, the default behavior will be
|
||||
// to continue on to the underlying CPU/CUDA kernel advertised by the dispatch
|
||||
@ -877,6 +803,8 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
/*storage_size=*/r.toSymIntOptional(14),
|
||||
r.toDispatchKeySetOptional(13));
|
||||
|
||||
tensor.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
|
||||
const auto sizes_strides_policy = r.stringViewOptional(10);
|
||||
if (sizes_strides_policy.has_value()) {
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
|
||||
@ -892,13 +820,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
|
||||
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
|
||||
}
|
||||
|
||||
return THPVariable_NewWithVar(
|
||||
(PyTypeObject*)cls,
|
||||
tensor,
|
||||
// false is the default
|
||||
/*allow_preexisting_pyobj=*/false,
|
||||
// we checked __torch_dispatch__ above; avoid checking again.
|
||||
/*has_torch_dispatch_if_known=*/true);
|
||||
return THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -1699,11 +1621,7 @@ static PyObject* THPVariable_dtensor_new(
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
PyObject* cls = r.pyobject(0);
|
||||
|
||||
TORCH_CHECK_TYPE(
|
||||
PyType_Check(cls),
|
||||
"cls must be a type (got ",
|
||||
Py_TYPE(cls)->tp_name,
|
||||
")");
|
||||
TORCH_CHECK_TENSOR_SUBTYPE(cls);
|
||||
|
||||
#ifndef NDEBUG
|
||||
// This is specifically for making a DTensor, which we know defines
|
||||
@ -1756,14 +1674,9 @@ static PyObject* THPVariable_dtensor_new(
|
||||
/*storage_size=*/std::nullopt,
|
||||
extra_dispatch_keys);
|
||||
tensor.set_requires_grad(requires_grad);
|
||||
py::object py_tensor =
|
||||
py::reinterpret_steal<py::object>(THPVariable_NewWithVar(
|
||||
(PyTypeObject*)cls,
|
||||
tensor,
|
||||
// false is the default
|
||||
/*allow_preexisting_pyobj=*/false,
|
||||
// we know DTensor has __torch_dispatch__; avoid checking again.
|
||||
/*has_torch_dispatch_if_known=*/true));
|
||||
tensor.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
py::object py_tensor = py::reinterpret_steal<py::object>(
|
||||
THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls));
|
||||
py_tensor.attr(dtensor_interned_strings._spec) = spec;
|
||||
py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor;
|
||||
return py_tensor.release().ptr();
|
||||
@ -3440,15 +3353,16 @@ static PyTypeObject THPVariableMetaType = {
|
||||
nullptr, /* tp_new */
|
||||
};
|
||||
|
||||
static void THPVariable_dealloc(PyObject* self);
|
||||
static int THPVariable_clear(THPVariable* self);
|
||||
static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg);
|
||||
|
||||
static PyTypeObject THPVariableType = {
|
||||
PyVarObject_HEAD_INIT(&THPVariableMetaType, 0)
|
||||
"torch._C.TensorBase", /* tp_name */
|
||||
sizeof(THPVariable), /* tp_basicsize */
|
||||
0, /* tp_itemsize */
|
||||
// This is unspecified, because it is illegal to create a THPVariableType
|
||||
// directly. Subclasses will have their tp_dealloc set appropriately
|
||||
// by the metaclass
|
||||
nullptr, /* tp_dealloc */
|
||||
THPVariable_dealloc, /* tp_dealloc */
|
||||
0, /* tp_vectorcall_offset */
|
||||
nullptr, /* tp_getattr */
|
||||
nullptr, /* tp_setattr */
|
||||
@ -3467,9 +3381,8 @@ static PyTypeObject THPVariableType = {
|
||||
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
|
||||
Py_TPFLAGS_HAVE_GC, /* tp_flags */
|
||||
nullptr, /* tp_doc */
|
||||
// Also set by metaclass
|
||||
(traverseproc)THPFake_traverse, /* tp_traverse */
|
||||
(inquiry)THPFake_clear, /* tp_clear */
|
||||
(traverseproc)THPVariable_traverse, /* tp_traverse */
|
||||
(inquiry)THPVariable_clear, /* tp_clear */
|
||||
nullptr, /* tp_richcompare */
|
||||
0, /* tp_weaklistoffset */
|
||||
nullptr, /* tp_iter */
|
||||
@ -3498,345 +3411,68 @@ PyObject* THPVariable_pynew(
|
||||
type != &THPVariableType,
|
||||
"Cannot directly construct TensorBase; subclass it and then construct that");
|
||||
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
|
||||
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
|
||||
// given a raw pointer that will refcount bump
|
||||
// NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
|
||||
// alias(), lift_fresh()) which can return Tensor subclasses. We allow
|
||||
// these to be passed on directly.
|
||||
return THPVariable_NewWithVar(
|
||||
type,
|
||||
tensor,
|
||||
/*allow_preexisting_pyobj=*/true);
|
||||
PyObject* obj = THPVariable_WrapWithType(
|
||||
torch::utils::base_tensor_ctor(args, kwargs), type);
|
||||
if (check_has_torch_dispatch(obj)) {
|
||||
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
}
|
||||
return obj;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static int THPVariable_subclass_clear(THPVariable* self) {
|
||||
// Is it OK for an object to still be live after running
|
||||
// tp_clear? Yes. When Python is breaking reference cycles, it can't assume
|
||||
// that an object will dealloc after it's cleared. The source code explicitly
|
||||
// handles this case:
|
||||
// https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025
|
||||
|
||||
// Note that we don't need to actually resurrect here. There are 2 cases:
|
||||
// 1. The PyObject is not part of a reference cycle. In this case, we don't
|
||||
// need to do anything. The GC will move on to try and break the reference
|
||||
// cycle on another object, which will eventually trigger tp_dealloc (and thus
|
||||
// resurrection).
|
||||
|
||||
// 2. The PyObject is part of a reference cycle. This case should not actually
|
||||
// be possible, due to the logic in our tp_traverse
|
||||
// (THPVariable_subclass_traverse).
|
||||
|
||||
// In fact, resurrecting here breaks the invariant that "C++ owns Python only
|
||||
// when PyObject's refcount would otherwise be 0". Most immediately, as we're
|
||||
// merely breaking reference cycles here, there can be other references to the
|
||||
// PyObject. *However*, if other objects in the refcycle resurrect, then we
|
||||
// will be in a state where the PyObject has multiple Python references, yet
|
||||
// C++ owns the PyObject.
|
||||
|
||||
// See https://github.com/pytorch/pytorch/pull/75933 for more discussion.
|
||||
if (isResurrectable(self)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int THPVariable_clear(THPVariable* self) {
|
||||
// First clear Tensor specific things
|
||||
|
||||
Py_CLEAR(self->backward_hooks);
|
||||
Py_CLEAR(self->post_accumulate_grad_hooks);
|
||||
const auto& tensor = THPVariable_Unpack(self);
|
||||
if (tensor.defined()) {
|
||||
// Two situations to consider:
|
||||
// PyObject -owns-> Tensor
|
||||
// unsafeIsBorrowed() is FALSE. We're obligated to look through
|
||||
// Tensor to break references. Clearing cdata must induce the
|
||||
// destruction of the C++ Tensor. If there were other references
|
||||
// to C++ tensor, the Python object would have been resurrected
|
||||
// by flipping the ownership.
|
||||
// Tensor -owns-> PyObject
|
||||
// unsafeIsBorrowed() is TRUE. We're deallocating the PyObject
|
||||
// because Tensor asked us to (it's already destructing).
|
||||
if (self->cdata.defined()) {
|
||||
auto pyobj_slot = self->cdata.unsafeGetTensorImpl()->pyobj_slot();
|
||||
// Typically the Tensor's pyobj_slot points back to this object. The only
|
||||
// time that's not the case is if we had a race in THPVariable_Wrap and we
|
||||
// need to discard the Python object because some other thread beat us to
|
||||
// setting the pyobj_slot.
|
||||
if (pyobj_slot->load_pyobj() == (PyObject*)self) {
|
||||
// A Tensor's Python object should only be destroyed when the Tensor has
|
||||
// no other references too.
|
||||
TORCH_INTERNAL_ASSERT(self->cdata.use_count() == 1);
|
||||
|
||||
if (!self->cdata.unsafeIsBorrowed() &&
|
||||
tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false) == (PyObject*)self) {
|
||||
// TODO: empirically, on OS X this assert appears to be untrue
|
||||
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
|
||||
// distributed/rpc/test_process_group_agent.py
|
||||
//
|
||||
// libc++abi.dylib: terminating with uncaught exception of type
|
||||
// c10::Error:
|
||||
// !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
|
||||
// ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
|
||||
// please report a bug to PyTorch. Exception raised from
|
||||
// THPVariable_subclass_clear at
|
||||
// ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
|
||||
// first): frame #0: c10::Error::Error(c10::SourceLocation,
|
||||
// std::__1::basic_string<char, std::__1::char_traits<char>,
|
||||
// std::__1::allocator<char> >) + 98 (0x1158a0442 in libc10.dylib) frame
|
||||
// #1: c10::detail::torchCheckFail(char const*, char const*, unsigned
|
||||
// int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2:
|
||||
// c10::detail::torchInternalAssertFail(char const*, char const*,
|
||||
// unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
|
||||
// (0x1141e3f89 in libtorch_python.dylib) frame #3:
|
||||
// THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in
|
||||
// libtorch_python.dylib) frame #4:
|
||||
// THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
|
||||
// libtorch_python.dylib) frame #5: (anonymous
|
||||
// namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
|
||||
// _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6:
|
||||
// c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in
|
||||
// libc10.dylib) frame #7:
|
||||
// c10::MaybeOwned<at::Tensor>::operator=(c10::MaybeOwned<at::Tensor>&&)
|
||||
// + 91 (0x11488c11b in libtorch_python.dylib) frame #8:
|
||||
// THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in
|
||||
// libtorch_python.dylib) <omitting python frames> frame #47: start + 1
|
||||
// (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???)
|
||||
// TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
|
||||
if (auto grad_acc =
|
||||
torch::autograd::impl::try_get_grad_accumulator(tensor)) {
|
||||
grad_acc->pre_hooks().clear();
|
||||
grad_acc->tensor_pre_hooks().clear();
|
||||
grad_acc->retains_grad_hooks().clear();
|
||||
}
|
||||
// Clear the pyobj_slot so that a try_incref() call from
|
||||
// weak_intrusive_ptr::lock() won't see a freed pointer.
|
||||
pyobj_slot->clear();
|
||||
}
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(!isResurrectable(self));
|
||||
{
|
||||
// MapAllocator can take significant time to release large tensors;
|
||||
// release the GIL here to avoid impacting main thread perf.
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
self->cdata = MaybeOwned<Variable>();
|
||||
self->cdata = Variable();
|
||||
}
|
||||
// Since we override the basic subtype_clear from CPython, we need a crappy
|
||||
// version here just like for traverse and dealloc
|
||||
|
||||
// Clear all slots until we get to the base Tensor class
|
||||
PyTypeObject* type = Py_TYPE((PyObject*)self);
|
||||
PyTypeObject* base = type;
|
||||
while (base != &THPVariableType) {
|
||||
if (Py_SIZE(base))
|
||||
clear_slots(base, (PyObject*)self);
|
||||
base = base->tp_base;
|
||||
TORCH_INTERNAL_ASSERT(base);
|
||||
}
|
||||
|
||||
// Assume we never have managed dict for Tensors as we don't set the flag on
|
||||
// the base class
|
||||
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||
PyObject** dictptr = _PyObject_GetDictPtr((PyObject*)self);
|
||||
if (dictptr && *dictptr)
|
||||
Py_CLEAR(*dictptr);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
|
||||
// on subclasses. It's never valid to construct a THPVariable so it's not
|
||||
// necessary to implement the dealloc for that case
|
||||
static void THPVariable_subclass_dealloc(PyObject* self) {
|
||||
if (THPVariable_tryResurrect((THPVariable*)self))
|
||||
return;
|
||||
|
||||
// This is like a crappy version of subtype_dealloc.
|
||||
// Unfortunately, we cannot directly delegate to
|
||||
// subtype_dealloc as it will start walking the parent
|
||||
// chain *starting with* the type of self, which will cause
|
||||
// us to go back to our custom dealloc.
|
||||
//
|
||||
// We have to replicate the subtype_dealloc logic to ensure
|
||||
// that finalizers are handled correctly
|
||||
PyTypeObject* type = Py_TYPE(self);
|
||||
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||
TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented");
|
||||
|
||||
static void THPVariable_dealloc(PyObject* self) {
|
||||
PyObject_GC_UnTrack(self);
|
||||
// TODO: consider using trash can
|
||||
|
||||
bool has_finalizer = type->tp_finalize || type->tp_del;
|
||||
|
||||
if (type->tp_finalize) {
|
||||
PyObject_GC_Track(self);
|
||||
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
|
||||
/* Resurrected */
|
||||
return;
|
||||
}
|
||||
PyObject_GC_UnTrack(self);
|
||||
}
|
||||
|
||||
// base test is unnecessary as THPVariable does not set this
|
||||
if (type->tp_weaklistoffset) {
|
||||
PyObject_ClearWeakRefs(self);
|
||||
}
|
||||
|
||||
if (type->tp_del) {
|
||||
PyObject_GC_Track(self);
|
||||
type->tp_del(self);
|
||||
if (Py_REFCNT(self) > 0) {
|
||||
/* Resurrected */
|
||||
return;
|
||||
}
|
||||
PyObject_GC_UnTrack(self);
|
||||
}
|
||||
|
||||
if (has_finalizer) {
|
||||
/* New weakrefs could be created during the finalizer call.
|
||||
If this occurs, clear them out without calling their
|
||||
finalizers since they might rely on part of the object
|
||||
being finalized that has already been destroyed. */
|
||||
if (type->tp_weaklistoffset) {
|
||||
/* Modeled after GET_WEAKREFS_LISTPTR() */
|
||||
PyWeakReference** list =
|
||||
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
|
||||
while (*list)
|
||||
_PyWeakref_ClearRef(*list);
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all slots until we get to base class THPVariableType
|
||||
{
|
||||
PyTypeObject* base = type;
|
||||
while (base != &THPVariableType) {
|
||||
if (Py_SIZE(base)) {
|
||||
clear_slots(base, self);
|
||||
}
|
||||
base = base->tp_base;
|
||||
TORCH_INTERNAL_ASSERT(base);
|
||||
}
|
||||
}
|
||||
|
||||
// All Python defined classes have __dict__
|
||||
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||
PyObject** dictptr = _PyObject_GetDictPtr(self);
|
||||
if (dictptr != nullptr) {
|
||||
PyObject* dict = *dictptr;
|
||||
if (dict != nullptr) {
|
||||
Py_DECREF(dict);
|
||||
*dictptr = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// subtype_dealloc allows for this but we don't
|
||||
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
|
||||
|
||||
// Finally clear out the base THPVariable
|
||||
THPVariable_subclass_clear((THPVariable*)self);
|
||||
((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
|
||||
THPVariable_clear((THPVariable*)self);
|
||||
((THPVariable*)self)->cdata.~Variable();
|
||||
Py_TYPE(self)->tp_free(self);
|
||||
|
||||
// Python defined subclasses should always be on the heap
|
||||
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||
Py_DECREF(type);
|
||||
}
|
||||
|
||||
// Creates a new Python object for a Variable.
|
||||
static PyObject* THPVariable_NewWithVar(
|
||||
PyTypeObject* type,
|
||||
const at::TensorBase& _var,
|
||||
bool allow_preexisting_pyobj,
|
||||
std::optional<bool> has_torch_dispatch_if_known) {
|
||||
// Make sure that the reinterpret into a THPVariable* will be valid
|
||||
TORCH_CHECK(
|
||||
type == &THPVariableType || PyType_IsSubtype(type, &THPVariableType),
|
||||
"Creating a Tensor subclass from a class ",
|
||||
"that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
|
||||
|
||||
// This function overwrite the Tensor's pyobj field without extra checks
|
||||
// Make sure it is not set otherwise we would leak memory
|
||||
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
|
||||
/*ignore_hermetic_tls=*/false);
|
||||
|
||||
// Under some circumstances, we may attempt to create a new Python
|
||||
// object for a variable that already has a Python object. The most common
|
||||
// situation this can occur is if you have a TorchDispatchMode active that
|
||||
// is returning a subclass from lift_fresh (which is invoked to
|
||||
// appropriately "wrap" a constant tensor into whatever ambient modes are
|
||||
// active.)
|
||||
//
|
||||
// In general, it is impossible to handle this case compositionally.
|
||||
// Suppose you have a user call ATensor([1, 2, 3]) when a mode is active
|
||||
// that is transforming all ops (including the internal lift_fresh call that
|
||||
// transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output
|
||||
// BTensor, where ATensor and BTensor are completely unrelated subclasses
|
||||
// and there is no way to compose them. There is no way to satisfy the user
|
||||
// request here: in particular, you can't just try to re-invoke the ATensor
|
||||
// constructor on the returned BTensor, because (1) this could cause an
|
||||
// infinite loop--we are already in ATensor.__new__ and (2) there isn't any
|
||||
// guarantee that ATensor.__new__ supports a single element constructor
|
||||
// anyway.
|
||||
//
|
||||
// However, a more common case is a user just called torch.Tensor([1, 2, 3]),
|
||||
// and a fake tensor mode is active. Really, all you want is to get back
|
||||
// a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3)
|
||||
// would have returned a fake tensor (concretely, the way this happens
|
||||
// is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it
|
||||
// turns into a FakeTensor when we call lift_fresh on this real tensor).
|
||||
// This case is compositional because FakeTensor is a subclass of Tensor, so
|
||||
// it's valid for us to return it in place of a Tensor. So this is what we
|
||||
// do.
|
||||
|
||||
if (mb_obj.has_value() && mb_obj.value()) {
|
||||
TORCH_CHECK(
|
||||
allow_preexisting_pyobj,
|
||||
"Creating a new Tensor subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Tensor object is already associated to a python object ",
|
||||
"of type ",
|
||||
mb_obj.value()->ob_type->tp_name);
|
||||
// Even if we allow pre-existing PyObject, we don't allow completely
|
||||
// ignoring the requested type. Check that we fulfilled a subtype
|
||||
// relation here. In the common case the requested type is Tensor and
|
||||
// this always succeeds.
|
||||
PyObject* obj = *mb_obj;
|
||||
// Check if it's OK to just directly return the Python object without
|
||||
// allocating a new variable. We just check that the existing Python
|
||||
// object is a subclass of the requested type.
|
||||
PyTypeObject* obj_type = Py_TYPE(obj);
|
||||
TORCH_CHECK(
|
||||
obj_type == type || PyType_IsSubtype(obj_type, type),
|
||||
"Creating a new Tensor subclass ",
|
||||
type->tp_name,
|
||||
" but the raw Tensor object is already associated to a python object ",
|
||||
"of type ",
|
||||
mb_obj.value()->ob_type->tp_name,
|
||||
" which is not a subclass of the "
|
||||
"requested type");
|
||||
// We may (in fact, we typically will) need to resurrect this
|
||||
return THPVariable_Wrap(_var);
|
||||
}
|
||||
|
||||
PyObject* obj = type->tp_alloc(type, 0);
|
||||
if (obj) {
|
||||
auto v = (THPVariable*)obj;
|
||||
// TODO: named constructor to avoid default initialization
|
||||
new (&v->cdata) MaybeOwned<Variable>();
|
||||
if (c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
// Do NOT initialize pyobj field on the tensor, you own the C++
|
||||
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!check_has_torch_dispatch(obj),
|
||||
"While HermeticPyObject was enabled, we attempted to create a tensor "
|
||||
"subclass with __torch_dispatch__. This violates the invariant that "
|
||||
"operations in HermeticPyObject have equivalent C++ implementations. "
|
||||
"If your operator registered from Python operator registration isn't "
|
||||
"doing anything strange, there may be an internal PyTorch bug involving "
|
||||
"not appropriately disabling TorchDispatchMode before executing "
|
||||
"Python op registration.");
|
||||
} else {
|
||||
// Normal codepath
|
||||
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
|
||||
const auto& var = THPVariable_Unpack(v);
|
||||
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj);
|
||||
if (has_torch_dispatch_if_known.has_value()
|
||||
? *has_torch_dispatch_if_known
|
||||
: check_has_torch_dispatch(obj)) {
|
||||
var.unsafeGetTensorImpl()->set_python_dispatch(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) {
|
||||
TORCH_CHECK_TYPE(
|
||||
PyType_Check(cls),
|
||||
"cls must be a type (got ",
|
||||
Py_TYPE(cls)->tp_name,
|
||||
")");
|
||||
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(cls);
|
||||
TORCH_CHECK_TYPE(
|
||||
type == &THPVariableType || cls == THPVariableClass ||
|
||||
PyType_IsSubtype(type, &THPVariableType),
|
||||
"Creating a Tensor subclass from a class that does not inherit from "
|
||||
"Tensor is not possible. Make sure your class inherits from Tensor.");
|
||||
}
|
||||
|
||||
/// NOTE [ PyObject Traversal ]
|
||||
@ -3855,7 +3491,7 @@ static PyObject* THPVariable_NewWithVar(
|
||||
/// into account these C++ ownership links.
|
||||
///
|
||||
/// The main danger here comes from the fact that, while all python-related code
|
||||
/// is thread safe wrt the GC execution (thanks to the GIL), other threads might
|
||||
/// is thread safe wrt the GC execution, other threads might
|
||||
/// be using our C++ objects arbitrarily which can lead to shared_ptr ref count
|
||||
/// going up or down in between the different traverse/clear invocations. The
|
||||
/// one constraint we add here that is not explicitly mentioned in the GC
|
||||
@ -3885,124 +3521,46 @@ static PyObject* THPVariable_NewWithVar(
|
||||
/// https://github.com/pytorch/pytorch/issues/7343
|
||||
///
|
||||
|
||||
static int traverse_slots(
|
||||
PyTypeObject* type,
|
||||
PyObject* self,
|
||||
visitproc visit,
|
||||
void* arg) {
|
||||
auto n = Py_SIZE(type);
|
||||
auto mp = type->tp_members;
|
||||
for (Py_ssize_t i = 0; i < n; i++, mp++) {
|
||||
if (mp->type == T_OBJECT_EX) {
|
||||
char* addr = (char*)self + mp->offset;
|
||||
PyObject* obj = *(PyObject**)addr;
|
||||
if (obj != nullptr) {
|
||||
int err = visit(obj, arg);
|
||||
if (err)
|
||||
return err;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int THPVariable_subclass_traverse(
|
||||
PyObject* self,
|
||||
visitproc visit,
|
||||
void* arg) {
|
||||
// If the tensor is eligible to be resurrected, don't traverse it; instead
|
||||
// treat all of its references as a root (as they WOULD be a root since we
|
||||
// can treat the inbound C++ references as root owners).
|
||||
//
|
||||
// This works because unlike conventional GCs, Python's GC operates in two
|
||||
// phases: first it uses traverse to discover roots, and then it uses traverse
|
||||
// to do reachability. Bypassing traverse during root discovery forces Python
|
||||
// to treat self as a root for everything it refers to. For a full
|
||||
// explanation of the algorithm see
|
||||
// https://devguide.python.org/garbage_collector/
|
||||
//
|
||||
// NB: if we don't hold an owning reference to the underlying Tensor, it is
|
||||
// possible that the underlying Tensor has already gone dead. In that case,
|
||||
// it's not safe to access it. But it's also safe to traverse, because if
|
||||
// the underlying Tensor *is* live, then root discovery will determine that
|
||||
// self is live, and nothing will get GC'ed anyway (resurrection cannot happen
|
||||
// if the C++ objects owns the PyObject)
|
||||
static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg) {
|
||||
THPVariable* var = reinterpret_cast<THPVariable*>(self);
|
||||
if (isResurrectable(var)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Crappy version of subtype_traverse; same deal as
|
||||
// THPVariable_subclass_dealloc
|
||||
|
||||
PyTypeObject* type = Py_TYPE(self);
|
||||
// Traverse slots until we get to base class THPVariableType
|
||||
{
|
||||
PyTypeObject* base = type;
|
||||
while (base != &THPVariableType) {
|
||||
if (Py_SIZE(base)) {
|
||||
int err = traverse_slots(base, self, visit, arg);
|
||||
if (err)
|
||||
return err;
|
||||
}
|
||||
base = base->tp_base;
|
||||
TORCH_INTERNAL_ASSERT(base);
|
||||
}
|
||||
}
|
||||
|
||||
// All Python defined classes have __dict__
|
||||
if (C10_LIKELY(type->tp_dictoffset)) {
|
||||
PyObject** dictptr = _PyObject_GetDictPtr(self);
|
||||
if (dictptr && *dictptr)
|
||||
Py_VISIT(*dictptr);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
|
||||
Py_VISIT(type);
|
||||
|
||||
// Finally traverse THPVariable special stuff
|
||||
Py_VISIT(var->backward_hooks);
|
||||
Py_VISIT(var->post_accumulate_grad_hooks);
|
||||
if (!var->cdata.unsafeIsBorrowed()) {
|
||||
const auto& tensor = THPVariable_Unpack(var);
|
||||
if (tensor.defined()) {
|
||||
// WARNING: The grad_fn traversal logic is very subtle, if you change
|
||||
// this, be very careful not to re-introduce this bug:
|
||||
// https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
|
||||
const auto& tensor = THPVariable_Unpack(var);
|
||||
if (tensor.defined()) {
|
||||
// WARNING: The grad_fn traversal logic is very subtle, if you change
|
||||
// this, be very careful not to re-introduce this bug:
|
||||
// https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
|
||||
|
||||
// We ensure that we follow NOTE [ PyObject Traversal ] he by checking
|
||||
// that this python object is the sole owner of the underlying Tensor and
|
||||
// that this Tensor is the sole owner of its grad_fn. In this case, the
|
||||
// only way to get a new reference to the grad_fn is by using this python
|
||||
// object, which requires the GIL to be accessed. Note that this is only
|
||||
// valid as long as user don't share non-owning references across
|
||||
// different threads (which is crazy and should never be done).
|
||||
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
|
||||
if (tensor.use_count() == 1) {
|
||||
if (autograd_meta) {
|
||||
// Do NOT call grad_fn() here as that might trigger a recompute
|
||||
const auto& grad_fn = autograd_meta->grad_fn_;
|
||||
if (grad_fn && grad_fn.use_count() == 1) {
|
||||
// All Node can have a pyobj (stored in "pyobj_")
|
||||
Py_VISIT(grad_fn->pyobj());
|
||||
// PyNode are special as they also have an "obj" field
|
||||
if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
|
||||
Py_VISIT(py_node_fn->obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// We ensure that we follow NOTE [ PyObject Traversal ] he by checking
|
||||
// that this python object is the sole owner of the underlying Tensor and
|
||||
// that this Tensor is the sole owner of its grad_fn. In this case, the
|
||||
// only way to get a new reference to the grad_fn is by using this python
|
||||
// object, which requires the GIL to be accessed. Note that this is only
|
||||
// valid as long as user don't share non-owning references across
|
||||
// different threads (which is crazy and should never be done).
|
||||
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
|
||||
if (tensor.use_count() == 1) {
|
||||
if (autograd_meta) {
|
||||
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
|
||||
if (auto pyhook =
|
||||
dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
|
||||
Py_VISIT(pyhook->dict);
|
||||
// Do NOT call grad_fn() here as that might trigger a recompute
|
||||
const auto& grad_fn = autograd_meta->grad_fn_;
|
||||
if (grad_fn && grad_fn.use_count() == 1) {
|
||||
// All Node can have a pyobj (stored in "pyobj_")
|
||||
Py_VISIT(grad_fn->pyobj());
|
||||
// PyNode are special as they also have an "obj" field
|
||||
if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
|
||||
Py_VISIT(py_node_fn->obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (autograd_meta) {
|
||||
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
|
||||
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
|
||||
Py_VISIT(pyhook->dict);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -4010,17 +3568,6 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
||||
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
||||
return -1;
|
||||
}
|
||||
// It is important for all three of these to be overridden correctly for the
|
||||
// resurrection checks to properly happen. In particular, an older version
|
||||
// was not overriding tp_clear here. This lead to the default subtype_clear
|
||||
// running on the Tensor object (as only TensorBase tp_clear was custom),
|
||||
// clearing the __dict__ field, before the TensorBase custom clear was called
|
||||
// and would properly detect the resurrect.
|
||||
// See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior
|
||||
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
|
||||
((PyTypeObject*)cls)->tp_traverse =
|
||||
(traverseproc)THPVariable_subclass_traverse;
|
||||
((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear;
|
||||
|
||||
// Don't do anything for the base Tensor class
|
||||
if (!THPVariableClass) {
|
||||
|
||||
@ -17,7 +17,7 @@ namespace py = pybind11;
|
||||
struct THPVariable {
|
||||
PyObject_HEAD
|
||||
// Payload
|
||||
c10::MaybeOwned<at::Tensor> cdata;
|
||||
at::Tensor cdata;
|
||||
// Hooks to be run on backwards pass (corresponds to Python attr
|
||||
// '_backwards_hooks', set by 'register_hook')
|
||||
PyObject* backward_hooks = nullptr;
|
||||
@ -37,7 +37,11 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass;
|
||||
TORCH_PYTHON_API extern PyObject* ParameterClass;
|
||||
|
||||
bool THPVariable_initModule(PyObject* module);
|
||||
TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase&& var);
|
||||
TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var);
|
||||
TORCH_PYTHON_API PyObject* THPVariable_Wrap(
|
||||
const at::TensorBase& var,
|
||||
PyTypeObject* type);
|
||||
|
||||
inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
|
||||
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.
|
||||
@ -69,7 +73,7 @@ inline bool THPVariable_Check(PyObject* obj) {
|
||||
}
|
||||
|
||||
inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
|
||||
return *var->cdata;
|
||||
return var->cdata;
|
||||
}
|
||||
|
||||
inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {
|
||||
|
||||
@ -65,7 +65,9 @@ inline at::Tensor clone_obey_contract(
|
||||
.new_empty_strided_symint(
|
||||
variable.sym_sizes(),
|
||||
variable.sym_strides(),
|
||||
variable.options().memory_format(std::nullopt))
|
||||
variable.options()
|
||||
.memory_format(std::nullopt)
|
||||
.dtype(new_grad.dtype()))
|
||||
.copy_(new_grad));
|
||||
} else {
|
||||
// (2)
|
||||
|
||||
@ -70,6 +70,10 @@ inline PyObject* wrap(const at::Tensor& tensor) {
|
||||
return THPVariable_Wrap(tensor);
|
||||
}
|
||||
|
||||
inline PyObject* wrap(at::Tensor&& tensor) {
|
||||
return THPVariable_Wrap(std::move(tensor));
|
||||
}
|
||||
|
||||
inline PyObject* wrap(const at::Scalar& scalar) {
|
||||
return wrap(scalar_to_tensor(scalar));
|
||||
}
|
||||
|
||||
@ -197,6 +197,22 @@ TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
|
||||
TORCH_API void create_cpp_hook(
|
||||
const at::TensorBase& /*self*/,
|
||||
bool is_retains_grad_hooks = false);
|
||||
|
||||
inline bool is_tensor_stealable(
|
||||
const at::Tensor& new_grad,
|
||||
size_t num_expected_refs = 1) {
|
||||
size_t use_count = new_grad.use_count();
|
||||
if (use_count <= num_expected_refs) {
|
||||
return true;
|
||||
}
|
||||
if (use_count >= 2 &&
|
||||
new_grad.unsafeGetTensorImpl()->pyobj_slot()->has_unique_reference()) {
|
||||
// The Python wrapper, if it exists, also has a reference to the Tensor.
|
||||
num_expected_refs++;
|
||||
}
|
||||
return use_count <= num_expected_refs;
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
||||
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -894,7 +910,7 @@ inline Variable make_variable(
|
||||
bool requires_grad = false,
|
||||
bool allow_tensor_metadata_change = true) {
|
||||
if (data.defined()) {
|
||||
if (data.getIntrusivePtr().use_count() == 1 &&
|
||||
if (impl::is_tensor_stealable(data) &&
|
||||
data.getIntrusivePtr()->unique_version()) {
|
||||
auto data_impl = data.unsafeReleaseIntrusivePtr();
|
||||
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
||||
|
||||
@ -1,19 +1,67 @@
|
||||
#include <torch/csrc/utils/pyobject_preservation.h>
|
||||
|
||||
#include <structmember.h>
|
||||
#include <c10/core/impl/PyObjectSlot.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
void clear_slots(PyTypeObject* type, PyObject* self) {
|
||||
Py_ssize_t n = Py_SIZE(type);
|
||||
PyMemberDef* mp = type->tp_members;
|
||||
namespace torch::utils {
|
||||
|
||||
for (Py_ssize_t i = 0; i < n; i++, mp++) {
|
||||
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
|
||||
char* addr = (char*)self + mp->offset;
|
||||
PyObject* obj = *(PyObject**)addr;
|
||||
if (obj != nullptr) {
|
||||
*(PyObject**)addr = nullptr;
|
||||
Py_DECREF(obj);
|
||||
}
|
||||
}
|
||||
}
|
||||
using c10::intrusive_ptr_target;
|
||||
using c10::impl::PyObjectSlot;
|
||||
|
||||
void PyObjectPreservation::init_fresh_nonatomic(
|
||||
intrusive_ptr_target* target,
|
||||
PyObjectSlot* slot,
|
||||
PyObject* pyobj) {
|
||||
TORCH_INTERNAL_ASSERT(slot->load_pyobj() == nullptr);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
target->combined_refcount_.load(std::memory_order_relaxed) ==
|
||||
c10::detail::kUniqueRef);
|
||||
|
||||
slot->pyobj_.store(pyobj, std::memory_order_relaxed);
|
||||
slot->pyobj_interpreter_.store(
|
||||
c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed);
|
||||
target->combined_refcount_.store(
|
||||
c10::detail::kHasPyObject | c10::detail::kUniqueRef,
|
||||
std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
PyObject* PyObjectPreservation::init_once(
|
||||
intrusive_ptr_target* target,
|
||||
PyObjectSlot* slot,
|
||||
PyObject* pyobj) {
|
||||
PyObject* expected = nullptr;
|
||||
if (!slot->pyobj_.compare_exchange_strong(
|
||||
expected, pyobj, std::memory_order_acq_rel)) {
|
||||
TORCH_INTERNAL_ASSERT(expected != nullptr);
|
||||
return expected;
|
||||
}
|
||||
|
||||
slot->pyobj_interpreter_.store(
|
||||
c10::impl::getGlobalPyInterpreter(), std::memory_order_release);
|
||||
|
||||
bool increfed = false;
|
||||
auto combined = target->combined_refcount_.load(std::memory_order_relaxed);
|
||||
do {
|
||||
TORCH_INTERNAL_ASSERT(!c10::detail::has_pyobject(combined));
|
||||
if (c10::detail::refcount(combined) > 1 && !increfed) {
|
||||
// We need to incref the object to preserve the invariant that
|
||||
// if refcount > 1, the c10 object holds a reference to the PyObject.
|
||||
// This must happen before we set the kHasPyObject bit.
|
||||
Py_INCREF(pyobj);
|
||||
increfed = true;
|
||||
}
|
||||
} while (!target->combined_refcount_.compare_exchange_weak(
|
||||
combined,
|
||||
combined | c10::detail::kHasPyObject,
|
||||
std::memory_order_acq_rel,
|
||||
std::memory_order_relaxed));
|
||||
|
||||
if (increfed && c10::detail::refcount(combined) == 1) {
|
||||
// Fix up if refcount if we did the incref in a failed compare-exchange
|
||||
Py_DECREF(pyobj);
|
||||
}
|
||||
|
||||
return pyobj;
|
||||
}
|
||||
|
||||
} // namespace torch::utils
|
||||
|
||||
@ -4,4 +4,28 @@
|
||||
|
||||
// This file contains utilities used for handling PyObject preservation
|
||||
|
||||
void clear_slots(PyTypeObject* type, PyObject* self);
|
||||
namespace c10 {
|
||||
class intrusive_ptr_target;
|
||||
namespace impl {
|
||||
struct PyObjectSlot;
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::utils {
|
||||
|
||||
class PyObjectPreservation {
|
||||
public:
|
||||
// Store a PyObject wrapper on a fresh c10 wrapper. The caller must hold
|
||||
// a unique reference to `target`.
|
||||
static void init_fresh_nonatomic(
|
||||
c10::intrusive_ptr_target* target,
|
||||
c10::impl::PyObjectSlot* slot,
|
||||
PyObject* pyobj);
|
||||
|
||||
static PyObject* init_once(
|
||||
c10::intrusive_ptr_target* target,
|
||||
c10::impl::PyObjectSlot* slot,
|
||||
PyObject* pyobj);
|
||||
};
|
||||
|
||||
} // namespace torch::utils
|
||||
|
||||
@ -207,12 +207,19 @@ def tensorify_python_scalars(
|
||||
and node.target is torch.ops.aten._local_scalar_dense.default
|
||||
):
|
||||
dtype = node.args[0].meta["val"].dtype
|
||||
if not dtype.is_floating_point:
|
||||
continue
|
||||
|
||||
assert isinstance(node.args[0], fx.Node), node.args[0]
|
||||
|
||||
s = node.meta["val"].node.expr
|
||||
|
||||
expr_to_sym_proxy[s] = MetaProxy(
|
||||
node, tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
|
||||
# only tensorify if the dtype is floating point
|
||||
if not dtype.is_floating_point:
|
||||
continue
|
||||
|
||||
expr_to_tensor_proxy[s] = MetaProxy(
|
||||
node.args[0], tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
@ -220,9 +227,7 @@ def tensorify_python_scalars(
|
||||
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
|
||||
expr_to_tensor_proxy[s], torch.float64
|
||||
)
|
||||
expr_to_sym_proxy[s] = MetaProxy(
|
||||
node, tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
elif (sym_expr := _get_sym_val(node)) is not None:
|
||||
if sym_expr not in expr_to_sym_proxy and not isinstance(
|
||||
|
||||
Reference in New Issue
Block a user