Compare commits

...

4 Commits

Author SHA1 Message Date
2ee209ebd7 Rework PyObject preservation (v2) (#167564)
Summary:
Make the PyObject preservation scheme thread-safe with free threaded (nogil) Python. The general idea is:

* Python Tensor and Storage objects always hold a strong reference to their underlying c10 object
* c10 objects hold a strong reference to their Python objects if there's at least one other reference to the c10 object

This is implemented in `intrusive_ptr`:

* The top most bit (`kHasPyObject`) from the weakref count is now used to indicate if the `intrusive_ptr_target` has an associated PyObject. So `kHasPyObject` is one bit, the weakref count is now 31 bits and the strong refcount remains 32 bits.
* When the reference count increases from one to two and `kHasPyObject` is set, we incref the associated Python object to ensure that it's kept alive.
* When the reference count decreases from two to one (i.e., there are no C++ reference to the `intrusive_ptr_target` other than from the Python object), we decre the associated Python object to break the cycle.

Other benefits:

* We can delete a lot of the copypasta from Python internal `subtype_dealloc`
* This fixes the weakref and GC bugs we had in the previous scheme. Python weakrefs on Tensors and Storages should just work as expected now.

Risks:

* Extra branch for reference count operations on `intrusive_ptr<TensorImpl>`, `intrusive_ptr<StorageImpl>`, and the generic `intrusive_ptr<intrusive_ptr_target>` even when we're not using Python.
* It's a big change

(Second attempt at https://github.com/pytorch/pytorch/pull/166342)

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben albanD


Differential Revision: D86936370

Pulled By: colesbury
2025-11-13 20:32:01 -08:00
2aba180114 Always track _local_scalar_dense output in tensorify_python_scalars. (#166573)
We need to track all symbols, we used to skip
u = item()
and fail with
```
 File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
    expr_to_sym_proxy[expr]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: u0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166573
Approved by: https://github.com/bobrenjc93
2025-11-14 03:51:43 +00:00
45b2c3d312 [OpenReg][Feat][Docs] Enrich OpenReg device management implementation and add focused documentation (#165897)
## Summary
This PR enriches OpenReg device management codes and adds focused documentation.

## Key Changes
- Introduced device management documentation in `device.md`.
- Updated `OpenRegFunctions.h` and `OpenRegFunctions.cpp` to use `DeviceIndex` and added error handling.
- Implemented `check_device_index` function for validating device indices.
- Enhanced Python bindings in `Module.cpp` for device management.
- Added tests for invalid device index handling in `test_device.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165897
Approved by: https://github.com/fffrog
2025-11-14 03:08:23 +00:00
5b1e112cf9 [Dynamo] Imporve-graph-break-skip-logs (#167067)
Fixes #150477

### Summary:

- Added frame information (function name, file, line number) to all graph break/skip messages
- Standardized message format: "torch.compile will skip tracing the frame <name> (<file> line <N>) and fall back to eager. Reason: <reason>"

### Impacts:
module: dynamo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167067
Approved by: https://github.com/williamwen42
2025-11-14 03:06:37 +00:00
48 changed files with 1237 additions and 1254 deletions

View File

@ -245,6 +245,9 @@ class TORCH_API TensorBase {
size_t weak_use_count() const noexcept {
return impl_.weak_use_count();
}
bool is_uniquely_owned() const noexcept {
return impl_.is_uniquely_owned();
}
std::string toString() const;

View File

@ -10,6 +10,13 @@
...
}
{
ignore_empty_generic_uninitialised_conditional_jump
Memcheck:Cond
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
...
}
{
Cond_cuda
Memcheck:Cond

View File

@ -44,7 +44,7 @@ struct C10_API SafePyObject {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
(*pyinterpreter_)->decref(data_);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
~SafePyObject() {
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
(*pyinterpreter_)->decref(data_);
}
}

View File

@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() {
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
}
void StorageImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void StorageImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool StorageImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View File

@ -105,6 +105,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear();
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_);
@ -370,4 +376,18 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
bool resizable,
std::optional<at::Device> device_opt);
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
} // namespace c10

View File

@ -277,7 +277,6 @@ void TensorImpl::release_resources() {
if (storage_) {
storage_ = {};
}
pyobj_slot_.maybe_destroy_pyobj();
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
@ -989,6 +988,30 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
}
}
void TensorImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void TensorImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool TensorImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
namespace impl {
namespace {

View File

@ -2178,6 +2178,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
private:
// See NOTE [std::optional operator usage in CUDA]
// We probably don't want to expose this publicly until
@ -3079,6 +3085,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
friend class C10_TensorImpl_Size_Check_Dummy_Class;
};
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
// Note [TensorImpl size constraints]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Changed the size of TensorImpl? If the size went down, good for

View File

@ -11,8 +11,11 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void incref(PyObject* pyobj) const override {} // do nothing
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing
void decref(PyObject* pyobj) const override {} // do nothing
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
return false;
}
#define PANIC(m) \
TORCH_INTERNAL_ASSERT( \
@ -20,6 +23,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
"attempted to call " #m \
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
size_t refcnt(PyObject* pyobj) const override {
PANIC(refcnt);
}
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
PANIC(detach);
}

View File

@ -18,6 +18,9 @@ namespace c10 {
struct IValue;
class OperatorHandle;
struct TensorImpl;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10
namespace torch::jit {
@ -126,9 +129,12 @@ struct C10_API PyInterpreterVTable {
// Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
virtual void decref(PyObject* pyobj) const = 0;
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
// Run Py_REFCNT on a PyObject.
virtual size_t refcnt(PyObject* pyobj) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this

View File

@ -1,56 +0,0 @@
#include <c10/core/impl/PyObjectSlot.h>
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
}
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
// then the PyObject holds an owning reference to the tensor).
// So it is OK to clear pyobj_ here as it is impossible for it to
// be used again (modulo weak reference races)
pyobj_ = nullptr; // for safety
}
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
}
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
}
void PyObjectSlot::set_owns_pyobj(bool b) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
pyobj_ = reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
}
} // namespace c10::impl

View File

@ -8,117 +8,58 @@
#include <atomic>
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10::impl {
struct C10_API PyObjectSlot {
public:
PyObjectSlot();
~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
//
// NB: This lives in a header so that we can inline away the switch on status
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
// Query the PyObject interpreter. This may return null if there is no
// interpreter. This is racy!
PyInterpreter* pyobj_interpreter();
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
// interpreter.
PyInterpreter* pyobj_interpreter() const {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyInterpreter& load_pyobj_interpreter() const;
PyInterpreter& load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
TORCH_INTERNAL_ASSERT(
interpreter, "cannot access PyObject for Tensor - no interpreter set");
return *interpreter;
}
bool owns_pyobj();
PyObject* load_pyobj() const {
return pyobj_.load(std::memory_order_acquire);
}
void set_owns_pyobj(bool b);
void store_pyobj(PyObject* obj) {
pyobj_.store(obj, std::memory_order_release);
}
bool has_unique_reference() const {
PyObject* pyobj = load_pyobj();
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
}
void clear() {
pyobj_.store(nullptr, std::memory_order_relaxed);
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
}
private:
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
// This is now always the global interpreter if the PyObject is set.
// Maybe we can remove this field some day...
std::atomic<PyInterpreter*> pyobj_interpreter_;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be
// protected by an atomic as it is only allowed to be accessed when you hold
// the GIL, or during destruction of the tensor.
//
// When a PyObject dies, you are obligated to clear this field
// (otherwise, you will try to use-after-free the pyobj); this currently
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
//
// NB: Ordinarily, this should not be a strong reference, as if the
// PyObject owns the Tensor, this would create a reference cycle.
// However, sometimes this ownership flips. To track who owns
// who, this has a single pointer tag indicating whether or not the
// C++ object owns the PyObject (the common case, zero, means PyObject
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
// or check_pyobj for checked access. See references to PyObject
// resurrection in torch/csrc/autograd/python_variable.cpp
PyObject* pyobj_;
// The PyObject representing this Tensor or nullptr. Ownership is managed
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
// reference is already dead.
std::atomic<PyObject*> pyobj_;
friend class torch::utils::PyObjectPreservation;
};
} // namespace c10::impl

View File

@ -12,6 +12,10 @@ template <typename, typename...>
class class_;
}
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10 {
class intrusive_ptr_target;
namespace raw {
@ -33,6 +37,8 @@ constexpr uint64_t kImpracticallyHugeWeakReferenceCount =
constexpr uint64_t kReferenceCountOne = 1;
constexpr uint64_t kWeakReferenceCountOne = (kReferenceCountOne << 32);
constexpr uint64_t kUniqueRef = (kReferenceCountOne | kWeakReferenceCountOne);
// Indicates whether the object has a PyObject wrapper.
constexpr uint64_t kHasPyObject = (uint64_t(1) << 63);
template <class TTarget>
struct intrusive_target_default_null_type final {
@ -55,7 +61,11 @@ inline uint32_t refcount(uint64_t combined_refcount) {
}
inline uint32_t weakcount(uint64_t combined_refcount) {
return static_cast<uint32_t>(combined_refcount >> 32);
return static_cast<uint32_t>((combined_refcount & ~kHasPyObject) >> 32);
}
inline bool has_pyobject(uint64_t combined_refcount) {
return (combined_refcount & kHasPyObject) != 0;
}
// The only requirement for refcount increment is that it happens-before
@ -66,12 +76,6 @@ inline uint64_t atomic_combined_refcount_increment(
return combined_refcount.fetch_add(inc, std::memory_order_relaxed) + inc;
}
inline uint32_t atomic_refcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::refcount(atomic_combined_refcount_increment(
combined_refcount, kReferenceCountOne));
}
inline uint32_t atomic_weakcount_increment(
std::atomic<uint64_t>& combined_refcount) {
return detail::weakcount(atomic_combined_refcount_increment(
@ -99,6 +103,11 @@ inline uint32_t atomic_weakcount_decrement(
combined_refcount, kWeakReferenceCountOne));
}
template <class T, class = void>
struct TargetTraits {
static constexpr bool can_have_pyobject = false;
};
} // namespace detail
/**
@ -155,6 +164,23 @@ class C10_API intrusive_ptr_target {
// we can atomically operate on both at the same time for performance
// and defined behaviors.
//
// Note [PyObject preservation for Tensor and Storages]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// intrusive_ptr has special support for preserving PyObject wrappers
// for TensorImpl and StorageImpl. The most significant bit (kHasPyObject) of
// the combined_refcount_ is used to indicate whether the object has a
// PyObject wrapper.
//
// - The PyObject, if it exists, holds a strong reference to the
// intrusive_ptr_target.
//
// - When the refcount goes from 1 to 2, we incref the PyObject.
//
// - When the refcount goes from 2 to 1, we decref the PyObject.
//
// In other words, the intrusive_ptr keeps the PyObject alive as long as there
// are other C++ references to the intrusive_ptr_target.
mutable std::atomic<uint64_t> combined_refcount_;
static_assert(sizeof(std::atomic<uint64_t>) == 8);
static_assert(alignof(std::atomic<uint64_t>) == 8);
@ -172,6 +198,8 @@ class C10_API intrusive_ptr_target {
template <typename T>
friend struct ExclusivelyOwnedTensorTraits;
friend class torch::utils::PyObjectPreservation;
protected:
// protected destructor. We never want to destruct intrusive_ptr_target*
// directly.
@ -255,6 +283,16 @@ class C10_API intrusive_ptr_target {
*/
virtual void release_resources() {}
/**
* These two methods are called when the refcount transitions between one
* and two and the object has a PyObject wrapper.
*/
virtual void incref_pyobject() const {}
virtual void decref_pyobject() const {}
virtual bool try_incref_pyobject() const {
return false;
}
uint32_t refcount(std::memory_order order = std::memory_order_relaxed) const {
return detail::refcount(combined_refcount_.load(order));
}
@ -265,6 +303,19 @@ class C10_API intrusive_ptr_target {
}
};
namespace detail {
#ifndef C10_MOBILE
template <>
struct TargetTraits<c10::intrusive_ptr_target> {
// A generic intrusive_ptr<intrusive_ptr_target> may actually be a TensorImpl
// or StorageImpl, so we have to allow for PyObject support.
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
template <class TTarget, class NullType>
class weak_intrusive_ptr;
@ -314,18 +365,34 @@ class intrusive_ptr final {
void retain_() {
if (target_ != NullType::singleton()) {
uint32_t new_refcount =
detail::atomic_refcount_increment(target_->combined_refcount_);
uint64_t combined = detail::atomic_combined_refcount_increment(
target_->combined_refcount_, detail::kReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
new_refcount != 1,
"intrusive_ptr: Cannot increase refcount after it reached zero.");
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 1 to 2, we need to incref the
// PyObject. In other words, we need to ensure that the PyObject stays
// alive now that we have a C++ reference to this object in addition to
// the PyObject itself.
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
target_->incref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!detail::has_pyobject(combined),
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
void reset_() noexcept {
if (target_ != NullType::singleton()) {
if (target_->combined_refcount_.load(std::memory_order_acquire) ==
detail::kUniqueRef) {
if (is_uniquely_owned()) {
// Both counts are 1, so there are no weak references and
// we are releasing the last strong reference. No other
// threads can observe the effects of this target_ deletion
@ -337,9 +404,10 @@ class intrusive_ptr final {
auto combined_refcount = detail::atomic_combined_refcount_decrement(
target_->combined_refcount_, detail::kReferenceCountOne);
if (detail::refcount(combined_refcount) == 0) {
bool should_delete =
(combined_refcount == detail::kWeakReferenceCountOne);
uint32_t new_refcount = detail::refcount(combined_refcount);
bool has_pyobject = detail::has_pyobject(combined_refcount);
if (new_refcount == 0) {
bool should_delete = detail::weakcount(combined_refcount) == 1;
// See comment above about weakcount. As long as refcount>0,
// weakcount is one larger than the actual number of weak references.
// So we need to decrement it here.
@ -356,6 +424,18 @@ class intrusive_ptr final {
if (should_delete) {
delete target_;
}
} else if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
// If the refcount transitioned from 2 to 1, we need to decref the
// PyObject. In other words, we don't want to keep the PyObject alive if
// there are no C++ references to this object other than the PyObject
// itself.
if (C10_UNLIKELY(has_pyobject && new_refcount == 1)) {
target_->decref_pyobject();
}
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!has_pyobject,
"TargetTraits indicates that type cannot have PyObject, but refcount has PyObject bit set.");
}
}
}
@ -522,6 +602,16 @@ class intrusive_ptr final {
return use_count() == 1;
}
/**
* Stronger than unique() in that it must not have any weakrefs as well.
*/
bool is_uniquely_owned() const noexcept {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(target_ != NullType::singleton());
uint64_t combined =
target_->combined_refcount_.load(std::memory_order_acquire);
return (combined & ~detail::kHasPyObject) == detail::kUniqueRef;
}
/**
* Returns an owning (!) pointer to the underlying object and makes the
* intrusive_ptr instance invalid. That means the refcount is not decreased.
@ -932,6 +1022,7 @@ class weak_intrusive_ptr final {
if (target_ == NullType::singleton()) {
return intrusive_ptr<TTarget, NullType>();
} else {
bool increfed = false;
auto combined_refcount =
target_->combined_refcount_.load(std::memory_order_relaxed);
do {
@ -940,12 +1031,31 @@ class weak_intrusive_ptr final {
// Return nullptr.
return intrusive_ptr<TTarget, NullType>();
}
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (detail::has_pyobject(combined_refcount) &&
detail::refcount(combined_refcount) == 1 && !increfed) {
// Object has a python wrapper with no other C++ references.
// We need to to incref the Python object before we acquire a
// strong reference to the C++ object to avoid a situation
// where the Python object is deallocated concurrently.
if (!target_->try_incref_pyobject()) {
return intrusive_ptr<TTarget, NullType>();
}
increfed = true;
}
}
} while (!target_->combined_refcount_.compare_exchange_weak(
combined_refcount,
combined_refcount + detail::kReferenceCountOne,
std::memory_order_acquire,
std::memory_order_relaxed));
if constexpr (detail::TargetTraits<TTarget>::can_have_pyobject) {
if (increfed && detail::refcount(combined_refcount) != 1) {
target_->decref_pyobject();
}
}
return intrusive_ptr<TTarget, NullType>(
target_, raw::DontIncreaseRefcount{});
}
@ -1060,7 +1170,18 @@ namespace intrusive_ptr {
// NullType::singleton to this function
inline void incref(intrusive_ptr_target* self) {
if (self) {
detail::atomic_refcount_increment(self->combined_refcount_);
uint64_t combined = detail::atomic_combined_refcount_increment(
self->combined_refcount_, detail::kReferenceCountOne);
#ifndef C10_MOBILE
if (C10_UNLIKELY(
detail::has_pyobject(combined) &&
detail::refcount(combined) == 2)) {
self->incref_pyobject();
}
#else
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!detail::has_pyobject(combined));
#endif
}
}

View File

@ -0,0 +1,113 @@
# Device Management
## Background
Device management handles basic operations like querying how many devices are available and switching between them. Accelerator backends need to wrap their device runtime's APIs and expose them to PyTorch.
The OpenReg implementation ([`OpenRegFunctions.h/cpp`][OpenReg Device Management]) shows how to wrap a third-party runtime. These functions are used throughout the backend - by streams, events, generators, and Python bindings.
## Design
Accelerator vendors need to implement these core functions:
| Function Name | Description | Application Scenarios |
| ------------------------- | ---------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------- |
| `device_count()` | Query the total number of available devices in the system | - Application initialization<br>- Multi-device workload distribution<br>- Validating device indices before use |
| `current_device()` | Get the currently active device for the calling thread | - Debugging and logging<br>- Determining tensor placement<br>- Guard implementations |
| `set_device()` | Change the active device for subsequent operations | - Switching context between devices<br>- Initializing specific device resources<br>- Multi-GPU training loops |
| `exchange_device()` | Atomically swap device and return the previous device | - Implementing device guards<br>- Temporarily switching device context<br>- RAII-based device management |
| `maybe_exchange_device()` | Conditionally exchange device only if the index is valid (-1 OK) | - Safe device switching with optional indices<br>- Guard implementations with nullable device values |
These functions are building blocks for more complex features like streams, events, and memory management. Make sure to validate inputs and handle errors properly.
## Implementation
This section shows how to implement device management using `set_device` as an example. The implementation requires:
1. C++ wrappers around the device runtime
2. Python bindings to expose the C++ functions
3. User-friendly Python APIs
### C++ Side
Wrap the device runtime's API and add error handling. The `SetDevice` function shows this pattern:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG SetDevice FUNCTION
:end-before: LITERALINCLUDE END: OPENREG SetDevice FUNCTION
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG set_device FUNCTION
:end-before: LITERALINCLUDE END: OPENREG set_device FUNCTION
:linenos:
```
### Binding
Expose the C++ functions to Python using pybind11:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: MODULE SET DEVICE HELPER
:end-before: LITERALINCLUDE END: MODULE SET DEVICE HELPER
:linenos:
```
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/csrc/Module.cpp
:language: c++
:start-after: LITERALINCLUDE START: OPENREG MODULE METHODS
:end-before: LITERALINCLUDE END: OPENREG MODULE METHODS
:linenos:
:emphasize-lines: 5
```
### Python Side
Wrap the C++ bindings with user-friendly Python functions:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/openreg/__init__.py
:language: python
:start-after: LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
:end-before: LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
:linenos:
```
Here's the complete mapping from C++ to Python:
| C++ Binding Function | C++ Binding API (pybind11) | Python User API | Description |
| -------------------- | ---------------------------------------- | -------------------------------- | -------------------------------------------- |
| `_getDeviceCount` | `torch_openreg._C._get_device_count()` | `torch.openreg.device_count()` | Returns the total number of devices |
| `_getDevice` | `torch_openreg._C._get_device()` | `torch.openreg.current_device()` | Returns the current active device index |
| `_setDevice` | `torch_openreg._C._set_device(idx)` | `torch.openreg.set_device(idx)` | Sets the active device |
| `_exchangeDevice` | `torch_openreg._C._exchange_device(idx)` | N/A (internal use only) | Atomically swaps device and returns previous |
## Guard
Device guards provide automatic device switching with exception safety. They're similar to lock guards in C++ - they switch device on construction and restore it on destruction.
Implement `DeviceGuardImplInterface` to integrate with PyTorch's guard system:
```{eval-rst}
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegGuard.h
:language: c++
:start-after: LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:end-before: LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
:linenos:
```
**What needs to be implemented:**
1. **exchangeDevice()**: Switch to a new device and return the old one (used by guard constructors)
2. **getDevice()**: Get the current device
3. **setDevice()**: Set the active device
4. **Type checking**: Validate that device type matches the backend
This makes the guard available to PyTorch for the `PrivateUse1` device type. Users can then use standard PyTorch device guards with the custom backend.
[OpenReg Device Management]: https://github.com/pytorch/pytorch/blob/main/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/runtime/OpenRegFunctions.cpp "OpenReg Device Management"

View File

@ -42,6 +42,7 @@ Next, we will delve into each chapter of this guide. Each chapter focuses on a k
:glob:
:maxdepth: 1
device
hooks
autoload
operators

View File

@ -4,17 +4,12 @@
#include <c10/util/Exception.h>
void orCheckFail(
const char* func,
const char* file,
uint32_t line,
const char* msg = "");
void orCheckFail(const char* func, const char* file, uint32_t line, const char* msg = "");
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (__err != orSuccess) { \
orCheckFail( \
__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
#define OPENREG_CHECK(EXPR, ...) \
do { \
const orError_t __err = EXPR; \
if (C10_UNLIKELY(__err != orSuccess)) { \
orCheckFail(__func__, __FILE__, static_cast<uint32_t>(__LINE__), ##__VA_ARGS__); \
} \
} while (0)

View File

@ -1,3 +1,4 @@
#include <c10/util/Exception.h>
#include <include/openreg.h>
#include "OpenRegException.h"
@ -9,21 +10,22 @@ orError_t GetDeviceCount(int* dev_count) {
return orGetDeviceCount(dev_count);
}
orError_t GetDevice(c10::DeviceIndex* device) {
orError_t GetDevice(DeviceIndex* device) {
int tmp_device = -1;
auto err = orGetDevice(&tmp_device);
*device = static_cast<c10::DeviceIndex>(tmp_device);
*device = static_cast<DeviceIndex>(tmp_device);
return err;
}
orError_t SetDevice(c10::DeviceIndex device) {
// LITERALINCLUDE START: OPENREG SetDevice FUNCTION
orError_t SetDevice(DeviceIndex device) {
int cur_device = -1;
orGetDevice(&cur_device);
OPENREG_CHECK(orGetDevice(&cur_device));
if (device == cur_device) {
return orSuccess;
}
return orSetDevice(device);
}
// LITERALINCLUDE END: OPENREG SetDevice FUNCTION
int device_count_impl() {
int count = 0;
@ -31,34 +33,37 @@ int device_count_impl() {
return count;
}
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept {
OPENREG_EXPORT DeviceIndex device_count() noexcept {
// initialize number of devices only once
static int count = []() {
try {
auto result = device_count_impl();
TORCH_CHECK(
result <= std::numeric_limits<c10::DeviceIndex>::max(),
result <= std::numeric_limits<DeviceIndex>::max(),
"Too many devices, DeviceIndex overflowed");
return result;
} catch (const c10::Error& ex) {
} catch (const Error& ex) {
// We don't want to fail, but still log the warning
// msg() returns the message without the stack trace
TORCH_WARN("Device initialization: ", ex.msg());
return 0;
}
}();
return static_cast<c10::DeviceIndex>(count);
return static_cast<DeviceIndex>(count);
}
OPENREG_EXPORT c10::DeviceIndex current_device() {
c10::DeviceIndex cur_device = -1;
GetDevice(&cur_device);
OPENREG_EXPORT DeviceIndex current_device() {
DeviceIndex cur_device = -1;
OPENREG_CHECK(GetDevice(&cur_device));
return cur_device;
}
OPENREG_EXPORT void set_device(c10::DeviceIndex device) {
SetDevice(device);
// LITERALINCLUDE START: OPENREG set_device FUNCTION
OPENREG_EXPORT void set_device(DeviceIndex device) {
check_device_index(device);
OPENREG_CHECK(SetDevice(device));
}
// LITERALINCLUDE END: OPENREG set_device FUNCTION
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
int current_device = -1;
@ -71,4 +76,8 @@ OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) {
return current_device;
}
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device) {
check_device_index(to_device);
return ExchangeDevice(to_device);
}
} // namespace c10::openreg

View File

@ -9,10 +9,20 @@
namespace c10::openreg {
OPENREG_EXPORT c10::DeviceIndex device_count() noexcept;
OPENREG_EXPORT c10::DeviceIndex current_device();
OPENREG_EXPORT void set_device(c10::DeviceIndex device);
OPENREG_EXPORT DeviceIndex device_count() noexcept;
OPENREG_EXPORT DeviceIndex current_device();
OPENREG_EXPORT void set_device(DeviceIndex device);
OPENREG_EXPORT DeviceIndex maybe_exchange_device(DeviceIndex to_device);
OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device);
static inline void check_device_index(int64_t device) {
TORCH_CHECK(device >= 0 && device < c10::openreg::device_count(),
"The device index is out of range. It must be in [0, ",
static_cast<int>(c10::openreg::device_count()),
"), but got ",
static_cast<int>(device),
".");
}
} // namespace c10::openreg

View File

@ -2,6 +2,8 @@
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG GUARD REGISTRATION
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);
// LITERALINCLUDE END: OPENREG GUARD REGISTRATION
} // namespace c10::openreg

View File

@ -11,6 +11,7 @@
namespace c10::openreg {
// LITERALINCLUDE START: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = c10::DeviceType::PrivateUse1;
@ -58,6 +59,7 @@ struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
set_device(d.index());
}
// LITERALINCLUDE END: OPENREG DEVICE MGMT GUARD IMPL EXAMPLE
/**
* Set the current device to c10::Device, without checking for errors

View File

@ -27,6 +27,10 @@ class TestDevice(TestCase):
self.assertEqual(torch.accelerator.current_device_index(), 1)
self.assertEqual(torch.accelerator.current_device_index(), device)
def test_invalid_device_index(self):
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
torch.accelerator.set_device_index(2)
if __name__ == "__main__":
run_tests()

View File

@ -34,18 +34,21 @@ static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) {
}
// LITERALINCLUDE END: OPENREG GET DEFAULT GENERATOR
// LITERALINCLUDE START: MODULE SET DEVICE HELPER
PyObject* _setDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice");
auto device = THPUtils_unpackLong(arg);
auto device = THPUtils_unpackDeviceIndex(arg);
torch::utils::device_lazy_init(at::kPrivateUse1);
c10::openreg::set_device(static_cast<c10::DeviceIndex>(device));
c10::openreg::set_device(device);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
// LITERALINCLUDE END: MODULE SET DEVICE HELPER
PyObject* _exchangeDevice(PyObject* self, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice");

View File

@ -41,8 +41,13 @@ def current_device():
return torch_openreg._C._get_device()
# LITERALINCLUDE START: PYTHON SET DEVICE FUNCTION
def set_device(device) -> None:
return torch_openreg._C._set_device(device)
if device >= 0:
torch_openreg._C._set_device(device)
# LITERALINCLUDE END: PYTHON SET DEVICE FUNCTION
def init():

View File

@ -952,7 +952,9 @@ User code traceback:
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: skip: from user code at:
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
assert x is None
""",
@ -1078,6 +1080,88 @@ from user code:
""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_skipped_frame_with_verbose_traceback(self, records):
def fn(x):
with GenericCtxMgr():
torch._dynamo.graph_break()
return x + 1
torch.compile(fn, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
self.assertExpectedInline(
munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0),
"""\
Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
from user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
@make_logging_test(graph_breaks=True)
def test_skip_frame_in_loop_message(self, records):
def fn(x):
for i in range(2):
with GenericCtxMgr():
if x.sum() > 0:
x = x + 1
return x
torch.compile(fn, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.
torch.compile will skip tracing the frame fn (test_error_messages.py line N) and fall back to eager.
The graph break occurred in the following user code:
File "test_error_messages.py", line N, in fn
if x.sum() > 0:
""",
)
@make_logging_test(dynamo=logging.DEBUG)
def test_skip_frame_empty_function_message(self, records):
def empty_fn(x):
pass
torch.compile(empty_fn, backend="eager")(torch.randn(3))
skip_messages = [
r
for r in records
if "intentionally decided to skip the frame" in r.getMessage()
]
self.assertEqual(len(skip_messages), 1)
msg = munge_exc(skip_messages[0].getMessage(), suppress_suffix=True, skip=0)
msg = re.sub(r" (\d+)$", r" N", msg, flags=re.MULTILINE)
self.assertExpectedInline(
msg,
"""\
Skipping frame torch.compile intentionally decided to skip the frame empty_fn (test_error_messages.py line N) and fall back to eager.
Reason: no content in function call empty_fn test_error_messages.py N""",
)
@make_logging_test(graph_breaks=True)
def test_nested_compile_user_frames(self, records):
def fn(x):
@ -1624,6 +1708,110 @@ from user code:
)
class NestedGraphBreakLoggingTests(
LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks
):
@make_logging_test(graph_breaks=True)
def test_skipped_frame_with_verbose_traceback_nested(self, records):
global f1, f2, f3
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def f1(x):
with GenericCtxMgr():
torch._dynamo.graph_break()
return x + 1
def f2(x):
return f1(x + 2)
def f3(x):
return f2(x + 3)
torch.compile(f3, backend="eager")(torch.randn(3))
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
Graph break under GenericContextWrappingVariable
Explanation: Attempted to graph break in an active context manager(s) that doesn't support graph breaking.
Hint: Move the offending context manager(s) to outside the compiled region.
Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one.
Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr)]
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html
User code traceback:
File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested
torch.compile(f3, backend="eager")(torch.randn(3))
File "test_error_messages.py", line N, in f3
return f2(x + 3)
File "test_error_messages.py", line N, in f2
return f1(x + 2)
File "test_error_messages.py", line N, in f1
torch._dynamo.graph_break()
""",
)
@make_logging_test(graph_breaks=True)
def test_skip_frame_in_loop_message_nested(self, records):
global f1, f2, f3
class GenericCtxMgr:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def f1(x):
for i in range(2):
with GenericCtxMgr():
if x.sum() > 0:
x = x + 1
return x
def f2(x):
return f1(x + 4)
def f3(x):
return f2(x + 5)
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
self.assertEqual(len(records), 1)
self.assertExpectedInline(
munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered:
Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
User code traceback:
File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested
result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841
File "test_error_messages.py", line N, in f3
return f2(x + 5)
File "test_error_messages.py", line N, in f2
return f1(x + 4)
File "test_error_messages.py", line N, in f1
if x.sum() > 0:
""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -14036,6 +14036,44 @@ class DynamoOpPromotionTests(torch._dynamo.test_case.TestCase):
except Exception as e:
self.fail(f"torch.compile failed with error: {e}")
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_tensorify_track_item_symint(self):
def _random_resize(image: torch.Tensor):
image_metanet = image
default_patch_size = 14
rand_cnn_resolution = (224, 256)
min_nump = rand_cnn_resolution[0] // default_patch_size
max_nump = rand_cnn_resolution[1] // default_patch_size
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
torch._check(new_nump > 0)
torch._check(new_nump * default_patch_size > 1)
image_metanet = F.interpolate(
image_metanet,
size=(new_nump * default_patch_size, new_nump * default_patch_size),
mode="bilinear",
align_corners=True,
)
img_h_new, img_w_new = image_metanet.shape[2:]
return (img_h_new, img_w_new), image_metanet
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
# Test the function
input_tensor = torch.rand(1, 3, 224, 224)
(h, w), output = _random_resize_compiled(input_tensor)
# Verify output properties
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], 3)
self.assertEqual(output.shape[2], h)
self.assertEqual(output.shape[3], w)
self.assertTrue(h % 14 == 0)
self.assertTrue(w % 14 == 0)
self.assertTrue(224 <= h <= 256)
self.assertTrue(224 <= w <= 256)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -10895,6 +10895,34 @@ get_out().sum().backward()
self.assertTrue(gradcheck(func, x, fast_mode=True))
def test_grad_thread_safety(self):
import threading
from concurrent.futures import ThreadPoolExecutor
NUM_ITERS = 10
NUM_THREADS = 4
# Concurrent calls to tensor.untyped_storage()
def access_grad(tensor, barrier):
barrier.wait()
return weakref.ref(tensor.grad)
for i in range(NUM_ITERS):
tensor = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
(tensor**2).sum().backward()
barrier = threading.Barrier(NUM_THREADS)
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
futures = [
executor.submit(access_grad, tensor, barrier)
for _ in range(NUM_THREADS)
]
# Check that all the grad tensors returned were the same
for future in futures:
self.assertEqual(future.result()(), tensor.grad)
self.assertIsNotNone(tensor.grad)
def index_perm_variable(shape, max_indices):
if not isinstance(shape, tuple):

View File

@ -259,7 +259,8 @@ class TestTorchDeviceType(TestCase):
def test_storage_use_count(self, device):
a = torch.randn(10, device=device)
prev_cf = torch._C._storage_Use_Count(a.untyped_storage()._cdata)
self.assertEqual(prev_cf, 1)
# Two references: 'a' and the wrapper returned by untyped_storage()
self.assertEqual(prev_cf, 2)
b = a.view(2, 5)
self.assertEqual(torch._C._storage_Use_Count(b.untyped_storage()._cdata), prev_cf + 1)
@ -9324,7 +9325,7 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
member_var = object()
err_msg = "Creating a Tensor subclass from a class that does not inherit from Tensor"
with self.assertRaisesRegex(RuntimeError, err_msg):
with self.assertRaisesRegex(TypeError, err_msg):
s0 = t0.as_subclass(BadSubTensor)
# FIXME: Port to a test suite that better fits slicing
@ -10324,20 +10325,21 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_tensor_dead_weak_ref(self):
x = torch.empty(2)
x = torch.ones(2)
w_x = weakref.ref(x)
y = torch.empty(2)
y = torch.ones(2)
y.grad = x
del x
x = w_x()
# Ideally, x would keep the tensor live. But CPython doesn't
# provide enough hooks to do this. So it will go dead and x
# will transmute into an undefined tensor. Not great, but the
# best we can do.
# x should keep the tensor live. This didn't happen in earlier PyTorch
# versions.
del y
self.assertRaises(RuntimeError, lambda: x.sigmoid())
self.assertEqual(2, x.sum())
del x
self.assertIsNone(w_x())
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1993")
def test_storage_dead_weak_ref(self):
@ -10345,16 +10347,9 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
w_x = weakref.ref(x)
y = torch.tensor(x)
del x
x = w_x()
# Ideally, x would keep the storage live. But CPython doesn't
# provide enough hooks to do this. So it will go dead and x
# will transmute into storage with null StorageImpl. Not great, but the
# best we can do.
self.assertIsNotNone(w_x())
del y
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x[0])
self.assertRaisesRegex(RuntimeError, "Got a null Storage", lambda: x.float())
self.assertIsNone(w_x())
def test_tensor_resurrected_weak_ref(self):
x = torch.empty(2)
@ -10415,6 +10410,31 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
self.assertTrue(called)
def test_storage_thread_safety(self):
import threading
from concurrent.futures import ThreadPoolExecutor
NUM_ITERS = 10
NUM_THREADS = 4
# Concurrent calls to tensor.untyped_storage()
def access_untyped_storage(tensor, barrier):
barrier.wait()
return weakref.ref(tensor.untyped_storage())
for i in range(NUM_ITERS):
tensor = torch.tensor([1.0, 2.0, 3.0])
barrier = threading.Barrier(NUM_THREADS)
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
futures = [
executor.submit(access_untyped_storage, tensor, barrier)
for _ in range(NUM_THREADS)
]
# Check that all the storages returned were the same
for future in futures:
self.assertEqual(future.result()(), tensor.untyped_storage())
# FIXME: move to test_linalg
@torch.inference_mode()
def test_bmm_multithreaded(self):

View File

@ -1870,7 +1870,7 @@ class ConvertFrame:
raise
soft_fail = isinstance(e, Unsupported)
code = frame.f_code
# This is a soft failure. In the sense, the code path reaches here
# when we do not support graph breaks on bytecodes like LOAD_ATTR,
# BUILD_SET etc. In such case, we can fallback to eager without
@ -1885,7 +1885,13 @@ class ConvertFrame:
user_stack_formatted = "".join(
traceback.format_list(user_stack)
)
user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}"
frame_info = exc.format_frame_info(code)
user_stack_trace = (
"Graph break: torch.compile cannot properly resume from this graph break, which results in a skip.\n"
f"torch.compile will skip tracing the frame {frame_info} and fall back to eager.\n"
"The graph break occurred in the following user code:\n"
f"{user_stack_formatted}"
)
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
@ -1897,6 +1903,7 @@ class ConvertFrame:
graph_break_log.debug(
user_stack_trace,
exc_info=True,
stack_info=config.verbose,
)
if not config.suppress_errors and not soft_fail:

View File

@ -794,6 +794,38 @@ def format_error_msg_verbose(
return msg
def format_frame_info(code: types.CodeType) -> str:
return (
f"{getattr(code, 'co_name', '<unknown>')} "
f"({getattr(code, 'co_filename', '<unknown>')} "
f"line {getattr(code, 'co_firstlineno', 0)})"
)
def format_skip_frame_message(code: Optional[types.CodeType], reason: str) -> str:
if code is not None:
frame_info = format_frame_info(code)
return (
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: {reason}"
)
else:
return (
f"torch.compile intentionally decided to skip the frame and fall back to eager.\n"
f"Reason: {reason}"
)
def format_loop_skip_frame_message(code: types.CodeType, frame_summary: str) -> str:
frame_info = format_frame_info(code)
return (
"Skipping frame because there is a graph break in a for/while loop\n"
f"torch.compile intentionally decided to skip the frame {frame_info} and fall back to eager.\n"
f"Reason: Skipping frame because there is a graph break in a for/while loop.\n"
f"{frame_summary}"
)
def format_error_msg(
exc: Exception,
code: types.CodeType,

View File

@ -94,6 +94,8 @@ from .exc import (
BackendCompilerFailed,
collapse_resume_frames,
format_graph_break_message,
format_loop_skip_frame_message,
format_skip_frame_message,
get_stack_above_dynamo,
ResumePrologueTracingError,
StepUnsupported,
@ -605,9 +607,9 @@ def generic_jump(
)
# compile a partial subgraph prefix then jump into user code
if self.maybe_has_backedge():
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"{self.frame_summary()}"
msg = format_loop_skip_frame_message(
self.f_code,
"".join(traceback.format_list([self.frame_summary()])),
)
log.info(msg)
raise exc.SkipFrame(msg)
@ -883,9 +885,9 @@ def break_graph_if_unsupported(
)
if self.maybe_has_backedge():
msg = (
"Skipping frame because there is a graph break in a for/while loop\n"
f"{self.frame_summary()}"
msg = format_loop_skip_frame_message(
self.f_code,
"".join(traceback.format_list([self.frame_summary()])),
)
log.info(msg)
raise exc.SkipFrame(msg) from excp
@ -4626,8 +4628,9 @@ class InstructionTranslator(InstructionTranslatorBase):
and not self.error_on_graph_break
and not self.is_tracing_resume_prologue
):
raise exc.SkipFrame("because no content in function call")
raise exc.SkipFrame(
format_skip_frame_message(self.f_code, "no content in function call")
)
self.instruction_pointer = None
_step_logger()(
logging.INFO,

View File

@ -2248,12 +2248,15 @@ def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
try:
val.data_ptr() # will throw for functorch tensors
except RuntimeError as e:
from .exc import SkipFrame
from .exc import format_skip_frame_message, SkipFrame
# This will be GradTrackingTensor/BatchedTensor/etc
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
raise SkipFrame(
f"torch.compile cannot be run in context: {functorch_subclass_name}"
format_skip_frame_message(
None,
f"torch.compile cannot be run in context: {functorch_subclass_name}",
)
) from e

View File

@ -42,6 +42,7 @@ from torch._guards import Source
from .. import config, graph_break_hints, polyfills, variables
from ..bytecode_transformation import create_call_function, create_rot_n, is_generator
from ..exc import (
format_skip_frame_message,
get_dynamo_observed_exception,
handle_observed_exception,
InfiniteGeneratorError,
@ -1652,8 +1653,13 @@ class SkipFunctionVariable(VariableTracker):
skip_frame_msg = kwargs.get("msg")
if skip_frame_msg:
skip_frame_msg = skip_frame_msg.as_python_constant()
else:
skip_frame_msg = ""
raise SkipFrame(
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
format_skip_frame_message(
tx.f_code,
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}",
)
)
elif self.value is torch._dynamo.step_unsupported:
raise StepUnsupported

View File

@ -536,9 +536,14 @@ class StorageWeakRefWrapper:
if self.extra_ref_check is not None and not self.extra_ref_check():
return False
# if extra_ref_check is not None we expect an additional reference
stor_count = torch._C._storage_Use_Count(self.ref.cdata)
return (stor_count - (self.extra_ref_check is not None)) == 0
if self.extra_ref_check is not None:
# if extra_ref_check is not None we expect two additional references:
# - one from the Python storage object
# - one from the cached Tensor
stor_count -= 2
assert stor_count >= 0
return stor_count == 0
def __repr__(self) -> str:
if self.ref is None or self.ref.expired():
@ -1439,7 +1444,15 @@ class CUDAGraphNode:
self_loc = self_ref()
if self_loc is None:
return False
return self_loc.get_output_refcount(i) == 2
refcount = self_loc.get_output_refcount(i)
# pyrefly: ignore
if self_loc.cached_tensor_outputs[i]._use_count() > 1:
# c10::Tensor may also holds one reference count
assert refcount >= 3
return refcount == 3
else:
assert refcount >= 2
return refcount == 2
check = functools.partial(check_refcount, i=i)

View File

@ -891,10 +891,14 @@ class TorchLogsFormatter(logging.Formatter):
# exception handling - copied from logging.Formatter.format
s = record.message
if record.exc_info:
from torch._dynamo import config
should_format_exc = config.verbose or artifact_name != "graph_breaks"
# Cache the traceback text to avoid converting it multiple times
# (it's constant anyway)
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if should_format_exc:
if not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
if s[-1:] != "\n":
s = s + "\n"

View File

@ -398,36 +398,27 @@ static PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
// weak_use_count() adds 1 if use_count is non-zero
TORCH_CHECK(
a->cdata->weak_use_count() == 1,
a->cdata.weak_use_count() == 1,
"Expected no weakrefs to t1's Tensor object but got ",
a->cdata->weak_use_count() - 1);
a->cdata.weak_use_count() - 1);
TORCH_CHECK(
b->cdata->weak_use_count() == 1,
b->cdata.weak_use_count() == 1,
"Expected no weakrefs to t2's Tensor object but got ",
b->cdata->weak_use_count() - 1);
b->cdata.weak_use_count() - 1);
// NB: Creating local copies of *both* Tensors here ensures that they each
// hold a strong reference to their PyObject. This avoids having to fix up
// reference counts when we swap the PyObject slots below.
at::Tensor tmp_a = a->cdata;
at::Tensor tmp_b = b->cdata;
// Swap the Tensor Impl
c10::MaybeOwned<at::Tensor> tmp = a->cdata;
a->cdata = tmp_b;
b->cdata = tmp_a;
// The TensorImpls contain PyObjectSlots that have a reference to the PyObject
// associated with the TensorImpl. Swap this field as well.
std::optional<PyObject*> mb_obj_a =
a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
std::optional<PyObject*> mb_obj_b =
b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
TORCH_INTERNAL_ASSERT(
mb_obj_a.has_value() && mb_obj_b.has_value(),
"Both tensors should have PyObjects tagged by the current python interpreter");
TORCH_CHECK(mb_obj_a.value() == a_);
TORCH_CHECK(mb_obj_b.value() == b_);
a->cdata = b->cdata;
b->cdata = tmp;
a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(a_);
b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(b_);
// Fix up the PyObjects associated with each TensorImpl
a->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(a_);
b->cdata.unsafeGetTensorImpl()->pyobj_slot()->store_pyobj(b_);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS

View File

@ -45,7 +45,9 @@ struct ConcretePyInterpreterVTable final
std::string name() const override;
void incref(PyObject* pyobj) const override;
void decref(PyObject* pyobj, bool has_pyobj_slot) const override;
void decref(PyObject* pyobj) const override;
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override;
size_t refcnt(PyObject* pyobj) const override;
// TODO: Need to make this work for StorageImpl too. I imagine I'll want to
// operate upon a PyObjectSlot rather than a TensorImpl
@ -235,53 +237,13 @@ py::object torchDispatchFromTensorImpl(
TorchFunctionName::TorchDispatch));
}
// NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
// Before calling PyInterpreter::decref, we must statically know if the
// pyobj has a PyObjectSlot or not.
// - If it has a PyObjectSlot, we need to be careful about PyObject resurrection
// - If it does not have a PyObjectSlot, we can freely decref
// One alternative to this is using PyObject_IsInstance
// to get at this information. However, we don't want to risk an incorrect
// `__instancecheck__` changing the semantics here.
void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot)
const {
void ConcretePyInterpreterVTable::decref(PyObject* pyobj) const {
// Leak the pyobj if not initialized. This can happen if we are running
// exit handlers that are destructing tensors with residual (owned)
// PyObjects stored in them.
if (!Py_IsInitialized())
return;
pybind11::gil_scoped_acquire gil;
// Two possibilities:
// 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or
// Storage. Then we must be careful about PyObject resurrection (see
// THPVariable_clear).
// 2. We are decref-ing some other Python object. We don't do
// PyObject resurrection on non-Tensors, so we just carry on as usual
if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) {
if (THPVariable_Check(pyobj)) {
// It's still alive! This can happen if a weak ref resurrected
// the PyObject without flipping ownership. At this point it is
// too late to rescue the object, so just stub out the PyObject
// so that it fails on subsequent uses. Don't raise an error here;
// you're probably in a destructor.
TORCH_WARN(
"Deallocating Tensor that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"Tensor and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this tensor via the PyObject will now fail.");
(reinterpret_cast<THPVariable*>(pyobj))->cdata =
c10::MaybeOwned<torch::autograd::Variable>();
} else if (THPStorage_Check(pyobj)) {
TORCH_WARN(
"Deallocating UntypedStorage that still has live PyObject references. "
"This probably happened because you took out a weak reference to "
"UntypedStorage and didn't call _fix_weakref() after dereferencing it. "
"Subsequent accesses to this storage via the PyObject will now fail.");
(reinterpret_cast<THPStorage*>(pyobj))->cdata =
c10::MaybeOwned<c10::Storage>();
}
}
Py_DECREF(pyobj);
}
@ -292,6 +254,25 @@ void ConcretePyInterpreterVTable::incref(PyObject* pyobj) const {
Py_INCREF(pyobj);
}
bool ConcretePyInterpreterVTable::try_incref(
const c10::impl::PyObjectSlot& pyobj_slot) const {
if (!Py_IsInitialized())
return false;
pybind11::gil_scoped_acquire gil;
PyObject* pyobj = pyobj_slot.load_pyobj();
if (!pyobj) {
return false;
}
return PyUnstable_TryIncRef(pyobj);
}
size_t ConcretePyInterpreterVTable::refcnt(PyObject* pyobj) const {
if (!Py_IsInitialized() || pyobj == nullptr)
return 0;
pybind11::gil_scoped_acquire gil;
return Py_REFCNT(pyobj);
}
bool isPythonTensor(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
}
@ -620,11 +601,7 @@ static void set_tensor_attr_with_capsule(
const c10::TensorImpl* tensor,
py::capsule& capsule,
const char* attr_name) {
std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
auto obj = mb_obj.value();
PyObject* obj = tensor->pyobj_slot()->load_pyobj();
py::handle(obj).attr(attr_name) = capsule;
}
@ -648,11 +625,7 @@ static c10::ArrayRef<T> get_set_cached_attr(
const c10::TensorImpl* tensor,
const char* base_attr_name,
const py::object& obj) {
std::optional<PyObject*> mb_obj =
tensor->pyobj_slot()->check_pyobj(getPyInterpreter());
TORCH_CHECK(
mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value");
auto tensor_obj = mb_obj.value();
PyObject* tensor_obj = tensor->pyobj_slot()->load_pyobj();
auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len");
bool is_buffer_allocated = false;

View File

@ -23,6 +23,8 @@
#include <c10/util/intrusive_ptr.h>
#include <fmt/format.h>
using torch::utils::PyObjectPreservation;
template <>
void THPPointer<c10::StorageImpl>::free() {
if (ptr) {
@ -32,238 +34,72 @@ void THPPointer<c10::StorageImpl>::free() {
PyTypeObject* THPStorageClass = nullptr;
PyObject* THPStorage_NewWithStorage(
PyTypeObject* type,
c10::Storage _storage,
bool allow_preexisting_pyobj) {
TORCH_CHECK(
PyType_IsSubtype(type, &THPStorageType),
"Creating a Storage subclass from a class that does not inherit from ",
"Storage is not possible. Make sure your class inherits from Storage.");
auto maybe_pyobj = _storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
if (maybe_pyobj.has_value() && maybe_pyobj.value()) {
TORCH_CHECK(
allow_preexisting_pyobj,
"Creating a new Storage subclass ",
type->tp_name,
" but the raw Storage object is already associated to a python object ",
"of type ",
maybe_pyobj.value()->ob_type->tp_name);
PyObject* obj = *maybe_pyobj;
PyTypeObject* obj_type = Py_TYPE(obj);
TORCH_CHECK(
obj_type == type || PyType_IsSubtype(obj_type, type),
"Creating a new Storage subclass ",
type->tp_name,
" but the raw Storage object is already associated to a python object ",
"of type ",
maybe_pyobj.value()->ob_type->tp_name,
" which is not a subclass of the "
"requested type");
return THPStorage_Wrap(std::move(_storage));
}
// Create a new Python Storage object, but don't set the pyobj slot on the
// c10::Storage object.
static PyObject* THPStorage_New(PyTypeObject* type, c10::Storage _storage) {
PyObject* obj = type->tp_alloc(type, 0);
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
auto s = reinterpret_cast<THPStorage*>(obj);
new (&s->cdata) c10::MaybeOwned<c10::Storage>();
s->cdata = c10::MaybeOwned<c10::Storage>::owned(std::move(_storage));
if (!c10::impl::HermeticPyObjectTLS::get_state()) {
s->is_hermetic = false;
const auto& storage = THPStorage_Unpack(s);
storage.unsafeGetStorageImpl()->pyobj_slot()->init_pyobj(obj);
} else {
s->is_hermetic = true;
}
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in
// free-threaded Python.
PyUnstable_EnableTryIncRef(obj);
auto s = (THPStorage*)obj;
new (&s->cdata) c10::Storage(std::move(_storage));
return obj;
}
// Wraps the c10::Storage with a storage PyObject
// Create a new Python Storage object for a new c10::Storage, and set the
// pyobj slot. The c10::Storage must not already have a pyobj set.
PyObject* THPStorage_NewWithStorage(PyTypeObject* type, c10::Storage _storage) {
TORCH_CHECK(
type == THPStorageClass || PyType_IsSubtype(type, &THPStorageType),
"Creating a Storage subclass from a class that does not inherit from ",
"Storage is not possible. Make sure your class inherits from Storage.");
TORCH_INTERNAL_ASSERT(_storage.use_count() == 1);
c10::StorageImpl* storage_impl = _storage.unsafeGetStorageImpl();
PyObject* obj = THPStorage_New(type, std::move(_storage));
PyObjectPreservation::init_fresh_nonatomic(
storage_impl, storage_impl->pyobj_slot(), obj);
return obj;
}
// Returns a PyObject wrapper for the c10::Storage object. The existing
// wrapper is returned if it already exists.
PyObject* THPStorage_Wrap(c10::Storage storage) {
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage));
return THPStorage_New(THPStorageClass, std::move(storage));
}
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot();
std::optional<PyObject*> maybe_pyobj = pyobj_slot->check_pyobj(
/*ignore_hermetic_tls=*/false);
if (maybe_pyobj.has_value()) {
auto obj = *maybe_pyobj;
if (obj) {
TORCH_CHECK(
THPStorage_Check(obj),
"Expected a storage type, but got ",
Py_TYPE(obj)->tp_name);
if (pyobj_slot->owns_pyobj()) {
pyobj_slot->set_owns_pyobj(false);
reinterpret_cast<THPStorage*>(obj)->cdata =
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
return obj;
} else {
Py_INCREF(obj);
return obj;
}
}
PyObject* obj = pyobj_slot->load_pyobj();
if (obj) {
return Py_NewRef(obj);
}
return THPStorage_NewWithStorage(THPStorageClass, std::move(storage));
obj = THPStorage_New(THPStorageClass, std::move(storage));
PyObject* wrapper =
PyObjectPreservation::init_once(storage_impl, pyobj_slot, obj);
if (wrapper != obj) {
// Another thread beat us to it
Py_DECREF(obj);
return Py_NewRef(wrapper);
}
return obj;
}
static bool THPStorage_isPreservable(THPStorage* self) {
if (self->cdata.unsafeIsBorrowed()) {
return false;
}
auto const& storage = THPStorage_Unpack(self);
if (self->is_hermetic) {
return false;
}
if (storage.unsafeGetStorageImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/true) != reinterpret_cast<PyObject*>(self)) {
return false;
}
if (storage.use_count() <= 1) {
return false;
}
return true;
}
static bool THPStorage_tryPreserve(THPStorage* self) {
if (!THPStorage_isPreservable(self)) {
return false;
}
const auto& storage = THPStorage_Unpack(self);
c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl();
auto maybe_pyobj = storage_impl->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/true);
// NOTE: It is possible to just set the PyObjectSlot here, but the point is
// that we should have already set PyObjectSlot when the storage PyObject
// was created.
TORCH_INTERNAL_ASSERT(
maybe_pyobj.has_value(),
"Trying to preserve a Python storage whose PyObjectSlot does not have a PyObject");
PyObject* pyobj = *maybe_pyobj;
TORCH_CHECK(
THPStorage_Check(pyobj),
"Expected a storage type, but got ",
Py_TYPE(pyobj)->tp_name);
TORCH_INTERNAL_ASSERT(
(void*)pyobj == (void*)self,
"Python storage and the PyObject in the internal PyObjectSlot are not at the same address");
TORCH_INTERNAL_ASSERT(!storage_impl->pyobj_slot()->owns_pyobj());
storage_impl->pyobj_slot()->set_owns_pyobj(true);
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
// ensure the PyObject is in a valid state
_Py_NewReference(reinterpret_cast<PyObject*>(self));
self->cdata = c10::MaybeOwned<c10::Storage>::borrowed(storage);
return true;
}
static void THPStorage_subclass_dealloc(PyObject* self) {
static void THPStorage_dealloc(PyObject* self) {
THPStorage* _self = reinterpret_cast<THPStorage*>(self);
if (THPStorage_tryPreserve(_self)) {
return;
auto pyobj_slot = _self->cdata.unsafeGetStorageImpl()->pyobj_slot();
if (pyobj_slot->load_pyobj() == self) {
TORCH_INTERNAL_ASSERT(_self->cdata.use_count() == 1);
pyobj_slot->clear();
}
// Some subclass of StorageBase could be GC-tracked objects even
// though the base class is not
auto* type = Py_TYPE(self);
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
PyObject_GC_UnTrack(self);
}
bool has_finalizer = type->tp_finalize || type->tp_del;
if (type->tp_finalize) {
PyObject_GC_Track(self);
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
// The finalizer has resurrected the PyObject and there is a new Python
// reference to it, so we can just stop deallocating. Read about
// resurrection from `__del__` here:
// https://docs.python.org/3/reference/datamodel.html#object.__del__
return;
}
PyObject_GC_UnTrack(self);
}
// base test is unnecessary as THPStorae does not set this
if (type->tp_weaklistoffset) {
PyObject_ClearWeakRefs(self);
}
if (type->tp_del) {
PyObject_GC_Track(self);
type->tp_del(self);
if (Py_REFCNT(self) > 0) {
// Resurrected (see above comment about resurrection from `__del__`)
return;
}
PyObject_GC_UnTrack(self);
}
if (has_finalizer) {
/* New weakrefs could be created during the finalizer call.
If this occurs, clear them out without calling their
finalizers since they might rely on part of the object
being finalized that has already been destroyed. */
if (type->tp_weaklistoffset) {
/* Modeled after GET_WEAKREFS_LISTPTR() */
PyWeakReference** list = reinterpret_cast<PyWeakReference**>(
PyObject_GET_WEAKREFS_LISTPTR(self));
while (*list)
_PyWeakref_ClearRef(*list);
}
}
// Clear slots
{
PyTypeObject* base = type;
while (base != &THPStorageType) {
if (Py_SIZE(base)) {
clear_slots(base, self);
}
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
}
// Clear __dict__
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr(self);
if (dictptr != nullptr) {
PyObject* dict = *dictptr;
if (dict != nullptr) {
Py_DECREF(dict);
*dictptr = nullptr;
}
}
}
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
_self->cdata.~MaybeOwned<c10::Storage>();
_self->cdata.~Storage();
Py_TYPE(_self)->tp_free(self);
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
Py_DECREF(type);
}
static PyObject* THPStorage_pynew(
@ -553,64 +389,13 @@ static PyMappingMethods THPStorage_mappingmethods = {
reinterpret_cast<binaryfunc>(THPStorage_get),
reinterpret_cast<objobjargproc>(THPStorage_set)};
struct THPStorageMeta {
PyHeapTypeObject base;
};
static int THPStorageMetaType_init(
PyObject* cls,
PyObject* args,
PyObject* kwargs);
static PyTypeObject THPStorageMetaType = {
PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0)
"torch._C._StorageMeta", /* tp_name */
sizeof(THPStorageMeta), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
// NOLINTNEXTLINE(misc-redundant-expression)
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
nullptr, /* tp_methods */
nullptr, /* tp_members */
nullptr, /* tp_getset */
DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
THPStorageMetaType_init, /* tp_init */
nullptr, /* tp_alloc */
nullptr, /* tp_new */
};
// TODO: implement equality
PyTypeObject THPStorageType = {
PyVarObject_HEAD_INIT(&THPStorageMetaType, 0)
PyVarObject_HEAD_INIT(DEFERRED_ADDRESS(&PyType_Type), 0)
"torch._C.StorageBase", /* tp_name */
sizeof(THPStorage), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
THPStorage_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
@ -649,15 +434,6 @@ PyTypeObject THPStorageType = {
THPStorage_pynew, /* tp_new */
};
int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
return -1;
}
(reinterpret_cast<PyTypeObject*>(cls))->tp_dealloc =
static_cast<destructor>(THPStorage_subclass_dealloc);
return 0;
}
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
HANDLE_TH_ERRORS
THPStorage_assertNotNull(self);
@ -692,13 +468,6 @@ bool THPStorage_init(PyObject* module) {
THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
THPStorageMetaType.tp_base = &PyType_Type;
if (PyType_Ready(&THPStorageMetaType) < 0)
return false;
Py_INCREF(&THPStorageMetaType);
PyModule_AddObject(
module, "_StorageMeta", reinterpret_cast<PyObject*>(&THPStorageMetaType));
THPStorageType.tp_methods = methods.data();
THPStorageType.tp_getset = THPStorage_properties;
if (PyType_Ready(&THPStorageType) < 0)

View File

@ -11,15 +11,13 @@
struct THPStorage {
PyObject_HEAD
c10::MaybeOwned<c10::Storage> cdata;
bool is_hermetic;
c10::Storage cdata;
};
TORCH_PYTHON_API PyObject* THPStorage_Wrap(c10::Storage storage);
TORCH_PYTHON_API PyObject* THPStorage_NewWithStorage(
PyTypeObject* type,
c10::Storage _storage,
bool allow_preexisting_pyobj = false);
c10::Storage _storage);
TORCH_PYTHON_API extern PyTypeObject* THPStorageClass;
inline bool THPStorage_CheckTypeExact(PyTypeObject* tp) {
@ -49,7 +47,7 @@ TORCH_PYTHON_API void THPStorage_assertNotNull(PyObject* obj);
TORCH_PYTHON_API extern PyTypeObject THPStorageType;
inline const c10::Storage& THPStorage_Unpack(THPStorage* storage) {
return *storage->cdata;
return storage->cdata;
}
inline const c10::Storage& THPStorage_Unpack(PyObject* obj) {

View File

@ -529,9 +529,8 @@ static PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) {
THPUtils_typename(new_cdata));
c10::StorageImpl* ptr =
static_cast<c10::StorageImpl*>(PyLong_AsVoidPtr(new_cdata));
self->cdata.~MaybeOwned<c10::Storage>();
self->cdata = c10::MaybeOwned<c10::Storage>::owned(
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr)));
self->cdata =
c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(ptr));
Py_INCREF(self);
return reinterpret_cast<PyObject*>(self);
END_HANDLE_TH_ERRORS

View File

@ -180,7 +180,9 @@ struct TORCH_API AccumulateGrad : public Node {
if (!GradMode::is_enabled() && !new_grad.is_sparse() &&
!new_grad.is_sparse_csr() &&
!(variable.is_sparse_csr() && new_grad.layout() == at::kStrided) &&
at::caching::adjusted_use_count(new_grad) <= num_expected_refs &&
impl::is_tensor_stealable(
new_grad,
num_expected_refs + at::caching::is_cached_tensor(new_grad)) &&
(new_grad.is_mkldnn() ||
utils::obeys_layout_contract(new_grad, variable))) {
// See Case 1.1: Stealable dense new_grad
@ -193,7 +195,7 @@ struct TORCH_API AccumulateGrad : public Node {
// SparseTensor should be the only one holding a reference to these.
new_grad._indices().use_count() <= 1 &&
new_grad._values().use_count() <= 1 &&
new_grad.use_count() <= num_expected_refs) {
impl::is_tensor_stealable(new_grad, num_expected_refs)) {
// Case 1.2: Stealable sparse new_grad
// No scenario where we expect this to be true currently
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(

View File

@ -86,8 +86,8 @@ bool can_accumulate_inplace(const Variable& v) {
v.is_non_overlapping_and_dense() &&
// and we hold the last reference
at::caching::adjusted_use_count(v) == 1 && v.has_storage() &&
v.storage().use_count() == 1);
impl::is_tensor_stealable(v, 1 + at::caching::is_cached_tensor(v)) &&
v.has_storage() && v.storage().use_count() == 1);
}
} // anonymous namespace

View File

@ -54,6 +54,7 @@
using namespace at;
using namespace torch;
using namespace torch::autograd;
using torch::utils::PyObjectPreservation;
namespace {
class OperatorArgsKwargsView {
@ -321,20 +322,15 @@ PyObject* THPVariableClass = nullptr;
PyObject* ParameterClass = nullptr;
static PyObject* THPVariable_NewWithVar(
PyTypeObject* type,
const at::TensorBase& _var,
bool allow_preexisting_pyobj = false,
std::optional<bool> has_torch_dispatch_if_known = std::nullopt);
// clang-tidy gets confused by static const
static constexpr const char* VOLATILE_WARNING =
"volatile was removed and now has no effect. Use "
"`with torch.no_grad():` instead.";
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls);
static bool check_has_torch_dispatch(PyObject* obj) {
PyTypeObject* tp = Py_TYPE(obj);
if (THPVariable_CheckTypeExact(tp)) {
if (THPVariable_CheckExact(obj)) {
return false;
}
py::object attr = PyObject_FastGetAttrString(obj, "__torch_dispatch__");
@ -370,152 +366,86 @@ void activateGPUTrace() {
c10::impl::GPUTrace::set_trace(getPyInterpreter());
}
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
static void check_tensor_subclass(PyObject* obj, PyTypeObject* type) {
TORCH_CHECK(
PyObject_TypeCheck(obj, type),
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
Py_TYPE(obj)->tp_name,
" which is not a subclass of the requested type");
}
// Generic for const Tensor& or Tensor&&
template <typename T>
static PyObject* THPVariable_WrapWithType(
T&& var,
std::optional<PyTypeObject*> desired_type) {
if (!var.defined()) {
Py_RETURN_NONE;
}
if (c10::impl::HermeticPyObjectTLS::get_state()) {
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
}
c10::TensorImpl* tensor_impl = var.unsafeGetTensorImpl();
c10::impl::PyObjectSlot* pyobj_slot = tensor_impl->pyobj_slot();
std::optional<PyObject*> mb_obj =
var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
if (mb_obj.has_value()) {
auto obj = *mb_obj;
if (obj) {
if (var.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
// C++ owns the Python object; this implies there weren't any other
// owning references to the Python object. Since we're making the
// object "live" again on Python side, let's flip back the ownership
// (Python owns C++) as it would now be unsound to deallocate the C++
// object if all C++ references go to zero
var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false);
reinterpret_cast<THPVariable*>(obj)->cdata =
MaybeOwned<Variable>::owned(Variable(var));
// NB: incref is not necessary, because we are "stealing" the previous
// ownership from the Variable to return it here for the wrap
return obj;
}
Py_INCREF(obj);
return obj;
PyObject* obj = pyobj_slot->load_pyobj();
if (obj) {
if (desired_type) {
check_tensor_subclass(obj, *desired_type);
}
// TODO: a better invariant is that if we tagged, we MUST have a valid
// PyObject. That's PyObject preservation
// (https://github.com/pytorch/pytorch/pull/56017). Prior to this PR
// being a thing, the PyObject field will get cleared when all references
// to the Python object are removed.
return Py_NewRef(obj);
}
if (C10_LIKELY(var.device().type() != c10::kXLA)) {
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPVariableClass);
if (desired_type) {
type = *desired_type;
} else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) {
if (auto clazz = getPythonTensorClass(var.device())) {
type = reinterpret_cast<PyTypeObject*>(clazz);
}
}
if (auto clazz = getPythonTensorClass(var.device())) {
return THPVariable_NewWithVar((PyTypeObject*)clazz, var);
obj = type->tp_alloc(type, 0);
TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object");
// Ensure that PyUnstable_TryIncref calls don't fail spuriously in
// free-threaded Python.
PyUnstable_EnableTryIncRef(obj);
auto v = reinterpret_cast<THPVariable*>(obj);
new (&v->cdata) Tensor(std::forward<T>(var));
if (THPVariable_Unpack(obj).is_uniquely_owned()) {
// We can use a faster non-atomic code path if we have the only reference to
// a fresh Tensor.
PyObjectPreservation::init_fresh_nonatomic(tensor_impl, pyobj_slot, obj);
return obj;
}
return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var);
PyObject* wrapper =
PyObjectPreservation::init_once(tensor_impl, pyobj_slot, obj);
if (wrapper != obj) {
// Another thread beat us to it
Py_DECREF(obj);
if (desired_type) {
check_tensor_subclass(wrapper, *desired_type);
}
return Py_NewRef(wrapper);
}
return obj;
}
static bool isResurrectable(THPVariable* self) {
// We want to divide this check into 2 cases.
// 1. C++ owns PyObject (in this case, self->cdata.unsafeIsBorrowed() is
// true). You might think that in this case, it is impossible for tp_clear to
// be called: surely the C++ reference to the PyObject is keeping it live? And
// you'd be right! In fact, when C++ owns the PyObject, we have an invariant
// that the refcount on the PyObject should be precisely one (because if you
// take out another reference to the PyObject, we're supposed to flip the
// ownership pointer back). In reality, you can violate this invariant
// temporarily with weak references, so we don't test for it in asserts.
// 2. PyObject owns C++ (in this case, self->cdata.unsafeIsBorrowed() is
// false). In this case, tp_clear can get called if the PyObject is referenced
// from a dead cycle, and nowhere else. But if resurrection did not occur,
// then the reference to C++ from the PyObject must be the ONLY reference to
// the C++ object.
if (self->cdata.unsafeIsBorrowed()) {
return false;
}
auto const& tensor = THPVariable_Unpack(self);
if (!tensor.defined() || tensor.use_count() <= 1) {
return false;
}
// Check if this is hermetic. If it is, no resurrection.
if (tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false) != (PyObject*)self) {
return false;
}
return true;
PyObject* THPVariable_Wrap(at::TensorBase&& var) {
return THPVariable_WrapWithType(std::move(var), std::nullopt);
}
// returns true if successfully rezzed; if so, cancel the
// rest of deallocation
static bool THPVariable_tryResurrect(THPVariable* self) {
const auto& tensor = THPVariable_Unpack(self);
if (!isResurrectable(self)) {
return false;
}
// At this point, we are definitely going to resurrect the tensor. So, the
// tensor better be defined :)
TORCH_INTERNAL_ASSERT(tensor.defined());
// There are other C++ owners of the tensor. Flip ownership
// so that C++ owns this Python object, and cancel deallocation.
TORCH_INTERNAL_ASSERT(
!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
auto maybe_pyobj = tensor_impl->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
TORCH_INTERNAL_ASSERT(
maybe_pyobj.has_value(),
"Trying to preserve a Python tensor whose PyObjectSlot does not have a PyObject");
tensor_impl->pyobj_slot()->set_owns_pyobj(true);
// Resurrect the Python object. This is something CPython does
// internally occasionally, see
// https://github.com/python/cpython/blob/b98eba5bc2ffbe7a0ed49d540ebc4f756ae61985/Objects/object.c#L248-L259
// so we just copy the pattern here. Note that we don't have to worry
// about saving and restoring the refcount (as the quoted code does)
// because we actually DO need to reset the refcount to one here, we
// can't assume that some other code has taken care of it.
// NB: this will overreport _Py_RefTotal but based on inspection of object.c
// there is no way to avoid this
// When resurrecting, we MUST use _Py_NewReference and not Py_INCREF to
// ensure the PyObject is in a valid state
_Py_NewReference((PyObject*)self);
// Flip THPVariable to be non-owning
// (near use-after-free miss here: fresh MaybeOwned is created breaking
// reference on Tensor in struct BEFORE we overwrite the old one)
TORCH_INTERNAL_ASSERT(!c10::impl::HermeticPyObjectTLS::get_state());
self->cdata = MaybeOwned<Variable>::borrowed(tensor);
// NB: At this point, tensor *could* be dead (e.g., some other C++ thread
// decrefed it.) At this point, it is probably waiting on the GIL to
// deallocate the Python object and will kill self, BUT NOT YET.
return true;
PyObject* THPVariable_Wrap(const at::TensorBase& var) {
return THPVariable_WrapWithType(var, std::nullopt);
}
static int THPFake_traverse(THPVariable* self, visitproc visit, void* arg) {
TORCH_INTERNAL_ASSERT(
false, "TensorBase tp_traverse function was not overridden properly");
return 0;
}
static int THPFake_clear(THPVariable* self) {
TORCH_INTERNAL_ASSERT(
false, "TensorBase tp_clear function was not overridden properly");
return 0;
PyObject* THPVariable_Wrap(const at::TensorBase& var, PyTypeObject* type) {
return THPVariable_WrapWithType(var, type);
}
static PyObject* THPVariable_pynew(
@ -677,16 +607,16 @@ static PyObject* THPVariable_as_subclass(
ParsedArgs<1> parsed_args{};
auto r = parser.parse(_self, args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
TORCH_CHECK_TENSOR_SUBTYPE(cls);
// guard completely turns off torch dispatch modes, doesn't just pop off the
// stack
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
c10::impl::DisablePythonDispatcher dpd_g;
return THPVariable_NewWithVar((PyTypeObject*)cls, self.alias());
PyObject* obj = THPVariable_WrapWithType(self.alias(), (PyTypeObject*)cls);
if (check_has_torch_dispatch(obj)) {
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
}
return obj;
END_HANDLE_TH_ERRORS
}
@ -701,11 +631,7 @@ static PyObject* THPVariable_make_subclass(
ParsedArgs<7> parsed_args{};
auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
TORCH_CHECK_TENSOR_SUBTYPE(cls);
// guard completely turns off torch dispatch modes, doesn't just pop off the
// stack
torch_dispatch_mode::StashTorchDispatchStackGuard td_g;
@ -738,7 +664,11 @@ static PyObject* THPVariable_make_subclass(
data.unsafeGetTensorImpl()->_change_backend_component_keys(r.device(6));
}
return THPVariable_NewWithVar((PyTypeObject*)cls, data);
PyObject* obj = THPVariable_WrapWithType(data, (PyTypeObject*)cls);
if (check_has_torch_dispatch(obj)) {
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
}
return obj;
END_HANDLE_TH_ERRORS
}
@ -835,11 +765,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
TORCH_CHECK_TENSOR_SUBTYPE(cls);
// This is an important safety check; without it, the default behavior will be
// to continue on to the underlying CPU/CUDA kernel advertised by the dispatch
@ -877,6 +803,8 @@ static PyObject* THPVariable_make_wrapper_subclass(
/*storage_size=*/r.toSymIntOptional(14),
r.toDispatchKeySetOptional(13));
tensor.unsafeGetTensorImpl()->set_python_dispatch(true);
const auto sizes_strides_policy = r.stringViewOptional(10);
if (sizes_strides_policy.has_value()) {
tensor.unsafeGetTensorImpl()->set_python_custom_sizes_strides(
@ -892,13 +820,7 @@ static PyObject* THPVariable_make_wrapper_subclass(
tensor.unsafeGetTensorImpl()->set_python_custom_layout(true);
}
return THPVariable_NewWithVar(
(PyTypeObject*)cls,
tensor,
// false is the default
/*allow_preexisting_pyobj=*/false,
// we checked __torch_dispatch__ above; avoid checking again.
/*has_torch_dispatch_if_known=*/true);
return THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls);
END_HANDLE_TH_ERRORS
}
@ -1699,11 +1621,7 @@ static PyObject* THPVariable_dtensor_new(
auto r = parser.parse(args, kwargs, parsed_args);
PyObject* cls = r.pyobject(0);
TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
TORCH_CHECK_TENSOR_SUBTYPE(cls);
#ifndef NDEBUG
// This is specifically for making a DTensor, which we know defines
@ -1756,14 +1674,9 @@ static PyObject* THPVariable_dtensor_new(
/*storage_size=*/std::nullopt,
extra_dispatch_keys);
tensor.set_requires_grad(requires_grad);
py::object py_tensor =
py::reinterpret_steal<py::object>(THPVariable_NewWithVar(
(PyTypeObject*)cls,
tensor,
// false is the default
/*allow_preexisting_pyobj=*/false,
// we know DTensor has __torch_dispatch__; avoid checking again.
/*has_torch_dispatch_if_known=*/true));
tensor.unsafeGetTensorImpl()->set_python_dispatch(true);
py::object py_tensor = py::reinterpret_steal<py::object>(
THPVariable_WrapWithType(std::move(tensor), (PyTypeObject*)cls));
py_tensor.attr(dtensor_interned_strings._spec) = spec;
py_tensor.attr(dtensor_interned_strings._local_tensor) = local_tensor;
return py_tensor.release().ptr();
@ -3440,15 +3353,16 @@ static PyTypeObject THPVariableMetaType = {
nullptr, /* tp_new */
};
static void THPVariable_dealloc(PyObject* self);
static int THPVariable_clear(THPVariable* self);
static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg);
static PyTypeObject THPVariableType = {
PyVarObject_HEAD_INIT(&THPVariableMetaType, 0)
"torch._C.TensorBase", /* tp_name */
sizeof(THPVariable), /* tp_basicsize */
0, /* tp_itemsize */
// This is unspecified, because it is illegal to create a THPVariableType
// directly. Subclasses will have their tp_dealloc set appropriately
// by the metaclass
nullptr, /* tp_dealloc */
THPVariable_dealloc, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
@ -3467,9 +3381,8 @@ static PyTypeObject THPVariableType = {
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
Py_TPFLAGS_HAVE_GC, /* tp_flags */
nullptr, /* tp_doc */
// Also set by metaclass
(traverseproc)THPFake_traverse, /* tp_traverse */
(inquiry)THPFake_clear, /* tp_clear */
(traverseproc)THPVariable_traverse, /* tp_traverse */
(inquiry)THPVariable_clear, /* tp_clear */
nullptr, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
@ -3498,345 +3411,68 @@ PyObject* THPVariable_pynew(
type != &THPVariableType,
"Cannot directly construct TensorBase; subclass it and then construct that");
jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR);
auto tensor = torch::utils::base_tensor_ctor(args, kwargs);
// WARNING: tensor is NOT guaranteed to be a fresh tensor; e.g., if it was
// given a raw pointer that will refcount bump
// NB: base_tensor_ctor can call into dispatched ATen functions (e.g.,
// alias(), lift_fresh()) which can return Tensor subclasses. We allow
// these to be passed on directly.
return THPVariable_NewWithVar(
type,
tensor,
/*allow_preexisting_pyobj=*/true);
PyObject* obj = THPVariable_WrapWithType(
torch::utils::base_tensor_ctor(args, kwargs), type);
if (check_has_torch_dispatch(obj)) {
THPVariable_Unpack(obj).unsafeGetTensorImpl()->set_python_dispatch(true);
}
return obj;
END_HANDLE_TH_ERRORS
}
static int THPVariable_subclass_clear(THPVariable* self) {
// Is it OK for an object to still be live after running
// tp_clear? Yes. When Python is breaking reference cycles, it can't assume
// that an object will dealloc after it's cleared. The source code explicitly
// handles this case:
// https://github.com/python/cpython/blob/4e661cd69164318c1f871faa476c68a04092ddc4/Modules/gcmodule.c#L1010-L1025
// Note that we don't need to actually resurrect here. There are 2 cases:
// 1. The PyObject is not part of a reference cycle. In this case, we don't
// need to do anything. The GC will move on to try and break the reference
// cycle on another object, which will eventually trigger tp_dealloc (and thus
// resurrection).
// 2. The PyObject is part of a reference cycle. This case should not actually
// be possible, due to the logic in our tp_traverse
// (THPVariable_subclass_traverse).
// In fact, resurrecting here breaks the invariant that "C++ owns Python only
// when PyObject's refcount would otherwise be 0". Most immediately, as we're
// merely breaking reference cycles here, there can be other references to the
// PyObject. *However*, if other objects in the refcycle resurrect, then we
// will be in a state where the PyObject has multiple Python references, yet
// C++ owns the PyObject.
// See https://github.com/pytorch/pytorch/pull/75933 for more discussion.
if (isResurrectable(self)) {
return 0;
}
static int THPVariable_clear(THPVariable* self) {
// First clear Tensor specific things
Py_CLEAR(self->backward_hooks);
Py_CLEAR(self->post_accumulate_grad_hooks);
const auto& tensor = THPVariable_Unpack(self);
if (tensor.defined()) {
// Two situations to consider:
// PyObject -owns-> Tensor
// unsafeIsBorrowed() is FALSE. We're obligated to look through
// Tensor to break references. Clearing cdata must induce the
// destruction of the C++ Tensor. If there were other references
// to C++ tensor, the Python object would have been resurrected
// by flipping the ownership.
// Tensor -owns-> PyObject
// unsafeIsBorrowed() is TRUE. We're deallocating the PyObject
// because Tensor asked us to (it's already destructing).
if (self->cdata.defined()) {
auto pyobj_slot = self->cdata.unsafeGetTensorImpl()->pyobj_slot();
// Typically the Tensor's pyobj_slot points back to this object. The only
// time that's not the case is if we had a race in THPVariable_Wrap and we
// need to discard the Python object because some other thread beat us to
// setting the pyobj_slot.
if (pyobj_slot->load_pyobj() == (PyObject*)self) {
// A Tensor's Python object should only be destroyed when the Tensor has
// no other references too.
TORCH_INTERNAL_ASSERT(self->cdata.use_count() == 1);
if (!self->cdata.unsafeIsBorrowed() &&
tensor.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false) == (PyObject*)self) {
// TODO: empirically, on OS X this assert appears to be untrue
// In test_py_tensors_multi_async_call - ProcessGroupRpcTestWithSpawn
// distributed/rpc/test_process_group_agent.py
//
// libc++abi.dylib: terminating with uncaught exception of type
// c10::Error:
// !tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()INTERNAL
// ASSERT FAILED at "../torch/csrc/autograd/python_variable.cpp":171,
// please report a bug to PyTorch. Exception raised from
// THPVariable_subclass_clear at
// ../torch/csrc/autograd/python_variable.cpp:171 (most recent call
// first): frame #0: c10::Error::Error(c10::SourceLocation,
// std::__1::basic_string<char, std::__1::char_traits<char>,
// std::__1::allocator<char> >) + 98 (0x1158a0442 in libc10.dylib) frame
// #1: c10::detail::torchCheckFail(char const*, char const*, unsigned
// int, char const*) + 205 (0x11589ed3d in libc10.dylib) frame #2:
// c10::detail::torchInternalAssertFail(char const*, char const*,
// unsigned int, char const*, c10::detail::CompileTimeEmptyString) + 9
// (0x1141e3f89 in libtorch_python.dylib) frame #3:
// THPVariable_subclass_clear(THPVariable*) + 412 (0x1148a547c in
// libtorch_python.dylib) frame #4:
// THPVariable_subclass_dealloc(_object*) + 453 (0x1148a5035 in
// libtorch_python.dylib) frame #5: (anonymous
// namespace)::concrete_decref_fn(c10::impl::PyInterpreter const*,
// _object*) + 53 (0x1148a5ea5 in libtorch_python.dylib) frame #6:
// c10::TensorImpl::release_resources() + 182 (0x11588c4a6 in
// libc10.dylib) frame #7:
// c10::MaybeOwned<at::Tensor>::operator=(c10::MaybeOwned<at::Tensor>&&)
// + 91 (0x11488c11b in libtorch_python.dylib) frame #8:
// THPVariable_subclass_dealloc(_object*) + 607 (0x1148a50cf in
// libtorch_python.dylib) <omitting python frames> frame #47: start + 1
// (0x7fff6ffc7cc9 in libdyld.dylib) frame #48: 0x0 + 4 (0x4 in ???)
// TORCH_INTERNAL_ASSERT(!tensor.unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj());
if (auto grad_acc =
torch::autograd::impl::try_get_grad_accumulator(tensor)) {
grad_acc->pre_hooks().clear();
grad_acc->tensor_pre_hooks().clear();
grad_acc->retains_grad_hooks().clear();
}
// Clear the pyobj_slot so that a try_incref() call from
// weak_intrusive_ptr::lock() won't see a freed pointer.
pyobj_slot->clear();
}
}
TORCH_INTERNAL_ASSERT(!isResurrectable(self));
{
// MapAllocator can take significant time to release large tensors;
// release the GIL here to avoid impacting main thread perf.
pybind11::gil_scoped_release no_gil;
self->cdata = MaybeOwned<Variable>();
self->cdata = Variable();
}
// Since we override the basic subtype_clear from CPython, we need a crappy
// version here just like for traverse and dealloc
// Clear all slots until we get to the base Tensor class
PyTypeObject* type = Py_TYPE((PyObject*)self);
PyTypeObject* base = type;
while (base != &THPVariableType) {
if (Py_SIZE(base))
clear_slots(base, (PyObject*)self);
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
// Assume we never have managed dict for Tensors as we don't set the flag on
// the base class
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr((PyObject*)self);
if (dictptr && *dictptr)
Py_CLEAR(*dictptr);
}
return 0;
}
// NB: this is not the tp_dealloc on THPVariable; instead, its the dealloc
// on subclasses. It's never valid to construct a THPVariable so it's not
// necessary to implement the dealloc for that case
static void THPVariable_subclass_dealloc(PyObject* self) {
if (THPVariable_tryResurrect((THPVariable*)self))
return;
// This is like a crappy version of subtype_dealloc.
// Unfortunately, we cannot directly delegate to
// subtype_dealloc as it will start walking the parent
// chain *starting with* the type of self, which will cause
// us to go back to our custom dealloc.
//
// We have to replicate the subtype_dealloc logic to ensure
// that finalizers are handled correctly
PyTypeObject* type = Py_TYPE(self);
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
TORCH_INTERNAL_ASSERT(PyType_IS_GC(type), "GC types not implemented");
static void THPVariable_dealloc(PyObject* self) {
PyObject_GC_UnTrack(self);
// TODO: consider using trash can
bool has_finalizer = type->tp_finalize || type->tp_del;
if (type->tp_finalize) {
PyObject_GC_Track(self);
if (PyObject_CallFinalizerFromDealloc(self) < 0) {
/* Resurrected */
return;
}
PyObject_GC_UnTrack(self);
}
// base test is unnecessary as THPVariable does not set this
if (type->tp_weaklistoffset) {
PyObject_ClearWeakRefs(self);
}
if (type->tp_del) {
PyObject_GC_Track(self);
type->tp_del(self);
if (Py_REFCNT(self) > 0) {
/* Resurrected */
return;
}
PyObject_GC_UnTrack(self);
}
if (has_finalizer) {
/* New weakrefs could be created during the finalizer call.
If this occurs, clear them out without calling their
finalizers since they might rely on part of the object
being finalized that has already been destroyed. */
if (type->tp_weaklistoffset) {
/* Modeled after GET_WEAKREFS_LISTPTR() */
PyWeakReference** list =
(PyWeakReference**)PyObject_GET_WEAKREFS_LISTPTR(self);
while (*list)
_PyWeakref_ClearRef(*list);
}
}
// Clear all slots until we get to base class THPVariableType
{
PyTypeObject* base = type;
while (base != &THPVariableType) {
if (Py_SIZE(base)) {
clear_slots(base, self);
}
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
}
// All Python defined classes have __dict__
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr(self);
if (dictptr != nullptr) {
PyObject* dict = *dictptr;
if (dict != nullptr) {
Py_DECREF(dict);
*dictptr = nullptr;
}
}
}
// subtype_dealloc allows for this but we don't
TORCH_INTERNAL_ASSERT(Py_TYPE(self) == type);
// Finally clear out the base THPVariable
THPVariable_subclass_clear((THPVariable*)self);
((THPVariable*)self)->cdata.~MaybeOwned<Variable>();
THPVariable_clear((THPVariable*)self);
((THPVariable*)self)->cdata.~Variable();
Py_TYPE(self)->tp_free(self);
// Python defined subclasses should always be on the heap
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
Py_DECREF(type);
}
// Creates a new Python object for a Variable.
static PyObject* THPVariable_NewWithVar(
PyTypeObject* type,
const at::TensorBase& _var,
bool allow_preexisting_pyobj,
std::optional<bool> has_torch_dispatch_if_known) {
// Make sure that the reinterpret into a THPVariable* will be valid
TORCH_CHECK(
type == &THPVariableType || PyType_IsSubtype(type, &THPVariableType),
"Creating a Tensor subclass from a class ",
"that does not inherit from Tensor is not possible. Make sure your class inherits from Tensor.");
// This function overwrite the Tensor's pyobj field without extra checks
// Make sure it is not set otherwise we would leak memory
auto mb_obj = _var.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
/*ignore_hermetic_tls=*/false);
// Under some circumstances, we may attempt to create a new Python
// object for a variable that already has a Python object. The most common
// situation this can occur is if you have a TorchDispatchMode active that
// is returning a subclass from lift_fresh (which is invoked to
// appropriately "wrap" a constant tensor into whatever ambient modes are
// active.)
//
// In general, it is impossible to handle this case compositionally.
// Suppose you have a user call ATensor([1, 2, 3]) when a mode is active
// that is transforming all ops (including the internal lift_fresh call that
// transforms [1, 2, 3] into a torch.tensor([1., 2., 3.])) to output
// BTensor, where ATensor and BTensor are completely unrelated subclasses
// and there is no way to compose them. There is no way to satisfy the user
// request here: in particular, you can't just try to re-invoke the ATensor
// constructor on the returned BTensor, because (1) this could cause an
// infinite loop--we are already in ATensor.__new__ and (2) there isn't any
// guarantee that ATensor.__new__ supports a single element constructor
// anyway.
//
// However, a more common case is a user just called torch.Tensor([1, 2, 3]),
// and a fake tensor mode is active. Really, all you want is to get back
// a FakeTensor, in the same way torch.tensor([1, 2, 3]) or torch.arange(3)
// would have returned a fake tensor (concretely, the way this happens
// is we create a *real* tensor torch.tensor([1., 2., 3.]), and then it
// turns into a FakeTensor when we call lift_fresh on this real tensor).
// This case is compositional because FakeTensor is a subclass of Tensor, so
// it's valid for us to return it in place of a Tensor. So this is what we
// do.
if (mb_obj.has_value() && mb_obj.value()) {
TORCH_CHECK(
allow_preexisting_pyobj,
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
mb_obj.value()->ob_type->tp_name);
// Even if we allow pre-existing PyObject, we don't allow completely
// ignoring the requested type. Check that we fulfilled a subtype
// relation here. In the common case the requested type is Tensor and
// this always succeeds.
PyObject* obj = *mb_obj;
// Check if it's OK to just directly return the Python object without
// allocating a new variable. We just check that the existing Python
// object is a subclass of the requested type.
PyTypeObject* obj_type = Py_TYPE(obj);
TORCH_CHECK(
obj_type == type || PyType_IsSubtype(obj_type, type),
"Creating a new Tensor subclass ",
type->tp_name,
" but the raw Tensor object is already associated to a python object ",
"of type ",
mb_obj.value()->ob_type->tp_name,
" which is not a subclass of the "
"requested type");
// We may (in fact, we typically will) need to resurrect this
return THPVariable_Wrap(_var);
}
PyObject* obj = type->tp_alloc(type, 0);
if (obj) {
auto v = (THPVariable*)obj;
// TODO: named constructor to avoid default initialization
new (&v->cdata) MaybeOwned<Variable>();
if (c10::impl::HermeticPyObjectTLS::get_state()) {
// Do NOT initialize pyobj field on the tensor, you own the C++
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
TORCH_INTERNAL_ASSERT(
!check_has_torch_dispatch(obj),
"While HermeticPyObject was enabled, we attempted to create a tensor "
"subclass with __torch_dispatch__. This violates the invariant that "
"operations in HermeticPyObject have equivalent C++ implementations. "
"If your operator registered from Python operator registration isn't "
"doing anything strange, there may be an internal PyTorch bug involving "
"not appropriately disabling TorchDispatchMode before executing "
"Python op registration.");
} else {
// Normal codepath
v->cdata = MaybeOwned<Variable>::owned(Variable(_var));
const auto& var = THPVariable_Unpack(v);
var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(obj);
if (has_torch_dispatch_if_known.has_value()
? *has_torch_dispatch_if_known
: check_has_torch_dispatch(obj)) {
var.unsafeGetTensorImpl()->set_python_dispatch(true);
}
}
}
return obj;
static void TORCH_CHECK_TENSOR_SUBTYPE(PyObject* cls) {
TORCH_CHECK_TYPE(
PyType_Check(cls),
"cls must be a type (got ",
Py_TYPE(cls)->tp_name,
")");
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(cls);
TORCH_CHECK_TYPE(
type == &THPVariableType || cls == THPVariableClass ||
PyType_IsSubtype(type, &THPVariableType),
"Creating a Tensor subclass from a class that does not inherit from "
"Tensor is not possible. Make sure your class inherits from Tensor.");
}
/// NOTE [ PyObject Traversal ]
@ -3855,7 +3491,7 @@ static PyObject* THPVariable_NewWithVar(
/// into account these C++ ownership links.
///
/// The main danger here comes from the fact that, while all python-related code
/// is thread safe wrt the GC execution (thanks to the GIL), other threads might
/// is thread safe wrt the GC execution, other threads might
/// be using our C++ objects arbitrarily which can lead to shared_ptr ref count
/// going up or down in between the different traverse/clear invocations. The
/// one constraint we add here that is not explicitly mentioned in the GC
@ -3885,124 +3521,46 @@ static PyObject* THPVariable_NewWithVar(
/// https://github.com/pytorch/pytorch/issues/7343
///
static int traverse_slots(
PyTypeObject* type,
PyObject* self,
visitproc visit,
void* arg) {
auto n = Py_SIZE(type);
auto mp = type->tp_members;
for (Py_ssize_t i = 0; i < n; i++, mp++) {
if (mp->type == T_OBJECT_EX) {
char* addr = (char*)self + mp->offset;
PyObject* obj = *(PyObject**)addr;
if (obj != nullptr) {
int err = visit(obj, arg);
if (err)
return err;
}
}
}
return 0;
}
static int THPVariable_subclass_traverse(
PyObject* self,
visitproc visit,
void* arg) {
// If the tensor is eligible to be resurrected, don't traverse it; instead
// treat all of its references as a root (as they WOULD be a root since we
// can treat the inbound C++ references as root owners).
//
// This works because unlike conventional GCs, Python's GC operates in two
// phases: first it uses traverse to discover roots, and then it uses traverse
// to do reachability. Bypassing traverse during root discovery forces Python
// to treat self as a root for everything it refers to. For a full
// explanation of the algorithm see
// https://devguide.python.org/garbage_collector/
//
// NB: if we don't hold an owning reference to the underlying Tensor, it is
// possible that the underlying Tensor has already gone dead. In that case,
// it's not safe to access it. But it's also safe to traverse, because if
// the underlying Tensor *is* live, then root discovery will determine that
// self is live, and nothing will get GC'ed anyway (resurrection cannot happen
// if the C++ objects owns the PyObject)
static int THPVariable_traverse(PyObject* self, visitproc visit, void* arg) {
THPVariable* var = reinterpret_cast<THPVariable*>(self);
if (isResurrectable(var)) {
return 0;
}
// Crappy version of subtype_traverse; same deal as
// THPVariable_subclass_dealloc
PyTypeObject* type = Py_TYPE(self);
// Traverse slots until we get to base class THPVariableType
{
PyTypeObject* base = type;
while (base != &THPVariableType) {
if (Py_SIZE(base)) {
int err = traverse_slots(base, self, visit, arg);
if (err)
return err;
}
base = base->tp_base;
TORCH_INTERNAL_ASSERT(base);
}
}
// All Python defined classes have __dict__
if (C10_LIKELY(type->tp_dictoffset)) {
PyObject** dictptr = _PyObject_GetDictPtr(self);
if (dictptr && *dictptr)
Py_VISIT(*dictptr);
}
TORCH_INTERNAL_ASSERT(type->tp_flags & Py_TPFLAGS_HEAPTYPE);
Py_VISIT(type);
// Finally traverse THPVariable special stuff
Py_VISIT(var->backward_hooks);
Py_VISIT(var->post_accumulate_grad_hooks);
if (!var->cdata.unsafeIsBorrowed()) {
const auto& tensor = THPVariable_Unpack(var);
if (tensor.defined()) {
// WARNING: The grad_fn traversal logic is very subtle, if you change
// this, be very careful not to re-introduce this bug:
// https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
const auto& tensor = THPVariable_Unpack(var);
if (tensor.defined()) {
// WARNING: The grad_fn traversal logic is very subtle, if you change
// this, be very careful not to re-introduce this bug:
// https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
// We ensure that we follow NOTE [ PyObject Traversal ] he by checking
// that this python object is the sole owner of the underlying Tensor and
// that this Tensor is the sole owner of its grad_fn. In this case, the
// only way to get a new reference to the grad_fn is by using this python
// object, which requires the GIL to be accessed. Note that this is only
// valid as long as user don't share non-owning references across
// different threads (which is crazy and should never be done).
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
if (tensor.use_count() == 1) {
if (autograd_meta) {
// Do NOT call grad_fn() here as that might trigger a recompute
const auto& grad_fn = autograd_meta->grad_fn_;
if (grad_fn && grad_fn.use_count() == 1) {
// All Node can have a pyobj (stored in "pyobj_")
Py_VISIT(grad_fn->pyobj());
// PyNode are special as they also have an "obj" field
if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
Py_VISIT(py_node_fn->obj);
}
}
}
}
// We ensure that we follow NOTE [ PyObject Traversal ] he by checking
// that this python object is the sole owner of the underlying Tensor and
// that this Tensor is the sole owner of its grad_fn. In this case, the
// only way to get a new reference to the grad_fn is by using this python
// object, which requires the GIL to be accessed. Note that this is only
// valid as long as user don't share non-owning references across
// different threads (which is crazy and should never be done).
auto autograd_meta = torch::autograd::impl::get_autograd_meta(tensor);
if (tensor.use_count() == 1) {
if (autograd_meta) {
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
if (auto pyhook =
dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
// Do NOT call grad_fn() here as that might trigger a recompute
const auto& grad_fn = autograd_meta->grad_fn_;
if (grad_fn && grad_fn.use_count() == 1) {
// All Node can have a pyobj (stored in "pyobj_")
Py_VISIT(grad_fn->pyobj());
// PyNode are special as they also have an "obj" field
if (auto py_node_fn = dynamic_cast<PyNode*>(grad_fn.get())) {
Py_VISIT(py_node_fn->obj);
}
}
}
}
if (autograd_meta) {
for (const auto& hook : torch::autograd::impl::hooks(tensor)) {
if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}
}
return 0;
}
@ -4010,17 +3568,6 @@ int THPVariableMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
return -1;
}
// It is important for all three of these to be overridden correctly for the
// resurrection checks to properly happen. In particular, an older version
// was not overriding tp_clear here. This lead to the default subtype_clear
// running on the Tensor object (as only TensorBase tp_clear was custom),
// clearing the __dict__ field, before the TensorBase custom clear was called
// and would properly detect the resurrect.
// See https://github.com/pytorch/pytorch/issues/136358 for the exact behavior
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPVariable_subclass_dealloc;
((PyTypeObject*)cls)->tp_traverse =
(traverseproc)THPVariable_subclass_traverse;
((PyTypeObject*)cls)->tp_clear = (inquiry)THPVariable_subclass_clear;
// Don't do anything for the base Tensor class
if (!THPVariableClass) {

View File

@ -17,7 +17,7 @@ namespace py = pybind11;
struct THPVariable {
PyObject_HEAD
// Payload
c10::MaybeOwned<at::Tensor> cdata;
at::Tensor cdata;
// Hooks to be run on backwards pass (corresponds to Python attr
// '_backwards_hooks', set by 'register_hook')
PyObject* backward_hooks = nullptr;
@ -37,7 +37,11 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass;
TORCH_PYTHON_API extern PyObject* ParameterClass;
bool THPVariable_initModule(PyObject* module);
TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase&& var);
TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var);
TORCH_PYTHON_API PyObject* THPVariable_Wrap(
const at::TensorBase& var,
PyTypeObject* type);
inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
// Check that a python object is a `Tensor`, but not a `Tensor` subclass.
@ -69,7 +73,7 @@ inline bool THPVariable_Check(PyObject* obj) {
}
inline const at::Tensor& THPVariable_Unpack(THPVariable* var) {
return *var->cdata;
return var->cdata;
}
inline const at::Tensor& THPVariable_Unpack(PyObject* obj) {

View File

@ -65,7 +65,9 @@ inline at::Tensor clone_obey_contract(
.new_empty_strided_symint(
variable.sym_sizes(),
variable.sym_strides(),
variable.options().memory_format(std::nullopt))
variable.options()
.memory_format(std::nullopt)
.dtype(new_grad.dtype()))
.copy_(new_grad));
} else {
// (2)

View File

@ -70,6 +70,10 @@ inline PyObject* wrap(const at::Tensor& tensor) {
return THPVariable_Wrap(tensor);
}
inline PyObject* wrap(at::Tensor&& tensor) {
return THPVariable_Wrap(std::move(tensor));
}
inline PyObject* wrap(const at::Scalar& scalar) {
return wrap(scalar_to_tensor(scalar));
}

View File

@ -197,6 +197,22 @@ TORCH_API std::unique_ptr<PostAccumulateGradHook>& post_acc_grad_hooks(
TORCH_API void create_cpp_hook(
const at::TensorBase& /*self*/,
bool is_retains_grad_hooks = false);
inline bool is_tensor_stealable(
const at::Tensor& new_grad,
size_t num_expected_refs = 1) {
size_t use_count = new_grad.use_count();
if (use_count <= num_expected_refs) {
return true;
}
if (use_count >= 2 &&
new_grad.unsafeGetTensorImpl()->pyobj_slot()->has_unique_reference()) {
// The Python wrapper, if it exists, also has a reference to the Tensor.
num_expected_refs++;
}
return use_count <= num_expected_refs;
}
} // namespace impl
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -894,7 +910,7 @@ inline Variable make_variable(
bool requires_grad = false,
bool allow_tensor_metadata_change = true) {
if (data.defined()) {
if (data.getIntrusivePtr().use_count() == 1 &&
if (impl::is_tensor_stealable(data) &&
data.getIntrusivePtr()->unique_version()) {
auto data_impl = data.unsafeReleaseIntrusivePtr();
data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);

View File

@ -1,19 +1,67 @@
#include <torch/csrc/utils/pyobject_preservation.h>
#include <structmember.h>
#include <c10/core/impl/PyObjectSlot.h>
#include <c10/util/intrusive_ptr.h>
void clear_slots(PyTypeObject* type, PyObject* self) {
Py_ssize_t n = Py_SIZE(type);
PyMemberDef* mp = type->tp_members;
namespace torch::utils {
for (Py_ssize_t i = 0; i < n; i++, mp++) {
if (mp->type == T_OBJECT_EX && !(mp->flags & READONLY)) {
char* addr = (char*)self + mp->offset;
PyObject* obj = *(PyObject**)addr;
if (obj != nullptr) {
*(PyObject**)addr = nullptr;
Py_DECREF(obj);
}
}
}
using c10::intrusive_ptr_target;
using c10::impl::PyObjectSlot;
void PyObjectPreservation::init_fresh_nonatomic(
intrusive_ptr_target* target,
PyObjectSlot* slot,
PyObject* pyobj) {
TORCH_INTERNAL_ASSERT(slot->load_pyobj() == nullptr);
TORCH_INTERNAL_ASSERT(
target->combined_refcount_.load(std::memory_order_relaxed) ==
c10::detail::kUniqueRef);
slot->pyobj_.store(pyobj, std::memory_order_relaxed);
slot->pyobj_interpreter_.store(
c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed);
target->combined_refcount_.store(
c10::detail::kHasPyObject | c10::detail::kUniqueRef,
std::memory_order_relaxed);
}
PyObject* PyObjectPreservation::init_once(
intrusive_ptr_target* target,
PyObjectSlot* slot,
PyObject* pyobj) {
PyObject* expected = nullptr;
if (!slot->pyobj_.compare_exchange_strong(
expected, pyobj, std::memory_order_acq_rel)) {
TORCH_INTERNAL_ASSERT(expected != nullptr);
return expected;
}
slot->pyobj_interpreter_.store(
c10::impl::getGlobalPyInterpreter(), std::memory_order_release);
bool increfed = false;
auto combined = target->combined_refcount_.load(std::memory_order_relaxed);
do {
TORCH_INTERNAL_ASSERT(!c10::detail::has_pyobject(combined));
if (c10::detail::refcount(combined) > 1 && !increfed) {
// We need to incref the object to preserve the invariant that
// if refcount > 1, the c10 object holds a reference to the PyObject.
// This must happen before we set the kHasPyObject bit.
Py_INCREF(pyobj);
increfed = true;
}
} while (!target->combined_refcount_.compare_exchange_weak(
combined,
combined | c10::detail::kHasPyObject,
std::memory_order_acq_rel,
std::memory_order_relaxed));
if (increfed && c10::detail::refcount(combined) == 1) {
// Fix up if refcount if we did the incref in a failed compare-exchange
Py_DECREF(pyobj);
}
return pyobj;
}
} // namespace torch::utils

View File

@ -4,4 +4,28 @@
// This file contains utilities used for handling PyObject preservation
void clear_slots(PyTypeObject* type, PyObject* self);
namespace c10 {
class intrusive_ptr_target;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10
namespace torch::utils {
class PyObjectPreservation {
public:
// Store a PyObject wrapper on a fresh c10 wrapper. The caller must hold
// a unique reference to `target`.
static void init_fresh_nonatomic(
c10::intrusive_ptr_target* target,
c10::impl::PyObjectSlot* slot,
PyObject* pyobj);
static PyObject* init_once(
c10::intrusive_ptr_target* target,
c10::impl::PyObjectSlot* slot,
PyObject* pyobj);
};
} // namespace torch::utils

View File

@ -207,12 +207,19 @@ def tensorify_python_scalars(
and node.target is torch.ops.aten._local_scalar_dense.default
):
dtype = node.args[0].meta["val"].dtype
if not dtype.is_floating_point:
continue
assert isinstance(node.args[0], fx.Node), node.args[0]
s = node.meta["val"].node.expr
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
# only tensorify if the dtype is floating point
if not dtype.is_floating_point:
continue
expr_to_tensor_proxy[s] = MetaProxy(
node.args[0], tracer=tracer, fake_mode=fake_mode
)
@ -220,9 +227,7 @@ def tensorify_python_scalars(
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
expr_to_tensor_proxy[s], torch.float64
)
expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode
)
# pyrefly: ignore [bad-argument-type]
elif (sym_expr := _get_sym_val(node)) is not None:
if sym_expr not in expr_to_sym_proxy and not isinstance(